diff --git a/README.md b/README.md index 05fcb23f7edd657f2ea495d848fadc226e56b524..16d354ca7b150814f11fd825d6a22c84cebc2a01 100644 --- a/README.md +++ b/README.md @@ -22,6 +22,8 @@ organization for the purposes of conducting machine learning and deep neural networks research. The system is general enough to be applicable in a wide variety of other domains, as well. +TensorFlow provides stable Python API and C APIs as well as without API backwards compatibility guarantee like C++, Go, Java, JavaScript and Swift. + Keep up to date with release announcements and security updates by subscribing to [announce@tensorflow.org](https://groups.google.com/a/tensorflow.org/forum/#!forum/announce). @@ -81,13 +83,13 @@ The TensorFlow project strives to abide by generally accepted best practices in | Build Type | Status | Artifacts | | --- | --- | --- | -| **Linux CPU** | ![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/ubuntu-cc.png) | [pypi](https://pypi.org/project/tf-nightly/) | -| **Linux GPU** | ![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/ubuntu-gpu-cc.png) | [pypi](https://pypi.org/project/tf-nightly-gpu/) | -| **Linux XLA** | TBA | TBA | -| **MacOS** | ![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/macos-py2-cc.png) | [pypi](https://pypi.org/project/tf-nightly/) | -| **Windows CPU** | [![Status](https://ci.tensorflow.org/buildStatus/icon?job=tensorflow-master-win-cmake-py)](https://ci.tensorflow.org/job/tensorflow-master-win-cmake-py) | [pypi](https://pypi.org/project/tf-nightly/) | -| **Windows GPU** | [![Status](http://ci.tensorflow.org/job/tf-master-win-gpu-cmake/badge/icon)](http://ci.tensorflow.org/job/tf-master-win-gpu-cmake/) | [pypi](https://pypi.org/project/tf-nightly-gpu/) | -| **Android** | [![Status](https://ci.tensorflow.org/buildStatus/icon?job=tensorflow-master-android)](https://ci.tensorflow.org/job/tensorflow-master-android) | [![Download](https://api.bintray.com/packages/google/tensorflow/tensorflow/images/download.svg)](https://bintray.com/google/tensorflow/tensorflow/_latestVersion) [demo APK](https://ci.tensorflow.org/view/Nightly/job/nightly-android/lastSuccessfulBuild/artifact/out/tensorflow_demo.apk), [native libs](https://ci.tensorflow.org/view/Nightly/job/nightly-android/lastSuccessfulBuild/artifact/out/native/) [build history](https://ci.tensorflow.org/view/Nightly/job/nightly-android/) | +| **Linux CPU** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/ubuntu-cc.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/ubuntu-cc.html) | [pypi](https://pypi.org/project/tf-nightly/) | +| **Linux GPU** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/ubuntu-gpu-py3.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/ubuntu-gpu-py3.html) | [pypi](https://pypi.org/project/tf-nightly-gpu/) | +| **Linux XLA** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/ubuntu-xla.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/ubuntu-xla.html) | TBA | +| **MacOS** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/macos-py2-cc.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/macos-py2-cc.html) | [pypi](https://pypi.org/project/tf-nightly/) | +| **Windows CPU** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/windows-cpu.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/windows-cpu.html) | [pypi](https://pypi.org/project/tf-nightly/) | +| **Windows GPU** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/windows-gpu.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/windows-gpu.html) | [pypi](https://pypi.org/project/tf-nightly-gpu/) | +| **Android** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/android.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/android.html) | [![Download](https://api.bintray.com/packages/google/tensorflow/tensorflow/images/download.svg)](https://bintray.com/google/tensorflow/tensorflow/_latestVersion) | ### Community Supported Builds @@ -97,17 +99,20 @@ The TensorFlow project strives to abide by generally accepted best practices in | **IBM s390x** | [![Build Status](http://ibmz-ci.osuosl.org/job/TensorFlow_IBMZ_CI/badge/icon)](http://ibmz-ci.osuosl.org/job/TensorFlow_IBMZ_CI/) | TBA | | **IBM ppc64le CPU** | [![Build Status](http://powerci.osuosl.org/job/TensorFlow_Ubuntu_16.04_CPU/badge/icon)](http://powerci.osuosl.org/job/TensorFlow_Ubuntu_16.04_CPU/) | TBA | | **IBM ppc64le GPU** | [![Build Status](http://powerci.osuosl.org/job/TensorFlow_Ubuntu_16.04_PPC64LE_GPU/badge/icon)](http://powerci.osuosl.org/job/TensorFlow_Ubuntu_16.04_PPC64LE_GPU/) | TBA | -| **Linux CPU with Intel® MKL-DNN®** | [![Build Status](https://tensorflow-ci.intel.com/job/tensorflow-mkl-linux-cpu/badge/icon)](https://tensorflow-ci.intel.com/job/tensorflow-mkl-linux-cpu/) | TBA | +| **Linux CPU with Intel® MKL-DNN** Nightly | [![Build Status](https://tensorflow-ci.intel.com/job/tensorflow-mkl-linux-cpu/badge/icon)](https://tensorflow-ci.intel.com/job/tensorflow-mkl-linux-cpu/) | [Nightly](https://tensorflow-ci.intel.com/job/tensorflow-mkl-build-whl-nightly/) | +| **Linux CPU with Intel® MKL-DNN** Python 2.7
**Linux CPU with Intel® MKL-DNN** Python 3.5
**Linux CPU with Intel® MKL-DNN** Python 3.6 | [![Build Status](https://tensorflow-ci.intel.com/job/tensorflow-mkl-build-release-whl/badge/icon)](https://tensorflow-ci.intel.com/job/tensorflow-mkl-build-release-whl/lastStableBuild)|[1.9.0 py2.7](https://storage.googleapis.com/intel-optimized-tensorflow/tensorflow-1.9.0-cp27-cp27mu-linux_x86_64.whl)
[1.9.0 py3.5](https://storage.googleapis.com/intel-optimized-tensorflow/tensorflow-1.9.0-cp35-cp35m-linux_x86_64.whl)
[1.9.0 py3.6](https://storage.googleapis.com/intel-optimized-tensorflow/tensorflow-1.9.0-cp36-cp36m-linux_x86_64.whl) | ## For more information - +* [Tensorflow Blog](https://medium.com/tensorflow) +* [TensorFlow Course at Stanford](https://web.stanford.edu/class/cs20si) +* [TensorFlow Model Zoo](https://github.com/tensorflow/models) +* [TensorFlow MOOC on Udacity](https://www.udacity.com/course/deep-learning--ud730) +* [TensorFlow Roadmap](https://www.tensorflow.org/community/roadmap) +* [Tensorflow Twitter](https://twitter.com/tensorflow) * [TensorFlow Website](https://www.tensorflow.org) * [TensorFlow White Papers](https://www.tensorflow.org/about/bib) * [TensorFlow YouTube Channel](https://www.youtube.com/channel/UC0rqucBdTuFTjJiefW5t-IQ) -* [TensorFlow Model Zoo](https://github.com/tensorflow/models) -* [TensorFlow MOOC on Udacity](https://www.udacity.com/course/deep-learning--ud730) -* [TensorFlow Course at Stanford](https://web.stanford.edu/class/cs20si) Learn more about the TensorFlow community at the [community page of tensorflow.org](https://www.tensorflow.org/community) for a few ways to participate. diff --git a/RELEASE.md b/RELEASE.md index 6b67072f8ecafa08c747f8296c7c2a59eb2350fa..763ef3b279dde209ed387534032deae40a33a9e4 100644 --- a/RELEASE.md +++ b/RELEASE.md @@ -1,3 +1,68 @@ +# Release 1.10.0 + +## Major Features And Improvements + +* The `tf.lite` runtime now supports `complex64`. +* Initial [Google Cloud Bigtable integration](https://github.com/tensorflow/tensorflow/tree/r1.10/tensorflow/contrib/bigtable) for `tf.data`. +* Improved local run behavior in `tf.estimator.train_and_evaluate` which does not reload checkpoints for evaluation. +* `RunConfig` now sets device_filters to restrict how workers and PS can communicate. This can speed up training and ensure clean shutdowns in some situations. But if you have jobs that require communication between workers, you will have to set custom session_options in your `RunConfig`. +* Moved Distributions and Bijectors from `tf.contrib.distributions` to [Tensorflow Probability (TFP)](https://github.com/tensorflow/probability). `tf.contrib.distributions` is now deprecated and will be removed by the end of 2018. +* Adding new endpoints for existing tensorflow symbols. These endpoints are going to be the preferred endpoints going forward and may replace some of the existing endpoints in the future. See below for the complete list. New symbols have been added to the following modules: [`tf.debugging`](https://www.tensorflow.org/versions/master/api_docs/python/tf/debugging), [`tf.dtypes`](https://www.tensorflow.org/versions/master/api_docs/python/tf/dtypes), [`tf.image`](https://www.tensorflow.org/versions/master/api_docs/python/tf/image), [`tf.io`](https://www.tensorflow.org/versions/master/api_docs/python/tf/io), [`tf.linalg`](https://www.tensorflow.org/versions/master/api_docs/python/tf/linalg), [`tf.manip`](https://www.tensorflow.org/versions/master/api_docs/python/tf/manip), [`tf.math`](https://www.tensorflow.org/versions/master/api_docs/python/tf/math), [`tf.quantization`](https://www.tensorflow.org/versions/master/api_docs/python/tf/quantization), [`tf.strings`](https://www.tensorflow.org/versions/master/api_docs/python/tf/strings) + +## Breaking Changes + +* Prebuilt binaries are now (as of TensorFlow 1.10) built against NCCL 2.2 and no longer include NCCL in the binary install. TensorFlow usage with multiple GPUs and NCCL requires upgrade to [NCCL 2.2](https://developer.nvidia.com/nccl). See updated install guides: [Installing TensorFlow on Ubuntu](https://www.tensorflow.org/install/install_linux#tensorflow_gpu_support) and [Install TensorFlow from Sources](https://www.tensorflow.org/install/install_sources#optional_install_tensorflow_for_gpu_prerequisites). +* Starting from TensorFlow 1.11, Windows builds will use Bazel. Therefore, we will drop official support for cmake. + +## Bug Fixes and Other Changes + +* `tf.data`: + * `tf.contrib.data.group_by_reducer()` is now available via the public API. + * `tf.contrib.data.choose_from_datasets()` is now available via the public API. + * Adding `drop_remainder` argument to `tf.data.Dataset.batch()` and `tf.data.Dataset.padded_batch()`, deprecating `tf.contrib.data.batch_and_drop_remainder()` and `tf.contrib.data.padded_batch_and_drop_remainder()`. +* `tf.estimator`: + * `Estimator`s now use custom savers included in `EstimatorSpec` scaffolds for saving SavedModels during export. + * `EstimatorSpec` will now add a default prediction output for export if no `export_output` is provided, eliminating the need to explicitly include a `PredictOutput` object in the `model_fn` for simple use-cases. + * Support sparse_combiner in canned Linear Estimators. + * Added batch normalization to `DNNClassifier`, `DNNRegressor`, and `DNNEstimator`. + * Adding ranking support for boosted trees. + * Adding center bias option for boosted trees. +* Add `synchronization` and `aggregation` args to get_variable(). These args will be used for distributed variables. +* Add `synchronization` and `aggregation` args to the layer `add_weight()` API. These args will be used for distributed variables. +* `tf.losses.*` do not add to the global collection when executing eagerly (to avoid leaking memory). +* Support different summary and checkpoint directories in `tf.train.MonitoredTrainingSession()`. +* Added IndRNN, IndyGRU, and IndyLSTM cells to `tf.contrib.rnn`. +* Add safe static factory functions for SparseTensor and convert all CHECKs to DCHECKs. Using the constructor directly is unsafe and deprecated. +* Make the Bigtable client connection pool configurable & increase the default # of connections for performance. +* Added derivative of `tf.random_gamma` with respect to the alpha parameter. +* Added derivative of `tf.igamma(a, x)` and `tf.igammac(a, x)` with respect to a. +* Modified Bessel functions of order zero and one. +* Add FillTriangular Bijector to create triangular matrices. +* Added support for Type III DCT, and `tf.spectral.idct(type=2|3)`. +* Correctly handle CuDNN RNN weight loaded when nest in `TimeDistributed`. +* Adding per-element weight support for `WALSComputePartialLhsAndRhsOp`. +* ZerosLike and OnesLike ops treated as constants by Graph Transform Tool. +* Gamma distribution and the derived distributions (Beta, Dirichlet, Student's t, inverse Gamma) now fully reparameterized. +* Java: Experimental wrapper classes to make graph generation easier. Thanks @karllessard and @kbsriram +* Build & link in secure gRPC components (switch from the insecure grpc dependency to secure grpc dependency). +* Adding new endpoints for existing tensorflow symbols. These endpoints are going to be the preferred endpoints going forward and may replace some of the existing endpoints in the future. List of new endpoints: + * New endpoints in `tf.image` namespace: `tf.image.extract_image_patches` + * New endpoints in `tf.debugging` namespace: `tf.debugging.check_numerics`, `tf.debugging.is_finite`, `tf.debugging.is_inf`, `tf.debugging.is_nan`. + * New endpoints in `tf.dtypes` namespace: `tf.dtypes.as_string`. + * New endpoints in `tf.io` namespace: `tf.io.decode_base64`, `tf.io.decode_compressed`, `tf.io.decode_json_example`, `tf.io.decode_raw`, `tf.io.encode_base64`, `tf.io.matching_files`, `tf.io.parse_tensor`, `tf.io.read_file, `tf.io.write_file`. + * New endpoints in tf.linalg namespace: `tf.linalg.cross`, `tf.linalg.tensor_diag` (corresponds to `tf.diag`), `tf.linalg.tensor_diag_part` (corresponds to `tf.diag_part`). + * New endpoints in tf.manip namespace: `tf.manip.batch_to_space_nd`, `tf.manip.gather_nd`, `tf.manip.reshape`, `tf.manip.reverse`, `tf.manip.scatter_nd`, `tf.manip.space_to_batch_nd`, `tf.manip.tile` + * New endpoints in tf.math namespace: `tf.math.acos`, `tf.math.acosh`, `tf.math.add`, `tf.math.asin`, `tf.math.asinh`, `tf.math.atan`, `tf.math.atan2`, `tf.math.atanh`, `tf.math.betainc`, `tf.math.ceil`, `tf.math.cos`, `tf.math.cosh`, `tf.math.digamma`, `tf.math.equal`, `tf.math.erfc`, `tf.math.exp`, `tf.math.expm1`, `tf.math.floor`, `tf.math.greater`, `tf.math.greater_equal`, `tf.math.igamma`, `tf.math.igammac`, `tf.math.invert_permutation`, `tf.math.less`, `tf.math.less_equal`, `tf.math.lgamma`, `tf.math.log`, `tf.math.log1p`, `tf.math.logical_and`, `tf.math.logical_not`, `tf.math.logical_or`, `tf.math.maximum`, `tf.math.minimum`, `tf.math.not_equal`, `tf.math.polygamma`, `tf.math.reciprocal`, `tf.math.rint`, `tf.math.rsqrt`, `tf.math.segment_max`, `tf.math.segment_mean`, `tf.math.segment_min`, `tf.math.segment_prod`, `tf.math.segment_sum`, `tf.math.sin`, `tf.math.sinh`, `tf.math.softplus`, `tf.math.softsign`, `tf.math.squared_difference`, `tf.math.tan`, `tf.math.unsorted_segment_max`, `tf.math.unsorted_segment_min`, `tf.math.unsorted_segment_prod`, `tf.math.unsorted_segment_sum`, `tf.math.zeta`. + * New endpoints in `tf.quantization` namespace: `tf.quantization.dequantize`, `tf.quantization.fake_quant_with_min_max_args`, `tf.quantization.fake_quant_with_min_max_args_gradient`, `tf.quantization.fake_quant_with_min_max_vars`, `tf.quantization.fake_quant_with_min_max_vars_gradient`, `tf.quantization.fake_quant_with_min_max_vars_per_channel`, `tf.quantization.fake_quant_with_min_max_vars_per_channel_gradient`. + * New endpoints in tf.strings namespace: `tf.strings.join` (corresponds to `tf.string_join`), `tf.strings.regex_replace`, `tf.strings.to_number` (corresponds to `tf.string_to_number`), `tf.strings.strip` (corresponds to `tf.string_strip`), `tf.strings.substr`, `tf.strings.to_hash_bucket` (corresponds to `tf.string_to_hash_bucket`), `tf.strings.to_hash_bucket_fast` (corresponds to `tf.string_to_hash_bucket_fast`), `tf.strings.to_hash_bucket_strong` (corresponds to `tf.string_to_hash_bucket_strong`). + + +## Thanks to our Contributors + +This release contains contributions from many people at Google, as well as: + +Ag Ramesh, Alex Wiltschko, Alexander Pantyukhin, Amogh Mannekote, An Jiaoyang, Andrei Nigmatulin, Andrew Ginns, BjøRn Moholt, Brett Koonce, Chengzhi Chen, Chinmay Das, Christian Ertler, Christoph Boeddeker, Clayne Robison, Courtial Florian, ctiijima, Dan Douthit, Dan J, Dan Ringwalt, EFanZh, Emanuele Ballarin, eqy, Evgeniy Zheltonozhskiy, Freedom" Koan-Sin Tan, FréDéRic Branchaud-Charron, G K, gracehoney, Guillaume Klein, Guozhong Zhuang, Hsien-Yang Li, hsm207, ImSheridan, Jayaram Bobba, Jiandong Ruan, Jie, Joel Shor, Jonas Rauber, Jongmin Baek, jsawruk, Karan Kaw, Karl Lessard, karl@kubx.ca, Kb Sriram, KinmanLam, leiiwang, Li, Yiqiang, Loo Rong Jie, Mahmoud Abuzaina, Mahmoud Aslan, ManHyuk, Martin Patz, Martin Zeitler, mktozk, Mohammad Ashraf Bhuiyan, mrTsjolder, Naman Bhalla, Nick Felt, Nicolas Lopez, Niranjan Hasabnis, Nishidha Panpaliya, Nitish, nrstott, Nutti, Parag Jain, PeterLee, Philipp Jund, Rach L, Rafal Wojdyla, Roland Zimmermann, Sergei Lebedev, SneakyFish5, Soila Kavulya, Sriram Veturi, Steven Schmatz, Taehoon Lee, Tang, Wenyi, Taras Sereda, Ted Chang, Tim Zaman, Tristan Rice, tucan, vchigrin, Vikram Tiwari, Vincent, WeberXie, William D. Irons, Yan Facai (颜发才), Yong Tang, Yu Yi, Yuxin Wu, Zé ViníCius + # Release 1.9.0 ## Major Features And Improvements diff --git a/configure.py b/configure.py index f97bf8a66836a6647ba6aca625cb1526e11b39af..b6285cfc385836b14f1f58ba9675de3070d3fde0 100644 --- a/configure.py +++ b/configure.py @@ -839,14 +839,15 @@ def set_tf_cuda_version(environ_cp): cuda_toolkit_path = cygpath(cuda_toolkit_path) if is_windows(): - cuda_rt_lib_path = 'lib/x64/cudart.lib' + cuda_rt_lib_paths = ['lib/x64/cudart.lib'] elif is_linux(): - cuda_rt_lib_path = 'lib64/libcudart.so.%s' % tf_cuda_version + cuda_rt_lib_paths = ['%s/libcudart.so.%s' % (x, tf_cuda_version) + for x in ['lib64', 'lib/x86_64-linux-gnu']] elif is_macos(): - cuda_rt_lib_path = 'lib/libcudart.%s.dylib' % tf_cuda_version + cuda_rt_lib_paths = ['lib/libcudart.%s.dylib' % tf_cuda_version] - cuda_toolkit_path_full = os.path.join(cuda_toolkit_path, cuda_rt_lib_path) - if os.path.exists(cuda_toolkit_path_full): + cuda_toolkit_paths_full = [os.path.join(cuda_toolkit_path, x) for x in cuda_rt_lib_paths] + if any([os.path.exists(x) for x in cuda_toolkit_paths_full]): break # Reset and retry @@ -1398,8 +1399,11 @@ def set_grpc_build_flags(): write_to_bazelrc('build --define grpc_no_ares=true') -def set_build_strip_flag(): - write_to_bazelrc('build --strip=always') +def set_system_libs_flag(environ_cp): + syslibs = environ_cp.get('TF_SYSTEM_LIBS', '') + syslibs = ','.join(sorted(syslibs.split(','))) + if syslibs and syslibs != '': + write_action_env_to_bazelrc('TF_SYSTEM_LIBS', syslibs) def set_windows_build_flags(environ_cp): @@ -1558,7 +1562,7 @@ def main(): set_grpc_build_flags() set_cc_opt_flags(environ_cp) - set_build_strip_flag() + set_system_libs_flag(environ_cp) if is_windows(): set_windows_build_flags(environ_cp) diff --git a/tensorflow/BUILD b/tensorflow/BUILD index 388ca3f293ebfa120037b75fe70c66b9d715c051..94e059b9148bd1a84d7bda1c79bde79f8c8324ad 100644 --- a/tensorflow/BUILD +++ b/tensorflow/BUILD @@ -123,12 +123,6 @@ config_setting( visibility = ["//visibility:public"], ) -config_setting( - name = "windows_msvc", - values = {"cpu": "x64_windows_msvc"}, - visibility = ["//visibility:public"], -) - config_setting( name = "no_tensorflow_py_deps", define_values = {"no_tensorflow_py_deps": "true"}, @@ -381,6 +375,15 @@ config_setting( }, ) +# Setting to use when loading kernels dynamically +config_setting( + name = "dynamic_loaded_kernels", + define_values = { + "dynamic_loaded_kernels": "true", + }, + visibility = ["//visibility:public"], +) + config_setting( name = "using_cuda_nvcc", define_values = { @@ -408,14 +411,6 @@ config_setting( visibility = ["//visibility:public"], ) -# TODO(laigd): consider removing this option and make TensorRT enabled -# automatically when CUDA is enabled. -config_setting( - name = "with_tensorrt_support", - values = {"define": "with_tensorrt_support=true"}, - visibility = ["//visibility:public"], -) - package_group( name = "internal", packages = [ @@ -429,23 +424,18 @@ package_group( load( "//third_party/mkl:build_defs.bzl", - "if_mkl", + "if_mkl_ml", ) filegroup( name = "intel_binary_blob", - data = if_mkl( + data = if_mkl_ml( [ "//third_party/mkl:intel_binary_blob", ], ), ) -filegroup( - name = "docs_src", - data = glob(["docs_src/**/*.md"]), -) - cc_library( name = "grpc", deps = select({ @@ -492,7 +482,6 @@ tf_cc_shared_object( linkopts = select({ "//tensorflow:darwin": [], "//tensorflow:windows": [], - "//tensorflow:windows_msvc": [], "//conditions:default": [ "-Wl,--version-script", # This line must be directly followed by the version_script.lds file "$(location //tensorflow:tf_framework_version_script.lds)", @@ -534,7 +523,6 @@ tf_cc_shared_object( "-Wl,-install_name,@rpath/libtensorflow.so", ], "//tensorflow:windows": [], - "//tensorflow:windows_msvc": [], "//conditions:default": [ "-z defs", "-Wl,--version-script", # This line must be directly followed by the version_script.lds file @@ -559,7 +547,6 @@ tf_cc_shared_object( "$(location //tensorflow:tf_exported_symbols.lds)", ], "//tensorflow:windows": [], - "//tensorflow:windows_msvc": [], "//conditions:default": [ "-z defs", "-Wl,--version-script", # This line must be directly followed by the version_script.lds file @@ -589,6 +576,7 @@ exports_files( gen_api_init_files( name = "tensorflow_python_api_gen", srcs = ["api_template.__init__.py"], + api_version = 1, root_init_template = "api_template.__init__.py", ) diff --git a/tensorflow/__init__.py b/tensorflow/__init__.py index 440e9f8dbd2f4b2a2ab78eaaf26408584e7c1446..21677512b63828fa2035527ed573bf4dc4603085 100644 --- a/tensorflow/__init__.py +++ b/tensorflow/__init__.py @@ -28,7 +28,8 @@ contrib = LazyLoader('contrib', globals(), 'tensorflow.contrib') del LazyLoader from tensorflow.python.platform import flags # pylint: disable=g-import-not-at-top -app.flags = flags # pylint: disable=undefined-variable +from tensorflow.python.platform import app # pylint: disable=g-import-not-at-top +app.flags = flags del absolute_import del division diff --git a/tensorflow/c/c_api.cc b/tensorflow/c/c_api.cc index 10bc8cdbee5a9df6d2084c10adab4ed6e5e6f0d3..b8adf6c1279e72d0c2056368253aa0cb470216e5 100644 --- a/tensorflow/c/c_api.cc +++ b/tensorflow/c/c_api.cc @@ -52,6 +52,7 @@ limitations under the License. #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/mem.h" #include "tensorflow/core/platform/mutex.h" @@ -201,7 +202,8 @@ TF_Tensor* TF_NewTensor(TF_DataType dtype, const int64_t* dims, int num_dims, buf->len_ = len; if (dtype != TF_STRING && dtype != TF_RESOURCE && tensorflow::DataTypeCanUseMemcpy(static_cast(dtype)) && - reinterpret_cast(data) % EIGEN_MAX_ALIGN_BYTES != 0) { + reinterpret_cast(data) % std::max(1, EIGEN_MAX_ALIGN_BYTES) != + 0) { // TF_STRING and TF_RESOURCE tensors have a different representation in // TF_Tensor than they do in tensorflow::Tensor. So a copy here is a waste // (any alignment requirements will be taken care of by TF_TensorToTensor @@ -2389,6 +2391,12 @@ void TF_AbortWhile(const TF_WhileParams* params) { FreeWhileResources(params); } void TF_AddGradients(TF_Graph* g, TF_Output* y, int ny, TF_Output* x, int nx, TF_Output* dx, TF_Status* status, TF_Output* dy) { + TF_AddGradientsWithPrefix(g, nullptr, y, ny, x, nx, dx, status, dy); +} + +void TF_AddGradientsWithPrefix(TF_Graph* g, const char* prefix, TF_Output* y, + int ny, TF_Output* x, int nx, TF_Output* dx, + TF_Status* status, TF_Output* dy) { #ifdef __ANDROID__ status->status = tensorflow::errors::Unimplemented( "Adding gradients is not supported in Android. File a bug at " @@ -2405,9 +2413,29 @@ void TF_AddGradients(TF_Graph* g, TF_Output* y, int ny, TF_Output* x, int nx, const int first_new_node_id = g->graph.num_node_ids(); + string prefix_cmp; + const char* child_scope_name; + if (prefix == nullptr) { + child_scope_name = "gradients"; + } else { + prefix_cmp = string(prefix) + "/"; + // The operation should fail if the provided name prefix has already been + // used in this graph + for (const auto& pair : g->name_map) { + const string& name = pair.first; + if (name.compare(prefix) == 0 || + tensorflow::str_util::StartsWith(name, prefix_cmp)) { + status->status = InvalidArgument( + "prefix [", prefix, + "] conflicts with existing node in the graph named [", name, "]"); + return; + } + } + child_scope_name = prefix; + } tensorflow::Scope scope = NewInternalScope(&g->graph, &status->status, &g->refiner) - .NewSubScope("gradients"); + .NewSubScope(child_scope_name); if (dx != nullptr) { std::vector dx_arg = OutputsFromTFOutputs(dx, ny); @@ -2422,6 +2450,18 @@ void TF_AddGradients(TF_Graph* g, TF_Output* y, int ny, TF_Output* x, int nx, for (int i = first_new_node_id; i < g->graph.num_node_ids(); ++i) { Node* n = g->graph.FindNodeId(i); if (n == nullptr) continue; + + // Adding the gradients to the graph can alter the prefix to prevent + // name collisions only if this prefix has not been provided explicitly + // by the user. If it was provided, assert that it remained intact. + if (prefix != nullptr && + !tensorflow::str_util::StartsWith(n->name(), prefix_cmp)) { + status->status = tensorflow::errors::Internal( + "BUG: The gradients prefix have been unexpectedly altered when " + "adding the nodes to the graph. This is a bug. Please file an " + "issue at https://github.com/tensorflow/tensorflow/issues."); + return; + } // We have a convoluted scheme here: Using the C++ graph construction API // to add potentially many nodes to the graph without running the checks // (such as uniqueness of the names of nodes) we run with other functions diff --git a/tensorflow/c/c_api.h b/tensorflow/c/c_api.h index c8ae6f2dd1780c4fe50ff1924be8d2e9a7502cf0..850f6ecd637d768bca99720e0add07680829e17a 100644 --- a/tensorflow/c/c_api.h +++ b/tensorflow/c/c_api.h @@ -1131,6 +1131,7 @@ TF_CAPI_EXPORT extern void TF_AbortWhile(const TF_WhileParams* params); // Adds operations to compute the partial derivatives of sum of `y`s w.r.t `x`s, // i.e., d(y_1 + y_2 + ...)/dx_1, d(y_1 + y_2 + ...)/dx_2... +// // `dx` are used as initial gradients (which represent the symbolic partial // derivatives of some loss function `L` w.r.t. `y`). // `dx` must be nullptr or have size `ny`. @@ -1139,6 +1140,12 @@ TF_CAPI_EXPORT extern void TF_AbortWhile(const TF_WhileParams* params); // The partial derivatives are returned in `dy`. `dy` should be allocated to // size `nx`. // +// Gradient nodes are automatically named under the "gradients/" prefix. To +// guarantee name uniqueness, subsequent calls to the same graph will +// append an incremental tag to the prefix: "gradients_1/", "gradients_2/", ... +// See TF_AddGradientsWithPrefix, which provides a means to specify a custom +// name prefix for operations added to a graph to compute the gradients. +// // WARNING: This function does not yet support all the gradients that python // supports. See // https://www.tensorflow.org/code/tensorflow/cc/gradients/README.md @@ -1147,6 +1154,33 @@ TF_CAPI_EXPORT void TF_AddGradients(TF_Graph* g, TF_Output* y, int ny, TF_Output* x, int nx, TF_Output* dx, TF_Status* status, TF_Output* dy); +// Adds operations to compute the partial derivatives of sum of `y`s w.r.t `x`s, +// i.e., d(y_1 + y_2 + ...)/dx_1, d(y_1 + y_2 + ...)/dx_2... +// This is a variant of TF_AddGradients that allows to caller to pass a custom +// name prefix to the operations added to a graph to compute the gradients. +// +// `dx` are used as initial gradients (which represent the symbolic partial +// derivatives of some loss function `L` w.r.t. `y`). +// `dx` must be nullptr or have size `ny`. +// If `dx` is nullptr, the implementation will use dx of `OnesLike` for all +// shapes in `y`. +// The partial derivatives are returned in `dy`. `dy` should be allocated to +// size `nx`. +// `prefix` names the scope into which all gradients operations are being added. +// `prefix` must be unique within the provided graph otherwise this operation +// will fail. If `prefix` is nullptr, the default prefixing behaviour takes +// place, see TF_AddGradients for more details. +// +// WARNING: This function does not yet support all the gradients that python +// supports. See +// https://www.tensorflow.org/code/tensorflow/cc/gradients/README.md +// for instructions on how to add C++ more gradients. +TF_CAPI_EXPORT void TF_AddGradientsWithPrefix(TF_Graph* g, const char* prefix, + TF_Output* y, int ny, + TF_Output* x, int nx, + TF_Output* dx, TF_Status* status, + TF_Output* dy); + // Create a TF_Function from a TF_Graph // // Params: @@ -1236,6 +1270,11 @@ TF_CAPI_EXPORT extern TF_Function* TF_GraphToFunction( int noutputs, const TF_Output* outputs, const char* const* output_names, const TF_FunctionOptions* opts, const char* description, TF_Status* status); +// Returns the name of the graph function. +// The return value points to memory that is only usable until the next +// mutation to *func. +TF_CAPI_EXPORT extern const char* TF_FunctionName(TF_Function* func); + // Write out a serialized representation of `func` (as a FunctionDef protocol // message) to `output_func_def` (allocated by TF_NewBuffer()). // `output_func_def`'s underlying buffer will be freed when TF_DeleteBuffer() diff --git a/tensorflow/c/c_api_experimental.cc b/tensorflow/c/c_api_experimental.cc index 170046c8024dc85c899108b254cd3a95a3be4096..69b3ffe2a1f620e346405607ecf742fb863aa644 100644 --- a/tensorflow/c/c_api_experimental.cc +++ b/tensorflow/c/c_api_experimental.cc @@ -84,6 +84,18 @@ TF_Buffer* TF_CreateConfig(unsigned char enable_xla_compilation, return ret; } +TF_Buffer* TF_CreateRunOptions(unsigned char enable_full_trace) { + tensorflow::RunOptions options; + if (enable_full_trace) { + options.set_trace_level(tensorflow::RunOptions::FULL_TRACE); + } else { + options.set_trace_level(tensorflow::RunOptions::NO_TRACE); + } + TF_Buffer* ret = TF_NewBuffer(); + TF_CHECK_OK(MessageToBuffer(options, ret)); + return ret; +} + const char* TF_GraphDebugString(TF_Graph* graph, size_t* len) { tensorflow::mutex_lock c(graph->mu); const auto& debug_str = graph->graph.ToGraphDefDebug().DebugString(); diff --git a/tensorflow/c/c_api_experimental.h b/tensorflow/c/c_api_experimental.h index 2d81c01e0dd056e9beb3b45f24809381554a7924..6617c5a572e90e78369f73d714f39942f213040f 100644 --- a/tensorflow/c/c_api_experimental.h +++ b/tensorflow/c/c_api_experimental.h @@ -70,6 +70,12 @@ TF_CAPI_EXPORT extern TF_Buffer* TF_CreateConfig( unsigned char enable_xla_compilation, unsigned char gpu_memory_allow_growth); +// Create a serialized tensorflow.RunOptions proto, where RunOptions.trace_level +// is set to FULL_TRACE if `enable_full_trace` is non-zero, and NO_TRACE +// otherwise. +TF_CAPI_EXPORT extern TF_Buffer* TF_CreateRunOptions( + unsigned char enable_full_trace); + // Returns the graph content in a human-readable format, with length set in // `len`. The format is subject to change in the future. // The returned string is heap-allocated, and caller should call free() on it. diff --git a/tensorflow/c/c_api_function.cc b/tensorflow/c/c_api_function.cc index 384e6c8cb97022264c5327da5ca5861057608fbe..a2c5a42c11361779de61b515e0f08dcc45e609b9 100644 --- a/tensorflow/c/c_api_function.cc +++ b/tensorflow/c/c_api_function.cc @@ -536,6 +536,10 @@ TF_Function* TF_GraphToFunction(const TF_Graph* fn_body, const char* fn_name, return tf_function; } +const char* TF_FunctionName(TF_Function* func) { + return func->fdef.signature().name().c_str(); +} + void TF_GraphCopyFunction(TF_Graph* g, const TF_Function* func, const TF_Function* grad, TF_Status* status) { if (func == nullptr) { diff --git a/tensorflow/c/c_api_function_test.cc b/tensorflow/c/c_api_function_test.cc index f7ca219c896b2a7c07fc4d0739c70f2666652672..73fe73769bc1219ce865149d67d333c53371ccc5 100644 --- a/tensorflow/c/c_api_function_test.cc +++ b/tensorflow/c/c_api_function_test.cc @@ -193,6 +193,7 @@ class CApiFunctionTest : public ::testing::Test { ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); ASSERT_NE(func_, nullptr); + ASSERT_EQ(std::string(func_name_), std::string(TF_FunctionName(func_))); TF_GraphCopyFunction(host_graph_, func_, nullptr, s_); ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); } @@ -1618,5 +1619,66 @@ TEST_F(CApiFunctionTest, GetFunctionsFromGraph) { TF_DeleteFunction(func1); } +// This test only works when the TF build includes XLA compiler. One way to set +// this up is via bazel build option "--define with_xla_support=true". +// +// FIXME: generalize the macro name TENSORFLOW_EAGER_USE_XLA to +// something like TENSORFLOW_CAPI_USE_XLA. +#ifdef TENSORFLOW_EAGER_USE_XLA +TEST_F(CApiFunctionTest, StatelessIf_XLA) { + TF_Function* func; + const std::string funcName = "BranchFunc"; + DefineFunction(funcName.c_str(), &func); + TF_GraphCopyFunction(host_graph_, func, nullptr, s_); + ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); + + TF_Operation* feed = Placeholder(host_graph_, s_); + ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); + + TF_Operation* true_cond = ScalarConst(true, host_graph_, s_); + ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); + + TF_OperationDescription* desc = + TF_NewOperation(host_graph_, "StatelessIf", "IfNode"); + TF_AddInput(desc, {true_cond, 0}); + TF_Output inputs[] = {{feed, 0}}; + TF_AddInputList(desc, inputs, TF_ARRAYSIZE(inputs)); + ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); + TF_SetAttrType(desc, "Tcond", TF_BOOL); + TF_DataType inputType = TF_INT32; + TF_SetAttrTypeList(desc, "Tin", &inputType, 1); + TF_SetAttrTypeList(desc, "Tout", &inputType, 1); + TF_SetAttrFuncName(desc, "then_branch", funcName.data(), funcName.size()); + TF_SetAttrFuncName(desc, "else_branch", funcName.data(), funcName.size()); + TF_SetDevice(desc, "/device:XLA_CPU:0"); + auto op = TF_FinishOperation(desc, s_); + ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); + ASSERT_NE(op, nullptr); + + // Create a session for this graph. + CSession csession(host_graph_, s_, /*use_XLA*/ true); + ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); + + // Run the graph. + csession.SetInputs({{feed, Int32Tensor(17)}}); + csession.SetOutputs({op}); + csession.Run(s_); + ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); + TF_Tensor* out = csession.output_tensor(0); + ASSERT_TRUE(out != nullptr); + EXPECT_EQ(TF_INT32, TF_TensorType(out)); + EXPECT_EQ(0, TF_NumDims(out)); // scalar + ASSERT_EQ(sizeof(int32), TF_TensorByteSize(out)); + int32* output_contents = static_cast(TF_TensorData(out)); + EXPECT_EQ(-17, *output_contents); + + // Clean up + csession.CloseAndDelete(s_); + ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); + + TF_DeleteFunction(func); +} +#endif // TENSORFLOW_EAGER_USE_XLA + } // namespace } // namespace tensorflow diff --git a/tensorflow/c/c_api_test.cc b/tensorflow/c/c_api_test.cc index e674b1623cf540eb8024d9be5ed8d77aa2fe17ba..aa2a537f03be31ae45ff3d6f7815b449d661cf9c 100644 --- a/tensorflow/c/c_api_test.cc +++ b/tensorflow/c/c_api_test.cc @@ -1483,8 +1483,8 @@ class CApiGradientsTest : public ::testing::Test { BuildSuccessGraph(inputs, outputs); BuildExpectedGraph(grad_inputs_provided, expected_grad_outputs); - AddGradients(grad_inputs_provided, inputs, 2, outputs, 1, grad_outputs); - + AddGradients(grad_inputs_provided, nullptr, inputs, 2, outputs, 1, + grad_outputs); EXPECT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); // Compare that the graphs match. @@ -1505,7 +1505,8 @@ class CApiGradientsTest : public ::testing::Test { BuildErrorGraph(inputs, outputs); - AddGradients(grad_inputs_provided, inputs, 1, outputs, 1, grad_outputs); + AddGradients(grad_inputs_provided, nullptr, inputs, 1, outputs, 1, + grad_outputs); string expected_msg = "No gradient defined for op: TestOpWithNoGradient. Please see " @@ -1549,19 +1550,20 @@ class CApiGradientsTest : public ::testing::Test { EXPECT_EQ(*a_data, *b_data); } - void AddGradients(bool grad_inputs_provided, TF_Output* inputs, int ninputs, - TF_Output* outputs, int noutputs, TF_Output* grad_outputs) { + void AddGradients(bool grad_inputs_provided, const char* prefix, + TF_Output* inputs, int ninputs, TF_Output* outputs, + int noutputs, TF_Output* grad_outputs) { if (grad_inputs_provided) { TF_Output grad_inputs[1]; const float grad_inputs_val[] = {1.0, 1.0, 1.0, 1.0}; TF_Operation* grad_inputs_op = FloatConst2x2(graph_, s_, grad_inputs_val, "GradInputs"); grad_inputs[0] = TF_Output{grad_inputs_op, 0}; - TF_AddGradients(graph_, outputs, noutputs, inputs, ninputs, grad_inputs, - s_, grad_outputs); + TF_AddGradientsWithPrefix(graph_, prefix, outputs, noutputs, inputs, + ninputs, grad_inputs, s_, grad_outputs); } else { - TF_AddGradients(graph_, outputs, noutputs, inputs, ninputs, nullptr, s_, - grad_outputs); + TF_AddGradientsWithPrefix(graph_, prefix, outputs, noutputs, inputs, + ninputs, nullptr, s_, grad_outputs); } } @@ -1706,6 +1708,20 @@ class CApiGradientsTest : public ::testing::Test { return op; } + void BuildGraphAndAddGradientsWithPrefixes(const char* prefix1, + const char* prefix2 = nullptr) { + TF_Output inputs[2]; + TF_Output outputs[1]; + TF_Output grad_outputs[2]; + + BuildSuccessGraph(inputs, outputs); + + AddGradients(false, prefix1, inputs, 2, outputs, 1, grad_outputs); + if (prefix2 != nullptr) { + AddGradients(false, prefix2, inputs, 2, outputs, 1, grad_outputs); + } + } + TF_Status* s_; TF_Graph* graph_; TF_Graph* expected_graph_; @@ -1725,6 +1741,56 @@ TEST_F(CApiGradientsTest, OpWithNoGradientRegistered_NoGradInputs) { TestGradientsError(false); } +TEST_F(CApiGradientsTest, GradientsPrefix_PrefixIsOk) { + BuildGraphAndAddGradientsWithPrefixes("gradients"); + ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); +} + +TEST_F(CApiGradientsTest, GradientsPrefix_TwoGradientsWithDistinctPrefixes) { + BuildGraphAndAddGradientsWithPrefixes("gradients", "gradients_1"); + ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); +} + +TEST_F(CApiGradientsTest, GradientsPrefix_TwoGradientsInSameScope) { + BuildGraphAndAddGradientsWithPrefixes("scope/gradients", "scope/gradients_1"); + ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); +} + +TEST_F(CApiGradientsTest, GradientsPrefix_TwoGradientsInDifferentScopes) { + BuildGraphAndAddGradientsWithPrefixes("scope/gradients", "scope_1/gradients"); + ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); +} + +TEST_F(CApiGradientsTest, GradientsPrefix_2ndGradientsAsSubScopeOf1st) { + BuildGraphAndAddGradientsWithPrefixes("gradients", "gradients/sub"); + ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); +} + +TEST_F(CApiGradientsTest, GradientsPrefix_PrefixMatchesExistingNodeName) { + BuildGraphAndAddGradientsWithPrefixes("Const_0"); + ASSERT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(s_)) << TF_Message(s_); +} + +TEST_F(CApiGradientsTest, GradientsPrefix_TwoGradientsWithIdenticalPrefixes) { + BuildGraphAndAddGradientsWithPrefixes("gradients", "gradients"); + ASSERT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(s_)) << TF_Message(s_); +} + +TEST_F(CApiGradientsTest, GradientsPrefix_2ndGradientsMatchingNodeOf1st) { + BuildGraphAndAddGradientsWithPrefixes("gradients", "gradients/MatMul"); + ASSERT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(s_)) << TF_Message(s_); +} + +TEST_F(CApiGradientsTest, GradientsPrefix_1stGradientsMatchingNodeOf2nd) { + BuildGraphAndAddGradientsWithPrefixes("gradients/MatMul", "gradients"); + ASSERT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(s_)) << TF_Message(s_); +} + +TEST_F(CApiGradientsTest, GradientsPrefix_2ndGradientsAsParentScopeOf1st) { + BuildGraphAndAddGradientsWithPrefixes("gradients/sub", "gradients"); + ASSERT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(s_)) << TF_Message(s_); +} + void ScalarFloatFromTensor(const TF_Tensor* t, float* f) { ASSERT_TRUE(t != nullptr); ASSERT_EQ(TF_FLOAT, TF_TensorType(t)); diff --git a/tensorflow/c/c_test_util.cc b/tensorflow/c/c_test_util.cc index 24eb6c069b21349fce288db3e79fbf14e824ad11..f15d9ee20adb31a0b76e2cd0d1e67f17a9deff05 100644 --- a/tensorflow/c/c_test_util.cc +++ b/tensorflow/c/c_test_util.cc @@ -26,6 +26,10 @@ limitations under the License. using tensorflow::GraphDef; using tensorflow::NodeDef; +static void BoolDeallocator(void* data, size_t, void* arg) { + delete[] static_cast(data); +} + static void Int32Deallocator(void* data, size_t, void* arg) { delete[] static_cast(data); } @@ -38,6 +42,14 @@ static void FloatDeallocator(void* data, size_t, void* arg) { delete[] static_cast(data); } +TF_Tensor* BoolTensor(bool v) { + const int num_bytes = sizeof(bool); + bool* values = new bool[1]; + values[0] = v; + return TF_NewTensor(TF_BOOL, nullptr, 0, values, num_bytes, &BoolDeallocator, + nullptr); +} + TF_Tensor* Int8Tensor(const int64_t* dims, int num_dims, const char* values) { int64_t num_values = 1; for (int i = 0; i < num_dims; ++i) { @@ -131,6 +143,12 @@ TF_Operation* Const(TF_Tensor* t, TF_Graph* graph, TF_Status* s, return op; } +TF_Operation* ScalarConst(bool v, TF_Graph* graph, TF_Status* s, + const char* name) { + unique_tensor_ptr tensor(BoolTensor(v), TF_DeleteTensor); + return Const(tensor.get(), graph, s, name); +} + TF_Operation* ScalarConst(int32_t v, TF_Graph* graph, TF_Status* s, const char* name) { unique_tensor_ptr tensor(Int32Tensor(v), TF_DeleteTensor); diff --git a/tensorflow/c/c_test_util.h b/tensorflow/c/c_test_util.h index 38313d647ca93d4779bb1325f8ed7bde4b743879..7eeb1ee5e17ad7e5644f8bc8a18ca967b108475d 100644 --- a/tensorflow/c/c_test_util.h +++ b/tensorflow/c/c_test_util.h @@ -31,6 +31,8 @@ using ::tensorflow::string; typedef std::unique_ptr unique_tensor_ptr; +TF_Tensor* BoolTensor(int32_t v); + // Create a tensor with values of type TF_INT8 provided by `values`. TF_Tensor* Int8Tensor(const int64_t* dims, int num_dims, const char* values); @@ -55,6 +57,9 @@ TF_Operation* Placeholder(TF_Graph* graph, TF_Status* s, TF_Operation* Const(TF_Tensor* t, TF_Graph* graph, TF_Status* s, const char* name = "const"); +TF_Operation* ScalarConst(bool v, TF_Graph* graph, TF_Status* s, + const char* name = "scalar"); + TF_Operation* ScalarConst(int32_t v, TF_Graph* graph, TF_Status* s, const char* name = "scalar"); diff --git a/tensorflow/c/eager/c_api.cc b/tensorflow/c/eager/c_api.cc index 6c510536d6f2a586b91baf96fa41b779db2c8d35..dfb1c9a37644c726e1eabab775593596d5b556b9 100644 --- a/tensorflow/c/eager/c_api.cc +++ b/tensorflow/c/eager/c_api.cc @@ -110,7 +110,7 @@ tensorflow::Status GetAllRemoteDevices( tensorflow::Status CreateRemoteContexts( const std::vector& remote_workers, int64 rendezvous_id, - const tensorflow::ServerDef& server_def, + int keep_alive_secs, const tensorflow::ServerDef& server_def, tensorflow::eager::EagerClientCache* remote_eager_workers, bool async, tensorflow::gtl::FlatMap* remote_contexts) { for (int i = 0; i < remote_workers.size(); i++) { @@ -129,6 +129,7 @@ tensorflow::Status CreateRemoteContexts( request.mutable_server_def()->set_job_name(parsed_name.job); request.mutable_server_def()->set_task_index(parsed_name.task); request.set_async(async); + request.set_keep_alive_secs(keep_alive_secs); auto* eager_client = remote_eager_workers->GetClient(remote_worker); if (eager_client == nullptr) { return tensorflow::errors::Internal( @@ -150,8 +151,9 @@ tensorflow::Status CreateRemoteContexts( return tensorflow::Status::OK(); } -tensorflow::Status NewRemoteAwareTFE_Context(const TFE_ContextOptions* opts, - TFE_Context** ctx) { +tensorflow::Status UpdateTFE_ContextWithServerDef( + int keep_alive_secs, const tensorflow::ServerDef& server_def, + TFE_Context* ctx) { // We don't use the TF_RETURN_IF_ERROR macro directly since that destroys the // server object (which currently CHECK-fails) and we miss the error, instead, // we log the error, and then return to allow the user to see the error @@ -165,12 +167,12 @@ tensorflow::Status NewRemoteAwareTFE_Context(const TFE_ContextOptions* opts, } \ } while (0); - string worker_name = tensorflow::strings::StrCat( - "/job:", opts->server_def.job_name(), - "/replica:0/task:", opts->server_def.task_index()); + string worker_name = + tensorflow::strings::StrCat("/job:", server_def.job_name(), + "/replica:0/task:", server_def.task_index()); std::unique_ptr server; - LOG_AND_RETURN_IF_ERROR(tensorflow::NewServer(opts->server_def, &server)); + LOG_AND_RETURN_IF_ERROR(tensorflow::NewServer(server_def, &server)); tensorflow::GrpcServer* grpc_server = dynamic_cast(server.get()); @@ -202,15 +204,15 @@ tensorflow::Status NewRemoteAwareTFE_Context(const TFE_ContextOptions* opts, // Initialize remote eager workers. tensorflow::gtl::FlatMap remote_contexts; LOG_AND_RETURN_IF_ERROR(CreateRemoteContexts( - remote_workers, rendezvous_id, opts->server_def, - remote_eager_workers.get(), opts->async, &remote_contexts)); + remote_workers, rendezvous_id, keep_alive_secs, server_def, + remote_eager_workers.get(), ctx->context.Async(), &remote_contexts)); tensorflow::RemoteRendezvous* r = grpc_server->worker_env()->rendezvous_mgr->Find(rendezvous_id); auto session_name = tensorflow::strings::StrCat("eager_", rendezvous_id); TF_RETURN_IF_ERROR(grpc_server->worker_env()->session_mgr->CreateSession( - session_name, opts->server_def, true)); + session_name, server_def, true)); std::shared_ptr worker_session; TF_RETURN_IF_ERROR( @@ -221,10 +223,11 @@ tensorflow::Status NewRemoteAwareTFE_Context(const TFE_ContextOptions* opts, TF_RETURN_IF_ERROR(r->Initialize(worker_session.get())); auto* device_mgr = grpc_server->worker_env()->device_mgr; - *ctx = new TFE_Context(opts->session_options.options, opts->policy, - opts->async, device_mgr, r, std::move(server), - std::move(remote_eager_workers), - std::move(remote_device_mgr), remote_contexts); + + ctx->context.InitializeRemote(std::move(server), + std::move(remote_eager_workers), + std::move(remote_device_mgr), remote_contexts, + r, device_mgr, keep_alive_secs); return tensorflow::Status::OK(); #undef LOG_AND_RETURN_IF_ERROR @@ -249,15 +252,6 @@ void TFE_ContextOptionsSetDevicePlacementPolicy( options->policy = policy; } -TF_CAPI_EXPORT extern void TFE_ContextOptionsSetServerDef( - TFE_ContextOptions* options, const void* proto, size_t proto_len, - TF_Status* status) { - if (!options->server_def.ParseFromArray(proto, proto_len)) { - status->status = tensorflow::errors::InvalidArgument( - "Invalid tensorflow.ServerDef protocol buffer"); - } -} - TF_CAPI_EXPORT extern void TFE_ContextSetAsyncForThread(TFE_Context* ctx, unsigned char async, TF_Status* status) { @@ -267,12 +261,6 @@ TF_CAPI_EXPORT extern void TFE_ContextSetAsyncForThread(TFE_Context* ctx, void TFE_DeleteContextOptions(TFE_ContextOptions* options) { delete options; } TFE_Context* TFE_NewContext(const TFE_ContextOptions* opts, TF_Status* status) { - if (!opts->server_def.job_name().empty()) { - TFE_Context* ctx = nullptr; - status->status = NewRemoteAwareTFE_Context(opts, &ctx); - return ctx; - } - std::vector devices; status->status = tensorflow::DeviceFactory::AddDevices( opts->session_options.options, "/job:localhost/replica:0/task:0", @@ -288,7 +276,7 @@ TFE_Context* TFE_NewContext(const TFE_ContextOptions* opts, TF_Status* status) { opts->async, std::move(device_mgr), r); } -void TFE_DeleteContext(TFE_Context* ctx, TF_Status* status) { delete ctx; } +void TFE_DeleteContext(TFE_Context* ctx) { delete ctx; } TF_DeviceList* TFE_ContextListDevices(TFE_Context* ctx, TF_Status* status) { TF_DeviceList* list = new TF_DeviceList; @@ -301,6 +289,22 @@ TF_DeviceList* TFE_ContextListDevices(TFE_Context* ctx, TF_Status* status) { void TFE_ContextClearCaches(TFE_Context* ctx) { ctx->context.ClearCaches(); } +// Set server_def on the context, possibly updating it. +TF_CAPI_EXPORT extern void TFE_ContextSetServerDef(TFE_Context* ctx, + int keep_alive_secs, + const void* proto, + size_t proto_len, + TF_Status* status) { + tensorflow::ServerDef server_def; + if (!server_def.ParseFromArray(proto, proto_len)) { + status->status = tensorflow::errors::InvalidArgument( + "Invalid tensorflow.ServerDef protocol buffer"); + return; + } + status->status = + UpdateTFE_ContextWithServerDef(keep_alive_secs, server_def, ctx); +} + void TFE_ContextSetThreadLocalDevicePlacementPolicy( TFE_Context* ctx, TFE_ContextDevicePlacementPolicy policy) { ctx->context.SetThreadLocalDevicePlacementPolicy( @@ -336,7 +340,7 @@ TFE_TensorHandle* TFE_NewTensorHandle(TF_Tensor* t, TF_Status* status) { } void TFE_DeleteTensorHandle(TFE_TensorHandle* h) { - DCHECK(h); + if (h == nullptr) return; if (h->handle) { h->handle->Unref(); } @@ -348,6 +352,11 @@ TF_DataType TFE_TensorHandleDataType(TFE_TensorHandle* h) { } int TFE_TensorHandleNumDims(TFE_TensorHandle* h, TF_Status* status) { + if (h == nullptr || h->handle == nullptr) { + status->status = tensorflow::errors::InvalidArgument( + "The passed in handle is a nullptr"); + return -1; + } int result; status->status = h->handle->NumDims(&result); return result; @@ -355,12 +364,22 @@ int TFE_TensorHandleNumDims(TFE_TensorHandle* h, TF_Status* status) { int64_t TFE_TensorHandleDim(TFE_TensorHandle* h, int dim_index, TF_Status* status) { + if (h == nullptr || h->handle == nullptr) { + status->status = tensorflow::errors::InvalidArgument( + "The passed in handle is a nullptr"); + return -1; + } tensorflow::int64 result; status->status = h->handle->Dim(dim_index, &result); return result; } const char* TFE_TensorHandleDeviceName(TFE_TensorHandle* h, TF_Status* status) { + if (h == nullptr || h->handle == nullptr) { + status->status = tensorflow::errors::InvalidArgument( + "The passed in handle is a nullptr"); + return nullptr; + } tensorflow::Device* d = nullptr; status->status = h->handle->OpDevice(&d); return (d == nullptr) ? "/job:localhost/replica:0/task:0/device:CPU:0" @@ -368,6 +387,11 @@ const char* TFE_TensorHandleDeviceName(TFE_TensorHandle* h, TF_Status* status) { } TF_Tensor* TFE_TensorHandleResolve(TFE_TensorHandle* h, TF_Status* status) { + if (h == nullptr || h->handle == nullptr) { + status->status = tensorflow::errors::InvalidArgument( + "The passed in handle is a nullptr"); + return nullptr; + } // TODO(agarwal): move this implementation inside TFE_TensorHandle. tensorflow::Device* d = nullptr; tensorflow::Device* op_device = nullptr; @@ -700,6 +724,10 @@ TFE_Op* GetFunc(TFE_Context* ctx, const tensorflow::NameAttrList& func, } } // namespace +void TFE_ContextStartStep(TFE_Context* ctx) { ctx->context.StartStep(); } + +void TFE_ContextEndStep(TFE_Context* ctx) { ctx->context.EndStep(); } + namespace tensorflow { void SetOpAttrValueScalar(TFE_Context* ctx, TFE_Op* op, const tensorflow::AttrValue& default_value, diff --git a/tensorflow/c/eager/c_api.h b/tensorflow/c/eager/c_api.h index fdbd5374b2afe815c3a81b453930eb8f1fa351d3..a0ebc6fa0a22ed61be91c2974352c2988fb4cd92 100644 --- a/tensorflow/c/eager/c_api.h +++ b/tensorflow/c/eager/c_api.h @@ -81,16 +81,6 @@ TF_CAPI_EXPORT extern void TFE_ContextOptionsSetAsync(TFE_ContextOptions*, TF_CAPI_EXPORT extern void TFE_ContextOptionsSetDevicePlacementPolicy( TFE_ContextOptions*, TFE_ContextDevicePlacementPolicy); -// A tensorflow.ServerDef specifies remote workers (in addition to the current -// workers name). Operations created on this context can then be executed on -// any of these remote workers by setting an appropriate device. -// -// If the following is set, all servers identified by the -// ServerDef must be up when the context is created. -TF_CAPI_EXPORT extern void TFE_ContextOptionsSetServerDef( - TFE_ContextOptions* options, const void* proto, size_t proto_len, - TF_Status* status); - // Destroy an options object. TF_CAPI_EXPORT extern void TFE_DeleteContextOptions(TFE_ContextOptions*); @@ -102,8 +92,7 @@ typedef struct TFE_Context TFE_Context; TF_CAPI_EXPORT extern TFE_Context* TFE_NewContext( const TFE_ContextOptions* opts, TF_Status* status); -TF_CAPI_EXPORT extern void TFE_DeleteContext(TFE_Context* ctx, - TF_Status* status); +TF_CAPI_EXPORT extern void TFE_DeleteContext(TFE_Context* ctx); TF_CAPI_EXPORT extern TF_DeviceList* TFE_ContextListDevices(TFE_Context* ctx, TF_Status* status); @@ -128,6 +117,18 @@ TF_CAPI_EXPORT extern void TFE_ContextSetAsyncForThread(TFE_Context*, unsigned char async, TF_Status* status); +// A tensorflow.ServerDef specifies remote workers (in addition to the current +// workers name). Operations created on this context can then be executed on +// any of these remote workers by setting an appropriate device. +// +// If the following is set, all servers identified by the +// ServerDef must be up when the context is created. +TF_CAPI_EXPORT extern void TFE_ContextSetServerDef(TFE_Context* ctx, + int keep_alive_secs, + const void* proto, + size_t proto_len, + TF_Status* status); + // Causes the calling thread to block till all ops dispatched in async mode // have been executed. Note that "execution" here refers to kernel execution / // scheduling of copies, etc. Similar to sync execution, it doesn't guarantee @@ -380,6 +381,16 @@ TF_CAPI_EXPORT extern void TFE_ContextExportRunMetadata(TFE_Context* ctx, TF_Buffer* buf, TF_Status* status); +// Some TF ops need a step container to be set to limit the lifetime of some +// resources (mostly TensorArray and Stack, used in while loop gradients in +// graph mode). Calling this on a context tells it to start a step. +TF_CAPI_EXPORT extern void TFE_ContextStartStep(TFE_Context* ctx); + +// Ends a step. When there is no active step (that is, every started step has +// been ended) step containers will be cleared. Note: it is not safe to call +// TFE_ContextEndStep while ops which rely on the step container may be running. +TF_CAPI_EXPORT extern void TFE_ContextEndStep(TFE_Context* ctx); + #ifdef __cplusplus } /* end extern "C" */ #endif diff --git a/tensorflow/c/eager/c_api_internal.h b/tensorflow/c/eager/c_api_internal.h index 4c5077023d5bb3b83808bf3908e7110dd026e3ad..a5c0681e2e4eddae08954d9d0178ca96a3f8f29a 100644 --- a/tensorflow/c/eager/c_api_internal.h +++ b/tensorflow/c/eager/c_api_internal.h @@ -59,7 +59,6 @@ struct TFE_ContextOptions { // true if async execution is enabled. bool async = false; TFE_ContextDevicePlacementPolicy policy{TFE_DEVICE_PLACEMENT_SILENT}; - tensorflow::ServerDef server_def; }; struct TFE_Context { @@ -73,23 +72,6 @@ struct TFE_Context { default_policy), async, std::move(device_mgr), rendezvous) {} - explicit TFE_Context( - const tensorflow::SessionOptions& opts, - TFE_ContextDevicePlacementPolicy default_policy, bool async, - tensorflow::DeviceMgr* local_device_mgr, - tensorflow::Rendezvous* rendezvous, - std::unique_ptr server, - std::unique_ptr remote_eager_workers, - std::unique_ptr remote_device_mgr, - const tensorflow::gtl::FlatMap& - remote_contexts) - : context(opts, - static_cast( - default_policy), - async, local_device_mgr, rendezvous, std::move(server), - std::move(remote_eager_workers), std::move(remote_device_mgr), - remote_contexts) {} - tensorflow::EagerContext context; }; diff --git a/tensorflow/c/eager/c_api_test.cc b/tensorflow/c/eager/c_api_test.cc index 3504a8b5e78480732d3454097c1b2197ac2b2e17..7126227cf529023eadf38984668a40118641bb1b 100644 --- a/tensorflow/c/eager/c_api_test.cc +++ b/tensorflow/c/eager/c_api_test.cc @@ -49,7 +49,7 @@ void BM_InitOp(int iters) { } tensorflow::testing::StopTiming(); TFE_DeleteTensorHandle(m); - TFE_DeleteContext(ctx, status); + TFE_DeleteContext(ctx); CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); TF_DeleteStatus(status); } @@ -80,7 +80,7 @@ void BM_Execute(int iters, int async) { tensorflow::testing::StopTiming(); TFE_DeleteOp(matmul); TFE_DeleteTensorHandle(m); - TFE_DeleteContext(ctx, status); + TFE_DeleteContext(ctx); CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); TF_DeleteStatus(status); } @@ -95,7 +95,7 @@ TEST(CAPI, Context) { TF_DeviceList* devices = TFE_ContextListDevices(ctx, status); EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); - TFE_DeleteContext(ctx, status); + TFE_DeleteContext(ctx); EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); const int num_devices = TF_DeviceListCount(devices); @@ -108,14 +108,14 @@ TEST(CAPI, Context) { TF_DeleteStatus(status); } -tensorflow::ServerDef GetServerDef(int num_tasks) { +tensorflow::ServerDef GetServerDef(const string& job_name, int num_tasks) { tensorflow::ServerDef server_def; server_def.set_protocol("grpc"); - server_def.set_job_name("localhost"); + server_def.set_job_name(job_name); server_def.set_task_index(0); tensorflow::ClusterDef* cluster_def = server_def.mutable_cluster(); tensorflow::JobDef* job_def = cluster_def->add_job(); - job_def->set_name("localhost"); + job_def->set_name(job_name); for (int i = 0; i < num_tasks; i++) { int port = tensorflow::testing::PickUnusedPortOrDie(); job_def->mutable_tasks()->insert( @@ -124,6 +124,10 @@ tensorflow::ServerDef GetServerDef(int num_tasks) { return server_def; } +tensorflow::ServerDef GetServerDef(int num_tasks) { + return GetServerDef("localhost", num_tasks); +} + void TestRemoteExecute(bool async) { tensorflow::ServerDef server_def = GetServerDef(2); @@ -140,9 +144,6 @@ void TestRemoteExecute(bool async) { TF_Status* status = TF_NewStatus(); TFE_ContextOptions* opts = TFE_NewContextOptions(); - TFE_ContextOptionsSetServerDef(opts, serialized.data(), serialized.size(), - status); - EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); TFE_ContextOptionsSetAsync(opts, static_cast(async)); TFE_ContextOptionsSetDevicePlacementPolicy(opts, TFE_DEVICE_PLACEMENT_EXPLICIT); @@ -150,6 +151,9 @@ void TestRemoteExecute(bool async) { EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); TFE_DeleteContextOptions(opts); + TFE_ContextSetServerDef(ctx, 0, serialized.data(), serialized.size(), status); + EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_TensorHandle* h0_task0 = TestMatrixTensorHandle(); TFE_TensorHandle* h1_task0 = TestMatrixTensorHandle(); const char remote_device_name[] = @@ -195,8 +199,8 @@ void TestRemoteExecute(bool async) { TFE_DeleteOp(matmul); TFE_ContextAsyncWait(ctx, status); - TFE_DeleteContext(ctx, status); EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_DeleteContext(ctx); TF_DeleteStatus(status); @@ -229,15 +233,15 @@ void TestRemoteExecuteSilentCopies(bool async) { TF_Status* status = TF_NewStatus(); TFE_ContextOptions* opts = TFE_NewContextOptions(); - TFE_ContextOptionsSetServerDef(opts, serialized.data(), serialized.size(), - status); - EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); TFE_ContextOptionsSetAsync(opts, static_cast(async)); TFE_ContextOptionsSetDevicePlacementPolicy(opts, TFE_DEVICE_PLACEMENT_SILENT); TFE_Context* ctx = TFE_NewContext(opts, status); EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); TFE_DeleteContextOptions(opts); + TFE_ContextSetServerDef(ctx, 0, serialized.data(), serialized.size(), status); + EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_TensorHandle* h0_task0 = TestMatrixTensorHandle(); TFE_TensorHandle* h1_task0 = TestMatrixTensorHandle(); const char task1_name[] = "/job:localhost/replica:0/task:1/device:CPU:0"; @@ -281,7 +285,7 @@ void TestRemoteExecuteSilentCopies(bool async) { TFE_DeleteOp(matmul); TFE_ContextAsyncWait(ctx, status); - TFE_DeleteContext(ctx, status); + TFE_DeleteContext(ctx); EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); TF_DeleteStatus(status); @@ -296,6 +300,147 @@ TEST(CAPI, RemoteExecuteSilentCopiesAsync) { TestRemoteExecuteSilentCopies(true); } +void CheckTFE_TensorHandleHasFloats(TFE_TensorHandle* handle, + const std::vector& expected_values) { + std::unique_ptr status( + TF_NewStatus(), TF_DeleteStatus); + TF_Tensor* t = TFE_TensorHandleResolve(handle, status.get()); + ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); + std::unique_ptr actual_values(new float[expected_values.size()]); + EXPECT_EQ(sizeof(float) * expected_values.size(), TF_TensorByteSize(t)); + memcpy(actual_values.get(), TF_TensorData(t), TF_TensorByteSize(t)); + TF_DeleteTensor(t); + + for (int i = 0; i < expected_values.size(); i++) { + EXPECT_EQ(expected_values[i], actual_values[i]) + << "Mismatch in expected values at (zero-based) index " << i; + } +} + +void CheckRemoteMatMulExecutesOK(TFE_Context* ctx, + const char* remote_device_name, + const char* local_device_name) { + TF_Status* status = TF_NewStatus(); + TFE_TensorHandle* h0_task0 = TestMatrixTensorHandle(); + + TFE_Op* matmul = MatMulOp(ctx, h0_task0, h0_task0); + TFE_OpSetDevice(matmul, remote_device_name, status); + EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + + TFE_TensorHandle* retvals[1]; + int num_retvals = 1; + TFE_Execute(matmul, &retvals[0], &num_retvals, status); + EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + + auto* retval_task0 = + TFE_TensorHandleCopyToDevice(retvals[0], ctx, local_device_name, status); + ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + + CheckTFE_TensorHandleHasFloats(retval_task0, {7, 10, 15, 22}); + + TFE_DeleteTensorHandle(retval_task0); + TFE_DeleteTensorHandle(h0_task0); + TFE_DeleteTensorHandle(retvals[0]); + + TFE_DeleteOp(matmul); + + TFE_ContextAsyncWait(ctx, status); + EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TF_DeleteStatus(status); +} + +void TestRemoteExecuteChangeServerDef(bool async) { + tensorflow::ServerDef server_def = GetServerDef(2); + + // This server def has the task index set to 0. + string serialized = server_def.SerializeAsString(); + + server_def.set_task_index(1); + + std::unique_ptr worker_server; + ASSERT_TRUE(tensorflow::GrpcServer::Create( + server_def, tensorflow::Env::Default(), &worker_server) + .ok()); + ASSERT_TRUE(worker_server->Start().ok()); + + TF_Status* status = TF_NewStatus(); + TFE_ContextOptions* opts = TFE_NewContextOptions(); + TFE_ContextOptionsSetAsync(opts, static_cast(async)); + TFE_ContextOptionsSetDevicePlacementPolicy(opts, TFE_DEVICE_PLACEMENT_SILENT); + TFE_Context* ctx = TFE_NewContext(opts, status); + EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_DeleteContextOptions(opts); + + TFE_ContextSetServerDef(ctx, 0, serialized.data(), serialized.size(), status); + EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + + const char remote_device_name[] = + "/job:localhost/replica:0/task:1/device:CPU:0"; + const char local_device_name[] = + "/job:localhost/replica:0/task:0/device:CPU:0"; + CheckRemoteMatMulExecutesOK(ctx, remote_device_name, local_device_name); + + TFE_ContextAsyncWait(ctx, status); + EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + + // TODO(nareshmodi): Figure out how to correctly shut the server down. + worker_server.release(); + + // Update the server def with a new set of names (worker instead of + // localhost). + tensorflow::ServerDef updated_server_def = GetServerDef("worker", 2); + serialized = updated_server_def.SerializeAsString(); + + updated_server_def.set_task_index(1); + tensorflow::Status s = tensorflow::GrpcServer::Create( + updated_server_def, tensorflow::Env::Default(), &worker_server); + ASSERT_TRUE(s.ok()) << s.error_message(); + ASSERT_TRUE(worker_server->Start().ok()); + + TFE_ContextSetServerDef(ctx, 0, serialized.data(), serialized.size(), status); + EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + + // Create a new tensor_handle. + TFE_TensorHandle* h0_task0_new = TestMatrixTensorHandle(); + + // Check that copying it to the old remote device (named localhost) fails. + TFE_TensorHandleCopyToDevice(h0_task0_new, ctx, remote_device_name, status); + EXPECT_NE(TF_OK, TF_GetCode(status)) << TF_Message(status); + + // Copying and executing on the new remote device works. + const char new_remote_device_name[] = + "/job:worker/replica:0/task:1/device:CPU:0"; + const char new_local_device_name[] = + "/job:worker/replica:0/task:0/device:CPU:0"; + + auto* h0_task1_new = TFE_TensorHandleCopyToDevice( + h0_task0_new, ctx, new_remote_device_name, status); + EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + + TFE_DeleteTensorHandle(h0_task0_new); + TFE_DeleteTensorHandle(h0_task1_new); + + CheckRemoteMatMulExecutesOK(ctx, new_remote_device_name, + new_local_device_name); + + TFE_ContextAsyncWait(ctx, status); + EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + + TF_DeleteStatus(status); + + TFE_DeleteContext(ctx); + + // TODO(nareshmodi): Figure out how to correctly shut the server down. + worker_server.release(); +} + +TEST(CAPI, RemoteExecuteChangeServerDef) { + TestRemoteExecuteChangeServerDef(false); +} +TEST(CAPI, RemoteExecuteChangeServerDefAsync) { + TestRemoteExecuteChangeServerDef(true); +} + TEST(CAPI, TensorHandle) { TFE_TensorHandle* h = TestMatrixTensorHandle(); EXPECT_EQ(TF_FLOAT, TFE_TensorHandleDataType(h)); @@ -380,8 +525,7 @@ void TensorHandleCopyBetweenDevices(bool async) { TF_DeleteDeviceList(devices); TF_DeleteTensor(t); TFE_DeleteTensorHandle(hcpu); - TFE_DeleteContext(ctx, status.get()); - EXPECT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); + TFE_DeleteContext(ctx); } TEST(CAPI, TensorHandleCopyBetweenDevices) { @@ -418,7 +562,7 @@ void TensorHandleCopyBetweenDevicesError(bool async) { TFE_DeleteTensorHandle(hcopy); TFE_DeleteTensorHandle(hcpu); if (hdevice != nullptr) TFE_DeleteTensorHandle(hdevice); - TFE_DeleteContext(ctx, status.get()); + TFE_DeleteContext(ctx); } TEST(CAPI, TensorHandleCopyBetweenDevicesError) { @@ -451,7 +595,7 @@ void TensorHandleCopyBetweenTwoGPUDevices(bool async) { TF_DeleteDeviceList(devices); TF_DeleteTensor(t); TFE_DeleteTensorHandle(hcpu); - TFE_DeleteContext(ctx, status.get()); + TFE_DeleteContext(ctx); return; } const string gpu_1_name(TF_DeviceListName(devices, 1, status.get())); @@ -484,8 +628,7 @@ void TensorHandleCopyBetweenTwoGPUDevices(bool async) { TF_DeleteDeviceList(devices); TF_DeleteTensor(t); TFE_DeleteTensorHandle(hcpu); - TFE_DeleteContext(ctx, status.get()); - EXPECT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); + TFE_DeleteContext(ctx); } TEST(CAPI, TensorHandleCopyBetweenTwoGPUDevices) { @@ -533,8 +676,7 @@ void TensorHandleSilentCopy(bool async) { TFE_DeleteTensorHandle(hcpu); TFE_ContextAsyncWait(ctx, status.get()); EXPECT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); - TFE_DeleteContext(ctx, status.get()); - EXPECT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); + TFE_DeleteContext(ctx); } TEST(CAPI, TensorHandleSilentCopy) { TensorHandleSilentCopy(false); } @@ -580,8 +722,7 @@ void TensorHandleSilentCopyLocal(bool async) { TFE_DeleteTensorHandle(hcpu); TFE_ContextAsyncWait(ctx, status.get()); EXPECT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); - TFE_DeleteContext(ctx, status.get()); - EXPECT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); + TFE_DeleteContext(ctx); } TEST(CAPI, TensorHandleSilentCopyLocal) { TensorHandleSilentCopyLocal(false); } TEST(CAPI, TensorHandleSilentCopyLocalAsync) { @@ -614,11 +755,47 @@ void SetAndGetOpDevices(bool async) { TFE_DeleteOp(matmul); TFE_DeleteTensorHandle(m); - TFE_DeleteContext(ctx, status); + TFE_DeleteContext(ctx); ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); TF_DeleteStatus(status); } +TEST(CAPI, TensorHandleNullptr) { + TFE_TensorHandle* h = nullptr; + std::unique_ptr status( + TF_NewStatus(), TF_DeleteStatus); + + TF_Tensor* t = TFE_TensorHandleResolve(h, status.get()); + ASSERT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(status.get())); + ASSERT_EQ(t, nullptr); + ASSERT_EQ("The passed in handle is a nullptr", + string(TF_Message(status.get()))); + + TF_SetStatus(status.get(), TF_OK, ""); + + const char* device_name = TFE_TensorHandleDeviceName(h, status.get()); + ASSERT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(status.get())); + ASSERT_EQ(device_name, nullptr); + ASSERT_EQ("The passed in handle is a nullptr", + string(TF_Message(status.get()))); + + TF_SetStatus(status.get(), TF_OK, ""); + + int num_dims = TFE_TensorHandleNumDims(h, status.get()); + ASSERT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(status.get())); + ASSERT_EQ(num_dims, -1); + ASSERT_EQ("The passed in handle is a nullptr", + string(TF_Message(status.get()))); + + TF_SetStatus(status.get(), TF_OK, ""); + + int dim = TFE_TensorHandleDim(h, 0, status.get()); + ASSERT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(status.get())); + ASSERT_EQ(dim, -1); + ASSERT_EQ("The passed in handle is a nullptr", + string(TF_Message(status.get()))); +} + void Execute_MatMul_CPU(bool async) { TF_Status* status = TF_NewStatus(); TFE_ContextOptions* opts = TFE_NewContextOptions(); @@ -640,7 +817,7 @@ void Execute_MatMul_CPU(bool async) { TF_Tensor* t = TFE_TensorHandleResolve(retvals[0], status); ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); TFE_DeleteTensorHandle(retvals[0]); - TFE_DeleteContext(ctx, status); + TFE_DeleteContext(ctx); ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); float product[4] = {0}; EXPECT_EQ(sizeof(product), TF_TensorByteSize(t)); @@ -712,7 +889,7 @@ void Execute_MatMul_CPU_Runtime_Error(bool async) { TFE_DeleteTensorHandle(m1); TFE_DeleteTensorHandle(m2); TFE_DeleteTensorHandle(retvals[0]); - TFE_DeleteContext(ctx, status); + TFE_DeleteContext(ctx); TF_DeleteStatus(status); } TEST(CAPI, Execute_MatMul_CPU_Runtime_Error) { @@ -743,7 +920,7 @@ void Execute_MatMul_CPU_Type_Error(bool async) { if (retvals[0] != nullptr) { TFE_DeleteTensorHandle(retvals[0]); } - TFE_DeleteContext(ctx, status); + TFE_DeleteContext(ctx); TF_DeleteStatus(status); } @@ -781,7 +958,7 @@ TEST(CAPI, Execute_Min_CPU) { TF_DeleteTensor(t); EXPECT_EQ(1, output[0]); EXPECT_EQ(3, output[1]); - TFE_DeleteContext(ctx, status); + TFE_DeleteContext(ctx); ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); TF_DeleteStatus(status); } @@ -823,7 +1000,7 @@ void Execute_MatMul_XLA_CPU(bool async) { EXPECT_EQ(10, product[1]); EXPECT_EQ(15, product[2]); EXPECT_EQ(22, product[3]); - TFE_DeleteContext(ctx, status); + TFE_DeleteContext(ctx); TF_DeleteStatus(status); } TEST(CAPI, Execute_MatMul_XLA_CPU) { Execute_MatMul_XLA_CPU(false); } @@ -862,7 +1039,7 @@ void Execute_Min_XLA_CPU(bool async) { TF_DeleteTensor(t); EXPECT_EQ(1, output[0]); EXPECT_EQ(3, output[1]); - TFE_DeleteContext(ctx, status); + TFE_DeleteContext(ctx); TF_DeleteStatus(status); } TEST(CAPI, Execute_Min_XLA_CPU) { Execute_Min_XLA_CPU(false); } @@ -898,7 +1075,7 @@ void ExecuteWithTracing(bool async) { TF_Tensor* t = TFE_TensorHandleResolve(retvals[0], status); TFE_DeleteTensorHandle(retvals[0]); - TFE_DeleteContext(ctx, status); + TFE_DeleteContext(ctx); ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); float product[4] = {0}; EXPECT_EQ(sizeof(product), TF_TensorByteSize(t)); @@ -974,7 +1151,7 @@ TEST(CAPI, Function_ident_CPU) { TF_DeleteTensor(r); TFE_DeleteTensorHandle(result[0]); } - TFE_DeleteContext(ctx, status); + TFE_DeleteContext(ctx); ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status); TF_DeleteStatus(status); } @@ -1044,7 +1221,7 @@ TEST(CAPI, Function_ident_XLA_CPU) { TF_DeleteTensor(r); TFE_DeleteTensorHandle(result[0]); } - TFE_DeleteContext(ctx, status); + TFE_DeleteContext(ctx); ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status); TF_DeleteStatus(status); } @@ -1120,7 +1297,7 @@ void FunctionDefAndExecute(bool async) { EXPECT_EQ(10, product[1]); EXPECT_EQ(15, product[2]); EXPECT_EQ(22, product[3]); - TFE_DeleteContext(ctx, status); + TFE_DeleteContext(ctx); EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); TF_DeleteStatus(status); } @@ -1161,7 +1338,7 @@ void BM_ExecuteFunction(int iters, int async) { tensorflow::testing::StopTiming(); TFE_DeleteTensorHandle(m); TFE_DeleteTensorHandle(retval[0]); - TFE_DeleteContext(ctx, status); + TFE_DeleteContext(ctx); EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); TF_DeleteStatus(status); } @@ -1249,7 +1426,7 @@ TEST(CAPI, Variables) { TFE_DeleteTensorHandle(var_handle); TFE_DeleteTensorHandle(value_handle); - TFE_DeleteContext(ctx, status); + TFE_DeleteContext(ctx); CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); TF_DeleteStatus(status); } @@ -1288,10 +1465,67 @@ void BM_ReadVariable(int iters) { TFE_DeleteOp(op); TFE_DeleteTensorHandle(var_handle); - TFE_DeleteContext(ctx, status); + TFE_DeleteContext(ctx); CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); TF_DeleteStatus(status); } BENCHMARK(BM_ReadVariable); +TEST(CAPI, StringAttributes) { + // Test that TFE_OpSetAttrString doesn't hold on to the value after it + // returns. + TF_Status* status = TF_NewStatus(); + TFE_ContextOptions* opts = TFE_NewContextOptions(); + TFE_Context* ctx = TFE_NewContext(opts, status); + ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_DeleteContextOptions(opts); + + std::vector dims(4, 1); + TFE_Op* op = TFE_NewOp(ctx, "AvgPool", status); + ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + + TF_Tensor* tensor = + TF_AllocateTensor(TF_FLOAT, dims.data(), dims.size(), sizeof(float)); + float tensor_data[] = {1}; + memcpy(TF_TensorData(tensor), tensor_data, TF_TensorByteSize(tensor)); + TFE_TensorHandle* tensor_handle = TFE_NewTensorHandle(tensor, status); + ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_OpAddInput(op, tensor_handle, status); + TF_DeleteTensor(tensor); + TFE_DeleteTensorHandle(tensor_handle); + + std::vector values(4, 1); + TFE_OpSetAttrIntList(op, "ksize", values.data(), values.size()); + TFE_OpSetAttrIntList(op, "strides", values.data(), values.size()); + + const int BUFFER_SIZE = 10; + char buffer[BUFFER_SIZE]; + std::strncpy(buffer, "VALID", BUFFER_SIZE); + TFE_OpSetAttrString(op, "padding", buffer, std::strlen(buffer)); + // Overwriting value in "buffer", should be fine since TFE_Op + // shouldn't be holding on to it. + std::strncpy(buffer, "NHWC", BUFFER_SIZE); + TFE_OpSetAttrString(op, "data_format", buffer, std::strlen(buffer)); + + TFE_OpSetAttrType(op, "T", TF_FLOAT); + + ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + + TFE_TensorHandle* retvals[1]; + int num_retvals = 1; + TFE_Execute(op, &retvals[0], &num_retvals, status); + ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + ASSERT_EQ(1, num_retvals); + + tensor = TFE_TensorHandleResolve(retvals[0], status); + ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + EXPECT_EQ(4, TF_TensorByteSize(tensor)); + TF_DeleteTensor(tensor); + TFE_DeleteTensorHandle(retvals[0]); + + TFE_DeleteOp(op); + + TFE_DeleteContext(ctx); + TF_DeleteStatus(status); +} } // namespace diff --git a/tensorflow/cc/BUILD b/tensorflow/cc/BUILD index a98f0b00b2c70055f697ed4f15cb14708384b62f..f56521dac0374849081fe94f16feb08e55647b56 100644 --- a/tensorflow/cc/BUILD +++ b/tensorflow/cc/BUILD @@ -121,6 +121,7 @@ cc_library( deps = [ ":array_grad", ":data_flow_grad", + ":image_grad", ":math_grad", ":nn_grad", ], @@ -331,6 +332,36 @@ tf_cc_test( ], ) +cc_library( + name = "image_grad", + srcs = ["gradients/image_grad.cc"], + deps = [ + ":cc_ops", + ":cc_ops_internal", + ":grad_op_registry", + ":gradients", + ], + alwayslink = 1, +) + +tf_cc_test( + name = "gradients_image_grad_test", + srcs = ["gradients/image_grad_test.cc"], + deps = [ + ":cc_ops", + ":client_session", + ":grad_op_registry", + ":grad_testutil", + ":gradient_checker", + ":image_grad", + ":testutil", + "//tensorflow/core:lib_internal", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core:testlib", + ], +) + cc_library( name = "math_grad", srcs = ["gradients/math_grad.cc"], @@ -348,9 +379,11 @@ tf_cc_test( srcs = ["gradients/math_grad_test.cc"], deps = [ ":cc_ops", + ":client_session", ":grad_op_registry", ":grad_testutil", ":gradient_checker", + ":gradients", ":math_grad", ":testutil", "//tensorflow/core:lib_internal", @@ -595,7 +628,6 @@ tf_cc_binary( copts = tf_copts(), linkopts = select({ "//tensorflow:windows": [], - "//tensorflow:windows_msvc": [], "//tensorflow:darwin": [ "-lm", "-lpthread", diff --git a/tensorflow/cc/client/client_session.cc b/tensorflow/cc/client/client_session.cc index ba056a8f3a84910aebf5079573cb64c19f41469d..0e61089a5950ee894ad5489317757cff8a85e966 100644 --- a/tensorflow/cc/client/client_session.cc +++ b/tensorflow/cc/client/client_session.cc @@ -127,4 +127,22 @@ Status ClientSession::Run(const RunOptions& run_options, const FeedType& inputs, target_node_names, outputs, run_metadata); } +Status ClientSession::MakeCallable(const CallableOptions& callable_options, + CallableHandle* out_handle) { + TF_RETURN_IF_ERROR(impl()->MaybeExtendGraph()); + return impl()->session_->MakeCallable(callable_options, out_handle); +} + +Status ClientSession::RunCallable(CallableHandle handle, + const std::vector& feed_tensors, + std::vector* fetch_tensors, + RunMetadata* run_metadata) { + return impl()->session_->RunCallable(handle, feed_tensors, fetch_tensors, + run_metadata); +} + +Status ClientSession::ReleaseCallable(CallableHandle handle) { + return impl()->session_->ReleaseCallable(handle); +} + } // end namespace tensorflow diff --git a/tensorflow/cc/client/client_session.h b/tensorflow/cc/client/client_session.h index 5fb4109f7d15d5997f745acd913e60a02855fd73..7dd653eec4ec729b652cb779d06e820bfb437b3c 100644 --- a/tensorflow/cc/client/client_session.h +++ b/tensorflow/cc/client/client_session.h @@ -87,7 +87,33 @@ class ClientSession { const std::vector& run_outputs, std::vector* outputs, RunMetadata* run_metadata) const; - // TODO(keveman): Add support for partial run. + /// \brief A handle to a subgraph, created with + /// `ClientSession::MakeCallable()`. + typedef int64 CallableHandle; + + /// \brief Creates a `handle` for invoking the subgraph defined by + /// `callable_options`. + /// NOTE: This API is still experimental and may change. + Status MakeCallable(const CallableOptions& callable_options, + CallableHandle* out_handle); + + /// \brief Invokes the subgraph named by `handle` with the given options and + /// input tensors. + /// + /// The order of tensors in `feed_tensors` must match the order of names in + /// `CallableOptions::feed()` and the order of tensors in `fetch_tensors` will + /// match the order of names in `CallableOptions::fetch()` when this subgraph + /// was created. + /// NOTE: This API is still experimental and may change. + Status RunCallable(CallableHandle handle, + const std::vector& feed_tensors, + std::vector* fetch_tensors, + RunMetadata* run_metadata); + + /// \brief Releases resources associated with the given `handle` in this + /// session. + /// NOTE: This API is still experimental and may change. + Status ReleaseCallable(CallableHandle handle); private: class Impl; diff --git a/tensorflow/cc/client/client_session_test.cc b/tensorflow/cc/client/client_session_test.cc index ea5cf5a1f12be316cc6e0d0a02cd3caf4d177400..559ffea7e817526e7f1396cd0e8187d01364f23b 100644 --- a/tensorflow/cc/client/client_session_test.cc +++ b/tensorflow/cc/client/client_session_test.cc @@ -95,5 +95,26 @@ TEST(ClientSessionTest, MultiThreaded) { test::ExpectTensorEqual(outputs[0], test::AsTensor({-1, 2}, {2})); } +TEST(ClientSessionTest, Callable) { + Scope root = Scope::NewRootScope(); + auto a = Placeholder(root, DT_INT32); + auto b = Placeholder(root, DT_INT32); + auto c = Add(root, a, b); + ClientSession session(root); + std::vector outputs; + + CallableOptions options; + options.add_feed(a.node()->name()); + options.add_feed(b.node()->name()); + options.add_fetch(c.node()->name()); + ClientSession::CallableHandle callable; + TF_CHECK_OK(session.MakeCallable(options, &callable)); + TF_EXPECT_OK(session.RunCallable( + callable, {test::AsTensor({1}, {}), test::AsTensor({41}, {})}, + &outputs, nullptr)); + test::ExpectTensorEqual(outputs[0], test::AsTensor({42}, {})); + TF_EXPECT_OK(session.ReleaseCallable(callable)); +} + } // namespace } // namespace tensorflow diff --git a/tensorflow/cc/framework/cc_op_gen.cc b/tensorflow/cc/framework/cc_op_gen.cc index dfdef88945deca376368edd6f7aa322b1e1cbf94..c20ea95a15e3f53b9b26716ed7b624fa853017c9 100644 --- a/tensorflow/cc/framework/cc_op_gen.cc +++ b/tensorflow/cc/framework/cc_op_gen.cc @@ -508,15 +508,6 @@ bool HasOptionalAttrs( return false; } -const ApiDef::Arg* FindInputArg(StringPiece name, const ApiDef& api_def) { - for (int i = 0; i < api_def.in_arg_size(); ++i) { - if (api_def.in_arg(i).name() == name) { - return &api_def.in_arg(i); - } - } - return nullptr; -} - struct OpInfo { // graph_op_def: The OpDef used by the runtime, has the names that // must be used when calling NodeBuilder. diff --git a/tensorflow/cc/framework/gradient_checker.cc b/tensorflow/cc/framework/gradient_checker.cc index de2645cb440bda1f35e764af9197ca97bb760c08..e9f9c59e3aa0e8a9dc5d5e658540e9da73adaca5 100644 --- a/tensorflow/cc/framework/gradient_checker.cc +++ b/tensorflow/cc/framework/gradient_checker.cc @@ -247,7 +247,7 @@ Status ComputeNumericJacobianTranspose(const Scope& scope, const OutputList& xs, auto y_pos_flat = y_pos[y_idx].flat(); auto y_neg_flat = y_neg[y_idx].flat(); const int64 y_size = y_shapes[y_idx].num_elements(); - const Y_T scale = Y_T{2 * delta}; + const Y_T scale = 2 * delta; auto jacobian = (*jacobian_ts)[x_idx * y_num + y_idx].matrix(); for (int c = 0; c < y_size; ++c) { SetJacobian(&jacobian, r * x_stride + unit_dimension, @@ -351,7 +351,14 @@ Status ComputeGradientErrorInternal(const Scope& scope, const OutputList& xs, auto jac_n = jacobian_ns[i].matrix(); for (int r = 0; r < jacobian_ts[i].dim_size(0); ++r) { for (int c = 0; c < jacobian_ts[i].dim_size(1); ++c) { - *max_error = std::max(*max_error, std::fabs(jac_t(r, c) - jac_n(r, c))); + auto cur_error = std::fabs(jac_t(r, c) - jac_n(r, c)); + // Treat any NaN as max_error and immediately return. + // (Note that std::max may ignore NaN arguments.) + if (std::isnan(cur_error)) { + *max_error = cur_error; + return Status::OK(); + } + *max_error = std::max(*max_error, cur_error); } } } @@ -409,6 +416,7 @@ Status ComputeGradientError(const Scope& scope, const Output& x, const Output& y, const TensorShape& y_shape, JAC_T* max_error); INSTANTIATE_GRAD_ERR_TYPE(float, float, float); +INSTANTIATE_GRAD_ERR_TYPE(double, float, double); INSTANTIATE_GRAD_ERR_TYPE(double, double, double); INSTANTIATE_GRAD_ERR_TYPE(complex64, float, float); INSTANTIATE_GRAD_ERR_TYPE(float, complex64, float); diff --git a/tensorflow/cc/framework/gradient_checker_test.cc b/tensorflow/cc/framework/gradient_checker_test.cc index d4f0a7f5ab3716be41e22c02a21aca028f76fb88..8dd762c282eff287bddd49ea6f38b2b8060949b0 100644 --- a/tensorflow/cc/framework/gradient_checker_test.cc +++ b/tensorflow/cc/framework/gradient_checker_test.cc @@ -28,12 +28,14 @@ namespace { using ops::Complex; using ops::Const; +using ops::Div; using ops::MatMul; using ops::Placeholder; using ops::Real; using ops::Split; using ops::Square; using ops::Stack; +using ops::Sub; using ops::Unstack; TEST(GradientCheckerTest, BasicFloat) { @@ -104,6 +106,20 @@ TEST(GradientCheckerTest, Complex64ToFloat) { EXPECT_LT(max_error, 1e-4); } +// When calculating gradients that are undefined, test we get NaN +// as the computed error rather than 0. +TEST(GradientCheckerTest, BasicNan) { + Scope scope = Scope::NewRootScope(); + TensorShape shape({2, 4, 3}); + auto x = Placeholder(scope, DT_FLOAT, Placeholder::Shape(shape)); + // y = x/(x-x) should always return NaN + auto y = Div(scope, x, Sub(scope, x, x)); + float max_error; + TF_ASSERT_OK((ComputeGradientError( + scope, {x}, {shape}, {y}, {shape}, &max_error))); + EXPECT_TRUE(std::isnan(max_error)); +} + TEST(GradientCheckerTest, MatMulGrad) { Scope scope = Scope::NewRootScope(); diff --git a/tensorflow/cc/gradients/array_grad.cc b/tensorflow/cc/gradients/array_grad.cc index b353accddcb6db9a07c112de03ead2f02c4ee6a6..e9173227aadbf86eab666e6c17bacacb92888572 100644 --- a/tensorflow/cc/gradients/array_grad.cc +++ b/tensorflow/cc/gradients/array_grad.cc @@ -120,6 +120,24 @@ Status SplitGrad(const Scope& scope, const Operation& op, } REGISTER_GRADIENT_OP("Split", SplitGrad); +Status FillGrad(const Scope& scope, const Operation& op, + const std::vector& grad_inputs, + std::vector* grad_outputs) { + // y = fill(fill_shape, x) + // No gradient returned for the fill_shape argument. + grad_outputs->push_back(NoGradient()); + // The gradient for x (which must be a scalar) is just the sum of + // all the gradients from the shape it fills. + // We use ReduceSum to implement this, which needs an argument providing + // the indices of all the dimensions of the incoming gradient. + // grad(x) = reduce_sum(grad(y), [0..rank(grad(y))]) + auto all_dims = Range(scope, Const(scope, 0), Rank(scope, grad_inputs[0]), + Const(scope, 1)); + grad_outputs->push_back(ReduceSum(scope, grad_inputs[0], all_dims)); + return scope.status(); +} +REGISTER_GRADIENT_OP("Fill", FillGrad); + Status DiagGrad(const Scope& scope, const Operation& op, const std::vector& grad_inputs, std::vector* grad_outputs) { diff --git a/tensorflow/cc/gradients/array_grad_test.cc b/tensorflow/cc/gradients/array_grad_test.cc index d09275b6487b4212aa35a0476002f2bb587fa210..f41de3dc2098df55fbbb616557f264a4e70db6b6 100644 --- a/tensorflow/cc/gradients/array_grad_test.cc +++ b/tensorflow/cc/gradients/array_grad_test.cc @@ -108,6 +108,14 @@ TEST_F(ArrayGradTest, SplitGrad) { RunTest({x}, {x_shape}, y.output, {y_shape, y_shape}); } +TEST_F(ArrayGradTest, FillGrad) { + TensorShape x_shape({}); + auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x_shape)); + TensorShape y_shape({2, 5, 3}); + auto y = Fill(scope_, {2, 5, 3}, x); + RunTest(x, x_shape, y, y_shape); +} + TEST_F(ArrayGradTest, DiagGrad) { TensorShape x_shape({5, 2}); auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x_shape)); diff --git a/tensorflow/cc/gradients/image_grad.cc b/tensorflow/cc/gradients/image_grad.cc new file mode 100644 index 0000000000000000000000000000000000000000..882709e1e2817431a32c453fe0f35f2b2e6c69b0 --- /dev/null +++ b/tensorflow/cc/gradients/image_grad.cc @@ -0,0 +1,74 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include "tensorflow/cc/framework/grad_op_registry.h" +#include "tensorflow/cc/framework/gradients.h" +#include "tensorflow/cc/ops/image_ops_internal.h" +#include "tensorflow/cc/ops/standard_ops.h" + +namespace tensorflow { +namespace ops { +namespace { + +Status ResizeNearestNeighborGradHelper(const Scope& scope, const Operation& op, + const std::vector& grad_inputs, + std::vector* grad_outputs) { + bool align_corners; + TF_RETURN_IF_ERROR( + GetNodeAttr(op.node()->attrs(), "align_corners", &align_corners)); + // The internal gradient implementation needs the shape of the input image. + // x_shape = shape(x)[1:3] + // = slice(shape(x), {1}, {3 - 1}) + auto x_shape = Slice(scope, Shape(scope, op.input(0)), {1}, {2}); + grad_outputs->push_back(internal::ResizeNearestNeighborGrad( + scope, grad_inputs[0], x_shape, + internal::ResizeNearestNeighborGrad::AlignCorners(align_corners))); + grad_outputs->push_back(NoGradient()); + return scope.status(); +} +REGISTER_GRADIENT_OP("ResizeNearestNeighbor", ResizeNearestNeighborGradHelper); + +Status ResizeBilinearGradHelper(const Scope& scope, const Operation& op, + const std::vector& grad_inputs, + std::vector* grad_outputs) { + bool align_corners; + TF_RETURN_IF_ERROR( + GetNodeAttr(op.node()->attrs(), "align_corners", &align_corners)); + grad_outputs->push_back(internal::ResizeBilinearGrad( + scope, grad_inputs[0], op.input(0), + internal::ResizeBilinearGrad::AlignCorners(align_corners))); + grad_outputs->push_back(NoGradient()); + return scope.status(); +} +REGISTER_GRADIENT_OP("ResizeBilinear", ResizeBilinearGradHelper); + +Status ResizeBicubicGradHelper(const Scope& scope, const Operation& op, + const std::vector& grad_inputs, + std::vector* grad_outputs) { + bool align_corners; + TF_RETURN_IF_ERROR( + GetNodeAttr(op.node()->attrs(), "align_corners", &align_corners)); + grad_outputs->push_back(internal::ResizeBicubicGrad( + scope, grad_inputs[0], op.input(0), + internal::ResizeBicubicGrad::AlignCorners(align_corners))); + grad_outputs->push_back(NoGradient()); + return scope.status(); +} +REGISTER_GRADIENT_OP("ResizeBicubic", ResizeBicubicGradHelper); + +} // anonymous namespace +} // namespace ops +} // namespace tensorflow diff --git a/tensorflow/cc/gradients/image_grad_test.cc b/tensorflow/cc/gradients/image_grad_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..2e55c7561b030c50bd67bd53fd0d55710085c5d2 --- /dev/null +++ b/tensorflow/cc/gradients/image_grad_test.cc @@ -0,0 +1,157 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/cc/client/client_session.h" +#include "tensorflow/cc/framework/grad_op_registry.h" +#include "tensorflow/cc/framework/gradient_checker.h" +#include "tensorflow/cc/framework/testutil.h" +#include "tensorflow/cc/gradients/grad_testutil.h" +#include "tensorflow/cc/ops/image_ops.h" +#include "tensorflow/cc/ops/standard_ops.h" +#include "tensorflow/core/framework/tensor_testutil.h" +#include "tensorflow/core/lib/core/status_test_util.h" + +namespace tensorflow { +namespace { + +using ops::Const; +using ops::ResizeBicubic; +using ops::ResizeBilinear; +using ops::ResizeNearestNeighbor; + +class ImageGradTest : public ::testing::Test { + protected: + ImageGradTest() : scope_(Scope::NewRootScope()) {} + + enum OpType { RESIZE_NEAREST, RESIZE_BILINEAR, RESIZE_BICUBIC }; + + template + Tensor MakeData(const TensorShape& data_shape) { + DataType data_type = DataTypeToEnum::v(); + Tensor data(data_type, data_shape); + auto data_flat = data.flat(); + for (int i = 0; i < data_flat.size(); ++i) { + data_flat(i) = T(i); + } + return data; + } + + template + void MakeOp(const OpType op_type, const Tensor& x_data, const Input& y_shape, + const bool align_corners, Output* x, Output* y) { + *x = Const(scope_, x_data); + switch (op_type) { + case RESIZE_NEAREST: + *y = ResizeNearestNeighbor( + scope_, *x, y_shape, + ResizeNearestNeighbor::AlignCorners(align_corners)); + return; + case RESIZE_BILINEAR: + *y = ResizeBilinear(scope_, *x, y_shape, + ResizeBilinear::AlignCorners(align_corners)); + return; + case RESIZE_BICUBIC: + *y = ResizeBicubic(scope_, *x, y_shape, + ResizeBicubic::AlignCorners(align_corners)); + return; + } + assert(false); + } + + template + void TestResizedShapeForType(const OpType op_type, const bool align_corners) { + TensorShape x_shape({1, 2, 2, 1}); + Tensor x_data = MakeData(x_shape); + Output x, y; + MakeOp(op_type, x_data, {4, 6}, align_corners, &x, &y); + + ClientSession session(scope_); + std::vector outputs; + TF_ASSERT_OK(session.Run({y}, &outputs)); + EXPECT_EQ(outputs.size(), 1); + EXPECT_EQ(outputs[0].shape(), TensorShape({1, 4, 6, 1})); + } + + void TestResizedShape(OpType op_type) { + for (const bool align_corners : {true, false}) { + TestResizedShapeForType(op_type, align_corners); + TestResizedShapeForType(op_type, align_corners); + TestResizedShapeForType(op_type, align_corners); + } + } + + template + void TestResizeToSmallerAndAlign(const OpType op_type, + const bool align_corners) { + TensorShape x_shape({1, 4, 6, 1}); + Tensor x_data = MakeData(x_shape); + Output x, y; + MakeOp(op_type, x_data, {2, 3}, align_corners, &x, &y); + JAC_T max_error; + TF_ASSERT_OK((ComputeGradientError( + scope_, x, x_data, y, {1, 2, 3, 1}, &max_error))); + EXPECT_LT(max_error, 1e-3); + } + + template + void TestResizeToLargerAndAlign(const OpType op_type, + const bool align_corners) { + TensorShape x_shape({1, 2, 3, 1}); + Tensor x_data = MakeData(x_shape); + Output x, y; + MakeOp(op_type, x_data, {4, 6}, align_corners, &x, &y); + JAC_T max_error; + TF_ASSERT_OK((ComputeGradientError( + scope_, x, x_data, y, {1, 4, 6, 1}, &max_error))); + EXPECT_LT(max_error, 1e-3); + } + + template + void TestResize(OpType op_type) { + for (const bool align_corners : {true, false}) { + TestResizeToSmallerAndAlign(op_type, align_corners); + TestResizeToLargerAndAlign(op_type, align_corners); + } + } + + Scope scope_; +}; + +TEST_F(ImageGradTest, TestNearestNeighbor) { + TestResizedShape(RESIZE_NEAREST); + TestResize(RESIZE_NEAREST); + TestResize(RESIZE_NEAREST); +} + +TEST_F(ImageGradTest, TestBilinear) { + TestResizedShape(RESIZE_BILINEAR); + TestResize(RESIZE_BILINEAR); + // Note that Y_T is always float for this op. We choose + // double for the jacobian to capture the higher precision + // between X_T and Y_T. + TestResize(RESIZE_BILINEAR); +} + +TEST_F(ImageGradTest, TestBicubic) { + TestResizedShape(RESIZE_BICUBIC); + TestResize(RESIZE_BICUBIC); + // Note that Y_T is always float for this op. We choose + // double for the jacobian to capture the higher precision + // between X_T and Y_T. + TestResize(RESIZE_BICUBIC); +} + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/cc/gradients/math_grad.cc b/tensorflow/cc/gradients/math_grad.cc index 35a01e0341cb08c9b314908b6dcd76fd99c1e68b..1329b568ab8d4cc5cc5eed554e74bf1100d9bdcf 100644 --- a/tensorflow/cc/gradients/math_grad.cc +++ b/tensorflow/cc/gradients/math_grad.cc @@ -441,6 +441,21 @@ Status RealDivGrad(const Scope& scope, const Operation& op, } REGISTER_GRADIENT_OP("RealDiv", RealDivGrad); +Status DivNoNanGrad(const Scope& scope, const Operation& op, + const std::vector& grad_inputs, + std::vector* grad_outputs) { + auto x_1 = ConjugateHelper(scope, op.input(0)); + auto x_2 = ConjugateHelper(scope, op.input(1)); + // y = x_1 / x_2 + // dy/dx_1 = 1/x_2 + // dy/dx_2 = -x_1/x_2^2 + auto gx_1 = DivNoNan(scope, grad_inputs[0], x_2); + auto gx_2 = Mul(scope, grad_inputs[0], + DivNoNan(scope, DivNoNan(scope, Neg(scope, x_1), x_2), x_2)); + return BinaryGradCommon(scope, op, grad_outputs, gx_1, gx_2); +} +REGISTER_GRADIENT_OP("DivNoNan", DivNoNanGrad); + Status SquaredDifferenceGrad(const Scope& scope, const Operation& op, const std::vector& grad_inputs, std::vector* grad_outputs) { @@ -1007,6 +1022,26 @@ Status ProdGrad(const Scope& scope, const Operation& op, } REGISTER_GRADIENT_OP("Prod", ProdGrad); +Status SegmentSumGrad(const Scope& scope, const Operation& op, + const std::vector& grad_inputs, + std::vector* grad_outputs) { + // The SegmentSum operation sums segments of the Tensor that have the same + // index in the segment_ids parameter. + // i.e z = [2, 3, 4, 5], segment_ids [0, 0, 0, 1] + // will produce [2 + 3 + 4, 5] = [9, 5] + // The gradient that will flow back to the gather operation will look like + // [x1, x2], it will have the same shape as the output of the SegmentSum + // operation. The differentiation step of the SegmentSum operation just + // broadcast the gradient in order to retrieve the z's shape. + // dy/dz = [x1, x1, x1, x2] + grad_outputs->push_back(Gather(scope, grad_inputs[0], op.input(1))); + + // stop propagation along segment_ids + grad_outputs->push_back(NoGradient()); + return scope.status(); +} +REGISTER_GRADIENT_OP("SegmentSum", SegmentSumGrad); + // MatMulGrad helper function used to compute two MatMul operations // based on input matrix transposition combinations. Status MatMulGradHelper(const Scope& scope, const bool is_batch, diff --git a/tensorflow/cc/gradients/math_grad_test.cc b/tensorflow/cc/gradients/math_grad_test.cc index fd7b6fe6625f27bda92e2f56f60908658cdecd7e..c16938322c3555939ace1013f3bb95c5689b503e 100644 --- a/tensorflow/cc/gradients/math_grad_test.cc +++ b/tensorflow/cc/gradients/math_grad_test.cc @@ -13,8 +13,10 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include "tensorflow/cc/client/client_session.h" #include "tensorflow/cc/framework/grad_op_registry.h" #include "tensorflow/cc/framework/gradient_checker.h" +#include "tensorflow/cc/framework/gradients.h" #include "tensorflow/cc/framework/testutil.h" #include "tensorflow/cc/gradients/grad_testutil.h" #include "tensorflow/cc/ops/standard_ops.h" @@ -31,6 +33,7 @@ using ops::AddN; using ops::BatchMatMul; using ops::Const; using ops::Div; +using ops::DivNoNan; using ops::MatMul; using ops::Max; using ops::Maximum; @@ -42,6 +45,7 @@ using ops::Placeholder; using ops::Pow; using ops::Prod; using ops::RealDiv; +using ops::SegmentSum; using ops::SquaredDifference; using ops::Sub; using ops::Sum; @@ -475,11 +479,7 @@ TEST_F(CWiseUnaryGradTest, Tan_Complex) { auto x_fn = [this](const int i) { return CRV({{1, 0}, {0, 1}, {2, -1}, {1, 2}, {3, 4}}); }; - // TODO(kbsriram) - // Enable when tan kernel supports complex inputs - if (false) { - TestCWiseGrad(TAN, x_fn); - } + TestCWiseGrad(TAN, x_fn); } TEST_F(CWiseUnaryGradTest, Atan) { @@ -854,6 +854,36 @@ TEST_F(NaryGradTest, RealDiv) { RunTest({x}, {x_shape}, {y}, {x_shape}); } +TEST_F(NaryGradTest, DivNoNan) { + { + TensorShape x_shape({3, 2, 5}); + const auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x_shape)); + // Test x / (1 + |x|) rather than x_1 / x_2 to avoid triggering large + // division errors in the numeric estimator used by the gradient checker. + const auto y = DivNoNan( + scope_, x, Add(scope_, Const(scope_, 1), Abs(scope_, x))); + RunTest({x}, {x_shape}, {y}, {x_shape}); + } + { + // Return 0 gradient (rather than NaN) for division by zero. + const auto x = Placeholder(scope_, DT_FLOAT); + const auto zero = Const(scope_, 0.0); + const auto y = DivNoNan(scope_, x, zero); + + std::vector grad_outputs; + TF_EXPECT_OK(AddSymbolicGradients(scope_, {y}, {x}, &grad_outputs)); + ClientSession session(scope_); + std::vector grad_result; + TF_EXPECT_OK( + session.Run({{x, {-3.0f, 0.0f, 3.0f}}}, grad_outputs, &grad_result)); + EXPECT_EQ(grad_result.size(), 1); + EXPECT_EQ(grad_result[0].NumElements(), 3); + EXPECT_EQ(grad_result[0].flat()(0), 0.0f); + EXPECT_EQ(grad_result[0].flat()(1), 0.0f); + EXPECT_EQ(grad_result[0].flat()(2), 0.0f); + } +} + TEST_F(NaryGradTest, SquaredDifference) { TensorShape x1_shape({3, 2, 5}); TensorShape x2_shape({2, 5}); @@ -902,5 +932,14 @@ TEST_F(NaryGradTest, Prod) { RunTest({x}, {x_shape}, {y}, {y_shape}); } +TEST_F(NaryGradTest, SegmentSum) { + TensorShape x_shape({3, 4}); + auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x_shape)); + auto y = SegmentSum(scope_, x, {0, 0, 1}); + // the sum is always on the first dimension + TensorShape y_shape({2, 4}); + RunTest({x}, {x_shape}, {y}, {y_shape}); +} + } // namespace } // namespace tensorflow diff --git a/tensorflow/cc/saved_model/loader.cc b/tensorflow/cc/saved_model/loader.cc index d47b02574317f5bbbe9bfdde04e306505062a434..3830416159158cca8bfb8422c2959b49fa42406d 100644 --- a/tensorflow/cc/saved_model/loader.cc +++ b/tensorflow/cc/saved_model/loader.cc @@ -74,6 +74,54 @@ void AddAssetsTensorsToInputs(const StringPiece export_dir, } } +// Like Session::Run(), but uses the Make/Run/ReleaseCallable() API to avoid +// leaving behind non-GC'ed state. +// +// Detailed motivation behind this approach, from ashankar@: +// +// Each call to Session::Run() that identifies a new subgraph (based on feeds +// and fetches) creates some datastructures that live as long as the session +// (the partitioned graph, associated executors etc.). +// +// A pathological case of this would be if say the initialization op +// (main_op/legacy_init_op) involves the use of a large constant. Then we +// allocate memory for that large constant that will just stick around till the +// session dies. With this Callable mechanism, that memory will be released +// right after ReleaseCallable returns. +// +// However, the resource manager state remains. +Status RunOnce(const RunOptions& run_options, + const std::vector>& inputs, + const std::vector& output_tensor_names, + const std::vector& target_node_names, + std::vector* outputs, RunMetadata* run_metadata, + Session* session) { + CallableOptions callable_options; + std::vector feed_tensors; + *callable_options.mutable_run_options() = run_options; + for (const auto& input : inputs) { + const string& name = input.first; + const Tensor& tensor = input.second; + callable_options.add_feed(name); + feed_tensors.push_back(tensor); + } + for (const string& output_tensor_name : output_tensor_names) { + callable_options.add_fetch(output_tensor_name); + } + for (const string& target_node_name : target_node_names) { + callable_options.add_target(target_node_name); + } + + Session::CallableHandle callable_handle; + TF_RETURN_IF_ERROR(session->MakeCallable(callable_options, &callable_handle)); + const Status run_status = session->RunCallable(callable_handle, feed_tensors, + outputs, run_metadata); + // Be sure to call ReleaseCallable() regardless of the outcome of + // RunCallable(). + session->ReleaseCallable(callable_handle).IgnoreError(); + return run_status; +} + bool HasMainOp(const MetaGraphDef& meta_graph_def) { const auto& collection_def_map = meta_graph_def.collection_def(); if (collection_def_map.find(kSavedModelMainOpKey) != @@ -100,8 +148,8 @@ Status RunMainOp(const RunOptions& run_options, const string& export_dir, AddAssetsTensorsToInputs(export_dir, asset_file_defs, &inputs); RunMetadata run_metadata; const StringPiece main_op_name = main_op_it->second.node_list().value(0); - return session->Run(run_options, inputs, {}, {main_op_name.ToString()}, - nullptr /* outputs */, &run_metadata); + return RunOnce(run_options, inputs, {}, {main_op_name.ToString()}, + nullptr /* outputs */, &run_metadata, session); } return Status::OK(); } @@ -122,7 +170,8 @@ Status RunRestore(const RunOptions& run_options, const string& export_dir, variables_directory, MetaFilename(kSavedModelVariablesFilename)); if (!Env::Default()->FileExists(variables_index_path).ok()) { LOG(INFO) << "The specified SavedModel has no variables; no checkpoints " - "were restored."; + "were restored. File does not exist: " + << variables_index_path; return Status::OK(); } const string variables_path = @@ -138,8 +187,8 @@ Status RunRestore(const RunOptions& run_options, const string& export_dir, AddAssetsTensorsToInputs(export_dir, asset_file_defs, &inputs); RunMetadata run_metadata; - return session->Run(run_options, inputs, {}, {restore_op_name.ToString()}, - nullptr /* outputs */, &run_metadata); + return RunOnce(run_options, inputs, {}, {restore_op_name.ToString()}, + nullptr /* outputs */, &run_metadata, session); } Status GetAssetFileDefs(const MetaGraphDef& meta_graph_def, diff --git a/tensorflow/compiler/aot/BUILD b/tensorflow/compiler/aot/BUILD index fef8b8d4d4cdcc97a913ae2ba6d1a8b0b0084f89..2220d0786d3757abc378d1a3d0ddc704bba6a4f3 100644 --- a/tensorflow/compiler/aot/BUILD +++ b/tensorflow/compiler/aot/BUILD @@ -8,28 +8,6 @@ load("//tensorflow/compiler/aot:tfcompile.bzl", "tf_library") load("//tensorflow:tensorflow.bzl", "tf_cc_test") load("//tensorflow:tensorflow.bzl", "tf_cc_binary") -# Optional runtime utilities for use by code generated by tfcompile. -cc_library( - name = "runtime", - srcs = ["runtime.cc"], - hdrs = ["runtime.h"], - visibility = ["//visibility:public"], - deps = [ - "//tensorflow/core:framework_lite", - ], -) - -tf_cc_test( - name = "runtime_test", - srcs = ["runtime_test.cc"], - deps = [ - ":runtime", - "//tensorflow/core:framework", - "//tensorflow/core:test", - "//tensorflow/core:test_main", - ], -) - # Don't depend on this directly; this is only used for the benchmark test # generated by tf_library. cc_library( @@ -53,9 +31,9 @@ cc_library( ], deps = [ ":embedded_protocol_buffers", - ":runtime", # needed by codegen to print aligned_buffer_bytes "//tensorflow/compiler/tf2xla", "//tensorflow/compiler/tf2xla:common", + "//tensorflow/compiler/tf2xla:cpu_function_runtime", "//tensorflow/compiler/tf2xla:tf2xla_proto", "//tensorflow/compiler/tf2xla:tf2xla_util", "//tensorflow/compiler/tf2xla:xla_compiler", @@ -70,12 +48,14 @@ cc_library( "//tensorflow/compiler/xla/client:compile_only_client", "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/service:compiler", + "//tensorflow/compiler/xla/service/cpu:buffer_info_util", "//tensorflow/compiler/xla/service/cpu:cpu_compiler", "//tensorflow/core:core_cpu_internal", "//tensorflow/core:framework_internal", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", "//tensorflow/core:protos_all_cc", + "@com_google_absl//absl/memory", ], ) @@ -214,6 +194,7 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla/service/llvm_ir:llvm_util", "//tensorflow/core:lib", + "@com_google_absl//absl/memory", "@llvm//:core", "@llvm//:support", "@llvm//:target", @@ -238,7 +219,6 @@ test_suite( tests = [ ":benchmark_test", ":codegen_test", - ":runtime_test", ":test_graph_tfadd_test", ":test_graph_tfunknownop2_test", ":test_graph_tfunknownop3_test", diff --git a/tensorflow/compiler/aot/codegen.cc b/tensorflow/compiler/aot/codegen.cc index 28070d60dbbe6dd8f930b8e6509cedcf09f94e11..a8485576ac1cbbb39450ab46f67761533b34e0b6 100644 --- a/tensorflow/compiler/aot/codegen.cc +++ b/tensorflow/compiler/aot/codegen.cc @@ -19,11 +19,13 @@ limitations under the License. #include #include +#include "absl/memory/memory.h" #include "tensorflow/compiler/aot/embedded_protocol_buffers.h" -#include "tensorflow/compiler/aot/runtime.h" +#include "tensorflow/compiler/tf2xla/cpu_function_runtime.h" #include "tensorflow/compiler/tf2xla/str_util.h" #include "tensorflow/compiler/tf2xla/tf2xla_util.h" #include "tensorflow/compiler/xla/service/compiler.h" +#include "tensorflow/compiler/xla/service/cpu/buffer_info_util.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/errors.h" @@ -36,6 +38,8 @@ namespace tfcompile { namespace { +using BufferInfo = cpu_function_runtime::BufferInfo; + bool IsAlpha(char c) { return (c >= 'A' && c <= 'Z') || (c >= 'a' && c <= 'z'); } @@ -85,27 +89,36 @@ Status XLATypeToCpp(xla::PrimitiveType type, string* str) { return Status::OK(); } -// total_buffer_bytes returns the sum of each size in `sizes`, skipping -1 -// values. There are `n` entries in `sizes`. -size_t total_buffer_bytes(const intptr_t* sizes, size_t n) { - size_t total = 0; - for (size_t i = 0; i < n; ++i) { - if (sizes[i] != -1) { - total += sizes[i]; - } - } - return total; +// Returns the sum of the size of each buffer in `buffer_infos`. +size_t TotalBufferBytes(const std::vector& buffer_infos) { + return std::accumulate(buffer_infos.begin(), buffer_infos.end(), size_t{0}, + [](size_t size, const BufferInfo& buffer_info) { + return size + buffer_info.size(); + }); } -// Fills in arg_sizes with the byte size of each positional arg. -Status ComputeArgSizes(const CompileResult& compile_result, - std::vector* arg_sizes) { - const xla::ProgramShape& ps = compile_result.program_shape; - for (int i = 0; i < ps.parameters_size(); ++i) { - arg_sizes->push_back(xla::ShapeUtil::ByteSizeOf( - ps.parameters(i), compile_result.pointer_size)); - } - return Status::OK(); +// Returns a vector of BufferInfo instances in `buffer_infos` that are entry +// parameter buffers. +std::vector ExtractEntryParamBufferInfos( + const std::vector& buffer_infos) { + std::vector result; + std::copy_if(buffer_infos.begin(), buffer_infos.end(), + std::back_inserter(result), [](const BufferInfo& buffer_info) { + return buffer_info.is_entry_parameter(); + }); + return result; +} + +// Returns a vector of BufferInfo instances in `buffer_infos` that are temp +// buffers. +std::vector ExtractTempBufferInfos( + const std::vector& buffer_infos) { + std::vector result; + std::copy_if(buffer_infos.begin(), buffer_infos.end(), + std::back_inserter(result), [](const BufferInfo& buffer_info) { + return buffer_info.is_temp_buffer(); + }); + return result; } // Add (from,to) rewrite pairs based on the given shape. These rewrite pairs @@ -278,6 +291,25 @@ Status ValidateFeedFetchCppNames(const tf2xla::Config& config) { return Status::OK(); } +// Returns a list of C++ expressions that, when executed, will construct the +// BufferInfo instances in `buffer_infos`. +std::vector BufferInfosToCppExpression( + const std::vector& buffer_infos) { + std::vector buffer_infos_as_strings; + std::transform(buffer_infos.begin(), buffer_infos.end(), + std::back_inserter(buffer_infos_as_strings), + [](const BufferInfo& buffer_info) { + std::pair encoded = buffer_info.Encode(); + string encoded_second_as_str = + encoded.second == ~0ULL + ? "~0ULL" + : strings::StrCat(encoded.second, "ULL"); + return strings::StrCat( + "::tensorflow::cpu_function_runtime::BufferInfo({", + encoded.first, "ULL, ", encoded_second_as_str, "})"); + }); + return buffer_infos_as_strings; +} } // namespace Status GenerateHeader(const CodegenOpts& opts, const tf2xla::Config& config, @@ -286,29 +318,35 @@ Status GenerateHeader(const CodegenOpts& opts, const tf2xla::Config& config, TF_RETURN_IF_ERROR(ValidateConfig(config)); TF_RETURN_IF_ERROR(ValidateFeedFetchCppNames(config)); const int64 result_index = compile_result.aot->result_buffer_index(); - const xla::BufferSizes& temp_sizes = compile_result.aot->buffer_sizes(); - if (result_index < 0 || result_index >= temp_sizes.size()) { + const std::vector& buffer_infos = + compile_result.aot->buffer_infos(); + const std::vector arg_index_table = + ::xla::cpu::CreateArgIndexTableFromBufferInfos(buffer_infos); + std::vector buffer_infos_as_strings = + BufferInfosToCppExpression(buffer_infos); + if (result_index < 0 || result_index >= buffer_infos.size()) { return errors::InvalidArgument("result index: ", result_index, " is outside the range of temp sizes: [0,", - temp_sizes.size(), ")"); + buffer_infos.size(), ")"); } // Compute sizes and generate methods. - std::vector arg_sizes; - TF_RETURN_IF_ERROR(ComputeArgSizes(compile_result, &arg_sizes)); + std::vector buffer_infos_for_args = + ExtractEntryParamBufferInfos(buffer_infos); + std::vector buffer_infos_for_temps = + ExtractTempBufferInfos(buffer_infos); const xla::ProgramShape& ps = compile_result.program_shape; string methods_arg, methods_result; TF_RETURN_IF_ERROR(GenArgMethods(config, ps, compile_result, &methods_arg)); TF_RETURN_IF_ERROR(GenResultMethods(config, ps, &methods_result)); - const std::vector iarg(arg_sizes.begin(), arg_sizes.end()); - const std::vector itemp(temp_sizes.begin(), temp_sizes.end()); - const size_t arg_bytes_aligned = - runtime::aligned_buffer_bytes(iarg.data(), iarg.size()); - const size_t arg_bytes_total = total_buffer_bytes(iarg.data(), iarg.size()); - const size_t temp_bytes_aligned = - runtime::aligned_buffer_bytes(itemp.data(), itemp.size()); - const size_t temp_bytes_total = - total_buffer_bytes(itemp.data(), itemp.size()); + const size_t arg_bytes_aligned = cpu_function_runtime::AlignedBufferBytes( + buffer_infos_for_args.data(), buffer_infos_for_args.size(), + /*allocate_entry_params=*/true); + const size_t arg_bytes_total = TotalBufferBytes(buffer_infos_for_args); + const size_t temp_bytes_aligned = cpu_function_runtime::AlignedBufferBytes( + buffer_infos_for_temps.data(), buffer_infos_for_temps.size(), + /*allocate_entry_params=*/true); + const size_t temp_bytes_total = TotalBufferBytes(buffer_infos_for_temps); // Create rewrite strings for namespace start and end. string ns_start; @@ -343,8 +381,8 @@ Status GenerateHeader(const CodegenOpts& opts, const tf2xla::Config& config, // calling HloProfilePrinter::profile_counters_size. const string assign_profile_counters_size = opts.gen_hlo_profile_printer_data - ? "data->profile_counters_size = " - "data->hlo_profile_printer_data->profile_counters_size();" + ? "data->set_profile_counters_size(" + "data->hlo_profile_printer_data()->profile_counters_size());" : ""; // Use a poor-man's text templating mechanism; first populate the full header @@ -414,9 +452,8 @@ class {{CLASS}} : public tensorflow::XlaCompiledCpuFunction { static constexpr size_t kNumArgs = {{ARG_NUM}}; // Byte size of each argument buffer. There are kNumArgs entries. - static const intptr_t* ArgSizes() { - static constexpr intptr_t kArgSizes[kNumArgs] = {{{ARG_SIZES}}}; - return kArgSizes; + static const ::tensorflow::int64 ArgSize(::tensorflow::int32 index) { + return BufferInfos()[ArgIndexToBufferIndex()[index]].size(); } // Returns static data used to create an XlaCompiledCpuFunction. @@ -424,17 +461,17 @@ class {{CLASS}} : public tensorflow::XlaCompiledCpuFunction { static XlaCompiledCpuFunction::StaticData* kStaticData = [](){ XlaCompiledCpuFunction::StaticData* data = new XlaCompiledCpuFunction::StaticData; - data->raw_function = {{ENTRY}}; - data->arg_sizes = ArgSizes(); - data->num_args = kNumArgs; - data->temp_sizes = TempSizes(); - data->num_temps = kNumTemps; - data->result_index = kResultIndex; - data->arg_names = StaticArgNames(); - data->result_names = StaticResultNames(); - data->program_shape = StaticProgramShape(); - data->hlo_profile_printer_data = StaticHloProfilePrinterData(); - {{ASSIGN_PROFILE_COUNTERS_SIZE}} + data->set_raw_function({{ENTRY}}); + data->set_buffer_infos(BufferInfos()); + data->set_num_buffers(kNumBuffers); + data->set_arg_index_table(ArgIndexToBufferIndex()); + data->set_num_args(kNumArgs); + data->set_result_index(kResultIndex); + data->set_arg_names(StaticArgNames()); + data->set_result_names(StaticResultNames()); + data->set_program_shape(StaticProgramShape()); + data->set_hlo_profile_printer_data(StaticHloProfilePrinterData()); +{{ASSIGN_PROFILE_COUNTERS_SIZE}} return data; }(); return *kStaticData; @@ -482,17 +519,27 @@ class {{CLASS}} : public tensorflow::XlaCompiledCpuFunction { {{METHODS_RESULT}} private: - // Number of result and temporary buffers for the compiled computation. - static constexpr size_t kNumTemps = {{TEMP_NUM}}; - // The 0-based index of the result tuple in the temporary buffers. - static constexpr size_t kResultIndex = {{RESULT_INDEX}}; + // Number of buffers for the compiled computation. + static constexpr size_t kNumBuffers = {{NUM_BUFFERS}}; - // Byte size of each result / temporary buffer. There are kNumTemps entries. - static const intptr_t* TempSizes() { - static constexpr intptr_t kTempSizes[kNumTemps] = {{{TEMP_SIZES}}}; - return kTempSizes; + static const ::tensorflow::cpu_function_runtime::BufferInfo* BufferInfos() { + static const ::tensorflow::cpu_function_runtime::BufferInfo + kBufferInfos[kNumBuffers] = { +{{BUFFER_INFOS_AS_STRING}} + }; + return kBufferInfos; } + static const ::tensorflow::int32* ArgIndexToBufferIndex() { + static constexpr ::tensorflow::int32 kArgIndexToBufferIndex[kNumArgs] = { +{{ARG_INDEX_TABLE}} + }; + return kArgIndexToBufferIndex; + } + + // The 0-based index of the result tuple in the temporary buffers. + static constexpr size_t kResultIndex = {{RESULT_INDEX}}; + // Array of names of each positional argument, terminated by nullptr. static const char** StaticArgNames() {{ARG_NAMES_CODE}} @@ -523,8 +570,8 @@ class {{CLASS}} : public tensorflow::XlaCompiledCpuFunction { {"{{ARG_BYTES_ALIGNED}}", strings::StrCat(arg_bytes_aligned)}, {"{{ARG_BYTES_TOTAL}}", strings::StrCat(arg_bytes_total)}, {"{{ARG_NAMES_CODE}}", arg_names_code}, - {"{{ARG_NUM}}", strings::StrCat(arg_sizes.size())}, - {"{{ARG_SIZES}}", str_util::Join(arg_sizes, ", ")}, + {"{{ARG_NUM}}", strings::StrCat(arg_index_table.size())}, + {"{{ARG_INDEX_TABLE}}", str_util::Join(arg_index_table, ", ")}, {"{{ASSIGN_PROFILE_COUNTERS_SIZE}}", assign_profile_counters_size}, {"{{CLASS}}", opts.class_name}, {"{{DECLS_FROM_OBJ_FILE}}", @@ -546,8 +593,9 @@ class {{CLASS}} : public tensorflow::XlaCompiledCpuFunction { {"{{RESULT_NAMES_CODE}}", result_names_code}, {"{{TEMP_BYTES_ALIGNED}}", strings::StrCat(temp_bytes_aligned)}, {"{{TEMP_BYTES_TOTAL}}", strings::StrCat(temp_bytes_total)}, - {"{{TEMP_NUM}}", strings::StrCat(temp_sizes.size())}, - {"{{TEMP_SIZES}}", str_util::Join(temp_sizes, ", ")}}; + {"{{NUM_BUFFERS}}", strings::StrCat(buffer_infos.size())}, + {"{{BUFFER_INFOS_AS_STRING}}", + str_util::Join(buffer_infos_as_strings, ",\n")}}; str_util::ReplaceAllPairs(header, rewrites); return Status::OK(); } @@ -570,7 +618,7 @@ Status GenerateMetadata(const CodegenOpts& opts, if (opts.gen_program_shape) { program_shape = - tensorflow::MakeUnique(compile_result.program_shape); + absl::make_unique(compile_result.program_shape); // The parameter names are currently meaningless, and redundant with the // rest of our metadata, so clear them out to avoid confusion and save // space. diff --git a/tensorflow/compiler/aot/codegen_test.cc b/tensorflow/compiler/aot/codegen_test.cc index 29bc9c13b889c86c2ba8776c7b067c54cb05bc43..60d59ae996e8f7ec490c98aeab05182626e61976 100644 --- a/tensorflow/compiler/aot/codegen_test.cc +++ b/tensorflow/compiler/aot/codegen_test.cc @@ -32,6 +32,8 @@ namespace tensorflow { namespace tfcompile { namespace { +using ::tensorflow::cpu_function_runtime::BufferInfo; + void ExpectErrorContains(const Status& status, StringPiece str) { EXPECT_NE(Status::OK(), status); EXPECT_TRUE(str_util::StrContains(status.error_message(), str)) @@ -171,8 +173,14 @@ TEST(CodegenTest, Golden) { fetch->mutable_id()->set_node_name("fetch0"); fetch->set_name("myfetch"); CompileResult compile_result; - compile_result.aot.reset( - new xla::cpu::CpuAotCompilationResult({}, {1, -1, 2, -1, 3, 120}, 5, {})); + compile_result.aot.reset(new xla::cpu::CpuAotCompilationResult( + {}, + {BufferInfo::MakeTempBuffer(1), + BufferInfo::MakeEntryParameter(/*size=*/8, /*param_number=*/0), + BufferInfo::MakeTempBuffer(2), + BufferInfo::MakeEntryParameter(/*size=*/96, /*param_number=*/1), + BufferInfo::MakeTempBuffer(3), BufferInfo::MakeTempBuffer(120)}, + 5, {})); compile_result.program_shape = xla::ShapeUtil::MakeProgramShape( { xla::ShapeUtil::MakeShape(xla::F32, {1, 2}), diff --git a/tensorflow/compiler/aot/codegen_test_h.golden b/tensorflow/compiler/aot/codegen_test_h.golden index 6641d45e83020f4144616a6a2837c844330298f5..e4d8a02877c75fa72c5747650ab9c7ac229955b3 100644 --- a/tensorflow/compiler/aot/codegen_test_h.golden +++ b/tensorflow/compiler/aot/codegen_test_h.golden @@ -65,9 +65,8 @@ class MyClass : public tensorflow::XlaCompiledCpuFunction { static constexpr size_t kNumArgs = 2; // Byte size of each argument buffer. There are kNumArgs entries. - static const intptr_t* ArgSizes() { - static constexpr intptr_t kArgSizes[kNumArgs] = {8, 96}; - return kArgSizes; + static const ::tensorflow::int64 ArgSize(::tensorflow::int32 index) { + return BufferInfos()[ArgIndexToBufferIndex()[index]].size(); } // Returns static data used to create an XlaCompiledCpuFunction. @@ -75,17 +74,17 @@ class MyClass : public tensorflow::XlaCompiledCpuFunction { static XlaCompiledCpuFunction::StaticData* kStaticData = [](){ XlaCompiledCpuFunction::StaticData* data = new XlaCompiledCpuFunction::StaticData; - data->raw_function = entry_point; - data->arg_sizes = ArgSizes(); - data->num_args = kNumArgs; - data->temp_sizes = TempSizes(); - data->num_temps = kNumTemps; - data->result_index = kResultIndex; - data->arg_names = StaticArgNames(); - data->result_names = StaticResultNames(); - data->program_shape = StaticProgramShape(); - data->hlo_profile_printer_data = StaticHloProfilePrinterData(); - + data->set_raw_function(entry_point); + data->set_buffer_infos(BufferInfos()); + data->set_num_buffers(kNumBuffers); + data->set_arg_index_table(ArgIndexToBufferIndex()); + data->set_num_args(kNumArgs); + data->set_result_index(kResultIndex); + data->set_arg_names(StaticArgNames()); + data->set_result_names(StaticResultNames()); + data->set_program_shape(StaticProgramShape()); + data->set_hlo_profile_printer_data(StaticHloProfilePrinterData()); + return data; }(); return *kStaticData; @@ -215,17 +214,32 @@ class MyClass : public tensorflow::XlaCompiledCpuFunction { } private: - // Number of result and temporary buffers for the compiled computation. - static constexpr size_t kNumTemps = 6; - // The 0-based index of the result tuple in the temporary buffers. - static constexpr size_t kResultIndex = 5; + // Number of buffers for the compiled computation. + static constexpr size_t kNumBuffers = 6; + + static const ::tensorflow::cpu_function_runtime::BufferInfo* BufferInfos() { + static const ::tensorflow::cpu_function_runtime::BufferInfo + kBufferInfos[kNumBuffers] = { +::tensorflow::cpu_function_runtime::BufferInfo({5ULL, ~0ULL}), +::tensorflow::cpu_function_runtime::BufferInfo({34ULL, 0ULL}), +::tensorflow::cpu_function_runtime::BufferInfo({9ULL, ~0ULL}), +::tensorflow::cpu_function_runtime::BufferInfo({386ULL, 1ULL}), +::tensorflow::cpu_function_runtime::BufferInfo({13ULL, ~0ULL}), +::tensorflow::cpu_function_runtime::BufferInfo({481ULL, ~0ULL}) + }; + return kBufferInfos; + } - // Byte size of each result / temporary buffer. There are kNumTemps entries. - static const intptr_t* TempSizes() { - static constexpr intptr_t kTempSizes[kNumTemps] = {1, -1, 2, -1, 3, 120}; - return kTempSizes; + static const ::tensorflow::int32* ArgIndexToBufferIndex() { + static constexpr ::tensorflow::int32 kArgIndexToBufferIndex[kNumArgs] = { +1, 3 + }; + return kArgIndexToBufferIndex; } + // The 0-based index of the result tuple in the temporary buffers. + static constexpr size_t kResultIndex = 5; + // Array of names of each positional argument, terminated by nullptr. static const char** StaticArgNames() { static const char* kNames[] = {"myfeed", nullptr}; diff --git a/tensorflow/compiler/aot/embedded_protocol_buffers.cc b/tensorflow/compiler/aot/embedded_protocol_buffers.cc index 4e27aafec7747655d8e4ea3ddd1788d495ca0710..8fb2fad31c680c5dbbd058a1b9a9265607224429 100644 --- a/tensorflow/compiler/aot/embedded_protocol_buffers.cc +++ b/tensorflow/compiler/aot/embedded_protocol_buffers.cc @@ -18,6 +18,7 @@ limitations under the License. #include #include +#include "absl/memory/memory.h" #include "llvm/ADT/Triple.h" #include "llvm/IR/GlobalVariable.h" #include "llvm/IR/LLVMContext.h" @@ -27,7 +28,6 @@ limitations under the License. #include "llvm/Target/TargetMachine.h" #include "llvm/Target/TargetOptions.h" #include "tensorflow/compiler/tf2xla/str_util.h" -#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" #include "tensorflow/compiler/xla/util.h" @@ -105,7 +105,7 @@ GetTargetMachineFromTriple(StringPiece target_triple) { error.c_str()); } - return WrapUnique(target->createTargetMachine( + return absl::WrapUnique(target->createTargetMachine( normalized_triple, /*CPU=*/"", /*Features=*/"", llvm::TargetOptions(), llvm::None)); } @@ -118,7 +118,7 @@ StatusOr CreateEmbeddedProtocolBuffers( llvm::LLVMContext llvm_context; std::unique_ptr module_with_serialized_proto = - MakeUnique("embedded_data_module", llvm_context); + absl::make_unique("embedded_data_module", llvm_context); EmbeddedProtocolBuffers result; diff --git a/tensorflow/compiler/aot/runtime.h b/tensorflow/compiler/aot/runtime.h deleted file mode 100644 index d1a669ceb17b9fd71d26e978035283f8824b0376..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/aot/runtime.h +++ /dev/null @@ -1,58 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -// This file contains utilities to make it easier to invoke functions generated -// by tfcompile. Usage of these utilities is optional. - -#ifndef TENSORFLOW_COMPILER_AOT_RUNTIME_H_ -#define TENSORFLOW_COMPILER_AOT_RUNTIME_H_ - -#include "tensorflow/core/platform/types.h" - -namespace tensorflow { -namespace tfcompile { -namespace runtime { - -// Align to 64-bytes, to mimic tensorflow::Allocator::kAllocatorAlignment. -static constexpr size_t kAlign = 64; - -// aligned_buffer_bytes returns the sum of each size in `sizes`, skipping -1 -// values. There are `n` entries in `sizes`. Each buffer is aligned to kAlign -// byte boundaries. -size_t aligned_buffer_bytes(const intptr_t* sizes, size_t n); - -// MallocContiguousBuffers allocates buffers for use by the entry point -// generated by tfcompile. `sizes` is an array of byte sizes for each buffer, -// where -1 causes the buffer pointer to be nullptr. There are `n` entries in -// `sizes`. If `annotate_initialized` is set, the allocated memory will be -// annotated as having been initialized - this is useful when allocating -// temporary buffers. -// -// A single contiguous block of memory is allocated, and portions of it are -// parceled out into `bufs`, which must have space for `n` entries. Returns the -// head of the allocated contiguous block, which should be passed to -// FreeContiguous when the buffers are no longer in use. -void* MallocContiguousBuffers(const intptr_t* sizes, size_t n, void** bufs, - bool annotate_initialized); - -// FreeContiguous frees the contiguous block of memory allocated by -// MallocContiguousBuffers. -void FreeContiguous(void* contiguous); - -} // namespace runtime -} // namespace tfcompile -} // namespace tensorflow - -#endif // TENSORFLOW_COMPILER_AOT_RUNTIME_H_ diff --git a/tensorflow/compiler/aot/test.cc b/tensorflow/compiler/aot/test.cc index 6b098049cbd7539a2b2e2696b13139a8a6b28e0f..5deb47d12310d24dce847227bd119249210ffb8d 100644 --- a/tensorflow/compiler/aot/test.cc +++ b/tensorflow/compiler/aot/test.cc @@ -51,11 +51,9 @@ namespace tensorflow { namespace tfcompile { namespace { -void zero_buffers(void** bufs, const intptr_t* sizes, size_t n) { - for (int i = 0; i < n; ++i) { - if (sizes[i] != -1) { - memset(bufs[i], 0, sizes[i]); - } +void zero_buffers(XlaCompiledCpuFunction* computation) { + for (int i = 0; i < computation->num_args(); ++i) { + memset(computation->arg_data(i), 0, computation->arg_size(i)); } } @@ -66,7 +64,7 @@ TEST(TEST_NAME, NoCrash) { CPP_CLASS computation; computation.set_thread_pool(&device); - zero_buffers(computation.args(), CPP_CLASS::ArgSizes(), CPP_CLASS::kNumArgs); + zero_buffers(&computation); EXPECT_TRUE(computation.Run()); } @@ -80,7 +78,7 @@ void BM_NAME(int iters) { CPP_CLASS computation; computation.set_thread_pool(&device); - zero_buffers(computation.args(), CPP_CLASS::ArgSizes(), CPP_CLASS::kNumArgs); + zero_buffers(&computation); testing::StartTiming(); while (--iters) { diff --git a/tensorflow/compiler/aot/tests/tfcompile_test.cc b/tensorflow/compiler/aot/tests/tfcompile_test.cc index fee46280e9a0e7ba2cf7c3ed46469ae8cc0841d4..0c0c676ece78565e03578d3e33633c7e23b77669 100644 --- a/tensorflow/compiler/aot/tests/tfcompile_test.cc +++ b/tensorflow/compiler/aot/tests/tfcompile_test.cc @@ -44,8 +44,8 @@ using ::testing::IsSupersetOf; TEST(TFCompileTest, Add) { AddComp add; - EXPECT_EQ(add.arg0_data(), add.args()[0]); - EXPECT_EQ(add.arg1_data(), add.args()[1]); + EXPECT_EQ(add.arg0_data(), add.arg_data(0)); + EXPECT_EQ(add.arg1_data(), add.arg_data(1)); add.arg0() = 1; add.arg1() = 2; @@ -67,10 +67,10 @@ TEST(TFCompileTest, Add) { EXPECT_EQ(add_const.error_msg(), ""); EXPECT_EQ(add_const.arg0(), 123); EXPECT_EQ(add_const.arg0_data()[0], 123); - EXPECT_EQ(add_const.arg0_data(), add.args()[0]); + EXPECT_EQ(add_const.arg0_data(), add.arg_data(0)); EXPECT_EQ(add_const.arg1(), 456); EXPECT_EQ(add_const.arg1_data()[0], 456); - EXPECT_EQ(add_const.arg1_data(), add.args()[1]); + EXPECT_EQ(add_const.arg1_data(), add.arg_data(1)); EXPECT_EQ(add_const.result0(), 579); EXPECT_EQ(add_const.result0_data()[0], 579); EXPECT_EQ(add_const.result0_data(), add_const.results()[0]); @@ -85,8 +85,8 @@ TEST(TFCompileTest, Add_SetArg) { int32 arg_y = 32; add.set_arg0_data(&arg_x); add.set_arg1_data(&arg_y); - EXPECT_EQ(add.arg0_data(), add.args()[0]); - EXPECT_EQ(add.arg1_data(), add.args()[1]); + EXPECT_EQ(add.arg0_data(), add.arg_data(0)); + EXPECT_EQ(add.arg1_data(), add.arg_data(1)); EXPECT_TRUE(add.Run()); EXPECT_EQ(add.error_msg(), ""); @@ -97,7 +97,7 @@ TEST(TFCompileTest, Add_SetArg) { TEST(TFCompileTest, AddWithCkpt) { AddWithCkptComp add; - EXPECT_EQ(add.arg0_data(), add.args()[0]); + EXPECT_EQ(add.arg0_data(), add.arg_data(0)); add.arg0() = 1; EXPECT_TRUE(add.Run()); @@ -117,7 +117,7 @@ TEST(TFCompileTest, AddWithCkpt) { EXPECT_EQ(add_const.error_msg(), ""); EXPECT_EQ(add_const.arg0(), 111); EXPECT_EQ(add_const.arg0_data()[0], 111); - EXPECT_EQ(add_const.arg0_data(), add_const.args()[0]); + EXPECT_EQ(add_const.arg0_data(), add_const.arg_data(0)); EXPECT_EQ(add_const.result0(), 153); EXPECT_EQ(add_const.result0_data()[0], 153); EXPECT_EQ(add_const.result0_data(), add_const.results()[0]); @@ -125,7 +125,7 @@ TEST(TFCompileTest, AddWithCkpt) { TEST(TFCompileTest, AddWithCkptSaver) { AddWithCkptSaverComp add; - EXPECT_EQ(add.arg0_data(), add.args()[0]); + EXPECT_EQ(add.arg0_data(), add.arg_data(0)); add.arg0() = 1; EXPECT_TRUE(add.Run()); @@ -145,7 +145,7 @@ TEST(TFCompileTest, AddWithCkptSaver) { EXPECT_EQ(add_const.error_msg(), ""); EXPECT_EQ(add_const.arg0(), 111); EXPECT_EQ(add_const.arg0_data()[0], 111); - EXPECT_EQ(add_const.arg0_data(), add_const.args()[0]); + EXPECT_EQ(add_const.arg0_data(), add_const.arg_data(0)); EXPECT_EQ(add_const.result0(), 153); EXPECT_EQ(add_const.result0_data()[0], 153); EXPECT_EQ(add_const.result0_data(), add_const.results()[0]); @@ -153,9 +153,9 @@ TEST(TFCompileTest, AddWithCkptSaver) { TEST(TFCompileTest, Cond) { CondComp cond; - EXPECT_EQ(cond.arg0_data(), cond.args()[0]); - EXPECT_EQ(cond.arg1_data(), cond.args()[1]); - EXPECT_EQ(cond.arg2_data(), cond.args()[2]); + EXPECT_EQ(cond.arg0_data(), cond.arg_data(0)); + EXPECT_EQ(cond.arg1_data(), cond.arg_data(1)); + EXPECT_EQ(cond.arg2_data(), cond.arg_data(2)); cond.arg1() = 10; cond.arg2() = 20; { @@ -178,8 +178,8 @@ TEST(TFCompileTest, Cond) { TEST(TFCompileTest, Gather) { GatherComp gather; - EXPECT_EQ(gather.arg0_data(), gather.args()[0]); - EXPECT_EQ(gather.arg1_data(), gather.args()[1]); + EXPECT_EQ(gather.arg0_data(), gather.arg_data(0)); + EXPECT_EQ(gather.arg1_data(), gather.arg_data(1)); // Successful gather. { @@ -202,12 +202,12 @@ TEST(TFCompileTest, Gather) { EXPECT_EQ(gather_const.arg0(i), params[i]); EXPECT_EQ(gather_const.arg0_data()[i], params[i]); } - EXPECT_EQ(gather_const.arg0_data(), gather_const.args()[0]); + EXPECT_EQ(gather_const.arg0_data(), gather_const.arg_data(0)); for (int i = 0; i < 2; ++i) { EXPECT_EQ(gather_const.arg1(i), indices[i]); EXPECT_EQ(gather_const.arg1_data()[i], indices[i]); } - EXPECT_EQ(gather_const.arg1_data(), gather_const.args()[1]); + EXPECT_EQ(gather_const.arg1_data(), gather_const.arg_data(1)); for (int i = 0; i < 2; ++i) { EXPECT_EQ(gather_const.result0(i), results[i]); EXPECT_EQ(gather_const.result0_data()[i], results[i]); @@ -222,8 +222,8 @@ TEST(TFCompileTest, MatMul2) { foo::bar::MatMulComp matmul; matmul.set_thread_pool(&device); - EXPECT_EQ(matmul.arg0_data(), matmul.args()[0]); - EXPECT_EQ(matmul.arg1_data(), matmul.args()[1]); + EXPECT_EQ(matmul.arg0_data(), matmul.arg_data(0)); + EXPECT_EQ(matmul.arg1_data(), matmul.arg_data(1)); // Test using the argN() methods. { @@ -271,12 +271,12 @@ TEST(TFCompileTest, MatMul2) { EXPECT_EQ(matmul_const.arg0(i / 3, i % 3), args[i]); EXPECT_EQ(matmul_const.arg0_data()[i], args[i]); } - EXPECT_EQ(matmul_const.arg0_data(), matmul.args()[0]); + EXPECT_EQ(matmul_const.arg0_data(), matmul.arg_data(0)); for (int i = 0; i < 6; ++i) { EXPECT_EQ(matmul_const.arg1(i / 2, i % 2), args[i + 6]); EXPECT_EQ(matmul_const.arg1_data()[i], args[i + 6]); } - EXPECT_EQ(matmul_const.arg1_data(), matmul.args()[1]); + EXPECT_EQ(matmul_const.arg1_data(), matmul.arg_data(1)); for (int i = 0; i < 4; ++i) { EXPECT_EQ(matmul_const.result0(i / 2, i % 2), results[i]); EXPECT_EQ(matmul_const.result0_data()[i], results[i]); @@ -300,8 +300,8 @@ TEST(TFCompileTest, MatMul2_SetArg) { float arg1[3][2] = {{7, 8}, {9, 10}, {11, 12}}; matmul.set_arg0_data(&arg0); matmul.set_arg1_data(&arg1); - EXPECT_EQ(matmul.arg0_data(), matmul.args()[0]); - EXPECT_EQ(matmul.arg1_data(), matmul.args()[1]); + EXPECT_EQ(matmul.arg0_data(), matmul.arg_data(0)); + EXPECT_EQ(matmul.arg1_data(), matmul.arg_data(1)); EXPECT_TRUE(matmul.Run()); EXPECT_EQ(matmul.error_msg(), ""); @@ -319,8 +319,8 @@ TEST(TFCompileTest, MatMulAndAdd1) { MatMulAndAddComp muladd; muladd.set_thread_pool(&device); - EXPECT_EQ(muladd.arg0_data(), muladd.args()[0]); - EXPECT_EQ(muladd.arg1_data(), muladd.args()[1]); + EXPECT_EQ(muladd.arg0_data(), muladd.arg_data(0)); + EXPECT_EQ(muladd.arg1_data(), muladd.arg_data(1)); // Test methods with positional args and results. { @@ -346,12 +346,12 @@ TEST(TFCompileTest, MatMulAndAdd1) { EXPECT_EQ(muladd_const.arg0(i / 2, i % 2), args[i]); EXPECT_EQ(muladd_const.arg0_data()[i], args[i]); } - EXPECT_EQ(muladd_const.arg0_data(), muladd.args()[0]); + EXPECT_EQ(muladd_const.arg0_data(), muladd.arg_data(0)); for (int i = 0; i < 4; ++i) { EXPECT_EQ(muladd_const.arg1(i / 2, i % 2), args[i + 4]); EXPECT_EQ(muladd_const.arg1_data()[i], args[i + 4]); } - EXPECT_EQ(muladd_const.arg1_data(), muladd.args()[1]); + EXPECT_EQ(muladd_const.arg1_data(), muladd.arg_data(1)); for (int i = 0; i < 4; ++i) { EXPECT_EQ(muladd_const.result0(i / 2, i % 2), results0[i]); EXPECT_EQ(muladd_const.result0_data()[i], results0[i]); @@ -387,12 +387,12 @@ TEST(TFCompileTest, MatMulAndAdd1) { EXPECT_EQ(muladd_const.arg_x(i / 2, i % 2), args[i]); EXPECT_EQ(muladd_const.arg_x_data()[i], args[i]); } - EXPECT_EQ(muladd_const.arg_x_data(), muladd.args()[0]); + EXPECT_EQ(muladd_const.arg_x_data(), muladd.arg_data(0)); for (int i = 0; i < 4; ++i) { EXPECT_EQ(muladd_const.arg_y(i / 2, i % 2), args[i + 4]); EXPECT_EQ(muladd_const.arg_y_data()[i], args[i + 4]); } - EXPECT_EQ(muladd_const.arg_y_data(), muladd.args()[1]); + EXPECT_EQ(muladd_const.arg_y_data(), muladd.arg_data(1)); for (int i = 0; i < 4; ++i) { EXPECT_EQ(muladd_const.result_x_y_prod(i / 2, i % 2), results0[i]); EXPECT_EQ(muladd_const.result_x_y_prod_data()[i], results0[i]); @@ -407,8 +407,8 @@ TEST(TFCompileTest, MatMulAndAdd1) { TEST(TFCompileTest, Function) { // The function is equivalent to an addition FunctionComp add_fn; - EXPECT_EQ(add_fn.arg0_data(), add_fn.args()[0]); - EXPECT_EQ(add_fn.arg1_data(), add_fn.args()[1]); + EXPECT_EQ(add_fn.arg0_data(), add_fn.arg_data(0)); + EXPECT_EQ(add_fn.arg1_data(), add_fn.arg_data(1)); add_fn.arg0() = 1; add_fn.arg1() = 2; @@ -451,8 +451,8 @@ TEST(TFCompileTest, AssertEqAndReturnDiff) { // Assert is converted into a no-op in XLA, so there is no failure even if the // two args are different. AssertComp assert; - EXPECT_EQ(assert.arg0_data(), assert.args()[0]); - EXPECT_EQ(assert.arg1_data(), assert.args()[1]); + EXPECT_EQ(assert.arg0_data(), assert.arg_data(0)); + EXPECT_EQ(assert.arg1_data(), assert.arg_data(1)); assert.arg0() = 2; assert.arg1() = 1; diff --git a/tensorflow/compiler/aot/tfcompile.bzl b/tensorflow/compiler/aot/tfcompile.bzl index 5c57fee326ca743dcb8aaae354d261ed4d7f44be..326f73b975aec3a7a6bc7cdc9a92f540ad545ad6 100644 --- a/tensorflow/compiler/aot/tfcompile.bzl +++ b/tensorflow/compiler/aot/tfcompile.bzl @@ -16,339 +16,365 @@ tf_library( ) """ -load("//tensorflow:tensorflow.bzl", - "if_android", "tf_cc_test", "tf_copts") - -def tf_library(name, graph, config, - freeze_checkpoint=None, freeze_saver=None, - cpp_class=None, gen_test=True, gen_benchmark=True, - visibility=None, testonly=None, - tfcompile_flags=None, - tfcompile_tool="//tensorflow/compiler/aot:tfcompile", - include_standard_runtime_deps=True, - enable_xla_hlo_profiling=False, deps=None, tags=None): - """Runs tfcompile to compile a TensorFlow graph into executable code. - - Given an invocation of tf_library(name="foo", ...), generates the following - build targets: - foo: A cc_library containing the generated header and computation. - foo_test: A cc_test with simple tests and benchmarks. Only created if - gen_test=True. - foo_benchmark: A cc_binary that runs a minimal-dependency benchmark, useful - for mobile devices or other platforms that can't compile the - full test libraries. Only created if gen_benchmark=True. - - Args: - name: The name of the build rule. - graph: The TensorFlow GraphDef to compile. If the file ends in '.pbtxt' it - is expected to be in the human-readable proto text format, otherwise it is - expected to be in the proto binary format. - config: File containing tensorflow.tf2xla.Config proto. If the file ends - in '.pbtxt' it is expected to be in the human-readable proto text format, - otherwise it is expected to be in the proto binary format. - freeze_checkpoint: If provided, run freeze_graph with this checkpoint to - convert variables into constants. - freeze_saver: If provided, run freeze_graph with this saver, in SaverDef - binary form, to convert variables into constants. - cpp_class: The name of the generated C++ class, wrapping the generated - function. The syntax of this flag is - [[::],...]. This mirrors the C++ syntax - for referring to a class, where multiple namespaces may precede the class - name, separated by double-colons. The class will be generated in the - given namespace(s), or if no namespaces are given, within the global - namespace. - gen_test: If True, also generate a cc_test rule that builds a simple - test and benchmark. - gen_benchmark: If True, also generate a binary with a simple benchmark. - Unlike the output of gen_test, this benchmark can be run on android. - visibility: Bazel build visibility. - testonly: Bazel testonly attribute. - tfcompile_flags: Extra flags to pass to tfcompile to control compilation. - tfcompile_tool: The tfcompile binary. A non-default can be passed to - use a tfcompile built with extra dependencies. - include_standard_runtime_deps: If True, the standard list of kernel/runtime - deps is added to deps. If False, deps must contain the full set of deps - needed by the generated library. - enable_xla_hlo_profiling: Enable XLA HLO profiling in the generated program, - and emit metadata that lets us pretty-print the gathered profile counters. - deps: a list of deps to include on the build rules for the generated - library, added to the standard deps if standard_runtime_deps is True. - tags: tags to apply to subsidiary build rules. - - The output header is called .h. - """ - if not cpp_class: - fail("cpp_class must be specified") - - tfcompile_graph = graph - if freeze_checkpoint or freeze_saver: - if not freeze_checkpoint: - fail("freeze_checkpoint must be specified when freeze_saver is specified") +load( + "//tensorflow:tensorflow.bzl", + "if_android", + "tf_cc_test", + "tf_copts", +) - freeze_name = "freeze_" + name - freeze_file = freeze_name + ".pb" +def tf_library( + name, + graph, + config, + freeze_checkpoint = None, + freeze_saver = None, + cpp_class = None, + gen_test = True, + gen_benchmark = True, + visibility = None, + testonly = None, + tfcompile_flags = None, + tfcompile_tool = "//tensorflow/compiler/aot:tfcompile", + include_standard_runtime_deps = True, + enable_xla_hlo_profiling = False, + deps = None, + tags = None): + """Runs tfcompile to compile a TensorFlow graph into executable code. - # First run tfcompile to generate the list of out_nodes. - out_nodes_file = "out_nodes_" + freeze_name - native.genrule( - name=("gen_" + out_nodes_file), - srcs=[config], - outs=[out_nodes_file], - cmd=("$(location " + tfcompile_tool + ")" + - " --config=$(location " + config + ")" + - " --dump_fetch_nodes > $@"), - tools=[tfcompile_tool], - # Run tfcompile on the build host, rather than forge, since it's - # typically way faster on the local machine. - local=1, - tags=tags, - ) + Given an invocation of tf_library(name="foo", ...), generates the following + build targets: + foo: A cc_library containing the generated header and + computation. + foo_test: A cc_test with simple tests and benchmarks. Only created if + gen_test=True. + foo_benchmark: A cc_binary that runs a minimal-dependency benchmark, + useful for mobile devices or other platforms that can't + compile the full test libraries. Only created if + gen_benchmark=True. + The output header is called .h. - # Now run freeze_graph to convert variables into constants. - freeze_args = (" --input_graph=$(location " + graph + ")" + - " --checkpoint_version=1" + - " --input_binary=" + str(not graph.endswith(".pbtxt")) + - " --input_checkpoint=$(location " + freeze_checkpoint + ")" + - " --output_graph=$(location " + freeze_file + ")" + - " --output_node_names=$$(<$(location " + out_nodes_file + - "))") - freeze_saver_srcs = [] - if freeze_saver: - freeze_args += " --input_saver=$(location " + freeze_saver + ")" - freeze_saver_srcs += [freeze_saver] - native.genrule( - name=freeze_name, - srcs=[ - graph, - freeze_checkpoint, - out_nodes_file, - ] + freeze_saver_srcs, - outs=[freeze_file], - cmd=("$(location //tensorflow/python/tools:freeze_graph)" + - freeze_args), - tools=["//tensorflow/python/tools:freeze_graph"], - tags=tags, - ) - tfcompile_graph = freeze_file + Args: + name: The name of the build rule. + graph: The TensorFlow GraphDef to compile. If the file ends in '.pbtxt' + it is expected to be in the human-readable proto text format, otherwise + it is expected to be in the proto binary format. + config: File containing tensorflow.tf2xla.Config proto. If the file ends + in '.pbtxt' it is expected to be in the human-readable proto text + format, otherwise it is expected to be in the proto binary format. + freeze_checkpoint: If provided, run freeze_graph with this checkpoint to + convert variables into constants. + freeze_saver: If provided, run freeze_graph with this saver, in SaverDef + binary form, to convert variables into constants. + cpp_class: The name of the generated C++ class, wrapping the generated + function. The syntax of this flag is + [[::],...]. This mirrors the C++ syntax + for referring to a class, where multiple namespaces may precede the + class name, separated by double-colons. The class will be generated in + the given namespace(s), or if no namespaces are given, within the global + namespace. + gen_test: If True, also generate a cc_test rule that builds a simple + test and benchmark. + gen_benchmark: If True, also generate a binary with a simple benchmark. + Unlike the output of gen_test, this benchmark can be run on android. + visibility: Bazel build visibility. + testonly: Bazel testonly attribute. + tfcompile_flags: Extra flags to pass to tfcompile to control compilation. + tfcompile_tool: The tfcompile binary. A non-default can be passed to + use a tfcompile built with extra dependencies. + include_standard_runtime_deps: If True, the standard list of + kernel/runtime deps is added to deps. If False, deps must contain the + full set of deps needed by the generated library. + enable_xla_hlo_profiling: Enable XLA HLO profiling in the generated + program, and emit metadata that lets us pretty-print the gathered + profile counters. + deps: a list of deps to include on the build rules for the generated + library, added to the standard deps if standard_runtime_deps is True. + tags: tags to apply to subsidiary build rules. + """ + if not cpp_class: + fail("cpp_class must be specified") - # Rule that runs tfcompile to produce the header and object file. - header_file = name + ".h" - metadata_object_file = name + "_tfcompile_metadata.o" - function_object_file = name + "_tfcompile_function.o" - ep = ("__" + native.package_name() + "__" + name).replace("/", "_") - if type(tfcompile_flags) == type(""): - flags = tfcompile_flags - else: - flags = " ".join(["'" + arg.replace("'", "'\\''") + "'" for arg in (tfcompile_flags or [])]) - if enable_xla_hlo_profiling: - profiling_flag = "--xla_hlo_profile" - else: - profiling_flag = "" - native.genrule( - name=("gen_" + name), - srcs=[ - tfcompile_graph, - config, - ], - outs=[ - header_file, - metadata_object_file, - function_object_file, - ], - cmd=("$(location " + tfcompile_tool + ")" + - " --graph=$(location " + tfcompile_graph + ")" + - " --config=$(location " + config + ")" + - " --entry_point=" + ep + - " --cpp_class=" + cpp_class + - " --target_triple=" + target_llvm_triple() + - " --out_header=$(@D)/" + header_file + - " --out_metadata_object=$(@D)/" + metadata_object_file + - " --out_function_object=$(@D)/" + function_object_file + - " " + flags + " " + profiling_flag), - tools=[tfcompile_tool], - visibility=visibility, - testonly=testonly, - # Run tfcompile on the build host since it's typically faster on the local - # machine. - # - # Note that setting the local=1 attribute on a *test target* causes the - # test infrastructure to skip that test. However this is a genrule, not a - # test target, and runs with --genrule_strategy=forced_forge, meaning the - # local=1 attribute is ignored, and the genrule is still run. - # - # https://www.bazel.io/versions/master/docs/be/general.html#genrule - local=1, - tags=tags, - ) + tfcompile_graph = graph + if freeze_checkpoint or freeze_saver: + if not freeze_checkpoint: + fail("freeze_checkpoint must be specified when freeze_saver is " + + "specified") - # Rule that runs tfcompile to produce the SessionModule proto, useful for - # debugging. TODO(b/64813587): Once the SessionModule proto is - # deterministic, move this into the main rule above. - session_module_pb = name + "_session_module.pb" - native.genrule( - name=(name + "_session_module"), - srcs=[ - tfcompile_graph, - config, - ], - outs=[ - session_module_pb, - ], - cmd=("$(location " + tfcompile_tool + ")" + - " --graph=$(location " + tfcompile_graph + ")" + - " --config=$(location " + config + ")" + - " --entry_point=" + ep + - " --cpp_class=" + cpp_class + - " --target_triple=" + target_llvm_triple() + - " --out_session_module=$(@D)/" + session_module_pb + - " " + flags), - tools=[tfcompile_tool], - visibility=visibility, - testonly=testonly, - local=1, - tags=tags, - ) + freeze_name = "freeze_" + name + freeze_file = freeze_name + ".pb" - # The cc_library rule packaging up the header and object file, and needed - # kernel implementations. - need_xla_data_proto = (flags and flags.find("--gen_program_shape") != -1) - native.cc_library( - name=name, - srcs=[function_object_file, metadata_object_file], - hdrs=[header_file], - visibility=visibility, - testonly=testonly, - deps = [ - # These deps are required by all tf_library targets even if - # include_standard_runtime_deps is False. Without them, the - # generated code will fail to compile. - "//tensorflow/compiler/tf2xla:xla_compiled_cpu_function", - "//tensorflow/core:framework_lite", - ] + (need_xla_data_proto and [ - # If we're generating the program shape, we must depend on the proto. - "//tensorflow/compiler/xla:xla_data_proto", - ] or []) + (enable_xla_hlo_profiling and [ - "//tensorflow/compiler/xla/service:hlo_profile_printer_data" - ] or []) + (include_standard_runtime_deps and [ - # TODO(cwhipkey): only depend on kernel code that the model actually needed. - "//tensorflow/compiler/tf2xla/kernels:index_ops_kernel_argmax_float_1d", - "//tensorflow/compiler/tf2xla/kernels:index_ops_kernel_argmax_float_2d", - "//tensorflow/compiler/xla/service/cpu:runtime_conv2d", - "//tensorflow/compiler/xla/service/cpu:runtime_matmul", - "//tensorflow/compiler/xla/service/cpu:runtime_single_threaded_conv2d", - "//tensorflow/compiler/xla/service/cpu:runtime_single_threaded_matmul", - "//third_party/eigen3", - ] or []) + (deps or []), - tags=tags, - ) + # First run tfcompile to generate the list of out_nodes. + out_nodes_file = "out_nodes_" + freeze_name + native.genrule( + name = ("gen_" + out_nodes_file), + srcs = [config], + outs = [out_nodes_file], + cmd = ("$(location " + tfcompile_tool + ")" + + " --config=$(location " + config + ")" + + " --dump_fetch_nodes > $@"), + tools = [tfcompile_tool], + # Run tfcompile on the build host, rather than forge, since it's + # typically way faster on the local machine. + local = 1, + tags = tags, + ) - # Variables used for gen_test and gen_benchmark. - no_ns_name = "" - cpp_class_split = cpp_class.rsplit("::", maxsplit=2) - if len(cpp_class_split) == 1: - no_ns_name = cpp_class_split[0] - else: - no_ns_name = cpp_class_split[1] - sed_replace = ( - "-e \"s|{{TFCOMPILE_HEADER}}|$(location " + header_file + ")|g\" " + - "-e \"s|{{TFCOMPILE_CPP_CLASS}}|" + cpp_class + "|g\" " + - "-e \"s|{{TFCOMPILE_NAME}}|" + no_ns_name + "|g\" ") + # Now run freeze_graph to convert variables into constants. + freeze_args = ( + " --input_graph=$(location " + graph + ")" + + " --checkpoint_version=1" + + " --input_binary=" + str(not graph.endswith(".pbtxt")) + + " --input_checkpoint=$(location " + freeze_checkpoint + ")" + + " --output_graph=$(location " + freeze_file + ")" + + " --output_node_names=$$(<$(location " + out_nodes_file + + "))" + ) + freeze_saver_srcs = [] + if freeze_saver: + freeze_args += " --input_saver=$(location " + freeze_saver + ")" + freeze_saver_srcs += [freeze_saver] + native.genrule( + name = freeze_name, + srcs = [ + graph, + freeze_checkpoint, + out_nodes_file, + ] + freeze_saver_srcs, + outs = [freeze_file], + cmd = ("$(location " + + "//tensorflow/python/tools:freeze_graph)" + + freeze_args), + tools = ["//tensorflow/python/tools:freeze_graph"], + tags = tags, + ) + tfcompile_graph = freeze_file - if gen_test: - test_name = name + "_test" - test_file = test_name + ".cc" - # Rule to rewrite test.cc to produce the test_file. + # Rule that runs tfcompile to produce the header and object file. + header_file = name + ".h" + metadata_object_file = name + "_tfcompile_metadata.o" + function_object_file = name + "_tfcompile_function.o" + ep = ("__" + native.package_name() + "__" + name).replace("/", "_") + if type(tfcompile_flags) == type(""): + flags = tfcompile_flags + else: + flags = " ".join([ + "'" + arg.replace("'", "'\\''") + "'" + for arg in (tfcompile_flags or []) + ]) + if enable_xla_hlo_profiling: + profiling_flag = "--xla_hlo_profile" + else: + profiling_flag = "" native.genrule( - name=("gen_" + test_name), - testonly=1, - srcs=[ - "//tensorflow/compiler/aot:test.cc", + name = ("gen_" + name), + srcs = [ + tfcompile_graph, + config, + ], + outs = [ header_file, + metadata_object_file, + function_object_file, ], - outs=[test_file], - cmd=("sed " + sed_replace + - " $(location //tensorflow/compiler/aot:test.cc) " + - "> $(OUTS)"), - tags=tags, - ) - - # The cc_test rule for the generated code. To ensure that this works - # reliably across build configurations, we must use tf_cc_test instead of - # native.cc_test. This is related to how we build - # //tensorflow/core:lib -- see the note in tensorflow/core/BUILD - # for more details. - tf_cc_test( - name=test_name, - srcs=[test_file], - deps=[ - ":" + name, - "//tensorflow/compiler/aot:runtime", - "//tensorflow/compiler/aot:tf_library_test_main", - "//tensorflow/compiler/xla:executable_run_options", - "//third_party/eigen3", - "//tensorflow/core:lib", - "//tensorflow/core:test", - ], - tags=tags, + cmd = ("$(location " + tfcompile_tool + ")" + + " --graph=$(location " + tfcompile_graph + ")" + + " --config=$(location " + config + ")" + + " --entry_point=" + ep + + " --cpp_class=" + cpp_class + + " --target_triple=" + target_llvm_triple() + + " --out_header=$(@D)/" + header_file + + " --out_metadata_object=$(@D)/" + metadata_object_file + + " --out_function_object=$(@D)/" + function_object_file + + " " + flags + " " + profiling_flag), + tools = [tfcompile_tool], + visibility = visibility, + testonly = testonly, + # Run tfcompile on the build host since it's typically faster on the + # local machine. + # + # Note that setting the local=1 attribute on a *test target* causes the + # test infrastructure to skip that test. However this is a genrule, not + # a test target, and runs with --genrule_strategy=forced_forge, meaning + # the local=1 attribute is ignored, and the genrule is still run. + # + # https://www.bazel.io/versions/master/docs/be/general.html#genrule + local = 1, + tags = tags, ) - if gen_benchmark: - benchmark_name = name + "_benchmark" - benchmark_file = benchmark_name + ".cc" - benchmark_main = ("//tensorflow/compiler/aot:" + - "benchmark_main.template") - - # Rule to rewrite benchmark.cc to produce the benchmark_file. + # Rule that runs tfcompile to produce the SessionModule proto, useful for + # debugging. TODO(b/64813587): Once the SessionModule proto is + # deterministic, move this into the main rule above. + session_module_pb = name + "_session_module.pb" native.genrule( - name=("gen_" + benchmark_name), - srcs=[ - benchmark_main, - header_file, + name = (name + "_session_module"), + srcs = [ + tfcompile_graph, + config, ], + outs = [ + session_module_pb, + ], + cmd = ("$(location " + tfcompile_tool + ")" + + " --graph=$(location " + tfcompile_graph + ")" + + " --config=$(location " + config + ")" + + " --entry_point=" + ep + + " --cpp_class=" + cpp_class + + " --target_triple=" + target_llvm_triple() + + " --out_session_module=$(@D)/" + session_module_pb + + " " + flags), + tools = [tfcompile_tool], + visibility = visibility, testonly = testonly, - outs=[benchmark_file], - cmd=("sed " + sed_replace + - " $(location " + benchmark_main + ") " + - "> $(OUTS)"), - tags=tags, + local = 1, + tags = tags, ) - # The cc_benchmark rule for the generated code. This does not need the - # tf_cc_binary since we (by deliberate design) do not depend on - # //tensorflow/core:lib. - # - # Note: to get smaller size on android for comparison, compile with: - # --copt=-fvisibility=hidden - # --copt=-D_LIBCPP_TYPE_VIS=_LIBCPP_HIDDEN - # --copt=-D_LIBCPP_EXCEPTION_ABI=_LIBCPP_HIDDEN - native.cc_binary( - name=benchmark_name, - srcs=[benchmark_file], + # The cc_library rule packaging up the header and object file, and needed + # kernel implementations. + need_xla_data_proto = (flags and flags.find("--gen_program_shape") != -1) + native.cc_library( + name = name, + srcs = [function_object_file, metadata_object_file], + hdrs = [header_file], + visibility = visibility, testonly = testonly, - copts = tf_copts(), - linkopts = if_android(["-pie", "-s"]), - deps=[ - ":" + name, - "//tensorflow/compiler/aot:benchmark", - "//tensorflow/compiler/aot:runtime", - "//tensorflow/compiler/xla:executable_run_options", + deps = [ + # These deps are required by all tf_library targets even if + # include_standard_runtime_deps is False. Without them, the + # generated code will fail to compile. + "//tensorflow/compiler/tf2xla:xla_compiled_cpu_function", + "//tensorflow/core:framework_lite", + ] + (need_xla_data_proto and [ + # If we're generating the program shape, we must depend on the + # proto. + "//tensorflow/compiler/xla:xla_data_proto", + ] or []) + (enable_xla_hlo_profiling and [ + "//tensorflow/compiler/xla/service:hlo_profile_printer_data", + ] or []) + (include_standard_runtime_deps and [ + # TODO(cwhipkey): only depend on kernel code that the model actually + # needed. + "//tensorflow/compiler/tf2xla/kernels:index_ops_kernel_argmax_float_1d", + "//tensorflow/compiler/tf2xla/kernels:index_ops_kernel_argmax_float_2d", + "//tensorflow/compiler/xla/service/cpu:runtime_conv2d", + "//tensorflow/compiler/xla/service/cpu:runtime_matmul", + "//tensorflow/compiler/xla/service/cpu:runtime_single_threaded_conv2d", + "//tensorflow/compiler/xla/service/cpu:runtime_single_threaded_matmul", "//third_party/eigen3", - ] + if_android([ - "//tensorflow/compiler/aot:benchmark_extra_android", - ]), - tags=tags, + ] or []) + (deps or []), + tags = tags, + ) + + # Variables used for gen_test and gen_benchmark. + cpp_class_split = cpp_class.rsplit("::", maxsplit = 2) + if len(cpp_class_split) == 1: + no_ns_name = cpp_class_split[0] + else: + no_ns_name = cpp_class_split[1] + sed_replace = ( + "-e \"s|{{TFCOMPILE_HEADER}}|$(location " + header_file + ")|g\" " + + "-e \"s|{{TFCOMPILE_CPP_CLASS}}|" + cpp_class + "|g\" " + + "-e \"s|{{TFCOMPILE_NAME}}|" + no_ns_name + "|g\" " ) + if gen_test: + test_name = name + "_test" + test_file = test_name + ".cc" + + # Rule to rewrite test.cc to produce the test_file. + native.genrule( + name = ("gen_" + test_name), + testonly = 1, + srcs = [ + "//tensorflow/compiler/aot:test.cc", + header_file, + ], + outs = [test_file], + cmd = ( + "sed " + sed_replace + + " $(location //tensorflow/compiler/aot:test.cc) " + + "> $(OUTS)" + ), + tags = tags, + ) + + # The cc_test rule for the generated code. To ensure that this works + # reliably across build configurations, we must use tf_cc_test instead + # of native.cc_test. This is related to how we build + # //tensorflow/core:lib -- see the note in + # tensorflow/core/BUILD for more details. + tf_cc_test( + name = test_name, + srcs = [test_file], + deps = [ + ":" + name, + "//tensorflow/compiler/aot:tf_library_test_main", + "//tensorflow/compiler/xla:executable_run_options", + "//third_party/eigen3", + "//tensorflow/core:lib", + "//tensorflow/core:test", + ], + tags = tags, + ) + + if gen_benchmark: + benchmark_name = name + "_benchmark" + benchmark_file = benchmark_name + ".cc" + benchmark_main = ("//tensorflow/compiler/aot:" + + "benchmark_main.template") + + # Rule to rewrite benchmark.cc to produce the benchmark_file. + native.genrule( + name = ("gen_" + benchmark_name), + srcs = [ + benchmark_main, + header_file, + ], + testonly = testonly, + outs = [benchmark_file], + cmd = ("sed " + sed_replace + + " $(location " + benchmark_main + ") " + + "> $(OUTS)"), + tags = tags, + ) + + # The cc_benchmark rule for the generated code. This does not need the + # tf_cc_binary since we (by deliberate design) do not depend on + # //tensorflow/core:lib. + # + # Note: to get smaller size on android for comparison, compile with: + # --copt=-fvisibility=hidden + # --copt=-D_LIBCPP_TYPE_VIS=_LIBCPP_HIDDEN + # --copt=-D_LIBCPP_EXCEPTION_ABI=_LIBCPP_HIDDEN + native.cc_binary( + name = benchmark_name, + srcs = [benchmark_file], + testonly = testonly, + copts = tf_copts(), + linkopts = if_android(["-pie", "-s"]), + deps = [ + ":" + name, + "//tensorflow/compiler/aot:benchmark", + "//tensorflow/compiler/xla:executable_run_options", + "//third_party/eigen3", + ] + if_android([ + "//tensorflow/compiler/aot:benchmark_extra_android", + ]), + tags = tags, + ) + def target_llvm_triple(): - """Returns the target LLVM triple to be used for compiling the target.""" - # TODO(toddw): Add target_triple for other targets. For details see: - # http://llvm.org/docs/doxygen/html/Triple_8h_source.html - return select({ - "//tensorflow:android_armeabi": "armv5-none-android", - "//tensorflow:android_arm": "armv7-none-android", - "//tensorflow:android_arm64": "aarch64-none-android", - "//tensorflow:android_x86": "i686-none-android", - "//tensorflow:linux_ppc64le": "ppc64le-ibm-linux-gnu", - "//tensorflow:darwin": "x86_64-none-darwin", - "//conditions:default": "x86_64-pc-linux", - }) + """Returns the target LLVM triple to be used for compiling the target.""" + + # TODO(toddw): Add target_triple for other targets. For details see: + # http://llvm.org/docs/doxygen/html/Triple_8h_source.html + return select({ + "//tensorflow:android_armeabi": "armv5-none-android", + "//tensorflow:android_arm": "armv7-none-android", + "//tensorflow:android_arm64": "aarch64-none-android", + "//tensorflow:android_x86": "i686-none-android", + "//tensorflow:linux_ppc64le": "ppc64le-ibm-linux-gnu", + "//tensorflow:darwin": "x86_64-none-darwin", + "//conditions:default": "x86_64-pc-linux", + }) diff --git a/tensorflow/compiler/jit/BUILD b/tensorflow/compiler/jit/BUILD index 9174a67cc6d110ac21c7bb09346bb1b2dfad0579..2c9adfe4f0d8b53a987e5338d1e7f82de47747b7 100644 --- a/tensorflow/compiler/jit/BUILD +++ b/tensorflow/compiler/jit/BUILD @@ -128,11 +128,11 @@ cc_library( "//tensorflow/compiler/tf2xla:common", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/service:shaped_buffer", - "//tensorflow/core:core_cpu", "//tensorflow/core:core_cpu_internal", "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", + "@com_google_absl//absl/memory", ], ) @@ -160,12 +160,14 @@ cc_library( "//tensorflow/compiler/jit/ops:xla_ops", "//tensorflow/compiler/tf2xla:common", "//tensorflow/compiler/tf2xla:dump_graph", + "//tensorflow/compiler/tf2xla:tf2xla_util", "//tensorflow/compiler/tf2xla:xla_compiler", "//tensorflow/compiler/tf2xla/kernels:xla_ops", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla/client:client_library", "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/service:stream_pool", "//tensorflow/core:core_cpu", "//tensorflow/core:core_cpu_internal", "//tensorflow/core:framework", @@ -177,6 +179,7 @@ cc_library( "//tensorflow/core/kernels:constant_op", "//tensorflow/core/kernels:control_flow_ops", "//tensorflow/core/kernels:fifo_queue", + "//tensorflow/core/kernels:function_ops", "//tensorflow/core/kernels:identity_n_op", "//tensorflow/core/kernels:identity_op", "//tensorflow/core/kernels:no_op", @@ -185,6 +188,10 @@ cc_library( "//tensorflow/core/kernels:sendrecv_ops", "//tensorflow/core/kernels:shape_ops", "//tensorflow/core/kernels:variable_ops", + "//tensorflow/core/kernels/data:generator_dataset_op", + "//tensorflow/core/kernels/data:iterator_ops", + "//tensorflow/core/kernels/data:prefetch_dataset_op", + "@com_google_absl//absl/memory", ], ) @@ -229,6 +236,7 @@ cc_library( "//tensorflow/core:lib_internal", "//tensorflow/core:protos_all_cc", "//tensorflow/core/kernels:variable_ops", + "@com_google_absl//absl/memory", ], ) @@ -277,6 +285,7 @@ cc_library( "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", + "@com_google_absl//absl/memory", ], alwayslink = 1, ) @@ -297,6 +306,7 @@ tf_cc_test( "//tensorflow/core:test", "//tensorflow/core:test_main", "//tensorflow/core:testlib", + "@com_google_absl//absl/memory", ], ) @@ -305,14 +315,19 @@ cc_library( srcs = [ "build_xla_launch_ops_pass.cc", "deadness_analysis.cc", + "deadness_analysis_internal.h", "encapsulate_subgraphs_pass.cc", "mark_for_compilation_pass.cc", + "mark_for_compilation_pass_test_helper.cc", + "partially_decluster_pass.cc", ], hdrs = [ "build_xla_launch_ops_pass.h", "deadness_analysis.h", "encapsulate_subgraphs_pass.h", "mark_for_compilation_pass.h", + "mark_for_compilation_pass_test_helper.h", + "partially_decluster_pass.h", ], deps = [ ":common", @@ -347,6 +362,7 @@ cc_library( "//tensorflow/compiler/jit/graphcycles", "//tensorflow/core:framework", "//tensorflow/core:graph", + "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", "//tensorflow/core/kernels:bounds_check", ], @@ -377,16 +393,46 @@ tf_cc_test( ) tf_cc_test( - name = "compilation_passes_test", + name = "deadness_analysis_test", size = "small", srcs = [ + "deadness_analysis_internal.h", "deadness_analysis_test.cc", + ], + deps = [ + ":common", + ":compilation_passes", + "//tensorflow/cc:cc_ops", + "//tensorflow/cc:cc_ops_internal", + "//tensorflow/cc:function_ops", + "//tensorflow/cc:ops", + "//tensorflow/cc:sendrecv_ops", + "//tensorflow/compiler/jit/kernels:xla_launch_op", + "//tensorflow/compiler/tf2xla:xla_compiler", + "//tensorflow/compiler/tf2xla/kernels:xla_ops", + "//tensorflow/core:core_cpu", + "//tensorflow/core:framework", + "//tensorflow/core:framework_internal", + "//tensorflow/core:graph", + "//tensorflow/core:lib", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core:testlib", + ], +) + +tf_cc_test( + name = "compilation_passes_test", + size = "small", + srcs = [ "encapsulate_subgraphs_pass_test.cc", "mark_for_compilation_pass_test.cc", + "partially_decluster_pass_test.cc", ], deps = [ ":common", ":compilation_passes", + ":xla_cluster_util", "//tensorflow/cc:cc_ops", "//tensorflow/cc:cc_ops_internal", "//tensorflow/cc:function_ops", diff --git a/tensorflow/compiler/jit/create_xla_launch_op.cc b/tensorflow/compiler/jit/create_xla_launch_op.cc index a2e6285339f9ed0bde8d72f5b4752b1ecc22f426..1b1ce78ed2b79d0948b6fc951f82a2cebe8009e5 100644 --- a/tensorflow/compiler/jit/create_xla_launch_op.cc +++ b/tensorflow/compiler/jit/create_xla_launch_op.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/jit/create_xla_launch_op.h" +#include "absl/memory/memory.h" #include "tensorflow/compiler/jit/defs.h" #include "tensorflow/compiler/jit/kernels/xla_launch_op.h" #include "tensorflow/compiler/jit/mark_for_compilation_pass.h" @@ -223,8 +224,8 @@ Status CreateXlaLaunchOp(FunctionLibraryRuntime* flr, const NodeDef& node_def, &fbody->fdef.signature(), flr, fbody->arg_types, input_memory_types, fbody->ret_types, output_memory_types, flr->graph_def_version(), &s); - *kernel = MakeUnique(&construction, constant_arg_indices, - resource_arg_indices, function); + *kernel = absl::make_unique( + &construction, constant_arg_indices, resource_arg_indices, function); return s; } diff --git a/tensorflow/compiler/jit/create_xla_launch_op_test.cc b/tensorflow/compiler/jit/create_xla_launch_op_test.cc index b75ab486b80e098bc0a59f9ea8cdbaa23a28fef9..73866607621cd745f6e640a14405daebf0dd9985 100644 --- a/tensorflow/compiler/jit/create_xla_launch_op_test.cc +++ b/tensorflow/compiler/jit/create_xla_launch_op_test.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/compiler/jit/create_xla_launch_op.h" +#include "absl/memory/memory.h" #include "tensorflow/core/common_runtime/device_factory.h" #include "tensorflow/core/common_runtime/function.h" #include "tensorflow/core/framework/function_testlib.h" @@ -65,11 +66,11 @@ class CreateXlaLaunchOpTest : public ::testing::Test { for (const auto& fdef : flib) { *(proto.add_function()) = fdef; } - lib_def_ = - MakeUnique(OpRegistry::Global(), proto); + lib_def_ = absl::make_unique( + OpRegistry::Global(), proto); OptimizerOptions opts; - device_mgr_ = MakeUnique(devices_); - pflr_ = MakeUnique( + device_mgr_ = absl::make_unique(devices_); + pflr_ = absl::make_unique( device_mgr_.get(), Env::Default(), TF_GRAPH_DEF_VERSION, lib_def_.get(), opts, /*default_thread_pool=*/nullptr, /*cluster_flr=*/nullptr); flr_ = pflr_->GetFLR("/job:localhost/replica:0/task:0/cpu:0"); diff --git a/tensorflow/compiler/jit/deadness_analysis.cc b/tensorflow/compiler/jit/deadness_analysis.cc index d81e5fe9008975c126bcd8e0ea7cef19f1eb1bf3..0ca0f949dcd13992ccd9504d75ca65d2aff72a19 100644 --- a/tensorflow/compiler/jit/deadness_analysis.cc +++ b/tensorflow/compiler/jit/deadness_analysis.cc @@ -14,24 +14,86 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/jit/deadness_analysis.h" +#include "tensorflow/compiler/jit/deadness_analysis_internal.h" #include "tensorflow/core/graph/algorithm.h" #include "tensorflow/core/graph/tensor_id.h" #include "tensorflow/core/lib/gtl/flatset.h" #include "tensorflow/core/lib/hash/hash.h" // ALGORITHM OVERVIEW +// ================== // // We map every output produced by each node in the TensorFlow graph (including // control dependence) into an instance of the Predicate class. Instances of // Predicate denote logical formulas and mapping a node `n` to a predicate -// `pred` implies that `n` is executed whenver `pred` is true. Then we can -// deduce mismatching liveness in the inputs to node by comparing the predicate -// those inputs are mapped to. +// `pred` implies that `n` is live whenever `pred` is true. Then we can deduce +// mismatching liveness in the inputs to node by comparing the predicate those +// inputs are mapped to. The core logic of this pass resides in creating the +// map from TensorFlow nodes to predicates. // -// Loops are handled pessimistically -- we map Merge nodes with backedges to -// uninterpreted symbols (the same kind we use to represent Switch and _Recv). -// Predicate equality has to hold over all possible assignments to these -// uninterpreted symbols. +// +// MAPPING NODES TO PREDICATES, MODULO CYCLES +// ------------------------------------------ +// +// If we ignore cycles for a moment, computing predicates is fairly +// straightforward. We traverse the graph in RPO, mapping each node to a +// predicate based on the predicates its inputs are mapped to. For instance a +// Merge(X, Y) node will be mapped to OR(PredicateFor(X), PredicateFor(Y)). +// Roughtly speaking, we abstract interpret each node on the "liveness" domain, +// where values in the domain represent if a tensor carries a dead signal or +// not. +// +// +// DEALING WITH CYCLES +// ------------------- +// +// We map Merge nodes that are the target of a backedge to AndRecurrence +// instances. An AndRecurrence with start() = S and step() = X, printed as +// {S,&,X}, *roughly* represents the infinite list of predicates +// [S,S&X,S&X&X,S&X&X, ...]. So {S,&,X} can be used to represent the predicate +// for Merge in a graph like: +// +// Init +// | +// v +// Merge <-----------+ +// | | +// v | +// Incr | +// | | +// v | +// Switch <- Cond | +// | | +// v (oidx: 1) | +// | | +// +---------------+ +// +// Where S is the predicate for Init and X is the predicate that asserts that +// Cond is true. {S,&,X} states that Merge is live on the first "iteration" iff +// S is true, live on the second iteration iff "S&X" is true, live on the third +// iteration iff "S&X&X" is true etc. There is a subtlety here, S&X&X would +// normally be equivalent to S&X which isn't quite what we want to represent. +// Instead we want {S,&,X} to denote the infinite list [S, S&X, +// S&X&X',S&X&X'&X'', ...] where X, X', X'' are predicates that assert Cond is +// true on iteration 0, 1, 2 respectively. This is made more precise in the +// comment on the AndRecurrence class. +// +// The general algorithm that deals with cycles does two RPO (reverse post +// order) passes over the graph. On the first pass it assigns a symbolic +// predicate to merge nodes with backedges. On the second pass it tries to +// pattern matche the predicates for the backedges of these merges and infer an +// AndRecurrence for the merge. +// +// In other words, we do a pessimistic data flow analysis where the data-flow +// lattice has two elements, Symbolic and NonSymbolic with Symbolic > +// NonSymbolic. The lattice has height = 2 so two iterations are sufficient to +// converge. We don't do an optimistic data flow analysis to make pattern +// matching easier: if we assigned the predicate of the initial value to the +// merge during the first pass, on the second pass the backedge may see a +// simplified value that would be difficult to pattern match. +// +// We still use symbolic predicates for merges for which we can't pattern match +// on the backedge predicate. This is conservatively correct. namespace tensorflow { @@ -41,14 +103,21 @@ namespace { // above. class Predicate { public: - enum class Kind { kAnd, kOr, kNot, kSymbol }; + enum class Kind { kAnd, kOr, kNot, kAndRecurrence, kSymbol }; virtual string ToString() const = 0; int64 hash() const { return hash_; } + virtual gtl::ArraySlice GetOperands() const = 0; virtual Kind kind() const = 0; virtual ~Predicate() {} + // Invokes func on p and on all of its operands recursively. Does not invoke + // `func` on the same Predicate instance twice. Aborts the search if `func` + // returns true. + template + static void Visit(Predicate* p, const FunctionTy& func); + protected: explicit Predicate(int64 hash) : hash_(hash) {} @@ -89,7 +158,8 @@ class AndPredicate : public Predicate { Kind kind() const override { return Kind::kAnd; } - const gtl::ArraySlice operands() const { return operands_; } + gtl::ArraySlice GetOperands() const override { return operands_; } + gtl::ArraySlice operands() const { return operands_; } private: std::vector operands_; @@ -116,7 +186,8 @@ class OrPredicate : public Predicate { } Kind kind() const override { return Kind::kOr; } - const gtl::ArraySlice operands() const { return operands_; } + gtl::ArraySlice GetOperands() const override { return operands_; } + gtl::ArraySlice operands() const { return operands_; } private: std::vector operands_; @@ -127,23 +198,58 @@ class NotPredicate : public Predicate { public: explicit NotPredicate(Predicate* operand) : Predicate(HashPredicateSequence(Kind::kNot, {operand})), - operand_(operand) {} + operands_({operand}) {} string ToString() const override { return strings::StrCat("~", operand()->ToString()); } Kind kind() const override { return Kind::kNot; } - Predicate* operand() const { return operand_; } + Predicate* operand() const { return operands_[0]; } + gtl::ArraySlice GetOperands() const override { return operands_; } private: - Predicate* operand_; + std::array operands_; +}; + +// Represents an infinite list of predicates. +// +// An AndRecurrence with start = S and step = X is printed as {S,&,X} and stands +// for the list of predicates: +// +// S, S & GenSym(X,1), S & GenSym(X,1) & GenSym(X,2), ... +// +// where GenSym(, ) renames every SymbolPredicate in +// by appending to it, in effect creating a "fresh" symbol. +// This means {P,&,Q} is not equal to "P on the first iteration; P&Q on +// subsequent iterations". +class AndRecurrencePredicate : public Predicate { + public: + explicit AndRecurrencePredicate(Predicate* start, Predicate* step) + : Predicate(HashPredicateSequence(Kind::kAndRecurrence, {start, step})), + operands_({start, step}) {} + + Predicate* start() const { return operands_[0]; } + Predicate* step() const { return operands_[1]; } + + string ToString() const override { + return strings::StrCat("{", start()->ToString(), ",&,", step()->ToString(), + "}"); + } + + Kind kind() const override { return Kind::kAndRecurrence; } + + gtl::ArraySlice GetOperands() const override { return operands_; } + + private: + std::array operands_; }; // Represents an uninterpreted symbol in a logical predicate. // // Two predicates are equivalent iff they are equivalent for all assignments to -// the symbols contained in them. +// the symbols contained in them, i.e. predicates are forall qualified over +// symbols. class SymbolPredicate : public Predicate { public: explicit SymbolPredicate(TensorId tensor_id, bool must_be_true) @@ -151,8 +257,13 @@ class SymbolPredicate : public Predicate { tensor_id_(std::move(tensor_id)), must_be_true_(must_be_true) {} - string ToString() const override { return tensor_id_.ToString(); } + string ToString() const override { + return must_be_true() ? strings::StrCat("*", tensor_id_.ToString()) + : tensor_id_.ToString(); + } + Kind kind() const override { return Kind::kSymbol; } + gtl::ArraySlice GetOperands() const override { return {}; } // If `must_be_true()` is true this SymbolPredicate represents the proposition // "tensor_id() is live and evaluates to true". @@ -174,6 +285,29 @@ class SymbolPredicate : public Predicate { } }; +template +/*static*/ void Predicate::Visit(Predicate* p, const FunctionTy& func) { + gtl::FlatSet visited; + std::vector stack; + + stack.push_back(p); + visited.insert(p); + + while (!stack.empty()) { + Predicate* current = stack.back(); + stack.pop_back(); + bool done = func(current); + if (done) { + return; + } + for (Predicate* op : current->GetOperands()) { + if (visited.insert(op).second) { + stack.push_back(op); + } + } + } +} + // Creates and owns Predicate instances. Simplifies predicates as it creates // them. class PredicateFactory { @@ -199,6 +333,21 @@ class PredicateFactory { } } + Predicate* MakeAndRecurrencePredicate(Predicate* start, Predicate* step) { + auto it = interned_and_rec_instances_.find({start, step}); + if (it != interned_and_rec_instances_.end()) { + return it->second.get(); + } + + std::unique_ptr new_pred = + Make(start, step); + Predicate* new_pred_ptr = new_pred.get(); + CHECK(interned_and_rec_instances_ + .emplace(SignatureForAndRec(start, step), std::move(new_pred)) + .second); + return new_pred_ptr; + } + Predicate* MakeSymbolPredicate(TensorId tensor_id, bool must_be_true) { SignatureForSymbol signature = {tensor_id, must_be_true}; auto it = interned_symbol_instances_.find(signature); @@ -239,6 +388,7 @@ class PredicateFactory { using SignatureForAndOr = std::pair>; using SignatureForNot = Predicate*; + using SignatureForAndRec = std::pair; using SignatureForSymbol = std::pair; struct HashSignatureForAndOr { @@ -263,6 +413,8 @@ class PredicateFactory { interned_and_or_instances_; gtl::FlatMap> interned_not_instances_; + gtl::FlatMap> + interned_and_rec_instances_; gtl::FlatMap, HashSignatureForSymbol> interned_symbol_instances_; @@ -283,10 +435,7 @@ Predicate* PredicateFactory::MakeAndOrImpl(gtl::ArraySlice operands, if (op->kind() == pred_kind) { // "Inline" the operands of an inner And/Or into the parent And/Or. - gtl::ArraySlice operands = - is_and ? dynamic_cast(op)->operands() - : dynamic_cast(op)->operands(); - for (Predicate* subop : operands) { + for (Predicate* subop : op->GetOperands()) { if (simplified_ops_set.insert(subop).second) { simplified_ops.push_back(subop); } @@ -346,27 +495,49 @@ class DeadnessAnalysisImpl : public DeadnessAnalysis { : graph_(*graph), vlog_(VLOG_IS_ON(2)) {} Status Populate(); + Status PopulateWithReversePostOrder(gtl::ArraySlice rpo); bool HasInputsWithMismatchingDeadness(const Node& node) override; void Print() const override; + gtl::FlatMap PredicateMapAsString() const; private: enum class EdgeKind { kDataAndControl, kDataOnly, kControlOnly }; std::vector GetIncomingPreds(Node* n, EdgeKind edge_kind); - void SetPred(Node* n, int output_idx, Predicate* pred) { - CHECK( - predicate_map_.insert({TensorId(n->name(), output_idx), pred}).second); + + // Sets the predicate for output `output_idx` of `n` to `pred`. Sets the i'th + // bit of `should_revisit` if `pred` is different from the current predicate + // for the `output_idx` output of `n`. + void SetPredicate(Node* n, int output_idx, Predicate* pred, + std::vector* should_revisit) { + auto insert_result = + predicate_map_.insert({TensorId(n->name(), output_idx), pred}); + if (!insert_result.second && insert_result.first->second != pred) { + VLOG(4) << "For " << n->name() << ":" << output_idx << " from " + << insert_result.first->second->ToString() << " " + << insert_result.first->second << " to " << pred->ToString() + << " " << pred; + insert_result.first->second = pred; + if (should_revisit != nullptr) { + for (const Edge* e : n->out_edges()) { + (*should_revisit)[e->dst()->id()] = true; + } + } + } } - void SetPred(Node* n, gtl::ArraySlice output_idxs, Predicate* pred) { + + void SetPredicate(Node* n, gtl::ArraySlice output_idxs, Predicate* pred, + std::vector* should_revisit) { for (int output_idx : output_idxs) { - SetPred(n, output_idx, pred); + SetPredicate(n, output_idx, pred, should_revisit); } } - Status HandleSwitch(Node* n); - Status HandleMerge(Node* n); - Status HandleRecv(Node* n); - Status HandleGeneric(Node* n); + Status HandleSwitch(Node* n, std::vector* should_revisit); + Status HandleMerge(Node* n, std::vector* should_revisit); + Status HandleRecv(Node* n, std::vector* should_revisit); + Status HandleGeneric(Node* n, std::vector* should_revisit); + Status HandleNode(Node* n, std::vector* should_revisit); const Graph& graph_; gtl::FlatMap predicate_map_; @@ -389,14 +560,15 @@ std::vector DeadnessAnalysisImpl::GetIncomingPreds( if (should_process) { auto it = predicate_map_.find(InputEdgeToTensorId(in_edge)); - CHECK(it != predicate_map_.end()); + CHECK(it != predicate_map_.end()) << n->name(); incoming_preds.push_back(it->second); } } return incoming_preds; } -Status DeadnessAnalysisImpl::HandleSwitch(Node* n) { +Status DeadnessAnalysisImpl::HandleSwitch(Node* n, + std::vector* should_revisit) { std::vector input_preds = GetIncomingPreds(n, EdgeKind::kDataAndControl); const Edge* pred_edge; @@ -408,84 +580,252 @@ Status DeadnessAnalysisImpl::HandleSwitch(Node* n) { // Output 0 is alive iff all inputs are alive and the condition is false. input_preds.push_back(false_switch); - SetPred(n, 0, predicate_factory_.MakeAndPredicate(input_preds)); + SetPredicate(n, 0, predicate_factory_.MakeAndPredicate(input_preds), + should_revisit); input_preds.pop_back(); // Output 1 is alive iff all inputs are alive and the condition is true. input_preds.push_back(true_switch); - SetPred(n, 1, predicate_factory_.MakeAndPredicate(input_preds)); + SetPredicate(n, 1, predicate_factory_.MakeAndPredicate(input_preds), + should_revisit); input_preds.pop_back(); - // Control is alive iff any inputs are alive. - SetPred(n, Graph::kControlSlot, - predicate_factory_.MakeAndPredicate(input_preds)); + // Control is alive iff all inputs are alive. + SetPredicate(n, Graph::kControlSlot, + predicate_factory_.MakeAndPredicate(input_preds), + should_revisit); return Status::OK(); } -Status DeadnessAnalysisImpl::HandleMerge(Node* n) { +namespace { +const Edge* FindUniqueBackedge(Node* merge) { + CHECK(merge->IsMerge()); + const Edge* result = nullptr; + for (const Edge* e : merge->in_edges()) { + if (e->src()->IsNextIteration()) { + CHECK_EQ(result, nullptr) + << "Multiple backedges to " << merge->DebugString(); + result = e; + } + } + return result; +} + +// If `backedge_predicate` is equal to `symbolic_predicate` & Step where Step +// does not contain `symbolic_predicate` as an inner (not top-level) operand +// then returns `Step`. Otherwise returns nullptr. +Predicate* DeduceStepPredicate(PredicateFactory* predicate_factory, + Predicate* symbolic_predicate, + Predicate* backedge_predicate) { + CHECK(dynamic_cast(symbolic_predicate)); + if (backedge_predicate->kind() != Predicate::Kind::kAnd) { + return nullptr; + } + + std::vector and_ops; + gtl::ArraySlice recurrent_pred_ops = + backedge_predicate->GetOperands(); + + bool found_sym = false; + for (Predicate* and_op : recurrent_pred_ops) { + // We want the `symbol_predicate` to be the one of the operands of + // `backedge_predicate`, + if (and_op == symbolic_predicate) { + found_sym = true; + continue; + } + + // but we don't want it to be present anywhere else in the formula. E.g. we + // don't want the recurrent predicate to be + // symbol_predicate&(X|symbol_predicate). + bool found_sym_as_inner_operand = false; + auto has_self_as_inner_operand = [&](Predicate* p) { + if (p == symbolic_predicate) { + found_sym_as_inner_operand = true; + return true; // Stop searching, we're done. + } + + // Continue searching. + return false; + }; + + Predicate::Visit(and_op, has_self_as_inner_operand); + if (found_sym_as_inner_operand) { + return nullptr; + } + and_ops.push_back(and_op); + } + + return found_sym ? predicate_factory->MakeAndPredicate(and_ops) : nullptr; +} +} // namespace + +Status DeadnessAnalysisImpl::HandleMerge(Node* n, + std::vector* should_revisit) { // Merge ignores deadness of its control inputs. A merge that isn't the - // target of a backedge has is alive iff any of its data inputs are. We treat - // the liveness of a merge that is the target of a backedge symbolically. + // target of a backedge has is alive iff any of its data inputs are. The + // liveness of a merge that is the target of a backedge can sometimes be + // represented using a AndRecurrencePredicate. If neither apply, we represent + // the liveness of the merge symbolically. + + bool has_unvisited_backedge = false; + for (const Edge* e : n->in_edges()) { + if (!e->IsControlEdge() && e->src()->IsNextIteration()) { + has_unvisited_backedge |= !predicate_map_.count(InputEdgeToTensorId(e)); + } + } - bool has_backedge = std::any_of( - n->in_edges().begin(), n->in_edges().end(), [](const Edge* e) { - return !e->IsControlEdge() && e->src()->IsNextIteration(); - }); + auto it = predicate_map_.find(TensorId(n->name(), 0)); + if (it == predicate_map_.end()) { + if (has_unvisited_backedge) { + // We're visiting this merge for the first time and it has an unvisited + // backedge. + Predicate* input_data_pred = predicate_factory_.MakeSymbolPredicate( + TensorId(n->name(), 0), /*must_be_true=*/false); + SetPredicate(n, {0, 1, Graph::kControlSlot}, input_data_pred, + should_revisit); + return Status::OK(); + } - Predicate* input_data_pred = - has_backedge ? predicate_factory_.MakeSymbolPredicate( - TensorId(n->name(), 0), /*must_be_true=*/false) - : predicate_factory_.MakeOrPredicate( - GetIncomingPreds(n, EdgeKind::kDataOnly)); + // We're visiting this merge for the first time and it is a acyclic merge. + Predicate* input_data_pred = predicate_factory_.MakeOrPredicate( + GetIncomingPreds(n, EdgeKind::kDataOnly)); + SetPredicate(n, {0, 1, Graph::kControlSlot}, input_data_pred, + should_revisit); + return Status::OK(); + } + + if (it->second->kind() == Predicate::Kind::kSymbol) { + // Last time we visited this merge we only got a symbolic predicate because + // of an unvisited backedge. Try to pattern match the predicate expression + // for that backedge (which should be visited now) into an and recurrence + // for the merge node. + if (const Edge* unique_backedge = FindUniqueBackedge(n)) { + if (Predicate* step = DeduceStepPredicate( + &predicate_factory_, it->second, + predicate_map_[InputEdgeToTensorId(unique_backedge)])) { + // If the predicate for the backedge is "Sym&X" where "Sym" is the + // predicate for the merge then the merge has predicate {S,&,X} where S + // is the predicate for the merge ignoring the backedge. + std::vector non_recurrent_inputs; + for (const Edge* e : n->in_edges()) { + if (e != unique_backedge) { + non_recurrent_inputs.push_back( + predicate_map_[InputEdgeToTensorId(e)]); + } + } - SetPred(n, {0, 1, Graph::kControlSlot}, input_data_pred); + Predicate* start = + predicate_factory_.MakeOrPredicate(non_recurrent_inputs); + Predicate* and_rec = + predicate_factory_.MakeAndRecurrencePredicate(start, step); + SetPredicate(n, {0, 1, Graph::kControlSlot}, and_rec, should_revisit); + return Status::OK(); + } + } + } return Status::OK(); } -Status DeadnessAnalysisImpl::HandleRecv(Node* n) { +Status DeadnessAnalysisImpl::HandleRecv(Node* n, + std::vector* should_revisit) { // In addition to being alive or dead based on the inputs, a _Recv can also // acquire a dead signal from a _Send. std::vector input_preds = GetIncomingPreds(n, EdgeKind::kDataAndControl); input_preds.push_back(predicate_factory_.MakeSymbolPredicate( TensorId(n->name(), 0), /*must_be_true=*/false)); - SetPred(n, {0, Graph::kControlSlot}, - predicate_factory_.MakeAndPredicate(input_preds)); + SetPredicate(n, {0, Graph::kControlSlot}, + predicate_factory_.MakeAndPredicate(input_preds), + should_revisit); return Status::OK(); } -Status DeadnessAnalysisImpl::HandleGeneric(Node* n) { +Status DeadnessAnalysisImpl::HandleGeneric(Node* n, + std::vector* should_revisit) { // Generally nodes are alive iff all their inputs are alive. Predicate* pred = predicate_factory_.MakeAndPredicate( GetIncomingPreds(n, EdgeKind::kDataAndControl)); for (int output_idx = 0; output_idx < n->num_outputs(); output_idx++) { - SetPred(n, output_idx, pred); + SetPredicate(n, output_idx, pred, should_revisit); + } + SetPredicate(n, Graph::kControlSlot, pred, should_revisit); + return Status::OK(); +} + +Status DeadnessAnalysisImpl::HandleNode(Node* n, + std::vector* should_revisit) { + if (n->IsSwitch()) { + TF_RETURN_IF_ERROR(HandleSwitch(n, should_revisit)); + } else if (n->IsMerge()) { + TF_RETURN_IF_ERROR(HandleMerge(n, should_revisit)); + } else if (n->IsControlTrigger()) { + SetPredicate(n, Graph::kControlSlot, predicate_factory_.MakeTrue(), + nullptr); + } else if (n->IsRecv() || n->IsHostRecv()) { + TF_RETURN_IF_ERROR(HandleRecv(n, should_revisit)); + } else if (n->IsNextIteration()) { + TF_RETURN_IF_ERROR(HandleGeneric(n, should_revisit)); + } else { + TF_RETURN_IF_ERROR(HandleGeneric(n, should_revisit)); } - SetPred(n, Graph::kControlSlot, pred); return Status::OK(); } Status DeadnessAnalysisImpl::Populate() { std::vector rpo; - GetReversePostOrder(graph_, &rpo, /*stable_comparator=*/{}, + GetReversePostOrder(graph_, &rpo, /*stable_comparator=*/NodeComparatorName(), /*edge_filter=*/[](const Edge& edge) { return !edge.src()->IsNextIteration(); }); + return PopulateWithReversePostOrder(rpo); +} +Status DeadnessAnalysisImpl::PopulateWithReversePostOrder( + gtl::ArraySlice rpo) { // This an abstract interpretation over the deadness propagation semantics of // the graph executor. + // + // We iterate over the graph twice, each time in RPO. On the first iteration + // merge nodes with backedges are mapped to symbolic predicates. On the + // second iteration we use the predicates assigned to the backedges in the + // previous iteration to infer a more precise predicate for the backedge merge + // nodes and all the nodes that transitively use it. + // + // We don't track the output indices for should_revisit. Instead, putting a + // node in `should_revisit` denotes that the deadness flowing out from any + // output from said node may have changed. This is fine; only switches + // propagate different deadness along different output edges, and since the + // delta is solely due to the input *values* (and not input deadness), the + // delta should not change in the second iteration. + std::vector should_revisit; + should_revisit.resize(graph_.num_node_ids()); for (Node* n : rpo) { - if (n->IsSwitch()) { - TF_RETURN_IF_ERROR(HandleSwitch(n)); - } else if (n->IsMerge()) { - TF_RETURN_IF_ERROR(HandleMerge(n)); - } else if (n->IsControlTrigger()) { - SetPred(n, Graph::kControlSlot, predicate_factory_.MakeTrue()); - } else if (n->IsRecv() || n->IsHostRecv()) { - TF_RETURN_IF_ERROR(HandleRecv(n)); - } else { - TF_RETURN_IF_ERROR(HandleGeneric(n)); + VLOG(4) << "Visiting " << n->name(); + TF_RETURN_IF_ERROR(HandleNode(n, /*should_revisit=*/nullptr)); + if (n->IsNextIteration()) { + // If this is a backedge for a merge node then remember to reprocess the + // merge the next time we run. + for (const Edge* e : n->out_edges()) { + if (e->dst()->IsMerge()) { + should_revisit[e->dst()->id()] = true; + } + } + } + } + + for (Node* n : rpo) { + // The nodes added to should_revisit in the previous loop need to be + // revisited now. Reprocesing these initial nodes may add *their* consumers + // to should_revisit, and these newly added nodes will also be processed by + // this very same loop. Since we're traversing the graph in reverse post + // order (producers before consumers) and HandleNode(n) can only ever add + // n's consumers to should_revisit, we won't "miss" an addition to + // should_revisit. + if (should_revisit[n->id()]) { + VLOG(4) << "Revisiting " << n->name(); + TF_RETURN_IF_ERROR(HandleNode(n, &should_revisit)); } } @@ -563,4 +903,33 @@ DeadnessAnalysis::~DeadnessAnalysis() {} return Status::OK(); } +gtl::FlatMap +DeadnessAnalysisImpl::PredicateMapAsString() const { + gtl::FlatMap result; + std::vector tensor_ids; + for (const auto& kv_pair : predicate_map_) { + CHECK(result.insert({kv_pair.first, kv_pair.second->ToString()}).second); + } + return result; +} + +namespace deadness_analysis_internal { +Status ComputePredicates(const Graph& graph, + PredicateMapTy* out_predicate_map) { + DeadnessAnalysisImpl impl(&graph); + TF_RETURN_IF_ERROR(impl.Populate()); + *out_predicate_map = impl.PredicateMapAsString(); + return Status::OK(); +} + +Status ComputePredicates(const Graph& graph, + gtl::ArraySlice reverse_post_order, + PredicateMapTy* out_predicate_map) { + DeadnessAnalysisImpl impl(&graph); + TF_RETURN_IF_ERROR(impl.PopulateWithReversePostOrder(reverse_post_order)); + *out_predicate_map = impl.PredicateMapAsString(); + return Status::OK(); +} +} // namespace deadness_analysis_internal + } // namespace tensorflow diff --git a/tensorflow/compiler/jit/deadness_analysis_internal.h b/tensorflow/compiler/jit/deadness_analysis_internal.h new file mode 100644 index 0000000000000000000000000000000000000000..401d6e406ab3db81d0cbd69b480d5962dab1f357 --- /dev/null +++ b/tensorflow/compiler/jit/deadness_analysis_internal.h @@ -0,0 +1,40 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_JIT_DEADNESS_ANALYSIS_INTERNAL_H_ +#define TENSORFLOW_COMPILER_JIT_DEADNESS_ANALYSIS_INTERNAL_H_ + +#include "tensorflow/core/graph/tensor_id.h" +#include "tensorflow/core/lib/gtl/flatmap.h" + +namespace tensorflow { +namespace deadness_analysis_internal { + +// Returns a map describing the predicate each Tensor was mapped to. For +// testing purposes only. +using PredicateMapTy = gtl::FlatMap; +Status ComputePredicates(const Graph& graph, PredicateMapTy* out_predicate_map); + +// Returns a map describing the predicate each Tensor was mapped to. For +// testing purposes only. Makes deadness analysis visit the graph in the order +// specified in `reverse_post_order` which must be a valid RPO for the graph +// minus NextIteration->Merge edges. +Status ComputePredicates(const Graph& graph, + gtl::ArraySlice reverse_post_order, + PredicateMapTy* out_predicate_map); +} // namespace deadness_analysis_internal +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_JIT_DEADNESS_ANALYSIS_INTERNAL_H_ diff --git a/tensorflow/compiler/jit/deadness_analysis_test.cc b/tensorflow/compiler/jit/deadness_analysis_test.cc index 584385cab7665dce9c7c92eab6293436ca22c9b7..cc9f1023985560be0bce5971931d2ec8e742b377 100644 --- a/tensorflow/compiler/jit/deadness_analysis_test.cc +++ b/tensorflow/compiler/jit/deadness_analysis_test.cc @@ -21,6 +21,7 @@ limitations under the License. #include "tensorflow/cc/ops/function_ops.h" #include "tensorflow/cc/ops/sendrecv_ops.h" #include "tensorflow/cc/ops/standard_ops.h" +#include "tensorflow/compiler/jit/deadness_analysis_internal.h" #include "tensorflow/compiler/jit/defs.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" @@ -37,6 +38,9 @@ limitations under the License. namespace tensorflow { namespace { +using deadness_analysis_internal::ComputePredicates; +using deadness_analysis_internal::PredicateMapTy; + Status AnalyzeDeadness(Graph* graph, std::unique_ptr* result) { FixupSourceAndSinkEdges(graph); @@ -50,13 +54,73 @@ ops::Switch CreateSwitch(const Scope& root, const string& prefix) { return ops::Switch(root.WithOpName(prefix + "/switch"), value, predicate); } -Output CreateInductionVariable(const Scope& root, const string& prefix, - const string& frame_name, int32 init) { - Output initial_value = ops::Const(root.WithOpName(prefix + "/init"), init); +TensorId ControlOutputFor(const Output& o) { + return {o.node()->name(), Graph::kControlSlot}; +} + +void VLogGraphIfAsked(const Graph& graph) { + if (VLOG_IS_ON(3)) { + GraphDef graph_def; + graph.ToGraphDef(&graph_def); + string serialized; + ::tensorflow::protobuf::TextFormat::PrintToString(graph_def, &serialized); + LOG(INFO) << serialized; + } +} + +struct InductionVarInfo { + Output induction_var; + Output loop_cond; +}; + +// Creates an induction variable with the following structure (simplified for +// brevity): +// +// +---------------+ +// | initial_value | +// +---------------+ +// | +// | +// v +// +---------------+ +// | Enter | +// +---------------+ +// | +// | +// v +// +---------------+ +// +> | Merge | -+ +// | +---------------+ | +// | | | +// | | | +// | v | +// | +---------------+ | +// | | LessThan10 | | +// | +---------------+ | +// | | | +// | | | +// | v | +// | +---------------+ | +// +----+- | Switch | <+ +// | | +---------------+ +// | | | +// | | | +// | | v +// | | +---------------+ +// | +- | AddOne | +// | +---------------+ +// | +---------------+ +// +-----> | Exit | +// +---------------+ +InductionVarInfo CreateInductionVariable(const Scope& root, + const string& prefix, + const string& frame_name, + const Output& initial_value) { Output enter_initial_value = ops::internal::Enter( root.WithOpName(prefix + "/enter"), initial_value, frame_name); - ops::Merge iv(root.WithOpName(prefix + "/iv"), {enter_initial_value}); + ops::Merge iv(root.WithOpName(prefix + "/iv"), + {enter_initial_value, enter_initial_value}); Output increment_by = ops::Const(root.WithOpName(prefix + "/incr"), 1); Output final_value = ops::Const(root.WithOpName(prefix + "/final"), 10); Output loop_cond_expr = @@ -65,16 +129,84 @@ Output CreateInductionVariable(const Scope& root, const string& prefix, ops::LoopCond(root.WithOpName(prefix + "/cond"), loop_cond_expr); ops::Switch latch(root.WithOpName(prefix + "/latch"), iv.output, loop_cond); ops::internal::Exit exit(root.WithOpName(prefix + "/exit"), iv.output); - Output iv_next = - ops::Add(root.WithOpName(prefix + "/ivnext"), iv.output, increment_by); + Output iv_next = ops::Add(root.WithOpName(prefix + "/ivnext"), + latch.output_true, increment_by); Output next_iteration = - ops::NextIteration(root.WithOpName(prefix + "next_iteration"), iv_next); + ops::NextIteration(root.WithOpName(prefix + "/next_iteration"), iv_next); - root.graph()->AddEdge(next_iteration.node(), 0, iv.output.node(), 1); + CHECK(root.graph() + ->UpdateEdge(next_iteration.node(), 0, iv.output.node(), 1) + .ok()); root.graph()->AddControlEdge(iv.output.node(), increment_by.node()); root.graph()->AddControlEdge(iv.output.node(), final_value.node()); - return iv.output; + return {iv.output, loop_cond}; +} + +InductionVarInfo CreateInductionVariable(const Scope& root, + const string& prefix, + const string& frame_name, int32 init) { + return CreateInductionVariable( + root, prefix, frame_name, + ops::Const(root.WithOpName(prefix + "/init"), init)); +} + +// Creates an induction variable with the following structure: +// +// +---------------+ +// | initial_value | +// +---------------+ +// | +// | +// v +// +---------------+ +// | Enter | +// +---------------+ +// | +// | +// v +// +---------------+ +// | Merge | <+ +// +---------------+ | +// | | +// | | +// v | +// +-----------+ +---------------+ | +// | loop_cond | --> | Switch | -+ +// +-----------+ +---------------+ +// | +// | +// v +// +---------------+ +// | Exit | +// +---------------+ +struct DependentInductionVar { + Output induction_var; + ops::Switch latch; +}; + +DependentInductionVar CreateDependentLoopInvariantValue( + const Scope& root, const string& prefix, const string& frame_name, + const Output& loop_cond, const Output& value) { + Output enter_value = ops::internal::Enter(root.WithOpName(prefix + "/enter"), + value, frame_name); + ops::Merge iv(root.WithOpName(prefix + "/iv"), {enter_value, enter_value}); + ops::Switch latch(root.WithOpName(prefix + "/latch"), iv.output, loop_cond); + ops::internal::Exit exit(root.WithOpName(prefix + "/exit"), iv.output); + Output next_iteration = ops::NextIteration( + root.WithOpName(prefix + "/next_iteration"), latch.output_true); + CHECK(root.graph() + ->UpdateEdge(next_iteration.node(), 0, iv.output.node(), 1) + .ok()); + return {iv.output, latch}; +} + +DependentInductionVar CreateDependentLoopInvariantValue( + const Scope& root, const string& prefix, const string& frame_name, + const Output& loop_cond, int32 value) { + return CreateDependentLoopInvariantValue( + root, prefix, frame_name, loop_cond, + ops::Const(root.WithOpName(prefix + "/init"), value)); } TEST(DeadnessAnalysisTest, BasicPositive) { @@ -336,21 +468,224 @@ TEST(DeadnessAnalysisTest, HostRecv) { TEST(DeadnessAnalysisTest, Loop) { Scope root = Scope::NewRootScope().ExitOnError(); - Output iv0 = CreateInductionVariable(root, "iv0", "fr0", 0); - Output iv1 = CreateInductionVariable(root, "iv1", "fr0", 0); - Output iv2 = CreateInductionVariable(root, "iv2", "fr0", 1); + Output iv0 = CreateInductionVariable(root, "iv0", "fr0", 0).induction_var; + Output iv1 = CreateInductionVariable(root, "iv1", "fr0", 0).induction_var; + Output iv2 = CreateInductionVariable(root, "iv2", "fr0", 1).induction_var; Output add0 = ops::Add(root.WithOpName("add0"), iv0, iv1); Output add1 = ops::Add(root.WithOpName("add1"), iv1, iv2); - std::unique_ptr result; - TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result)); - // NB! iv0 and iv1 are equivalent and a smarter deadness analysis would have // noticed that. Today we are pessimistic here because we assign an // uninterpreted symbol to merges with backedges. - EXPECT_TRUE(result->HasInputsWithMismatchingDeadness(*add0.node())); - EXPECT_TRUE(result->HasInputsWithMismatchingDeadness(*add1.node())); + VLogGraphIfAsked(*root.graph()); + + { + std::unique_ptr result; + TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result)); + + EXPECT_TRUE(result->HasInputsWithMismatchingDeadness(*add0.node())); + EXPECT_TRUE(result->HasInputsWithMismatchingDeadness(*add1.node())); + } + { + PredicateMapTy predicate_map; + TF_ASSERT_OK(ComputePredicates(*root.graph(), &predicate_map)); + + // In theory we should be able to tell that iv0/cond:0 and iv1/cond:0 + // produce the same deadness. But we're not that smart today. + EXPECT_EQ(predicate_map[ControlOutputFor(iv0)], "{#true,&,*iv0/cond:0}"); + EXPECT_EQ(predicate_map[ControlOutputFor(iv1)], "{#true,&,*iv1/cond:0}"); + EXPECT_EQ(predicate_map[ControlOutputFor(iv2)], "{#true,&,*iv2/cond:0}"); + EXPECT_EQ(predicate_map[ControlOutputFor(add0)], + "({#true,&,*iv1/cond:0} & {#true,&,*iv0/cond:0})"); + EXPECT_EQ(predicate_map[ControlOutputFor(add1)], + "({#true,&,*iv1/cond:0} & {#true,&,*iv2/cond:0})"); + } +} + +TEST(DeadnessAnalysisTest, ControlEquivalentLoopBodies) { + Scope root = Scope::NewRootScope().ExitOnError(); + InductionVarInfo iv = CreateInductionVariable(root, "iv0", "frame", 0); + Output dependent_iv0 = + CreateDependentLoopInvariantValue(root, "div0", "frame", iv.loop_cond, 0) + .induction_var; + Output dependent_iv1 = + CreateDependentLoopInvariantValue(root, "div1", "frame", iv.loop_cond, 0) + .induction_var; + Output add0 = ops::Add(root.WithOpName("add0"), dependent_iv0, dependent_iv1); + + VLogGraphIfAsked(*root.graph()); + + { + std::unique_ptr result; + TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result)); + + EXPECT_FALSE(result->HasInputsWithMismatchingDeadness(*add0.node())); + } + { + PredicateMapTy predicate_map; + TF_ASSERT_OK(ComputePredicates(*root.graph(), &predicate_map)); + + EXPECT_EQ(predicate_map[ControlOutputFor(iv.induction_var)], + "{#true,&,*iv0/cond:0}"); + EXPECT_EQ(predicate_map[ControlOutputFor(dependent_iv0)], + "{#true,&,(*iv0/cond:0 & iv0/iv:0)}"); + EXPECT_EQ(predicate_map[ControlOutputFor(dependent_iv1)], + "{#true,&,(*iv0/cond:0 & iv0/iv:0)}"); + EXPECT_EQ(predicate_map[ControlOutputFor(add0)], + "{#true,&,(*iv0/cond:0 & iv0/iv:0)}"); + } +} + +TEST(DeadnessAnalysisTest, LoopInvariantPredicateOnBackedge) { + // Create a merge that "looks like" a loop but isn't really. It has a value + // that does not depend on the merge on its backedge. + Scope root = Scope::NewRootScope().ExitOnError(); + InductionVarInfo iv = CreateInductionVariable(root, "iv0", "frame", 0); + DependentInductionVar dependent_iv = + CreateDependentLoopInvariantValue(root, "div0", "frame", iv.loop_cond, 0); + FixupSourceAndSinkEdges(root.graph()); + + // To make deadness analysis think that dependent_iv is a loop we need an RPO + // that visits the merge before the backedge. This is a legal RPO for + // deadness analysis since it ignores NextIteration->Merge edges during RPO. + // Right now dependent_iv has an edge from Merge to NextIteration so do the + // RPO with this edge in place. Then remove this edge to get our test case. + std::vector rpo; + GetReversePostOrder(*root.graph(), &rpo, /*stable_comparator=*/{}, + /*edge_filter=*/[](const Edge& edge) { + return !edge.src()->IsNextIteration(); + }); + TF_ASSERT_OK(root.graph()->UpdateEdge( + iv.induction_var.node(), 0, dependent_iv.latch.output_true.node(), 0)); + + VLogGraphIfAsked(*root.graph()); + + { + PredicateMapTy predicate_map; + TF_ASSERT_OK(ComputePredicates(*root.graph(), rpo, &predicate_map)); + + EXPECT_EQ(predicate_map[ControlOutputFor(dependent_iv.induction_var)], + "div0/iv:0"); + } +} + +TEST(DeadnessAnalysisTest, ControlEquivalentNestedLoopBodies) { + Scope root = Scope::NewRootScope().ExitOnError(); + InductionVarInfo iv_outer = + CreateInductionVariable(root, "iv_outer", "frame", 0); + ops::Switch inner_value(root.WithOpName("outer_is_live"), + ops::Const(root.WithOpName("constant"), 5), + iv_outer.loop_cond); + InductionVarInfo iv_inner = CreateInductionVariable( + root, "iv_inner", "frame", + ops::internal::Enter(root.WithOpName("iv_inner/enter"), + inner_value.output_true, "frame_inner")); + + Output dependent_outer_iv0 = + CreateDependentLoopInvariantValue(root, "dependent_outer_iv0", "frame", + iv_outer.loop_cond, 0) + .induction_var; + Output dependent_outer_iv1 = + CreateDependentLoopInvariantValue(root, "dependent_outer_iv1", "frame", + iv_outer.loop_cond, 0) + .induction_var; + + Output dependent_inner_iv0 = + CreateDependentLoopInvariantValue(root, "dependent_inner_iv0", "frame", + iv_inner.loop_cond, dependent_outer_iv0) + .induction_var; + Output dependent_inner_iv1 = + CreateDependentLoopInvariantValue(root, "dependent_inner_iv1", "frame", + iv_inner.loop_cond, dependent_outer_iv1) + .induction_var; + + Output add0 = ops::Add(root.WithOpName("add0"), dependent_inner_iv0, + dependent_inner_iv1); + + VLogGraphIfAsked(*root.graph()); + + { + std::unique_ptr result; + TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result)); + + EXPECT_FALSE(result->HasInputsWithMismatchingDeadness(*add0.node())); + } + { + PredicateMapTy predicate_map; + TF_ASSERT_OK(ComputePredicates(*root.graph(), &predicate_map)); + + EXPECT_EQ(predicate_map[ControlOutputFor(iv_outer.induction_var)], + "{#true,&,*iv_outer/cond:0}"); + EXPECT_EQ(predicate_map[ControlOutputFor(iv_inner.induction_var)], + "{(*iv_outer/cond:0 & {#true,&,*iv_outer/cond:0}),&," + "*iv_inner/cond:0}"); + + EXPECT_EQ(predicate_map[ControlOutputFor(dependent_inner_iv0)], + "{{#true,&,(iv_outer/iv:0 & *iv_outer/cond:0)},&," + "(*iv_inner/cond:0 & iv_inner/iv:0)}"); + EXPECT_EQ(predicate_map[ControlOutputFor(dependent_inner_iv1)], + "{{#true,&,(iv_outer/iv:0 & *iv_outer/cond:0)},&," + "(*iv_inner/cond:0 & iv_inner/iv:0)}"); + EXPECT_EQ(predicate_map[ControlOutputFor(add0)], + "{{#true,&,(iv_outer/iv:0 & *iv_outer/cond:0)},&," + "(*iv_inner/cond:0 & iv_inner/iv:0)}"); + } +} + +TEST(DeadnessAnalysisTest, ControlNonEquivalentNestedLoopBodies) { + Scope root = Scope::NewRootScope().ExitOnError(); + InductionVarInfo iv_outer_0 = + CreateInductionVariable(root, "iv_outer_0", "frame", 0); + ops::Switch inner_value_0(root.WithOpName("outer_0_is_live"), + ops::Const(root.WithOpName("constant"), 5), + iv_outer_0.loop_cond); + InductionVarInfo iv_inner_0 = CreateInductionVariable( + root, "iv_inner_0", "frame", + ops::internal::Enter(root.WithOpName("iv_inner_0/enter"), + inner_value_0.output_true, "frame_inner")); + + InductionVarInfo iv_outer_1 = + CreateInductionVariable(root, "iv_outer_1", "frame", 1); + ops::Switch inner_init_value_1(root.WithOpName("outer_1_is_live"), + ops::Const(root.WithOpName("constant"), 5), + iv_outer_1.loop_cond); + InductionVarInfo iv_inner_1 = CreateInductionVariable( + root, "iv_inner_1", "frame", + ops::internal::Enter(root.WithOpName("iv_inner_1/enter"), + inner_init_value_1.output_true, "frame_inner")); + Output add0 = ops::Add(root.WithOpName("add0"), iv_inner_0.induction_var, + iv_inner_1.induction_var); + + VLogGraphIfAsked(*root.graph()); + + { + std::unique_ptr result; + TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result)); + + EXPECT_TRUE(result->HasInputsWithMismatchingDeadness(*add0.node())); + } + + { + PredicateMapTy predicate_map; + TF_ASSERT_OK(ComputePredicates(*root.graph(), &predicate_map)); + + EXPECT_EQ(predicate_map[ControlOutputFor(iv_outer_0.induction_var)], + "{#true,&,*iv_outer_0/cond:0}"); + EXPECT_EQ(predicate_map[ControlOutputFor(iv_inner_0.induction_var)], + "{(*iv_outer_0/cond:0 & {#true,&,*iv_outer_0/cond:0}),&," + "*iv_inner_0/cond:0}"); + EXPECT_EQ(predicate_map[ControlOutputFor(iv_outer_1.induction_var)], + "{#true,&,*iv_outer_1/cond:0}"); + EXPECT_EQ(predicate_map[ControlOutputFor(iv_inner_1.induction_var)], + "{(*iv_outer_1/cond:0 & {#true,&,*iv_outer_1/cond:0}),&," + "*iv_inner_1/cond:0}"); + EXPECT_EQ(predicate_map[ControlOutputFor(add0)], + "({(*iv_outer_1/cond:0 & {#true,&,*iv_outer_1/cond:0}),&," + "*iv_inner_1/cond:0} & " + "{(*iv_outer_0/cond:0 & {#true,&,*iv_outer_0/cond:0}),&," + "*iv_inner_0/cond:0})"); + } } TEST(DeadnessAnalysisTest, ControlInputs) { @@ -439,5 +774,27 @@ TEST(DeadnessAnalysisTest, RecvVsSwitch) { EXPECT_TRUE(result->HasInputsWithMismatchingDeadness(*logical_and.node())); } +TEST(DeadnessAnalysisTest, RecvVsSwitchText) { + // Demonstrates why we need the must_be_true bit on SymbolP. + Scope root = Scope::NewRootScope().ExitOnError(); + + Output recv = ops::_Recv(root.WithOpName("recv"), DT_BOOL, "tensor", "sender", + 0, "receiver"); + Output value = ops::Placeholder(root.WithOpName("value"), DT_BOOL); + ops::Switch sw(root.WithOpName("switch"), value, recv); + Output logical_and = + ops::LogicalAnd(root.WithOpName("and"), recv, sw.output_true); + + std::unique_ptr result; + TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result)); + + PredicateMapTy predicate_map; + TF_ASSERT_OK(ComputePredicates(*root.graph(), &predicate_map)); + + TensorId logical_and_output_0 = {logical_and.node()->name(), + Graph::kControlSlot}; + EXPECT_EQ(predicate_map[logical_and_output_0], "(recv:0 & *recv:0)"); +} + } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc b/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc index fdd71c6a588ad96301f543651c8531e6f9c3ca05..f150bf1819d407e1c6a279673a89de4307b5426b 100644 --- a/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc +++ b/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc @@ -1161,8 +1161,7 @@ Status Encapsulator::Subgraph::ReplaceFunctionDef( strings::StrCat("replace_encapsulate_fdef_", name), fdef); } - TF_RETURN_IF_ERROR(library->RemoveFunction(name)); - TF_RETURN_IF_ERROR(library->AddFunctionDef(fdef)); + TF_RETURN_IF_ERROR(library->ReplaceFunction(name, fdef)); return Status::OK(); } diff --git a/tensorflow/compiler/jit/jit_compilation_pass_registration.cc b/tensorflow/compiler/jit/jit_compilation_pass_registration.cc index 4d49a14b24d53bbcb434560d59b8c97a17e18f86..c37b6112cc8a92047d495d057f59e2281710e678 100644 --- a/tensorflow/compiler/jit/jit_compilation_pass_registration.cc +++ b/tensorflow/compiler/jit/jit_compilation_pass_registration.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/compiler/jit/build_xla_launch_ops_pass.h" #include "tensorflow/compiler/jit/encapsulate_subgraphs_pass.h" #include "tensorflow/compiler/jit/mark_for_compilation_pass.h" +#include "tensorflow/compiler/jit/partially_decluster_pass.h" #include "tensorflow/core/common_runtime/optimization_registry.h" namespace tensorflow { @@ -23,15 +24,18 @@ namespace tensorflow { REGISTER_OPTIMIZATION(OptimizationPassRegistry::POST_REWRITE_FOR_EXEC, 10, MarkForCompilationPass); +REGISTER_OPTIMIZATION(OptimizationPassRegistry::POST_REWRITE_FOR_EXEC, 20, + PartiallyDeclusterPass); + // The EncapsulateSubgraphs pass must run after the MarkForCompilationPass. We // also need to run it after the graph been rewritten to have _Send nodes added // for fetches. Before the _Send nodes are added, fetch nodes are identified by // name, and encapsulation might remove that node from the graph. -REGISTER_OPTIMIZATION(OptimizationPassRegistry::POST_REWRITE_FOR_EXEC, 20, +REGISTER_OPTIMIZATION(OptimizationPassRegistry::POST_REWRITE_FOR_EXEC, 30, EncapsulateSubgraphsPass); // Must run after EncapsulateSubgraphsPass. -REGISTER_OPTIMIZATION(OptimizationPassRegistry::POST_REWRITE_FOR_EXEC, 30, +REGISTER_OPTIMIZATION(OptimizationPassRegistry::POST_REWRITE_FOR_EXEC, 40, BuildXlaLaunchOpsPass); } // namespace tensorflow diff --git a/tensorflow/compiler/jit/kernels/BUILD b/tensorflow/compiler/jit/kernels/BUILD index 00a6f4075f9a18efc3895b033eb6d08e36088a53..8f78c110cb15f3cbc0344d102764241996b0d7de 100644 --- a/tensorflow/compiler/jit/kernels/BUILD +++ b/tensorflow/compiler/jit/kernels/BUILD @@ -16,6 +16,7 @@ cc_library( "//tensorflow/compiler/jit:xla_device", "//tensorflow/compiler/jit:xla_launch_util", "//tensorflow/compiler/tf2xla:common", + "//tensorflow/compiler/tf2xla:tf2xla_util", "//tensorflow/compiler/tf2xla:xla_compiler", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla/client:client_library", diff --git a/tensorflow/compiler/jit/kernels/xla_launch_op.cc b/tensorflow/compiler/jit/kernels/xla_launch_op.cc index c5d0e4f8fb61b90eb58d9df398d680b3c5481196..7f4370b5b07b249bc9cf1f2ecf4086de359be68c 100644 --- a/tensorflow/compiler/jit/kernels/xla_launch_op.cc +++ b/tensorflow/compiler/jit/kernels/xla_launch_op.cc @@ -19,6 +19,7 @@ limitations under the License. #include "tensorflow/compiler/jit/xla_device.h" #include "tensorflow/compiler/jit/xla_launch_util.h" #include "tensorflow/compiler/tf2xla/shape_util.h" +#include "tensorflow/compiler/tf2xla/tf2xla_util.h" #include "tensorflow/compiler/tf2xla/xla_compiler.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/client_library.h" @@ -153,6 +154,10 @@ void XlaLocalLaunchBase::Compute(OpKernelContext* ctx) { XlaCompiler::Options options; options.client = client; + if (ctx->op_device_context() != nullptr) { + options.device_ordinal = + ctx->op_device_context()->stream()->parent()->device_ordinal(); + } options.device_type = cache->device_type(); options.flib_def = ctx->function_library()->GetFunctionLibraryDefinition(); options.graph_def_version = ctx->function_library()->graph_def_version(); @@ -195,7 +200,7 @@ void XlaLocalLaunchBase::Compute(OpKernelContext* ctx) { run_options.set_stream(stream); run_options.set_allocator(xla_allocator); run_options.set_intra_op_thread_pool(&ctx->eigen_cpu_device()); - run_options.set_rng_seed(ctx->step_id()); + run_options.set_rng_seed(GetXLARandomSeed()); Env* env = Env::Default(); auto start_time = env->NowMicros(); @@ -205,7 +210,8 @@ void XlaLocalLaunchBase::Compute(OpKernelContext* ctx) { auto elapsed = env->NowMicros() - start_time; VLOG(2) << "Elapsed time: " << elapsed << "us"; - launch_context.PopulateOutputs(ctx, kernel, run_result.ConsumeValueOrDie()); + OP_REQUIRES_OK(ctx, launch_context.PopulateOutputs( + ctx, kernel, run_result.ConsumeValueOrDie())); VLOG(1) << "Done"; } diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass.cc b/tensorflow/compiler/jit/mark_for_compilation_pass.cc index 38eb6d830f4d4e889810acd0f928e93d0b22bde8..f4e179dab246cc65b14946f112d1398237fd2906 100644 --- a/tensorflow/compiler/jit/mark_for_compilation_pass.cc +++ b/tensorflow/compiler/jit/mark_for_compilation_pass.cc @@ -39,7 +39,9 @@ limitations under the License. #include "tensorflow/core/graph/algorithm.h" #include "tensorflow/core/graph/control_flow.h" #include "tensorflow/core/kernels/bounds_check.h" +#include "tensorflow/core/lib/gtl/cleanup.h" #include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/public/version.h" namespace tensorflow { @@ -65,6 +67,7 @@ bool HasXLAKernel(const Node& node, const DeviceType& jit_device_type) { // XLA cluster so it can't implement the forward-tensor-ref semantic. Leave // such nodes out of XLA clusters. if (HasForwardedRefInput(node)) { + VLOG(2) << "Rejecting " << node.name() << ": Identity with unsafe cast."; return false; } @@ -84,14 +87,13 @@ bool IsCompilableCall(const NodeDef& call_def, bool IsCompilableWhile(const Node& while_node, const DeviceType& jit_device_type, int depth, FunctionLibraryRuntime* lib_runtime) { - VLOG(2) << "Loop marking: " << while_node.type_string(); - const NameAttrList* name_attr; NodeDef call; Status status; status = GetNodeAttr(while_node.attrs(), "cond", &name_attr); if (!status.ok()) { - VLOG(2) << "Missing 'cond' attribute on While node."; + VLOG(2) << "Rejecting While " << while_node.name() + << ": missing 'cond' attribute on While node."; return false; } const string cond_func = name_attr->name(); @@ -99,12 +101,14 @@ bool IsCompilableWhile(const Node& while_node, call.set_op(cond_func); *call.mutable_attr() = name_attr->attr(); if (!IsCompilableCall(call, jit_device_type, depth + 1, lib_runtime)) { - VLOG(2) << "Can't compile loop condition: " << cond_func; + VLOG(2) << "Rejecting While " << while_node.name() + << ": can't compile loop condition: " << cond_func; return false; } status = GetNodeAttr(while_node.attrs(), "body", &name_attr); if (!status.ok()) { - VLOG(2) << "Missing 'body' attribute on While node."; + VLOG(2) << "Rejecting While " << while_node.name() + << ": missing 'body' attribute on While node."; return false; } const string body_func = name_attr->name(); @@ -112,10 +116,10 @@ bool IsCompilableWhile(const Node& while_node, call.set_op(body_func); *call.mutable_attr() = name_attr->attr(); if (!IsCompilableCall(call, jit_device_type, depth + 1, lib_runtime)) { - VLOG(2) << "Can't compile loop body: " << body_func; + VLOG(2) << "Rejecting While " << while_node.name() + << ": can't compile loop body: " << body_func; return false; } - VLOG(2) << "Loop is compilable."; return true; } @@ -125,10 +129,9 @@ bool IsCompilableWhile(const Node& while_node, bool IsCompilableCall(const NodeDef& call_def, const DeviceType& jit_device_type, int depth, FunctionLibraryRuntime* lib_runtime) { - VLOG(2) << "Function marking: " << call_def.op(); - if (depth > kMaxRecursionDepth) { - VLOG(2) << "Function depth limit exceeded"; + VLOG(2) << "Rejecting " << call_def.op() + << ": function depth limit exceeded."; return false; } @@ -136,9 +139,14 @@ bool IsCompilableCall(const NodeDef& call_def, Status status = lib_runtime->Instantiate(call_def.op(), AttrSlice(call_def), &handle); if (!status.ok()) { - VLOG(2) << "Could not instantiate " << call_def.op() << ": " << status; + VLOG(2) << "Rejecting " << call_def.op() + << ": could not instantiate: " << status; return false; } + + auto release_handle_on_return = gtl::MakeCleanup( + [&] { TF_CHECK_OK(lib_runtime->ReleaseHandle(handle)); }); + const FunctionBody* fbody = lib_runtime->GetFunctionBody(handle); CHECK(fbody); const FunctionDef& fdef = fbody->fdef; @@ -150,7 +158,8 @@ bool IsCompilableCall(const NodeDef& call_def, // tf2xla to translate the TF graph into XLA. So we avoid this for now. // // TODO(b/36139787): Create a mechanism to set inlining hints. - VLOG(2) << "Can't compile noinline function: " << fdef.DebugString(); + VLOG(2) << "Rejecting " << call_def.op() + << ": can't compile noinline function."; return false; } @@ -164,23 +173,14 @@ bool IsCompilableCall(const NodeDef& call_def, if (!HasXLAKernel(*node, jit_device_type) && !IsCompilableCall(node->def(), jit_device_type, depth + 1, lib_runtime)) { - VLOG(2) << "Function marking failed: unsupported op " << node->name() - << ": " << node->def().ShortDebugString(); + VLOG(2) << "Rejecting " << call_def.op() << ": unsupported op " + << node->name() << ": " << node->def().ShortDebugString(); return false; } } - VLOG(2) << "Function is compilable: " << call_def.op(); return true; } -// Tests whether `node` has a DT_RESOURCE typed input or output. -bool HasResourceInputOrOutput(const Node& node) { - return std::find(node.input_types().begin(), node.input_types().end(), - DT_RESOURCE) != node.input_types().end() || - std::find(node.output_types().begin(), node.output_types().end(), - DT_RESOURCE) != node.output_types().end(); -} - // Returns true if the op can be decomposed into XLA ops for which // there are fusable elemental implementations. // @@ -357,24 +357,27 @@ Status FindCompilationCandidates( } std::sort(sorted_nodes.begin(), sorted_nodes.end(), NodeComparatorID()); + if (fuel >= std::numeric_limits::max() / 2) { + // The assumption is that if fuel started out as INT64_MAX, it will forever + // stay greater than INT64_MAX / 2. + VLOG(2) << "Starting fuel: infinity"; + } else { + VLOG(2) << "Starting fuel: " << fuel; + } + for (Node* node : sorted_nodes) { - VLOG(2) << "Fuel: " << fuel; if (fuel <= 0) { - VLOG(2) + VLOG(1) << "Hit fuel limit; not marking any remaining ops as clusterable."; break; } - VLOG(2) << "FindCompilationCandidates(): Processing " - << node->DebugString(); - DeviceType device_type(""); TF_RETURN_IF_ERROR( DeviceToDeviceType(node->assigned_device_name(), &device_type)); if (is_compilable_fn && !is_compilable_fn(node, device_type)) { - VLOG(2) << "Compilation rejected node: not compilable " << node->name() - << ": " << node->type_string(); + // is_compilable_fn has already logged the reason if it returned false. continue; } @@ -384,14 +387,14 @@ Status FindCompilationCandidates( DeviceType jit_device_type(registration->compilation_device_name); if (!HasXLAKernel(*node, jit_device_type) && !IsCompilableCall(node->def(), jit_device_type, 0, lib_runtime)) { - VLOG(2) << "Compilation rejected node: unsupported op " << node->name() - << ": " << node->type_string(); + VLOG(2) << "Rejecting " << node->name() << ": unsupported op " + << node->type_string(); continue; } if (!registration->compile_resource_ops && HasResourceInputOrOutput(*node)) { - VLOG(2) << "Compilation rejected node: resource input/output " - << node->name() << ": " << node->type_string(); + VLOG(2) << "Rejecting: " << node->name() << ": resource input/output " + << node->type_string(); continue; } if (node->type_string() == "While" && @@ -401,15 +404,11 @@ Status FindCompilationCandidates( // _Arg nodes in a top-level function represent feeds. // Do not compile them. if (node->type_string() == "_Arg") { - VLOG(2) << "Skipping jit compilation for '_Arg'-typed node " - << node->DebugString(); continue; } // _Retval nodes in a top-level function represent fetches. // Do not compile them. if (node->type_string() == "_Retval") { - VLOG(2) << "Compilation rejected node: return value " << node->name() - << ": " << node->type_string(); continue; } candidates->insert(node); @@ -462,6 +461,7 @@ Status MarkForCompilationPass::Run( VLOG(1) << "flags->tf_xla_cpu_global_jit = " << flags->tf_xla_cpu_global_jit; VLOG(1) << "flags->tf_xla_fusion_only = " << flags->tf_xla_fusion_only; + VLOG(1) << "flags->tf_xla_auto_jit = " << flags->tf_xla_auto_jit; const FunctionLibraryDefinition* fld = options.flib_def; std::unique_ptr deadness; @@ -474,6 +474,7 @@ Status MarkForCompilationPass::Run( const XlaOpRegistry::DeviceRegistration* registration; if (!XlaOpRegistry::GetCompilationDevice(device_type.type(), ®istration)) { + VLOG(2) << "Rejecting " << node->name() << ": could not find JIT device."; return false; } @@ -483,21 +484,36 @@ Status MarkForCompilationPass::Run( // If there is a _XlaCompile annotation, use its value. bool compile = false; Status status = GetNodeAttr(node->attrs(), kXlaCompileAttr, &compile); - if (status.ok()) return compile; + if (status.ok()) { + if (!compile) { + VLOG(2) << "Rejecting " << node->name() << ": kXlaCompileAttr(" + << kXlaCompileAttr << ") is false."; + } + return compile; + } status = fld->GetAttr(*node, kXlaCompileAttr, &compile); - if (status.ok()) return compile; + if (status.ok()) { + if (!compile) { + VLOG(2) << "Rejecting " << node->name() << ": kXlaCompileAttr(" + << kXlaCompileAttr << ") on callee is false."; + } + return compile; + } // If inputs to `node` can have conflicting deadness (i.e. some are alive // and some are dead) then don't compile it. XLA cannot represent the // deadness semantics of these nodes correctly and auto-clustering these // nodes can cause deadness to propagate to nodes that should be live. if (node->IsMerge() || deadness->HasInputsWithMismatchingDeadness(*node)) { + VLOG(2) << "Rejecting " << node->name() << ": mismatching deadness."; return false; } // Check for fusable ops only if requested. if (global_jit_level > 0 && fusion_only && !IsXlaFusable(node->def())) { + VLOG(2) << "Rejecting " << node->name() + << ": not fusable op but fusion_only enabled."; return false; } @@ -505,12 +521,75 @@ Status MarkForCompilationPass::Run( // Ignore enable_jit_by_default if global jit compilation for CPU // is explicitly requested via tf_xla_cpu_global_jit flag bool ignore_registration = cpu_global_jit && device_type == DEVICE_CPU; - return (ignore_registration || registration->enable_jit_by_default) && - global_jit_level > 0; + bool should_compile = + (ignore_registration || registration->enable_jit_by_default) && + global_jit_level > 0; + if (!should_compile) { + if (global_jit_level <= 0) { + VLOG(2) << "Rejecting " << node->name() << ": global jit disabled."; + } else { + VLOG(2) << "Rejecting " << node->name() << ": JIT for device disabled."; + } + } + return should_compile; }; return RunImpl(options, is_compilable); } +static string RatioToString(int numerator, int denominator) { + return strings::Printf("%d / %d (%.2f%%)", numerator, denominator, + (100.0 * numerator) / denominator); +} + +static void VLogClusteringSummary(const Graph& g) { + if (!VLOG_IS_ON(2)) { + return; + } + + std::map cluster_name_to_size; + std::map> + cluster_name_to_op_histogram; + std::map unclustered_op_histogram; + int clustered_node_count = 0; + + for (Node* n : g.nodes()) { + gtl::optional cluster_name = GetXlaClusterForNode(*n); + if (cluster_name) { + clustered_node_count++; + cluster_name_to_size[*cluster_name]++; + cluster_name_to_op_histogram[*cluster_name][n->type_string()]++; + } else { + unclustered_op_histogram[n->type_string()]++; + } + } + + int unclustered_node_count = g.num_nodes() - clustered_node_count; + + VLOG(2) << "*** Clustering info for graph of size " << g.num_nodes(); + VLOG(2) << " Built " << cluster_name_to_size.size() << " clusters, size " + << RatioToString(clustered_node_count, g.num_nodes()); + + for (const auto& cluster_name_size_pair : cluster_name_to_size) { + StringPiece cluster_name = cluster_name_size_pair.first; + int size = cluster_name_size_pair.second; + VLOG(2) << " " << cluster_name << " " + << RatioToString(size, g.num_nodes()); + for (const auto& op_count_pair : + cluster_name_to_op_histogram[cluster_name]) { + VLOG(3) << " " << op_count_pair.first << ": " << op_count_pair.second + << " instances"; + } + } + + if (!unclustered_op_histogram.empty()) { + VLOG(2) << " Unclustered nodes: " + << RatioToString(unclustered_node_count, g.num_nodes()); + for (const auto& pair : unclustered_op_histogram) { + VLOG(3) << " " << pair.first << ": " << pair.second << " instances"; + } + } +} + // Is 'node' an operator that consumes only the shape of its input, not the // data itself? static bool IsShapeConsumerOp(const Node& node) { @@ -699,6 +778,9 @@ Status MarkForCompilationPass::RunImpl( dump_graph::DumpGraphToFile("mark_for_compilation", **options.graph, options.flib_def); } + + VLogClusteringSummary(*graph); + return Status::OK(); } diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass.h b/tensorflow/compiler/jit/mark_for_compilation_pass.h index e9acbfb19e42cb43cb0b986c438a569de29b2ebc..f1137af3c1e8539fda318d88d2c5b5187953ccab 100644 --- a/tensorflow/compiler/jit/mark_for_compilation_pass.h +++ b/tensorflow/compiler/jit/mark_for_compilation_pass.h @@ -40,20 +40,18 @@ class MarkForCompilationPass : public GraphOptimizationPass { Status Run(const GraphOptimizationPassOptions& options) override; - // Run() just calls RunImpl() if --tf_xla_auto_jit is enabled. To run the pass - // unconditionally, call RunImpl() directly. - // is_compilable_fn, if set, is a predicate that must be true for a node to - // be compiled. + private: Status RunImpl(const GraphOptimizationPassOptions& options, const std::function& is_compilable_fn = {}); + + friend class MarkForCompilationPassTestHelper; }; // Returns true iff 'ndef' is a call to a function that is compilable. A // function is compilable iff every operator in the function body is // compilable. bool IsCompilable(FunctionLibraryRuntime* flr, const NodeDef& ndef); - } // namespace tensorflow #endif // TENSORFLOW_COMPILER_JIT_MARK_FOR_COMPILATION_PASS_H_ diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc b/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc index 2c5f4fb774fcab082c0d0d316cdc6757cacc1e96..a780d4a936a3b757495c26d337f19c80a67f343a 100644 --- a/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc +++ b/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/jit/mark_for_compilation_pass.h" +#include "tensorflow/compiler/jit/mark_for_compilation_pass_test_helper.h" #include "tensorflow/cc/framework/ops.h" #include "tensorflow/cc/ops/array_ops.h" @@ -39,27 +39,6 @@ namespace { REGISTER_OP("UncompilableNullary").Output("o: float"); REGISTER_OP("UncompilableUnary").Input("a: float").Output("o: float"); -Status MarkForCompilation(std::unique_ptr* graph, - FunctionLibraryDefinition* flib_def) { - // Assign all nodes to the CPU device. - static const char* kCpuDevice = "/job:localhost/replica:0/task:0/cpu:0"; - for (Node* n : (*graph)->nodes()) { - n->set_assigned_device_name(kCpuDevice); - } - - GraphOptimizationPassOptions opt_options; - opt_options.graph = graph; - opt_options.flib_def = flib_def; - MarkForCompilationPass pass; - return pass.RunImpl(opt_options); -} - -Status MarkForCompilation(std::unique_ptr* graph) { - FunctionDefLibrary flib; - FunctionLibraryDefinition flib_def((*graph)->op_registry(), flib); - return MarkForCompilation(graph, &flib_def); -} - std::unordered_map GetClusters(const Graph& graph) { std::unordered_map ids; for (Node* node : graph.nodes()) { @@ -88,7 +67,7 @@ TEST(XlaCompilationTest, Chains) { TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get())); } - TF_ASSERT_OK(MarkForCompilation(&graph)); + TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph)); auto clusters = GetClusters(*graph); EXPECT_EQ(4, clusters.size()); EXPECT_EQ(clusters["B"], clusters["C"]); @@ -113,7 +92,7 @@ TEST(XlaCompilationTest, UncompilableCycles) { TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get())); } - TF_ASSERT_OK(MarkForCompilation(&graph)); + TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph)); auto clusters = GetClusters(*graph); EXPECT_TRUE(clusters.empty()); @@ -133,7 +112,7 @@ TEST(XlaCompilationTest, CompilableCycles) { TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get())); } - TF_ASSERT_OK(MarkForCompilation(&graph)); + TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph)); auto clusters = GetClusters(*graph); EXPECT_EQ(3, clusters.size()); @@ -156,7 +135,7 @@ TEST(XlaCompilationTest, Complex128Unsupported) { TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get())); } - TF_ASSERT_OK(MarkForCompilation(&graph)); + TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph)); auto clusters = GetClusters(*graph); EXPECT_TRUE(clusters.empty()); } @@ -177,7 +156,7 @@ TEST(XlaCompilationTest, HalfSupported) { TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get())); } - TF_ASSERT_OK(MarkForCompilation(&graph)); + TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph)); auto clusters = GetClusters(*graph); EXPECT_FALSE(clusters.empty()); } @@ -206,7 +185,7 @@ TEST(XlaCompilationTest, ConcatWithConstArg) { TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get())); } - TF_ASSERT_OK(MarkForCompilation(&graph)); + TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph)); auto clusters = GetClusters(*graph); EXPECT_EQ(3, clusters.size()); // Everything should be compiled. } @@ -241,7 +220,8 @@ TEST(XlaCompilationTest, FunctionCalls) { TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get())); } - TF_ASSERT_OK(MarkForCompilation(&graph, &flib_def)); + TF_ASSERT_OK( + MarkForCompilationPassTestHelper::MarkForCompilation(&graph, &flib_def)); auto clusters = GetClusters(*graph); EXPECT_EQ(2, clusters.size()); @@ -272,7 +252,7 @@ TEST(XlaCompilationTest, MetadataOpsDontStartClusters) { ops::UnaryOp("Shape", d, builder.opts().WithName("E")); TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get())); } - TF_ASSERT_OK(MarkForCompilation(&graph)); + TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph)); auto clusters = GetClusters(*graph); EXPECT_EQ(0, clusters.size()); // Nothing should be compiled. } @@ -359,7 +339,7 @@ TEST(XlaCompilationTest, SymbolicGradients) { TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get())); } - TF_ASSERT_OK(MarkForCompilation(&graph)); + TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph)); auto clusters = GetClusters(*graph); EXPECT_EQ(2, clusters.size()); @@ -384,7 +364,7 @@ TEST(XlaCompilationTest, Loops) { std::unique_ptr graph(new Graph(OpRegistry::Global())); TF_EXPECT_OK(root.ToGraph(graph.get())); - TF_ASSERT_OK(MarkForCompilation(&graph)); + TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph)); auto clusters = GetClusters(*graph); // Nothing should be compiled. In particular, 'd' and 'c' must not be @@ -411,7 +391,7 @@ TEST(XlaCompilationTest, CyclesWithAllDifferentScopes) { TF_CHECK_OK(GraphDefBuilderToGraph(builder, graph.get())); } - TF_ASSERT_OK(MarkForCompilation(&graph)); + TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph)); auto clusters = GetClusters(*graph); // The computation is: C = A + relu(A) @@ -442,7 +422,7 @@ TEST(XlaCompilationTest, CyclesWithSplittingScopes) { TF_CHECK_OK(GraphDefBuilderToGraph(builder, graph.get())); } - TF_ASSERT_OK(MarkForCompilation(&graph)); + TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph)); auto clusters = GetClusters(*graph); // The computation is: D = relu(A) + (A @ relu(A)) @@ -472,7 +452,7 @@ TEST(XlaCompilationTest, CyclesWithDifferentScopesAndBridge) { TF_CHECK_OK(GraphDefBuilderToGraph(builder, graph.get())); } - TF_ASSERT_OK(MarkForCompilation(&graph)); + TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph)); auto clusters = GetClusters(*graph); // The computation is: C = A @ relu(A) @@ -512,7 +492,7 @@ TEST(XlaCompilationTest, Resources) { ops::UnaryOp("Relu", d, builder.opts().WithName("E")); TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get())); } - TF_ASSERT_OK(MarkForCompilation(&graph)); + TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph)); auto clusters = GetClusters(*graph); EXPECT_EQ(0, clusters.size()); // Nothing should be compiled. } @@ -542,7 +522,7 @@ TEST(XlaCompilationTest, IllegalCycle_UsefulErrorMessage) { TF_EXPECT_OK(root.ToGraph(graph.get())); - Status status = MarkForCompilation(&graph); + Status status = MarkForCompilationPassTestHelper::MarkForCompilation(&graph); EXPECT_FALSE(status.ok()); EXPECT_TRUE(str_util::StrContains(status.ToString(), "Edge from c to a would create a cycle.\n" @@ -570,7 +550,7 @@ TEST(XlaCompilationTest, Retval) { TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get())); } - TF_ASSERT_OK(MarkForCompilation(&graph)); + TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph)); auto clusters = GetClusters(*graph); EXPECT_EQ(2, clusters.size()); @@ -588,7 +568,7 @@ TEST(XlaCompilationTest, DontCountIdentityOps) { auto r = ops::_Retval(root.WithOpName("R"), c, 0); } TF_ASSERT_OK(root.ToGraph(graph.get())); - TF_ASSERT_OK(MarkForCompilation(&graph)); + TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph)); auto clusters = GetClusters(*graph); EXPECT_TRUE(clusters.empty()); @@ -604,7 +584,7 @@ TEST(XlaCompilationTest, DontCountIdentityOpsWithLocalJit) { auto r = ops::_Retval(root.WithOpName("R"), b, 0); } TF_ASSERT_OK(root.ToGraph(graph.get())); - TF_ASSERT_OK(MarkForCompilation(&graph)); + TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph)); auto clusters = GetClusters(*graph); EXPECT_TRUE(clusters.empty()); @@ -618,7 +598,7 @@ TEST(XlaCompilationTest, ConstOp) { auto c = ops::Const(root.WithOpName("const"), 0.5f); c.node()->AddAttr(kXlaCompileAttr, true); TF_ASSERT_OK(root.ToGraph(graph.get())); - TF_ASSERT_OK(MarkForCompilation(&graph)); + TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph)); EXPECT_EQ(1, GetClusters(*graph).size()); } @@ -629,7 +609,7 @@ TEST(XlaCompilationTest, ConstOp) { auto c = ops::Const(root.WithOpName("const"), string("string")); c.node()->AddAttr(kXlaCompileAttr, true); TF_ASSERT_OK(root.ToGraph(graph.get())); - TF_ASSERT_OK(MarkForCompilation(&graph)); + TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph)); EXPECT_TRUE(GetClusters(*graph).empty()); } } @@ -644,7 +624,7 @@ TEST(XlaCompilationTest, DontClusterIdentityWithRefInput) { std::unique_ptr graph(new Graph(OpRegistry::Global())); TF_ASSERT_OK(root.ToGraph(graph.get())); - TF_ASSERT_OK(MarkForCompilation(&graph)); + TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph)); std::unordered_map clusters = GetClusters(*graph); @@ -667,7 +647,7 @@ TEST(XlaCompilationTest, ClusterIdentityWithNonRefInput) { std::unique_ptr graph(new Graph(OpRegistry::Global())); TF_ASSERT_OK(root.ToGraph(graph.get())); - TF_ASSERT_OK(MarkForCompilation(&graph)); + TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph)); std::unordered_map clusters = GetClusters(*graph); @@ -699,7 +679,7 @@ TEST(XlaCompilationTest, ClusterControlTrigger) { std::unique_ptr graph(new Graph(OpRegistry::Global())); TF_ASSERT_OK(root.ToGraph(graph.get())); - TF_ASSERT_OK(MarkForCompilation(&graph)); + TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph)); std::unordered_map clusters = GetClusters(*graph); diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass_test_helper.cc b/tensorflow/compiler/jit/mark_for_compilation_pass_test_helper.cc new file mode 100644 index 0000000000000000000000000000000000000000..a84b82e47923b2e7eec0e7eb848bd4377befbd07 --- /dev/null +++ b/tensorflow/compiler/jit/mark_for_compilation_pass_test_helper.cc @@ -0,0 +1,40 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/jit/mark_for_compilation_pass_test_helper.h" + +namespace tensorflow { +/*static*/ Status MarkForCompilationPassTestHelper::MarkForCompilation( + std::unique_ptr* graph, FunctionLibraryDefinition* flib_def) { + // Assign all nodes to the CPU device. + static const char* kCpuDevice = "/job:localhost/replica:0/task:0/cpu:0"; + for (Node* n : (*graph)->nodes()) { + n->set_assigned_device_name(kCpuDevice); + } + + GraphOptimizationPassOptions opt_options; + opt_options.graph = graph; + opt_options.flib_def = flib_def; + MarkForCompilationPass pass; + return pass.RunImpl(opt_options); +} + +/*static*/ Status MarkForCompilationPassTestHelper::MarkForCompilation( + std::unique_ptr* graph) { + FunctionDefLibrary flib; + FunctionLibraryDefinition flib_def((*graph)->op_registry(), flib); + return MarkForCompilation(graph, &flib_def); +} +} // namespace tensorflow diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass_test_helper.h b/tensorflow/compiler/jit/mark_for_compilation_pass_test_helper.h new file mode 100644 index 0000000000000000000000000000000000000000..b9a0531cb0e431a98d57a6d9a2e3e41b51e7b743 --- /dev/null +++ b/tensorflow/compiler/jit/mark_for_compilation_pass_test_helper.h @@ -0,0 +1,35 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_JIT_MARK_FOR_COMPILATION_PASS_TEST_HELPER_H_ +#define TENSORFLOW_COMPILER_JIT_MARK_FOR_COMPILATION_PASS_TEST_HELPER_H_ + +#include "tensorflow/compiler/jit/mark_for_compilation_pass.h" + +namespace tensorflow { +class MarkForCompilationPassTestHelper { + public: + // Runs the MarkForCompilation pass on `graph` after assigning all nodes in + // `graph` to the CPU device. To make testing easier, ignores device + // registration, _XlaCompile attributes, input deadness and global jit level. + static Status MarkForCompilation(std::unique_ptr* graph, + FunctionLibraryDefinition* flib_def); + + // Like `MarkForCompilation` but creates `flib_def` from the op registry. + static Status MarkForCompilation(std::unique_ptr* graph); +}; +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_JIT_MARK_FOR_COMPILATION_PASS_TEST_HELPER_H_ diff --git a/tensorflow/compiler/jit/partially_decluster_pass.cc b/tensorflow/compiler/jit/partially_decluster_pass.cc new file mode 100644 index 0000000000000000000000000000000000000000..68ead39424c35c1ef0bcc92e57af7931c0c57462 --- /dev/null +++ b/tensorflow/compiler/jit/partially_decluster_pass.cc @@ -0,0 +1,177 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/jit/partially_decluster_pass.h" +#include "tensorflow/compiler/jit/xla_cluster_util.h" +#include "tensorflow/core/framework/memory_types.h" +#include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/lib/gtl/flatset.h" + +namespace tensorflow { +namespace { +Status FindNodesToDecluster(const Graph& graph, gtl::FlatSet* result, + gtl::ArraySlice post_order) { + // Find nodes that have at least one user outside their cluster that expects + // hostmem output. These nodes should be cloned to outside the cluster to + // avoid the device-host copy we'd otherwise need. + + MemoryTypeVector input_mtypes, output_mtypes; + + for (Node* n : post_order) { + gtl::optional from_cluster = GetXlaClusterForNode(*n); + if (!from_cluster) { + continue; + } + + // We assume the only XLA-auto-clusterable operations with side effects are + // resource variable updates. We can't execute these twice. + if (HasResourceInputOrOutput(*n)) { + continue; + } + + DeviceType device_type(""); + TF_RETURN_IF_ERROR( + DeviceToDeviceType(n->assigned_device_name(), &device_type)); + TF_RETURN_IF_ERROR(MemoryTypesForNode(graph.op_registry(), device_type, + n->def(), &input_mtypes, + &output_mtypes)); + for (const Edge* e : n->out_edges()) { + Node* dst = e->dst(); + + if (e->IsControlEdge()) { + continue; + } + + bool edge_incurs_extra_device_to_host_copy; + if (output_mtypes[e->src_output()] == DEVICE_MEMORY) { + // If the output of the *TensorFlow* operation is in DEVICE_MEMORY then + // keep the node clustered -- XLA will also produce the output in device + // memory and we will get some benefit from clustering. + edge_incurs_extra_device_to_host_copy = false; + } else { + MemoryTypeVector dst_input_mtypes, dst_output_mtypes; + DeviceType dst_device_type(""); + TF_RETURN_IF_ERROR( + DeviceToDeviceType(dst->assigned_device_name(), &dst_device_type)); + TF_RETURN_IF_ERROR(MemoryTypesForNode(graph.op_registry(), device_type, + dst->def(), &dst_input_mtypes, + &dst_output_mtypes)); + edge_incurs_extra_device_to_host_copy = + dst_input_mtypes[e->dst_input()] == HOST_MEMORY; + } + + if (!edge_incurs_extra_device_to_host_copy) { + continue; + } + + // Check if `dst` is in a different cluster, unclustered, or about to be + // partially declustered (here we rely on the post-order traversal order). + // If yes, decluster `n` to avoid the device-to-host memcpy. + gtl::optional dst_cluster = + result->count(dst) ? gtl::nullopt : GetXlaClusterForNode(*dst); + if (from_cluster != dst_cluster) { + CHECK(result->insert(n).second); + break; + } + } + } + return Status::OK(); +} + +Status PartiallyDeclusterNode(Graph* graph, Node* n) { + StringPiece cluster_name = *GetXlaClusterForNode(*n); + gtl::InlinedVector out_edges_to_clone; + for (const Edge* out_edge : n->out_edges()) { + if (out_edge->IsControlEdge()) { + continue; + } + + Node* dst = out_edge->dst(); + gtl::optional dst_cluster_name = GetXlaClusterForNode(*dst); + if (dst_cluster_name != cluster_name) { + out_edges_to_clone.push_back(out_edge); + } + } + + CHECK(!out_edges_to_clone.empty()) << n->DebugString(); + + NodeDef ndef = n->def(); + ndef.set_name(strings::StrCat(n->name(), "/declustered")); + RemoveFromXlaCluster(&ndef); + Status s; + Node* cloned_node = graph->AddNode(ndef, &s); + cloned_node->set_assigned_device_name(n->assigned_device_name()); + TF_RETURN_IF_ERROR(s); + + for (const Edge* in_edge : n->in_edges()) { + graph->AddEdge(in_edge->src(), in_edge->src_output(), cloned_node, + in_edge->dst_input()); + } + + for (const Edge* out_edge_to_clone : out_edges_to_clone) { + graph->AddEdge(cloned_node, out_edge_to_clone->src_output(), + out_edge_to_clone->dst(), out_edge_to_clone->dst_input()); + graph->RemoveEdge(out_edge_to_clone); + } + + return Status::OK(); +} +} // namespace + +Status PartiallyDeclusterPass::Run( + const GraphOptimizationPassOptions& options) { + // NB! In this pass we assume the only XLA-auto-clusterable operations that + // may have side effects are resource variable operations so we don't cluster + // those. The pass will have to be updated if this assumption becomes + // invalid. + + Graph* graph = options.graph->get(); + + // When deciding whether to decluster a particular node, we base our decision + // on if we've decided that some of its consumers have to be declustered too. + // Iterating the graph in post-order guarantees that consumers have been + // visited before producers. + std::vector post_order; + GetPostOrder(*graph, &post_order, /*stable_comparator=*/NodeComparatorName(), + /*edge_filter=*/[](const Edge& edge) { + return !edge.src()->IsNextIteration(); + }); + + gtl::FlatSet nodes_to_partially_decluster; + TF_RETURN_IF_ERROR(FindNodesToDecluster( + **options.graph, &nodes_to_partially_decluster, post_order)); + + if (VLOG_IS_ON(3)) { + for (Node* n : post_order) { + if (nodes_to_partially_decluster.count(n)) { + VLOG(3) << n->DebugString(); + } + } + } + + for (Node* n : post_order) { + if (nodes_to_partially_decluster.count(n)) { + TF_RETURN_IF_ERROR(PartiallyDeclusterNode(graph, n)); + } + } + + nodes_to_partially_decluster.clear(); + TF_RETURN_IF_ERROR(FindNodesToDecluster( + **options.graph, &nodes_to_partially_decluster, post_order)); + CHECK(nodes_to_partially_decluster.empty()); + + return Status::OK(); +} +} // namespace tensorflow diff --git a/tensorflow/compiler/jit/partially_decluster_pass.h b/tensorflow/compiler/jit/partially_decluster_pass.h new file mode 100644 index 0000000000000000000000000000000000000000..6949b5028ee55e182b27589f9a9711dad7839e86 --- /dev/null +++ b/tensorflow/compiler/jit/partially_decluster_pass.h @@ -0,0 +1,58 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_JIT_PARTIALLY_DECLUSTER_PASS_H_ +#define TENSORFLOW_COMPILER_JIT_PARTIALLY_DECLUSTER_PASS_H_ + +#include "tensorflow/core/common_runtime/optimization_registry.h" + +namespace tensorflow { + +// Clones nodes from within a cluster to outside the cluster if profitable. +// +// Today this only clones to avoid device-to-host copies, but in the future we +// may consider other reasons to clone. For instance, we convert this: +// +// ..... +// | +// v +// A_Clustered ====> C_Unclustered +// | +// v +// B_Clustered +// +// to: +// +// ..... +// | | +// | +-------------+ +// | | +// v v +// A_Clustered A_Unclustered ====> C_Unclustered +// | +// v +// B_Clustered +// +// where the ===> arrow has a hostmem source and destination and would entail a +// device to host copy if the source and destination were not in the same XLA +// cluster. +class PartiallyDeclusterPass : public GraphOptimizationPass { + public: + Status Run(const GraphOptimizationPassOptions& options) override; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_JIT_PARTIALLY_DECLUSTER_PASS_H_ diff --git a/tensorflow/compiler/jit/partially_decluster_pass_test.cc b/tensorflow/compiler/jit/partially_decluster_pass_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..08a956e4c6478ff76a0fe8f1f60d94824daf535c --- /dev/null +++ b/tensorflow/compiler/jit/partially_decluster_pass_test.cc @@ -0,0 +1,284 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/jit/partially_decluster_pass.h" + +#include "tensorflow/cc/framework/ops.h" +#include "tensorflow/cc/ops/array_ops.h" +#include "tensorflow/cc/ops/control_flow_ops_internal.h" +#include "tensorflow/cc/ops/function_ops.h" +#include "tensorflow/cc/ops/sendrecv_ops.h" +#include "tensorflow/cc/ops/standard_ops.h" +#include "tensorflow/compiler/jit/defs.h" +#include "tensorflow/compiler/jit/xla_cluster_util.h" +#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/core/framework/node_def_util.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/graph/algorithm.h" +#include "tensorflow/core/graph/graph_constructor.h" +#include "tensorflow/core/graph/graph_def_builder.h" +#include "tensorflow/core/graph/graph_def_builder_util.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { +namespace { +REGISTER_OP("FakeNullary").Output("out: float"); + +REGISTER_OP("FakeBinary") + .Input("host_in: float") + .Input("device_in: float") + .Output("host_out: float") + .Output("device_out: float"); + +REGISTER_OP("FakeResourceVar").Output("out: resource"); + +REGISTER_OP("FakeResourceUpdate") + .Input("in: resource") + .Output("out: resource") + .Output("something_else: float"); + +class FakeBinaryOp : public OpKernel { + public: + explicit FakeBinaryOp(OpKernelConstruction* context) : OpKernel(context) {} + + void Compute(OpKernelContext* ctx) override { CHECK(false); } +}; + +class FakeResourceVarUpdateOp : public OpKernel { + public: + explicit FakeResourceVarUpdateOp(OpKernelConstruction* context) + : OpKernel(context) {} + + void Compute(OpKernelContext* ctx) override { CHECK(false); } +}; + +REGISTER_KERNEL_BUILDER(Name("FakeBinary") + .Device(DEVICE_CPU) + .HostMemory("host_in") + .HostMemory("host_out"), + FakeBinaryOp); + +REGISTER_KERNEL_BUILDER(Name("FakeResourceVarUpdate") + .Device(DEVICE_CPU) + .HostMemory("something_else"), + FakeResourceVarUpdateOp); + +Status PartiallyDecluster(std::unique_ptr* graph) { + FixupSourceAndSinkEdges(graph->get()); + // Assign all nodes to the CPU device. + static const char* kCpuDevice = "/job:localhost/replica:0/task:0/cpu:0"; + for (Node* n : (*graph)->nodes()) { + n->set_assigned_device_name(kCpuDevice); + } + + GraphOptimizationPassOptions opt_options; + opt_options.graph = graph; + PartiallyDeclusterPass pass; + return pass.Run(opt_options); +} + +const Node* FindNodeByName(const Graph& graph, const string& name) { + for (const Node* node : graph.nodes()) { + if (node->name() == name) { + return node; + } + } + return nullptr; +} + +bool GetInputsForNode(const Graph& graph, const string& node_name, + std::vector* inputs) { + const Node* node = FindNodeByName(graph, node_name); + if (node == nullptr) { + return false; + } + for (const Edge* e : node->in_edges()) { + inputs->push_back(e->src()); + } + std::sort(inputs->begin(), inputs->end(), NodeComparatorName()); + return true; +} + +TEST(PartiallyDeclusterPassTest, ClusteredAndUnclustered) { + std::unique_ptr graph(new Graph(OpRegistry::Global())); + { + GraphDefBuilder builder(GraphDefBuilder::kFailImmediately); + Node* input = + ops::SourceOp("FakeNullary", builder.opts().WithName("Input")); + Node* clustered_producer = + ops::BinaryOp("FakeBinary", input, input, + builder.opts().WithName("ClusteredProducer")); + ops::BinaryOp("FakeBinary", clustered_producer, input, + builder.opts().WithName("UnclusteredConsumer")); + Node* clustered_consumer = + ops::BinaryOp("FakeBinary", {clustered_producer, 1}, input, + builder.opts().WithName("ClusteredConsumer")); + clustered_producer->AddAttr(kXlaClusterAttr, "cluster_0"); + clustered_consumer->AddAttr(kXlaClusterAttr, "cluster_0"); + TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get())); + } + + TF_ASSERT_OK(PartiallyDecluster(&graph)); + std::vector unclustered_consumer_inputs; + ASSERT_TRUE(GetInputsForNode(*graph, "UnclusteredConsumer", + &unclustered_consumer_inputs)); + ASSERT_EQ(unclustered_consumer_inputs.size(), 2); + EXPECT_EQ(unclustered_consumer_inputs[0]->name(), + "ClusteredProducer/declustered"); + EXPECT_EQ(unclustered_consumer_inputs[1]->name(), "Input"); + + std::vector clustered_consumer_inputs; + ASSERT_TRUE(GetInputsForNode(*graph, "ClusteredConsumer", + &clustered_consumer_inputs)); + ASSERT_EQ(clustered_consumer_inputs.size(), 2); + EXPECT_EQ(clustered_consumer_inputs[0]->name(), "ClusteredProducer"); + EXPECT_EQ(clustered_consumer_inputs[1]->name(), "Input"); +} + +TEST(PartiallyDeclusterPassTest, DifferentClusters) { + std::unique_ptr graph(new Graph(OpRegistry::Global())); + { + GraphDefBuilder builder(GraphDefBuilder::kFailImmediately); + Node* input = + ops::SourceOp("FakeNullary", builder.opts().WithName("Input")); + Node* clustered_producer = + ops::BinaryOp("FakeBinary", input, input, + builder.opts().WithName("ClusteredProducer")); + Node* consumer_in_different_cluster = + ops::BinaryOp("FakeBinary", clustered_producer, input, + builder.opts().WithName("ConsumerInDifferentCluster")); + Node* clustered_consumer = + ops::BinaryOp("FakeBinary", input, {clustered_producer, 1}, + builder.opts().WithName("ClusteredConsumer")); + clustered_producer->AddAttr(kXlaClusterAttr, "cluster_0"); + clustered_consumer->AddAttr(kXlaClusterAttr, "cluster_0"); + consumer_in_different_cluster->AddAttr(kXlaClusterAttr, "cluster_1"); + TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get())); + } + + TF_ASSERT_OK(PartiallyDecluster(&graph)); + std::vector inputs; + ASSERT_TRUE(GetInputsForNode(*graph, "ConsumerInDifferentCluster", &inputs)); + ASSERT_EQ(inputs.size(), 2); + EXPECT_EQ(inputs[0]->name(), "ClusteredProducer/declustered"); + EXPECT_EQ(inputs[1]->name(), "Input"); +} + +TEST(PartiallyDeclusterPassTest, DontDeclusterIfUserIsDeviceMem) { + std::unique_ptr graph(new Graph(OpRegistry::Global())); + { + GraphDefBuilder builder(GraphDefBuilder::kFailImmediately); + Node* input = + ops::SourceOp("FakeNullary", builder.opts().WithName("Input")); + Node* clustered_producer = + ops::BinaryOp("FakeBinary", input, input, + builder.opts().WithName("ClusteredProducer")); + // The first input is hostmem and the second input is devicemem. + Node* consumer_in_different_cluster = + ops::BinaryOp("FakeBinary", input, clustered_producer, + builder.opts().WithName("ConsumerInDifferentCluster")); + Node* clustered_consumer = + ops::BinaryOp("FakeBinary", input, {clustered_producer, 1}, + builder.opts().WithName("ClusteredConsumer")); + clustered_producer->AddAttr(kXlaClusterAttr, "cluster_0"); + clustered_consumer->AddAttr(kXlaClusterAttr, "cluster_0"); + consumer_in_different_cluster->AddAttr(kXlaClusterAttr, "cluster_1"); + TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get())); + } + + TF_ASSERT_OK(PartiallyDecluster(&graph)); + std::vector inputs; + ASSERT_TRUE(GetInputsForNode(*graph, "ConsumerInDifferentCluster", &inputs)); + ASSERT_EQ(inputs.size(), 2); + EXPECT_EQ(inputs[0]->name(), "ClusteredProducer"); + EXPECT_EQ(inputs[1]->name(), "Input"); +} + +TEST(PartiallyDeclusterPassTest, DontDuplicateResourceVarOps) { + std::unique_ptr graph(new Graph(OpRegistry::Global())); + { + GraphDefBuilder builder(GraphDefBuilder::kFailImmediately); + Node* input = + ops::SourceOp("FakeNullary", builder.opts().WithName("Input")); + Node* resource_var = ops::SourceOp("FakeResourceVar", + builder.opts().WithName("ResourceVar")); + Node* clustered_producer = + ops::UnaryOp("FakeResourceUpdate", resource_var, + builder.opts().WithName("ClusteredProducer")); + Node* consumer_in_different_cluster = + ops::BinaryOp("FakeBinary", {clustered_producer, 1}, input, + builder.opts().WithName("ConsumerInDifferentCluster")); + Node* clustered_consumer = + ops::BinaryOp("FakeBinary", input, {clustered_producer, 1}, + builder.opts().WithName("ClusteredConsumer")); + clustered_producer->AddAttr(kXlaClusterAttr, "cluster_0"); + clustered_consumer->AddAttr(kXlaClusterAttr, "cluster_0"); + consumer_in_different_cluster->AddAttr(kXlaClusterAttr, "cluster_1"); + TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get())); + } + + TF_ASSERT_OK(PartiallyDecluster(&graph)); + std::vector inputs; + ASSERT_TRUE(GetInputsForNode(*graph, "ConsumerInDifferentCluster", &inputs)); + ASSERT_EQ(inputs.size(), 2); + EXPECT_EQ(inputs[0]->name(), "ClusteredProducer"); + EXPECT_EQ(inputs[1]->name(), "Input"); +} + +TEST(PartiallyDeclusterPassTest, DeclusterDependentNodes) { + std::unique_ptr graph(new Graph(OpRegistry::Global())); + { + GraphDefBuilder builder(GraphDefBuilder::kFailImmediately); + Node* input = + ops::SourceOp("FakeNullary", builder.opts().WithName("Input")); + Node* clustered_producer_0 = + ops::BinaryOp("FakeBinary", input, input, + builder.opts().WithName("ClusteredProducer0")); + Node* clustered_producer_1 = + ops::BinaryOp("FakeBinary", clustered_producer_0, input, + builder.opts().WithName("ClusteredProducer1")); + ops::BinaryOp("FakeBinary", clustered_producer_1, input, + builder.opts().WithName("UnclusteredConsumer")); + Node* clustered_consumer = + ops::BinaryOp("FakeBinary", {clustered_producer_1, 1}, input, + builder.opts().WithName("ClusteredConsumer")); + clustered_producer_0->AddAttr(kXlaClusterAttr, "cluster_0"); + clustered_producer_1->AddAttr(kXlaClusterAttr, "cluster_0"); + clustered_consumer->AddAttr(kXlaClusterAttr, "cluster_0"); + TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get())); + } + + TF_ASSERT_OK(PartiallyDecluster(&graph)); + std::vector unclustered_consumer_inputs, declustered_producer_1_inputs; + + ASSERT_TRUE(GetInputsForNode(*graph, "UnclusteredConsumer", + &unclustered_consumer_inputs)); + ASSERT_EQ(unclustered_consumer_inputs.size(), 2); + EXPECT_EQ(unclustered_consumer_inputs[0]->name(), + "ClusteredProducer1/declustered"); + EXPECT_EQ(unclustered_consumer_inputs[1]->name(), "Input"); + + ASSERT_TRUE(GetInputsForNode(*graph, "ClusteredProducer1/declustered", + &declustered_producer_1_inputs)); + ASSERT_EQ(declustered_producer_1_inputs.size(), 2); + EXPECT_EQ(declustered_producer_1_inputs[0]->name(), + "ClusteredProducer0/declustered"); + EXPECT_EQ(declustered_producer_1_inputs[1]->name(), "Input"); +} +} // namespace +} // namespace tensorflow diff --git a/tensorflow/compiler/jit/xla_cluster_util.cc b/tensorflow/compiler/jit/xla_cluster_util.cc index a5628b12a27c9ed052e22c784517a07f2c1c059a..0a025a1fc0b268963069a8c1a3be700040be3f8e 100644 --- a/tensorflow/compiler/jit/xla_cluster_util.cc +++ b/tensorflow/compiler/jit/xla_cluster_util.cc @@ -185,4 +185,26 @@ Status CreateCycleDetectionGraph(const Graph* graph, GraphCycles* cycles) { return Status::OK(); } +gtl::optional GetXlaClusterForNode(const Node& node) { + const AttrValue* attr_value = node.attrs().Find(kXlaClusterAttr); + if (attr_value == nullptr) { + return gtl::nullopt; + } + Status s = AttrValueHasType(*attr_value, "string"); + if (!s.ok()) { + return gtl::nullopt; + } + return attr_value->s(); +} + +bool HasResourceInputOrOutput(const Node& node) { + return std::find(node.input_types().begin(), node.input_types().end(), + DT_RESOURCE) != node.input_types().end() || + std::find(node.output_types().begin(), node.output_types().end(), + DT_RESOURCE) != node.output_types().end(); +} + +void RemoveFromXlaCluster(NodeDef* node_def) { + node_def->mutable_attr()->erase(kXlaClusterAttr); +} } // namespace tensorflow diff --git a/tensorflow/compiler/jit/xla_cluster_util.h b/tensorflow/compiler/jit/xla_cluster_util.h index bcce082aaf6044ff0654efa4d78c0f493a350d00..bff76da6f9bcb06170e5aeb111da8545a6d291f8 100644 --- a/tensorflow/compiler/jit/xla_cluster_util.h +++ b/tensorflow/compiler/jit/xla_cluster_util.h @@ -20,6 +20,7 @@ limitations under the License. #include "tensorflow/compiler/jit/graphcycles/graphcycles.h" #include "tensorflow/core/graph/algorithm.h" +#include "tensorflow/core/lib/gtl/optional.h" namespace tensorflow { @@ -44,6 +45,16 @@ bool HasForwardedRefInput(const Node& node); // the enclosing graph. Status CreateCycleDetectionGraph(const Graph* graph, GraphCycles* cycles); +// Returns the XLA cluster in which `node` is placed if it is in an XLA cluster, +// otherwise returns nullopt. +gtl::optional GetXlaClusterForNode(const Node& node); + +// Removes `node_def` its XLA cluster (by clearing its _XlaCluster attribute). +void RemoveFromXlaCluster(NodeDef* node_def); + +// Returns true if `node` has a DT_RESOURCE typed input or output. +bool HasResourceInputOrOutput(const Node& node); + } // namespace tensorflow #endif // TENSORFLOW_COMPILER_JIT_XLA_CLUSTER_UTIL_H_ diff --git a/tensorflow/compiler/jit/xla_compilation_cache.cc b/tensorflow/compiler/jit/xla_compilation_cache.cc index 54a41a4daa790401c797277e7eaab531dd34ac80..7140d47a9421ec73d0144e855b490f89569e6ae9 100644 --- a/tensorflow/compiler/jit/xla_compilation_cache.cc +++ b/tensorflow/compiler/jit/xla_compilation_cache.cc @@ -209,7 +209,9 @@ Status XlaCompilationCache::BuildExecutable( argument_layouts[i] = &result.xla_input_shapes[i]; } xla::ExecutableBuildOptions build_options; - build_options.set_device_ordinal(client_->default_device_ordinal()); + build_options.set_device_ordinal(options.device_ordinal != -1 + ? options.device_ordinal + : client_->default_device_ordinal()); build_options.set_result_layout(result.xla_output_shape); build_options.set_device_allocator(options.device_allocator); @@ -256,6 +258,7 @@ Status XlaCompilationCache::CompileImpl( xla::LocalExecutable** executable, const XlaCompiler::CompileOptions* compile_options, bool compile_single_op) { + CHECK_NE(executable, nullptr); VLOG(1) << "XlaCompilationCache::Compile " << DebugString(); if (VLOG_IS_ON(2)) { @@ -293,7 +296,7 @@ Status XlaCompilationCache::CompileImpl( // protect the contents of the cache entry. Entry* entry; { - mutex_lock lock(mu_); + mutex_lock lock(compile_cache_mu_); // Find or create a cache entry. std::unique_ptr& e = cache_[signature]; if (!e) { @@ -309,6 +312,8 @@ Status XlaCompilationCache::CompileImpl( if (!entry->compiled) { VLOG(1) << "Compilation cache miss for signature: " << SignatureDebugString(signature); + tensorflow::Env* env = tensorflow::Env::Default(); + const uint64 compile_start_us = env->NowMicros(); // Do the actual JIT compilation without holding the lock (it can take // a long time.) std::vector args; @@ -327,18 +332,35 @@ Status XlaCompilationCache::CompileImpl( compile_options ? *compile_options : XlaCompiler::CompileOptions(), function, args, &entry->compilation_result); } - } - *compilation_result = &entry->compilation_result; - if (entry->compilation_status.ok() && executable) { - if (entry->executable == nullptr) { - entry->compilation_status = BuildExecutable( - options, entry->compilation_result, &entry->executable); + TF_RETURN_IF_ERROR(entry->compilation_status); + CHECK_EQ(entry->executable.get(), nullptr); + entry->compilation_status = + BuildExecutable(options, entry->compilation_result, &entry->executable); + + const uint64 compile_end_us = env->NowMicros(); + const uint64 compile_time_us = compile_end_us - compile_start_us; + { + mutex_lock lock(compile_stats_mu_); + auto it = compile_stats_.emplace(function.name(), CompileStats{}).first; + it->second.compile_count++; + it->second.cumulative_compile_time_us += compile_time_us; + VLOG(1) << "compiled " << function.name() << " " + << it->second.compile_count + << " times, compile time: " << compile_time_us + << " us, cumulative: " << it->second.cumulative_compile_time_us + << " us (" + << tensorflow::strings::HumanReadableElapsedTime(compile_time_us / + 1.0e6) + << " / " + << tensorflow::strings::HumanReadableElapsedTime( + it->second.cumulative_compile_time_us / 1.0e6) + << ")"; } - *executable = entry->executable.get(); } - - Status status = entry->compilation_status; - return status; + TF_RETURN_IF_ERROR(entry->compilation_status); + *compilation_result = &entry->compilation_result; + *executable = entry->executable.get(); + return Status::OK(); } } // namespace tensorflow diff --git a/tensorflow/compiler/jit/xla_compilation_cache.h b/tensorflow/compiler/jit/xla_compilation_cache.h index be1043d8c3fc0573922837e541615114a6d7a1a5..fc5f008f4f52c32d97e680784082d0e7bcb7d8eb 100644 --- a/tensorflow/compiler/jit/xla_compilation_cache.h +++ b/tensorflow/compiler/jit/xla_compilation_cache.h @@ -24,6 +24,7 @@ limitations under the License. #include "tensorflow/core/framework/graph.pb.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/lib/core/threadpool.h" +#include "tensorflow/core/lib/gtl/flatmap.h" #include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/platform/thread_annotations.h" @@ -150,9 +151,22 @@ class XlaCompilationCache : public ResourceBase { std::unique_ptr executable GUARDED_BY(mu); }; - mutex mu_; - std::unordered_map, Signature::Hash> cache_ - GUARDED_BY(mu_); + mutex compile_cache_mu_; + gtl::FlatMap, Signature::Hash> cache_ + GUARDED_BY(compile_cache_mu_); + + struct CompileStats { + // Number of times the cluster has been (re-)compiled. + int64 compile_count = 0; + + // Cumulative time spent compiling the cluster. + int64 cumulative_compile_time_us = 0; + }; + mutex compile_stats_mu_; + + // Maps cluster names to compilation statistics for said cluster. + gtl::FlatMap compile_stats_ + GUARDED_BY(compile_stats_mu_); TF_DISALLOW_COPY_AND_ASSIGN(XlaCompilationCache); }; diff --git a/tensorflow/compiler/jit/xla_compile_on_demand_op.cc b/tensorflow/compiler/jit/xla_compile_on_demand_op.cc index d288d37bc75380168a31937024dd41bdbe7dce9d..dd84fb34c171f8d2174444ddd3b3b476e7142718 100644 --- a/tensorflow/compiler/jit/xla_compile_on_demand_op.cc +++ b/tensorflow/compiler/jit/xla_compile_on_demand_op.cc @@ -18,6 +18,7 @@ limitations under the License. #include "tensorflow/compiler/jit/xla_compile_on_demand_op.h" #include "tensorflow/compiler/jit/xla_device.h" #include "tensorflow/compiler/jit/xla_launch_util.h" +#include "tensorflow/compiler/tf2xla/tf2xla_util.h" #include "tensorflow/compiler/tf2xla/xla_compiler.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" @@ -71,13 +72,14 @@ Status XlaCompileOnDemandOp::Run(OpKernelContext* ctx, run_options.set_stream(stream); run_options.set_allocator(client->backend().memory_allocator()); run_options.set_intra_op_thread_pool(&ctx->eigen_cpu_device()); - run_options.set_rng_seed(ctx->step_id()); + run_options.set_rng_seed(GetXLARandomSeed()); xla::StatusOr run_result = executable->Run(launch_context.arguments(), run_options); TF_RETURN_IF_ERROR(run_result.status()); - launch_context.PopulateOutputs(ctx, result, run_result.ConsumeValueOrDie()); + TF_RETURN_IF_ERROR(launch_context.PopulateOutputs( + ctx, result, run_result.ConsumeValueOrDie())); return Status::OK(); } diff --git a/tensorflow/compiler/jit/xla_device.cc b/tensorflow/compiler/jit/xla_device.cc index c55eba2f79ddcf10931ea659a64df559cef06ec5..70e6d0be0f2cffe98fd77fddac5866789c411a51 100644 --- a/tensorflow/compiler/jit/xla_device.cc +++ b/tensorflow/compiler/jit/xla_device.cc @@ -18,6 +18,7 @@ limitations under the License. #include #include +#include "absl/memory/memory.h" #include "tensorflow/compiler/jit/defs.h" #include "tensorflow/compiler/jit/xla_compile_on_demand_op.h" #include "tensorflow/compiler/jit/xla_device_context.h" @@ -26,6 +27,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/client_library.h" +#include "tensorflow/compiler/xla/service/stream_pool.h" #include "tensorflow/core/common_runtime/device.h" #include "tensorflow/core/common_runtime/device_factory.h" #include "tensorflow/core/common_runtime/dma_helper.h" @@ -100,7 +102,7 @@ XlaDeviceAllocator* XlaDeviceAllocatorState::GetOrCreateXlaDeviceAllocator( } std::unique_ptr alloc = - xla::MakeUnique(); + absl::make_unique(); XlaDeviceAllocator* alloc_ptr = alloc.get(); state.allocators_[{backend, device_ordinal}] = std::move(alloc); return alloc_ptr; @@ -211,17 +213,20 @@ XlaDevice::XlaDevice( use_multiple_streams), device_ordinal_(device_ordinal), jit_device_name_(jit_device_name), - xla_allocator_(nullptr), platform_(platform), use_multiple_streams_(use_multiple_streams), transfer_as_literal_(transfer_as_literal), shape_representation_fn_(shape_representation_fn) { - VLOG(1) << "Created XLA device " << jit_device_name; + VLOG(1) << "Created XLA device " << jit_device_name << " " << this; + thread_pool_.reset(new thread::ThreadPool(options.env, "xla_device", + /*num_threads=*/1)); } XlaDevice::~XlaDevice() { - if (gpu_device_info_ != nullptr) { - gpu_device_info_->default_context->Unref(); + VLOG(1) << "Destroying XLA device " << jit_device_name_ << " " << this; + mutex_lock lock(mu_); + if (device_context_) { + device_context_->Unref(); } } @@ -237,6 +242,11 @@ xla::LocalClient* XlaDevice::client() const { } Allocator* XlaDevice::GetAllocator(AllocatorAttributes attr) { + mutex_lock lock(mu_); + return GetAllocatorLocked(attr); +} + +Allocator* XlaDevice::GetAllocatorLocked(AllocatorAttributes attr) { if (attr.on_host()) { return cpu_allocator(); } @@ -249,83 +259,111 @@ Allocator* XlaDevice::GetAllocator(AllocatorAttributes attr) { return xla_allocator_; } -xla::StatusOr XlaDevice::GetStream() { - if (!stream_) { - xla::Backend* backend = client()->mutable_backend(); - TF_ASSIGN_OR_RETURN(stream_, backend->BorrowStream(device_ordinal_)); - } - return stream_.get(); +Status XlaDevice::EnsureDeviceContextOk() { + mutex_lock lock(mu_); + return GetDeviceContextLocked().status(); } -xla::StatusOr XlaDevice::GetDeviceToHostStream() { - if (!use_multiple_streams_) { - return GetStream(); - } - if (!device_to_host_stream_) { - xla::Backend* backend = client()->mutable_backend(); - TF_ASSIGN_OR_RETURN(device_to_host_stream_, - backend->BorrowStream(device_ordinal_)); +Status XlaDevice::EnsureStreamOkLocked(xla::Backend* backend, + const string& name, + std::shared_ptr* stream, + bool* stream_was_changed) { + if (!(*stream) || !(*stream)->ok()) { + xla::StreamPool::Ptr ptr; + TF_ASSIGN_OR_RETURN(ptr, backend->BorrowStream(device_ordinal_)); + *stream = std::shared_ptr(std::move(ptr)); + VLOG(1) << "XlaDevice " << this << " new " << name << " " + << (*stream)->DebugStreamPointers(); + *stream_was_changed = true; } - return device_to_host_stream_.get(); + return Status::OK(); } -xla::StatusOr XlaDevice::GetHostToDeviceStream() { - if (!use_multiple_streams_) { - return GetStream(); +xla::StatusOr XlaDevice::GetDeviceContextLocked() { + xla::Backend* backend = client()->mutable_backend(); + + // Ensure all our streams are valid, borrowing new streams if necessary. + bool need_new_device_context = !device_context_; + TF_RETURN_IF_ERROR(EnsureStreamOkLocked(backend, "stream", &stream_, + &need_new_device_context)); + + std::shared_ptr host_to_device_stream = stream_; + std::shared_ptr device_to_host_stream = stream_; + if (use_multiple_streams_) { + TF_RETURN_IF_ERROR(EnsureStreamOkLocked(backend, "host_to_device_stream", + &host_to_device_stream_, + &need_new_device_context)); + TF_RETURN_IF_ERROR(EnsureStreamOkLocked(backend, "device_to_host_stream", + &device_to_host_stream_, + &need_new_device_context)); + host_to_device_stream = host_to_device_stream_; + device_to_host_stream = device_to_host_stream_; } - if (!host_to_device_stream_) { - xla::Backend* backend = client()->mutable_backend(); - TF_ASSIGN_OR_RETURN(host_to_device_stream_, - backend->BorrowStream(device_ordinal_)); + + if (!need_new_device_context) { + return device_context_; } - return host_to_device_stream_.get(); -} -Status XlaDevice::CreateAndSetGpuDeviceInfo() { - if (gpu_device_info_ == nullptr) { - TF_ASSIGN_OR_RETURN(se::Stream * stream, GetStream()); - // Call GetAllocator for the side-effect of ensuring the allocator - // is created. - GetAllocator({}); - // XlaDevice owns both gpu_device_info_ and - // gpu_device_info_->default_context. - gpu_device_info_ = MakeUnique(); - gpu_device_info_->stream = stream; - gpu_device_info_->default_context = - new XlaDeviceContext(stream, stream, stream, client(), - transfer_as_literal_, shape_representation_fn_); - set_tensorflow_gpu_device_info(gpu_device_info_.get()); + // At this point we know we need a new device context. + // Call GetAllocator for the side-effect of ensuring the allocator is created. + GetAllocatorLocked({}); + if (device_context_) { + device_context_->Unref(); + } + // The XlaDeviceContext keeps a reference count to the streams, and the + // XlaDeviceContext remains live for the duration of a Executor run. This + // ensures that the streams remain live for the duration of a run, even if + // an error is encountered and the streams are replaced with new ones. + device_context_ = new XlaDeviceContext( + stream_, host_to_device_stream, device_to_host_stream, client(), + transfer_as_literal_, shape_representation_fn_, thread_pool_.get()); + VLOG(1) << "XlaDevice " << this << " new XlaDeviceContext " + << device_context_; + + // Create and set a new GpuDeviceInfo, if necessary. + // + // TODO(b/78232898): This isn't thread-safe; there is a race between the call + // to set_tensorflow_gpu_device_info() with ops that call the getter + // tensorflow_gpu_device_info(). This isn't trivially fixed by adding locking + // to those methods; see the bug for details. Our only saving grace at the + // moment is that this race doesn't seem to occur in practice. + if (use_gpu_device_info_) { + auto gpu_device_info = absl::make_unique(); + gpu_device_info->stream = stream_.get(); + gpu_device_info->default_context = device_context_; + set_tensorflow_gpu_device_info(gpu_device_info.get()); + gpu_device_info_ = std::move(gpu_device_info); + VLOG(1) << "XlaDevice " << this << " new GpuDeviceInfo " + << gpu_device_info_.get(); } - return Status::OK(); + return device_context_; +} + +Status XlaDevice::UseGpuDeviceInfo() { + mutex_lock lock(mu_); + use_gpu_device_info_ = true; + return GetDeviceContextLocked().status(); } Status XlaDevice::FillContextMap(const Graph* graph, DeviceContextMap* device_context_map) { VLOG(1) << "XlaDevice::FillContextMap"; - device_context_map->resize(graph->num_node_ids()); - TF_ASSIGN_OR_RETURN(se::Stream * stream, GetStream()); - TF_ASSIGN_OR_RETURN(se::Stream * device_to_host_stream, - GetDeviceToHostStream()); - TF_ASSIGN_OR_RETURN(se::Stream * host_to_device_stream, - GetHostToDeviceStream()); + mutex_lock lock(mu_); + TF_ASSIGN_OR_RETURN(XlaDeviceContext * device_context, + GetDeviceContextLocked()); - // Call GetAllocator for the side-effect of ensuring the allocator is created. - GetAllocator({}); - auto ctx = new XlaDeviceContext( - stream, host_to_device_stream, device_to_host_stream, client(), - transfer_as_literal_, shape_representation_fn_); + device_context_map->resize(graph->num_node_ids()); for (Node* n : graph->nodes()) { VLOG(2) << n->id() << " : " << n->type_string() << " : " << n->name(); - ctx->Ref(); - (*device_context_map)[n->id()] = ctx; + device_context->Ref(); + (*device_context_map)[n->id()] = device_context; } - ctx->Unref(); return Status::OK(); } void XlaDevice::Compute(OpKernel* op_kernel, OpKernelContext* context) { - VLOG(1) << "XlaDevice::Compute " << op_kernel->name() << ":" + VLOG(2) << "XlaDevice::Compute " << op_kernel->name() << ":" << op_kernel->type_string(); // When Xprof profiling is off (which is the default), constructing the // activity is simple enough that its overhead is negligible. @@ -336,13 +374,29 @@ void XlaDevice::Compute(OpKernel* op_kernel, OpKernelContext* context) { void XlaDevice::ComputeAsync(AsyncOpKernel* op_kernel, OpKernelContext* context, AsyncOpKernel::DoneCallback done) { - VLOG(1) << "XlaDevice::ComputeAsync " << op_kernel->name() << ":" + VLOG(2) << "XlaDevice::ComputeAsync " << op_kernel->name() << ":" << op_kernel->type_string(); tracing::ScopedActivity activity(op_kernel->name(), op_kernel->type_string(), op_kernel->IsExpensive()); op_kernel->ComputeAsync(context, done); } +Status XlaDevice::Sync() { + VLOG(1) << "XlaDevice::Sync"; + std::shared_ptr stream; + { + mutex_lock lock(mu_); + stream = stream_; + } + if (!stream) return Status::OK(); + + if (!stream->parent()->SynchronizeAllActivity() || !stream->ok()) { + return errors::Internal("XlaDevice::Sync() failed."); + } + VLOG(1) << "XlaDevice::Sync completed"; + return Status::OK(); +} + Status XlaDevice::MakeTensorFromProto(const TensorProto& tensor_proto, const AllocatorAttributes alloc_attrs, Tensor* tensor) { @@ -358,21 +412,17 @@ Status XlaDevice::MakeTensorFromProto(const TensorProto& tensor_proto, if (alloc_attrs.on_host()) { *tensor = parsed; } else { - Tensor copy(GetAllocator(alloc_attrs), parsed.dtype(), parsed.shape()); + mutex_lock lock(mu_); + TF_ASSIGN_OR_RETURN(XlaDeviceContext * device_context, + GetDeviceContextLocked()); + Allocator* allocator = GetAllocatorLocked(alloc_attrs); + Tensor copy(allocator, parsed.dtype(), parsed.shape()); Notification n; - TF_ASSIGN_OR_RETURN(se::Stream * stream, GetStream()); - TF_ASSIGN_OR_RETURN(se::Stream * device_to_host_stream, - GetDeviceToHostStream()); - TF_ASSIGN_OR_RETURN(se::Stream * host_to_device_stream, - GetHostToDeviceStream()); - XlaTransferManager manager(stream, host_to_device_stream, - device_to_host_stream, client(), - transfer_as_literal_, shape_representation_fn_); - manager.CopyCPUTensorToDevice(&parsed, this, ©, - [&n, &status](const Status& s) { - status = s; - n.Notify(); - }); + device_context->CopyCPUTensorToDevice(&parsed, this, ©, + [&n, &status](const Status& s) { + status = s; + n.Notify(); + }); n.WaitForNotification(); *tensor = copy; } diff --git a/tensorflow/compiler/jit/xla_device.h b/tensorflow/compiler/jit/xla_device.h index fccdb143680353ccbe3106bd48aa297980179d55..dbf35f349f84268ebac0f73a86c9ca0704e90835 100644 --- a/tensorflow/compiler/jit/xla_device.h +++ b/tensorflow/compiler/jit/xla_device.h @@ -25,6 +25,7 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_JIT_XLA_DEVICE_H_ #define TENSORFLOW_COMPILER_JIT_XLA_DEVICE_H_ +#include "tensorflow/compiler/jit/xla_device_context.h" #include "tensorflow/compiler/jit/xla_tensor.h" #include "tensorflow/compiler/tf2xla/xla_compiler.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" @@ -39,6 +40,7 @@ limitations under the License. #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/types.h" #include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" namespace tensorflow { @@ -116,62 +118,88 @@ class XlaDevice : public LocalDevice { const PaddedShapeFn& padded_shape_fn); ~XlaDevice() override; - Allocator* GetAllocator(AllocatorAttributes attr) override; + Allocator* GetAllocator(AllocatorAttributes attr) override + LOCKS_EXCLUDED(mu_); void Compute(OpKernel* op_kernel, OpKernelContext* context) override; void ComputeAsync(AsyncOpKernel* op_kernel, OpKernelContext* context, AsyncOpKernel::DoneCallback done) override; - Status Sync() override { return Status::OK(); } + Status Sync() override; Status FillContextMap(const Graph* graph, - DeviceContextMap* device_context_map) override; + DeviceContextMap* device_context_map) override + LOCKS_EXCLUDED(mu_); Status MakeTensorFromProto(const TensorProto& tensor_proto, const AllocatorAttributes alloc_attrs, - Tensor* tensor) override; + Tensor* tensor) override LOCKS_EXCLUDED(mu_); - xla::LocalClient* client() const; const Metadata& metadata() { return xla_metadata_; } - xla::StatusOr GetStream(); - xla::StatusOr GetHostToDeviceStream(); - xla::StatusOr GetDeviceToHostStream(); - // If not already set, create and set GpuDeviceInfo. - // Not thread-safe - Status CreateAndSetGpuDeviceInfo(); + // Ensures the DeviceContext associated with this XlaDevice is created and + // valid (i.e. all streams are ok). If any state is not valid, a new + // DeviceContext will be created. + // + // TODO(b/111859745): The Eager context needs to call this method to recover + // from failures. + Status EnsureDeviceContextOk() LOCKS_EXCLUDED(mu_); + + // Instructs this XlaDevice to set a GpuDeviceInfo, which holds extra + // information for GPU and TPU devices. + Status UseGpuDeviceInfo() LOCKS_EXCLUDED(mu_); private: + xla::LocalClient* client() const; + Allocator* GetAllocatorLocked(AllocatorAttributes attr) + EXCLUSIVE_LOCKS_REQUIRED(mu_); + Status EnsureStreamOkLocked(xla::Backend* backend, const string& name, + std::shared_ptr* stream, + bool* stream_was_changed) + EXCLUSIVE_LOCKS_REQUIRED(mu_); + xla::StatusOr GetDeviceContextLocked() + EXCLUSIVE_LOCKS_REQUIRED(mu_); + + mutex mu_; // The metadata of this XlaDevice. const Metadata xla_metadata_; // Which hardware device in the client's platform this XlaDevice controls. const int device_ordinal_; // The name of the device that is used to compile Ops for this XlaDevice. - DeviceType jit_device_name_; + const DeviceType jit_device_name_; + // The platform for this device. + se::Platform* const platform_; // Not owned. // Memory allocator associated with this device. - Allocator* xla_allocator_; // Not owned. - se::Platform* platform_; // Not owned. + Allocator* xla_allocator_ GUARDED_BY(mu_) = nullptr; // Not owned. // Stream associated with this device. Operations enqueued on this // stream are executed on the device. Operations include data // copying back and forth between CPU and the device, and // computations enqueued by XLA. - xla::Backend::StreamPtr stream_; - // If true, only stream_ is valid and all computation and transfers use - // stream_. If false, computation is performed by stream_ and transfers are + std::shared_ptr stream_ GUARDED_BY(mu_); + // If false, only stream_ is valid and all computation and transfers use + // stream_. If true, computation is performed by stream_ and transfers are // performed by host_to_device/device_to_host_stream. - bool use_multiple_streams_; + const bool use_multiple_streams_; // If use_multiple_streams_, host to device transfers are performed using this // stream. - xla::Backend::StreamPtr host_to_device_stream_; + std::shared_ptr host_to_device_stream_ GUARDED_BY(mu_); // If use_multiple_streams_, device to host transfers are performed using this // stream. - xla::Backend::StreamPtr device_to_host_stream_; + std::shared_ptr device_to_host_stream_ GUARDED_BY(mu_); // Must we use XLA's transfer manager for correct host<->device transfers? if // false, we can use ThenMemcpy() instead. - bool transfer_as_literal_; - XlaCompiler::ShapeRepresentationFn shape_representation_fn_; + const bool transfer_as_literal_; + const XlaCompiler::ShapeRepresentationFn shape_representation_fn_; + + // The device context accessed by all users of the XlaDevice, set by calls to + // EnsureDeviceContextOk. If gpu_device_info_ is non-null, this pointer is + // also filled in to that struct. XlaDeviceContext is a ref-counted object. + XlaDeviceContext* device_context_ GUARDED_BY(mu_) = nullptr; + + // Holds extra information for GPU and TPU devices, e.g. the device context. + bool use_gpu_device_info_ GUARDED_BY(mu_) = false; + std::unique_ptr gpu_device_info_ GUARDED_BY(mu_); - // If set, holds default device context (that we must Unref) - // and its stream. - std::unique_ptr gpu_device_info_; + // Thread pool used for running closures + std::unique_ptr thread_pool_; }; // Builds OpKernel registrations on 'device' for the JIT operators diff --git a/tensorflow/compiler/jit/xla_device_context.cc b/tensorflow/compiler/jit/xla_device_context.cc index 8cf198239c84c3720585f53ebc95876ce4396793..0a0c0892411e8ebcd5624a29f3bd020fe6483944 100644 --- a/tensorflow/compiler/jit/xla_device_context.cc +++ b/tensorflow/compiler/jit/xla_device_context.cc @@ -15,6 +15,9 @@ limitations under the License. #include "tensorflow/compiler/jit/xla_device_context.h" +#include + +#include "tensorflow/compiler/jit/xla_device.h" #include "tensorflow/compiler/jit/xla_launch_util.h" #include "tensorflow/compiler/tf2xla/literal_util.h" #include "tensorflow/compiler/tf2xla/shape_util.h" @@ -48,17 +51,20 @@ void XlaDeviceAllocator::DeallocateRaw(void* ptr) { void XlaDeviceAllocator::GetStats(AllocatorStats* stats) { stats->Clear(); } XlaTransferManager::XlaTransferManager( - se::Stream* compute_stream, se::Stream* host_to_device_stream, - se::Stream* device_to_host_stream, xla::LocalClient* client, + std::shared_ptr compute_stream, + std::shared_ptr host_to_device_stream, + std::shared_ptr device_to_host_stream, xla::LocalClient* client, bool transfer_as_literal, - XlaCompiler::ShapeRepresentationFn shape_representation_fn) - : stream_(compute_stream), - host_to_device_stream_(host_to_device_stream), - device_to_host_stream_(device_to_host_stream), + XlaCompiler::ShapeRepresentationFn shape_representation_fn, + thread::ThreadPool* thread_pool) + : stream_(std::move(compute_stream)), + host_to_device_stream_(std::move(host_to_device_stream)), + device_to_host_stream_(std::move(device_to_host_stream)), client_(client), transfer_manager_(client->backend().transfer_manager()), transfer_as_literal_(transfer_as_literal), - shape_representation_fn_(std::move(shape_representation_fn)) { + shape_representation_fn_(std::move(shape_representation_fn)), + thread_pool_(thread_pool) { CHECK(host_to_device_stream_ != nullptr); CHECK(device_to_host_stream_ != nullptr); CHECK(stream_ != nullptr); @@ -88,47 +94,40 @@ Status XlaTransferManager::TransferLiteralToDevice( if (UseMultipleStreams()) { // Initially wait for the compute stream so that memory allocations are // synchronized. - host_to_device_stream_->ThenWaitFor(stream_); + host_to_device_stream_->ThenWaitFor(stream_.get()); } TF_RETURN_IF_ERROR(transfer_manager_->TransferLiteralToDeviceAsync( - host_to_device_stream_, *literal, shaped_buffer)); + host_to_device_stream_.get(), *literal, shaped_buffer)); if (UseMultipleStreams()) { - se::Event event(stream_->parent()); - TF_RET_CHECK(event.Init()) << "Event failed to initialize!"; - host_to_device_stream_->ThenRecordEvent(&event); - xla_tensor->SetDefinedOn(host_to_device_stream_, std::move(event)); + auto event = std::make_shared(stream_->parent()); + TF_RET_CHECK(event->Init()) << "Event failed to initialize!"; + host_to_device_stream_->ThenRecordEvent(event.get()); + xla_tensor->SetDefinedOn(host_to_device_stream_.get(), std::move(event)); } // Unref the host tensor, and capture the literal shared_ptr too so it goes // out of scope when the lambda completes. host_to_device_stream_->ThenDoHostCallback([ref, literal]() { ref.Unref(); }); + return Status::OK(); } void XlaTransferManager::TransferLiteralFromDevice( Tensor* host_tensor, const Tensor& device_tensor, const StatusCallback& done) const { + xla::MutableBorrowingLiteral literal; + TF_CHECK_OK(HostTensorToMutableBorrowingLiteral(host_tensor, &literal)); + const xla::ShapedBuffer& shaped_buffer = XlaTensor::FromTensor(&device_tensor)->shaped_buffer(); TensorReference ref(device_tensor); transfer_manager_->TransferLiteralFromDevice( - device_to_host_stream_, shaped_buffer, - [=, &shaped_buffer]( - xla::StatusOr > literal_or) { + device_to_host_stream_.get(), shaped_buffer, literal, + [=, &shaped_buffer, &literal](xla::Status status) { ref.Unref(); done([&]() -> Status { - TF_ASSIGN_OR_RETURN(auto literal, std::move(literal_or)); - VLOG(1) << "Transfer from device as literal: " << literal->ToString() + VLOG(1) << "Transfer from device as literal: " << literal.ToString() << " " << shaped_buffer.ToString(); - Tensor tensor; - TF_RETURN_IF_ERROR( - LiteralToHostTensor(*literal, host_tensor->dtype(), &tensor)); - // Reshape the tensor back to its declared shape. - Status status; - if (!host_tensor->CopyFrom(tensor, device_tensor.shape())) { - status = errors::Internal( - "Tensor::CopyFrom failed when copying from XLA device to CPU"); - } return status; }()); }); @@ -186,8 +185,14 @@ void XlaTransferManager::CopyCPUTensorToDevice(const Tensor* cpu_tensor, status = TransferLiteralToDevice(reshaped_cpu_tensor, device_tensor); if (status.ok()) { xla_tensor->set_host_tensor(*cpu_tensor); - host_to_device_stream_->ThenDoHostCallback( - [done]() { done(Status::OK()); }); + host_to_device_stream_->ThenDoHostCallback([this, done]() { + // We must not call the done closure directly from DoHostCallback + // to avoid a deadlock. If done() is the callback that ends an + // Executor's run, the Executor may call XlaDevice::Sync() inside the + // callback. This deadlocks, because XlaDevice::Sync() waits for all + // stream activity to complete. + thread_pool_->Schedule([done]() { done(Status::OK()); }); + }); return; } } else { @@ -199,7 +204,7 @@ void XlaTransferManager::CopyCPUTensorToDevice(const Tensor* cpu_tensor, if (!block_status.ok()) { status = xla::InternalError( "Failed to complete data transfer on stream %p: %s", - host_to_device_stream_, block_status.error_message().c_str()); + host_to_device_stream_.get(), block_status.error_message().c_str()); } } xla_tensor->set_host_tensor(*cpu_tensor); @@ -232,9 +237,9 @@ void XlaTransferManager::CopyDeviceTensorToCPU(const Tensor* device_tensor, XlaTensor* xla_tensor = XlaTensor::FromTensor(device_tensor); if (se::Event* event = - xla_tensor->GetDefinitionEvent(device_to_host_stream_)) { + xla_tensor->GetDefinitionEvent(device_to_host_stream_.get())) { device_to_host_stream_->ThenWaitFor(event); - xla_tensor->SetDefinedOn(device_to_host_stream_); + xla_tensor->SetDefinedOn(device_to_host_stream_.get()); } Status status; @@ -247,7 +252,7 @@ void XlaTransferManager::CopyDeviceTensorToCPU(const Tensor* device_tensor, Status block_status = device_to_host_stream_->BlockHostUntilDone(); if (!block_status.ok()) { status = xla::InternalError( - "Failed to complete data transfer on stream %p: %s", stream_, + "Failed to complete data transfer on stream %p: %s", stream_.get(), block_status.error_message().c_str()); } } @@ -285,14 +290,14 @@ void XlaTransferManager::CopyDeviceTensorToDevice(const Tensor& src_tensor, if (stream_ != device_to_device_stream) { // Initially wait for the compute stream so that memory allocations are // synchronized. - device_to_device_stream->ThenWaitFor(stream_); + device_to_device_stream->ThenWaitFor(stream_.get()); } } if (se::Event* event = - xla_src->GetDefinitionEvent(device_to_device_stream)) { + xla_src->GetDefinitionEvent(device_to_device_stream.get())) { device_to_device_stream->ThenWaitFor(event); - xla_src->SetDefinedOn(device_to_device_stream); + xla_src->SetDefinedOn(device_to_device_stream.get()); } auto from_iter = xla_src->shaped_buffer().buffers().begin(); @@ -304,28 +309,37 @@ void XlaTransferManager::CopyDeviceTensorToDevice(const Tensor& src_tensor, } if (UseMultipleStreams()) { - se::Event event(stream_->parent()); - CHECK(event.Init()); - device_to_device_stream->ThenRecordEvent(&event); - xla_dst->SetDefinedOn(device_to_device_stream, std::move(event)); + auto event = std::make_shared(stream_->parent()); + TF_RET_CHECK(event->Init()) << "Event failed to initialize"; + device_to_device_stream->ThenRecordEvent(event.get()); + xla_dst->SetDefinedOn(device_to_device_stream.get(), std::move(event)); } return Status::OK(); }(); if (!status.ok()) { return done(status); } else { - stream_->ThenDoHostCallback([=]() { done(Status::OK()); }); + stream_->ThenDoHostCallback([this, done]() { + // We must not call the done closure directly from DoHostCallback to avoid + // a deadlock. If done() is the callback that ends an Executor's run, the + // Executor may call XlaDevice::Sync() inside the callback. This + // deadlocks, because XlaDevice::Sync() waits for all stream activity to + // complete. + thread_pool_->Schedule([done]() { done(Status::OK()); }); + }); } } XlaDeviceContext::XlaDeviceContext( - se::Stream* compute_stream, se::Stream* host_to_device_stream, - se::Stream* device_to_host_stream, xla::LocalClient* client, + std::shared_ptr compute_stream, + std::shared_ptr host_to_device_stream, + std::shared_ptr device_to_host_stream, xla::LocalClient* client, bool transfer_as_literal, - XlaCompiler::ShapeRepresentationFn shape_representation_fn) - : manager_(compute_stream, host_to_device_stream, device_to_host_stream, - client, transfer_as_literal, - std::move(shape_representation_fn)) {} + XlaCompiler::ShapeRepresentationFn shape_representation_fn, + thread::ThreadPool* thread_pool) + : manager_(std::move(compute_stream), std::move(host_to_device_stream), + std::move(device_to_host_stream), client, transfer_as_literal, + std::move(shape_representation_fn), thread_pool) {} void XlaDeviceContext::CopyCPUTensorToDevice(const Tensor* cpu_tensor, Device* device, diff --git a/tensorflow/compiler/jit/xla_device_context.h b/tensorflow/compiler/jit/xla_device_context.h index 912f8d779e72f44821bc4fb25efa30bd35d01412..2e7445340cbaf788bfd06260f4376596895231c1 100644 --- a/tensorflow/compiler/jit/xla_device_context.h +++ b/tensorflow/compiler/jit/xla_device_context.h @@ -47,10 +47,12 @@ class XlaDeviceAllocator : public Allocator { class XlaTransferManager { public: explicit XlaTransferManager( - se::Stream* compute_stream, se::Stream* host_to_device_stream, - se::Stream* device_to_host_stream, xla::LocalClient* client, - bool transfer_as_literal, - XlaCompiler::ShapeRepresentationFn shape_representation_fn); + std::shared_ptr compute_stream, + std::shared_ptr host_to_device_stream, + std::shared_ptr device_to_host_stream, + xla::LocalClient* client, bool transfer_as_literal, + XlaCompiler::ShapeRepresentationFn shape_representation_fn, + thread::ThreadPool* thread_pool); void CopyCPUTensorToDevice(const Tensor* cpu_tensor, Device* device, Tensor* device_tensor, StatusCallback done) const; @@ -61,7 +63,7 @@ class XlaTransferManager { void CopyDeviceTensorToDevice(const Tensor& src_tensor, Tensor* dst_tensor, const StatusCallback& done); - se::Stream* stream() const { return stream_; } + se::Stream* stream() const { return stream_.get(); } private: Status TransferLiteralToDevice(const Tensor& host_tensor, @@ -73,13 +75,13 @@ class XlaTransferManager { // The main compute stream of the device, used to synchronize the transfer // streams if they are set. - se::Stream* stream_; + std::shared_ptr stream_; // The stream to use for transferring data from host to device. Can be // idential to stream_, but must not be nullptr. - se::Stream* host_to_device_stream_; + std::shared_ptr host_to_device_stream_; // The stream to use for transferring data from device to host. Can be // idential to stream_, but must not be nullptr. - se::Stream* device_to_host_stream_; + std::shared_ptr device_to_host_stream_; // For the underlying memory allocator and XLA's TransferManager. xla::LocalClient* client_; // Transfer manager, for marshalling data to and from the device. @@ -87,6 +89,9 @@ class XlaTransferManager { // True if we must use XLA's TransferManager for correct device transfers. const bool transfer_as_literal_; XlaCompiler::ShapeRepresentationFn shape_representation_fn_; + + // Thread pool used for running closures + thread::ThreadPool* thread_pool_; }; // DeviceContext for operators assigned to XlaDevice devices. The @@ -95,10 +100,12 @@ class XlaTransferManager { class XlaDeviceContext : public DeviceContext { public: explicit XlaDeviceContext( - se::Stream* compute_stream, se::Stream* host_to_device_stream, - se::Stream* device_to_host_stream, xla::LocalClient* client, - bool transfer_as_literal, - XlaCompiler::ShapeRepresentationFn shape_representation_fn); + std::shared_ptr compute_stream, + std::shared_ptr host_to_device_stream, + std::shared_ptr device_to_host_stream, + xla::LocalClient* client, bool transfer_as_literal, + XlaCompiler::ShapeRepresentationFn shape_representation_fn, + thread::ThreadPool* thread_pool); void CopyCPUTensorToDevice(const Tensor* cpu_tensor, Device* device, Tensor* device_tensor, diff --git a/tensorflow/compiler/jit/xla_device_ops.h b/tensorflow/compiler/jit/xla_device_ops.h index 6adda327f186a607b4e7371bf4c5071dd86582da..da3e329247e825d4a33a53dc310899d6ba6ce9cf 100644 --- a/tensorflow/compiler/jit/xla_device_ops.h +++ b/tensorflow/compiler/jit/xla_device_ops.h @@ -23,7 +23,11 @@ limitations under the License. #include "tensorflow/core/kernels/cast_op.h" #include "tensorflow/core/kernels/constant_op.h" #include "tensorflow/core/kernels/control_flow_ops.h" +#include "tensorflow/core/kernels/data/generator_dataset_op.h" +#include "tensorflow/core/kernels/data/iterator_ops.h" +#include "tensorflow/core/kernels/data/prefetch_dataset_op.h" #include "tensorflow/core/kernels/fifo_queue.h" +#include "tensorflow/core/kernels/function_ops.h" #include "tensorflow/core/kernels/identity_n_op.h" #include "tensorflow/core/kernels/identity_op.h" #include "tensorflow/core/kernels/no_op.h" @@ -166,7 +170,69 @@ class XlaAssignVariableOp : public AsyncOpKernel { QueueIsClosedOp); \ \ REGISTER_KERNEL_BUILDER( \ - Name("FIFOQueueV2").Device(DEVICE).HostMemory("handle"), FIFOQueueOp); + Name("FIFOQueueV2").Device(DEVICE).HostMemory("handle"), FIFOQueueOp); \ + \ + REGISTER_KERNEL_BUILDER( \ + Name(kArgOp).Device(DEVICE).HostMemory("output").TypeConstraint("T", \ + TYPES), \ + ArgOp); \ + REGISTER_KERNEL_BUILDER(Name(kArgOp) \ + .Device(DEVICE) \ + .HostMemory("output") \ + .TypeConstraint("T"), \ + ArgOp); \ + \ + REGISTER_KERNEL_BUILDER(Name(kRetOp) \ + .Device(DEVICE) \ + .TypeConstraint("T", TYPES) \ + .HostMemory("input"), \ + RetvalOp); \ + REGISTER_KERNEL_BUILDER(Name(kRetOp) \ + .Device(DEVICE) \ + .TypeConstraint("T") \ + .HostMemory("input"), \ + RetvalOp); \ + \ + REGISTER_KERNEL_BUILDER( \ + Name("RemoteCall").Device(DEVICE).HostMemory("target"), RemoteCallOp); \ + \ + REGISTER_KERNEL_BUILDER( \ + Name("GeneratorDataset").Device(DEVICE).HostMemory("handle"), \ + GeneratorDatasetOp); \ + REGISTER_KERNEL_BUILDER(Name("PrefetchDataset") \ + .Device(DEVICE) \ + .HostMemory("buffer_size") \ + .HostMemory("input_dataset") \ + .HostMemory("handle"), \ + PrefetchDatasetOp); \ + \ + REGISTER_KERNEL_BUILDER(Name("IteratorV2").Device(DEVICE), \ + IteratorHandleOp); \ + REGISTER_KERNEL_BUILDER( \ + Name("MakeIterator").Device(DEVICE).HostMemory("dataset"), \ + MakeIteratorOp); \ + REGISTER_KERNEL_BUILDER(Name("AnonymousIterator").Device(DEVICE), \ + AnonymousIteratorHandleOp); \ + REGISTER_KERNEL_BUILDER(Name("IteratorGetNext").Device(DEVICE), \ + IteratorGetNextOp); \ + REGISTER_KERNEL_BUILDER(Name("IteratorToStringHandle") \ + .Device(DEVICE) \ + .HostMemory("string_handle"), \ + IteratorToStringHandleOp); \ + REGISTER_KERNEL_BUILDER(Name("IteratorFromStringHandleV2") \ + .Device(DEVICE) \ + .HostMemory("string_handle"), \ + IteratorFromStringHandleOp); \ + REGISTER_KERNEL_BUILDER(Name(FunctionLibraryDefinition::kArgOp) \ + .Device(DEVICE) \ + .HostMemory("output") \ + .TypeConstraint("T"), \ + ArgOp); \ + REGISTER_KERNEL_BUILDER(Name(FunctionLibraryDefinition::kRetOp) \ + .Device(DEVICE) \ + .TypeConstraint("T") \ + .HostMemory("input"), \ + RetvalOp); // TODO(phawkins): currently we do not register the QueueEnqueueMany, // QueueDequeueMany, or QueueDequeueUpTo kernels because they attempt to read diff --git a/tensorflow/compiler/jit/xla_gpu_device.cc b/tensorflow/compiler/jit/xla_gpu_device.cc index 851b118b0c18cfd752302b8f8dec27dae3e12acd..ef4466f0056ea98adc1ae6774105466af0d14293 100644 --- a/tensorflow/compiler/jit/xla_gpu_device.cc +++ b/tensorflow/compiler/jit/xla_gpu_device.cc @@ -59,7 +59,7 @@ Status XlaGpuDeviceFactory::CreateDevices(const SessionOptions& options, } // TODO(b/78468222): Uncomment after fixing this bug - // status = device->CreateAndSetGpuDeviceInfo(); + // status = device->UseGpuDeviceInfo(); // if (!status.ok()) { // errors::AppendToMessage(&status, "while setting up ", DEVICE_GPU_XLA_JIT, // " device"); diff --git a/tensorflow/compiler/jit/xla_launch_util.cc b/tensorflow/compiler/jit/xla_launch_util.cc index 6134b8c6946429918a5ca37188cbff13a6cd1c79..2ffce9298d99e1e136e15e9a4b0e3f5b26121bd5 100644 --- a/tensorflow/compiler/jit/xla_launch_util.cc +++ b/tensorflow/compiler/jit/xla_launch_util.cc @@ -15,6 +15,9 @@ limitations under the License. #include "tensorflow/compiler/jit/xla_launch_util.h" +#include + +#include "absl/memory/memory.h" #include "tensorflow/compiler/jit/defs.h" #include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/tf2xla/xla_compiler.h" @@ -173,7 +176,7 @@ void XlaComputationLaunchContext::PopulateInputs( << " not the same as on-host shape " << xla::ShapeUtil::HumanStringWithLayout(shape); se::DeviceMemoryBase dmem = XlaTensor::DeviceMemoryFromTensor(*t); - arg_buffers_[i] = xla::MakeUnique( + arg_buffers_[i] = absl::make_unique( /*on_host_shape=*/shape, /*on_device_shape=*/shape, client_->platform(), client_->default_device_ordinal()); arg_buffers_[i]->set_buffer(dmem, /*index=*/{}); @@ -182,7 +185,7 @@ void XlaComputationLaunchContext::PopulateInputs( } } -void XlaComputationLaunchContext::PopulateOutputs( +Status XlaComputationLaunchContext::PopulateOutputs( OpKernelContext* ctx, const XlaCompiler::CompilationResult* kernel, ScopedShapedBuffer output) { se::Stream* stream = @@ -211,6 +214,15 @@ void XlaComputationLaunchContext::PopulateOutputs( output = ScopedShapedBuffer(std::move(buffer), output.memory_allocator()); } + std::shared_ptr definition_event; + if (use_multiple_streams_) { + definition_event = std::make_shared(stream->parent()); + if (!definition_event->Init()) { + return errors::Internal("Failed to initialize tensor definition event."); + } + stream->ThenRecordEvent(definition_event.get()); + } + // Copy XLA results to the OpOutputList. int output_num = 0; for (int i = 0; i < ctx->num_outputs(); ++i) { @@ -228,12 +240,13 @@ void XlaComputationLaunchContext::PopulateOutputs( // reallocate the device buffer later. VLOG(1) << "Constant output tensor on device"; - OP_REQUIRES_OK( - ctx, ctx->allocate_output(i, const_tensor.shape(), &output_tensor)); + TF_RETURN_IF_ERROR( + ctx->allocate_output(i, const_tensor.shape(), &output_tensor)); Device* device = dynamic_cast(ctx->device()); - OP_REQUIRES(ctx, device != nullptr, - errors::Internal("DeviceBase was not a Device.")); + if (device == nullptr) { + return errors::Internal("DeviceBase was not a Device."); + } ctx->op_device_context()->CopyCPUTensorToDevice( &const_tensor, device, output_tensor, [&](Status status) { TF_CHECK_OK(status); }); @@ -263,16 +276,13 @@ void XlaComputationLaunchContext::PopulateOutputs( se::DeviceMemoryBase buffer = output.buffer({output_num}); if (allocate_xla_tensors_) { Tensor* output_tensor; - OP_REQUIRES_OK(ctx, ctx->allocate_output(i, shape, &output_tensor)); + TF_RETURN_IF_ERROR(ctx->allocate_output(i, shape, &output_tensor)); XlaTensor* xla_tensor = XlaTensor::FromTensor(output_tensor); if (xla_tensor) { xla_tensor->set_shaped_buffer(ScopedShapedBuffer( ExtractSubShapedBuffer(&output, output_num, xla_allocator_))); if (use_multiple_streams_) { - se::Event event(stream->parent()); - CHECK(event.Init()); - stream->ThenRecordEvent(&event); - xla_tensor->SetDefinedOn(stream, std::move(event)); + xla_tensor->SetDefinedOn(stream, definition_event); } } else { // xla_tensor wasn't valid, which must mean this is a zero-element @@ -298,41 +308,39 @@ void XlaComputationLaunchContext::PopulateOutputs( for (int i = 0; i < kernel->resource_updates.size(); ++i) { Allocator* allocator = ctx->device()->GetAllocator({}); const XlaCompiler::ResourceUpdate& write = kernel->resource_updates[i]; - OP_REQUIRES(ctx, - write.input_index >= 0 && write.input_index < ctx->num_inputs(), - errors::Internal("Invalid input index for variable write.")); + if (write.input_index < 0 || write.input_index >= ctx->num_inputs()) { + return errors::Internal("Invalid input index for variable write."); + } se::DeviceMemoryBase buffer = output.buffer({output_num}); Var* variable = nullptr; // TODO(b/35625933): tensorflow::Var should contain a PersistentTensor, // not a Tensor. - OP_REQUIRES_OK(ctx, LookupOrCreateResource( - ctx, HandleFromInput(ctx, write.input_index), - &variable, [this, ctx, &write](Var** ptr) { - *ptr = new Var(write.type); - return Status::OK(); - })); + TF_RETURN_IF_ERROR(LookupOrCreateResource( + ctx, HandleFromInput(ctx, write.input_index), &variable, + [&write](Var** ptr) { + *ptr = new Var(write.type); + return Status::OK(); + })); core::ScopedUnref s(variable); mutex_lock ml(*variable->mu()); - OP_REQUIRES(ctx, variable->tensor()->dtype() == write.type, - errors::Internal("Mismatched type in variable write")); + if (variable->tensor()->dtype() != write.type) { + return errors::Internal("Mismatched type in variable write"); + } if (allocate_xla_tensors_) { Tensor output_tensor; - OP_REQUIRES_OK( - ctx, ctx->allocate_temp(write.type, write.shape, &output_tensor)); + TF_RETURN_IF_ERROR( + ctx->allocate_temp(write.type, write.shape, &output_tensor)); XlaTensor* xla_tensor = XlaTensor::FromTensor(&output_tensor); CHECK(xla_tensor); xla_tensor->set_shaped_buffer( ExtractSubShapedBuffer(&output, output_num, xla_allocator_)); if (use_multiple_streams_) { - se::Event event(stream->parent()); - CHECK(event.Init()); - stream->ThenRecordEvent(&event); - xla_tensor->SetDefinedOn(stream, std::move(event)); + xla_tensor->SetDefinedOn(stream, definition_event); } *variable->tensor() = output_tensor; } else { @@ -343,6 +351,7 @@ void XlaComputationLaunchContext::PopulateOutputs( } ++output_num; } + return Status::OK(); } } // namespace tensorflow diff --git a/tensorflow/compiler/jit/xla_launch_util.h b/tensorflow/compiler/jit/xla_launch_util.h index 1ea3fa4cf29266e8c452385226e56bd0b82622d9..4232f514b3b48681bf510ee568f916f5f4ebe882 100644 --- a/tensorflow/compiler/jit/xla_launch_util.h +++ b/tensorflow/compiler/jit/xla_launch_util.h @@ -93,9 +93,9 @@ class XlaComputationLaunchContext { const std::map& variables); // Given the XLA output in `output`, populate all outputs of `ctx`. - void PopulateOutputs(OpKernelContext* ctx, - const XlaCompiler::CompilationResult* kernel, - xla::ScopedShapedBuffer output); + Status PopulateOutputs(OpKernelContext* ctx, + const XlaCompiler::CompilationResult* kernel, + xla::ScopedShapedBuffer output); // Return the argument list. Only valid after PopulateInputs() has been // called. diff --git a/tensorflow/compiler/jit/xla_tensor.cc b/tensorflow/compiler/jit/xla_tensor.cc index d777dfa5a34fb9615ddcf393ed53be1491cb70af..92ba7de1b7d32fcf693cd12a380d7a1e0d861d71 100644 --- a/tensorflow/compiler/jit/xla_tensor.cc +++ b/tensorflow/compiler/jit/xla_tensor.cc @@ -75,7 +75,7 @@ Status XlaTensor::AllocateShapedBuffer(DataType dtype, const TensorShape& shape, se::Event* XlaTensor::GetDefinitionEvent(se::Stream* stream) { mutex_lock lock(mu_); - if (!definition_event_.has_value()) { + if (!definition_event_) { return nullptr; } @@ -87,10 +87,11 @@ se::Event* XlaTensor::GetDefinitionEvent(se::Stream* stream) { return nullptr; } - return &*definition_event_; + return definition_event_.get(); } -void XlaTensor::SetDefinedOn(se::Stream* stream, se::Event event) { +void XlaTensor::SetDefinedOn(se::Stream* stream, + std::shared_ptr event) { mutex_lock lock(mu_); definition_event_ = std::move(event); streams_defined_on_ = {stream}; diff --git a/tensorflow/compiler/jit/xla_tensor.h b/tensorflow/compiler/jit/xla_tensor.h index f7e401c731163200c518074f2caa6907efb1f684..07a9bf0d4a7f732629afab97d2d69c9a7effa18d 100644 --- a/tensorflow/compiler/jit/xla_tensor.h +++ b/tensorflow/compiler/jit/xla_tensor.h @@ -16,6 +16,9 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_JIT_XLA_TENSOR_H_ #define TENSORFLOW_COMPILER_JIT_XLA_TENSOR_H_ +#include + +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/service/shaped_buffer.h" #include "tensorflow/core/framework/allocator.h" @@ -68,7 +71,7 @@ class XlaTensor { // Mutates the XlaTensor to set the ShapedBuffer. void set_shaped_buffer(xla::ScopedShapedBuffer shaped_buffer) { shaped_buffer_ = - xla::MakeUnique(std::move(shaped_buffer)); + absl::make_unique(std::move(shaped_buffer)); } // Some tensors on the device may have known values on the host. We use these @@ -94,7 +97,7 @@ class XlaTensor { // Assert that the tensor's content is defined on 'stream' by the time 'event' // triggers. - void SetDefinedOn(se::Stream* stream, se::Event event); + void SetDefinedOn(se::Stream* stream, std::shared_ptr event); // Assert that the tensor's content is defined on 'stream'. This version does // not provide an event, and must be called *after* SetDefinedOn(Stream, @@ -116,7 +119,7 @@ class XlaTensor { // An optional event that is triggered when the tensor's content has been // defined. If this event is nullptr, it is assumed that the tensor's content // is always defined. - gtl::optional definition_event_; + std::shared_ptr definition_event_; // A list of all streams for which the tensor's content is defined for any // newly enqueued command. gtl::InlinedVector streams_defined_on_ GUARDED_BY(mu_); diff --git a/tensorflow/compiler/tests/BUILD b/tensorflow/compiler/tests/BUILD index 080bed50e68ba353a5029f5eb959003b51327f4a..ae98b3f0f9d5dac66b9716ad84a9f0371511e9b6 100644 --- a/tensorflow/compiler/tests/BUILD +++ b/tensorflow/compiler/tests/BUILD @@ -673,6 +673,7 @@ tf_xla_py_test( "cpu", "cpu_ondemand", ], + shard_count = 5, tags = ["optonly"], deps = [ ":xla_test", @@ -690,11 +691,7 @@ tf_xla_py_test( size = "small", srcs = ["random_ops_test.py"], disabled_backends = [ - # TODO(b/110300529): RngNormal doesn't return values with the expected variance - "cpu", "cpu_ondemand", - # TODO(b/31361304): enable RNG ops on GPU when parallelized. - "gpu", ], deps = [ ":xla_test", @@ -1002,6 +999,7 @@ tf_xla_py_test( name = "sort_ops_test", size = "medium", srcs = ["sort_ops_test.py"], + shard_count = 5, # Times out in fastbuild mode. tags = ["optonly"], deps = [ diff --git a/tensorflow/compiler/tests/adam_test.py b/tensorflow/compiler/tests/adam_test.py index 03554d6933aca39b428c6af4be0c78e2c7ccb0c9..0d2e4d029636577adc74784d9a8b3494b94dc67d 100644 --- a/tensorflow/compiler/tests/adam_test.py +++ b/tensorflow/compiler/tests/adam_test.py @@ -52,6 +52,9 @@ class AdamOptimizerTest(xla_test.XLATestCase): def testBasic(self): for dtype in self.float_types: + # TODO: test fails for float16 due to excessive precision requirements. + if dtype == np.float16: + continue with self.test_session(), self.test_scope(): variable_scope.get_variable_scope().set_use_resource(True) @@ -91,6 +94,9 @@ class AdamOptimizerTest(xla_test.XLATestCase): def testTensorLearningRate(self): for dtype in self.float_types: + # TODO: test fails for float16 due to excessive precision requirements. + if dtype == np.float16: + continue with self.test_session(), self.test_scope(): variable_scope.get_variable_scope().set_use_resource(True) @@ -130,6 +136,9 @@ class AdamOptimizerTest(xla_test.XLATestCase): def testSharing(self): for dtype in self.float_types: + # TODO: test fails for float16 due to excessive precision requirements. + if dtype == np.float16: + continue with self.test_session(), self.test_scope(): variable_scope.get_variable_scope().set_use_resource(True) diff --git a/tensorflow/compiler/tests/binary_ops_test.py b/tensorflow/compiler/tests/binary_ops_test.py index 0aafda7fb4d710f154157ee352d6616e5aa8935f..5b7001b5a463ae0bd4e8f07032256717aab70d49 100644 --- a/tensorflow/compiler/tests/binary_ops_test.py +++ b/tensorflow/compiler/tests/binary_ops_test.py @@ -1165,6 +1165,16 @@ class BinaryOpsTest(xla_test.XLATestCase): def testTile(self): for dtype in self.numeric_types: + self._testBinary( + array_ops.tile, + np.array([[6], [3], [4]], dtype=dtype), + np.array([2, 0], dtype=np.int32), + expected=np.empty([6, 0], dtype=dtype)) + self._testBinary( + array_ops.tile, + np.array([[6, 3, 4]], dtype=dtype), + np.array([2, 0], dtype=np.int32), + expected=np.empty([2, 0], dtype=dtype)) self._testBinary( array_ops.tile, np.array([[6]], dtype=dtype), diff --git a/tensorflow/compiler/tests/eager_test.py b/tensorflow/compiler/tests/eager_test.py index 6ead15da13b86b9d2b4cf2c19e5cf2a90b061b91..3d21fb5864c22a6f449c54d03abc0f234e28dab1 100644 --- a/tensorflow/compiler/tests/eager_test.py +++ b/tensorflow/compiler/tests/eager_test.py @@ -32,6 +32,7 @@ from tensorflow.python.layers import convolutional from tensorflow.python.layers import pooling from tensorflow.python.ops import array_ops from tensorflow.python.ops import embedding_ops +from tensorflow.python.ops import gen_random_ops from tensorflow.python.ops import init_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import nn_ops @@ -122,6 +123,14 @@ class EagerTest(xla_test.XLATestCase): with self.test_scope(): self.assertAllEqual(2, array_ops.identity(2)) + def testRandomOps(self): + with self.test_scope(): + tensor = gen_random_ops.random_uniform((2, 2), dtypes.float32) + row0 = tensor[0].numpy() + row1 = tensor[1].numpy() + # It should be very unlikely to rng to generate two equal rows. + self.assertFalse((row0 == row1).all()) + def testIdentityOnVariable(self): with self.test_scope(): v = resource_variable_ops.ResourceVariable(True) @@ -400,6 +409,21 @@ class EagerFunctionTest(xla_test.XLATestCase): self.assertEqual(75, y.numpy()) self.assertEqual(30, dy.numpy()) + def testGradientTapeInDefun(self): + with self.test_scope(): + v0 = resource_variable_ops.ResourceVariable(5.0) + + @function.defun + def f(): + x = constant_op.constant(1.0) + with backprop.GradientTape() as tape: + y = v0 * x + dy = tape.gradient(y, v0) + return dy + + dy = f() + self.assertEqual(1.0, dy.numpy()) + def testSliceInDefun(self): with self.test_scope(): @@ -419,7 +443,6 @@ class EagerFunctionTest(xla_test.XLATestCase): self.assertAllEqual((2, 3, 4), dz.shape.as_list()) def testNestedDefun(self): - self.skipTest('Nested defuns do not work on TPU at the moment') with self.test_scope(): @function.defun diff --git a/tensorflow/compiler/tests/image_ops_test.py b/tensorflow/compiler/tests/image_ops_test.py index 8b01ef96db3e8ab58850df234c2e05b764be52ba..bf986ade06b11358552ee92df3169f965ce3f534 100644 --- a/tensorflow/compiler/tests/image_ops_test.py +++ b/tensorflow/compiler/tests/image_ops_test.py @@ -26,6 +26,7 @@ import numpy as np from six.moves import xrange # pylint: disable=redefined-builtin from tensorflow.compiler.tests import xla_test +from tensorflow.python.compat import compat from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops @@ -579,5 +580,140 @@ class ResizeBilinearTest(xla_test.XLATestCase): large_tolerance=True) +class NonMaxSuppressionTest(xla_test.XLATestCase): + + def testNMS128From1024(self): + # TODO(b/26783907): The Sort HLO is not implemented on CPU or GPU. + if self.device in ["XLA_CPU", "XLA_GPU"]: + return + + with compat.forward_compatibility_horizon(2018, 8, 8): + num_boxes = 1024 + boxes_np = np.random.normal(50, 10, (num_boxes, 4)).astype("f4") + scores_np = np.random.normal(0.5, 0.1, (num_boxes,)).astype("f4") + + max_output_size = 128 + iou_threshold_np = np.array(0.5, dtype=np.float32) + score_threshold_np = np.array(0.0, dtype=np.float32) + + with self.test_session() as sess: + boxes = array_ops.placeholder(boxes_np.dtype, shape=boxes_np.shape) + scores = array_ops.placeholder(scores_np.dtype, shape=scores_np.shape) + iou_threshold = array_ops.placeholder(iou_threshold_np.dtype, + iou_threshold_np.shape) + score_threshold = array_ops.placeholder(score_threshold_np.dtype, + score_threshold_np.shape) + with self.test_scope(): + selected_indices = image_ops.non_max_suppression_padded( + boxes=boxes, + scores=scores, + max_output_size=max_output_size, + iou_threshold=iou_threshold, + score_threshold=score_threshold, + pad_to_max_output_size=True) + inputs_feed = { + boxes: boxes_np, + scores: scores_np, + score_threshold: score_threshold_np, + iou_threshold: iou_threshold_np + } + (indices_tf, _) = sess.run(selected_indices, feed_dict=inputs_feed) + + self.assertEqual(indices_tf.size, max_output_size) + + def testNMS3From6Boxes(self): + # TODO(b/26783907): The Sort HLO is not implemented on CPU or GPU. + if self.device in ["XLA_CPU", "XLA_GPU"]: + return + + with compat.forward_compatibility_horizon(2018, 8, 8): + # Three boxes are selected based on IOU. + boxes_data = [[0, 0, 1, 1], [0, 0.1, 1, 1.1], [0, -0.1, 1, 0.9], + [0, 10, 1, 11], [0, 10.1, 1, 11.1], [0, 100, 1, 101]] + boxes_np = np.array(boxes_data, dtype=np.float32) + + scores_data = [0.9, 0.75, 0.6, 0.95, 0.5, 0.3] + scores_np = np.array(scores_data, dtype=np.float32) + + max_output_size = 3 + iou_threshold_np = np.array(0.5, dtype=np.float32) + score_threshold_np = np.array(0.0, dtype=np.float32) + + with self.test_session() as sess: + boxes = array_ops.placeholder(boxes_np.dtype, shape=boxes_np.shape) + scores = array_ops.placeholder(scores_np.dtype, shape=scores_np.shape) + iou_threshold = array_ops.placeholder(iou_threshold_np.dtype, + iou_threshold_np.shape) + score_threshold = array_ops.placeholder(score_threshold_np.dtype, + score_threshold_np.shape) + with self.test_scope(): + selected_indices = image_ops.non_max_suppression_padded( + boxes=boxes, + scores=scores, + max_output_size=max_output_size, + iou_threshold=iou_threshold, + score_threshold=score_threshold, + pad_to_max_output_size=True) + inputs_feed = { + boxes: boxes_np, + scores: scores_np, + score_threshold: score_threshold_np, + iou_threshold: iou_threshold_np + } + (indices_tf, num_valid) = sess.run( + selected_indices, feed_dict=inputs_feed) + + self.assertEqual(indices_tf.size, max_output_size) + self.assertEqual(num_valid, 3) + self.assertAllClose(indices_tf[:num_valid], [3, 0, 5]) + + def testNMS3Then2WithScoreThresh(self): + # Three boxes are selected based on IOU. + # One is filtered out by score threshold. + + # TODO(b/26783907): The Sort HLO is not implemented on CPU or GPU. + if self.device in ["XLA_CPU", "XLA_GPU"]: + return + + with compat.forward_compatibility_horizon(2018, 8, 8): + boxes_data = [[0, 0, 1, 1], [0, 0.1, 1, 1.1], [0, -0.1, 1, 0.9], + [0, 10, 1, 11], [0, 10.1, 1, 11.1], [0, 100, 1, 101]] + boxes_np = np.array(boxes_data, dtype=np.float32) + + scores_data = [0.9, 0.75, 0.6, 0.95, 0.5, 0.3] + scores_np = np.array(scores_data, dtype=np.float32) + max_output_size = 3 + iou_threshold_np = np.array(0.5, dtype=np.float32) + score_threshold_np = np.array(0.4, dtype=np.float32) + + with self.test_session() as sess: + boxes = array_ops.placeholder(boxes_np.dtype, shape=boxes_np.shape) + scores = array_ops.placeholder(scores_np.dtype, shape=scores_np.shape) + iou_threshold = array_ops.placeholder(iou_threshold_np.dtype, + iou_threshold_np.shape) + score_threshold = array_ops.placeholder(score_threshold_np.dtype, + score_threshold_np.shape) + with self.test_scope(): + selected_indices = image_ops.non_max_suppression_padded( + boxes=boxes, + scores=scores, + max_output_size=max_output_size, + iou_threshold=iou_threshold, + score_threshold=score_threshold, + pad_to_max_output_size=True) + inputs_feed = { + boxes: boxes_np, + scores: scores_np, + iou_threshold: iou_threshold_np, + score_threshold: score_threshold_np + } + (indices_tf, num_valid) = sess.run( + selected_indices, feed_dict=inputs_feed) + + self.assertEqual(indices_tf.size, max_output_size) + self.assertEqual(num_valid, 2) + self.assertAllClose(indices_tf[:num_valid], [3, 0]) + + if __name__ == "__main__": test.main() diff --git a/tensorflow/compiler/tests/random_ops_test.py b/tensorflow/compiler/tests/random_ops_test.py index 14c5e7a975e478ca6ceed37c28339b40612801c8..8c4e16e4e075726d741f6ff8cdfb6b1aad6cd33e 100644 --- a/tensorflow/compiler/tests/random_ops_test.py +++ b/tensorflow/compiler/tests/random_ops_test.py @@ -57,7 +57,8 @@ class RandomOpsTest(xla_test.XLATestCase): def testRandomUniformIsNotConstant(self): def rng(dtype): - return random_ops.random_uniform(shape=[2], dtype=dtype, maxval=1000000) + dtype = dtypes.as_dtype(dtype) + return random_ops.random_uniform(shape=[2], dtype=dtype, maxval=dtype.max) for dtype in self._random_types(): self._testRngIsNotConstant(rng, dtype) @@ -73,6 +74,11 @@ class RandomOpsTest(xla_test.XLATestCase): def testRandomUniformIsInRange(self): for dtype in self._random_types(): + # TODO (b/112272078): enable bfloat16 for CPU and GPU when the bug is + # fixed. + if (self.device in ["XLA_GPU", "XLA_CPU" + ]) and (dtype in [dtypes.bfloat16, dtypes.half]): + continue with self.test_session() as sess: with self.test_scope(): x = random_ops.random_uniform( @@ -95,7 +101,7 @@ class RandomOpsTest(xla_test.XLATestCase): for dtype in [dtypes.float32]: with self.test_session() as sess: with self.test_scope(): - x = random_ops.truncated_normal(shape=[count], dtype=dtype, seed=42) + x = random_ops.truncated_normal(shape=[count], dtype=dtype) y = sess.run(x) def normal_cdf(x): @@ -124,20 +130,23 @@ class RandomOpsTest(xla_test.XLATestCase): # Department of Scientific Computing website. Florida State University. expected_mean = mu + (normal_pdf(alpha) - normal_pdf(beta)) / z * sigma actual_mean = np.mean(y) - self.assertAllClose(actual_mean, expected_mean, atol=2e-4) + self.assertAllClose(actual_mean, expected_mean, atol=2e-3) expected_median = mu + probit( (normal_cdf(alpha) + normal_cdf(beta)) / 2.) * sigma actual_median = np.median(y) - self.assertAllClose(actual_median, expected_median, atol=8e-4) + self.assertAllClose(actual_median, expected_median, atol=1e-2) expected_variance = sigma**2 * (1 + ( (alpha * normal_pdf(alpha) - beta * normal_pdf(beta)) / z) - ( (normal_pdf(alpha) - normal_pdf(beta)) / z)**2) actual_variance = np.var(y) - self.assertAllClose(actual_variance, expected_variance, rtol=3e-4) + self.assertAllClose(actual_variance, expected_variance, rtol=2*1e-3) def testShuffle1d(self): + # TODO(b/26783907): this test requires the CPU backend to implement sort. + if self.device in ["XLA_CPU"]: + return with self.test_session() as sess: with self.test_scope(): x = math_ops.range(1 << 16) diff --git a/tensorflow/compiler/tests/randomized_tests.cc b/tensorflow/compiler/tests/randomized_tests.cc index 16f293891d56d78885dd515bb7b9899faf0690f7..c0ea242044540b1cef44186880ba3cd92b8849d6 100644 --- a/tensorflow/compiler/tests/randomized_tests.cc +++ b/tensorflow/compiler/tests/randomized_tests.cc @@ -62,6 +62,7 @@ limitations under the License. #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/lib/core/stringpiece.h" +#include "tensorflow/core/lib/gtl/flatset.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/public/session.h" #include "tensorflow/core/public/session_options.h" @@ -101,6 +102,9 @@ class OpTestBuilder { OpTestBuilder& RandomInput(DataType type); OpTestBuilder& RandomInput(DataType type, std::vector dims); + // As RandomInput but the values are unique. + OpTestBuilder& RandomUniqueInput(DataType type, std::vector dims); + // Sets an attribute. template OpTestBuilder& Attr(StringPiece attr_name, T&& value); @@ -126,6 +130,7 @@ class OpTestBuilder { DataType type = DT_INVALID; bool has_dims = false; + bool needs_unique_values = false; std::vector dims; }; @@ -167,6 +172,18 @@ OpTestBuilder& OpTestBuilder::RandomInput(DataType type, return *this; } +OpTestBuilder& OpTestBuilder::RandomUniqueInput(DataType type, + std::vector dims) { + VLOG(1) << "Adding input: " << type << " " << TensorShape(dims).DebugString(); + InputDescription input; + input.type = type; + input.has_dims = true; + input.needs_unique_values = true; + input.dims = std::move(dims); + inputs_.push_back(input); + return *this; +} + template OpTestBuilder& OpTestBuilder::Attr(StringPiece attr_name, T&& value) { AddNodeAttr(attr_name, std::forward(value), &node_def_); @@ -289,7 +306,8 @@ class OpTest : public ::testing::Test { // Returns a tensor filled with random but "reasonable" values from the middle // of the type's range. If the shape is omitted, a random shape is used. // TODO(phawkins): generalize this code to a caller-supplied distribution. - Tensor RandomTensor(DataType dtype, gtl::ArraySlice shape); + Tensor RandomTensor(DataType dtype, bool needs_unique_values, + gtl::ArraySlice shape); Tensor RandomTensor(DataType dtype); // Like RandomTensor, but uses values >= 0. @@ -432,49 +450,90 @@ std::vector OpTest::RandomDims(int min_rank, int max_rank, return dims; } -Tensor OpTest::RandomTensor(DataType dtype, gtl::ArraySlice shape) { +Tensor OpTest::RandomTensor(DataType dtype, bool needs_unique_values, + gtl::ArraySlice shape) { Tensor tensor(dtype, TensorShape(shape)); switch (dtype) { case DT_FLOAT: { + gtl::FlatSet already_generated; std::uniform_real_distribution distribution(-1.0f, 1.0f); - test::FillFn(&tensor, [this, &distribution](int i) -> float { - return distribution(generator()); + test::FillFn(&tensor, [&](int i) -> float { + float generated; + do { + generated = distribution(generator()); + } while (needs_unique_values && + !already_generated.insert(generated).second); + return generated; }); break; } case DT_DOUBLE: { + gtl::FlatSet already_generated; std::uniform_real_distribution distribution(-1.0, 1.0); - test::FillFn(&tensor, [this, &distribution](int i) -> double { - return distribution(generator()); + test::FillFn(&tensor, [&](int i) -> double { + double generated; + do { + generated = distribution(generator()); + } while (needs_unique_values && + !already_generated.insert(generated).second); + return generated; }); break; } case DT_COMPLEX64: { + gtl::FlatSet> already_generated; std::uniform_real_distribution distribution(-1.0f, 1.0f); - test::FillFn(&tensor, [this, &distribution](int i) { - return complex64(distribution(generator()), distribution(generator())); + test::FillFn(&tensor, [&](int i) { + complex64 generated; + do { + generated = + complex64(distribution(generator()), distribution(generator())); + } while ( + needs_unique_values && + !already_generated + .insert(std::make_pair(generated.real(), generated.imag())) + .second); + return generated; }); break; } case DT_INT32: { + gtl::FlatSet already_generated; std::uniform_int_distribution distribution(-(1 << 20), 1 << 20); - test::FillFn(&tensor, [this, &distribution](int i) -> int32 { - return distribution(generator()); + test::FillFn(&tensor, [&](int i) -> int32 { + int32 generated; + do { + generated = distribution(generator()); + } while (needs_unique_values && + !already_generated.insert(generated).second); + return generated; }); break; } case DT_INT64: { + gtl::FlatSet already_generated; std::uniform_int_distribution distribution(-(1LL << 40), 1LL << 40); - test::FillFn(&tensor, [this, &distribution](int i) -> int64 { - return distribution(generator()); + test::FillFn(&tensor, [&](int i) -> int64 { + int64 generated; + do { + generated = distribution(generator()); + } while (needs_unique_values && + !already_generated.insert(generated).second); + return generated; }); break; } case DT_BOOL: { + gtl::FlatSet already_generated; std::bernoulli_distribution distribution; - test::FillFn(&tensor, [this, &distribution](int i) -> bool { - return distribution(generator()); + test::FillFn(&tensor, [&](int i) -> bool { + bool generated; + do { + generated = distribution(generator()); + } while (needs_unique_values && + !already_generated.insert(generated).second); + return generated; }); break; } @@ -485,7 +544,7 @@ Tensor OpTest::RandomTensor(DataType dtype, gtl::ArraySlice shape) { } Tensor OpTest::RandomTensor(DataType dtype) { - return RandomTensor(dtype, RandomDims()); + return RandomTensor(dtype, /*needs_unique_values=*/false, RandomDims()); } Tensor OpTest::RandomNonNegativeTensor(DataType dtype, @@ -761,7 +820,8 @@ OpTest::TestResult OpTest::ExpectTfAndXlaOutputsAreClose( VLOG(1) << "Ignoring oversize dims."; return kInvalid; } - input_tensors.push_back(RandomTensor(input.type, dims)); + input_tensors.push_back( + RandomTensor(input.type, input.needs_unique_values, dims)); } VLOG(1) << "Input: " << input_tensors.back().DebugString(); } @@ -960,7 +1020,7 @@ TEST_F(OpTest, ArgMax) { std::uniform_int_distribution(-num_dims, num_dims)(generator()); return ExpectTfAndXlaOutputsAreClose( OpTestBuilder("ArgMax") - .RandomInput(DT_FLOAT, dims) + .RandomUniqueInput(DT_FLOAT, dims) .Input(test::AsScalar(reduce_dim)) .Attr("T", DT_FLOAT) .Attr("Tidx", DT_INT32) @@ -976,7 +1036,7 @@ TEST_F(OpTest, ArgMin) { std::uniform_int_distribution(-num_dims, num_dims)(generator()); return ExpectTfAndXlaOutputsAreClose( OpTestBuilder("ArgMin") - .RandomInput(DT_FLOAT, dims) + .RandomUniqueInput(DT_FLOAT, dims) .Input(test::AsScalar(reduce_dim)) .Attr("T", DT_FLOAT) .Attr("Tidx", DT_INT32) diff --git a/tensorflow/compiler/tests/reverse_ops_test.py b/tensorflow/compiler/tests/reverse_ops_test.py index d01c676e7c2fe705344f26818350c46c30451c67..32ab5d08f0b925ee6b7b641ddba6b950149a6d20 100644 --- a/tensorflow/compiler/tests/reverse_ops_test.py +++ b/tensorflow/compiler/tests/reverse_ops_test.py @@ -32,14 +32,20 @@ class ReverseOpsTest(xla_test.XLATestCase): def testReverseOneDim(self): shape = (7, 5, 9, 11) - for revdim in range(len(shape)): + for revdim in range(-len(shape), len(shape)): self._AssertReverseEqual([revdim], shape) def testReverseMoreThanOneDim(self): shape = (7, 5, 9, 11) + # The offset is used to test various (but not all) combinations of negative + # and positive axis indices that are guaranteed to not collide at the same + # index. for revdims in itertools.chain.from_iterable( - itertools.combinations(range(len(shape)), k) - for k in range(2, len(shape)+1)): + itertools.combinations(range(-offset, + len(shape) - offset), k) + for k in range(2, + len(shape) + 1) + for offset in range(0, len(shape))): self._AssertReverseEqual(revdims, shape) def _AssertReverseEqual(self, revdims, shape): @@ -50,15 +56,16 @@ class ReverseOpsTest(xla_test.XLATestCase): p = array_ops.placeholder(dtypes.int32, shape=shape) axis = constant_op.constant( np.array(revdims, dtype=np.int32), - shape=(len(revdims),), dtype=dtypes.int32) + shape=(len(revdims),), + dtype=dtypes.int32) rval = array_ops.reverse(p, axis).eval({p: pval}) slices = [ - slice(-1, None, -1) if d in revdims else slice(None) - for d in range(len(shape))] - self.assertEqual( - pval[slices].flatten().tolist(), - rval.flatten().tolist()) + slice(-1, None, -1) + if d in revdims or d - len(shape) in revdims else slice(None) + for d in range(len(shape)) + ] + self.assertEqual(pval[slices].flatten().tolist(), rval.flatten().tolist()) if __name__ == '__main__': diff --git a/tensorflow/compiler/tests/unary_ops_test.py b/tensorflow/compiler/tests/unary_ops_test.py index 5f25ff9002964e94db384d7b01f07cfc4f8938b1..124cf9da813861fb3774e3bb29ad947af1598059 100644 --- a/tensorflow/compiler/tests/unary_ops_test.py +++ b/tensorflow/compiler/tests/unary_ops_test.py @@ -361,6 +361,12 @@ class UnaryOpsTest(xla_test.XLATestCase): np.array([[-0.05, 6.05, 5]], dtype=dtype), expected=np.array([[0, 6, 5]], dtype=dtype)) + self._assertOpOutputMatchesExpected( + nn_ops.softmax, + np.array([1, 2, 3, 4], dtype=dtype), + expected=np.array([0.032058604, 0.087144323, 0.23688284, 0.64391428], + dtype=dtype)) + self._assertOpOutputMatchesExpected( nn_ops.softmax, np.array([[1, 1, 1, 1], [1, 2, 3, 4]], dtype=dtype), @@ -369,6 +375,14 @@ class UnaryOpsTest(xla_test.XLATestCase): [0.032058604, 0.087144323, 0.23688284, 0.64391428]], dtype=dtype)) + self._assertOpOutputMatchesExpected( + nn_ops.softmax, + np.array([[[1, 1], [1, 1]], [[1, 2], [3, 4]]], dtype=dtype), + expected=np.array( + [[[0.5, 0.5], [0.5, 0.5]], + [[0.26894142, 0.73105858], [0.26894142, 0.73105858]]], + dtype=dtype)) + self._assertOpOutputMatchesExpected( nn_ops.softsign, np.array([[-2, -1, 0, 1, 2]], dtype=dtype), @@ -382,6 +396,11 @@ class UnaryOpsTest(xla_test.XLATestCase): expected=np.array( [[True, False, True], [False, True, True]], dtype=np.bool)) + self._assertOpOutputMatchesExpected( + math_ops.lgamma, + np.array(0.5, dtype=dtype), + expected=np.array(np.log(np.pi) / 2, dtype=dtype)) + self._assertOpOutputMatchesExpected( math_ops.lgamma, np.array( @@ -406,6 +425,19 @@ class UnaryOpsTest(xla_test.XLATestCase): ], dtype=dtype)) + # The actual result is complex. Take the real part. + self._assertOpOutputMatchesExpected( + math_ops.lgamma, + np.array([-1 / 2, -5 / 2, -9 / 2], dtype=dtype), + expected=np.array( + [ + np.log(np.pi) / 2 + np.log(2), + np.log(np.pi) / 2 - np.log(15) + np.log(8), + np.log(np.pi) / 2 - np.log(945) + np.log(32), + ], + dtype=dtype), + atol=1e-4) + self._assertOpOutputMatchesExpected( math_ops.digamma, np.array( diff --git a/tensorflow/compiler/tests/xla_device_test.py b/tensorflow/compiler/tests/xla_device_test.py index 06d977b93c28792704b910c688af510bc650d2a4..85084bb1240cf05f6eabfbea772df113cabe613c 100644 --- a/tensorflow/compiler/tests/xla_device_test.py +++ b/tensorflow/compiler/tests/xla_device_test.py @@ -21,6 +21,8 @@ from __future__ import print_function import numpy as np from tensorflow.compiler.tests import xla_test +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import errors from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import gen_control_flow_ops @@ -47,6 +49,34 @@ class XlaDeviceTest(xla_test.XLATestCase): result = sess.run(z, {x: inputs}) self.assertAllCloseAccordingToType(result, inputs + inputs) + def testCopiesOfUnsupportedTypesFailGracefully(self): + """Tests that copies of unsupported types don't crash.""" + test_types = set([ + np.uint8, np.uint16, np.uint32, np.uint64, np.int8, np.int16, np.int32, + np.int64, np.float16, np.float32, np.float16, + dtypes.bfloat16.as_numpy_dtype + ]) + shape = (10, 10) + for unsupported_dtype in test_types - self.all_types: + with self.test_session() as sess: + with ops.device("CPU"): + x = array_ops.placeholder(unsupported_dtype, shape) + with self.test_scope(): + y, = array_ops.identity_n([x]) + with ops.device("CPU"): + z = array_ops.identity(y) + + inputs = np.random.randint(-100, 100, shape) + inputs = inputs.astype(unsupported_dtype) + # Execution should either succeed or raise an InvalidArgumentError, + # but not crash. Even "unsupported types" may succeed here since some + # backends (e.g., the CPU backend) are happy to handle buffers of + # unsupported types, even if they cannot compute with them. + try: + sess.run(z, {x: inputs}) + except errors.InvalidArgumentError: + pass + def testControlTrigger(self): with self.test_session() as sess: with self.test_scope(): diff --git a/tensorflow/compiler/tf2xla/BUILD b/tensorflow/compiler/tf2xla/BUILD index 881624fff8c575f90051c06575364135e355b0f0..575917d078d6042f3c90c858bad3f4366085cc5b 100644 --- a/tensorflow/compiler/tf2xla/BUILD +++ b/tensorflow/compiler/tf2xla/BUILD @@ -91,6 +91,22 @@ cc_library( ], ) +cc_library( + name = "cpu_function_runtime", + srcs = ["cpu_function_runtime.cc"], + hdrs = ["cpu_function_runtime.h"], + visibility = [ + "//tensorflow/compiler/aot:__pkg__", + "//tensorflow/compiler/xla/service/cpu:__pkg__", + ], + deps = [ + # Keep dependencies to a minimum here; this library is used in every AOT + # binary produced by tfcompile. + "//tensorflow/compiler/xla:executable_run_options", + "//tensorflow/core:framework_lite", + ], +) + cc_library( name = "xla_compiled_cpu_function", srcs = ["xla_compiled_cpu_function.cc"], @@ -99,12 +115,23 @@ cc_library( deps = [ # Keep dependencies to a minimum here; this library is used in every AOT # binary produced by tfcompile. - "//tensorflow/compiler/aot:runtime", + ":cpu_function_runtime", "//tensorflow/compiler/xla:executable_run_options", "//tensorflow/core:framework_lite", ], ) +tf_cc_test( + name = "cpu_function_runtime_test", + srcs = ["cpu_function_runtime_test.cc"], + deps = [ + ":cpu_function_runtime", + "//tensorflow/core:framework", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + ], +) + cc_library( name = "xla_jit_compiled_cpu_function", srcs = ["xla_jit_compiled_cpu_function.cc"], @@ -121,6 +148,7 @@ cc_library( "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/service:cpu_plugin", + "//tensorflow/compiler/xla/service/cpu:buffer_info_util", "//tensorflow/compiler/xla/service/cpu:cpu_executable", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", @@ -140,14 +168,12 @@ cc_library( "xla_op_registry.cc", "xla_resource.cc", "xla_cpu_backend.cc", - "legacy_flags/backend_registration_flags.cc", ] + if_cuda_is_configured([ "xla_gpu_backend.cc", ]), hdrs = [ "const_analysis.h", "graph_compiler.h", - "legacy_flags/backend_registration_flags.h", "xla_compilation_device.h", "xla_compiler.h", "xla_context.h", @@ -173,20 +199,19 @@ cc_library( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:client_library", "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/client/lib:arithmetic", "//tensorflow/compiler/xla/client/lib:constants", "//tensorflow/compiler/xla/client/lib:numeric", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", - "//tensorflow/compiler/xla/legacy_flags:parse_flags_from_env", "//tensorflow/core:core_cpu", "//tensorflow/core:core_cpu_internal", "//tensorflow/core:framework", - "//tensorflow/core:framework_internal", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", "//tensorflow/core:protos_all_cc", "//tensorflow/core:stream_executor_no_cuda", + "@com_google_absl//absl/memory", ], alwayslink = 1, ) @@ -418,22 +443,95 @@ cc_library( ], ) +cc_library( + name = "functionalize_control_flow_util", + srcs = [ + "functionalize_control_flow_util.cc", + ], + hdrs = [ + "functionalize_control_flow_util.h", + ], + deps = [ + "//tensorflow/compiler/tf2xla/ops:xla_ops", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/core:core_cpu", + "//tensorflow/core:core_cpu_internal", + "//tensorflow/core:graph", + "//tensorflow/core:protos_all_cc", + ], +) + +cc_library( + name = "functionalize_cond", + srcs = [ + "functionalize_cond.cc", + ], + hdrs = [ + "functionalize_cond.h", + ], + deps = [ + ":functionalize_control_flow_util", + ":tf2xla_util", + "//tensorflow/compiler/jit:union_find", + "//tensorflow/compiler/tf2xla:dump_graph", + "//tensorflow/compiler/tf2xla/ops:xla_ops", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/core:core_cpu", + "//tensorflow/core:core_cpu_internal", + "//tensorflow/core:framework", + "//tensorflow/core:graph", + "//tensorflow/core:lib", + "@com_google_absl//absl/memory", + ], +) + cc_library( name = "functionalize_control_flow", - srcs = ["functionalize_control_flow.cc"], - hdrs = ["functionalize_control_flow.h"], + srcs = [ + "functionalize_control_flow.cc", + ], + hdrs = [ + "functionalize_control_flow.h", + ], + deps = [ + ":functionalize_cond", + ":functionalize_control_flow_util", + ":functionalize_while", + ":tf2xla_util", + "//tensorflow/compiler/jit:union_find", + "//tensorflow/compiler/tf2xla:dump_graph", + "//tensorflow/compiler/tf2xla/ops:xla_ops", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/core:core_cpu", + "//tensorflow/core:core_cpu_internal", + "//tensorflow/core:framework", + "//tensorflow/core:graph", + "//tensorflow/core:lib", + "@com_google_absl//absl/memory", + ], +) + +cc_library( + name = "functionalize_while", + srcs = [ + "functionalize_while.cc", + ], + hdrs = [ + "functionalize_while.h", + ], deps = [ + ":functionalize_control_flow_util", ":tf2xla_util", "//tensorflow/compiler/jit:union_find", "//tensorflow/compiler/tf2xla:dump_graph", "//tensorflow/compiler/tf2xla/ops:xla_ops", "//tensorflow/compiler/xla:status_macros", - "//tensorflow/compiler/xla:util", "//tensorflow/core:core_cpu", "//tensorflow/core:core_cpu_internal", "//tensorflow/core:framework", "//tensorflow/core:graph", "//tensorflow/core:lib", + "@com_google_absl//absl/memory", ], ) @@ -461,6 +559,32 @@ tf_cc_test( ], ) +tf_cc_test( + name = "functionalize_cond_test", + srcs = ["functionalize_cond_test.cc"], + deps = [ + ":functionalize_cond", + ":functionalize_control_flow", + ":test_util", + "//tensorflow/cc:cc_ops", + "//tensorflow/cc:cc_ops_internal", + "//tensorflow/cc:function_ops", + "//tensorflow/cc:ops", + "//tensorflow/cc:resource_variable_ops", + "//tensorflow/compiler/tf2xla/cc:xla_ops", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/core:core_cpu", + "//tensorflow/core:core_cpu_internal", + "//tensorflow/core:framework", + "//tensorflow/core:framework_internal", + "//tensorflow/core:ops", + "//tensorflow/core:resource_variable_ops_op_lib", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core:testlib", + ], +) + cc_library( name = "test_util", testonly = 1, diff --git a/tensorflow/compiler/aot/runtime.cc b/tensorflow/compiler/tf2xla/cpu_function_runtime.cc similarity index 70% rename from tensorflow/compiler/aot/runtime.cc rename to tensorflow/compiler/tf2xla/cpu_function_runtime.cc index 5e74079fc158379b8977ada6412141e39142c3d3..fcc4095e39673b786544984a41988c3e9c5b0efb 100644 --- a/tensorflow/compiler/aot/runtime.cc +++ b/tensorflow/compiler/tf2xla/cpu_function_runtime.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -13,22 +13,16 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/aot/runtime.h" - -#include +#include "tensorflow/compiler/tf2xla/cpu_function_runtime.h" #include "tensorflow/core/platform/dynamic_annotations.h" namespace tensorflow { -namespace tfcompile { -namespace runtime { - namespace { - // Inline memory allocation routines here, because depending on '//base' brings // in libraries which use c++ streams, which adds considerable code size on // android. -inline void* aligned_malloc(size_t size, int minimum_alignment) { +void* aligned_malloc(size_t size, int minimum_alignment) { #if defined(__ANDROID__) || defined(OS_ANDROID) || defined(OS_CYGWIN) return memalign(minimum_alignment, size); #elif defined(_WIN32) @@ -47,7 +41,7 @@ inline void* aligned_malloc(size_t size, int minimum_alignment) { #endif } -inline void aligned_free(void* aligned_memory) { +void aligned_free(void* aligned_memory) { #if defined(_WIN32) _aligned_free(aligned_memory); #else @@ -58,22 +52,29 @@ inline void aligned_free(void* aligned_memory) { size_t align_to(size_t n, size_t align) { return (((n - 1) / align) + 1) * align; } - } // namespace -size_t aligned_buffer_bytes(const intptr_t* sizes, size_t n) { +namespace cpu_function_runtime { +size_t AlignedBufferBytes(const BufferInfo* buffer_infos, size_t n, + bool allocate_entry_params) { size_t total = 0; for (size_t i = 0; i < n; ++i) { - if (sizes[i] != -1) { - total += align_to(sizes[i], kAlign); + bool should_allocate = + buffer_infos[i].is_temp_buffer() || + (buffer_infos[i].is_entry_parameter() && allocate_entry_params); + + if (should_allocate) { + total += align_to(buffer_infos[i].size(), kAlign); } } return total; } -void* MallocContiguousBuffers(const intptr_t* sizes, size_t n, void** bufs, +void* MallocContiguousBuffers(const BufferInfo* buffer_infos, size_t n, + bool allocate_entry_params, void** bufs, bool annotate_initialized) { - const size_t total = aligned_buffer_bytes(sizes, n); + const size_t total = + AlignedBufferBytes(buffer_infos, n, allocate_entry_params); void* contiguous = nullptr; if (total > 0) { contiguous = aligned_malloc(total, kAlign); @@ -85,11 +86,14 @@ void* MallocContiguousBuffers(const intptr_t* sizes, size_t n, void** bufs, } uintptr_t pos = reinterpret_cast(contiguous); for (size_t i = 0; i < n; ++i) { - if (sizes[i] == -1) { - bufs[i] = nullptr; - } else { + bool should_allocate = + buffer_infos[i].is_temp_buffer() || + (buffer_infos[i].is_entry_parameter() && allocate_entry_params); + if (should_allocate) { bufs[i] = reinterpret_cast(pos); - pos += align_to(sizes[i], kAlign); + pos += align_to(buffer_infos[i].size(), kAlign); + } else { + bufs[i] = nullptr; } } return contiguous; @@ -100,7 +104,5 @@ void FreeContiguous(void* contiguous) { aligned_free(contiguous); } } - -} // namespace runtime -} // namespace tfcompile +} // namespace cpu_function_runtime } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/cpu_function_runtime.h b/tensorflow/compiler/tf2xla/cpu_function_runtime.h new file mode 100644 index 0000000000000000000000000000000000000000..dfc1e8b8aebcf3142e9f61f60171c6b58634c71d --- /dev/null +++ b/tensorflow/compiler/tf2xla/cpu_function_runtime.h @@ -0,0 +1,165 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_TF2XLA_CPU_FUNCTION_RUNTIME_H_ +#define TENSORFLOW_COMPILER_TF2XLA_CPU_FUNCTION_RUNTIME_H_ + +#include "tensorflow/core/platform/types.h" + +#include + +namespace tensorflow { +namespace cpu_function_runtime { +// Stores information about one buffer used by an XLA:CPU compiled function. +// These buffers are used for holding inputs to the computation, outputs from +// the computation and as temporary scratch space. +class BufferInfo { + public: + // Creates a BufferInfo from a serialized encoding generated by `Encode`. + explicit BufferInfo(std::pair encoding) + : entry_param_number_(encoding.second) { + Kind kind; + uint64 size; + Unpack(encoding.first, &kind, &size); + kind_ = kind; + size_ = size; + } + + // Returns true if this buffer stores a constant. These never need to be + // allocated by the runtime. + bool is_constant() const { return kind() == Kind::kConstant; } + + // Returns true if this buffer stores an entry parameter. These may or may + // not need to be allocated by the runtime, depending on + // XlaCompiledCpuFunction::AllocMode. + bool is_entry_parameter() const { return kind() == Kind::kEntryParameter; } + + // Returns the entry parameter number of this buffer. + uint64 entry_parameter_number() const { + assert(is_entry_parameter()); + return entry_param_number_; + } + + // Returns true if this buffer is temporary scratch space required by the XLA + // computations. These are always allocated by the runtime. + bool is_temp_buffer() const { return kind() == Kind::kTempBuffer; } + + // Returns true if this buffer is allocated on the C stack or into registers. + // These buffers are never allocated by the runtime. + bool is_on_stack_buffer() const { return kind() == Kind::kOnStackBuffer; } + + // Returns the size for this buffer. + uint64 size() const { return size_; } + + // Encodes this BufferInfo into two 64 bit integers that can be used to + // reconstruct the BufferInfo later using the constructor. We need this + // because we use BufferInfo in places where using protocol buffers would + // negatively impact binary size. + std::pair Encode() const { + static_assert(sizeof(*this) == 16, ""); + uint64 upper = Pack(kind(), size_); + uint64 lower = entry_param_number_; + return {upper, lower}; + } + + bool operator==(const BufferInfo& buffer_info) const { + if (kind() != buffer_info.kind() || size() != buffer_info.size()) { + return false; + } + return !is_entry_parameter() || + entry_parameter_number() == buffer_info.entry_parameter_number(); + } + + // Factory methods: + + static BufferInfo MakeTempBuffer(uint64 size) { + return BufferInfo(Kind::kTempBuffer, /*size=*/size, + /*entry_param_number=*/-1); + } + static BufferInfo MakeConstant(uint64 size) { + return BufferInfo(Kind::kConstant, /*size=*/size, + /*entry_param_number=*/-1); + } + static BufferInfo MakeEntryParameter(uint64 size, uint64 param_number) { + return BufferInfo(Kind::kEntryParameter, /*size=*/size, + /*entry_param_number=*/param_number); + } + static BufferInfo MakeOnStackBuffer(uint64 size) { + return BufferInfo(Kind::kOnStackBuffer, /*size=*/size, + /*entry_param_number=*/-1); + } + + private: + BufferInfo() = default; + + enum class Kind : unsigned { + kConstant, + kTempBuffer, + kEntryParameter, + kOnStackBuffer + }; + + Kind kind() const { return static_cast(kind_); } + + explicit BufferInfo(Kind kind, uint64 size, uint64 entry_param_number) + : kind_(kind), size_(size), entry_param_number_(entry_param_number) {} + + static uint64 Pack(Kind kind, uint64 size) { + return (static_cast(size) << 2) | static_cast(kind); + } + + static void Unpack(uint64 packed, Kind* kind, uint64* size) { + *size = packed >> 2; + *kind = static_cast((packed << 62) >> 62); + } + + Kind kind_ : 2; + uint64 size_ : 62; + int64 entry_param_number_; +}; + +// Align to 64-bytes, to mimic tensorflow::Allocator::kAllocatorAlignment. +constexpr size_t kAlign = 64; + +// AlignedBufferBytes returns the sum of the size of each buffer in +// `buffer_infos`, skipping constants, on-stack buffers and, if +// allocate_entry_params is false, entry parameters. There are `n` entries in +// `buffer_infos`. Each buffer is aligned to kAlign byte boundaries. +size_t AlignedBufferBytes(const BufferInfo* buffer_infos, size_t n, + bool allocate_entry_params); + +// MallocContiguousBuffers allocates buffers for use by the entry point +// generated by tfcompile. There are `n` entries in `buffer_infos`. If +// `annotate_initialized` is set, the allocated memory will be annotated as +// having been initialized - this is useful when allocating temporary buffers. +// If allocate_entry_params is true then allocates temp buffers and entry +// parameters, otherwise allocated only temp buffers. Slots in `bufs` +// corresponding to unallocated buffers are set to nullptr. +// +// A single contiguous block of memory is allocated, and portions of it are +// parceled out into `bufs`, which must have space for `n` entries. Returns +// the head of the allocated contiguous block, which should be passed to +// FreeContiguous when the buffers are no longer in use. +void* MallocContiguousBuffers(const BufferInfo* buffer_infos, size_t n, + bool allocate_entry_params, void** bufs, + bool annotate_initialized); + +// FreeContiguous frees the contiguous block of memory allocated by +// MallocContiguousBuffers. +void FreeContiguous(void* contiguous); +} // namespace cpu_function_runtime +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_TF2XLA_CPU_FUNCTION_RUNTIME_H_ diff --git a/tensorflow/compiler/aot/runtime_test.cc b/tensorflow/compiler/tf2xla/cpu_function_runtime_test.cc similarity index 50% rename from tensorflow/compiler/aot/runtime_test.cc rename to tensorflow/compiler/tf2xla/cpu_function_runtime_test.cc index 06ec623eb2dce5f8dc7156fb7e7b9ad57d90c8ee..8ca628c4eb6700d7184899bc1753dd6c6aa392b0 100644 --- a/tensorflow/compiler/aot/runtime_test.cc +++ b/tensorflow/compiler/tf2xla/cpu_function_runtime_test.cc @@ -13,39 +13,70 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/aot/runtime.h" +#include "tensorflow/compiler/tf2xla/cpu_function_runtime.h" #include "tensorflow/core/framework/allocator.h" #include "tensorflow/core/platform/test.h" namespace tensorflow { -namespace tfcompile { -namespace runtime { namespace { -TEST(Runtime, AlignmentValue) { +using cpu_function_runtime::BufferInfo; + +TEST(XlaCompiledCpuFunctionTest, AlignmentValue) { // We've chosen 64 byte alignment for the tfcompile runtime to mimic the // regular tensorflow allocator, which was chosen to play nicely with Eigen. // The tfcompile runtime also has a requirement that comes from the xla // generated code, on the relation: buffer_size >= 16 ? 2 * sizeof(void*) : 8 // So any value that we choose must abide by that constraint as well. - EXPECT_EQ(kAlign, Allocator::kAllocatorAlignment); + EXPECT_EQ(cpu_function_runtime::kAlign, Allocator::kAllocatorAlignment); +} + +std::vector SizesToBufferInfos(const intptr_t* sizes, size_t n) { + std::vector buffer_infos; + std::transform(sizes, sizes + n, std::back_inserter(buffer_infos), + [&](intptr_t size) { + if (size == -1) { + // Use a dummy on-stack buffer allocation to indicat the + // the current slot does not need an allocation. + int64 on_stack_buffer_size = 4; + return BufferInfo::MakeOnStackBuffer(on_stack_buffer_size); + } + return BufferInfo::MakeTempBuffer(size); + }); + return buffer_infos; +} + +// Simple wrappers to make writing tests more ergonomic. + +size_t AlignedBufferBytesFromSizes(const intptr_t* sizes, size_t n) { + std::vector buffer_infos = SizesToBufferInfos(sizes, n); + return AlignedBufferBytes(buffer_infos.data(), n, + /*allocate_entry_params=*/false); } -TEST(Runtime, AlignedBufferBytes) { - EXPECT_EQ(aligned_buffer_bytes(nullptr, 0), 0); +void* MallocContiguousBuffersFromSizes(const intptr_t* sizes, size_t n, + void** bufs, bool annotate_initialized) { + std::vector buffer_infos = SizesToBufferInfos(sizes, n); + return MallocContiguousBuffers(buffer_infos.data(), n, + /*allocate_entry_params=*/false, bufs, + annotate_initialized); +} + +TEST(XlaCompiledCpuFunctionTest, AlignedBufferBytes) { + EXPECT_EQ(AlignedBufferBytesFromSizes(nullptr, 0), 0); static constexpr intptr_t sizesA[1] = {-1}; - EXPECT_EQ(aligned_buffer_bytes(sizesA, 1), 0); + EXPECT_EQ(AlignedBufferBytesFromSizes(sizesA, 1), 0); static constexpr intptr_t sizesB[1] = {3}; - EXPECT_EQ(aligned_buffer_bytes(sizesB, 1), 64); + EXPECT_EQ(AlignedBufferBytesFromSizes(sizesB, 1), 64); static constexpr intptr_t sizesC[1] = {32}; - EXPECT_EQ(aligned_buffer_bytes(sizesC, 1), 64); + EXPECT_EQ(AlignedBufferBytesFromSizes(sizesC, 1), 64); static constexpr intptr_t sizesD[7] = {1, -1, 32, -1, 64, 2, 3}; - EXPECT_EQ(aligned_buffer_bytes(sizesD, 7), 320); + EXPECT_EQ(AlignedBufferBytesFromSizes(sizesD, 7), 320); } void* add_ptr(void* base, uintptr_t delta) { @@ -56,48 +87,48 @@ void* add_ptr(void* base, uintptr_t delta) { // expected nullptrs, and write to each byte of allocated memory. We rely on // the leak checker to tell us if there's an inconsistency between malloc and // free. We also check the contiguous property. -TEST(Runtime, MallocFreeContiguousBuffers) { +TEST(XlaCompiledCpuFunctionTest, MallocFreeContiguousBuffers) { // Test empty sizes. - void* base = MallocContiguousBuffers(nullptr, 0, nullptr, false); + void* base = MallocContiguousBuffersFromSizes(nullptr, 0, nullptr, false); EXPECT_EQ(base, nullptr); - FreeContiguous(base); + cpu_function_runtime::FreeContiguous(base); // Test non-empty sizes with 0 sum. static constexpr intptr_t sizesA[1] = {-1}; void* bufA[1]; - base = MallocContiguousBuffers(sizesA, 1, bufA, false); + base = MallocContiguousBuffersFromSizes(sizesA, 1, bufA, false); EXPECT_EQ(base, nullptr); EXPECT_EQ(bufA[0], nullptr); - FreeContiguous(base); + cpu_function_runtime::FreeContiguous(base); // Test non-empty sizes with non-0 sum. static constexpr intptr_t sizesB[1] = {3}; void* bufB[1]; - base = MallocContiguousBuffers(sizesB, 1, bufB, false); + base = MallocContiguousBuffersFromSizes(sizesB, 1, bufB, false); EXPECT_NE(base, nullptr); EXPECT_EQ(bufB[0], add_ptr(base, 0)); char* bufB0_bytes = static_cast(bufB[0]); bufB0_bytes[0] = 'A'; bufB0_bytes[1] = 'B'; bufB0_bytes[2] = 'C'; - FreeContiguous(base); + cpu_function_runtime::FreeContiguous(base); // Test non-empty sizes with non-0 sum, and annotate_initialized. static constexpr intptr_t sizesC[1] = {3}; void* bufC[1]; - base = MallocContiguousBuffers(sizesC, 1, bufC, true); + base = MallocContiguousBuffersFromSizes(sizesC, 1, bufC, true); EXPECT_NE(base, nullptr); EXPECT_EQ(bufC[0], add_ptr(base, 0)); char* bufC0_bytes = static_cast(bufC[0]); bufC0_bytes[0] = 'A'; bufC0_bytes[1] = 'B'; bufC0_bytes[2] = 'C'; - FreeContiguous(base); + cpu_function_runtime::FreeContiguous(base); // Test mixed sizes. static constexpr intptr_t sizesD[7] = {1, -1, 32, -1, 64, 2, 3}; void* bufD[7]; - base = MallocContiguousBuffers(sizesD, 7, bufD, false); + base = MallocContiguousBuffersFromSizes(sizesD, 7, bufD, false); EXPECT_NE(base, nullptr); EXPECT_EQ(bufD[0], add_ptr(base, 0)); EXPECT_EQ(bufD[1], nullptr); @@ -115,10 +146,26 @@ TEST(Runtime, MallocFreeContiguousBuffers) { } } } - FreeContiguous(base); + cpu_function_runtime::FreeContiguous(base); +} + +void CheckRoundTripIsOk(const BufferInfo& buffer_info) { + BufferInfo round_trip(buffer_info.Encode()); + ASSERT_EQ(round_trip, buffer_info); +} + +TEST(XlaCompiledCpuFunctionTest, BufferInfoTest) { + CheckRoundTripIsOk(BufferInfo::MakeTempBuffer(0)); + CheckRoundTripIsOk(BufferInfo::MakeTempBuffer(4)); + CheckRoundTripIsOk(BufferInfo::MakeOnStackBuffer(0)); + CheckRoundTripIsOk(BufferInfo::MakeOnStackBuffer(4)); + CheckRoundTripIsOk(BufferInfo::MakeConstant(0)); + CheckRoundTripIsOk(BufferInfo::MakeConstant(4)); + CheckRoundTripIsOk( + BufferInfo::MakeEntryParameter(/*size=*/0, /*param_number=*/4)); + CheckRoundTripIsOk( + BufferInfo::MakeEntryParameter(/*size=*/4, /*param_number=*/0)); } } // namespace -} // namespace runtime -} // namespace tfcompile } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/dump_graph.cc b/tensorflow/compiler/tf2xla/dump_graph.cc index 03603ee9baefd1d20d220faf63c9c1c427ebdf31..24616c01c7e54b2e8662457ca6af23a0bc563e08 100644 --- a/tensorflow/compiler/tf2xla/dump_graph.cc +++ b/tensorflow/compiler/tf2xla/dump_graph.cc @@ -33,7 +33,7 @@ struct NameCounts { std::unordered_map counts; }; -string MakeUniquePath(string name) { +string MakeUniqueFilename(string name) { static NameCounts& instance = *new NameCounts; // Remove illegal characters from `name`. @@ -50,26 +50,41 @@ string MakeUniquePath(string name) { count = instance.counts[name]++; } - legacy_flags::DumpGraphFlags* flags = legacy_flags::GetDumpGraphFlags(); - string path = strings::StrCat(flags->tf_dump_graph_prefix, "/", name); + string filename = name; if (count > 0) { - strings::StrAppend(&path, "_", count); + strings::StrAppend(&filename, "_", count); } - strings::StrAppend(&path, ".pbtxt"); - return path; + strings::StrAppend(&filename, ".pbtxt"); + return filename; +} + +string WriteTextProtoToUniqueFile( + Env* env, const string& name, const char* proto_type, + const ::tensorflow::protobuf::Message& proto) { + const string& dirname = + legacy_flags::GetDumpGraphFlags()->tf_dump_graph_prefix; + Status status = env->RecursivelyCreateDir(dirname); + if (!status.ok()) { + LOG(WARNING) << "Failed to create " << dirname << " for dumping " + << proto_type << ": " << status; + return "(unavailable)"; + } + string filepath = strings::StrCat(dirname, "/", MakeUniqueFilename(name)); + status = WriteTextProto(Env::Default(), filepath, proto); + if (!status.ok()) { + LOG(WARNING) << "Failed to dump " << proto_type << " to file: " << filepath + << " : " << status; + return "(unavailable)"; + } + LOG(INFO) << "Dumped " << proto_type << " to " << filepath; + return filepath; } } // anonymous namespace string DumpGraphDefToFile(const string& name, GraphDef const& graph_def) { - string path = MakeUniquePath(name); - Status status = WriteTextProto(Env::Default(), path, graph_def); - if (!status.ok()) { - VLOG(1) << "Failed to dump GraphDef to file: " << path << " : " << status; - path.clear(); - path = "(unavailable)"; - } - return path; + return WriteTextProtoToUniqueFile(Env::Default(), name, "GraphDef", + graph_def); } string DumpGraphToFile(const string& name, Graph const& graph, @@ -83,15 +98,7 @@ string DumpGraphToFile(const string& name, Graph const& graph, } string DumpFunctionDefToFile(const string& name, FunctionDef const& fdef) { - string path = MakeUniquePath(name); - Status status = WriteTextProto(Env::Default(), path, fdef); - if (!status.ok()) { - VLOG(1) << "Failed to dump FunctionDef to file: " << path << " : " - << status; - path.clear(); - path = "(unavailable)"; - } - return path; + return WriteTextProtoToUniqueFile(Env::Default(), name, "FunctionDef", fdef); } } // namespace dump_graph diff --git a/tensorflow/compiler/tf2xla/functionalize_cond.cc b/tensorflow/compiler/tf2xla/functionalize_cond.cc new file mode 100644 index 0000000000000000000000000000000000000000..0f5471616e111ddd34d990c5fee396f7297e0fd4 --- /dev/null +++ b/tensorflow/compiler/tf2xla/functionalize_cond.cc @@ -0,0 +1,1380 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/tf2xla/functionalize_cond.h" + +#include +#include +#include +#include +#include + +#include "absl/memory/memory.h" +#include "tensorflow/compiler/jit/union_find.h" +#include "tensorflow/compiler/tf2xla/dump_graph.h" +#include "tensorflow/compiler/tf2xla/functionalize_control_flow_util.h" +#include "tensorflow/compiler/tf2xla/tf2xla_util.h" +#include "tensorflow/core/common_runtime/function.h" +#include "tensorflow/core/framework/graph_to_functiondef.h" +#include "tensorflow/core/framework/node_def_builder.h" +#include "tensorflow/core/graph/algorithm.h" +#include "tensorflow/core/graph/control_flow.h" +#include "tensorflow/core/graph/node_builder.h" +#include "tensorflow/core/lib/gtl/optional.h" + +using xla::StatusOr; + +namespace tensorflow { +namespace functionalize_cond { + +string DebugString(const CondStateMap::CondNode& node) { + return node.ToString(); +} + +// TODO(jpienaar): Move to OutputTensor. +string DebugString(const OutputTensor& tensor) { + return strings::StrCat(tensor.node->name(), ":", tensor.index); +} + +string DebugString(CondStateMap::CondId cond_state) { + if (cond_state == nullptr || cond_state->empty()) return "[]"; + return strings::StrCat( + "[", + tensorflow::str_util::Join( + *cond_state, ", ", + [](string* output, const CondStateMap::CondNode& node) { + strings::StrAppend(output, node.ToString()); + }), + "]"); +} + +string Branch_Name(BranchType b) { + switch (b) { + case BranchType::kElseBranch: + return "else"; + case BranchType::kThenBranch: + return "then"; + case BranchType::kBoth: + return "both"; + case BranchType::kNeither: + return "neither"; + } +} + +// Returns the predicate of a switch. +Status GetSwitchPredicate(const Node& switch_node, OutputTensor* pred) { + const Edge* pred_edge; + TF_RETURN_IF_ERROR(switch_node.input_edge(1, &pred_edge)); + // The predicate can be preceded by a identity node. Look through + // identity nodes to predicate. + while (pred_edge->src()->IsIdentity()) { + TF_RETURN_IF_ERROR(pred_edge->src()->input_edge(0, &pred_edge)); + } + *pred = OutputTensor(pred_edge->src(), pred_edge->src_output()); + return Status::OK(); +} + +CondStateMap::CondNode::CondNode(Type type, Node* switch_node, + BranchType branch) + : type(type), branch(branch) { + if (type == Type::kSwitch) { + TF_CHECK_OK(GetSwitchPredicate(*switch_node, &predicate)); + } +} + +string CondStateMap::CondNode::ToString() const { + switch (type) { + case Type::kSwitch: + return strings::StrCat("s(", DebugString(predicate), ",", + Branch_Name(branch), ")"); + case Type::kMerge: + return "m"; + case Type::kDead: + return "d"; + } +} + +bool CondStateMap::CondNode::operator==(const CondNode& other) const { + if (type != Type::kSwitch) return type == other.type; + return type == other.type && predicate == other.predicate && + branch == other.branch; +} + +bool CondStateMap::CondNode::operator!=(const CondNode& other) const { + return !(*this == other); +} + +CondStateMap::CondStateMap(Graph* graph) { + node_to_condid_map_.resize(graph->num_node_ids()); + // Initialize the dead state (empty state is designated with a nullptr). + dead_id_ = GetUniqueId({CondNode(CondStateMap::CondNode::Type::kDead)}); +} + +bool CondStateMap::IsDead(CondStateMap::CondId id) const { + return id == dead_id_; +} + +bool CondStateMap::IsEmpty(CondStateMap::CondId id) const { + return id == nullptr; +} + +size_t CondStateMap::CondHash::operator()( + const CondStateMap::CondNode& item) const { + return Hash64Combine(Hash64Combine(OutputTensor::Hash()(item.predicate), + hash()(item.branch)), + hash()(item.type)); +} + +size_t CondStateMap::CondHash::operator()( + const CondStateMap::CondState& vec) const { + if (vec.empty()) return 0; + size_t h = (*this)(vec.front()); + auto it = vec.begin(); + for (++it; it != vec.end(); ++it) { + h = Hash64Combine(h, (*this)(*it)); + } + return h; +} + +// CondArgNode represents a input to the conditional and its corresponding +// switch nodes. +struct CondArgNode { + explicit CondArgNode(Node* src, int src_output) + : src(src), src_output(src_output) {} + + string ToString() const { + return strings::StrCat("src=", src->name(), ":", src_output, + " switches=", NodesToString(switches)); + } + + Node* src; + int src_output; + std::array branch_copy; + std::vector switches; +}; +using CondArgNodes = std::vector; + +string DebugString(const CondArgNodes& nodes) { + return strings::StrCat( + "[", + tensorflow::str_util::Join(nodes, ", ", + [](string* output, const CondArgNode& node) { + strings::StrAppend(output, node.ToString()); + }), + "]"); +} + +CondStateMap::CondId CondStateMap::LookupId(const Node* node) const { + if (node->id() < node_to_condid_map_.size()) + return node_to_condid_map_[node->id()]; + return added_node_mapping_.at(node->id()); +} + +CondStateMap::CondId CondStateMap::GetUniqueId( + const CondStateMap::CondState& state) { + if (state.empty()) return nullptr; + return &*condstate_set_.insert(state).first; +} + +const CondStateMap::CondState& CondStateMap::LookupState( + const Node* node) const { + return *LookupId(node); +} + +void CondStateMap::ResetId(const Node* node, CondStateMap::CondId id) { + if (node->id() < node_to_condid_map_.size()) + node_to_condid_map_[node->id()] = id; + else + added_node_mapping_[node->id()] = id; +} + +void CondStateMap::MarkDead(const Node* node) { ResetId(node, dead_id_); } + +string CondStateMap::CondStateToString(const Node* node) const { + return CondStateToString(LookupId(node)); +} + +string CondStateMap::CondStateToString(CondStateMap::CondId id) const { + return DebugString(id); +} + +FunctionalizeCond::FunctionalizeCond(Graph* graph, + FunctionLibraryDefinition* library) + : cond_state_map_(graph), library_(library), graph_(graph) {} + +// Class representing the merge/switch nodes that will become a conditional. +class Conditional { + public: + Conditional(OutputTensor predicate, FunctionalizeCond* parent, + CondStateMap* cond_state_map); + + // Adds merge node that is part of this conditional. + Status AddMerge(Node* m); + + // Constructs an If node from the merge nodes. + Status BuildAndReplace(Graph* graph, FunctionLibraryDefinition* library); + + private: + // Extracts the then/else bodies: creates new graphs with the nodes + // corresponding to the nodes in the then/else branches as of this conditional + // as function bodies. + Status ExtractBodies(Graph* graph); + + // Builds the arguments that are the input to the If. + Status BuildArgumentNodes(); + + // Builds the If node for the extracted bodies with the given predicate. + Status BuildIfNode(Graph* graph, FunctionLibraryDefinition* library); + + // Adds input edges to If node. + Status AddInputEdges(Graph* graph); + + // Adds output edges from If node. + Status AddOutputEdges(Graph* graph); + + // Adds switch node that is part of this conditional. + Status AddSwitch(Node* s); + + // Internal name of conditional. The name is based on the first merge node + // added. + string name() const; + + // The FunctionalizeCond instance that created this. + FunctionalizeCond* parent_; + + // Mapping between nodes and their cond state. + CondStateMap* cond_state_map_; + + // The predicate of the conditional. + OutputTensor predicate_; + + // The predicate of the switches of the conditional. This may be different + // than predicate (which is initialized from the original graph) as the + // predicate could be the output of a newly created If node. + OutputTensor switch_predicate_; + + // Switch nodes in graph that are part of this conditional. + std::set switches_; + + // Merge nodes in graph that are part of this conditional. + std::set merges_; + + // Vector of control inputs from outside the conditional to a node inside. + std::vector external_control_inputs_; + std::vector external_control_outputs_; + + // Graphs corresponding to the then and else branch. + std::array, 2> bodies_; + + // Maps from graph_ to the branch body's graph. + std::array, 2> node_maps_; + + // The argument nodes created for the switches. + CondArgNodes cond_arg_nodes_; + + // The constructed If node. + Node* if_node_ = nullptr; + + // Whether the merge nodes of this conditional have been replaced. + bool replaced_ = false; +}; + +Conditional::Conditional(OutputTensor predicate, FunctionalizeCond* parent, + CondStateMap* cond_state_map) + : parent_(parent), cond_state_map_(cond_state_map), predicate_(predicate) {} + +Status Conditional::AddMerge(Node* m) { + merges_.insert(m); + return Status::OK(); +} + +Status Conditional::AddSwitch(Node* s) { + VLOG(5) << "Adding switch " << s->DebugString(); + OutputTensor predicate; + TF_RETURN_IF_ERROR(GetSwitchPredicate(*s, &predicate)); + if (switch_predicate_.node == nullptr) switch_predicate_ = predicate; + if (!(switch_predicate_ == predicate)) { + return errors::InvalidArgument( + "Merge nodes ", NodesToString(merges_), + " directly dominated by switch nodes with different predicates (", + DebugString(switch_predicate_), " vs ", DebugString(predicate), ")."); + } + switches_.insert(s); + return Status::OK(); +} + +Status Conditional::BuildArgumentNodes() { + VLOG(1) << "Build function arguments"; + struct Hash { + size_t operator()(const std::pair& item) const { + return Hash64Combine(hash()(item.first), + std::hash()(item.second)); + } + }; + + std::unordered_map, int, Hash> input_index; + for (Node* switch_node : switches_) { + const Edge* e; + TF_RETURN_IF_ERROR(switch_node->input_edge(0, &e)); + std::pair key = std::make_pair(e->src(), e->src_output()); + if (input_index.find(key) == input_index.end()) { + input_index[key] = cond_arg_nodes_.size(); + cond_arg_nodes_.emplace_back(key.first, key.second); + } + cond_arg_nodes_.at(input_index.at(key)).switches.push_back(switch_node); + } + VLOG(5) << "CondArg nodes created: " << DebugString(cond_arg_nodes_); + + int arg_count = 0; + for (CondArgNode& cond_arg_node : cond_arg_nodes_) { + DataType dtype = cond_arg_node.src->output_type(cond_arg_node.src_output); + for (auto branch : {BranchType::kElseBranch, BranchType::kThenBranch}) { + int branch_index = static_cast(branch); + TF_RETURN_IF_ERROR( + NodeBuilder(strings::StrCat("_Arg", arg_count), + FunctionLibraryDefinition::kArgOp) + .Attr("T", dtype) + .Attr("index", arg_count) + .Finalize(bodies_[branch_index].get(), + &cond_arg_node.branch_copy[branch_index])); + } + for (Node* node : cond_arg_node.switches) { + for (const Edge* e : node->out_edges()) { + if (e->IsControlEdge()) continue; + int branch_index = e->src_output(); + Node* src_copy = cond_arg_node.branch_copy[branch_index]; + Node* dst_copy = node_maps_[branch_index][e->dst()->id()]; + + // The graph may contain dead switch nodes, + if (dst_copy == nullptr) continue; + + TF_RET_CHECK(dst_copy != nullptr) + << "Unable to find copied node for " << e->dst()->DebugString() + << " on branch " << Branch_Name(BranchType(branch_index)); + // If the input goes directly to a merge then the merge has + // been replaced by a retval so the dst input is 0 instead of + // dst_input. + int dst_input = IsMerge(e->dst()) ? 0 : e->dst_input(); + bodies_[branch_index]->AddEdge(src_copy, 0, dst_copy, dst_input); + } + } + ++arg_count; + } + + // Verify that all retvals have an input. + // TODO(jpienaar): One could add a ZerosLike in the branch that doesn't have + // input. + for (Node* m : merges_) { + for (auto branch : {BranchType::kElseBranch, BranchType::kThenBranch}) { + bool has_input = false; + for (auto e : node_maps_[static_cast(branch)][m->id()]->in_edges()) { + if (!e->IsControlEdge()) { + has_input = true; + break; + } + } + if (!has_input) { + return errors::Internal( + "Failed to functionalize control flow with merge '", m->name(), + "' that doesn't have input on ", Branch_Name(branch), " branch."); + } + } + } + + return Status::OK(); +} + +Status Conditional::ExtractBodies(Graph* graph) { + VLOG(2) << "Extracting bodies for " << name(); + for (auto b : {BranchType::kElseBranch, BranchType::kThenBranch}) { + bodies_[static_cast(b)] = + absl::make_unique(graph->op_registry()); + } + + auto find_branch = [&](const Edge* e) { + const auto& id = cond_state_map_->LookupId(e->src()); + return IsSwitch(e->src()) ? BranchType(e->src_output()) + : cond_state_map_->FindBranchOf(id, predicate_); + }; + + std::array, 2> stacks; + VLOG(5) << "Merges: " << NodesToString(merges_); + for (Node* m : merges_) { + VLOG(5) << "For merge: " << m->DebugString() << " " + << cond_state_map_->CondStateToString(m); + for (auto e : m->in_edges()) { + if (e->IsControlEdge()) continue; + BranchType branch = find_branch(e); + TF_RET_CHECK(branch == BranchType::kThenBranch || + branch == BranchType::kElseBranch) + << "Error: " << e->src()->name() + << " is not on either then or else branch (" << Branch_Name(branch) + << ")."; + Node* src = e->src(); + if (IsSwitch(src)) { + // Switch node outputs and dependencies are handled separately. + TF_RETURN_IF_ERROR(AddSwitch(src)); + } else { + stacks[static_cast(branch)].push_back(src); + } + } + } + + for (auto branch : {BranchType::kElseBranch, BranchType::kThenBranch}) { + int branch_index = static_cast(branch); + auto output = bodies_[branch_index].get(); + auto& stack = stacks[branch_index]; + VLOG(5) << "In branch: " << Branch_Name(branch) << " " + << NodesToString(stack); + std::vector visited(graph->num_node_ids(), false); + node_maps_[branch_index].resize(graph->num_node_ids(), nullptr); + auto& node_map = node_maps_[branch_index]; + + while (!stack.empty()) { + Node* n = stack.back(); + stack.pop_back(); + + if (visited.at(n->id())) continue; + visited[n->id()] = true; + + // Verify output edges and record control edges exitting scope. + for (const Edge* e : n->out_edges()) { + Node* dst = e->dst(); + if (IsMerge(dst)) continue; + Node* src = e->src(); + + auto dst_id = cond_state_map_->LookupId(dst); + auto src_id = cond_state_map_->LookupId(src); + if (dst_id != src_id) { + if (e->IsControlEdge()) { + external_control_outputs_.push_back(e->src()); + } else { + // Constants are treated specially to workaround the case of + // non-dominated constant nodes. + if (!IsConstant(src)) { + // TODO(b/78882471): A node that feeds into two different + // CondState is not necessarily an error so log a warning for now + // but revisit to improve the testing to enable making this an + // error. + LOG(WARNING) << errors::InvalidArgument( + "Graph contains node ", src->name(), " that feeds into node ", + dst->name(), + " but these nodes are in different control contexts (", + DebugString(src_id), " vs ", DebugString(dst_id), + " (detected during out edge testing)"); + } + } + } + } + + // Copying incomming edges to dst node. + for (const Edge* e : n->in_edges()) { + Node* src = e->src(); + // Skip src/dst node. + if (!src->IsOp()) continue; + + Node* dst = e->dst(); + if (IsSwitch(src)) { + // Switch node outputs and dependencies are handled separately. + TF_RETURN_IF_ERROR(AddSwitch(src)); + continue; + } + + // Verify input is from the same context. + auto src_id = cond_state_map_->LookupId(src); + auto dst_id = cond_state_map_->LookupId(dst); + if (IsMerge(dst) || src_id == dst_id) { + // TODO(jpienaar): The merge case can be more strict. + if (node_map.at(src->id()) == nullptr) { + node_map.at(src->id()) = output->CopyNode(src); + stack.push_back(src); + } + } else if (e->IsControlEdge()) { + external_control_inputs_.push_back(src); + } else { + // This shouldn't happen, this means we have an external data input + // not entering via a switch node. Work around this for constant + // nodes as some constant nodes are inserted without the required + // control context dominance. + if (IsConstant(src)) { + node_map.at(src->id()) = output->CopyNode(src); + } else { + return errors::InvalidArgument( + "Graph contains node ", src->name(), " that feeds into node ", + dst->name(), + " but these nodes are in different control contexts (", + DebugString(src_id), " vs ", DebugString(dst_id), + " (detected during in edge testing)"); + } + } + + Node* src_copy = node_map.at(e->src()->id()); + int src_output = e->src_output(); + if (node_map.at(dst->id()) == nullptr) { + node_map.at(dst->id()) = output->CopyNode(dst); + } + Node* dst_copy = node_map.at(e->dst()->id()); + if (e->IsControlEdge()) { + // Skip control inputs from external context. + if (src_copy != nullptr) output->AddControlEdge(src_copy, dst_copy); + } else { + output->AddEdge(src_copy, src_output, dst_copy, e->dst_input()); + } + } + } + } + + // Build return values from the merge nodes. + int index = 0; + for (Node* m : merges_) { + for (auto branch : {BranchType::kElseBranch, BranchType::kThenBranch}) { + int branch_index = static_cast(branch); + auto& node_map = node_maps_[branch_index]; + auto output = bodies_[branch_index].get(); + TF_ASSIGN_OR_RETURN(node_map[m->id()], + BuildRetvalNode(output, m->output_type(0), index)); + } + ++index; + + // Connect the input to the merge_ with the retval, except if it is a + // Swich node, which is handled separately. + for (auto e : m->in_edges()) { + if (e->IsControlEdge()) continue; + int branch_index = static_cast(find_branch(e)); + auto& node_map = node_maps_[branch_index]; + auto output = bodies_[branch_index].get(); + Node* in = e->src(); + if (!IsSwitch(in)) { + if (node_map.at(in->id()) == nullptr) { + node_map[in->id()] = output->CopyNode(in); + } + output->AddEdge(node_map[in->id()], e->src_output(), + node_map.at(m->id()), 0); + } + } + } + return Status::OK(); +} + +Status Conditional::BuildIfNode(Graph* graph, + FunctionLibraryDefinition* library) { + VLOG(2) << "Build cond function for " << name(); + NodeDefBuilder builder(name(), "If"); + const string branch_name[] = {"else_branch", "then_branch"}; + for (auto branch : {BranchType::kElseBranch, BranchType::kThenBranch}) { + int branch_index = static_cast(branch); + static std::atomic sequence_num(0LL); + int64 id = ++sequence_num; + + NameAttrList body_name; + body_name.set_name(strings::StrCat("_functionalize_if_", + branch_name[branch_index], "_", id)); + + VLOG(3) << "FunctionalizeControlFlow (" << branch_name[branch_index] + << "): " + << dump_graph::DumpGraphToFile( + "functionalize_cond_body_" + branch_name[branch_index], + *bodies_[branch_index], nullptr); + + FunctionDef body_fdef; + TF_RETURN_IF_ERROR(GraphToFunctionDef(*bodies_[branch_index], + body_name.name(), &body_fdef)); + TF_RETURN_IF_ERROR(library->AddFunctionDef(body_fdef)); + builder.Attr(branch_name[branch_index], body_name); + } + + VLOG(3) << "Build input type"; + std::vector inputs; + DataTypeVector in_arg_types; + for (auto& kv : cond_arg_nodes_) { + bool inserted = false; + for (const Node* arg : kv.switches) { + const Edge* in_edge; + TF_RETURN_IF_ERROR(arg->input_edge(0, &in_edge)); + if (in_edge->IsControlEdge()) { + builder.ControlInput(in_edge->src()->name()); + } else { + if (!inserted) { + DataType dtype = arg->input_type(0); + inputs.emplace_back(NodeDefBuilder::NodeOut( + in_edge->src()->name(), in_edge->src_output(), dtype)); + in_arg_types.push_back(dtype); + inserted = true; + } + } + } + } + builder.Attr("Tin", in_arg_types); + + DataTypeVector out_type; + for (const Node* merge : merges_) { + DataType dtype = merge->output_type(0); + out_type.push_back(dtype); + } + builder.Attr("Tout", out_type); + VLOG(3) << "Build output type: " << DataTypeVectorString(out_type); + + builder.Attr("Tcond", DT_BOOL); + builder.Device(predicate_.node->assigned_device_name()); + // Conditional should be the first input ... + builder.Input(NodeDefBuilder::NodeOut(predicate_.node->name(), + predicate_.index, + predicate_.node->output_type(0))); + // ... followed by the other inputs. + builder.Input(inputs); + + VLOG(3) << "Build If node"; + NodeDef if_def; + TF_RETURN_IF_ERROR(builder.Finalize(&if_def)); + TF_ASSIGN_OR_RETURN(if_node_, parent_->AddIfNode(if_def, *merges_.begin())); + + return Status::OK(); +} + +Status Conditional::AddInputEdges(Graph* graph) { + VLOG(2) << "AddInputEdges for " << if_node_->name(); + int index = 0; + // Add predicate input. + graph->AddEdge(const_cast(predicate_.node), predicate_.index, if_node_, + index++); + // Add function body inputs. + for (auto& arg : cond_arg_nodes_) { + if (arg.src_output == Graph::kControlSlot) { + graph->AddControlEdge(arg.src, if_node_); + } else { + graph->AddEdge(arg.src, arg.src_output, if_node_, index++); + } + } + for (Node* n : external_control_inputs_) { + graph->AddControlEdge(n, if_node_); + } + return Status::OK(); +} + +Status Conditional::AddOutputEdges(Graph* graph) { + VLOG(2) << "AddOutputEdges for " << if_node_->name(); + int i = 0; + for (Node* node : merges_) { + TF_RETURN_IF_ERROR(parent_->AddIdentityNode(node, if_node_, i)); + std::vector edges(node->out_edges().begin(), + node->out_edges().end()); + for (const Edge* edge : edges) { + Node* dst = edge->dst(); + int dst_input = edge->dst_input(); + if (edge->src_output() > 0) { + return errors::Unimplemented("Output of index (", edge->src_output(), + ") of merge node ", node->name()); + } + + bool control_edge = edge->IsControlEdge(); + graph->RemoveEdge(edge); + if (control_edge) { + graph->AddControlEdge(if_node_, dst); + } else { + graph->AddEdge(if_node_, i, dst, dst_input); + } + } + ++i; + } + for (Node* n : external_control_outputs_) { + graph->AddControlEdge(if_node_, n); + } + + return Status::OK(); +} + +Status Conditional::BuildAndReplace(Graph* graph, + FunctionLibraryDefinition* library) { + VLOG(1) << "Build If and replace merge nodes " << name(); + if (replaced_) return Status::OK(); + + TF_RETURN_IF_ERROR(ExtractBodies(graph)); + TF_RETURN_IF_ERROR(BuildArgumentNodes()); + + if (VLOG_IS_ON(3)) { + LOG(INFO) << "Extracted bodies:"; + for (auto branch : {BranchType::kElseBranch, BranchType::kThenBranch}) { + int branch_index = static_cast(branch); + auto output = bodies_[branch_index].get(); + LOG(INFO) << Branch_Name(branch) << ": " + << DebugString(output->ToGraphDefDebug()); + } + } + + TF_RETURN_IF_ERROR(BuildIfNode(graph, library)); + TF_RETURN_IF_ERROR(AddInputEdges(graph)); + TF_RETURN_IF_ERROR(AddOutputEdges(graph)); + TF_RETURN_IF_ERROR(parent_->PropagateUpdatedState(if_node_)); + for (Node* m : merges_) cond_state_map_->MarkDead(m); + + // Check that the if_node doesn't feed into itself. + TF_RETURN_WITH_CONTEXT_IF_ERROR( + CheckNodeNotInCycle(if_node_, graph->num_node_ids()), + "Converting to If failed."); + + replaced_ = true; + return Status::OK(); +} + +string Conditional::name() const { + CHECK(!merges_.empty()); + return strings::StrCat((*merges_.begin())->name(), "_if"); +} + +bool CondStateMap::ScopeIn(CondStateMap::CondId id, + CondStateMap::CondId* scope) { + if (id == nullptr) { + *scope = nullptr; + return true; + } + CondState state; + for (const CondNode& node : *id) { + if (node.type == CondNode::Type::kSwitch) { + state.push_back(node); + } + if (node.type == CondNode::Type::kMerge) { + if (state.empty()) { + return false; + } + DCHECK(state.back().type == CondNode::Type::kSwitch && + state.back().branch == BranchType::kBoth); + state.pop_back(); + } + } + *scope = GetUniqueId(state); + return true; +} + +Status FunctionalizeCond::AddIdentityNode(const Node* replacee, Node* if_node, + int port) { + Node* id; + TF_RETURN_IF_ERROR(NodeBuilder(replacee->name(), "Identity") + .Input(if_node, port) + .Finalize(graph_, &id)); + cond_state_map_.ResetId(id, cond_state_map_.LookupId(if_node)); + return Status::OK(); +} + +StatusOr FunctionalizeCond::AddIfNode(const NodeDef& def, + const Node* replacee) { + Status status; + Node* ret = graph_->AddNode(def, &status); + TF_RETURN_IF_ERROR(status); + CondStateMap::CondState state = cond_state_map_.LookupState(replacee); + state.pop_back(); + VLOG(1) << "Adding If for " << replacee->name(); + cond_state_map_.ResetId(ret, cond_state_map_.GetUniqueId(state)); + return ret; +} + +Status FunctionalizeCond::PropagateUpdatedState(const Node* replacee) { + VLOG(2) << "Propagating update state for " << replacee->name() << " " + << cond_state_map_.CondStateToString(replacee); + // Redo topological sort as the order could have changed. + // TODO(jpienaar): The original topological order could also be updated + // dynamically if needed. + std::vector rev_topo_order; + GetPostOrder(*graph_, &rev_topo_order); + + // All the outputs of the new node could potentially be updated. + std::unordered_set changed; + for (auto n : replacee->out_nodes()) + if (n->IsOp()) changed.insert(n); + + // Iterate through the changed/possible changed nodes in topological order. + for (auto it = rev_topo_order.rbegin(); + it != rev_topo_order.rend() && !changed.empty(); ++it) { + if (changed.find(*it) != changed.end()) { + // Update the node state. + Node* n = *it; + CondStateMap::CondId old_state = cond_state_map_.LookupId(n); + cond_state_map_.ResetId(n, nullptr); + TF_RETURN_IF_ERROR(DetermineCondState(n)); + if (cond_state_map_.LookupId(n) != old_state) { + for (auto out : n->out_nodes()) + if (out->IsOp()) changed.insert(out); + } + changed.erase(n); + } + } + return Status::OK(); +} + +// Returns the most restrictive branch of two branches or neither. This is the +// meet operator of the BranchType lattice. +BranchType MeetBranch(const BranchType& lhs, const BranchType& rhs) { + if (lhs == rhs) return lhs; + if (lhs == BranchType::kNeither) return rhs; + if (rhs == BranchType::kNeither) return lhs; + if (lhs == BranchType::kBoth) return rhs; + if (rhs == BranchType::kBoth) return lhs; + return BranchType::kNeither; +} + +CondStateMap::ContainsResult CondStateMap::LhsHoldsWhereverRhsHolds( + CondStateMap::CondId lhs, CondStateMap::CondId rhs) { + CondId lhs_scope; + CondId rhs_scope; + bool could_determine_scope = ScopeIn(lhs, &lhs_scope); + could_determine_scope = could_determine_scope && ScopeIn(rhs, &rhs_scope); + if (!could_determine_scope) return kIncomparable; + + // Returns whether a contains b. + auto contains = [&](CondId a, CondId b) { + // Handle empty states. + if (a == nullptr && b != nullptr) return true; + if (a == nullptr && b == nullptr) return true; + if (a != nullptr && b == nullptr) return false; + + if (a->size() > b->size()) return false; + auto a_it = a->begin(); + auto b_it = b->begin(); + while (a_it != a->end()) { + if (*a_it != *b_it) { + if (!(a_it->predicate == b_it->predicate)) return false; + BranchType mb = MeetBranch(a_it->branch, b_it->branch); + if (mb != b_it->branch) return false; + } + ++a_it; + ++b_it; + } + return true; + }; + + bool lhs_contains_rhs = contains(lhs_scope, rhs_scope); + bool rhs_contains_lhs = contains(rhs_scope, lhs_scope); + if (lhs_contains_rhs && rhs_contains_lhs) return kEqual; + if (lhs_contains_rhs) return kLhsContainsRhs; + if (rhs_contains_lhs) return kRhsContainsLhs; + return kIncomparable; +} + +BranchType CondStateMap::FindBranchOf(CondId id, OutputTensor predicate) const { + if (IsEmpty(id)) return BranchType::kNeither; + gtl::optional b; + const CondState& nodes = *id; + for (auto it = nodes.rbegin(); it != nodes.rend(); ++it) { + if (it->type == CondStateMap::CondNode::Type::kSwitch && + it->predicate == predicate) { + if (b.has_value()) { + b = MeetBranch(*b, it->branch); + } else { + b = it->branch; + } + if (*b == BranchType::kNeither) { + LOG(FATAL) << "Inconsistent state for node: " << DebugString(id); + } + } + } + return b.has_value() ? *b : BranchType::kNeither; +} + +StatusOr FunctionalizeCond::JoinCondStatesNonMerge( + CondStateMap::CondId src, CondStateMap::CondId dst) { + VLOG(4) << "Joining src=" << DebugString(src) << " [" << src + << "] and dst=" << DebugString(dst) << " [" << dst << "]"; + + if (cond_state_map_.IsEmpty(dst) || cond_state_map_.IsDead(src)) return src; + if (cond_state_map_.IsDead(dst)) return dst; + + // Nothing to do if the CondState is the same. + if (src == dst) return src; + + CondStateMap::CondId src_scope; + CondStateMap::CondId dst_scope; + if (!cond_state_map_.ScopeIn(src, &src_scope)) + return errors::Unimplemented( + "Predicates that must hold for node to execute are invalid! ", + DebugString(src)); + if (!cond_state_map_.ScopeIn(dst, &dst_scope)) + return errors::Unimplemented( + "Predicates that must hold for node to execute are invalid! ", + DebugString(dst)); + + auto result = cond_state_map_.LhsHoldsWhereverRhsHolds(src_scope, dst_scope); + switch (result) { + case CondStateMap::kIncomparable: + return errors::InvalidArgument( + "Graph contains node with inputs predicated on incompatible " + "predicates: ", + DebugString(src), " and ", DebugString(dst)); + case CondStateMap::kEqual: + // If both respect the same predicates, propagate the longer constraint. + if ((src != nullptr && dst == nullptr) || + (src != nullptr && dst != nullptr && src->size() > dst->size())) + return src; + else + return dst; + case CondStateMap::kLhsContainsRhs: + // src contains dst, so dst is already more restrictive. + return dst; + case CondStateMap::kRhsContainsLhs: + // dst contains src, so src is more restrictive. + return src; + } +} + +StatusOr +FindThenElseSwitchForPredicate(const OutputTensor& pred, + CondStateMap::CondId id) { + for (auto it = id->begin(); it != id->end(); ++it) { + // Along every path one there can be only one instance of a then or else + // switch for a given predicate, so return once found. + if (it->type == CondStateMap::CondNode::Type::kSwitch && + it->predicate == pred && + (it->branch == BranchType::kThenBranch || + it->branch == BranchType::kElseBranch)) + return it; + } + return errors::Internal("Unable to find then/else branch with predicate ", + DebugString(pred), " for ", DebugString(id)); +} + +StatusOr FunctionalizeCond::JoinCondStatesMerge( + CondStateMap::CondId src, CondStateMap::CondId dst) { + // Determine the flow state when joining two states for a merge + // node. Combining the two states for a merge node is effectively performing a + // disjunction of the states along the different input edges. For a merge that + // can be transformed into a If the two inputs paths have to have a predicate + // on which they differ (e.g., along one edge predicate `p` has to hold while + // on another it should not). This function first determines this predicate + // and then the resultant state is the common path between the two inputs + // followed by s(p, both). + VLOG(4) << "Joining (for merge) " << DebugString(src) << " and " + << DebugString(dst); + if (cond_state_map_.IsEmpty(dst)) return src; + + if (cond_state_map_.IsDead(src)) return src; + if (cond_state_map_.IsDead(dst)) return dst; + + CondStateMap::CondId src_scope; + CondStateMap::CondId dst_scope; + if (!cond_state_map_.ScopeIn(src, &src_scope)) + return errors::Unimplemented( + "Predicates that must hold for node to execute are invalid! ", + DebugString(src)); + if (!cond_state_map_.ScopeIn(dst, &dst_scope)) + return errors::Unimplemented( + "Predicates that must hold for node to execute are invalid! ", + DebugString(dst)); + + TF_RET_CHECK(src_scope != nullptr && dst_scope != nullptr) + << "Illegal merge inputs from outer scope: src=" << DebugString(src) + << " dst=" << DebugString(dst); + auto src_it = src_scope->begin(); + auto dst_it = dst_scope->begin(); + + // Find branch divergent condition. + OutputTensor pred; + while (src_it != src_scope->end() && dst_it != dst_scope->end()) { + if (*src_it != *dst_it) { + VLOG(5) << "Diverges with: " << DebugString(*src_it) << " and " + << DebugString(*dst_it); + if (!(src_it->predicate == dst_it->predicate)) { + return errors::InvalidArgument( + "Unable to find common predicate which holds for one input " + "but not the other of the merge node."); + } + pred = src_it->predicate; + break; + } + ++src_it; + ++dst_it; + } + + if (pred.node == nullptr) + return errors::InvalidArgument("Unable to determine predicate for merge."); + + TF_ASSIGN_OR_RETURN(auto div_src_it, + FindThenElseSwitchForPredicate(pred, src)); + TF_ASSIGN_OR_RETURN(auto div_dst_it, + FindThenElseSwitchForPredicate(pred, dst)); + TF_RET_CHECK(*div_src_it != *div_dst_it); + + CondStateMap::CondState result; + // Populate result with the longest/most restrictive path up to the divergent + // node. For example, if the one input is `[switch(pred:0, then)]` and the + // other is `[switch(pred:0, both), merge, switch(pred:0, else)]` (as created + // in gradient of cond test), then the resultant state here should be + // `[switch(pred:0, both), merge, switch(pred:0, both)]`. + if (std::distance(src->begin(), div_src_it) > + std::distance(dst->begin(), div_dst_it)) { + result.assign(src->begin(), std::next(div_src_it)); + } else { + result.assign(dst->begin(), std::next(div_dst_it)); + } + result.back().branch = BranchType::kBoth; + return cond_state_map_.GetUniqueId(result); +} + +CondStateMap::CondId FunctionalizeCond::StateAlongEdge(const Edge* e) { + Node* src = e->src(); + CondStateMap::CondId id = cond_state_map_.LookupId(e->src()); + if (IsMerge(src)) { + CondStateMap::CondState state; + if (id != nullptr) state = *id; + state.emplace_back(CondStateMap::CondNode::Type::kMerge); + return cond_state_map_.GetUniqueId(state); + } + if (IsSwitch(src)) { + CondStateMap::CondState state; + if (id != nullptr) state = *id; + if (e->IsControlEdge()) { + state.emplace_back(CondStateMap::CondNode::Type::kSwitch, src, + BranchType::kBoth); + } else { + state.emplace_back(CondStateMap::CondNode::Type::kSwitch, src, + BranchType(e->src_output())); + } + return cond_state_map_.GetUniqueId(state); + } + return id; +} + +Status FunctionalizeCond::DetermineCondStateMerge(Node* dst) { + // Only Merge nodes with two inputs are supported, but if this is a redundant + // merge, then the dead edge may already have been removed (if due to a + // switch) and so the input count would be incorrect. + if (cond_state_map_.IsDead(cond_state_map_.LookupId(dst))) + return Status::OK(); + + int data_inputs = 0; + for (auto e : dst->in_edges()) { + Node* src = e->src(); + VLOG(5) << "Processing forward flow for merge: " << e->DebugString() << " " + << cond_state_map_.CondStateToString(src); + if (!src->IsOp()) continue; + if (!e->IsControlEdge()) ++data_inputs; + + CondStateMap::CondId prop = StateAlongEdge(e); + auto id_or = JoinCondStatesMerge(prop, cond_state_map_.LookupId(dst)); + TF_RETURN_WITH_CONTEXT_IF_ERROR(id_or.status(), "for node ", dst->name()); + cond_state_map_.ResetId(dst, id_or.ValueOrDie()); + } + + // Incomplete Merge nodes are not supported. + if (data_inputs != 2) { + return errors::Unimplemented( + dst->name(), " only has ", data_inputs, + " inputs, while only merge nodes with two inputs supported."); + } + return Status::OK(); +} + +Status FunctionalizeCond::DetermineCondState(Node* dst) { + // The logic for the merge and non-merge case differ: for non-merge it is + // the most restrictive CondState, while for merge nodes the + // resultant state is less restrictive than either. + if (IsMerge(dst)) { + TF_RETURN_IF_ERROR(DetermineCondStateMerge(dst)); + } else { + // Handle non-merge join. + for (auto e : dst->in_edges()) { + VLOG(5) << "Processing forward flow for: " << e->DebugString() << " " + << cond_state_map_.CondStateToString(dst); + Node* src = e->src(); + if (!src->IsOp()) continue; + + // Joining the state between the current and propagated state. + CondStateMap::CondId prop = StateAlongEdge(e); + auto id_or = JoinCondStatesNonMerge(prop, cond_state_map_.LookupId(dst)); + TF_RETURN_WITH_CONTEXT_IF_ERROR(id_or.status(), "for node ", dst->name()); + cond_state_map_.ResetId(dst, id_or.ValueOrDie()); + } + } + return Status::OK(); +} + +Status FunctionalizeCond::RemoveRedundantMerge(Node* node) { + // Handle redundant merge nodes. A merge node is considered redundant if + // one input edge is dead while the other has a value. + if (!cond_state_map_.IsDead(cond_state_map_.LookupId(node))) + return Status::OK(); + + const Edge* non_dead_edge = nullptr; + for (auto e : node->in_edges()) { + if (e->IsControlEdge()) continue; + Node* src = e->src(); + + // Handle merge with dead state. + const auto& src_id = cond_state_map_.LookupId(src); + if (!cond_state_map_.IsDead(src_id)) { + non_dead_edge = e; + break; + } + } + + if (non_dead_edge == nullptr) { + return errors::InvalidArgument("Merge node ", node->name(), + " has no non-dead inputs."); + } + cond_state_map_.MarkDead(node); + delete_nodes_.push_back(node->id()); + VLOG(5) << "removing redundant merge: " << node->name(); + while (!node->out_edges().empty()) { + const Edge* oe = *node->out_edges().begin(); + Node* dst_node = oe->dst(); + int dst_port = oe->dst_input(); + graph_->RemoveEdge(oe); + graph_->AddEdge(non_dead_edge->src(), + dst_port == Graph::kControlSlot + ? Graph::kControlSlot + : non_dead_edge->src_output(), + dst_node, dst_port); + } + return Status::OK(); +} + +Status FunctionalizeCond::RemoveRedundantSwitch(Node* node) { + // Handle redundant switch nodes. A switch node is considered redundant if + // the predicate of the switch already holds on the current branch. E.g., if + // p is the predicate of the switch but p is already known to hold on this + // branch, then the switch can be removed and the dead state propagated + // along one. The checking of predicate is based on the exact predicate + // (rather than boolean equivalence) and aimed at redundant switches as + // currently generated by gradient code. + OutputTensor pred; + TF_RETURN_IF_ERROR(GetSwitchPredicate(*node, &pred)); + auto dst_id = cond_state_map_.LookupId(node); + BranchType b = cond_state_map_.FindBranchOf(dst_id, pred); + // Determine if we are already on a branch where the switch predicate is + // true/false. + if (b != BranchType::kThenBranch && b != BranchType::kElseBranch) + return Status::OK(); + + VLOG(5) << "Redundant switch " << node->name(); + const Edge* value_edge; + TF_RETURN_IF_ERROR(node->input_edge(0, &value_edge)); + Node* val_node = value_edge->src(); + int val_port = value_edge->src_output(); + while (!node->out_edges().empty()) { + auto e = *node->out_edges().begin(); + Node* dst_node = e->dst(); + int dst_input = e->dst_input(); + int switch_branch = e->src_output(); + graph_->RemoveEdge(e); + if (switch_branch == Graph::kControlSlot) { + if (IsMerge(dst_node)) { + auto id_or = + JoinCondStatesMerge(dst_id, cond_state_map_.LookupId(dst_node)); + TF_RETURN_IF_ERROR(id_or.status()); + cond_state_map_.ResetId(dst_node, id_or.ValueOrDie()); + } else { + auto id_or = + JoinCondStatesNonMerge(dst_id, cond_state_map_.LookupId(dst_node)); + TF_RETURN_IF_ERROR(id_or.status()); + cond_state_map_.ResetId(dst_node, id_or.ValueOrDie()); + } + } else if (BranchType(switch_branch) != b) { + cond_state_map_.MarkDead(dst_node); + delete_nodes_.push_back(dst_node->id()); + continue; + } + graph_->AddEdge( + val_node, + switch_branch == Graph::kControlSlot ? Graph::kControlSlot : val_port, + dst_node, dst_input); + } + return Status::OK(); +} + +Status FunctionalizeCond::DetermineCondStates( + std::vector rev_topo_order) { + // The state that is propagated along the given edge. + for (auto it = rev_topo_order.rbegin(); it != rev_topo_order.rend(); ++it) { + Node* dst = *it; + TF_RETURN_IF_ERROR(DetermineCondState(dst)); + if (IsSwitch(dst)) TF_RETURN_IF_ERROR(RemoveRedundantSwitch(dst)); + if (IsMerge(dst)) TF_RETURN_IF_ERROR(RemoveRedundantMerge(dst)); + + VLOG(5) << dst->name() << " :: " << cond_state_map_.CondStateToString(dst); + } + return Status::OK(); +} + +void FunctionalizeCond::DeleteReachableNodes() { + // Delete all nodes that have been extracted or are reachable from + // deleted/dead nodes. The input and outgoing edges should have already been + // removed. + std::vector deleted(graph_->num_node_ids(), false); + // Don't try to delete source or sink nodes. + deleted[graph_->kSourceId] = true; + deleted[graph_->kSinkId] = true; + while (!delete_nodes_.empty()) { + int d_id = delete_nodes_.front(); + delete_nodes_.pop_front(); + if (deleted[d_id]) continue; + Node* d = graph_->FindNodeId(d_id); + // Switch and Merge nodes could have been deleted already. + if (d == nullptr) continue; + for (const Edge* e : d->out_edges()) { + delete_nodes_.push_back(e->dst()->id()); + } + deleted[d_id] = true; + graph_->RemoveNode(d); + } +} + +void FunctionalizeCond::SortMergeNodes(std::vector* merge_order) { + // Sort merge nodes by nesting depth. + using sort_pair = std::pair; + std::vector inner_to_outer_merge_order; + inner_to_outer_merge_order.reserve(merge_order->size()); + for (auto it = merge_order->rbegin(); it != merge_order->rend(); ++it) { + Node* merge = *it; + CondStateMap::CondId id = cond_state_map_.LookupId(merge); + int depth = 0; + for (auto cond_node_it = id->begin(); cond_node_it != id->end(); + ++cond_node_it) { + if (cond_node_it->type == CondStateMap::CondNode::Type::kSwitch && + (cond_node_it->branch == BranchType::kThenBranch || + cond_node_it->branch == BranchType::kElseBranch)) { + ++depth; + } + } + inner_to_outer_merge_order.emplace_back(depth, merge); + } + std::stable_sort( + inner_to_outer_merge_order.begin(), inner_to_outer_merge_order.end(), + [](sort_pair lhs, sort_pair rhs) { return lhs.first > rhs.first; }); + merge_order->clear(); + for (sort_pair t : inner_to_outer_merge_order) { + merge_order->push_back(t.second); + } +} + +Status FunctionalizeCond::FunctionalizeInternal() { + // The general approach for converting a tf.cond (as lowered via switch/merge + // nodes) to a functional if is as follows: + // 1. Determine the topological order and collect all the switch and merge + // nodes in the graph; + // 2. Compute the predicates and dominance structure for all the nodes in the + // graph - this includes which predicate must be true for a op to execute + // (predicate values are considered directly rather than attempting to + // determine deeper equivalence). We shall refer to this structure as the + // CondState; + // 3. Sort the merge nodes by nesting depth; + // 4. Extract merge nodes together that have the same CondState and whose + // input nodes have the same state from the innermost to the outermost into + // IfOps; Note: In the above only nodes paths that converge to a merge node + // will be considered for removal. + + // Perform a DFS over the graph and + // * Determine the reverse topological order of the nodes (there should be no + // cycles at this point so the post-order numbering corresponds to the + // reverse topological sorting); + // * Record reverse topological for merge and switch nodes; + std::vector rev_topo_order; + std::vector switch_ids; + std::vector merge_order; + DFS(*graph_, nullptr, [&](Node* n) { + if (IsSwitch(n)) { + switch_ids.push_back(n->id()); + } + if (IsMerge(n)) { + merge_order.push_back(n); + } + if (n->IsOp()) { + rev_topo_order.push_back(n); + } + }); + + // No merges to functionalize. + if (merge_order.empty()) { + // No merges mean no switch values consumed (as only considering values + // fetchable as output of merge); + for (auto it = switch_ids.begin(); it != switch_ids.end(); ++it) { + graph_->RemoveNode(graph_->FindNodeId(*it)); + } + return Status::OK(); + } + + TF_RETURN_IF_ERROR(DetermineCondStates(std::move(rev_topo_order))); + + if (VLOG_IS_ON(4)) DumpGraphWithCondState("cond_id"); + + // Sort the merge nodes from innermost outwards. + SortMergeNodes(&merge_order); + + // Extract from innermost out. + for (auto it = merge_order.begin(); it != merge_order.end(); ++it) { + Node* merge = *it; + auto id = cond_state_map_.LookupId(merge); + if (cond_state_map_.IsDead(id)) continue; + + // Construct a Conditional with the predicate of the merge (which is the + // last entry of the CondState for the merge) and this as parent. + DCHECK(id->back().predicate.node != nullptr); + Conditional cond(id->back().predicate, this, &cond_state_map_); + TF_RETURN_IF_ERROR(cond.AddMerge(merge)); + + // Find all merge nodes with the same CondId. This is done repeatedly as + // the CondId can change due replaced conditionals. E.g., the one branch + // could previously have had a conditional nested in it, and so would have + // had CondState with sub-state [switch(p,b),m] (where p is some predicate), + // post removing the nested conditional that sub-state would no longer be + // path of the propagated state along that path. + auto end = merge_order.end(); + for (auto merge_candidate_it = std::next(it); merge_candidate_it != end; + ++merge_candidate_it) { + auto merge_candidate_it_id = + cond_state_map_.LookupId(*merge_candidate_it); + if (merge_candidate_it_id != id) continue; + TF_RETURN_IF_ERROR(cond.AddMerge(*merge_candidate_it)); + } + + TF_RETURN_IF_ERROR(cond.BuildAndReplace(graph_, library_)); + + if (VLOG_IS_ON(4)) DumpGraphWithCondState("after_extract"); + } + + // All remaining Switch nodes are not reachable from a Merge node and + // removed. This is to account for dead Switch nodes. + for (int s_id : switch_ids) delete_nodes_.push_back(s_id); + for (Node* m : merge_order) delete_nodes_.push_back(m->id()); + DeleteReachableNodes(); + + return Status::OK(); +} + +void FunctionalizeCond::DumpGraphWithCondState(const string& name) { + const char* const kCondGroupDebugAttr = "_XlaFunctionalizeCondGroup"; + + for (Node* n : graph_->nodes()) { + n->ClearAttr(kCondGroupDebugAttr); + n->AddAttr(kCondGroupDebugAttr, cond_state_map_.CondStateToString(n)); + } + LOG(INFO) << "FunctionalizeControlFlow (" << name << "): " + << dump_graph::DumpGraphToFile( + strings::StrCat("functionalize_", name), *graph_, library_); +} + +Status FunctionalizeCond::Functionalize(Graph* graph, + FunctionLibraryDefinition* library) { + VLOG(1) << "FunctionalizeCond::Functionalize"; + FunctionalizeCond fc(graph, library); + return fc.FunctionalizeInternal(); +} + +} // namespace functionalize_cond + +Status FunctionalizeCond(Graph* graph, FunctionLibraryDefinition* library) { + // FunctionalizeControlFlow is invoked for every function, so the loops's + // bodies and conditionals that were extracted into functions will be handled + // in successive invocations. + return functionalize_cond::FunctionalizeCond::Functionalize(graph, library); +} + +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/functionalize_cond.h b/tensorflow/compiler/tf2xla/functionalize_cond.h new file mode 100644 index 0000000000000000000000000000000000000000..86436011c6ebdc608a5811a1b0d6a10015d405bd --- /dev/null +++ b/tensorflow/compiler/tf2xla/functionalize_cond.h @@ -0,0 +1,248 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_TF2XLA_FUNCTIONALIZE_COND_H_ +#define TENSORFLOW_COMPILER_TF2XLA_FUNCTIONALIZE_COND_H_ + +#include +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/core/framework/function.h" +#include "tensorflow/core/graph/graph.h" + +namespace tensorflow { + +// Functionalize all the switch-merge nodes of a loop-free graph into If +// nodes. That is, attempt to transform every remaining switch and merge nodes +// in the graph into If nodes. +// Precondition: All while loops have been removed from graph. +Status FunctionalizeCond(Graph* graph, FunctionLibraryDefinition* library); + +// Internal functions/classes exposed for testing purposes. +namespace functionalize_cond { + +// All nodes are assumed to be either in no branch, then branch, else branch, +// or both branches (such as merge nodes). +// The code below relies on Else and Then being 0 and 1 (corresponding to the +// switch outputs). Both and Neither are arbitrary. +enum class BranchType { + kElseBranch = 0, + kThenBranch = 1, + kBoth = 2, + kNeither = 3, +}; + +// CondStateMap is responsible for mapping from each graph Node to a CondState, +// where each CondState is the array of CondNodes (corresponding to switch, +// merge or dead states) as described below. For efficiency, this class interns +// the CondState, so that CondState equality comparisons are simply pointer +// comparisons. +class CondStateMap { + public: + explicit CondStateMap(Graph* graph); + + // Represents an entry in the CondState. An entry can either be the + // switch (along with predicate), merge, or dead: + // * switch node indicates a node that is executed along a branch with the + // given predicate - a branch can be then, else or both; + // * merge node indicates that the node is executed as output of a merge; + // * dead indicates that this node can never be executed; + struct CondNode { + enum class Type { kSwitch = 1, kMerge = 2, kDead = 3 }; + + CondNode(Type type, Node* switch_node = nullptr, + BranchType branch = BranchType::kNeither); + + string ToString() const; + bool operator==(const CondNode& other) const; + bool operator!=(const CondNode& other) const; + + // Type of node. + Type type; + + // Predicate and branch, only used when type is kSwitch. + OutputTensor predicate; + BranchType branch; + }; + + // A node in the graph is executed when multiple conditions hold. The order + // represents the nesting of the predicates that hold and is used when + // extracting the nested conditionals. + using CondState = std::vector; + + // Every unique ID is mapped to a CondState. + using CondId = const CondState*; + + // Returns the CondId for a given node. + CondId LookupId(const Node* node) const; + + // Returns the unique CondId for CondState. + CondId GetUniqueId(const CondState& state); + + // Returns the CondState for a Node. + // REQUIRES: node has a non-empty CondState. + const CondState& LookupState(const Node* node) const; + + // Resets the CondId for a given node. + void ResetId(const Node* node, CondId id); + + // Marks `node` as dead. + void MarkDead(const Node* node); + + // Determine branch execution of CondState. + BranchType FindBranchOf(CondId id, OutputTensor predicate) const; + + // Enum to represent whether one cond flow state contains another. + enum ContainsResult { + kIncomparable, + kEqual, + kLhsContainsRhs, + kRhsContainsLhs + }; + + // Returns whether the lhs CondState holds wherever rhs CondState hols. I.e., + // [(p,t)] contains [(p,t), (r,t)]. + ContainsResult LhsHoldsWhereverRhsHolds(CondId lhs, CondId rhs); + + // Returns textual representation of node's CondState. + string CondStateToString(const Node* node) const; + string CondStateToString(CondId id) const; + + // Returns whether the cond state is the dead state. + bool IsDead(CondId id) const; + + // Returns whether the cond state is the empty state. + bool IsEmpty(CondId id) const; + + // Computes the predicates that have to hold for a node to execute and returns + // whether it was possible to determine the predicates that must hold. `scope` + // is populated with these predicates. Scope differs from state in that it + // does not include merge and both nodes. + bool ScopeIn(CondId id, CondId* scope); + + private: + // Hash for CondNode and CondState. + struct CondHash { + size_t operator()(const CondNode& item) const; + size_t operator()(const CondState& vec) const; + }; + + // Set to keep track of unique CondStates. + // Pointers to the entries in the unordered set are used as identifiers: + // unordered_set guarantees that the pointers remain the same. + std::unordered_set condstate_set_; + + // Mapping from Node id to CondId. + std::vector node_to_condid_map_; + + // Track the CondId for newly inserted nodes. We use a vector to quickly map + // from Node id in the original graph to the CondId, but there will be nodes + // added to the original graph (such as If nodes) whose CondState needs to be + // tracked too. + std::unordered_map added_node_mapping_; + + // Identifier of the dead flow state. The empty flow state is represented with + // a nullptr. + CondId dead_id_; +}; + +// FunctionalizeCond groups all the state used by functionalizing conditionals +// of the given graph together. +class FunctionalizeCond { + public: + // Functionalize all the switch-merge nodes of a loop-free graph into If + // nodes. That is, attempt to transform every remaining switch and merge nodes + // in the graph into If nodes. + // Precondition: All while loops have been removed from graph. + static Status Functionalize(Graph* graph, FunctionLibraryDefinition* library); + + // Build identity node with the same name as the merge that will be replaced + // in case the output is fetched/colocated. + Status AddIdentityNode(const Node* replacee, Node* if_node, int port); + + // Add a If node to the graph defined by def that will, amongst other, replace + // replacee in the graph. + xla::StatusOr AddIfNode(const NodeDef& def, const Node* replacee); + + // Propagates the state of a newly inserted node. + Status PropagateUpdatedState(const Node* replacee); + + // Dump graph with the CondState annotated. + void DumpGraphWithCondState(const string& name); + + private: + FunctionalizeCond(Graph* graph, FunctionLibraryDefinition* library); + + // Performs the actual cond functionalization. Iterate over groups of merge + // nodes (linked by common predicate & CondIds of the incomming edges), + // from innermost to outermost, and extract into If nodes. + Status FunctionalizeInternal(); + + // Returns the forward flow state propagated along edge `e`. + // This may modify cond_state_map_. + CondStateMap::CondId StateAlongEdge(const Edge* e); + + // Determines the CondState of all the nodes in the given vector where + // the input is expected in reverse topological order. + // This populates the cond_state_map_. + Status DetermineCondStates(std::vector rev_topo_order); + + // Determine the CondState for a given node using the incomming edges + // to the node. Note: it is expected that this node's CondState is only + // determined once its input's CondState is. + Status DetermineCondState(Node* dst); + + // Helper functions for DetermineCondState. + Status DetermineCondStateMerge(Node* dst); + + // Helper functions for DetermineCondStates. Determines the dst node's + // CondState by joining the src and dst's CondState where either + // the dst node is a merge or not. + // These may modify cond_state_map_. + xla::StatusOr JoinCondStatesMerge( + CondStateMap::CondId src, CondStateMap::CondId dst); + xla::StatusOr JoinCondStatesNonMerge( + CondStateMap::CondId src, CondStateMap::CondId dst); + + // Checks if a merge node is redundant and if so removes it from the graph. + Status RemoveRedundantMerge(Node* node); + + // Checks if a switch node is redundant and if so removes it from the graph. + Status RemoveRedundantSwitch(Node* node); + + // Sorts merge nodes (in reverse topological order) in order of increasing + // nesting depth. + void SortMergeNodes(std::vector* merge_order); + + // Deletes all nodes in/consumers of `delete_nodes_`. + void DeleteReachableNodes(); + + // Member used to unique the CondState to a unique CondId and keep track of + // CondState/CondId per Node. + CondStateMap cond_state_map_; + + // Nodes to be deleted. + std::deque delete_nodes_; + + FunctionLibraryDefinition* library_; + Graph* graph_; + + friend class FunctionalizeCondTest; +}; + +} // namespace functionalize_cond + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_TF2XLA_FUNCTIONALIZE_COND_H_ diff --git a/tensorflow/compiler/tf2xla/functionalize_cond_test.cc b/tensorflow/compiler/tf2xla/functionalize_cond_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..548c948d09562bf7b98cc0efb06a6aebda4382c3 --- /dev/null +++ b/tensorflow/compiler/tf2xla/functionalize_cond_test.cc @@ -0,0 +1,180 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Tests for the backward const analysis. + +#include "tensorflow/compiler/tf2xla/functionalize_cond.h" + +#include "tensorflow/cc/framework/ops.h" +#include "tensorflow/cc/ops/function_ops.h" +#include "tensorflow/cc/ops/standard_ops.h" +#include "tensorflow/core/graph/testlib.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { +namespace functionalize_cond { + +class FunctionalizeCondTest : public ::testing::Test { + protected: + FunctionalizeCondTest() { + graph_.reset(new Graph(OpRegistry::Global())); + flib_def_.reset( + new FunctionLibraryDefinition(OpRegistry::Global(), fdef_lib_)); + fc_.reset(new functionalize_cond::FunctionalizeCond(graph_.get(), + flib_def_.get())); + } + + CondStateMap::CondId GetUniqueId( + const CondStateMap::CondStateMap::CondState& state) { + return fc_->cond_state_map_.GetUniqueId(state); + } + + xla::StatusOr JoinCondStatesNonMerge( + CondStateMap::CondId src, CondStateMap::CondId dst) { + return fc_->JoinCondStatesNonMerge(src, dst); + } + + xla::StatusOr JoinCondStatesMerge( + CondStateMap::CondId src, CondStateMap::CondId dst) { + return fc_->JoinCondStatesMerge(src, dst); + } + + bool ScopeIn(CondStateMap::CondId ff, CondStateMap::CondId* scope) { + return fc_->cond_state_map_.ScopeIn(ff, scope); + } + + CondStateMap::ContainsResult LhsHoldsWhereverRhsHolds( + CondStateMap::CondId lhs, CondStateMap::CondId rhs) { + return fc_->cond_state_map_.LhsHoldsWhereverRhsHolds(lhs, rhs); + } + + FunctionDefLibrary fdef_lib_; + std::unique_ptr fc_; + std::unique_ptr flib_def_; + std::unique_ptr graph_; +}; + +namespace { + +TEST_F(FunctionalizeCondTest, ScopeIn) { + Tensor pred_tensor(DT_BOOL, TensorShape()); + Node* pred = test::graph::Constant(graph_.get(), pred_tensor, "pred"); + Tensor val_tensor(DT_INT32, TensorShape()); + Node* val = test::graph::Constant(graph_.get(), val_tensor, "val"); + Node* s = test::graph::Switch(graph_.get(), val, pred); + + { + CondStateMap::CondStateMap::CondState ss; + ss.emplace_back(CondStateMap::CondNode( + CondStateMap::CondNode::Type::kSwitch, s, BranchType::kThenBranch)); + CondStateMap::CondId id = GetUniqueId(ss); + CondStateMap::CondId scope; + ASSERT_TRUE(ScopeIn(id, &scope)); + ASSERT_TRUE(id == scope); + } + + CondStateMap::CondState empty; + { + CondStateMap::CondState ss; + ss.emplace_back(CondStateMap::CondNode( + CondStateMap::CondNode::Type::kSwitch, s, BranchType::kBoth)); + ss.emplace_back( + CondStateMap::CondNode(CondStateMap::CondNode::Type::kMerge)); + CondStateMap::CondId id = GetUniqueId(ss); + CondStateMap::CondId scope_1; + ASSERT_TRUE(ScopeIn(id, &scope_1)); + ASSERT_TRUE(scope_1 == GetUniqueId(empty)); + ASSERT_TRUE(id != scope_1); + + ss.clear(); + ss.emplace_back(CondStateMap::CondNode( + CondStateMap::CondNode::Type::kSwitch, s, BranchType::kBoth)); + id = GetUniqueId(ss); + CondStateMap::CondId scope_2; + ASSERT_TRUE(ScopeIn(id, &scope_2)); + + ASSERT_TRUE(LhsHoldsWhereverRhsHolds(scope_1, scope_2) == + CondStateMap::ContainsResult::kLhsContainsRhs); + } +} + +TEST_F(FunctionalizeCondTest, JoinCondStates) { + Tensor pred_tensor(DT_BOOL, TensorShape()); + Node* pred = test::graph::Constant(graph_.get(), pred_tensor, "pred"); + Tensor val_tensor(DT_INT32, TensorShape()); + Node* val = test::graph::Constant(graph_.get(), val_tensor, "val"); + Node* s = test::graph::Switch(graph_.get(), val, pred); + + CondStateMap::CondId empty = GetUniqueId({}); + + CondStateMap::CondId then_branch; + { + CondStateMap::CondState ss; + ss.emplace_back(CondStateMap::CondNode( + CondStateMap::CondNode::Type::kSwitch, s, BranchType::kThenBranch)); + then_branch = GetUniqueId(ss); + } + CondStateMap::CondId else_branch; + { + CondStateMap::CondState ss; + ss.emplace_back(CondStateMap::CondNode( + CondStateMap::CondNode::Type::kSwitch, s, BranchType::kElseBranch)); + else_branch = GetUniqueId(ss); + } + + // An non-merge op with inputs from then and else branch. + Status status = JoinCondStatesNonMerge(then_branch, else_branch).status(); + EXPECT_TRUE(errors::IsInvalidArgument(status)); + + // Merge between then and else branch. + auto joined_or = JoinCondStatesMerge(then_branch, else_branch); + TF_EXPECT_OK(joined_or.status()); + CondStateMap::CondId joined = joined_or.ValueOrDie(); + + // Merge between then branch and both branch. + auto t = JoinCondStatesNonMerge(then_branch, joined); + // Note: this is OK in terms of constraint predication, but + TF_EXPECT_OK(t.status()); + + // Post merge the propagated forward flow state has an additional merge. + CondStateMap::CondId post_merge; + { + CondStateMap::CondState ss; + ss = *joined; + ss.emplace_back( + CondStateMap::CondNode(CondStateMap::CondNode::Type::kMerge)); + post_merge = GetUniqueId(ss); + } + + t = JoinCondStatesNonMerge(post_merge, joined); + TF_EXPECT_OK(t.status()); + EXPECT_TRUE(joined == t.ValueOrDie()); + + // No predicate that results in two paths predicated on different conditions + // merge. + t = JoinCondStatesMerge(post_merge, joined); + EXPECT_FALSE(t.ok()); + + // Post the merge we are effectively in the root scope and merging should + // result in the more restrictive post merge state. + t = JoinCondStatesNonMerge(post_merge, empty); + TF_EXPECT_OK(t.status()); + EXPECT_TRUE(post_merge == t.ValueOrDie()); +} + +} // namespace +} // namespace functionalize_cond +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/functionalize_control_flow.cc b/tensorflow/compiler/tf2xla/functionalize_control_flow.cc index 6cc95149a16a59fce8486c5d103ad09e3e262765..188ada7255f0fc15f64f2a2b1a128637add2afe0 100644 --- a/tensorflow/compiler/tf2xla/functionalize_control_flow.cc +++ b/tensorflow/compiler/tf2xla/functionalize_control_flow.cc @@ -21,1437 +21,24 @@ limitations under the License. #include #include +#include "absl/memory/memory.h" #include "tensorflow/compiler/jit/union_find.h" #include "tensorflow/compiler/tf2xla/dump_graph.h" +#include "tensorflow/compiler/tf2xla/functionalize_cond.h" +#include "tensorflow/compiler/tf2xla/functionalize_control_flow_util.h" +#include "tensorflow/compiler/tf2xla/functionalize_while.h" #include "tensorflow/compiler/tf2xla/tf2xla_util.h" -#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/core/common_runtime/function.h" #include "tensorflow/core/framework/graph_to_functiondef.h" #include "tensorflow/core/framework/node_def_builder.h" #include "tensorflow/core/graph/algorithm.h" #include "tensorflow/core/graph/control_flow.h" +#include "tensorflow/core/graph/node_builder.h" #include "tensorflow/core/lib/gtl/optional.h" namespace tensorflow { -namespace { - -using xla::StatusOr; - -const char* const kArgOp = "_Arg"; -const char* const kRetValOp = "_Retval"; - -// Information about a loop argument. -struct Arg { - // Every loop argument has an Enter node. - Node* enter; - - // Is the loop argument a loop-invariant value? Taken from the `is_constant` - // attribute on the Enter node. - bool is_loop_invariant; - - // If 'is_loop_invariant' is true, the following are all nullptr. Non-constant - // arguments must have all of the following nodes: - Node* merge = nullptr; - Node* switch_node = nullptr; - Node* next_iteration = nullptr; - Node* exit = nullptr; -}; - -// Information about a loop frame. -struct Frame { - string name; - - // Pointer to the parent frame. The root frame has a pointer to itself. - Frame* parent = nullptr; - int num_children = 0; - - // Arguments to this loop. - std::vector args; - - // The loop condition of the loop. There should be exactly one loop condition - // in every loop. - Node* loop_cond = nullptr; - - // Set of nodes that belong to the loop frame. - std::unordered_set nodes; -}; - -// Comparison function used for sorting nodes consistently. -// a) resource variables are last, and -// b) sort lexicographically by name (for deterministic output). -struct NodeCmp { - bool operator()(const Node* lhs, const Node* rhs) const { - bool lhs_is_resource = - lhs->num_inputs() > 0 ? (lhs->input_type(0) == DT_RESOURCE) : false; - bool rhs_is_resource = - rhs->num_inputs() > 0 ? (rhs->input_type(0) == DT_RESOURCE) : false; - return std::tie(lhs_is_resource, lhs->name()) < - std::tie(rhs_is_resource, rhs->name()); - } -}; - -// Returns a textual representation of the names of the nodes in the input. -template -string NodesToString(const T& nodes) { - return strings::StrCat("{", - str_util::Join(nodes, ",", - [](string* output, const Node* node) { - strings::StrAppend(output, - node->name()); - }), - "}"); -} - -// Copies a subgraph from `graph` to `output` by performing a reverse DFS -// starting at nodes in vector `stack`. -// `node_map` is a vector indexed by source node ID to dest nodes. -// Does not traverse into nodes in `node_map`, so by adding nodes to `node_map` -// before the traversal clients can cut the graph. If a frame is provided (frame -// != nullptr), then this functions will return an error if the -// traversal leaves 'frame'; the client must add enough nodes to `node_map` to -// cut the graph and prevent the traversal from escaping. -// -// `squash_src_outputs` contains a bool for each source node ID. If true, then -// the source output on that node will be replaced by zero when copied. This is -// used when replacing a Switch node with an _Arg node. The output we are -// taking from the Switch node was not necessarily the first output, but _Arg -// nodes only have one output. By adding the Switch node to `squash_src_outputs` -// we rewrite the src_output of the corresponding edge to be 0. -Status CopySubgraph(const Graph& graph, const Frame* frame, - std::vector stack, - const std::vector& squash_src_outputs, - std::vector* node_map, Graph* output) { - VLOG(3) << "Stack: " << NodesToString(stack); - std::vector visited(graph.num_node_ids(), false); - while (!stack.empty()) { - Node* n = stack.back(); - stack.pop_back(); - - VLOG(5) << "Copying node " << n->name(); - - if (visited[n->id()]) continue; - visited[n->id()] = true; - - for (const Edge* e : n->in_edges()) { - Node* src = e->src(); - if (frame != nullptr && frame->nodes.find(src) == frame->nodes.end()) { - // We traversed out of the loop frame, without encountering a cut node. - return errors::Internal("Graph traversal of loop frame ", frame->name, - " escaped frame at ", src->name(), - " without encountering an argument node."); - } - if ((*node_map)[src->id()] == nullptr) { - (*node_map)[src->id()] = output->CopyNode(src); - stack.push_back(src); - } - Node* src_copy = (*node_map)[e->src()->id()]; - int src_output = squash_src_outputs[e->src()->id()] && !e->IsControlEdge() - ? 0 - : e->src_output(); - Node* dst_copy = (*node_map)[e->dst()->id()]; - output->AddEdge(src_copy, src_output, dst_copy, e->dst_input()); - } - } - return Status::OK(); -} - -StatusOr AddNode(const NodeDef& node_def, Graph* graph) { - Status status; - Node* inserted_node = graph->AddNode(node_def, &status); - if (!status.ok()) { - return status; - } - return inserted_node; -} - -// Check that the graph has no cycle containing the given node. -Status CheckNoCycleContains(const Node* node, const int num_nodes) { - std::vector ready; - ready.push_back(node); - std::vector visited(num_nodes); - while (!ready.empty()) { - const Node* current_node = ready.back(); - ready.pop_back(); - visited[current_node->id()] = true; - for (const Edge* out : current_node->out_edges()) { - if (out->dst() == node) { - return errors::Internal("Detect a cycle: Node \"", node->name(), "\"(", - node->def().op(), ") feeds into itself."); - } else if (!visited[out->dst()->id()]) { - ready.push_back(out->dst()); - } - } - } - return Status::OK(); -} - -StatusOr BuildArgNode(Graph* graph, DataType type, int index) { - NodeDef arg_def; - NodeDefBuilder builder(strings::StrCat(kArgOp, index), kArgOp); - builder.Attr("T", type); - builder.Attr("index", index); - TF_RETURN_IF_ERROR(builder.Finalize(&arg_def)); - return AddNode(arg_def, graph); -} - -StatusOr BuildRetvalNode(Graph* graph, DataType type, int index) { - NodeDef ret_def; - ret_def.set_op(kRetValOp); - ret_def.set_name(strings::StrCat(kRetValOp, index)); - AddNodeAttr("T", type, &ret_def); - AddNodeAttr("index", index, &ret_def); - return AddNode(ret_def, graph); -} - -// Builds a graph for the loop condition. -Status BuildLoopCondition(const Graph& graph, Frame* frame, - std::unique_ptr* cond_output) { - VLOG(2) << "Building loop condition for " << frame->name; - *cond_output = xla::MakeUnique(graph.op_registry()); - Graph* output = cond_output->get(); - - // Map from nodes in the original graph to the condition graph. - std::vector node_map(graph.num_node_ids(), nullptr); - std::vector squash_src_outputs(graph.num_node_ids(), false); - - // Build one _Arg node for each Enter node. - for (int i = 0; i < frame->args.size(); ++i) { - const Arg& arg = frame->args[i]; - - TF_ASSIGN_OR_RETURN(Node * arg_node, - BuildArgNode(output, arg.enter->input_type(0), i)); - if (arg.is_loop_invariant) { - node_map[arg.enter->id()] = arg_node; - } else { - node_map[arg.merge->id()] = arg_node; - } - } - - // Build a Retval node for the loop condition. The LoopCond nodes are always - // boolean because of the type constraints on the LoopCond op. - TF_ASSIGN_OR_RETURN(node_map[frame->loop_cond->id()], - BuildRetvalNode(output, DT_BOOL, 0)); - - // Performs a reverse DFS, copying nodes and edges to the output graph. - // The _Arg and _Retval nodes were added unconditionally above, so we are - // guaranteed to get the correct function signature. - return CopySubgraph(graph, frame, {frame->loop_cond}, squash_src_outputs, - &node_map, output); -} - -// Builds a graph for the loop body. -Status BuildLoopBody(const Graph& graph, Frame* frame, - DataTypeVector* arg_types, - std::unique_ptr* body_output) { - VLOG(2) << "Building loop body for " << frame->name; - *body_output = xla::MakeUnique(graph.op_registry()); - Graph* output = body_output->get(); - - // Map from nodes in the original graph to the condition graph. - std::vector node_map(graph.num_node_ids(), nullptr); - std::vector squash_src_outputs(graph.num_node_ids(), false); - - // Build one _Arg node for each Enter node. - std::vector next_iterations; - next_iterations.reserve(frame->args.size()); - arg_types->reserve(frame->args.size()); - for (int i = 0; i < frame->args.size(); ++i) { - const Arg& arg = frame->args[i]; - - DataType dtype = arg.enter->input_type(0); - arg_types->push_back(dtype); - - TF_ASSIGN_OR_RETURN(Node * arg_node, BuildArgNode(output, dtype, i)); - - if (dtype == DT_RESOURCE) { - // The convention of the XLA bridge is that resource variable arguments - // are only inputs to the loop body and have no corresponding output. - // TODO(b/37741920): change the convention so that DT_RESOURCE variables - // are both inputs and outputs, and then remove this case. - TF_RET_CHECK(arg.is_loop_invariant); - node_map[arg.enter->id()] = arg_node; - } else { - TF_ASSIGN_OR_RETURN(Node * retval_node, - BuildRetvalNode(output, dtype, i)); - - if (arg.is_loop_invariant) { - // Argument is loop-invariant. Forward it from the Arg to the Retval. - node_map[arg.enter->id()] = arg_node; - output->AddEdge(arg_node, 0, retval_node, 0); - } else { - // Argument is loop-varying. - node_map[arg.switch_node->id()] = arg_node; - // The Switch node has two outputs, but _Arg only has one. This tells - // the CopySubgraph function to rewrite the output number of edges from - // the _Arg node to be 0 rather than copying the output number from the - // Switch node. - squash_src_outputs[arg.switch_node->id()] = true; - node_map[arg.next_iteration->id()] = retval_node; - next_iterations.push_back(arg.next_iteration); - } - } - } - - // Performs a reverse DFS, copying nodes and edges to the output graph. - // The _Arg and _Retval nodes were added unconditionally above, so we are - // guaranteed to get the correct function signature. - TF_RETURN_IF_ERROR(CopySubgraph(graph, frame, std::move(next_iterations), - squash_src_outputs, &node_map, output)); - - return Status::OK(); -} - -// Copy the FunctionDef of given function from lookup_library to library, if -// it can be found in lookup_library but is missing from library. -Status AddMissingFunctionByName(const string& function_name, - const FunctionLibraryDefinition* lookup_library, - FunctionLibraryDefinition* library) { - if (!library->Find(function_name) && lookup_library->Find(function_name)) { - return library->AddFunctionDef(*lookup_library->Find(function_name)); - } - return Status::OK(); -} - -// Iterate over all functions that the given fdef refers to. Copy the missing -// FunctionDefs from lookup_library to library. -Status AddMissingFunctionDef(const FunctionDef& fdef, - const FunctionLibraryDefinition* lookup_library, - FunctionLibraryDefinition* library) { - TF_RET_CHECK(lookup_library); - for (const NodeDef& node : fdef.node_def()) { - if (library->Find(node.op())) { - continue; - } - // The function refered by 'SymbolicGradient' node is specified in its - // attribute 'f'. - if (node.op() == FunctionLibraryDefinition::kGradientOp) { - const AttrValue* attr = - AttrSlice(&node.attr()).Find(FunctionLibraryDefinition::kFuncAttr); - if (!attr) { - return errors::InvalidArgument("SymbolicGradient is missing attr: f"); - } - const string& func_name = attr->func().name(); - TF_RETURN_IF_ERROR( - AddMissingFunctionByName(func_name, lookup_library, library)); - // Copy the user-defined gradient function if it exists. - const string grad_name = lookup_library->FindGradient(func_name); - if (!grad_name.empty() && library->FindGradient(func_name).empty()) { - TF_RETURN_IF_ERROR( - AddMissingFunctionByName(grad_name, lookup_library, library)); - GradientDef grad_def; - grad_def.set_function_name(func_name); - grad_def.set_gradient_func(grad_name); - TF_RETURN_IF_ERROR(library->AddGradientDef(grad_def)); - } - } else if (lookup_library->Find(node.op())) { - TF_RETURN_IF_ERROR( - library->AddFunctionDef(*lookup_library->Find(node.op()))); - } - } - return Status::OK(); -} - -Status FunctionalizeLoop(const FunctionLibraryDefinition* lookup_library, - Graph* graph, Frame* frame, - FunctionLibraryDefinition* library) { - VLOG(2) << "Frame " << frame->name << " before: " - << dump_graph::DumpGraphToFile("functionalize_before", *graph, - library); - - // Split loop-varying Enter nodes with multiple successors. If the same - // Tensor is fed as input to multiple loop arguments, we may end up with a - // shared Enter node. We clone Enter nodes with multiple successors to - // maintain the invariant of a unique Enter node per argument of the final - // loop. - std::vector args; - for (const Arg& arg : frame->args) { - if (arg.is_loop_invariant) { - args.push_back(arg); - } else { - std::vector edges(arg.enter->out_edges().begin(), - arg.enter->out_edges().end()); - for (int i = 0; i < edges.size(); ++i) { - if (edges[i]->IsControlEdge() && edges[i]->dst()->IsSink()) { - continue; - } - TF_RET_CHECK(!edges[i]->IsControlEdge()) << edges[i]->src()->name(); - Arg new_arg; - new_arg.is_loop_invariant = false; - if (i == 0) { - new_arg.enter = arg.enter; - } else { - new_arg.enter = graph->CopyNode(arg.enter); - frame->nodes.insert(new_arg.enter); - for (Edge const* e : arg.enter->in_edges()) { - graph->AddEdge(e->src(), e->src_output(), new_arg.enter, - e->IsControlEdge() ? Graph::kControlSlot : 0); - } - Node* dst = edges[i]->dst(); - int dst_input = edges[i]->dst_input(); - graph->RemoveEdge(edges[i]); - graph->AddEdge(new_arg.enter, 0, dst, dst_input); - } - args.push_back(new_arg); - } - } - } - frame->args = std::move(args); - - std::sort( - frame->args.begin(), frame->args.end(), - [](const Arg& a, const Arg& b) { return NodeCmp()(a.enter, b.enter); }); - - if (frame->loop_cond == nullptr) { - return errors::InvalidArgument("Loop ", frame->name, - " has no LoopCond node"); - } - - // Find the set of Switch nodes that are successors of the LoopCond. - std::unordered_set switches; - for (const Edge* edge : frame->loop_cond->out_edges()) { - if (!edge->IsControlEdge() && IsSwitch(edge->dst()) && - edge->dst_input() == 1) { - switches.insert(edge->dst()); - } - } - - // For each non-constant argument, looks for the following pattern of nodes: - // Enter ----> Merge --------> Switch --> Exit - // ^ ^ - // | | - // NextIteration LoopCond - // ^ ^ - // | | - // ... ... - for (Arg& arg : frame->args) { - if (!arg.is_loop_invariant) { - // Follow the edge from the Enter to Merge. - const Edge* enter_merge = nullptr; - for (const Edge* e : arg.enter->out_edges()) { - // Ignore control-edges to the sink node. These are allowed by the - // graph invariants, although probably they should have been stripped - // off earlier. - if (e->IsControlEdge() && e->dst()->IsSink()) { - continue; - } - if (enter_merge != nullptr) { - return errors::Internal( - "Enter node for loop-varying argument ", arg.enter->name(), - " has multiple successors: ", enter_merge->dst()->name(), " and ", - e->dst()->name()); - } - enter_merge = e; - } - if (enter_merge == nullptr) { - return errors::Internal("Enter node for loop-varying argument ", - arg.enter->name(), " has zero successors"); - } - arg.merge = enter_merge->dst(); - if (!IsMerge(arg.merge)) { - return errors::InvalidArgument( - "Successor of Enter node for loop-varying argument ", - arg.merge->name(), - " is not a Merge node; got: ", arg.merge->type_string()); - } - - // Find the NextIteration from the merge. There should be two inputs to - // the Merge and the NextIteration should be the other input. - if (arg.merge->input_types().size() != 2) { - return errors::InvalidArgument( - "Unexpected number of inputs to Merge node for loop-varying " - "argument ", - arg.merge->name(), "; expected 2, got ", - arg.merge->input_types().size()); - } - TF_RETURN_IF_ERROR(arg.merge->input_node(1 - enter_merge->dst_input(), - &arg.next_iteration)); - if (!IsNextIteration(arg.next_iteration)) { - return errors::InvalidArgument( - "Expected NextIteration node as input to Merge node; got node ", - arg.next_iteration->name(), " with kind ", - arg.next_iteration->type_string()); - } - - // Find the Switch successor of the Merge. There should be exactly one - // Switch node that is a successor of both the Merge and the LoopCond. - for (const Edge* edge : arg.merge->out_edges()) { - if (edge->dst_input() == 0 && IsSwitch(edge->dst()) && - switches.find(edge->dst()) != switches.end()) { - if (arg.switch_node != nullptr) { - return errors::InvalidArgument("Duplicate Switch successors to ", - arg.merge->name()); - } - arg.switch_node = edge->dst(); - } - } - if (arg.switch_node == nullptr) { - return errors::InvalidArgument("Missing Switch successor to ", - arg.merge->name()); - } - - // Update the device on the Identity outputs of the switch to match their - // target. These Identity outputs do not - - // Loop over the switch node's output to: - // - Find the Exit successor. - // - Set the sharding on all Identity outputs of the switch. These - // identity nodes are values used by the loop body or condition. - // The Identity node may have the wrong device so copy the device from - // one of its outputs instead. - std::deque possible_exit; - for (const Edge* edge : arg.switch_node->out_edges()) { - if (edge->src_output() == 0) { - possible_exit.push_back(edge); - } - if (IsIdentity(edge->dst())) { - TF_RETURN_IF_ERROR( - SetNodeShardingFromNeighbors(edge->dst(), /*out_edges=*/true)); - } - } - // TODO(b/67425339): Allow general graph between switch and exit. - while (!possible_exit.empty()) { - const Edge* edge = possible_exit.front(); - possible_exit.pop_front(); - if (IsExit(edge->dst())) { - if (arg.exit != nullptr) { - return errors::InvalidArgument("Duplicate Exit successors to ", - arg.switch_node->name()); - } - arg.exit = edge->dst(); - } else { - if (!IsIdentity(edge->dst())) { - return errors::Unimplemented("General graph between switch (", - arg.switch_node->name(), - ") and exit node of frame ", - frame->name, " not supported yet."); - } - for (const Edge* out : edge->dst()->out_edges()) { - possible_exit.push_back(out); - } - } - } - } - } - - // Builds the condition and body functions. - std::unique_ptr cond_graph; - TF_RETURN_IF_ERROR(BuildLoopCondition(*graph, frame, &cond_graph)); - DataTypeVector arg_types; - std::unique_ptr body_graph; - TF_RETURN_IF_ERROR(BuildLoopBody(*graph, frame, &arg_types, &body_graph)); - - VLOG(2) << "Frame " << frame->name << " condition: " - << dump_graph::DumpGraphToFile("loop_condition", *cond_graph, library) - << " body: " << dump_graph::DumpGraphToFile("loop_body", *body_graph); - - static std::atomic sequence_num(0LL); - int64 id = ++sequence_num; - NameAttrList cond_name; - cond_name.set_name(strings::StrCat("_functionalize_cond_", id)); - NameAttrList body_name; - body_name.set_name(strings::StrCat("_functionalize_body_", id)); - FunctionDef cond_fdef; - TF_RETURN_IF_ERROR( - GraphToFunctionDef(*cond_graph, cond_name.name(), &cond_fdef)); - FunctionDef body_fdef; - TF_RETURN_IF_ERROR( - GraphToFunctionDef(*body_graph, body_name.name(), &body_fdef)); - - TF_RETURN_IF_ERROR(library->AddFunctionDef(cond_fdef)); - TF_RETURN_IF_ERROR(library->AddFunctionDef(body_fdef)); - if (lookup_library) { - // Copy missing FunctionDefs from lookup_library to library to make library - // self-contained. - TF_RETURN_IF_ERROR( - AddMissingFunctionDef(cond_fdef, lookup_library, library)); - TF_RETURN_IF_ERROR( - AddMissingFunctionDef(body_fdef, lookup_library, library)); - } - - // Builds a While operator. - NodeDef while_def; - NodeDefBuilder builder(frame->loop_cond->name(), "XlaWhile"); - builder.Attr("T", arg_types); - builder.Attr("cond", cond_name); - builder.Attr("body", body_name); - std::vector inputs; - for (int i = 0; i < frame->args.size(); ++i) { - const Arg& arg = frame->args[i]; - const Edge* in_edge; - TF_RETURN_IF_ERROR(arg.enter->input_edge(0, &in_edge)); - if (in_edge->IsControlEdge()) { - builder.ControlInput(in_edge->src()->name()); - } else { - inputs.push_back(NodeDefBuilder::NodeOut( - in_edge->src()->name(), in_edge->src_output(), arg_types[i])); - } - } - builder.Input(inputs); - TF_RETURN_IF_ERROR(builder.Finalize(&while_def)); - TF_ASSIGN_OR_RETURN(Node * while_node, AddNode(while_def, graph)); - - // Copies edges to the Enter nodes and from the Exit nodes onto the While. - for (int i = 0; i < frame->args.size(); ++i) { - const Arg& arg = frame->args[i]; - const Edge* in_edge; - TF_RETURN_IF_ERROR(arg.enter->input_edge(0, &in_edge)); - if (in_edge->IsControlEdge()) { - graph->AddControlEdge(in_edge->src(), while_node); - } else { - graph->AddEdge(in_edge->src(), in_edge->src_output(), while_node, i); - } - - if (!arg.is_loop_invariant) { - // Add output edges if the output of the loop is consumed. - if (arg.exit != nullptr) { - std::vector edges(arg.exit->out_edges().begin(), - arg.exit->out_edges().end()); - for (const Edge* edge : edges) { - Node* dst = edge->dst(); - int dst_input = edge->dst_input(); - graph->RemoveEdge(edge); - - if (dst_input == Graph::kControlSlot) { - graph->AddControlEdge(while_node, dst); - } else { - graph->AddEdge(while_node, i, dst, dst_input); - } - } - } - } - } - - // Remove the old nodes from the graph, and add the while node to the parent - // frame. - for (Node* node : frame->nodes) { - graph->RemoveNode(node); - } - frame->nodes.clear(); - frame->parent->nodes.insert(while_node); - - VLOG(2) << "Frame " << frame->name << " after: " - << dump_graph::DumpGraphToFile("functionalize_after", *graph, - library); - - return Status::OK(); -} - -class FunctionalizeCond { - public: - // All nodes are assumed to be either in no branch, then branch, else branch, - // or both branches (such as merge nodes). - enum Branch { - kElseBranch = 0, - kThenBranch = 1, - kBoth = 2, - kNeither = 3, - kNumBranchTypes = 4 - }; - - // Returns a textual representation of the Branch b. - static string Branch_Name(FunctionalizeCond::Branch b); - - // Functionalize all the switch-merge nodes of a loop-free graph into XlaIf - // nodes. That is, attempt to transform every remaining switch and merge nodes - // in the graph into XlaIf nodes. - // Precondition: All while loops have been removed from graph. - static Status Functionalize(Graph* graph, FunctionLibraryDefinition* library); - - private: - // CondArgNode represents a input to the conditional and its corresponding - // switch nodes. - struct CondArgNode { - explicit CondArgNode(Node* src, int src_output) - : src(src), src_output(src_output) {} - string ToString() const { - return strings::StrCat("src=", src->name(), ":", src_output, - " switches=", NodesToString(switches)); - } - - Node* src; - int src_output; - std::vector switches; - }; - using CondArgNodes = std::vector; - - struct ForwardFlowNode { - explicit ForwardFlowNode(Branch branch = Branch::kNeither) - : branch(branch), count(0) {} - string ToString() const { - return strings::StrCat("branch=", Branch_Name(branch), " count=", count); - } - Branch branch; - int count; - }; - - // Group of switch nodes that will be part of the same XlaIf. - struct SwitchCluster { - explicit SwitchCluster(const Edge* predicate_edge) - : predicate_edge(predicate_edge) {} - string ToString() const { - return strings::StrCat(name, " predicate=", predicate_edge->src()->name(), - " switches=", NodesToString(switches)); - } - - string name; - const Edge* predicate_edge; - std::vector switches; - }; - - FunctionalizeCond(Graph* graph, FunctionLibraryDefinition* library, - bool dump_graphs) - : library_(library), graph_(graph), dump_graphs_(dump_graphs) {} - - // Perform the actual cond functionalization. Iterate over groups of switch - // nodes (linked by common predicate), from innermost to outermost, and - // extract into XlaIf nodes. - Status FunctionalizeInternal(); - - // Determines the branch_map (mapping from node to branch of cond) and - // frontier (the nodes where the cond ends). - StatusOr, - std::unordered_set>> - DetermineBranchMapAndFrontier(const SwitchCluster& switch_cluster); - - // Returns XlaIf node created from subgraph of merge and switch nodes. This - // encapsulates the process of extracting the bodies needed for the then and - // else branch, creates a XlaIf node, removing the nodes of the branches from - // the graph and replacing the merge node with a XlaIf. - StatusOr ConvertToXlaIf(const CondArgNodes& cond_arg_nodes, - const SwitchCluster& switch_cluster, - const std::vector& switches); - - // Builds a XlaIfOp to replace the Switch-Graph-Merge cluster with. - StatusOr BuildAndAddXlaIfOp(const CondArgNodes& cond_arg_nodes, - const SwitchCluster& switch_cluster, - const std::vector& merge_nodes); - - // Extracts a function body corresponding to the given input edge of the merge - // node. - Status ExtractBody(const CondArgNodes& cond_arg_nodes, - const std::vector& switches, - const std::vector& merge_nodes, int input_edge, - Graph* body); - - // Adds all the input edges to `if_node` corresponding to the arguments. - Status AddInputEdges(const CondArgNodes& cond_arg_nodes, - const Edge* predicate_edge, Node* if_node); - - // Adds all output edges from the `if_node`. - Status AddOutputEdges(const std::vector& outputs, Node* if_node); - - // Returns the switch clusters of graph_ in postorder. Dead switch nodes are - // skipped and removed from the graph. - StatusOr> DeterminePredicateSwitchOrder(); - - // Update the state for destination based on the state of source and the node - // being updated. - Status Join(const ForwardFlowNode& src_state, const Node* dst, - ForwardFlowNode* dst_state); - - // Ensure that all nodes in the branch_map are dominated by the switch - // nodes. Returns nodes that are not dominated by the switches but are a - // control dependency of a node in the cond, and remove such control - // dependencies. - StatusOr> EnsureDominanceAndReturnNonDominatedControlNodes( - const std::unordered_map& branch_map, - const std::vector& switches); - - // Validates that the frontier of nodes for the conditional - // section are as expected. - Status ValidateFrontier( - const std::unordered_map& branch_map, - const std::unordered_set& frontier); - - FunctionLibraryDefinition* library_; - Graph* graph_; - bool dump_graphs_; -}; - -bool IsDeadSwitch(const Node* node) { - for (const Edge* e : node->out_edges()) { - const Node* dst = e->dst(); - if (!dst->IsIdentity()) { - return false; - } - for (const Edge* ee : dst->out_edges()) { - if (!ee->IsControlEdge() || !ee->dst()->IsSink()) { - return false; - } - } - } - return true; -} - -string FunctionalizeCond::Branch_Name(FunctionalizeCond::Branch b) { - const string branch_name[FunctionalizeCond::kNumBranchTypes + 1] = { - "else", "then", "both", "neither", "count"}; - return branch_name[b]; -} - -Status FunctionalizeCond::ValidateFrontier( - const std::unordered_map& - branch_map, - const std::unordered_set& frontier) { - std::unordered_set pending[kNumBranchTypes]; - for (Node* n : frontier) { - pending[branch_map.at(n).branch].insert(n); - } - TF_RET_CHECK(pending[kNeither].empty()) << NodesToString(pending[kNeither]); - for (const Node* n : pending[kBoth]) { - TF_RET_CHECK(IsMerge(n)) << n->DebugString(); - // Merge nodes may be in then or else branch too - } - int index = (pending[kThenBranch].size() <= pending[kElseBranch].size()) - ? kThenBranch - : kElseBranch; - int other = 1 - index; - for (const Node* n : pending[index]) { - if (pending[other].find(n) != pending[other].end()) { - return errors::Internal( - "Node (", n->DebugString().c_str(), - ") in both Else and Then branch should be in Both."); - } - } - // An empty frontier indicates a dead switch. Above we attempt to remove dead - // switch nodes, but not all are removed so don't treat it as an error yet. - // TODO(jpienaar): Find out why dead switch nodes remain. - // if (pending[kBoth].empty() && pending[kThenBranch].empty() && - // pending[kElseBranch].empty()) { - // return errors::Internal("Unexpected empty frontier for switch nodes"); - // } - return Status::OK(); -} - -Status FunctionalizeCond::Join(const ForwardFlowNode& src_state, - const Node* dst, ForwardFlowNode* dst_state) { - TF_RET_CHECK(dst_state->branch != Branch::kBoth && - dst_state->branch != Branch::kNumBranchTypes) - << "Unexpected/Invalid branch type: Merging " - << Branch_Name(src_state.branch) << " with " - << Branch_Name(dst_state->branch); - if (dst_state->branch == Branch::kNeither) { - dst_state->branch = src_state.branch; - } else if (src_state.branch != dst_state->branch && - src_state.branch != Branch::kNeither) { - if (IsMerge(dst)) { - dst_state->branch = Branch::kBoth; - } else { - return errors::Internal("Illegal merge:\n", src_state.ToString(), - " with ", dst_state->ToString(), " for\n", - dst->DebugString()); - } - } - ++dst_state->count; - return Status::OK(); -} - -StatusOr> -FunctionalizeCond::DeterminePredicateSwitchOrder() { - struct Cluster { - bool operator==(const Cluster& other) const { - return representative == other.representative; - } - int representative = -1; - }; - - // Perform a DFS over the graph and - // * Determine the reverse topological order of the nodes (there should be no - // cycles at this point so the post-order numbering corresponds to the - // reverse topological sorting); - // * Identify dead switches; - // * Initialize the cluster's representative; - std::vector> clusters(graph_->num_node_ids()); - std::vector dead_switches; - std::vector switch_order; - std::vector rev_topo_sorted_nodes; - DFS(*graph_, nullptr, [&](Node* n) { - clusters[n->id()].Get().representative = n->id(); - if (IsSwitch(n)) { - if (IsDeadSwitch(n)) { - dead_switches.push_back(n); - } else { - rev_topo_sorted_nodes.push_back(n); - switch_order.push_back(n); - } - } else if (n->IsOp()) { - // Exclude src and sink nodes from further consideration. - rev_topo_sorted_nodes.push_back(n); - } - }); - - std::vector switch_clusters; - // Return early if there are no switches in the graph. - if (switch_order.empty()) { - return switch_clusters; - } - - // Remove all dead switch nodes. - for (Node* n : dead_switches) { - VLOG(2) << "Removing dead switch: " << n->DebugString(); - graph_->RemoveNode(n); - } - - // Identify switch nodes that are part of the same control flow context by - // considering the operands of operations: an operation is part of the same - // control context as its operands unless the operation is a switch. Control - // dependencies are considered part of the same control flow context if the - // switch depth is the same (see comment below). - - // entry_cluster records the input cluster to a switch node. This is used when - // merging with a merge node where the dst's cluster is merged with the entry - // cluster of the merge node's cluster (which corresponds to a switch cluster - // and so has an entry cluster). - std::unordered_map*> entry_cluster; - - // Returns the output cluster of a node. Where the output cluster is cluster - // where the output of the node is used. For non-merge nodes this is simply - // the cluster they are part of, while for merge nodes it is the entry cluster - // of the cluster they are part of (this will correspond to the entry node of - // a switch node that dominates the merge). - auto find_output_cluster = [&](Node* n) { - UnionFind* cluster = &clusters[n->id()]; - if (!IsMerge(n)) return cluster; - auto it = entry_cluster.find(clusters[n->id()].Get().representative); - // If the cluster is not found in the entry_cluster map then an - // instruction not dominated by a switch node has been merged into the - // cluster of the merge. This indicates a failure of the clustering. - CHECK(it != entry_cluster.end()) - << "Unable to find entry for n=" << n->id() << " (" - << cluster->Get().representative << ")"; - return it->second; - }; - - // TODO(jpienaar): This could be combined with DetermineBranchMapAndFrontier. - std::vector switch_depth(graph_->num_node_ids()); - for (auto it = rev_topo_sorted_nodes.rbegin(); - it != rev_topo_sorted_nodes.rend(); ++it) { - Node* n = *it; - - // Compute switch depth. - int new_switch_depth = 0; - for (const Edge* e : n->in_edges()) { - Node* src = e->src(); - new_switch_depth = std::max( - new_switch_depth, switch_depth[src->id()] - (IsMerge(src) ? 1 : 0)); - } - switch_depth[n->id()] = new_switch_depth + (IsSwitch(n) ? 1 : 0); - - // Only merge the input operands of a switch. The switch's clustering itself - // is determined by the interaction of the switch's outputs. - if (IsSwitch(n)) { - Node* input; - TF_CHECK_OK(n->input_node(0, &input)); - entry_cluster[n->id()] = find_output_cluster(input); - UnionFind* cluster = entry_cluster[n->id()]; - int cluster_depth = switch_depth[cluster->Get().representative]; - // Merge the inputs of the switch node with one another. This results in - // predicates and control input residing in the same cluster. - for (const Edge* e : n->in_edges()) { - // Only consider the data inputs to the Switch node. - if (e->IsControlEdge()) continue; - - Node* src = e->src(); - UnionFind* src_cluster = find_output_cluster(src); - int src_cluster_depth = switch_depth[src_cluster->Get().representative]; - if (cluster_depth != src_cluster_depth) { - return errors::InvalidArgument( - "Unable to functionalize control flow in graph: Switch ('", - n->name(), "') has operands ('", input->name(), "' and '", - src->name(), "') that have different switch depths (", - cluster_depth, " != ", src_cluster_depth, ")"); - } - cluster->Merge(src_cluster); - } - continue; - } - - for (const Edge* e : n->in_edges()) { - Node* src = e->src(); - if (!src->IsOp()) continue; - UnionFind* cluster = find_output_cluster(src); - // Merge a node with its data operands and with its control operands if - // the src and dst are in the same ControlContext. The ControlContext is - // not explicitly available here, and instead the switch depth is used as - // a proxy here. Due to the invariant that control edges can only be from - // a containing scope to an inner scope or from the inner scope to its - // containing scope (for exit nodes), the switch depth will only match if - // the src and dst are in the same ControlContext. Control edges between - // ControlContexts are handled during the extraction. - int src_id = cluster->Get().representative; - int src_depth = switch_depth[src_id]; - if (!e->IsControlEdge() || new_switch_depth == src_depth) { - if (src_depth != new_switch_depth) { - // TODO(b/77601805) remove this when outside_compilation supports - // control flow. - if (str_util::StrContains(src->name(), "outside_compilation") || - str_util::StrContains(n->name(), "outside_compilation")) { - return errors::InvalidArgument( - "outside_compilation is not yet supported within TensorFlow " - "control flow constructs b/77601805"); - } - return errors::InvalidArgument( - "Unable to functionalize control flow in graph: Operand ('", - src->name(), "') and operator ('", n->name(), - "') have different switch depths (", src_depth, - " != ", new_switch_depth, ")"); - } - cluster->Merge(&clusters[n->id()]); - } - } - } - - if (dump_graphs_) { - // Mark the switch cluster each node is part of. - for (Node* n : graph_->nodes()) { - n->ClearAttr("_XlaFunctionalizeSwitchGroup"); - n->AddAttr("_XlaFunctionalizeSwitchGroup", - clusters[n->id()].Get().representative); - } - LOG(INFO) << "FunctionalizeControlFlow (with_clusters): " - << dump_graph::DumpGraphToFile("functionalize_clustered", *graph_, - library_); - } - - // Verify all the nodes of a cluster are at the same depth. - std::unordered_map> cluster_to_depth_node; - for (Node* n : graph_->nodes()) { - int depth = switch_depth[n->id()]; - int cluster_rep = clusters[n->id()].Get().representative; - auto it = cluster_to_depth_node.find(cluster_rep); - if (it == cluster_to_depth_node.end()) { - cluster_to_depth_node[cluster_rep] = std::make_pair(depth, n); - } else { - if (it->second.first != depth) { - return errors::Internal( - "Illegal clustering created, mismatch in depths:", "\n\t", - n->DebugString(), "(", clusters[n->id()].Get().representative, - ") at depth=", depth, " vs\n\t", it->second.second->DebugString(), - "(", clusters[n->id()].Get().representative, ") at depth ", - it->second.first); - } - } - } - - struct Hash { - size_t operator()(const std::pair& item) const { - return Hash64Combine(hash()(item.first), - std::hash()(item.second.representative)); - } - }; - - // Merge Switch nodes with common predicate. - std::unordered_map, int, Hash> predicate_index; - // The nodes in switch_order are in reverse topological order, but the - // clustered switches need not be (i.e., when considered as a cluster one - // element of a cluster may be later in the topological order than another - // node whose cluster is later in the topological order of clustered - // switches). - for (auto it = switch_order.rbegin(); it != switch_order.rend(); ++it) { - const Edge* pred_edge; - TF_CHECK_OK((*it)->input_edge(1, &pred_edge)); - // The predicate can be preceded by a identity node. Look through identity - // nodes to predicate. - while (pred_edge->src()->IsIdentity()) { - TF_CHECK_OK(pred_edge->src()->input_edge(0, &pred_edge)); - } - auto repr = std::make_pair(pred_edge->src(), clusters[(*it)->id()].Get()); - if (predicate_index.find(repr) == predicate_index.end()) { - predicate_index[repr] = switch_clusters.size(); - switch_clusters.emplace_back(pred_edge); - // Generate a name by concatenating with the cluster representative as - // there could be multiple switch clusters with the same predicate. - switch_clusters[predicate_index[repr]].name = strings::StrCat( - pred_edge->src()->name(), "_", repr.second.representative, "_If"); - } - switch_clusters[predicate_index[repr]].switches.push_back(*it); - } - - return switch_clusters; -} - -StatusOr> -FunctionalizeCond::EnsureDominanceAndReturnNonDominatedControlNodes( - const std::unordered_map& branch_map, - const std::vector& switches) { - std::vector old_control_nodes; - for (const auto& kv : branch_map) { - if (kv.second.count != kv.first->in_edges().size()) { - std::vector delete_edges; - for (const Edge* in : kv.first->in_edges()) { - auto it = branch_map.find(in->src()); - if (it == branch_map.end()) { - if (in->IsControlEdge()) { - old_control_nodes.push_back(in->src()); - delete_edges.push_back(in); - } else { - if (IsSwitch(in->src())) { - if (std::find(switches.begin(), switches.end(), in->src()) == - switches.end()) { - return errors::Internal( - "Unexpected switch node found during flow forward: ", - in->src()->DebugString()); - } - continue; - } - return errors::InvalidArgument( - "Value ", kv.first->name(), "'s input, ", in->src()->name(), - ", is not dominated by switch nodes ", NodesToString(switches)); - } - } - } - // Remove control edges from nodes that are not dominated by the switch - // nodes. New control dependencies will be added between these nodes and - // the XlaIf node inserted. - for (const Edge* e : delete_edges) { - graph_->RemoveEdge(e); - } - } - } - return old_control_nodes; -} - -StatusOr< - std::pair, - std::unordered_set>> -FunctionalizeCond::DetermineBranchMapAndFrontier( - const SwitchCluster& switch_cluster) { - std::unordered_map branch_map; - std::unordered_set frontier; - std::vector stack = switch_cluster.switches; - std::vector visited(graph_->num_node_ids(), false); - while (!stack.empty()) { - Node* n = stack.back(); - stack.pop_back(); - - if (visited[n->id()]) { - continue; - } - visited[n->id()] = true; - - // Propagate branch state along each edge of a switch node. - bool sink_only = true; - for (const Edge* e : n->out_edges()) { - Node* out = e->dst(); - if (!out->IsOp()) { - continue; - } - sink_only = false; - // Propagate branch information. - ForwardFlowNode& ffn = branch_map[out]; - if (IsSwitch(n)) { - int index = e->IsControlEdge() ? Branch::kNeither : e->src_output(); - TF_RETURN_WITH_CONTEXT_IF_ERROR( - Join(ForwardFlowNode(Branch(index)), out, &ffn), " when joining ", - e->DebugString()); - } else { - TF_RETURN_WITH_CONTEXT_IF_ERROR(Join(branch_map[n], out, &ffn), - " when joining ", e->DebugString()); - } - if (IsMerge(out)) { - if (out->in_edges().size() == ffn.count) { - frontier.insert(out); - } - } else if (!visited[out->id()]) { - stack.push_back(out); - } - } - if (sink_only) { - if (!IsIdentity(n)) { - VLOG(1) << "Feeding into sink: " << n->DebugString(); - } - } - } - - if (dump_graphs_) { - for (const auto& kv : branch_map) { - // Append attribute to the graph if running with logging to make the - // changes clearer in the visualization. - kv.first->AddAttr("_XlaFunctionalizeBranch", - Branch_Name(kv.second.branch)); - } - } - return std::make_pair(std::move(branch_map), std::move(frontier)); -} - -Status FunctionalizeCond::FunctionalizeInternal() { - TF_ASSIGN_OR_RETURN(std::vector predicate_switch_order, - DeterminePredicateSwitchOrder()); - - // Iterate from innermost set of clustered switches to outermost, replacing - // matching switch->merge subgraphs with single XlaIf nodes. - for (auto it = predicate_switch_order.rbegin(); - it != predicate_switch_order.rend(); ++it) { - auto& ps = *it; - VLOG(3) << "Flow down from: " << ps.ToString(); - - std::unordered_map branch_map; - std::unordered_set frontier; - TF_ASSIGN_OR_RETURN(std::tie(branch_map, frontier), - DetermineBranchMapAndFrontier(ps)); - - if (dump_graphs_) - LOG(INFO) << "FunctionalizeControlFlow (before XlaIf conversion): " - << dump_graph::DumpGraphToFile("functionalize_bc", *graph_, - library_); - TF_RETURN_IF_ERROR(ValidateFrontier(branch_map, frontier)); - - struct Hash { - size_t operator()(const std::pair& item) const { - return Hash64Combine(hash()(item.first), - std::hash()(item.second)); - } - }; - - // Sort the merge and switch nodes using NodeCmp. The switch-nodes are - // further grouped (post sorting) by input to the switch node as in the - // functionalized form each input will be passed in only once. This grouping - // should retain the sorted order. - CondArgNodes cond_arg_nodes; - std::sort(ps.switches.begin(), ps.switches.end(), NodeCmp()); - std::unordered_map, int, Hash> input_index; - for (Node* switch_node : ps.switches) { - const Edge* e; - TF_RETURN_IF_ERROR(switch_node->input_edge(0, &e)); - std::pair key = std::make_pair(e->src(), e->src_output()); - if (input_index.find(key) == input_index.end()) { - input_index[key] = cond_arg_nodes.size(); - cond_arg_nodes.emplace_back(key.first, key.second); - } - cond_arg_nodes.at(input_index.at(key)).switches.push_back(switch_node); - } - std::vector merge_nodes(frontier.begin(), frontier.end()); - std::sort(merge_nodes.begin(), merge_nodes.end(), NodeCmp()); - - TF_ASSIGN_OR_RETURN(std::vector old_control_nodes, - EnsureDominanceAndReturnNonDominatedControlNodes( - branch_map, ps.switches)); - - TF_ASSIGN_OR_RETURN(Node * if_node, - ConvertToXlaIf(cond_arg_nodes, ps, merge_nodes)); - for (Node* old : old_control_nodes) { - graph_->AddControlEdge(old, if_node); - } - - for (auto& del_kv : branch_map) { - graph_->RemoveNode(del_kv.first); - } - for (auto& kv : cond_arg_nodes) { - for (Node* node : kv.switches) { - graph_->RemoveNode(node); - } - } - if (dump_graphs_) - LOG(INFO) << "FunctionalizeControlFlow (after XlaIf conversion): " - << dump_graph::DumpGraphToFile("functionalize_ac", *graph_, - library_); - } - return Status::OK(); -} - -StatusOr FunctionalizeCond::BuildAndAddXlaIfOp( - const CondArgNodes& cond_arg_nodes, const SwitchCluster& switch_cluster, - const std::vector& merge_nodes) { - VLOG(2) << "Build if op for " << switch_cluster.name; - - NodeDef if_def; - // Create a new If node using the name of the merge node. - NodeDefBuilder builder(switch_cluster.name, "XlaIf"); - string branch[] = {"else_branch", "then_branch"}; - for (int i = 0; i < 2; ++i) { - static std::atomic sequence_num(0LL); - int64 id = ++sequence_num; - - NameAttrList body_name; - body_name.set_name( - strings::StrCat("_functionalize_if_", branch[i], "_", id)); - auto body = xla::MakeUnique(graph_->op_registry()); - TF_RETURN_IF_ERROR(ExtractBody(cond_arg_nodes, switch_cluster.switches, - merge_nodes, i, body.get())); - VLOG(3) << "Body " << branch[i] << ": " << DebugString(body.get()); - FunctionDef body_fdef; - TF_RETURN_IF_ERROR(GraphToFunctionDef(*body, body_name.name(), &body_fdef)); - TF_RETURN_IF_ERROR(library_->AddFunctionDef(body_fdef)); - builder.Attr(branch[i], body_name); - } - - // Build input type. - std::vector inputs; - DataTypeVector in_arg_types; - for (auto& kv : cond_arg_nodes) { - bool inserted = false; - for (const Node* arg : kv.switches) { - const Edge* in_edge; - TF_RETURN_IF_ERROR(arg->input_edge(0, &in_edge)); - if (in_edge->IsControlEdge()) { - builder.ControlInput(in_edge->src()->name()); - } else { - if (!inserted) { - DataType dtype = arg->input_type(0); - inputs.emplace_back(NodeDefBuilder::NodeOut( - in_edge->src()->name(), in_edge->src_output(), dtype)); - in_arg_types.push_back(dtype); - inserted = true; - } - } - } - } - builder.Attr("Tin", in_arg_types); - - // Build output type. - DataTypeVector out_type; - for (const Node* merge : merge_nodes) { - DataType dtype = merge->output_type(0); - out_type.push_back(dtype); - } - builder.Attr("Tout", out_type); - - builder.Attr("Tcond", DT_BOOL); - builder.Device(switch_cluster.predicate_edge->src()->assigned_device_name()); - // Conditional should be the first input ... - builder.Input(NodeDefBuilder::NodeOut( - switch_cluster.predicate_edge->src()->name(), - switch_cluster.predicate_edge->src_output(), - switch_cluster.predicate_edge->src()->output_type(0))); - // ... followed by the other inputs. - builder.Input(inputs); - - TF_RETURN_IF_ERROR(builder.Finalize(&if_def)); - TF_ASSIGN_OR_RETURN(Node * if_node, AddNode(if_def, graph_)); - return if_node; -} - -Status FunctionalizeCond::ExtractBody(const CondArgNodes& cond_arg_nodes, - const std::vector& switches, - const std::vector& merge_nodes, - int input_edge, Graph* body) { - VLOG(2) << "ExtractBody for " << NodesToString(merge_nodes) << " along edge " - << input_edge; - std::vector squash_src_outputs(graph_->num_node_ids(), false); - std::vector node_map(graph_->num_node_ids(), nullptr); - int arg_count = 0; - for (auto& kv : cond_arg_nodes) { - Node* arg_node = nullptr; - for (const auto* arg : kv.switches) { - DataType dtype = arg->input_type(0); - if (arg_node == nullptr) { - TF_ASSIGN_OR_RETURN(arg_node, BuildArgNode(body, dtype, arg_count++)); - } - node_map.at(arg->id()) = arg_node; - squash_src_outputs.at(arg->id()) = true; - } - } - - std::vector stack; - stack.reserve(merge_nodes.size()); - for (int j = 0; j < merge_nodes.size(); ++j) { - Node* node = merge_nodes[j]; - TF_ASSIGN_OR_RETURN(node_map.at(node->id()), - BuildRetvalNode(body, node->output_type(0), - /*index=*/j)); - const Edge* in_edge; - TF_RETURN_IF_ERROR(node->input_edge(input_edge, &in_edge)); - Node* in = in_edge->src(); - if (node_map.at(in->id()) == nullptr) { - node_map.at(in->id()) = body->CopyNode(in); - } - - if (std::find(switches.begin(), switches.end(), in) == switches.end()) { - body->AddEdge(node_map.at(in->id()), in_edge->src_output(), - node_map.at(node->id()), 0); - } else { - body->AddEdge(node_map.at(in->id()), 0, node_map.at(node->id()), 0); - // Don't include input nodes that are already just returned in stack. - continue; - } - stack.push_back(in); - } - - return CopySubgraph(*graph_, nullptr, stack, squash_src_outputs, &node_map, - body); -} - -Status FunctionalizeCond::AddInputEdges(const CondArgNodes& cond_arg_nodes, - const Edge* predicate_edge, - Node* if_node) { - VLOG(3) << "AddInputEdges for " << if_node->name(); - int index = 0; - graph_->AddEdge(predicate_edge->src(), predicate_edge->src_output(), if_node, - index++); - for (auto& arg : cond_arg_nodes) { - if (arg.src_output == Graph::kControlSlot) { - graph_->AddControlEdge(arg.src, if_node); - } else { - graph_->AddEdge(arg.src, arg.src_output, if_node, index++); - } - } - return Status::OK(); -} - -Status FunctionalizeCond::AddOutputEdges(const std::vector& outputs, - Node* if_node) { - VLOG(3) << "AddOutputEdges for " << if_node->name(); - for (int i = 0; i < outputs.size(); ++i) { - Node* node = outputs[i]; - std::vector edges(node->out_edges().begin(), - node->out_edges().end()); - for (const Edge* edge : edges) { - Node* dst = edge->dst(); - int dst_input = edge->dst_input(); - - if (edge->src_output() > 0) { - return errors::Unimplemented("Output of index (", edge->src_output(), - ") of merge node ", node->name()); - } - - int src_output = - dst_input == Graph::kControlSlot ? Graph::kControlSlot : i; - graph_->RemoveEdge(edge); - graph_->AddEdge(if_node, src_output, dst, dst_input); - } - } - return Status::OK(); -} - -StatusOr FunctionalizeCond::ConvertToXlaIf( - const CondArgNodes& cond_arg_nodes, const SwitchCluster& switch_cluster, - const std::vector& merge_nodes) { - VLOG(1) << "ConvertToXlaIf for " << switch_cluster.ToString() << " -> " - << NodesToString(merge_nodes); - - // Extract bodies and builds a If operator. - TF_ASSIGN_OR_RETURN( - Node * if_node, - BuildAndAddXlaIfOp(cond_arg_nodes, switch_cluster, merge_nodes)); - TF_RETURN_IF_ERROR( - AddInputEdges(cond_arg_nodes, switch_cluster.predicate_edge, if_node)); - TF_RETURN_IF_ERROR(AddOutputEdges(merge_nodes, if_node)); - // Check that the if_node doesn't feed into itself. - TF_RETURN_WITH_CONTEXT_IF_ERROR( - CheckNoCycleContains(if_node, graph_->num_node_ids()), - "ConvertToXlaIf failed."); - - return if_node; -} - -Status FunctionalizeCond::Functionalize(Graph* graph, - FunctionLibraryDefinition* library) { - VLOG(1) << "FunctionalizeCond::Functionalize"; - FunctionalizeCond fc(graph, library, /*dump_graphs=*/VLOG_IS_ON(2)); - return fc.FunctionalizeInternal(); -} - -} // namespace - -// Transformation that converts TensorFlow's graph control flow constructs into -// functional equivalents. -Status FunctionalizeControlFlow(Graph* graph, - FunctionLibraryDefinition* library) { - return FunctionalizeControlFlow(/*lookup_library=*/nullptr, graph, library); -} - Status FunctionalizeControlFlow(const FunctionLibraryDefinition* lookup_library, Graph* graph, FunctionLibraryDefinition* library) { @@ -1459,98 +46,26 @@ Status FunctionalizeControlFlow(const FunctionLibraryDefinition* lookup_library, << dump_graph::DumpGraphToFile("functionalize_initial", *graph, library); - // Note: BuildControlFlowInfo() requires that the graph's source node is - // connected to all source nodes in the graph. Many graphs violate this - // invariant. - std::vector cf_info; - std::vector unreachable_nodes; - TF_RETURN_WITH_CONTEXT_IF_ERROR( - BuildControlFlowInfo(graph, &cf_info, &unreachable_nodes), - "FunctionalizeControlFlow failed"); - if (!unreachable_nodes.empty()) { - return errors::InvalidArgument( - "The following nodes are unreachable from the source in the graph: ", - tensorflow::str_util::Join(unreachable_nodes, ", ")); - } - - // Builds Frames, indexed by name. - std::unordered_map frames; - for (Node* node : graph->op_nodes()) { - const ControlFlowInfo& cf = cf_info[node->id()]; - - VLOG(2) << "node: " << node->name() << " (" << node->id() - << ") frame_name: " << cf.frame_name - << " frame: " << (cf.frame ? cf.frame->name() : "---") - << " parent_frame: " - << (cf.parent_frame ? cf.parent_frame->name() : "---"); - TF_RET_CHECK(cf.frame != nullptr && cf.parent_frame != nullptr); - - Frame& frame = frames[cf.frame_name]; - Frame* parent = &frames[cf_info[cf.parent_frame->id()].frame_name]; - if (frame.parent == nullptr) { - frame.parent = parent; - frame.name = cf.frame_name; - ++parent->num_children; - } - - if (IsEnter(node)) { - Arg arg; - arg.enter = node; - TF_RETURN_IF_ERROR(GetNodeAttr(arg.enter->attrs(), "is_constant", - &arg.is_loop_invariant)); - frame.args.push_back(arg); - } else if (IsLoopCond(node)) { - frame.loop_cond = node; - } - frame.nodes.insert(node); - } - - // Adds frames with no children (i.e., the innermost frames) to a worklist. - std::deque worklist; - for (auto& frame : frames) { - if (frame.second.num_children == 0) { - worklist.push_back(&frame.second); - } - } - - // Eliminate loops from innermost to outermost. - while (!worklist.empty()) { - Frame* frame = worklist.front(); - worklist.pop_front(); - if (frame->parent == frame) { - // Skip the root frame. - continue; - } - - TF_RETURN_IF_ERROR( - FunctionalizeLoop(lookup_library, graph, frame, library)); - - // If the parent has no remaining children, add it to the worklist. - --frame->parent->num_children; - if (frame->parent->num_children == 0) { - worklist.push_back(frame->parent); - } - } - // There should be no cycle at this point, since while loops have been removed - // from graph. - // Check that the newly added XlaWhile nodes don't feed into themselves. - for (const Node* node : graph->op_nodes()) { - if (node->def().op() == "XlaWhile") { - TF_RETURN_WITH_CONTEXT_IF_ERROR( - CheckNoCycleContains(node, graph->num_node_ids()), - "FunctionalizeLoop failed."); - } - } + // Functionalize and remove while loops from graph. + TF_RETURN_IF_ERROR(FunctionalizeWhileLoop(lookup_library, graph, library)); // FunctionalizeControlFlow is invoked for every function, so the loops's // bodies and conditionals that were extracted into functions will be handled // in successive invocations. - TF_RETURN_IF_ERROR(FunctionalizeCond::Functionalize(graph, library)); + TF_RETURN_IF_ERROR(FunctionalizeCond(graph, library)); VLOG(2) << "FunctionalizeControlFlow (final): " << dump_graph::DumpGraphToFile("functionalize_final", *graph, library); + return Status::OK(); } +// Transformation that converts TensorFlow's graph control flow constructs into +// functional equivalents. +Status FunctionalizeControlFlow(Graph* graph, + FunctionLibraryDefinition* library) { + return FunctionalizeControlFlow(/*lookup_library=*/nullptr, graph, library); +} + } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/functionalize_control_flow.h b/tensorflow/compiler/tf2xla/functionalize_control_flow.h index d941041d15532446d1413f16fe64602bfb1a7daa..55600f2a8b5302cef26b9be4ccd0f8804476a17a 100644 --- a/tensorflow/compiler/tf2xla/functionalize_control_flow.h +++ b/tensorflow/compiler/tf2xla/functionalize_control_flow.h @@ -16,14 +16,16 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_TF2XLA_FUNCTIONALIZE_CONTROL_FLOW_H_ #define TENSORFLOW_COMPILER_TF2XLA_FUNCTIONALIZE_CONTROL_FLOW_H_ +#include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/core/framework/function.h" #include "tensorflow/core/graph/graph.h" namespace tensorflow { // Transformation that converts tf.while_loop() loops into functional While -// operators, suitable for XLA compilation. If lookup_library is provided, use -// it to make the library for control flow self-contained. +// operators and tf.cond() conditionals into function If operators, suitable for +// XLA compilation. If lookup_library is provided, use it to make the library +// for control flow self-contained. Status FunctionalizeControlFlow(Graph* graph, FunctionLibraryDefinition* library); Status FunctionalizeControlFlow(const FunctionLibraryDefinition* lookup_library, diff --git a/tensorflow/compiler/tf2xla/functionalize_control_flow_test.cc b/tensorflow/compiler/tf2xla/functionalize_control_flow_test.cc index aae2f8ee5acd6249f8b6002d94c877f18064f936..cc52057f214a45a861660c3d34cbbffd9c45a640 100644 --- a/tensorflow/compiler/tf2xla/functionalize_control_flow_test.cc +++ b/tensorflow/compiler/tf2xla/functionalize_control_flow_test.cc @@ -37,12 +37,12 @@ limitations under the License. namespace tensorflow { namespace { -// Returns the names of the "then" and "else" functions for the XlaIf node in a +// Returns the names of the "then" and "else" functions for the If node in a // graph. Status FindIfThenAndElse(const GraphDef& graph, string* op_name, NameAttrList* then_fn, NameAttrList* else_fn) { for (const NodeDef& node : graph.node()) { - if (node.op() == "XlaIf") { + if (node.op() == "If") { *op_name = node.name(); const NameAttrList* result; TF_RETURN_IF_ERROR(GetNodeAttr(node, "then_branch", &result)); @@ -52,7 +52,7 @@ Status FindIfThenAndElse(const GraphDef& graph, string* op_name, return Status::OK(); } } - return errors::NotFound("No XlaIf node found in graph"); + return errors::NotFound("No If node found in graph"); } // Graph: @@ -115,8 +115,13 @@ TEST(FunctionalizeControlFlow, Conditional) { auto if_op = ops::XlaIf(scope.WithOpName(op_name), less, std::initializer_list{less, y, x}, then_fn, else_fn, {DT_INT32}); + auto id = ops::Identity(scope.WithOpName("cond/Merge"), if_op.output[0]); GraphDef expected; TF_EXPECT_OK(scope.ToGraphDef(&expected)); + // TODO(jpienaar): Create wrapper for IfOp. + for (NodeDef& n : *expected.mutable_node()) { + if (n.op() == "XlaIf") n.set_op("If"); + } TF_EXPECT_GRAPH_EQ(expected, graph_def); } @@ -1013,60 +1018,5 @@ TEST(FunctionalizeControlFlow, Complex) { } } -TEST(FunctionalizeControlFlow, Cycle) { - std::unique_ptr graph(new Graph(OpRegistry::Global())); - // ----------------------------------------------------- - // | | - // | v - // less -> switch_1 --> add -> merge_1 -> identity -> switch_2 - // | ^ | - // | | v - // --------> one -------------------------> add_2 ---> merge_2 - { - Scope scope = Scope::NewRootScope().ExitOnError(); - - auto x = ops::Placeholder(scope.WithOpName("x"), DT_INT32); - auto y = ops::Placeholder(scope.WithOpName("y"), DT_INT32); - auto less = ops::Less(scope.WithOpName("cond/Less"), y, x); - auto switch_1 = ops::Switch(scope.WithOpName("cond/Switch"), x, less); - auto two = - ops::Const(scope.WithOpName("cond/two") - .WithControlDependencies(switch_1.output_true), - 2); - auto mul = ops::Multiply(scope.WithOpName("cond/true/mul"), - switch_1.output_true, two); - auto one = - ops::Const(scope.WithOpName("cond/one") - .WithControlDependencies(switch_1.output_false), - 1); - auto add = ops::Add(scope.WithOpName("cond/false/add"), - switch_1.output_false, one); - - auto merge_1 = ops::Merge(scope.WithOpName("cond/Merge"), - std::initializer_list{add, mul}); - auto identity = - ops::Identity(scope.WithOpName("cond/Merge/identity"), merge_1.output); - auto switch_2 = - ops::Switch(scope.WithOpName("grad/cond/Switch"), identity, less); - auto add_2 = ops::Add(scope.WithOpName("cond_2/false/add"), - switch_2.output_false, one); - auto mul_2 = ops::Multiply(scope.WithOpName("cond_2/true/mul"), - switch_2.output_true, two); - auto merge_2 = ops::Merge(scope.WithOpName("cond_2/Merge"), - std::initializer_list{add_2, mul_2}); - TF_ASSERT_OK(scope.ToGraph(graph.get())); - } - // No cycle before functionalize control flow. - TF_EXPECT_OK(graph::ValidateGraphHasNoCycle(*graph)); - FunctionLibraryDefinition library(OpRegistry::Global(), {}); - // switch_1 and switch_2 have the same switch depth. They are replaced by a - // single XlaIf node during FunctionalizeControlFlow, resulting in a cycle: - // less -> XlaIf <--> identity. - Status status = FunctionalizeControlFlow(graph.get(), &library); - EXPECT_FALSE(status.ok()); - EXPECT_TRUE(str_util::StrContains(status.error_message(), "Detect a cycle")) - << status.error_message(); -} - } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/functionalize_control_flow_util.cc b/tensorflow/compiler/tf2xla/functionalize_control_flow_util.cc new file mode 100644 index 0000000000000000000000000000000000000000..924fcdd9cd72a6472e0b2748680f2552fa65ec79 --- /dev/null +++ b/tensorflow/compiler/tf2xla/functionalize_control_flow_util.cc @@ -0,0 +1,72 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/tf2xla/functionalize_control_flow_util.h" + +#include "tensorflow/core/framework/node_def.pb.h" + +namespace tensorflow { + +bool NodeCmpByNameResourcesLast::operator()(const Node* lhs, + const Node* rhs) const { + bool lhs_is_resource = + lhs->num_inputs() > 0 ? (lhs->input_type(0) == DT_RESOURCE) : false; + bool rhs_is_resource = + rhs->num_inputs() > 0 ? (rhs->input_type(0) == DT_RESOURCE) : false; + return std::tie(lhs_is_resource, lhs->name()) < + std::tie(rhs_is_resource, rhs->name()); +} + +xla::StatusOr AddNodeDefToGraph(const NodeDef& node_def, Graph* graph) { + Status status; + Node* inserted_node = graph->AddNode(node_def, &status); + if (!status.ok()) { + return status; + } + return inserted_node; +} + +xla::StatusOr BuildRetvalNode(Graph* graph, DataType type, int index) { + const char* const kRetValOp = "_Retval"; + NodeDef ret_def; + ret_def.set_op(kRetValOp); + ret_def.set_name(strings::StrCat(kRetValOp, index)); + AddNodeAttr("T", type, &ret_def); + AddNodeAttr("index", index, &ret_def); + return AddNodeDefToGraph(ret_def, graph); +} + +// Check that the graph has no cycle containing the given node. +Status CheckNodeNotInCycle(const Node* node, const int num_nodes) { + std::vector ready; + ready.push_back(node); + std::vector visited(num_nodes); + while (!ready.empty()) { + const Node* current_node = ready.back(); + ready.pop_back(); + visited[current_node->id()] = true; + for (const Edge* out : current_node->out_edges()) { + if (out->dst() == node) { + return errors::Internal("Detected a cycle: ", FormatNodeForError(*node), + " (", node->def().op(), ") feeds into itself."); + } else if (!visited[out->dst()->id()]) { + ready.push_back(out->dst()); + } + } + } + return Status::OK(); +} + +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/functionalize_control_flow_util.h b/tensorflow/compiler/tf2xla/functionalize_control_flow_util.h new file mode 100644 index 0000000000000000000000000000000000000000..a0544b69e9ea3a1bd16dcd08bc4b4638a8fc31fb --- /dev/null +++ b/tensorflow/compiler/tf2xla/functionalize_control_flow_util.h @@ -0,0 +1,56 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_TF2XLA_FUNCTIONALIZE_CONTROL_FLOW_UTIL_H_ +#define TENSORFLOW_COMPILER_TF2XLA_FUNCTIONALIZE_CONTROL_FLOW_UTIL_H_ + +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/core/graph/graph.h" + +// Utility functions shared between functionalize cond and while. + +namespace tensorflow { + +// Check that the graph has no cycle containing the given node. +Status CheckNodeNotInCycle(const Node* node, const int num_nodes); + +// Comparison function used for sorting nodes consistently. +// a) resource variables are last, and +// b) sort lexicographically by name (for deterministic output). +struct NodeCmpByNameResourcesLast { + bool operator()(const Node* lhs, const Node* rhs) const; +}; + +// Returns the Node* created from the NodeDef in the Graph. +xla::StatusOr AddNodeDefToGraph(const NodeDef& node_def, Graph* graph); + +// Build a retval node of given type and index. +xla::StatusOr BuildRetvalNode(Graph* graph, DataType type, int index); + +// Returns a textual representation of the names of the nodes in the input. +template +string NodesToString(const T& nodes) { + return strings::StrCat("{", + str_util::Join(nodes, ",", + [](string* output, const Node* node) { + strings::StrAppend(output, + node->name()); + }), + "}"); +} + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_TF2XLA_FUNCTIONALIZE_CONTROL_FLOW_UTIL_H_ diff --git a/tensorflow/compiler/tf2xla/functionalize_while.cc b/tensorflow/compiler/tf2xla/functionalize_while.cc new file mode 100644 index 0000000000000000000000000000000000000000..4fd134c69809a70c1618a3ebaa0b27a2c1467a54 --- /dev/null +++ b/tensorflow/compiler/tf2xla/functionalize_while.cc @@ -0,0 +1,668 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/tf2xla/functionalize_while.h" + +#include +#include +#include +#include +#include + +#include "absl/memory/memory.h" +#include "tensorflow/compiler/jit/union_find.h" +#include "tensorflow/compiler/tf2xla/dump_graph.h" +#include "tensorflow/compiler/tf2xla/functionalize_control_flow_util.h" +#include "tensorflow/compiler/tf2xla/tf2xla_util.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/core/common_runtime/function.h" +#include "tensorflow/core/framework/graph_to_functiondef.h" +#include "tensorflow/core/framework/node_def_builder.h" +#include "tensorflow/core/graph/algorithm.h" +#include "tensorflow/core/graph/control_flow.h" +#include "tensorflow/core/graph/node_builder.h" +#include "tensorflow/core/lib/gtl/optional.h" + +namespace tensorflow { +namespace { + +using xla::StatusOr; + +// Information about a loop argument. +struct Arg { + // Every loop argument has an Enter node. + Node* enter; + + // Is the loop argument a loop-invariant value? Taken from the `is_constant` + // attribute on the Enter node. + bool is_loop_invariant; + + // If 'is_loop_invariant' is true, the following are all nullptr. Non-constant + // arguments must have all of the following nodes: + Node* merge = nullptr; + Node* switch_node = nullptr; + Node* next_iteration = nullptr; + Node* exit = nullptr; +}; + +// Information about a loop frame. +struct Frame { + string name; + + // Pointer to the parent frame. The root frame has a pointer to itself. + Frame* parent = nullptr; + int num_children = 0; + + // Arguments to this loop. + std::vector args; + + // The loop condition of the loop. There should be exactly one loop condition + // in every loop. + Node* loop_cond = nullptr; + + // Set of nodes that belong to the loop frame. + std::unordered_set nodes; +}; + +// Copies a subgraph from `graph` to `output` by performing a reverse DFS +// starting at nodes in vector `stack`. +// `node_map` is a vector indexed by source node ID to dest nodes. +// Does not traverse into nodes in `node_map`, so by adding nodes to `node_map` +// before the traversal clients can cut the graph. If a frame is provided (frame +// != nullptr), then this functions will return an error if the +// traversal leaves 'frame'; the client must add enough nodes to `node_map` to +// cut the graph and prevent the traversal from escaping. +// +// `squash_src_outputs` contains a bool for each source node ID. If true, then +// the source output on that node will be replaced by zero when copied. This is +// used when replacing a Switch node with an _Arg node. The output we are +// taking from the Switch node was not necessarily the first output, but _Arg +// nodes only have one output. By adding the Switch node to `squash_src_outputs` +// we rewrite the src_output of the corresponding edge to be 0. +Status CopySubgraph(const Graph& graph, const Frame* frame, + std::vector stack, + const std::vector& squash_src_outputs, + std::vector* node_map, Graph* output) { + VLOG(3) << "Stack: " << NodesToString(stack); + std::vector visited(graph.num_node_ids(), false); + while (!stack.empty()) { + Node* n = stack.back(); + stack.pop_back(); + + VLOG(5) << "Copying node " << n->name(); + + if (visited[n->id()]) continue; + visited[n->id()] = true; + + for (const Edge* e : n->in_edges()) { + Node* src = e->src(); + if (frame != nullptr && frame->nodes.find(src) == frame->nodes.end()) { + // We traversed out of the loop frame, without encountering a cut node. + return errors::Internal("Graph traversal of loop frame ", frame->name, + " escaped frame at ", src->name(), + " without encountering an argument node."); + } + if ((*node_map)[src->id()] == nullptr) { + (*node_map)[src->id()] = output->CopyNode(src); + stack.push_back(src); + } + Node* src_copy = (*node_map)[e->src()->id()]; + int src_output = squash_src_outputs[e->src()->id()] && !e->IsControlEdge() + ? 0 + : e->src_output(); + Node* dst_copy = (*node_map)[e->dst()->id()]; + output->AddEdge(src_copy, src_output, dst_copy, e->dst_input()); + } + } + return Status::OK(); +} + +StatusOr BuildArgNode(Graph* graph, DataType type, int index) { + const char* const kArgOp = "_Arg"; + NodeDef arg_def; + NodeDefBuilder builder(strings::StrCat(kArgOp, index), kArgOp); + builder.Attr("T", type); + builder.Attr("index", index); + TF_RETURN_IF_ERROR(builder.Finalize(&arg_def)); + return AddNodeDefToGraph(arg_def, graph); +} + +// Builds a graph for the loop condition. +Status BuildLoopCondition(const Graph& graph, Frame* frame, + std::unique_ptr* cond_output) { + VLOG(2) << "Building loop condition for " << frame->name; + *cond_output = absl::make_unique(graph.op_registry()); + Graph* output = cond_output->get(); + + // Map from nodes in the original graph to the condition graph. + std::vector node_map(graph.num_node_ids(), nullptr); + std::vector squash_src_outputs(graph.num_node_ids(), false); + + // Build one _Arg node for each Enter node. + for (int i = 0; i < frame->args.size(); ++i) { + const Arg& arg = frame->args[i]; + + TF_ASSIGN_OR_RETURN(Node * arg_node, + BuildArgNode(output, arg.enter->input_type(0), i)); + if (arg.is_loop_invariant) { + node_map[arg.enter->id()] = arg_node; + } else { + node_map[arg.merge->id()] = arg_node; + } + } + + // Build a Retval node for the loop condition. The LoopCond nodes are always + // boolean because of the type constraints on the LoopCond op. + TF_ASSIGN_OR_RETURN(node_map[frame->loop_cond->id()], + BuildRetvalNode(output, DT_BOOL, 0)); + + // Performs a reverse DFS, copying nodes and edges to the output graph. + // The _Arg and _Retval nodes were added unconditionally above, so we are + // guaranteed to get the correct function signature. + return CopySubgraph(graph, frame, {frame->loop_cond}, squash_src_outputs, + &node_map, output); +} + +// Builds a graph for the loop body. +Status BuildLoopBody(const Graph& graph, Frame* frame, + DataTypeVector* arg_types, + std::unique_ptr* body_output) { + VLOG(2) << "Building loop body for " << frame->name; + *body_output = absl::make_unique(graph.op_registry()); + Graph* output = body_output->get(); + + // Map from nodes in the original graph to the condition graph. + std::vector node_map(graph.num_node_ids(), nullptr); + std::vector squash_src_outputs(graph.num_node_ids(), false); + + // Build one _Arg node for each Enter node. + std::vector next_iterations; + next_iterations.reserve(frame->args.size()); + arg_types->reserve(frame->args.size()); + for (int i = 0; i < frame->args.size(); ++i) { + const Arg& arg = frame->args[i]; + + DataType dtype = arg.enter->input_type(0); + arg_types->push_back(dtype); + + TF_ASSIGN_OR_RETURN(Node * arg_node, BuildArgNode(output, dtype, i)); + + if (dtype == DT_RESOURCE) { + // The convention of the XLA bridge is that resource variable arguments + // are only inputs to the loop body and have no corresponding output. + // TODO(b/37741920): change the convention so that DT_RESOURCE variables + // are both inputs and outputs, and then remove this case. + TF_RET_CHECK(arg.is_loop_invariant); + node_map[arg.enter->id()] = arg_node; + } else { + TF_ASSIGN_OR_RETURN(Node * retval_node, + BuildRetvalNode(output, dtype, i)); + + if (arg.is_loop_invariant) { + // Argument is loop-invariant. Forward it from the Arg to the Retval. + node_map[arg.enter->id()] = arg_node; + output->AddEdge(arg_node, 0, retval_node, 0); + } else { + // Argument is loop-varying. + node_map[arg.switch_node->id()] = arg_node; + // The Switch node has two outputs, but _Arg only has one. This tells + // the CopySubgraph function to rewrite the output number of edges from + // the _Arg node to be 0 rather than copying the output number from the + // Switch node. + squash_src_outputs[arg.switch_node->id()] = true; + node_map[arg.next_iteration->id()] = retval_node; + next_iterations.push_back(arg.next_iteration); + } + } + } + + // Performs a reverse DFS, copying nodes and edges to the output graph. + // The _Arg and _Retval nodes were added unconditionally above, so we are + // guaranteed to get the correct function signature. + TF_RETURN_IF_ERROR(CopySubgraph(graph, frame, std::move(next_iterations), + squash_src_outputs, &node_map, output)); + + return Status::OK(); +} + +// Copy the FunctionDef of given function from lookup_library to library, if +// it can be found in lookup_library but is missing from library. +Status AddMissingFunctionByName(const string& function_name, + const FunctionLibraryDefinition* lookup_library, + FunctionLibraryDefinition* library) { + if (!library->Find(function_name) && lookup_library->Find(function_name)) { + return library->AddFunctionDef(*lookup_library->Find(function_name)); + } + return Status::OK(); +} + +// Iterate over all functions that the given fdef refers to. Copy the missing +// FunctionDefs from lookup_library to library. +Status AddMissingFunctionDef(const FunctionDef& fdef, + const FunctionLibraryDefinition* lookup_library, + FunctionLibraryDefinition* library) { + TF_RET_CHECK(lookup_library); + for (const NodeDef& node : fdef.node_def()) { + if (library->Find(node.op())) { + continue; + } + // The function referred by 'SymbolicGradient' node is specified in its + // attribute 'f'. + if (node.op() == FunctionLibraryDefinition::kGradientOp) { + const AttrValue* attr = + AttrSlice(&node.attr()).Find(FunctionLibraryDefinition::kFuncAttr); + if (!attr) { + return errors::InvalidArgument("SymbolicGradient is missing attr: f"); + } + const string& func_name = attr->func().name(); + TF_RETURN_IF_ERROR( + AddMissingFunctionByName(func_name, lookup_library, library)); + // Copy the user-defined gradient function if it exists. + const string grad_name = lookup_library->FindGradient(func_name); + if (!grad_name.empty() && library->FindGradient(func_name).empty()) { + TF_RETURN_IF_ERROR( + AddMissingFunctionByName(grad_name, lookup_library, library)); + GradientDef grad_def; + grad_def.set_function_name(func_name); + grad_def.set_gradient_func(grad_name); + TF_RETURN_IF_ERROR(library->AddGradientDef(grad_def)); + } + } else if (lookup_library->Find(node.op())) { + TF_RETURN_IF_ERROR( + library->AddFunctionDef(*lookup_library->Find(node.op()))); + } + } + return Status::OK(); +} + +Status FunctionalizeLoop(const FunctionLibraryDefinition* lookup_library, + Graph* graph, Frame* frame, + FunctionLibraryDefinition* library) { + VLOG(2) << "Frame " << frame->name << " before: " + << dump_graph::DumpGraphToFile("functionalize_before", *graph, + library); + + // Split loop-varying Enter nodes with multiple successors. If the same + // Tensor is fed as input to multiple loop arguments, we may end up with a + // shared Enter node. We clone Enter nodes with multiple successors to + // maintain the invariant of a unique Enter node per argument of the final + // loop. + std::vector args; + for (const Arg& arg : frame->args) { + if (arg.is_loop_invariant) { + args.push_back(arg); + } else { + std::vector edges(arg.enter->out_edges().begin(), + arg.enter->out_edges().end()); + for (int i = 0; i < edges.size(); ++i) { + if (edges[i]->IsControlEdge() && edges[i]->dst()->IsSink()) { + continue; + } + TF_RET_CHECK(!edges[i]->IsControlEdge()) << edges[i]->src()->name(); + Arg new_arg; + new_arg.is_loop_invariant = false; + if (i == 0) { + new_arg.enter = arg.enter; + } else { + new_arg.enter = graph->CopyNode(arg.enter); + frame->nodes.insert(new_arg.enter); + for (Edge const* e : arg.enter->in_edges()) { + graph->AddEdge(e->src(), e->src_output(), new_arg.enter, + e->IsControlEdge() ? Graph::kControlSlot : 0); + } + Node* dst = edges[i]->dst(); + int dst_input = edges[i]->dst_input(); + graph->RemoveEdge(edges[i]); + graph->AddEdge(new_arg.enter, 0, dst, dst_input); + } + args.push_back(new_arg); + } + } + } + frame->args = std::move(args); + + std::sort(frame->args.begin(), frame->args.end(), + [](const Arg& a, const Arg& b) { + return NodeCmpByNameResourcesLast()(a.enter, b.enter); + }); + + if (frame->loop_cond == nullptr) { + return errors::InvalidArgument("Loop ", frame->name, + " has no LoopCond node"); + } + + // Find the set of Switch nodes that are successors of the LoopCond. + std::unordered_set switches; + for (const Edge* edge : frame->loop_cond->out_edges()) { + if (!edge->IsControlEdge() && IsSwitch(edge->dst()) && + edge->dst_input() == 1) { + switches.insert(edge->dst()); + } + } + + // For each non-constant argument, looks for the following pattern of nodes: + // Enter ----> Merge --------> Switch --> Exit + // ^ ^ + // | | + // NextIteration LoopCond + // ^ ^ + // | | + // ... ... + for (Arg& arg : frame->args) { + if (!arg.is_loop_invariant) { + // Follow the edge from the Enter to Merge. + const Edge* enter_merge = nullptr; + for (const Edge* e : arg.enter->out_edges()) { + // Ignore control-edges to the sink node. These are allowed by the + // graph invariants, although probably they should have been stripped + // off earlier. + if (e->IsControlEdge() && e->dst()->IsSink()) { + continue; + } + if (enter_merge != nullptr) { + return errors::Internal("Enter node for loop-varying argument ", + FormatNodeForError(*arg.enter), + " has multiple successors: ", + FormatNodeForError(*enter_merge->dst()), + " and ", FormatNodeForError(*e->dst())); + } + enter_merge = e; + } + if (enter_merge == nullptr) { + return errors::Internal("Enter node for loop-varying argument ", + FormatNodeForError(*arg.enter), + " has zero successors"); + } + arg.merge = enter_merge->dst(); + if (!IsMerge(arg.merge)) { + return errors::InvalidArgument( + "Successor of Enter node for loop-varying argument ", + FormatNodeForError(*arg.merge), + " is not a Merge node; got: ", arg.merge->type_string()); + } + + // Find the NextIteration from the merge. There should be two inputs to + // the Merge and the NextIteration should be the other input. + if (arg.merge->input_types().size() != 2) { + return errors::InvalidArgument( + "Unexpected number of inputs to Merge node for loop-varying " + "argument ", + FormatNodeForError(*arg.merge), "; expected 2, got ", + arg.merge->input_types().size()); + } + TF_RETURN_IF_ERROR(arg.merge->input_node(1 - enter_merge->dst_input(), + &arg.next_iteration)); + if (!IsNextIteration(arg.next_iteration)) { + return errors::InvalidArgument( + "Expected NextIteration node as input to Merge node; got node ", + FormatNodeForError(*arg.next_iteration), " with kind ", + arg.next_iteration->type_string()); + } + + // Find the Switch successor of the Merge. There should be exactly one + // Switch node that is a successor of both the Merge and the LoopCond. + for (const Edge* edge : arg.merge->out_edges()) { + if (edge->dst_input() == 0 && IsSwitch(edge->dst()) && + switches.find(edge->dst()) != switches.end()) { + if (arg.switch_node != nullptr) { + return errors::InvalidArgument("Duplicate Switch successors to ", + FormatNodeForError(*arg.merge)); + } + arg.switch_node = edge->dst(); + } + } + if (arg.switch_node == nullptr) { + return errors::InvalidArgument("Missing Switch successor to ", + FormatNodeForError(*arg.merge)); + } + + // Update the device on the Identity outputs of the switch to match their + // target. These Identity outputs do not + + // Loop over the switch node's output to: + // - Find the Exit successor. + // - Set the sharding on all Identity outputs of the switch. These + // identity nodes are values used by the loop body or condition. + // The Identity node may have the wrong device so copy the device from + // one of its outputs instead. + std::deque possible_exit; + for (const Edge* edge : arg.switch_node->out_edges()) { + if (edge->src_output() == 0) { + possible_exit.push_back(edge); + } + if (IsIdentity(edge->dst())) { + TF_RETURN_IF_ERROR( + SetNodeShardingFromNeighbors(edge->dst(), /*out_edges=*/true)); + } + } + // TODO(b/67425339): Allow general graph between switch and exit. + while (!possible_exit.empty()) { + const Edge* edge = possible_exit.front(); + possible_exit.pop_front(); + if (IsExit(edge->dst())) { + if (arg.exit != nullptr) { + return errors::InvalidArgument( + "Duplicate Exit successors to ", + FormatNodeForError(*arg.switch_node)); + } + arg.exit = edge->dst(); + } else { + if (!IsIdentity(edge->dst())) { + return errors::Unimplemented("General graph between switch (", + FormatNodeForError(*arg.switch_node), + ") and exit node of frame ", + frame->name, " not supported yet."); + } + for (const Edge* out : edge->dst()->out_edges()) { + possible_exit.push_back(out); + } + } + } + } + } + + // Builds the condition and body functions. + std::unique_ptr cond_graph; + TF_RETURN_IF_ERROR(BuildLoopCondition(*graph, frame, &cond_graph)); + DataTypeVector arg_types; + std::unique_ptr body_graph; + TF_RETURN_IF_ERROR(BuildLoopBody(*graph, frame, &arg_types, &body_graph)); + + VLOG(2) << "Frame " << frame->name << " condition: " + << dump_graph::DumpGraphToFile("loop_condition", *cond_graph, library) + << " body: " << dump_graph::DumpGraphToFile("loop_body", *body_graph); + + static std::atomic sequence_num(0LL); + int64 id = ++sequence_num; + NameAttrList cond_name; + cond_name.set_name(strings::StrCat("_functionalize_cond_", id)); + NameAttrList body_name; + body_name.set_name(strings::StrCat("_functionalize_body_", id)); + FunctionDef cond_fdef; + TF_RETURN_IF_ERROR( + GraphToFunctionDef(*cond_graph, cond_name.name(), &cond_fdef)); + FunctionDef body_fdef; + TF_RETURN_IF_ERROR( + GraphToFunctionDef(*body_graph, body_name.name(), &body_fdef)); + + TF_RETURN_IF_ERROR(library->AddFunctionDef(cond_fdef)); + TF_RETURN_IF_ERROR(library->AddFunctionDef(body_fdef)); + if (lookup_library) { + // Copy missing FunctionDefs from lookup_library to library to make library + // self-contained. + TF_RETURN_IF_ERROR( + AddMissingFunctionDef(cond_fdef, lookup_library, library)); + TF_RETURN_IF_ERROR( + AddMissingFunctionDef(body_fdef, lookup_library, library)); + } + + // Builds a While operator. + NodeDef while_def; + NodeDefBuilder builder(frame->loop_cond->name(), "XlaWhile"); + builder.Attr("T", arg_types); + builder.Attr("cond", cond_name); + builder.Attr("body", body_name); + std::vector inputs; + for (int i = 0; i < frame->args.size(); ++i) { + const Arg& arg = frame->args[i]; + const Edge* in_edge; + TF_RETURN_IF_ERROR(arg.enter->input_edge(0, &in_edge)); + if (in_edge->IsControlEdge()) { + builder.ControlInput(in_edge->src()->name()); + } else { + inputs.push_back(NodeDefBuilder::NodeOut( + in_edge->src()->name(), in_edge->src_output(), arg_types[i])); + } + } + builder.Input(inputs); + TF_RETURN_IF_ERROR(builder.Finalize(&while_def)); + TF_ASSIGN_OR_RETURN(Node * while_node, AddNodeDefToGraph(while_def, graph)); + + // Copies edges to the Enter nodes and from the Exit nodes onto the While. + for (int i = 0; i < frame->args.size(); ++i) { + const Arg& arg = frame->args[i]; + const Edge* in_edge; + TF_RETURN_IF_ERROR(arg.enter->input_edge(0, &in_edge)); + if (in_edge->IsControlEdge()) { + graph->AddControlEdge(in_edge->src(), while_node); + } else { + graph->AddEdge(in_edge->src(), in_edge->src_output(), while_node, i); + } + + if (!arg.is_loop_invariant) { + // Add output edges if the output of the loop is consumed. + if (arg.exit != nullptr) { + std::vector edges(arg.exit->out_edges().begin(), + arg.exit->out_edges().end()); + for (const Edge* edge : edges) { + Node* dst = edge->dst(); + int dst_input = edge->dst_input(); + graph->RemoveEdge(edge); + + if (dst_input == Graph::kControlSlot) { + graph->AddControlEdge(while_node, dst); + } else { + graph->AddEdge(while_node, i, dst, dst_input); + } + } + } + } + } + + // Remove the old nodes from the graph, and add the while node to the parent + // frame. + for (Node* node : frame->nodes) { + graph->RemoveNode(node); + } + frame->nodes.clear(); + frame->parent->nodes.insert(while_node); + + VLOG(2) << "Frame " << frame->name << " after: " + << dump_graph::DumpGraphToFile("functionalize_after", *graph, + library); + + return Status::OK(); +} +} // namespace + +Status FunctionalizeWhileLoop(const FunctionLibraryDefinition* lookup_library, + Graph* graph, + FunctionLibraryDefinition* library) { + // Note: BuildControlFlowInfo() requires that the graph's source node is + // connected to all source nodes in the graph. Many graphs violate this + // invariant. + std::vector cf_info; + std::vector unreachable_nodes; + TF_RETURN_IF_ERROR(BuildControlFlowInfo(graph, &cf_info, &unreachable_nodes)); + if (!unreachable_nodes.empty()) { + return errors::InvalidArgument( + "The following nodes are unreachable from the source in the graph: ", + errors::FormatNodeNamesForError(unreachable_nodes)); + } + + // Builds Frames, indexed by name. + std::unordered_map frames; + for (Node* node : graph->op_nodes()) { + const ControlFlowInfo& cf = cf_info[node->id()]; + + VLOG(2) << "node: " << node->name() << " (" << node->id() + << ") frame_name: " << cf.frame_name + << " frame: " << (cf.frame ? cf.frame->name() : "---") + << " parent_frame: " + << (cf.parent_frame ? cf.parent_frame->name() : "---"); + TF_RET_CHECK(cf.frame != nullptr && cf.parent_frame != nullptr); + + Frame& frame = frames[cf.frame_name]; + Frame* parent = &frames[cf_info[cf.parent_frame->id()].frame_name]; + if (frame.parent == nullptr) { + frame.parent = parent; + frame.name = cf.frame_name; + ++parent->num_children; + } + + if (IsEnter(node)) { + Arg arg; + arg.enter = node; + TF_RETURN_IF_ERROR(GetNodeAttr(arg.enter->attrs(), "is_constant", + &arg.is_loop_invariant)); + frame.args.push_back(arg); + } else if (IsLoopCond(node)) { + frame.loop_cond = node; + } + frame.nodes.insert(node); + } + + // Adds frames with no children (i.e., the innermost frames) to a worklist. + std::deque worklist; + for (auto& frame : frames) { + if (frame.second.num_children == 0) { + worklist.push_back(&frame.second); + } + } + + // Eliminate loops from innermost to outermost. + while (!worklist.empty()) { + Frame* frame = worklist.front(); + worklist.pop_front(); + if (frame->parent == frame) { + // Skip the root frame. + continue; + } + + TF_RETURN_IF_ERROR( + FunctionalizeLoop(lookup_library, graph, frame, library)); + + // If the parent has no remaining children, add it to the worklist. + --frame->parent->num_children; + if (frame->parent->num_children == 0) { + worklist.push_back(frame->parent); + } + } + + // There should be no cycle at this point, since while loops have been removed + // from graph. + // Check that the newly added XlaWhile nodes don't feed into themselves. + for (const Node* node : graph->op_nodes()) { + if (node->def().op() == "XlaWhile") { + TF_RETURN_WITH_CONTEXT_IF_ERROR( + CheckNodeNotInCycle(node, graph->num_node_ids()), + "Functionalizing loop failed."); + } + } + + return Status::OK(); +} + +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/functionalize_while.h b/tensorflow/compiler/tf2xla/functionalize_while.h new file mode 100644 index 0000000000000000000000000000000000000000..a708c6e4ec4e13527b4ee2d6c435dddee0a2b4e2 --- /dev/null +++ b/tensorflow/compiler/tf2xla/functionalize_while.h @@ -0,0 +1,32 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_TF2XLA_FUNCTIONALIZE_WHILE_H_ +#define TENSORFLOW_COMPILER_TF2XLA_FUNCTIONALIZE_WHILE_H_ + +#include "tensorflow/core/framework/function.h" +#include "tensorflow/core/graph/graph.h" + +namespace tensorflow { + +// Transformation that converts tf.while_loop() loops into functional While +// operators, suitable for XLA compilation. If lookup_library is provided, use +// it to make the library for control flow self-contained. +Status FunctionalizeWhileLoop(const FunctionLibraryDefinition* lookup_library, + Graph* graph, FunctionLibraryDefinition* library); + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_TF2XLA_FUNCTIONALIZE_WHILE_H_ diff --git a/tensorflow/compiler/tf2xla/graph_compiler.cc b/tensorflow/compiler/tf2xla/graph_compiler.cc index e1cea03865ce9978e634429b5ce41fe8b245a575..e4fdf0a6186eb69a2e3413838c91616b992ef2d6 100644 --- a/tensorflow/compiler/tf2xla/graph_compiler.cc +++ b/tensorflow/compiler/tf2xla/graph_compiler.cc @@ -29,7 +29,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_context.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/xla/client/client_library.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/core/common_runtime/device.h" #include "tensorflow/core/common_runtime/executor.h" #include "tensorflow/core/common_runtime/function.h" diff --git a/tensorflow/compiler/tf2xla/kernels/BUILD b/tensorflow/compiler/tf2xla/kernels/BUILD index 7f3e32d96d3c6846471f74a0cec53c09f396ebe8..b1366e9e31e28406c5bf1a808b9c5670558ed9c7 100644 --- a/tensorflow/compiler/tf2xla/kernels/BUILD +++ b/tensorflow/compiler/tf2xla/kernels/BUILD @@ -6,6 +6,10 @@ package( load("//tensorflow:tensorflow.bzl", "tf_copts") load("//tensorflow:tensorflow.bzl", "tf_kernel_library") +load( + "//third_party/mkl:build_defs.bzl", + "if_mkl", +) tf_kernel_library( name = "xla_ops", @@ -123,13 +127,15 @@ tf_kernel_library( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:client_library", + "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/client/lib:arithmetic", "//tensorflow/compiler/xla/client/lib:constants", "//tensorflow/compiler/xla/client/lib:math", "//tensorflow/compiler/xla/client/lib:numeric", + "//tensorflow/compiler/xla/client/lib:pooling", "//tensorflow/compiler/xla/client/lib:prng", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", + "//tensorflow/compiler/xla/client/lib:sorting", "//tensorflow/core:framework", "//tensorflow/core:image_ops_op_lib", "//tensorflow/core:lib", @@ -152,8 +158,14 @@ tf_kernel_library( "//tensorflow/core/kernels:sparse_to_dense_op", "//tensorflow/core/kernels:stack_ops", "//tensorflow/core/kernels:training_ops", - "//tensorflow/core/kernels:transpose_op", - ], + ] + if_mkl( + [ + "//tensorflow/core/kernels:mkl_transpose_op", + ], + [ + "//tensorflow/core/kernels:transpose_op", + ], + ), ) tf_kernel_library( @@ -165,8 +177,8 @@ tf_kernel_library( "//tensorflow/compiler/tf2xla:xla_compiler", "//tensorflow/compiler/tf2xla/ops:xla_ops", "//tensorflow/compiler/xla:literal", + "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/client:xla_computation", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", @@ -182,7 +194,7 @@ tf_kernel_library( "//tensorflow/compiler/tf2xla:xla_compiler", "//tensorflow/compiler/tf2xla/ops:xla_ops", "//tensorflow/compiler/xla:literal", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", + "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", @@ -219,8 +231,8 @@ tf_kernel_library( "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla/client:client_library", + "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/client/lib:arithmetic", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core/kernels:argmax_op", diff --git a/tensorflow/compiler/tf2xla/kernels/aggregate_ops.cc b/tensorflow/compiler/tf2xla/kernels/aggregate_ops.cc index e33532828040123243f839ab1aa655b4bbc72520..41a453da80dec6b6f57a4d222e2c33ef6b786a10 100644 --- a/tensorflow/compiler/tf2xla/kernels/aggregate_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/aggregate_ops.cc @@ -15,7 +15,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" namespace tensorflow { namespace { diff --git a/tensorflow/compiler/tf2xla/kernels/arg_op.cc b/tensorflow/compiler/tf2xla/kernels/arg_op.cc index 26fc1620a4f032b3af28de6e3a5af0e965e82341..276d744c096f8996c774964204feaa3762bdb844 100644 --- a/tensorflow/compiler/tf2xla/kernels/arg_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/arg_op.cc @@ -65,6 +65,6 @@ class XlaArgOp : public XlaOpKernel { TF_DISALLOW_COPY_AND_ASSIGN(XlaArgOp); }; -REGISTER_XLA_OP(Name("_Arg").AllowResourceTypes(), XlaArgOp); +REGISTER_XLA_OP(Name("_Arg").AllowResourceTypes().CompilationOnly(), XlaArgOp); } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/batch_norm_op.cc b/tensorflow/compiler/tf2xla/kernels/batch_norm_op.cc index c4af79281d2162b1dbfb0a7881720892f4bc49d2..b3ad0aea84eef601de08909f760699b8700d28f4 100644 --- a/tensorflow/compiler/tf2xla/kernels/batch_norm_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/batch_norm_op.cc @@ -18,7 +18,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/core/util/tensor_format.h" namespace tensorflow { diff --git a/tensorflow/compiler/tf2xla/kernels/batchtospace_op.cc b/tensorflow/compiler/tf2xla/kernels/batchtospace_op.cc index 26130fd9e7fce75c6d2a5a53cfc85842cf762b35..48f2a005ab16651fe29d0f6f9d881f95693da461 100644 --- a/tensorflow/compiler/tf2xla/kernels/batchtospace_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/batchtospace_op.cc @@ -16,7 +16,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" namespace tensorflow { namespace { diff --git a/tensorflow/compiler/tf2xla/kernels/bias_ops.cc b/tensorflow/compiler/tf2xla/kernels/bias_ops.cc index e9b2c0b16d39cb3b747c0316621fb01de709b12e..41f540506ba41fbe7f91393e7b8e26a89e72ef0a 100644 --- a/tensorflow/compiler/tf2xla/kernels/bias_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/bias_ops.cc @@ -18,7 +18,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/util/tensor_format.h" diff --git a/tensorflow/compiler/tf2xla/kernels/binary_ops.cc b/tensorflow/compiler/tf2xla/kernels/binary_ops.cc index d6d4ae89376b67c14af8ef4f3a608fcc83b6fb59..2c328102e0bd84709707f102272691b6aec9a577 100644 --- a/tensorflow/compiler/tf2xla/kernels/binary_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/binary_ops.cc @@ -19,7 +19,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/client_library.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/core/framework/kernel_def_builder.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/types.h" diff --git a/tensorflow/compiler/tf2xla/kernels/bucketize_op.cc b/tensorflow/compiler/tf2xla/kernels/bucketize_op.cc index efbdb76eaaf78904fe783a018940b1b096ec39bd..5078f8662bd397eaa51274ec816c130b8ced92cc 100644 --- a/tensorflow/compiler/tf2xla/kernels/bucketize_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/bucketize_op.cc @@ -18,7 +18,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/lib/arithmetic.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/core/framework/op_kernel.h" namespace tensorflow { diff --git a/tensorflow/compiler/tf2xla/kernels/cast_op.cc b/tensorflow/compiler/tf2xla/kernels/cast_op.cc index 62eebf762b3e063da8ec456cc4726d3cc9b77d1d..8cc2479dd555380da7500abe6b2aca380110333b 100644 --- a/tensorflow/compiler/tf2xla/kernels/cast_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/cast_op.cc @@ -17,7 +17,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/primitive_util.h" #include "tensorflow/core/framework/kernel_def_builder.h" diff --git a/tensorflow/compiler/tf2xla/kernels/categorical_op.cc b/tensorflow/compiler/tf2xla/kernels/categorical_op.cc index 1784e712b56145bbdff5f1daa2e031b65d0774b6..e7fef77edcba0ea5a521956a704225ac4f7fcb22 100644 --- a/tensorflow/compiler/tf2xla/kernels/categorical_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/categorical_op.cc @@ -21,7 +21,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/lib/arithmetic.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_shape.h" diff --git a/tensorflow/compiler/tf2xla/kernels/clip_by_value_op.cc b/tensorflow/compiler/tf2xla/kernels/clip_by_value_op.cc index 4e6d33304c4ae08a0fd1e0a8373267a527087528..547fe48046e8c934e3bc14d02c8448e107c1a406 100644 --- a/tensorflow/compiler/tf2xla/kernels/clip_by_value_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/clip_by_value_op.cc @@ -15,7 +15,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/core/framework/tensor_shape.h" namespace tensorflow { diff --git a/tensorflow/compiler/tf2xla/kernels/concat_op.cc b/tensorflow/compiler/tf2xla/kernels/concat_op.cc index e3a32a5c0e2f93237c8c7ebeea3668b5d1ab6c23..f4106051043859a6786705009d76b02a64cd3ff1 100644 --- a/tensorflow/compiler/tf2xla/kernels/concat_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/concat_op.cc @@ -22,7 +22,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" diff --git a/tensorflow/compiler/tf2xla/kernels/const_op.cc b/tensorflow/compiler/tf2xla/kernels/const_op.cc index f4360d8c3f6fc4007c31fdcfd7f7634de15c76d4..da8cf3fc6fa694f592280f8c249d317827d9cd09 100644 --- a/tensorflow/compiler/tf2xla/kernels/const_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/const_op.cc @@ -17,7 +17,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_compiler.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/core/framework/kernel_def_builder.h" #include "tensorflow/core/framework/tensor.pb.h" diff --git a/tensorflow/compiler/tf2xla/kernels/conv_ops.cc b/tensorflow/compiler/tf2xla/kernels/conv_ops.cc index 48ac4867edcef97be001a24f42f6a35225d466c9..674720e22fbf9d995e74c7dbd0ef7d7765941867 100644 --- a/tensorflow/compiler/tf2xla/kernels/conv_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/conv_ops.cc @@ -19,7 +19,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/lib/numeric.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/core/framework/numeric_op.h" #include "tensorflow/core/framework/op_kernel.h" @@ -120,45 +120,30 @@ xla::XlaOp CreateExpandedFilterMask(const TensorShape& filter_shape, {expanded_filter_shape.dims() - 2}); } -// Expands a filter of shape [H, W, ..., M, N] to [H, W, ..., M, M*N] by adding -// zeros for the cross-depth filters. Used to build a depthwise convolution. -xla::XlaOp ExpandFilterForDepthwiseConvolution(const TensorShape& filter_shape, - DataType dtype, - const xla::XlaOp& filter, - xla::XlaBuilder* builder) { - int64 depthwise_multiplier = filter_shape.dim_size(filter_shape.dims() - 1); - int64 input_feature = filter_shape.dim_size(filter_shape.dims() - 2); - TensorShape expanded_filter_shape = - ExpandedFilterShapeForDepthwiseConvolution(filter_shape); +// Reshapes a filter of shape [H, W, ..., M, N] to [H, W, ..., 1, M*N]. Used to +// build a depthwise convolution. +xla::XlaOp ReshapeFilterForDepthwiseConvolution(const TensorShape& filter_shape, + const xla::XlaOp& filter) { + int64 input_feature_dim = filter_shape.dims() - 2; + int64 output_feature_dim = filter_shape.dims() - 1; + int64 depthwise_multiplier = filter_shape.dim_size(output_feature_dim); + int64 input_feature = filter_shape.dim_size(input_feature_dim); // Create a [H, W, ..., 1, N*M] reshape of the filter. - TensorShape implicit_broadcast_filter_shape = expanded_filter_shape; - implicit_broadcast_filter_shape.set_dim( - implicit_broadcast_filter_shape.dims() - 2, 1); - implicit_broadcast_filter_shape.set_dim( - implicit_broadcast_filter_shape.dims() - 1, - depthwise_multiplier * input_feature); - auto implicit_broadcast_filter = - xla::Reshape(filter, implicit_broadcast_filter_shape.dim_sizes()); - - // Broadcast the filter to [H, W, ..., M, M*N]. - auto expanded_zero = CreateExpandedZero(filter_shape, dtype, builder); - auto expanded_filter = xla::Add(implicit_broadcast_filter, expanded_zero); - - // If the filter mask is set, choose the broadcasted filter, othwerwise, - // choose zero. - return xla::Select(CreateExpandedFilterMask(filter_shape, builder), - expanded_filter, expanded_zero); + TensorShape implicit_broadcast_filter_shape = filter_shape; + implicit_broadcast_filter_shape.set_dim(input_feature_dim, 1); + implicit_broadcast_filter_shape.set_dim(output_feature_dim, + depthwise_multiplier * input_feature); + return xla::Reshape(filter, implicit_broadcast_filter_shape.dim_sizes()); } -// Inverse of ExpandFilterForDepthwiseConvolution. +// Reduces the results of the convolution with an expanded filter to the +// non-expanded filter. xla::XlaOp ContractFilterForDepthwiseBackprop(XlaOpKernelContext* ctx, const TensorShape& filter_shape, DataType dtype, const xla::XlaOp& filter_backprop, xla::XlaBuilder* builder) { - TensorShape expanded_filter_shape = - ExpandedFilterShapeForDepthwiseConvolution(filter_shape); auto masked_expanded_filter = xla::Select( CreateExpandedFilterMask(filter_shape, builder), filter_backprop, CreateExpandedZero(filter_shape, dtype, builder)); @@ -168,8 +153,7 @@ xla::XlaOp ContractFilterForDepthwiseBackprop(XlaOpKernelContext* ctx, // ExpandedZero guarantees that only one element is non zero, so there // cannot be accumulated precision error. xla::Reduce(masked_expanded_filter, XlaHelpers::Zero(builder, dtype), - *ctx->GetOrCreateAdd(dtype), - {expanded_filter_shape.dims() - 2}), + *ctx->GetOrCreateAdd(dtype), {filter_shape.dims() - 2}), filter_shape.dim_sizes()); } @@ -245,15 +229,9 @@ class ConvOp : public XlaOpKernel { "input and filter must have the same depth: ", in_depth, " vs ", input_shape.dim_size(feature_dim))); - xla::XlaBuilder* b = ctx->builder(); - xla::XlaOp filter = ctx->Input(1); - TensorShape expanded_filter_shape = filter_shape; if (depthwise_) { - filter = ExpandFilterForDepthwiseConvolution( - filter_shape, ctx->input_type(0), filter, b); - expanded_filter_shape = - ExpandedFilterShapeForDepthwiseConvolution(filter_shape); + filter = ReshapeFilterForDepthwiseConvolution(filter_shape, filter); } xla::ConvolutionDimensionNumbers dims; @@ -280,14 +258,15 @@ class ConvOp : public XlaOpKernel { int64 unused_output_size; OP_REQUIRES_OK( ctx, GetWindowedOutputSizeVerboseV2( - input_shape.dim_size(dim), expanded_filter_shape.dim_size(i), + input_shape.dim_size(dim), filter_shape.dim_size(i), rhs_dilation[i], window_strides[i], padding_, &unused_output_size, &padding[i].first, &padding[i].second)); } - xla::XlaOp conv = - xla::ConvGeneralDilated(ctx->Input(0), filter, window_strides, padding, - lhs_dilation, rhs_dilation, dims); + xla::XlaOp conv = xla::ConvGeneralDilated( + ctx->Input(0), filter, window_strides, padding, lhs_dilation, + rhs_dilation, dims, + /*feature_group_count=*/depthwise_ ? in_depth : 1); ctx->SetOutput(0, conv); } @@ -388,7 +367,6 @@ class ConvBackpropInputOp : public XlaOpKernel { expanded_filter_shape, out_backprop_shape, dilations_, strides_, padding_, data_format_, &dims)); - xla::XlaBuilder* b = ctx->builder(); auto filter = ctx->Input(1); auto out_backprop = ctx->Input(2); @@ -425,12 +403,6 @@ class ConvBackpropInputOp : public XlaOpKernel { rhs_dilation[i] = dilations_[dim]; } - // If this is a depthwise convolution, expand the filter. - if (depthwise_) { - filter = ExpandFilterForDepthwiseConvolution( - filter_shape, ctx->input_type(1), filter, b); - } - // Mirror the filter in the spatial dimensions. xla::XlaOp mirrored_weights = xla::Rev(filter, kernel_spatial_dims); @@ -438,7 +410,11 @@ class ConvBackpropInputOp : public XlaOpKernel { // = gradients (with padding and dilation) mirrored_weights xla::XlaOp in_backprop = xla::ConvGeneralDilated( out_backprop, mirrored_weights, /*window_strides=*/ones, padding, - lhs_dilation, rhs_dilation, dnums); + lhs_dilation, rhs_dilation, dnums, + /*feature_group_count=*/ + depthwise_ ? out_backprop_shape.dim_size(feature_dim) / + filter_shape.dim_size(num_spatial_dims_ + 1) + : 1); ctx->SetOutput(0, in_backprop); } diff --git a/tensorflow/compiler/tf2xla/kernels/cross_op.cc b/tensorflow/compiler/tf2xla/kernels/cross_op.cc index 500a564f3f0489a42dbc9d5b70ae7708a7a43973..db579a5b35d69deb3dca578e31c1b54fada76342 100644 --- a/tensorflow/compiler/tf2xla/kernels/cross_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/cross_op.cc @@ -16,7 +16,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" namespace tensorflow { namespace { diff --git a/tensorflow/compiler/tf2xla/kernels/cwise_ops.cc b/tensorflow/compiler/tf2xla/kernels/cwise_ops.cc index 9ff3e0222831cb4339943966810eeae451e47a2c..ef1015552d181a183d412f9c269dd5ec608b388f 100644 --- a/tensorflow/compiler/tf2xla/kernels/cwise_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/cwise_ops.cc @@ -22,7 +22,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/client_library.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/types.h" #include "tensorflow/core/util/bcast.h" diff --git a/tensorflow/compiler/tf2xla/kernels/cwise_ops.h b/tensorflow/compiler/tf2xla/kernels/cwise_ops.h index 4f92dbc8740b697322424058530b8477c35d809a..a5b870f8dbf70bcee331992345d63fd5d986bdca 100644 --- a/tensorflow/compiler/tf2xla/kernels/cwise_ops.h +++ b/tensorflow/compiler/tf2xla/kernels/cwise_ops.h @@ -20,7 +20,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/xla/client/client_library.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/util/bcast.h" diff --git a/tensorflow/compiler/tf2xla/kernels/depthtospace_op.cc b/tensorflow/compiler/tf2xla/kernels/depthtospace_op.cc index f3149200250935629a6e4bf67bff0c048135ce3e..12b0e38288e8f222ed506a75ec2575f27141c859 100644 --- a/tensorflow/compiler/tf2xla/kernels/depthtospace_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/depthtospace_op.cc @@ -16,7 +16,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/core/util/tensor_format.h" namespace tensorflow { diff --git a/tensorflow/compiler/tf2xla/kernels/diag_op.cc b/tensorflow/compiler/tf2xla/kernels/diag_op.cc index 22cda27567a58f17ca92803d4eccfc1f29f0b8b8..ed44ad218b6dc073583ec339da082b6881ad672d 100644 --- a/tensorflow/compiler/tf2xla/kernels/diag_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/diag_op.cc @@ -20,7 +20,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/lib/constants.h" #include "tensorflow/compiler/xla/client/lib/numeric.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/framework/op_kernel.h" diff --git a/tensorflow/compiler/tf2xla/kernels/dynamic_slice_ops.cc b/tensorflow/compiler/tf2xla/kernels/dynamic_slice_ops.cc index 3b86ea34c9e7d943eb9c7de222e0a2be049ebc68..a3389d5b905bf3ee15744ab4fcee193d312e2ae0 100644 --- a/tensorflow/compiler/tf2xla/kernels/dynamic_slice_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/dynamic_slice_ops.cc @@ -18,7 +18,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/compiler/tf2xla/type_util.h" diff --git a/tensorflow/compiler/tf2xla/kernels/dynamic_stitch_op.cc b/tensorflow/compiler/tf2xla/kernels/dynamic_stitch_op.cc index 958231505b50431b9bb267b0a3cc5ed56e3aeb21..cb73053666d4c32bc0a2ef19b174aee1a29f101e 100644 --- a/tensorflow/compiler/tf2xla/kernels/dynamic_stitch_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/dynamic_stitch_op.cc @@ -20,7 +20,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" diff --git a/tensorflow/compiler/tf2xla/kernels/elu_op.cc b/tensorflow/compiler/tf2xla/kernels/elu_op.cc index 81f42e504e4b6f813a29769719a7a7fb5d99b9c5..5fdb1d972c55efb876972d3f472b53a1f7cde1c2 100644 --- a/tensorflow/compiler/tf2xla/kernels/elu_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/elu_op.cc @@ -18,7 +18,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/kernels/cwise_ops.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/core/framework/kernel_def_builder.h" #include "tensorflow/core/framework/types.h" diff --git a/tensorflow/compiler/tf2xla/kernels/extract_image_patches_op.cc b/tensorflow/compiler/tf2xla/kernels/extract_image_patches_op.cc index 65d42a302fca48c7b5f88813f80e975823f63ddf..c68b0bfd7961892294c2931e5c4c44de534a7740 100644 --- a/tensorflow/compiler/tf2xla/kernels/extract_image_patches_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/extract_image_patches_op.cc @@ -18,7 +18,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/lib/numeric.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/core/util/tensor_format.h" namespace tensorflow { diff --git a/tensorflow/compiler/tf2xla/kernels/fake_quantize_ops.cc b/tensorflow/compiler/tf2xla/kernels/fake_quantize_ops.cc index 2fd1a34741e1c7235397f9a69dd8444b4679fa22..cdba6680dee3fade5bdf0c453ed672b653072b0d 100644 --- a/tensorflow/compiler/tf2xla/kernels/fake_quantize_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/fake_quantize_ops.cc @@ -17,7 +17,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/lib/arithmetic.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/core/platform/macros.h" namespace tensorflow { diff --git a/tensorflow/compiler/tf2xla/kernels/fft_ops.cc b/tensorflow/compiler/tf2xla/kernels/fft_ops.cc index b2b00e51e3b00fa93c258af489cf0f4a3e6e764b..80bcef966360ec9a1ca63a02741108ce41b31846 100644 --- a/tensorflow/compiler/tf2xla/kernels/fft_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/fft_ops.cc @@ -18,7 +18,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/core/framework/numeric_op.h" #include "tensorflow/core/framework/op_kernel.h" diff --git a/tensorflow/compiler/tf2xla/kernels/fill_op.cc b/tensorflow/compiler/tf2xla/kernels/fill_op.cc index 95faa1d058f4c0d3fa802b157c6daba1e1adaf41..54b21a278229024e3e54e9135548be6b69b077e1 100644 --- a/tensorflow/compiler/tf2xla/kernels/fill_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/fill_op.cc @@ -19,7 +19,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/core/framework/kernel_def_builder.h" #include "tensorflow/core/framework/register_types.h" diff --git a/tensorflow/compiler/tf2xla/kernels/gather_op.cc b/tensorflow/compiler/tf2xla/kernels/gather_op.cc index 5f041be5df226ed996b21844c0cf92b6dfac005c..44140304fdf5cdf60d8ad8b85c532fcadff8ba86 100644 --- a/tensorflow/compiler/tf2xla/kernels/gather_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/gather_op.cc @@ -21,7 +21,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/core/framework/kernel_def_builder.h" #include "tensorflow/core/framework/op_kernel.h" @@ -95,11 +95,11 @@ Status XlaGather(const xla::XlaOp& input, const TensorShape& input_shape, // operand = s32[3,3] parameter(0) // indices = s32[2] parameter(1) // gather = s32[3,2] gather(operand, indices), - // output_window_dims={0}, - // elided_window_dims={1}, - // gather_dims_to_operand_dims={1}, + // offset_dims={0}, + // collapsed_slice_dims={1}, + // start_index_map={1}, // index_vector_dim=1, - // window_bounds={3, 1} + // slice_sizes={3, 1} // // // Example of an N-D gather pulling out slices of shape [1,1,2] out of a @@ -108,42 +108,42 @@ Status XlaGather(const xla::XlaOp& input, const TensorShape& input_shape, // operand = s32[3,3,2] parameter(0) // indices = s32[2,2] parameter(1) // gather = s32[2,2] gather(operand, indices), - // output_window_dims={1}, - // elided_window_dims={0,1}, - // gather_dims_to_operand_dims={0,1}, + // offset_dims={1}, + // collapsed_slice_dims={0,1}, + // start_index_map={0,1}, // index_vector_dim=0, - // window_bounds={1,1,2} + // slice_sizes={1,1,2} xla::GatherDimensionNumbers dim_numbers; - std::vector window_bounds; - window_bounds.reserve(input_shape.dims()); + std::vector slice_sizes; + slice_sizes.reserve(input_shape.dims()); for (int64 i = 0; i < input_shape.dims(); i++) { int64 window_bound; if (axis <= i && i < (axis + num_index_dims)) { - dim_numbers.add_elided_window_dims(i); + dim_numbers.add_collapsed_slice_dims(i); window_bound = 1; } else { window_bound = input_shape.dim_size(i); } - window_bounds.push_back(window_bound); + slice_sizes.push_back(window_bound); if (i < axis) { - dim_numbers.add_output_window_dims(i); + dim_numbers.add_offset_dims(i); } else if (i >= (axis + num_index_dims)) { int64 indices_rank = indices_are_nd ? (indices_shape.dims() - 1) : indices_shape.dims(); - dim_numbers.add_output_window_dims(i + indices_rank - num_index_dims); + dim_numbers.add_offset_dims(i + indices_rank - num_index_dims); } } dim_numbers.set_index_vector_dim(indices_are_nd ? (indices_shape.dims() - 1) : indices_shape.dims()); for (int64 i = axis; i < axis + num_index_dims; i++) { - dim_numbers.add_gather_dims_to_operand_dims(i); + dim_numbers.add_start_index_map(i); } - *gather_output = xla::Gather(input, indices, dim_numbers, window_bounds); + *gather_output = xla::Gather(input, indices, dim_numbers, slice_sizes); return Status::OK(); } diff --git a/tensorflow/compiler/tf2xla/kernels/gather_op_helpers.h b/tensorflow/compiler/tf2xla/kernels/gather_op_helpers.h index d898e43b858bac706d524c7c271f48b1b5fa258f..92346283c31dfe1d638526ac4b26ef762cd7fd14 100644 --- a/tensorflow/compiler/tf2xla/kernels/gather_op_helpers.h +++ b/tensorflow/compiler/tf2xla/kernels/gather_op_helpers.h @@ -20,7 +20,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/xla/client/client_library.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/util/bcast.h" diff --git a/tensorflow/compiler/tf2xla/kernels/identity_op.cc b/tensorflow/compiler/tf2xla/kernels/identity_op.cc index e72200bfbcff20c55ac03030f1afc4bacaabf7ce..19dd38c46ef154ea74bcbb6721dd04924702efcc 100644 --- a/tensorflow/compiler/tf2xla/kernels/identity_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/identity_op.cc @@ -25,7 +25,10 @@ class IdentityOp : public XlaOpKernel { void Compile(XlaOpKernelContext* ctx) override { for (int i = 0; i < ctx->num_inputs(); ++i) { - ctx->SetOutput(i, ctx->Input(i)); + // Forwards using the underlying op_kernel_context so both tensor and + // resource values are forwarded correctly. + ctx->op_kernel_context()->set_output(i, + ctx->op_kernel_context()->input(i)); } } @@ -35,9 +38,10 @@ class IdentityOp : public XlaOpKernel { // XLA_* devices also register a "real" Identity operator so we suppress the // dummy operator using CompilationOnly(). -REGISTER_XLA_OP(Name("Identity").CompilationOnly(), IdentityOp); - -REGISTER_XLA_OP(Name("IdentityN").CompilationOnly(), IdentityOp); +REGISTER_XLA_OP(Name("Identity").AllowResourceTypes().CompilationOnly(), + IdentityOp); +REGISTER_XLA_OP(Name("IdentityN").AllowResourceTypes().CompilationOnly(), + IdentityOp); REGISTER_XLA_OP(Name("PlaceholderWithDefault"), IdentityOp); REGISTER_XLA_OP(Name("PreventGradient"), IdentityOp); REGISTER_XLA_OP(Name("StopGradient"), IdentityOp); diff --git a/tensorflow/compiler/tf2xla/kernels/if_op.cc b/tensorflow/compiler/tf2xla/kernels/if_op.cc index e2160feba00a7272635207ebcb53670cacf34620..6e1dbf5472f0b1eb0abcbe29c553ae926ecf2d8a 100644 --- a/tensorflow/compiler/tf2xla/kernels/if_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/if_op.cc @@ -19,7 +19,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_context.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" namespace tensorflow { @@ -200,25 +200,24 @@ void XlaIfOp::Compile(XlaOpKernelContext* ctx) { } } - xla::XlaOp outputs = xla::Conditional( - ctx->Input(0), xla::Tuple(b, inputs), *then_result.computation, - xla::Tuple(b, inputs), *else_result.computation); + auto input_tuple = xla::Tuple(b, inputs); + xla::XlaOp outputs = + xla::Conditional(ctx->Input(0), input_tuple, *then_result.computation, + input_tuple, *else_result.computation); // Sets non-variable outputs. for (int i = 0; i < output_types_.size(); ++i) { - if (ctx->input_type(i) != DT_RESOURCE) { - xla::XlaOp output_handle = xla::GetTupleElement(outputs, i); - if (VLOG_IS_ON(2)) { - LOG(INFO) << "Setting output " << i; - auto shape_or = b->GetShape(output_handle); - if (shape_or.ok()) { - LOG(INFO) << "Shape for output " << i << ": " - << xla::ShapeUtil::HumanString(shape_or.ValueOrDie()); - } else { - LOG(INFO) << "Shape unknown for output " << i; - } + xla::XlaOp output_handle = xla::GetTupleElement(outputs, i); + if (VLOG_IS_ON(2)) { + LOG(INFO) << "Setting output " << i; + auto shape_or = b->GetShape(output_handle); + if (shape_or.ok()) { + LOG(INFO) << "Shape for output " << i << ": " + << xla::ShapeUtil::HumanString(shape_or.ValueOrDie()); + } else { + LOG(INFO) << "Shape unknown for output " << i; } - ctx->SetOutput(i, output_handle); } + ctx->SetOutput(i, output_handle); } // Updates the values of any resource variables modified by the conditional @@ -247,6 +246,7 @@ void XlaIfOp::Compile(XlaOpKernelContext* ctx) { } REGISTER_XLA_OP(Name("If").AllowResourceTypes(), XlaIfOp); +REGISTER_XLA_OP(Name("StatelessIf").AllowResourceTypes(), XlaIfOp); REGISTER_XLA_OP(Name("XlaIf").AllowResourceTypes(), XlaIfOp); } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/image_ops.cc b/tensorflow/compiler/tf2xla/kernels/image_ops.cc index cb4caf7bcb4caaa1bf7e0e79e52bb966a8838db3..33a73fe5fdf403e513be085dd7bcea3255277b4a 100644 --- a/tensorflow/compiler/tf2xla/kernels/image_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/image_ops.cc @@ -17,7 +17,12 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/lib/arithmetic.h" +#include "tensorflow/compiler/xla/client/lib/constants.h" +#include "tensorflow/compiler/xla/client/lib/sorting.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/core/framework/tensor_shape.h" namespace tensorflow { namespace { @@ -311,5 +316,150 @@ class AdjustHueOp : public XlaOpKernel { }; REGISTER_XLA_OP(Name("AdjustHue"), AdjustHueOp); +class NonMaxSuppressionOp : public XlaOpKernel { + public: + explicit NonMaxSuppressionOp(OpKernelConstruction* context) + : XlaOpKernel(context) { + OP_REQUIRES_OK(context, context->GetAttr("pad_to_max_output_size", + &pad_to_max_output_size_)); + } + + void Compile(XlaOpKernelContext* context) override { + // TODO(b/111646731): Improve scalability of this op, using blocking. + int num_boxes_dim = 0; + int coords_dim = 1; + const TensorShape& boxes_shape = context->InputShape("boxes"); + OP_REQUIRES(context, TensorShapeUtils::IsMatrix(boxes_shape), + errors::InvalidArgument("boxes must be 2-D, currently: ", + boxes_shape.DebugString())); + const int64 num_boxes = boxes_shape.dim_size(num_boxes_dim); + OP_REQUIRES(context, boxes_shape.dim_size(coords_dim) == 4, + errors::InvalidArgument("boxes must have 4 columns", + boxes_shape.DebugString())); + const TensorShape& scores_shape = context->InputShape("scores"); + OP_REQUIRES(context, TensorShapeUtils::IsVector(scores_shape), + errors::InvalidArgument("scores must be 1-D, currently: ", + scores_shape.DebugString())); + OP_REQUIRES( + context, scores_shape.dim_size(0) == num_boxes, + errors::InvalidArgument("scores size must equal number of boxes", + scores_shape.DebugString())); + OP_REQUIRES(context, pad_to_max_output_size_, + errors::InvalidArgument( + "XLA compilation requires pad_to_max_output_size == True")); + + xla::XlaOp boxes = context->Input("boxes"); + xla::XlaOp scores = context->Input("scores"); + int64 output_size; + OP_REQUIRES_OK(context, context->ConstantInputAsIntScalar(2, &output_size)); + OP_REQUIRES( + context, output_size >= 0, + errors::InvalidArgument("Need output_size >= 0, got ", output_size)); + xla::XlaOp score_thresh = context->Input("score_threshold"); + xla::XlaOp iou_thresh = context->Input("iou_threshold"); + + xla::XlaBuilder* const builder = context->builder(); + + // Choose a more convenient layout. + xla::XlaOp boxes_t = xla::Transpose(boxes, {1, 0}); + coords_dim = 0; + num_boxes_dim = 1; + + // Shapes are henceforth [1, num_boxes]. + xla::XlaOp coord_y0 = xla::SliceInDim(boxes_t, + /*start_index=*/0, + /*limit_index=*/1, + /*stride=*/1, + /*dimno=*/coords_dim); + xla::XlaOp coord_x0 = xla::SliceInDim(boxes_t, + /*start_index=*/1, + /*limit_index=*/2, + /*stride=*/1, + /*dimno=*/coords_dim); + xla::XlaOp coord_y1 = xla::SliceInDim(boxes_t, + /*start_index=*/2, + /*limit_index=*/3, + /*stride=*/1, + /*dimno=*/coords_dim); + xla::XlaOp coord_x1 = xla::SliceInDim(boxes_t, + /*start_index=*/3, + /*limit_index=*/4, + /*stride=*/1, + /*dimno=*/coords_dim); + xla::XlaOp y1 = + xla::Select(xla::Le(coord_y0, coord_y1), coord_y0, coord_y1); + xla::XlaOp y2 = + xla::Select(xla::Le(coord_y0, coord_y1), coord_y1, coord_y0); + xla::XlaOp x1 = + xla::Select(xla::Le(coord_x0, coord_x1), coord_x0, coord_x1); + xla::XlaOp x2 = + xla::Select(xla::Le(coord_x0, coord_x1), coord_x1, coord_x0); + xla::XlaOp area = (y2 - y1) * (x2 - x1); + + // Transpose the 1xN tensors, instead of the NxN tensors. + xla::XlaOp y1_t = xla::Transpose(y1, {1, 0}); + xla::XlaOp y2_t = xla::Transpose(y2, {1, 0}); + xla::XlaOp x1_t = xla::Transpose(x1, {1, 0}); + xla::XlaOp x2_t = xla::Transpose(x2, {1, 0}); + xla::XlaOp area_t = xla::Transpose(area, {1, 0}); + + // Shapes are henceforth [num_boxes, num_boxes]. + xla::XlaOp i_xmin = xla::Max(x1, x1_t); + xla::XlaOp i_ymin = xla::Max(y1, y1_t); + xla::XlaOp i_xmax = xla::Min(x2, x2_t); + xla::XlaOp i_ymax = xla::Min(y2, y2_t); + auto square_zero = xla::ZerosLike(i_xmin); + + xla::XlaOp i_area = xla::Max(i_xmax - i_xmin, square_zero) * + xla::Max(i_ymax - i_ymin, square_zero); + xla::XlaOp u_area = area + area_t - i_area; + xla::XlaOp iou = i_area / u_area; + + xla::XlaOp iou_thresh_mask = xla::Gt(iou, iou_thresh + square_zero); + xla::XlaOp scores_2d = xla::Reshape(scores, {num_boxes, 1}); + xla::XlaOp score_cmp_mask = + xla::Gt(scores_2d, xla::Transpose(scores_2d, {1, 0})); + xla::XlaOp suppress = xla::And(iou_thresh_mask, score_cmp_mask); + + // Shapes are [num_boxes] after the reduce. + xla::XlaOp included_iou = xla::Not(xla::Reduce( + suppress, + /*init_value=*/xla::ConstantR0(builder, false), + /*computation=*/CreateScalarOrComputation(xla::PRED, builder), + /*dimensions_to_reduce=*/{0})); + xla::XlaOp included_score = + xla::Gt(scores, xla::Broadcast(score_thresh, {num_boxes})); + xla::XlaOp included = xla::And(included_iou, included_score); + xla::XlaOp neg_inf = + xla::Broadcast(xla::MinValue(builder, xla::F32), {num_boxes}); + xla::XlaOp scores_included = xla::Select(included, scores, neg_inf); + + xla::XlaOp ones_included = xla::Select( + included, + xla::Broadcast(xla::ConstantR0(builder, 1), {num_boxes}), + xla::Broadcast(xla::ConstantR0(builder, 0), {num_boxes})); + + // num_valid is scalar. + xla::XlaOp num_valid = xla::Reduce( + ones_included, + /*init_value=*/xla::ConstantR0(builder, 0), + /*computation=*/CreateScalarAddComputation(xla::S32, builder), + /*dimensions_to_reduce=*/{0}); + + xla::XlaOp output_tuple = TopK(scores_included, output_size); + xla::XlaOp selected_indices = xla::GetTupleElement(output_tuple, 1); + + context->SetOutput(0, selected_indices); + context->SetOutput(1, num_valid); + } + + private: + bool pad_to_max_output_size_; +}; + +REGISTER_XLA_OP( + Name("NonMaxSuppressionV4").CompileTimeConstInput("max_output_size"), + NonMaxSuppressionOp); + } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/image_resize_ops.cc b/tensorflow/compiler/tf2xla/kernels/image_resize_ops.cc index d6bf92fb3df8d38909df99e11c85ede4fac2bf81..8d75624e74028ea083c3facc4f9578ec14c50e6d 100644 --- a/tensorflow/compiler/tf2xla/kernels/image_resize_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/image_resize_ops.cc @@ -19,7 +19,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/array4d.h" #include "tensorflow/compiler/xla/client/lib/numeric.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/core/framework/kernel_def_builder.h" #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/lib/math/math_util.h" diff --git a/tensorflow/compiler/tf2xla/kernels/l2loss_op.cc b/tensorflow/compiler/tf2xla/kernels/l2loss_op.cc index 9e64711051d31107db1bf6f1966f9ed6f5630c34..f028e361bccd51de0bd69a1d2227c7afaed53455 100644 --- a/tensorflow/compiler/tf2xla/kernels/l2loss_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/l2loss_op.cc @@ -16,7 +16,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/types.h" #include "tensorflow/core/kernels/no_op.h" diff --git a/tensorflow/compiler/tf2xla/kernels/listdiff_op.cc b/tensorflow/compiler/tf2xla/kernels/listdiff_op.cc index 2fb072f827906d40dcf410f0312394c4f568a28d..a11bbe918f7f8eb050aaa40d4344f9cc9e9a10a4 100644 --- a/tensorflow/compiler/tf2xla/kernels/listdiff_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/listdiff_op.cc @@ -22,7 +22,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/core/framework/kernel_def_builder.h" #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/lib/core/errors.h" diff --git a/tensorflow/compiler/tf2xla/kernels/lrn_ops.cc b/tensorflow/compiler/tf2xla/kernels/lrn_ops.cc index dc934543cb2f94fbe1e8f1f865156eb082d6a127..87ee2d3aede50eb24e65570f106d49030e1d4236 100644 --- a/tensorflow/compiler/tf2xla/kernels/lrn_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/lrn_ops.cc @@ -16,7 +16,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/core/framework/kernel_def_builder.h" namespace tensorflow { diff --git a/tensorflow/compiler/tf2xla/kernels/matmul_op.cc b/tensorflow/compiler/tf2xla/kernels/matmul_op.cc index aa45b025512cdeb27e3b0cabb3f194a58c6f86f9..6440770c29894c951f010f6c1deb929f4fe79bbf 100644 --- a/tensorflow/compiler/tf2xla/kernels/matmul_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/matmul_op.cc @@ -18,7 +18,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/core/framework/op_kernel.h" namespace tensorflow { diff --git a/tensorflow/compiler/tf2xla/kernels/matrix_band_part_op.cc b/tensorflow/compiler/tf2xla/kernels/matrix_band_part_op.cc index e06c87db7adb1840606208fe15cd68a3ca4d137a..8dfd7de591c4a3c4768dd60b41e03d294ad49397 100644 --- a/tensorflow/compiler/tf2xla/kernels/matrix_band_part_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/matrix_band_part_op.cc @@ -17,7 +17,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/lib/numeric.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/core/framework/tensor_shape.h" namespace tensorflow { diff --git a/tensorflow/compiler/tf2xla/kernels/matrix_set_diag_op.cc b/tensorflow/compiler/tf2xla/kernels/matrix_set_diag_op.cc index e2ab4b83cfb45b2f9a7f3aba2d2a927d10ad8b85..c0ca881ff82cee04e0c5e35f9a2d5732fabdd8a6 100644 --- a/tensorflow/compiler/tf2xla/kernels/matrix_set_diag_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/matrix_set_diag_op.cc @@ -17,7 +17,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/lib/numeric.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" namespace tensorflow { diff --git a/tensorflow/compiler/tf2xla/kernels/mirror_pad_op.cc b/tensorflow/compiler/tf2xla/kernels/mirror_pad_op.cc index 529959dbd90b05f8860360f70e087ef225150600..eedfc3c9140d7b1ccc1944611de98c1d49fbdaf2 100644 --- a/tensorflow/compiler/tf2xla/kernels/mirror_pad_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/mirror_pad_op.cc @@ -16,7 +16,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/core/util/mirror_pad_mode.h" namespace tensorflow { diff --git a/tensorflow/compiler/tf2xla/kernels/pack_op.cc b/tensorflow/compiler/tf2xla/kernels/pack_op.cc index 3aed47de2603f3e187ad515d4db3f884da4c6cc8..a9b519d8928cc2807831fd6b4f12e60b7d58ea55 100644 --- a/tensorflow/compiler/tf2xla/kernels/pack_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/pack_op.cc @@ -22,7 +22,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" diff --git a/tensorflow/compiler/tf2xla/kernels/pad_op.cc b/tensorflow/compiler/tf2xla/kernels/pad_op.cc index 89fd610bc63349d008836c3c4e6ec8927c232a54..e5937b56c17d01892928b073da09f38941ea1bbb 100644 --- a/tensorflow/compiler/tf2xla/kernels/pad_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/pad_op.cc @@ -17,7 +17,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/core/framework/kernel_def_builder.h" #include "tensorflow/core/framework/register_types.h" diff --git a/tensorflow/compiler/tf2xla/kernels/pooling_ops.cc b/tensorflow/compiler/tf2xla/kernels/pooling_ops.cc index 2a4c0cab4b3a4ba9a883850f1264c286aa2d6782..d4d180aff806f12875f0e43f111ee090f6607ef6 100644 --- a/tensorflow/compiler/tf2xla/kernels/pooling_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/pooling_ops.cc @@ -21,7 +21,8 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/lib/arithmetic.h" #include "tensorflow/compiler/xla/client/lib/constants.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/lib/pooling.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/util.h" @@ -71,59 +72,53 @@ class PoolingOp : public XlaOpKernel { int num_dims() const { return num_spatial_dims_ + 2; } - // Method that builds an initial value to use in reductions. - virtual xla::XlaOp InitValue(xla::XlaBuilder* b) = 0; - - // The reduction operation to apply to each window. - virtual const xla::XlaComputation* Reduction(XlaOpKernelContext* ctx) = 0; - - // A post-processing operation to apply on the outputs of the ReduceWindow. - virtual xla::XlaOp PostProcessOutput(XlaOpKernelContext* ctx, - const xla::XlaOp& output, DataType dtype, - const TensorShape& input_shape) = 0; - - void Compile(XlaOpKernelContext* ctx) override { - std::vector ksize = ksize_; - std::vector stride = stride_; - if (ctx->num_inputs() != 1) { - const TensorShape ksize_shape = ctx->InputShape(1); - // Validate input sizes. - OP_REQUIRES(ctx, TensorShapeUtils::IsVector(ksize_shape), - errors::InvalidArgument("ksize must be a vector, not shape ", - ksize_shape.DebugString())); - OP_REQUIRES(ctx, ksize_shape.num_elements() == num_dims(), - errors::InvalidArgument("Sliding window ksize field must " - "specify ", - num_dims(), " dimensions")); - ksize.clear(); - OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntVector(1, &ksize)); - - const TensorShape stride_shape = ctx->InputShape(2); - // Validate input sizes. - OP_REQUIRES(ctx, TensorShapeUtils::IsVector(stride_shape), - errors::InvalidArgument("stride must be a vector, not shape ", - stride_shape.DebugString())); - OP_REQUIRES(ctx, stride_shape.num_elements() == num_dims(), - errors::InvalidArgument("Sliding window stride field must " - "specify ", - num_dims(), " dimensions")); - stride.clear(); - OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntVector(2, &stride)); + protected: + xla::StatusOr> GetKernelSize(XlaOpKernelContext* ctx) { + if (ctx->num_inputs() == 1) { + return ksize_; } - const TensorShape input_shape = ctx->InputShape(0); - OP_REQUIRES(ctx, input_shape.dims() == num_dims(), - errors::InvalidArgument("Input to ", type_string(), - " operator must have ", num_dims(), - " dimensions")); + const TensorShape ksize_shape = ctx->InputShape(1); + // Validate input sizes. + if (!TensorShapeUtils::IsVector(ksize_shape)) { + return errors::InvalidArgument("ksize must be a vector, not shape ", + ksize_shape.DebugString()); + } + if (ksize_shape.num_elements() != num_dims()) { + return errors::InvalidArgument( + "Sliding window ksize field must " + "specify ", + num_dims(), " dimensions"); + } + std::vector ksize; + auto status = ctx->ConstantInputAsIntVector(1, &ksize); + if (!status.ok()) { + return status; + } + return ksize; + } - xla::XlaBuilder* const b = ctx->builder(); - auto input = - XlaHelpers::ConvertElementType(b, ctx->Input(0), reduction_type_); - auto reduce = xla::ReduceWindow(input, InitValue(b), *Reduction(ctx), ksize, - stride, padding_); - auto pooled = XlaHelpers::ConvertElementType(b, reduce, input_type(0)); - ctx->SetOutput(0, - PostProcessOutput(ctx, pooled, input_type(0), input_shape)); + xla::StatusOr> GetStride(XlaOpKernelContext* ctx) { + if (ctx->num_inputs() == 1) { + return stride_; + } + const TensorShape stride_shape = ctx->InputShape(2); + // Validate input sizes. + if (!TensorShapeUtils::IsVector(stride_shape)) { + return errors::InvalidArgument("stride must be a vector, not shape ", + stride_shape.DebugString()); + } + if (stride_shape.num_elements() != num_dims()) { + return errors::InvalidArgument( + "Sliding window stride field must " + "specify ", + num_dims(), " dimensions"); + } + std::vector stride; + auto status = ctx->ConstantInputAsIntVector(2, &stride); + if (!status.ok()) { + return status; + } + return stride; } protected: @@ -136,24 +131,48 @@ class PoolingOp : public XlaOpKernel { xla::PrimitiveType xla_reduction_type_; }; +// Converts the tensor data format to the one required by the XLA pooling +// library. +xla::TensorFormat XlaTensorFormat(tensorflow::TensorFormat data_format, + int num_spatial_dims) { + int num_dims = num_spatial_dims + 2; + int batch_dimension = GetTensorBatchDimIndex(num_dims, data_format); + int feature_dimension = GetTensorFeatureDimIndex(num_dims, data_format); + gtl::InlinedVector spatial_dimensions(num_spatial_dims); + for (int spatial_dim = 0; spatial_dim < num_spatial_dims; ++spatial_dim) { + spatial_dimensions[spatial_dim] = + GetTensorSpatialDimIndex(num_dims, data_format, spatial_dim); + } + return xla::TensorFormat(/*batch_dimension=*/batch_dimension, + /*feature_dimension=*/feature_dimension, + /*spatial_dimensions=*/spatial_dimensions); +} + class MaxPoolOp : public PoolingOp { public: MaxPoolOp(OpKernelConstruction* ctx, int num_spatial_dims) : PoolingOp(ctx, /*num_spatial_dims=*/num_spatial_dims, /*reduction_type=*/ctx->input_type(0)) {} - xla::XlaOp InitValue(xla::XlaBuilder* b) override { - return xla::MinValue(b, xla_reduction_type_); - } + void Compile(XlaOpKernelContext* ctx) override { + auto ksize_or_error = GetKernelSize(ctx); + OP_REQUIRES_OK(ctx, ksize_or_error.status()); + std::vector ksize = ksize_or_error.ValueOrDie(); - const xla::XlaComputation* Reduction(XlaOpKernelContext* ctx) override { - return ctx->GetOrCreateMax(reduction_type_); - } + auto stride_or_error = GetStride(ctx); + OP_REQUIRES_OK(ctx, stride_or_error.status()); + std::vector stride = stride_or_error.ValueOrDie(); + + const TensorShape input_shape = ctx->InputShape(0); + OP_REQUIRES(ctx, input_shape.dims() == num_dims(), + errors::InvalidArgument("Input to ", type_string(), + " operator must have ", num_dims(), + " dimensions")); - xla::XlaOp PostProcessOutput(XlaOpKernelContext* ctx, - const xla::XlaOp& output, DataType dtype, - const TensorShape& input_shape) override { - return output; + auto pooling = + xla::MaxPool(ctx->Input(0), ksize, stride, padding_, + XlaTensorFormat(data_format_, input_shape.dims() - 2)); + ctx->SetOutput(0, pooling); } }; @@ -180,9 +199,8 @@ class MaxPool3DOp : public MaxPoolOp { }; REGISTER_XLA_OP(Name("MaxPool3D"), MaxPool3DOp); -// Common computation shared between AvgPool and AvgPoolGrad. Divide each -// element of an image by the count of elements that contributed to that -// element during pooling. +// Divide each element of an image by the count of elements that contributed to +// that element during pooling. static xla::XlaOp AvgPoolDivideByCount( XlaOpKernelContext* ctx, const xla::XlaOp& output, DataType dtype, const TensorShape& input_shape, xla::Padding padding, @@ -241,20 +259,34 @@ class AvgPoolOp : public PoolingOp { /*reduction_type=*/ XlaHelpers::SumAccumulationType(ctx->input_type(0))) {} - xla::XlaOp InitValue(xla::XlaBuilder* b) override { - return xla::Zero(b, xla_reduction_type_); - } + void Compile(XlaOpKernelContext* ctx) override { + auto ksize_or_error = GetKernelSize(ctx); + OP_REQUIRES_OK(ctx, ksize_or_error.status()); + std::vector ksize = ksize_or_error.ValueOrDie(); - const xla::XlaComputation* Reduction(XlaOpKernelContext* ctx) override { - return ctx->GetOrCreateAdd(reduction_type_); - } + auto stride_or_error = GetStride(ctx); + OP_REQUIRES_OK(ctx, stride_or_error.status()); + std::vector stride = stride_or_error.ValueOrDie(); + + const TensorShape input_shape = ctx->InputShape(0); + OP_REQUIRES(ctx, input_shape.dims() == num_dims(), + errors::InvalidArgument("Input to ", type_string(), + " operator must have ", num_dims(), + " dimensions")); - xla::XlaOp PostProcessOutput(XlaOpKernelContext* ctx, - const xla::XlaOp& output, DataType dtype, - const TensorShape& input_shape) override { - return AvgPoolDivideByCount(ctx, output, dtype, input_shape, padding_, - ksize_, stride_, num_spatial_dims_, - data_format_); + auto xla_data_format = + XlaTensorFormat(data_format_, input_shape.dims() - 2); + auto spatial_padding = MakeSpatialPadding( + input_shape.dim_sizes(), ksize, stride, padding_, xla_data_format); + + // Convert the input to the reduction type. + auto converted_input = + ConvertElementType(ctx->Input(0), xla_reduction_type_); + auto pooling = + xla::AvgPool(converted_input, ksize, stride, spatial_padding, + xla_data_format, padding_ == xla::Padding::kValid); + // Convert the pooling result back to the input type before returning it. + ctx->SetOutput(0, ConvertElementType(pooling, ctx->input_xla_type(0))); } }; diff --git a/tensorflow/compiler/tf2xla/kernels/quantize_and_dequantize_op.cc b/tensorflow/compiler/tf2xla/kernels/quantize_and_dequantize_op.cc index 2e632e185d6df1ed188df3f4eca0574871bb17f4..6f4ed496a1774dde68dd9d5fbd37995d615b678c 100644 --- a/tensorflow/compiler/tf2xla/kernels/quantize_and_dequantize_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/quantize_and_dequantize_op.cc @@ -19,7 +19,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/lib/arithmetic.h" #include "tensorflow/compiler/xla/client/lib/constants.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/core/platform/macros.h" diff --git a/tensorflow/compiler/tf2xla/kernels/random_ops.cc b/tensorflow/compiler/tf2xla/kernels/random_ops.cc index 607cad798a98cfa0c6161a8154001926384e724e..2da9340625db08b14b78340c471f096baf15689d 100644 --- a/tensorflow/compiler/tf2xla/kernels/random_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/random_ops.cc @@ -27,7 +27,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/lib/arithmetic.h" #include "tensorflow/compiler/xla/client/lib/numeric.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_shape.h" diff --git a/tensorflow/compiler/tf2xla/kernels/reduce_window_op.cc b/tensorflow/compiler/tf2xla/kernels/reduce_window_op.cc index 23ac45beb783face11d247e511a2214915d4d411..b11a4ce36da9907ce8fe377c075023a4540797fa 100644 --- a/tensorflow/compiler/tf2xla/kernels/reduce_window_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/reduce_window_op.cc @@ -19,7 +19,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_compiler.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/core/framework/function.h" #include "tensorflow/core/framework/op_kernel.h" diff --git a/tensorflow/compiler/tf2xla/kernels/reduction_ops.cc b/tensorflow/compiler/tf2xla/kernels/reduction_ops.cc index be7f2bce8cb249aa51ca091e02da7dffc7d06743..0d260fa8fcaa513d7854c1e9215952404d555c70 100644 --- a/tensorflow/compiler/tf2xla/kernels/reduction_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/reduction_ops.cc @@ -20,7 +20,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/lib/constants.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/core/framework/kernel_def_builder.h" diff --git a/tensorflow/compiler/tf2xla/kernels/reduction_ops.h b/tensorflow/compiler/tf2xla/kernels/reduction_ops.h index 8333f9b288e27efe9497306f031980c9eec7c99c..466e79828d111ee7cadcf713703e8f252c63e62c 100644 --- a/tensorflow/compiler/tf2xla/kernels/reduction_ops.h +++ b/tensorflow/compiler/tf2xla/kernels/reduction_ops.h @@ -19,7 +19,7 @@ limitations under the License. #define TENSORFLOW_COMPILER_TF2XLA_KERNELS_REDUCTION_OPS_H_ #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/core/framework/op_kernel.h" namespace tensorflow { diff --git a/tensorflow/compiler/tf2xla/kernels/reduction_ops_common.cc b/tensorflow/compiler/tf2xla/kernels/reduction_ops_common.cc index bb8dd3ac909cce9f0ad6801a6079801950e6cef1..b52f0a0ab6290f2019bb58120be5c2364ec15bb6 100644 --- a/tensorflow/compiler/tf2xla/kernels/reduction_ops_common.cc +++ b/tensorflow/compiler/tf2xla/kernels/reduction_ops_common.cc @@ -19,7 +19,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/type_util.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/core/framework/kernel_def_builder.h" diff --git a/tensorflow/compiler/tf2xla/kernels/relu_op.cc b/tensorflow/compiler/tf2xla/kernels/relu_op.cc index f4b804e54677c7226d8d3429c9e8c27686d19ccf..d35777ccb1271ec6a7c9972c714d06b2415d9c34 100644 --- a/tensorflow/compiler/tf2xla/kernels/relu_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/relu_op.cc @@ -18,7 +18,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/kernels/cwise_ops.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/core/framework/kernel_def_builder.h" #include "tensorflow/core/framework/types.h" diff --git a/tensorflow/compiler/tf2xla/kernels/reshape_op.cc b/tensorflow/compiler/tf2xla/kernels/reshape_op.cc index 354fec9be75e9559b204e2afd6ee08dfc7cea872..121750a82a8c5cbe940068555ad273b7e0d22dfc 100644 --- a/tensorflow/compiler/tf2xla/kernels/reshape_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/reshape_op.cc @@ -19,7 +19,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" diff --git a/tensorflow/compiler/tf2xla/kernels/retval_op.cc b/tensorflow/compiler/tf2xla/kernels/retval_op.cc index 5be70a4ded31a988cb77cdabe3fc8a041bc3ad16..64900e4709fd3e16d21096b0cfff8922906cb0d4 100644 --- a/tensorflow/compiler/tf2xla/kernels/retval_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/retval_op.cc @@ -16,7 +16,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_context.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/core/framework/kernel_def_builder.h" #include "tensorflow/core/framework/op_kernel.h" @@ -104,7 +104,7 @@ class RetvalOp : public XlaOpKernel { TF_DISALLOW_COPY_AND_ASSIGN(RetvalOp); }; -REGISTER_XLA_OP(Name("_Retval"), RetvalOp); +REGISTER_XLA_OP(Name("_Retval").CompilationOnly(), RetvalOp); } // anonymous namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/reverse_op.cc b/tensorflow/compiler/tf2xla/kernels/reverse_op.cc index ec15b4cc7a523d5b8d4287bbe3321433f315063b..c0afccaa5b15dd33fcd016dfdd9bb18e244bf90a 100644 --- a/tensorflow/compiler/tf2xla/kernels/reverse_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/reverse_op.cc @@ -19,7 +19,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" @@ -95,10 +95,24 @@ class ReverseV2Op : public XlaOpKernel { std::vector axes; OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntVector(1, &axes)); + // witnessed_axes is used to ensure that the same axis is not marked to be + // reversed multiple times. + gtl::InlinedVector witnessed_axes(x_shape.dims(), false); + for (int d = 0; d < axes.size(); ++d) { - OP_REQUIRES(ctx, (0 <= axes[d]) && (axes[d] < x_shape.dims()), - errors::InvalidArgument(axes[d], " is out of range [0, ", - x_shape.dims(), ").")); + OP_REQUIRES( + ctx, (-x_shape.dims() <= axes[d]) && (axes[d] < x_shape.dims()), + errors::InvalidArgument(axes[d], " is out of range [-", + x_shape.dims(), ", ", x_shape.dims(), ").")); + // Axes can be negative and are shifted to the canonical index before + // being lowered to HLO. + if (axes[d] < 0) { + axes[d] += x_shape.dims(); + } + OP_REQUIRES(ctx, !witnessed_axes[axes[d]], + errors::InvalidArgument("canonicalized axis ", axes[d], + " was repeated.")); + witnessed_axes[axes[d]] = true; } ctx->SetOutput(0, xla::Rev(ctx->Input(0), axes)); diff --git a/tensorflow/compiler/tf2xla/kernels/reverse_sequence_op.cc b/tensorflow/compiler/tf2xla/kernels/reverse_sequence_op.cc index c810456f94322acfccae18d78efa861eede4648c..03a50ef8a059e5a005c4cc2e5e98acedfea8619a 100644 --- a/tensorflow/compiler/tf2xla/kernels/reverse_sequence_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/reverse_sequence_op.cc @@ -18,7 +18,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/lib/numeric.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/core/framework/tensor_shape.h" namespace tensorflow { diff --git a/tensorflow/compiler/tf2xla/kernels/scan_ops.cc b/tensorflow/compiler/tf2xla/kernels/scan_ops.cc index 56f237d5887fd9c88bb74bafcc5e44470f8807bf..ab094d7dd1ce9856a3c2854fd2776827d6c4b76f 100644 --- a/tensorflow/compiler/tf2xla/kernels/scan_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/scan_ops.cc @@ -20,7 +20,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/core/framework/op_kernel.h" diff --git a/tensorflow/compiler/tf2xla/kernels/scatter_nd_op.cc b/tensorflow/compiler/tf2xla/kernels/scatter_nd_op.cc index 14709bb6cbce4b3ae0f7ff859b0fa622c6eda293..f1f32699fee5f03f603f830722fe65622dee5d3e 100644 --- a/tensorflow/compiler/tf2xla/kernels/scatter_nd_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/scatter_nd_op.cc @@ -19,7 +19,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/core/framework/kernel_def_builder.h" #include "tensorflow/core/framework/op_kernel.h" diff --git a/tensorflow/compiler/tf2xla/kernels/segment_reduction_ops.cc b/tensorflow/compiler/tf2xla/kernels/segment_reduction_ops.cc index e2ac7da2c2630725efe3dbcc51c3f3d30e7aca2c..b22ecb7c6dbb42a33a4f4d90b18b20816df16a50 100644 --- a/tensorflow/compiler/tf2xla/kernels/segment_reduction_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/segment_reduction_ops.cc @@ -19,7 +19,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/lib/constants.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" namespace tensorflow { namespace { diff --git a/tensorflow/compiler/tf2xla/kernels/select_op.cc b/tensorflow/compiler/tf2xla/kernels/select_op.cc index 5c010c9df23ba6c7732d87fa014879d93ff586ce..6ce50efb4aa6e3434a7c6009cf9f52f6cff9cc9f 100644 --- a/tensorflow/compiler/tf2xla/kernels/select_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/select_op.cc @@ -19,7 +19,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/core/framework/kernel_def_builder.h" #include "tensorflow/core/kernels/bounds_check.h" diff --git a/tensorflow/compiler/tf2xla/kernels/sendrecv_ops.cc b/tensorflow/compiler/tf2xla/kernels/sendrecv_ops.cc index 6281d6c6533f7f49a269f5c7e52226ba0f1d29f6..a7f5a8f1698b9d02560de427d356e9e6be5caa7c 100644 --- a/tensorflow/compiler/tf2xla/kernels/sendrecv_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/sendrecv_ops.cc @@ -18,7 +18,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/framework/kernel_def_builder.h" #include "tensorflow/core/framework/types.h" diff --git a/tensorflow/compiler/tf2xla/kernels/shape_op.cc b/tensorflow/compiler/tf2xla/kernels/shape_op.cc index 5798823cd54c66dd179e3611c0041f7c5a1ff2b5..4e0cf99d8e7ff45ed9145981b5e2e637ce4d4e4b 100644 --- a/tensorflow/compiler/tf2xla/kernels/shape_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/shape_op.cc @@ -20,7 +20,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/core/framework/kernel_def_builder.h" #include "tensorflow/core/kernels/bounds_check.h" diff --git a/tensorflow/compiler/tf2xla/kernels/slice_op.cc b/tensorflow/compiler/tf2xla/kernels/slice_op.cc index 1864584adee357ce35a3e8a38a4e3c58c356bfca..6adc3c58de63ee70c26bed47eebef955893df4a5 100644 --- a/tensorflow/compiler/tf2xla/kernels/slice_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/slice_op.cc @@ -19,7 +19,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/tensor.h" diff --git a/tensorflow/compiler/tf2xla/kernels/softmax_op.cc b/tensorflow/compiler/tf2xla/kernels/softmax_op.cc index 60c6a5d349e479001589a0651e05e77768c8ffbf..025ba827410f1a9f993a8a1855558a2daa86609b 100644 --- a/tensorflow/compiler/tf2xla/kernels/softmax_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/softmax_op.cc @@ -20,7 +20,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/lib/constants.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/tensor.h" @@ -38,11 +38,15 @@ class SoftmaxOp : public XlaOpKernel { void Compile(XlaOpKernelContext* ctx) override { const TensorShape logits_shape = ctx->InputShape(0); - OP_REQUIRES(ctx, TensorShapeUtils::IsMatrix(logits_shape), - errors::InvalidArgument("logits must be 2-dimensional")); + OP_REQUIRES(ctx, TensorShapeUtils::IsVectorOrHigher(logits_shape), + errors::InvalidArgument("logits must have >= 1 dimension, got ", + logits_shape.DebugString())); - const int kBatchDim = 0; - const int kClassDim = 1; + // Major dimensions are batch dimensions, minor dimension is the class + // dimension. + std::vector batch_dims(logits_shape.dims() - 1); + std::iota(batch_dims.begin(), batch_dims.end(), 0); + const int kClassDim = logits_shape.dims() - 1; const DataType type = input_type(0); const xla::PrimitiveType xla_type = ctx->input_xla_type(0); @@ -56,7 +60,7 @@ class SoftmaxOp : public XlaOpKernel { xla::Reduce(logits, xla::MinValue(b, xla_type), max_func, {kClassDim}); // Subtract the max in batch b from every element in batch b. Broadcasts // along the batch dimension. - auto shifted_logits = xla::Sub(logits, logits_max, {kBatchDim}); + auto shifted_logits = xla::Sub(logits, logits_max, batch_dims); auto exp_shifted = xla::Exp(shifted_logits); const DataType accumulation_type = XlaHelpers::SumAccumulationType(type); xla::PrimitiveType xla_accumulation_type; @@ -71,9 +75,9 @@ class SoftmaxOp : public XlaOpKernel { auto softmax = log_ // softmax = shifted_logits - log(sum(exp(shifted_logits))) - ? xla::Sub(shifted_logits, xla::Log(sum), {kBatchDim}) + ? xla::Sub(shifted_logits, xla::Log(sum), batch_dims) // softmax = exp(shifted_logits) / sum(exp(shifted_logits)) - : xla::Div(exp_shifted, sum, {kBatchDim}); + : xla::Div(exp_shifted, sum, batch_dims); ctx->SetOutput(0, softmax); } diff --git a/tensorflow/compiler/tf2xla/kernels/sort_ops.cc b/tensorflow/compiler/tf2xla/kernels/sort_ops.cc index faaf8964ff7c40d75a493b03e6b400632117cb45..aaeeae01ccb303091a6d37d1aeb4b2a3377dc638 100644 --- a/tensorflow/compiler/tf2xla/kernels/sort_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/sort_ops.cc @@ -15,7 +15,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" namespace tensorflow { namespace { diff --git a/tensorflow/compiler/tf2xla/kernels/spacetobatch_op.cc b/tensorflow/compiler/tf2xla/kernels/spacetobatch_op.cc index 8a8525efa186ed4aa02c494f7505f6245677e96e..7327258c31f21f45ff7ffffbc9db7a2a70b4a14c 100644 --- a/tensorflow/compiler/tf2xla/kernels/spacetobatch_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/spacetobatch_op.cc @@ -16,7 +16,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" namespace tensorflow { namespace { diff --git a/tensorflow/compiler/tf2xla/kernels/spacetodepth_op.cc b/tensorflow/compiler/tf2xla/kernels/spacetodepth_op.cc index 47d282fe9ec664bbc424793e93f778ebb13c6877..4493539fe34f0ce635fdc58660d4ff90af9c9379 100644 --- a/tensorflow/compiler/tf2xla/kernels/spacetodepth_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/spacetodepth_op.cc @@ -16,7 +16,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/core/util/tensor_format.h" namespace tensorflow { diff --git a/tensorflow/compiler/tf2xla/kernels/split_op.cc b/tensorflow/compiler/tf2xla/kernels/split_op.cc index 242638f981198ffd7a9c5b5f6365168de59a1f85..93fc14e9efca868e84444dd0e07d7f0dfa84c042 100644 --- a/tensorflow/compiler/tf2xla/kernels/split_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/split_op.cc @@ -19,7 +19,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" diff --git a/tensorflow/compiler/tf2xla/kernels/stateless_random_ops.cc b/tensorflow/compiler/tf2xla/kernels/stateless_random_ops.cc index cc4b13d3b933cdc15efd94d3ce7a353a856bcb88..5412e135478361d08965e4621ec52cfb4a792f1d 100644 --- a/tensorflow/compiler/tf2xla/kernels/stateless_random_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/stateless_random_ops.cc @@ -24,7 +24,7 @@ limitations under the License. #include "tensorflow/compiler/xla/client/lib/math.h" #include "tensorflow/compiler/xla/client/lib/numeric.h" #include "tensorflow/compiler/xla/client/lib/prng.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_shape.h" diff --git a/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc b/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc index c2165ccd86dfa1c119790beb20af0844fb1bbda8..1062399d91bd9a9bf8c3820c5ecac534c110746d 100644 --- a/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc @@ -19,7 +19,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/tensor.h" diff --git a/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc b/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc index 26326f18b844fa9dc48aeedfa5dcff3d09033a18..be1814d8e3ae2c0ddad0134b9288e0ea084aa81b 100644 --- a/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc @@ -25,7 +25,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/tf2xla/xla_resource.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/partial_tensor_shape.h" diff --git a/tensorflow/compiler/tf2xla/kernels/tile_ops.cc b/tensorflow/compiler/tf2xla/kernels/tile_ops.cc index c9e56942625a009fb3660f413a845547192460d5..2c7213f322eb6fec1f134a444b569ae72307d00f 100644 --- a/tensorflow/compiler/tf2xla/kernels/tile_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/tile_ops.cc @@ -20,7 +20,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/core/framework/numeric_op.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/tensor.h" @@ -70,7 +70,7 @@ class TileOp : public XlaOpKernel { bool one_dimension_is_broadcasted_without_multiple = true; for (int i = 0; i < input_dims; ++i) { int multiple = literal.Get({i}); - OP_REQUIRES(ctx, multiple, + OP_REQUIRES(ctx, multiple >= 0, errors::InvalidArgument("Expected multiples[", i, "] >= 0, but got ", multiple)); int64 new_dim = input_shape.dim_size(i) * multiple; diff --git a/tensorflow/compiler/tf2xla/kernels/topk_op.cc b/tensorflow/compiler/tf2xla/kernels/topk_op.cc index 82d4a69777b06cc3dec1ceb1a0a4163dcb1e4667..183879c7602ccbbd74fca6cb9fa3fc94c066c37d 100644 --- a/tensorflow/compiler/tf2xla/kernels/topk_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/topk_op.cc @@ -13,11 +13,11 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/lib/numeric.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/lib/sorting.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/core/framework/kernel_def_builder.h" #include "tensorflow/core/framework/types.h" @@ -47,31 +47,12 @@ class TopKOp : public XlaOpKernel { context, last_dim_size >= k, errors::InvalidArgument("input must have at least k columns. Had ", last_dim_size, ", needed ", k)); - - xla::XlaBuilder* const b = context->builder(); if (last_dim_size < k) { k = last_dim_size; } - const xla::XlaOp input = context->Input(0); - - xla::XlaOp iota_s32 = xla::Iota(b, xla::S32, last_dim_size); - auto input_dims = input_shape.dim_sizes(); - std::vector broadcast_dims(input_dims.begin(), input_dims.end() - 1); - xla::XlaOp broadcast_s32 = xla::Broadcast(iota_s32, broadcast_dims); - xla::XlaOp sort_result = xla::Sort(xla::Neg(input), broadcast_s32); - - std::vector start_indices(input_shape.dims(), 0); - std::vector limit_indices(input_dims.begin(), input_dims.end()); - limit_indices[last_dim] = k; - std::vector strides(input_shape.dims(), 1); - - xla::XlaOp values = - xla::Neg(xla::Slice(xla::GetTupleElement(sort_result, 0), start_indices, - limit_indices, strides)); - xla::XlaOp indices = xla::Slice(xla::GetTupleElement(sort_result, 1), - start_indices, limit_indices, strides); - context->SetOutput(0, values); - context->SetOutput(1, indices); + xla::XlaOp output_tuple = TopK(context->Input(0), k); + context->SetOutput(0, xla::GetTupleElement(output_tuple, 0)); + context->SetOutput(1, xla::GetTupleElement(output_tuple, 1)); } private: diff --git a/tensorflow/compiler/tf2xla/kernels/training_ops.cc b/tensorflow/compiler/tf2xla/kernels/training_ops.cc index 98df73024962b8009a74976d473df752d590b47a..be5e91138656716daddcc3c7a68dbb78ecb69103 100644 --- a/tensorflow/compiler/tf2xla/kernels/training_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/training_ops.cc @@ -18,7 +18,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/lib/constants.h" #include "tensorflow/compiler/xla/client/lib/math.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/core/framework/kernel_def_builder.h" #include "tensorflow/core/framework/types.h" diff --git a/tensorflow/compiler/tf2xla/kernels/transpose_op.cc b/tensorflow/compiler/tf2xla/kernels/transpose_op.cc index 6c721c48fe3af45aff5cd0bd5e74e2693faf9f97..f9148b394212777271f9eba51313ee17b19819af 100644 --- a/tensorflow/compiler/tf2xla/kernels/transpose_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/transpose_op.cc @@ -23,7 +23,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/core/framework/kernel_def_builder.h" #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/kernels/bounds_check.h" diff --git a/tensorflow/compiler/tf2xla/kernels/unary_ops.cc b/tensorflow/compiler/tf2xla/kernels/unary_ops.cc index e6ec794cfd4103f622f64a113464c2f4cbfd4215..0bdfc05726105e2d18362a691cbe2aab00bf77f3 100644 --- a/tensorflow/compiler/tf2xla/kernels/unary_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/unary_ops.cc @@ -23,7 +23,7 @@ limitations under the License. #include "tensorflow/compiler/xla/client/lib/arithmetic.h" #include "tensorflow/compiler/xla/client/lib/constants.h" #include "tensorflow/compiler/xla/client/lib/math.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/core/framework/kernel_def_builder.h" namespace tensorflow { diff --git a/tensorflow/compiler/tf2xla/kernels/unpack_op.cc b/tensorflow/compiler/tf2xla/kernels/unpack_op.cc index f951127bb95cd52864af869676a6b4c4961c1a43..8671632976023fded04c26a9780c1a67638b0916 100644 --- a/tensorflow/compiler/tf2xla/kernels/unpack_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/unpack_op.cc @@ -22,7 +22,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" diff --git a/tensorflow/compiler/tf2xla/kernels/variable_ops.cc b/tensorflow/compiler/tf2xla/kernels/variable_ops.cc index bb27b5d56f3c24dc093a60e698b1080dfb76514d..2c92a585f5679242d672d0402e617ff199b94f17 100644 --- a/tensorflow/compiler/tf2xla/kernels/variable_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/variable_ops.cc @@ -19,7 +19,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/core/framework/kernel_def_builder.h" #include "tensorflow/core/framework/types.h" diff --git a/tensorflow/compiler/tf2xla/kernels/while_op.cc b/tensorflow/compiler/tf2xla/kernels/while_op.cc index c653a110292da93033d055170aeda81fadde999a..296518229ebf0ba46717afc4f26d5ae1551c2862 100644 --- a/tensorflow/compiler/tf2xla/kernels/while_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/while_op.cc @@ -21,7 +21,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/core/framework/function.h" @@ -301,6 +301,7 @@ void XlaWhileOp::Compile(XlaOpKernelContext* ctx) { } REGISTER_XLA_OP(Name("While").AllowResourceTypes(), XlaWhileOp); +REGISTER_XLA_OP(Name("StatelessWhile").AllowResourceTypes(), XlaWhileOp); REGISTER_XLA_OP(Name("XlaWhile").AllowResourceTypes(), XlaWhileOp); } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/legacy_flags/backend_registration_flags.cc b/tensorflow/compiler/tf2xla/legacy_flags/backend_registration_flags.cc deleted file mode 100644 index 661505021f820e2a87a5d414c6fe382bf6153045..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/tf2xla/legacy_flags/backend_registration_flags.cc +++ /dev/null @@ -1,63 +0,0 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -// Legacy flags for the XLA bridge's backend registration modules. - -#include // NOLINT -#include - -#include "tensorflow/compiler/tf2xla/legacy_flags/backend_registration_flags.h" -#include "tensorflow/compiler/xla/legacy_flags/parse_flags_from_env.h" -#include "tensorflow/core/platform/types.h" -#include "tensorflow/core/util/command_line_flags.h" - -namespace tensorflow { -namespace legacy_flags { - -// Pointers to the parsed value of the flags and flag descriptors, initialized -// via flags_init. -static BackendRegistrationFlags* flags; -static std::vector* flag_list; -static std::once_flag flags_init; - -// Allocate *flags. Called via call_once(&flags_init,...). -static void AllocateFlags() { - flags = new BackendRegistrationFlags; - flags->tf_enable_prng_ops_gpu = false; - flag_list = new std::vector({ - Flag("tf_enable_prng_ops_gpu", &flags->tf_enable_prng_ops_gpu, - "Whether to enable PRNG ops: [RandomStandardNormal | RandomUniform " - "| RandomUniformInt | TruncatedNormal] on GPU."), - }); - xla::legacy_flags::ParseFlagsFromEnv(*flag_list); -} - -// Append to *append_to flag definitions associated with the XLA bridge's -// backend registration modules. -void AppendBackendRegistrationFlags(std::vector* append_to) { - std::call_once(flags_init, &AllocateFlags); - append_to->insert(append_to->end(), flag_list->begin(), flag_list->end()); -} - -// Return a pointer to the BackendRegistrationFlags struct; -// repeated calls return the same pointer. -// This should be called only after Flags::Parse() has returned. -BackendRegistrationFlags* GetBackendRegistrationFlags() { - std::call_once(flags_init, &AllocateFlags); - return flags; -} - -} // namespace legacy_flags -} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/legacy_flags/backend_registration_flags.h b/tensorflow/compiler/tf2xla/legacy_flags/backend_registration_flags.h deleted file mode 100644 index 861c923dd51f90be2acbeb23911a93e873aabdce..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/tf2xla/legacy_flags/backend_registration_flags.h +++ /dev/null @@ -1,49 +0,0 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_COMPILER_TF2XLA_LEGACY_FLAGS_BACKEND_REGISTRATION_FLAGS_H_ -#define TENSORFLOW_COMPILER_TF2XLA_LEGACY_FLAGS_BACKEND_REGISTRATION_FLAGS_H_ - -// Legacy flags for the XLA bridge's backend registration modules. - -#include - -#include "tensorflow/core/platform/types.h" -#include "tensorflow/core/util/command_line_flags.h" - -namespace tensorflow { -namespace legacy_flags { - -// Append to *flag_list flag definitions associated with the XLA bridge's -// backend registration modules. -void AppendBackendRegistrationFlags(std::vector* append_to); - -// The values of flags associated with the XLA bridge's backend registration -// module. -typedef struct { - // Whether to enable RandomUniform op on GPU backend. - // TODO (b/32333178): Remove this flag or set its default to true. - bool tf_enable_prng_ops_gpu; -} BackendRegistrationFlags; - -// Return a pointer to the BackendRegistrationFlags struct; -// repeated calls return the same pointer. -// This should be called only after Flags::Parse() has returned. -BackendRegistrationFlags* GetBackendRegistrationFlags(); - -} // namespace legacy_flags -} // namespace tensorflow - -#endif // TENSORFLOW_COMPILER_TF2XLA_LEGACY_FLAGS_BACKEND_REGISTRATION_FLAGS_H_ diff --git a/tensorflow/compiler/tf2xla/lib/BUILD b/tensorflow/compiler/tf2xla/lib/BUILD index e35a457f09de81bc45d90bdc3f49cbc5ee0511a1..cb7a40e23d539f758d963791f1c2b4d37374ade5 100644 --- a/tensorflow/compiler/tf2xla/lib/BUILD +++ b/tensorflow/compiler/tf2xla/lib/BUILD @@ -25,8 +25,8 @@ cc_library( "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/client:xla_computation", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", "//tensorflow/core:lib", ], ) @@ -44,9 +44,9 @@ cc_library( "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/client/lib:constants", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", "//tensorflow/core:lib", ], ) @@ -59,9 +59,9 @@ cc_library( "//tensorflow/compiler/tf2xla:xla_compiler", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/client/lib:constants", "//tensorflow/compiler/xla/client/lib:math", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", "//tensorflow/core:protos_all_cc", ], ) @@ -78,12 +78,12 @@ cc_library( "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/client/lib:arithmetic", "//tensorflow/compiler/xla/client/lib:constants", "//tensorflow/compiler/xla/client/lib:math", "//tensorflow/compiler/xla/client/lib:numeric", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", "//tensorflow/core:lib", ], ) @@ -100,9 +100,9 @@ cc_library( "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/client/lib:arithmetic", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", "//tensorflow/core:lib", ], ) @@ -119,10 +119,10 @@ cc_library( "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/client/lib:constants", "//tensorflow/compiler/xla/client/lib:numeric", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", "//tensorflow/core:lib", ], ) @@ -142,7 +142,7 @@ xla_test( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", + "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:xla_internal_test_main", @@ -162,8 +162,8 @@ cc_library( "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/client:xla_computation", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", "//tensorflow/core:lib", ], ) @@ -200,8 +200,8 @@ cc_library( "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/client:xla_computation", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", "//tensorflow/core:lib", ], ) diff --git a/tensorflow/compiler/tf2xla/lib/batch_dot.cc b/tensorflow/compiler/tf2xla/lib/batch_dot.cc index 3c4eec081ba9744226cfbd8d5392220cbf7276f3..f666d22ea44216beef74608bb4d9f33fb2fe82c6 100644 --- a/tensorflow/compiler/tf2xla/lib/batch_dot.cc +++ b/tensorflow/compiler/tf2xla/lib/batch_dot.cc @@ -18,7 +18,7 @@ limitations under the License. #include #include -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/statusor.h" diff --git a/tensorflow/compiler/tf2xla/lib/batch_dot.h b/tensorflow/compiler/tf2xla/lib/batch_dot.h index dbba5eaf26883186e3c587f52f16bb7c37ea9d8f..8757b16a1ca6a8cec5e3c801c885e7bbbb2f2c76 100644 --- a/tensorflow/compiler/tf2xla/lib/batch_dot.h +++ b/tensorflow/compiler/tf2xla/lib/batch_dot.h @@ -16,7 +16,7 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_TF2XLA_LIB_BATCH_DOT_H_ #define TENSORFLOW_COMPILER_TF2XLA_LIB_BATCH_DOT_H_ -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/client/xla_computation.h" namespace tensorflow { diff --git a/tensorflow/compiler/tf2xla/lib/cholesky.cc b/tensorflow/compiler/tf2xla/lib/cholesky.cc index 35b137aa2cc0b5e6c2d2b917c0a95410522305c2..87d73eb3f07ebd7dfa4fef50ebe76cad0c4ed117 100644 --- a/tensorflow/compiler/tf2xla/lib/cholesky.cc +++ b/tensorflow/compiler/tf2xla/lib/cholesky.cc @@ -23,7 +23,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/lib/util.h" #include "tensorflow/compiler/tf2xla/lib/while_loop.h" #include "tensorflow/compiler/xla/client/lib/constants.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" diff --git a/tensorflow/compiler/tf2xla/lib/cholesky.h b/tensorflow/compiler/tf2xla/lib/cholesky.h index bc1b0ed82f16659d615d3068060e2d7e3c82d941..1bef9bb166c576ec665bb48265b4da200ddca2a0 100644 --- a/tensorflow/compiler/tf2xla/lib/cholesky.h +++ b/tensorflow/compiler/tf2xla/lib/cholesky.h @@ -16,7 +16,7 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_TF2XLA_LIB_CHOLESKY_H_ #define TENSORFLOW_COMPILER_TF2XLA_LIB_CHOLESKY_H_ -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/client/xla_computation.h" namespace tensorflow { diff --git a/tensorflow/compiler/tf2xla/lib/qr.cc b/tensorflow/compiler/tf2xla/lib/qr.cc index 9c8ac7af25e4222f35bedd3816fc817af7e1f068..fc0c1ee838190b1f1a7ca5b901c97e0a35232a97 100644 --- a/tensorflow/compiler/tf2xla/lib/qr.cc +++ b/tensorflow/compiler/tf2xla/lib/qr.cc @@ -25,7 +25,7 @@ limitations under the License. #include "tensorflow/compiler/xla/client/lib/constants.h" #include "tensorflow/compiler/xla/client/lib/math.h" #include "tensorflow/compiler/xla/client/lib/numeric.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" diff --git a/tensorflow/compiler/tf2xla/lib/qr.h b/tensorflow/compiler/tf2xla/lib/qr.h index 3aa6a9b07539487b954b2d8c8d0e0bbcc49c2b42..abd2316ac961f583dd29f90f43cf6209de30bd6a 100644 --- a/tensorflow/compiler/tf2xla/lib/qr.h +++ b/tensorflow/compiler/tf2xla/lib/qr.h @@ -16,7 +16,7 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_TF2XLA_LIB_QR_H_ #define TENSORFLOW_COMPILER_TF2XLA_LIB_QR_H_ -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" namespace tensorflow { diff --git a/tensorflow/compiler/tf2xla/lib/random.cc b/tensorflow/compiler/tf2xla/lib/random.cc index 8ff10fbd3fbf9308140af84c752a5a50bec8fd32..5e7cf00ee5e063aef36a9531ff87d8fe6928ca1f 100644 --- a/tensorflow/compiler/tf2xla/lib/random.cc +++ b/tensorflow/compiler/tf2xla/lib/random.cc @@ -21,7 +21,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/xla/client/lib/constants.h" #include "tensorflow/compiler/xla/client/lib/math.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/status_macros.h" namespace tensorflow { diff --git a/tensorflow/compiler/tf2xla/lib/random.h b/tensorflow/compiler/tf2xla/lib/random.h index 2c573fd85b2783fdac13457cdb277cf988ac40c4..59fc5d0433a51328bc78006ab1c3495d908b44ac 100644 --- a/tensorflow/compiler/tf2xla/lib/random.h +++ b/tensorflow/compiler/tf2xla/lib/random.h @@ -16,7 +16,7 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_TF2XLA_LIB_RANDOM_H_ #define TENSORFLOW_COMPILER_TF2XLA_LIB_RANDOM_H_ -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/core/framework/types.pb.h" diff --git a/tensorflow/compiler/tf2xla/lib/scatter.cc b/tensorflow/compiler/tf2xla/lib/scatter.cc index 739032fef7759daee5d10d209ead5e1ffa60ef8c..ba22eff73abab11abeb57283c63318b2e50a9ca1 100644 --- a/tensorflow/compiler/tf2xla/lib/scatter.cc +++ b/tensorflow/compiler/tf2xla/lib/scatter.cc @@ -21,7 +21,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/lib/util.h" #include "tensorflow/compiler/tf2xla/lib/while_loop.h" #include "tensorflow/compiler/xla/client/lib/arithmetic.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" diff --git a/tensorflow/compiler/tf2xla/lib/scatter.h b/tensorflow/compiler/tf2xla/lib/scatter.h index 452fda565d4763f366ab8ffb761f7521ee57d70b..13a5f1b850a612bddeeac39bef431c19925351ca 100644 --- a/tensorflow/compiler/tf2xla/lib/scatter.h +++ b/tensorflow/compiler/tf2xla/lib/scatter.h @@ -18,7 +18,7 @@ limitations under the License. #include -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/statusor.h" diff --git a/tensorflow/compiler/tf2xla/lib/triangular_solve.cc b/tensorflow/compiler/tf2xla/lib/triangular_solve.cc index 05dad759df734994fe44a485463280357e8d40b3..febb638e5e8a87d78919f1eaa556d9c05ee40112 100644 --- a/tensorflow/compiler/tf2xla/lib/triangular_solve.cc +++ b/tensorflow/compiler/tf2xla/lib/triangular_solve.cc @@ -22,7 +22,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/lib/util.h" #include "tensorflow/compiler/xla/client/lib/constants.h" #include "tensorflow/compiler/xla/client/lib/numeric.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/shape_util.h" @@ -57,7 +57,7 @@ xla::XlaOp DiagonalBlocks(xla::XlaOp a, int64 block_size) { // We can grab entire blocks using gather if (n > block_size) { // Construct the starting indices of the diagonal blocks - auto gather_indices = + auto start_indices = Transpose(Broadcast(Mul(Iota(builder, xla::S32, num_blocks), xla::ConstantR0(builder, block_size)), /*broadcast_sizes=*/{2}), @@ -65,13 +65,13 @@ xla::XlaOp DiagonalBlocks(xla::XlaOp a, int64 block_size) { // Gather the diagonal blocks xla::GatherDimensionNumbers dim_numbers; - dim_numbers.add_output_window_dims(ndims - 1); - dim_numbers.add_output_window_dims(ndims); - dim_numbers.add_gather_dims_to_operand_dims(ndims - 2); - dim_numbers.add_gather_dims_to_operand_dims(ndims - 1); + dim_numbers.add_offset_dims(ndims - 1); + dim_numbers.add_offset_dims(ndims); + dim_numbers.add_start_index_map(ndims - 2); + dim_numbers.add_start_index_map(ndims - 1); dim_numbers.set_index_vector_dim(1); - diag_blocks = Gather(a, gather_indices, dim_numbers, - /*window_bounds=*/{block_size, block_size}); + diag_blocks = Gather(a, start_indices, dim_numbers, + /*slice_sizes=*/{block_size, block_size}); } // The last block might be smaller than the block size, diff --git a/tensorflow/compiler/tf2xla/lib/triangular_solve.h b/tensorflow/compiler/tf2xla/lib/triangular_solve.h index 9c4314e275ff8294a20713a4237f91c9d5fa8f74..555760b7efabddfb25c9135b109a1c48b487415e 100644 --- a/tensorflow/compiler/tf2xla/lib/triangular_solve.h +++ b/tensorflow/compiler/tf2xla/lib/triangular_solve.h @@ -16,7 +16,7 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_TF2XLA_LIB_TRIANGULAR_SOLVE_H_ #define TENSORFLOW_COMPILER_TF2XLA_LIB_TRIANGULAR_SOLVE_H_ -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/client/xla_computation.h" namespace tensorflow { diff --git a/tensorflow/compiler/tf2xla/lib/triangular_solve_test.cc b/tensorflow/compiler/tf2xla/lib/triangular_solve_test.cc index a29496dec44798eb0b16bf59b7b84e48c6bdd56e..aeebf16028d40189203cdfd815f06a339ee72902 100644 --- a/tensorflow/compiler/tf2xla/lib/triangular_solve_test.cc +++ b/tensorflow/compiler/tf2xla/lib/triangular_solve_test.cc @@ -20,7 +20,7 @@ limitations under the License. #include #include "tensorflow/compiler/xla/array2d.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/test.h" diff --git a/tensorflow/compiler/tf2xla/lib/util.cc b/tensorflow/compiler/tf2xla/lib/util.cc index a6f5d346cb5ecb85ff6b2306c2502ba31d74cc64..8b5beba383cda45d36e2ee27ca5e3b3c5988b6b7 100644 --- a/tensorflow/compiler/tf2xla/lib/util.cc +++ b/tensorflow/compiler/tf2xla/lib/util.cc @@ -18,7 +18,7 @@ limitations under the License. #include #include -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/shape_util.h" diff --git a/tensorflow/compiler/tf2xla/lib/util.h b/tensorflow/compiler/tf2xla/lib/util.h index a139873d3204dae222d6c97793b0aca0deaeecfb..b4905c952820a45371e090aa98466654e2db9661 100644 --- a/tensorflow/compiler/tf2xla/lib/util.h +++ b/tensorflow/compiler/tf2xla/lib/util.h @@ -16,7 +16,7 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_TF2XLA_LIB_UTIL_H_ #define TENSORFLOW_COMPILER_TF2XLA_LIB_UTIL_H_ -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/core/lib/gtl/array_slice.h" diff --git a/tensorflow/compiler/tf2xla/lib/while_loop.cc b/tensorflow/compiler/tf2xla/lib/while_loop.cc index 574e70ddeeab8a3041cd730ce2717daec4f82ddf..d64394f1401d7ceea004a59c991ef6f4a1c58b41 100644 --- a/tensorflow/compiler/tf2xla/lib/while_loop.cc +++ b/tensorflow/compiler/tf2xla/lib/while_loop.cc @@ -15,7 +15,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/lib/while_loop.h" #include "tensorflow/compiler/tf2xla/lib/util.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" diff --git a/tensorflow/compiler/tf2xla/lib/while_loop.h b/tensorflow/compiler/tf2xla/lib/while_loop.h index 69cc70bfaf94f80bf3c63a2d0ef3b2a226be8123..9493b1f109be0725f7f733b9f9da664264275a69 100644 --- a/tensorflow/compiler/tf2xla/lib/while_loop.h +++ b/tensorflow/compiler/tf2xla/lib/while_loop.h @@ -19,7 +19,7 @@ limitations under the License. #include #include -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/core/lib/core/stringpiece.h" diff --git a/tensorflow/compiler/tf2xla/literal_util.cc b/tensorflow/compiler/tf2xla/literal_util.cc index 2fb66913ada375d53512b9a1115326b3cc2afea4..77da1bf29ced60e490f07abad41cf8ce96232982 100644 --- a/tensorflow/compiler/tf2xla/literal_util.cc +++ b/tensorflow/compiler/tf2xla/literal_util.cc @@ -32,6 +32,23 @@ Status HostTensorToBorrowingLiteral(const Tensor& host_tensor, return Status::OK(); } +Status HostTensorToMutableBorrowingLiteral( + Tensor* host_tensor, xla::MutableBorrowingLiteral* literal) { + xla::Shape xla_shape; + TF_RETURN_IF_ERROR(TensorShapeToXLAShape(host_tensor->dtype(), + host_tensor->shape(), &xla_shape)); + return HostTensorToMutableBorrowingLiteral(xla_shape, host_tensor, literal); +} + +Status HostTensorToMutableBorrowingLiteral( + const xla::Shape& xla_shape, Tensor* host_tensor, + xla::MutableBorrowingLiteral* literal) { + *literal = xla::MutableBorrowingLiteral( + static_cast(DMAHelper::base(host_tensor)), xla_shape); + + return Status::OK(); +} + Status HostTensorsToBorrowingLiteralTuple( tensorflow::gtl::ArraySlice host_tensors, xla::BorrowingLiteral* literal) { diff --git a/tensorflow/compiler/tf2xla/literal_util.h b/tensorflow/compiler/tf2xla/literal_util.h index 0610a57029e72dff79a84742346f78a42b7f4ff1..09d6fa811669b422532673540e4da47f47e6be4e 100644 --- a/tensorflow/compiler/tf2xla/literal_util.h +++ b/tensorflow/compiler/tf2xla/literal_util.h @@ -30,6 +30,16 @@ namespace tensorflow { // 'host_tensor'. Status HostTensorToBorrowingLiteral(const Tensor& host_tensor, xla::BorrowingLiteral* literal); +// Returns a MutableBorrowingLiteral that utilizes the same underlying buffer +// owned by 'host_tensor', but is mutable via the xla::Literal methods. +Status HostTensorToMutableBorrowingLiteral( + Tensor* host_tensor, xla::MutableBorrowingLiteral* literal); +// Similar as above, except the literal shape is explicitly provided and used +// instead of obtaining it from the 'host_tensor'. The provided literal shape +// 'xla_shape' must be compatible with the shape of 'host_tensor'. +Status HostTensorToMutableBorrowingLiteral( + const xla::Shape& xla_shape, Tensor* host_tensor, + xla::MutableBorrowingLiteral* literal); // Returns a BorrowingLiteral tuple that utilizes the same underlying buffers // owned by 'host_tensors'. diff --git a/tensorflow/compiler/tf2xla/tf2xla_util.cc b/tensorflow/compiler/tf2xla/tf2xla_util.cc index 9203e8d9e607e99ad738350a1c3f2b9e900df179..0e07485d1861aa40b14e527b14947c6f8bab647e 100644 --- a/tensorflow/compiler/tf2xla/tf2xla_util.cc +++ b/tensorflow/compiler/tf2xla/tf2xla_util.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/tf2xla_util.h" #include +#include #include #include @@ -297,4 +298,29 @@ void AddDtypeToKernalDefConstraint(StringPiece name, DataType dtype, } } +namespace { +uint32 InitialRandomSeed() { + // Support plumbing the TF seed through to XLA is being worked on. + // If a user wants deterministic behavior, their best option + // is to start with a known checkpoint. This also handles issues when + // multiple random calls can be invoked in any order by TF executor. + // Another option is to use stateless random ops. They have much cleaner + // semantics. + // If a user really wants to set a deterministic seed for XLA-based + // devices, this is the place to do it. + std::random_device rd; + // Make the starting value odd. + return rd() | 1; +} +} // namespace + +uint32 GetXLARandomSeed() { + // We initialize counter with an odd number and increment it by two + // everytime. This ensures that it will never be zero, even + // after an overflow. When seeded with zero, some XLA backends + // can return all zeros instead of random numbers. + static std::atomic counter(InitialRandomSeed()); + return counter.fetch_add(2); +} + } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/tf2xla_util.h b/tensorflow/compiler/tf2xla/tf2xla_util.h index 745beb39c1d917cd0d1cd219536ee26a96253ec9..33620ef810bd4fe897f384474e661e341a448b93 100644 --- a/tensorflow/compiler/tf2xla/tf2xla_util.h +++ b/tensorflow/compiler/tf2xla/tf2xla_util.h @@ -56,6 +56,9 @@ Status SetNodeShardingFromNeighbors(Node* n, bool out_edges); void AddDtypeToKernalDefConstraint(StringPiece name, DataType dtype, KernelDef* kdef); +// Returns the next random seed to use for seeding xla rng. +uint32 GetXLARandomSeed(); + } // namespace tensorflow #endif // TENSORFLOW_COMPILER_TF2XLA_TF2XLA_UTIL_H_ diff --git a/tensorflow/compiler/tf2xla/xla_compilation_device.cc b/tensorflow/compiler/tf2xla/xla_compilation_device.cc index fe7ec633eca2504faf6cbb2f5fd7f59780ab7976..e89f4733281194f0263ae8cc4907caa0ad781165 100644 --- a/tensorflow/compiler/tf2xla/xla_compilation_device.cc +++ b/tensorflow/compiler/tf2xla/xla_compilation_device.cc @@ -22,7 +22,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/sharding_util.h" #include "tensorflow/compiler/tf2xla/xla_context.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/core/common_runtime/local_device.h" #include "tensorflow/core/framework/device_base.h" #include "tensorflow/core/platform/mem.h" diff --git a/tensorflow/compiler/tf2xla/xla_compilation_device.h b/tensorflow/compiler/tf2xla/xla_compilation_device.h index d0b9e34e162f3412cd6662a2e2bbfe3df213c4c2..a6e78825334fec748be5fee80669649df699d2fb 100644 --- a/tensorflow/compiler/tf2xla/xla_compilation_device.h +++ b/tensorflow/compiler/tf2xla/xla_compilation_device.h @@ -19,7 +19,7 @@ limitations under the License. #include #include "tensorflow/compiler/tf2xla/xla_resource.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/common_runtime/local_device.h" #include "tensorflow/core/framework/device_base.h" diff --git a/tensorflow/compiler/tf2xla/xla_compiled_cpu_function.cc b/tensorflow/compiler/tf2xla/xla_compiled_cpu_function.cc index 672e19bd93449ccc31f4af5ded23257b197a3c39..1f0f240135dfcd0c540cc39a42514c67ce979ee0 100644 --- a/tensorflow/compiler/tf2xla/xla_compiled_cpu_function.cc +++ b/tensorflow/compiler/tf2xla/xla_compiled_cpu_function.cc @@ -16,45 +16,47 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_compiled_cpu_function.h" #include -#include "tensorflow/compiler/aot/runtime.h" namespace tensorflow { XlaCompiledCpuFunction::XlaCompiledCpuFunction(const StaticData& static_data, AllocMode alloc_mode) - : raw_function_(static_data.raw_function), - result_index_(static_data.result_index), - args_(new void*[static_data.num_args]), - temps_(new void*[static_data.num_temps]), - arg_names_(static_data.arg_names), - result_names_(static_data.result_names), - program_shape_(static_data.program_shape), - hlo_profile_printer_data_(static_data.hlo_profile_printer_data) { + : raw_function_(static_data.raw_function_), + result_index_(static_data.result_index_), + buffer_table_(new void*[static_data.num_buffers_]), + buffer_infos_(static_data.buffer_infos_), + arg_index_table_(static_data.arg_index_table_), + num_args_(static_data.num_args_), + arg_names_(static_data.arg_names_), + result_names_(static_data.result_names_), + program_shape_(static_data.program_shape_), + hlo_profile_printer_data_(static_data.hlo_profile_printer_data_) { + bool allocate_entry_params = + alloc_mode == AllocMode::ARGS_RESULTS_PROFILES_AND_TEMPS; // Allocate arg and temp buffers. - if (alloc_mode == AllocMode::ARGS_RESULTS_PROFILES_AND_TEMPS) { - alloc_args_ = tensorflow::tfcompile::runtime::MallocContiguousBuffers( - static_data.arg_sizes, static_data.num_args, args_, - /*annotate_initialized=*/false); - } - alloc_temps_ = tensorflow::tfcompile::runtime::MallocContiguousBuffers( - static_data.temp_sizes, static_data.num_temps, temps_, + alloc_buffer_table_ = cpu_function_runtime::MallocContiguousBuffers( + static_data.buffer_infos_, static_data.num_buffers_, + /*allocate_entry_params=*/allocate_entry_params, buffer_table_, /*annotate_initialized=*/true); - // If Hlo profiling is enabled the generated code expects an appropriately // sized buffer to be passed in as the last argument. If Hlo profiling is // disabled the last function argument is still present in the function // signature, but it is ignored by the generated code and we pass in null for // it. if (hlo_profiling_enabled()) { - profile_counters_ = new int64[static_data.profile_counters_size](); + profile_counters_ = new int64[static_data.profile_counters_size_](); } } +bool XlaCompiledCpuFunction::Run() { + raw_function_(buffer_table_[result_index_], &run_options_, nullptr, + buffer_table_, profile_counters_); + return true; +} + XlaCompiledCpuFunction::~XlaCompiledCpuFunction() { - tensorflow::tfcompile::runtime::FreeContiguous(alloc_args_); - tensorflow::tfcompile::runtime::FreeContiguous(alloc_temps_); - delete[] args_; - delete[] temps_; + cpu_function_runtime::FreeContiguous(alloc_buffer_table_); + delete[] buffer_table_; delete[] profile_counters_; } diff --git a/tensorflow/compiler/tf2xla/xla_compiled_cpu_function.h b/tensorflow/compiler/tf2xla/xla_compiled_cpu_function.h index 48a8c083cacf2f6ecf9dc1817b6174c01385d035..425e769346ffcbc548495d93cb7adc779f860110 100644 --- a/tensorflow/compiler/tf2xla/xla_compiled_cpu_function.h +++ b/tensorflow/compiler/tf2xla/xla_compiled_cpu_function.h @@ -19,6 +19,7 @@ limitations under the License. #include #include +#include "tensorflow/compiler/tf2xla/cpu_function_runtime.h" #include "tensorflow/compiler/xla/executable_run_options.h" #include "tensorflow/core/platform/types.h" @@ -56,36 +57,85 @@ class XlaCompiledCpuFunction { // StaticData represents the state necessary to run an XLA-compiled // function. For JIT this is backed by data in XlaJitCompiledCpuFunction; for // AOT this is backed by data compiled into the object file. - struct StaticData { + // + // The contents of StaticData are XLA-internal implementation details and + // should not be relied on by clients. + // + // TODO(sanjoy): Come up with a cleaner way to express the contraint we want + // here: generated XlaCompiledCpuFunction subclasses should be able to create + // instances of StaticData but only XlaCompiledCpuFunction should be able to + // read from StaticData instances. + class StaticData { + public: + void set_raw_function(RawFunction raw_function) { + raw_function_ = raw_function; + } + void set_buffer_infos( + const cpu_function_runtime::BufferInfo* buffer_infos) { + buffer_infos_ = buffer_infos; + } + void set_num_buffers(size_t num_buffers) { num_buffers_ = num_buffers; } + void set_arg_index_table(const int32* arg_index_table) { + arg_index_table_ = arg_index_table; + } + void set_num_args(int64 num_args) { num_args_ = num_args; } + void set_result_index(size_t result_index) { result_index_ = result_index; } + void set_arg_names(const char** arg_names) { arg_names_ = arg_names; } + void set_result_names(const char** result_names) { + result_names_ = result_names; + } + void set_program_shape(const xla::ProgramShape* program_shape) { + program_shape_ = program_shape; + } + const xla::HloProfilePrinterData* hlo_profile_printer_data() const { + return hlo_profile_printer_data_; + } + void set_hlo_profile_printer_data( + const xla::HloProfilePrinterData* hlo_profile_printer_data) { + hlo_profile_printer_data_ = hlo_profile_printer_data; + } + void set_profile_counters_size(int64 profile_counters_size) { + profile_counters_size_ = profile_counters_size; + } + + private: // The raw function to call. - RawFunction raw_function; + RawFunction raw_function_; + + // Contains information about the buffers used by the XLA computation. + const cpu_function_runtime::BufferInfo* buffer_infos_ = nullptr; + size_t num_buffers_ = 0; + + // Entry parameter i is described by + // buffer_infos[arg_index_table[i]]. + const int32* arg_index_table_ = nullptr; - // Cardinality and sizes of arg and temp buffers. - const intptr_t* arg_sizes = nullptr; - size_t num_args = 0; - const intptr_t* temp_sizes = nullptr; - size_t num_temps = 0; + // There are num_args entry parameters. + int64 num_args_ = 0; // The 0-based index of the result tuple, in the temp buffers. - size_t result_index = 0; + size_t result_index_ = 0; // [Optional] Arrays of arg and result names. These are arrays of C-style // strings, where the array is terminated by nullptr. - const char** arg_names = nullptr; - const char** result_names = nullptr; + const char** arg_names_ = nullptr; + const char** result_names_ = nullptr; // [Optional] Arg and result shapes. - const xla::ProgramShape* program_shape = nullptr; + const xla::ProgramShape* program_shape_ = nullptr; // [Optional] Profile printer data. Null if profiling is disabled. - const xla::HloProfilePrinterData* hlo_profile_printer_data = nullptr; + const xla::HloProfilePrinterData* hlo_profile_printer_data_ = nullptr; // [Optional] The number of profile counters expected in the profile counter // buffer by the generated code and hlo_profile_printer. 0 if profiling is // disabled. This information is already present in // hlo_profile_printer_data but xla::HloProfilePrinterData is forward // declared so we don't have access to that information here. - int64 profile_counters_size = 0; + int64 profile_counters_size_ = 0; + + // Only XlaCompiledCpuFunction is allowed to read the above fields. + friend class XlaCompiledCpuFunction; }; // AllocMode controls the buffer allocation mode. @@ -113,11 +163,7 @@ class XlaCompiledCpuFunction { // Runs the computation, with inputs read from arg buffers, and outputs // written to result buffers. Returns true on success and false on failure. - bool Run() { - raw_function_(temps_[result_index_], &run_options_, - const_cast(args_), temps_, profile_counters_); - return true; - } + bool Run(); // Returns the error message from the previous failed Run call. // @@ -129,14 +175,25 @@ class XlaCompiledCpuFunction { // ------------------------------ // Arg methods for managing input buffers. Buffers are in row-major order. - // Returns the underlying array of argument buffers, where args()[I] is the - // buffer for the positional argument at index I. - void** args() { return args_; } - const void* const* args() const { return args_; } - // Returns the buffer for the positional argument at the given `index`. - void* arg_data(size_t index) { return args_[index]; } - const void* arg_data(size_t index) const { return args_[index]; } + void* arg_data(size_t index) { + return buffer_table_[arg_index_table_[index]]; + } + const void* arg_data(size_t index) const { + return buffer_table_[arg_index_table_[index]]; + } + + int num_args() const { return num_args_; } + + // Returns the size of entry parameter `idx`. + // + // There is a static version of this method on tfcompile generated subclasses + // of XlaCompiledCpuFunction, but try to prefer this when possible since it + // works both for XlaJitCompiledCpuFunction and AOT compiled subclasses. + int arg_size(int idx) const { + assert(idx < num_args()); + return buffer_infos_[arg_index_table_[idx]].size(); + } // Sets the buffer for the positional argument at the given `index` to `data`. // Must be called before Run to have an effect. May be called under any @@ -149,7 +206,9 @@ class XlaCompiledCpuFunction { // // Aliasing of argument and result buffers is not allowed, and results in // undefined behavior. - void set_arg_data(size_t index, void* data) { args_[index] = data; } + void set_arg_data(size_t index, void* data) { + buffer_table_[arg_index_table_[index]] = data; + } // ------------------------------ // Result methods for managing output buffers. Buffers are in row-major order. @@ -159,9 +218,9 @@ class XlaCompiledCpuFunction { // Returns the underlying array of result buffers, where results()[I] is the // buffer for the positional result at index I. - void** results() { return static_cast(temps_[result_index_]); } + void** results() { return static_cast(buffer_table_[result_index_]); } const void* const* results() const { - return static_cast(temps_[result_index_]); + return static_cast(buffer_table_[result_index_]); } // Profile counters for this XLA computation. @@ -219,14 +278,28 @@ class XlaCompiledCpuFunction { const RawFunction raw_function_; const size_t result_index_; - // Arrays of argument and temp buffers; entries in args_ may be overwritten by - // the user. - void** args_ = nullptr; - void** temps_ = nullptr; + // Array containing pointers to argument and temp buffers (slots corresponding + // to constant and on-stack buffers are null). + void** const buffer_table_; - // Backing memory for individual arg and temp buffers. - void* alloc_args_ = nullptr; - void* alloc_temps_ = nullptr; + // Describes the buffers used by the XLA computation. + const cpu_function_runtime::BufferInfo* const buffer_infos_; + + // Argument i needs to be placed in buffer_table_[arg_index_to_temp_index_[i]] + // for XLA generated code to be able to find it. + // + // For now we need to keep around the args_ array because there is code that + // depends on args() returning a void**. However, in the future we may remove + // args_ in favor of using buffer_table_ as the sole storage for the + // arguments. + const int32* const arg_index_table_; + + // The number of incoming arguments. + const int32 num_args_; + + // Backing memory for buffer_table_ and args_, the latter depending on + // AllocMode. + void* alloc_buffer_table_ = nullptr; // Backing memory for profiling counters. int64* profile_counters_ = nullptr; diff --git a/tensorflow/compiler/tf2xla/xla_compiler.cc b/tensorflow/compiler/tf2xla/xla_compiler.cc index 678e209cf6551b7071a67fa62b2d3e4d12f4efb9..43ff5fcef89f23cf08a09607427cec9c03f0f6e5 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler.cc +++ b/tensorflow/compiler/tf2xla/xla_compiler.cc @@ -18,6 +18,7 @@ limitations under the License. #include #include +#include "absl/memory/memory.h" #include "tensorflow/compiler/tf2xla/dump_graph.h" #include "tensorflow/compiler/tf2xla/functionalize_control_flow.h" #include "tensorflow/compiler/tf2xla/graph_compiler.h" @@ -28,13 +29,14 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_compilation_device.h" #include "tensorflow/compiler/tf2xla/xla_context.h" #include "tensorflow/compiler/xla/client/client_library.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/core/common_runtime/device.h" #include "tensorflow/core/common_runtime/executor.h" #include "tensorflow/core/common_runtime/function.h" #include "tensorflow/core/common_runtime/graph_optimizer.h" #include "tensorflow/core/framework/attr_value_util.h" +#include "tensorflow/core/framework/node_def_util.h" #include "tensorflow/core/graph/algorithm.h" #include "tensorflow/core/graph/graph_constructor.h" #include "tensorflow/core/graph/node_builder.h" @@ -309,7 +311,7 @@ Status ExecuteGraph(XlaContext* xla_context, std::unique_ptr graph, // unique_ptr so we can capture the cleanup status in the end. xla_context->Ref(); Status status; - auto step_container = xla::MakeUnique( + auto step_container = absl::make_unique( step_id, [&status, device](const string& name) { status = device->resource_manager()->Cleanup(name); }); @@ -689,12 +691,12 @@ Status ValidateFunctionDef(const FunctionDef* fdef, Status ValidateGraph(const Graph* graph, const FunctionLibraryDefinition& flib_def, const DeviceType& device_type, const string& name) { - auto maybe_error = [&](const string& op, const Status& s) -> Status { + auto maybe_error = [&](const Node* node, const Status& s) -> Status { if (!s.ok()) { return errors::InvalidArgument(strings::StrCat( "Detected unsupported operations when trying to compile graph ", name, - " on ", device_type.type_string(), ": ", op, " (", s.error_message(), - ")")); + " on ", device_type.type_string(), ": ", node->def().op(), " (", + s.error_message(), ")", FormatNodeForError(*node))); } return Status::OK(); }; @@ -707,15 +709,15 @@ Status ValidateGraph(const Graph* graph, Status s; if (fdef) { s = ValidateFunctionDef(fdef, flib_def); - TF_RETURN_IF_ERROR(maybe_error(node->def().op(), s)); + TF_RETURN_IF_ERROR(maybe_error(node, s)); continue; } const OpDef* op_def; s = OpRegistry::Global()->LookUpOpDef(node->def().op(), &op_def); - TF_RETURN_IF_ERROR(maybe_error(node->def().op(), s)); + TF_RETURN_IF_ERROR(maybe_error(node, s)); TF_RETURN_IF_ERROR(ValidateNodeDef(node->def(), *op_def)); s = FindKernelDef(device_type, node->def(), nullptr, nullptr); - TF_RETURN_IF_ERROR(maybe_error(node->def().op(), s)); + TF_RETURN_IF_ERROR(maybe_error(node, s)); } return Status::OK(); } diff --git a/tensorflow/compiler/tf2xla/xla_compiler.h b/tensorflow/compiler/tf2xla/xla_compiler.h index acc64d99d3e7f0be76aada5ac4042787a5f4b0f6..25332c8d8e3210a0217a1ba3f5767115fe6b1d93 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler.h +++ b/tensorflow/compiler/tf2xla/xla_compiler.h @@ -252,6 +252,12 @@ class XlaCompiler { // The default empty value is invalid. DeviceType device_type = DeviceType(""); + // The device to use during compilation to execute instructions on, for + // example for auto-tuning. + // Valid values are defined by `xla::Backend::devices_ordinal_supported()`. + // -1 indicates the default device should be used. + int device_ordinal = -1; + xla::Client* client = nullptr; // Function library in which to find function definitions. Must be non-null. diff --git a/tensorflow/compiler/tf2xla/xla_compiler_test.cc b/tensorflow/compiler/tf2xla/xla_compiler_test.cc index 2fb93be01d6bf4dca22b74f64c1d6c8b0d7f6fb5..7227df96499f6e8f1b5f09ad5e27aa5f7b63e8c8 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler_test.cc +++ b/tensorflow/compiler/tf2xla/xla_compiler_test.cc @@ -312,7 +312,7 @@ TEST_F(XlaCompilerTest, HasSaneErrorOnNonCompileTimeConstantInputToReshape) { str_util::StrContains(status.error_message(), "depends on a parameter")) << status.error_message(); EXPECT_TRUE( - str_util::StrContains(status.error_message(), "[[Node: C = Reshape")) + str_util::StrContains(status.error_message(), "[[{{node C}} = Reshape")) << status.error_message(); } @@ -821,7 +821,10 @@ TEST_F(XlaCompilerTest, Variables) { Scope scope = Scope::NewRootScope().ExitOnError(); auto a = ops::_Arg(scope.WithOpName("A"), DT_INT32, 0); auto var = ops::_Arg(scope.WithOpName("V"), DT_RESOURCE, 1); - auto write = ops::AssignAddVariableOp(scope, var, a); + // Adds an identity op around the resource to make sure identity ops propagate + // resources correctly. + auto identity = ops::Identity(scope.WithOpName("VIdentity"), var); + auto write = ops::AssignAddVariableOp(scope, identity, a); auto read = ops::ReadVariableOp( scope.WithControlDependencies(std::vector{write}), var, DT_INT32); @@ -1077,6 +1080,8 @@ TEST_F(XlaCompilerTest, FunctionWithInvalidOp) { ASSERT_FALSE(status.ok()); EXPECT_TRUE(str_util::StrContains(status.error_message(), "InvalidOp")) << status.error_message(); + EXPECT_TRUE(str_util::StrContains(status.error_message(), "{{node fill_fn}}")) + << status.error_message(); } // Tests a graph which has a node with invalid data type. @@ -1101,6 +1106,8 @@ TEST_F(XlaCompilerTest, NodeWithInvalidDataType) { EXPECT_TRUE(str_util::StrContains(status.error_message(), "is not in the list of allowed values")) << status.error_message(); + EXPECT_TRUE(str_util::StrContains(status.error_message(), "{{node Shape}}")) + << status.error_message(); } TEST_F(XlaCompilerTest, SingleOpWithoutInputs) { @@ -1122,9 +1129,10 @@ TEST_F(XlaCompilerTest, SingleOpWithoutInputs) { status = compiler.CompileGraph(XlaCompiler::CompileOptions(), "NoOp", std::move(graph_copy), args, &result); ASSERT_FALSE(status.ok()); - EXPECT_TRUE(str_util::StrContains(status.error_message(), - "The following nodes are unreachable " - "from the source in the graph: NoOp")) + EXPECT_TRUE( + str_util::StrContains(status.error_message(), + "The following nodes are unreachable " + "from the source in the graph: {{node NoOp}}")) << status.error_message(); } diff --git a/tensorflow/compiler/tf2xla/xla_context.cc b/tensorflow/compiler/tf2xla/xla_context.cc index 2836cb3df3558bc7470b6aec81676c642744a2de..b24e3aabbe6ba858a8bfb4dd435726984cc7b0f5 100644 --- a/tensorflow/compiler/tf2xla/xla_context.cc +++ b/tensorflow/compiler/tf2xla/xla_context.cc @@ -25,7 +25,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/xla/client/client_library.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/literal.h" diff --git a/tensorflow/compiler/tf2xla/xla_context.h b/tensorflow/compiler/tf2xla/xla_context.h index beee7d48e89a4217b382b27f173f1c2b49c86611..3db37afdba71342cfb20af8841a40cb54709ca73 100644 --- a/tensorflow/compiler/tf2xla/xla_context.h +++ b/tensorflow/compiler/tf2xla/xla_context.h @@ -22,7 +22,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_compilation_device.h" #include "tensorflow/compiler/tf2xla/xla_compiler.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/xla_data.pb.h" diff --git a/tensorflow/compiler/tf2xla/xla_gpu_backend.cc b/tensorflow/compiler/tf2xla/xla_gpu_backend.cc index dc98d4fda6ae21411065981a7b7383ef0ad50f44..1398e9ee536a9675e5b703ec3fabf4a8b9d89cbf 100644 --- a/tensorflow/compiler/tf2xla/xla_gpu_backend.cc +++ b/tensorflow/compiler/tf2xla/xla_gpu_backend.cc @@ -13,7 +13,6 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/tf2xla/legacy_flags/backend_registration_flags.h" #include "tensorflow/compiler/tf2xla/tf2xla_util.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/core/framework/kernel_def.pb.h" @@ -21,20 +20,6 @@ limitations under the License. namespace tensorflow { bool GpuOpFilter(KernelDef* kdef) { - // TODO(b/31361304): The GPU backend does not parallelize PRNG ops, leading to - // slow code. - legacy_flags::BackendRegistrationFlags* flags = - legacy_flags::GetBackendRegistrationFlags(); - VLOG(2) << "flags->tf_enable_prng_ops_gpu: " << flags->tf_enable_prng_ops_gpu; - if (!flags->tf_enable_prng_ops_gpu && - (kdef->op() == "RandomStandardNormal" || kdef->op() == "RandomUniform" || - kdef->op() == "RandomUniformInt" || kdef->op() == "TruncatedNormal")) { - return false; - } - // TODO(b/26783907): The GPU backend currently does not implement sort. - if (kdef->op() == "XlaSort" || kdef->op() == "TopKV2") { - return false; - } if (kdef->op() == "Const") { AddDtypeToKernalDefConstraint("dtype", DT_STRING, kdef); } diff --git a/tensorflow/compiler/tf2xla/xla_helpers.cc b/tensorflow/compiler/tf2xla/xla_helpers.cc index 225da168073f6f5bb00293ad2e9621f5a1da2baa..8efb3d55c88757b9366bdf9622287bdd0a72e295 100644 --- a/tensorflow/compiler/tf2xla/xla_helpers.cc +++ b/tensorflow/compiler/tf2xla/xla_helpers.cc @@ -26,7 +26,7 @@ limitations under the License. #include "tensorflow/compiler/xla/client/lib/arithmetic.h" #include "tensorflow/compiler/xla/client/lib/constants.h" #include "tensorflow/compiler/xla/client/lib/numeric.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/core/framework/tensor.h" diff --git a/tensorflow/compiler/tf2xla/xla_helpers.h b/tensorflow/compiler/tf2xla/xla_helpers.h index d6ca4ab9346593892917e8375b07a8790dc26e79..e6522157a535fc3e4ec96cb0496b6be2e525c336 100644 --- a/tensorflow/compiler/tf2xla/xla_helpers.h +++ b/tensorflow/compiler/tf2xla/xla_helpers.h @@ -19,7 +19,7 @@ limitations under the License. #define TENSORFLOW_COMPILER_TF2XLA_XLA_HELPERS_H_ #include "tensorflow/compiler/tf2xla/xla_context.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/lib/gtl/array_slice.h" diff --git a/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function.cc b/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function.cc index 00ccfb1c7873c85564b1bf4cf582cd31baa17ad5..86a78ee429e8913edb4a948727fa692083c472f4 100644 --- a/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function.cc +++ b/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function.cc @@ -24,6 +24,7 @@ limitations under the License. #include "tensorflow/compiler/xla/client/client_library.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/client/xla_computation.h" +#include "tensorflow/compiler/xla/service/cpu/buffer_info_util.h" #include "tensorflow/compiler/xla/service/cpu/cpu_executable.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" @@ -35,41 +36,6 @@ limitations under the License. namespace tensorflow { namespace { - -// Returns a vector of positional argument buffer sizes. -xla::StatusOr> ComputeArgSizes( - const xla::ProgramShape& program_shape) { - std::vector arg_sizes; - const size_t num_args = program_shape.parameters_size(); - arg_sizes.reserve(num_args); - for (int i = 0; i < num_args; ++i) { - const xla::Shape& arg_shape = program_shape.parameters(i); - constexpr size_t kPointerSize = sizeof(void*); - arg_sizes.push_back(xla::ShapeUtil::ByteSizeOf(arg_shape, kPointerSize)); - } - return std::move(arg_sizes); -} - -// Returns a vector of positional temporary buffer sizes. -xla::StatusOr> ComputeTempSizes( - const xla::BufferAssignment& buffer_assignment) { - const std::vector& allocations = - buffer_assignment.Allocations(); - std::vector temp_sizes; - temp_sizes.reserve(allocations.size()); - for (const xla::BufferAllocation& allocation : allocations) { - // Callers don't allocate temporary buffers for parameters. Nor for - // thread-local buffers, which are lowered to alloca. - if (allocation.is_entry_computation_parameter() || - allocation.is_thread_local()) { - temp_sizes.push_back(-1); - } else { - temp_sizes.push_back(allocation.size()); - } - } - return std::move(temp_sizes); -} - // Returns the index of the result in the temp buffers. xla::StatusOr ComputeResultIndex( const xla::BufferAssignment& buffer_assignment) { @@ -153,11 +119,11 @@ XlaJitCompiledCpuFunction::Compile( const xla::BufferAssignment& buffer_assignment = cpu_executable->buffer_assignment(); - // Compute buffer sizes and the result index, needed to run the raw function. - TF_ASSIGN_OR_RETURN(std::vector arg_sizes, - ComputeArgSizes(*program_shape)); - TF_ASSIGN_OR_RETURN(std::vector temp_sizes, - ComputeTempSizes(buffer_assignment)); + // Compute buffer infos and the result index, needed to run the raw function. + std::vector buffer_infos = + xla::cpu::CreateBufferInfosFromBufferAssignment(buffer_assignment); + std::vector arg_index_table = + xla::cpu::CreateArgIndexTableFromBufferInfos(buffer_infos); TF_ASSIGN_OR_RETURN(size_t result_index, ComputeResultIndex(buffer_assignment)); @@ -165,28 +131,28 @@ XlaJitCompiledCpuFunction::Compile( new XlaJitCompiledCpuFunction); XlaJitCompiledCpuFunction* jit = jit_unique_ptr.get(); jit->executable_ = std::move(executable); - jit->arg_sizes_ = std::move(arg_sizes); - jit->temp_sizes_ = std::move(temp_sizes); + jit->buffer_infos_ = std::move(buffer_infos); + jit->arg_index_table_ = std::move(arg_index_table); jit->program_shape_ = std::move(program_shape); - jit->static_data_.raw_function = std::move(raw_function); - jit->static_data_.arg_sizes = jit->arg_sizes_.data(); - jit->static_data_.num_args = jit->arg_sizes_.size(); - jit->static_data_.temp_sizes = jit->temp_sizes_.data(); - jit->static_data_.num_temps = jit->temp_sizes_.size(); - jit->static_data_.result_index = result_index; + jit->static_data_.set_raw_function(raw_function); + jit->static_data_.set_buffer_infos(jit->buffer_infos_.data()); + jit->static_data_.set_num_buffers(jit->buffer_infos_.size()); + jit->static_data_.set_arg_index_table(jit->arg_index_table_.data()); + jit->static_data_.set_num_args(jit->arg_index_table_.size()); + jit->static_data_.set_result_index(result_index); // Optional metadata is collected and set below. CollectNames(config.feed(), &jit->nonempty_arg_names_, &jit->arg_names_); CollectNames(config.fetch(), &jit->nonempty_result_names_, &jit->result_names_); - jit->static_data_.arg_names = jit->arg_names_.data(); - jit->static_data_.result_names = jit->result_names_.data(); - jit->static_data_.program_shape = jit->program_shape_.get(); + jit->static_data_.set_arg_names(jit->arg_names_.data()); + jit->static_data_.set_result_names(jit->result_names_.data()); + jit->static_data_.set_program_shape(jit->program_shape_.get()); if (cpu_executable->hlo_profiling_enabled()) { - jit->static_data_.hlo_profile_printer_data = - &cpu_executable->hlo_profile_printer_data(); - jit->static_data_.profile_counters_size = - cpu_executable->hlo_profile_printer_data().profile_counters_size(); + jit->static_data_.set_hlo_profile_printer_data( + &cpu_executable->hlo_profile_printer_data()); + jit->static_data_.set_profile_counters_size( + cpu_executable->hlo_profile_printer_data().profile_counters_size()); } return std::move(jit_unique_ptr); diff --git a/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function.h b/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function.h index af307ae4eff74927242c4650d8a43710e991cc52..d3c8f22a8078d03d15447ed200c914390f40b04f 100644 --- a/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function.h +++ b/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function.h @@ -66,9 +66,11 @@ class XlaJitCompiledCpuFunction { // The static data is backed by the rest of the state in this class. XlaCompiledCpuFunction::StaticData static_data_; - // The backing arrays of arg and temp buffer sizes. - std::vector arg_sizes_; - std::vector temp_sizes_; + // The backing array for buffer infos. + std::vector buffer_infos_; + + // The backing array for the arg index table. + std::vector arg_index_table_; // The backing arrays of arg and result names. We hold the actual strings in // nonempty_*_names_, and hold arrays of pointers in *_names_ for the static diff --git a/tensorflow/compiler/tf2xla/xla_op_kernel.cc b/tensorflow/compiler/tf2xla/xla_op_kernel.cc index 38ec559576e8d829a8ca175c52205d384693f221..82028c8b9ca9f65a73f8b50edc0a47c7068aba9a 100644 --- a/tensorflow/compiler/tf2xla/xla_op_kernel.cc +++ b/tensorflow/compiler/tf2xla/xla_op_kernel.cc @@ -21,7 +21,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/tf2xla/type_util.h" #include "tensorflow/compiler/tf2xla/xla_context.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/core/common_runtime/dma_helper.h" diff --git a/tensorflow/compiler/tf2xla/xla_op_kernel.h b/tensorflow/compiler/tf2xla/xla_op_kernel.h index 71990b57d9b61efaa9f1d276ad067ae4567dfbb3..ac9dfe3369078df7392a4ef04679f7d7beacf8bb 100644 --- a/tensorflow/compiler/tf2xla/xla_op_kernel.h +++ b/tensorflow/compiler/tf2xla/xla_op_kernel.h @@ -17,7 +17,7 @@ limitations under the License. #define TENSORFLOW_COMPILER_TF2XLA_XLA_OP_KERNEL_H_ #include "tensorflow/compiler/tf2xla/xla_compiler.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/framework/op_kernel.h" diff --git a/tensorflow/compiler/tf2xla/xla_resource.cc b/tensorflow/compiler/tf2xla/xla_resource.cc index baea8149658ec0849ebb570931ca68518ec5284e..7928fa034725206a752cbfe086d01f15cd235df9 100644 --- a/tensorflow/compiler/tf2xla/xla_resource.cc +++ b/tensorflow/compiler/tf2xla/xla_resource.cc @@ -22,7 +22,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/sharding_util.h" #include "tensorflow/compiler/tf2xla/xla_context.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" namespace tensorflow { diff --git a/tensorflow/compiler/tf2xla/xla_resource.h b/tensorflow/compiler/tf2xla/xla_resource.h index 4de18a77887496d30e3b1407ecd9042e619653af..2438490be13809b9f3571a362900b44cb838e76b 100644 --- a/tensorflow/compiler/tf2xla/xla_resource.h +++ b/tensorflow/compiler/tf2xla/xla_resource.h @@ -18,7 +18,7 @@ limitations under the License. #include -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/framework/types.pb.h" diff --git a/tensorflow/compiler/xla/BUILD b/tensorflow/compiler/xla/BUILD index f1c383fd9e3fff8a306ba0ddcc3f9ee42c63d66a..2cf77b71fb21a21f912c7fc2ef9980ca7afe92d2 100644 --- a/tensorflow/compiler/xla/BUILD +++ b/tensorflow/compiler/xla/BUILD @@ -161,7 +161,6 @@ cc_library( "iterator_util.h", "map_util.h", "overflow_util.h", - "ptr_util.h", "util.h", ], visibility = ["//visibility:public"], @@ -172,7 +171,8 @@ cc_library( ":types", ":xla_data_proto", "//tensorflow/core:lib", - "//tensorflow/core:ptr_util", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/memory", ], ) @@ -210,6 +210,7 @@ tf_cc_test( ":test", ":util", "//tensorflow/core:test_main", + "@com_google_absl//absl/memory", ], ) @@ -297,6 +298,7 @@ cc_library( ":util", ":xla_data_proto", "//tensorflow/core:lib", + "@com_google_absl//absl/memory", ], ) @@ -315,6 +317,7 @@ tf_cc_test( "//tensorflow/core:lib", "//tensorflow/core:test", "//tensorflow/core:test_main", + "@com_google_absl//absl/memory", ], ) @@ -335,6 +338,7 @@ cc_library( ":util", ":xla_data_proto", "//tensorflow/core:lib", + "@com_google_absl//absl/memory", ], ) @@ -405,8 +409,8 @@ cc_library( deps = [ ":array", ":types", - ":util", "//tensorflow/core:lib", + "@com_google_absl//absl/memory", ], ) @@ -489,6 +493,7 @@ cc_library( ":util", ":xla_data_proto", "//tensorflow/core:lib", + "@com_google_absl//absl/memory", ], ) @@ -521,6 +526,7 @@ cc_library( ":xla_data_proto", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", + "@com_google_absl//absl/memory", ], ) @@ -576,10 +582,10 @@ cc_library( deps = [ ":shape_util", ":status_macros", - ":util", ":xla_data_proto", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", + "@com_google_absl//absl/memory", ], ) @@ -593,6 +599,7 @@ tf_cc_test( ":xla_data_proto", "//tensorflow/core:test", "//tensorflow/core:test_main", + "@com_google_absl//absl/memory", ], ) @@ -636,12 +643,13 @@ cc_library( ":window_util", ":xla_data_proto", "//tensorflow/compiler/xla/client:padding", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", + "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service:hlo_evaluator", "//tensorflow/compiler/xla/service:shape_inference", "//tensorflow/compiler/xla/service/cpu:runtime_single_threaded_matmul", "//tensorflow/core:lib", + "@com_google_absl//absl/memory", ], ) @@ -660,6 +668,7 @@ tf_cc_test( "//tensorflow/compiler/xla/client:padding", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/core:test_main", + "@com_google_absl//absl/memory", ], ) diff --git a/tensorflow/compiler/xla/array.h b/tensorflow/compiler/xla/array.h index ea75ad32d5df7bbadd37e89de6144b264ab6d5d1..2d5d078aa77423cc18bab053b80a7576acbd849e 100644 --- a/tensorflow/compiler/xla/array.h +++ b/tensorflow/compiler/xla/array.h @@ -409,7 +409,7 @@ class Array { // Returns the total number of elements in the array. int64 num_elements() const { - return std::accumulate(sizes_.begin(), sizes_.end(), 1, + return std::accumulate(sizes_.begin(), sizes_.end(), 1LL, std::multiplies()); } diff --git a/tensorflow/compiler/xla/array2d.h b/tensorflow/compiler/xla/array2d.h index a17e81f44832f272fd93dce9f854042b4a84fde4..340f94fab72a24fb39cf1dfc1d722e2ee6c3685a 100644 --- a/tensorflow/compiler/xla/array2d.h +++ b/tensorflow/compiler/xla/array2d.h @@ -24,8 +24,8 @@ limitations under the License. #include #include +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/array.h" -#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/core/lib/core/bits.h" #include "tensorflow/core/lib/strings/str_util.h" @@ -101,7 +101,7 @@ class Array2D : public Array { template std::unique_ptr> MakeLinspaceArray2D(double from, double to, int64 n1, int64 n2) { - auto array = MakeUnique>(n1, n2); + auto array = absl::make_unique>(n1, n2); int64 count = n1 * n2; NativeT step = static_cast((count > 1) ? (to - from) / (count - 1) : 0); diff --git a/tensorflow/compiler/xla/client/BUILD b/tensorflow/compiler/xla/client/BUILD index 289d3f552aee3514b2b04c3a30ffb904a66fc5d4..6be44b1c390deb67c2b33853675aaa12cdfc8621 100644 --- a/tensorflow/compiler/xla/client/BUILD +++ b/tensorflow/compiler/xla/client/BUILD @@ -71,12 +71,12 @@ cc_library( "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:types", - "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla:xla_proto", "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/compiler/xla/service:hlo_proto", "//tensorflow/core:lib", + "@com_google_absl//absl/memory", ], ) @@ -104,7 +104,6 @@ cc_library( "//tensorflow/compiler/xla:executable_run_options", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", - "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/service:backend", "//tensorflow/compiler/xla/service:compiler", @@ -114,8 +113,10 @@ cc_library( "//tensorflow/compiler/xla/service:local_service", "//tensorflow/compiler/xla/service:shaped_buffer", "//tensorflow/compiler/xla/service:source_map_util", + "//tensorflow/compiler/xla/service:stream_pool", "//tensorflow/core:lib", "//tensorflow/core:stream_executor_no_cuda", + "@com_google_absl//absl/memory", "@llvm//:support", ], ) @@ -129,11 +130,11 @@ cc_library( ":xla_computation", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", - "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/service:compile_only_service", "//tensorflow/compiler/xla/service:compiler", "//tensorflow/core:stream_executor_no_cuda", + "@com_google_absl//absl/memory", "@llvm//:support", ], ) @@ -158,6 +159,7 @@ cc_library( "//tensorflow/compiler/xla/service:platform_util", "//tensorflow/core:lib", "//tensorflow/core:stream_executor_no_cuda", + "@com_google_absl//absl/memory", ], ) @@ -185,5 +187,52 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/service:hlo_proto", + "@com_google_absl//absl/memory", + ], +) + +cc_library( + name = "xla_builder", + srcs = ["xla_builder.cc"], + hdrs = ["xla_builder.h"], + visibility = ["//visibility:public"], + deps = [ + ":padding", + ":sharding_builder", + ":xla_computation", + "//tensorflow/compiler/xla:execution_options_util", + "//tensorflow/compiler/xla:literal", + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service:hlo_proto", + "//tensorflow/compiler/xla/service:shape_inference", + "//tensorflow/core:lib", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/memory", + ], +) + +tf_cc_test( + name = "xla_builder_test", + srcs = ["xla_builder_test.cc"], + deps = [ + ":xla_builder", + ":xla_computation", + "//tensorflow/compiler/xla:literal", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla:test_helpers", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", + "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service:hlo_matchers", + "//tensorflow/core:test", ], ) diff --git a/tensorflow/compiler/xla/client/client.cc b/tensorflow/compiler/xla/client/client.cc index d0ce5e8a6afa262d4cffdfe8431aab570ffd28df..25608d6616f687825db0fb3d739e52f1ade9ce52 100644 --- a/tensorflow/compiler/xla/client/client.cc +++ b/tensorflow/compiler/xla/client/client.cc @@ -18,11 +18,11 @@ limitations under the License. #include #include +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/execution_options_util.h" #include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/literal.h" -#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/core/lib/core/errors.h" @@ -89,7 +89,7 @@ StatusOr> Client::TransferToServer( "TransferToServer request"); } - return MakeUnique(stub_, response.data()); + return absl::make_unique(stub_, response.data()); } Status Client::TransferToInfeed(const LiteralSlice& literal, int64 replica_id, @@ -248,7 +248,7 @@ StatusOr> Client::Execute( } } - return MakeUnique(stub_, response.output()); + return absl::make_unique(stub_, response.output()); } StatusOr>> Client::ExecuteParallel( @@ -278,7 +278,7 @@ StatusOr>> Client::ExecuteParallel( std::vector> outputs; for (size_t i = 0; i < computations.size(); ++i) { outputs.push_back( - MakeUnique(stub_, response.responses(i).output())); + absl::make_unique(stub_, response.responses(i).output())); if (computations[i].execution_profile != nullptr) { *computations[i].execution_profile = response.responses(i).profile(); } @@ -340,7 +340,7 @@ StatusOr>> Client::DeconstructTuple( std::vector> handles; for (auto& handle : response.element_handles()) { - handles.push_back(MakeUnique(stub_, handle)); + handles.push_back(absl::make_unique(stub_, handle)); } return std::move(handles); } @@ -369,7 +369,7 @@ StatusOr Client::GetComputationStats( StatusOr> Client::GetComputationShape( const XlaComputation& computation) { TF_ASSIGN_OR_RETURN(const auto& result, computation.GetProgramShape()); - return MakeUnique(result); + return absl::make_unique(result); } StatusOr Client::GetShape(const GlobalData& data) { diff --git a/tensorflow/compiler/xla/client/client_library.cc b/tensorflow/compiler/xla/client/client_library.cc index 803a9e40094391ba47ed27713f4538caf875c4f6..27b7fa7b29206affa9f9c2e4becd9e4ea66484ab 100644 --- a/tensorflow/compiler/xla/client/client_library.cc +++ b/tensorflow/compiler/xla/client/client_library.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/compiler/xla/client/client_library.h" +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/service/backend.h" #include "tensorflow/compiler/xla/service/platform_util.h" #include "tensorflow/compiler/xla/status_macros.h" @@ -94,10 +95,10 @@ ClientLibrary::~ClientLibrary() = default; service_options.set_intra_op_parallelism_threads( options.intra_op_parallelism_threads()); - auto instance = MakeUnique(); + auto instance = absl::make_unique(); TF_ASSIGN_OR_RETURN(instance->service, LocalService::NewService(service_options)); - instance->client = MakeUnique(instance->service.get()); + instance->client = absl::make_unique(instance->service.get()); LocalClient* cl = instance->client.get(); client_library.local_instances_.insert( @@ -134,10 +135,11 @@ ClientLibrary::GetOrCreateCompileOnlyClient(se::Platform* platform) { return it->second->client.get(); } - auto instance = MakeUnique(); + auto instance = absl::make_unique(); TF_ASSIGN_OR_RETURN(instance->service, CompileOnlyService::NewService(platform)); - instance->client = MakeUnique(instance->service.get()); + instance->client = + absl::make_unique(instance->service.get()); CompileOnlyClient* cl = instance->client.get(); client_library.compile_only_instances_.insert( diff --git a/tensorflow/compiler/xla/client/compile_only_client.cc b/tensorflow/compiler/xla/client/compile_only_client.cc index 5c9abad4c3126be5e45e96c770c0679fe8606788..b6012a0352069917063084c5c5f022ef3e8c27a1 100644 --- a/tensorflow/compiler/xla/client/compile_only_client.cc +++ b/tensorflow/compiler/xla/client/compile_only_client.cc @@ -15,8 +15,8 @@ limitations under the License. #include "tensorflow/compiler/xla/client/compile_only_client.h" +#include "absl/memory/memory.h" #include "llvm/ADT/Triple.h" -#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/status_macros.h" namespace xla { diff --git a/tensorflow/compiler/xla/client/lib/BUILD b/tensorflow/compiler/xla/client/lib/BUILD index 45506986c88124920a13be75a41a24ef2b8facf1..a2f32ab97eab10294a607f35fc79ded1cc2c5792 100644 --- a/tensorflow/compiler/xla/client/lib/BUILD +++ b/tensorflow/compiler/xla/client/lib/BUILD @@ -29,8 +29,8 @@ cc_library( "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/client:xla_computation", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", "//tensorflow/core:lib", ], ) @@ -45,7 +45,7 @@ cc_library( "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", + "//tensorflow/compiler/xla/client:xla_builder", ], ) @@ -58,7 +58,7 @@ xla_test( "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", + "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", ], @@ -72,7 +72,7 @@ cc_library( ":constants", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", + "//tensorflow/compiler/xla/client:xla_builder", ], ) @@ -86,7 +86,7 @@ xla_test( "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", + "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", ], @@ -101,7 +101,7 @@ cc_library( ":constants", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", + "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/core:lib", ], ) @@ -115,7 +115,31 @@ xla_test( "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/tests:client_library_test_base", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + ], +) + +cc_library( + name = "pooling", + srcs = ["pooling.cc"], + hdrs = ["pooling.h"], + deps = [ + ":arithmetic", + ":constants", + "//tensorflow/compiler/tf2xla/lib:util", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/core:lib", + ], +) + +xla_test( + name = "pooling_test", + srcs = ["pooling_test.cc"], + deps = [ + ":pooling", + "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", ], @@ -131,11 +155,42 @@ cc_library( ":numeric", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", + "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/core:lib", ], ) +cc_library( + name = "sorting", + srcs = ["sorting.cc"], + hdrs = ["sorting.h"], + deps = [ + ":numeric", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/client:xla_builder", + ], +) + +xla_test( + name = "sorting_test", + srcs = ["sorting_test.cc"], + blacklisted_backends = [ + "cpu", + "gpu", + ], + tags = ["enable_for_xla_interpreter"], + deps = [ + ":sorting", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/tests:client_library_test_base", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + ], +) + cc_library( name = "testing", srcs = ["testing.cc"], @@ -150,8 +205,8 @@ cc_library( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client", "//tensorflow/compiler/xla/client:global_data", + "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/client:xla_computation", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", "//tensorflow/compiler/xla/tests:test_utils", "//tensorflow/core:lib", ], diff --git a/tensorflow/compiler/xla/client/lib/arithmetic.cc b/tensorflow/compiler/xla/client/lib/arithmetic.cc index 1872925aba30465e391a1ee8b4287588dca05598..9225b1acd69c214d6f08a45372a8082ed789c18c 100644 --- a/tensorflow/compiler/xla/client/lib/arithmetic.cc +++ b/tensorflow/compiler/xla/client/lib/arithmetic.cc @@ -18,7 +18,7 @@ limitations under the License. #include #include "tensorflow/compiler/xla/client/lib/constants.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" diff --git a/tensorflow/compiler/xla/client/lib/arithmetic.h b/tensorflow/compiler/xla/client/lib/arithmetic.h index 80d3f8b95ac0553f27923c739ce083bbd1b2164b..632e8cc8bc64fad236a0226c6e93079aadde7050 100644 --- a/tensorflow/compiler/xla/client/lib/arithmetic.h +++ b/tensorflow/compiler/xla/client/lib/arithmetic.h @@ -18,7 +18,7 @@ limitations under the License. #include -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/xla_data.pb.h" diff --git a/tensorflow/compiler/xla/client/lib/constants.h b/tensorflow/compiler/xla/client/lib/constants.h index b47f5243f008ecb2045456e4505d1a571fbed745..0c8a9b8cc02ba0c1ebdf6a060d4b99262dceb178 100644 --- a/tensorflow/compiler/xla/client/lib/constants.h +++ b/tensorflow/compiler/xla/client/lib/constants.h @@ -18,7 +18,7 @@ limitations under the License. #include -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/primitive_util.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" diff --git a/tensorflow/compiler/xla/client/lib/constants_test.cc b/tensorflow/compiler/xla/client/lib/constants_test.cc index f1e3439862344c01af15ec0571155ca46a579e54..f4320f65c1f76d4d4c384110b39d6606773aaf01 100644 --- a/tensorflow/compiler/xla/client/lib/constants_test.cc +++ b/tensorflow/compiler/xla/client/lib/constants_test.cc @@ -14,7 +14,7 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/xla/client/lib/constants.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" #include "tensorflow/compiler/xla/tests/test_macros.h" diff --git a/tensorflow/compiler/xla/client/lib/math.cc b/tensorflow/compiler/xla/client/lib/math.cc index 0221de7672c7b7c02b1f8b9c7ff4f92151e567c6..e569610b85578769750216d18151e635d475db37 100644 --- a/tensorflow/compiler/xla/client/lib/math.cc +++ b/tensorflow/compiler/xla/client/lib/math.cc @@ -207,7 +207,11 @@ XlaOp Lgamma(XlaOp input) { XlaOp log_y = log_sqrt_two_pi + (z + one_half) * log_t - t + Log(x); - XlaOp reflection = log_pi - Log(Sin(pi * input)) - log_y; + // If z = a + 0j, the analytic continuation of log reduces to taking the + // absolute value of the real part. + // Re(log(z)) = Re(log|z| + arg(z)j) + // = log|a| + XlaOp reflection = log_pi - Log(Abs(Sin(pi * input))) - log_y; XlaOp result = Select(need_to_reflect, reflection, log_y); return result; } diff --git a/tensorflow/compiler/xla/client/lib/math.h b/tensorflow/compiler/xla/client/lib/math.h index d003d529cc316dfde63f76284f98ae698e1d8034..13db2325569cf2e25e3ff1200adf4b2544dc2f73 100644 --- a/tensorflow/compiler/xla/client/lib/math.h +++ b/tensorflow/compiler/xla/client/lib/math.h @@ -16,7 +16,7 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_CLIENT_LIB_MATH_H_ #define TENSORFLOW_COMPILER_XLA_CLIENT_LIB_MATH_H_ -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" namespace xla { diff --git a/tensorflow/compiler/xla/client/lib/math_test.cc b/tensorflow/compiler/xla/client/lib/math_test.cc index 1df287d7db2fb5498900d1bff51b621915a6b0af..14c259a7fa2a47642663b65d2785e5bbdc040cfd 100644 --- a/tensorflow/compiler/xla/client/lib/math_test.cc +++ b/tensorflow/compiler/xla/client/lib/math_test.cc @@ -14,7 +14,7 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/xla/client/lib/math.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" diff --git a/tensorflow/compiler/xla/client/lib/numeric.h b/tensorflow/compiler/xla/client/lib/numeric.h index 212f6583137390fe1e41bb88b71ba041e2d22ff3..efd8cdc25724198633e0bf1c48c4e7d9e4b4c9e1 100644 --- a/tensorflow/compiler/xla/client/lib/numeric.h +++ b/tensorflow/compiler/xla/client/lib/numeric.h @@ -16,7 +16,7 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_CLIENT_LIB_NUMERIC_H_ #define TENSORFLOW_COMPILER_XLA_CLIENT_LIB_NUMERIC_H_ -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" diff --git a/tensorflow/compiler/xla/client/lib/numeric_test.cc b/tensorflow/compiler/xla/client/lib/numeric_test.cc index f56cadc5472cb4df6ad73c7cf27b14dce528761c..8a96ec68d2dca8485215258b1f6731b934e6f2a8 100644 --- a/tensorflow/compiler/xla/client/lib/numeric_test.cc +++ b/tensorflow/compiler/xla/client/lib/numeric_test.cc @@ -14,7 +14,7 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/xla/client/lib/numeric.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" #include "tensorflow/compiler/xla/tests/test_macros.h" diff --git a/tensorflow/compiler/xla/client/lib/pooling.cc b/tensorflow/compiler/xla/client/lib/pooling.cc new file mode 100644 index 0000000000000000000000000000000000000000..7199269a6c889f3589c1148687faf0bb2aaae90a --- /dev/null +++ b/tensorflow/compiler/xla/client/lib/pooling.cc @@ -0,0 +1,183 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/client/lib/pooling.h" +#include "tensorflow/compiler/tf2xla/lib/util.h" +#include "tensorflow/compiler/xla/client/lib/arithmetic.h" +#include "tensorflow/compiler/xla/client/lib/constants.h" + +namespace xla { + +namespace { + +// Common computation shared between AvgPool and AvgPoolGrad. Divide each +// element of an image by the count of elements that contributed to that +// element during pooling. +XlaOp AvgPoolDivideByCountWithGeneralPadding( + XlaOp sums, PrimitiveType dtype, + tensorflow::gtl::ArraySlice input_shape, + tensorflow::gtl::ArraySlice> spatial_padding, + tensorflow::gtl::ArraySlice ksize, + tensorflow::gtl::ArraySlice stride, + const TensorFormat& data_format) { + // The padding shouldn't be included in the counts. We use another + // ReduceWindow to find the right counts. + const int num_spatial_dims = spatial_padding.size(); + + std::vector input_dim_sizes(num_spatial_dims); + std::vector window_dims(num_spatial_dims); + std::vector window_ksize(num_spatial_dims); + std::vector window_stride(num_spatial_dims); + CHECK_EQ(data_format.num_spatial_dims(), num_spatial_dims) + << "Invalid number of spatial dimentions in data format specification"; + for (int i = 0; i < num_spatial_dims; ++i) { + int dim = data_format.spatial_dimension(i); + input_dim_sizes[i] = input_shape[dim]; + window_dims[i] = dim; + window_ksize[i] = ksize[dim]; + window_stride[i] = stride[dim]; + } + + XlaBuilder* b = sums.builder(); + // Build a matrix of all 1s, with the same width/height as the input. + auto ones = Broadcast(One(b, dtype), input_dim_sizes); + PaddingConfig padding_config; + for (int i = 0; i < num_spatial_dims; ++i) { + auto dims = padding_config.add_dimensions(); + dims->set_edge_padding_low(spatial_padding[i].first); + dims->set_edge_padding_high(spatial_padding[i].second); + } + auto zero = Zero(b, dtype); + auto padded_ones = Pad(ones, zero, padding_config); + + // Perform a ReduceWindow with the same window size, strides, and padding + // to count the number of contributions to each result element. + auto counts = + ReduceWindow(padded_ones, zero, CreateScalarAddComputation(dtype, b), + window_ksize, window_stride, Padding::kValid); + + return Div(sums, counts, window_dims); +} + +// Sums all elements in the window specified by 'kernel_size' and 'stride'. +XlaOp ComputeSums(XlaOp operand, XlaOp init_value, + tensorflow::gtl::ArraySlice kernel_size, + tensorflow::gtl::ArraySlice stride, + const TensorFormat& data_format) { + XlaBuilder* b = operand.builder(); + return b->ReportErrorOrReturn([&]() -> StatusOr { + TF_ASSIGN_OR_RETURN(Shape operand_shape, b->GetShape(operand)); + TF_ASSIGN_OR_RETURN(Shape init_shape, b->GetShape(init_value)); + PrimitiveType accumulation_type = init_shape.element_type(); + auto add_computation = CreateScalarAddComputation(accumulation_type, b); + return ReduceWindow(operand, init_value, add_computation, kernel_size, + stride, Padding::kValid); + }); +} + +// Creates a padding configuration out of spatial padding values. +PaddingConfig MakeSpatialPaddingConfig( + tensorflow::gtl::ArraySlice> spatial_padding, + tensorflow::gtl::ArraySlice kernel_size, + tensorflow::gtl::ArraySlice stride, + const TensorFormat& data_format) { + const int num_spatial_dims = kernel_size.size() - 2; + PaddingConfig padding_config; + for (int i = 0; i < 2 + num_spatial_dims; ++i) { + padding_config.add_dimensions(); + } + CHECK_EQ(data_format.num_spatial_dims(), num_spatial_dims) + << "Invalid number of spatial dimentions in data format specification"; + for (int i = 0; i < num_spatial_dims; ++i) { + int dim = data_format.spatial_dimension(i); + auto padding_dimension = padding_config.mutable_dimensions(dim); + padding_dimension->set_edge_padding_low(spatial_padding[i].first); + padding_dimension->set_edge_padding_high(spatial_padding[i].second); + } + return padding_config; +} + +} // namespace + +XlaOp MaxPool(XlaOp operand, tensorflow::gtl::ArraySlice kernel_size, + tensorflow::gtl::ArraySlice stride, Padding padding, + const TensorFormat& data_format) { + XlaBuilder* b = operand.builder(); + return b->ReportErrorOrReturn([&]() -> StatusOr { + TF_ASSIGN_OR_RETURN(Shape operand_shape, b->GetShape(operand)); + PrimitiveType dtype = operand_shape.element_type(); + auto max_computation = CreateScalarMaxComputation(dtype, b); + auto init_value = MinValue(b, dtype); + return ReduceWindow(operand, init_value, max_computation, kernel_size, + stride, padding); + }); +} + +XlaOp AvgPool(XlaOp operand, tensorflow::gtl::ArraySlice kernel_size, + tensorflow::gtl::ArraySlice stride, + tensorflow::gtl::ArraySlice> padding, + const TensorFormat& data_format, + const bool counts_include_padding) { + XlaBuilder* b = operand.builder(); + return b->ReportErrorOrReturn([&]() -> StatusOr { + TF_ASSIGN_OR_RETURN(Shape operand_shape, b->GetShape(operand)); + PrimitiveType dtype = operand_shape.element_type(); + auto init_value = Zero(b, dtype); + std::vector input_size(operand_shape.dimensions().begin(), + operand_shape.dimensions().end()); + auto padding_config = + MakeSpatialPaddingConfig(padding, kernel_size, stride, data_format); + auto padded_operand = Pad(operand, Zero(b, dtype), padding_config); + auto pooled = ComputeSums(padded_operand, init_value, kernel_size, stride, + data_format); + if (counts_include_padding) { + // If counts include padding, all windows have the same number of elements + // contributing to each average. Divide by the window size everywhere to + // get the average. + int64 window_size = + std::accumulate(kernel_size.begin(), kernel_size.end(), 1, + [](int64 x, int64 y) { return x * y; }); + + auto divisor = ConstantR0WithType(b, dtype, window_size); + return pooled / divisor; + } else { + return AvgPoolDivideByCountWithGeneralPadding( + pooled, dtype, input_size, padding, kernel_size, stride, data_format); + } + }); +} + +std::vector> MakeSpatialPadding( + tensorflow::gtl::ArraySlice input_size, + tensorflow::gtl::ArraySlice kernel_size, + tensorflow::gtl::ArraySlice stride, Padding padding, + const TensorFormat& data_format) { + const int num_spatial_dims = kernel_size.size() - 2; + std::vector input_spatial_dimensions; + std::vector kernel_size_spatial_dimensions; + std::vector stride_spatial_dimensions; + CHECK_EQ(data_format.num_spatial_dims(), num_spatial_dims) + << "Invalid number of spatial dimentions in data format specification"; + for (int i = 0; i < num_spatial_dims; ++i) { + int dim = data_format.spatial_dimension(i); + input_spatial_dimensions.push_back(input_size[dim]); + kernel_size_spatial_dimensions.push_back(kernel_size[dim]); + stride_spatial_dimensions.push_back(stride[dim]); + } + return MakePadding(input_spatial_dimensions, kernel_size_spatial_dimensions, + stride_spatial_dimensions, padding); +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/client/lib/pooling.h b/tensorflow/compiler/xla/client/lib/pooling.h new file mode 100644 index 0000000000000000000000000000000000000000..1699c585d3b09a306c21cfa797a9023a8463bd1f --- /dev/null +++ b/tensorflow/compiler/xla/client/lib/pooling.h @@ -0,0 +1,73 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_CLIENT_LIB_POOLING_H_ +#define TENSORFLOW_COMPILER_XLA_CLIENT_LIB_POOLING_H_ + +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/core/lib/gtl/inlined_vector.h" + +namespace xla { + +// Tensor format for reduce window operations. +class TensorFormat { + public: + TensorFormat(int batch_dimension, int feature_dimension, + tensorflow::gtl::ArraySlice spatial_dimensions) + : batch_dimension_(batch_dimension), + feature_dimension_(feature_dimension), + spatial_dimensions_(spatial_dimensions.begin(), + spatial_dimensions.end()) {} + + int batch_dimension() const { return batch_dimension_; } + + int feature_dimension() const { return feature_dimension_; } + + int spatial_dimension(int dim) const { return spatial_dimensions_[dim]; } + + int num_spatial_dims() const { return spatial_dimensions_.size(); } + + private: + // The number of the dimension that represents the batch. + int batch_dimension_; + // The number of the dimension that represents the features. + int feature_dimension_; + // The dimension numbers for the spatial dimensions. + tensorflow::gtl::InlinedVector spatial_dimensions_; +}; + +// Computes the max pool of 'operand'. +XlaOp MaxPool(XlaOp operand, tensorflow::gtl::ArraySlice kernel_size, + tensorflow::gtl::ArraySlice stride, Padding padding, + const TensorFormat& data_format); + +// Computes the average pool of 'operand'. +XlaOp AvgPool(XlaOp operand, tensorflow::gtl::ArraySlice kernel_size, + tensorflow::gtl::ArraySlice stride, + tensorflow::gtl::ArraySlice> padding, + const TensorFormat& data_format, + const bool counts_include_padding); + +// Returns the list of low and high padding elements in each spatial dimension +// for the given 'padding' specification. +std::vector> MakeSpatialPadding( + tensorflow::gtl::ArraySlice input_size, + tensorflow::gtl::ArraySlice kernel_size, + tensorflow::gtl::ArraySlice stride, Padding padding, + const TensorFormat& data_format); + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_CLIENT_LIB_POOLING_H_ diff --git a/tensorflow/compiler/xla/client/lib/pooling_test.cc b/tensorflow/compiler/xla/client/lib/pooling_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..4b4553b60db555ad7c2ab6b695236df745e30683 --- /dev/null +++ b/tensorflow/compiler/xla/client/lib/pooling_test.cc @@ -0,0 +1,185 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/client/lib/pooling.h" +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/tests/client_library_test_base.h" +#include "tensorflow/compiler/xla/tests/test_macros.h" + +namespace xla { +namespace { + +TensorFormat MakeNCHWFormat(int num_spatial_dims) { + tensorflow::gtl::InlinedVector spatial_dimensions; + for (int i = 0; i < num_spatial_dims; ++i) { + spatial_dimensions.push_back(i + 2); + } + return TensorFormat(/*batch_dimension=*/0, /*feature_dimension=*/1, + /*spatial_dimensions=*/spatial_dimensions); +} + +std::vector> MakeGeneralPadding( + XlaOp input, tensorflow::gtl::ArraySlice kernel_size, + tensorflow::gtl::ArraySlice stride, Padding padding, + const xla::TensorFormat& data_format) { + XlaBuilder* b = input.builder(); + Shape operand_shape = b->GetShape(input).ValueOrDie(); + std::vector input_size(operand_shape.dimensions().begin(), + operand_shape.dimensions().end()); + return MakeSpatialPadding(input_size, kernel_size, stride, padding, + data_format); +} + +// Add singleton batch and feature dimensions to spatial dimensions, according +// to 'data_format' specification. +std::vector ExpandWithBatchAndFeatureDimensions( + tensorflow::gtl::ArraySlice spatial_dim_sizes, + const xla::TensorFormat& data_format) { + const int num_spatial_dims = spatial_dim_sizes.size(); + std::vector tensor_sizes(num_spatial_dims + 2, 1); + for (int i = 0; i < num_spatial_dims; ++i) { + int dim = data_format.spatial_dimension(i); + tensor_sizes[dim] = spatial_dim_sizes[i]; + } + return tensor_sizes; +} + +class PoolingTest : public ClientLibraryTestBase { + public: + ErrorSpec error_spec_{0.0001}; +}; + +XLA_TEST_F(PoolingTest, MaxPool2D) { + XlaBuilder builder(TestName()); + + XlaOp input = ConstantR4FromArray4D( + &builder, {{{{1, 2, 3, 4, 5}, {5, 4, 3, 2, 1}}}}); + auto data_format = MakeNCHWFormat(2); + auto kernel_size = ExpandWithBatchAndFeatureDimensions({2, 2}, data_format); + auto stride = kernel_size; + MaxPool(input, kernel_size, stride, Padding::kValid, data_format); + + ComputeAndCompareR4(&builder, {{{{5, 4}}}}, {}, error_spec_); +} + +XLA_TEST_F(PoolingTest, MaxPool2DWithPadding) { + XlaBuilder builder(TestName()); + + XlaOp input = ConstantR4FromArray4D( + &builder, {{{{1, 2, 3, 4, 5}, {5, 4, 3, 2, 1}}}}); + auto data_format = MakeNCHWFormat(2); + auto kernel_size = ExpandWithBatchAndFeatureDimensions({2, 2}, data_format); + auto stride = kernel_size; + MaxPool(input, kernel_size, stride, Padding::kSame, data_format); + + ComputeAndCompareR4(&builder, {{{{5, 4, 5}}}}, {}, error_spec_); +} + +XLA_TEST_F(PoolingTest, MaxPool2DWithPaddingAndStride) { + XlaBuilder builder(TestName()); + + XlaOp input = ConstantR4FromArray4D( + &builder, {{{{1, 2, 3, 4, 5}, {5, 4, 3, 2, 1}}}}); + auto data_format = MakeNCHWFormat(2); + auto kernel_size = ExpandWithBatchAndFeatureDimensions({2, 2}, data_format); + auto stride = ExpandWithBatchAndFeatureDimensions({1, 1}, data_format); + MaxPool(input, kernel_size, stride, Padding::kSame, data_format); + + ComputeAndCompareR4(&builder, {{{{5, 4, 4, 5, 5}, {5, 4, 3, 2, 1}}}}, + {}, error_spec_); +} + +XLA_TEST_F(PoolingTest, AvgPool2D) { + XlaBuilder builder(TestName()); + + XlaOp input = ConstantR4FromArray4D( + &builder, {{{{1, 2, 3, 4, 5}, {5, 4, 3, 2, 1}}}}); + auto data_format = MakeNCHWFormat(2); + auto kernel_size = ExpandWithBatchAndFeatureDimensions({2, 2}, data_format); + auto stride = kernel_size; + auto padding = MakeGeneralPadding(input, kernel_size, stride, Padding::kValid, + data_format); + AvgPool(input, kernel_size, stride, padding, data_format, + /*counts_include_padding=*/true); + + ComputeAndCompareR4(&builder, {{{{3, 3}}}}, {}, error_spec_); +} + +XLA_TEST_F(PoolingTest, AvgPool2DWithPadding) { + XlaBuilder builder(TestName()); + + XlaOp input = ConstantR4FromArray4D( + &builder, {{{{1, 2, 3, 4, 5}, {5, 4, 3, 2, 1}}}}); + auto data_format = MakeNCHWFormat(2); + auto kernel_size = ExpandWithBatchAndFeatureDimensions({2, 2}, data_format); + auto stride = kernel_size; + auto padding = MakeGeneralPadding(input, kernel_size, stride, Padding::kSame, + data_format); + AvgPool(input, kernel_size, stride, padding, data_format, + /*counts_include_padding=*/false); + + ComputeAndCompareR4(&builder, {{{{3, 3, 3}}}}, {}, error_spec_); +} + +XLA_TEST_F(PoolingTest, AvgPool2DWithPaddingAndStride) { + XlaBuilder builder(TestName()); + + XlaOp input = ConstantR4FromArray4D( + &builder, {{{{1, 2, 3, 4, 5}, {5, 4, 3, 2, 1}}}}); + auto data_format = MakeNCHWFormat(2); + auto kernel_size = ExpandWithBatchAndFeatureDimensions({2, 2}, data_format); + auto stride = ExpandWithBatchAndFeatureDimensions({1, 1}, data_format); + auto padding = MakeGeneralPadding(input, kernel_size, stride, Padding::kSame, + data_format); + AvgPool(input, kernel_size, stride, padding, data_format, + /*counts_include_padding=*/false); + + ComputeAndCompareR4(&builder, + {{{{3, 3, 3, 3, 3}, {4.5, 3.5, 2.5, 1.5, 1}}}}, {}, + error_spec_); +} + +XLA_TEST_F(PoolingTest, AvgPool2DWithGeneralPaddingCountNotIncludePadding) { + XlaBuilder builder(TestName()); + + XlaOp input = ConstantR4FromArray4D( + &builder, {{{{1, 2, 3, 4, 5}, {5, 4, 3, 2, 1}}}}); + auto data_format = MakeNCHWFormat(2); + auto kernel_size = ExpandWithBatchAndFeatureDimensions({3, 3}, data_format); + auto stride = kernel_size; + AvgPool(input, kernel_size, stride, {{1, 1}, {2, 1}}, data_format, + /*counts_include_padding=*/false); + + ComputeAndCompareR4(&builder, {{{{3, 3}}}}, {}, error_spec_); +} + +XLA_TEST_F(PoolingTest, + AvgPool2DWithGeneralPaddingCountNotIncludePaddingAndStride) { + XlaBuilder builder(TestName()); + + XlaOp input = ConstantR4FromArray4D( + &builder, {{{{1, 2, 3, 4, 5}, {5, 4, 3, 2, 1}}}}); + auto data_format = MakeNCHWFormat(2); + auto kernel_size = ExpandWithBatchAndFeatureDimensions({3, 3}, data_format); + auto stride = ExpandWithBatchAndFeatureDimensions({2, 2}, data_format); + AvgPool(input, kernel_size, stride, {{2, 1}, {1, 1}}, data_format, + /*counts_include_padding=*/false); + + ComputeAndCompareR4(&builder, {{{{1.5, 3, 4.5}, {3, 3, 3}}}}, {}, + error_spec_); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/client/lib/prng.cc b/tensorflow/compiler/xla/client/lib/prng.cc index 299a6ac2b630e94567becc3ec139b8c24eab396a..6ef81689489d8117d5951bcb75693c2e3413e4d6 100644 --- a/tensorflow/compiler/xla/client/lib/prng.cc +++ b/tensorflow/compiler/xla/client/lib/prng.cc @@ -18,7 +18,7 @@ limitations under the License. #include "tensorflow/compiler/xla/client/lib/constants.h" #include "tensorflow/compiler/xla/client/lib/math.h" #include "tensorflow/compiler/xla/client/lib/numeric.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/core/casts.h" @@ -56,7 +56,7 @@ ThreeFry2x32State ThreeFry2x32(ThreeFry2x32State input, ThreeFry2x32State key) { // Performs a single round of the Threefry2x32 algorithm, with a rotation // amount 'rotation'. - auto round = [builder](ThreeFry2x32State v, int rotation) { + auto round = [](ThreeFry2x32State v, int rotation) { v[0] = v[0] + v[1]; v[1] = RotateLeftS32(v[1], rotation); v[1] = v[0] ^ v[1]; diff --git a/tensorflow/compiler/xla/client/lib/prng.h b/tensorflow/compiler/xla/client/lib/prng.h index ac86390239668eeff1ad9eed0f6c82e10d5db004..ad000b1fa1d0655c8fccc0bb33379f2499b77f26 100644 --- a/tensorflow/compiler/xla/client/lib/prng.h +++ b/tensorflow/compiler/xla/client/lib/prng.h @@ -18,7 +18,7 @@ limitations under the License. #include -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/xla_data.pb.h" namespace xla { diff --git a/tensorflow/compiler/xla/client/lib/sorting.cc b/tensorflow/compiler/xla/client/lib/sorting.cc new file mode 100644 index 0000000000000000000000000000000000000000..a904be259a3870a679b2c4699ec01e2a11b1ce46 --- /dev/null +++ b/tensorflow/compiler/xla/client/lib/sorting.cc @@ -0,0 +1,46 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/client/lib/sorting.h" +#include "tensorflow/compiler/xla/client/lib/numeric.h" + +namespace xla { + +XlaOp TopK(XlaOp input, int64 k) { + XlaBuilder* const builder = input.builder(); + return builder->ReportErrorOrReturn([&]() -> StatusOr { + TF_ASSIGN_OR_RETURN(Shape input_shape, builder->GetShape(input)); + int last_dim = input_shape.dimensions_size() - 1; + int last_dim_size = input_shape.dimensions(last_dim); + + XlaOp iota_s32 = Iota(builder, S32, last_dim_size); + auto input_dims = input_shape.dimensions(); + std::vector broadcast_dims(input_dims.begin(), input_dims.end() - 1); + XlaOp broadcast_s32 = Broadcast(iota_s32, broadcast_dims); + XlaOp sort_result = Sort(Neg(input), broadcast_s32); + std::vector start_indices(input_shape.dimensions_size(), 0); + std::vector limit_indices(input_dims.begin(), input_dims.end()); + limit_indices[last_dim] = k; + std::vector strides(input_shape.dimensions_size(), 1); + + XlaOp values = Neg(Slice(GetTupleElement(sort_result, 0), start_indices, + limit_indices, strides)); + XlaOp indices = Slice(GetTupleElement(sort_result, 1), start_indices, + limit_indices, strides); + return Tuple(builder, {values, indices}); + }); +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/ptr_util.h b/tensorflow/compiler/xla/client/lib/sorting.h similarity index 53% rename from tensorflow/compiler/xla/ptr_util.h rename to tensorflow/compiler/xla/client/lib/sorting.h index bfcdfc62f9541ab09b94a48d5121e16bad4d43cd..b9dfafdd6f957ae050e0f5dbd076d5288235b490 100644 --- a/tensorflow/compiler/xla/ptr_util.h +++ b/tensorflow/compiler/xla/client/lib/sorting.h @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -13,23 +13,19 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_COMPILER_XLA_PTR_UTIL_H_ -#define TENSORFLOW_COMPILER_XLA_PTR_UTIL_H_ +#ifndef TENSORFLOW_COMPILER_XLA_CLIENT_LIB_SORTING_H_ +#define TENSORFLOW_COMPILER_XLA_CLIENT_LIB_SORTING_H_ -// As this was moved to tensorflow/core/util, provide indirections here to -// maintain current functionality of the library. +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" -#include - -#include -#include -#include +namespace xla { -#include "tensorflow/core/util/ptr_util.h" +// Returns a tuple composed of the top `k` values and corresponding indices in +// `input`. Output values are in descending order, from largest to smallest. +XlaOp TopK(XlaOp input, int64 k); -namespace xla { -using tensorflow::MakeUnique; -using tensorflow::WrapUnique; } // namespace xla -#endif // TENSORFLOW_COMPILER_XLA_PTR_UTIL_H_ +#endif // TENSORFLOW_COMPILER_XLA_CLIENT_LIB_SORTING_H_ diff --git a/tensorflow/compiler/xla/client/lib/sorting_test.cc b/tensorflow/compiler/xla/client/lib/sorting_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..fef98c9923096e21a755c6d730de2c7c10852b2d --- /dev/null +++ b/tensorflow/compiler/xla/client/lib/sorting_test.cc @@ -0,0 +1,60 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/client/lib/sorting.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/tests/client_library_test_base.h" +#include "tensorflow/compiler/xla/tests/test_macros.h" +#include "tensorflow/compiler/xla/types.h" + +namespace xla { +namespace { + +using SortingTest = ClientLibraryTestBase; + +XLA_TEST_F(SortingTest, TopK3From8Values) { + XlaBuilder builder(TestName()); + auto x = + ConstantR1(&builder, {0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0}); + xla::GetTupleElement(xla::TopK(x, 3), 0); + ComputeAndCompareR1(&builder, {7.0, 6.0, 5.0}, {}); +} + +XLA_TEST_F(SortingTest, TopK3From8Indices) { + XlaBuilder builder(TestName()); + auto x_rev = + ConstantR1(&builder, {7.0, 6.0, 5.0, 4.0, 3.0, 2.0, 1.0, 0.0}); + xla::GetTupleElement(xla::TopK(x_rev, 3), 1); + ComputeAndCompareR1(&builder, {0, 1, 2}, {}); +} + +XLA_TEST_F(SortingTest, TopKFullSort) { + XlaBuilder builder(TestName()); + const int kSize = 16; + std::mt19937 eng; + std::uniform_real_distribution u_dist(0.0, 100.0); + auto gen = std::bind(u_dist, eng); + std::vector inputs(kSize); + std::generate(inputs.begin(), inputs.end(), gen); + auto x = ConstantR1(&builder, inputs); + xla::GetTupleElement(xla::TopK(x, kSize), 0); + + std::sort(inputs.begin(), inputs.end(), std::greater()); + ComputeAndCompareR1(&builder, inputs, {}); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/client/lib/testing.cc b/tensorflow/compiler/xla/client/lib/testing.cc index 2de65016dd6fb65d4e19a0d37e4b65105d28d407..081fec7ad92958aa285e4be41394d7b1876e0815 100644 --- a/tensorflow/compiler/xla/client/lib/testing.cc +++ b/tensorflow/compiler/xla/client/lib/testing.cc @@ -15,8 +15,7 @@ limitations under the License. #include "tensorflow/compiler/xla/client/lib/testing.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" -#include "tensorflow/compiler/xla/client/xla_computation.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/execution_options_util.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/shape_util.h" @@ -99,14 +98,13 @@ std::vector> MakeFakeArgumentsOrDie( << "Computation should have progran shape."; auto program_shape = computation.proto().program_shape(); - // For every (unbound) parameter that the computation wants, we manufacture - // some arbitrary data so that we can invoke the computation. - std::vector> fake_arguments; - for (const Shape& parameter : program_shape.parameters()) { - fake_arguments.push_back(MakeFakeDataOrDie(parameter, client)); - } - - return fake_arguments; + // Create and run a program which produces a tuple with one element per + // parameter, then return the tuple's constituent buffers. + std::vector param_shapes(program_shape.parameters().begin(), + program_shape.parameters().end()); + auto fake_input_tuple = + MakeFakeDataOrDie(ShapeUtil::MakeTupleShape(param_shapes), client); + return client->DeconstructTuple(*fake_input_tuple).ValueOrDie(); } } // namespace xla diff --git a/tensorflow/compiler/xla/client/local_client.cc b/tensorflow/compiler/xla/client/local_client.cc index 035ee9bf4cbda17b04020efd8511504da94d2835..1cd3e9b22f9cf3383cfcbc19c79acba0e5938190 100644 --- a/tensorflow/compiler/xla/client/local_client.cc +++ b/tensorflow/compiler/xla/client/local_client.cc @@ -17,12 +17,13 @@ limitations under the License. #include +#include "absl/memory/memory.h" #include "llvm/ADT/Triple.h" #include "tensorflow/compiler/xla/client/xla_computation.h" -#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/backend.h" #include "tensorflow/compiler/xla/service/service_executable_run_options.h" #include "tensorflow/compiler/xla/service/source_map_util.h" +#include "tensorflow/compiler/xla/service/stream_pool.h" #include "tensorflow/compiler/xla/status_macros.h" using xla::source_map_util::InvalidParameterArgument; @@ -30,8 +31,8 @@ using xla::source_map_util::InvalidParameterArgument; namespace xla { namespace { -StatusOr BorrowStreamForDevice(int device_ordinal, - Backend* backend) { +StatusOr BorrowStreamForDevice(int device_ordinal, + Backend* backend) { if (device_ordinal < 0) { device_ordinal = backend->default_device_ordinal(); } @@ -100,11 +101,14 @@ Status LocalExecutable::ValidateExecutionOptions( } } - // Verify that the device the executable was built for is equivalent to the - // device it will run on. - int run_device_ordinal = run_options.device_ordinal() == -1 - ? backend_->default_device_ordinal() - : run_options.device_ordinal(); + // Verify that the device the executable was built for is equivalent + // to the device it will run on. + int run_device_ordinal = run_options.device_ordinal(); + if (run_device_ordinal == -1) { + run_device_ordinal = run_options.stream() != nullptr + ? run_options.stream()->parent()->device_ordinal() + : backend_->default_device_ordinal(); + } TF_ASSIGN_OR_RETURN(bool devices_equivalent, backend_->devices_equivalent( run_device_ordinal, build_options_.device_ordinal())); @@ -142,7 +146,7 @@ StatusOr LocalExecutable::Run( TF_RETURN_IF_ERROR( ValidateExecutionOptions(arguments, run_options, *backend_)); - Backend::StreamPtr stream; + StreamPool::Ptr stream; if (run_options.stream() == nullptr) { // NB! The lifetime of `stream` needs to match the lifetime of // `actual_options` (otherwise we will end up using a returned stream in @@ -253,9 +257,9 @@ StatusOr> LocalClient::Compile( TF_ASSIGN_OR_RETURN(std::unique_ptr executable, local_service_->CompileExecutable( computation, argument_layouts, updated_options)); - return WrapUnique(new LocalExecutable(std::move(executable), - local_service_->mutable_backend(), - updated_options)); + return absl::WrapUnique(new LocalExecutable(std::move(executable), + local_service_->mutable_backend(), + updated_options)); } StatusOr LocalClient::LiteralToShapedBuffer( @@ -299,7 +303,7 @@ StatusOr> LocalClient::TransferFromOutfeedLocal( const Shape& shape, int device_ordinal) { TF_ASSIGN_OR_RETURN(se::StreamExecutor * executor, backend().stream_executor(device_ordinal)); - auto literal = MakeUnique(); + auto literal = Literal::CreateFromShape(shape); TF_RETURN_IF_ERROR(backend().transfer_manager()->TransferLiteralFromOutfeed( executor, shape, literal.get())); return std::move(literal); diff --git a/tensorflow/compiler/xla/client/xla_client/xla_builder.cc b/tensorflow/compiler/xla/client/xla_builder.cc similarity index 91% rename from tensorflow/compiler/xla/client/xla_client/xla_builder.cc rename to tensorflow/compiler/xla/client/xla_builder.cc index 152335e22ace5d437cacf4e4fd0132de54e1dc6d..54fe87a7a8d2b2979e719affc31e63402cb49dba 100644 --- a/tensorflow/compiler/xla/client/xla_client/xla_builder.cc +++ b/tensorflow/compiler/xla/client/xla_builder.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include #include @@ -21,6 +21,8 @@ limitations under the License. #include #include +#include "absl/algorithm/container.h" +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/client/sharding_builder.h" #include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/execution_options_util.h" @@ -45,21 +47,6 @@ int64 GetUniqueId() { return id; } -// Returns true if an instruction with the given opcode can be the root of the -// computation. -bool CanBeRoot(HloOpcode opcode) { - switch (opcode) { - case HloOpcode::kAfterAll: - case HloOpcode::kSend: - case HloOpcode::kSendDone: - case HloOpcode::kOutfeed: - case HloOpcode::kTrace: - return false; - default: - return true; - } -} - } // namespace XlaOp operator-(const XlaOp& x) { return Neg(x); } @@ -142,28 +129,13 @@ XlaOp XlaBuilder::ReportErrorOrReturn( return ReportErrorOrReturn(op_creator()); } -StatusOr XlaBuilder::GetProgramShape(int64* root_id) const { +StatusOr XlaBuilder::GetProgramShape(int64 root_id) const { TF_RETURN_IF_ERROR(first_error_); - - TF_RET_CHECK(root_id != nullptr); + TF_RET_CHECK((root_id >= 0) && (root_id < instructions_.size())); ProgramShape program_shape; - // Not all instructions can be roots. Walk backwards from the last added - // instruction until a valid root is found. - int64 index = instructions_.size() - 1; - for (; index >= 0; index--) { - TF_ASSIGN_OR_RETURN(HloOpcode opcode, - StringToHloOpcode(instructions_[index].opcode())); - if (CanBeRoot(opcode)) { - break; - } - } - if (index < 0) { - return FailedPrecondition("no root instruction was found"); - } - *root_id = instructions_[index].id(); - *program_shape.mutable_result() = instructions_[index].shape(); + *program_shape.mutable_result() = instructions_[root_id].shape(); // Check that the parameter numbers are continuous from 0, and add parameter // shapes and names to the program shape. @@ -188,8 +160,15 @@ StatusOr XlaBuilder::GetProgramShape(int64* root_id) const { } StatusOr XlaBuilder::GetProgramShape() const { - int64 root; - return GetProgramShape(&root); + TF_RET_CHECK(!instructions_.empty()); + return GetProgramShape(instructions_.back().id()); +} + +StatusOr XlaBuilder::GetProgramShape(XlaOp root) const { + if (root.builder_ != this) { + return InvalidArgument("Given root operation is not in this computation."); + } + return GetProgramShape(root.handle()); } void XlaBuilder::IsConstantVisitor(const int64 op_handle, @@ -257,17 +236,29 @@ StatusOr XlaBuilder::Build() { first_error_backtrace_.Dump(tensorflow::DebugWriteToString, &backtrace); return AppendStatus(first_error_, backtrace); } + return Build(instructions_.back().id()); +} + +StatusOr XlaBuilder::Build(XlaOp root) { + if (root.builder_ != this) { + return InvalidArgument("Given root operation is not in this computation."); + } + return Build(root.handle()); +} + +StatusOr XlaBuilder::Build(int64 root_id) { + if (!first_error_.ok()) { + string backtrace; + first_error_backtrace_.Dump(tensorflow::DebugWriteToString, &backtrace); + return AppendStatus(first_error_, backtrace); + } HloComputationProto entry; entry.set_id(GetUniqueId()); // Give the computation a global unique id. entry.set_name(StrCat(name_, entry.id())); // Ensure that the name is unique. - { - int64 root_id; - TF_ASSIGN_OR_RETURN(*entry.mutable_program_shape(), - GetProgramShape(&root_id)); - entry.set_root_id(root_id); - } + TF_ASSIGN_OR_RETURN(*entry.mutable_program_shape(), GetProgramShape(root_id)); + entry.set_root_id(root_id); for (auto& instruction : instructions_) { // Ensures that the instruction names are unique among the whole graph. @@ -480,8 +471,8 @@ XlaOp XlaBuilder::Call(const XlaComputation& computation, HloInstructionProto instr; std::vector operand_shape_ptrs; TF_ASSIGN_OR_RETURN(const auto& operand_shapes, GetOperandShapes(operands)); - c_transform(operand_shapes, std::back_inserter(operand_shape_ptrs), - [](const Shape& shape) { return &shape; }); + absl::c_transform(operand_shapes, std::back_inserter(operand_shape_ptrs), + [](const Shape& shape) { return &shape; }); TF_ASSIGN_OR_RETURN(const ProgramShape& called_program_shape, computation.GetProgramShape()); TF_ASSIGN_OR_RETURN( @@ -633,8 +624,8 @@ XlaOp XlaBuilder::ConcatInDim(tensorflow::gtl::ArraySlice operands, std::vector operand_shape_ptrs; TF_ASSIGN_OR_RETURN(const auto& operand_shapes, GetOperandShapes(operands)); - c_transform(operand_shapes, std::back_inserter(operand_shape_ptrs), - [](const Shape& shape) { return &shape; }); + absl::c_transform(operand_shapes, std::back_inserter(operand_shape_ptrs), + [](const Shape& shape) { return &shape; }); TF_ASSIGN_OR_RETURN( *instr.mutable_shape(), ShapeInference::InferConcatOpShape(operand_shape_ptrs, dimension)); @@ -760,8 +751,8 @@ XlaOp XlaBuilder::Tuple(tensorflow::gtl::ArraySlice elements) { HloInstructionProto instr; std::vector operand_shape_ptrs; TF_ASSIGN_OR_RETURN(const auto& operand_shapes, GetOperandShapes(elements)); - c_transform(operand_shapes, std::back_inserter(operand_shape_ptrs), - [](const Shape& shape) { return &shape; }); + absl::c_transform(operand_shapes, std::back_inserter(operand_shape_ptrs), + [](const Shape& shape) { return &shape; }); TF_ASSIGN_OR_RETURN(*instr.mutable_shape(), ShapeInference::InferVariadicOpShape( HloOpcode::kTuple, operand_shape_ptrs)); @@ -893,24 +884,28 @@ Status XlaBuilder::VerifyConvolution( XlaOp XlaBuilder::Conv(const XlaOp& lhs, const XlaOp& rhs, tensorflow::gtl::ArraySlice window_strides, - Padding padding) { + Padding padding, int64 feature_group_count) { return ConvWithGeneralDimensions( lhs, rhs, window_strides, padding, - CreateDefaultConvDimensionNumbers(window_strides.size())); + CreateDefaultConvDimensionNumbers(window_strides.size()), + feature_group_count); } XlaOp XlaBuilder::ConvWithGeneralPadding( const XlaOp& lhs, const XlaOp& rhs, tensorflow::gtl::ArraySlice window_strides, - tensorflow::gtl::ArraySlice> padding) { + tensorflow::gtl::ArraySlice> padding, + int64 feature_group_count) { return ConvGeneral(lhs, rhs, window_strides, padding, - CreateDefaultConvDimensionNumbers(window_strides.size())); + CreateDefaultConvDimensionNumbers(window_strides.size()), + feature_group_count); } XlaOp XlaBuilder::ConvWithGeneralDimensions( const XlaOp& lhs, const XlaOp& rhs, tensorflow::gtl::ArraySlice window_strides, Padding padding, - const ConvolutionDimensionNumbers& dimension_numbers) { + const ConvolutionDimensionNumbers& dimension_numbers, + int64 feature_group_count) { return ReportErrorOrReturn([&]() -> StatusOr { TF_ASSIGN_OR_RETURN(const Shape& lhs_shape, GetShape(lhs)); TF_ASSIGN_OR_RETURN(const Shape& rhs_shape, GetShape(rhs)); @@ -937,7 +932,7 @@ XlaOp XlaBuilder::ConvWithGeneralDimensions( return ConvGeneral(lhs, rhs, window_strides, MakePadding(base_area_dimensions, window_dimensions, window_strides, padding), - dimension_numbers); + dimension_numbers, feature_group_count); }); } @@ -945,9 +940,10 @@ XlaOp XlaBuilder::ConvGeneral( const XlaOp& lhs, const XlaOp& rhs, tensorflow::gtl::ArraySlice window_strides, tensorflow::gtl::ArraySlice> padding, - const ConvolutionDimensionNumbers& dimension_numbers) { + const ConvolutionDimensionNumbers& dimension_numbers, + int64 feature_group_count) { return ConvGeneralDilated(lhs, rhs, window_strides, padding, {}, {}, - dimension_numbers); + dimension_numbers, feature_group_count); } XlaOp XlaBuilder::ConvGeneralDilated( @@ -956,7 +952,8 @@ XlaOp XlaBuilder::ConvGeneralDilated( tensorflow::gtl::ArraySlice> padding, tensorflow::gtl::ArraySlice lhs_dilation, tensorflow::gtl::ArraySlice rhs_dilation, - const ConvolutionDimensionNumbers& dimension_numbers) { + const ConvolutionDimensionNumbers& dimension_numbers, + int64 feature_group_count) { return ReportErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; TF_ASSIGN_OR_RETURN(const Shape& lhs_shape, GetShape(lhs)); @@ -975,12 +972,13 @@ XlaOp XlaBuilder::ConvGeneralDilated( MakeWindow(window_dimensions, window_strides, padding, lhs_dilation, rhs_dilation)); - TF_ASSIGN_OR_RETURN( - *instr.mutable_shape(), - ShapeInference::InferConvolveShape(lhs_shape, rhs_shape, instr.window(), - dimension_numbers)); + TF_ASSIGN_OR_RETURN(*instr.mutable_shape(), + ShapeInference::InferConvolveShape( + lhs_shape, rhs_shape, instr.window(), + dimension_numbers, feature_group_count)); *instr.mutable_convolution_dimension_numbers() = dimension_numbers; + instr.set_feature_group_count(feature_group_count); return AddInstruction(std::move(instr), HloOpcode::kConvolution, {lhs, rhs}); @@ -1084,6 +1082,23 @@ XlaOp XlaBuilder::Infeed(const Shape& shape, const string& config) { "Replicated sharding is not yet supported for infeeds"); } + // Infeed takes a single token operand. Generate the token to pass to the + // infeed. + XlaOp token; + auto make_token = [&]() { + HloInstructionProto token_instr; + *token_instr.mutable_shape() = ShapeUtil::MakeTokenShape(); + return AddInstruction(std::move(token_instr), HloOpcode::kAfterAll, {}); + }; + if (sharding()) { + // Arbitrarily assign token to device 0. + OpSharding sharding = sharding_builder::AssignDevice(0); + XlaScopedShardingAssignment scoped_sharding(this, sharding); + TF_ASSIGN_OR_RETURN(token, make_token()); + } else { + TF_ASSIGN_OR_RETURN(token, make_token()); + } + // The sharding is set by the client according to the data tuple shape. // However, the shape of the infeed instruction is a tuple containing the // data and a token. For tuple sharding type, the sharding must be changed @@ -1099,11 +1114,11 @@ XlaOp XlaBuilder::Infeed(const Shape& shape, const string& config) { sharding_builder::AssignDevice(0); XlaScopedShardingAssignment scoped_sharding(this, infeed_instruction_sharding); - TF_ASSIGN_OR_RETURN(infeed, - AddInstruction(std::move(instr), HloOpcode::kInfeed)); + TF_ASSIGN_OR_RETURN(infeed, AddInstruction(std::move(instr), + HloOpcode::kInfeed, {token})); } else { - TF_ASSIGN_OR_RETURN(infeed, - AddInstruction(std::move(instr), HloOpcode::kInfeed)); + TF_ASSIGN_OR_RETURN(infeed, AddInstruction(std::move(instr), + HloOpcode::kInfeed, {token})); } // The infeed instruction produces a tuple of the infed data and a token @@ -1169,8 +1184,15 @@ void XlaBuilder::Outfeed(const XlaOp& operand, const Shape& shape_with_layout, instr.set_outfeed_config(outfeed_config); + // Outfeed takes a token as its second operand. Generate the token to pass + // to the outfeed. + HloInstructionProto token_instr; + *token_instr.mutable_shape() = ShapeUtil::MakeTokenShape(); + TF_ASSIGN_OR_RETURN(XlaOp token, AddInstruction(std::move(token_instr), + HloOpcode::kAfterAll, {})); + TF_RETURN_IF_ERROR( - AddInstruction(std::move(instr), HloOpcode::kOutfeed, {operand}) + AddInstruction(std::move(instr), HloOpcode::kOutfeed, {operand, token}) .status()); // The outfeed instruction produces a token. However, existing users expect @@ -1520,8 +1542,8 @@ XlaOp XlaBuilder::Map(tensorflow::gtl::ArraySlice operands, HloInstructionProto instr; std::vector operand_shape_ptrs; TF_ASSIGN_OR_RETURN(const auto& operand_shapes, GetOperandShapes(operands)); - c_transform(operand_shapes, std::back_inserter(operand_shape_ptrs), - [](const Shape& shape) { return &shape; }); + absl::c_transform(operand_shapes, std::back_inserter(operand_shape_ptrs), + [](const Shape& shape) { return &shape; }); TF_ASSIGN_OR_RETURN(const ProgramShape& called_program_shape, computation.GetProgramShape()); TF_ASSIGN_OR_RETURN( @@ -1611,27 +1633,53 @@ XlaOp XlaBuilder::While(const XlaComputation& condition, }); } -XlaOp XlaBuilder::Gather(const XlaOp& input, const XlaOp& gather_indices, +XlaOp XlaBuilder::Gather(const XlaOp& input, const XlaOp& start_indices, const GatherDimensionNumbers& dimension_numbers, - tensorflow::gtl::ArraySlice window_bounds) { + tensorflow::gtl::ArraySlice slice_sizes) { return ReportErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; TF_ASSIGN_OR_RETURN(const Shape& input_shape, GetShape(input)); - TF_ASSIGN_OR_RETURN(const Shape& gather_indices_shape, - GetShape(gather_indices)); + TF_ASSIGN_OR_RETURN(const Shape& start_indices_shape, + GetShape(start_indices)); TF_ASSIGN_OR_RETURN( *instr.mutable_shape(), - ShapeInference::InferGatherShape(input_shape, gather_indices_shape, - dimension_numbers, window_bounds)); + ShapeInference::InferGatherShape(input_shape, start_indices_shape, + dimension_numbers, slice_sizes)); *instr.mutable_gather_dimension_numbers() = dimension_numbers; - for (int64 bound : window_bounds) { - instr.add_gather_window_bounds(bound); + for (int64 bound : slice_sizes) { + instr.add_gather_slice_sizes(bound); } return AddInstruction(std::move(instr), HloOpcode::kGather, - {input, gather_indices}); + {input, start_indices}); + }); +} + +XlaOp XlaBuilder::Scatter(const XlaOp& input, const XlaOp& scatter_indices, + const XlaOp& updates, + const XlaComputation& update_computation, + const ScatterDimensionNumbers& dimension_numbers) { + return ReportErrorOrReturn([&]() -> StatusOr { + HloInstructionProto instr; + + TF_ASSIGN_OR_RETURN(const Shape& input_shape, GetShape(input)); + TF_ASSIGN_OR_RETURN(const Shape& scatter_indices_shape, + GetShape(scatter_indices)); + TF_ASSIGN_OR_RETURN(const Shape& updates_shape, GetShape(updates)); + TF_ASSIGN_OR_RETURN(const ProgramShape& to_apply_shape, + update_computation.GetProgramShape()); + TF_ASSIGN_OR_RETURN(*instr.mutable_shape(), + ShapeInference::InferScatterShape( + input_shape, scatter_indices_shape, updates_shape, + to_apply_shape, dimension_numbers)); + + *instr.mutable_scatter_dimension_numbers() = dimension_numbers; + + AddCalledComputation(update_computation, &instr); + return AddInstruction(std::move(instr), HloOpcode::kScatter, + {input, scatter_indices, updates}); }); } @@ -1681,7 +1729,7 @@ XlaOp XlaBuilder::Reduce( TF_ASSIGN_OR_RETURN(*instr.mutable_shape(), ShapeInference::InferReduceShape( - operand_shape, init_shape, dimensions_to_reduce, + {&operand_shape, &init_shape}, dimensions_to_reduce, called_program_shape)); for (int64 dim : dimensions_to_reduce) { @@ -1866,6 +1914,61 @@ XlaOp XlaBuilder::CrossReplicaSum( }); } +XlaOp XlaBuilder::AllToAll(const XlaOp& operand, int64 split_dimension, + int64 concat_dimension, int64 split_count, + const std::vector& replica_groups) { + return ReportErrorOrReturn([&]() -> StatusOr { + TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand)); + + // The HloInstruction for Alltoall currently only handles the data + // communication: it accepts N already split parts and scatters them to N + // cores, and each core gathers the N received parts into a tuple as the + // output. So here we explicitly split the operand before the hlo alltoall, + // and concat the tuple elements. + // + // First, run shape inference to make sure the shapes are valid. + TF_RETURN_IF_ERROR( + ShapeInference::InferAllToAllShape(operand_shape, split_dimension, + concat_dimension, split_count) + .status()); + + // Split into N parts. + std::vector slices; + slices.reserve(split_count); + const int64 block_size = + operand_shape.dimensions(split_dimension) / split_count; + for (int i = 0; i < split_count; i++) { + slices.push_back(SliceInDim(operand, /*start_index=*/i * block_size, + /*limit_index=*/(i + 1) * block_size, + /*stride=*/1, /*dimno=*/split_dimension)); + } + + // Handle data communication. + HloInstructionProto instr; + TF_ASSIGN_OR_RETURN(auto slice_shapes, this->GetOperandShapes(slices)); + std::vector slice_shape_ptrs; + absl::c_transform(slice_shapes, std::back_inserter(slice_shape_ptrs), + [](const Shape& shape) { return &shape; }); + TF_ASSIGN_OR_RETURN( + *instr.mutable_shape(), + ShapeInference::InferAllToAllTupleShape(slice_shape_ptrs)); + for (const ReplicaGroup& group : replica_groups) { + *instr.add_replica_groups() = group; + } + TF_ASSIGN_OR_RETURN( + XlaOp alltoall, + AddInstruction(std::move(instr), HloOpcode::kAllToAll, slices)); + + // Concat the N received parts. + std::vector received; + received.reserve(split_count); + for (int i = 0; i < split_count; i++) { + received.push_back(this->GetTupleElement(alltoall, i)); + } + return this->ConcatInDim(received, concat_dimension); + }); +} + XlaOp XlaBuilder::SelectAndScatter( const XlaOp& operand, const XlaComputation& select, tensorflow::gtl::ArraySlice window_dimensions, @@ -2137,11 +2240,6 @@ StatusOr XlaBuilder::BuildConstantSubGraph( TF_ASSIGN_OR_RETURN(const HloInstructionProto* root, LookUpInstruction(root_op)); - TF_ASSIGN_OR_RETURN(HloOpcode opcode, StringToHloOpcode(root->opcode())); - if (!CanBeRoot(opcode)) { - return InvalidArgument("the operand with opcode %s cannot be root", - root->opcode().c_str()); - } HloComputationProto entry; entry.set_id(GetUniqueId()); // Give the computation a global unique id. @@ -2200,7 +2298,7 @@ StatusOr XlaBuilder::BuildConstantSubGraph( std::unique_ptr XlaBuilder::CreateSubBuilder( const string& computation_name) { - auto sub_builder = MakeUnique(computation_name); + auto sub_builder = absl::make_unique(computation_name); sub_builder->parent_builder_ = this; sub_builder->die_immediately_on_error_ = this->die_immediately_on_error_; return sub_builder; @@ -2473,32 +2571,38 @@ XlaOp DotGeneral(const XlaOp& lhs, const XlaOp& rhs, } XlaOp Conv(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice window_strides, Padding padding) { - return lhs.builder()->Conv(lhs, rhs, window_strides, padding); + tensorflow::gtl::ArraySlice window_strides, Padding padding, + int64 feature_group_count) { + return lhs.builder()->Conv(lhs, rhs, window_strides, padding, + feature_group_count); } XlaOp ConvWithGeneralPadding( const XlaOp& lhs, const XlaOp& rhs, tensorflow::gtl::ArraySlice window_strides, - tensorflow::gtl::ArraySlice> padding) { + tensorflow::gtl::ArraySlice> padding, + int64 feature_group_count) { return lhs.builder()->ConvWithGeneralPadding(lhs, rhs, window_strides, - padding); + padding, feature_group_count); } XlaOp ConvWithGeneralDimensions( const XlaOp& lhs, const XlaOp& rhs, tensorflow::gtl::ArraySlice window_strides, Padding padding, - const ConvolutionDimensionNumbers& dimension_numbers) { + const ConvolutionDimensionNumbers& dimension_numbers, + int64 feature_group_count) { return lhs.builder()->ConvWithGeneralDimensions(lhs, rhs, window_strides, - padding, dimension_numbers); + padding, dimension_numbers, + feature_group_count); } XlaOp ConvGeneral(const XlaOp& lhs, const XlaOp& rhs, tensorflow::gtl::ArraySlice window_strides, tensorflow::gtl::ArraySlice> padding, - const ConvolutionDimensionNumbers& dimension_numbers) { + const ConvolutionDimensionNumbers& dimension_numbers, + int64 feature_group_count) { return lhs.builder()->ConvGeneral(lhs, rhs, window_strides, padding, - dimension_numbers); + dimension_numbers, feature_group_count); } XlaOp ConvGeneralDilated( @@ -2507,10 +2611,11 @@ XlaOp ConvGeneralDilated( tensorflow::gtl::ArraySlice> padding, tensorflow::gtl::ArraySlice lhs_dilation, tensorflow::gtl::ArraySlice rhs_dilation, - const ConvolutionDimensionNumbers& dimension_numbers) { - return lhs.builder()->ConvGeneralDilated(lhs, rhs, window_strides, padding, - lhs_dilation, rhs_dilation, - dimension_numbers); + const ConvolutionDimensionNumbers& dimension_numbers, + int64 feature_group_count) { + return lhs.builder()->ConvGeneralDilated( + lhs, rhs, window_strides, padding, lhs_dilation, rhs_dilation, + dimension_numbers, feature_group_count); } XlaOp Fft(const XlaOp& operand, FftType fft_type, @@ -2667,6 +2772,13 @@ XlaOp CrossReplicaSum( replica_group_ids, channel_id); } +XlaOp AllToAll(const XlaOp& operand, int64 split_dimension, + int64 concat_dimension, int64 split_count, + const std::vector& replica_groups) { + return operand.builder()->AllToAll(operand, split_dimension, concat_dimension, + split_count, replica_groups); +} + XlaOp SelectAndScatter(const XlaOp& operand, const XlaComputation& select, tensorflow::gtl::ArraySlice window_dimensions, tensorflow::gtl::ArraySlice window_strides, @@ -2796,11 +2908,18 @@ XlaOp ReducePrecision(const XlaOp& operand, const int exponent_bits, mantissa_bits); } -XlaOp Gather(const XlaOp& input, const XlaOp& gather_indices, +XlaOp Gather(const XlaOp& input, const XlaOp& start_indices, const GatherDimensionNumbers& dimension_numbers, - tensorflow::gtl::ArraySlice window_bounds) { - return input.builder()->Gather(input, gather_indices, dimension_numbers, - window_bounds); + tensorflow::gtl::ArraySlice slice_sizes) { + return input.builder()->Gather(input, start_indices, dimension_numbers, + slice_sizes); +} + +XlaOp Scatter(const XlaOp& input, const XlaOp& scatter_indices, + const XlaOp& updates, const XlaComputation& update_computation, + const ScatterDimensionNumbers& dimension_numbers) { + return input.builder()->Scatter(input, scatter_indices, updates, + update_computation, dimension_numbers); } void Send(const XlaOp& operand, const ChannelHandle& handle) { diff --git a/tensorflow/compiler/xla/client/xla_client/xla_builder.h b/tensorflow/compiler/xla/client/xla_builder.h similarity index 95% rename from tensorflow/compiler/xla/client/xla_client/xla_builder.h rename to tensorflow/compiler/xla/client/xla_builder.h index 980e84e40c30f87484d24da9a9369b8d7475f632..469d5048b26527bbcf20cbe11b01c8ec7a4bc1e4 100644 --- a/tensorflow/compiler/xla/client/xla_client/xla_builder.h +++ b/tensorflow/compiler/xla/client/xla_builder.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_COMPILER_XLA_CLIENT_XLA_CLIENT_XLA_BUILDER_H_ -#define TENSORFLOW_COMPILER_XLA_CLIENT_XLA_CLIENT_XLA_BUILDER_H_ +#ifndef TENSORFLOW_COMPILER_XLA_CLIENT_XLA_BUILDER_H_ +#define TENSORFLOW_COMPILER_XLA_CLIENT_XLA_BUILDER_H_ #include #include @@ -195,9 +195,14 @@ class XlaBuilder { // Builds the computation with the requested operations, or returns a non-ok // status. Note that all ops that have been enqueued will be moved to the - // computation being returned. + // computation being returned. The root of the computation will be the last + // added operation. StatusOr Build(); + // Overload of Build which specifies a particular root instruction for the + // computation. + StatusOr Build(XlaOp root); + // Builds the computation with the requested operations, or notes an error in // the parent XlaBuilder and returns an empty computation if building failed. // This function is intended to be used where the returned XlaComputation is @@ -225,9 +230,14 @@ class XlaBuilder { // Returns the shape of the given op. StatusOr GetShape(const XlaOp& op) const; - // Returns the (inferred) result for the current computation's shape. + // Returns the (inferred) result for the current computation's shape. This + // assumes the root instruction is the last added instruction. StatusOr GetProgramShape() const; + // Returns the (inferred) result for the current computation's shape using the + // given operation as the root. + StatusOr GetProgramShape(XlaOp root) const; + // Reports an error to the builder, by // * storing it internally and capturing a backtrace if it's the first error // (this deferred value will be produced on the call to @@ -255,6 +265,9 @@ class XlaBuilder { StatusOr IsConstant(const XlaOp& operand) const; private: + // Build helper which takes the id of the root operation.. + StatusOr Build(int64 root_id); + // Enqueues a "retrieve parameter value" instruction for a parameter that was // passed to the computation. XlaOp Parameter(int64 parameter_number, const Shape& shape, @@ -499,22 +512,24 @@ class XlaBuilder { // Enqueues a convolution instruction onto the computation, which uses the // default convolution dimension numbers. XlaOp Conv(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice window_strides, - Padding padding); + tensorflow::gtl::ArraySlice window_strides, Padding padding, + int64 feature_group_count = 1); // Enqueues a convolution instruction onto the computation, with the caller // provided padding configuration in the format returned by MakePadding(). XlaOp ConvWithGeneralPadding( const XlaOp& lhs, const XlaOp& rhs, tensorflow::gtl::ArraySlice window_strides, - tensorflow::gtl::ArraySlice> padding); + tensorflow::gtl::ArraySlice> padding, + int64 feature_group_count = 1); // Enqueues a convolution instruction onto the computation, with the caller // provided dimension numbers configuration. XlaOp ConvWithGeneralDimensions( const XlaOp& lhs, const XlaOp& rhs, tensorflow::gtl::ArraySlice window_strides, Padding padding, - const ConvolutionDimensionNumbers& dimension_numbers); + const ConvolutionDimensionNumbers& dimension_numbers, + int64 feature_group_count = 1); // Enqueues a convolution instruction onto the computation, with the caller // provided padding configuration as well as the dimension numbers. @@ -522,7 +537,8 @@ class XlaBuilder { const XlaOp& lhs, const XlaOp& rhs, tensorflow::gtl::ArraySlice window_strides, tensorflow::gtl::ArraySlice> padding, - const ConvolutionDimensionNumbers& dimension_numbers); + const ConvolutionDimensionNumbers& dimension_numbers, + int64 feature_group_count = 1); // Enqueues a convolution instruction onto the computation, with the caller // provided padding configuration, dilation factors and dimension numbers. @@ -532,7 +548,8 @@ class XlaBuilder { tensorflow::gtl::ArraySlice> padding, tensorflow::gtl::ArraySlice lhs_dilation, tensorflow::gtl::ArraySlice rhs_dilation, - const ConvolutionDimensionNumbers& dimension_numbers); + const ConvolutionDimensionNumbers& dimension_numbers, + int64 feature_group_count = 1); // Enqueues an FFT instruction onto the computation, of the given type and // with the given FFT length. @@ -686,9 +703,9 @@ class XlaBuilder { // For example, we have 4 replicas, then replica_group_ids={0,1,0,1} means, // replica 0 and 2 are in subgroup 0, replica 1 and 3 are in subgroup 1. // - // - `channel_id`: for Allreduce nodes from different models, if they have the - // same channel_id, they will be 'Allreduce'd. If empty, Allreduce will not be - // applied cross models. + // - `channel_id`: for Allreduce nodes from different modules, if they have + // the same channel_id, they will be 'Allreduce'd. If empty, Allreduce will + // not be applied cross modules. // // TODO(b/79737069): Rename this to AllReduce when it's ready to use. XlaOp CrossReplicaSum( @@ -697,6 +714,13 @@ class XlaBuilder { const tensorflow::gtl::optional& channel_id = tensorflow::gtl::nullopt); + // Enqueues an operation that do an Alltoall of the operand cross cores. + // + // TODO(b/110096724): This is NOT YET ready to use. + XlaOp AllToAll(const XlaOp& operand, int64 split_dimension, + int64 concat_dimension, int64 split_count, + const std::vector& replica_groups); + // Enqueues an operation that scatters the `source` array to the selected // indices of each window. XlaOp SelectAndScatter(const XlaOp& operand, const XlaComputation& select, @@ -853,9 +877,14 @@ class XlaBuilder { const int mantissa_bits); // Enqueues a Gather node onto the computation. - XlaOp Gather(const XlaOp& input, const XlaOp& gather_indices, + XlaOp Gather(const XlaOp& input, const XlaOp& start_indices, const GatherDimensionNumbers& dimension_numbers, - tensorflow::gtl::ArraySlice window_bounds); + tensorflow::gtl::ArraySlice slice_sizes); + + // Enqueues a Scatter node onto the computation. + XlaOp Scatter(const XlaOp& input, const XlaOp& scatter_indices, + const XlaOp& updates, const XlaComputation& update_computation, + const ScatterDimensionNumbers& dimension_numbers); // Enqueues a Send node onto the computation for device-to-device // communication, to send the given operand to a Recv instruction that shares @@ -964,9 +993,8 @@ class XlaBuilder { // shape. StatusOr Reshape(const Shape& shape, const XlaOp& operand); - // Returns the (inferred) result for the program shape for the current - // computation and fills the root_id in the pointer. - StatusOr GetProgramShape(int64* root_id) const; + // Returns the (inferred) result for the program shape using the given root. + StatusOr GetProgramShape(int64 root_id) const; // Returns shapes for the operands. StatusOr> GetOperandShapes( @@ -1137,27 +1165,31 @@ class XlaBuilder { const DotDimensionNumbers& dimension_numbers); friend XlaOp Conv(const XlaOp& lhs, const XlaOp& rhs, tensorflow::gtl::ArraySlice window_strides, - Padding padding); + Padding padding, int64 feature_group_count); friend XlaOp ConvWithGeneralPadding( const XlaOp& lhs, const XlaOp& rhs, tensorflow::gtl::ArraySlice window_strides, - tensorflow::gtl::ArraySlice> padding); + tensorflow::gtl::ArraySlice> padding, + int64 feature_group_count); friend XlaOp ConvWithGeneralDimensions( const XlaOp& lhs, const XlaOp& rhs, tensorflow::gtl::ArraySlice window_strides, Padding padding, - const ConvolutionDimensionNumbers& dimension_numbers); + const ConvolutionDimensionNumbers& dimension_numbers, + int64 feature_group_count); friend XlaOp ConvGeneral( const XlaOp& lhs, const XlaOp& rhs, tensorflow::gtl::ArraySlice window_strides, tensorflow::gtl::ArraySlice> padding, - const ConvolutionDimensionNumbers& dimension_numbers); + const ConvolutionDimensionNumbers& dimension_numbers, + int64 feature_group_count); friend XlaOp ConvGeneralDilated( const XlaOp& lhs, const XlaOp& rhs, tensorflow::gtl::ArraySlice window_strides, tensorflow::gtl::ArraySlice> padding, tensorflow::gtl::ArraySlice lhs_dilation, tensorflow::gtl::ArraySlice rhs_dilation, - const ConvolutionDimensionNumbers& dimension_numbers); + const ConvolutionDimensionNumbers& dimension_numbers, + int64 feature_group_count); friend XlaOp Fft(const XlaOp& operand, FftType fft_type, tensorflow::gtl::ArraySlice fft_length); friend XlaOp Infeed(XlaBuilder* builder, const Shape& shape, @@ -1229,6 +1261,9 @@ class XlaBuilder { const XlaOp& operand, const XlaComputation& computation, tensorflow::gtl::ArraySlice replica_group_ids, const tensorflow::gtl::optional& channel_id); + friend XlaOp AllToAll(const XlaOp& operand, int64 split_dimension, + int64 concat_dimension, int64 split_count, + const std::vector& replica_groups); friend XlaOp SelectAndScatter( const XlaOp& operand, const XlaComputation& select, tensorflow::gtl::ArraySlice window_dimensions, @@ -1293,9 +1328,13 @@ class XlaBuilder { const XlaComputation& false_computation); friend XlaOp ReducePrecision(const XlaOp& operand, const int exponent_bits, const int mantissa_bits); - friend XlaOp Gather(const XlaOp& input, const XlaOp& gather_indices, + friend XlaOp Gather(const XlaOp& input, const XlaOp& start_indices, const GatherDimensionNumbers& dimension_numbers, - tensorflow::gtl::ArraySlice window_bounds); + tensorflow::gtl::ArraySlice slice_sizes); + friend XlaOp Scatter(const XlaOp& input, const XlaOp& scatter_indices, + const XlaOp& updates, + const XlaComputation& update_computation, + const ScatterDimensionNumbers& dimension_numbers); friend void Send(const XlaOp& operand, const ChannelHandle& handle); friend XlaOp Recv(XlaBuilder* builder, const Shape& shape, const ChannelHandle& handle); @@ -1615,28 +1654,32 @@ XlaOp DotGeneral(const XlaOp& lhs, const XlaOp& rhs, // Enqueues a convolution instruction onto the computation, which uses the // default convolution dimension numbers. XlaOp Conv(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice window_strides, Padding padding); + tensorflow::gtl::ArraySlice window_strides, Padding padding, + int64 feature_group_count = 1); // Enqueues a convolution instruction onto the computation, with the caller // provided padding configuration in the format returned by MakePadding(). XlaOp ConvWithGeneralPadding( const XlaOp& lhs, const XlaOp& rhs, tensorflow::gtl::ArraySlice window_strides, - tensorflow::gtl::ArraySlice> padding); + tensorflow::gtl::ArraySlice> padding, + int64 feature_group_count = 1); // Enqueues a convolution instruction onto the computation, with the caller // provided dimension numbers configuration. XlaOp ConvWithGeneralDimensions( const XlaOp& lhs, const XlaOp& rhs, tensorflow::gtl::ArraySlice window_strides, Padding padding, - const ConvolutionDimensionNumbers& dimension_numbers); + const ConvolutionDimensionNumbers& dimension_numbers, + int64 feature_group_count = 1); // Enqueues a convolution instruction onto the computation, with the caller // provided padding configuration as well as the dimension numbers. XlaOp ConvGeneral(const XlaOp& lhs, const XlaOp& rhs, tensorflow::gtl::ArraySlice window_strides, tensorflow::gtl::ArraySlice> padding, - const ConvolutionDimensionNumbers& dimension_numbers); + const ConvolutionDimensionNumbers& dimension_numbers, + int64 feature_group_count = 1); // Enqueues a convolution instruction onto the computation, with the caller // provided padding configuration, dilation factors and dimension numbers. @@ -1646,7 +1689,8 @@ XlaOp ConvGeneralDilated( tensorflow::gtl::ArraySlice> padding, tensorflow::gtl::ArraySlice lhs_dilation, tensorflow::gtl::ArraySlice rhs_dilation, - const ConvolutionDimensionNumbers& dimension_numbers); + const ConvolutionDimensionNumbers& dimension_numbers, + int64 feature_group_count = 1); // Enqueues an FFT instruction onto the computation, of the given type and // with the given FFT length. @@ -1811,9 +1855,9 @@ XlaOp CrossReplicaSum( // For example, we have 4 replicas, then replica_group_ids={0,1,0,1} means, // replica 0 and 2 are in subgroup 0, replica 1 and 3 are in subgroup 1. // -// - `channel_id`: for Allreduce nodes from different models, if they have the +// - `channel_id`: for Allreduce nodes from different modules, if they have the // same channel_id, they will be 'Allreduce'd. If empty, Allreduce will not be -// applied cross models. +// applied cross modules. // // TODO(b/79737069): Rename this to AllReduce when it's ready to use. XlaOp CrossReplicaSum(const XlaOp& operand, const XlaComputation& computation, @@ -1821,6 +1865,13 @@ XlaOp CrossReplicaSum(const XlaOp& operand, const XlaComputation& computation, const tensorflow::gtl::optional& channel_id = tensorflow::gtl::nullopt); +// Enqueues an operation that do an Alltoall of the operand cross cores. +// +// TODO(b/110096724): This is NOT YET ready to use. +XlaOp AllToAll(const XlaOp& operand, int64 split_dimension, + int64 concat_dimension, int64 split_count, + const std::vector& replica_groups = {}); + // Enqueues an operation that scatters the `source` array to the selected // indices of each window. XlaOp SelectAndScatter(const XlaOp& operand, const XlaComputation& select, @@ -1973,9 +2024,14 @@ XlaOp ReducePrecision(const XlaOp& operand, const int exponent_bits, const int mantissa_bits); // Enqueues a Gather node onto the computation. -XlaOp Gather(const XlaOp& input, const XlaOp& gather_indices, +XlaOp Gather(const XlaOp& input, const XlaOp& start_indices, const GatherDimensionNumbers& dimension_numbers, - tensorflow::gtl::ArraySlice window_bounds); + tensorflow::gtl::ArraySlice slice_sizes); + +// Enqueues a Scatter node onto the computation. +XlaOp Scatter(const XlaOp& input, const XlaOp& scatter_indices, + const XlaOp& updates, const XlaComputation& update_computation, + const ScatterDimensionNumbers& dimension_numbers); // Enqueues a Send node onto the computation for device-to-device // communication. This operation sends the given operand to @@ -2238,4 +2294,4 @@ XlaOp ConstantR4FromArray4D(XlaBuilder* builder, } // namespace xla -#endif // TENSORFLOW_COMPILER_XLA_CLIENT_XLA_CLIENT_XLA_BUILDER_H_ +#endif // TENSORFLOW_COMPILER_XLA_CLIENT_XLA_BUILDER_H_ diff --git a/tensorflow/compiler/xla/client/xla_client/xla_builder_test.cc b/tensorflow/compiler/xla/client/xla_builder_test.cc similarity index 82% rename from tensorflow/compiler/xla/client/xla_client/xla_builder_test.cc rename to tensorflow/compiler/xla/client/xla_builder_test.cc index b4a5aedfb1765507ec57aa0291fc7bb33015206c..49a15ec3b449bdec07aa6ecfbc40b7b9f62c3f4e 100644 --- a/tensorflow/compiler/xla/client/xla_client/xla_builder_test.cc +++ b/tensorflow/compiler/xla/client/xla_builder_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include @@ -24,6 +24,7 @@ limitations under the License. #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/test_helpers.h" #include "tensorflow/compiler/xla/xla_data.pb.h" namespace xla { @@ -46,6 +47,17 @@ class XlaBuilderTest : public ::testing::Test { return HloModule::CreateFromProto(proto, config); } + // Overload which explicitly specifies the root instruction. + StatusOr> BuildHloModule(XlaBuilder* b, + XlaOp root) { + TF_ASSIGN_OR_RETURN(XlaComputation computation, b->Build(root)); + const HloModuleProto& proto = computation.proto(); + TF_ASSIGN_OR_RETURN(const auto& config, + HloModule::CreateModuleConfigFromProto( + proto, legacy_flags::GetDebugOptionsFromFlags())); + return HloModule::CreateFromProto(proto, config); + } + // Returns the name of the test currently being run. string TestName() const { return ::testing::UnitTest::GetInstance()->current_test_info()->name(); @@ -293,6 +305,21 @@ TEST_F(XlaBuilderTest, Transpose) { EXPECT_THAT(root, op::Transpose(op::Parameter())); } +TEST_F(XlaBuilderTest, AllToAll) { + XlaBuilder b(TestName()); + auto x = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {4, 16}), "x"); + AllToAll(x, /*split_dimension=*/1, /*concat_dimension=*/0, + /*split_count=*/2); + TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); + auto root = module->entry_computation()->root_instruction(); + + // AllToAll is decomposed into slices -> all-to-all -> gte -> concat. + EXPECT_EQ(root->opcode(), HloOpcode::kConcatenate); + EXPECT_EQ(root->operand(0)->operand(0)->opcode(), HloOpcode::kAllToAll); + EXPECT_TRUE( + ShapeUtil::Equal(root->shape(), ShapeUtil::MakeShape(F32, {8, 8}))); +} + TEST_F(XlaBuilderTest, ReportError) { XlaBuilder b(TestName()); auto x = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {5, 7}), "x"); @@ -320,5 +347,45 @@ TEST_F(XlaBuilderTest, ReportErrorOrReturnHandlesErrors) { EXPECT_THAT(statusor.status().error_message(), HasSubstr("a test error")); } +TEST_F(XlaBuilderTest, BuildWithSpecificRoot) { + XlaBuilder b(TestName()); + XlaOp constant = ConstantR0(&b, 1.0); + Add(constant, ConstantR0(&b, 2.0)); + TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b, /*root=*/constant)); + auto root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, op::Constant()); +} + +TEST_F(XlaBuilderTest, BuildWithSpecificRootAndMultipleParameters) { + // Specifying a particular root in Build should still include all entry + // parameters. + XlaBuilder b(TestName()); + const Shape shape = ShapeUtil::MakeShape(F32, {42, 123}); + XlaOp x = Parameter(&b, 0, shape, "x"); + XlaOp y = Parameter(&b, 1, shape, "y"); + XlaOp z = Parameter(&b, 2, shape, "z"); + Add(x, Sub(y, z)); + TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b, /*root=*/x)); + auto root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, op::Parameter()); + EXPECT_EQ(module->entry_computation()->num_parameters(), 3); + EXPECT_EQ(module->entry_computation()->instruction_count(), 5); +} + +TEST_F(XlaBuilderTest, BuildWithSpecificRootWithWrongBuilder) { + XlaBuilder b(TestName()); + XlaBuilder other_b(TestName()); + const Shape shape = ShapeUtil::MakeShape(F32, {42, 123}); + + Parameter(&b, 0, shape, "param"); + XlaOp other_param = Parameter(&other_b, 0, shape, "other_param"); + + Status status = b.Build(other_param).status(); + ASSERT_IS_NOT_OK(status); + EXPECT_THAT( + status.error_message(), + ::testing::HasSubstr("root operation is not in this computation")); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/client/xla_client/BUILD b/tensorflow/compiler/xla/client/xla_client/BUILD deleted file mode 100644 index a7168e731b064cf11d6aa54c6a56d09ebc421797..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/xla/client/xla_client/BUILD +++ /dev/null @@ -1,68 +0,0 @@ -# Description: -# The new XLA client libraries. - -licenses(["notice"]) # Apache 2.0 - -package(default_visibility = [":friends"]) - -package_group( - name = "friends", - includes = [ - "//tensorflow/compiler/xla:friends", - ], -) - -# Filegroup used to collect source files for dependency checking. -filegroup( - name = "c_srcs", - data = glob([ - "**/*.cc", - "**/*.h", - ]), -) - -load("//tensorflow:tensorflow.bzl", "tf_cc_test") - -cc_library( - name = "xla_builder", - srcs = ["xla_builder.cc"], - hdrs = ["xla_builder.h"], - visibility = ["//visibility:public"], - deps = [ - "//tensorflow/compiler/xla:execution_options_util", - "//tensorflow/compiler/xla:literal", - "//tensorflow/compiler/xla:literal_util", - "//tensorflow/compiler/xla:shape_util", - "//tensorflow/compiler/xla:status_macros", - "//tensorflow/compiler/xla:statusor", - "//tensorflow/compiler/xla:types", - "//tensorflow/compiler/xla:util", - "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/client:padding", - "//tensorflow/compiler/xla/client:sharding_builder", - "//tensorflow/compiler/xla/client:xla_computation", - "//tensorflow/compiler/xla/service:hlo", - "//tensorflow/compiler/xla/service:hlo_proto", - "//tensorflow/compiler/xla/service:shape_inference", - "//tensorflow/core:lib", - ], -) - -tf_cc_test( - name = "xla_builder_test", - srcs = ["xla_builder_test.cc"], - deps = [ - ":xla_builder", - "//tensorflow/compiler/xla:literal", - "//tensorflow/compiler/xla:shape_util", - "//tensorflow/compiler/xla:status_macros", - "//tensorflow/compiler/xla:test", - "//tensorflow/compiler/xla:test_helpers", - "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/client:xla_computation", - "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", - "//tensorflow/compiler/xla/service:hlo", - "//tensorflow/compiler/xla/service:hlo_matchers", - "//tensorflow/core:test", - ], -) diff --git a/tensorflow/compiler/xla/client/xla_computation.cc b/tensorflow/compiler/xla/client/xla_computation.cc index 3543d41fc2656ec028646edebc0bf5b6af7f67a5..22c9e83bb2ae9e3e205bdd480b64c703e31c6ffd 100644 --- a/tensorflow/compiler/xla/client/xla_computation.cc +++ b/tensorflow/compiler/xla/client/xla_computation.cc @@ -17,7 +17,7 @@ limitations under the License. #include -#include "tensorflow/compiler/xla/ptr_util.h" +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/util.h" @@ -32,7 +32,7 @@ StatusOr> XlaComputation::Snapshot() const { if (IsNull()) { return InvalidArgument("Computation is invalid."); } - auto session = MakeUnique(); + auto session = absl::make_unique(); *session->mutable_hlo()->mutable_hlo_module() = proto_; return std::move(session); } diff --git a/tensorflow/compiler/xla/experimental/xla_sharding/xla_sharding.py b/tensorflow/compiler/xla/experimental/xla_sharding/xla_sharding.py index abd10b164eaef8e75ed304483861baf250c5b954..fb135f5ceda67ce6c001de15b8f3f084ca164826 100644 --- a/tensorflow/compiler/xla/experimental/xla_sharding/xla_sharding.py +++ b/tensorflow/compiler/xla/experimental/xla_sharding/xla_sharding.py @@ -20,7 +20,7 @@ from __future__ import print_function import math -import numpy as np +import numpy as _np # Avoids becoming a part of public Tensorflow API. from tensorflow.compiler.xla import xla_data_pb2 from tensorflow.compiler.xla.python_api import xla_shape @@ -85,7 +85,7 @@ class Sharding(object): something we really want to expose to users (especially as the contract for tile_assignment is very strict). """ - if not isinstance(tile_assignment, np.ndarray): + if not isinstance(tile_assignment, _np.ndarray): raise TypeError('Tile assignment must be of type np.ndarray') if not isinstance(tile_shape, xla_shape.Shape): raise TypeError('Tile shape must be of type xla_shape.Shape') diff --git a/tensorflow/compiler/xla/iterator_util_test.cc b/tensorflow/compiler/xla/iterator_util_test.cc index 7bc3189507ec5233c6983eb26cfb07dc9bfadd52..ec8b66df2db0b9d8c045fbf6133f607e57c81c26 100644 --- a/tensorflow/compiler/xla/iterator_util_test.cc +++ b/tensorflow/compiler/xla/iterator_util_test.cc @@ -18,7 +18,7 @@ limitations under the License. #include #include -#include "tensorflow/compiler/xla/ptr_util.h" +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/test.h" namespace xla { @@ -27,7 +27,7 @@ namespace { TEST(UnwrappingIteratorTest, Simple) { std::vector> v; for (int i = 0; i < 3; ++i) { - v.push_back(MakeUnique(i)); + v.push_back(absl::make_unique(i)); } int i = 0; for (auto iter = MakeUnwrappingIterator(v.begin()); @@ -51,7 +51,7 @@ TEST(UnwrappingIteratorTest, PostincrementOperator) { TEST(UnwrappingIteratorTest, StdFind) { std::list> l; for (int i = 0; i < 3; ++i) { - l.push_back(MakeUnique(i)); + l.push_back(absl::make_unique(i)); } EXPECT_EQ(l.begin()->get(), *std::find(MakeUnwrappingIterator(l.begin()), diff --git a/tensorflow/compiler/xla/layout_util.cc b/tensorflow/compiler/xla/layout_util.cc index 15eeb2ea13607d43c995197f8f0e3c58abd4d94a..b72d190d54591384392e79e73e90cf52df04a902 100644 --- a/tensorflow/compiler/xla/layout_util.cc +++ b/tensorflow/compiler/xla/layout_util.cc @@ -297,7 +297,7 @@ Layout CreateDefaultLayoutForRank(int64 rank) { shape.layout().padded_dimensions_size() == 0) { return false; } - CHECK(IsDenseArray(shape)); + CHECK(IsDenseArray(shape)) << shape.ShortDebugString(); CHECK_EQ(shape.dimensions_size(), shape.layout().padded_dimensions_size()); for (int64 i = 0; i < shape.dimensions_size(); ++i) { if (shape.layout().padded_dimensions(i) > shape.dimensions(i)) { diff --git a/tensorflow/compiler/xla/legacy_flags/debug_options_flags.cc b/tensorflow/compiler/xla/legacy_flags/debug_options_flags.cc index f42fb92359f40ec763866af094972046f6407ae1..5d27e4a46b57242c96ee84d37466ffb7d613a974 100644 --- a/tensorflow/compiler/xla/legacy_flags/debug_options_flags.cc +++ b/tensorflow/compiler/xla/legacy_flags/debug_options_flags.cc @@ -31,7 +31,6 @@ std::vector* flag_objects; std::once_flag flags_init; void SetDebugOptionsDefaults(DebugOptions* flags) { - flags->set_xla_enable_fast_math(true); flags->set_xla_llvm_enable_alias_scope_metadata(true); flags->set_xla_llvm_enable_noalias_metadata(true); flags->set_xla_llvm_enable_invariant_load_metadata(true); @@ -53,6 +52,11 @@ void SetDebugOptionsDefaults(DebugOptions* flags) { // the heuristics needed to decide when to run on multiple streams. See // b/77879207. flags->set_xla_gpu_disable_multi_streaming(true); + + // TODO(jlebar): Disable fastmath once doing so is not a performance + // regression. + flags->set_xla_cpu_enable_fast_math(true); + flags->set_xla_gpu_enable_fast_math(true); } // Allocates flag_values and flag_objects; this function must not be called more @@ -150,10 +154,16 @@ void AllocateFlags() { flag_values->mutable_xla_generate_hlo_text_to(), "Dump all HLO modules as text into the provided directory path."), tensorflow::Flag( - "xla_enable_fast_math", - bool_setter_for(&DebugOptions::set_xla_enable_fast_math), - flag_values->xla_enable_fast_math(), - "Enable unsafe fast-math optimizations in the compiler; " + "xla_cpu_enable_fast_math", + bool_setter_for(&DebugOptions::set_xla_cpu_enable_fast_math), + flag_values->xla_cpu_enable_fast_math(), + "Enable unsafe fast-math optimizations in the CPU compiler; " + "this may produce faster code at the expense of some accuracy."), + tensorflow::Flag( + "xla_gpu_enable_fast_math", + bool_setter_for(&DebugOptions::set_xla_cpu_enable_fast_math), + flag_values->xla_cpu_enable_fast_math(), + "Enable unsafe fast-math optimizations in the GPU compiler; " "this may produce faster code at the expense of some accuracy."), tensorflow::Flag( "xla_llvm_enable_alias_scope_metadata", @@ -306,6 +316,13 @@ void AllocateFlags() { bool_setter_for(&DebugOptions::set_xla_cpu_use_mkl_dnn), flag_values->xla_cpu_use_mkl_dnn(), "Generate calls to MKL-DNN in the CPU backend."), + tensorflow::Flag( + "xla_gpu_crash_on_verification_failures", + bool_setter_for( + &DebugOptions::set_xla_gpu_crash_on_verification_failures), + flag_values->xla_gpu_crash_on_verification_failures(), + "Crashes the program on extra verification failures, e.g. cuDNN " + "cross checking failures"), }); ParseFlagsFromEnv(*flag_objects); } diff --git a/tensorflow/compiler/xla/literal.cc b/tensorflow/compiler/xla/literal.cc index 0545deb096e9eace5a9713f200e10559aa718441..d54f051a1a959488fe716e17b69ba087e4020ae3 100644 --- a/tensorflow/compiler/xla/literal.cc +++ b/tensorflow/compiler/xla/literal.cc @@ -22,6 +22,7 @@ limitations under the License. #include #include +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/index_util.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" @@ -71,7 +72,7 @@ std::ostream& operator<<(std::ostream& out, const Literal& literal) { return out; } -Literal::StrideConfig::StrideConfig( +MutableLiteralBase::StrideConfig::StrideConfig( const Shape& source_shape, const Shape& dest_shape, tensorflow::gtl::ArraySlice dimensions) : dimensions(dimensions), @@ -133,7 +134,8 @@ void Literal::SetPiece(const Shape& shape, Piece* piece, bool allocate_arrays) { } Literal::Literal(const Shape& shape, bool allocate_arrays) - : LiteralBase(), shape_(MakeUnique(shape)) { + : MutableLiteralBase() { + shape_ = absl::make_unique(shape); CHECK(LayoutUtil::HasLayout(*shape_)); root_piece_ = new Piece(); root_piece_->set_subshape(shape_.get()); @@ -159,7 +161,9 @@ void Literal::DeallocateBuffers() { }); } -Literal::Literal(Literal&& other) : LiteralBase() { *this = std::move(other); } +Literal::Literal(Literal&& other) : MutableLiteralBase() { + *this = std::move(other); +} Literal& Literal::operator=(Literal&& other) { DCHECK(&other.root_piece_->subshape() == other.shape_.get()); @@ -172,7 +176,7 @@ Literal& Literal::operator=(Literal&& other) { } std::unique_ptr LiteralBase::CreateFromShape(const Shape& shape) { - auto literal = MakeUnique(shape); + auto literal = absl::make_unique(shape); literal->root_piece_->ForEachMutableSubpiece( [&](const ShapeIndex& index, Piece* piece) { if (ShapeUtil::IsArray(piece->subshape())) { @@ -187,12 +191,13 @@ const SparseIndexArray* LiteralBase::sparse_indices( return piece(shape_index).sparse_indices(); } -SparseIndexArray* Literal::sparse_indices(const ShapeIndex& shape_index) { +SparseIndexArray* MutableLiteralBase::sparse_indices( + const ShapeIndex& shape_index) { return piece(shape_index).sparse_indices(); } template -Status Literal::CopySliceFromInternal( +Status MutableLiteralBase::CopySliceFromInternal( const LiteralBase& src_literal, tensorflow::gtl::ArraySlice src_base, tensorflow::gtl::ArraySlice dest_base, tensorflow::gtl::ArraySlice copy_size) { @@ -225,8 +230,8 @@ Status Literal::CopySliceFromInternal( // proper stride size at the matching dimension. DimensionVector src_indexes(src_base.size(), 0); DimensionVector dest_indexes(dest_base.size(), 0); - Literal::StrideConfig stride_config(src_literal.shape(), shape(), - copy_size); + MutableLiteralBase::StrideConfig stride_config(src_literal.shape(), shape(), + copy_size); auto copy_proc = [&](tensorflow::gtl::ArraySlice indexes) { // Map from multi-dimensional index, to source index. @@ -253,9 +258,10 @@ Status Literal::CopySliceFromInternal( return Status::OK(); } -Status Literal::CopyElementFrom(const LiteralSlice& src_literal, - tensorflow::gtl::ArraySlice src_index, - tensorflow::gtl::ArraySlice dest_index) { +Status MutableLiteralBase::CopyElementFrom( + const LiteralSlice& src_literal, + tensorflow::gtl::ArraySlice src_index, + tensorflow::gtl::ArraySlice dest_index) { DCHECK_EQ(shape().element_type(), src_literal.shape().element_type()); const int64 src_linear_index = IndexUtil::MultidimensionalIndexToLinearIndex( src_literal.shape(), src_index); @@ -275,8 +281,8 @@ Status Literal::CopyElementFrom(const LiteralSlice& src_literal, return Status::OK(); } -/* static */ StatusOr> Literal::CreateFromProto( - const LiteralProto& proto) { +/* static */ StatusOr> +MutableLiteralBase::CreateFromProto(const LiteralProto& proto) { if (!proto.has_shape()) { return InvalidArgument("LiteralProto has no shape"); } @@ -284,7 +290,7 @@ Status Literal::CopyElementFrom(const LiteralSlice& src_literal, return InvalidArgument("LiteralProto has no layout"); } - auto literal = MakeUnique(proto.shape()); + auto literal = absl::make_unique(proto.shape()); TF_RETURN_IF_ERROR(literal->root_piece_->ForEachMutableSubpieceWithStatus( [&](const ShapeIndex& index, Piece* piece) { @@ -405,9 +411,9 @@ Status LiteralBase::Piece::CopyFrom(const LiteralBase::Piece& src) { return Status::OK(); } -Status Literal::CopyFrom(const LiteralSlice& src_literal, - const ShapeIndex& dest_shape_index, - const ShapeIndex& src_shape_index) { +Status MutableLiteralBase::CopyFrom(const LiteralSlice& src_literal, + const ShapeIndex& dest_shape_index, + const ShapeIndex& src_shape_index) { const Shape& dest_subshape = ShapeUtil::GetSubshape(shape(), dest_shape_index); const Shape& src_subshape = @@ -474,7 +480,7 @@ Status Literal::MoveFrom(Literal&& src_literal, dest_piece.set_sparse_indices(src_piece.sparse_indices()); }); - src_literal.shape_ = MakeUnique(ShapeUtil::MakeNil()); + src_literal.shape_ = absl::make_unique(ShapeUtil::MakeNil()); delete src_literal.root_piece_; src_literal.root_piece_ = new LiteralBase::Piece(); src_literal.root_piece_->set_subshape(src_literal.shape_.get()); @@ -482,10 +488,11 @@ Status Literal::MoveFrom(Literal&& src_literal, return Status::OK(); } -Status Literal::CopySliceFrom(const LiteralSlice& src_literal, - tensorflow::gtl::ArraySlice src_base, - tensorflow::gtl::ArraySlice dest_base, - tensorflow::gtl::ArraySlice copy_size) { +Status MutableLiteralBase::CopySliceFrom( + const LiteralSlice& src_literal, + tensorflow::gtl::ArraySlice src_base, + tensorflow::gtl::ArraySlice dest_base, + tensorflow::gtl::ArraySlice copy_size) { TF_RET_CHECK(ShapeUtil::IsArray(shape())) << ShapeUtil::HumanString(shape()); TF_RET_CHECK(ShapeUtil::IsArray(src_literal.shape())) << ShapeUtil::HumanString(src_literal.shape()); @@ -543,7 +550,7 @@ Status Literal::CopySliceFrom(const LiteralSlice& src_literal, shape().element_type()); } -void Literal::PopulateR1(const tensorflow::core::Bitmap& values) { +void MutableLiteralBase::PopulateR1(const tensorflow::core::Bitmap& values) { CHECK(ShapeUtil::IsArray(shape())); CHECK_EQ(ShapeUtil::Rank(shape()), 1); CHECK_EQ(element_count(), values.bits()); @@ -560,7 +567,7 @@ std::unique_ptr LiteralBase::Relayout( Shape* subshape = ShapeUtil::GetMutableSubshape(&new_shape, shape_index); TF_CHECK_OK(LayoutUtil::ValidateLayoutForShape(new_layout, *subshape)); *subshape->mutable_layout() = new_layout; - auto result = MakeUnique(new_shape); + auto result = absl::make_unique(new_shape); TF_CHECK_OK(result->CopyFrom(*this)); return result; } @@ -596,7 +603,7 @@ StatusOr> LiteralBase::Broadcast( result_shape.dimensions(dimensions[i])); } - std::unique_ptr result = MakeUnique(result_shape); + std::unique_ptr result = absl::make_unique(result_shape); // scratch_source_index is temporary storage space for the computed index into // the input literal. We put it here to avoid allocating an std::vector in @@ -685,7 +692,7 @@ std::unique_ptr LiteralBase::Transpose( for (auto index : LayoutUtil::MinorToMajor(shape())) { layout->add_minor_to_major(inverse_permutation[index]); } - auto new_literal = MakeUnique(permuted_shape); + auto new_literal = absl::make_unique(permuted_shape); DCHECK_EQ(ShapeUtil::ByteSizeOf(new_literal->shape()), ShapeUtil::ByteSizeOf(shape())); std::memcpy(new_literal->untyped_data(), untyped_data(), size_bytes()); @@ -696,7 +703,7 @@ template std::unique_ptr LiteralBase::SliceInternal( const Shape& result_shape, tensorflow::gtl::ArraySlice start_indices) const { - auto result_literal = MakeUnique(result_shape); + auto result_literal = absl::make_unique(result_shape); DimensionVector new_indices(ShapeUtil::Rank(result_shape)); result_literal->EachCell( [&](tensorflow::gtl::ArraySlice indices, NativeT /*value*/) { @@ -750,7 +757,7 @@ Literal LiteralBase::Clone() const { } std::unique_ptr LiteralBase::CloneToUnique() const { - auto result = MakeUnique(shape()); + auto result = absl::make_unique(shape()); TF_CHECK_OK(result->CopyFrom(*this)); return result; } @@ -895,8 +902,8 @@ size_t LiteralBase::Hash() const { return hash_value; } -Status Literal::SetIntegralAsS64(tensorflow::gtl::ArraySlice multi_index, - int64 value) { +Status MutableLiteralBase::SetIntegralAsS64( + tensorflow::gtl::ArraySlice multi_index, int64 value) { CHECK(LayoutUtil::IsDenseArray(shape())); switch (shape().element_type()) { case PRED: @@ -933,7 +940,7 @@ tensorflow::gtl::ArraySlice LiteralBase::GetSparseIndex( return p.sparse_indices()->At(sparse_element_number); } -void Literal::SortSparseElements(const ShapeIndex& shape_index) { +void MutableLiteralBase::SortSparseElements(const ShapeIndex& shape_index) { piece(shape_index).SortSparseElements(); } @@ -1197,7 +1204,7 @@ template std::unique_ptr ConvertBetweenNativeTypesWithConverter( const LiteralBase& src_literal, const ConverterType& converter) { CHECK(ShapeUtil::IsArray(src_literal.shape())); - auto result_literal = MakeUnique(ShapeUtil::ChangeElementType( + auto result_literal = absl::make_unique(ShapeUtil::ChangeElementType( src_literal.shape(), primitive_util::NativeToPrimitiveType())); auto src_data = src_literal.data(); @@ -1243,7 +1250,7 @@ BitcastBetweenNativeTypes(const LiteralBase& src_literal) { template std::unique_ptr ConvertToC64(const LiteralBase& src_literal) { CHECK(ShapeUtil::IsArray(src_literal.shape())); - auto result_literal = MakeUnique( + auto result_literal = absl::make_unique( ShapeUtil::ChangeElementType(src_literal.shape(), C64)); using NativeSrcT = typename primitive_util::PrimitiveTypeToNative::type; @@ -1390,12 +1397,12 @@ StatusOr> LiteralBase::ConvertToShape( element.ConvertToShape(ShapeUtil::GetSubshape(dest_shape, {i}))); elements.push_back(std::move(*new_element)); } - auto converted = MakeUnique(); - *converted = Literal::MoveIntoTuple(&elements); + auto converted = absl::make_unique(); + *converted = MutableLiteralBase::MoveIntoTuple(&elements); return std::move(converted); } -/* static */ Literal Literal::MoveIntoTuple( +/* static */ Literal MutableLiteralBase::MoveIntoTuple( tensorflow::gtl::MutableArraySlice elements) { std::vector element_shapes; for (const Literal& element : elements) { @@ -1808,7 +1815,8 @@ Status CopyFromRepeatedField(tensorflow::gtl::MutableArraySlice dest, } // namespace Status LiteralBase::Piece::CopyFromProto(const LiteralProto& proto) { - // These conditions should have been checked in Literal::CreateFromProto. + // These conditions should have been checked in + // MutableLiteralBase::CreateFromProto. TF_RET_CHECK(proto.has_shape()); TF_RET_CHECK(LayoutUtil::HasLayout(proto.shape())); TF_RET_CHECK(ShapeUtil::Equal(proto.shape(), subshape())); @@ -1900,7 +1908,7 @@ const void* LiteralBase::untyped_data(const ShapeIndex& shape_index) const { return piece(shape_index).untyped_data(); } -void* Literal::untyped_data(const ShapeIndex& shape_index) { +void* MutableLiteralBase::untyped_data(const ShapeIndex& shape_index) { return piece(shape_index).untyped_data(); } @@ -1916,6 +1924,127 @@ string LiteralBase::GetR1U8AsString() const { ShapeUtil::ElementsIn(shape())); } +void MutableBorrowingLiteral::CopyPieceSubtree(const Shape& shape, + Piece* src_piece, + Piece* dest_piece) { + DCHECK(ShapeUtil::Equal(src_piece->subshape(), dest_piece->subshape())) + << "src_piece has shape: " + << ShapeUtil::HumanString(src_piece->subshape()) + << "dest_piece has shape: " + << ShapeUtil::HumanString(dest_piece->subshape()); + if (ShapeUtil::IsTuple(shape)) { + for (int i = 0; i < ShapeUtil::TupleElementCount(shape); ++i) { + const Shape& subshape = shape.tuple_shapes(i); + + auto child_piece = Piece(); + child_piece.set_subshape(&subshape); + + CopyPieceSubtree(subshape, &src_piece->child(i), &child_piece); + + dest_piece->emplace_back(std::move(child_piece)); + } + } else if (ShapeUtil::IsArray(shape)) { + dest_piece->set_buffer(src_piece->buffer()); + } else { + // If the shape is neither an array nor tuple, then it must be + // zero-sized. Otherwise, some memory needs to be allocated for it. + CHECK_EQ(dest_piece->size_bytes(), 0); + } +} + +MutableLiteralBase::~MutableLiteralBase() {} + +MutableBorrowingLiteral::MutableBorrowingLiteral( + const MutableBorrowingLiteral& literal) + : MutableLiteralBase() { + shape_ = absl::make_unique(literal.shape()); + CHECK(LayoutUtil::HasLayout(*shape_)); + + root_piece_ = new Piece(); + root_piece_->set_subshape(shape_.get()); + + CopyPieceSubtree(*shape_, &literal.root_piece(), root_piece_); +} + +MutableBorrowingLiteral& MutableBorrowingLiteral::operator=( + const MutableBorrowingLiteral& literal) { + shape_ = absl::make_unique(literal.shape()); + CHECK(LayoutUtil::HasLayout(*shape_)); + + root_piece_ = new Piece(); + root_piece_->set_subshape(shape_.get()); + + CopyPieceSubtree(*shape_, &literal.root_piece(), root_piece_); + + return *this; +} + +MutableBorrowingLiteral::MutableBorrowingLiteral( + const MutableLiteralBase& literal) + : MutableLiteralBase() { + shape_ = absl::make_unique(literal.shape()); + CHECK(LayoutUtil::HasLayout(*shape_)); + + root_piece_ = new Piece(); + root_piece_->set_subshape(shape_.get()); + + CopyPieceSubtree(*shape_, &literal.root_piece(), root_piece_); +} + +MutableBorrowingLiteral::MutableBorrowingLiteral(MutableLiteralBase* literal) + : MutableLiteralBase() { + shape_ = absl::make_unique(literal->shape()); + CHECK(LayoutUtil::HasLayout(*shape_)); + + root_piece_ = new Piece(); + root_piece_->set_subshape(shape_.get()); + + CopyPieceSubtree(*shape_, &literal->root_piece(), root_piece_); +} + +MutableBorrowingLiteral::MutableBorrowingLiteral( + MutableBorrowingLiteral literal, const ShapeIndex& view_root) + : MutableLiteralBase() { + shape_ = absl::make_unique(literal.piece(view_root).subshape()); + CHECK(LayoutUtil::HasLayout(*shape_)); + + root_piece_ = new Piece(); + root_piece_->set_subshape(shape_.get()); + + CopyPieceSubtree(*shape_, &literal.piece(view_root), root_piece_); +} + +MutableBorrowingLiteral::MutableBorrowingLiteral(const char* src_buf_ptr, + const Shape& shape) + : MutableLiteralBase() { + shape_ = absl::make_unique(shape); + CHECK(LayoutUtil::HasLayout(*shape_)); + CHECK(!ShapeUtil::IsTuple(*shape_)); + + root_piece_ = new Piece(); + root_piece_->set_buffer(const_cast(src_buf_ptr)); + root_piece_->set_subshape(shape_.get()); +} + +MutableBorrowingLiteral::~MutableBorrowingLiteral() { + if (root_piece_ != nullptr) { + root_piece_->ForEachMutableSubpiece( + [&](const ShapeIndex& index, Piece* piece) { + if (piece->buffer() != nullptr) { + delete piece->sparse_indices(); + } + }); + delete root_piece_; + } +} + +LiteralSlice::LiteralSlice(const LiteralBase& literal) + : LiteralBase(), root_piece_(&literal.root_piece()) {} + +LiteralSlice::LiteralSlice(const LiteralBase& literal, + const ShapeIndex& view_root) + : LiteralBase(), root_piece_(&literal.piece(view_root)) {} + void BorrowingLiteral::BuildPieceSubtree(const Shape& shape, Piece* piece) { CHECK(ShapeUtil::IsTuple(shape)); for (int i = 0; i < ShapeUtil::TupleElementCount(shape); ++i) { @@ -1932,15 +2061,8 @@ void BorrowingLiteral::BuildPieceSubtree(const Shape& shape, Piece* piece) { } } -LiteralSlice::LiteralSlice(const LiteralBase& literal) - : LiteralBase(), root_piece_(&literal.root_piece()) {} - -LiteralSlice::LiteralSlice(const LiteralBase& literal, - const ShapeIndex& view_root) - : LiteralBase(), root_piece_(&literal.piece(view_root)) {} - BorrowingLiteral::BorrowingLiteral(const char* src_buf_ptr, const Shape& shape) - : LiteralBase(), shape_(MakeUnique(shape)) { + : LiteralBase(), shape_(absl::make_unique(shape)) { CHECK(ShapeUtil::IsArray(*shape_)); CHECK(LayoutUtil::HasLayout(*shape_)); @@ -1951,7 +2073,7 @@ BorrowingLiteral::BorrowingLiteral(const char* src_buf_ptr, const Shape& shape) BorrowingLiteral::BorrowingLiteral( tensorflow::gtl::ArraySlice src_buf_ptrs, const Shape& shape) - : LiteralBase(), shape_(MakeUnique(shape)) { + : LiteralBase(), shape_(absl::make_unique(shape)) { CHECK(ShapeUtil::IsTuple(*shape_)); CHECK(!ShapeUtil::IsNestedTuple(*shape_)); CHECK_EQ(src_buf_ptrs.size(), ShapeUtil::TupleElementCount(*shape_)); diff --git a/tensorflow/compiler/xla/literal.h b/tensorflow/compiler/xla/literal.h index dd67dfa8d4a556aea179bc47abfdc9a9c8872c45..ed9de652994bc948efe38a8fcc3ba9bed36c9f3a 100644 --- a/tensorflow/compiler/xla/literal.h +++ b/tensorflow/compiler/xla/literal.h @@ -25,13 +25,13 @@ limitations under the License. #include #include +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/array2d.h" #include "tensorflow/compiler/xla/array3d.h" #include "tensorflow/compiler/xla/array4d.h" #include "tensorflow/compiler/xla/index_util.h" #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/primitive_util.h" -#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/sparse_index_array.h" #include "tensorflow/compiler/xla/status_macros.h" @@ -310,9 +310,10 @@ class LiteralBase { // type of literal itself (0 for numeric types, and false for predicates). // // Note: It's an antipattern to use this method then immediately call - // Literal::Populate on the result (since that results in zero initialization, - // then reinitialization. Conside if a call to MakeUnique(shape), - // followed by the call to Literal::Populate can be used instead. + // MutableLiteralBase::Populate on the result (since that results in zero + // initialization, then reinitialization. Conside if a call to + // absl::make_unique(shape), followed by the call to + // MutableLiteralBase::Populate can be used instead. static std::unique_ptr CreateFromShape(const Shape& shape); protected: @@ -534,7 +535,7 @@ class LiteralBase { virtual const Piece& root_piece() const = 0; // LiteralSlice and Literal must access Pieces of other Literals. - friend class Literal; + friend class MutableLiteralBase; friend class LiteralSlice; friend class BorrowingLiteral; @@ -545,33 +546,10 @@ class LiteralBase { tensorflow::gtl::ArraySlice start_indices) const; }; -// Class representing literal values in XLA. -// -// The underlying buffer and shape is always owned by this class. -class Literal : public LiteralBase { +// Abstract base class representing a mutable literal in XLA. +class MutableLiteralBase : public LiteralBase { public: - Literal() : Literal(ShapeUtil::MakeNil()) {} - - // Create a literal of the given shape. The literal is allocated sufficient - // memory to hold the shape. Memory is uninitialized. - explicit Literal(const Shape& shape); - virtual ~Literal(); - - // Literals are moveable, but not copyable. To copy a literal use - // Literal::Clone or Literal::CloneToUnique. This prevents inadvertent copies - // of literals which can be expensive. - Literal(const Literal& other) = delete; - Literal& operator=(const Literal& other) = delete; - Literal(Literal&& other); - // 'allocate_arrays' indicates whether to allocate memory for the arrays in - // the shape. If false, buffer pointers inside of the Literal::Pieces are set - // to nullptr. - Literal(const Shape& shape, bool allocate_arrays); - Literal& operator=(Literal&& other); - - // TODO(b/67651157): Remove this accessor. Literal users should not be able to - // mutate the shape as this can produce malformed Literals. - Shape* mutable_shape_do_not_use() { return shape_.get(); } + virtual ~MutableLiteralBase() = 0; // Returns a MutableArraySlice view of the array for this literal for the // given NativeT (e.g., float). CHECKs if the subshape of the literal at the @@ -587,6 +565,10 @@ class Literal : public LiteralBase { // is not a sparse array. SparseIndexArray* sparse_indices(const ShapeIndex& shape_index = {}); + // TODO(b/67651157): Remove this accessor. Literal users should not be able to + // mutate the shape as this can produce malformed Literals. + Shape* mutable_shape_do_not_use() { return shape_.get(); } + // Returns a pointer to the underlying buffer holding the array at the given // shape index. CHECKs if the subshape of the literal at the given ShapeIndex // is not array. @@ -613,21 +595,6 @@ class Literal : public LiteralBase { const ShapeIndex& dest_shape_index = {}, const ShapeIndex& src_shape_index = {}); - // Returns a vector containing the tuple elements of this Literal as separate - // Literals. This Literal must be tuple-shaped and can be a nested tuple. The - // elements are moved into the new Literals; no data is copied. Upon return - // this Literal is set to a nil shape (empty tuple) - std::vector DecomposeTuple(); - - // Similar to CopyFrom, but with move semantincs. The subshape of this literal - // rooted at 'dest_shape_index' must be *equal* to the shape 'src_literal' - // (layouts and shapes must match), but need not be arrays. The memory - // allocated in this literal for the subshape at dest_shape_index is - // deallocated, and the respective buffers are replaced with those in - // src_literal. Upon return, src_literal is set to a nil shape (empty tuple). - Status MoveFrom(Literal&& src_literal, - const ShapeIndex& dest_shape_index = {}); - // Copies the values from src_literal, starting at src_base shape indexes, // to this literal, starting at dest_base, where the copy size in each // dimension is specified by copy_size. @@ -730,12 +697,7 @@ class Literal : public LiteralBase { static StatusOr> CreateFromProto( const LiteralProto& proto); - private: - // Recursively sets the subshapes and buffers of all subpieces rooted at - // 'piece'. If 'allocate_array' is true, memory is allocated for the arrays in - // the shape. - void SetPiece(const Shape& shape, Piece* piece, bool allocate_arrays); - + protected: // Returns the piece at the given ShapeIndex. Piece& piece(const ShapeIndex& shape_index) { return const_cast(LiteralBase::piece(shape_index)); @@ -783,12 +745,83 @@ class Literal : public LiteralBase { template Status PopulateInternal(const FnType& generator, bool parallel); + friend class LiteralBase; + friend class MutableBorrowingLiteral; +}; +std::ostream& operator<<(std::ostream& out, const Literal& literal); + +// The underlying buffer and shape is always owned by this class. +class Literal : public MutableLiteralBase { + public: + Literal() : Literal(ShapeUtil::MakeNil()) {} + + // Create a literal of the given shape. The literal is allocated sufficient + // memory to hold the shape. Memory is uninitialized. + explicit Literal(const Shape& shape); + virtual ~Literal(); + + // Literals are moveable, but not copyable. To copy a literal use + // Literal::Clone or Literal::CloneToUnique. This prevents inadvertent copies + // of literals which can be expensive. + Literal(const Literal& other) = delete; + Literal& operator=(const Literal& other) = delete; + Literal(Literal&& other); + // 'allocate_arrays' indicates whether to allocate memory for the arrays in + // the shape. If false, buffer pointers inside of the Literal::Pieces are set + // to nullptr. + Literal(const Shape& shape, bool allocate_arrays); + Literal& operator=(Literal&& other); + + // Similar to CopyFrom, but with move semantincs. The subshape of this literal + // rooted at 'dest_shape_index' must be *equal* to the shape 'src_literal' + // (layouts and shapes must match), but need not be arrays. The memory + // allocated in this literal for the subshape at dest_shape_index is + // deallocated, and the respective buffers are replaced with those in + // src_literal. Upon return, src_literal is set to a nil shape (empty tuple). + virtual Status MoveFrom(Literal&& src_literal, + const ShapeIndex& dest_shape_index = {}); + + // Returns a vector containing the tuple elements of this Literal as separate + // Literals. This Literal must be tuple-shaped and can be a nested tuple. The + // elements are moved into the new Literals; no data is copied. Upon return + // this Literal is set to a nil shape (empty tuple) + std::vector DecomposeTuple(); + + private: // Deallocate the buffers held by this literal. void DeallocateBuffers(); - friend class LiteralBase; + // Recursively sets the subshapes and buffers of all subpieces rooted at + // 'piece'. If 'allocate_array' is true, memory is allocated for the arrays in + // the shape. + void SetPiece(const Shape& shape, Piece* piece, bool allocate_arrays); +}; + +// The underlying buffer is not owned by this class and is always owned by +// others. The shape is not owned by this class and not mutable. +class MutableBorrowingLiteral : public MutableLiteralBase { + public: + virtual ~MutableBorrowingLiteral(); + + MutableBorrowingLiteral() : MutableLiteralBase() {} + + MutableBorrowingLiteral(const MutableBorrowingLiteral& literal); + MutableBorrowingLiteral& operator=(const MutableBorrowingLiteral& literal); + + // Implicit conversion constructors. + MutableBorrowingLiteral(const MutableLiteralBase& literal); + MutableBorrowingLiteral(MutableLiteralBase* literal); + MutableBorrowingLiteral(MutableBorrowingLiteral literal, + const ShapeIndex& view_root); + MutableBorrowingLiteral(const char* src_buf_ptr, const Shape& shape); + + private: + // Recursively copies the subtree from the `src_piece` at the given child + // index to the `dest_piece`. For buffers only the pointers are copied, but + // not the content. + void CopyPieceSubtree(const Shape& shape, Piece* src_piece, + Piece* dest_piece); }; -std::ostream& operator<<(std::ostream& out, const Literal& literal); // A read-only view of a Literal. A LiteralSlice contains pointers to shape and // literal buffers always owned by others. @@ -831,9 +864,9 @@ class BorrowingLiteral : public LiteralBase { const Piece& root_piece() const override { return root_piece_; }; Piece root_piece_; - // Shape of this literal. Stored as unique_ptr so such that the (default) - // move construction of this class would be trivially correct: the pointer to - // Shape root_piece_ stores will still point to the correct address. + // Shape of this literal. Stored as unique_ptr such that the (default) move + // construction of this class would be trivially correct: the pointer to Shape + // root_piece_ stores will still point to the correct address. std::unique_ptr shape_; }; @@ -886,7 +919,7 @@ tensorflow::gtl::ArraySlice LiteralBase::data( } template -tensorflow::gtl::MutableArraySlice Literal::data( +tensorflow::gtl::MutableArraySlice MutableLiteralBase::data( const ShapeIndex& shape_index) { return piece(shape_index).data(); } @@ -904,14 +937,15 @@ inline NativeT LiteralBase::Get( } template -inline void Literal::Set(tensorflow::gtl::ArraySlice multi_index, - const ShapeIndex& shape_index, NativeT value) { +inline void MutableLiteralBase::Set( + tensorflow::gtl::ArraySlice multi_index, + const ShapeIndex& shape_index, NativeT value) { return piece(shape_index).Set(multi_index, value); } template -inline void Literal::Set(tensorflow::gtl::ArraySlice multi_index, - NativeT value) { +inline void MutableLiteralBase::Set( + tensorflow::gtl::ArraySlice multi_index, NativeT value) { return root_piece().Set(multi_index, value); } @@ -929,7 +963,7 @@ NativeT LiteralBase::GetSparseElement(int64 sparse_element_number, } template -void Literal::AppendSparseElement( +void MutableLiteralBase::AppendSparseElement( tensorflow::gtl::ArraySlice multi_index, NativeT value, const ShapeIndex& shape_index) { Piece& p = piece(shape_index); @@ -959,7 +993,8 @@ void LiteralBase::EachCell( } template -inline void Literal::PopulateR1(tensorflow::gtl::ArraySlice values) { +inline void MutableLiteralBase::PopulateR1( + tensorflow::gtl::ArraySlice values) { CHECK(ShapeUtil::IsArray(shape())); CHECK_EQ(ShapeUtil::Rank(shape()), 1); CHECK_EQ(ShapeUtil::ElementsIn(shape()), values.size()); @@ -971,7 +1006,7 @@ inline void Literal::PopulateR1(tensorflow::gtl::ArraySlice values) { } template -void Literal::PopulateR2( +void MutableLiteralBase::PopulateR2( std::initializer_list> values) { CHECK(ShapeUtil::IsArray(shape())); CHECK_EQ(ShapeUtil::Rank(shape()), 2); @@ -996,7 +1031,7 @@ void Literal::PopulateR2( } template -void Literal::PopulateFromArray(const Array& values) { +void MutableLiteralBase::PopulateFromArray(const Array& values) { CHECK(ShapeUtil::IsArray(shape())); CHECK_EQ(shape().element_type(), primitive_util::NativeToPrimitiveType()); @@ -1009,24 +1044,24 @@ void Literal::PopulateFromArray(const Array& values) { } template -void Literal::PopulateR2FromArray2D(const Array2D& values) { +void MutableLiteralBase::PopulateR2FromArray2D(const Array2D& values) { PopulateFromArray(values); } template -void Literal::PopulateR3FromArray3D(const Array3D& values) { +void MutableLiteralBase::PopulateR3FromArray3D(const Array3D& values) { PopulateFromArray(values); } template -void Literal::PopulateR4FromArray4D(const Array4D& values) { +void MutableLiteralBase::PopulateR4FromArray4D(const Array4D& values) { PopulateFromArray(values); } template -void Literal::PopulateSparse(SparseIndexArray indices, - tensorflow::gtl::ArraySlice values, - bool sort) { +void MutableLiteralBase::PopulateSparse( + SparseIndexArray indices, tensorflow::gtl::ArraySlice values, + bool sort) { CHECK(LayoutUtil::IsSparseArray(shape())); int rank = ShapeUtil::Rank(shape()); CHECK_EQ(indices.rank(), rank); @@ -1049,7 +1084,8 @@ void Literal::PopulateSparse(SparseIndexArray indices, } template -Status Literal::PopulateInternal(const FnType& generator, bool parallel) { +Status MutableLiteralBase::PopulateInternal(const FnType& generator, + bool parallel) { const Shape& this_shape = shape(); const int64 rank = ShapeUtil::Rank(this_shape); TF_RET_CHECK(LayoutUtil::IsDenseArray(this_shape)); @@ -1092,17 +1128,17 @@ Status Literal::PopulateInternal(const FnType& generator, bool parallel) { return Status::OK(); } template -Status Literal::Populate(const FnType& generator) { +Status MutableLiteralBase::Populate(const FnType& generator) { return PopulateInternal(generator, /*parallel=*/false); } template -Status Literal::PopulateParallel(const FnType& generator) { +Status MutableLiteralBase::PopulateParallel(const FnType& generator) { return PopulateInternal(generator, /*parallel=*/true); } template -void Literal::PopulateWithValue(NativeT value) { +void MutableLiteralBase::PopulateWithValue(NativeT value) { CHECK(ShapeUtil::IsArray(shape())); CHECK_EQ(shape().element_type(), primitive_util::NativeToPrimitiveType()); @@ -1118,8 +1154,8 @@ std::unique_ptr LiteralBase::Replicate(int64 times) const { for (int64 bound : shape().dimensions()) { bounds.push_back(bound); } - auto literal = - MakeUnique(ShapeUtil::MakeShape(shape().element_type(), bounds)); + auto literal = absl::make_unique( + ShapeUtil::MakeShape(shape().element_type(), bounds)); int64 elements = ShapeUtil::ElementsIn(literal->shape()); if (elements == 0) { return literal; diff --git a/tensorflow/compiler/xla/literal_comparison.cc b/tensorflow/compiler/xla/literal_comparison.cc index 94993cc87443ba8c22fd7c2eacfc8756d3f48edc..6883a6bbab4de252ba47c6d34bcecd2e75c80818 100644 --- a/tensorflow/compiler/xla/literal_comparison.cc +++ b/tensorflow/compiler/xla/literal_comparison.cc @@ -38,7 +38,8 @@ namespace { // between the left-hand-side and right-hand-side, by bit-casting to UnsignedT // -- on miscompare, a nice error message is given in the AssertionFailure. template -Status CompareFloatsBitwiseEqual(FloatT lhs, FloatT rhs) { +Status CompareFloatsBitwiseEqual( + FloatT lhs, FloatT rhs, tensorflow::gtl::ArraySlice multi_index) { auto ulhs = tensorflow::bit_cast(lhs); auto urhs = tensorflow::bit_cast(rhs); auto lhs_double = static_cast(lhs); @@ -46,9 +47,10 @@ Status CompareFloatsBitwiseEqual(FloatT lhs, FloatT rhs) { if (ulhs != urhs) { return InvalidArgument( "floating values are not bitwise-equal; and equality testing " - "was requested: %s=%g=%a vs %s=%g=%a", + "was requested: %s=%g=%a vs %s=%g=%a at index %s", StrCat(tensorflow::strings::Hex(ulhs)).c_str(), lhs_double, lhs_double, - StrCat(tensorflow::strings::Hex(urhs)).c_str(), rhs_double, rhs_double); + StrCat(tensorflow::strings::Hex(urhs)).c_str(), rhs_double, rhs_double, + LiteralUtil::MultiIndexAsString(multi_index).c_str()); } return Status::OK(); } @@ -57,39 +59,48 @@ Status CompareFloatsBitwiseEqual(FloatT lhs, FloatT rhs) { // bitwise helper above (this is the un-specialized fallback, to just use the // default gunit implementation). template -Status CompareEqual(NativeT lhs, NativeT rhs) { +Status CompareEqual(NativeT lhs, NativeT rhs, + tensorflow::gtl::ArraySlice multi_index) { if (lhs == rhs) { return Status::OK(); } - return InvalidArgument("Expected equality of these values:\n %s\n %s", - StrCat(lhs).c_str(), StrCat(rhs).c_str()); + return InvalidArgument( + "Expected equality of these values:\n %s\n %s\nat index %s", + StrCat(lhs).c_str(), StrCat(rhs).c_str(), + LiteralUtil::MultiIndexAsString(multi_index).c_str()); } // Specializations for floating types that do bitwise comparisons when equality // comparison is requested. template <> -Status CompareEqual(bfloat16 lhs, bfloat16 rhs) { - return CompareFloatsBitwiseEqual(lhs, rhs); +Status CompareEqual(bfloat16 lhs, bfloat16 rhs, + tensorflow::gtl::ArraySlice multi_index) { + return CompareFloatsBitwiseEqual(lhs, rhs, multi_index); } template <> -Status CompareEqual(Eigen::half lhs, Eigen::half rhs) { - return CompareFloatsBitwiseEqual(lhs, rhs); +Status CompareEqual( + Eigen::half lhs, Eigen::half rhs, + tensorflow::gtl::ArraySlice multi_index) { + return CompareFloatsBitwiseEqual(lhs, rhs, multi_index); } template <> -Status CompareEqual(float lhs, float rhs) { - return CompareFloatsBitwiseEqual(lhs, rhs); +Status CompareEqual(float lhs, float rhs, + tensorflow::gtl::ArraySlice multi_index) { + return CompareFloatsBitwiseEqual(lhs, rhs, multi_index); } template <> -Status CompareEqual(double lhs, double rhs) { - return CompareFloatsBitwiseEqual(lhs, rhs); +Status CompareEqual(double lhs, double rhs, + tensorflow::gtl::ArraySlice multi_index) { + return CompareFloatsBitwiseEqual(lhs, rhs, multi_index); } template <> -Status CompareEqual(complex64 lhs, complex64 rhs) { - auto res = CompareEqual(lhs.real(), rhs.real()); +Status CompareEqual(complex64 lhs, complex64 rhs, + tensorflow::gtl::ArraySlice multi_index) { + auto res = CompareEqual(lhs.real(), rhs.real(), multi_index); if (!res.ok()) { return res; } - return CompareEqual(lhs.imag(), rhs.imag()); + return CompareEqual(lhs.imag(), rhs.imag(), multi_index); } // A recursive function which iterates through every index of expected and @@ -102,7 +113,7 @@ Status Equal(LiteralSlice expected, LiteralSlice actual, if (dimension == expected.shape().dimensions_size()) { NativeT expected_value = expected.Get(multi_index); NativeT actual_value = actual.Get(multi_index); - return CompareEqual(expected_value, actual_value); + return CompareEqual(expected_value, actual_value, multi_index); } Status result; @@ -720,12 +731,10 @@ Status Equal(const LiteralSlice& expected, const LiteralSlice& actual) { return Status::OK(); } - return AppendStatus(result, - tensorflow::strings::Printf( - "\nat index: %s\nexpected: %s\nactual: %s", - LiteralUtil::MultiIndexAsString(multi_index).c_str(), - ToStringTruncated(expected).c_str(), - ToStringTruncated(actual).c_str())); + return AppendStatus( + result, tensorflow::strings::Printf("\nexpected: %s\nactual: %s", + ToStringTruncated(expected).c_str(), + ToStringTruncated(actual).c_str())); } Status Near(const LiteralSlice& expected, const LiteralSlice& actual, diff --git a/tensorflow/compiler/xla/literal_test.cc b/tensorflow/compiler/xla/literal_test.cc index e8f919950f0efc8b508f7ad4aee5233176bc0abd..c5d0c2c267e06f7d10651f57496c4d1dd76eff52 100644 --- a/tensorflow/compiler/xla/literal_test.cc +++ b/tensorflow/compiler/xla/literal_test.cc @@ -17,6 +17,7 @@ limitations under the License. #include +#include "absl/memory/memory.h" #include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/xla/array3d.h" #include "tensorflow/compiler/xla/array4d.h" @@ -355,15 +356,15 @@ TEST_F(LiteralUtilTest, TokenEquality) { TEST_F(LiteralUtilTest, DifferentLayoutEquality) { // Test equality with literals which have different layouts. - auto colmajor = - MakeUnique(ShapeUtil::MakeShapeWithLayout(F32, {2, 2}, {0, 1})); + auto colmajor = absl::make_unique( + ShapeUtil::MakeShapeWithLayout(F32, {2, 2}, {0, 1})); colmajor->Set({0, 0}, 1.0); colmajor->Set({0, 1}, 2.0); colmajor->Set({1, 0}, 3.0); colmajor->Set({1, 1}, 4.0); - auto rowmajor = - MakeUnique(ShapeUtil::MakeShapeWithLayout(F32, {2, 2}, {1, 0})); + auto rowmajor = absl::make_unique( + ShapeUtil::MakeShapeWithLayout(F32, {2, 2}, {1, 0})); rowmajor->Set({0, 0}, 1.0); rowmajor->Set({0, 1}, 2.0); rowmajor->Set({1, 0}, 3.0); @@ -1089,7 +1090,7 @@ TEST_F(LiteralUtilTest, Populate) { Shape shape = ShapeUtil::MakeShapeWithLayout( primitive_util::NativeToPrimitiveType(), data.dimensions, data.layout); - auto literal = MakeUnique(shape); + auto literal = absl::make_unique(shape); auto generator = [&](ArraySlice indexes) -> uint32 { // Offsets from linear index just to avoid R0 literals to be initialized // with zero. @@ -1131,7 +1132,7 @@ TEST_F(LiteralUtilTest, PopulateParallel) { Shape shape = ShapeUtil::MakeShapeWithLayout( primitive_util::NativeToPrimitiveType(), data.dimensions, data.layout); - auto literal = MakeUnique(shape); + auto literal = absl::make_unique(shape); auto generator = [&](ArraySlice indexes) -> uint32 { // Offsets from linear index just to avoid R0 literals to be initialized // with zero. diff --git a/tensorflow/compiler/xla/literal_util.cc b/tensorflow/compiler/xla/literal_util.cc index 548fbe8a83a3797aa8ac32dc1f6c085fc0100197..d4c7b76b2819d8b6b07297351d7cd9180e764c25 100644 --- a/tensorflow/compiler/xla/literal_util.cc +++ b/tensorflow/compiler/xla/literal_util.cc @@ -22,6 +22,7 @@ limitations under the License. #include #include +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/index_util.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" @@ -34,9 +35,9 @@ limitations under the License. #include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/mem.h" #include "tensorflow/core/platform/types.h" -using tensorflow::strings::Printf; using tensorflow::strings::StrCat; namespace xla { @@ -57,7 +58,7 @@ std::unique_ptr ConvertType(LiteralSlice literal) { primitive_util::NativeToPrimitiveType()); } }); - auto result = MakeUnique(result_shape); + auto result = absl::make_unique(result_shape); // Then copy over the data from 'literal' converting FromNativeT values to // ToNativeT values as necessary. @@ -102,7 +103,7 @@ std::unique_ptr ConvertType(LiteralSlice literal) { } /* static */ std::unique_ptr LiteralUtil::CreateToken() { - return MakeUnique(ShapeUtil::MakeTokenShape()); + return absl::make_unique(ShapeUtil::MakeTokenShape()); } /* static */ Literal LiteralUtil::Zero(PrimitiveType primitive_type) { @@ -279,7 +280,7 @@ std::unique_ptr ConvertType(LiteralSlice literal) { /* static */ std::unique_ptr LiteralUtil::CreateR1( const tensorflow::core::Bitmap& values) { - auto literal = MakeUnique( + auto literal = absl::make_unique( ShapeUtil::MakeShape(PRED, {static_cast(values.bits())})); literal->PopulateR1(values); return literal; @@ -287,7 +288,7 @@ std::unique_ptr ConvertType(LiteralSlice literal) { /* static */ std::unique_ptr LiteralUtil::CreateR1U8( tensorflow::StringPiece value) { - auto literal = MakeUnique( + auto literal = absl::make_unique( ShapeUtil::MakeShape(U8, {static_cast(value.size())})); for (int i = 0; i < value.size(); ++i) { literal->Set({i}, value[i]); @@ -312,7 +313,7 @@ std::unique_ptr ConvertType(LiteralSlice literal) { CHECK_EQ(ShapeUtil::ElementsIn(literal.shape()), new_num_elements); CHECK_EQ(new_dimensions.size(), minor_to_major.size()); - auto new_literal = MakeUnique( + auto new_literal = absl::make_unique( ShapeUtil::MakeShape(literal.shape().element_type(), new_dimensions)); // Create a new shape with the given minor-to-major layout. This shape is used @@ -436,7 +437,8 @@ std::unique_ptr ConvertType(LiteralSlice literal) { for (const auto* element : elements) { element_shapes.push_back(element->shape()); } - auto literal = MakeUnique(ShapeUtil::MakeTupleShape(element_shapes)); + auto literal = + absl::make_unique(ShapeUtil::MakeTupleShape(element_shapes)); for (int i = 0; i < elements.size(); ++i) { TF_CHECK_OK(literal->CopyFrom(*elements[i], /*dest_shape_index=*/{i})); } @@ -449,7 +451,8 @@ std::unique_ptr ConvertType(LiteralSlice literal) { for (const auto& element : elements) { element_shapes.push_back(element.shape()); } - auto literal = MakeUnique(ShapeUtil::MakeTupleShape(element_shapes)); + auto literal = + absl::make_unique(ShapeUtil::MakeTupleShape(element_shapes)); for (int i = 0; i < elements.size(); ++i) { TF_CHECK_OK(literal->CopyFrom(elements[i], /*dest_shape_index=*/{i})); } @@ -463,7 +466,8 @@ std::unique_ptr ConvertType(LiteralSlice literal) { for (const auto& element : elements) { element_shapes.push_back(element->shape()); } - auto literal = MakeUnique(ShapeUtil::MakeTupleShape(element_shapes)); + auto literal = + absl::make_unique(ShapeUtil::MakeTupleShape(element_shapes)); for (int64 i = 0; i < elements.size(); ++i) { TF_CHECK_OK( literal->MoveFrom(std::move(*elements[i]), /*dest_shape_index=*/{i})); diff --git a/tensorflow/compiler/xla/literal_util.h b/tensorflow/compiler/xla/literal_util.h index e3737a9d0051b32dc0becc19e1849c856a50e52e..1109021ea892a38c1134b3fee6c608c25167c675 100644 --- a/tensorflow/compiler/xla/literal_util.h +++ b/tensorflow/compiler/xla/literal_util.h @@ -27,6 +27,7 @@ limitations under the License. #include #include +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/array2d.h" #include "tensorflow/compiler/xla/array3d.h" #include "tensorflow/compiler/xla/array4d.h" @@ -34,7 +35,6 @@ limitations under the License. #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/primitive_util.h" -#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/sparse_index_array.h" #include "tensorflow/compiler/xla/status_macros.h" @@ -327,7 +327,7 @@ std::ostream& operator<<(std::ostream& out, const Literal& literal); template /* static */ std::unique_ptr LiteralUtil::CreateR0(NativeT value) { - auto literal = MakeUnique(ShapeUtil::MakeShape( + auto literal = absl::make_unique(ShapeUtil::MakeShape( primitive_util::NativeToPrimitiveType(), {})); literal->Set({}, value); return literal; @@ -336,7 +336,7 @@ template template /* static */ std::unique_ptr LiteralUtil::CreateR1( tensorflow::gtl::ArraySlice values) { - auto literal = MakeUnique( + auto literal = absl::make_unique( ShapeUtil::MakeShape(primitive_util::NativeToPrimitiveType(), {static_cast(values.size())})); literal->PopulateR1(values); @@ -347,7 +347,7 @@ template /* static */ std::unique_ptr LiteralUtil::CreateR2WithLayout( std::initializer_list> values, const Layout& layout) { - auto literal = MakeUnique(ShapeUtil::MakeShapeWithLayout( + auto literal = absl::make_unique(ShapeUtil::MakeShapeWithLayout( primitive_util::NativeToPrimitiveType(), {static_cast(values.size()), static_cast(values.begin()->size())}, @@ -433,9 +433,10 @@ template int64 rank = dimensions.size(); CHECK_EQ(num_elements, indices.index_count()); CHECK_EQ(rank, indices.rank()); - auto literal = MakeUnique(ShapeUtil::MakeShapeWithSparseLayout( - primitive_util::NativeToPrimitiveType(), dimensions, - indices.max_indices())); + auto literal = + absl::make_unique(ShapeUtil::MakeShapeWithSparseLayout( + primitive_util::NativeToPrimitiveType(), dimensions, + indices.max_indices())); literal->PopulateSparse(indices, values, sort); return literal; } @@ -451,7 +452,7 @@ template template /* static */ std::unique_ptr LiteralUtil::CreateFromArrayWithLayout( const Array& values, const Layout& layout) { - auto literal = MakeUnique(ShapeUtil::MakeShapeWithLayout( + auto literal = absl::make_unique(ShapeUtil::MakeShapeWithLayout( primitive_util::NativeToPrimitiveType(), values.dimensions(), AsInt64Slice(layout.minor_to_major()))); literal->PopulateFromArray(values); @@ -571,8 +572,9 @@ template /* static */ std::unique_ptr LiteralUtil::CreateFullWithDescendingLayout( tensorflow::gtl::ArraySlice dimensions, NativeT value) { - auto literal = MakeUnique(ShapeUtil::MakeShapeWithDescendingLayout( - primitive_util::NativeToPrimitiveType(), dimensions)); + auto literal = + absl::make_unique(ShapeUtil::MakeShapeWithDescendingLayout( + primitive_util::NativeToPrimitiveType(), dimensions)); literal->PopulateWithValue(value); return literal; } @@ -584,7 +586,7 @@ LiteralUtil::CreateRandomLiteral( const std::function)>& generator) { using NativeT = typename primitive_util::PrimitiveTypeToNative::type; TF_RET_CHECK(shape.element_type() == type); - auto literal = MakeUnique(shape); + auto literal = absl::make_unique(shape); TF_RETURN_IF_ERROR(literal.get()->Populate( [&](tensorflow::gtl::ArraySlice indexes) { return generator(indexes); diff --git a/tensorflow/compiler/xla/metric_table_report.cc b/tensorflow/compiler/xla/metric_table_report.cc index fed0e58e66a04df2ff9554cb0dd0053b7c669803..69ef4f7a2f3ea559a334a11cbe8392b610742bab 100644 --- a/tensorflow/compiler/xla/metric_table_report.cc +++ b/tensorflow/compiler/xla/metric_table_report.cc @@ -134,8 +134,7 @@ void MetricTableReport::AppendHeader() { void MetricTableReport::AppendCategoryTable() { const std::vector categories = MakeCategories(&entries_); - AppendLine("********** categories table **********"); - AppendLine("The left hand side numbers are ", metric_name_, "."); + AppendLine("********** categories table for ", metric_name_, " **********"); AppendLine(); double metric_sum = UnaccountedMetric(); @@ -185,8 +184,8 @@ void MetricTableReport::AppendCategoryTable() { } void MetricTableReport::AppendEntryTable() { - AppendLine("********** ", entry_name_, " table **********"); - AppendLine("The left hand side numbers are ", metric_name_, "."); + AppendLine("********** ", entry_name_, " table for ", metric_name_, + " **********"); AppendLine(); double metric_sum = UnaccountedMetric(); diff --git a/tensorflow/compiler/xla/packed_literal_reader.cc b/tensorflow/compiler/xla/packed_literal_reader.cc index 6b7fd10d63f8f97b0e0bf7570488c06323368d75..55c4a80e29b7d493e676e412dfd259677169b417 100644 --- a/tensorflow/compiler/xla/packed_literal_reader.cc +++ b/tensorflow/compiler/xla/packed_literal_reader.cc @@ -19,9 +19,9 @@ limitations under the License. #include #include +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/literal.h" -#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/types.h" @@ -57,7 +57,7 @@ StatusOr> PackedLiteralReader::Read( PrimitiveType_Name(shape.element_type()).c_str()); } - auto result = MakeUnique(literal_shape); + auto result = absl::make_unique(literal_shape); result->PopulateWithValue(std::numeric_limits::quiet_NaN()); int64 elements = ShapeUtil::ElementsIn(shape); diff --git a/tensorflow/compiler/xla/python/BUILD b/tensorflow/compiler/xla/python/BUILD index e26e35eb119f1a7c89088e91f959846abbe739f3..a91336c3ac920bc1f28a17e2b9835eba81c94d75 100644 --- a/tensorflow/compiler/xla/python/BUILD +++ b/tensorflow/compiler/xla/python/BUILD @@ -53,12 +53,13 @@ cc_library( "//tensorflow/compiler/xla/client:client_library", "//tensorflow/compiler/xla/client:executable_build_options", "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/client/lib:math", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", "//tensorflow/compiler/xla/service:shaped_buffer", "//tensorflow/core:framework_lite", "//tensorflow/core:lib", + "@com_google_absl//absl/memory", ], ) diff --git a/tensorflow/compiler/xla/python/local_computation_builder.cc b/tensorflow/compiler/xla/python/local_computation_builder.cc index fbcf0f19698bd7b710cb52f1da7aaa81227ad56f..c133a2041978f5b13d94a5a579525c8e2d11fbeb 100644 --- a/tensorflow/compiler/xla/python/local_computation_builder.cc +++ b/tensorflow/compiler/xla/python/local_computation_builder.cc @@ -14,11 +14,10 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/xla/python/local_computation_builder.h" +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/client/lib/math.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" -#include "tensorflow/compiler/xla/client/xla_computation.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/executable_run_options.h" -#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/platform/thread_annotations.h" @@ -576,6 +575,16 @@ StatusOr LocalComputationBuilder::IsConstant(const LocalOp& operand) { return builder_.IsConstant(operand.op()); } +LocalOp LocalComputationBuilder::Sort(const LocalOp& operand, int64 dimension) { + return xla::Sort(operand.op(), tensorflow::gtl::nullopt, dimension); +} + +LocalOp LocalComputationBuilder::SortKeyVal(const LocalOp& keys, + const LocalOp& values, + int64 dimension) { + return xla::Sort(keys.op(), values.op(), dimension); +} + StatusOr LocalComputationBuilder::BuildConstantSubGraph( const LocalOp& operand) { TF_ASSIGN_OR_RETURN(XlaComputation computation, @@ -625,6 +634,7 @@ _FORWARD_BINOP(ShiftRightArithmetic) _FORWARD_BINOP(ShiftRightLogical) _FORWARD_BINOP(Atan2) _FORWARD_BINOP(Pow) +_FORWARD_BINOP(Complex) _FORWARD_UNOP(Not) _FORWARD_UNOP(Abs) _FORWARD_UNOP(Exp) @@ -640,7 +650,6 @@ _FORWARD_UNOP(Sin) _FORWARD_UNOP(Tanh) _FORWARD_UNOP(IsFinite) _FORWARD_UNOP(Neg) -_FORWARD_UNOP(Sort) _FORWARD_UNOP(Sqrt) _FORWARD_UNOP(Rsqrt) _FORWARD_UNOP(Square) @@ -659,6 +668,9 @@ _FORWARD_UNOP(Asinh) _FORWARD_UNOP(Atanh) _FORWARD_UNOP(Cosh) _FORWARD_UNOP(Sinh) +_FORWARD_UNOP(Real) +_FORWARD_UNOP(Imag) +_FORWARD_UNOP(Conj) #undef _FORWARD #undef _FORWARD_UNOP diff --git a/tensorflow/compiler/xla/python/local_computation_builder.h b/tensorflow/compiler/xla/python/local_computation_builder.h index 57da7e53d5b2da3859cc2d29b39eb75e66e93047..5f9078ab847264663c46c294b2f0b65c2b154750 100644 --- a/tensorflow/compiler/xla/python/local_computation_builder.h +++ b/tensorflow/compiler/xla/python/local_computation_builder.h @@ -19,7 +19,7 @@ limitations under the License. #include "tensorflow/compiler/xla/client/client_library.h" #include "tensorflow/compiler/xla/client/executable_build_options.h" #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/service/shaped_buffer.h" #include "tensorflow/compiler/xla/xla_data.pb.h" @@ -301,6 +301,11 @@ class LocalComputationBuilder { StatusOr IsConstant(const LocalOp& operand); + LocalOp Sort(const LocalOp& operand, int64 dimension); + + LocalOp SortKeyVal(const LocalOp& keys, const LocalOp& values, + int64 dimension); + StatusOr BuildConstantSubGraph(const LocalOp& operand); #define _FORWARD(method_name, return_sig, args_sig) \ @@ -341,6 +346,7 @@ class LocalComputationBuilder { _FORWARD_BINOP(ShiftRightLogical) _FORWARD_BINOP(Atan2) _FORWARD_BINOP(Pow) + _FORWARD_BINOP(Complex) _FORWARD_UNOP(Not) _FORWARD_UNOP(Abs) _FORWARD_UNOP(Exp) @@ -356,7 +362,6 @@ class LocalComputationBuilder { _FORWARD_UNOP(Tanh) _FORWARD_UNOP(IsFinite) _FORWARD_UNOP(Neg) - _FORWARD_UNOP(Sort) _FORWARD_UNOP(Sqrt) _FORWARD_UNOP(Rsqrt) _FORWARD_UNOP(Square) @@ -375,6 +380,9 @@ class LocalComputationBuilder { _FORWARD_UNOP(Atanh) _FORWARD_UNOP(Cosh) _FORWARD_UNOP(Sinh) + _FORWARD_UNOP(Real) + _FORWARD_UNOP(Imag) + _FORWARD_UNOP(Conj) #undef _FORWARD #undef _FORWARD_UNOP diff --git a/tensorflow/compiler/xla/python/local_computation_builder.i b/tensorflow/compiler/xla/python/local_computation_builder.i index 9b8b0aa7f28e64f434bb24f88a3a9cbe177f8a78..fa5d75908f93683c3e02f97e8e39edf5d111a9e3 100644 --- a/tensorflow/compiler/xla/python/local_computation_builder.i +++ b/tensorflow/compiler/xla/python/local_computation_builder.i @@ -1011,6 +1011,7 @@ tensorflow::ImportNumpy(); %unignore xla::swig::LocalComputationBuilder::Pow; %unignore xla::swig::LocalComputationBuilder::Neg; %unignore xla::swig::LocalComputationBuilder::Sort; +%unignore xla::swig::LocalComputationBuilder::SortKeyVal; %unignore xla::swig::LocalComputationBuilder::Sqrt; %unignore xla::swig::LocalComputationBuilder::Rsqrt; %unignore xla::swig::LocalComputationBuilder::Square; @@ -1029,6 +1030,10 @@ tensorflow::ImportNumpy(); %unignore xla::swig::LocalComputationBuilder::Atanh; %unignore xla::swig::LocalComputationBuilder::Cosh; %unignore xla::swig::LocalComputationBuilder::Sinh; +%unignore xla::swig::LocalComputationBuilder::Real; +%unignore xla::swig::LocalComputationBuilder::Imag; +%unignore xla::swig::LocalComputationBuilder::Conj; +%unignore xla::swig::LocalComputationBuilder::Complex; %unignore xla::swig::DestructureLocalShapedBufferTuple; %unignore xla::swig::DeleteLocalShapedBuffer; %unignore xla::swig::DeleteLocalComputation; diff --git a/tensorflow/compiler/xla/python/numpy_bridge.cc b/tensorflow/compiler/xla/python/numpy_bridge.cc index 71351abd593d45fb5080112438a91df368eee173..6f665faf61b25b23a32ce4d0a012543ba18d7e64 100644 --- a/tensorflow/compiler/xla/python/numpy_bridge.cc +++ b/tensorflow/compiler/xla/python/numpy_bridge.cc @@ -50,6 +50,8 @@ int PrimitiveTypeToNumpyType(PrimitiveType primitive_type) { return NPY_FLOAT32; case F64: return NPY_FLOAT64; + case C64: + return NPY_COMPLEX64; case TUPLE: return NPY_OBJECT; default: @@ -83,6 +85,8 @@ PrimitiveType NumpyTypeToPrimitiveType(int np_type) { return F32; case NPY_FLOAT64: return F64; + case NPY_COMPLEX64: + return C64; case NPY_OBJECT: return TUPLE; default: @@ -104,6 +108,7 @@ bool NumpyTypeIsValid(int np_type) { case NPY_FLOAT16: case NPY_FLOAT32: case NPY_FLOAT64: + case NPY_COMPLEX64: case NPY_OBJECT: return true; default: @@ -425,6 +430,9 @@ Status CopyNumpyArrayToLiteral(int np_type, PyArrayObject* py_array, case NPY_FLOAT64: CopyNumpyArrayToLiteral(py_array, literal); break; + case NPY_COMPLEX64: + CopyNumpyArrayToLiteral(py_array, literal); + break; default: return InvalidArgument( "No XLA literal container for Numpy type number: %d", np_type); @@ -462,6 +470,9 @@ void CopyLiteralToNumpyArray(int np_type, const LiteralSlice& literal, case NPY_FLOAT64: CopyLiteralToNumpyArray(literal, py_array); break; + case NPY_COMPLEX64: + CopyLiteralToNumpyArray(literal, py_array); + break; default: LOG(FATAL) << "No XLA literal container for Numpy type" << np_type; } diff --git a/tensorflow/compiler/xla/python/xla_client.py b/tensorflow/compiler/xla/python/xla_client.py index c0105b385b02e13b360ad1fb5af734d2209a92c2..fa4366ff0789a3d05c26479a746a18dfcf7e902b 100644 --- a/tensorflow/compiler/xla/python/xla_client.py +++ b/tensorflow/compiler/xla/python/xla_client.py @@ -105,7 +105,6 @@ _UNARY_OPS = [ 'Square', 'Reciprocal', 'Neg', - 'Sort', 'Erf', 'Erfc', 'ErfInv', @@ -120,6 +119,9 @@ _UNARY_OPS = [ 'Atanh', 'Cosh', 'Sinh', + 'Real', + 'Imag', + 'Conj', ] _BINARY_OPS = [ @@ -144,6 +146,7 @@ _BINARY_OPS = [ 'ShiftRightArithmetic', 'ShiftRightLogical', 'Atan2', + 'Complex', ] @@ -1214,6 +1217,14 @@ class ComputationBuilder(object): lhs_dilation, rhs_dilation, dimension_numbers) + def Sort(self, operand, dimension=-1): + """Enqueues a sort operation onto the computation.""" + return self._client.Sort(operand, dimension) + + def SortKeyVal(self, keys, values, dimension=-1): + """Enqueues a key-value sort operation onto the computation.""" + return self._client.SortKeyVal(keys, values, dimension) + def _forward_methods_to_local_builder(): """Forward remaining ComputationBuilder methods to the C API. diff --git a/tensorflow/compiler/xla/python_api/BUILD b/tensorflow/compiler/xla/python_api/BUILD index 8999cda5ef852d1246bea45a3312575ec1ac0721..d790c4db6c466a2bf4d2cf30365749fb901f74a0 100644 --- a/tensorflow/compiler/xla/python_api/BUILD +++ b/tensorflow/compiler/xla/python_api/BUILD @@ -10,6 +10,8 @@ py_library( srcs = ["types.py"], deps = [ "//tensorflow/compiler/xla:xla_data_proto_py", + "//tensorflow/python:dtypes", + "//tensorflow/python:platform", "//third_party/py/numpy", ], ) diff --git a/tensorflow/compiler/xla/python_api/types.py b/tensorflow/compiler/xla/python_api/types.py index b60f8dce92ace1b2c682374a2605b3a477936bbc..57dfce3971b829d2a3052d347e5d2d322db0c841 100644 --- a/tensorflow/compiler/xla/python_api/types.py +++ b/tensorflow/compiler/xla/python_api/types.py @@ -20,9 +20,10 @@ from __future__ import print_function import collections -import numpy as np +import numpy as _np # Avoids becoming a part of public Tensorflow API. from tensorflow.compiler.xla import xla_data_pb2 +from tensorflow.python.framework import dtypes # Records corresponsence between a XLA primitive type and Python/Numpy types. # @@ -40,76 +41,82 @@ TypeConversionRecord = collections.namedtuple('TypeConversionRecord', [ # Maps from XLA primitive types to TypeConversionRecord. MAP_XLA_TYPE_TO_RECORD = { + xla_data_pb2.BF16: + TypeConversionRecord( + primitive_type=xla_data_pb2.BF16, + numpy_dtype=dtypes.bfloat16.as_numpy_dtype, + literal_field_name='bf16s', + literal_field_type=float), xla_data_pb2.F16: TypeConversionRecord( primitive_type=xla_data_pb2.F16, - numpy_dtype=np.float16, + numpy_dtype=_np.float16, literal_field_name='f16s', literal_field_type=float), xla_data_pb2.F32: TypeConversionRecord( primitive_type=xla_data_pb2.F32, - numpy_dtype=np.float32, + numpy_dtype=_np.float32, literal_field_name='f32s', literal_field_type=float), xla_data_pb2.F64: TypeConversionRecord( primitive_type=xla_data_pb2.F64, - numpy_dtype=np.float64, + numpy_dtype=_np.float64, literal_field_name='f64s', literal_field_type=float), xla_data_pb2.S8: TypeConversionRecord( primitive_type=xla_data_pb2.S8, - numpy_dtype=np.int8, + numpy_dtype=_np.int8, literal_field_name='s8s', literal_field_type=int), xla_data_pb2.S16: TypeConversionRecord( primitive_type=xla_data_pb2.S16, - numpy_dtype=np.int16, + numpy_dtype=_np.int16, literal_field_name='s16s', literal_field_type=int), xla_data_pb2.S32: TypeConversionRecord( primitive_type=xla_data_pb2.S32, - numpy_dtype=np.int32, + numpy_dtype=_np.int32, literal_field_name='s32s', literal_field_type=int), xla_data_pb2.S64: TypeConversionRecord( primitive_type=xla_data_pb2.S64, - numpy_dtype=np.int64, + numpy_dtype=_np.int64, literal_field_name='s64s', literal_field_type=int), xla_data_pb2.U8: TypeConversionRecord( primitive_type=xla_data_pb2.U8, - numpy_dtype=np.uint8, + numpy_dtype=_np.uint8, literal_field_name='s8s', literal_field_type=int), xla_data_pb2.U16: TypeConversionRecord( primitive_type=xla_data_pb2.U16, - numpy_dtype=np.uint16, + numpy_dtype=_np.uint16, literal_field_name='s16s', literal_field_type=int), xla_data_pb2.U32: TypeConversionRecord( primitive_type=xla_data_pb2.U32, - numpy_dtype=np.uint32, + numpy_dtype=_np.uint32, literal_field_name='s32s', literal_field_type=int), xla_data_pb2.U64: TypeConversionRecord( primitive_type=xla_data_pb2.U64, - numpy_dtype=np.uint64, + numpy_dtype=_np.uint64, literal_field_name='s64s', literal_field_type=int), xla_data_pb2.PRED: TypeConversionRecord( primitive_type=xla_data_pb2.PRED, - numpy_dtype=np.bool, + numpy_dtype=_np.bool, literal_field_name='preds', literal_field_type=bool) } @@ -119,6 +126,6 @@ MAP_XLA_TYPE_TO_RECORD = { # doesn't work as expected (https://github.com/numpy/numpy/issues/7242). Thus, # when keying by dtype in this dict, we use the string form of dtypes. MAP_DTYPE_TO_RECORD = { - str(np.dtype(record.numpy_dtype)): record + str(_np.dtype(record.numpy_dtype)): record for record in MAP_XLA_TYPE_TO_RECORD.values() } diff --git a/tensorflow/compiler/xla/python_api/xla_literal.py b/tensorflow/compiler/xla/python_api/xla_literal.py index b040098c294ffaae92b72f678947f99289239314..757e41a78ad2b57d2ef6e1f3055160be22c7b3ed 100644 --- a/tensorflow/compiler/xla/python_api/xla_literal.py +++ b/tensorflow/compiler/xla/python_api/xla_literal.py @@ -18,7 +18,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import numpy as np +import numpy as _np # Avoids becoming a part of public Tensorflow API. from tensorflow.compiler.xla import xla_data_pb2 from tensorflow.compiler.xla.python_api import types @@ -35,7 +35,7 @@ def ConvertLiteralToNumpyArray(literal): type_record = types.MAP_XLA_TYPE_TO_RECORD[element_type] if not literal.shape.dimensions: - return np.array( + return _np.array( getattr(literal, type_record.literal_field_name)[0], type_record.numpy_dtype) else: @@ -54,7 +54,7 @@ def ConvertLiteralToNumpyArray(literal): numpy_reshaper = lambda arr: arr.reshape(numpy_shape, order='C') else: raise NotImplementedError('Unsupported layout: {0}'.format(layout_order)) - ndarray = np.array( + ndarray = _np.array( getattr(literal, type_record.literal_field_name), copy=False, dtype=type_record.numpy_dtype) @@ -69,11 +69,11 @@ def _ConvertNumpyArrayToLiteral(ndarray): if ndarray.ndim == 0: getattr(literal, type_record.literal_field_name).append( - np.asscalar(ndarray.astype(type_record.literal_field_type))) + _np.asscalar(ndarray.astype(type_record.literal_field_type))) else: # Ndarrays with boolean dtypes need special type conversion with protobufs - if ndarray.dtype in {np.bool_, np.dtype('bool')}: - for element in np.nditer(ndarray): + if ndarray.dtype in {_np.bool_, _np.dtype('bool')}: + for element in _np.nditer(ndarray): getattr(literal, type_record.literal_field_name).append( type_record.literal_field_type(element)) else: diff --git a/tensorflow/compiler/xla/python_api/xla_shape.py b/tensorflow/compiler/xla/python_api/xla_shape.py index 6af28958035bbb03e7e1dbb0d0c7bb2c2f25b96d..f158f6b2410352432445f669155aff0af5526abf 100644 --- a/tensorflow/compiler/xla/python_api/xla_shape.py +++ b/tensorflow/compiler/xla/python_api/xla_shape.py @@ -18,7 +18,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import numpy as np +import numpy as _np # Avoids becoming a part of public Tensorflow API. from tensorflow.compiler.xla import xla_data_pb2 from tensorflow.compiler.xla.python_api import types @@ -111,7 +111,7 @@ def _CreateShapeFromNumpy(ndarray): # pylint: disable=invalid-name # Set the shape's layout based on the ordering of ndarray. # Numpy arrays come in two orders: Fortran (column-major) and C (row-major). - if np.isfortran(ndarray): + if _np.isfortran(ndarray): # Column-major layout. This corresponds to a "dimension order is # minor-to-major" layout in XLA. layout = range(ndarray.ndim) diff --git a/tensorflow/compiler/xla/reference_util.cc b/tensorflow/compiler/xla/reference_util.cc index 6397f1f47915aaa559beda467c26c66795c98f60..3de7ee2bc8c936680735102607436af77a17769c 100644 --- a/tensorflow/compiler/xla/reference_util.cc +++ b/tensorflow/compiler/xla/reference_util.cc @@ -18,7 +18,8 @@ limitations under the License. #include #include -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "absl/memory/memory.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul.h" #include "tensorflow/compiler/xla/service/hlo_evaluator.h" @@ -43,7 +44,7 @@ std::unique_ptr> MatmulArray2DImpl( int m = lhs.height(); int n = rhs.width(); int k = lhs.width(); - auto result = MakeUnique>(m, n); + auto result = absl::make_unique>(m, n); // Because Eigen is a header-oriented library, make sure that the Eigen code // is the same as the code used by the CPU backend (otherwise the linker will // randomly pick *some* definition). @@ -77,7 +78,8 @@ std::unique_ptr> MatmulArray2DImpl( /* static */ std::unique_ptr> ReferenceUtil::Array2DF32ToF64( const Array2D& input) { - auto result = MakeUnique>(input.height(), input.width()); + auto result = + absl::make_unique>(input.height(), input.width()); for (int64 rowno = 0; rowno < input.height(); ++rowno) { for (int64 colno = 0; colno < input.height(); ++colno) { (*result)(rowno, colno) = input(rowno, colno); @@ -126,8 +128,8 @@ ReferenceUtil::ConvArray3DGeneralDimensionsDilated( a4dlhs, a4drhs, {kernel_stride, 1}, padding, {lhs_dilation, 1}, {rhs_dilation, 1}, dnums2d); - auto convr3 = MakeUnique>(convr4->planes(), convr4->depth(), - convr4->height()); + auto convr3 = absl::make_unique>( + convr4->planes(), convr4->depth(), convr4->height()); convr4->Each( [&](tensorflow::gtl::ArraySlice indices, float* value_ptr) { CHECK_EQ(indices[3], 0); @@ -201,7 +203,7 @@ ReferenceUtil::ReduceWindow1DGeneric( window_util::StridedBound(padded_width, window[i], stride[i]); pad_low[i] = padding[i].first; } - auto result = MakeUnique>(window_counts[0]); + auto result = absl::make_unique>(window_counts[0]); // Do a full 1D reduce window. for (int64 i0 = 0; i0 < window_counts[0]; ++i0) { @@ -247,7 +249,8 @@ ReferenceUtil::ReduceWindow2DGeneric( window_util::StridedBound(padded_width, window[i], stride[i]); pad_low[i] = padding[i].first; } - auto result = MakeUnique>(window_counts[0], window_counts[1]); + auto result = + absl::make_unique>(window_counts[0], window_counts[1]); // Do a full 2D reduce window. for (int64 i0 = 0; i0 < window_counts[0]; ++i0) { @@ -296,8 +299,8 @@ ReferenceUtil::ReduceWindow2DGeneric( WindowCount(dim_lengths[i], window[i], stride[i], padding); pad_low[i] = padding_both[i].first; } - auto result = MakeUnique>(window_counts[0], window_counts[1], - window_counts[2]); + auto result = absl::make_unique>( + window_counts[0], window_counts[1], window_counts[2]); for (int64 i0 = 0; i0 < window_counts[0]; ++i0) { for (int64 i1 = 0; i1 < window_counts[1]; ++i1) { @@ -358,8 +361,8 @@ ReferenceUtil::ReduceWindow4DGeneric( window_util::StridedBound(padded_width, window[i], stride[i]); pad_low[i] = padding[i].first; } - auto result = MakeUnique>(window_counts[0], window_counts[1], - window_counts[2], window_counts[3]); + auto result = absl::make_unique>( + window_counts[0], window_counts[1], window_counts[2], window_counts[3]); // Do a full 4D reduce window. for (int64 i0 = 0; i0 < window_counts[0]; ++i0) { for (int64 i1 = 0; i1 < window_counts[1]; ++i1) { @@ -426,8 +429,8 @@ ReferenceUtil::SelectAndScatter4DGePlus( const tensorflow::gtl::ArraySlice& window, const tensorflow::gtl::ArraySlice& stride, bool same_padding) { Padding padding = same_padding ? Padding::kSame : Padding::kValid; - auto result = MakeUnique>(operand.n1(), operand.n2(), - operand.n3(), operand.n4()); + auto result = absl::make_unique>(operand.n1(), operand.n2(), + operand.n3(), operand.n4()); std::vector dim_lengths{operand.n1(), operand.n2(), operand.n3(), operand.n4()}; auto padding_both = xla::MakePadding(dim_lengths, window, stride, padding); @@ -583,10 +586,10 @@ ReferenceUtil::ConvArray4DGeneralDimensionsDilated( CHECK_EQ(ShapeUtil::Rank(result_literal->shape()), 4); auto result = - MakeUnique>(result_literal->shape().dimensions(0), - result_literal->shape().dimensions(1), - result_literal->shape().dimensions(2), - result_literal->shape().dimensions(3)); + absl::make_unique>(result_literal->shape().dimensions(0), + result_literal->shape().dimensions(1), + result_literal->shape().dimensions(2), + result_literal->shape().dimensions(3)); result->Each([&](tensorflow::gtl::ArraySlice indices, float* value) { *value = result_literal->Get(indices); @@ -601,7 +604,7 @@ ReferenceUtil::ReduceToColArray2D( const std::function& reduce_function) { int64 rows = matrix.height(); int64 cols = matrix.width(); - auto result = MakeUnique>(); + auto result = absl::make_unique>(); for (int64 i = 0; i < rows; ++i) { float acc = init; for (int64 j = 0; j < cols; ++j) { @@ -618,7 +621,7 @@ ReferenceUtil::ReduceToRowArray2D( const std::function& reduce_function) { int64 rows = matrix.height(); int64 cols = matrix.width(); - auto result = MakeUnique>(); + auto result = absl::make_unique>(); for (int64 i = 0; i < cols; ++i) { float acc = init; for (int64 j = 0; j < rows; ++j) { @@ -674,8 +677,8 @@ ReferenceUtil::ReduceToRowArray2D( /* static */ std::unique_ptr> ReferenceUtil::Broadcast1DTo4D( const std::vector& array, const std::vector& bounds, int64 broadcast_from_dim) { - auto result = - MakeUnique>(bounds[0], bounds[1], bounds[2], bounds[3]); + auto result = absl::make_unique>(bounds[0], bounds[1], + bounds[2], bounds[3]); for (int64 i = 0; i < result->n1(); ++i) { for (int64 j = 0; j < result->n2(); ++j) { for (int64 k = 0; k < result->n3(); ++k) { @@ -710,7 +713,7 @@ ReferenceUtil::ReduceToRowArray2D( CHECK_EQ(dims.size(), 1); int64 rows = dims[0] == 0 ? array.n2() : array.n1(); int64 cols = dims[0] == 2 ? array.n2() : array.n3(); - auto result = MakeUnique>(rows, cols); + auto result = absl::make_unique>(rows, cols); result->Fill(init); for (int i0 = 0; i0 < array.n1(); ++i0) { for (int i1 = 0; i1 < array.n2(); ++i1) { @@ -730,7 +733,7 @@ ReferenceUtil::ReduceToRowArray2D( const std::function& map_function) { int64 rows = matrix.height(); int64 cols = matrix.width(); - auto result = MakeUnique>(rows, cols); + auto result = absl::make_unique>(rows, cols); for (int64 i = 0; i < rows; ++i) { for (int64 j = 0; j < cols; ++j) { (*result)(i, j) = map_function(matrix(i, j)); @@ -746,7 +749,7 @@ ReferenceUtil::ReduceToRowArray2D( CHECK_EQ(lhs.width(), rhs.width()); int64 rows = lhs.height(); int64 cols = rhs.width(); - auto result = MakeUnique>(rows, cols); + auto result = absl::make_unique>(rows, cols); for (int64 i = 0; i < rows; ++i) { for (int64 j = 0; j < cols; ++j) { (*result)(i, j) = map_function(lhs(i, j), rhs(i, j)); @@ -760,7 +763,7 @@ ReferenceUtil::ReduceToRowArray2D( const std::function& map_function) { int64 rows = matrix.height(); int64 cols = matrix.width(); - auto result = MakeUnique>(rows, cols); + auto result = absl::make_unique>(rows, cols); for (int64 i = 0; i < rows; ++i) { for (int64 j = 0; j < cols; ++j) { (*result)(i, j) = map_function(matrix(i, j), i, j); diff --git a/tensorflow/compiler/xla/reference_util.h b/tensorflow/compiler/xla/reference_util.h index 8fa6961d197dce519cf151283b8bc0836a4615c0..88f853a3591c25289a8022909da8cdd4437883a6 100644 --- a/tensorflow/compiler/xla/reference_util.h +++ b/tensorflow/compiler/xla/reference_util.h @@ -22,11 +22,11 @@ limitations under the License. #include #include +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/array2d.h" #include "tensorflow/compiler/xla/array3d.h" #include "tensorflow/compiler/xla/array4d.h" #include "tensorflow/compiler/xla/client/padding.h" -#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/gtl/array_slice.h" @@ -42,7 +42,8 @@ class ReferenceUtil { template static std::unique_ptr> TransposeArray2D( const Array2D& operand) { - auto result = MakeUnique>(operand.width(), operand.height()); + auto result = + absl::make_unique>(operand.width(), operand.height()); for (int64 w = 0; w < operand.width(); ++w) { for (int64 h = 0; h < operand.height(); ++h) { (*result)(w, h) = operand(h, w); @@ -242,7 +243,7 @@ class ReferenceUtil { const Array2D& rhs, int concatenate_dimension) { CHECK(0 <= concatenate_dimension && concatenate_dimension < 2); - auto result = MakeUnique>( + auto result = absl::make_unique>( concatenate_dimension == 0 ? lhs.n1() + rhs.n1() : lhs.n1(), concatenate_dimension == 1 ? lhs.n2() + rhs.n2() : lhs.n2()); for (int64 i0 = 0; i0 < result->n1(); ++i0) { @@ -276,7 +277,8 @@ class ReferenceUtil { out_dims[i] = lhs_dims[i] + rhs_dims[i]; } } - auto result = MakeUnique>(out_dims[0], out_dims[1], out_dims[2]); + auto result = + absl::make_unique>(out_dims[0], out_dims[1], out_dims[2]); for (int64 i0 = 0; i0 < result->n1(); ++i0) { for (int64 i1 = 0; i1 < result->n2(); ++i1) { for (int64 i2 = 0; i2 < result->n3(); ++i2) { @@ -310,8 +312,8 @@ class ReferenceUtil { out_dims[i] = lhs_dims[i] + rhs_dims[i]; } } - auto result = MakeUnique>(out_dims[0], out_dims[1], out_dims[2], - out_dims[3]); + auto result = absl::make_unique>(out_dims[0], out_dims[1], + out_dims[2], out_dims[3]); for (int64 i0 = 0; i0 < result->n1(); ++i0) { for (int64 i1 = 0; i1 < result->n2(); ++i1) { for (int64 i2 = 0; i2 < result->n3(); ++i2) { @@ -355,9 +357,9 @@ class ReferenceUtil { CHECK_LE(limits[1], input.n2()); CHECK_GE(strides[0], 1); CHECK_GE(strides[1], 1); - auto result = - MakeUnique>(CeilOfRatio(limits[0] - starts[0], strides[0]), - CeilOfRatio(limits[1] - starts[1], strides[1])); + auto result = absl::make_unique>( + CeilOfRatio(limits[0] - starts[0], strides[0]), + CeilOfRatio(limits[1] - starts[1], strides[1])); for (int64 i0 = 0; i0 < result->n1(); ++i0) { for (int64 i1 = 0; i1 < result->n2(); ++i1) { (*result)(i0, i1) = @@ -381,10 +383,10 @@ class ReferenceUtil { CHECK_GE(strides[0], 1); CHECK_GE(strides[1], 1); CHECK_GE(strides[2], 1); - auto result = - MakeUnique>(CeilOfRatio(limits[0] - starts[0], strides[0]), - CeilOfRatio(limits[1] - starts[1], strides[1]), - CeilOfRatio(limits[2] - starts[2], strides[2])); + auto result = absl::make_unique>( + CeilOfRatio(limits[0] - starts[0], strides[0]), + CeilOfRatio(limits[1] - starts[1], strides[1]), + CeilOfRatio(limits[2] - starts[2], strides[2])); for (int64 i0 = 0; i0 < result->n1(); ++i0) { for (int64 i1 = 0; i1 < result->n2(); ++i1) { @@ -415,11 +417,11 @@ class ReferenceUtil { CHECK_GE(strides[1], 1); CHECK_GE(strides[2], 1); CHECK_GE(strides[3], 1); - auto result = - MakeUnique>(CeilOfRatio(limits[0] - starts[0], strides[0]), - CeilOfRatio(limits[1] - starts[1], strides[1]), - CeilOfRatio(limits[2] - starts[2], strides[2]), - CeilOfRatio(limits[3] - starts[3], strides[3])); + auto result = absl::make_unique>( + CeilOfRatio(limits[0] - starts[0], strides[0]), + CeilOfRatio(limits[1] - starts[1], strides[1]), + CeilOfRatio(limits[2] - starts[2], strides[2]), + CeilOfRatio(limits[3] - starts[3], strides[3])); for (int64 i0 = 0; i0 < result->n1(); ++i0) { for (int64 i1 = 0; i1 < result->n2(); ++i1) { for (int64 i2 = 0; i2 < result->n3(); ++i2) { @@ -460,8 +462,8 @@ class ReferenceUtil { template static std::unique_ptr> MapWithIndexArray4D( const Array4D& input, F&& map_function) { - auto result = MakeUnique>(input.planes(), input.depth(), - input.height(), input.width()); + auto result = absl::make_unique>( + input.planes(), input.depth(), input.height(), input.width()); for (int64 plane = 0; plane < input.planes(); ++plane) { for (int64 depth = 0; depth < input.depth(); ++depth) { for (int64 height = 0; height < input.height(); ++height) { @@ -495,8 +497,8 @@ class ReferenceUtil { template static std::unique_ptr> MapWithIndexArray4D( const Array4D& lhs, const Array4D& rhs, F&& map_function) { - auto result = MakeUnique>(lhs.planes(), lhs.depth(), - lhs.height(), lhs.width()); + auto result = absl::make_unique>(lhs.planes(), lhs.depth(), + lhs.height(), lhs.width()); for (int64 plane = 0; plane < lhs.planes(); ++plane) { for (int64 depth = 0; depth < lhs.depth(); ++depth) { for (int64 height = 0; height < lhs.height(); ++height) { @@ -530,7 +532,7 @@ class ReferenceUtil { int64 out1 = in1 + low_padding1 + high_padding1 + (in1 - 1) * interior_padding1; - auto result = MakeUnique>(out0, out1); + auto result = absl::make_unique>(out0, out1); result->Fill(pad); int64 o0 = low_padding0; for (int64 i0 = 0; i0 < in0; ++i0) { @@ -669,7 +671,7 @@ class ReferenceUtil { static std::unique_ptr> ApplyElementwise2D( F&& f, const Array2D& array1, const Array2D&... arrays) { AssertSameSize2D(array1, arrays...); - auto result = MakeUnique>(array1.n1(), array1.n2()); + auto result = absl::make_unique>(array1.n1(), array1.n2()); for (int64 i = 0; i < array1.n1(); ++i) { for (int64 j = 0; j < array1.n2(); ++j) { (*result)(i, j) = f(array1(i, j), arrays(i, j)...); diff --git a/tensorflow/compiler/xla/reference_util_test.cc b/tensorflow/compiler/xla/reference_util_test.cc index 8091bed4996a753649a5ecedda69a1ae48fb5897..3ec0192148492c2516bf1c14fd4b960b08014388 100644 --- a/tensorflow/compiler/xla/reference_util_test.cc +++ b/tensorflow/compiler/xla/reference_util_test.cc @@ -18,12 +18,12 @@ limitations under the License. #include #include +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/array2d.h" #include "tensorflow/compiler/xla/array3d.h" #include "tensorflow/compiler/xla/array4d.h" #include "tensorflow/compiler/xla/client/padding.h" #include "tensorflow/compiler/xla/literal.h" -#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" @@ -36,7 +36,7 @@ namespace { class ReferenceUtilTest : public ::testing::Test { protected: ReferenceUtilTest() { - matrix_ = MakeUnique>(rows_, cols_); + matrix_ = absl::make_unique>(rows_, cols_); // [1.f 2.f 3.f] // [4.f 5.f 6.f] for (int64 i = 0; i < rows_; ++i) { @@ -112,8 +112,8 @@ TEST_F(ReferenceUtilTest, MapWithIndexArray2D) { } TEST_F(ReferenceUtilTest, MapArray4D) { - auto input = MakeUnique>(/*planes=*/2, /*depth=*/3, - /*height=*/4, /*width=*/5); + auto input = absl::make_unique>(/*planes=*/2, /*depth=*/3, + /*height=*/4, /*width=*/5); input->FillWithMultiples(1.0f); auto multiply_by_two = [](float value) { return 2 * value; }; auto result = ReferenceUtil::MapArray4D(*input, multiply_by_two); @@ -126,8 +126,8 @@ TEST_F(ReferenceUtilTest, MapArray4D) { } TEST_F(ReferenceUtilTest, MapWithIndexArray4D) { - auto input = MakeUnique>(/*planes=*/2, /*depth=*/3, - /*height=*/4, /*width=*/5); + auto input = absl::make_unique>(/*planes=*/2, /*depth=*/3, + /*height=*/4, /*width=*/5); input->FillWithMultiples(1.0f); auto subtract_index = [](float value, int64 plane, int64 depth, int64 height, int64 width) { diff --git a/tensorflow/compiler/xla/rpc/BUILD b/tensorflow/compiler/xla/rpc/BUILD index 0b1cec1925d4424db086f8a3f62c91ede090189c..44b22a5586dee3f7dd8ea0edbf9deb2090986ac8 100644 --- a/tensorflow/compiler/xla/rpc/BUILD +++ b/tensorflow/compiler/xla/rpc/BUILD @@ -56,7 +56,7 @@ tf_cc_test( ":grpc_stub", "//tensorflow:grpc++", "//tensorflow/compiler/xla/client", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", + "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/core:framework_internal", "//tensorflow/core:lib", diff --git a/tensorflow/compiler/xla/rpc/grpc_client_test.cc b/tensorflow/compiler/xla/rpc/grpc_client_test.cc index 90efee50b4f19056fac8ef1b341b48175903ff83..67886761813f0bb45a600661b017be91ffeade73 100644 --- a/tensorflow/compiler/xla/rpc/grpc_client_test.cc +++ b/tensorflow/compiler/xla/rpc/grpc_client_test.cc @@ -24,7 +24,7 @@ limitations under the License. #include "grpcpp/security/credentials.h" #include "tensorflow/compiler/xla/client/client.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/rpc/grpc_stub.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/core/lib/io/path.h" diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD index 2305dd4318bc1f5201269b61e801fd493852bbd3..01f273ad1f7c70250b7530d591f59a53feee45d9 100644 --- a/tensorflow/compiler/xla/service/BUILD +++ b/tensorflow/compiler/xla/service/BUILD @@ -175,6 +175,7 @@ cc_library( "//tensorflow/compiler/xla:window_util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:lib", + "@com_google_absl//absl/algorithm:container", ], ) @@ -237,6 +238,8 @@ cc_library( "//tensorflow/compiler/xla:window_util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:lib", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/memory", ], ) @@ -256,13 +259,14 @@ tf_cc_test( "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", + "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/service:hlo_element_type_converter", "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:xla_internal_test_main", # fixdeps: keep "//tensorflow/core:lib", "//tensorflow/core:test", + "@com_google_absl//absl/memory", ], ) @@ -311,6 +315,8 @@ cc_library( "//tensorflow/core:human_readable_json", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/memory", ], ) @@ -449,6 +455,7 @@ cc_library( "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:util", "//tensorflow/core:lib", + "@com_google_absl//absl/memory", ], ) @@ -517,6 +524,7 @@ tf_cc_test( "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", "//tensorflow/core:test", + "@com_google_absl//absl/memory", ], ) @@ -564,16 +572,17 @@ cc_library( ":computation_placer", ":device_memory_allocator", ":platform_util", - ":pool", + ":stream_pool", ":transfer_manager", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", - "//tensorflow/core:core_cpu_internal", + "//tensorflow/core:core_cpu_lib", "//tensorflow/core:lib", "//tensorflow/core:stream_executor_no_cuda", "//third_party/eigen3", + "@com_google_absl//absl/memory", ], ) @@ -598,6 +607,7 @@ cc_library( ":hlo_proto_util", ":platform_util", ":source_map_util", + ":stream_pool", ":transfer_manager", "//tensorflow/compiler/xla:executable_run_options", "//tensorflow/compiler/xla:execution_options_util", @@ -612,7 +622,9 @@ cc_library( "//tensorflow/compiler/xla:xla_proto", "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/core:lib", + "//tensorflow/core:ptr_util", "//tensorflow/core:stream_executor_no_cuda", + "@com_google_absl//absl/memory", ], alwayslink = 1, ) @@ -645,6 +657,7 @@ cc_library( "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/core:lib", "//tensorflow/core:stream_executor_no_cuda", + "@com_google_absl//absl/memory", ], ) @@ -717,6 +730,7 @@ cc_library( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:lib", "//tensorflow/core:stream_executor_no_cuda", + "@com_google_absl//absl/memory", ], ) @@ -734,6 +748,7 @@ tf_cc_test( "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:ptr_util", "//tensorflow/core:test", + "@com_google_absl//absl/memory", ], ) @@ -751,8 +766,8 @@ cc_library( ":hlo_execution_profile", ":hlo_graph_dumper", ":hlo_proto", - ":pool", ":shaped_buffer", + ":stream_pool", "//tensorflow/compiler/xla:executable_run_options", "//tensorflow/compiler/xla:status", "//tensorflow/compiler/xla:status_macros", @@ -764,6 +779,7 @@ cc_library( "//tensorflow/core:lib_internal", "//tensorflow/core:stream_executor_no_cuda", "//tensorflow/stream_executor", + "@com_google_absl//absl/memory", ], ) @@ -811,6 +827,7 @@ cc_library( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:lib", "//tensorflow/core:stream_executor_no_cuda", + "@com_google_absl//absl/memory", ], ) @@ -829,6 +846,7 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:lib", + "@com_google_absl//absl/memory", ], ) @@ -838,13 +856,14 @@ cc_library( hdrs = ["execution_tracker.h"], deps = [ ":backend", - ":pool", + ":stream_pool", "//tensorflow/compiler/xla:executable_run_options", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:lib", "//tensorflow/core:stream_executor_no_cuda", + "@com_google_absl//absl/memory", ], ) @@ -862,6 +881,7 @@ cc_library( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", + "@com_google_absl//absl/memory", ], ) @@ -921,6 +941,7 @@ tf_cc_test( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "@com_google_absl//absl/memory", ], ) @@ -946,9 +967,9 @@ cc_library( "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", - "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", + "@com_google_absl//absl/memory", ], ) @@ -976,6 +997,7 @@ tf_cc_test( "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", + "@com_google_absl//absl/memory", ], ) @@ -1030,6 +1052,7 @@ cc_library( "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:util", "//tensorflow/core:lib", + "@com_google_absl//absl/memory", ], ) @@ -1048,6 +1071,7 @@ tf_cc_test( "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", + "@com_google_absl//absl/memory", ], ) @@ -1064,6 +1088,7 @@ cc_library( "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:util", "//tensorflow/core:lib", + "@com_google_absl//absl/memory", ], ) @@ -1081,6 +1106,7 @@ cc_library( "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", "//tensorflow/core:lib", + "@com_google_absl//absl/memory", ], ) @@ -1141,6 +1167,7 @@ cc_library( ":hlo_pass", "//tensorflow/compiler/xla:util", "//tensorflow/core:lib", + "@com_google_absl//absl/algorithm:container", ], ) @@ -1180,6 +1207,8 @@ cc_library( "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:util", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/memory", ], ) @@ -1197,6 +1226,7 @@ tf_cc_test( "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:test", + "@com_google_absl//absl/memory", ], ) @@ -1230,6 +1260,22 @@ cc_library( "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:util", + "@com_google_absl//absl/algorithm:container", + ], +) + +cc_library( + name = "scatter_expander", + srcs = ["scatter_expander.cc"], + hdrs = ["scatter_expander.h"], + deps = [ + ":hlo", + ":hlo_creation_utils", + ":hlo_pass", + ":while_util", + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:statusor", + "@com_google_absl//absl/algorithm:container", ], ) @@ -1252,6 +1298,7 @@ tf_cc_test( "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", + "@com_google_absl//absl/memory", ], ) @@ -1274,6 +1321,8 @@ cc_library( "//tensorflow/compiler/xla:window_util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:lib", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/memory", ], ) @@ -1297,6 +1346,7 @@ tf_cc_test( "//tensorflow/compiler/xla/tests:xla_internal_test_main", # fixdeps: keep "//tensorflow/core:lib", "//tensorflow/core:test", + "@com_google_absl//absl/memory", ], ) @@ -1308,8 +1358,7 @@ cc_library( ":hlo", ":hlo_creation_utils", ":hlo_pass", - "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/core:lib", + "@com_google_absl//absl/algorithm:container", ], ) @@ -1384,6 +1433,52 @@ tf_cc_test( ], ) +cc_library( + name = "convolution_feature_group_converter", + srcs = ["convolution_feature_group_converter.cc"], + hdrs = ["convolution_feature_group_converter.h"], + deps = [ + ":hlo", + ":hlo_pass", + "//tensorflow/compiler/xla:literal", + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/core:lib", + "@com_google_absl//absl/memory", + ], +) + +tf_cc_test( + name = "convolution_feature_group_converter_test", + size = "small", + srcs = ["convolution_feature_group_converter_test.cc"], + deps = [ + ":convolution_feature_group_converter", + ":hlo", + ":hlo_matchers", + ":hlo_parser", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla/tests:hlo_test_base", + ], +) + +cc_library( + name = "while_loop_analysis", + srcs = ["while_loop_analysis.cc"], + hdrs = ["while_loop_analysis.h"], + deps = [ + ":hlo", + ":hlo_evaluator", + "//tensorflow/compiler/xla:literal", + "//tensorflow/core:lib", + ], +) + cc_library( name = "while_loop_simplifier", srcs = ["while_loop_simplifier.cc"], @@ -1391,8 +1486,8 @@ cc_library( deps = [ ":call_inliner", ":hlo", - ":hlo_evaluator", ":hlo_pass", + ":while_loop_analysis", "//tensorflow/compiler/xla:statusor", "//tensorflow/core:lib", ], @@ -1522,6 +1617,7 @@ cc_library( "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:util", "//tensorflow/core:lib", + "@com_google_absl//absl/algorithm:container", ], ) @@ -1542,6 +1638,7 @@ tf_cc_test( "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", + "@com_google_absl//absl/memory", ], ) @@ -1575,6 +1672,7 @@ tf_cc_test( "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "@com_google_absl//absl/memory", ], ) @@ -1594,6 +1692,7 @@ cc_library( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:lib", "//tensorflow/core:stream_executor_no_cuda", + "@com_google_absl//absl/memory", ], alwayslink = True, # Contains per-platform computation placer registration ) @@ -1663,8 +1762,8 @@ tf_cc_test( "//tensorflow/compiler/xla/client:client_library", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client:padding", + "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/client:xla_computation", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", @@ -1684,6 +1783,8 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/core:lib", "//tensorflow/core:stream_executor_no_cuda", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/memory", ], ) @@ -1729,6 +1830,7 @@ tf_cc_binary( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:lib", + "@com_google_absl//absl/memory", ], ) @@ -1745,6 +1847,7 @@ tf_cc_test( "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", + "@com_google_absl//absl/memory", ], ) @@ -1804,6 +1907,7 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:lib", + "@com_google_absl//absl/memory", ], ) @@ -1822,6 +1926,7 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:lib", + "@com_google_absl//absl/memory", ], ) @@ -1863,6 +1968,7 @@ cc_library( "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", "//tensorflow/core:lib", + "@com_google_absl//absl/memory", ], ) @@ -1956,6 +2062,7 @@ cc_library( "//tensorflow/compiler/xla:statusor", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", + "@com_google_absl//absl/memory", ], ) @@ -1968,7 +2075,6 @@ cc_library( ":hlo_dataflow_analysis", ":logical_buffer", ":logical_buffer_analysis", - "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_tree", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:statusor", @@ -1976,6 +2082,7 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:lib", + "@com_google_absl//absl/memory", ], ) @@ -2026,6 +2133,7 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:lib", + "@com_google_absl//absl/memory", ], ) @@ -2116,6 +2224,7 @@ cc_library( ":shape_inference", "//tensorflow/compiler/xla:status_macros", "//tensorflow/core:lib", + "@com_google_absl//absl/memory", ], ) @@ -2198,6 +2307,7 @@ tf_cc_test( "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", "//tensorflow/core:test", + "@com_google_absl//absl/memory", ], ) @@ -2279,6 +2389,7 @@ cc_library( "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", "//tensorflow/core:lib", + "@com_google_absl//absl/memory", ], ) @@ -2316,6 +2427,7 @@ tf_cc_test( "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:test_utils", "//tensorflow/core:lib", + "@com_google_absl//absl/memory", ], ) @@ -2332,6 +2444,7 @@ cc_library( "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:types", "//tensorflow/core:lib", + "@com_google_absl//absl/memory", ], ) @@ -2363,6 +2476,7 @@ cc_library( "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", "//tensorflow/core:lib", + "@com_google_absl//absl/memory", ], ) @@ -2377,6 +2491,7 @@ cc_library( "//tensorflow/compiler/xla:shape_tree", "//tensorflow/compiler/xla:shape_util", "//tensorflow/core:lib", + "@com_google_absl//absl/memory", ], ) @@ -2437,6 +2552,7 @@ tf_cc_test( "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:test", + "@com_google_absl//absl/memory", ], ) @@ -2505,6 +2621,7 @@ cc_library( "//tensorflow/compiler/xla/service/llvm_ir:loop_emitter", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", + "@com_google_absl//absl/algorithm:container", "@llvm//:core", "@llvm//:transform_utils", ], @@ -2536,10 +2653,10 @@ cc_library( ":computation_layout", "//tensorflow/compiler/xla:shape_layout", "//tensorflow/compiler/xla:types", - "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla:xla_proto", "//tensorflow/core:lib", + "@com_google_absl//absl/memory", ], ) @@ -2670,7 +2787,7 @@ tf_cc_test( "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", + "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/service/gpu:ir_emission_utils", "//tensorflow/compiler/xla/tests:hlo_test_base", @@ -2707,7 +2824,7 @@ tf_cc_test( "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", + "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", @@ -2715,21 +2832,25 @@ tf_cc_test( ) cc_library( - name = "pool", - hdrs = ["pool.h"], + name = "stream_pool", + srcs = ["stream_pool.cc"], + hdrs = ["stream_pool.h"], deps = [ - "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla:types", "//tensorflow/core:lib", + "//tensorflow/core:stream_executor_no_cuda", + "@com_google_absl//absl/memory", ], ) tf_cc_test( - name = "pool_test", - srcs = ["pool_test.cc"], + name = "stream_pool_test", + srcs = ["stream_pool_test.cc"], deps = [ - ":pool", + ":stream_pool", "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/core:stream_executor_no_cuda", ], ) @@ -2816,6 +2937,7 @@ cc_library( "//tensorflow/core:lib", "//tensorflow/core:stream_executor_no_cuda", "//third_party/eigen3", + "@com_google_absl//absl/memory", ], ) @@ -2863,6 +2985,7 @@ cc_library( ":tuple_util", "//tensorflow/compiler/xla:literal_util", "//tensorflow/core:lib", + "@com_google_absl//absl/algorithm:container", ], ) @@ -2876,6 +2999,7 @@ tf_cc_test( "//tensorflow/compiler/xla/service:hlo_matchers", "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "@com_google_absl//absl/algorithm:container", ], ) @@ -2891,6 +3015,7 @@ cc_library( "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:util", "//tensorflow/core:lib", + "@com_google_absl//absl/algorithm:container", ], ) @@ -2918,6 +3043,7 @@ cc_library( "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:util", "//tensorflow/core:lib", + "@com_google_absl//absl/algorithm:container", ], ) @@ -2972,6 +3098,7 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/core:lib", "//tensorflow/core:ptr_util", + "@com_google_absl//absl/algorithm:container", ], ) @@ -3005,6 +3132,8 @@ cc_library( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/memory", ], ) diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.cc b/tensorflow/compiler/xla/service/algebraic_simplifier.cc index 505c0e8dff44ace09bd67f54ecb3f2716a2fb167..1d26e306519a95500a91d982ca59918a3ab64174 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier.cc @@ -22,6 +22,8 @@ limitations under the License. #include #include +#include "absl/algorithm/container.h" +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/literal_util.h" @@ -150,6 +152,8 @@ class AlgebraicSimplifierVisitor : public DfsHloVisitorWithDefault { Status HandleDynamicUpdateSlice( HloInstruction* dynamic_update_slice) override; + Status HandleSort(HloInstruction* sort) override; + Status HandleTranspose(HloInstruction* transpose) override; Status HandleSubtract(HloInstruction* sub) override; @@ -538,7 +542,7 @@ Status AlgebraicSimplifierVisitor::HandleConstant(HloInstruction* constant) { // If a literal is all the same element replace it with a scalar broadcast. if (ShapeUtil::ElementsIn(constant->shape()) > 1 && constant->literal().IsAllFirst()) { - std::unique_ptr unique_scalar = MakeUnique( + std::unique_ptr unique_scalar = absl::make_unique( LiteralUtil::GetFirstScalarLiteral(constant->literal())); HloInstruction* scalar = computation_->AddInstruction( HloInstruction::CreateConstant(std::move(unique_scalar))); @@ -1703,6 +1707,10 @@ Status AlgebraicSimplifierVisitor::HandleReshape(HloInstruction* reshape) { reshape, HloInstruction::CreateReshape(reshape->shape(), operand->mutable_operand(0))); } + if (operand->opcode() == HloOpcode::kRng && operand->user_count() == 1) { + *operand->mutable_shape() = reshape->shape(); + return ReplaceInstruction(reshape, operand); + } if (HloOpcode::kBroadcast == reshape->operand(0)->opcode()) { auto opt_dims = ReshapeLeavesDimensionsUnmodified( @@ -1746,8 +1754,8 @@ Status AlgebraicSimplifierVisitor::HandleSlice(HloInstruction* slice) { } auto is_unstrided_slice = [](const HloInstruction* hlo) { - return c_all_of(hlo->slice_strides(), - [](int64 stride) { return stride == 1; }); + return absl::c_all_of(hlo->slice_strides(), + [](int64 stride) { return stride == 1; }); }; if (slice->operand(0)->opcode() == HloOpcode::kSlice && is_unstrided_slice(slice) && is_unstrided_slice(slice->operand(0))) { @@ -1801,6 +1809,12 @@ Status AlgebraicSimplifierVisitor::HandleDynamicUpdateSlice( } Status AlgebraicSimplifierVisitor::HandleReduce(HloInstruction* reduce) { + // TODO(b/112040122): Most of those optimizations can be done for multi-output + // reduces. + if (ShapeUtil::IsTuple(reduce->shape())) { + return Status::OK(); + } + auto arg = reduce->mutable_operand(0); auto init_value = reduce->mutable_operand(1); tensorflow::gtl::ArraySlice dimensions(reduce->dimensions()); @@ -1918,7 +1932,8 @@ Status AlgebraicSimplifierVisitor::HandleReduce(HloInstruction* reduce) { // This should make fusion easier or use less memory bandwidth in the unfused // case. if (arg->opcode() == HloOpcode::kConcatenate && - c_linear_search(reduce->dimensions(), arg->concatenate_dimension())) { + absl::c_linear_search(reduce->dimensions(), + arg->concatenate_dimension())) { HloInstruction* old_reduce = nullptr; for (HloInstruction* operand : arg->operands()) { HloInstruction* new_reduce = computation_->AddInstruction( @@ -2105,6 +2120,21 @@ Status AlgebraicSimplifierVisitor::HandleReduceWindow( /*reduce_computation=*/function)); } +Status AlgebraicSimplifierVisitor::HandleSort(HloInstruction* sort) { + auto operand = sort->mutable_operand(0); + int64 dimension_to_sort = sort->dimensions(0); + if (ShapeUtil::IsZeroElementArray(operand->shape()) || + operand->shape().dimensions(dimension_to_sort) <= 1) { + if (sort->operand_count() == 1) { + return ReplaceInstruction(sort, operand); + } + // If it is key/value sort, the output of sort is a tuple. + return ReplaceWithNewInstruction( + sort, HloInstruction::CreateTuple({operand, sort->mutable_operand(1)})); + } + return Status::OK(); +} + Status AlgebraicSimplifierVisitor::HandleTranspose(HloInstruction* transpose) { auto operand = transpose->mutable_operand(0); if (std::is_sorted(transpose->dimensions().begin(), @@ -2121,6 +2151,11 @@ Status AlgebraicSimplifierVisitor::HandleTranspose(HloInstruction* transpose) { transpose->dimensions()))); } + if (operand->opcode() == HloOpcode::kRng && operand->user_count() == 1) { + *operand->mutable_shape() = transpose->shape(); + return ReplaceInstruction(transpose, operand); + } + if (is_layout_sensitive_ && TransposeIsBitcast(transpose)) { ReplaceWithBitcast(transpose); return Status::OK(); diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc index 8b81b4c97ef373bcfb89bf0761ebb16b6e14e3fc..427069af5f49866d4e7c818696a6912302643b54 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc @@ -18,9 +18,9 @@ limitations under the License. #include #include +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/literal.h" -#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_matchers.h" @@ -1428,6 +1428,37 @@ TEST_F(AlgebraicSimplifierTest, NoBitcastAdded) { EXPECT_THAT(computation->root_instruction(), op::Reshape(param0)); } +// Test transforming reshapes and transposes of rng. +TEST_F(AlgebraicSimplifierTest, ReshapeOfTransposeOfRngToRng) { + HloComputation::Builder builder(TestName()); + HloInstruction* zero = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.0f))); + HloInstruction* one = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0f))); + HloInstruction* rng0 = builder.AddInstruction( + HloInstruction::CreateRng(ShapeUtil::MakeShape(F32, {2, 2}), + RandomDistribution::RNG_UNIFORM, {zero, one})); + + HloInstruction* transpose = builder.AddInstruction( + HloInstruction::CreateTranspose(rng0->shape(), rng0, {1, 0})); + Shape reshape_shape = builder + .AddInstruction(HloInstruction::CreateReshape( + ShapeUtil::MakeShape(F32, {4}), transpose)) + ->shape(); + + auto computation = module().AddEntryComputation(builder.Build()); + + AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, + bitcasting_callback()); + EXPECT_TRUE(simplifier.Run(&module()).ValueOrDie()); + + // Verify that that reshape(transpose(rng)) is replace by a single rng of the + // same shape as the reshape. + EXPECT_THAT(computation->root_instruction(), op::Rng()); + EXPECT_TRUE(ShapeUtil::Equal(computation->root_instruction()->shape(), + reshape_shape)); +} + // Test transforming reshapes to bitcasts under various conditions. TEST_F(AlgebraicSimplifierTest, ReshapeReplacedWithBitcast) { HloComputation::Builder builder(TestName()); @@ -1941,6 +1972,40 @@ TEST_F(AlgebraicSimplifierTest, SliceOfSliceToSlice) { EXPECT_EQ(computation->root_instruction()->slice_limits(1), dim1 - 4); } +TEST_F(AlgebraicSimplifierTest, RemoveNoopSort) { + auto builder = HloComputation::Builder(TestName()); + + Shape keys_shape = ShapeUtil::MakeShape(F32, {1}); + auto keys = builder.AddInstruction( + HloInstruction::CreateParameter(0, keys_shape, "keys")); + builder.AddInstruction(HloInstruction::CreateSort(keys_shape, 0, keys)); + auto module = CreateNewModule(); + HloComputation* computation = module->AddEntryComputation(builder.Build()); + AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, + non_bitcasting_callback()); + ASSERT_TRUE(simplifier.Run(module).ValueOrDie()); + EXPECT_THAT(computation->root_instruction(), keys); +} + +TEST_F(AlgebraicSimplifierTest, ReplaceEffectiveScalarKeyValueSortWithTuple) { + auto builder = HloComputation::Builder(TestName()); + + Shape keys_shape = ShapeUtil::MakeShape(F32, {5, 0}); + Shape values_shape = ShapeUtil::MakeShape(S32, {5, 0}); + auto keys = builder.AddInstruction( + HloInstruction::CreateParameter(0, keys_shape, "keys")); + auto values = builder.AddInstruction( + HloInstruction::CreateParameter(1, values_shape, "values")); + builder.AddInstruction(HloInstruction::CreateSort( + ShapeUtil::MakeTupleShape({keys_shape, values_shape}), 0, keys, values)); + auto module = CreateNewModule(); + HloComputation* computation = module->AddEntryComputation(builder.Build()); + AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, + non_bitcasting_callback()); + ASSERT_TRUE(simplifier.Run(module).ValueOrDie()); + EXPECT_THAT(computation->root_instruction(), op::Tuple(keys, values)); +} + TEST_F(AlgebraicSimplifierTest, ConvertConvToMatmul) { struct ConvTestOptions { int in_batch = 10; @@ -1972,7 +2037,7 @@ TEST_F(AlgebraicSimplifierTest, ConvertConvToMatmul) { // Builds a convolution from and runs algebraic simplification on // the computation. Returns a string description of the result of // simplification. - auto build_and_simplify = [&options, this]() -> string { + auto build_and_simplify = [&]() -> string { HloComputation::Builder b(TestName()); Window window; diff --git a/tensorflow/compiler/xla/service/allocation_tracker.cc b/tensorflow/compiler/xla/service/allocation_tracker.cc index 95b4cb6d2e694063b648b264bd2454ae0a5469ff..d0806d24a22ce57af3116b9aaddb487ec24bfbae 100644 --- a/tensorflow/compiler/xla/service/allocation_tracker.cc +++ b/tensorflow/compiler/xla/service/allocation_tracker.cc @@ -17,8 +17,8 @@ limitations under the License. #include +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/map_util.h" -#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/device_memory_allocator.h" #include "tensorflow/compiler/xla/service/transfer_manager.h" #include "tensorflow/compiler/xla/shape_util.h" @@ -91,8 +91,9 @@ StatusOr AllocationTracker::RegisterInternal( // If ShapedBufferTy is ScopedShapedBuffer, release the ScopedShapedBuffer // into a regular ShapedBuffer, which is stored in // handle_to_shaped_buffers_. - handle_to_shaped_buffers_[handle].emplace_back(MakeUnique( - ReleaseIfScopedShapedBuffer(std::move(shaped_buffer)))); + handle_to_shaped_buffers_[handle].emplace_back( + absl::make_unique( + ReleaseIfScopedShapedBuffer(std::move(shaped_buffer)))); } GlobalDataHandle result; @@ -109,11 +110,11 @@ Status AllocationTracker::Unregister(const GlobalDataHandle& data) { ResolveInternal(data)); for (const auto& shaped_buffer : replicated_buffers) { std::vector shape_indices; - ShapeUtil::ForEachSubshape(shaped_buffer->on_device_shape(), - [this, &shape_indices](const Shape& /*subshape*/, - const ShapeIndex& index) { - shape_indices.push_back(index); - }); + ShapeUtil::ForEachSubshape( + shaped_buffer->on_device_shape(), + [&shape_indices](const Shape& /*subshape*/, const ShapeIndex& index) { + shape_indices.push_back(index); + }); for (const ShapeIndex& index : shape_indices) { TF_RETURN_IF_ERROR(DecrementRefCount(shaped_buffer->buffer(index), shaped_buffer->device_ordinal())); diff --git a/tensorflow/compiler/xla/service/backend.cc b/tensorflow/compiler/xla/service/backend.cc index 349b32451a697dbd6804b44cd1a36419c753bb14..841d0fa85bb9c548cd737e21bb988886f43378bd 100644 --- a/tensorflow/compiler/xla/service/backend.cc +++ b/tensorflow/compiler/xla/service/backend.cc @@ -21,6 +21,7 @@ limitations under the License. #include #include +#include "absl/memory/memory.h" #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #include "tensorflow/compiler/xla/service/compiler.h" #include "tensorflow/compiler/xla/service/platform_util.h" @@ -96,24 +97,19 @@ Backend::CreateDefaultBackend() { return CreateBackend(backend_options); } -StatusOr Backend::BorrowStream(int device_ordinal) { - TF_ASSIGN_OR_RETURN(auto exec, stream_executor(device_ordinal)); - return BorrowStream(exec); +StatusOr Backend::BorrowStream(int device_ordinal) { + TF_ASSIGN_OR_RETURN(auto executor, stream_executor(device_ordinal)); + return BorrowStream(executor); } -StatusOr Backend::BorrowStream( - se::StreamExecutor* executor) { +StatusOr Backend::BorrowStream(se::StreamExecutor* executor) { tensorflow::mutex_lock l(mu_); if (0 == stream_pools_.count(executor)) { stream_pools_.emplace(std::piecewise_construct, std::forward_as_tuple(executor), - std::forward_as_tuple([executor]() { - auto stream = MakeUnique(executor); - stream->Init(); - return stream; - })); + std::forward_as_tuple()); } - return stream_pools_.at(executor).Allocate(); + return stream_pools_.at(executor).BorrowStream(executor); } Backend::Backend( @@ -132,8 +128,8 @@ Backend::Backend( } } // Create a memory allocator for the valid stream executors. - memory_allocator_ = - MakeUnique(platform, stream_executors); + memory_allocator_ = absl::make_unique( + platform, stream_executors); CHECK(!stream_executors_.empty()) << "Service found no devices for backend " << platform_->Name() << '.'; diff --git a/tensorflow/compiler/xla/service/backend.h b/tensorflow/compiler/xla/service/backend.h index 6546602473e3381cf13879ddebd05d34d1f7a055..1bc3796fa48c1627538474d04ef5358ba64dfce9 100644 --- a/tensorflow/compiler/xla/service/backend.h +++ b/tensorflow/compiler/xla/service/backend.h @@ -24,7 +24,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/compiler.h" #include "tensorflow/compiler/xla/service/computation_placer.h" #include "tensorflow/compiler/xla/service/device_memory_allocator.h" -#include "tensorflow/compiler/xla/service/pool.h" +#include "tensorflow/compiler/xla/service/stream_pool.h" #include "tensorflow/compiler/xla/service/transfer_manager.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" @@ -63,11 +63,9 @@ class BackendOptions { // // It also offers a pooling API for creation/use of initialized streams: // -// StreamPtr stream = backend->BorrowStream().ConsumeValueOrDie(); +// StreamPool::Ptr stream = backend->BorrowStream().ConsumeValueOrDie(); class Backend { public: - using StreamPtr = Pool::SmartPtr; - // Creates a new backend. static StatusOr> CreateBackend( const BackendOptions& options); @@ -114,13 +112,13 @@ class Backend { // Borrows a stream for use by the caller, either by grabbing it from an // internal pool, or by constructing/initializating it, and returns the result // to the caller. - StatusOr BorrowStream(int device_ordinal); - StatusOr BorrowStream(se::StreamExecutor* executor); + StatusOr BorrowStream(int device_ordinal); + StatusOr BorrowStream(se::StreamExecutor* executor); // Returns a function to borrow a stream, as `BorrowStream` above does. // Purely for convenience, the caller could rather make this anonymous // function itself. - std::function(int)> StreamBorrower() { + std::function(int)> StreamBorrower() { return [this](int device_ordinal) { return BorrowStream(device_ordinal); }; } @@ -169,7 +167,7 @@ class Backend { tensorflow::mutex mu_; // Mapping from stream executor to stream pools, used by `BorrowStream` above. - std::map> stream_pools_ GUARDED_BY(mu_); + std::map stream_pools_ GUARDED_BY(mu_); // The default memory allocator to use. std::unique_ptr memory_allocator_; diff --git a/tensorflow/compiler/xla/service/batch_dot_simplification.cc b/tensorflow/compiler/xla/service/batch_dot_simplification.cc index 2099916509acdbc2680cc2b5bd405e96f2f7bfb8..b226e7ecb09c207645451073f7300d1278df4d28 100644 --- a/tensorflow/compiler/xla/service/batch_dot_simplification.cc +++ b/tensorflow/compiler/xla/service/batch_dot_simplification.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/batch_dot_simplification.h" +#include "absl/algorithm/container.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_creation_utils.h" @@ -84,10 +85,10 @@ StatusOr BatchDotSimplification::Run(HloModule* module) { bool changed = false; std::vector dot_instrs; for (HloComputation* computation : module->MakeNonfusionComputations()) { - c_copy_if(computation->instructions(), std::back_inserter(dot_instrs), - [](HloInstruction* instr) { - return instr->opcode() == HloOpcode::kDot; - }); + absl::c_copy_if(computation->instructions(), std::back_inserter(dot_instrs), + [](HloInstruction* instr) { + return instr->opcode() == HloOpcode::kDot; + }); } for (HloInstruction* dot_instr : dot_instrs) { TF_ASSIGN_OR_RETURN(bool elided_batch_dim_from_one, diff --git a/tensorflow/compiler/xla/service/batchnorm_expander_test.cc b/tensorflow/compiler/xla/service/batchnorm_expander_test.cc index 32f785a70adf0e7ea3ce281f7ff73224be8d424e..f62ab12319bf2cf6d37a5133b8e07dc4052179d0 100644 --- a/tensorflow/compiler/xla/service/batchnorm_expander_test.cc +++ b/tensorflow/compiler/xla/service/batchnorm_expander_test.cc @@ -18,9 +18,9 @@ limitations under the License. #include #include +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/literal.h" -#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_matchers.h" @@ -137,9 +137,9 @@ ENTRY entry { if (instruction->opcode() == HloOpcode::kParameter) { continue; } - ASSERT_TRUE(instruction->has_sharding()); - TF_ASSERT_OK_AND_ASSIGN(int device, instruction->sharding().UniqueDevice()); - EXPECT_EQ(device, 1); + auto device = instruction->sharding_unique_device(); + ASSERT_TRUE(device); + EXPECT_EQ(*device, 1); } } diff --git a/tensorflow/compiler/xla/service/bfloat16_conversion_folding_test.cc b/tensorflow/compiler/xla/service/bfloat16_conversion_folding_test.cc index f7b4c1405dbc8719d8fba5476e6e41d2921ea877..7cf05ca443c00c3b40eeb7d756cf216b45c45c39 100644 --- a/tensorflow/compiler/xla/service/bfloat16_conversion_folding_test.cc +++ b/tensorflow/compiler/xla/service/bfloat16_conversion_folding_test.cc @@ -235,7 +235,8 @@ TEST_F(BFloat16ConversionFoldingTest, FoldCrossReplicaSumTupleOutput) { HloInstruction* crs = builder.AddInstruction(HloInstruction::CreateCrossReplicaSum( ShapeUtil::MakeTupleShape({f32_shape, f32_shape}), {convert_a, b}, - sum, /*replica_group_ids=*/{}, /*barrier=*/"")); + sum, /*replica_group_ids=*/{}, /*barrier=*/"", + /*all_reduce_id=*/tensorflow::gtl::nullopt)); HloInstruction* gte_a = builder.AddInstruction( HloInstruction::CreateGetTupleElement(f32_shape, crs, 0)); HloInstruction* gte_b = builder.AddInstruction( diff --git a/tensorflow/compiler/xla/service/bfloat16_normalization.cc b/tensorflow/compiler/xla/service/bfloat16_normalization.cc index 14c54ddd135af024327f63418b410da1ed3c4fd4..16e99b57220cc185fbfaa75d30a0de709cf61ee7 100644 --- a/tensorflow/compiler/xla/service/bfloat16_normalization.cc +++ b/tensorflow/compiler/xla/service/bfloat16_normalization.cc @@ -34,8 +34,10 @@ class BFloat16NormalizationVisitor : public DfsHloVisitorWithDefault { Status DefaultAction(HloInstruction* hlo) override; - // Special handling for cross-replica-sum which can have a tuple output. + // Special handling for cross-replica-sum and sort which can have a tuple + // output. Status HandleCrossReplicaSum(HloInstruction* crs) override; + Status HandleSort(HloInstruction* sort) override; static bool Run(HloComputation* computation, const BFloat16Support* bfloat16_support) { @@ -49,6 +51,10 @@ class BFloat16NormalizationVisitor : public DfsHloVisitorWithDefault { // conversions between F32 and BF16 to make it supported. Status HandleInstruction(HloInstruction* hlo); + // Handle instructions with tuple outputs by examining each output + // independently. + Status HandleMultipleOutputs(HloInstruction* hlo); + // Inserts a conversion HLO that changes the given HLO's output type. Status InsertConvertAfterOutput(HloInstruction* hlo, PrimitiveType to, HloComputation* computation); @@ -148,22 +154,35 @@ Status BFloat16NormalizationVisitor::HandleCrossReplicaSum( HloInstruction* crs) { if (!ShapeUtil::IsTuple(crs->shape())) { return HandleInstruction(crs); + } else { + return HandleMultipleOutputs(crs); } +} + +Status BFloat16NormalizationVisitor::HandleSort(HloInstruction* sort) { + if (!ShapeUtil::IsTuple(sort->shape())) { + return HandleInstruction(sort); + } else { + return HandleMultipleOutputs(sort); + } +} - std::vector operand_types(crs->operand_count()); - std::vector output_types(crs->operand_count()); +Status BFloat16NormalizationVisitor::HandleMultipleOutputs( + HloInstruction* hlo) { + std::vector operand_types(hlo->operand_count()); + std::vector output_types(hlo->operand_count()); int64 f32_count = 0; int64 bf16_count = 0; bool has_unsupported_bf16_operand = false; bool has_unsupported_bf16_output = false; - for (int64 i = 0; i < crs->operand_count(); ++i) { - operand_types[i] = crs->operand(i)->shape().element_type(); - output_types[i] = ShapeUtil::GetSubshape(crs->shape(), {i}).element_type(); + for (int64 i = 0; i < hlo->operand_count(); ++i) { + operand_types[i] = hlo->operand(i)->shape().element_type(); + output_types[i] = ShapeUtil::GetSubshape(hlo->shape(), {i}).element_type(); if (operand_types[i] == F32) { f32_count += 1; } else if (operand_types[i] == BF16) { bf16_count += 1; - if (!bfloat16_support_->SupportsBF16Operand(*crs, i)) { + if (!bfloat16_support_->SupportsBF16Operand(*hlo, i)) { has_unsupported_bf16_operand = true; } } @@ -171,7 +190,7 @@ Status BFloat16NormalizationVisitor::HandleCrossReplicaSum( f32_count += 1; } else if (output_types[i] == BF16) { bf16_count += 1; - if (!bfloat16_support_->SupportsBF16Output(*crs)) { + if (!bfloat16_support_->SupportsBF16Output(*hlo)) { has_unsupported_bf16_output = true; } } @@ -185,43 +204,43 @@ Status BFloat16NormalizationVisitor::HandleCrossReplicaSum( if (operand_types[i] != BF16) { return false; } - if (!bfloat16_support_->SupportsBF16Operand(*crs, i)) { + if (!bfloat16_support_->SupportsBF16Operand(*hlo, i)) { return true; } - if (bfloat16_support_->SupportsMixedPrecisions(*crs)) { + if (bfloat16_support_->SupportsMixedPrecisions(*hlo)) { return false; } return has_unsupported_bf16_operand || has_unsupported_bf16_output || f32_count > 0; }; - for (int64 i = 0; i < crs->operand_count(); ++i) { + for (int64 i = 0; i < hlo->operand_count(); ++i) { if (should_convert_operand(i)) { - TF_RETURN_IF_ERROR(InsertConvertBeforeOperand(crs, i, F32, computation_)); + TF_RETURN_IF_ERROR(InsertConvertBeforeOperand(hlo, i, F32, computation_)); f32_count += 1; bf16_count -= 1; } } if (!has_unsupported_bf16_output && - (bfloat16_support_->SupportsMixedPrecisions(*crs) || f32_count == 0 || + (bfloat16_support_->SupportsMixedPrecisions(*hlo) || f32_count == 0 || bf16_count == 0)) { return Status::OK(); } - std::vector materialized_users = crs->users(); - std::vector output_elements(crs->operand_count()); - auto original_shape = crs->shape(); - for (int64 i = 0; i < crs->operand_count(); ++i) { - auto subshape = ShapeUtil::GetMutableSubshape(crs->mutable_shape(), {i}); + std::vector materialized_users = hlo->users(); + std::vector output_elements(hlo->operand_count()); + auto original_shape = hlo->shape(); + for (int64 i = 0; i < hlo->operand_count(); ++i) { + auto subshape = ShapeUtil::GetMutableSubshape(hlo->mutable_shape(), {i}); if (output_types[i] != BF16) { output_elements[i] = computation_->AddInstruction( - HloInstruction::CreateGetTupleElement(*subshape, crs, i)); + HloInstruction::CreateGetTupleElement(*subshape, hlo, i)); continue; } subshape->set_element_type(F32); auto gte = computation_->AddInstruction( - HloInstruction::CreateGetTupleElement(*subshape, crs, i)); + HloInstruction::CreateGetTupleElement(*subshape, hlo, i)); output_elements[i] = computation_->AddInstruction(HloInstruction::CreateConvert( ShapeUtil::ChangeElementType(*subshape, BF16), gte)); @@ -229,11 +248,11 @@ Status BFloat16NormalizationVisitor::HandleCrossReplicaSum( auto tuple = computation_->AddInstruction( HloInstruction::CreateTuple(output_elements)); - // Use the crs' shape temporarily, in order to pass checks in + // Use the hlo' shape temporarily, in order to pass checks in // ReplaceUseWith. - *tuple->mutable_shape() = crs->shape(); + *tuple->mutable_shape() = hlo->shape(); for (auto* user : materialized_users) { - TF_RETURN_IF_ERROR(crs->ReplaceUseWith(user, tuple)); + TF_RETURN_IF_ERROR(hlo->ReplaceUseWith(user, tuple)); } *tuple->mutable_shape() = original_shape; return Status::OK(); diff --git a/tensorflow/compiler/xla/service/bfloat16_normalization_test.cc b/tensorflow/compiler/xla/service/bfloat16_normalization_test.cc index 830f26422bdc2b3bd789e7d5926bcebac815d34a..f9f1f64998f5b925102dc238941897ff6d441b3f 100644 --- a/tensorflow/compiler/xla/service/bfloat16_normalization_test.cc +++ b/tensorflow/compiler/xla/service/bfloat16_normalization_test.cc @@ -251,7 +251,8 @@ TEST_F(BFloat16NormalizationTest, ResolveMixedPrecisionTupleCrossReplicaSum) { HloInstruction* crs = builder.AddInstruction(HloInstruction::CreateCrossReplicaSum( ShapeUtil::MakeTupleShape({f32_shape, bf16_shape}), {a, b}, reduction, - /*replica_group_ids=*/{}, /*barrier=*/"")); + /*replica_group_ids=*/{}, /*barrier=*/"", + /*all_reduce_id=*/tensorflow::gtl::nullopt)); HloInstruction* gte = builder.AddInstruction( HloInstruction::CreateGetTupleElement(bf16_shape, crs, 1)); @@ -265,6 +266,33 @@ TEST_F(BFloat16NormalizationTest, ResolveMixedPrecisionTupleCrossReplicaSum) { EXPECT_EQ(ShapeUtil::GetSubshape(crs->shape(), {1}).element_type(), F32); } +TEST_F(BFloat16NormalizationTest, ResolveMixedPrecisionTupleSort) { + auto module = CreateNewModule(); + auto builder = HloComputation::Builder(TestName()); + Shape f32_shape = ShapeUtil::MakeShape(F32, {1024}); + Shape bf16_shape = ShapeUtil::MakeShape(BF16, {1024}); + Shape s32_shape = ShapeUtil::MakeShape(BF16, {1024}); + + HloInstruction* key = builder.AddInstruction( + HloInstruction::CreateParameter(0, f32_shape, "key")); + HloInstruction* value = builder.AddInstruction( + HloInstruction::CreateParameter(1, s32_shape, "value")); + + HloInstruction* sort = builder.AddInstruction(HloInstruction::CreateSort( + ShapeUtil::MakeTupleShape({bf16_shape, s32_shape}), 0, key, value)); + HloInstruction* gte = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(bf16_shape, sort, 0)); + + auto computation = module->AddEntryComputation(builder.Build()); + + EXPECT_TRUE(Normalize(module.get())); + + EXPECT_EQ(computation->root_instruction(), gte); + EXPECT_EQ(gte->shape().element_type(), BF16); + EXPECT_EQ(sort->operand(0)->shape().element_type(), F32); + EXPECT_EQ(ShapeUtil::GetSubshape(sort->shape(), {0}).element_type(), F32); +} + // Tests that the normalization should not cause unsupported mixed precision due // to resolving unsupported BF16 operand. TEST_F(BFloat16NormalizationTest, DoNotAddUnsupportedMixedPrecision) { diff --git a/tensorflow/compiler/xla/service/bfloat16_propagation.cc b/tensorflow/compiler/xla/service/bfloat16_propagation.cc index b21c83a07f69d6ec93cf9305802e4d3af2783bdc..2fb401c4289728f3f59538464c5b8ad49957985b 100644 --- a/tensorflow/compiler/xla/service/bfloat16_propagation.cc +++ b/tensorflow/compiler/xla/service/bfloat16_propagation.cc @@ -215,7 +215,12 @@ bool BFloat16Propagation::AllUsersConsumeBF16(const HloInstruction& hlo, if (ContainsKey(values_that_must_be_kept_as_f32_, value)) { return false; } - if (ValueTypeAfterChange(value) == BF16) { + // We use the original type for the value because we are going to examine + // the uses of it, instead of the value itself. If ValueTypeAfterChange() + // were used, it would cause problems when there are aliasing buffers, i.e., + // ResolveInconsistencyOfAliasingBuffers() would fail to revert the + // tentative change to BF16 even if the uses require F32. + if (value->shape().element_type() == BF16) { continue; } for (const HloUse& use : value->uses()) { @@ -566,6 +571,9 @@ bool BFloat16Propagation::ResolveInconsistencyOfAliasingBuffersHelper( } visited_computations->insert(visited_in_while.begin(), visited_in_while.end()); + } else if (hlo->opcode() == HloOpcode::kFusion) { + ResolveInconsistencyOfAliasingBuffersHelper( + hlo->fused_instructions_computation(), visited_computations); } } // Now adjust parameters of called computations. @@ -769,8 +777,7 @@ StatusOr BFloat16Propagation::Run(HloModule* module) { // propagation in reverse topological order. for (auto comp_it = computations_topological_order.rbegin(); comp_it != computations_topological_order.rend(); ++comp_it) { - if ((*comp_it)->IsFusionComputation()) { - // Fusion computations are handled when visiting the fusion instruction. + if (ContainsKey(computations_visited_in_backward_pass_, *comp_it)) { continue; } auto insts = (*comp_it)->MakeInstructionPostOrder(); @@ -778,6 +785,7 @@ StatusOr BFloat16Propagation::Run(HloModule* module) { DetermineInstructionPrecision(*inst_it, /*skip_parameters=*/true); } + computations_visited_in_backward_pass_.insert(*comp_it); } // It's possible that an instruction does not define a buffer, but the diff --git a/tensorflow/compiler/xla/service/bfloat16_propagation_test.cc b/tensorflow/compiler/xla/service/bfloat16_propagation_test.cc index aeafb25ad7215ea3d297e4a8bf7e1ba72d33d528..69b654d30e42b1ed69304206f09120e86831d468 100644 --- a/tensorflow/compiler/xla/service/bfloat16_propagation_test.cc +++ b/tensorflow/compiler/xla/service/bfloat16_propagation_test.cc @@ -508,6 +508,63 @@ TEST_F(BFloat16PropagationTest, PropagateThroughSimpleWhile) { EXPECT_FALSE(OutputsBF16(dot)); } +// Tests that if the while condition prevents using BF16, no changes should be +// made to the while body and thus the fusion node inside it. +TEST_F(BFloat16PropagationTest, + ConditionPreventsPropagationForFusionInsideWhile) { + auto module = CreateNewModule(); + auto builder = HloComputation::Builder(TestName()); + Shape shape = ShapeUtil::MakeShape(F32, {4, 4}); + + HloInstruction* param0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, shape, "param0")); + HloInstruction* param1 = builder.AddInstruction( + HloInstruction::CreateParameter(1, shape, "param1")); + HloInstruction* add = builder.AddInstruction( + HloInstruction::CreateBinary(shape, HloOpcode::kAdd, param0, param1)); + + auto builder_cond = HloComputation::Builder("cond"); + auto cond_param = builder_cond.AddInstruction( + HloInstruction::CreateParameter(0, shape, "cond_param")); + builder_cond.AddInstruction(HloInstruction::CreateBinary( + ShapeUtil::MakeShape(PRED, {}), HloOpcode::kGt, + builder_cond.AddInstruction(HloInstruction::CreateSlice( + ShapeUtil::MakeShape(F32, {}), cond_param, {0, 0}, {1, 1}, {1, 1})), + builder_cond.AddInstruction(HloInstruction::CreateSlice( + ShapeUtil::MakeShape(F32, {}), cond_param, {1, 1}, {2, 2}, {1, 1})))); + auto cond = module->AddEmbeddedComputation(builder_cond.Build()); + + auto builder_body = HloComputation::Builder("body"); + auto body_param = builder_body.AddInstruction( + HloInstruction::CreateParameter(0, shape, "body_param")); + auto body_transpose = builder_body.AddInstruction( + HloInstruction::CreateTranspose(shape, body_param, {0, 1})); + + auto builder_f = HloComputation::Builder("fusion"); + HloInstruction* a_f = + builder_f.AddInstruction(HloInstruction::CreateParameter(0, shape, "a")); + builder_f.AddInstruction(HloInstruction::CreateTranspose(shape, a_f, {0, 1})); + auto comp_f = module->AddEmbeddedComputation(builder_f.Build()); + auto body_fusion = builder_body.AddInstruction(HloInstruction::CreateFusion( + shape, HloInstruction::FusionKind::kCustom, {body_transpose}, comp_f)); + auto body = module->AddEmbeddedComputation(builder_body.Build()); + + auto while_hlo = builder.AddInstruction( + HloInstruction::CreateWhile(shape, cond, body, add)); + + auto dot = builder.AddInstruction(HloInstruction::CreateBinary( + shape, HloOpcode::kDot, while_hlo, while_hlo)); + auto computation = module->AddEntryComputation(builder.Build()); + + EXPECT_FALSE(PropagatePrecision(module.get())); + EXPECT_EQ(computation->root_instruction(), dot); + EXPECT_FALSE(OutputsBF16(add)); + EXPECT_FALSE(OutputsBF16(body_fusion)); + EXPECT_FALSE(OutputsBF16(body_param)); + EXPECT_FALSE(OutputsBF16(body_transpose)); + EXPECT_FALSE(OutputsBF16(a_f)); +} + // Tests that BF16 is propagated properly through while computations with // tuple-shaped input/output. TEST_F(BFloat16PropagationTest, PropagateThroughTupleWhile) { @@ -553,10 +610,14 @@ TEST_F(BFloat16PropagationTest, PropagateThroughTupleWhile) { HloInstruction::CreateGetTupleElement(shape, body_param, 0)); auto body_rhs = builder_body.AddInstruction( HloInstruction::CreateGetTupleElement(shape, body_param, 1)); - auto body_dot = builder_body.AddInstruction( + auto body_dot1 = builder_body.AddInstruction( HloInstruction::CreateBinary(shape, HloOpcode::kDot, body_lhs, body_rhs)); + auto body_dot2 = builder_body.AddInstruction( + HloInstruction::CreateBinary(shape, HloOpcode::kDot, body_rhs, body_lhs)); + auto body_transpose = builder_body.AddInstruction( + HloInstruction::CreateTranspose(shape, body_dot2, {0, 1})); builder_body.AddInstruction( - HloInstruction::CreateTuple({body_dot, body_rhs})); + HloInstruction::CreateTuple({body_dot1, body_transpose})); auto body = module->AddEmbeddedComputation(builder_body.Build()); auto while_hlo = builder.AddInstruction( @@ -575,9 +636,11 @@ TEST_F(BFloat16PropagationTest, PropagateThroughTupleWhile) { EXPECT_EQ(computation->root_instruction(), dot); EXPECT_TRUE(OutputsBF16(lhs)); EXPECT_FALSE(OutputsBF16(rhs)); - EXPECT_TRUE(OutputsBF16(body_dot)); + EXPECT_TRUE(OutputsBF16(body_dot1)); EXPECT_TRUE(OutputsBF16(body_lhs)); EXPECT_FALSE(OutputsBF16(body_rhs)); + EXPECT_FALSE(OutputsBF16(body_dot2)); + EXPECT_FALSE(OutputsBF16(body_transpose)); EXPECT_TRUE(OutputsBF16(cond_lhs)); EXPECT_FALSE(OutputsBF16(cond_rhs)); EXPECT_TRUE(OutputsBF16(add0)); diff --git a/tensorflow/compiler/xla/service/buffer_assignment.cc b/tensorflow/compiler/xla/service/buffer_assignment.cc index b4c7cf0dd8d3520077f2131b65192865d3701602..cc15c7122fc0776d38461569c5ced1eede2439f3 100644 --- a/tensorflow/compiler/xla/service/buffer_assignment.cc +++ b/tensorflow/compiler/xla/service/buffer_assignment.cc @@ -22,8 +22,8 @@ limitations under the License. #include #include +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/map_util.h" -#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/buffer_value_containers.h" #include "tensorflow/compiler/xla/service/heap_simulator.h" #include "tensorflow/compiler/xla/service/hlo.pb.h" @@ -139,6 +139,7 @@ Status GatherComputationsByAllocationType( case HloOpcode::kMap: case HloOpcode::kReduce: case HloOpcode::kReduceWindow: + case HloOpcode::kScatter: case HloOpcode::kSelectAndScatter: case HloOpcode::kFusion: // Map/reduce etc computations are always thread-local. @@ -817,8 +818,7 @@ bool BufferAssigner::MaybeAssignBuffer(BufferAllocation* allocation, } Status BufferAssigner::AssignBuffersForComputation( - const HloComputation* computation, const DebugOptions& debug_options, - bool is_thread_local, + const HloComputation* computation, bool is_thread_local, const FlatSet& colocated_buffers, const FlatSet& colocated_allocations, FlatMap>* @@ -878,8 +878,8 @@ Status BufferAssigner::AssignBuffersForComputation( // important reuse case where an elementwise instruction reuses one of its // operand's buffer. This improves locality. std::sort(sorted_buffers.begin(), sorted_buffers.end(), - [this, has_sequential_order, &liveness, &post_order_position, - assignment](const LogicalBuffer* a, const LogicalBuffer* b) { + [has_sequential_order, &liveness, &post_order_position, assignment]( + const LogicalBuffer* a, const LogicalBuffer* b) { // Primary sort is by decreasing buffer size. const int64 a_size = assignment->buffer_size_(*a); const int64 b_size = assignment->buffer_size_(*b); @@ -1100,8 +1100,8 @@ Status BufferAssigner::AssignBuffersWithSequentialOrdering( options.buffers_to_assign = &buffer_value_set; TF_ASSIGN_OR_RETURN( const HeapSimulator::Result result, - HeapSimulator::Run(MakeUnique( - MakeUnique(alignment)), + HeapSimulator::Run(absl::make_unique( + absl::make_unique(alignment)), assignment->module(), module_sequence, assignment->points_to_analysis(), assignment->buffer_size_, options)); @@ -1130,11 +1130,12 @@ Status BufferAssigner::AssignBuffersWithSequentialOrdering( options.buffers_to_assign = &buffer_value_set; TF_ASSIGN_OR_RETURN( const HeapSimulator::Result result, - HeapSimulator::Run(MakeUnique( - MakeUnique(alignment)), - *computation, *instruction_sequence, - assignment->points_to_analysis(), - assignment->buffer_size_, options)); + HeapSimulator::Run( + absl::make_unique( + absl::make_unique(alignment)), + *computation, *instruction_sequence, + assignment->points_to_analysis(), assignment->buffer_size_, + options)); AssignBuffersFromHeapSimulator(result, assignment, single_colored_set.first); } @@ -1342,11 +1343,25 @@ BufferAssigner::MergeColocatedBufferSets( auto cannot_merge_buffer_sets = [&colocated_buffer_sets, &buffer_liveness, &buffer_size, &is_entry_parameter](int64 i, int64 j) { - // Do not merge if one of the sets includes live outs or entry parameters. + // Do not merge if one of the sets includes live outs, entry parameters or + // constants. + // + // Buffer liveness does not report the correct live range for entry + // parameter and live out buffers so we have to special case them here. On + // backends that support constant buffer allocations, constant buffers are + // assigned globals in readonly storage so we can't merge colocated buffer + // sets containing constants with colocated buffer sets containing writing + // instructions or other constants. + // + // Moreover (on the CPU/GPU backends) the entry parameter buffers belong to + // the caller of the executable so we can't write to entry parameters + // either, and the argument for not merging constants also applies to entry + // parameters. for (int64 key : {i, j}) { for (auto& buffer : colocated_buffer_sets[key]) { if (buffer_liveness.MaybeLiveOut(*buffer) || - is_entry_parameter(*buffer)) { + is_entry_parameter(*buffer) || + buffer->instruction()->opcode() == HloOpcode::kConstant) { return true; } } @@ -1428,9 +1443,9 @@ void BufferAssigner::BuildColocatedBufferSets( const HloInstruction* while_hlo = instruction; ShapeUtil::ForEachSubshape( while_hlo->shape(), - [this, while_hlo, &points_to_analysis, &buffer_liveness, - buffer_size, computation, colocated_buffer_sets]( - const Shape& /*subshape*/, const ShapeIndex& index) { + [this, while_hlo, &points_to_analysis, buffer_size, + colocated_buffer_sets](const Shape& /*subshape*/, + const ShapeIndex& index) { std::vector colocated_set; // Add while.init. AddBufferToColocatedSet(while_hlo->operand(0), index, @@ -1632,7 +1647,8 @@ StatusOr> BufferAssigner::CreateAssignment( XLA_VLOG_LINES(3, liveness->ToString()); XLA_VLOG_LINES(3, liveness->points_to_analysis().ToString()); - // Can't use MakeUnique because BufferAssignment constructor is private. + // Can't use absl::make_unique because BufferAssignment constructor is + // private. std::unique_ptr assignment( new BufferAssignment(module, std::move(liveness), std::move(buffer_size), std::move(color_alignment))); @@ -1664,7 +1680,7 @@ StatusOr> BufferAssigner::CreateAssignment( buffers_to_assign_sequentially; for (auto* computation : global_computations) { TF_RETURN_IF_ERROR(AssignBuffersForComputation( - computation, module->config().debug_options(), + computation, /*is_thread_local=*/false, colocated_buffers, colocated_allocations, &buffers_to_assign_sequentially, assignment.get())); } @@ -1685,7 +1701,7 @@ StatusOr> BufferAssigner::CreateAssignment( continue; } TF_RETURN_IF_ERROR(AssignBuffersForComputation( - computation, module->config().debug_options(), + computation, /*is_thread_local=*/true, colocated_buffers, colocated_allocations, /*buffers_to_assign_sequentially=*/nullptr, assignment.get())); } diff --git a/tensorflow/compiler/xla/service/buffer_assignment.h b/tensorflow/compiler/xla/service/buffer_assignment.h index 4fcf1fc73defcfba16c33224bd9c785675674408..94495290c131e22392079dc2d0237d990b646d3e 100644 --- a/tensorflow/compiler/xla/service/buffer_assignment.h +++ b/tensorflow/compiler/xla/service/buffer_assignment.h @@ -32,7 +32,6 @@ limitations under the License. #include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" -#include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/lib/gtl/flatmap.h" #include "tensorflow/core/lib/gtl/flatset.h" @@ -543,8 +542,7 @@ class BufferAssigner { // true, then all assigned buffers have the is_thread_local flag set to // true. Status AssignBuffersForComputation( - const HloComputation* computation, const DebugOptions& debug_options, - bool is_thread_local, + const HloComputation* computation, bool is_thread_local, const tensorflow::gtl::FlatSet& colocated_buffers, const tensorflow::gtl::FlatSet& colocated_allocations, diff --git a/tensorflow/compiler/xla/service/buffer_assignment_test.cc b/tensorflow/compiler/xla/service/buffer_assignment_test.cc index dea855d39ad759e7c3c13fcd3ccf06e4a1089df7..52abda16c4ee8e494b596e0690a8067743380054 100644 --- a/tensorflow/compiler/xla/service/buffer_assignment_test.cc +++ b/tensorflow/compiler/xla/service/buffer_assignment_test.cc @@ -21,8 +21,8 @@ limitations under the License. #include #include +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/literal.h" -#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/buffer_value.h" #include "tensorflow/compiler/xla/service/call_graph.h" #include "tensorflow/compiler/xla/service/copy_insertion.h" @@ -87,7 +87,7 @@ class BufferAssignmentTest : public HloTestBase { std::unique_ptr RunBufferAssignment(HloModule* module, int64 alignment = 1) { return BufferAssigner::Run( - module, xla::MakeUnique(module), + module, absl::make_unique(module), backend().compiler()->BufferSizeBytesFunction(), [alignment](LogicalBuffer::Color) { return alignment; }, /*allow_input_output_aliasing=*/false, @@ -98,7 +98,7 @@ class BufferAssignmentTest : public HloTestBase { std::unique_ptr RunBufferAssignmentNoBuffersForConstants( HloModule* module, int64 alignment = 1) { return BufferAssigner::Run( - module, xla::MakeUnique(module), + module, absl::make_unique(module), backend().compiler()->BufferSizeBytesFunction(), [alignment](LogicalBuffer::Color) { return alignment; }, /*allow_input_output_aliasing=*/false, @@ -109,7 +109,7 @@ class BufferAssignmentTest : public HloTestBase { std::unique_ptr RunColoredBufferAssignment( HloModule* module, BufferLiveness::Colorer colorer, int64 alignment = 1) { return BufferAssigner::Run( - module, xla::MakeUnique(module), + module, absl::make_unique(module), backend().compiler()->BufferSizeBytesFunction(), [alignment](LogicalBuffer::Color) { return alignment; }, /*allow_input_output_aliasing=*/false, @@ -127,7 +127,8 @@ class BufferAssignmentTest : public HloTestBase { instruction_sequence.end()); return BufferAssigner::Run( module, - xla::MakeUnique(module, module_sequence), + absl::make_unique(module, + module_sequence), backend().compiler()->BufferSizeBytesFunction(), [alignment](LogicalBuffer::Color) { return alignment; }, /*allow_input_output_aliasing=*/false, @@ -1769,7 +1770,8 @@ class WhileBufferAssignmentTest : public HloTestBase { auto sequence = ScheduleComputationsInModule(*module, ByteSizeOf).ConsumeValueOrDie(); return BufferAssigner::Run( - module, xla::MakeUnique(module, sequence), + module, + absl::make_unique(module, sequence), ByteSizeOf, [alignment](LogicalBuffer::Color) { return alignment; }, /*allow_input_output_aliasing=*/false, @@ -1923,6 +1925,74 @@ ENTRY %test_module { EXPECT_NE(slice_param, slice_while1); } +TEST_F(WhileBufferAssignmentTest, ColocatedBufferWithConstant) { + const Shape r0s32 = ShapeUtil::MakeShape(S32, {}); + + const char* module_str = R"( +HloModule test_module + +%cond.v0 { + %param = s32[] parameter(0) + ROOT %constant = pred[] constant(true) +} + +%cond.v1 { + %param.0 = s32[] parameter(0) + ROOT %constant.0 = pred[] constant(true) +} + +%body.v0 { + ROOT %param.1 = s32[] parameter(0) +} + +%body.v1 { + %param.2 = s32[] parameter(0) + ROOT add = s32[] add(%param.2, %param.2) +} + +ENTRY %test_module { + %constant.42 = s32[] constant(42) + %while.0 = s32[] while(%constant.42), condition=%cond.v0, body=%body.v0 + %mul = s32[] multiply(%while.0, %while.0) + %while.1 = s32[] while(%mul), condition=%cond.v1, body=%body.v1 + ROOT %bcast = s32[1024,1024]{1,0} broadcast(s32[] %while.1), dimensions={} +})"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseHloString(module_str)); + + // Run CopyInsertion and check if the graph constructed above doesn't need + // any copies inserted for BufferAssignment to run. + int64 instruction_count = module->instruction_count(); + CopyInsertion copy_insertion; + ASSERT_IS_OK(copy_insertion.Run(module.get()).status()); + ASSERT_EQ(instruction_count, module->instruction_count()); + + // Get the instructions in the module. + const HloInstruction* bcast = module->entry_computation()->root_instruction(); + const HloInstruction* constant = + module->entry_computation()->GetInstructionWithName("constant.42"); + ASSERT_EQ(bcast->opcode(), HloOpcode::kBroadcast); + const HloInstruction* while1 = bcast->operand(0); + ASSERT_EQ(while1->opcode(), HloOpcode::kWhile); + const HloInstruction* while0 = while1->operand(0)->operand(0); + ASSERT_EQ(while0->opcode(), HloOpcode::kWhile); + + // Run buffer assignment. + auto assignment = RunBufferAssignment(module.get()); + TF_ASSERT_OK_AND_ASSIGN(auto slice_constant, + assignment->GetUniqueSlice(constant, {})); + TF_ASSERT_OK_AND_ASSIGN(auto slice_while0, + assignment->GetUniqueSlice(while0, {})); + TF_ASSERT_OK_AND_ASSIGN(auto slice_while1, + assignment->GetUniqueSlice(while1, {})); + + // The constant slice is part of the while0's colocation set (init value), but + // not merged into the while1's colocation set. + EXPECT_EQ(slice_constant, slice_while0); + EXPECT_NE(slice_constant, slice_while1); +} + // Tests that the colocated buffers for while instructions are properly assigned // during buffer assignment such that the result tuple elements are not assigned // to the same buffer. @@ -2015,7 +2085,7 @@ TEST_F(WhileBufferAssignmentTest, ColocatedBuffers) { auto assignment, BufferAssigner::Run( module.get(), - xla::MakeUnique(module.get(), sequence), + absl::make_unique(module.get(), sequence), backend().compiler()->BufferSizeBytesFunction(), [](LogicalBuffer::Color) { return 1; }, /*allow_input_output_aliasing=*/false, @@ -2272,7 +2342,7 @@ TEST_F(WhileBufferAssignmentTest, WhileLoopsInterferingResultRange) { auto assignment = BufferAssigner::Run( module.get(), - xla::MakeUnique(module.get(), sequence), + absl::make_unique(module.get(), sequence), ByteSizeOf, [](LogicalBuffer::Color) { return 1; }, /*allow_input_output_aliasing=*/false, /*allocate_buffers_for_constants=*/true) diff --git a/tensorflow/compiler/xla/service/buffer_liveness_test.cc b/tensorflow/compiler/xla/service/buffer_liveness_test.cc index 4a927b57674345f8b3493c098778182a299c5902..3ffb7de65fb63b24e8be4978063d3f9f78f3e9ac 100644 --- a/tensorflow/compiler/xla/service/buffer_liveness_test.cc +++ b/tensorflow/compiler/xla/service/buffer_liveness_test.cc @@ -18,7 +18,7 @@ limitations under the License. #include #include -#include "tensorflow/compiler/xla/ptr_util.h" +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" @@ -119,8 +119,8 @@ TEST_F(BufferLivenessTest, ElementwiseChain) { module->AddEntryComputation(builder.Build()); auto liveness = - BufferLiveness::Run(module.get(), - xla::MakeUnique(module.get())) + BufferLiveness::Run( + module.get(), absl::make_unique(module.get())) .ConsumeValueOrDie(); EXPECT_FALSE(InstructionsMayInterfere(*liveness, param, negate)); @@ -167,10 +167,10 @@ TEST_F(BufferLivenessTest, MultipleEntryParameters_Sequential) { SequentialHloOrdering::HloModuleSequence sequence; sequence.insert({entry, {param0, negate, param1, exp, add}}); - auto liveness = - BufferLiveness::Run(module.get(), xla::MakeUnique( - module.get(), sequence)) - .ConsumeValueOrDie(); + auto liveness = BufferLiveness::Run(module.get(), + absl::make_unique( + module.get(), sequence)) + .ConsumeValueOrDie(); // Entry parameters interfere as if they are defined simultaneously at // the very beginning. @@ -215,8 +215,8 @@ TEST_F(BufferLivenessTest, NonElementwiseOperand) { module->AddEntryComputation(builder.Build()); auto liveness = - BufferLiveness::Run(module.get(), - xla::MakeUnique(module.get())) + BufferLiveness::Run( + module.get(), absl::make_unique(module.get())) .ConsumeValueOrDie(); EXPECT_FALSE(InstructionsMayInterfere(*liveness, param, exp)); @@ -249,8 +249,8 @@ TEST_F(BufferLivenessTest, OverlappedBuffers) { module->AddEntryComputation(builder.Build()); auto liveness = - BufferLiveness::Run(module.get(), - xla::MakeUnique(module.get())) + BufferLiveness::Run( + module.get(), absl::make_unique(module.get())) .ConsumeValueOrDie(); EXPECT_TRUE(InstructionsMayInterfere(*liveness, param, negate)); @@ -293,10 +293,10 @@ TEST_F(BufferLivenessTest, OverlappedBuffersSequentialOrder) { SequentialHloOrdering::HloModuleSequence module_sequence; std::vector order = {param, negate, exp, add}; module_sequence.emplace(computation, order); - auto liveness = - BufferLiveness::Run(module.get(), xla::MakeUnique( - module.get(), module_sequence)) - .ConsumeValueOrDie(); + auto liveness = BufferLiveness::Run(module.get(), + absl::make_unique( + module.get(), module_sequence)) + .ConsumeValueOrDie(); EXPECT_TRUE(InstructionsMayInterfere(*liveness, param, negate)); EXPECT_FALSE(InstructionsMayInterfere(*liveness, param, exp)); @@ -342,10 +342,10 @@ TEST_F(BufferLivenessTest, RootInstructionIsNotLastInSequentialOrder) { std::vector order = {param, add, recv, recv_done, send, send_done}; module_sequence.emplace(computation, order); - auto liveness = - BufferLiveness::Run(module.get(), xla::MakeUnique( - module.get(), module_sequence)) - .ConsumeValueOrDie(); + auto liveness = BufferLiveness::Run(module.get(), + absl::make_unique( + module.get(), module_sequence)) + .ConsumeValueOrDie(); EXPECT_FALSE(InstructionsMayInterfere(*liveness, param, add)); // Check the root instruction (add) buffer interferes with the recv buffer. @@ -376,8 +376,8 @@ TEST_F(BufferLivenessTest, TupleLiveOut) { module->AddEntryComputation(builder.Build()); auto liveness = - BufferLiveness::Run(module.get(), - xla::MakeUnique(module.get())) + BufferLiveness::Run( + module.get(), absl::make_unique(module.get())) .ConsumeValueOrDie(); // All buffers should be live out except the param @@ -412,8 +412,8 @@ TEST_F(BufferLivenessTest, EmbeddedComputation) { module->AddEntryComputation(builder.Build()); auto liveness = - BufferLiveness::Run(module.get(), - xla::MakeUnique(module.get())) + BufferLiveness::Run( + module.get(), absl::make_unique(module.get())) .ConsumeValueOrDie(); // Buffers in different computations should always interfere. @@ -453,8 +453,8 @@ TEST_F(BufferLivenessTest, TupleConstantLiveOut) { module->AddEntryComputation(builder.Build()); auto liveness = - BufferLiveness::Run(module.get(), - xla::MakeUnique(module.get())) + BufferLiveness::Run( + module.get(), absl::make_unique(module.get())) .ConsumeValueOrDie(); // Only the element buffers of the tuple constant which are pointed to by @@ -518,8 +518,8 @@ TEST_F(BufferLivenessTest, IndependentTupleElements) { module->AddEmbeddedComputation(builder.Build()); auto liveness = - BufferLiveness::Run(module.get(), - xla::MakeUnique(module.get())) + BufferLiveness::Run( + module.get(), absl::make_unique(module.get())) .ConsumeValueOrDie(); // We compare tuple element pairs that are input/output to the computation: @@ -580,8 +580,8 @@ TEST_F(BufferLivenessTest, DependentTupleElements) { module->AddEmbeddedComputation(builder.Build()); auto liveness = - BufferLiveness::Run(module.get(), - xla::MakeUnique(module.get())) + BufferLiveness::Run( + module.get(), absl::make_unique(module.get())) .ConsumeValueOrDie(); // We compare tuple element pairs that are input/output to the computation: @@ -668,10 +668,10 @@ class FusedDynamicUpdateSliceLivenessTest : public BufferLivenessTest { } // Run BufferLiveness on 'module'. - auto liveness = - BufferLiveness::Run( - module.get(), xla::MakeUnique(module.get())) - .ConsumeValueOrDie(); + auto liveness = BufferLiveness::Run( + module.get(), + absl::make_unique(module.get())) + .ConsumeValueOrDie(); // Return whether or not buffers interference is detected between // 'tuple_param0' and 'tuple_root' at shape index '{1}'. return TupleElementsMayInterfere(*liveness, tuple_param0, tuple_root, {1}); @@ -780,10 +780,10 @@ class DynamicUpdateSliceLivenessTest : public BufferLivenessTest { module->AddEntryComputation(BuildDummyComputation()); module->AddEmbeddedComputation(builder.Build()); // Run BufferLiveness on 'module'. - auto liveness = - BufferLiveness::Run( - module.get(), xla::MakeUnique(module.get())) - .ConsumeValueOrDie(); + auto liveness = BufferLiveness::Run( + module.get(), + absl::make_unique(module.get())) + .ConsumeValueOrDie(); // Return whether or not buffers interference is detected between // 'tuple_param0' and 'tuple_root' at shape index '{1}'. return TupleElementsMayInterfere(*liveness, tuple_param0, tuple_root, {1}); diff --git a/tensorflow/compiler/xla/service/call_graph.cc b/tensorflow/compiler/xla/service/call_graph.cc index a23427f00ccd88bb0fe1d973a667f80ca54b14cd..d6efef5f12f62733ddd3a5314249ee9262571f97 100644 --- a/tensorflow/compiler/xla/service/call_graph.cc +++ b/tensorflow/compiler/xla/service/call_graph.cc @@ -17,8 +17,8 @@ limitations under the License. #include +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/map_util.h" -#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/core/errors.h" @@ -61,6 +61,7 @@ CallContext GetInstructionCallContext(HloOpcode opcode) { case HloOpcode::kMap: case HloOpcode::kReduce: case HloOpcode::kReduceWindow: + case HloOpcode::kScatter: case HloOpcode::kSelectAndScatter: case HloOpcode::kFusion: return CallContext::kParallel; @@ -236,8 +237,8 @@ void CallGraph::SetCallContexts() { /* static */ std::unique_ptr CallGraph::Build(const HloModule* module) { - // Constructor for CallGraph is private so MakeUnique can't be used. - auto call_graph = WrapUnique(new CallGraph(module)); + // Constructor for CallGraph is private so absl::make_unique can't be used. + auto call_graph = absl::WrapUnique(new CallGraph(module)); VLOG(2) << "Building call graph for:"; XLA_VLOG_LINES(2, module->ToString()); diff --git a/tensorflow/compiler/xla/service/call_inliner_test.cc b/tensorflow/compiler/xla/service/call_inliner_test.cc index ff968bca297077c7cf869ff8d2becb8bf739dce3..e75f6f146d7c5896cfe6566fdec212a60e9f8457 100644 --- a/tensorflow/compiler/xla/service/call_inliner_test.cc +++ b/tensorflow/compiler/xla/service/call_inliner_test.cc @@ -18,9 +18,9 @@ limitations under the License. #include #include +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/literal.h" -#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_matchers.h" diff --git a/tensorflow/compiler/xla/service/channel_tracker.cc b/tensorflow/compiler/xla/service/channel_tracker.cc index 13008efed1494402eaff47904c2e4797334381a1..9c9e373821d7f84f3468ef6c6a4f7dae9715b9f8 100644 --- a/tensorflow/compiler/xla/service/channel_tracker.cc +++ b/tensorflow/compiler/xla/service/channel_tracker.cc @@ -15,7 +15,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/channel_tracker.h" -#include "tensorflow/compiler/xla/ptr_util.h" +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/status.h" diff --git a/tensorflow/compiler/xla/service/compiler.h b/tensorflow/compiler/xla/service/compiler.h index 99abb9bae32b35652e84cddc7c38dbd97ecb5006..34f7fe12cac5a4dcd3822865bee903d6eabc25c0 100644 --- a/tensorflow/compiler/xla/service/compiler.h +++ b/tensorflow/compiler/xla/service/compiler.h @@ -48,11 +48,6 @@ namespace xla { // compuation. using ObjectFileData = std::vector; -// Contains the buffer sizes information needed to allocate buffers to execute -// an ahead-of-time computation. Entries which contain -1 designate a parameter -// which should be skipped over during allocation. -using BufferSizes = std::vector; - // Abstract superclass describing the result of an ahead-of-time compilation. class AotCompilationResult { public: diff --git a/tensorflow/compiler/xla/service/computation_placer.cc b/tensorflow/compiler/xla/service/computation_placer.cc index d26486fcfe0b1bc51867de5113cc5e42a0d7b4f0..afbbea35b893b8c14dbc0454e0a01fcb451cb709 100644 --- a/tensorflow/compiler/xla/service/computation_placer.cc +++ b/tensorflow/compiler/xla/service/computation_placer.cc @@ -19,8 +19,8 @@ limitations under the License. #include #include +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/literal.h" -#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status.h" #include "tensorflow/compiler/xla/status_macros.h" @@ -29,9 +29,13 @@ limitations under the License. #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" +using tensorflow::strings::StrAppend; +using tensorflow::strings::StrCat; + namespace xla { Status DeviceAssignment::Serialize(DeviceAssignmentProto* proto) const { @@ -56,8 +60,8 @@ DeviceAssignment::Deserialize(const DeviceAssignmentProto& proto) { "computation_count=%d", proto.replica_count(), proto.computation_count()); } - auto assignment = MakeUnique(proto.replica_count(), - proto.computation_count()); + auto assignment = absl::make_unique( + proto.replica_count(), proto.computation_count()); for (int computation = 0; computation < proto.computation_count(); ++computation) { const auto& computation_device = proto.computation_devices(computation); @@ -71,6 +75,19 @@ DeviceAssignment::Deserialize(const DeviceAssignmentProto& proto) { return std::move(assignment); } +string DeviceAssignment::ToString() const { + string output = StrCat("Computations: ", computation_count(), + " Replicas: ", replica_count(), "\n"); + for (int computation = 0; computation < computation_count(); ++computation) { + StrAppend(&output, "Computation ", computation, ": "); + for (int replica = 0; replica < replica_count(); ++replica) { + StrAppend(&output, operator()(replica, computation), " "); + } + StrAppend(&output, "\n"); + } + return output; +} + StatusOr ComputationPlacer::DeviceId(int replica, int computation, int replica_count, int computation_count) { @@ -139,7 +156,7 @@ ComputationPlacer::GetPlatformComputationPlacers() { } // namespace xla static std::unique_ptr CreateComputationPlacer() { - return xla::MakeUnique(); + return absl::make_unique(); } static bool InitModule() { diff --git a/tensorflow/compiler/xla/service/computation_placer.h b/tensorflow/compiler/xla/service/computation_placer.h index 737d00e93ecb51a9bd544bbcbe99d93374d108fb..c899ffb9dc562426ef14c0d414469c04debeec70 100644 --- a/tensorflow/compiler/xla/service/computation_placer.h +++ b/tensorflow/compiler/xla/service/computation_placer.h @@ -55,6 +55,8 @@ class DeviceAssignment : public Array2D { // due to a StatusOr of an incomplete type (DeviceAssignment). static StatusOr> Deserialize( const DeviceAssignmentProto& proto); + + string ToString() const; }; // A generic implementation of the XLA computation placer, which assigns device diff --git a/tensorflow/compiler/xla/service/convolution_feature_group_converter.cc b/tensorflow/compiler/xla/service/convolution_feature_group_converter.cc new file mode 100644 index 0000000000000000000000000000000000000000..8affa08b6529e33cb5f1b13103dc4d69bcaa0e9c --- /dev/null +++ b/tensorflow/compiler/xla/service/convolution_feature_group_converter.cc @@ -0,0 +1,248 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/convolution_feature_group_converter.h" + +#include +#include + +#include "absl/memory/memory.h" +#include "tensorflow/compiler/xla/literal.h" +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/platform/logging.h" + +namespace xla { + +namespace { + +// ConvolutionVisitor traverses the HLO computation and rewrites Convolution +// operations with feature_group_count > 1 into convolutions with +// feature_group_count = 1. +class ConvolutionVisitor : public DfsHloVisitorWithDefault { + public: + // Default visitor action is to do nothing and return OK. + Status DefaultAction(HloInstruction* /*hlo_instruction*/) override { + return Status::OK(); + } + + Status HandleConvolution(HloInstruction* convolution) override; + + // Runs the visitor on a computation. + static bool Run(HloComputation* computation); + + // Returns whether any convolution ops were rewritten. + const bool changed() const { return changed_; } + + ~ConvolutionVisitor() override = default; + + private: + explicit ConvolutionVisitor(HloComputation* computation) + : computation_(computation) {} + + // Current HloComputation instance the ConvolutionVisitor is traversing. + HloComputation* computation_; + + // Whether rewrite has occurred. + bool changed_ = false; +}; + +bool ConvolutionVisitor::Run(HloComputation* computation) { + ConvolutionVisitor visitor(computation); + TF_CHECK_OK(computation->Accept(&visitor)); + return visitor.changed_; +} + +Shape ExpandedFilterShape(const Shape& shape, int64 group_count, + int64 input_feature_dim) { + int64 num_dims = shape.dimensions_size(); + CHECK_GE(num_dims, 2); + Shape expanded_shape = shape; + expanded_shape.set_dimensions( + input_feature_dim, shape.dimensions(input_feature_dim) * group_count); + return expanded_shape; +} + +// Returns a vector with 'group_count' many groups, where the i-th group +// consists of 'group_size' times the value i. +std::vector GetMaskIds(int64 group_size, int64 group_count) { + std::vector values; + for (int i = 0; i < group_count; ++i) { + for (int j = 0; j < group_size; ++j) { + values.push_back(i); + } + } + return values; +} + +// Create a mask for grouped convolution that will make a normal convolution +// produce the same results as a grouped convolution. For a [2, 1, 6] +// filter this returns a [2, 3, 6] mask +// 1 1 0 0 0 0 +// 0 0 1 1 0 0 +// 0 0 0 0 1 1 +// +// 1 1 0 0 0 0 +// 0 0 1 1 0 0 +// 0 0 0 0 1 1 +// +// The first step is to create a rank 1 constant: +// 0 1 2 +// +// This is broadcasted to +// 0 0 0 0 0 0 +// 1 1 1 1 1 1 +// 2 2 2 2 2 2 +// +// 0 0 0 0 0 0 +// 1 1 1 1 1 1 +// 2 2 2 2 2 2 +// +// Then we create another rank 1 constant +// 0 0 1 1 2 2 +// +// This is broadcasted to +// 0 0 1 1 2 2 +// 0 0 1 1 2 2 +// 0 0 1 1 2 2 +// +// 0 0 1 1 2 2 +// 0 0 1 1 2 2 +// 0 0 1 1 2 2 +// +// Finally we use the Eq op of these two broadcasted constants and get the +// desired mask. +HloInstruction* GetExpandedFilterMask( + const Shape& filter_shape, int64 input_feature_dim, + int64 output_feature_dim, int64 group_count, + const std::function)>& + add_instruction) { + Shape expanded_filter_shape = + ExpandedFilterShape(filter_shape, group_count, input_feature_dim); + Shape mask_shape = ShapeUtil::MakeShape( + S32, AsInt64Slice(expanded_filter_shape.dimensions())); + int64 output_feature = filter_shape.dimensions(output_feature_dim); + int64 group_size = filter_shape.dimensions(input_feature_dim); + + // Create a 'input_feature' sized linspace and 'output_feature' sized linspace + // that will be broadcasted into perpendicular dimensions and compared. + const std::vector input_feature_filter_mask = + GetMaskIds(group_size, group_count); + const std::vector output_feature_filter_mask = + GetMaskIds(output_feature / group_count, group_count); + + auto mask1 = add_instruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR1(input_feature_filter_mask))); + auto broadcasted_mask1 = add_instruction( + HloInstruction::CreateBroadcast(mask_shape, mask1, {input_feature_dim})); + auto mask2 = add_instruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR1(output_feature_filter_mask))); + auto broadcasted_mask2 = add_instruction( + HloInstruction::CreateBroadcast(mask_shape, mask2, {output_feature_dim})); + + // Compare the broadcasted output feature linspace to the input feature + // linspace to create a diagonal predicate. + Shape predicate_shape = ShapeUtil::MakeShape( + PRED, AsInt64Slice(expanded_filter_shape.dimensions())); + return add_instruction(HloInstruction::CreateBinary( + predicate_shape, HloOpcode::kEq, broadcasted_mask1, broadcasted_mask2)); +} + +Status ConvolutionVisitor::HandleConvolution(HloInstruction* convolution) { + int64 group_count = convolution->feature_group_count(); + if (group_count == 1) { + return Status::OK(); + } + auto filter = convolution->mutable_operand(1); + changed_ = true; + auto add = [&](std::unique_ptr inst) { + return computation_->AddInstruction(std::move(inst)); + }; + + auto dim_numbers = convolution->convolution_dimension_numbers(); + int64 input_feature_dim = dim_numbers.kernel_input_feature_dimension(); + int64 group_size = filter->shape().dimensions(input_feature_dim); + int64 output_feature_dim = dim_numbers.kernel_output_feature_dimension(); + auto expanded_filter_shape = + ExpandedFilterShape(filter->shape(), group_count, input_feature_dim); + HloInstruction* filter_mask = GetExpandedFilterMask( + filter->shape(), input_feature_dim, output_feature_dim, group_count, add); + HloInstruction* expanded_filter; + // We want to repeat 'filter' in the 'input_feature_dim' dimension + // 'group_count' times. + if (group_size == 1) { + Shape reshaped_filter_shape = + ShapeUtil::DeleteDimension(input_feature_dim, filter->shape()); + auto reshaped_filter = + add(HloInstruction::CreateReshape(reshaped_filter_shape, filter)); + std::vector broadcast_dims; + for (int64 i = 0; i < filter->shape().dimensions_size(); ++i) { + if (i == input_feature_dim) { + continue; + } + broadcast_dims.push_back(i); + } + expanded_filter = add(HloInstruction::CreateBroadcast( + expanded_filter_shape, reshaped_filter, broadcast_dims)); + } else { + // We could possibly also use reshape, broadcast, reshape instead of concat + // here, but it would require more complex code, and for depthwise + // convolution we would never end up in this branch. + std::vector concat_operands(group_count, filter); + expanded_filter = add(HloInstruction::CreateConcatenate( + expanded_filter_shape, concat_operands, input_feature_dim)); + } + auto zero = add(HloInstruction::CreateConstant(absl::make_unique( + LiteralUtil::Zero(expanded_filter_shape.element_type())))); + auto zero_filter = + add(HloInstruction::CreateBroadcast(expanded_filter_shape, zero, {})); + auto new_filter = add( + HloInstruction::CreateTernary(expanded_filter_shape, HloOpcode::kSelect, + filter_mask, expanded_filter, zero_filter)); + auto new_convolution = HloInstruction::CreateConvolve( + convolution->shape(), convolution->mutable_operand(0), new_filter, + convolution->window(), dim_numbers, /*feature_group_count=*/1); + TF_RETURN_IF_ERROR(computation_->ReplaceWithNewInstruction( + convolution, std::move(new_convolution))); + return Status::OK(); +} + +} // namespace + +StatusOr ConvolutionFeatureGroupConverter::Run(HloModule* module) { + XLA_VLOG_LINES(2, "ConvolutionFeatureGroupConverter::Run(), before:\n" + + module->ToString()); + bool changed = false; + for (auto* comp : module->MakeNonfusionComputations()) { + if (ConvolutionVisitor::Run(comp)) { + changed = true; + } + } + XLA_VLOG_LINES(2, "ConvolutionFeatureGroupConverter::Run(), after:\n" + + module->ToString()); + return changed; +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/convolution_feature_group_converter.h b/tensorflow/compiler/xla/service/convolution_feature_group_converter.h new file mode 100644 index 0000000000000000000000000000000000000000..f213cc870918d476e839f97ae067504038f8cacc --- /dev/null +++ b/tensorflow/compiler/xla/service/convolution_feature_group_converter.h @@ -0,0 +1,43 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CONVOLUTION_FEATURE_GROUP_CONVERTER_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_CONVOLUTION_FEATURE_GROUP_CONVERTER_H_ + +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_pass_interface.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/core/lib/core/stringpiece.h" + +namespace xla { + +// A pass which rewrites convolutions with feature_group_count > 1 into +// convolutions with feature_group_count = 1. +class ConvolutionFeatureGroupConverter : public HloPassInterface { + public: + ConvolutionFeatureGroupConverter() {} + + tensorflow::StringPiece name() const override { + return "convolution-feature-group-converter"; + } + + // Run convolution rewriting on the given computation. Returns whether the + // computation was changed. + StatusOr Run(HloModule* module) override; +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_CONVOLUTION_FEATURE_GROUP_CONVERTER_H_ diff --git a/tensorflow/compiler/xla/service/convolution_feature_group_converter_test.cc b/tensorflow/compiler/xla/service/convolution_feature_group_converter_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..28373ebf636c7b6b3059dcf6cd931901ebc87fc2 --- /dev/null +++ b/tensorflow/compiler/xla/service/convolution_feature_group_converter_test.cc @@ -0,0 +1,100 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/convolution_feature_group_converter.h" + +#include +#include + +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_matchers.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/types.h" + +namespace xla { +namespace { + +using ConvolutionFeatureGroupConverterTest = HloTestBase; +namespace op = testing::opcode_matchers; + +TEST_F(ConvolutionFeatureGroupConverterTest, + ConvertFeatureGroupCountEqualToInputFeatureDim) { + string hlo_string = R"(HloModule Convolve1D1Window_0_module + +ENTRY %Convolve1D1Window_0.v3 (input: f32[1,2,2], filter: f32[1,1,2]) -> f32[1,2,2] { + %input = f32[1,2,2]{2,1,0} parameter(0) + %copy = f32[1,2,2]{2,0,1} copy(f32[1,2,2]{2,1,0} %input) + %filter = f32[1,1,2]{2,1,0} parameter(1) + ROOT %convolution = f32[1,2,2]{2,0,1} convolution(f32[1,2,2]{2,0,1} %copy, f32[1,1,2]{2,1,0} %filter), window={size=1}, dim_labels=b0f_0io->b0f, feature_group_count=2 +})"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseHloString(hlo_string)); + + auto computation = module->entry_computation(); + HloInstruction* root = computation->root_instruction(); + EXPECT_EQ(root->opcode(), HloOpcode::kConvolution); + ConvolutionFeatureGroupConverter converter; + ASSERT_TRUE(converter.Run(module.get()).ValueOrDie()); + root = computation->root_instruction(); + // Make sure the convolution is converted to one with feature_group_count = 1. + EXPECT_EQ(root->opcode(), HloOpcode::kConvolution); + EXPECT_EQ(root->feature_group_count(), 1); + // Verify that the filter operand has been replaced. + EXPECT_THAT(root->operand(1), + op::Select(op::Eq(op::Broadcast(op::Constant()), + op::Broadcast(op::Constant())), + op::Broadcast(op::Reshape(op::Parameter())), + op::Broadcast(op::Constant()))); +} + +TEST_F(ConvolutionFeatureGroupConverterTest, + ConvertFeatureGroupCountDivisorOfInputFeatureDim) { + string hlo_string = R"(HloModule Convolve1D1Window_0_module + +ENTRY %Convolve1D1Window_0.v3 (input: f32[1,2,4], filter: f32[1,2,2]) -> f32[1,2,2] { + %input = f32[1,2,4]{2,1,0} parameter(0) + %copy = f32[1,2,4]{2,0,1} copy(f32[1,2,4]{2,1,0} %input) + %filter = f32[1,2,2]{2,1,0} parameter(1) + ROOT %convolution = f32[1,2,2]{2,0,1} convolution(f32[1,2,4]{2,0,1} %copy, f32[1,2,2]{2,1,0} %filter), window={size=1}, dim_labels=b0f_0io->b0f, feature_group_count=2 +})"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseHloString(hlo_string)); + + auto computation = module->entry_computation(); + HloInstruction* root = computation->root_instruction(); + EXPECT_EQ(root->opcode(), HloOpcode::kConvolution); + ConvolutionFeatureGroupConverter converter; + ASSERT_TRUE(converter.Run(module.get()).ValueOrDie()); + root = computation->root_instruction(); + // Make sure the convolution is converted to one with feature_group_count = 1. + EXPECT_EQ(root->opcode(), HloOpcode::kConvolution); + EXPECT_EQ(root->feature_group_count(), 1); + // Verify that the filter operand has been replaced. + EXPECT_THAT(root->operand(1), + op::Select(op::Eq(op::Broadcast(op::Constant()), + op::Broadcast(op::Constant())), + // We expect to see Concatenate here instead of + // Broadcast, because feature_group_count < input + // feature dimension. + op::Concatenate(op::Parameter(), op::Parameter()), + op::Broadcast(op::Constant()))); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/service/copy_insertion.cc b/tensorflow/compiler/xla/service/copy_insertion.cc index 36fb9b43aa20bad788a0638b4fed6c88fc9023f0..3e39c1bab1e07d192a8c145be5103085fd3c189b 100644 --- a/tensorflow/compiler/xla/service/copy_insertion.cc +++ b/tensorflow/compiler/xla/service/copy_insertion.cc @@ -312,7 +312,7 @@ Status AddCopiesForWhile(const HloAliasAnalysis& alias_analysis, return Status::OK(); } -// We add copies for all the indices of the true and false computaiton roots, +// We add copies for all the indices of the true and false computation roots, // in order to resolve interference. We later rely on the CopyRemover to drop // the unnecessary ones. Status AddCopiesForConditional(const HloAliasAnalysis& alias_analysis, @@ -648,7 +648,12 @@ class CopyRemover { // We can only perform copy elision if the resulting merged values have // totally ordered live ranges; otherwise the merged buffer would have // live range interference. - if (IsHead(*dest)) { + if (src->next == dest) { + // In the process of eliding copies, its possible for a copy to have the + // same source and destination buffer. In this case, the copy can be + // safely removed. + VLOG(2) << copy->name() << " source and destination buffers are same."; + } else if (IsHead(*dest)) { // The copy copies an arbitrary value in the source buffer (call it s_x) // and defines d_0, the first value in the destination buffer. After // merging, the values in the combined buffer must be strictly ordered diff --git a/tensorflow/compiler/xla/service/copy_insertion_test.cc b/tensorflow/compiler/xla/service/copy_insertion_test.cc index cd735256b83f5f1d69a89e693de6064d460a36e5..892d0d7b547aaf1e7f1c55e4163d1e1fd9518def 100644 --- a/tensorflow/compiler/xla/service/copy_insertion_test.cc +++ b/tensorflow/compiler/xla/service/copy_insertion_test.cc @@ -2007,5 +2007,46 @@ ENTRY TestComputation { InsertCopies(module.get()); } +TEST_F(CopyInsertionTest, NestedWhiles) { + // Verify that only no unnecessary copies remain after copy insertion for + // trivial nested whiles (b/112472605). + const string& hlo_string = R"( +HloModule TestModule + +cond.inner { + ROOT param.cond.inner = pred[] parameter(0) +} + +body.inner { + param.body.inner = pred[] parameter(0) + ROOT neg = pred[] negate(param.body.inner) +} + +cond.outer { + ROOT param.cond.outer = pred[] parameter(0) +} + +body.outer { + param.cond.outer = pred[] parameter(0) + ROOT while = pred[] while(param.cond.outer), condition=cond.inner, body=body.inner +} + +ENTRY TestComputation { + entry_param = pred[] parameter(0) + ROOT while = pred[] while(entry_param), condition=cond.outer, body=body.outer +} +)"; + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr module, + HloRunner::CreateModuleFromString(hlo_string, GetDebugOptionsForTest())); + InsertCopies(module.get()); + + // There should only be a single copy inserted, and it's in the entry + // computation. + EXPECT_EQ(CountCopies(*module), 1); + EXPECT_THAT(module->entry_computation()->root_instruction(), + op::While(op::Copy(op::Parameter()))); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/cpu/BUILD b/tensorflow/compiler/xla/service/cpu/BUILD index bcac65ecda0770798a7d2b14e088dda180de4981..850948b54b8c8ef7ac4e5da4c64e7ce018e31624 100644 --- a/tensorflow/compiler/xla/service/cpu/BUILD +++ b/tensorflow/compiler/xla/service/cpu/BUILD @@ -20,7 +20,7 @@ load("//tensorflow:tensorflow.bzl", "tf_cc_binary") load("//tensorflow/compiler/xla:xla.bzl", "ORC_JIT_MEMORY_MAPPER_TARGETS") load( "//third_party/mkl:build_defs.bzl", - "if_mkl", + "mkl_deps", ) # Filegroup used to collect source files for dependency checking. @@ -50,16 +50,29 @@ cc_library( "//tensorflow/compiler/xla/service/cpu:cpu_runtime", "//tensorflow/core:lib", "//tensorflow/core:stream_executor_no_cuda", + "@com_google_absl//absl/memory", ], alwayslink = True, # Contains per-platform transfer manager registration ) +cc_library( + name = "buffer_info_util", + srcs = ["buffer_info_util.cc"], + hdrs = ["buffer_info_util.h"], + deps = [ + "//tensorflow/compiler/tf2xla:cpu_function_runtime", + "//tensorflow/compiler/xla/service:buffer_assignment", + "//tensorflow/core:lib", + ], +) + cc_library( name = "cpu_compiler", srcs = ["cpu_compiler.cc"], hdrs = ["cpu_compiler.h"], deps = [ ":compiler_functor", + ":buffer_info_util", ":conv_canonicalization", ":cpu_copy_insertion", ":cpu_executable", @@ -73,6 +86,9 @@ cc_library( ":ir_emitter", ":parallel_task_assignment", ":simple_orc_jit", + "@com_google_absl//absl/memory", + "//tensorflow/compiler/tf2xla:cpu_function_runtime", + "//tensorflow/compiler/xla/service:scatter_expander", "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:protobuf_util", "//tensorflow/compiler/xla:status_macros", @@ -87,6 +103,7 @@ cc_library( "//tensorflow/compiler/xla/service:buffer_liveness", "//tensorflow/compiler/xla/service:call_inliner", "//tensorflow/compiler/xla/service:conditional_simplifier", + "//tensorflow/compiler/xla/service:convolution_feature_group_converter", "//tensorflow/compiler/xla/service:dot_decomposer", "//tensorflow/compiler/xla/service:executable", "//tensorflow/compiler/xla/service:flatten_call_graph", @@ -163,6 +180,7 @@ cc_library( ":runtime_single_threaded_conv2d", ":runtime_single_threaded_fft", ":runtime_single_threaded_matmul", + "@com_google_absl//absl/memory", "@llvm//:execution_engine", "@llvm//:core", "@llvm//:mc", # fixdeps: keep @@ -252,6 +270,7 @@ cc_library( "//tensorflow/compiler/xla/service:hlo_module_config", "//tensorflow/compiler/xla/service:name_uniquer", "//tensorflow/compiler/xla/service/llvm_ir:alias_analysis", + "//tensorflow/compiler/xla/service/llvm_ir:buffer_assignment_util", "//tensorflow/compiler/xla/service/llvm_ir:dynamic_update_slice_util", "//tensorflow/compiler/xla/service/llvm_ir:fused_ir_emitter", "//tensorflow/compiler/xla/service/llvm_ir:ir_array", @@ -363,8 +382,8 @@ tf_cc_binary( "//tensorflow/compiler/xla/client:client_library", "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/client:xla_computation", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", "//tensorflow/core:lib", ], ) @@ -402,6 +421,7 @@ cc_library( "//tensorflow/compiler/xla/service:llvm_compiler", "//tensorflow/compiler/xla/service/llvm_ir:llvm_util", "//tensorflow/core:lib", + "@com_google_absl//absl/memory", "@llvm//:analysis", "@llvm//:core", "@llvm//:ipo", @@ -483,10 +503,7 @@ cc_library( "//tensorflow/core:framework_lite", "//tensorflow/core/kernels:eigen_helpers", "//third_party/eigen3", - ] + if_mkl([ - "@mkl_dnn", - "//third_party/mkl:intel_binary_blob", - ]), + ] + mkl_deps(), ) cc_library( @@ -540,10 +557,7 @@ cc_library( "//tensorflow/compiler/xla:executable_run_options", "//tensorflow/core:framework_lite", "//third_party/eigen3", - ] + if_mkl([ - "//third_party/mkl:intel_binary_blob", - "@mkl_dnn", - ]), + ] + mkl_deps(), ) cc_library( @@ -624,6 +638,7 @@ tf_cc_test( "//tensorflow/core:lib", "//tensorflow/core:test", "//third_party/eigen3", + "@com_google_absl//absl/memory", ], ) @@ -800,6 +815,7 @@ cc_library( "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service:hlo_cost_analysis", "//tensorflow/compiler/xla/service:hlo_pass", + "@com_google_absl//absl/memory", ], ) @@ -883,6 +899,7 @@ cc_library( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/service/llvm_ir:llvm_util", "//tensorflow/core:lib", + "@com_google_absl//absl/algorithm:container", "@llvm//:core", "@llvm//:support", ], diff --git a/tensorflow/compiler/xla/service/cpu/buffer_info_util.cc b/tensorflow/compiler/xla/service/cpu/buffer_info_util.cc new file mode 100644 index 0000000000000000000000000000000000000000..408fe0f5bf5d729165eadd532d4740211620645d --- /dev/null +++ b/tensorflow/compiler/xla/service/cpu/buffer_info_util.cc @@ -0,0 +1,57 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/cpu/buffer_info_util.h" + +namespace xla { +namespace cpu { + +using BufferInfo = ::tensorflow::cpu_function_runtime::BufferInfo; + +std::vector CreateBufferInfosFromBufferAssignment( + const BufferAssignment& buffer_assignment) { + std::vector buffer_infos; + for (const BufferAllocation& allocation : buffer_assignment.Allocations()) { + if (allocation.is_thread_local()) { + buffer_infos.push_back(BufferInfo::MakeOnStackBuffer(allocation.size())); + } else if (allocation.is_constant()) { + buffer_infos.push_back(BufferInfo::MakeConstant(allocation.size())); + } else if (allocation.is_entry_computation_parameter()) { + buffer_infos.push_back(BufferInfo::MakeEntryParameter( + /*size=*/allocation.size(), + /*param_number=*/allocation.parameter_number())); + } else { + buffer_infos.push_back(BufferInfo::MakeTempBuffer(allocation.size())); + } + } + return buffer_infos; +} + +std::vector CreateArgIndexTableFromBufferInfos( + tensorflow::gtl::ArraySlice buffer_infos) { + std::vector result; + for (int64 i = 0; i < buffer_infos.size(); i++) { + if (buffer_infos[i].is_entry_parameter()) { + if (buffer_infos[i].entry_parameter_number() >= result.size()) { + result.resize(buffer_infos[i].entry_parameter_number() + 1); + } + result[buffer_infos[i].entry_parameter_number()] = i; + } + } + return result; +} + +} // namespace cpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/cpu/buffer_info_util.h b/tensorflow/compiler/xla/service/cpu/buffer_info_util.h new file mode 100644 index 0000000000000000000000000000000000000000..05de70c72686dcbdaf0b47c46cde23ed45abdb42 --- /dev/null +++ b/tensorflow/compiler/xla/service/cpu/buffer_info_util.h @@ -0,0 +1,42 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_BUFFER_INFO_UTIL_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_BUFFER_INFO_UTIL_H_ + +#include "tensorflow/compiler/tf2xla/cpu_function_runtime.h" +#include "tensorflow/compiler/xla/service/buffer_assignment.h" +#include "tensorflow/core/lib/gtl/array_slice.h" + +namespace xla { +namespace cpu { +// Creates and returns a list of BufferInfo instances containing relevant +// information from `buffer_assignment`. +std::vector<::tensorflow::cpu_function_runtime::BufferInfo> +CreateBufferInfosFromBufferAssignment( + const BufferAssignment& buffer_assignment); + +// Creates and returns a table containing the mapping from entry computation +// parameters to buffer allocation indices. +// +// If this function returns V then entry parameter i has buffer allocation index +// V[i]. +std::vector CreateArgIndexTableFromBufferInfos( + tensorflow::gtl::ArraySlice<::tensorflow::cpu_function_runtime::BufferInfo> + buffer_infos); +} // namespace cpu +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_BUFFER_INFO_UTIL_H_ diff --git a/tensorflow/compiler/xla/service/cpu/compiler_functor.cc b/tensorflow/compiler/xla/service/cpu/compiler_functor.cc index 6a7eb85e3baec3517b8f3ddef6a8dcfae9c9e614..73b03440cbb936017257b8a92f16dcc25d41e21c 100644 --- a/tensorflow/compiler/xla/service/cpu/compiler_functor.cc +++ b/tensorflow/compiler/xla/service/cpu/compiler_functor.cc @@ -22,6 +22,7 @@ limitations under the License. #include #include +#include "absl/memory/memory.h" #include "llvm/ADT/StringRef.h" #include "llvm/Analysis/TargetLibraryInfo.h" #include "llvm/Analysis/TargetTransformInfo.h" @@ -35,7 +36,6 @@ limitations under the License. #include "llvm/Transforms/IPO.h" #include "llvm/Transforms/IPO/AlwaysInliner.h" #include "llvm/Transforms/IPO/PassManagerBuilder.h" -#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/cpu/cpu_runtime.h" #include "tensorflow/compiler/xla/service/cpu/llvm_ir_runtime.h" #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" @@ -156,9 +156,26 @@ std::unique_ptr CompilerFunctor::operator()( target_machine_->addPassesToEmitMC(codegen_passes, mc_context, ostream); codegen_passes.run(module); - // Construct ObjectFile from machine code buffer. - return std::unique_ptr( + std::unique_ptr memory_buffer( new llvm::SmallVectorMemoryBuffer(std::move(stream_buffer))); + + if (VLOG_IS_ON(2)) { + llvm::Expected> obj_file = + llvm::object::ObjectFile::createObjectFile(*memory_buffer); + if (obj_file) { + StatusOr disasm_result = + disassembler_->DisassembleObjectFile(*obj_file.get()); + if (disasm_result.ok()) { + XLA_VLOG_LINES(2, disasm_result.ValueOrDie().text); + } else { + LOG(WARNING) << "Could not disassemble object file!"; + } + } else { + LOG(WARNING) << "Could convert memory buffer to object file!"; + } + } + + return memory_buffer; } static std::vector VectorFunctionsForTargetLibraryInfoImpl() { @@ -188,7 +205,7 @@ void CompilerFunctor::AddTargetInfoPasses( llvm::legacy::PassManagerBase* passes) const { llvm::Triple target_triple(target_machine_->getTargetTriple()); auto target_library_info_impl = - MakeUnique(target_triple); + absl::make_unique(target_triple); target_library_info_impl->addVectorizableFunctions( VectorFunctionsForTargetLibraryInfoImpl()); passes->add( diff --git a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc index 29fa29d33ad62a76191cef2de22ccc094b0cf35b..5116f926f50bf0344951ebb67def7eddd0919f2b 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc @@ -26,6 +26,7 @@ limitations under the License. // IWYU pragma: no_include "llvm/Config/Disassemblers.def.inc" // IWYU pragma: no_include "llvm/Config/Targets.def.inc" +#include "absl/memory/memory.h" #include "llvm/ADT/StringRef.h" #include "llvm/ADT/Triple.h" #include "llvm/IR/Function.h" @@ -42,7 +43,6 @@ limitations under the License. #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/map_util.h" #include "tensorflow/compiler/xla/protobuf_util.h" -#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/algebraic_simplifier.h" #include "tensorflow/compiler/xla/service/batch_dot_simplification.h" #include "tensorflow/compiler/xla/service/batchnorm_expander.h" @@ -50,6 +50,8 @@ limitations under the License. #include "tensorflow/compiler/xla/service/buffer_liveness.h" #include "tensorflow/compiler/xla/service/call_inliner.h" #include "tensorflow/compiler/xla/service/conditional_simplifier.h" +#include "tensorflow/compiler/xla/service/convolution_feature_group_converter.h" +#include "tensorflow/compiler/xla/service/cpu/buffer_info_util.h" #include "tensorflow/compiler/xla/service/cpu/compiler_functor.h" #include "tensorflow/compiler/xla/service/cpu/conv_canonicalization.h" #include "tensorflow/compiler/xla/service/cpu/cpu_copy_insertion.h" @@ -87,6 +89,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" #include "tensorflow/compiler/xla/service/reduce_precision_insertion.h" #include "tensorflow/compiler/xla/service/reshape_mover.h" +#include "tensorflow/compiler/xla/service/scatter_expander.h" #include "tensorflow/compiler/xla/service/transpose_folding.h" #include "tensorflow/compiler/xla/service/tuple_simplifier.h" #include "tensorflow/compiler/xla/service/while_loop_constant_sinking.h" @@ -103,6 +106,7 @@ limitations under the License. namespace xla { namespace cpu { +using BufferInfo = ::tensorflow::cpu_function_runtime::BufferInfo; CpuAotCompilationOptions::CpuAotCompilationOptions( string triple, string cpu_name, string features, string entry_point_name, @@ -120,11 +124,11 @@ se::Platform::Id CpuAotCompilationOptions::PlatformId() const { } CpuAotCompilationResult::CpuAotCompilationResult( - ObjectFileData object_file_data, BufferSizes buffer_sizes, + ObjectFileData object_file_data, std::vector buffer_infos, int64 result_buffer_index, std::unique_ptr hlo_profile_printer_data) : object_file_data_(std::move(object_file_data)), - buffer_sizes_(std::move(buffer_sizes)), + buffer_infos_(std::move(buffer_infos)), result_buffer_index_(result_buffer_index), hlo_profile_printer_data_(std::move(hlo_profile_printer_data)) {} @@ -255,6 +259,7 @@ Status CpuCompiler::RunHloPasses(HloModule* module, bool is_aot_compile, pipeline.AddPass(); pipeline.AddPass(); pipeline.AddPass(); + pipeline.AddPass(); pipeline.AddPass(&target_machine_features); { auto& pass = @@ -273,7 +278,7 @@ Status CpuCompiler::RunHloPasses(HloModule* module, bool is_aot_compile, // BatchNormExpander can create zero-sized ops, so zero-sized HLO // elimination has to come after that pass. - pipeline.AddPass(); + pass.AddPass(); pass.AddPass(); pass.AddPass(); @@ -297,6 +302,8 @@ Status CpuCompiler::RunHloPasses(HloModule* module, bool is_aot_compile, pipeline.AddPass(/*is_layout_sensitive=*/false); pipeline.AddPass(); + pipeline.AddPass(); + ReducePrecisionInsertion::AddPasses( &pipeline, module->config().debug_options(), ReducePrecisionInsertion::PassTiming::AFTER_FUSION); @@ -354,7 +361,7 @@ llvm::TargetOptions CompilerTargetOptions( llvm::TargetOptions target_options; llvm_ir::SetTargetOptions( /*fast_math_enabled=*/module_config.debug_options() - .xla_enable_fast_math(), + .xla_cpu_enable_fast_math(), &target_options); return target_options; } @@ -446,7 +453,7 @@ Status CreateHloProfilingArtifacts( computation_to_profile_idx, std::unique_ptr* hlo_profile_index_map, std::unique_ptr* hlo_profile_printer_data) { - *hlo_profile_index_map = MakeUnique(module); + *hlo_profile_index_map = absl::make_unique(module); const HloComputation& entry_computation = *module.entry_computation(); TF_ASSIGN_OR_RETURN( @@ -513,15 +520,15 @@ StatusOr> CpuCompiler::RunBackend( &pre_optimization_ir_hook, &post_optimization_ir_hook)); // Compile must be thread-safe so create a new LLVM context for the module. - auto llvm_context = xla::MakeUnique(); + auto llvm_context = absl::make_unique(); auto llvm_module = - xla::MakeUnique("__compute_module", *llvm_context); + absl::make_unique("__compute_module", *llvm_context); - auto jit = xla::MakeUnique( + auto jit = absl::make_unique( CompilerTargetOptions(module->config()), CodeGenOptLevel(module->config()), options::OptimizeForSizeRequested(module->config()), - module->config().debug_options().xla_enable_fast_math(), + module->config().debug_options().xla_cpu_enable_fast_math(), module->config().debug_options().xla_llvm_disable_expensive_passes(), pre_optimization_ir_hook, post_optimization_ir_hook); llvm_module->setDataLayout(jit->data_layout()); @@ -559,10 +566,12 @@ StatusOr> CpuCompiler::RunBackend( // temporary buffers are required to run the computation. TF_ASSIGN_OR_RETURN( std::unique_ptr assignment, - BufferAssigner::Run( - module.get(), - xla::MakeUnique(module.get(), module_sequence), - BufferSizeBytesFunction(), memory_alignment)); + BufferAssigner::Run(module.get(), + absl::make_unique( + module.get(), module_sequence), + BufferSizeBytesFunction(), memory_alignment, + /*allow_input_output_aliasing=*/false, + /*allocate_buffers_for_constants=*/true)); // BufferAssignment::ToString() includes a header, so no need for us to // print one ourselves. XLA_VLOG_LINES(2, assignment->ToString()); @@ -584,6 +593,8 @@ StatusOr> CpuCompiler::RunBackend( std::move(computation_to_profile_idx), &target_machine_features); + TF_RETURN_IF_ERROR(ir_emitter.EmitConstantGlobals()); + for (auto embedded_computation : entry_computation->MakeEmbeddedComputationsList()) { if (embedded_computation->IsFusionComputation()) { @@ -647,9 +658,9 @@ CpuCompiler::CompileAheadOfTime(std::vector> modules, // so we bail if the configs have conflicting flags. At the moment, the only // flag that needs to be consistent is fast-math. const bool fast_math_enabled = - modules[0]->config().debug_options().xla_enable_fast_math(); + modules[0]->config().debug_options().xla_cpu_enable_fast_math(); for (const auto& module : modules) { - if (module->config().debug_options().xla_enable_fast_math() != + if (module->config().debug_options().xla_cpu_enable_fast_math() != fast_math_enabled) { return InvalidArgument( "All HLO module configs must have the same value for " @@ -705,7 +716,7 @@ CpuCompiler::CompileAheadOfTime(std::vector> modules, llvm::StringRef cpu_name = llvm_ir::AsStringRef(options.cpu_name()); llvm::StringRef features = llvm_ir::AsStringRef(options.features()); llvm::CodeGenOpt::Level opt_level = CodeGenOptLevel(modules[0]->config()); - std::unique_ptr target_machine = WrapUnique( + std::unique_ptr target_machine = absl::WrapUnique( target->createTargetMachine(triple.getTriple(), cpu_name, features, CompilerTargetOptions(modules[0]->config()), reloc_model, llvm::None, opt_level)); @@ -746,8 +757,10 @@ CpuCompiler::CompileAheadOfTime(std::vector> modules, std::unique_ptr assignment, BufferAssigner::Run( module, - xla::MakeUnique(module, module_sequence), - BufferSizeBytesFunction(), memory_alignment)); + absl::make_unique(module, module_sequence), + BufferSizeBytesFunction(), memory_alignment, + /*allow_input_output_aliasing=*/false, + /*allocate_buffers_for_constants=*/true)); // BufferAssignment::ToString() includes a header, so no need for us to // print one ourselves. XLA_VLOG_LINES(2, assignment->ToString()); @@ -776,6 +789,9 @@ CpuCompiler::CompileAheadOfTime(std::vector> modules, std::move(instruction_to_profile_idx), std::move(computation_to_profile_idx), &target_machine_features); + + TF_RETURN_IF_ERROR(ir_emitter.EmitConstantGlobals()); + HloComputation* computation = module->entry_computation(); for (auto embedded_computation : computation->MakeEmbeddedComputationsList()) { @@ -821,7 +837,7 @@ CpuCompiler::CompileAheadOfTime(std::vector> modules, CompilerFunctor compiler_functor( target_machine.get(), &disassembler, opt_level, options::OptimizeForSizeRequested(module->config()), - module->config().debug_options().xla_enable_fast_math(), + module->config().debug_options().xla_cpu_enable_fast_math(), module->config().debug_options().xla_llvm_disable_expensive_passes(), pre_optimization_ir_dump_hook, post_optimization_ir_dump_hook); std::unique_ptr object_file = @@ -829,27 +845,14 @@ CpuCompiler::CompileAheadOfTime(std::vector> modules, ObjectFileData object_file_data(object_file->getBufferStart(), object_file->getBufferEnd()); - BufferSizes buffer_sizes; - for (const BufferAllocation& allocation : assignment->Allocations()) { - // Callers don't need to allocate temporary buffers for parameters. - if (allocation.is_entry_computation_parameter()) { - buffer_sizes.push_back(-1); - continue; - } - // Callers don't need to allocate anything for thread-local temporary - // buffers. They are lowered to allocas. - if (allocation.is_thread_local()) { - buffer_sizes.push_back(-1); - continue; - } - buffer_sizes.push_back(allocation.size()); - } + std::vector buffer_infos = + CreateBufferInfosFromBufferAssignment(*assignment); TF_ASSIGN_OR_RETURN(const BufferAllocation::Slice result_slice, assignment->GetUniqueTopLevelOutputSlice()); - results.emplace_back(MakeUnique( - std::move(object_file_data), std::move(buffer_sizes), + results.emplace_back(absl::make_unique( + std::move(object_file_data), std::move(buffer_infos), result_slice.index(), std::move(hlo_profile_printer_data))); } @@ -871,7 +874,7 @@ HloCostAnalysis::ShapeSizeFunction CpuCompiler::ShapeSizeBytesFunction() const { static bool InitModule() { xla::Compiler::RegisterCompilerFactory( stream_executor::host::kHostPlatformId, - []() { return xla::MakeUnique(); }); + []() { return absl::make_unique(); }); return true; } static bool module_initialized = InitModule(); diff --git a/tensorflow/compiler/xla/service/cpu/cpu_compiler.h b/tensorflow/compiler/xla/service/cpu/cpu_compiler.h index e56f9f01134f84b4698c078b750b0c1fdca7748e..04e1c48872ed55ca7f2aa3bec08c44a1666b90f1 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_compiler.h +++ b/tensorflow/compiler/xla/service/cpu/cpu_compiler.h @@ -19,6 +19,7 @@ limitations under the License. #include #include "llvm/Target/TargetMachine.h" +#include "tensorflow/compiler/tf2xla/cpu_function_runtime.h" #include "tensorflow/compiler/xla/service/executable.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/llvm_compiler.h" @@ -78,7 +79,8 @@ class CpuAotCompilationOptions : public AotCompilationOptions { class CpuAotCompilationResult : public AotCompilationResult { public: CpuAotCompilationResult( - ObjectFileData object_file_data, BufferSizes buffer_sizes, + ObjectFileData object_file_data, + std::vector<::tensorflow::cpu_function_runtime::BufferInfo> buffer_infos, int64 result_buffer_index, std::unique_ptr hlo_profile_printer_data); ~CpuAotCompilationResult(); @@ -88,17 +90,20 @@ class CpuAotCompilationResult : public AotCompilationResult { } const ObjectFileData& object_file_data() const { return object_file_data_; } - const BufferSizes& buffer_sizes() const { return buffer_sizes_; } + const std::vector<::tensorflow::cpu_function_runtime::BufferInfo>& + buffer_infos() const { + return buffer_infos_; + } int64 result_buffer_index() const { return result_buffer_index_; } private: // Contains the compiled computation: an object file. const ObjectFileData object_file_data_; - // The list of buffer sizes which should be allocated in order to execute the - // compiled computation. These buffers are used for temporary buffers used - // ephemerally during computation as well as the output result. - const BufferSizes buffer_sizes_; + // A list of BufferInfo objects describing the buffers used by the XLA + // computation. + const std::vector<::tensorflow::cpu_function_runtime::BufferInfo> + buffer_infos_; // Contains which buffer index into |buffer_sizes| was designated to the // result of the computation. This buffer should be passed into the output diff --git a/tensorflow/compiler/xla/service/cpu/cpu_executable.cc b/tensorflow/compiler/xla/service/cpu/cpu_executable.cc index 1093559892ddb9c238fd9c1f7e3d419ec7022776..c376864c3e1f882e11bc05f8cf93f2fb1c88e4ec 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_executable.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_executable.cc @@ -69,12 +69,19 @@ CpuExecutable::CpuExecutable( // guarded by the mutex. compute_function_ = reinterpret_cast(cantFail(sym.getAddress())); + VLOG(1) << "compute_function_ at address " + << reinterpret_cast(compute_function_); } -Status CpuExecutable::AllocateBuffers( +StatusOr, + std::vector>> +CpuExecutable::CreateTempArray( DeviceMemoryAllocator* memory_allocator, int device_ordinal, - std::vector* buffers) { - CHECK_EQ(buffers->size(), assignment_->Allocations().size()); + tensorflow::gtl::ArraySlice arguments) { + std::vector unowning_buffers( + assignment_->Allocations().size()); + std::vector owning_buffers( + assignment_->Allocations().size()); VLOG(3) << "Allocating " << assignment_->Allocations().size() << " allocations for module " << module().name(); for (BufferAllocation::Index i = 0; i < assignment_->Allocations().size(); @@ -84,44 +91,51 @@ Status CpuExecutable::AllocateBuffers( VLOG(3) << allocation.ToString(); if (allocation.is_entry_computation_parameter()) { + unowning_buffers[i] = arguments[allocation.parameter_number()]->buffer( + allocation.param_shape_index()); VLOG(3) << "allocation #" << i << " is a parameter"; continue; } + if (allocation.is_constant()) { + VLOG(3) << "allocation #" << i << " is a constant"; + continue; + } + if (allocation.is_thread_local()) { VLOG(3) << "buffer #" << i << " is thread-local"; continue; } int64 buffer_size = allocation.size(); - if (!(*buffers)[i].is_null()) { + if (!owning_buffers[i].is_null()) { VLOG(3) << "buffer #" << i << " is in the preallocated result ShapedBuffer"; } else { - TF_ASSIGN_OR_RETURN((*buffers)[i], memory_allocator->Allocate( - device_ordinal, buffer_size)); + TF_ASSIGN_OR_RETURN(owning_buffers[i], memory_allocator->Allocate( + device_ordinal, buffer_size)); + unowning_buffers[i] = owning_buffers[i].AsDeviceMemoryBase(); VLOG(3) << "buffer #" << i << " allocated " << buffer_size << " bytes [" - << (*buffers)[i].opaque() << "]"; + << owning_buffers[i].opaque() << "]"; } // Since the output buffer and all the temporary buffers were written into // by the JITed code, msan has no way of knowing their memory was // initialized. Mark them initialized so that msan doesn't flag loads from // these buffers. - TF_ANNOTATE_MEMORY_IS_INITIALIZED((*buffers)[i].opaque(), buffer_size); + TF_ANNOTATE_MEMORY_IS_INITIALIZED(owning_buffers[i].opaque(), buffer_size); } TF_ASSIGN_OR_RETURN(const BufferAllocation::Slice result_slice, assignment_->GetUniqueTopLevelOutputSlice()); VLOG(3) << "result index: " << result_slice.index(); - return Status::OK(); + return {{std::move(unowning_buffers), std::move(owning_buffers)}}; } Status CpuExecutable::ExecuteComputeFunction( const ExecutableRunOptions* run_options, - tensorflow::gtl::ArraySlice arguments, tensorflow::gtl::ArraySlice buffers, HloExecutionProfile* hlo_execution_profile) { // The calling convention for JITed functions is: @@ -131,17 +145,11 @@ Status CpuExecutable::ExecuteComputeFunction( // // result: Points at the result. // run_options: the ExecutableRunOptions object. - // args_array: An array of pointers, each of which points to a parameter. - // The size of this array is determined by the function's arity - // (ProgramShape). - // temps_array: An array of pointers, each of which points to a temporary - // buffer the computation needs. The size of this array is - // determined by buffer analysis. + // args_array: null + // temps_array: An array of pointers, containing pointers to temporary buffers + // required by the executable adn pointers to entry computation + // parameters. // - std::vector args_array; - for (const ShapedBuffer* argument : arguments) { - args_array.push_back(argument->root_buffer().opaque()); - } uint64 start_micros = tensorflow::Env::Default()->NowMicros(); @@ -164,16 +172,14 @@ Status CpuExecutable::ExecuteComputeFunction( if (VLOG_IS_ON(3)) { VLOG(3) << "Executing compute function:"; VLOG(3) << tensorflow::strings::Printf( - " func(void* result, void* params[%zu], void* temps[%zu], " + " func(void* result, void* params[null], void* temps[%zu], " "uint64 profile_counters[%zu])", - args_array.size(), buffer_pointers.size(), profile_counters_size); + buffer_pointers.size(), profile_counters_size); VLOG(3) << tensorflow::strings::Printf(" result = %p", result_buffer); auto ptr_printer = [](string* out, const void* p) { tensorflow::strings::StrAppend(out, tensorflow::strings::Printf("%p", p)); }; - VLOG(3) << tensorflow::strings::Printf( - " params = [%s]", - tensorflow::str_util::Join(args_array, ", ", ptr_printer).c_str()); + VLOG(3) << " params = nullptr"; VLOG(3) << tensorflow::strings::Printf( " temps = [%s]", tensorflow::str_util::Join(buffer_pointers, ", ", ptr_printer).c_str()); @@ -181,8 +187,8 @@ Status CpuExecutable::ExecuteComputeFunction( profile_counters); } - compute_function_(result_buffer, run_options, args_array.data(), - buffer_pointers.data(), profile_counters); + compute_function_(result_buffer, run_options, nullptr, buffer_pointers.data(), + profile_counters); uint64 end_micros = tensorflow::Env::Default()->NowMicros(); @@ -243,27 +249,11 @@ StatusOr CpuExecutable::ExecuteOnStream( const ServiceExecutableRunOptions* run_options, tensorflow::gtl::ArraySlice arguments, HloExecutionProfile* hlo_execution_profile) { - if (GetRootPointsToSet().IsAmbiguous()) { - return Unimplemented("Points-to set of root instruction is ambiguous"); - } - - se::Stream* stream = run_options->stream(); - DeviceMemoryAllocator* memory_allocator = run_options->allocator(); - std::vector buffers(assignment_->Allocations().size()); - - TF_RETURN_IF_ERROR(AllocateBuffers( - memory_allocator, stream->parent()->device_ordinal(), &buffers)); - - std::vector unowning_buffers; - unowning_buffers.reserve(buffers.size()); - for (auto& buffer : buffers) { - unowning_buffers.push_back(buffer.AsDeviceMemoryBase()); - } - TF_RETURN_IF_ERROR(ExecuteComputeFunction(&run_options->run_options(), - arguments, unowning_buffers, - hlo_execution_profile)); - - return CreateResultShapedBuffer(run_options, &buffers); + TF_ASSIGN_OR_RETURN( + auto result, + ExecuteAsyncOnStreamImpl(run_options, arguments, hlo_execution_profile)); + TF_RETURN_IF_ERROR(run_options->stream()->BlockHostUntilDone()); + return std::move(result); } StatusOr CpuExecutable::ExecuteAsyncOnStream( @@ -274,22 +264,30 @@ StatusOr CpuExecutable::ExecuteAsyncOnStream( "Asynchronous execution on stream with hlo profiling is not yet " "supported on CPU."); } + return ExecuteAsyncOnStreamImpl(run_options, arguments, nullptr); +} + +StatusOr CpuExecutable::ExecuteAsyncOnStreamImpl( + const ServiceExecutableRunOptions* run_options, + tensorflow::gtl::ArraySlice arguments, + HloExecutionProfile* hlo_execution_profile) { + if (GetRootPointsToSet().IsAmbiguous()) { + return Unimplemented("Points-to set of root instruction is ambiguous"); + } auto* host_stream = dynamic_cast( run_options->stream()->implementation()); se::Stream* stream = run_options->stream(); DeviceMemoryAllocator* memory_allocator = run_options->allocator(); - std::vector buffers(assignment_->Allocations().size()); - TF_RETURN_IF_ERROR(AllocateBuffers( - memory_allocator, stream->parent()->device_ordinal(), &buffers)); - + std::vector owning_buffers; std::vector unowning_buffers; - unowning_buffers.reserve(buffers.size()); - for (auto& buffer : buffers) { - unowning_buffers.push_back(buffer.AsDeviceMemoryBase()); - } + TF_ASSIGN_OR_RETURN( + std::tie(unowning_buffers, owning_buffers), + CreateTempArray(memory_allocator, stream->parent()->device_ordinal(), + arguments)); + TF_ASSIGN_OR_RETURN(ScopedShapedBuffer result, - CreateResultShapedBuffer(run_options, &buffers)); + CreateResultShapedBuffer(run_options, &owning_buffers)); // At this point, `unowning_buffers` contains unowning pointers to all of our // buffers, and `buffers` contains owning pointers to the non-live-out @@ -307,23 +305,22 @@ StatusOr CpuExecutable::ExecuteAsyncOnStream( struct AsyncRunTask { CpuExecutable* executable; ServiceExecutableRunOptions run_options; - std::vector arguments; std::vector unowning_buffers; std::shared_ptr> buffers; + HloExecutionProfile* hlo_execution_profile; void operator()() { // Failing a CHECK here is not great, but I don't see an obvious way to // return a failed Status asynchronously. TF_CHECK_OK(executable->ExecuteComputeFunction( - &run_options.run_options(), arguments, unowning_buffers, - /*hlo_execution_profile=*/nullptr)); + &run_options.run_options(), unowning_buffers, hlo_execution_profile)); } }; - host_stream->EnqueueTask(AsyncRunTask{ - this, *run_options, - std::vector(arguments.begin(), arguments.end()), - unowning_buffers, - std::make_shared>(std::move(buffers))}); + host_stream->EnqueueTask( + AsyncRunTask{this, *run_options, std::move(unowning_buffers), + std::make_shared>( + std::move(owning_buffers)), + hlo_execution_profile}); return std::move(result); } diff --git a/tensorflow/compiler/xla/service/cpu/cpu_executable.h b/tensorflow/compiler/xla/service/cpu/cpu_executable.h index 8dd47bfb865e8a0552542f510d3365cff0d111e0..96e53de57eee013fe6f847c10e23a38f5beb9adc 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_executable.h +++ b/tensorflow/compiler/xla/service/cpu/cpu_executable.h @@ -85,20 +85,39 @@ class CpuExecutable : public Executable { const BufferAssignment& buffer_assignment() const { return *assignment_; } private: - // Allocate buffers required for execution and assign them to the elements of - // "buffers". "buffers" should be sized to the number of buffers in buffer - // assignment. Each vector element corresponds to a particular Index. If - // a vector element already contains a non-null DeviceMemoryBase, then no - // buffer is assigned for this element. - Status AllocateBuffers(DeviceMemoryAllocator* memory_allocator, - int device_ordinal, - std::vector* buffers); + // This is for sharing the code between ExecuteOnStream and + // ExecuteAsyncOnStream. + // + // Notice that it's tricky to use correctly, as the profile object (when it + // exists) must out-live the task. + StatusOr ExecuteAsyncOnStreamImpl( + const ServiceExecutableRunOptions* run_options, + tensorflow::gtl::ArraySlice arguments, + HloExecutionProfile* hlo_execution_profile); + + // Creates an array suitable for passing as the "temps" argument to the JIT + // compiled function pointer. + // + // Returns (unowning_buffers, owning_buffers) where: + // + // - unowning_buffers.data() can be passed as the temps argument as-is and + // includes pointers to the scratch storage required by the computation, + // the live-out buffer into which the result will be written and entry + // computation parameters. + // + // - owning_buffers contains owning pointers to the buffers that were + // allocated by this routine. This routine allocates buffers for temporary + // storage and the live-out buffer into which the computation writes it + // result. + StatusOr, + std::vector>> + CreateTempArray(DeviceMemoryAllocator* memory_allocator, int device_ordinal, + tensorflow::gtl::ArraySlice arguments); // Calls the generated function performing the computation with the given // arguments using the supplied buffers. Status ExecuteComputeFunction( const ExecutableRunOptions* run_options, - tensorflow::gtl::ArraySlice arguments, tensorflow::gtl::ArraySlice buffers, HloExecutionProfile* hlo_execution_profile); diff --git a/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc b/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc index 991b14f17dbc8cd061af98e032824d3f7075e78b..e6130c7d76e0383d03fe56d19aee239c5992309d 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc @@ -697,8 +697,9 @@ void CreateComputationForDotAddOutputFusionTest(const string& test_name, HloInstruction::CreateBinary(dot_shape, HloOpcode::kAdd, dot, addend)); if (add_extra_use_for_dot) { + auto* token = builder.AddInstruction(HloInstruction::CreateToken()); builder.AddInstruction( - HloInstruction::CreateOutfeed(dot_shape, dot, "no_config")); + HloInstruction::CreateOutfeed(dot_shape, dot, token, "no_config")); } module->AddEntryComputation(builder.Build()); @@ -791,11 +792,11 @@ ENTRY main { operand = s32[3,3] parameter(0) indices = s32[2] parameter(1) gather = s32[3,2] gather(operand, indices), - output_window_dims={0}, - elided_window_dims={1}, - gather_dims_to_operand_dims={1}, + offset_dims={0}, + collapsed_slice_dims={1}, + start_index_map={1}, index_vector_dim=1, - window_bounds={3, 1} + slice_sizes={3, 1} one = s32[] constant(1) one_broadcasted = s32[3,2] broadcast(one), dimensions={} ROOT result = s32[3,2]{1,0} add(gather, one_broadcasted) @@ -807,11 +808,11 @@ ENTRY main { operand = s32[3,3] parameter(0) indices = s32[2,2] parameter(1) gather = s32[2,3,2] gather(operand, indices), - output_window_dims={1}, - elided_window_dims={1}, - gather_dims_to_operand_dims={1}, + offset_dims={1}, + collapsed_slice_dims={1}, + start_index_map={1}, index_vector_dim=2, - window_bounds={3, 1} + slice_sizes={3, 1} one = s32[] constant(1) one_broadcasted = s32[2,3,2] broadcast(one), dimensions={} ROOT result = s32[2,3,2]{2,1,0} add(gather, one_broadcasted) @@ -823,11 +824,11 @@ ENTRY main { operand = s32[3,3] parameter(0) indices = s32[2,2,2] parameter(1) gather = s32[2,2] gather(operand, indices), - output_window_dims={}, - elided_window_dims={0,1}, - gather_dims_to_operand_dims={0,1}, + offset_dims={}, + collapsed_slice_dims={0,1}, + start_index_map={0,1}, index_vector_dim=2, - window_bounds={1, 1} + slice_sizes={1, 1} one = s32[] constant(1) one_broadcasted = s32[2,2] broadcast(one), dimensions={} ROOT result = s32[2,2]{1,0} add(gather, one_broadcasted) @@ -839,11 +840,11 @@ ENTRY main { operand = s32[3,3,2] parameter(0) indices = s32[2,2] parameter(1) gather = s32[2,2] gather(operand, indices), - output_window_dims={1}, - elided_window_dims={0,1}, - gather_dims_to_operand_dims={0,1}, + offset_dims={1}, + collapsed_slice_dims={0,1}, + start_index_map={0,1}, index_vector_dim=1, - window_bounds={1,1,2} + slice_sizes={1,1,2} one = s32[] constant(1) one_broadcasted = s32[2,2] broadcast(one), dimensions={} ROOT result = s32[2,2]{1,0} add(gather, one_broadcasted) @@ -855,11 +856,11 @@ ENTRY main { operand = s32[3,3,2] parameter(0) indices = s32[2,2] parameter(1) gather = s32[2,2] gather(operand, indices), - output_window_dims={1}, - elided_window_dims={0,1}, - gather_dims_to_operand_dims={0,1}, + offset_dims={1}, + collapsed_slice_dims={0,1}, + start_index_map={0,1}, index_vector_dim=0, - window_bounds={1,1,2} + slice_sizes={1,1,2} one = s32[] constant(1) one_broadcasted = s32[2,2] broadcast(one), dimensions={} ROOT result = s32[2,2]{1,0} add(gather, one_broadcasted) @@ -871,11 +872,11 @@ ENTRY main { operand = s32[3,3] parameter(0) indices = s32[2] parameter(1) gather = s32[1,1] gather(operand, indices), - output_window_dims={0,1}, - elided_window_dims={}, - gather_dims_to_operand_dims={0,1}, + offset_dims={0,1}, + collapsed_slice_dims={}, + start_index_map={0,1}, index_vector_dim=0, - window_bounds={1,1} + slice_sizes={1,1} one = s32[] constant(1) one_broadcasted = s32[1,1] broadcast(one), dimensions={} ROOT result = s32[1,1]{1,0} add(gather, one_broadcasted) @@ -887,11 +888,11 @@ ENTRY main { operand = s32[3,3] parameter(0) indices = s32[2,2] parameter(1) gather = s32[2,1,1] gather(operand, indices), - output_window_dims={1,2}, - elided_window_dims={}, - gather_dims_to_operand_dims={0,1}, + offset_dims={1,2}, + collapsed_slice_dims={}, + start_index_map={0,1}, index_vector_dim=0, - window_bounds={1,1} + slice_sizes={1,1} one = s32[] constant(1) one_broadcasted = s32[2,1,1] broadcast(one), dimensions={} ROOT result = s32[2,1,1]{2,1,0} add(gather, one_broadcasted) diff --git a/tensorflow/compiler/xla/service/cpu/cpu_runtime.cc b/tensorflow/compiler/xla/service/cpu/cpu_runtime.cc index 54c52bc08f9c53b8c6898689b18c4cb7f4bdcfd0..639064040f521a9e84bd87c5d05f674204e4d6e2 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_runtime.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_runtime.cc @@ -92,9 +92,10 @@ tensorflow::string ShapeString(const void* shape_ptr, xla::int32 shape_length) { } // namespace -void* __xla_cpu_runtime_AcquireInfeedBufferForDequeue(xla::int32 buffer_length, - const void* shape, - xla::int32 shape_length) { +TF_ATTRIBUTE_NO_SANITIZE_MEMORY void* +__xla_cpu_runtime_AcquireInfeedBufferForDequeue(xla::int32 buffer_length, + const void* shape, + xla::int32 shape_length) { if (VLOG_IS_ON(2)) { LOG(INFO) << "AcquireInfeedBufferForDequeue: " << ShapeString(shape, shape_length); @@ -111,9 +112,11 @@ void* __xla_cpu_runtime_AcquireInfeedBufferForDequeue(xla::int32 buffer_length, return buffer->data(); } -void __xla_cpu_runtime_ReleaseInfeedBufferAfterDequeue( - xla::int32 buffer_length, void* buffer_ptr, const void* shape_ptr, - xla::int32 shape_length) { +TF_ATTRIBUTE_NO_SANITIZE_MEMORY void +__xla_cpu_runtime_ReleaseInfeedBufferAfterDequeue(xla::int32 buffer_length, + void* buffer_ptr, + const void* shape_ptr, + xla::int32 shape_length) { if (VLOG_IS_ON(2)) { LOG(INFO) << "ReleaseInfeedBufferAfterDeque: " << ShapeString(shape_ptr, shape_length); @@ -125,8 +128,10 @@ void __xla_cpu_runtime_ReleaseInfeedBufferAfterDequeue( std::move(shape)); } -void* __xla_cpu_runtime_AcquireOutfeedBufferForPopulation( - xla::int32 buffer_length, const void* shape_ptr, xla::int32 shape_length) { +TF_ATTRIBUTE_NO_SANITIZE_MEMORY void* +__xla_cpu_runtime_AcquireOutfeedBufferForPopulation(xla::int32 buffer_length, + const void* shape_ptr, + xla::int32 shape_length) { if (VLOG_IS_ON(2)) { LOG(INFO) << "AcquireOutfeedBufferForPopulation: " << ShapeString(shape_ptr, shape_length); @@ -143,9 +148,11 @@ void* __xla_cpu_runtime_AcquireOutfeedBufferForPopulation( return buffer->data(); } -void __xla_cpu_runtime_ReleaseOutfeedBufferAfterPopulation( - xla::int32 buffer_length, void* buffer_ptr, const void* shape_ptr, - xla::int32 shape_length) { +TF_ATTRIBUTE_NO_SANITIZE_MEMORY void +__xla_cpu_runtime_ReleaseOutfeedBufferAfterPopulation(xla::int32 buffer_length, + void* buffer_ptr, + const void* shape_ptr, + xla::int32 shape_length) { if (VLOG_IS_ON(2)) { LOG(INFO) << "ReleaseOutfeedBufferAfterPopulation: " << ShapeString(shape_ptr, shape_length); diff --git a/tensorflow/compiler/xla/service/cpu/cpu_runtime_test.cc b/tensorflow/compiler/xla/service/cpu/cpu_runtime_test.cc index 2ac950e6d93ade315808f2ca1d0bdd7bc85f53b9..bc4cfc099965e2ab12212f55e62bdf79c0cfb739 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_runtime_test.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_runtime_test.cc @@ -19,10 +19,10 @@ limitations under the License. #include #include +#include "absl/memory/memory.h" #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #include "tensorflow/compiler/xla/array2d.h" #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/cpu/runtime_matmul.h" #include "tensorflow/compiler/xla/service/cpu/runtime_matmul_mkl.h" #include "tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul.h" @@ -46,7 +46,7 @@ std::unique_ptr> MaybeTransposeArray2D(const Array2D& array, if (transpose) { std::swap(output_width, output_height); } - auto output = MakeUnique>(output_height, output_width); + auto output = absl::make_unique>(output_height, output_width); for (int y = 0; y < array.height(); y++) { for (int x = 0; x < array.width(); x++) { if (transpose) { @@ -93,7 +93,7 @@ std::unique_ptr> EigenMatrixMultiply(const Array2D& a, // Since we're going to transpose c before returning it. Swap the order of the // dimension sizes to ensure the returned array is properly dimensioned. - auto c_transpose = MakeUnique>(n, m); + auto c_transpose = absl::make_unique>(n, m); if (single_threaded) { __xla_cpu_runtime_EigenSingleThreadedMatMulF32( nullptr, c_transpose->data(), a_transpose->data(), b_transpose->data(), @@ -204,7 +204,7 @@ std::unique_ptr> MKLMatrixMultiply(const Array2D& a, // Since we're going to transpose c before returning it, swap the order of the // dimension sizes to ensure the returned array is properly dimensioned. - auto c_transpose = MakeUnique>(n, m); + auto c_transpose = absl::make_unique>(n, m); if (single_threaded) { __xla_cpu_runtime_MKLSingleThreadedMatMulF32( nullptr, c_transpose->data(), a_transpose->data(), b_transpose->data(), diff --git a/tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.cc b/tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.cc index 156166bf2b1ea6d3821da8f67ea2b2eca6825ca6..b07cd675ffc4dbd0c7d56da715b29014bb12ce88 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.cc @@ -19,6 +19,7 @@ limitations under the License. #include #include +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/service/cpu/cpu_runtime.h" @@ -173,7 +174,7 @@ CpuTransferManager::TransferBufferToInfeedInternal(se::StreamExecutor* executor, Status CpuTransferManager::TransferLiteralFromOutfeed( se::StreamExecutor* executor, const Shape& literal_shape, - Literal* literal) { + MutableBorrowingLiteral literal) { if (!ShapeUtil::IsTuple(literal_shape)) { int64 size = GetByteSizeRequirement(literal_shape); // Note: OSS build didn't like implicit conversion from @@ -181,18 +182,16 @@ Status CpuTransferManager::TransferLiteralFromOutfeed( tensorflow::gtl::ArraySlice dimensions( tensorflow::bit_cast(literal_shape.dimensions().data()), literal_shape.dimensions().size()); - *literal = std::move(*LiteralUtil::CreateFromDimensions( - literal_shape.element_type(), dimensions)); - TF_ASSIGN_OR_RETURN(Shape received_shape, - TransferArrayBufferFromOutfeed( - executor, literal->untyped_data(), size)); - TF_RET_CHECK(ShapeUtil::Compatible(received_shape, literal->shape())) + TF_ASSIGN_OR_RETURN( + Shape received_shape, + TransferArrayBufferFromOutfeed(executor, literal.untyped_data(), size)); + TF_RET_CHECK(ShapeUtil::Compatible(received_shape, literal.shape())) << "Shape received from outfeed " << ShapeUtil::HumanString(received_shape) << " did not match the shape that was requested for outfeed: " << ShapeUtil::HumanString(literal_shape); TF_RET_CHECK(size == GetByteSizeRequirement(received_shape)); - *literal->mutable_shape_do_not_use() = received_shape; + *literal.mutable_shape_do_not_use() = received_shape; return Status::OK(); } @@ -201,22 +200,12 @@ Status CpuTransferManager::TransferLiteralFromOutfeed( "Nested tuple outfeeds are not yet implemented on CPU."); } - std::vector> elements; std::vector> buffer_data; for (int64 i = 0; i < literal_shape.tuple_shapes_size(); ++i) { const Shape& tuple_element_shape = ShapeUtil::GetTupleElementShape(literal_shape, i); - // Note: OSS build didn't like implicit conversion from - // literal_shape.dimensions() to the array slice on 2017-07-10. - tensorflow::gtl::ArraySlice dimensions( - tensorflow::bit_cast( - tuple_element_shape.dimensions().data()), - tuple_element_shape.dimensions().size()); - auto empty = LiteralUtil::CreateFromDimensions( - tuple_element_shape.element_type(), dimensions); int64 size = GetByteSizeRequirement(tuple_element_shape); - buffer_data.push_back({empty->untyped_data(), size}); - elements.push_back(std::move(empty)); + buffer_data.push_back({literal.untyped_data({i}), size}); } TF_ASSIGN_OR_RETURN(Shape received_shape, @@ -230,11 +219,7 @@ Status CpuTransferManager::TransferLiteralFromOutfeed( TF_RET_CHECK(GetByteSizeRequirement(literal_shape) == GetByteSizeRequirement(received_shape)); - for (int64 i = 0; i < literal_shape.tuple_shapes_size(); ++i) { - *elements[i]->mutable_shape_do_not_use() = received_shape.tuple_shapes(i); - } - *literal = std::move(*LiteralUtil::MakeTupleOwned(std::move(elements))); - TF_RET_CHECK(ShapeUtil::Equal(literal->shape(), literal_shape)); + TF_RET_CHECK(ShapeUtil::Equal(literal.shape(), literal_shape)); return Status::OK(); } @@ -272,7 +257,7 @@ StatusOr CpuTransferManager::TransferBuffersFromOutfeedInternal( VLOG(2) << "Enqueueing outfeed buffer (for the device to populate) of length " << size_32 << "B"; - buffers.emplace_back(MakeUnique(b.first, size_32)); + buffers.emplace_back(absl::make_unique(b.first, size_32)); } std::vector buffer_pointers; @@ -299,7 +284,7 @@ StatusOr CpuTransferManager::TransferBuffersFromOutfeedInternal( } // namespace xla static std::unique_ptr CreateCpuTransferManager() { - return xla::MakeUnique(); + return absl::make_unique(); } static bool InitModule() { diff --git a/tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.h b/tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.h index 593575c0fdaddc71cd6bd844fd179096a9fb0fdc..80ef953d532798281c10b7a212b9c4d84a790c27 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.h +++ b/tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.h @@ -18,6 +18,7 @@ limitations under the License. #include +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/service/cpu/xfeed_manager.h" #include "tensorflow/compiler/xla/service/generic_transfer_manager.h" #include "tensorflow/compiler/xla/service/transfer_manager.h" @@ -41,7 +42,7 @@ class CpuTransferManager : public GenericTransferManager { const LiteralSlice& literal) override; Status TransferLiteralFromOutfeed(se::StreamExecutor* executor, const Shape& literal_shape, - Literal* literal) override; + MutableBorrowingLiteral literal) override; private: Status TransferBufferToInfeed(se::StreamExecutor* executor, int64 size, diff --git a/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc b/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc index 645888de783e4025cffd6fa4835e60b84bbd7d99..f2ac742b6e6fc12076e7a2a242155c005f4b05b8 100644 --- a/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc +++ b/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc @@ -1066,7 +1066,7 @@ bool DotOpEmitter::EmitExperimentalGebpDotIfEnabled( << config.GetCacheKey(); const bool enable_fast_math = - hlo_module_config_.debug_options().xla_enable_fast_math(); + hlo_module_config_.debug_options().xla_cpu_enable_fast_math(); const bool optimize_for_size = options::OptimizeForSizeRequested(hlo_module_config_); @@ -1149,7 +1149,7 @@ bool DotOpEmitter::EmitLlvmIrDotIfProfitable() { swap_operands ? lhs_array_.GetBasePointer() : rhs_array_.GetBasePointer(); const bool enable_fast_math = - hlo_module_config_.debug_options().xla_enable_fast_math(); + hlo_module_config_.debug_options().xla_cpu_enable_fast_math(); const bool optimize_for_size = options::OptimizeForSizeRequested(hlo_module_config_); diff --git a/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.cc b/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.cc index cf955a8add394c204673be0746a451d4edcadc96..db54454707983ade31594119b2e868fa168d4cc2 100644 --- a/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.cc +++ b/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.cc @@ -19,6 +19,8 @@ limitations under the License. #include "llvm/IR/Instructions.h" #include "llvm/IR/Module.h" +#include "tensorflow/compiler/xla/service/hlo_casting_utils.h" +#include "tensorflow/compiler/xla/service/hlo_instructions.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" #include "tensorflow/compiler/xla/types.h" @@ -28,47 +30,6 @@ limitations under the License. namespace xla { namespace cpu { -StatusOr CpuElementalIrEmitter::EmitFloatUnaryOp( - const HloInstruction* op, llvm::Value* operand_value) const { - switch (op->opcode()) { - case HloOpcode::kTanh: { - PrimitiveType element_type = op->shape().element_type(); - bool cast_result_to_fp16 = false; - string function_name; - switch (element_type) { - case F16: - cast_result_to_fp16 = true; - operand_value = b_->CreateFPCast(operand_value, b_->getFloatTy()); - TF_FALLTHROUGH_INTENDED; - case F32: - function_name = "tanhf"; - break; - case F64: - function_name = "tanh"; - break; - default: - return Unimplemented("tanh"); - } - // Create a function declaration. - llvm::Function* function = - llvm::cast(module_->getOrInsertFunction( - llvm_ir::AsStringRef(function_name), operand_value->getType(), - operand_value->getType())); - function->setCallingConv(llvm::CallingConv::C); - function->setDoesNotThrow(); - function->setDoesNotAccessMemory(); - // Create an instruction to call the function. - llvm::Value* result = b_->CreateCall(function, operand_value); - if (cast_result_to_fp16) { - result = b_->CreateFPCast(result, b_->getHalfTy()); - } - return result; - } - default: - return ElementalIrEmitter::EmitFloatUnaryOp(op, operand_value); - } -} - StatusOr CpuElementalIrEmitter::EmitAtan2( PrimitiveType prim_type, llvm::Value* lhs, llvm::Value* rhs) const { string function_name; @@ -104,6 +65,39 @@ StatusOr CpuElementalIrEmitter::EmitAtan2( return result; } +StatusOr CpuElementalIrEmitter::EmitTanh( + PrimitiveType prim_type, llvm::Value* value) const { + bool cast_result_to_fp16 = false; + string function_name; + switch (prim_type) { + case F16: + cast_result_to_fp16 = true; + value = b_->CreateFPCast(value, b_->getFloatTy()); + TF_FALLTHROUGH_INTENDED; + case F32: + function_name = "tanhf"; + break; + case F64: + function_name = "tanh"; + break; + default: + return Unimplemented("tanh"); + } + // Create a function declaration. + llvm::Function* function = llvm::cast( + module_->getOrInsertFunction(llvm_ir::AsStringRef(function_name), + value->getType(), value->getType())); + function->setCallingConv(llvm::CallingConv::C); + function->setDoesNotThrow(); + function->setDoesNotAccessMemory(); + // Create an instruction to call the function. + llvm::Value* result = b_->CreateCall(function, value); + if (cast_result_to_fp16) { + result = b_->CreateFPCast(result, b_->getHalfTy()); + } + return result; +} + llvm_ir::ElementGenerator CpuElementalIrEmitter::MakeElementGenerator( const HloInstruction* hlo, const HloToElementGeneratorMap& operand_to_generator) const { @@ -117,9 +111,8 @@ llvm_ir::ElementGenerator CpuElementalIrEmitter::MakeElementGenerator( ElementwiseSourceIndex(index, *hlo, i))); operands.push_back(operand_value); } - return ir_emitter_->EmitScalarCall(hlo->shape().element_type(), - hlo->to_apply(), operands, - llvm_ir::IrName(hlo)); + return ir_emitter_->EmitElementalMap(*Cast(hlo), + operands, llvm_ir::IrName(hlo)); }; } return ElementalIrEmitter::MakeElementGenerator(hlo, operand_to_generator); diff --git a/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.h b/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.h index 9598a886ab49fcecf5df7bd65f425fe485de3574..76833e765d05f2477961cd06cead66797c5be623 100644 --- a/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.h +++ b/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.h @@ -39,10 +39,10 @@ class CpuElementalIrEmitter : public ElementalIrEmitter { const HloToElementGeneratorMap& operand_to_generator) const override; protected: - StatusOr EmitFloatUnaryOp( - const HloInstruction* op, llvm::Value* operand_value) const override; StatusOr EmitAtan2(PrimitiveType prim_type, llvm::Value* lhs, llvm::Value* rhs) const override; + StatusOr EmitTanh(PrimitiveType prim_type, + llvm::Value* value) const override; IrEmitter* ir_emitter_; }; diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc index 9d9d3e04a93fe9bbc20a2fa84ec0e07d70ea37aa..6f433b4f30372da9cf4503396dbb60172cfc0cb0 100644 --- a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc +++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc @@ -51,6 +51,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_casting_utils.h" #include "tensorflow/compiler/xla/service/hlo_instructions.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/service/llvm_ir/buffer_assignment_util.h" #include "tensorflow/compiler/xla/service/llvm_ir/dynamic_update_slice_util.h" #include "tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h" #include "tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h" @@ -98,7 +99,7 @@ IrEmitter::IrEmitter( target_machine_features_(*target_machine_features) { b_.setFastMathFlags(llvm_ir::GetFastMathFlags( /*fast_math_enabled=*/hlo_module_config_.debug_options() - .xla_enable_fast_math())); + .xla_cpu_enable_fast_math())); } StatusOr IrEmitter::EmitComputation( @@ -115,6 +116,19 @@ StatusOr IrEmitter::EmitComputation( computation->root_instruction()->outer_dimension_partitions().size(); } + if (computation->root_instruction()->opcode() != HloOpcode::kOutfeed) { + TF_ASSIGN_OR_RETURN( + computation_root_allocation_, + assignment_.GetUniqueTopLevelSlice(computation->root_instruction())); + } + + for (const HloInstruction* param : computation->parameter_instructions()) { + TF_ASSIGN_OR_RETURN(BufferAllocation::Slice param_slice, + assignment_.GetUniqueTopLevelSlice(param)); + computation_parameter_allocations_[param_slice.allocation()->index()] = + param->parameter_number(); + } + InitializeIrFunction(function_name); // The rdtscp instruction is x86 specific. We will fallback to LLVM's generic // readcyclecounter if it is unavailable. @@ -131,6 +145,8 @@ StatusOr IrEmitter::EmitComputation( // Delete 'compute_function', finalizing 'ir_function' and restoring caller // IR insert point. compute_function_.reset(); + computation_root_allocation_ = BufferAllocation::Slice(); + computation_parameter_allocations_.clear(); return ir_function; } @@ -142,11 +158,11 @@ void IrEmitter::InitializeIrFunction(const string& function_name) { is_top_level_computation_ ? llvm::GlobalValue::ExternalLinkage : llvm::GlobalValue::InternalLinkage; // Create and initialize new IrFunction. - compute_function_.reset( - new IrFunction(function_name, linkage, - options::OptimizeForSizeRequested(hlo_module_config_), - hlo_module_config_.debug_options().xla_enable_fast_math(), - module_, &b_, num_dynamic_loop_bounds_)); + compute_function_.reset(new IrFunction( + function_name, linkage, + options::OptimizeForSizeRequested(hlo_module_config_), + hlo_module_config_.debug_options().xla_cpu_enable_fast_math(), module_, + &b_, num_dynamic_loop_bounds_)); } IrEmitter::~IrEmitter() {} @@ -175,25 +191,36 @@ llvm::Constant* IrEmitter::EmitGlobalForLiteral(const Literal& literal) { result_global, IrShapeType(literal.shape())->getPointerTo()); } -Status IrEmitter::HandleConstant(HloInstruction* constant) { - VLOG(2) << "HandleConstant: " << constant->ToString(); - const Literal& literal = constant->literal(); - llvm::Constant* global_for_const; +Status IrEmitter::EmitConstantGlobals() { + for (const BufferAllocation& allocation : assignment_.Allocations()) { + if (!allocation.is_constant()) { + continue; + } - auto it = emitted_literals_.find(&literal); - if (it != emitted_literals_.end()) { - global_for_const = it->second; - } else { - global_for_const = EmitGlobalForLiteral(literal); - emitted_literals_[&literal] = global_for_const; + const Literal& literal = llvm_ir::LiteralForConstantAllocation(allocation); + llvm::Constant* global_for_const; + auto it = emitted_literals_.find(&literal); + if (it != emitted_literals_.end()) { + global_for_const = it->second; + } else { + global_for_const = EmitGlobalForLiteral(literal); + InsertOrDie(&emitted_literals_, &literal, global_for_const); + } + + InsertOrDie(&constant_buffer_to_global_, allocation.index(), + global_for_const); } - emitted_value_[constant] = global_for_const; - VLOG(2) << " emitted value: " << llvm_ir::DumpToString(*global_for_const); - VLOG(2) << " its type: " - << llvm_ir::DumpToString(*global_for_const->getType()); + return Status::OK(); } +Status IrEmitter::HandleConstant(HloInstruction* constant) { + VLOG(2) << "HandleConstant: " << constant->ToString(); + // IrEmitter::EmitConstantGlobals has already taken care of emitting the body + // of the constant. + return EmitTargetAddressForOp(constant); +} + Status IrEmitter::HandleCopy(HloInstruction* copy) { if (ShapeUtil::IsTuple(copy->shape())) { // kCopy shallow copies a tuple so just memcpy the top-level buffer. @@ -472,23 +499,11 @@ Status IrEmitter::HandleTuple(HloInstruction* tuple) { return Status::OK(); } -StatusOr IrEmitter::EmitTargetElementLoopBodyForMap( - HloMapInstruction* map, const llvm_ir::IrArray::Index& index) { - llvm::Function* mapped_ir_function = - FindOrDie(emitted_functions_, map->to_apply()); - std::vector parameter_addresses; - for (const HloInstruction* operand : map->operands()) { - const llvm_ir::IrArray& array = GetIrArrayFor(operand); - parameter_addresses.push_back(array.EmitArrayElementAddress(index, &b_)); - } - return EmitElementFunctionCall(mapped_ir_function, map->shape(), - parameter_addresses, "map_function"); -} - -Status IrEmitter::HandleMap(HloInstruction* map) { - return EmitTargetElementLoop(map, [&](const llvm_ir::IrArray::Index& index) { - return EmitTargetElementLoopBodyForMap(Cast(map), index); - }); +llvm::Value* IrEmitter::EmitElementalMap( + const HloMapInstruction& map_instr, + tensorflow::gtl::ArraySlice elemental_operands, + tensorflow::StringPiece name) { + return EmitThreadLocalCall(*map_instr.to_apply(), elemental_operands, name); } StatusOr IrEmitter::EmitTargetElementLoopBodyForReduceWindow( @@ -496,9 +511,6 @@ StatusOr IrEmitter::EmitTargetElementLoopBodyForReduceWindow( const llvm_ir::IrArray::Index& index) { const HloInstruction* operand = reduce_window->operand(0); const Window& window = reduce_window->window(); - HloComputation* function = reduce_window->to_apply(); - // The called computation should have been emitted previously. - llvm::Function* reducer_function = FindOrDie(emitted_functions_, function); // We fold inputs into the accumulator and initialize it to // the initial value on the reduce_window. @@ -551,11 +563,10 @@ StatusOr IrEmitter::EmitTargetElementLoopBodyForReduceWindow( // We are not in the padding, so carry out the computation. llvm_ir::IrArray input_array(GetIrArrayFor(operand)); - llvm::Value* input_value_address = - input_array.EmitArrayElementAddress(input_index, &b_); - llvm::Value* result = EmitElementFunctionCall( - reducer_function, reduce_window->shape(), - {accumulator_address, input_value_address}, "reducer_function"); + llvm::Value* input_value = input_array.EmitReadArrayElement(input_index, &b_); + llvm::Value* result = EmitThreadLocalCall( + *reduce_window->to_apply(), + {b_.CreateLoad(accumulator_address), input_value}, "reducer_function"); b_.CreateStore(result, accumulator_address); SetToFirstInsertPoint(loops.GetOuterLoopExitBasicBlock(), &b_); @@ -566,7 +577,7 @@ Status IrEmitter::HandleReduceWindow(HloInstruction* reduce_window) { TF_RETURN_IF_ERROR(ElementTypesSameAndSupported( /*instruction=*/*reduce_window, /*operands=*/{reduce_window->operand(0)}, - /*supported_types=*/{F32, BF16, S32})); + /*supported_types=*/{F32, BF16, S32, F16})); // TODO(b/31410564): Implement dilation for reduce-window. if (window_util::HasDilation(reduce_window->window())) { @@ -611,12 +622,6 @@ Status IrEmitter::HandleSelectAndScatter(HloInstruction* select_and_scatter) { "Dilation for SelectAndScatter is not implemented on CPU. "); } - // The select and scatter computations should have been emitted previously. - llvm::Function* select_function = - FindOrDie(emitted_functions_, select_and_scatter->select()); - llvm::Function* scatter_function = - FindOrDie(emitted_functions_, select_and_scatter->scatter()); - // Pseudo code for select-and-scatter: // // initialized_flag is initially off for every window, and is turned on after @@ -721,11 +726,12 @@ Status IrEmitter::HandleSelectAndScatter(HloInstruction* select_and_scatter) { // If the initialized_flag is true, call the `select` function to potentially // update the selected value and index with the currently visiting operand. SetToFirstInsertPoint(if_initialized.true_block, &b_); - const Shape output_shape = ShapeUtil::MakeShape(PRED, {}); llvm::Value* operand_address = operand_array.EmitArrayElementAddress(operand_index, &b_); - llvm::Value* result = EmitElementFunctionCall( - select_function, output_shape, {selected_value_address, operand_address}, + llvm::Value* operand_element = b_.CreateLoad(operand_address); + llvm::Value* result = EmitThreadLocalCall( + *select_and_scatter->select(), + {b_.CreateLoad(selected_value_address), operand_element}, "select_function"); // If the 'select' function returns false, update the selected value and the @@ -752,14 +758,14 @@ Status IrEmitter::HandleSelectAndScatter(HloInstruction* select_and_scatter) { selected_index.push_back(b_.CreateLoad(selected_index_address_slot)); } llvm_ir::IrArray source_array(GetIrArrayFor(source)); - llvm::Value* source_value_address = - source_array.EmitArrayElementAddress(source_index, &b_); + llvm::Value* source_value = + source_array.EmitReadArrayElement(source_index, &b_); llvm_ir::IrArray output_array(GetIrArrayFor(select_and_scatter)); - llvm::Value* output_value_address = - output_array.EmitArrayElementAddress(selected_index, &b_); - llvm::Value* scatter_value = EmitElementFunctionCall( - scatter_function, source->shape(), - {output_value_address, source_value_address}, "scatter_function"); + llvm::Value* output_value = + output_array.EmitReadArrayElement(selected_index, &b_); + llvm::Value* scatter_value = + EmitThreadLocalCall(*select_and_scatter->scatter(), + {output_value, source_value}, "scatter_function"); output_array.EmitWriteArrayElement(selected_index, scatter_value, &b_); SetToFirstInsertPoint(source_loops.GetOuterLoopExitBasicBlock(), &b_); @@ -1236,46 +1242,7 @@ static llvm_ir::IrArray::Index FillReducedDimensionIndex( Status IrEmitter::HandleParameter(HloInstruction* parameter) { VLOG(2) << "HandleParameter: " << parameter->ToString(); - auto param_number = parameter->parameter_number(); - auto param_shape = parameter->shape(); - - // We have to access the parameter at offset param_number in the params - // array. The code generated here is equivalent to this C code: - // - // i8* param_address_untyped = params[param_number]; - // Param* param_address_typed = (Param*)param_address_untyped; - // - // Where Param is the actual element type of the underlying buffer (for - // example, float for an XLA F32 element type). - llvm::Value* params = compute_function_->parameters_arg(); - llvm::Value* param_address_offset = - llvm_ir::EmitBufferIndexingGEP(params, param_number, &b_); - llvm::LoadInst* param_address_untyped = b_.CreateLoad(param_address_offset); - param_address_untyped->setName(AsStringRef(IrName(parameter, "untyped"))); - if (is_top_level_computation_ && - hlo_module_config_.debug_options() - .xla_llvm_enable_invariant_load_metadata()) { - // In the entry computation the parameter slots in the %params argument are - // invariant through program execution. In computations that are called - // from the entry computation (via kWhile, kCall and kConditional) the - // parameter slots are *not* invariant since they're written to by their - // callers. - param_address_untyped->setMetadata( - llvm::LLVMContext::MD_invariant_load, - llvm::MDNode::get(param_address_untyped->getContext(), /*MDs=*/{})); - } - - llvm::Value* param_address_typed = b_.CreateBitCast( - param_address_untyped, IrShapeType(param_shape)->getPointerTo()); - emitted_value_[parameter] = param_address_typed; - - if (!ShapeUtil::IsOpaque(param_shape)) { - AttachAlignmentMetadataForLoad(param_address_untyped, param_shape); - AttachDereferenceableMetadataForLoad(param_address_untyped, param_shape); - } - - VLOG(2) << " emitted value: " << llvm_ir::DumpToString(*param_address_typed); - return Status::OK(); + return EmitTargetAddressForOp(parameter); } // Returns true if the relative order of the unreduced dimensions stays the same @@ -1739,9 +1706,6 @@ StatusOr IrEmitter::EmitTargetElementLoopBodyForReduce( const HloInstruction* arg = reduce->mutable_operand(0); const HloInstruction* init_value = reduce->mutable_operand(1); gtl::ArraySlice dimensions(reduce->dimensions()); - HloComputation* function = reduce->to_apply(); - // The called computation should have been emitted previously. - llvm::Function* reducer_function = FindOrDie(emitted_functions_, function); // Initialize an accumulator with init_value. PrimitiveType accumulator_type = reduce->shape().element_type(); @@ -1781,10 +1745,9 @@ StatusOr IrEmitter::EmitTargetElementLoopBodyForReduce( CHECK(index.end() == it); // Apply the reduction function to the loaded value. - llvm::Value* input_address = - arg_array.EmitArrayElementAddress(input_index, &b_); - llvm::Value* result = EmitElementFunctionCall( - reducer_function, reduce->shape(), {accumulator_addr, input_address}, + llvm::Value* input_element = arg_array.EmitReadArrayElement(input_index, &b_); + llvm::Value* result = EmitThreadLocalCall( + *reduce->to_apply(), {b_.CreateLoad(accumulator_addr), input_element}, "reduce_function"); b_.CreateStore(result, accumulator_addr); @@ -1793,6 +1756,10 @@ StatusOr IrEmitter::EmitTargetElementLoopBodyForReduce( } Status IrEmitter::HandleReduce(HloInstruction* reduce) { + // TODO(b/112040122): Support variadic reduce. + if (!ShapeUtil::IsArray(reduce->shape())) { + return Unimplemented("Variadic reduce is not supported on CPU"); + } auto arg = reduce->mutable_operand(0); auto init_value = reduce->mutable_operand(1); gtl::ArraySlice dimensions(reduce->dimensions()); @@ -1830,6 +1797,10 @@ Status IrEmitter::HandleSendDone(HloInstruction* send_done) { return Unimplemented("Send-done is not implemented on CPU."); } +Status IrEmitter::HandleScatter(HloInstruction*) { + return Unimplemented("Scatter is not implemented on CPUs."); +} + Status IrEmitter::HandleSlice(HloInstruction* slice) { VLOG(2) << "HandleSlice: " << slice->ToString(); auto operand = slice->operand(0); @@ -2122,18 +2093,13 @@ Status IrEmitter::HandleCall(HloInstruction* call) { HloComputation* computation = call->to_apply(); llvm::Function* call_ir_function = FindOrDie(emitted_functions_, computation); - std::vector parameter_addresses; - for (const HloInstruction* operand : call->operands()) { - parameter_addresses.push_back(GetEmittedValueFor(operand)); - } - TF_RETURN_IF_ERROR(EmitTargetAddressForOp(call)); if (!computation->root_instruction()->outer_dimension_partitions().empty()) { // ParallelTaskAssignment assigned partitions, emit call to // ParallelForkJoin. std::vector call_args = GetArrayFunctionCallArguments( - parameter_addresses, &b_, computation->name(), + {}, &b_, computation->name(), /*return_value_buffer=*/emitted_value_[call], /*exec_run_options_arg=*/GetExecutableRunOptionsArgument(), /*temp_buffers_arg=*/GetTempBuffersArgument(), @@ -2144,8 +2110,7 @@ Status IrEmitter::HandleCall(HloInstruction* call) { call_args, root->shape(), root->outer_dimension_partitions(), &b_, call_ir_function, computation->name())); } else { - EmitArrayFunctionCallInto(call_ir_function, parameter_addresses, - emitted_value_[call], computation->name()); + EmitGlobalCall(*computation, computation->name()); } return Status::OK(); @@ -2226,12 +2191,6 @@ Status IrEmitter::HandleWhile(HloInstruction* xla_while) { const HloInstruction* init = xla_while->operand(0); emitted_value_[xla_while] = GetEmittedValueFor(init); - // The called computation should have been emitted previously. - llvm::Function* condition_ir_function = - FindOrDie(emitted_functions_, condition); - llvm::Function* body_ir_function = - FindOrDie(emitted_functions_, xla_while->while_body()); - // Generating: // while (Condition(while_result)) { // // CopyInsertion pass inserts copies which enable 'while_result' to @@ -2248,12 +2207,10 @@ Status IrEmitter::HandleWhile(HloInstruction* xla_while) { // Calls the condition function to determine whether to proceed with the // body. It must return a bool, so use the scalar call form. - llvm::Value* while_result = GetEmittedValueFor(xla_while); - llvm::Value* while_condition = EmitElementFunctionCall( - condition_ir_function, condition->root_instruction()->shape(), - {while_result}, IrName(xla_while, "cond")); + EmitGlobalCall(*xla_while->while_condition(), IrName(xla_while, "cond")); llvm::Value* while_predicate = b_.CreateICmpNE( - while_condition, + b_.CreateLoad( + GetBufferForGlobalCallReturnValue(*xla_while->while_condition())), llvm::ConstantInt::get(llvm_ir::PrimitiveTypeToIrType(PRED, module_), 0)); // Branches to the body or to the while exit depending on the condition. @@ -2268,8 +2225,8 @@ Status IrEmitter::HandleWhile(HloInstruction* xla_while) { b_.SetInsertPoint(body_bb); // Calls the body function. - EmitArrayFunctionCallInto(body_ir_function, {while_result}, while_result, - IrName(xla_while, "body")); + EmitGlobalCall(*xla_while->while_body(), IrName(xla_while, "body")); + // Finishes with a branch back to the header. b_.CreateBr(header_bb); @@ -2437,8 +2394,6 @@ Status IrEmitter::HandleConcatenate(HloInstruction* concatenate) { Status IrEmitter::HandleConditional(HloInstruction* conditional) { auto pred = conditional->operand(0); - auto true_arg = conditional->operand(1); - auto false_arg = conditional->operand(2); TF_RET_CHECK(ShapeUtil::IsScalar(pred->shape()) && pred->shape().element_type() == PRED) << "Predicate on a Conditional must be bool; got: " @@ -2460,13 +2415,7 @@ Status IrEmitter::HandleConditional(HloInstruction* conditional) { << " and " << ShapeUtil::HumanString(false_computation->root_instruction()->shape()); - llvm::Function* true_function = - FindOrDie(emitted_functions_, true_computation); - llvm::Function* false_function = - FindOrDie(emitted_functions_, false_computation); - TF_RETURN_IF_ERROR(EmitTargetAddressForOp(conditional)); - llvm::Value* conditional_result = GetEmittedValueFor(conditional); // Generating: // if (pred) @@ -2483,12 +2432,12 @@ Status IrEmitter::HandleConditional(HloInstruction* conditional) { llvm_ir::EmitIfThenElse(pred_cond, "conditional", &b_); SetToFirstInsertPoint(if_data.true_block, &b_); - EmitArrayFunctionCallInto(true_function, {GetEmittedValueFor(true_arg)}, - conditional_result, IrName(conditional, "_true")); + EmitGlobalCall(*conditional->true_computation(), + IrName(conditional, "_true")); SetToFirstInsertPoint(if_data.false_block, &b_); - EmitArrayFunctionCallInto(false_function, {GetEmittedValueFor(false_arg)}, - conditional_result, IrName(conditional, "_false")); + EmitGlobalCall(*conditional->false_computation(), + IrName(conditional, "_false")); SetToFirstInsertPoint(if_data.after_block, &b_); return Status::OK(); @@ -2689,40 +2638,76 @@ llvm::Value* IrEmitter::GetExecutableRunOptionsArgument() { return compute_function_->exec_run_options_arg(); } -llvm::Value* IrEmitter::EmitTempBufferPointer( +llvm::Value* IrEmitter::EmitThreadLocalTempBufferPointer( const BufferAllocation::Slice& slice, const Shape& target_shape) { - llvm::Type* element_type = IrShapeType(target_shape); - // The alignment and number of bytes within the temporary buffer is determined - // by the maximal shape as determined by buffer assignment. - const BufferAllocation& allocation = assignment_.GetAllocation(slice.index()); - if (allocation.is_thread_local()) { + const BufferAllocation& allocation = *slice.allocation(); + llvm::Value* tempbuf_address = [&]() -> llvm::Value* { + if (slice == computation_root_allocation_) { + llvm::Argument* retval = compute_function_->result_arg(); + llvm::AttrBuilder attr_builder; + attr_builder.addAlignmentAttr(MinimumAlignmentForShape(target_shape)); + attr_builder.addDereferenceableAttr(ByteSizeOf(target_shape)); + retval->addAttrs(attr_builder); + return retval; + } + + auto param_it = + computation_parameter_allocations_.find(slice.allocation()->index()); + if (param_it != computation_parameter_allocations_.end()) { + int64 param_number = param_it->second; + // We have to access the parameter at offset param_number in the params + // array. The code generated here is equivalent to this C code: + // + // i8* param_address_untyped = params[param_number]; + // Param* param_address_typed = (Param*)param_address_untyped; + // + // Where Param is the actual element type of the underlying buffer (for + // example, float for an XLA F32 element type). + llvm::Value* params = compute_function_->parameters_arg(); + llvm::Value* param_address_offset = + llvm_ir::EmitBufferIndexingGEP(params, param_number, &b_); + llvm::LoadInst* param_address_untyped = + b_.CreateLoad(param_address_offset); + + if (!ShapeUtil::IsOpaque(target_shape)) { + AttachAlignmentMetadataForLoad(param_address_untyped, target_shape); + AttachDereferenceableMetadataForLoad(param_address_untyped, + target_shape); + } + return param_address_untyped; + } + // Thread-local allocations should only be assigned a single buffer. const auto& assigned_buffers = allocation.assigned_buffers(); CHECK_EQ(1, assigned_buffers.size()); const Shape& shape = assigned_buffers.begin()->first->shape(); - llvm::AllocaInst*& tempbuf_address = - thread_local_buffers_[{b_.GetInsertBlock()->getParent(), slice}]; - if (tempbuf_address == nullptr) { - tempbuf_address = llvm_ir::EmitAllocaAtFunctionEntry( + std::pair key = { + compute_function_->function(), slice}; + auto buf_it = thread_local_buffers_.find(key); + if (buf_it == thread_local_buffers_.end()) { + llvm::Value* buffer = llvm_ir::EmitAllocaAtFunctionEntry( IrShapeType(shape), tensorflow::strings::StrCat("thread_local", slice.ToString()), &b_, MinimumAlignmentForShape(target_shape)); + auto it_inserted_pair = thread_local_buffers_.insert({key, buffer}); + CHECK(it_inserted_pair.second); + buf_it = it_inserted_pair.first; } - return b_.CreateBitCast(tempbuf_address, element_type->getPointerTo()); - } + return buf_it->second; + }(); + return b_.CreateBitCast(tempbuf_address, + IrShapeType(target_shape)->getPointerTo()); +} +llvm::Value* IrEmitter::EmitGlobalTempBufferPointer( + const BufferAllocation::Slice& slice, const Shape& target_shape) { + const BufferAllocation& allocation = *slice.allocation(); llvm::Value* tempbuf_address_ptr = llvm_ir::EmitBufferIndexingGEP( GetTempBuffersArgument(), slice.index(), &b_); llvm::LoadInst* tempbuf_address_base = b_.CreateLoad(tempbuf_address_ptr); - if (is_top_level_computation_ && - hlo_module_config_.debug_options() + if (hlo_module_config_.debug_options() .xla_llvm_enable_invariant_load_metadata()) { - // In the entry computation the parameter slots in the %params argument are - // invariant through program execution. In computations that are called - // from the entry computation (via kWhile, kCall and kConditional) the - // parameter slots are *not* invariant since they're written to by their - // callers. tempbuf_address_base->setMetadata( llvm::LLVMContext::MD_invariant_load, llvm::MDNode::get(tempbuf_address_base->getContext(), /*MDs=*/{})); @@ -2737,85 +2722,25 @@ llvm::Value* IrEmitter::EmitTempBufferPointer( b_.CreateInBoundsGEP(tempbuf_address_base, b_.getInt64(slice.offset())); } return b_.CreateBitCast(tempbuf_address_untyped, - element_type->getPointerTo()); + IrShapeType(target_shape)->getPointerTo()); } -// Emits a function call returning a single array element. Allocates space -// for a single element_type value, and loads it after call. -llvm::Value* IrEmitter::EmitElementFunctionCall( - llvm::Function* function, const Shape& return_shape, - gtl::ArraySlice parameter_addresses, - tensorflow::StringPiece name) { - llvm::Value* return_value_buffer = EmitArrayFunctionCall( - function, return_shape, 1, parameter_addresses, name); - return b_.CreateLoad( - return_value_buffer, - AsStringRef(tensorflow::strings::StrCat(name, "_return_value"))); -} - -// Emits a core function call based on the following pseudo-code. -// -// char** parameter_addresses_buffer = -// allocate buffer with a pointer for each parameter to the function -// for each parameter index, i.e. for i = 0, ..., #parameters: -// parameter_addresses_buffer[i] = parameter_addresses[i] -// call function(return_value_buffer, -// parameter_addresses_buffer, -// temps) -// return return_value_buffer -- address of the return value. -void IrEmitter::EmitArrayFunctionCallInto( - llvm::Function* function, gtl::ArraySlice parameter_addresses, - llvm::Value* return_value_buffer, tensorflow::StringPiece name) { - b_.CreateCall(function, - GetArrayFunctionCallArguments( - parameter_addresses, &b_, name, - /*return_value_buffer=*/return_value_buffer, - /*exec_run_options_arg=*/GetExecutableRunOptionsArgument(), - /*temp_buffers_arg=*/GetTempBuffersArgument(), - /*profile_counters_arg=*/GetProfileCountersArgument())); -} - -llvm::Value* IrEmitter::EmitArrayFunctionCall( - llvm::Function* function, const Shape& return_shape, int64 element_count, - gtl::ArraySlice parameter_addresses, - tensorflow::StringPiece name) { - llvm::Value* elements = - llvm::ConstantInt::get(b_.getInt64Ty(), element_count); - PrimitiveType return_type = return_shape.element_type(); - llvm::Value* return_value_buffer = - llvm_ir::EmitAllocaAtFunctionEntryWithCount( - llvm_ir::PrimitiveTypeToIrType(return_type, module_), elements, - tensorflow::strings::StrCat(name, "_return_value_address"), &b_, - MinimumAlignmentForPrimitiveType(return_type)); - EmitArrayFunctionCallInto(function, parameter_addresses, return_value_buffer, - name); - return return_value_buffer; +llvm::Value* IrEmitter::EmitTempBufferPointer( + const BufferAllocation::Slice& slice, const Shape& target_shape) { + if (slice.allocation()->is_thread_local()) { + return EmitThreadLocalTempBufferPointer(slice, target_shape); + } else if (slice.allocation()->is_constant()) { + return FindOrDie(constant_buffer_to_global_, slice.allocation()->index()); + } else { + return EmitGlobalTempBufferPointer(slice, target_shape); + } } Status IrEmitter::EmitTargetAddressForOp(const HloInstruction* op) { - llvm::Value* addr; const Shape& target_shape = op->shape(); - if (op == op->parent()->root_instruction()) { - // For the root node, we write directly to the output buffer of the - // function. - llvm::Argument* retval = compute_function_->result_arg(); - if ((ShapeUtil::IsArray(target_shape) && - !ShapeUtil::IsZeroElementArray(target_shape)) || - (ShapeUtil::IsTuple(target_shape) && - !ShapeUtil::IsEmptyTuple(target_shape))) { - llvm::AttrBuilder attr_builder; - attr_builder.addAlignmentAttr(MinimumAlignmentForShape(target_shape)); - attr_builder.addDereferenceableAttr(ByteSizeOf(target_shape)); - retval->addAttrs(attr_builder); - } - addr = b_.CreateBitCast(retval, IrShapeType(target_shape)->getPointerTo()); - } else { - // For other nodes, we need the temporary buffer allocated for this node to - // write the result into. - TF_ASSIGN_OR_RETURN(const BufferAllocation::Slice slice, - assignment_.GetUniqueTopLevelSlice(op)); - addr = EmitTempBufferPointer(slice, target_shape); - } + TF_ASSIGN_OR_RETURN(const BufferAllocation::Slice slice, + assignment_.GetUniqueTopLevelSlice(op)); + llvm::Value* addr = EmitTempBufferPointer(slice, target_shape); addr->setName(AsStringRef(IrName(op))); emitted_value_[op] = addr; return Status::OK(); @@ -2920,20 +2845,69 @@ Status IrEmitter::DefaultAction(HloInstruction* hlo) { hlo, elemental_emitter.MakeElementGenerator(hlo, operand_to_generator)); } -StatusOr IrEmitter::EmitScalarCall( - PrimitiveType return_type, HloComputation* computation, - const std::vector& arguments, tensorflow::StringPiece name) { - llvm::Function* llvm_function = FindOrDie(emitted_functions_, computation); - std::vector argument_addrs; - for (auto argument : arguments) { - llvm::Value* argument_addr = llvm_ir::EmitAllocaAtFunctionEntry( - argument->getType(), "arg_addr", &b_); - b_.CreateStore(argument, argument_addr); - argument_addrs.push_back(argument_addr); +llvm::Value* IrEmitter::EmitThreadLocalCall( + const HloComputation& callee, + tensorflow::gtl::ArraySlice parameters, + tensorflow::StringPiece name) { + const Shape& return_shape = callee.root_instruction()->shape(); + + // Lifting this restriction to allow "small" arrays should be easy. Allowing + // larger arrays is difficult because we allocate the buffer for this return + // value on the stack. + CHECK(ShapeUtil::IsScalar(return_shape)); + + PrimitiveType return_type = return_shape.element_type(); + + std::vector parameter_addrs; + for (llvm::Value* parameter : parameters) { + CHECK(!parameter->getType()->isPointerTy()); + llvm::Value* parameter_addr = llvm_ir::EmitAllocaAtFunctionEntry( + parameter->getType(), "arg_addr", &b_); + b_.CreateStore(parameter, parameter_addr); + parameter_addrs.push_back(parameter_addr); } - return EmitElementFunctionCall(llvm_function, - ShapeUtil::MakeShape(return_type, {}), - argument_addrs, name); + + llvm::Value* return_value_buffer = llvm_ir::EmitAllocaAtFunctionEntry( + llvm_ir::PrimitiveTypeToIrType(return_type, module_), + tensorflow::strings::StrCat(name, "_retval_addr"), &b_, + MinimumAlignmentForPrimitiveType(return_type)); + + b_.CreateCall( + FindOrDie(emitted_functions_, &callee), + GetArrayFunctionCallArguments( + parameter_addrs, &b_, name, + /*return_value_buffer=*/return_value_buffer, + /*exec_run_options_arg=*/GetExecutableRunOptionsArgument(), + /*temp_buffers_arg=*/ + llvm::Constant::getNullValue(b_.getInt8PtrTy()->getPointerTo()), + /*profile_counters_arg=*/GetProfileCountersArgument())); + + return b_.CreateLoad(return_value_buffer); +} + +void IrEmitter::EmitGlobalCall(const HloComputation& callee, + tensorflow::StringPiece name) { + b_.CreateCall(FindOrDie(emitted_functions_, &callee), + GetArrayFunctionCallArguments( + /*parameter_addresses=*/{}, &b_, name, + /*return_value_buffer=*/ + llvm::Constant::getNullValue(b_.getInt8PtrTy()), + /*exec_run_options_arg=*/GetExecutableRunOptionsArgument(), + /*temp_buffers_arg=*/GetTempBuffersArgument(), + /*profile_counters_arg=*/GetProfileCountersArgument())); +} + +llvm::Value* IrEmitter::GetBufferForGlobalCallReturnValue( + const HloComputation& callee) { + const HloInstruction* root_inst = callee.root_instruction(); + if (root_inst->opcode() == HloOpcode::kOutfeed) { + return llvm::Constant::getNullValue(b_.getInt8PtrTy()); + } + + const BufferAllocation::Slice root_buffer = + assignment_.GetUniqueTopLevelSlice(root_inst).ValueOrDie(); + return EmitTempBufferPointer(root_buffer, root_inst->shape()); } + } // namespace cpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.h b/tensorflow/compiler/xla/service/cpu/ir_emitter.h index cf7fa05b20753dcd87c69ddcf8cc7e70f1412248..c9a1dab62dcbcd926baa82737d24efa03fd326e9 100644 --- a/tensorflow/compiler/xla/service/cpu/ir_emitter.h +++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.h @@ -100,10 +100,14 @@ class IrEmitter : public DfsHloVisitorWithDefault { llvm::IRBuilder<>* b() { return &b_; } - // Emits a call to `computation` with scalar arguments `arguments`. - StatusOr EmitScalarCall( - PrimitiveType return_type, HloComputation* computation, - const std::vector& arguments, tensorflow::StringPiece name); + // Emit an LLVM global variable for every constant buffer allocation. + Status EmitConstantGlobals(); + + // Emit code to map one element according to `map_instr`. + llvm::Value* EmitElementalMap( + const HloMapInstruction& map_instr, + tensorflow::gtl::ArraySlice elemental_operands, + tensorflow::StringPiece name); protected: // @@ -140,13 +144,13 @@ class IrEmitter : public DfsHloVisitorWithDefault { Status HandleRecvDone(HloInstruction* recv_done) override; Status HandlePad(HloInstruction* pad) override; Status HandleTuple(HloInstruction* tuple) override; - Status HandleMap(HloInstruction* map) override; Status HandleFusion(HloInstruction* fusion) override; Status HandleCall(HloInstruction* call) override; Status HandleCustomCall(HloInstruction* custom_call) override; Status HandleWhile(HloInstruction* xla_while) override; Status HandleConcatenate(HloInstruction* concatenate) override; Status HandleConditional(HloInstruction* conditional) override; + Status HandleScatter(HloInstruction* scatter) override; Status HandleAfterAll(HloInstruction* gen_token) override; Status HandleIota(HloInstruction* iota) override; Status HandleRng(HloInstruction* rng) override; @@ -215,9 +219,18 @@ class IrEmitter : public DfsHloVisitorWithDefault { // computation function being emitted by this emitter. llvm::Value* GetTempBuffersArgument(); - // Emits code that computes the address of the given temporary buffer to the - // function. target_shape is the shape of this temporary buffer. - // The returned Value's type is a pointer to element_type. + // Helper for EmitTempBufferPointer. + llvm::Value* EmitGlobalTempBufferPointer(const BufferAllocation::Slice& slice, + const Shape& target_shape); + + // Helper for EmitTempBufferPointer. + llvm::Value* EmitThreadLocalTempBufferPointer( + const BufferAllocation::Slice& slice, const Shape& target_shape); + + // Emits code that computes the address of the given buffer allocation slice. + // + // TODO(sanjoy): This should be renamed to reflect that it no longer provides + // access to just temporaries. llvm::Value* EmitTempBufferPointer(const BufferAllocation::Slice& slice, const Shape& target_shape); @@ -229,44 +242,27 @@ class IrEmitter : public DfsHloVisitorWithDefault { tensorflow::StringPiece function_name_suffix); // Used for LLVM IR register names. - // Methods that emit a function call. - // Parameters: - // function - The LLVM function to call. - // return_shape - The return shape of the HLO computation that was used to - // make the function. Not the same as the return type of the function - // in LLVM, since we use output parameters for the return type. - // element_count - number of elements to return (array form only). - // parameter_addresses - pointers to be passed to the function as - // parameters. - // name - used for LLVM IR register names. - - // Emits a function call, returning a scalar, often an element of a larger - // array. Returns a Value for the scalar element returned by the function. - llvm::Value* EmitElementFunctionCall( - llvm::Function* function, const Shape& return_shape, - tensorflow::gtl::ArraySlice parameter_addresses, + // Emits a call to a thread local function (e.g. to the computation nested + // within a reduce or a map). Thread local callees (by definition) only write + // to and read from thread local allocations. + // + // `parameters` holds the *scalar values* that need to be passed to the + // callee. The return value is the scalar returned by the callee. + llvm::Value* EmitThreadLocalCall( + const HloComputation& callee, + tensorflow::gtl::ArraySlice parameters, tensorflow::StringPiece name); - // Array function call emitter. Stores the function's result into a supplied - // buffer. - // Parameters: - // function - The LLVM function to call. - // parameter_addresses - pointers to be passed to the function as - // parameters. - // return_value - pointer to a buffer where the call result is stored. - - void EmitArrayFunctionCallInto( - llvm::Function* function, - tensorflow::gtl::ArraySlice parameter_addresses, - llvm::Value* return_value_buffer, tensorflow::StringPiece name); - - // Array function call emitter. Returns a Value for the function's return - // value buffer address. The return value buffer is alloca'ed by this - // function. - llvm::Value* EmitArrayFunctionCall( - llvm::Function* function, const Shape& return_shape, int64 element_count, - tensorflow::gtl::ArraySlice parameter_addresses, - tensorflow::StringPiece name); + // Emits a call to a "global" function (e.g. to the computation nested within + // a kWhile or a kCall). Buffer assignment unabiguously assignes buffers to + // the parameters and return values for these computations so there is no need + // to explicitly pass parameters or return results. + void EmitGlobalCall(const HloComputation& callee, + tensorflow::StringPiece name); + + // Returns the buffer to which a global call to `callee` would have written + // its result. + llvm::Value* GetBufferForGlobalCallReturnValue(const HloComputation& callee); // Verifies that the element types of all of the given operand instructions // match and are of one of the given supported types. @@ -405,11 +401,10 @@ class IrEmitter : public DfsHloVisitorWithDefault { NameUniquer name_uniquer_; // Map containing all previously emitted computations. - std::map emitted_functions_; + std::map emitted_functions_; // Map containing all previously emitted thread-local temporary buffers. - std::map, - llvm::AllocaInst*> + std::map, llvm::Value*> thread_local_buffers_; // The following fields track the IR emission state. According to LLVM memory @@ -419,6 +414,16 @@ class IrEmitter : public DfsHloVisitorWithDefault { std::unique_ptr compute_function_; llvm::IRBuilder<> b_; + // The buffer allocation slice for the root of the computation being compiled. + // Only relevant for thread local computations. + BufferAllocation::Slice computation_root_allocation_; + + // Maps the buffer allocation slices for the parameters to the computation + // being compiled to their parameter numbers. Only relevant for thread local + // computations. + tensorflow::gtl::FlatMap + computation_parameter_allocations_; + // Maps HLO instructions to their index into the profile counter array. const std::unordered_map instruction_to_profile_idx_; @@ -560,6 +565,9 @@ class IrEmitter : public DfsHloVisitorWithDefault { LiteralPtrHashFunctor, LiteralPtrEqualityFunctor> emitted_literals_; + tensorflow::gtl::FlatMap + constant_buffer_to_global_; + TF_DISALLOW_COPY_AND_ASSIGN(IrEmitter); }; diff --git a/tensorflow/compiler/xla/service/cpu/ir_function.cc b/tensorflow/compiler/xla/service/cpu/ir_function.cc index 6aff838462ac6bfe8a31971108a721b66dbe45bd..2db4d000f5b149969c88fb4325ca28aa11dc3708 100644 --- a/tensorflow/compiler/xla/service/cpu/ir_function.cc +++ b/tensorflow/compiler/xla/service/cpu/ir_function.cc @@ -80,9 +80,16 @@ void IrFunction::Initialize(const string& function_name, // void function(i8* retval, i8* run_options, i8** params, i8** temps, // i64* dynamic_loop_bounds, i64* prof_counters) // - // retval: points to the returned value. - // params: address of an array with pointers to parameters. - // temps: address of an array with pointers to temporary buffers. + // For thread local functions: + // retval: points to the returned value. + // params: address of an array with pointers to parameters. + // temps: is null + // + // For global functions: + // retval: is null + // params: is null + // temps: address of an array with pointers to temporary buffers and entry + // computation parameters. // // Therefore, the generated function's signature (FunctionType) is statically // determined - parameter unpacking is done in code generated into the @@ -196,18 +203,25 @@ std::vector GetArrayFunctionCallArguments( llvm::IRBuilder<>* b, tensorflow::StringPiece name, llvm::Value* return_value_buffer, llvm::Value* exec_run_options_arg, llvm::Value* temp_buffers_arg, llvm::Value* profile_counters_arg) { - llvm::Value* parameter_addresses_buffer = - llvm_ir::EmitAllocaAtFunctionEntryWithCount( - b->getInt8PtrTy(), b->getInt32(parameter_addresses.size()), - tensorflow::strings::StrCat(name, "_parameter_addresses"), b); - for (size_t i = 0; i < parameter_addresses.size(); ++i) { - llvm::Value* parameter_as_i8ptr = - b->CreateBitCast(parameter_addresses[i], b->getInt8PtrTy(), - AsStringRef(tensorflow::strings::StrCat( - name, "_parameter_", i, "_address_as_i8ptr"))); - llvm::Value* slot_in_param_addresses = - b->CreateInBoundsGEP(parameter_addresses_buffer, {b->getInt64(i)}); - b->CreateStore(parameter_as_i8ptr, slot_in_param_addresses); + llvm::Value* parameter_addresses_buffer; + + if (parameter_addresses.empty()) { + parameter_addresses_buffer = + llvm::Constant::getNullValue(b->getInt8PtrTy()->getPointerTo()); + } else { + parameter_addresses_buffer = llvm_ir::EmitAllocaAtFunctionEntryWithCount( + b->getInt8PtrTy(), b->getInt32(parameter_addresses.size()), + tensorflow::strings::StrCat(name, "_parameter_addresses"), b); + + for (size_t i = 0; i < parameter_addresses.size(); ++i) { + llvm::Value* parameter_as_i8ptr = + b->CreateBitCast(parameter_addresses[i], b->getInt8PtrTy(), + AsStringRef(tensorflow::strings::StrCat( + name, "_parameter_", i, "_address_as_i8ptr"))); + llvm::Value* slot_in_param_addresses = + b->CreateInBoundsGEP(parameter_addresses_buffer, {b->getInt64(i)}); + b->CreateStore(parameter_as_i8ptr, slot_in_param_addresses); + } } const auto to_int8_ptr = [=](llvm::Value* ptr) { diff --git a/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.cc b/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.cc index 4fa5984b0466b178a587e97cbced97deac749f74..286d407ca6e796a184738aee4d14bd5ed7e2f356 100644 --- a/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.cc +++ b/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/cpu/parallel_task_assignment.h" +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/service/cpu/dot_op_emitter.h" #include "tensorflow/compiler/xla/service/cpu/ir_emission_utils.h" #include "tensorflow/compiler/xla/service/cpu/shape_partition.h" @@ -109,7 +110,7 @@ ParallelTaskAssignment::ParallelTaskAssignment( : target_machine_features_(*target_machine_features) { VLOG(1) << "ParallelTaskAssignment max_parallelism: " << max_parallelism; // Run cost analysis on 'module'. - auto cost_analysis = MakeUnique(shape_size); + auto cost_analysis = absl::make_unique(shape_size); HloComputation* computation = module->entry_computation(); Status status = computation->root_instruction()->Accept(cost_analysis.get()); if (status.ok()) { diff --git a/tensorflow/compiler/xla/service/cpu/parallel_task_assignment_test.cc b/tensorflow/compiler/xla/service/cpu/parallel_task_assignment_test.cc index 36c9f743859ae2da6c4fb3fd753bd7862fe2d3ab..ee272b5f4f49904a9e75a4653b7dc1fdc89434c1 100644 --- a/tensorflow/compiler/xla/service/cpu/parallel_task_assignment_test.cc +++ b/tensorflow/compiler/xla/service/cpu/parallel_task_assignment_test.cc @@ -110,9 +110,10 @@ TEST_F(ParallelTaskAssignmentTest, InfeedOutfeedOperationNotParallelized) { const string hlo_string = R"( HloModule TestTaskParallel_infeed_outfeed ENTRY InfeedOutfeed { - infeed0 = (u32[12345678,2]{1,0}, token[]) infeed() + token = token[] after-all() + infeed0 = (u32[12345678,2]{1,0}, token[]) infeed(token) infeed0.data = u32[12345678,2]{1,0} get-tuple-element((u32[12345678,2]{1,0}, token[]) infeed0), index=0 - ROOT outfeed0 = token[] outfeed(infeed0.data) + ROOT outfeed0 = token[] outfeed(infeed0.data, token) } )"; diff --git a/tensorflow/compiler/xla/service/cpu/runtime_fork_join.cc b/tensorflow/compiler/xla/service/cpu/runtime_fork_join.cc index d03da46575b331de113cc5f33c2b4267504e8308..a5f34908d70dd18ec017bdf9833c7df40f80db07 100644 --- a/tensorflow/compiler/xla/service/cpu/runtime_fork_join.cc +++ b/tensorflow/compiler/xla/service/cpu/runtime_fork_join.cc @@ -20,6 +20,7 @@ limitations under the License. #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #include "tensorflow/compiler/xla/executable_run_options.h" #include "tensorflow/core/lib/core/blocking_counter.h" +#include "tensorflow/core/platform/dynamic_annotations.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/types.h" @@ -58,13 +59,14 @@ using ComputeFunctionType = void (*)(void*, const void*, const void**, void**, // [partition1_dim2_start] // [partition1_dim2_limit] // -void __xla_cpu_runtime_ParallelForkJoin( +TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_ParallelForkJoin( void* result_ptr, const void* run_options_ptr, const void** params, void** temps, uint64* prof_counters, int32 num_partitions, int64* partitions, int32 num_partitioned_dims, void* function_ptr) { VLOG(2) << "ParallelForkJoin ENTRY" << " num_partitions: " << num_partitions << " num_partitioned_dims: " << num_partitioned_dims; + CHECK_EQ(params, nullptr); CHECK_GT(num_partitions, 1); CHECK_GT(num_partitioned_dims, 0); const xla::ExecutableRunOptions* run_options = @@ -79,9 +81,9 @@ void __xla_cpu_runtime_ParallelForkJoin( for (int32 i = 1; i < num_partitions; ++i) { const int64 offset = i * stride; run_options->intra_op_thread_pool()->enqueueNoNotification( - [i, function, result_ptr, run_options_ptr, params, temps, prof_counters, + [i, function, result_ptr, run_options_ptr, temps, prof_counters, partitions, offset, &bc]() { - function(result_ptr, run_options_ptr, params, temps, + function(result_ptr, run_options_ptr, nullptr, temps, &partitions[offset], prof_counters); bc.DecrementCount(); VLOG(3) << "ParallelForkJoin partition " << i << " done."; diff --git a/tensorflow/compiler/xla/service/cpu/runtime_matmul.cc b/tensorflow/compiler/xla/service/cpu/runtime_matmul.cc index 39b13183ff093611a42b3931d45f64eadb420622..a71a85913cfef271bc2a226cb0cf2dd4204499a4 100644 --- a/tensorflow/compiler/xla/service/cpu/runtime_matmul.cc +++ b/tensorflow/compiler/xla/service/cpu/runtime_matmul.cc @@ -20,6 +20,7 @@ limitations under the License. #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #include "tensorflow/compiler/xla/executable_run_options.h" #include "tensorflow/compiler/xla/service/cpu/runtime_matvec.h" +#include "tensorflow/core/platform/dynamic_annotations.h" #include "tensorflow/core/platform/types.h" using tensorflow::int32; @@ -77,27 +78,24 @@ void MatMulImpl(const void* run_options_ptr, T* out, T* lhs, T* rhs, int64 m, } // namespace -void __xla_cpu_runtime_EigenMatMulF16(const void* run_options_ptr, - Eigen::half* out, Eigen::half* lhs, - Eigen::half* rhs, int64 m, int64 n, - int64 k, int32 transpose_lhs, - int32 transpose_rhs) { +TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_EigenMatMulF16( + const void* run_options_ptr, Eigen::half* out, Eigen::half* lhs, + Eigen::half* rhs, int64 m, int64 n, int64 k, int32 transpose_lhs, + int32 transpose_rhs) { MatMulImpl(run_options_ptr, out, lhs, rhs, m, n, k, transpose_lhs, transpose_rhs); } -void __xla_cpu_runtime_EigenMatMulF32(const void* run_options_ptr, float* out, - float* lhs, float* rhs, int64 m, int64 n, - int64 k, int32 transpose_lhs, - int32 transpose_rhs) { +TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_EigenMatMulF32( + const void* run_options_ptr, float* out, float* lhs, float* rhs, int64 m, + int64 n, int64 k, int32 transpose_lhs, int32 transpose_rhs) { MatMulImpl(run_options_ptr, out, lhs, rhs, m, n, k, transpose_lhs, transpose_rhs); } -void __xla_cpu_runtime_EigenMatMulF64(const void* run_options_ptr, double* out, - double* lhs, double* rhs, int64 m, - int64 n, int64 k, int32 transpose_lhs, - int32 transpose_rhs) { +TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_EigenMatMulF64( + const void* run_options_ptr, double* out, double* lhs, double* rhs, int64 m, + int64 n, int64 k, int32 transpose_lhs, int32 transpose_rhs) { MatMulImpl(run_options_ptr, out, lhs, rhs, m, n, k, transpose_lhs, transpose_rhs); } diff --git a/tensorflow/compiler/xla/service/cpu/runtime_matmul_mkl.cc b/tensorflow/compiler/xla/service/cpu/runtime_matmul_mkl.cc index f8c8dd5e93d53db8d87be0208b5cf4daac3464f1..8dc5f3c93b6ba1a722ea7b23b4b5190ac0600cd6 100644 --- a/tensorflow/compiler/xla/service/cpu/runtime_matmul_mkl.cc +++ b/tensorflow/compiler/xla/service/cpu/runtime_matmul_mkl.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#if defined(INTEL_MKL) && !defined(DO_NOT_USE_ML) +#if defined(INTEL_MKL) && !defined(INTEL_MKL_DNN_ONLY) #include "tensorflow/compiler/xla/service/cpu/runtime_matmul_mkl.h" #include "third_party/intel_mkl_ml/include/mkl_cblas.h" #include "third_party/intel_mkl_ml/include/mkl_service.h" @@ -23,6 +23,7 @@ limitations under the License. #define EIGEN_USE_THREADS #include "third_party/eigen3/unsupported/Eigen/CXX11/ThreadPool" +#include "tensorflow/core/platform/dynamic_annotations.h" using tensorflow::int32; using tensorflow::int64; @@ -74,10 +75,9 @@ void MatMulF64(const void* run_options_ptr, double* out, double* lhs, } // namespace -void __xla_cpu_runtime_MKLMatMulF32(const void* run_options_ptr, float* out, - float* lhs, float* rhs, int64 m, int64 n, - int64 k, int32 transpose_lhs, - int32 transpose_rhs) { +TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_MKLMatMulF32( + const void* run_options_ptr, float* out, float* lhs, float* rhs, int64 m, + int64 n, int64 k, int32 transpose_lhs, int32 transpose_rhs) { const xla::ExecutableRunOptions* run_options = static_cast(run_options_ptr); // BLAS GEMM MatMul uses OpenMP for parallelization, so we pass the thread @@ -88,11 +88,11 @@ void __xla_cpu_runtime_MKLMatMulF32(const void* run_options_ptr, float* out, // Set thread number back to the previous number. mkl_set_num_threads_local(prev_num_threads); } + // BLAS GEMM API for 64-bit Matrix Multiplication -void __xla_cpu_runtime_MKLMatMulF64(const void* run_options_ptr, double* out, - double* lhs, double* rhs, int64 m, int64 n, - int64 k, int32 transpose_lhs, - int32 transpose_rhs) { +TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_MKLMatMulF64( + const void* run_options_ptr, double* out, double* lhs, double* rhs, int64 m, + int64 n, int64 k, int32 transpose_lhs, int32 transpose_rhs) { const xla::ExecutableRunOptions* run_options = static_cast(run_options_ptr); // BLAS GEMM MatMul uses OpenMP for parallelization, so we pass the thread @@ -103,22 +103,26 @@ void __xla_cpu_runtime_MKLMatMulF64(const void* run_options_ptr, double* out, // Set thread number back to the previous number. mkl_set_num_threads_local(prev_num_threads); } -void __xla_cpu_runtime_MKLSingleThreadedMatMulF32(const void* run_options_ptr, - float* out, float* lhs, - float* rhs, int64 m, int64 n, - int64 k, int32 transpose_lhs, - int32 transpose_rhs) { + +TF_ATTRIBUTE_NO_SANITIZE_MEMORY void +__xla_cpu_runtime_MKLSingleThreadedMatMulF32(const void* run_options_ptr, + float* out, float* lhs, float* rhs, + int64 m, int64 n, int64 k, + int32 transpose_lhs, + int32 transpose_rhs) { // Set the thread number to 1 for single threaded excution. int prev_num_threads = mkl_set_num_threads_local(1); MatMulF32(nullptr, out, lhs, rhs, m, n, k, transpose_lhs, transpose_rhs); // Set thread number back to the previous number. mkl_set_num_threads_local(prev_num_threads); } -void __xla_cpu_runtime_MKLSingleThreadedMatMulF64(const void* run_options_ptr, - double* out, double* lhs, - double* rhs, int64 m, int64 n, - int64 k, int32 transpose_lhs, - int32 transpose_rhs) { + +TF_ATTRIBUTE_NO_SANITIZE_MEMORY void +__xla_cpu_runtime_MKLSingleThreadedMatMulF64(const void* run_options_ptr, + double* out, double* lhs, + double* rhs, int64 m, int64 n, + int64 k, int32 transpose_lhs, + int32 transpose_rhs) { // Set the thread number to 1 for single threaded excution. int prev_num_threads = mkl_set_num_threads_local(1); MatMulF64(nullptr, out, lhs, rhs, m, n, k, transpose_lhs, transpose_rhs); diff --git a/tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul.cc b/tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul.cc index 17303e2f0d34e531a3a56aa147608b949e0f43ae..16692e7f2e6145b2649b67987eef47916e958be2 100644 --- a/tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul.cc +++ b/tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul.cc @@ -17,6 +17,7 @@ limitations under the License. #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #include "tensorflow/compiler/xla/service/cpu/runtime_matvec.h" +#include "tensorflow/core/platform/dynamic_annotations.h" #include "tensorflow/core/platform/types.h" using tensorflow::int32; @@ -71,7 +72,8 @@ void SingleThreadedMatMul(const void* run_options_ptr, T* out, T* lhs, T* rhs, } // namespace -void __xla_cpu_runtime_EigenSingleThreadedMatMulF16( +TF_ATTRIBUTE_NO_SANITIZE_MEMORY void +__xla_cpu_runtime_EigenSingleThreadedMatMulF16( const void* run_options_ptr, Eigen::half* out, Eigen::half* lhs, Eigen::half* rhs, int64 m, int64 n, int64 k, int32 transpose_lhs, int32 transpose_rhs) { @@ -79,16 +81,22 @@ void __xla_cpu_runtime_EigenSingleThreadedMatMulF16( transpose_lhs, transpose_rhs); } -void __xla_cpu_runtime_EigenSingleThreadedMatMulF32( - const void* run_options_ptr, float* out, float* lhs, float* rhs, int64 m, - int64 n, int64 k, int32 transpose_lhs, int32 transpose_rhs) { +TF_ATTRIBUTE_NO_SANITIZE_MEMORY void +__xla_cpu_runtime_EigenSingleThreadedMatMulF32(const void* run_options_ptr, + float* out, float* lhs, + float* rhs, int64 m, int64 n, + int64 k, int32 transpose_lhs, + int32 transpose_rhs) { SingleThreadedMatMul(run_options_ptr, out, lhs, rhs, m, n, k, transpose_lhs, transpose_rhs); } -void __xla_cpu_runtime_EigenSingleThreadedMatMulF64( - const void* run_options_ptr, double* out, double* lhs, double* rhs, int64 m, - int64 n, int64 k, int32 transpose_lhs, int32 transpose_rhs) { +TF_ATTRIBUTE_NO_SANITIZE_MEMORY void +__xla_cpu_runtime_EigenSingleThreadedMatMulF64(const void* run_options_ptr, + double* out, double* lhs, + double* rhs, int64 m, int64 n, + int64 k, int32 transpose_lhs, + int32 transpose_rhs) { SingleThreadedMatMul(run_options_ptr, out, lhs, rhs, m, n, k, transpose_lhs, transpose_rhs); } diff --git a/tensorflow/compiler/xla/service/cpu/sample_harness.cc b/tensorflow/compiler/xla/service/cpu/sample_harness.cc index eb83432f5785738bd2d5d534a2a3a360f11719a5..f227e4ae139b92e56786e38ef8eef72c9e2cd424 100644 --- a/tensorflow/compiler/xla/service/cpu/sample_harness.cc +++ b/tensorflow/compiler/xla/service/cpu/sample_harness.cc @@ -21,7 +21,7 @@ limitations under the License. #include "tensorflow/compiler/xla/client/client_library.h" #include "tensorflow/compiler/xla/client/global_data.h" #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/statusor.h" diff --git a/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc b/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc index be772cfb7e564cebc5725854dbf5678e5c507556..b026aef3fec729716234a1f38c4ac4993666aeb5 100644 --- a/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc +++ b/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc @@ -20,13 +20,13 @@ limitations under the License. #include #include +#include "absl/memory/memory.h" #include "llvm/ExecutionEngine/ExecutionEngine.h" #include "llvm/ExecutionEngine/JITSymbol.h" #include "llvm/ExecutionEngine/SectionMemoryManager.h" #include "llvm/IR/Mangler.h" #include "llvm/Support/CodeGen.h" #include "llvm/Support/Host.h" -#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/cpu/cpu_runtime.h" #include "tensorflow/compiler/xla/service/cpu/custom_call_target_registry.h" #include "tensorflow/compiler/xla/service/cpu/orc_jit_memory_mapper.h" diff --git a/tensorflow/compiler/xla/service/cpu/tests/BUILD b/tensorflow/compiler/xla/service/cpu/tests/BUILD index e6d25680b56bd79a249c0222552f310d1ea05ca8..4635fa5d74f86eb7f2543d263132d87e6eaa20e0 100644 --- a/tensorflow/compiler/xla/service/cpu/tests/BUILD +++ b/tensorflow/compiler/xla/service/cpu/tests/BUILD @@ -51,6 +51,7 @@ tf_cc_test( "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/core:test", "//tensorflow/core:test_main", + "@com_google_absl//absl/memory", ], ) @@ -94,6 +95,7 @@ tf_cc_test( "//tensorflow/compiler/xla/tests:filecheck", "//tensorflow/core:test", "//tensorflow/core:test_main", + "@com_google_absl//absl/memory", "@llvm//:core", ], ) @@ -135,9 +137,9 @@ tf_cc_test( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/client/lib:arithmetic", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", "//tensorflow/compiler/xla/service:cpu_plugin", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", diff --git a/tensorflow/compiler/xla/service/cpu/tests/cpu_fusion_test.cc b/tensorflow/compiler/xla/service/cpu/tests/cpu_fusion_test.cc index d98856fdbf4165a5909f193ebe8512e21af83dfc..b68ac67574d0b9f20ecc0370cdaed87d4465b225 100644 --- a/tensorflow/compiler/xla/service/cpu/tests/cpu_fusion_test.cc +++ b/tensorflow/compiler/xla/service/cpu/tests/cpu_fusion_test.cc @@ -17,8 +17,8 @@ limitations under the License. #include #include +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/literal.h" -#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" diff --git a/tensorflow/compiler/xla/service/cpu/tests/cpu_infeed_test.cc b/tensorflow/compiler/xla/service/cpu/tests/cpu_infeed_test.cc index be3fae5161be13f08c52db38cace6abc7e7486ed..c35569c6619ba5b534c5d8bb7ad683d84b6ecf4b 100644 --- a/tensorflow/compiler/xla/service/cpu/tests/cpu_infeed_test.cc +++ b/tensorflow/compiler/xla/service/cpu/tests/cpu_infeed_test.cc @@ -19,7 +19,7 @@ limitations under the License. #include "tensorflow/compiler/xla/client/global_data.h" #include "tensorflow/compiler/xla/client/lib/arithmetic.h" #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/shape_util.h" @@ -220,7 +220,7 @@ TEST_F(InfeedTest, DISABLED_TwoInfeedsInTotalOrder) { // The body adds the reduced value of the Infeed data (first tuple element) // to the previous accumulator, and returns the accumulator and the continue // flag (second tuple element) as a tuple. - const auto build_body = [this, &result_shape](const Shape& infeed_shape) { + const auto build_body = [&result_shape](const Shape& infeed_shape) { XlaComputation body; XlaBuilder builder("body"); auto prev = Parameter(&builder, 0, result_shape, "prev"); diff --git a/tensorflow/compiler/xla/service/cpu/tests/cpu_literal_caching_test.cc b/tensorflow/compiler/xla/service/cpu/tests/cpu_literal_caching_test.cc index 90b99c828e2fcfd77579026a39d3a6711599feee..3b87683ffffefd2aa24dd234cc072425bef00a24 100644 --- a/tensorflow/compiler/xla/service/cpu/tests/cpu_literal_caching_test.cc +++ b/tensorflow/compiler/xla/service/cpu/tests/cpu_literal_caching_test.cc @@ -38,7 +38,8 @@ while_body { while_cond { arg_cond = f32[2,3,2] parameter(0) - infeed = (pred[], token[]) infeed() + token = token[] after-all() + infeed = (pred[], token[]) infeed(token) ROOT unknown = pred[] get-tuple-element((pred[], token[]) infeed), index=0 } @@ -50,8 +51,9 @@ ENTRY main { {{2, 1}, {2001, 3002}, {2001, 2002}}}) const_b = f32[2,3,2] while(f32[2,3,2] const_a), condition=while_cond, body=while_body - out0 = token[] outfeed(f32[2,3,2] const_a) - ROOT out1 = token[] outfeed(f32[2,3,2] const_b) + token = token[] after-all() + out0 = token[] outfeed(f32[2,3,2] const_a, token[] token) + ROOT out1 = token[] outfeed(f32[2,3,2] const_b, token[] token) } )"; @@ -85,7 +87,8 @@ while_body { while_cond { arg_cond = (f32[2,1]{1,0}, f32[1]{0}) parameter(0) - infeed = (pred[], token[]) infeed() + token = token[] after-all() + infeed = (pred[], token[]) infeed(token) ROOT unknown = pred[] get-tuple-element((pred[], token[]) infeed), index=0 } @@ -94,8 +97,9 @@ ENTRY main { const_a = (f32[2,1]{1,0}, f32[1]{0}) constant((f32[2,1], f32[1]) ( f32[2,1] { { 1 }, { 2 } }, {2} )) const_b = (f32[2,1]{1,0}, f32[1]{0}) while((f32[2,1]{1,0}, f32[1]{0}) const_a), condition=while_cond, body=while_body - out0 = () outfeed((f32[2,1]{1,0}, f32[1]{0}) const_a) - ROOT out1 = () outfeed((f32[2,1]{1,0}, f32[1]{0}) const_b) + token = token[] after-all() + out0 = () outfeed((f32[2,1]{1,0}, f32[1]{0}) const_a, token[] token) + ROOT out1 = () outfeed((f32[2,1]{1,0}, f32[1]{0}) const_b, token[] token) } )"; diff --git a/tensorflow/compiler/xla/service/cpu/tests/cpu_noalias_test.cc b/tensorflow/compiler/xla/service/cpu/tests/cpu_noalias_test.cc index 01daed4bcd38323bfe33e798a78c2b00b150a1bc..bb105194f1c9001ca4d9fff9174e1ea7e5d8b72a 100644 --- a/tensorflow/compiler/xla/service/cpu/tests/cpu_noalias_test.cc +++ b/tensorflow/compiler/xla/service/cpu/tests/cpu_noalias_test.cc @@ -16,9 +16,9 @@ limitations under the License. #include #include +#include "absl/memory/memory.h" #include "llvm/IR/Module.h" #include "tensorflow/compiler/xla/literal.h" -#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/buffer_assignment.h" #include "tensorflow/compiler/xla/service/cpu/tests/cpu_codegen_test.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" @@ -62,7 +62,8 @@ TEST_F(CpuNoAliasTest, Concat) { // Now that we have an HLO module, build an llvm_ir::AliasAnalysis for it. auto status_or_buffer_assn = BufferAssigner::Run( - hlo_module.get(), MakeUnique(hlo_module.get()), + hlo_module.get(), + absl::make_unique(hlo_module.get()), backend().compiler()->BufferSizeBytesFunction(), [](LogicalBuffer::Color) { return /*alignment=*/1; }); ASSERT_EQ(status_or_buffer_assn.status(), Status::OK()); diff --git a/tensorflow/compiler/xla/service/cpu/tests/cpu_outfeed_test.cc b/tensorflow/compiler/xla/service/cpu/tests/cpu_outfeed_test.cc index dac416e1c78c2f60d458480c5062f48b77d4878d..780c07f819ea2f94ed2f27dc0be0983f0389bfbc 100644 --- a/tensorflow/compiler/xla/service/cpu/tests/cpu_outfeed_test.cc +++ b/tensorflow/compiler/xla/service/cpu/tests/cpu_outfeed_test.cc @@ -32,7 +32,8 @@ ENTRY main { {{{1, 2}, {1001, 1002}, {2001, 2002}}, {{2, 1}, {2001, 3002}, {2001, 2002}}}) - outfeed = token[] outfeed(f32[2,3,2] const_a) + token = token[] after-all() + outfeed = token[] outfeed(f32[2,3,2] const_a, token) ROOT root = () tuple() } )"; diff --git a/tensorflow/compiler/xla/service/cpu/vector_support_library.cc b/tensorflow/compiler/xla/service/cpu/vector_support_library.cc index 3274be8d9dbfaa55e250748a389ad34fdeb81922..962ea69c09487735a7d5e3309dfbf2969655da81 100644 --- a/tensorflow/compiler/xla/service/cpu/vector_support_library.cc +++ b/tensorflow/compiler/xla/service/cpu/vector_support_library.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/cpu/vector_support_library.h" +#include "absl/algorithm/container.h" #include "llvm/Support/raw_ostream.h" #include "tensorflow/compiler/xla/service/cpu/target_machine_features.h" #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" @@ -422,8 +423,8 @@ TileVariable::TileVariable(VectorSupportLibrary* vector_support, std::vector TileVariable::Get() const { std::vector result; - c_transform(storage_, std::back_inserter(result), - [&](VectorVariable vect_var) { return vect_var.Get(); }); + absl::c_transform(storage_, std::back_inserter(result), + [&](VectorVariable vect_var) { return vect_var.Get(); }); return result; } diff --git a/tensorflow/compiler/xla/service/despecializer.cc b/tensorflow/compiler/xla/service/despecializer.cc index d938f3a2c4b5bfdd70d5a614b9890b4d7bf050f7..48e44714998f61c9bdccaa43719abc533eb83565 100644 --- a/tensorflow/compiler/xla/service/despecializer.cc +++ b/tensorflow/compiler/xla/service/despecializer.cc @@ -21,8 +21,33 @@ limitations under the License. namespace xla { +namespace { + +// Pass which strips control dependencies from all instructions in the module. +class ControlDepRemover : public HloPassInterface { + public: + ControlDepRemover() = default; + tensorflow::StringPiece name() const override { + return "control-dep-remover"; + } + + StatusOr Run(HloModule* module) override { + bool changed = false; + for (HloComputation* computation : module->computations()) { + for (HloInstruction* instruction : computation->instructions()) { + changed = changed || !instruction->control_predecessors().empty(); + TF_RETURN_IF_ERROR(instruction->DropAllControlDeps()); + } + } + return changed; + } +}; + +} // namespace + Despecializer::Despecializer() : pipeline_("despecializer") { // TODO(b/70588125): Also deal with window reversal in a fast way. + pipeline_.AddPass(); pipeline_.AddPass(); pipeline_.AddPass(); pipeline_.AddPass(); diff --git a/tensorflow/compiler/xla/service/dfs_hlo_visitor.h b/tensorflow/compiler/xla/service/dfs_hlo_visitor.h index 097fa23027bf55ad0b92c347c5a1209bb5836695..86d57581f84920e8005e8f3c420e7488fc095434 100644 --- a/tensorflow/compiler/xla/service/dfs_hlo_visitor.h +++ b/tensorflow/compiler/xla/service/dfs_hlo_visitor.h @@ -106,6 +106,7 @@ class DfsHloVisitorBase { virtual Status HandleConvolution(HloInstructionPtr hlo) = 0; virtual Status HandleFft(HloInstructionPtr fft) = 0; virtual Status HandleCrossReplicaSum(HloInstructionPtr hlo) = 0; + virtual Status HandleAllToAll(HloInstructionPtr hlo) = 0; virtual Status HandleCompare(HloInstructionPtr hlo) { return HandleElementwiseBinary(hlo); } @@ -233,6 +234,7 @@ class DfsHloVisitorBase { virtual Status HandleWhile(HloInstructionPtr hlo) = 0; virtual Status HandleConditional(HloInstructionPtr hlo) = 0; virtual Status HandleGather(HloInstructionPtr hlo) = 0; + virtual Status HandleScatter(HloInstructionPtr hlo) = 0; virtual Status HandlePad(HloInstructionPtr hlo) = 0; diff --git a/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h b/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h index f4316e0fb77855aad1c4710908df09c604da896e..617a5a2eb4796d8003099e39e3d26389e532e954 100644 --- a/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h +++ b/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h @@ -94,6 +94,9 @@ class DfsHloVisitorWithDefaultBase Status HandleCrossReplicaSum(HloInstructionPtr crs) override { return DefaultAction(crs); } + Status HandleAllToAll(HloInstructionPtr crs) override { + return DefaultAction(crs); + } Status HandleRng(HloInstructionPtr random) override { return DefaultAction(random); } @@ -194,6 +197,9 @@ class DfsHloVisitorWithDefaultBase Status HandleGather(HloInstructionPtr gather) override { return DefaultAction(gather); } + Status HandleScatter(HloInstructionPtr scatter) override { + return DefaultAction(scatter); + } Status HandleAfterAll(HloInstructionPtr token) override { return DefaultAction(token); } diff --git a/tensorflow/compiler/xla/service/elemental_ir_emitter.cc b/tensorflow/compiler/xla/service/elemental_ir_emitter.cc index f883eb828c7f6365dfd4d5e0b514dc6894adc12b..4b19aa5df972001ab1975fac5f88ad02703ff84b 100644 --- a/tensorflow/compiler/xla/service/elemental_ir_emitter.cc +++ b/tensorflow/compiler/xla/service/elemental_ir_emitter.cc @@ -21,6 +21,7 @@ limitations under the License. #include // IWYU pragma: no_include "llvm/IR/Intrinsics.gen.inc" +#include "absl/algorithm/container.h" #include "llvm/IR/BasicBlock.h" #include "llvm/IR/Instructions.h" #include "llvm/IR/Intrinsics.h" @@ -431,6 +432,8 @@ StatusOr ElementalIrEmitter::EmitFloatUnaryOp( return EmitCos(op->shape().element_type(), operand_value); case HloOpcode::kSin: return EmitSin(op->shape().element_type(), operand_value); + case HloOpcode::kTanh: + return EmitTanh(op->shape().element_type(), operand_value); case HloOpcode::kFloor: return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::floor, {operand_value}, @@ -1060,6 +1063,11 @@ StatusOr ElementalIrEmitter::EmitAtan2(PrimitiveType prim_type, return Unimplemented("atan2"); } +StatusOr ElementalIrEmitter::EmitTanh(PrimitiveType prim_type, + llvm::Value* value) const { + return Unimplemented("tanh"); +} + StatusOr ElementalIrEmitter::EmitReducePrecision( const HloInstruction* hlo, llvm::Value* x) const { if (hlo->operand(0)->shape().element_type() != F32) { @@ -1239,13 +1247,23 @@ StatusOr ElementalIrEmitter::ConvertValueForDistribution( // Convert raw integer to float in range [0, 1) if the element is a float. llvm::Value* elem_value = raw_value; if (elem_ir_ty->isFloatingPointTy()) { - elem_value = b_->CreateUIToFP(elem_value, elem_ir_ty); unsigned raw_value_size_in_bits = raw_value_ty->getPrimitiveSizeInBits(); CHECK(raw_value_size_in_bits == 32 || raw_value_size_in_bits == 64); - elem_value = b_->CreateFDiv( - elem_value, - llvm::ConstantFP::get(elem_ir_ty, - raw_value_size_in_bits == 64 ? 0x1p64 : 0x1p32)); + // Perform the division using the float type with the same number of bits + // as the raw value to avoid overflow. + if (raw_value_size_in_bits == 32) { + elem_value = b_->CreateUIToFP(elem_value, b_->getFloatTy()); + elem_value = b_->CreateFDiv( + elem_value, llvm::ConstantFP::get(b_->getFloatTy(), std::exp2(32))); + } else { + elem_value = b_->CreateUIToFP(elem_value, b_->getDoubleTy()); + elem_value = b_->CreateFDiv( + elem_value, llvm::ConstantFP::get(b_->getDoubleTy(), std::exp2(64))); + } + + if (elem_ir_ty != elem_value->getType()) { + elem_value = b_->CreateFPTrunc(elem_value, elem_ir_ty); + } } // Convert the value for the requested distribution. @@ -1302,6 +1320,7 @@ int32 GetNumberOfElementsPerPhiloxRngSample(PrimitiveType elem_prim_ty) { case F16: return 4; case U64: + case S64: case F64: return 2; default: @@ -1654,22 +1673,21 @@ StatusOr ElementalIrEmitter::EmitElementalGather( std::vector operand_to_output_dim(operand_shape.dimensions_size(), -1); for (int64 i = 0, e = operand_shape.dimensions_size(), operand_index_dim = 0; i < e; i++) { - if (c_binary_search(dim_numbers.elided_window_dims(), i)) { + if (absl::c_binary_search(dim_numbers.collapsed_slice_dims(), i)) { operand_index.push_back(index.GetConstantWithIndexType(0)); } else { - int64 output_window_dim = - dim_numbers.output_window_dims(operand_index_dim++); + int64 output_window_dim = dim_numbers.offset_dims(operand_index_dim++); operand_to_output_dim[i] = output_window_dim; operand_index.push_back(index[output_window_dim]); } } - // This is the index of the index vector in the gather_indices tensor. + // This is the index of the index vector in the start_indices tensor. IrArray::Index gather_index_index(index_type); { std::vector gather_index_index_components; for (int64 i = 0, e = output_shape.dimensions_size(); i < e; i++) { - if (!c_binary_search(dim_numbers.output_window_dims(), i)) { + if (!absl::c_binary_search(dim_numbers.offset_dims(), i)) { gather_index_index.push_back(index[i]); } } @@ -1682,7 +1700,7 @@ StatusOr ElementalIrEmitter::EmitElementalGather( auto add_to_operand_index = [&](llvm::Value* index_component, int64 dim) { llvm::Value* gather_dim_component_extended = b_->CreateSExtOrTrunc(index_component, index_type); - int64 operand_dim = dim_numbers.gather_dims_to_operand_dims(dim); + int64 operand_dim = dim_numbers.start_index_map(dim); int64 output_dim = operand_to_output_dim[operand_dim]; // If 'output_dim' is -1, it means 'operand_dim' is an elided window dim. // This means we set the iteration index to 0, so for the purpose of the @@ -2134,7 +2152,7 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator( return EmitElementalDot(hlo, operand_to_generator, dot_result_index); }; default: - return [this, hlo, &operand_to_generator](const IrArray::Index& index) { + return [hlo](const IrArray::Index& index) { return Unimplemented("Unhandled opcode for elemental IR emission: %s", HloOpcodeString(hlo->opcode()).c_str()); }; diff --git a/tensorflow/compiler/xla/service/elemental_ir_emitter.h b/tensorflow/compiler/xla/service/elemental_ir_emitter.h index fcb34557a52d35ef30a5dee643171e17407d05c2..1598a4dd85632cfa9835a81a21eddff3e57bfa1f 100644 --- a/tensorflow/compiler/xla/service/elemental_ir_emitter.h +++ b/tensorflow/compiler/xla/service/elemental_ir_emitter.h @@ -122,6 +122,9 @@ class ElementalIrEmitter { llvm::Value* lhs, llvm::Value* rhs) const; + virtual StatusOr EmitTanh(PrimitiveType prim_type, + llvm::Value* value) const; + virtual StatusOr EmitReducePrecision(const HloInstruction* hlo, llvm::Value* x) const; diff --git a/tensorflow/compiler/xla/service/executable.cc b/tensorflow/compiler/xla/service/executable.cc index fd75847d0c0e737957401b8efc420d504a3c0706..1c9f396b68fa20a03986d81d642d1726b26cd0dc 100644 --- a/tensorflow/compiler/xla/service/executable.cc +++ b/tensorflow/compiler/xla/service/executable.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/executable.h" +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/service/hlo_graph_dumper.h" #include "tensorflow/compiler/xla/status.h" @@ -76,8 +77,8 @@ StatusOr Executable::ExecuteOnStreamWrapper( std::unique_ptr profile_ptr = module_config().debug_options().xla_hlo_profile() && hlo_profiling_enabled() - ? MakeUnique(&hlo_profile_printer_data(), - &hlo_profile_index_map()) + ? absl::make_unique(&hlo_profile_printer_data(), + &hlo_profile_index_map()) : nullptr; StatusOr return_value = diff --git a/tensorflow/compiler/xla/service/execution_tracker.cc b/tensorflow/compiler/xla/service/execution_tracker.cc index 6794cfe297b0fb9a15eb9b7e6906d225f9597d07..70a78c8a2b6f3cf360ca2ac7255f8dc35235125e 100644 --- a/tensorflow/compiler/xla/service/execution_tracker.cc +++ b/tensorflow/compiler/xla/service/execution_tracker.cc @@ -17,7 +17,7 @@ limitations under the License. #include -#include "tensorflow/compiler/xla/ptr_util.h" +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" @@ -25,7 +25,7 @@ limitations under the License. namespace xla { AsyncExecution::AsyncExecution(Backend* backend, - std::vector streams, + std::vector streams, const ExecutionProfile& profile, GlobalDataHandle result) : backend_(CHECK_NOTNULL(backend)), @@ -46,14 +46,15 @@ Status AsyncExecution::BlockUntilDone() const { ExecutionTracker::ExecutionTracker() : next_handle_(1) {} -ExecutionHandle ExecutionTracker::Register( - Backend* backend, std::vector streams, - const ExecutionProfile& profile, GlobalDataHandle result) { +ExecutionHandle ExecutionTracker::Register(Backend* backend, + std::vector streams, + const ExecutionProfile& profile, + GlobalDataHandle result) { tensorflow::mutex_lock lock(execution_mutex_); int64 handle = next_handle_++; auto inserted = handle_to_execution_.emplace( - handle, - MakeUnique(backend, std::move(streams), profile, result)); + handle, absl::make_unique(backend, std::move(streams), + profile, result)); CHECK(inserted.second); ExecutionHandle execution_handle; diff --git a/tensorflow/compiler/xla/service/execution_tracker.h b/tensorflow/compiler/xla/service/execution_tracker.h index 4458152dd9a98890fc3a3e7f324245ec68821467..4e9b9f883e26f5564a9c63a40d2b4b9348908214 100644 --- a/tensorflow/compiler/xla/service/execution_tracker.h +++ b/tensorflow/compiler/xla/service/execution_tracker.h @@ -22,7 +22,7 @@ limitations under the License. #include "tensorflow/compiler/xla/executable_run_options.h" #include "tensorflow/compiler/xla/service/backend.h" -#include "tensorflow/compiler/xla/service/pool.h" +#include "tensorflow/compiler/xla/service/stream_pool.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" @@ -40,7 +40,7 @@ namespace xla { // the stream when destructed. class AsyncExecution { public: - AsyncExecution(Backend* backend, std::vector streams, + AsyncExecution(Backend* backend, std::vector streams, const ExecutionProfile& profile, GlobalDataHandle result); Status BlockUntilDone() const; @@ -54,7 +54,7 @@ class AsyncExecution { Backend* backend_; // Stream on which the execution is launched. - std::vector streams_; + std::vector streams_; // Profile object of the execution to be returned to the user. ExecutionProfile profile_; @@ -72,7 +72,7 @@ class ExecutionTracker { // Registers an execution with its backend, streams, and data handle to the // execution result. Returns a handle for the registered execution. ExecutionHandle Register(Backend* backend, - std::vector stream, + std::vector stream, const ExecutionProfile& profile, GlobalDataHandle data); diff --git a/tensorflow/compiler/xla/service/gather_expander.cc b/tensorflow/compiler/xla/service/gather_expander.cc index e3a42d0d06be9e4c9ef96ed2e6ff5daa8eebaf3e..d889fd8e88ed4008749c116314e9a0c54e6fa63d 100644 --- a/tensorflow/compiler/xla/service/gather_expander.cc +++ b/tensorflow/compiler/xla/service/gather_expander.cc @@ -15,6 +15,7 @@ limitations under the License. #include +#include "absl/algorithm/container.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/service/gather_expander.h" #include "tensorflow/compiler/xla/service/hlo_creation_utils.h" @@ -27,85 +28,85 @@ namespace xla { using tensorflow::gtl::ArraySlice; static StatusOr TransposeIndexVectorDimToLast( - HloInstruction* gather_indices, int64 index_vector_dim) { - const Shape& gather_indices_shape = gather_indices->shape(); + HloInstruction* start_indices, int64 index_vector_dim) { + const Shape& start_indices_shape = start_indices->shape(); - if (gather_indices_shape.dimensions_size() == index_vector_dim) { - return gather_indices; + if (start_indices_shape.dimensions_size() == index_vector_dim) { + return start_indices; } - if (index_vector_dim == (gather_indices_shape.dimensions_size() - 1)) { - return gather_indices; + if (index_vector_dim == (start_indices_shape.dimensions_size() - 1)) { + return start_indices; } std::vector permutation; - permutation.reserve(gather_indices_shape.dimensions_size()); - for (int64 i = 0, e = gather_indices_shape.dimensions_size(); i < e; i++) { + permutation.reserve(start_indices_shape.dimensions_size()); + for (int64 i = 0, e = start_indices_shape.dimensions_size(); i < e; i++) { if (i != index_vector_dim) { permutation.push_back(i); } } permutation.push_back(index_vector_dim); - return MakeTransposeHlo(gather_indices, permutation); + return MakeTransposeHlo(start_indices, permutation); } -// Canonicalizes the gather_indices tensors so that we only have deal with some +// Canonicalizes the start_indices tensors so that we only have deal with some // specific cases in the while loop that does the heavy lifting. // // See the "High Level Algorithm" section for a broader picture. static StatusOr CanonicalizeGatherIndices( - HloInstruction* gather_indices, int64 index_vector_dim) { + HloInstruction* start_indices, int64 index_vector_dim) { // Transpose the non-index-vector dimensions to the front. TF_ASSIGN_OR_RETURN( - HloInstruction * transposed_gather_indices, - TransposeIndexVectorDimToLast(gather_indices, index_vector_dim)); + HloInstruction * transposed_start_indices, + TransposeIndexVectorDimToLast(start_indices, index_vector_dim)); bool indices_are_scalar = - index_vector_dim == gather_indices->shape().dimensions_size(); + index_vector_dim == start_indices->shape().dimensions_size(); - // The number of dimensions in gather_indices that are index dimensions. - const int64 index_dims_in_gather_indices = indices_are_scalar ? 0 : 1; + // The number of dimensions in start_indices that are index dimensions. + const int64 index_dims_in_start_indices = indices_are_scalar ? 0 : 1; - // If there is only one index (i.e. gather_indices has rank 1 and this gather + // If there is only one index (i.e. start_indices has rank 1 and this gather // is really just a dynamic slice) add a leading degenerate dimension for // uniformity. Otherwise create a "collapsed" leading dimension that subsumes // all of the non-index-vector dimensions. - const Shape& shape = transposed_gather_indices->shape(); - if (shape.dimensions_size() == index_dims_in_gather_indices) { - return PrependDegenerateDims(transposed_gather_indices, 1); + const Shape& shape = transposed_start_indices->shape(); + if (shape.dimensions_size() == index_dims_in_start_indices) { + return PrependDegenerateDims(transposed_start_indices, 1); } else { - // Collapse all but the dimensions (0 or 1) in gather_indices containing the + // Collapse all but the dimensions (0 or 1) in start_indices containing the // index vectors. return CollapseFirstNDims( - transposed_gather_indices, - shape.dimensions_size() - index_dims_in_gather_indices); + transposed_start_indices, + shape.dimensions_size() - index_dims_in_start_indices); } } // Expands out or contracts away the gather dimensions in the accumulator // produced by the while loop. -static StatusOr AdjustGatherDimsInAccumulator( - const Shape& gather_indices_shape, HloInstruction* accumulator, +static StatusOr AdjustBatchDimsInAccumulator( + const Shape& start_indices_shape, HloInstruction* accumulator, int64 index_vector_dim) { - std::vector output_gather_dim_bounds; - output_gather_dim_bounds.reserve(gather_indices_shape.dimensions_size()); - for (int64 i = 0, e = gather_indices_shape.dimensions_size(); i < e; i++) { + std::vector batch_dim_bounds; + batch_dim_bounds.reserve(start_indices_shape.dimensions_size()); + for (int64 i = 0, e = start_indices_shape.dimensions_size(); i < e; i++) { if (i != index_vector_dim) { - output_gather_dim_bounds.push_back(gather_indices_shape.dimensions(i)); + batch_dim_bounds.push_back(start_indices_shape.dimensions(i)); } } - if (output_gather_dim_bounds.empty()) { - // If output_gather_dim_bounds is empty we must be lowering a (effectively) + if (batch_dim_bounds.empty()) { + // If batch_dim_bounds is empty we must be lowering a (effectively) // dynamic-slice. In that case, there is a leading degenerate gather // dimension that we added to make this special case play well with the // general while loop which we need to remove now. return ElideDegenerateDims(accumulator, {0}); } - return ExpandFirstDimIntoNDims(accumulator, output_gather_dim_bounds); + return ExpandFirstDimIntoNDims(accumulator, batch_dim_bounds); } -// Expand an index vector from the gather_indices tensor into a vector that can +// Expand an index vector from the start_indices tensor into a vector that can // be used to dynamic-slice out of the gather operand. static StatusOr ExpandIndexVectorIntoOperandSpace( HloInstruction* index_vector, const GatherDimensionNumbers& dim_numbers, @@ -121,10 +122,8 @@ static StatusOr ExpandIndexVectorIntoOperandSpace( std::vector expanded_index_components; for (int i = 0; i < operand_rank; i++) { - int64 index_vector_dim_index = - FindIndex(dim_numbers.gather_dims_to_operand_dims(), i); - if (index_vector_dim_index != - dim_numbers.gather_dims_to_operand_dims_size()) { + int64 index_vector_dim_index = FindIndex(dim_numbers.start_index_map(), i); + if (index_vector_dim_index != dim_numbers.start_index_map_size()) { TF_ASSIGN_OR_RETURN( HloInstruction * component_to_concat, MakeSliceHlo(index_vector, /*start_indices=*/{index_vector_dim_index}, @@ -147,10 +146,10 @@ static StatusOr> GatherLoopBody( const GatherDimensionNumbers& dim_numbers = gather.gather_dimension_numbers(); CHECK_EQ(incoming_loop_state.size(), 3); HloInstruction* const operand = incoming_loop_state[0]; - HloInstruction* const gather_indices = incoming_loop_state[1]; + HloInstruction* const start_indices = incoming_loop_state[1]; HloInstruction* const output_accumulator = incoming_loop_state[2]; - bool has_scalar_indices = gather_indices->shape().dimensions_size() == 1; + bool has_scalar_indices = start_indices->shape().dimensions_size() == 1; CHECK_EQ(has_scalar_indices, dim_numbers.index_vector_dim() == gather.operand(1)->shape().dimensions_size()); @@ -163,24 +162,24 @@ static StatusOr> GatherLoopBody( HloInstruction* index_vector; if (has_scalar_indices) { - // In this case gather_indices has rank 1 and induction_var_as_vector (of + // In this case start_indices has rank 1 and induction_var_as_vector (of // shape {1}) is an index into this rank 1 tensor. TF_ASSIGN_OR_RETURN( index_vector, - MakeDynamicSliceHlo(gather_indices, induction_var_as_vector, {1})); + MakeDynamicSliceHlo(start_indices, induction_var_as_vector, {1})); } else { - // In this case gather_indices has rank 2 and induction_var_as_vector (of + // In this case start_indices has rank 2 and induction_var_as_vector (of // shape {1}) is an index into just the first dimension of this rank 2 // tensor. TF_ASSIGN_OR_RETURN( - HloInstruction * index_into_gather_indices, + HloInstruction * index_into_start_indices, PadVectorWithZeros(induction_var_as_vector, /*zeros_to_prepend=*/0, /*zeros_to_append=*/1)); - int64 index_vector_size = gather_indices->shape().dimensions(1); + int64 index_vector_size = start_indices->shape().dimensions(1); TF_ASSIGN_OR_RETURN( HloInstruction * index_vector_2d, - MakeDynamicSliceHlo(gather_indices, index_into_gather_indices, + MakeDynamicSliceHlo(start_indices, index_into_start_indices, {1, index_vector_size})); TF_ASSIGN_OR_RETURN(index_vector, @@ -194,26 +193,26 @@ static StatusOr> GatherLoopBody( TF_ASSIGN_OR_RETURN(HloInstruction * gathered_slice, MakeDynamicSliceHlo(operand, gathered_slice_start, - gather.gather_window_bounds())); + gather.gather_slice_sizes())); TF_ASSIGN_OR_RETURN( - HloInstruction * gathered_slice_with_dims_elided, + HloInstruction* const gathered_slice_with_dims_collapsed, ElideDegenerateDims(gathered_slice, - AsInt64Slice(dim_numbers.elided_window_dims()))); + AsInt64Slice(dim_numbers.collapsed_slice_dims()))); TF_ASSIGN_OR_RETURN( - HloInstruction * gathered_slice_for_update, - PrependDegenerateDims(gathered_slice_with_dims_elided, 1)); + HloInstruction* const gathered_slice_for_update, + PrependDegenerateDims(gathered_slice_with_dims_collapsed, 1)); TF_ASSIGN_OR_RETURN( - HloInstruction * index_vector_into_accumulator, + HloInstruction* const index_vector_into_accumulator, PadVectorWithZeros( induction_var_as_vector, /*zeros_to_prepend=*/0, /*zeros_to_append=*/ - gathered_slice_with_dims_elided->shape().dimensions_size())); + gathered_slice_with_dims_collapsed->shape().dimensions_size())); TF_ASSIGN_OR_RETURN( - HloInstruction * updated_accumulator, + HloInstruction* const updated_accumulator, MakeDynamicUpdateSliceHlo(output_accumulator, gathered_slice_for_update, index_vector_into_accumulator)); @@ -221,19 +220,19 @@ static StatusOr> GatherLoopBody( // WhileUtil::MakeCountedLoop functions takes care of the induction variable // and the while loop exit condition. return StatusOr>{ - {operand, gather_indices, updated_accumulator}}; + {operand, start_indices, updated_accumulator}}; } static StatusOr CreateGatherLoopAccumulatorInitValue( HloComputation* computation, PrimitiveType element_type, - ArraySlice window_bounds, int64 gather_loop_trip_count, + ArraySlice slice_sizes, int64 gather_loop_trip_count, const GatherDimensionNumbers& dim_numbers) { std::vector accumulator_state_shape_dims; - accumulator_state_shape_dims.reserve(1 + window_bounds.size()); + accumulator_state_shape_dims.reserve(1 + slice_sizes.size()); accumulator_state_shape_dims.push_back(gather_loop_trip_count); - for (int64 i = 0; i < window_bounds.size(); i++) { - if (!c_binary_search(dim_numbers.elided_window_dims(), i)) { - accumulator_state_shape_dims.push_back(window_bounds[i]); + for (int64 i = 0; i < slice_sizes.size(); i++) { + if (!absl::c_binary_search(dim_numbers.collapsed_slice_dims(), i)) { + accumulator_state_shape_dims.push_back(slice_sizes[i]); } } return BroadcastZeros(computation, element_type, @@ -241,23 +240,23 @@ static StatusOr CreateGatherLoopAccumulatorInitValue( } // `accumulator` is almost the tensor the gather operation would have produced, -// except that it has the dimensions in the wrong order -- the gather dimensions -// are the major dimensions and the window dimensions are the minor dimensions. +// except that it has the dimensions in the wrong order -- the batch dimensions +// are the major dimensions and the offset dimensions are the minor dimensions. // Fix this up with a transpose. -static StatusOr PermuteGatherAndWindowDims( - HloInstruction* accumulator, ArraySlice output_window_dims, +static StatusOr PermuteBatchAndOffsetDims( + HloInstruction* accumulator, ArraySlice offset_dims, int64 output_rank) { std::vector permutation; permutation.reserve(output_rank); - int64 gather_idx_counter = 0; - int64 window_idx_counter = output_rank - output_window_dims.size(); + int64 batch_idx_counter = 0; + int64 offset_idx_counter = output_rank - offset_dims.size(); for (int64 i = 0; i < output_rank; i++) { - bool is_window_dim = c_binary_search(output_window_dims, i); - if (is_window_dim) { - permutation.push_back(window_idx_counter++); + bool is_offset_dim = absl::c_binary_search(offset_dims, i); + if (is_offset_dim) { + permutation.push_back(offset_idx_counter++); } else { - permutation.push_back(gather_idx_counter++); + permutation.push_back(batch_idx_counter++); } } @@ -268,11 +267,11 @@ static StatusOr PermuteGatherAndWindowDims( // // We follow the following steps in sequence: // -// 1. We canonicalize the gather_indices tensor such that it has rank +// 1. We canonicalize the start_indices tensor such that it has rank // 2 (i.e. is a matrix) where each row is an index vector into the // operand. // 2. We iterate over the set of indices in the canonicalized -// gather_indices tensor using a while loop, accumulating slices +// start_indices tensor using a while loop, accumulating slices // of the operand tensor into an accumulator using // DynamicUpdateSlice. // 3. The accumulator result from the while loop from (2) is then @@ -287,11 +286,11 @@ static StatusOr PermuteGatherAndWindowDims( // operand = s32[3,3] parameter(0) // indices = s32[2,2] parameter(1) // ROOT gather = s32[2,3,2] gather(operand, indices), -// output_window_dims={1}, -// elided_window_dims={1}, -// gather_dims_to_operand_dims={1}, +// offset_dims={1}, +// collapsed_slice_dims={1}, +// start_index_map={1}, // index_vector_dim=2, -// window_bounds={3, 1} +// slice_sizes={3, 1} // } // // We'd first reshape indices to s32[4,1], where each row is an index @@ -305,8 +304,8 @@ StatusOr GatherExpander::ExpandGather( HloComputation* computation = gather_instr->parent(); HloInstruction* operand = gather_instr->mutable_operand(0); - HloInstruction* gather_indices = gather_instr->mutable_operand(1); - const Shape& gather_indices_shape = gather_indices->shape(); + HloInstruction* start_indices = gather_instr->mutable_operand(1); + const Shape& start_indices_shape = start_indices->shape(); const Shape& output_shape = gather_instr->shape(); int64 output_rank = output_shape.dimensions_size(); @@ -314,9 +313,9 @@ StatusOr GatherExpander::ExpandGather( gather_instr->gather_dimension_numbers(); int64 gather_loop_trip_count = 1; - for (int64 i = 0, e = gather_indices_shape.dimensions_size(); i < e; i++) { + for (int64 i = 0, e = start_indices_shape.dimensions_size(); i < e; i++) { if (i != dim_numbers.index_vector_dim()) { - gather_loop_trip_count *= gather_indices_shape.dimensions(i); + gather_loop_trip_count *= start_indices_shape.dimensions(i); } } @@ -327,24 +326,24 @@ StatusOr GatherExpander::ExpandGather( gather_instr->ToString().c_str()); } - TF_ASSIGN_OR_RETURN(HloInstruction * canonical_gather_indices, - CanonicalizeGatherIndices( - gather_indices, dim_numbers.index_vector_dim())); + TF_ASSIGN_OR_RETURN( + HloInstruction * canonical_start_indices, + CanonicalizeGatherIndices(start_indices, dim_numbers.index_vector_dim())); CHECK_EQ(gather_loop_trip_count, - canonical_gather_indices->shape().dimensions(0)); + canonical_start_indices->shape().dimensions(0)); TF_ASSIGN_OR_RETURN( HloInstruction * accumulator_init, CreateGatherLoopAccumulatorInitValue( computation, output_shape.element_type(), - gather_instr->gather_window_bounds(), gather_loop_trip_count, + gather_instr->gather_slice_sizes(), gather_loop_trip_count, gather_instr->gather_dimension_numbers())); StatusOr> gather_loop_result_or_error = WhileUtil::MakeCountedLoop( computation, gather_loop_trip_count, - {operand, canonical_gather_indices, accumulator_init}, + {operand, canonical_start_indices, accumulator_init}, [&](HloInstruction* indvar, const std::vector& loop_state) { return GatherLoopBody(*gather_instr, indvar, loop_state); @@ -356,13 +355,13 @@ StatusOr GatherExpander::ExpandGather( HloInstruction* accumulator_result = gather_loop_result.back(); TF_ASSIGN_OR_RETURN( - HloInstruction * accumulator_with_output_gather_dims_decanonicalized, - AdjustGatherDimsInAccumulator(gather_indices->shape(), accumulator_result, - dim_numbers.index_vector_dim())); + HloInstruction* const accumulator_with_batch_dims_decanonicalized, + AdjustBatchDimsInAccumulator(start_indices->shape(), accumulator_result, + dim_numbers.index_vector_dim())); - return PermuteGatherAndWindowDims( - accumulator_with_output_gather_dims_decanonicalized, - AsInt64Slice(dim_numbers.output_window_dims()), output_rank); + return PermuteBatchAndOffsetDims(accumulator_with_batch_dims_decanonicalized, + AsInt64Slice(dim_numbers.offset_dims()), + output_rank); } StatusOr GatherExpander::Run(HloModule* module) { @@ -375,8 +374,8 @@ StatusOr GatherExpander::Run(HloModule* module) { std::vector gather_instrs; for (HloComputation* computation : module->MakeNonfusionComputations()) { - c_copy_if(computation->instructions(), std::back_inserter(gather_instrs), - is_nontrivial_gather); + absl::c_copy_if(computation->instructions(), + std::back_inserter(gather_instrs), is_nontrivial_gather); } for (HloInstruction* inst : gather_instrs) { diff --git a/tensorflow/compiler/xla/service/gather_expander_test.cc b/tensorflow/compiler/xla/service/gather_expander_test.cc index 020ffcd106862cb2641a9f3bceb70acdd969a458..141dd4d6f10272ce749edc4e91153c365ed322e6 100644 --- a/tensorflow/compiler/xla/service/gather_expander_test.cc +++ b/tensorflow/compiler/xla/service/gather_expander_test.cc @@ -28,11 +28,11 @@ ENTRY main { operand = s32[3,3] parameter(0) indices = s32[2147483647,5] parameter(1) ROOT gather = s32[2147483647,3,5] gather(operand, indices), - output_window_dims={1}, - elided_window_dims={1}, - gather_dims_to_operand_dims={1}, + offset_dims={1}, + collapsed_slice_dims={1}, + start_index_map={1}, index_vector_dim=2, - window_bounds={3, 1} + slice_sizes={3, 1} } )"; TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, @@ -55,11 +55,11 @@ ENTRY main { operand = s32[3,3] parameter(0) indices = s32[2] parameter(1) ROOT gather = s32[3,2] gather(operand, indices), - output_window_dims={0}, - elided_window_dims={1}, - gather_dims_to_operand_dims={1}, + offset_dims={0}, + collapsed_slice_dims={1}, + start_index_map={1}, index_vector_dim=1, - window_bounds={3, 1} + slice_sizes={3, 1} } )"; TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, diff --git a/tensorflow/compiler/xla/service/generic_transfer_manager.cc b/tensorflow/compiler/xla/service/generic_transfer_manager.cc index e314a469f00abdb9f60ae812c0b78d273dc95dbe..0ce2db907b643f3beabd127388370dbe601179e1 100644 --- a/tensorflow/compiler/xla/service/generic_transfer_manager.cc +++ b/tensorflow/compiler/xla/service/generic_transfer_manager.cc @@ -24,7 +24,6 @@ limitations under the License. #include "tensorflow/compiler/xla/service/interpreter/platform_id.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" -#include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" @@ -60,17 +59,19 @@ Status GenericTransferManager::WriteSingleTupleIndexTable( void GenericTransferManager::TransferLiteralFromDevice( se::Stream* stream, const ShapedBuffer& device_buffer, - std::function>)> done) { + MutableBorrowingLiteral literal, std::function done) { Status status = stream->BlockHostUntilDone(); if (!status.ok()) { return done(status); } - done(TransferLiteralFromDeviceInternal(stream->parent(), device_buffer)); + + done(TransferLiteralFromDeviceInternal(stream->parent(), device_buffer, + literal)); } -StatusOr> -GenericTransferManager::TransferLiteralFromDeviceInternal( - se::StreamExecutor* executor, const ShapedBuffer& device_buffer) { +Status GenericTransferManager::TransferLiteralFromDeviceInternal( + se::StreamExecutor* executor, const ShapedBuffer& device_buffer, + MutableBorrowingLiteral literal) { VLOG(2) << "transferring literal from device ordinal " << executor->device_ordinal() << "; device buffer: " << device_buffer; TF_RET_CHECK(executor->device_ordinal() == device_buffer.device_ordinal()); @@ -80,9 +81,6 @@ GenericTransferManager::TransferLiteralFromDeviceInternal( TF_RET_CHECK(ShapeUtil::Equal(device_buffer.on_device_shape(), device_buffer.on_host_shape())); - std::unique_ptr literal = - Literal::CreateFromShape(device_buffer.on_host_shape()); - TF_RETURN_IF_ERROR(ShapeUtil::ForEachSubshapeWithStatus( device_buffer.on_host_shape(), [&](const Shape& subshape, const ShapeIndex& index) -> Status { @@ -91,12 +89,12 @@ GenericTransferManager::TransferLiteralFromDeviceInternal( /*source=*/device_buffer.buffer(index), /*size=*/GetByteSizeRequirement(subshape), /*destination=*/ - literal->untyped_data(index))); + literal.untyped_data(index))); } return Status::OK(); })); - return std::move(literal); + return Status::OK(); } Status GenericTransferManager::TransferLiteralToDeviceAsync( @@ -160,7 +158,7 @@ Status GenericTransferManager::TransferLiteralToInfeed( Status GenericTransferManager::TransferLiteralFromOutfeed( se::StreamExecutor* executor, const Shape& literal_shape, - Literal* literal) { + MutableBorrowingLiteral literal) { return Unimplemented("Generic transfer from Outfeed"); } diff --git a/tensorflow/compiler/xla/service/generic_transfer_manager.h b/tensorflow/compiler/xla/service/generic_transfer_manager.h index 3cd002c1bf3555cc2d2891c88b3ad648f8d9fd8c..6c1a21587a7ef5199afb93715dc57be5139fbc22 100644 --- a/tensorflow/compiler/xla/service/generic_transfer_manager.h +++ b/tensorflow/compiler/xla/service/generic_transfer_manager.h @@ -19,7 +19,6 @@ limitations under the License. #include #include "tensorflow/compiler/xla/service/transfer_manager.h" -#include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" @@ -41,9 +40,10 @@ class GenericTransferManager : public TransferManager { se::Platform::Id PlatformId() const override; - void TransferLiteralFromDevice( - se::Stream* stream, const ShapedBuffer& device_buffer, - std::function>)> done) override; + void TransferLiteralFromDevice(se::Stream* stream, + const ShapedBuffer& device_buffer, + MutableBorrowingLiteral literal, + std::function done) override; Status TransferLiteralToDeviceAsync( se::Stream* stream, const LiteralSlice& literal, @@ -53,7 +53,7 @@ class GenericTransferManager : public TransferManager { const LiteralSlice& literal) override; Status TransferLiteralFromOutfeed(se::StreamExecutor* executor, const Shape& literal_shape, - Literal* literal) override; + MutableBorrowingLiteral literal) override; Status ResetDevices( tensorflow::gtl::ArraySlice executors) override; @@ -67,8 +67,9 @@ class GenericTransferManager : public TransferManager { const Shape& shape, se::DeviceMemoryBase* region) override; private: - StatusOr> TransferLiteralFromDeviceInternal( - se::StreamExecutor* executor, const ShapedBuffer& device_buffer); + Status TransferLiteralFromDeviceInternal(se::StreamExecutor* executor, + const ShapedBuffer& device_buffer, + MutableBorrowingLiteral literal); // The platform this transfer manager targets. const se::Platform::Id platform_id_; diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD index 6f1e766d1c2bde0871654b18831ed44a851febb5..17eefc430d215866b2877b57a0624ddd030830b9 100644 --- a/tensorflow/compiler/xla/service/gpu/BUILD +++ b/tensorflow/compiler/xla/service/gpu/BUILD @@ -1,6 +1,7 @@ # Description: # GPU-specific components in XLA service implementation. +load("//tensorflow/compiler/xla/tests:build_defs.bzl", "xla_test") load("//tensorflow/compiler/xla:xla.bzl", "xla_proto_library") licenses(["notice"]) # Apache 2.0 @@ -55,6 +56,7 @@ cc_library( "//tensorflow/compiler/xla/service:hlo", "//tensorflow/core:lib", "//tensorflow/core:stream_executor_no_cuda", + "@com_google_absl//absl/memory", ], ) @@ -90,6 +92,7 @@ cc_library( "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service:hlo_reachability", "//tensorflow/core:lib", + "@com_google_absl//absl/memory", ], ) @@ -106,6 +109,7 @@ tf_cc_test( "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", + "@com_google_absl//absl/memory", ], ) @@ -114,11 +118,13 @@ cc_library( srcs = ["hlo_to_ir_bindings.cc"], hdrs = ["hlo_to_ir_bindings.h"], deps = [ + ":buffer_allocations", ":ir_emission_utils", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla/service:buffer_assignment", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service/llvm_ir:alias_analysis", + "//tensorflow/compiler/xla/service/llvm_ir:buffer_assignment_util", "//tensorflow/compiler/xla/service/llvm_ir:ir_array", "//tensorflow/compiler/xla/service/llvm_ir:llvm_util", "//tensorflow/compiler/xla/service/llvm_ir:tuple_ops", @@ -142,6 +148,7 @@ cc_library( ], deps = [ ":backend_configs", + ":buffer_allocations", ":cudnn_convolution_runner", ":elemental_ir_emitter", ":gpu_constants", @@ -150,7 +157,6 @@ cc_library( ":ir_emission_utils", ":parallel_loop_emitter", ":partition_assignment", - ":while_transformer", "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", @@ -163,6 +169,8 @@ cc_library( "//tensorflow/compiler/xla/service:elemental_ir_emitter", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service:name_uniquer", + "//tensorflow/compiler/xla/service:while_loop_analysis", + "//tensorflow/compiler/xla/service/llvm_ir:buffer_assignment_util", "//tensorflow/compiler/xla/service/llvm_ir:dynamic_update_slice_util", "//tensorflow/compiler/xla/service/llvm_ir:fused_ir_emitter", "//tensorflow/compiler/xla/service/llvm_ir:ir_array", @@ -175,6 +183,8 @@ cc_library( "//tensorflow/compiler/xla/service/llvm_ir:tuple_ops", "//tensorflow/core:lib", "//tensorflow/core:stream_executor_no_cuda", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/memory", "@llvm//:core", "@llvm//:support", ], @@ -238,6 +248,7 @@ cc_library( "//tensorflow/compiler/xla/service:device_memory_allocator", "//tensorflow/core:lib", "//tensorflow/core:stream_executor_no_cuda", + "@com_google_absl//absl/memory", ], ) @@ -248,10 +259,11 @@ cc_library( deps = [ "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service:hlo_execution_profile", - "//tensorflow/compiler/xla/service:pool", + "//tensorflow/compiler/xla/service:stream_pool", "//tensorflow/core:lib", "//tensorflow/core:ptr_util", "//tensorflow/core:stream_executor_no_cuda", + "@com_google_absl//absl/memory", ], ) @@ -323,6 +335,7 @@ cc_library( "//tensorflow/compiler/xla/service:shaped_buffer", "//tensorflow/compiler/xla/service:transfer_manager", "//tensorflow/compiler/xla/service:tuple_points_to_analysis", + "//tensorflow/compiler/xla/service/llvm_ir:buffer_assignment_util", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", "//tensorflow/core:stream_executor_no_cuda", @@ -331,6 +344,7 @@ cc_library( "//tensorflow/core/platform/default/build_config:cufft_plugin", "//tensorflow/core/platform/default/build_config:stream_executor_cuda", # build_cleaner: keep "//tensorflow/stream_executor", + "@com_google_absl//absl/memory", ], ) @@ -356,10 +370,12 @@ cc_library( hdrs = ["cudnn_convolution_algorithm_picker.h"], deps = [ ":backend_configs", + ":buffer_comparator", ":cudnn_convolution_runner", ":gpu_executable", ":ir_emission_utils", "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla/service:compiler", "//tensorflow/compiler/xla/service:device_memory_allocator", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service:hlo_pass", @@ -458,6 +474,7 @@ cc_library( "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service:multi_output_fusion", "//tensorflow/core:lib", + "@com_google_absl//absl/algorithm:container", ], ) @@ -505,6 +522,7 @@ cc_library( "//tensorflow/compiler/xla/service:hlo_cost_analysis", "//tensorflow/compiler/xla/service:hlo_pass", "//tensorflow/core:lib", + "@com_google_absl//absl/algorithm:container", ], ) @@ -526,6 +544,24 @@ cc_library( name = "pad_insertion", srcs = ["pad_insertion.cc"], hdrs = ["pad_insertion.h"], + deps = [ + ":ir_emission_utils", + "//tensorflow/compiler/xla:literal", + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla:window_util", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/service:hlo_creation_utils", + "//tensorflow/compiler/xla/service:hlo_pass", + "//tensorflow/compiler/xla/service:shape_inference", + "@com_google_absl//absl/memory", + ], +) + +cc_library( + name = "pad_for_tensor_cores", + srcs = ["pad_for_tensor_cores.cc"], + hdrs = ["pad_for_tensor_cores.h"], deps = [ ":ir_emission_utils", "//tensorflow/compiler/xla:literal", @@ -539,6 +575,21 @@ cc_library( ], ) +tf_cc_test( + name = "pad_for_tensor_cores_test", + srcs = ["pad_for_tensor_cores_test.cc"], + deps = [ + ":ir_emission_utils", + ":pad_for_tensor_cores", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla/service:hlo_matchers", + "//tensorflow/compiler/xla/service:hlo_parser", + "//tensorflow/compiler/xla/tests:hlo_verified_test_base", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", # build_cleaner: keep + ], +) + cc_library( name = "gpu_transfer_manager", srcs = ["gpu_transfer_manager.cc"], @@ -560,6 +611,7 @@ cc_library( "//tensorflow/compiler/xla/service/gpu:infeed_manager", "//tensorflow/core:lib", "//tensorflow/core:stream_executor_no_cuda", + "@com_google_absl//absl/memory", "@llvm//:core", ], alwayslink = True, # Contains per-platform transfer manager registration @@ -583,9 +635,11 @@ cc_library( ":ir_emission_utils", ":ir_emitter", ":multi_output_fusion", + ":pad_for_tensor_cores", ":pad_insertion", ":partition_assignment", ":stream_assignment", + ":stream_executor_util", "//tensorflow/compiler/xla:protobuf_util", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", @@ -597,7 +651,7 @@ cc_library( "//tensorflow/compiler/xla/service:buffer_liveness", "//tensorflow/compiler/xla/service:call_inliner", "//tensorflow/compiler/xla/service:conditional_simplifier", - "//tensorflow/compiler/xla/service:dot_decomposer", + "//tensorflow/compiler/xla/service:convolution_feature_group_converter", "//tensorflow/compiler/xla/service:executable", "//tensorflow/compiler/xla/service:flatten_call_graph", "//tensorflow/compiler/xla/service:hlo", @@ -614,10 +668,10 @@ cc_library( "//tensorflow/compiler/xla/service:llvm_compiler", "//tensorflow/compiler/xla/service:reduce_precision_insertion", "//tensorflow/compiler/xla/service:reshape_mover", + "//tensorflow/compiler/xla/service:scatter_expander", "//tensorflow/compiler/xla/service:transpose_folding", "//tensorflow/compiler/xla/service:tuple_simplifier", "//tensorflow/compiler/xla/service:while_loop_constant_sinking", - "//tensorflow/compiler/xla/service:while_loop_invariant_code_motion", "//tensorflow/compiler/xla/service:while_loop_simplifier", "//tensorflow/compiler/xla/service:zero_sized_hlo_elimination", "//tensorflow/compiler/xla/service/gpu:cudnn_batchnorm_rewriter", @@ -628,6 +682,7 @@ cc_library( "//tensorflow/core:lib_internal", "//tensorflow/core:regexp_internal", "//tensorflow/core:stream_executor_no_cuda", + "@com_google_absl//absl/memory", "@llvm//:core", ], alwayslink = True, # Contains compiler registration @@ -660,8 +715,8 @@ cc_library( ":xfeed_queue", "//tensorflow/compiler/xla:shape_tree", "//tensorflow/compiler/xla:types", - "//tensorflow/compiler/xla:util", "//tensorflow/core:stream_executor_no_cuda", + "@com_google_absl//absl/memory", ], ) @@ -676,6 +731,7 @@ cc_library( "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:util", "//tensorflow/core:lib", + "@com_google_absl//absl/memory", ], ) @@ -710,6 +766,8 @@ tf_cc_test( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/service:computation_layout", "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service:hlo_matchers", + "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", # build_cleaner: keep ], @@ -723,12 +781,12 @@ cc_library( ":stream_assignment", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:types", - "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla/service:buffer_value", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service:hlo_ordering", "//tensorflow/compiler/xla/service:hlo_reachability", "//tensorflow/compiler/xla/service:hlo_scheduling", + "@com_google_absl//absl/memory", ], ) @@ -745,21 +803,7 @@ tf_cc_test( "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", - ], -) - -cc_library( - name = "while_transformer", - srcs = ["while_transformer.cc"], - hdrs = ["while_transformer.h"], - deps = [ - "//tensorflow/compiler/xla:literal", - "//tensorflow/compiler/xla:shape_util", - "//tensorflow/compiler/xla:status_macros", - "//tensorflow/compiler/xla:statusor", - "//tensorflow/compiler/xla:util", - "//tensorflow/compiler/xla/service:hlo", - "//tensorflow/core:lib", + "@com_google_absl//absl/memory", ], ) @@ -768,12 +812,12 @@ tf_cc_test( srcs = ["while_transformer_test.cc"], deps = [ ":instruction_fusion", - ":while_transformer", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla/service:copy_insertion", "//tensorflow/compiler/xla/service:hlo_verifier", + "//tensorflow/compiler/xla/service:while_loop_analysis", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:test", @@ -809,6 +853,7 @@ cc_library( deps = [ "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:stream_executor_no_cuda", ], @@ -827,3 +872,35 @@ tf_cc_test( "//tensorflow/core:test", ], ) + +cc_library( + name = "buffer_comparator", + srcs = ["buffer_comparator.cc"], + hdrs = ["buffer_comparator.h"], + deps = [ + ":gpu_executable", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla/service:compiler", + "//tensorflow/compiler/xla/service:device_memory_allocator", + "//tensorflow/compiler/xla/service:hlo_parser", + "//tensorflow/compiler/xla/service:hlo_runner", + "//tensorflow/core:lib", + "//tensorflow/core:stream_executor_no_cuda", + ], +) + +xla_test( + name = "buffer_comparator_test", + srcs = ["buffer_comparator_test.cc"], + backends = [ + "cpu", + "gpu", + ], + deps = [ + ":buffer_comparator", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla/service:backend", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + ], +) diff --git a/tensorflow/compiler/xla/service/gpu/buffer_allocations.cc b/tensorflow/compiler/xla/service/gpu/buffer_allocations.cc index b095d4cd731bb7877baffbf69cb17bd50e101d6b..e208ad61e331ecac12fe128359da7585a2a3a7b4 100644 --- a/tensorflow/compiler/xla/service/gpu/buffer_allocations.cc +++ b/tensorflow/compiler/xla/service/gpu/buffer_allocations.cc @@ -17,8 +17,8 @@ limitations under the License. #include +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/map_util.h" -#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/gpu/gpu_constants.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/types.h" @@ -40,16 +40,26 @@ StatusOr> BufferAllocations::Builder::Build( const BufferAssignment* buffer_assignment, int device_ordinal, DeviceMemoryAllocator* memory_allocator) { const int64 num_buffers = buffer_assignment->Allocations().size(); - auto buffer_allocations = WrapUnique(new BufferAllocations( + auto buffer_allocations = absl::WrapUnique(new BufferAllocations( num_buffers, device_ordinal, memory_allocator, buffer_assignment)); for (BufferAllocation::Index i = 0; i < num_buffers; ++i) { + const BufferAllocation& allocation = buffer_assignment->GetAllocation(i); + const int64 expected_alignment = [&] { + if (allocation.is_entry_computation_parameter()) { + return kEntryParameterAlignBytes; + } else if (allocation.is_constant()) { + return kConstantBufferAlignBytes; + } else { + return kXlaAllocatedBufferAlignBytes; + } + }(); + // If buffer #i's address is already registered (e.g. external arguments or // result buffers), use that registered buffer. if (registered_buffers_.count(i)) { se::DeviceMemoryBase address = FindOrDie(registered_buffers_, i); - if (reinterpret_cast(address.opaque()) % - kEntryParameterAlignBytes != + if (reinterpret_cast(address.opaque()) % expected_alignment != 0) { return InternalError( "Address of registered buffer %lld must be a multiple of %llx, but " @@ -62,7 +72,6 @@ StatusOr> BufferAllocations::Builder::Build( // Allocate each allocation that might escape, or is the temp buffer. bool seen_temp_buffer = false; - const BufferAllocation& allocation = buffer_assignment->GetAllocation(i); if (allocation.maybe_live_out() || allocation.IsPreallocatedTempBuffer()) { const int64 buffer_size = allocation.size(); se::DeviceMemoryBase buffer_address; @@ -70,8 +79,7 @@ StatusOr> BufferAllocations::Builder::Build( OwningDeviceMemory buffer; TF_ASSIGN_OR_RETURN( buffer, memory_allocator->Allocate(device_ordinal, buffer_size)); - if (reinterpret_cast(buffer.opaque()) % - kXlaAllocatedBufferAlignBytes != + if (reinterpret_cast(buffer.opaque()) % expected_alignment != 0) { return InternalError( "Address returned by memory_allocator->Allocate must be a " @@ -165,5 +173,10 @@ void BufferAllocations::SetBuffer(BufferAllocation::Index buffer_index, buffers_[buffer_index] = buffer; } +bool ShouldEmitLiteralInLlvmIr(const Literal& literal) { + // LLVM can sometimes do interesting optimizations using scalar constants. + return ShapeUtil::IsScalar(literal.shape()); +} + } // namespace gpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/buffer_allocations.h b/tensorflow/compiler/xla/service/gpu/buffer_allocations.h index 636623502597b3a66523938ba430e9d5a82f796c..f13eab0dd787a2bfa687c991f9d808568360fd24 100644 --- a/tensorflow/compiler/xla/service/gpu/buffer_allocations.h +++ b/tensorflow/compiler/xla/service/gpu/buffer_allocations.h @@ -107,6 +107,12 @@ class BufferAllocations { bool torn_down_ = false; }; +// LLVM and PTXAS don't deal well with large constants, so we only emit very +// small constants directly in LLVM IR. Larger constants are emitted with zero +// initializers in LLVM IR and are later overwritten when the PTX/CUBIN is +// loaded. +bool ShouldEmitLiteralInLlvmIr(const Literal& literal); + } // namespace gpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/buffer_comparator.cc b/tensorflow/compiler/xla/service/gpu/buffer_comparator.cc new file mode 100644 index 0000000000000000000000000000000000000000..6a285a6b989b29428fc15fd6aef29110577c226e --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/buffer_comparator.cc @@ -0,0 +1,205 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/gpu/buffer_comparator.h" + +#include +#include "tensorflow/compiler/xla/service/hlo_parser.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/core/lib/strings/str_util.h" + +namespace xla { +namespace gpu { + +static constexpr float kTolerance = 0.1f; + +static string GetCompHloText(size_t num_elements) { + // Implements the textual format of the comparison routine, as it's more + // readable. + static constexpr char kF16CompHloText[] = R"( +HloModule CompareF16 + +MaxF32 { + %lhs = f32[] parameter(0) + %rhs = f32[] parameter(1) + ROOT %max = f32[] maximum(%lhs, %rhs) +} + +Canonicalize (aparam: f16[SIZE]) -> f32[SIZE] { + %min_constant = f32[] constant(-65505) + %max_constant = f32[] constant(65505) + %large_constant = f32[] constant(1048576) + %min_values = f32[SIZE] broadcast(%min_constant), dimensions={} + %max_values = f32[SIZE] broadcast(%max_constant), dimensions={} + %large_values = f32[SIZE] broadcast(%large_constant), dimensions={} + + %a = f16[SIZE] parameter(0) + %converted = f32[SIZE] convert(%a) + %clamped = f32[SIZE] clamp(%min_values, %converted, %max_values) + + // Since the clamp() above already took care of infs, only NaNs will cause + // is-finite() to return false. + %is_finite = pred[SIZE] is-finite(%clamped) + ROOT %result = f32[SIZE] select(%is_finite, %clamped, %large_values) +} + +ENTRY MaxDifference { + %one_constant = f32[] constant(1.0) + %zero_constant = f32[] constant(0.0) + + %ones = f32[SIZE] broadcast(%one_constant), dimensions={} + + %lhs = f16[SIZE] parameter(0) + %rhs = f16[SIZE] parameter(1) + %lhs_canonical = f32[SIZE] call(%lhs), to_apply=Canonicalize + %rhs_canonical = f32[SIZE] call(%rhs), to_apply=Canonicalize + %sub = f32[SIZE] subtract(%lhs_canonical, %rhs_canonical) + %sub_abs = f32[SIZE] abs(%sub) + %lhs_abs = f32[SIZE] abs(%lhs_canonical) + %rhs_abs = f32[SIZE] abs(%rhs_canonical) + %max = f32[SIZE] maximum(%lhs_abs, %rhs_abs) + %denominator = f32[SIZE] add(%max, %ones) + %error = f32[SIZE] divide(%sub_abs, %denominator) + ROOT %max_diff = f32[] reduce(%error, %zero_constant), dimensions={0}, to_apply=MaxF32 +})"; + auto size_string = std::to_string(num_elements); + return tensorflow::str_util::StringReplace( + kF16CompHloText, "SIZE", {size_string.data(), size_string.size()}, true); +} + +StatusOr F16BufferComparator::Create( + se::DeviceMemory ref_buffer, Compiler* compiler, + DeviceMemoryAllocator* allocator, se::Stream* stream) { + auto stream_exec = stream->parent(); + int64 num_elements = ref_buffer.ElementCount(); + + // One may consider using hlo_runner to do all the compilation and execution. + // However, as of the time hlo_runner doesn't support injection for Compiler*, + // Stream*, or even the allocator. We may revisit this in the future if it + // proves to be a maintenance burden. + TF_ASSIGN_OR_RETURN( + auto exec, ([&]() -> StatusOr> { + HloModuleConfig config; + DebugOptions debug_options; + debug_options.set_xla_backend_optimization_level(2); + config.set_debug_options(debug_options); + TF_ASSIGN_OR_RETURN( + auto module, ParseHloString(GetCompHloText(num_elements), config)); + TF_ASSIGN_OR_RETURN( + module, + compiler->RunHloPasses(std::move(module), stream_exec, nullptr)); + return compiler->RunBackend(std::move(module), stream_exec, nullptr); + }())); + + TF_ASSIGN_OR_RETURN( + auto shaped_buffer, ([&]() -> StatusOr { + auto device_ordinal = stream_exec->device_ordinal(); + TF_ASSIGN_OR_RETURN( + auto owning_buffer, + allocator->Allocate(device_ordinal, ref_buffer.size())); + se::DeviceMemory buffer( + owning_buffer.AsDeviceMemoryBase()); + stream->ThenMemcpy(&buffer, ref_buffer, ref_buffer.size()); + Shape shape = ShapeUtil::MakeShape(xla::F16, {num_elements}); + ScopedShapedBuffer ret(shape, shape, allocator, device_ordinal); + ret.set_buffer(std::move(owning_buffer), {}); + return std::move(ret); + }())); + + return F16BufferComparator(stream, allocator, std::move(exec), + std::move(shaped_buffer)); +} + +StatusOr F16BufferComparator::CompareEqualImpl( + se::DeviceMemory test_buffer) { + if (ref_buffer_.root_buffer().size() != test_buffer.size()) { + return InternalError("Mismatched buffer size: %lld vs %lld", + ref_buffer_.root_buffer().size(), test_buffer.size()); + } + + int64 num_elements = test_buffer.ElementCount(); + + TF_ASSIGN_OR_RETURN( + auto result_buffer, ([&]() -> StatusOr { + auto stream_exec = stream_->parent(); + Shape shape = ShapeUtil::MakeShape(xla::F16, {num_elements}); + auto device_ordinal = stream_exec->device_ordinal(); + ShapedBuffer shaped_test_buffer(shape, shape, stream_exec->platform(), + device_ordinal); + shaped_test_buffer.set_buffer(test_buffer, {}); + ExecutableRunOptions run_options; + run_options.set_device_ordinal(stream_exec->device_ordinal()); + run_options.set_stream(stream_); + run_options.set_allocator(allocator_); + ServiceExecutableRunOptions service_run_options(run_options); + return exec_->ExecuteOnStream( + &service_run_options, {&ref_buffer_, &shaped_test_buffer}, nullptr); + }())); + + float result; + CHECK(result_buffer.root_buffer().size() == sizeof(result)); + stream_->ThenMemcpy(&result, result_buffer.root_buffer(), sizeof(result)); + TF_RETURN_IF_ERROR(stream_->BlockHostUntilDone()); + return result < kTolerance; +} + +StatusOr F16BufferComparator::CompareEqual( + se::DeviceMemory test_buffer) { + TF_ASSIGN_OR_RETURN(auto result, CompareEqualImpl(test_buffer)); + if (result) { + return true; + } + // Host side code that does the same thing, but report some of the + // differences as well. + int64 n = test_buffer.ElementCount(); + std::vector host_ref_buffer(n), host_test_buffer(n); + stream_->ThenMemcpy(host_ref_buffer.data(), ref_buffer_.root_buffer(), + ref_buffer_.root_buffer().size()); + stream_->ThenMemcpy(host_test_buffer.data(), test_buffer, test_buffer.size()); + TF_RETURN_IF_ERROR(stream_->BlockHostUntilDone()); + + const auto canonicalize = [](float a) -> float { + constexpr float kBigNumer = 1048576.; + constexpr float kMaxFp16Value = 65504.; + if (std::isnan(a)) { + return kBigNumer; + } + if (std::isinf(a)) { + if (a < 0) { + return -(kMaxFp16Value + 1); + } + return kMaxFp16Value + 1; + } + return a; + }; + int differences_seen = 0; + for (int64 i = 0; i < n && differences_seen < 10; i++) { + float original_ref = static_cast(host_ref_buffer[i]); + float original_test = static_cast(host_test_buffer[i]); + float ref = canonicalize(original_ref); + float test = canonicalize(original_test); + if (!(std::abs(ref - test) / (std::max(std::abs(ref), std::abs(test)) + 1) < + kTolerance)) { + differences_seen++; + LOG(ERROR) << "Difference at " << i << ": " << original_ref << " vs " + << original_test; + } + } + + return false; +} + +} // namespace gpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/buffer_comparator.h b/tensorflow/compiler/xla/service/gpu/buffer_comparator.h new file mode 100644 index 0000000000000000000000000000000000000000..bf2ba78ceacaea1070830f758c3712b1378bd96f --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/buffer_comparator.h @@ -0,0 +1,71 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_BUFFER_COMPARATOR_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_BUFFER_COMPARATOR_H_ + +#include "tensorflow/compiler/xla/service/compiler.h" +#include "tensorflow/compiler/xla/service/device_memory_allocator.h" +#include "tensorflow/compiler/xla/service/gpu/gpu_executable.h" +#include "tensorflow/core/platform/stream_executor_no_cuda.h" + +namespace xla { +namespace gpu { + +// A fp16 comparator that internally keeps a reference buffer, and compares it +// against other test buffers. +class F16BufferComparator { + public: + F16BufferComparator(const F16BufferComparator&) = delete; + F16BufferComparator(F16BufferComparator&&) = default; + + // Creates a new comparator. It internally allocates a buffer initialized by + // ref_buffer. + static StatusOr Create( + se::DeviceMemory ref_buffer, Compiler* compiler, + DeviceMemoryAllocator* allocator, se::Stream* stream); + + // Returns true if the internally allocated buffer "compares equal" to + // test_buffer. The definition of "equal" is: + // * All NaNs equal. + // * All infs are treated as 65505 or -65505, so that this checker is tolerant + // to fp16 overflows. + // * With NaNs and infs taken care of, a and b compare equal iff: + // abs(a - b) / (max(abs(a), abs(b)) + 1) < tolerance + // + // See the implementation for the tolerance value. + StatusOr CompareEqual(se::DeviceMemory test_buffer); + + private: + F16BufferComparator(se::Stream* stream, DeviceMemoryAllocator* allocator, + std::unique_ptr exec, + ScopedShapedBuffer ref_buffer) + : stream_(stream), + allocator_(allocator), + exec_(std::move(exec)), + ref_buffer_(std::move(ref_buffer)) {} + + StatusOr CompareEqualImpl(se::DeviceMemory test_buffer); + + se::Stream* stream_; + DeviceMemoryAllocator* allocator_; + std::unique_ptr exec_; + ScopedShapedBuffer ref_buffer_; +}; + +} // namespace gpu +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_BUFFER_COMPARATOR_H_ diff --git a/tensorflow/compiler/xla/service/gpu/buffer_comparator_test.cc b/tensorflow/compiler/xla/service/gpu/buffer_comparator_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..33761d1bd8807df225e2cf505303b120e418576f --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/buffer_comparator_test.cc @@ -0,0 +1,126 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/gpu/buffer_comparator.h" + +#include +#include "tensorflow/compiler/xla/service/backend.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/core/platform/test.h" + +namespace xla { +namespace gpu { +namespace { + +class BufferComparatorTest : public testing::Test { + protected: + BufferComparatorTest() + : backend_(Backend::CreateDefaultBackend().ConsumeValueOrDie()), + stream_exec_(backend_->default_stream_executor()), + allocator_(stream_exec_->platform(), {stream_exec_}), + compiler_(Compiler::GetForPlatform(stream_exec_->platform()) + .ConsumeValueOrDie()) {} + + // Take floats only for convenience. Still uses half internally. + bool CompareEqualFloatBuffers(const std::vector& lhs_float, + const std::vector& rhs_float) { + std::vector lhs(lhs_float.begin(), lhs_float.end()); + std::vector rhs(rhs_float.begin(), rhs_float.end()); + se::Stream stream(stream_exec_); + stream.Init(); + + auto owning_lhs_buffer = + allocator_ + .Allocate(stream_exec_->device_ordinal(), lhs.size() * sizeof(half)) + .ConsumeValueOrDie(); + + auto owning_rhs_buffer = + allocator_ + .Allocate(stream_exec_->device_ordinal(), rhs.size() * sizeof(half)) + .ConsumeValueOrDie(); + + auto lhs_buffer = + se::DeviceMemory(owning_lhs_buffer.AsDeviceMemoryBase()); + auto rhs_buffer = + se::DeviceMemory(owning_rhs_buffer.AsDeviceMemoryBase()); + + stream.ThenMemcpy(&lhs_buffer, lhs.data(), lhs_buffer.size()); + stream.ThenMemcpy(&rhs_buffer, rhs.data(), rhs_buffer.size()); + + TF_CHECK_OK(stream.BlockHostUntilDone()); + + return F16BufferComparator::Create(lhs_buffer, compiler_, &allocator_, + &stream) + .ConsumeValueOrDie() + .CompareEqual(rhs_buffer) + .ConsumeValueOrDie(); + } + + std::unique_ptr backend_; + se::StreamExecutor* stream_exec_; + StreamExecutorMemoryAllocator allocator_; + Compiler* compiler_; +}; + +TEST_F(BufferComparatorTest, TestNaNs) { + EXPECT_TRUE(CompareEqualFloatBuffers({std::nanf("")}, {std::nanf("")})); + // NaN values with different bit patterns should compare equal. + EXPECT_TRUE(CompareEqualFloatBuffers({std::nanf("")}, {std::nanf("1234")})); + EXPECT_FALSE(CompareEqualFloatBuffers({std::nanf("")}, {1.})); +} + +TEST_F(BufferComparatorTest, TestInfs) { + const auto inf = std::numeric_limits::infinity(); + EXPECT_FALSE(CompareEqualFloatBuffers({inf}, {std::nanf("")})); + EXPECT_TRUE(CompareEqualFloatBuffers({inf}, {inf})); + EXPECT_TRUE(CompareEqualFloatBuffers({inf}, {65504})); + EXPECT_TRUE(CompareEqualFloatBuffers({-inf}, {-65504})); + EXPECT_FALSE(CompareEqualFloatBuffers({inf}, {-65504})); + EXPECT_FALSE(CompareEqualFloatBuffers({-inf}, {65504})); + + EXPECT_FALSE(CompareEqualFloatBuffers({inf}, {20})); + EXPECT_FALSE(CompareEqualFloatBuffers({inf}, {-20})); + EXPECT_FALSE(CompareEqualFloatBuffers({-inf}, {20})); + EXPECT_FALSE(CompareEqualFloatBuffers({-inf}, {-20})); +} + +TEST_F(BufferComparatorTest, TestNumbers) { + EXPECT_TRUE(CompareEqualFloatBuffers({20}, {20.1})); + EXPECT_FALSE(CompareEqualFloatBuffers({0}, {1})); + EXPECT_TRUE(CompareEqualFloatBuffers({0.9}, {1})); + EXPECT_TRUE(CompareEqualFloatBuffers({9}, {10})); + EXPECT_TRUE(CompareEqualFloatBuffers({10}, {9})); +} + +TEST_F(BufferComparatorTest, TestMultiple) { + EXPECT_TRUE(CompareEqualFloatBuffers({20, 30, 40, 50, 60}, + {20.1, 30.1, 40.1, 50.1, 60.1})); + std::vector lhs(200); + std::vector rhs(200); + for (int i = 0; i < 200; i++) { + EXPECT_TRUE(CompareEqualFloatBuffers(lhs, rhs)) + << "should be the same at index " << i; + lhs[i] = 3; + rhs[i] = 5; + EXPECT_FALSE(CompareEqualFloatBuffers(lhs, rhs)) + << "should be the different at index " << i; + lhs[i] = 0; + rhs[i] = 0; + } +} + +} // namespace +} // namespace gpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/conditional_thunk.cc b/tensorflow/compiler/xla/service/gpu/conditional_thunk.cc index 5780e0af40699bb6ac2c190c09cd02023fb44db7..8b0426aa27fa3fbc7225dda81cef17e543f1cf28 100644 --- a/tensorflow/compiler/xla/service/gpu/conditional_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/conditional_thunk.cc @@ -15,7 +15,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/conditional_thunk.h" -#include "tensorflow/compiler/xla/ptr_util.h" +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/core/errors.h" diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.cc b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.cc index 5a63e65208ac3e8e23944bc31634f4d29d91c10c..caeb89d78ea3a3d49182abffa879d7503419c352 100644 --- a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.cc +++ b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.cc @@ -16,11 +16,13 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/service/gpu/backend_configs.pb.h" +#include "tensorflow/compiler/xla/service/gpu/buffer_comparator.h" #include "tensorflow/compiler/xla/service/gpu/convolution_thunk.h" #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" #include "tensorflow/core/lib/gtl/optional.h" #include "tensorflow/core/lib/strings/numbers.h" #include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/platform/mutex.h" namespace xla { namespace gpu { @@ -29,7 +31,6 @@ namespace { using se::DeviceMemoryBase; using se::dnn::AlgorithmConfig; using se::dnn::AlgorithmDesc; -using tensorflow::gtl::nullopt; using tensorflow::gtl::optional; class ScratchAllocator : public se::ScratchAllocator { @@ -137,6 +138,28 @@ string NumBytesToString(int64 bytes) { tensorflow::strings::HumanReadableNumBytes(bytes), " (", bytes, "B)"); } +// Acquires a process-global lock on the device pointed to by the given +// StreamExecutor. +// +// This is used to prevent other XLA instances from trying to autotune on this +// device while we're using it. +tensorflow::mutex_lock LockGpu(const se::StreamExecutor* stream_exec) { + static tensorflow::mutex mu(tensorflow::LINKER_INITIALIZED); + // se::Platform*s are global singletons guaranteed to live forever. + static auto* mutexes = + new std::map, + tensorflow::mutex>(); + + tensorflow::mutex_lock global_lock(mu); + auto it = mutexes + ->emplace(std::piecewise_construct, + std::make_tuple(stream_exec->platform(), + stream_exec->device_ordinal()), + std::make_tuple()) + .first; + return tensorflow::mutex_lock{it->second}; +} + } // anonymous namespace // We could have caching here so that we don't redo this work for two identical @@ -150,11 +173,24 @@ string NumBytesToString(int64 bytes) { // cache misses and doing extra work. Overall, caching doesn't seem worth the // trouble, but we may want to revisit this if we ever find a model where // caching would speed up compilation a lot. -optional> +StatusOr> CudnnConvolutionAlgorithmPicker::PickBestAlgorithm( CudnnConvKind kind, const Shape& input_shape, const Shape& filter_shape, const Shape& output_shape, const Window& window, const ConvolutionDimensionNumbers& dnums, HloInstruction* instr) { + CHECK_EQ(input_shape.element_type(), filter_shape.element_type()); + CHECK_EQ(input_shape.element_type(), output_shape.element_type()); + // TODO(timshen): for now only check fp16. It can be expanded to other types, + // with some work on the HLO routines. + const bool cross_check_enabled = input_shape.element_type() == xla::F16; + + // Don't run this function concurrently on the same GPU. + // + // This is a bit of a hack and doesn't protect us against arbitrary concurrent + // use of a GPU, but it's sufficient to let us compile two HLO modules + // concurrently and then run them sequentially. + tensorflow::mutex_lock lock = LockGpu(stream_exec_); + // Create a stream for us to do our work on. se::Stream stream{stream_exec_}; stream.Init(); @@ -176,51 +212,75 @@ CudnnConvolutionAlgorithmPicker::PickBestAlgorithm( // Allocate space for the input, filter, and output of the convolution. We // use a ScratchAllocator for this instead of calling allocator_ directly so // that our allocations don't leak. - // - // We don't put any data in these buffers, because (in theory, anyway) the - // speed of a conv isn't affected by the data being convolved. ScratchAllocator input_output_allocator(device_ordinal, allocator); - StatusOr maybe_input_buf = - input_output_allocator.AllocateBytes(&stream, - ShapeUtil::ByteSizeOf(input_shape)); - StatusOr maybe_filter_buf = - input_output_allocator.AllocateBytes(&stream, - ShapeUtil::ByteSizeOf(filter_shape)); - StatusOr maybe_output_buf = - input_output_allocator.AllocateBytes(&stream, - ShapeUtil::ByteSizeOf(output_shape)); - if (!maybe_input_buf.ok() || !maybe_filter_buf.ok() || - !maybe_output_buf.ok()) { - LOG(WARNING) - << "Couldn't allocate space for input/filter/output of convolution " - << instr->ToString() << ". Falling back to default algorithm."; - return nullopt; - } - - DeviceMemoryBase input_buf = maybe_input_buf.ValueOrDie(); - DeviceMemoryBase filter_buf = maybe_filter_buf.ValueOrDie(); - DeviceMemoryBase output_buf = maybe_output_buf.ValueOrDie(); - - // Although we don't have evidence this matters, zero out the buffers before - // autotuning. It's conceivable that using uninitialized memory as the inputs - // might affect performance if e.g. the inputs contain denormals, and this is - // easy enough. - if (!stream.ThenMemZero(&input_buf, input_buf.size()) - .ThenMemZero(&filter_buf, filter_buf.size()) - .ThenMemZero(&output_buf, output_buf.size()) - .BlockHostUntilDone() - .ok()) { - LOG(WARNING) - << "Couldn't zero out input/filter/output buffer for convolution " - << instr->ToString() << ". Falling back to default algorithm."; - return nullopt; + TF_ASSIGN_OR_RETURN(DeviceMemoryBase input_buf, + input_output_allocator.AllocateBytes( + &stream, ShapeUtil::ByteSizeOf(input_shape))); + TF_ASSIGN_OR_RETURN(DeviceMemoryBase filter_buf, + input_output_allocator.AllocateBytes( + &stream, ShapeUtil::ByteSizeOf(filter_shape))); + TF_ASSIGN_OR_RETURN(DeviceMemoryBase output_buf, + input_output_allocator.AllocateBytes( + &stream, ShapeUtil::ByteSizeOf(output_shape))); + + if (cross_check_enabled) { + // Broadcast a constant to the buffer, instead of zeroing the buffer. A + // non-zero constant is useful for the cross checking, because zero-inputs + // may not always reveal the bugs. + const auto initialize_f16 = [&stream](DeviceMemoryBase buffer) { + CHECK_EQ(0, (uintptr_t)buffer.opaque() % 4); + size_t left_over_bytes = buffer.size() % 4; + CHECK_EQ(0, left_over_bytes % 2); + + constexpr float kBroadcastedConstant = 0.1f; + Eigen::half halfs[2] = {Eigen::half(kBroadcastedConstant), + Eigen::half(kBroadcastedConstant)}; + uint32 bits; + static_assert(sizeof(bits) == sizeof(halfs), ""); + memcpy(&bits, halfs, sizeof(bits)); + + size_t aligned_size = buffer.size() / 4 * 4; + stream.ThenMemset32(&buffer, bits, aligned_size); + + DeviceMemoryBase left_over( + static_cast(buffer.opaque()) + aligned_size, left_over_bytes); + stream.ThenMemcpy(&left_over, halfs, left_over_bytes); + }; + initialize_f16(input_buf); + initialize_f16(filter_buf); + initialize_f16(output_buf); + } else { + // Although we don't have evidence this matters, zero out the buffers before + // autotuning. It's conceivable that using uninitialized memory as the + // inputs might affect performance if e.g. the inputs contain denormals, and + // this is easy enough. + stream.ThenMemZero(&input_buf, input_buf.size()) + .ThenMemZero(&filter_buf, filter_buf.size()) + .ThenMemZero(&output_buf, output_buf.size()); } + TF_RETURN_IF_ERROR(stream.BlockHostUntilDone()); + + DeviceMemoryBase* result_buf = [&] { + switch (kind) { + case CudnnConvKind::kBackwardFilter: + return &filter_buf; + case CudnnConvKind::kBackwardInput: + return &input_buf; + case CudnnConvKind::kForward: + return &output_buf; + } + }(); const bool use_winograd_nonfused = ShouldIncludeWinogradNonfusedAlgo( input_shape, output_shape, dnums, stream_exec_); se::dnn::ProfileResult best_result; int64 best_result_bytes_used = 0; + optional comparator; + // Use the first algorithm that's supported as reference. There isn't a + // particular reason to use it, as any algorithm sufficies. It doesn't make + // this algorithm considered correct, though. + optional first_algorithm; for (const AlgorithmDesc& alg : GetAlgorithms(kind, use_winograd_nonfused, stream_exec_)) { ScratchAllocator scratch_allocator(device_ordinal, allocator); @@ -236,6 +296,42 @@ CudnnConvolutionAlgorithmPicker::PickBestAlgorithm( .ok(); if (launch_ok && profile_result.is_valid()) { + const bool crash_on_checking_failure = + instr->GetModule() + ->config() + .debug_options() + .xla_gpu_crash_on_verification_failures(); + if (comparator.has_value()) { + StatusOr result = comparator->CompareEqual( + se::DeviceMemory(*result_buf)); + if (!result.ok()) { + LOG(ERROR) << "Unable to compare " + << AlgorithmToString(*first_algorithm) << " against " + << AlgorithmToString(alg) << " for " << instr->ToString() + << ": " << result.status(); + CHECK(!crash_on_checking_failure); + } else if (!result.ValueOrDie()) { + LOG(ERROR) << "Results mismatch between different convolution " + "algorithms. This is likely a bug in convolution, or " + "an excessive loss of precision in convolution. " + << instr->ToString() << " for " + << AlgorithmToString(*first_algorithm) << " vs " + << AlgorithmToString(alg); + CHECK(!crash_on_checking_failure); + } + } else if (cross_check_enabled) { + auto comp = F16BufferComparator::Create( + se::DeviceMemory(*result_buf), compiler_, allocator, + &stream); + if (comp.ok()) { + comparator.emplace(comp.ConsumeValueOrDie()); + first_algorithm.emplace(alg); + } else { + LOG(ERROR) << "Fail to initialize buffer comparator: " + << comp.status() << ", instruction: " << instr->ToString(); + CHECK(!crash_on_checking_failure); + } + } int64 scratch_bytes_used = scratch_allocator.TotalAllocatedBytes(); VLOG(3) << "Run of algorithm " << AlgorithmToString(alg) << " succeeded, taking " << profile_result.elapsed_time_in_ms() @@ -262,9 +358,10 @@ CudnnConvolutionAlgorithmPicker::PickBestAlgorithm( best_result_bytes_used); } - LOG(WARNING) << "All algorithms tried for convolution " << instr->ToString() - << " failed. Falling back to default algorithm."; - return nullopt; + return InternalError( + "All algorithms tried for convolution %s failed. Falling back to " + "default algorithm.", + instr->ToString().c_str()); } StatusOr CudnnConvolutionAlgorithmPicker::RunOnInstruction( @@ -275,12 +372,13 @@ StatusOr CudnnConvolutionAlgorithmPicker::RunOnInstruction( const auto& lhs_shape = instr->operand(0)->shape(); const auto& rhs_shape = instr->operand(1)->shape(); const auto& conv_result_shape = instr->shape().tuple_shapes(0); - optional> alg_scratch_and_tc; + StatusOr> alg_scratch_and_tc; if (call_target == kCudnnConvForwardCallTarget) { - alg_scratch_and_tc = PickBestAlgorithm( - CudnnConvKind::kForward, /*input_shape=*/lhs_shape, - /*filter_shape=*/rhs_shape, /*output_shape=*/conv_result_shape, - instr->window(), instr->convolution_dimension_numbers(), instr); + alg_scratch_and_tc = + PickBestAlgorithm(CudnnConvKind::kForward, /*input_shape=*/lhs_shape, + /*filter_shape=*/rhs_shape, + /*output_shape=*/conv_result_shape, instr->window(), + instr->convolution_dimension_numbers(), instr); } else if (call_target == kCudnnConvBackwardInputCallTarget) { alg_scratch_and_tc = PickBestAlgorithm( CudnnConvKind::kBackwardInput, /*input_shape=*/conv_result_shape, @@ -296,7 +394,8 @@ StatusOr CudnnConvolutionAlgorithmPicker::RunOnInstruction( << instr->ToString(); } - if (!alg_scratch_and_tc.has_value()) { + if (!alg_scratch_and_tc.ok()) { + LOG(ERROR) << alg_scratch_and_tc.status(); return false; } @@ -304,7 +403,8 @@ StatusOr CudnnConvolutionAlgorithmPicker::RunOnInstruction( bool tensor_ops_enabled; int64 scratch_bytes; - std::tie(algorithm, tensor_ops_enabled, scratch_bytes) = *alg_scratch_and_tc; + std::tie(algorithm, tensor_ops_enabled, scratch_bytes) = + alg_scratch_and_tc.ConsumeValueOrDie(); VLOG(1) << "Setting cudnn conv to use algorithm " << algorithm << " and " << NumBytesToString(scratch_bytes) diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.h b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.h index bc5d1ce94afd2075a006899f0f6bcf64352e5e99..8b7749628a8d0c54f66c4cd23a9eebbe42788971 100644 --- a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.h +++ b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.h @@ -16,6 +16,7 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUDNN_CONVOLUTION_ALGORITHM_PICKER_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUDNN_CONVOLUTION_ALGORITHM_PICKER_H_ +#include "tensorflow/compiler/xla/service/compiler.h" #include "tensorflow/compiler/xla/service/device_memory_allocator.h" #include "tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.h" #include "tensorflow/compiler/xla/service/hlo_module.h" @@ -34,8 +35,9 @@ class CudnnConvolutionAlgorithmPicker : public HloPassInterface { // memory while timing the various convolution algorithms. If it's null, // we'll use the default allocator on the StreamExecutor. CudnnConvolutionAlgorithmPicker(se::StreamExecutor* stream_exec, - DeviceMemoryAllocator* allocator) - : stream_exec_(stream_exec), allocator_(allocator) {} + DeviceMemoryAllocator* allocator, + Compiler* compiler) + : stream_exec_(stream_exec), allocator_(allocator), compiler_(compiler) {} tensorflow::StringPiece name() const override { return "cudnn-convolution-algorithm-picker"; @@ -46,13 +48,14 @@ class CudnnConvolutionAlgorithmPicker : public HloPassInterface { private: StatusOr RunOnComputation(HloComputation* computation); StatusOr RunOnInstruction(HloInstruction* instr); - tensorflow::gtl::optional> PickBestAlgorithm( + StatusOr> PickBestAlgorithm( CudnnConvKind kind, const Shape& input_shape, const Shape& filter_shape, const Shape& output_shape, const Window& window, const ConvolutionDimensionNumbers& dnums, HloInstruction* instr); se::StreamExecutor* stream_exec_; // never null DeviceMemoryAllocator* allocator_; // may be null + Compiler* compiler_; }; } // namespace gpu diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.cc b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.cc index 0645fbb3ad39f1f1649caf45a6068b5a196c30b9..7b0d9e53d60dda620714b3443b627405e562b353 100644 --- a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.cc +++ b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.cc @@ -96,15 +96,9 @@ Status RunCudnnConvolution( // tensorflow/python/ops/nn_ops.py). const int effective_num_dimensions = std::max(2, num_dimensions); - if (std::is_same::value) { - CHECK_EQ(F32, output_shape.element_type()) - << ShapeUtil::HumanString(output_shape); - } else if (std::is_same::value) { - CHECK_EQ(F16, output_shape.element_type()) - << ShapeUtil::HumanString(output_shape); - } else { - LOG(FATAL) << ShapeUtil::HumanString(output_shape); - } + CHECK_EQ(primitive_util::NativeToPrimitiveType(), + output_shape.element_type()) + << ShapeUtil::HumanString(output_shape); CHECK_EQ(num_dimensions, dnums.input_spatial_dimensions_size()); CHECK_EQ(num_dimensions, dnums.kernel_spatial_dimensions_size()); @@ -246,21 +240,31 @@ Status RunCudnnConvolution( se::dnn::AlgorithmConfig algorithm, se::Stream* stream, se::dnn::ProfileResult* profile_result) { PrimitiveType output_primitive_type = output_shape.element_type(); - CHECK(output_primitive_type == F32 || output_primitive_type == F16) - << ShapeUtil::HumanString(output_shape); - if (output_primitive_type == F32) { - return RunCudnnConvolution( - kind, input_shape, filter_shape, output_shape, - se::DeviceMemory(input_buf), se::DeviceMemory(filter_buf), - se::DeviceMemory(output_buf), scratch_allocator, window, dnums, - algorithm, stream, profile_result); + switch (output_primitive_type) { + case F16: + return RunCudnnConvolution(kind, input_shape, filter_shape, output_shape, + se::DeviceMemory(input_buf), + se::DeviceMemory(filter_buf), + se::DeviceMemory(output_buf), + scratch_allocator, window, dnums, algorithm, + stream, profile_result); + case F32: + return RunCudnnConvolution(kind, input_shape, filter_shape, output_shape, + se::DeviceMemory(input_buf), + se::DeviceMemory(filter_buf), + se::DeviceMemory(output_buf), + scratch_allocator, window, dnums, algorithm, + stream, profile_result); + case F64: + return RunCudnnConvolution(kind, input_shape, filter_shape, output_shape, + se::DeviceMemory(input_buf), + se::DeviceMemory(filter_buf), + se::DeviceMemory(output_buf), + scratch_allocator, window, dnums, algorithm, + stream, profile_result); + default: + LOG(FATAL) << ShapeUtil::HumanString(output_shape); } - return RunCudnnConvolution(kind, input_shape, filter_shape, output_shape, - se::DeviceMemory(input_buf), - se::DeviceMemory(filter_buf), - se::DeviceMemory(output_buf), - scratch_allocator, window, dnums, algorithm, - stream, profile_result); } } // namespace gpu diff --git a/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc b/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc index cc38db27e2680e950f74e104cef8829585c7b81c..9b6de115ad7e7f87e431f839c1690858f4bce3fd 100644 --- a/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc +++ b/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc @@ -210,11 +210,13 @@ StatusOr GpuElementalIrEmitter::EmitPowerOp( return make_sqrt(); } - if (hlo_module_config_.debug_options().xla_enable_fast_math() && - IsFPLiteralWithValue(rhs, -.5)) { + if (IsFPLiteralWithValue(rhs, -.5)) { VLOG(10) << "emitting pow(A, -.5) as 1/sqrt(A): " << op->ToString(); // LLVM's NVPTX backend knows how to transform 1/sqrt(A) into the NVPTX // rsqrt.approx instruction. + // + // TODO(jlebar): Does this happen with fastmath disabled? If not, should + // we force-enable it? TF_ASSIGN_OR_RETURN(auto* sqrt, make_sqrt()); return b_->CreateFDiv(llvm::ConstantFP::get(llvm_ty, 1), sqrt); } @@ -272,27 +274,20 @@ StatusOr GpuElementalIrEmitter::EmitAtan2( prim_type); } -StatusOr GpuElementalIrEmitter::EmitFloatUnaryOp( - const HloInstruction* op, llvm::Value* operand_value) const { - PrimitiveType input_type = op->operand(0)->shape().element_type(); - PrimitiveType output_type = op->shape().element_type(); - switch (op->opcode()) { - case HloOpcode::kTanh: - // If we don't care much about precision, emit a fast approximation of - // tanh. - if (hlo_module_config_.debug_options().xla_enable_fast_math()) { - // Upcast F16 to F32 if necessary. - llvm::Type* type = - input_type == F16 ? b_->getFloatTy() : operand_value->getType(); - llvm::Value* input = b_->CreateFPCast(operand_value, type); - llvm::Value* fast_tanh = llvm_ir::EmitFastTanh(b_, input); - return b_->CreateFPCast(fast_tanh, operand_value->getType()); - } - return EmitLibdeviceMathCall("__nv_tanh", {operand_value}, {input_type}, - output_type); - default: - return ElementalIrEmitter::EmitFloatUnaryOp(op, operand_value); - } +StatusOr GpuElementalIrEmitter::EmitTanh( + PrimitiveType prim_type, llvm::Value* value) const { + // Emit a fast approximation of tanh instead of calling __nv_tanh. + // __nv_tanh is particularly bad because it contains branches, thus + // preventing LLVM's load-store vectorizer from working its magic across a + // function which contains tanh calls. + // + // This routine isn't numerically precise, but it's good enough for ML. + + // Upcast F16 to F32 if necessary. + llvm::Type* type = prim_type == F16 ? b_->getFloatTy() : value->getType(); + llvm::Value* input = b_->CreateFPCast(value, type); + llvm::Value* fast_tanh = llvm_ir::EmitFastTanh(b_, input); + return b_->CreateFPCast(fast_tanh, value->getType()); } llvm::Value* GpuElementalIrEmitter::EmitDeviceFunctionCall( @@ -445,6 +440,8 @@ llvm_ir::ElementGenerator GpuElementalIrEmitter::MakeElementGenerator( return b_->CreateLoad(accum_ptr); }; case HloOpcode::kReduce: + // TODO(b/112040122): This should be supported. + CHECK_EQ(hlo->operand_count(), 2) << "Did not expect variadic reduce"; return [=, &operand_to_generator]( const IrArray::Index& output_index) -> StatusOr { const HloInstruction* operand = hlo->operand(0); diff --git a/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.h b/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.h index e3eacef133cb8b615a645ca2f11dd6dedf9f0176..84454d31bb820a3de6ef3364bd205b8115bd95c0 100644 --- a/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.h +++ b/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.h @@ -51,9 +51,6 @@ class GpuElementalIrEmitter : public ElementalIrEmitter { const HloToElementGeneratorMap& operand_to_generator) const override; protected: - StatusOr EmitFloatUnaryOp( - const HloInstruction* op, llvm::Value* operand_value) const override; - StatusOr EmitFloatBinaryOp( const HloInstruction* op, llvm::Value* lhs_value, llvm::Value* rhs_value) const override; @@ -85,6 +82,9 @@ class GpuElementalIrEmitter : public ElementalIrEmitter { StatusOr EmitAtan2(PrimitiveType prim_type, llvm::Value* lhs, llvm::Value* rhs) const override; + StatusOr EmitTanh(PrimitiveType prim_type, + llvm::Value* value) const override; + llvm::Value* EmitThreadId() const override; private: diff --git a/tensorflow/compiler/xla/service/gpu/for_thunk.cc b/tensorflow/compiler/xla/service/gpu/for_thunk.cc index b3a3c5dcb4d77889b65a119f09ddef9ba95d6b52..88f0b4d71c915c37f0b58cb91a8788fd8f9cc452 100644 --- a/tensorflow/compiler/xla/service/gpu/for_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/for_thunk.cc @@ -15,7 +15,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/for_thunk.h" -#include "tensorflow/compiler/xla/ptr_util.h" +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/core/errors.h" @@ -28,7 +28,7 @@ ForThunk::ForThunk(const int64 loop_limit, const HloInstruction* hlo) : Thunk(Kind::kWhile, hlo), loop_limit_(loop_limit), - body_thunk_sequence_(MakeUnique( + body_thunk_sequence_(absl::make_unique( // Pass nullptr as the HloInstruction* to the body_thunk_sequence_ // constructor because this SequentialThunk is logically "part of" // this ForThunk, and shouldn't be profiled separately from it. @@ -43,6 +43,8 @@ Status ForThunk::Initialize(const GpuExecutable& executable, Status ForThunk::ExecuteOnStream(const BufferAllocations& buffer_allocations, se::Stream* stream, HloExecutionProfiler* profiler) { + VLOG(2) << "Executing ForThunk with " << loop_limit_ << " iters for " + << (hlo_instruction() ? hlo_instruction()->ToString() : ""); auto op_profiler = profiler->MakeScopedInstructionProfiler(hlo_instruction()); for (int64 i = 0; i < loop_limit_; ++i) { profiler->StartHloComputation(); diff --git a/tensorflow/compiler/xla/service/gpu/fusion_merger.cc b/tensorflow/compiler/xla/service/gpu/fusion_merger.cc index 3cd30b754c3242f00c704de1afab2282ed827b41..9b86e5315bf51e88cca569499fe9acbe17998e48 100644 --- a/tensorflow/compiler/xla/service/gpu/fusion_merger.cc +++ b/tensorflow/compiler/xla/service/gpu/fusion_merger.cc @@ -18,6 +18,7 @@ limitations under the License. #include #include +#include "absl/algorithm/container.h" #include "tensorflow/compiler/xla/service/gpu/instruction_fusion.h" #include "tensorflow/compiler/xla/service/hlo_cost_analysis.h" #include "tensorflow/compiler/xla/shape_util.h" @@ -64,10 +65,11 @@ double CalculateBytesReadByFusionParameter(HloInstruction* param) { // Slice for a more accurate estimate of bytes read. double bytes = 0.0; for (auto& instruction : instructions) { - if (c_all_of(instruction->users(), [](const HloInstruction* instruction) { - return instruction->opcode() == HloOpcode::kSlice || - instruction->opcode() == HloOpcode::kDynamicSlice; - })) { + if (absl::c_all_of( + instruction->users(), [](const HloInstruction* instruction) { + return instruction->opcode() == HloOpcode::kSlice || + instruction->opcode() == HloOpcode::kDynamicSlice; + })) { // All users are slice: accumulate bytes of all user slice instructions. for (auto& user : instruction->users()) { bytes += ShapeUtil::ByteSizeOf(user->shape()); @@ -223,7 +225,7 @@ Status FusionInstructionMerger::HandleFusion(HloInstruction* fusion) { // Skip 'fusion' instruction if we cannot merge into all of its users. // Merging into all users enables the removal of 'fusion' from the // computation. - if (!c_all_of(fusion->users(), [](const HloInstruction* user) { + if (!absl::c_all_of(fusion->users(), [](const HloInstruction* user) { return user->opcode() == HloOpcode::kFusion && (user->fusion_kind() == HloInstruction::FusionKind::kLoop || user->fusion_kind() == HloInstruction::FusionKind::kInput); @@ -241,11 +243,11 @@ Status FusionInstructionMerger::HandleFusion(HloInstruction* fusion) { // If 'fusion' has just one user, then an earlier fusion pass chose not to // fuse this producer/comsumer pair (likely because of expensive instruction // re-use by the consumer), and so we honor that choice here as well. - if (c_any_of(fusion->fused_instructions(), - [](const HloInstruction* instruction) { - return instruction->opcode() != HloOpcode::kParameter && - GpuInstructionFusion::IsExpensive(*instruction); - })) { + if (absl::c_any_of(fusion->fused_instructions(), + [](const HloInstruction* instruction) { + return instruction->opcode() != HloOpcode::kParameter && + GpuInstructionFusion::IsExpensive(*instruction); + })) { VLOG(3) << "Not merging " << fusion->name() << ": Contains one or more expensive instructions."; ++num_fail_expensive_fused_instruction_; diff --git a/tensorflow/compiler/xla/service/gpu/gemm_thunk.cc b/tensorflow/compiler/xla/service/gpu/gemm_thunk.cc index dbc7754e251eb8075ab97dd2f36bbc400530fcf5..74282c568c09921dbeec2e9cce79b6c73b6ea592 100644 --- a/tensorflow/compiler/xla/service/gpu/gemm_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/gemm_thunk.cc @@ -18,6 +18,7 @@ limitations under the License. #include #include "tensorflow/compiler/xla/util.h" +#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" #include "tensorflow/core/platform/types.h" @@ -31,16 +32,19 @@ namespace { // dimensions. struct MatrixDescriptor { MatrixDescriptor(se::DeviceMemoryBase matrix_data, bool needs_transpose, - int64 matrix_num_rows, int64 matrix_num_cols) + int64 matrix_num_rows, int64 matrix_num_cols, + int64 matrix_batch_size) : data(matrix_data), transpose(needs_transpose), num_rows(matrix_num_rows), - num_cols(matrix_num_cols) {} + num_cols(matrix_num_cols), + batch_size(matrix_batch_size) {} se::DeviceMemoryBase data; bool transpose; // Whether this matrix needs to be transposed. int64 num_rows; int64 num_cols; + int64 batch_size; }; // Performs a gemm call without an explicit algorithm on lhs_matrix and @@ -50,6 +54,9 @@ bool DoGemm(MatrixDescriptor lhs_matrix, MatrixDescriptor rhs_matrix, MatrixDescriptor output_matrix, double alpha, se::Stream* stream) { DCHECK(!output_matrix.transpose); + const int64 batch_size = lhs_matrix.batch_size; + CHECK_EQ(batch_size, rhs_matrix.batch_size); + CHECK_EQ(batch_size, output_matrix.batch_size); se::DeviceMemory lhs_data(lhs_matrix.data); se::DeviceMemory rhs_data(rhs_matrix.data); se::DeviceMemory output_data(output_matrix.data); @@ -60,13 +67,30 @@ bool DoGemm(MatrixDescriptor lhs_matrix, MatrixDescriptor rhs_matrix, : se::blas::Transpose::kNoTranspose; auto k = lhs_matrix.transpose ? lhs_matrix.num_rows : lhs_matrix.num_cols; + if (batch_size == 1) { + return stream + ->ThenBlasGemm( + lhs_transpose, rhs_transpose, output_matrix.num_rows, + output_matrix.num_cols, /*size of reduce dim=*/k, /*alpha=*/alpha, + lhs_data, /*leading dim of LHS=*/lhs_matrix.num_rows, rhs_data, + /*leading dim of RHS=*/rhs_matrix.num_rows, /*beta=*/0.0, + &output_data, /*leading dim of output=*/output_matrix.num_rows) + .ok(); + } + + int64 lhs_stride = lhs_matrix.num_rows * lhs_matrix.num_cols; + int64 rhs_stride = rhs_matrix.num_rows * rhs_matrix.num_cols; + int64 output_stride = output_matrix.num_rows * output_matrix.num_cols; return stream - ->ThenBlasGemm( + ->ThenBlasGemmStridedBatched( lhs_transpose, rhs_transpose, output_matrix.num_rows, - output_matrix.num_cols, /*size of reduce dim=*/k, /*alpha=*/alpha, - lhs_data, /*leading dim of LHS=*/lhs_matrix.num_rows, rhs_data, - /*leading dim of RHS=*/rhs_matrix.num_rows, /*beta=*/0.0, - &output_data, /*leading dim of output=*/output_matrix.num_rows) + output_matrix.num_cols, /*size of reduce dim=*/k, + /*alpha=*/alpha, lhs_data, + /*leading dim of LHS=*/lhs_matrix.num_rows, lhs_stride, rhs_data, + /*leading dim of RHS=*/rhs_matrix.num_rows, rhs_stride, + /*beta=*/0.0, &output_data, + /*leading dim of output=*/output_matrix.num_rows, output_stride, + batch_size) .ok(); } @@ -93,6 +117,10 @@ bool DoGemmWithAlgorithm(MatrixDescriptor lhs_matrix, se::blas::ProfileResult* output_profile_result) { DCHECK(!output_matrix.transpose); + CHECK_EQ(1, lhs_matrix.batch_size); + CHECK_EQ(1, rhs_matrix.batch_size); + CHECK_EQ(1, output_matrix.batch_size); + se::DeviceMemory lhs_data(lhs_matrix.data); se::DeviceMemory rhs_data(rhs_matrix.data); se::DeviceMemory output_data(output_matrix.data); @@ -141,9 +169,15 @@ StatusOr DoGemmAutotune( alpha, computation_type, algorithm, stream, &profile_result)); - if (profile_result.is_valid() && profile_result.elapsed_time_in_ms() < - best_result.elapsed_time_in_ms()) { - best_result = profile_result; + if (profile_result.is_valid()) { + VLOG(3) << "cublas gemm algorithm " << algorithm << " took " + << profile_result.elapsed_time_in_ms() << "ms"; + if (profile_result.elapsed_time_in_ms() < + best_result.elapsed_time_in_ms()) { + best_result = profile_result; + } + } else { + VLOG(4) << "cublas gemm algorithm " << algorithm << " failed."; } } @@ -167,6 +201,8 @@ auto GetGemmFn(PrimitiveType type) -> decltype(&DoGemm) { return &DoGemm; case F64: return &DoGemm; + case C64: + return &DoGemm>; default: LOG(FATAL) << "Unsupported type."; } @@ -180,6 +216,8 @@ auto GetGemmWithAlgorithmFn(PrimitiveType type) return &DoGemmWithAlgorithm; case F64: return &DoGemmWithAlgorithm; + case C64: + return &DoGemmWithAlgorithm>; default: LOG(FATAL) << "Unsupported type."; } @@ -192,6 +230,8 @@ auto GetGemmAutotuneFn(PrimitiveType type) -> decltype(&DoGemmAutotune) { return &DoGemmAutotune; case F64: return &DoGemmAutotune; + case C64: + return &DoGemmAutotune>; default: LOG(FATAL) << "Unsupported type."; } @@ -210,6 +250,8 @@ se::blas::ComputationType GetBlasComputationType(PrimitiveType type) { return se::blas::ComputationType::kF32; case F64: return se::blas::ComputationType::kF64; + case C64: + return se::blas::ComputationType::kComplexF32; default: LOG(FATAL) << "Unsupported type."; } @@ -263,12 +305,37 @@ Status GemmThunk::ExecuteOnStream(const BufferAllocations& buffer_allocations, se::DeviceMemoryBase output_data = buffer_allocations.GetDeviceAddress(output_buffer_); + DotDimensionNumbers dim_nums = GetDimensionNumbers(*hlo_instruction()); + CHECK_EQ(dim_nums.lhs_batch_dimensions_size(), + dim_nums.rhs_batch_dimensions_size()); + CHECK_EQ(dim_nums.lhs_batch_dimensions_size() + 2, + ShapeUtil::Rank(output_shape_)); + + int64 row_dim = dim_nums.lhs_batch_dimensions_size(); + int64 col_dim = dim_nums.lhs_batch_dimensions_size() + 1; + int64 batch_size = std::accumulate(output_shape_.dimensions().begin(), + output_shape_.dimensions().end() - 2, 1, + std::multiplies()); + + // Check that the batch dims don't cover the last two dims. + for (int64 batch_dim : dim_nums.lhs_batch_dimensions()) { + CHECK_NE(row_dim, batch_dim); + CHECK_NE(col_dim, batch_dim); + } + + // Verify that the non-batch dimensions are minor-most. This is required for + // efficient access. + for (const auto* shape : {&lhs_shape_, &rhs_shape_, &output_shape_}) { + CHECK_LT(shape->layout().minor_to_major(row_dim), 2); + CHECK_LT(shape->layout().minor_to_major(col_dim), 2); + } + // BLAS gemm reduces rows of LHS and columns of RHS. The Dot operator between // matrices reduces dimension 1 of LHS and dimension 0 of RHS regardless of // their layout. Therefore, we should treat dimension 0 as row and dimension 1 // as column when mapping a matrix Dot to BLAS gemm. - int64 output_num_rows = output_shape_.dimensions(0); - int64 output_num_cols = output_shape_.dimensions(1); + int64 output_num_rows = output_shape_.dimensions(row_dim); + int64 output_num_cols = output_shape_.dimensions(col_dim); // BLAS gemm expects the inputs and the output are in column-major order. // Therefore, we need to convert dot between row-major matrices to that @@ -291,34 +358,46 @@ Status GemmThunk::ExecuteOnStream(const BufferAllocations& buffer_allocations, // the leading dimension of the LHS matrix of gemm is the number of rows in // B^T and thus the number of columns in B. - auto make_descriptor = [this](se::DeviceMemoryBase data, const Shape& shape, - bool transpose) -> MatrixDescriptor { - bool is_row_major = LayoutUtil::Minor(shape.layout(), 0) != 0; - bool layout_mismatch = LayoutUtil::Minor(shape.layout(), 0) != - LayoutUtil::Minor(output_shape_.layout(), 0); - return MatrixDescriptor(data, transpose ^ layout_mismatch, - shape.dimensions(is_row_major), - shape.dimensions(!is_row_major)); + auto make_descriptor = [&](se::DeviceMemoryBase data, const Shape& shape, + bool transpose) -> MatrixDescriptor { + bool is_row_major = LayoutUtil::Minor(shape.layout(), row_dim) != 0; + bool layout_mismatch = LayoutUtil::Minor(shape.layout(), row_dim) != + LayoutUtil::Minor(output_shape_.layout(), row_dim); + return MatrixDescriptor( + data, transpose ^ layout_mismatch, + shape.dimensions(row_dim + static_cast(is_row_major)), + shape.dimensions(row_dim + static_cast(!is_row_major)), + batch_size); }; - DotDimensionNumbers dim_nums = GetDimensionNumbers(*hlo_instruction()); - const MatrixDescriptor lhs_descriptor = make_descriptor( - lhs_data, lhs_shape_, dim_nums.lhs_contracting_dimensions(0) == 0); + lhs_data, lhs_shape_, dim_nums.lhs_contracting_dimensions(0) == row_dim); const MatrixDescriptor rhs_descriptor = make_descriptor( - rhs_data, rhs_shape_, dim_nums.rhs_contracting_dimensions(0) == 1); + rhs_data, rhs_shape_, dim_nums.rhs_contracting_dimensions(0) == col_dim); // Dispatches to a regular cublas gemm, a gemm-with-algorithm, or attempts to // autotune this gemm to figure out the best algorithm. - auto launch = [this](MatrixDescriptor lhs_matrix, MatrixDescriptor rhs_matrix, - MatrixDescriptor output_matrix, se::Stream* stream) { + auto launch = [&](MatrixDescriptor lhs_matrix, MatrixDescriptor rhs_matrix, + MatrixDescriptor output_matrix, se::Stream* stream) { PrimitiveType element_type = output_shape_.element_type(); se::blas::ComputationType computation_type = GetBlasComputationType(element_type); + // TODO(b/112111608): Implement auto tune for batched gemm. + if (batch_size != 1) { + return GetGemmFn(element_type)(lhs_matrix, rhs_matrix, output_matrix, + alpha_, stream); + } + + auto thunk_name = [&] { + return hlo_instruction() != nullptr ? hlo_instruction()->ToString() + : ""; + }; + const string& device_name = stream->parent()->GetDeviceDescription().name(); auto autotune_it = autotune_results_.find(device_name); if (autotune_it == autotune_results_.end()) { + VLOG(3) << "Starting autotune of GemmThunk " << thunk_name(); StatusOr best_algorithm = GetGemmAutotuneFn(element_type)(lhs_matrix, rhs_matrix, output_matrix, alpha_, computation_type, stream); @@ -326,11 +405,11 @@ Status GemmThunk::ExecuteOnStream(const BufferAllocations& buffer_allocations, autotune_results_.insert({device_name, best_algorithm}).first; if (autotune_it->second.ok()) { - VLOG(2) << "Autotune on GemmThunk " << this + VLOG(2) << "Autotune on GemmThunk " << thunk_name() << " successful; best algorithm is " << best_algorithm.ValueOrDie(); } else { - VLOG(2) << "Autotune on GemmThunk " << this + VLOG(2) << "Autotune on GemmThunk " << thunk_name() << " unsuccessful. Will use generic gemm."; } } @@ -340,7 +419,7 @@ Status GemmThunk::ExecuteOnStream(const BufferAllocations& buffer_allocations, if (best_algorithm.ok()) { auto algorithm = best_algorithm.ValueOrDie(); VLOG(2) << "Using algorithm " << algorithm - << " chosen by autotuning on GemmThunk " << this; + << " chosen by autotuning on GemmThunk " << thunk_name(); return GetGemmWithAlgorithmFn(element_type)( lhs_matrix, rhs_matrix, output_matrix, alpha_, computation_type, algorithm, stream, @@ -355,16 +434,16 @@ Status GemmThunk::ExecuteOnStream(const BufferAllocations& buffer_allocations, auto op_profiler = profiler->MakeScopedInstructionProfiler(hlo_instruction()); bool launch_ok; - if (LayoutUtil::Minor(output_shape_.layout(), 0) == 0) { - launch_ok = launch( - lhs_descriptor, rhs_descriptor, - MatrixDescriptor(output_data, false, output_num_rows, output_num_cols), - stream); + if (LayoutUtil::Minor(output_shape_.layout(), row_dim) == 0) { + launch_ok = launch(lhs_descriptor, rhs_descriptor, + MatrixDescriptor(output_data, false, output_num_rows, + output_num_cols, batch_size), + stream); } else { - launch_ok = launch( - rhs_descriptor, lhs_descriptor, - MatrixDescriptor(output_data, false, output_num_cols, output_num_rows), - stream); + launch_ok = launch(rhs_descriptor, lhs_descriptor, + MatrixDescriptor(output_data, false, output_num_cols, + output_num_rows, batch_size), + stream); } if (!launch_ok) { diff --git a/tensorflow/compiler/xla/service/gpu/gemm_thunk.h b/tensorflow/compiler/xla/service/gpu/gemm_thunk.h index 939c7f85e35b4fcb943a25aa6346d72798432920..12c81f9bfc6bfdac63edf9c826b835057107fa41 100644 --- a/tensorflow/compiler/xla/service/gpu/gemm_thunk.h +++ b/tensorflow/compiler/xla/service/gpu/gemm_thunk.h @@ -52,12 +52,12 @@ class GemmThunk : public Thunk { se::Stream* stream, HloExecutionProfiler* profiler) override; - // Returns true if we'll perform autotuning if run on the given stream. If - // so, we want the GPU to be quiescent during autotuning, so as not to - // introduce noise in our results. - bool ShouldHaltAllActivityBeforeRunning(se::Stream* stream) override { - return autotune_results_.count( - stream->parent()->GetDeviceDescription().name()) != 0; + bool WillAutotuneKernel(se::Stream* stream) override { + // We will autotune this kernel if we don't already have a autotune result + // for the stream device. + return autotune_results_.find( + stream->parent()->GetDeviceDescription().name()) == + autotune_results_.end(); } private: @@ -75,6 +75,8 @@ class GemmThunk : public Thunk { // results. The map's value is the best algorithm we've found for this thunk // on this device, or an error if none of the algorithms worked and we should // use the regular gemm without an algorithm. + // + // TODO(b/112415150): Make this thread safe. std::unordered_map> autotune_results_; }; diff --git a/tensorflow/compiler/xla/service/gpu/gpu_constants.cc b/tensorflow/compiler/xla/service/gpu/gpu_constants.cc index e6ddea6d2578bbb482c481a511cc8d8adb5fa2d6..7f0b030fece8f25578bd90a538279d455350278a 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_constants.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_constants.cc @@ -30,5 +30,7 @@ const int64 kEntryParameterAlignBytes = 16; const int64 kXlaAllocatedBufferAlignBytes = tensorflow::Allocator::kAllocatorAlignment; +const int64 kConstantBufferAlignBytes = kXlaAllocatedBufferAlignBytes; + } // namespace gpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/gpu_constants.h b/tensorflow/compiler/xla/service/gpu/gpu_constants.h index 925e6927b64625f011efe7b4b960421f41ddee79..6f5f1fa09c57dfd246d702c0adc92c7e2e76805a 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_constants.h +++ b/tensorflow/compiler/xla/service/gpu/gpu_constants.h @@ -28,6 +28,9 @@ extern const int64 kEntryParameterAlignBytes; // out (result) buffers. extern const int64 kXlaAllocatedBufferAlignBytes; +// Minimum alignment for constant buffers. +extern const int64 kConstantBufferAlignBytes; + } // namespace gpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.cc b/tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.cc index fbc1303085b579e898d2f503a341754109768567..75f414e47fe3edcc1b10b392ed5cc5038be6c190 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.cc @@ -48,80 +48,17 @@ StatusOr GpuCopyInsertion::Run(HloModule* module) { TF_ASSIGN_OR_RETURN(bool changed, generic_copy_insertion.Run(module)); - TF_ASSIGN_OR_RETURN(std::unique_ptr dataflow, - HloDataflowAnalysis::Run(*module)); - - // Make sure all operands of a library call are in memory instead of constants - // in IR. Also, init values of while and conditional nodes cannot be - // constants. Insert copies for any constants found at the operands of these - // nodes. - tensorflow::gtl::FlatSet inserted_copies; + // Check the assumption that the epsilon and feature_index constants of the + // CUDNN batchnorm op are not shared with other ops where we would replace + // them with a copy. These custom op calls are generated with the + // CudnnBatchNormRewriter, so this would only happen if HloCSE merges them. for (HloComputation* computation : module->computations()) { for (HloInstruction* hlo : computation->instructions()) { - // Inserts a copy of hlo->operand(n) if it's a constant. - auto copy_operand_if_constant = [&](int64 n) -> Status { - HloInstruction* operand = hlo->mutable_operand(n); - // Skip the operands that have already been replaced with a copy in a - // previous iteration (which is possible when a constant is used as an - // operand in multiple places). - if (ContainsKey(inserted_copies, operand)) { - return Status::OK(); - } - for (auto& pair : dataflow->GetInstructionValueSet(operand)) { - const HloValueSet& value_set = pair.second; - for (const HloValue* value : value_set.values()) { - if (value->defining_instruction()->IsConstant() && - !ContainsKey(hlo_to_copy_map_, value->defining_instruction())) { - HloInstruction* constant = value->defining_instruction(); - TF_ASSIGN_OR_RETURN(HloInstruction * copy, - FindOrInsertCopy(constant)); - TF_RETURN_IF_ERROR(constant->ReplaceAllUsesWith(copy)); - inserted_copies.insert(copy); - changed = true; - } - } - } - return Status::OK(); - }; - - if (IsCustomCallToDnnBatchNorm(*hlo)) { - // The epsilon and feature_index operands to a CUDNN batchnorm op don't - // need to be materialized in memory -- in fact, they must be constants. - // These are the last two operands of all three batchnorm ops. - for (int64 i = 0; i < hlo->operand_count() - 2; ++i) { - TF_RETURN_IF_ERROR(copy_operand_if_constant(i)); - } - } else if (ImplementedAsLibraryCall(*hlo) || - hlo->opcode() == HloOpcode::kCrossReplicaSum || - hlo->opcode() == HloOpcode::kWhile || - hlo->opcode() == HloOpcode::kConditional) { - // For all other library calls, cross-replica-sum, while and conditional - // ops materialize all the operands into memory. (Cross-replica-sum - // gets its constant args materialized even if it's not implemented as a - // libcall to simplify the implementation. It's slower, but we can - // constant fold away constant args *anyway*, so we just need to make it - // work.) - for (int64 i = 0; i < hlo->operand_count(); ++i) { - TF_RETURN_IF_ERROR(copy_operand_if_constant(i)); - } + if (!IsCustomCallToDnnBatchNorm(*hlo)) { + continue; } - } - } - - if (changed) { - // Check the assumption that the epsilon and feature_index constants of the - // CUDNN batchnorm op are not shared with other ops where we would replace - // them with a copy. These custom op calls are generated with the - // CudnnBatchNormRewriter, so this would only happen if HloCSE merges them. - for (HloComputation* computation : module->computations()) { - for (HloInstruction* hlo : computation->instructions()) { - if (!IsCustomCallToDnnBatchNorm(*hlo)) { - continue; - } - for (int64 i = hlo->operand_count() - 2; i < hlo->operand_count(); - ++i) { - CHECK_EQ(hlo->operand(i)->opcode(), HloOpcode::kConstant); - } + for (int64 i = hlo->operand_count() - 2; i < hlo->operand_count(); ++i) { + CHECK_EQ(hlo->operand(i)->opcode(), HloOpcode::kConstant); } } } diff --git a/tensorflow/compiler/xla/service/gpu/gpu_executable.cc b/tensorflow/compiler/xla/service/gpu/gpu_executable.cc index 0cad2958c72797b4d70f00676928b2b21d7a3e8d..a1fbd8022db55abb05a0c5f3f85bc27bf52652c2 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_executable.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_executable.cc @@ -19,11 +19,12 @@ limitations under the License. #include #include +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/map_util.h" -#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/gpu/buffer_allocations.h" #include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/llvm_ir/buffer_assignment_util.h" #include "tensorflow/compiler/xla/service/logical_buffer.h" #include "tensorflow/compiler/xla/service/shaped_buffer.h" #include "tensorflow/compiler/xla/service/transfer_manager.h" @@ -84,7 +85,7 @@ Status GpuExecutable::ExecuteThunks( } // Stream 0 indicates `main_stream` and substreams start from stream 1. - std::vector::SmartPtr> sub_streams; + std::vector sub_streams; sub_streams.reserve(thunk_schedule_->StreamCount() - 1); while (sub_streams.size() + 1 < thunk_schedule_->StreamCount()) { sub_streams.emplace_back(); @@ -130,9 +131,10 @@ Status GpuExecutable::ExecuteThunks( stream->ThenWaitFor(FindOrDie(thunk_to_finish_event, dependency).get()); } - // If this thunk requests it, wait for all currently-executing thunks to - // finish. This is useful e.g. if the thunk is about to perform autotuning. - if (thunk->ShouldHaltAllActivityBeforeRunning(stream)) { + // If this thunk is about to autotune then wait for all currently executing + // thunks to finish. This reduces noise and thus the probability of + // choosing a suboptimal algorithm. + if (thunk->WillAutotuneKernel(stream)) { TF_RETURN_IF_ERROR(main_stream->BlockHostUntilDone()); } @@ -142,7 +144,7 @@ Status GpuExecutable::ExecuteThunks( TF_RETURN_IF_ERROR( thunk->ExecuteOnStream(buffer_allocations, stream, &profiler)); if (thunk_schedule_->Depended(thunk)) { - auto finish_event = MakeUnique(main_stream->parent()); + auto finish_event = absl::make_unique(main_stream->parent()); finish_event->Init(); stream->ThenRecordEvent(finish_event.get()); thunk_to_finish_event[thunk] = std::move(finish_event); @@ -181,6 +183,55 @@ Status GpuExecutable::ExecuteThunks( return Status::OK(); } +StatusOr +GpuExecutable::ResolveConstantGlobals(se::StreamExecutor* executor) { + tensorflow::mutex_lock lock(module_handle_mutex_); + auto it = module_globals_.find(executor); + if (it != module_globals_.end()) { + return &it->second; + } + + se::MultiModuleLoaderSpec module_spec; + if (!cubin().empty()) { + module_spec.AddCudaCubinInMemory(cubin()); + } + module_spec.AddCudaPtxInMemory(ptx().c_str()); + + tensorflow::gtl::FlatMap globals; + se::ModuleHandle module_handle; + executor->LoadModule(module_spec, &module_handle); + + for (BufferAllocation::Index i = 0; i < assignment_->Allocations().size(); + ++i) { + const BufferAllocation& allocation = assignment_->GetAllocation(i); + if (allocation.is_constant()) { + TF_ASSIGN_OR_RETURN( + se::DeviceMemoryBase global, + executor->GetUntypedSymbol( + llvm_ir::ConstantBufferAllocationToGlobalName(allocation), + module_handle)); + VLOG(3) << "Resolved global " + << llvm_ir::ConstantBufferAllocationToGlobalName(allocation) + << " to " << global.opaque(); + InsertOrDie(&globals, i, global); + + const Literal& literal = + llvm_ir::LiteralForConstantAllocation(allocation); + CHECK(ShapeUtil::IsArray(literal.shape())); + if (!ShouldEmitLiteralInLlvmIr(literal)) { + VLOG(3) << "H2D memcpy for constant with shape " + << ShapeUtil::HumanString(literal.shape()); + TF_RETURN_IF_ERROR(executor->SynchronousMemcpyH2D( + literal.untyped_data(), allocation.size(), &global)); + } + } + } + + module_handles_.emplace(executor, + se::ScopedModuleHandle(executor, module_handle)); + return &module_globals_.emplace(executor, std::move(globals)).first->second; +} + StatusOr GpuExecutable::ExecuteOnStream( const ServiceExecutableRunOptions* run_options, tensorflow::gtl::ArraySlice arguments, @@ -192,6 +243,10 @@ StatusOr GpuExecutable::ExecuteOnStream( } BufferAllocations::Builder buffer_allocations_builder; + se::StreamExecutor* executor = run_options->stream()->parent(); + + TF_ASSIGN_OR_RETURN(auto* const globals, ResolveConstantGlobals(executor)); + for (BufferAllocation::Index i = 0; i < assignment_->Allocations().size(); ++i) { const BufferAllocation& allocation = assignment_->GetAllocation(i); @@ -213,8 +268,12 @@ StatusOr GpuExecutable::ExecuteOnStream( buffer_allocations_builder.RegisterBuffer(i, buffer); } + + if (allocation.is_constant()) { + buffer_allocations_builder.RegisterBuffer(i, FindOrDie(*globals, i)); + } } - se::StreamExecutor* executor = run_options->stream()->parent(); + TF_ASSIGN_OR_RETURN( auto buffer_allocations, buffer_allocations_builder.Build( @@ -235,7 +294,7 @@ StatusOr GpuExecutable::ExecuteOnStream( // the respective location in ShapedBuffer. std::set buffers_in_result; TF_RETURN_IF_ERROR(shaped_buffer.buffers().ForEachMutableElementWithStatus( - [&buffer_allocations, &buffers_in_result, &shaped_buffer, this]( + [&buffer_allocations, &buffers_in_result, this]( const ShapeIndex& index, se::DeviceMemoryBase* device_memory) { const auto& sources = this->GetRootPointsToSet().element(index); // The points-to set is unambiguous so the set should be a diff --git a/tensorflow/compiler/xla/service/gpu/gpu_executable.h b/tensorflow/compiler/xla/service/gpu/gpu_executable.h index 80ec38c3ac114fe4ad9d56784330c1144d913db1..c7ce6d0acbbbe594040271c0d45c71c016e36514 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_executable.h +++ b/tensorflow/compiler/xla/service/gpu/gpu_executable.h @@ -34,6 +34,8 @@ limitations under the License. #include "tensorflow/compiler/xla/types.h" #include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/lib/gtl/flatmap.h" +#include "tensorflow/core/lib/gtl/optional.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" @@ -66,7 +68,7 @@ class GpuExecutable : public Executable { } // Returns the compiled PTX for the computation. - tensorflow::StringPiece ptx() const { return ptx_; } + const string& ptx() const { return ptx_; } // Returns the cubin (compiled PTX) stored in this GpuExecutable. May be // empty, in which case compilation is left up to the GPU driver. @@ -98,6 +100,15 @@ class GpuExecutable : public Executable { // computation. Uses points-to analysis from buffer assignment. const PointsToSet& GetRootPointsToSet() const; + using BufferAllocToDeviceMemoryMap = + tensorflow::gtl::FlatMap; + + // Loads the PTX or CUBIN for this executable into `executor` and resolves the + // globals corresponding to constant buffers. Returns a map mapping buffer + // allocation indices to GPU pointers. + StatusOr ResolveConstantGlobals( + stream_executor::StreamExecutor* executor); + // The LLVM IR, in string format, of the unoptimized module generated for this // GpuExecutable. We save a string instead of an llvm::Module* because leaving // llvm::Module* in a singleton can cause the heap checker to emit false @@ -126,6 +137,14 @@ class GpuExecutable : public Executable { // memory for every output/temp buffers. const std::unique_ptr assignment_; + // Cache of module handles and constant buffer allocation maps used by + // `ResolveConstantGlobals`. + tensorflow::mutex module_handle_mutex_; + std::map + module_handles_ GUARDED_BY(module_handle_mutex_); + std::map + module_globals_ GUARDED_BY(module_handle_mutex_); + TF_DISALLOW_COPY_AND_ASSIGN(GpuExecutable); }; diff --git a/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.cc b/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.cc index 09ef62c87f8875a5803497e8eb628769f883202a..d033faee8d25ed81a1483f8314652ef999ab36c5 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.cc @@ -31,20 +31,13 @@ limitations under the License. namespace xla { namespace gpu { -using stream_executor::dnn::DataLayout; -using stream_executor::dnn::FilterLayout; - -static bool IsVoltaOrLater(const se::StreamExecutor& stream_executor) { - int major, minor; - CHECK(stream_executor.GetDeviceDescription().cuda_compute_capability(&major, - &minor)); - return major >= 7; -} +using se::dnn::DataLayout; +using se::dnn::FilterLayout; // Returns (input, filter, output) layouts. static std::tuple HeuristicLayoutAssignment(const HloInstruction* instr, - stream_executor::StreamExecutor* stream_executor) { + se::StreamExecutor* stream_executor) { // DataLayout and FilterLayout uses weird enum names. Translations: // N <=> Batch or Output // C <=> Depth or Input @@ -52,31 +45,44 @@ HeuristicLayoutAssignment(const HloInstruction* instr, // W <=> X // // Therefore kOutputInputYX and kBatchDepthYX mean NCHW. + // + // If you have trouble keeping these straight, consider that all that matters + // is the location of the channel dim: Is it major (NCHW), or minor (NHWC)? + + constexpr auto kAllNCHW = + std::make_tuple(DataLayout::kBatchDepthYX, FilterLayout::kOutputInputYX, + DataLayout::kBatchDepthYX); + constexpr auto kAllNHWC = + std::make_tuple(DataLayout::kBatchYXDepth, FilterLayout::kOutputYXInput, + DataLayout::kBatchYXDepth); - // As of today, our empirical evidence is that cudnn 7.0 is faster on V100 x - // fp16 with the mostly-NHWC layout. The heuristic may change as cudnn version - // changes, as well as the hardware updates. + // If we're not Volta or not fp16, the decision is easy: Use NCHW. if (!(instr->operand(0)->shape().element_type() == xla::PrimitiveType::F16 && IsVoltaOrLater(*stream_executor))) { - return std::make_tuple(DataLayout::kBatchDepthYX, - FilterLayout::kOutputInputYX, - DataLayout::kBatchDepthYX); + return kAllNCHW; } + VLOG(2) << "Using heuristic to figure out layouts for " << instr->ToString(); - // For BackwardInput that has stride, full NHWC layouts run significantly - // slower than (NHWC, NCHW, NCHW) or (NHWC, NCHW, NHWC). + + // Empirically we've found with Volta and cudnn 7 that backward-input convs + // with stride are significantly faster with NCHW layouts. // - // TODO(timshen): more closely compare (NHWC, NCHW, NCHW) and (NHWC, NCHW, - // NHWC). + // We could have used a mixed layout combination, e.g. (NHWC, NCHW, NCHW), + // which on paper gives good performance. However, there are two observations: + // * a mixed layout combination is more cuDNN-bug prone, based on empirical + // envidence. + // * we've also observed that for mixed layouts, cuDNN transposes data back + // and forth from a different layout combination. If we end up with + // transposes anyway, we prefer to have them in XLA, as they can be fused. + // TODO(timshen): Figure out the exact condition. This may be achieved by + // auto-tuning layouts offline. if (instr->custom_call_target() == kCudnnConvBackwardInputCallTarget && window_util::HasStride(instr->window())) { - return std::make_tuple(DataLayout::kBatchYXDepth, - FilterLayout::kOutputInputYX, - DataLayout::kBatchDepthYX); + return kAllNCHW; } - return std::make_tuple(DataLayout::kBatchYXDepth, - FilterLayout::kOutputYXInput, - DataLayout::kBatchYXDepth); + + // For other Volta f16 convolutions, use NHWC. + return kAllNHWC; } // Adds layout constraints on the cudnn custom-call instruction. The layout @@ -170,6 +176,38 @@ Status GpuLayoutAssignment::AddBackendConstraints( TF_RETURN_IF_ERROR( AddBackendConstraintsToDnnConvCustomCall(instruction, constraints)); } + + // For batched dot we require the default layout. + // TODO(b/112111608): This is overly conservative, the only real restriction + // is that batch dimensions must be major. + if (instruction->opcode() == HloOpcode::kDot && + ImplementedAsGemm(*instruction) && + instruction->dot_dimension_numbers().lhs_batch_dimensions_size() > 0) { + // Verify that the batch dims come before the row and col dims. + const DotDimensionNumbers& dim_nums = + instruction->dot_dimension_numbers(); + CHECK_EQ(dim_nums.lhs_batch_dimensions_size(), + dim_nums.rhs_batch_dimensions_size()); + CHECK_EQ(dim_nums.lhs_batch_dimensions_size() + 2, + ShapeUtil::Rank(instruction->shape())); + for (int64 batch_dim : dim_nums.lhs_batch_dimensions()) { + CHECK_LT(batch_dim, ShapeUtil::Rank(instruction->shape()) - 2); + } + + // Set both inputs and the output to default layout. + Shape op0_shape = instruction->operand(0)->shape(); + LayoutUtil::SetToDefaultLayout(&op0_shape); + Shape op1_shape = instruction->operand(1)->shape(); + LayoutUtil::SetToDefaultLayout(&op1_shape); + Shape output_shape = instruction->shape(); + LayoutUtil::SetToDefaultLayout(&output_shape); + TF_RETURN_IF_ERROR( + constraints->SetOperandLayout(op0_shape, instruction, 0)); + TF_RETURN_IF_ERROR( + constraints->SetOperandLayout(op1_shape, instruction, 1)); + TF_RETURN_IF_ERROR( + constraints->SetInstructionLayout(output_shape, instruction)); + } } return Status::OK(); } diff --git a/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment_test.cc b/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment_test.cc index 95f78ae29326caad2f0785e2ba285a996e685899..286547ebae2f1a4b8d783a06d13b4dd96052b952 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment_test.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment_test.cc @@ -20,8 +20,10 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_matchers.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/shape_layout.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" @@ -31,6 +33,8 @@ namespace xla { namespace gpu { namespace { +namespace op = xla::testing::opcode_matchers; + using LayoutAssignmentTest = HloTestBase; TEST_F(LayoutAssignmentTest, Elementwise) { @@ -327,6 +331,33 @@ TEST_F(LayoutAssignmentTest, BatchNormGrad) { } } +TEST_F(LayoutAssignmentTest, DotLayout) { + const char* hlo_text = R"( + HloModule DotLayout + ENTRY dot { + p0 = f32[8,8,256,64]{3,1,2,0} parameter(0) + p1 = f32[8,8,256,64]{3,1,2,0} parameter(1) + ROOT dot.1330.10585 = f32[8,8,256,256]{3,2,1,0} dot(p0, p1), + lhs_batch_dims={0,1}, lhs_contracting_dims={3}, + rhs_batch_dims={0,1}, rhs_contracting_dims={3} + })"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseHloString(hlo_text)); + + ComputationLayout computation_layout( + module->entry_computation()->ComputeProgramShape()); + GpuLayoutAssignment layout_assignment(&computation_layout, + backend().default_stream_executor()); + EXPECT_TRUE(layout_assignment.Run(module.get()).ValueOrDie()); + + Shape expected_shape = + ShapeUtil::MakeShapeWithLayout(F32, {8, 8, 256, 64}, {3, 2, 1, 0}); + EXPECT_THAT(module->entry_computation()->root_instruction(), + op::Dot(op::ShapeWithLayout(expected_shape), + op::ShapeWithLayout(expected_shape))); +} + } // namespace } // namespace gpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.cc b/tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.cc index 79b3f1efecdf06bfa93b17a1799f3009d517f3b5..44303724bb5cda4f392c8d17d60c114286b6b7e2 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.cc @@ -19,6 +19,7 @@ limitations under the License. #include #include +#include "absl/memory/memory.h" #include "llvm/IR/DataLayout.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/literal_util.h" @@ -117,38 +118,37 @@ StatusOr GpuTransferManager::TransferBufferToInfeedInternal( return std::move(buffer); } -static std::unique_ptr ShapeTreeToLiteral( +static void ShapeTreeToLiteral( ShapeTree>* shape_tree) { // This is a struct instead of a lambda for std::function-free recursion. struct Helper { - static std::unique_ptr helper( + static void helper( ShapeTree>* shape_tree, ShapeIndex* index) { const Shape& shape = ShapeUtil::GetSubshape(shape_tree->shape(), *index); if (ShapeUtil::IsArray(shape)) { - return (*shape_tree->mutable_element(*index))->WaitUntilAvailable(); + (*shape_tree->mutable_element(*index))->WaitUntilAvailable(); + return; } CHECK(ShapeUtil::IsTuple(shape)) << ShapeUtil::HumanStringWithLayout(shape); const int64 tuple_element_count = ShapeUtil::TupleElementCount(shape); index->push_back(0); - std::vector> tuple_operands; for (int64 i = 0; i < tuple_element_count; ++i) { index->back() = i; - tuple_operands.push_back(helper(shape_tree, index)); + helper(shape_tree, index); } index->pop_back(); - return LiteralUtil::MakeTupleOwned(std::move(tuple_operands)); } }; ShapeIndex index; - return Helper::helper(shape_tree, &index); + Helper::helper(shape_tree, &index); } Status GpuTransferManager::TransferLiteralFromOutfeed( se::StreamExecutor* /*executor*/, const Shape& literal_shape, - Literal* literal) { + MutableBorrowingLiteral literal) { ShapeTree> outfeed_buffers( &literal_shape); @@ -161,7 +161,10 @@ Status GpuTransferManager::TransferLiteralFromOutfeed( if (ShapeUtil::IsTuple(shape)) { return; } - *buffer = MakeUnique(GetByteSizeRequirement(shape)); + *buffer = absl::make_unique( + GetByteSizeRequirement(shape)); + (*buffer)->set_destination( + absl::make_unique(literal, index)); }); // Give the tree of buffers to the outfeed mananger. The device will fill it @@ -169,8 +172,8 @@ Status GpuTransferManager::TransferLiteralFromOutfeed( gpu::OutfeedManager* outfeed_manager = gpu::GetOrCreateOutfeedManager(); outfeed_manager->EnqueueDestination(&outfeed_buffers); - // Now turn the tree of buffers back into a literal. - *literal = std::move(*ShapeTreeToLiteral(&outfeed_buffers)); + // Now wait for the tree of buffers are written. + ShapeTreeToLiteral(&outfeed_buffers); return Status::OK(); } @@ -178,7 +181,7 @@ Status GpuTransferManager::TransferLiteralFromOutfeed( } // namespace xla static std::unique_ptr CreateNVPTXTransferManager() { - return xla::MakeUnique( + return absl::make_unique( /*id=*/stream_executor::cuda::kCudaPlatformId, /*pointer_size=*/llvm::DataLayout(xla::gpu::NVPTXCompiler::kDataLayout) .getPointerSize(0 /* default address space */)); diff --git a/tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.h b/tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.h index dceeb9e2eb01a7dd5e978d819ed1db56d828f353..7929042869763dfeab2fe8f87093b7ea758337d0 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.h +++ b/tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.h @@ -42,7 +42,7 @@ class GpuTransferManager : public GenericTransferManager { const LiteralSlice& literal) override; Status TransferLiteralFromOutfeed(se::StreamExecutor* executor, const Shape& literal_shape, - Literal* literal) override; + MutableBorrowingLiteral literal) override; private: // Initiates the infeed data transfers. InfeedBuffer->Done() must be diff --git a/tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.cc b/tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.cc index 19420e590d05892417da4d5e62fdcde5eba9d9f1..b9c21e8edb2bdde03acb1fe6197a399724c9c8ab 100644 --- a/tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.cc +++ b/tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.cc @@ -20,10 +20,11 @@ limitations under the License. #include #include +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_execution_profile.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" -#include "tensorflow/compiler/xla/service/pool.h" +#include "tensorflow/compiler/xla/service/stream_pool.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" #include "tensorflow/core/util/ptr_util.h" @@ -33,14 +34,13 @@ namespace gpu { namespace { void InitAndStartTimer(std::stack>* timers, se::Stream* stream) { - timers->push(MakeUnique(stream->parent())); + timers->push(absl::make_unique(stream->parent())); stream->InitTimer(timers->top().get()).ThenStartTimer(timers->top().get()); } -uint64 GetCyclesTaken( - std::stack>* timers, - const std::vector::SmartPtr>& sub_streams, - se::Stream* stream, double clock_rate_ghz) { +uint64 GetCyclesTaken(std::stack>* timers, + const std::vector& sub_streams, + se::Stream* stream, double clock_rate_ghz) { CHECK_GT(timers->size(), 0); stream->ThenWaitFor(&sub_streams); stream->ThenStopTimer(timers->top().get()); @@ -53,7 +53,7 @@ uint64 GetCyclesTaken( HloExecutionProfiler::HloExecutionProfiler( bool do_profile, HloExecutionProfile* profile, se::Stream* stream, - const std::vector::SmartPtr>& sub_streams, + const std::vector& sub_streams, const HloComputation* computation) : do_profile_(do_profile), profile_(profile), @@ -116,7 +116,7 @@ HloExecutionProfiler::MakeScopedInstructionProfiler( CHECK(hlo_instructions_.insert(hlo_instruction).second) << hlo_instruction->name(); } - return MakeUnique(this, hlo_instruction); + return absl::make_unique(this, hlo_instruction); } } // namespace gpu diff --git a/tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h b/tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h index 6654850bef3efa46028defbba81e3537fafbf143..80cde75f2bbb555f514fffea58ad92edf92fd0d1 100644 --- a/tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h +++ b/tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h @@ -24,7 +24,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_execution_profile.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" -#include "tensorflow/compiler/xla/service/pool.h" +#include "tensorflow/compiler/xla/service/stream_pool.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" namespace xla { @@ -38,10 +38,10 @@ class ScopedInstructionProfiler; class HloExecutionProfiler { public: // If profiling is enabled, start an execution timer running. - explicit HloExecutionProfiler( - bool do_profile, HloExecutionProfile* profile, se::Stream* stream, - const std::vector::SmartPtr>& sub_streams, - const HloComputation* computation); + explicit HloExecutionProfiler(bool do_profile, HloExecutionProfile* profile, + se::Stream* stream, + const std::vector& sub_streams, + const HloComputation* computation); // If profiling is enabled, sets the total cycle count on the profile from the // execution timer. @@ -72,7 +72,7 @@ class HloExecutionProfiler { double clock_rate_ghz_; HloExecutionProfile* profile_; se::Stream* stream_; - const std::vector::SmartPtr>& sub_streams_; + const std::vector& sub_streams_; const HloComputation* computation_; std::stack> timers_; // Contains the HLO instructions for which we are currently measuring the diff --git a/tensorflow/compiler/xla/service/gpu/hlo_schedule.cc b/tensorflow/compiler/xla/service/gpu/hlo_schedule.cc index 19de37b0fbed15455e8c6a9bfe427ba3d9f0a9dc..76055ff009c05499ecfbfce31d87c65f3e39785d 100644 --- a/tensorflow/compiler/xla/service/gpu/hlo_schedule.cc +++ b/tensorflow/compiler/xla/service/gpu/hlo_schedule.cc @@ -19,7 +19,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/hlo_schedule.h" -#include "tensorflow/compiler/xla/ptr_util.h" +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/service/buffer_value.h" #include "tensorflow/compiler/xla/service/hlo_reachability.h" #include "tensorflow/compiler/xla/service/hlo_scheduling.h" @@ -59,8 +59,8 @@ GpuHloOrdering::GpuHloOrdering( : PredecessorHloOrdering(module) { // The entry computation has a total order when there's only one stream. if (stream_assignment.StreamCount() == 1) { - entry_sequence_ = - MakeUnique>(thunk_launch_order); + entry_sequence_ = absl::make_unique>( + thunk_launch_order); } // The ordering of instructions for the entry computation is determined by the @@ -75,7 +75,7 @@ GpuHloOrdering::GpuHloOrdering( // same-stream predecessors of each instruction. // Compute the set of all instructions we will want to set reachability on. - auto predecessor_map = MakeUnique( + auto predecessor_map = absl::make_unique( module->entry_computation()->MakeInstructionPostOrder()); // The most recently visited instruction per stream. @@ -208,7 +208,7 @@ StatusOr> HloSchedule::Build( BFSLaunchOrder(entry_computation, &schedule->thunk_launch_order_); } - schedule->hlo_ordering_ = MakeUnique( + schedule->hlo_ordering_ = absl::make_unique( &module, stream_assignment, schedule->thunk_launch_order_); return std::move(schedule); diff --git a/tensorflow/compiler/xla/service/gpu/hlo_schedule_test.cc b/tensorflow/compiler/xla/service/gpu/hlo_schedule_test.cc index 45f0a1c645b2875cf90d2c11cfb66c3dd855d097..d4a96cd5b353436ea4d1d6db3810b3e777449cd4 100644 --- a/tensorflow/compiler/xla/service/gpu/hlo_schedule_test.cc +++ b/tensorflow/compiler/xla/service/gpu/hlo_schedule_test.cc @@ -18,6 +18,7 @@ limitations under the License. #include #include +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/service/gpu/stream_assignment.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" @@ -47,7 +48,7 @@ class HloScheduleTest : public HloTestBase { auto debug_options = GetDebugOptionsForTest(); debug_options.set_xla_gpu_disable_multi_streaming(false); config.set_debug_options(debug_options); - return MakeUnique("test_module", config); + return absl::make_unique("test_module", config); } HloVec RemoveHlo(const HloVec& input, diff --git a/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.cc b/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.cc index 1b6315ec0305712d1367a9380f0de3eed91e2ee1..8c11cd05419289d82b033c936bb60884f45cb636 100644 --- a/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.cc +++ b/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.cc @@ -18,8 +18,10 @@ limitations under the License. #include "llvm/IR/BasicBlock.h" #include "llvm/IR/Function.h" #include "llvm/IR/Instructions.h" +#include "tensorflow/compiler/xla/service/gpu/buffer_allocations.h" #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/service/llvm_ir/buffer_assignment_util.h" #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" #include "tensorflow/compiler/xla/service/llvm_ir/tuple_ops.h" #include "tensorflow/core/lib/strings/str_util.h" @@ -110,6 +112,12 @@ void HloToIrBindings::EmitBasePointersForHlos( llvm_ir::ShapeToIrType(non_io_hlo->shape(), module_); BindHloToIrValue(*non_io_hlo, b_->CreateAlloca(pointee_type), index); + } else if (slice.allocation()->is_constant()) { + llvm::Value* global_for_constant = + module_->getGlobalVariable(llvm_ir::AsStringRef( + llvm_ir::ConstantBufferAllocationToGlobalName( + *slice.allocation()))); + BindHloToIrValue(*non_io_hlo, global_for_constant); } else { const int64 offset = slice.offset(); CHECK_NE(nullptr, temp_buffer_base_); @@ -135,6 +143,14 @@ llvm::Value* HloToIrBindings::EmitGetTupleElement(const HloInstruction* gte, EmitGetTupleElement(gte->operand(0), base_ptr), b_, module_); } +// Returns true if `value` has a name that should not be changed. +static bool HasMeaningfulName(llvm::Value* value) { + if (auto* global = llvm::dyn_cast(value)) { + return global->getLinkage() != llvm::GlobalValue::PrivateLinkage; + } + return false; +} + llvm::Value* HloToIrBindings::GetTypedIrValue(const HloInstruction& hlo, ShapeIndexView shape_index, llvm::Value* ir_value) { @@ -149,8 +165,13 @@ llvm::Value* HloToIrBindings::GetTypedIrValue(const HloInstruction& hlo, } else { typed_ir_value = b_->CreateBitCast(ir_value, pointee_type->getPointerTo()); } - ir_value->setName(llvm_ir::AsStringRef(llvm_ir::IrName(&hlo, "raw"))); - typed_ir_value->setName(llvm_ir::AsStringRef(llvm_ir::IrName(&hlo, "typed"))); + if (!HasMeaningfulName(ir_value)) { + ir_value->setName(llvm_ir::AsStringRef(llvm_ir::IrName(&hlo, "raw"))); + } + if (!HasMeaningfulName(typed_ir_value)) { + typed_ir_value->setName( + llvm_ir::AsStringRef(llvm_ir::IrName(&hlo, "typed"))); + } return typed_ir_value; } diff --git a/tensorflow/compiler/xla/service/gpu/infeed_manager.cc b/tensorflow/compiler/xla/service/gpu/infeed_manager.cc index c5f0cdf6cd5d3e076bffa875fbba991bf0681ee8..a4364b0deb6c97b7b580e18bf67d5f3a8fd3cc62 100644 --- a/tensorflow/compiler/xla/service/gpu/infeed_manager.cc +++ b/tensorflow/compiler/xla/service/gpu/infeed_manager.cc @@ -15,7 +15,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/infeed_manager.h" -#include "tensorflow/compiler/xla/ptr_util.h" +#include "absl/memory/memory.h" namespace xla { namespace gpu { @@ -24,7 +24,7 @@ se::Stream* InfeedManager::GetStream(se::StreamExecutor* executor) { tensorflow::mutex_lock l(host_to_device_stream_mu_); if (host_to_device_executor_ == nullptr) { host_to_device_executor_ = executor; - host_to_device_stream_ = MakeUnique(executor); + host_to_device_stream_ = absl::make_unique(executor); host_to_device_stream_->Init(); } diff --git a/tensorflow/compiler/xla/service/gpu/instruction_fusion.cc b/tensorflow/compiler/xla/service/gpu/instruction_fusion.cc index af6259ae83e3e18ad4b69ab42fc126e7486794f1..0f2c83aeb2633a007559d8caac78ea2d233539ed 100644 --- a/tensorflow/compiler/xla/service/gpu/instruction_fusion.cc +++ b/tensorflow/compiler/xla/service/gpu/instruction_fusion.cc @@ -202,6 +202,7 @@ bool GpuInstructionFusion::ShouldFuse(HloInstruction* consumer, IsIEEEFloatingPointScalarConstant(producer->operand(0)) && fused_parameter_users[0]->opcode() == HloOpcode::kMultiply; } + return false; } // Other output fusions are not currently supported on GPUs. diff --git a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc index 6352b330d17d77da65ed4ffb5a225535ff6caf82..c349063c71f000435a05306101ad724505f2d197 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc @@ -38,24 +38,27 @@ namespace gpu { namespace { // Return whether the given shape is a matrix with no padding. -bool IsRank2WithNoPadding(const Shape& shape) { - return ShapeUtil::Rank(shape) == 2 && !LayoutUtil::IsPadded(shape); +bool IsRank2WithNoPadding(const Shape& shape, int64 batch_dimensions_size) { + return ShapeUtil::Rank(shape) == batch_dimensions_size + 2 && + !LayoutUtil::IsPadded(shape); } // In a gemm operation where output = lhs * rhs, check whether the given shapes // are valid for the operation. bool AreValidGemmShapes(const Shape& lhs_shape, const Shape& rhs_shape, - const Shape& output_shape) { + const Shape& output_shape, + int64 batch_dimensions_size) { // The inputs and the output must // 1) be matrices with no padding and a non-zero number of elements, // 2) have an allowed element type. PrimitiveType output_primitive_type = output_shape.element_type(); bool type_is_allowed = (output_primitive_type == F16 || output_primitive_type == F32 || - output_primitive_type == F64); - return type_is_allowed && IsRank2WithNoPadding(lhs_shape) && - IsRank2WithNoPadding(rhs_shape) && - IsRank2WithNoPadding(output_shape) && + output_primitive_type == F64 || output_primitive_type == C64); + return type_is_allowed && + IsRank2WithNoPadding(lhs_shape, batch_dimensions_size) && + IsRank2WithNoPadding(rhs_shape, batch_dimensions_size) && + IsRank2WithNoPadding(output_shape, batch_dimensions_size) && !ShapeUtil::IsZeroElementArray(lhs_shape) && !ShapeUtil::IsZeroElementArray(rhs_shape); } @@ -64,14 +67,15 @@ bool DotImplementedAsGemm(const HloInstruction& dot) { CHECK_EQ(dot.opcode(), HloOpcode::kDot); const Shape& lhs_shape = dot.operand(0)->shape(); const Shape& rhs_shape = dot.operand(1)->shape(); + const DotDimensionNumbers& dim_numbers = dot.dot_dimension_numbers(); // If gemm can accept the operand shapes, use it rather than a custom // kernel. - if (AreValidGemmShapes(lhs_shape, rhs_shape, dot.shape())) { + if (AreValidGemmShapes(lhs_shape, rhs_shape, dot.shape(), + dim_numbers.lhs_batch_dimensions_size())) { // The size of the reduction dimension should match. The shape inference // guarantees this invariant, so the check here is for programming // errors. - const DotDimensionNumbers& dim_numbers = dot.dot_dimension_numbers(); CHECK_EQ(lhs_shape.dimensions(dim_numbers.lhs_contracting_dimensions(0)), rhs_shape.dimensions(dim_numbers.rhs_contracting_dimensions(0))); return true; diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter.cc index 973848c336e77491fd4f47d7dbea89eaae2720db..7111b53944770c9dbfcd0611f67b18900bcf1ffb 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter.cc @@ -21,6 +21,7 @@ limitations under the License. #include "tensorflow/core/platform/logging.h" // IWYU pragma: no_include "llvm/IR/Intrinsics.gen.inc" +#include "absl/algorithm/container.h" #include "llvm/IR/BasicBlock.h" #include "llvm/IR/Constants.h" #include "llvm/IR/Instructions.h" @@ -64,7 +65,7 @@ IrEmitter::IrEmitter(const HloModuleConfig& hlo_module_config, hlo_module_config_(hlo_module_config) { b_.setFastMathFlags(llvm_ir::GetFastMathFlags( /*fast_math_enabled=*/hlo_module_config.debug_options() - .xla_enable_fast_math())); + .xla_gpu_enable_fast_math())); } Status IrEmitter::DefaultAction(HloInstruction* hlo) { @@ -81,19 +82,6 @@ Status IrEmitter::DefaultAction(HloInstruction* hlo) { } Status IrEmitter::HandleConstant(HloInstruction* constant) { - const Literal& literal = constant->literal(); - llvm::Constant* initializer = - llvm_ir::ConvertLiteralToIrConstant(literal, module_); - llvm::GlobalVariable* global_for_const = new llvm::GlobalVariable( - *module_, initializer->getType(), - /*isConstant=*/true, llvm::GlobalValue::PrivateLinkage, initializer, - /*Name=*/""); - VLOG(2) << "HandleConstant: " << constant->ToString() << std::endl - << " emitted_value: " << llvm_ir::DumpToString(*global_for_const) - << std::endl - << " its type: " - << llvm_ir::DumpToString(*global_for_const->getType()); - bindings_.BindHloToIrValue(*constant, global_for_const); return Status::OK(); } @@ -138,6 +126,10 @@ Status IrEmitter::HandleRecvDone(HloInstruction*) { return Unimplemented("Recv-done is not implemented on GPU"); } +Status IrEmitter::HandleScatter(HloInstruction*) { + return Unimplemented("Scatter is not implemented on GPUs."); +} + Status IrEmitter::HandleTuple(HloInstruction* tuple) { std::vector base_ptrs; for (const HloInstruction* operand : tuple->operands()) { @@ -463,6 +455,9 @@ Status IrEmitter::HandleDot(HloInstruction* dot) { const Shape& lhs_shape = lhs_instruction->shape(); const Shape& rhs_shape = rhs_instruction->shape(); + const DotDimensionNumbers& dnums = dot->dot_dimension_numbers(); + CHECK_EQ(dnums.lhs_batch_dimensions_size(), + dnums.rhs_batch_dimensions_size()); // TODO(b/110211620): Convert to use i32 index_type when it is possible. llvm::Type* index_type = b_.getInt64Ty(); @@ -498,9 +493,15 @@ Status IrEmitter::HandleDot(HloInstruction* dot) { const int64 lhs_reduction_dimension = ShapeUtil::GetDimensionNumber(lhs_shape, -1); const int64 rhs_reduction_dimension = - ShapeUtil::Rank(rhs_shape) >= 2 + ShapeUtil::Rank(rhs_shape) >= 2 + dnums.lhs_batch_dimensions_size() ? ShapeUtil::GetDimensionNumber(rhs_shape, -2) - : 0; + : dnums.lhs_batch_dimensions_size(); + + // Check that the batch dims don't cover the last two dims. + for (int64 batch_dim : dnums.lhs_batch_dimensions()) { + CHECK_NE(lhs_reduction_dimension, batch_dim); + CHECK_NE(rhs_reduction_dimension, batch_dim); + } // Verify the reduction dimension in the two operands are the same size. TF_RET_CHECK(lhs_shape.dimensions(lhs_reduction_dimension) == @@ -515,6 +516,13 @@ Status IrEmitter::HandleDot(HloInstruction* dot) { llvm_ir::IrArray::Index rhs_index = loop_nest.EmitOperandArrayLoopNest( rhs_array, /*dimension_to_skip=*/rhs_reduction_dimension, "rhs"); + // We don't have to iterate over the batch dimensions in both arrays, simplify + // the loop nest of the rhs. + for (int i = 0; i != dnums.lhs_batch_dimensions_size(); ++i) { + DCHECK(absl::c_linear_search(dnums.lhs_batch_dimensions(), i)); + rhs_index[i] = lhs_index[i]; + } + // Create the reduction loop which does the sum of products reduction. std::unique_ptr reduction_loop = loop_nest.AddLoop( /*start_index=*/0, @@ -577,7 +585,9 @@ Status IrEmitter::HandleDot(HloInstruction* dot) { target_index.push_back(lhs_index[dimension]); } } - for (size_t dimension = 0; dimension < rhs_index.size(); ++dimension) { + // Skip over the batch dimensions to not have them in the index twice. + for (size_t dimension = dnums.lhs_batch_dimensions_size(); + dimension < rhs_index.size(); ++dimension) { if (dimension != rhs_reduction_dimension) { target_index.push_back(rhs_index[dimension]); } @@ -623,6 +633,10 @@ Status IrEmitter::HandleParameter(HloInstruction* parameter) { } Status IrEmitter::HandleReduce(HloInstruction* reduce) { + // TODO(b/112040122): Support variadic reduce. + if (!ShapeUtil::IsArray(reduce->shape())) { + return Unimplemented("Variadic reduce is not supported on GPU"); + } auto arg = reduce->operand(0); auto init_value = reduce->operand(1); tensorflow::gtl::ArraySlice dimensions(reduce->dimensions()); diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter.h b/tensorflow/compiler/xla/service/gpu/ir_emitter.h index 80e2a203ac3a1fbe95bf38a886288ea8be130148..561c6838798aa92ce2c96b3c45d5ba42fe6edef3 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter.h +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter.h @@ -86,6 +86,7 @@ class IrEmitter : public DfsHloVisitorWithDefault { Status HandleParameter(HloInstruction* parameter) override; Status HandleReduce(HloInstruction* reduce) override; Status HandleTuple(HloInstruction* tuple) override; + Status HandleScatter(HloInstruction* scatter) override; Status HandleSelect(HloInstruction* select) override; Status HandleTupleSelect(HloInstruction* tuple_select) override; Status HandleFusion(HloInstruction* fusion) override; diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc index db6a4e6f30949abe0066bd46f05338d8aa662abb..dea2a31920e092ccc860145ca281651f7b011a85 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc @@ -21,6 +21,8 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h" +#include "absl/algorithm/container.h" +#include "absl/memory/memory.h" #include "llvm/ADT/StringRef.h" #include "llvm/IR/BasicBlock.h" #include "llvm/IR/Function.h" @@ -29,10 +31,10 @@ limitations under the License. #include "llvm/IR/LLVMContext.h" #include "llvm/IR/Module.h" #include "tensorflow/compiler/xla/literal.h" -#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/buffer_assignment.h" #include "tensorflow/compiler/xla/service/dfs_hlo_visitor.h" #include "tensorflow/compiler/xla/service/gpu/backend_configs.pb.h" +#include "tensorflow/compiler/xla/service/gpu/buffer_allocations.h" #include "tensorflow/compiler/xla/service/gpu/conditional_thunk.h" #include "tensorflow/compiler/xla/service/gpu/convolution_thunk.h" #include "tensorflow/compiler/xla/service/gpu/copy_thunk.h" @@ -55,10 +57,10 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/thunk.h" #include "tensorflow/compiler/xla/service/gpu/tuple_thunk.h" #include "tensorflow/compiler/xla/service/gpu/while_thunk.h" -#include "tensorflow/compiler/xla/service/gpu/while_transformer.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/service/llvm_ir/buffer_assignment_util.h" #include "tensorflow/compiler/xla/service/llvm_ir/dynamic_update_slice_util.h" #include "tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h" #include "tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.h" @@ -66,6 +68,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/llvm_ir/sort_util.h" #include "tensorflow/compiler/xla/service/llvm_ir/tuple_ops.h" #include "tensorflow/compiler/xla/service/name_uniquer.h" +#include "tensorflow/compiler/xla/service/while_loop_analysis.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/types.h" @@ -75,6 +78,7 @@ limitations under the License. #include "tensorflow/core/lib/core/bits.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/lib/gtl/optional.h" #include "tensorflow/core/platform/logging.h" namespace xla { @@ -168,40 +172,6 @@ Status IrEmitterUnnested::Postprocess(HloInstruction* hlo) { return DfsHloVisitor::Postprocess(hlo); } -namespace { -bool ImplementedAsHostToDeviceMemcpy(const BufferAssignment& buffer_assignment, - const HloInstruction& hlo) { - // `hlo` needs to satisfy the following conditions to be implemented as a - // host-to-device cuMemcpy. - // - // 1. `hlo` is a kCopy instruction. - // 2. `hlo`'s only operand is a kConstant instruction. - // 3. `hlo` and its operand have the same shape (thus the same layout too). - // 4. The address of `hlo`'s buffer is known at runtime (without dereferencing - // pointers in a tuple). - return hlo.opcode() == HloOpcode::kCopy && - hlo.operand(0)->opcode() == HloOpcode::kConstant && - ShapeUtil::Equal(hlo.operand(0)->shape(), hlo.shape()) && - buffer_assignment.GetUniqueTopLevelSlice(&hlo).ok(); -} - -bool ImplementedAsDeviceToDeviceMemcpy( - const BufferAssignment& buffer_assignment, const HloInstruction& hlo) { - // `hlo` needs to satisfy three conditions to be implemented as a - // device-to-device cuMemcpy. - // - // 1. `hlo` is a kCopy instruction. - // 2. `hlo` and its operand have the same shape (thus the same layout too). - // 3. `hlo` and its operand have a statically-known buffer assignment - // (constants do not, for instance), which means the source buffer also - // resides on the device. - return hlo.opcode() == HloOpcode::kCopy && - ShapeUtil::Equal(hlo.operand(0)->shape(), hlo.shape()) && - buffer_assignment.GetUniqueTopLevelSlice(&hlo).ok() && - buffer_assignment.GetUniqueTopLevelSlice(hlo.operand(0)).ok(); -} -} // namespace - llvm::Function* IrEmitterUnnested::BuildKernelPrototype( const HloInstruction& inst, tensorflow::gtl::ArraySlice args) { @@ -230,11 +200,20 @@ llvm::Function* IrEmitterUnnested::BuildKernelPrototype( ++arg_it; kernel->addDereferenceableAttr(arg_no + 1, alloc->size()); + + const int64 alignment = [&] { + if (alloc->is_entry_computation_parameter()) { + return kEntryParameterAlignBytes; + } else if (alloc->is_constant()) { + return kConstantBufferAlignBytes; + } else { + return kXlaAllocatedBufferAlignBytes; + } + }(); + kernel->addParamAttr( - arg_no, llvm::Attribute::get(context, llvm::Attribute::Alignment, - alloc->is_entry_computation_parameter() - ? kEntryParameterAlignBytes - : kXlaAllocatedBufferAlignBytes)); + arg_no, + llvm::Attribute::get(context, llvm::Attribute::Alignment, alignment)); if (alloc->IsPreallocatedTempBuffer()) { fn_arg->setName("temp_buf"); @@ -336,13 +315,13 @@ llvm::Type* GetIndexTypeForKernel(const HloInstruction* hlo, int64 launch_size, }; // Check the size of input tensors - if (!c_all_of(unnested_hlo->operands(), hlo_shape_in_range)) { + if (!absl::c_all_of(unnested_hlo->operands(), hlo_shape_in_range)) { return i64_ty; } // Check the size of the internal result tensors if (unnested_hlo->opcode() == HloOpcode::kFusion) { - if (!c_all_of( + if (!absl::c_all_of( unnested_hlo->fused_instructions_computation()->instructions(), hlo_shape_in_range)) { return i64_ty; @@ -367,11 +346,6 @@ Status IrEmitterUnnested::DefaultAction(HloInstruction* hlo) { } Status IrEmitterUnnested::HandleDot(HloInstruction* dot) { - const DotDimensionNumbers& dnums = dot->dot_dimension_numbers(); - if (dnums.lhs_batch_dimensions_size() > 0 || - dnums.rhs_batch_dimensions_size() > 0) { - return Unimplemented("Dot with batch dimensions not implemented."); - } if (ImplementedAsGemm(*dot)) { thunk_sequence_->emplace_back(BuildGemmThunk(dot)); return Status::OK(); @@ -410,7 +384,7 @@ Status IrEmitterUnnested::HandleCustomCall(HloInstruction* custom_call) { int64 feature_index_value = feature_index->literal().Get({}); thunk_sequence_->emplace_back( - MakeUnique( + absl::make_unique( /*operand=*/GetAllocationSlice(*custom_call->operand(0)), /*scale=*/GetAllocationSlice(*custom_call->operand(1)), /*offset=*/GetAllocationSlice(*custom_call->operand(2)), @@ -440,7 +414,7 @@ Status IrEmitterUnnested::HandleCustomCall(HloInstruction* custom_call) { auto output_mean = assn.GetUniqueSlice(custom_call, {1}).ValueOrDie(); auto output_inv_stddev = assn.GetUniqueSlice(custom_call, {2}).ValueOrDie(); thunk_sequence_->emplace_back( - MakeUnique( + absl::make_unique( /*operand=*/GetAllocationSlice(*custom_call->operand(0)), /*scale=*/GetAllocationSlice(*custom_call->operand(1)), /*offset=*/GetAllocationSlice(*custom_call->operand(2)), @@ -470,19 +444,20 @@ Status IrEmitterUnnested::HandleCustomCall(HloInstruction* custom_call) { auto output_grad_scale = assn.GetUniqueSlice(custom_call, {1}).ValueOrDie(); auto output_grad_offset = assn.GetUniqueSlice(custom_call, {2}).ValueOrDie(); - thunk_sequence_->emplace_back(MakeUnique( - /*operand=*/GetAllocationSlice(*custom_call->operand(0)), - /*scale=*/GetAllocationSlice(*custom_call->operand(1)), - /*mean=*/GetAllocationSlice(*custom_call->operand(2)), - /*inv_stddev=*/GetAllocationSlice(*custom_call->operand(3)), - /*grad_output=*/GetAllocationSlice(*custom_call->operand(4)), - /*epsilon=*/epsilon_value, - /*feature_index=*/feature_index_value, - /*output_grad_data=*/output_grad_data, - /*output_grad_scale=*/output_grad_scale, - /*output_grad_offset=*/output_grad_offset, - /*output_tuple=*/GetAllocationSlice(*custom_call), - /*hlo=*/custom_call)); + thunk_sequence_->emplace_back( + absl::make_unique( + /*operand=*/GetAllocationSlice(*custom_call->operand(0)), + /*scale=*/GetAllocationSlice(*custom_call->operand(1)), + /*mean=*/GetAllocationSlice(*custom_call->operand(2)), + /*inv_stddev=*/GetAllocationSlice(*custom_call->operand(3)), + /*grad_output=*/GetAllocationSlice(*custom_call->operand(4)), + /*epsilon=*/epsilon_value, + /*feature_index=*/feature_index_value, + /*output_grad_data=*/output_grad_data, + /*output_grad_scale=*/output_grad_scale, + /*output_grad_offset=*/output_grad_offset, + /*output_tuple=*/GetAllocationSlice(*custom_call), + /*hlo=*/custom_call)); return Status::OK(); } @@ -502,7 +477,7 @@ Status IrEmitterUnnested::HandleCustomCall(HloInstruction* custom_call) { const auto& target = custom_call->custom_call_target(); std::unique_ptr thunk; if (target == kCudnnConvForwardCallTarget) { - thunk = MakeUnique( + thunk = absl::make_unique( CudnnConvKind::kForward, /*input_buffer=*/lhs_slice, /*filter_buffer=*/rhs_slice, @@ -516,7 +491,7 @@ Status IrEmitterUnnested::HandleCustomCall(HloInstruction* custom_call) { backend_config.algorithm(), backend_config.tensor_ops_enabled(), custom_call); } else if (target == kCudnnConvBackwardInputCallTarget) { - thunk = MakeUnique( + thunk = absl::make_unique( CudnnConvKind::kBackwardInput, /*input_buffer=*/conv_result_slice, /*filter_buffer=*/rhs_slice, @@ -530,7 +505,7 @@ Status IrEmitterUnnested::HandleCustomCall(HloInstruction* custom_call) { backend_config.algorithm(), backend_config.tensor_ops_enabled(), custom_call); } else if (target == kCudnnConvBackwardFilterCallTarget) { - thunk = MakeUnique( + thunk = absl::make_unique( CudnnConvKind::kBackwardFilter, /*input_buffer=*/lhs_slice, /*filter_buffer=*/conv_result_slice, @@ -572,6 +547,11 @@ Status IrEmitterUnnested::HandleFusion(HloInstruction* fusion) { switch (root->opcode()) { case HloOpcode::kTuple: case HloOpcode::kReduce: { + if (root->opcode() == HloOpcode::kReduce && + ShapeUtil::IsTuple(root->shape())) { + // TODO(b/112040122): Support variadic reduce. + return Unimplemented("Variadic reduce is not supported on GPU"); + } VLOG(3) << "Emitting fused reduction to vector: " << fusion->ToString(); std::vector> thunks; ArraySlice output_instructions = @@ -598,7 +578,7 @@ Status IrEmitterUnnested::HandleFusion(HloInstruction* fusion) { thunks.push_back( BuildKernelThunk(fusion, /*implements_whole_instruction=*/false)); thunk_sequence_->emplace_back( - MakeUnique(std::move(thunks), fusion)); + absl::make_unique(std::move(thunks), fusion)); std::vector parameter_arrays; for (HloInstruction* operand : fusion->operands()) { parameter_arrays.push_back(GetIrArray(*operand, *fusion)); @@ -718,13 +698,12 @@ Status IrEmitterUnnested::HandleFusion(HloInstruction* fusion) { } Status IrEmitterUnnested::HandleCopy(HloInstruction* copy) { - if (ImplementedAsHostToDeviceMemcpy(ir_emitter_context_->buffer_assignment(), - *copy)) { - thunk_sequence_->emplace_back(BuildHostToDeviceCopyThunk(copy)); - return Status::OK(); - } - if (ImplementedAsDeviceToDeviceMemcpy( - ir_emitter_context_->buffer_assignment(), *copy)) { + CHECK(ShapeUtil::Compatible(copy->operand(0)->shape(), copy->shape())); + const BufferAssignment& buffer_assignment = + ir_emitter_context_->buffer_assignment(); + if (LayoutUtil::Equal(copy->operand(0)->shape().layout(), + copy->shape().layout()) && + buffer_assignment.GetUniqueTopLevelSlice(copy->operand(0)).ok()) { thunk_sequence_->emplace_back(BuildDeviceToDeviceCopyThunk(copy)); return Status::OK(); } @@ -1722,6 +1701,10 @@ Status IrEmitterUnnested::EmitReductionToVector( } Status IrEmitterUnnested::HandleReduce(HloInstruction* reduce) { + // TODO(b/112040122): Support multi-output reduce. + if (!ShapeUtil::IsArray(reduce->shape())) { + return Unimplemented("Multi-output reduce is not supported on GPU"); + } auto input = reduce->operand(0); auto init_value = reduce->operand(1); tensorflow::gtl::ArraySlice dimensions_to_reduce(reduce->dimensions()); @@ -1737,7 +1720,7 @@ Status IrEmitterUnnested::HandleReduce(HloInstruction* reduce) { thunks.push_back( BuildKernelThunk(reduce, /*implements_whole_instruction=*/false)); thunk_sequence_->emplace_back( - MakeUnique(std::move(thunks), reduce)); + absl::make_unique(std::move(thunks), reduce)); return EmitReductionToVector( reduce, input->shape(), {[&](const IrArray::Index& index) { @@ -1757,11 +1740,13 @@ Status IrEmitterUnnested::HandleReduce(HloInstruction* reduce) { Status IrEmitterUnnested::HandleTuple(HloInstruction* tuple) { bool all_tuple_elements_have_buffer = - c_all_of(tuple->operands(), [&](HloInstruction* tuple_element) { + absl::c_all_of(tuple->operands(), [&](HloInstruction* tuple_element) { return ir_emitter_context_->buffer_assignment() .GetUniqueTopLevelSlice(tuple_element) .ok(); }); + // TODO(b/111689850): This logic isn't quite correct. + // // Tuples (especially tuples that are the final result of a computation) can // be so huge that if we were to emit a kernel that took each tuple element as // a parameter, we would exceed the max allowable number of parameters to a @@ -1769,15 +1754,15 @@ Status IrEmitterUnnested::HandleTuple(HloInstruction* tuple) { // buffer, we collect their buffer addresses in a host array, and then copy // that array to the tuple's buffer. // - // Some tuple elements (e.g. const or bitcast of const) might not have a - // buffer -- their contents are stored in code. In that case, we fall back to - // emitting kernels which have access to their buffer addresses in code. + // Some tuple elements might not have an unambiguous buffer (like the result + // of a select-tuple). In that case, we fall back to emitting kernels which + // have access to their buffer addresses in code. if (all_tuple_elements_have_buffer) { std::vector tuple_element_buffers; for (const HloInstruction* tuple_element : tuple->operands()) { tuple_element_buffers.push_back(GetAllocationSlice(*tuple_element)); } - thunk_sequence_->emplace_back(MakeUnique( + thunk_sequence_->emplace_back(absl::make_unique( tuple_element_buffers, GetAllocationSlice(*tuple), tuple)); return Status::OK(); } @@ -1809,8 +1794,8 @@ Status IrEmitterUnnested::HandleSelectAndScatter( thunks.push_back(std::move(initializer_thunk)); thunks.push_back(BuildKernelThunk(select_and_scatter, /*implements_whole_instruction=*/false)); - thunk_sequence_->emplace_back( - MakeUnique(std::move(thunks), select_and_scatter)); + thunk_sequence_->emplace_back(absl::make_unique( + std::move(thunks), select_and_scatter)); // TODO(b/31410564): Implement dilation rate for select-and-scatter. if (window_util::HasDilation(window)) { @@ -1989,19 +1974,13 @@ Status IrEmitterUnnested::HandleWhile(HloInstruction* xla_while) { condition->root_instruction()->shape().element_type() == PRED) << "While condition computation must return bool"; // Build ForThunk for conformant while loops, otherwise build WhileThunk. - auto result = CanTransformWhileToFor(xla_while); - if (result.ok()) { - auto tuple = result.ConsumeValueOrDie(); - // loop_trip_count = (limit - start + increment - 1) / increment - const int64 loop_trip_count = - (std::get<1>(tuple) - std::get<0>(tuple) + std::get<2>(tuple) - 1) / - std::get<2>(tuple); - thunk_sequence_->emplace_back(BuildForThunk(xla_while, loop_trip_count)); + // TODO(b/112163966): Move trip count computation earlier in the pipeline. + if (auto loop_trip_count = ComputeWhileLoopTripCount(xla_while)) { + thunk_sequence_->emplace_back(BuildForThunk(xla_while, *loop_trip_count)); VLOG(3) << "Built ForThunk for while: " << xla_while->name(); } else { thunk_sequence_->emplace_back(BuildWhileThunk(xla_while)); - VLOG(3) << "Built WhileThunk for while: " << xla_while->name() - << " while-to-for transform status: " << result.status(); + VLOG(3) << "Built WhileThunk for while: " << xla_while->name(); } return Status::OK(); } @@ -2041,7 +2020,7 @@ Status IrEmitterUnnested::HandleRng(HloInstruction* rng) { thunks.push_back(std::move(rng_thunk)); thunks.push_back(std::move(increment_seed_thunk)); thunk_sequence_->emplace_back( - MakeUnique(std::move(thunks), rng)); + absl::make_unique(std::move(thunks), rng)); return Status::OK(); } @@ -2054,28 +2033,34 @@ Status IrEmitterUnnested::HandleSelect(HloInstruction* select) { Status IrEmitterUnnested::HandleSort(HloInstruction* sort) { std::vector> thunks; + auto keys = sort->operand(0); auto values = sort->operand_count() > 1 ? sort->operand(1) : nullptr; + ShapeIndex keys_shape_index({}); + ShapeIndex values_shape_index({}); if (values != nullptr) { - // TODO(b/26783907): Also sort the values by their corresponding key. - return Unimplemented("Key/Value Sort is not implemented on GPU"); + keys_shape_index = ShapeIndex({0}); + values_shape_index = ShapeIndex({1}); } + auto keys_destination = GetAllocationSlice(*sort, keys_shape_index); + auto values_destination = GetAllocationSlice(*sort, values_shape_index); - // First copy the operand to the output, so that we can sort in-place. - // TODO(b/26783907): Share buffer of output and operand when it is possible. - if (sort->operand(0)->IsConstant()) { - thunks.push_back(MakeUnique( - /*source_address=*/sort->operand(0)->literal().untyped_data(), - /*destination_buffer=*/GetAllocationSlice(*sort), - /*mem_size=*/ShapeUtil::ByteSizeOf(sort->shape()), sort)); - } else { - thunks.push_back(MakeUnique( - /*source_address=*/GetAllocationSlice(*sort->operand(0)), - /*destination_buffer=*/GetAllocationSlice(*sort), - /*mem_size=*/ShapeUtil::ByteSizeOf(sort->shape()), sort)); + if (keys_destination != GetAllocationSlice(*keys)) { + thunks.push_back(absl::make_unique( + /*source_address=*/GetAllocationSlice(*keys), + /*destination_buffer=*/keys_destination, + /*mem_size=*/ShapeUtil::ByteSizeOf(keys->shape()), nullptr)); + } + if (values != nullptr && values_destination != GetAllocationSlice(*values)) { + // TODO(b/26783907): Figure out why we never seem to share buffers for + // key/value sort. + thunks.push_back(absl::make_unique( + /*source_address=*/GetAllocationSlice(*values), + /*destination_buffer=*/values_destination, + /*mem_size=*/ShapeUtil::ByteSizeOf(values->shape()), nullptr)); } int64 dimension_to_sort = sort->dimensions(0); - int64 dimension_to_sort_bound = sort->shape().dimensions(dimension_to_sort); + int64 dimension_to_sort_bound = keys->shape().dimensions(dimension_to_sort); int64 num_stages = tensorflow::Log2Ceiling(dimension_to_sort_bound); auto index_type = b_.getInt64Ty(); @@ -2099,7 +2084,7 @@ Status IrEmitterUnnested::HandleSort(HloInstruction* sort) { thunks.push_back( BuildKernelThunk(sort, /*implements_whole_instruction=*/false)); LaunchDimensions launch_dimensions = CalculateLaunchDimensions( - sort->shape(), ir_emitter_context_->device_description()); + keys->shape(), ir_emitter_context_->device_description()); UpdateLaunchDimensions(launch_dimensions, thunks.back().get(), ir_emitter_context_->llvm_module()); @@ -2111,13 +2096,16 @@ Status IrEmitterUnnested::HandleSort(HloInstruction* sort) { } TF_RETURN_IF_ERROR(llvm_ir::EmitSortInPlace( - dimension_to_sort, GetIrArray(*sort, *sort), IrName(sort), xor_mask, - &b_, &launch_dimensions)); + dimension_to_sort, GetIrArray(*sort, *sort, keys_shape_index), + values != nullptr ? tensorflow::gtl::make_optional( + GetIrArray(*sort, *sort, values_shape_index)) + : tensorflow::gtl::nullopt, + IrName(sort), xor_mask, &b_, &launch_dimensions)); } } thunk_sequence_->emplace_back( - MakeUnique(std::move(thunks), sort)); + absl::make_unique(std::move(thunks), sort)); return Status::OK(); } @@ -2144,7 +2132,7 @@ Status IrEmitterUnnested::HandleCrossReplicaSum(HloInstruction* crs) { if (crs->operand_count() == 1) { CHECK(ShapeUtil::IsArray(crs->operand(0)->shape())) << "Operands to cross-replica-sum must be arrays: " << crs->ToString(); - thunk_sequence_->push_back(MakeUnique( + thunk_sequence_->push_back(absl::make_unique( /*source_address=*/GetAllocationSlice(*crs->operand(0)), /*destination_buffer=*/GetAllocationSlice(*crs), /*mem_size=*/ShapeUtil::ByteSizeOf(crs->shape()), crs)); @@ -2159,17 +2147,17 @@ Status IrEmitterUnnested::HandleCrossReplicaSum(HloInstruction* crs) { tuple_element_buffers.push_back(ir_emitter_context_->buffer_assignment() .GetUniqueSlice(crs, {i}) .ValueOrDie()); - thunks.push_back(MakeUnique( + thunks.push_back(absl::make_unique( /*source_address=*/GetAllocationSlice(*crs->operand(i)), /*destination_buffer=*/tuple_element_buffers.back(), /*mem_size=*/ShapeUtil::ByteSizeOf(crs->operand(i)->shape()), nullptr)); } // Output a tuple of the buffers above. - thunks.push_back(MakeUnique(tuple_element_buffers, - GetAllocationSlice(*crs), nullptr)); + thunks.push_back(absl::make_unique( + tuple_element_buffers, GetAllocationSlice(*crs), nullptr)); thunk_sequence_->push_back( - MakeUnique(std::move(thunks), crs)); + absl::make_unique(std::move(thunks), crs)); return Status::OK(); } @@ -2274,11 +2262,6 @@ GetHloBufferSlices(const HloInstruction* hlo, // Adds entries for all subshapes of instr to `slices`. auto add_slices_for = [&](const HloInstruction* instr) { - // GPU constants don't have buffers; don't bother looking for one. - if (instr->IsConstant()) { - return; - } - ShapeUtil::ForEachSubshape( instr->shape(), [&](const Shape& /*shape*/, const ShapeIndex& index) { if (slices.count({instr, index})) { @@ -2340,21 +2323,25 @@ std::unique_ptr IrEmitterUnnested::BuildKernelThunk( // We'll pass a pointer to each of the elements of `buffers` to our kernel, in // this order. - std::vector buffers(buffers_needed.begin(), - buffers_needed.end()); - std::sort(buffers.begin(), buffers.end(), + std::vector non_constant_buffers; + absl::c_copy_if(buffers_needed, std::back_inserter(non_constant_buffers), + [](const BufferAllocation* allocation) { + return !allocation->is_constant(); + }); + + std::sort(non_constant_buffers.begin(), non_constant_buffers.end(), [](const BufferAllocation* a, const BufferAllocation* b) { return a->index() < b->index(); }); - llvm::Function* kernel = BuildKernelPrototype(*inst, buffers); + llvm::Function* kernel = BuildKernelPrototype(*inst, non_constant_buffers); // Build a map from a BufferAllocation to the corresponding argument in our // kernel. std::unordered_map kernel_args; { auto arg_it = kernel->arg_begin(); - auto buffers_it = buffers.begin(); + auto buffers_it = non_constant_buffers.begin(); for (; arg_it != kernel->arg_end(); ++arg_it, ++buffers_it) { kernel_args[*buffers_it] = arg_it; } @@ -2372,8 +2359,16 @@ std::unique_ptr IrEmitterUnnested::BuildKernelThunk( << " is found in slice " << slice.ToString() << " at GTE index " << gte_index.ToString(); - llvm::Value* loc = b_.CreateInBoundsGEP(kernel_args.at(slice.allocation()), - {b_.getInt64(slice.offset())}); + llvm::Value* loc; + if (slice.allocation()->is_constant()) { + loc = ir_emitter_context_->llvm_module()->getGlobalVariable( + llvm_ir::AsStringRef(llvm_ir::ConstantBufferAllocationToGlobalName( + *slice.allocation()))); + CHECK_NE(loc, nullptr); + } else { + loc = b_.CreateInBoundsGEP(kernel_args.at(slice.allocation()), + {b_.getInt64(slice.offset())}); + } // If gte_index is nonempty, we have to dereference `loc` to get to the // value we're ultimately interested in. @@ -2396,16 +2391,16 @@ std::unique_ptr IrEmitterUnnested::BuildKernelThunk( llvm::ConstantPointerNull::get(b_.getInt8PtrTy())); } - return MakeUnique(buffers, llvm_ir::AsString(kernel->getName()), - implements_whole_instruction ? inst : nullptr, - unroll_factor); + return absl::make_unique( + non_constant_buffers, llvm_ir::AsString(kernel->getName()), + implements_whole_instruction ? inst : nullptr, unroll_factor); } std::unique_ptr IrEmitterUnnested::BuildHostToDeviceCopyThunk( const HloInstruction* inst) { const HloInstruction* operand = inst->operand(0); CHECK_EQ(HloOpcode::kConstant, operand->opcode()); - return MakeUnique( + return absl::make_unique( /*source_address=*/operand->literal().untyped_data(), /*destination_buffer=*/GetAllocationSlice(*inst), /*mem_size=*/ @@ -2417,7 +2412,7 @@ std::unique_ptr IrEmitterUnnested::BuildHostToDeviceCopyThunk( std::unique_ptr IrEmitterUnnested::BuildDeviceToDeviceCopyThunk( const HloInstruction* inst) { const HloInstruction* operand = inst->operand(0); - return MakeUnique( + return absl::make_unique( /*source_address=*/GetAllocationSlice(*operand), /*destination_buffer=*/GetAllocationSlice(*inst), /*mem_size=*/ @@ -2437,7 +2432,7 @@ std::unique_ptr IrEmitterUnnested::BuildInfeedThunk( .GetUniqueSlice(inst, index) .ConsumeValueOrDie(); }); - return MakeUnique(slices, inst); + return absl::make_unique(slices, inst); } std::unique_ptr IrEmitterUnnested::BuildOutfeedThunk( @@ -2454,7 +2449,7 @@ std::unique_ptr IrEmitterUnnested::BuildOutfeedThunk( *slice = status_or_slice.ConsumeValueOrDie(); } }); - return MakeUnique(std::move(slices), inst); + return absl::make_unique(std::move(slices), inst); } namespace { @@ -2477,7 +2472,7 @@ std::unique_ptr IrEmitterUnnested::BuildGemmThunk( if (inst->opcode() == HloOpcode::kDot) { const HloInstruction* lhs = inst->operand(0); const HloInstruction* rhs = inst->operand(1); - return MakeUnique( + return absl::make_unique( GetAllocationSlice(*lhs), // The buffer assigned to LHS. GetAllocationSlice(*rhs), // The buffer assigned to RHS. GetAllocationSlice(*inst), // The output buffer. @@ -2519,7 +2514,7 @@ std::unique_ptr IrEmitterUnnested::BuildGemmThunk( const HloInstruction* rhs = inst->operand(rhs_parameter->parameter_number()); - return MakeUnique( + return absl::make_unique( GetAllocationSlice(*lhs), // The buffer assigned to LHS. GetAllocationSlice(*rhs), // The buffer assigned to RHS. GetAllocationSlice(*inst), // The output buffer. @@ -2536,11 +2531,12 @@ std::unique_ptr IrEmitterUnnested::BuildGemmThunk( std::unique_ptr IrEmitterUnnested::BuildFftThunk( const HloInstruction* inst) { const HloInstruction* operand = inst->operand(0); - return MakeUnique(inst->fft_type(), inst->fft_length(), - /*input_buffer=*/GetAllocationSlice(*operand), - /*output_buffer=*/GetAllocationSlice(*inst), - /*input_shape=*/operand->shape(), - /*output_shape=*/inst->shape(), inst); + return absl::make_unique( + inst->fft_type(), inst->fft_length(), + /*input_buffer=*/GetAllocationSlice(*operand), + /*output_buffer=*/GetAllocationSlice(*inst), + /*input_shape=*/operand->shape(), + /*output_shape=*/inst->shape(), inst); } StatusOr> IrEmitterUnnested::BuildInitializerThunk( @@ -2589,9 +2585,9 @@ StatusOr> IrEmitterUnnested::BuildInitializerThunk( // MemzeroThunk. ArraySlice literal_bytes( reinterpret_cast(literal.untyped_data()), num_bytes); - if (c_all_of(literal_bytes, [](uint8 byte) { return byte == 0; })) { - return { - MakeUnique(GetAllocationSlice(*hlo, index), nullptr)}; + if (absl::c_all_of(literal_bytes, [](uint8 byte) { return byte == 0; })) { + return {absl::make_unique(GetAllocationSlice(*hlo, index), + nullptr)}; } // If the literal is 8 or 16 bits wide, we can emit a 32-bit memset by @@ -2608,7 +2604,7 @@ StatusOr> IrEmitterUnnested::BuildInitializerThunk( memcpy(&pattern16, literal_bytes.data(), sizeof(pattern16)); } uint32 pattern32 = uint32{pattern16} | (uint32{pattern16} << 16); - return {MakeUnique( + return {absl::make_unique( pattern32, GetAllocationSlice(*hlo, index), nullptr)}; } @@ -2619,7 +2615,7 @@ StatusOr> IrEmitterUnnested::BuildInitializerThunk( literal_bytes.size() - 4) == 0) { uint32 word; memcpy(&word, literal_bytes.data(), sizeof(word)); - return {MakeUnique( + return {absl::make_unique( word, GetAllocationSlice(*hlo, index), nullptr)}; } } @@ -2635,7 +2631,17 @@ StatusOr> IrEmitterUnnested::BuildInitializerThunk( // If the init_value was fused into this reduce we have to generate it first. if (fused && init_value_operand->opcode() != HloOpcode::kParameter) { CHECK_EQ(HloOpcode::kConstant, init_value_operand->opcode()); - TF_RETURN_IF_ERROR(HandleConstant(const_cast(init_value))); + + const Literal& literal = init_value_operand->literal(); + llvm::Constant* initializer = + llvm_ir::ConvertLiteralToIrConstant(literal, module_); + + llvm::GlobalVariable* global_for_const = new llvm::GlobalVariable( + *module_, initializer->getType(), + /*isConstant=*/true, llvm::GlobalValue::PrivateLinkage, initializer, + /*Name=*/""); + global_for_const->setAlignment(kConstantBufferAlignBytes); + bindings_.BindHloToIrValue(*init_value_operand, global_for_const); } TF_RETURN_IF_ERROR(ParallelLoopEmitter( [=](const IrArray::Index& index) { @@ -2761,7 +2767,7 @@ std::unique_ptr IrEmitterUnnested::BuildWhileThunk( ir_emitter_context_); TF_CHECK_OK(body->Accept(&ir_emitter_body)); - return MakeUnique( + return absl::make_unique( GetAllocationSlice(*condition->root_instruction()), // cond result ir_emitter_condition.ConsumeThunkSequence(), ir_emitter_body.ConsumeThunkSequence(), hlo); @@ -2779,8 +2785,8 @@ std::unique_ptr IrEmitterUnnested::BuildForThunk( ir_emitter_context_); TF_CHECK_OK(body->Accept(&ir_emitter_body)); - return MakeUnique(loop_limit, - ir_emitter_body.ConsumeThunkSequence(), hlo); + return absl::make_unique( + loop_limit, ir_emitter_body.ConsumeThunkSequence(), hlo); } std::unique_ptr IrEmitterUnnested::BuildConditionalThunk( @@ -2800,7 +2806,7 @@ std::unique_ptr IrEmitterUnnested::BuildConditionalThunk( ir_emitter_context_); TF_CHECK_OK(false_computation->Accept(&ir_emitter_false)); - return MakeUnique( + return absl::make_unique( GetAllocationSlice(*hlo->operand(0)), GetAllocationSlice(*hlo->operand(1)), GetAllocationSlice(*hlo->operand(2)), @@ -3102,7 +3108,7 @@ LaunchDimensions IrEmitterUnnested::EmitHlo021Tile( CeilOfRatio(output_dims_in_tiles[i], kTileSize); } const int64 num_tiles = - c_accumulate(output_dims_in_tiles, 1, std::multiplies()); + absl::c_accumulate(output_dims_in_tiles, 1, std::multiplies()); LaunchDimensions launch_dimensions(num_tiles, kThreadsPerTile); llvm::Type* index_ty = @@ -3367,5 +3373,47 @@ bool IrEmitterUnnested::CheckAndEmitHloWithTile021(HloInstruction* hlo) { return true; } +Status IrEmitterUnnested::EmitConstantGlobals() { + for (const BufferAllocation& allocation : + ir_emitter_context_->buffer_assignment().Allocations()) { + if (!allocation.is_constant()) { + continue; + } + + const Literal& literal = llvm_ir::LiteralForConstantAllocation(allocation); + const bool should_emit_initializer = ShouldEmitLiteralInLlvmIr(literal); + llvm::ArrayType* global_type = + llvm::ArrayType::get(b_.getInt8Ty(), allocation.size()); + llvm::Constant* initializer = + should_emit_initializer + ? llvm_ir::ConvertLiteralToIrConstant(literal, module_) + : llvm::ConstantAggregateZero::get(global_type); + if (should_emit_initializer) { + VLOG(3) << "Emitted initializer for constant with shape " + << ShapeUtil::HumanString(literal.shape()); + } + + // These globals will be looked up by name by GpuExecutable so we need to + // give them an external linkage. Not all of their uses are visible in the + // LLVM IR (e.g. TupleThunk) so we can't give then a linkage that merely + // preserves their names (like available_externally), we also need to ensure + // that they stick around even if they're "unused". + // + // We may have to be more more clever here in the future if we notice that + // we're keeping around too many globals because of their linkage. + llvm::GlobalVariable* global_for_const = new llvm::GlobalVariable( + global_type, /*isConstant=*/should_emit_initializer, + llvm::GlobalValue::ExternalLinkage, + /*Initializer=*/initializer, + llvm_ir::AsStringRef( + llvm_ir::ConstantBufferAllocationToGlobalName(allocation))); + global_for_const->setAlignment(kConstantBufferAlignBytes); + ir_emitter_context_->llvm_module()->getGlobalList().push_back( + global_for_const); + } + + return Status::OK(); +} + } // namespace gpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h index 616d8a2206e5a9666947008879c48f99a022e899..525441990795e160ba0e8facb910d5cc9796c4bb 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h @@ -92,6 +92,9 @@ class IrEmitterUnnested : public IrEmitter { const HloInstruction& hlo, const llvm_ir::ElementGenerator& body_emitter, KernelThunk* thunk); + // Emits LLVM global variables corresponding to constant instructions. + Status EmitConstantGlobals(); + private: // Builds the appropriate thunk for the instruction hlo and returns the owning // pointer to it. The caller needs to make sure `inst` outlives the lifetime diff --git a/tensorflow/compiler/xla/service/gpu/kernel_thunk.cc b/tensorflow/compiler/xla/service/gpu/kernel_thunk.cc index e76823ad103dfa5ba61a0d3ba81b2c028dfeb33e..6305396635eae7bb3fcda1d4675fb3b5f7d60553 100644 --- a/tensorflow/compiler/xla/service/gpu/kernel_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/kernel_thunk.cc @@ -15,7 +15,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/kernel_thunk.h" -#include "tensorflow/compiler/xla/ptr_util.h" +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/service/gpu/gpu_executable.h" #include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h" #include "tensorflow/compiler/xla/types.h" @@ -95,7 +95,7 @@ Status KernelThunk::ExecuteOnStream(const BufferAllocations& buffer_allocations, VLOG(3) << "Launching " << kernel->name(); // Launch the kernel with potentially multiple blocks and threads. static constexpr int kKernelArgsLimit = 1024; - auto kernel_args = MakeUnique>(); + auto kernel_args = absl::make_unique>(); for (const BufferAllocation* arg : args_) { const auto& buf = buffer_allocations.GetDeviceAddress(arg->index()); kernel_args->add_device_memory_argument(buf); diff --git a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/BUILD b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/BUILD index eb93efc560efbb4c14065ec98b980a1ca78605c6..6bd9c58f83063554d57aea5e2289907be701a2c1 100644 --- a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/BUILD +++ b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/BUILD @@ -34,6 +34,7 @@ cc_library( "//tensorflow/compiler/xla/service/llvm_ir:llvm_util", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", + "@com_google_absl//absl/memory", "@llvm//:amdgpu_code_gen", "@llvm//:analysis", "@llvm//:bit_reader", diff --git a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/nvptx_backend_lib.cc b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/nvptx_backend_lib.cc index 6c1c20fc0464927054deace8980620c3a9c6f09b..cce6e4814174c022f40b9aa199335a85ffaa6ed7 100644 --- a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/nvptx_backend_lib.cc +++ b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/nvptx_backend_lib.cc @@ -20,7 +20,7 @@ limitations under the License. #include #include -#include "tensorflow/compiler/xla/ptr_util.h" +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/dump_ir_pass.h" #include "tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/utils.h" #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" @@ -114,21 +114,20 @@ static string GetLibdeviceFilename(const string& libdevice_dir_path, // Gets the GPU name as it's known to LLVM for a given compute capability. If // we see an unrecognized compute capability, we return "sm_30". static string GetSmName(std::pair compute_capability) { - static auto* m = new std::map, int>( - {{{2, 0}, 20}, - {{2, 1}, 21}, - {{3, 0}, 30}, - {{3, 2}, 32}, - {{3, 5}, 35}, - {{3, 7}, 37}, - {{5, 0}, 50}, - {{5, 2}, 52}, - {{5, 3}, 53}, - {{6, 0}, 60}, - {{6, 1}, 61}, - {{6, 2}, 62}, - // TODO: Change this to 70 once LLVM NVPTX supports it - {{7, 0}, 60}}); + static auto* m = new std::map, int>({ + {{3, 0}, 30}, + {{3, 2}, 32}, + {{3, 5}, 35}, + {{3, 7}, 37}, + {{5, 0}, 50}, + {{5, 2}, 52}, + {{5, 3}, 53}, + {{6, 0}, 60}, + {{6, 1}, 61}, + {{6, 2}, 62}, + {{7, 0}, 70}, + {{7, 2}, 72}, + }); int sm_version = 30; auto it = m->find(compute_capability); if (it != m->end()) { @@ -181,7 +180,7 @@ std::unique_ptr GetTargetMachine( TargetOptions target_options = InitTargetOptionsFromCodeGenFlags(); llvm_ir::SetTargetOptions( /*fast_math_enabled=*/hlo_module_config.debug_options() - .xla_enable_fast_math(), + .xla_gpu_enable_fast_math(), &target_options); // Enable FMA synthesis. @@ -206,7 +205,7 @@ std::unique_ptr GetTargetMachine( default: codegen_opt_level = CodeGenOpt::None; } - return WrapUnique(target->createTargetMachine( + return absl::WrapUnique(target->createTargetMachine( triple.str(), llvm_ir::AsStringRef(cpu_name), "+ptx60", target_options, Optional(RelocModel), Optional(CMModel), codegen_opt_level)); @@ -329,7 +328,7 @@ Status LinkLibdeviceIfNecessary(llvm::Module* module, if (linker.linkInModule( std::move(libdevice_module), llvm::Linker::Flags::LinkOnlyNeeded, [](Module& M, const StringSet<>& GVS) { - internalizeModule(M, [&M, &GVS](const GlobalValue& GV) { + internalizeModule(M, [&GVS](const GlobalValue& GV) { return !GV.hasName() || (GVS.count(GV.getName()) == 0); }); })) { diff --git a/tensorflow/compiler/xla/service/gpu/multi_output_fusion.cc b/tensorflow/compiler/xla/service/gpu/multi_output_fusion.cc index 6fef7208533e83484ee7fdee22528fb4d219272c..34a479b289d26964ba10f5e92c3a8829110787c9 100644 --- a/tensorflow/compiler/xla/service/gpu/multi_output_fusion.cc +++ b/tensorflow/compiler/xla/service/gpu/multi_output_fusion.cc @@ -23,6 +23,7 @@ limitations under the License. #include #include +#include "absl/algorithm/container.h" #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/service/gpu/instruction_fusion.h" #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" @@ -113,17 +114,25 @@ bool IsInputFusibleReduction(HloInstruction* instr) { // of input parameters differ. In such situtations it is beneficial not to fuse. // We consider input params with maximum rank only. Params with smaller ranks // will be broadcasted and have not been observed to cause data locality issues. -// TODO(b/110927656): Improve reduce emitters to remove this limitation. +// TODO(b/111977086): Improve reduce emitters to remove this limitation. bool ReduceFriendlyInputLayouts(HloInstruction* instr) { + std::vector params; + if (instr->opcode() == HloOpcode::kFusion) { + params = instr->fused_parameters(); + } else { + for (HloInstruction* operand : instr->operands()) { + params.push_back(operand); + } + } int64 max_rank = 0; const Layout* max_rank_layout; - for (HloInstruction* param : instr->fused_parameters()) { + for (HloInstruction* param : params) { if (ShapeUtil::Rank(param->shape()) > max_rank) { max_rank = ShapeUtil::Rank(param->shape()); max_rank_layout = ¶m->shape().layout(); } } - return c_all_of(instr->fused_parameters(), [&](HloInstruction* param) { + return absl::c_all_of(params, [&](HloInstruction* param) { return (ShapeUtil::Rank(param->shape()) < max_rank) || (LayoutUtil::Equal(param->shape().layout(), *max_rank_layout)); }); @@ -221,7 +230,7 @@ bool GpuMultiOutputFusion::DoProducerConsumerMultiOutputFusion() { const bool is_loop_fusion = producer->opcode() == HloOpcode::kFusion && producer->fusion_kind() == HloInstruction::FusionKind::kLoop; - if (!is_loop_fusion) { + if (!producer->IsElementwise() && !is_loop_fusion) { VLOG(3) << producer->name() << " is not a loop fusion."; continue; } @@ -240,7 +249,7 @@ bool GpuMultiOutputFusion::DoProducerConsumerMultiOutputFusion() { } // Do not fuse a producer if the other operands of the fusion are // reachable from the producer, this would create a cycle. - if (c_any_of(consumer_operands, [&](HloInstruction* operand) { + if (absl::c_any_of(consumer_operands, [&](HloInstruction* operand) { return producer != operand && reachability()->IsReachable(producer, operand); })) { @@ -260,7 +269,7 @@ bool GpuMultiOutputFusion::DoProducerConsumerMultiOutputFusion() { for (auto& fusion_pair : potential_fusion_list) { HloInstruction* producer = fusion_pair.first; HloInstruction* consumer = fusion_pair.second; - if (!c_any_of(consumer->operands(), [&](HloInstruction* operand) { + if (!absl::c_any_of(consumer->operands(), [&](HloInstruction* operand) { return producer != operand && reachability()->IsReachable(producer, operand); })) { diff --git a/tensorflow/compiler/xla/service/gpu/multi_output_fusion_test.cc b/tensorflow/compiler/xla/service/gpu/multi_output_fusion_test.cc index ec4234b8d9a5da299a9dc574169b0bb5fe6a575f..14f157a5e518a0ec82c664c123629d04bd385bbf 100644 --- a/tensorflow/compiler/xla/service/gpu/multi_output_fusion_test.cc +++ b/tensorflow/compiler/xla/service/gpu/multi_output_fusion_test.cc @@ -256,6 +256,26 @@ TEST_F(MultiOutputFusionTest, MultiOutputFusionTwoLoops) { op::Tuple(op::Multiply(), op::Divide())); } +TEST_F(MultiOutputFusionTest, ProducerConsumerFusionElementwiseAndReduce) { + auto module = ParseHloString(tensorflow::strings::StrCat(kModulePrefix, R"( + ENTRY reduce { + p0 = f32[2,2,2]{2,1,0} parameter(0) + c0 = f32[] constant(0) + exp = f32[2,2,2]{2,1,0} exponential(p0) + reduce = f32[2,2]{1,0} reduce(exp, c0), dimensions={2}, to_apply=scalar_add_computation + ROOT root = (f32[2,2]{1,0}, f32[2,2,2]{2,1,0}) tuple(reduce, exp) + })")) + .ValueOrDie(); + ASSERT_TRUE(GpuMultiOutputFusion().Run(module.get()).ValueOrDie()); + SCOPED_TRACE(module->ToString()); + const HloInstruction* root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, op::Tuple(op::GetTupleElement(), op::GetTupleElement())); + const HloInstruction* fusion = root->operand(0)->operand(0); + ASSERT_TRUE(fusion->IsMultiOutputFusion()); + EXPECT_THAT(fusion->fused_expression_root(), + op::Tuple(op::Reduce(), op::Exp())); +} + TEST_F(MultiOutputFusionTest, ProducerConsumerFusionLoopFusionAndReduce) { auto module = ParseHloString(tensorflow::strings::StrCat(kModulePrefix, R"( fused_add { diff --git a/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc b/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc index 2eefadebcd1098b294c79d6e400857b1c2760824..5868c1a42e6986c82648c9a7b2935d8e9100f968 100644 --- a/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc +++ b/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc @@ -21,20 +21,20 @@ limitations under the License. #include // NOLINT(build/c++11): only using std::call_once, not mutex. #include +#include "absl/memory/memory.h" #include "llvm/IR/DiagnosticInfo.h" #include "llvm/IR/DiagnosticPrinter.h" #include "llvm/IR/LLVMContext.h" #include "llvm/IR/Module.h" #include "llvm/IR/Verifier.h" #include "tensorflow/compiler/xla/protobuf_util.h" -#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/algebraic_simplifier.h" #include "tensorflow/compiler/xla/service/batchnorm_expander.h" #include "tensorflow/compiler/xla/service/buffer_assignment.h" #include "tensorflow/compiler/xla/service/buffer_liveness.h" #include "tensorflow/compiler/xla/service/call_inliner.h" #include "tensorflow/compiler/xla/service/conditional_simplifier.h" -#include "tensorflow/compiler/xla/service/dot_decomposer.h" +#include "tensorflow/compiler/xla/service/convolution_feature_group_converter.h" #include "tensorflow/compiler/xla/service/flatten_call_graph.h" #include "tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_rewriter.h" #include "tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.h" @@ -52,9 +52,11 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h" #include "tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/nvptx_backend_lib.h" #include "tensorflow/compiler/xla/service/gpu/multi_output_fusion.h" +#include "tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores.h" #include "tensorflow/compiler/xla/service/gpu/pad_insertion.h" #include "tensorflow/compiler/xla/service/gpu/partition_assignment.h" #include "tensorflow/compiler/xla/service/gpu/stream_assignment.h" +#include "tensorflow/compiler/xla/service/gpu/stream_executor_util.h" #include "tensorflow/compiler/xla/service/gpu/thunk_schedule.h" #include "tensorflow/compiler/xla/service/hlo.pb.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" @@ -71,10 +73,10 @@ limitations under the License. #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" #include "tensorflow/compiler/xla/service/reduce_precision_insertion.h" #include "tensorflow/compiler/xla/service/reshape_mover.h" +#include "tensorflow/compiler/xla/service/scatter_expander.h" #include "tensorflow/compiler/xla/service/transpose_folding.h" #include "tensorflow/compiler/xla/service/tuple_simplifier.h" #include "tensorflow/compiler/xla/service/while_loop_constant_sinking.h" -#include "tensorflow/compiler/xla/service/while_loop_invariant_code_motion.h" #include "tensorflow/compiler/xla/service/while_loop_simplifier.h" #include "tensorflow/compiler/xla/service/zero_sized_hlo_elimination.h" #include "tensorflow/compiler/xla/status_macros.h" @@ -130,8 +132,12 @@ string GetLibdeviceDir(const string& config_cuda_data_dir) { } // Runs optimization passes on the given HLO module. +// +// It takes a compiler pointer, as passes may compile and execute HLOs on the +// fly for cuDNN verification or other purposes. Status OptimizeHloModule(HloModule* hlo_module, se::StreamExecutor* stream_exec, - DeviceMemoryAllocator* device_allocator) { + DeviceMemoryAllocator* device_allocator, + Compiler* compiler) { { HloPassPipeline pipeline("optimization"); pipeline.AddInvariantChecker(); @@ -146,7 +152,6 @@ Status OptimizeHloModule(HloModule* hlo_module, se::StreamExecutor* stream_exec, // support BF16 operations without directly implementing a BF16 lowering for // most ops. pipeline.AddPass(BF16, F32); - pipeline.AddPass(); { auto& pass = @@ -168,6 +173,8 @@ Status OptimizeHloModule(HloModule* hlo_module, se::StreamExecutor* stream_exec, // elimination has to come after that pass. pipeline.AddPass(); + pipeline.AddPass(); + pass.AddPass( /*is_layout_sensitive=*/false, [](const Shape&, const Shape&) { return false; }); @@ -197,8 +204,16 @@ Status OptimizeHloModule(HloModule* hlo_module, se::StreamExecutor* stream_exec, // (PadInsertion). HloPassPipeline pipeline("conv_canonicalization"); pipeline.AddInvariantChecker(); + // TODO(b/31709653): Directly use the grouped convolution support of Cudnn. + pipeline.AddPass(); pipeline.AddPass(); pipeline.AddPass(); + if (IsVoltaOrLater(*stream_exec)) { + pipeline.AddPass(); + // PadForTensorCores leaves behind unnecessary tuple/get-tuple-element + // pairs that TupleSimplifier fixes. + pipeline.AddPass(); + } TF_RETURN_IF_ERROR(pipeline.Run(hlo_module).status()); } @@ -240,8 +255,8 @@ Status OptimizeHloModule(HloModule* hlo_module, se::StreamExecutor* stream_exec, // the gte(customcall, 0) would probably already be into a fusion node. We // can't simplify across HloComputation boundaries, so in this case we // wouldn't be able to simplify away the new_tuple bits. - pipeline.AddPass(stream_exec, - device_allocator); + pipeline.AddPass( + stream_exec, device_allocator, compiler); // Clean up new_tuple described above. pipeline.AddPass(); @@ -275,14 +290,6 @@ Status OptimizeHloModule(HloModule* hlo_module, se::StreamExecutor* stream_exec, } } - { - // Do an aggressive LICM pass over while loops. In particular, this hoists - // constants that were sunk by WhileLoopConstantSinking. Leaving them in - // the while loop may result in unnecessary copies. - HloPassPipeline pipeline("while-loop-licm"); - pipeline.AddPass(true); - TF_RETURN_IF_ERROR(pipeline.Run(hlo_module).status()); - } return Status::OK(); } @@ -495,11 +502,15 @@ NVPTXCompiler::NVPTXCompiler() StatusOr> NVPTXCompiler::RunHloPasses( std::unique_ptr module, se::StreamExecutor* stream_exec, DeviceMemoryAllocator* device_allocator) { + // We dump the post-optimization HLO in RunBackend so no need to dump it here. + VLOG(2) << "*** HLO Before Optimization"; + XLA_VLOG_LINES(2, module->ToString()); + XLA_SCOPED_LOGGING_TIMER("NVPTXCompiler::RunHloPasses"); tracing::ScopedActivity activity("HLO Transforms", module->name(), /*is_expensive=*/true); TF_RETURN_IF_ERROR( - OptimizeHloModule(module.get(), stream_exec, device_allocator)); + OptimizeHloModule(module.get(), stream_exec, device_allocator, this)); return std::move(module); } @@ -540,15 +551,18 @@ StatusOr> NVPTXCompiler::RunBackend( // temporary buffers are required to run the computation. TF_ASSIGN_OR_RETURN( std::unique_ptr buffer_assignment, - BufferAssigner::Run(module.get(), hlo_schedule->ConsumeHloOrdering(), - BufferSizeBytesFunction(), - /*color_alignment=*/[](LogicalBuffer::Color) { - return kXlaAllocatedBufferAlignBytes; - })); + BufferAssigner::Run( + module.get(), hlo_schedule->ConsumeHloOrdering(), + BufferSizeBytesFunction(), + /*color_alignment=*/ + [](LogicalBuffer::Color) { return kXlaAllocatedBufferAlignBytes; }, + /*allow_input_output_aliasing=*/false, + /*allocate_buffers_for_constants=*/true)); // BufferAssignment::Stats::ToString() and BufferAssignment::ToString() // include headers, so no need for us to print them ourselves. XLA_VLOG_LINES(1, buffer_assignment->GetStats().ToString()); XLA_VLOG_LINES(2, buffer_assignment->ToString()); + VLOG(2) << "*** HLO After Optimization"; XLA_VLOG_LINES(2, module->ToString()); const string xla_dump_optimized_hlo_proto_to = module->config().debug_options().xla_dump_optimized_hlo_proto_to(); @@ -565,6 +579,9 @@ StatusOr> NVPTXCompiler::RunBackend( HloComputation* entry_computation = module->entry_computation(); IrEmitterUnnested ir_emitter(module->config(), entry_computation, &ir_emitter_context); + + TF_RETURN_IF_ERROR(ir_emitter.EmitConstantGlobals()); + { XLA_SCOPED_LOGGING_TIMER("NVPTXCompiler::RunBackend - IR emission"); TF_RETURN_IF_ERROR(entry_computation->Accept(&ir_emitter)); @@ -673,7 +690,7 @@ StatusOr> NVPTXCompiler::RunBackend( const std::vector cubin = CompilePtxOrGetCachedResult(ptx, cc_major, cc_minor); - auto thunk_schedule = MakeUnique( + auto thunk_schedule = absl::make_unique( ir_emitter.ConsumeThunkSequence(), std::move(stream_assignment), hlo_schedule->ThunkLaunchOrder()); VLOG(2) << "Printing the thunk schedule..."; @@ -687,7 +704,7 @@ StatusOr> NVPTXCompiler::RunBackend( cost_analysis.set_bytes_per_second( stream_exec->GetDeviceDescription().memory_bandwidth()); TF_RETURN_IF_ERROR(module->entry_computation()->Accept(&cost_analysis)); - profile_index_map = MakeUnique(*module); + profile_index_map = absl::make_unique(*module); profile_printer = CreateHloProfilePrinterData(*profile_index_map, cost_analysis); } @@ -796,7 +813,7 @@ se::Platform::Id NVPTXCompiler::PlatformId() const { static bool InitModule() { xla::Compiler::RegisterCompilerFactory( stream_executor::cuda::kCudaPlatformId, - []() { return xla::MakeUnique(); }); + []() { return absl::make_unique(); }); return true; } static bool module_initialized = InitModule(); diff --git a/tensorflow/compiler/xla/service/gpu/outfeed_manager.cc b/tensorflow/compiler/xla/service/gpu/outfeed_manager.cc index 4aaf0c9e142106a0e74f319d71dad4c4c96d3f08..2fa170964e974a6535307d7a21eb3e7760d02536 100644 --- a/tensorflow/compiler/xla/service/gpu/outfeed_manager.cc +++ b/tensorflow/compiler/xla/service/gpu/outfeed_manager.cc @@ -15,8 +15,8 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/outfeed_manager.h" +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/map_util.h" -#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/core/platform/logging.h" diff --git a/tensorflow/compiler/xla/service/gpu/outfeed_manager.h b/tensorflow/compiler/xla/service/gpu/outfeed_manager.h index a752eb70119b00e8cca7ddce26da7730ef5db8cb..160ba4b691f818ff01b41b8603c11853ea12c253 100644 --- a/tensorflow/compiler/xla/service/gpu/outfeed_manager.h +++ b/tensorflow/compiler/xla/service/gpu/outfeed_manager.h @@ -36,22 +36,19 @@ class OutfeedBuffer { OutfeedBuffer(int64 length) : length_(length) {} // Waits for the device transfer to be finished. - std::unique_ptr WaitUntilAvailable() { - done_.WaitForNotification(); - return std::move(destination_); - } + void WaitUntilAvailable() { done_.WaitForNotification(); } int64 length() const { return length_; } - void set_destination(std::unique_ptr destination) { + void set_destination(std::unique_ptr destination) { destination_ = std::move(destination); } - Literal* destination() { return destination_.get(); } + MutableBorrowingLiteral* destination() { return destination_.get(); } // Callback to signal that this buffer is consumed. void Done() { done_.Notify(); } private: - std::unique_ptr destination_; + std::unique_ptr destination_; const int64 length_; tensorflow::Notification done_; }; diff --git a/tensorflow/compiler/xla/service/gpu/outfeed_thunk.cc b/tensorflow/compiler/xla/service/gpu/outfeed_thunk.cc index 7986e63f43ee508370f94fdb9057b91bfe4add18..b99d998c4d7df514c024b1f8d643d08c72059d0e 100644 --- a/tensorflow/compiler/xla/service/gpu/outfeed_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/outfeed_thunk.cc @@ -50,10 +50,6 @@ Status OutfeedThunk::ExecuteOnStream( if (!*buffer) { // Tuple pointers. return Status::OK(); } - // Allocate storage for the literal data. - const Shape& shape = - ShapeUtil::GetSubshape(outfeed_buffers->shape(), index); - (*buffer)->set_destination(Literal::CreateFromShape(shape)); BufferAllocation::Slice slice = outfeed_slices_.element(index); se::DeviceMemoryBase data_address; diff --git a/tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores.cc b/tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores.cc new file mode 100644 index 0000000000000000000000000000000000000000..79f7d31816baf0b95b967771b956a9c06ac81e91 --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores.cc @@ -0,0 +1,233 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores.h" + +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/compiler/xla/window_util.h" + +namespace xla { +namespace gpu { + +using tensorflow::gtl::ArraySlice; + +// We want the input/output feature counts of an f16 conv to be factors of 8, +// because without this cudnn can't use tensor cores on the conv. +static constexpr int64 kDesiredNumFeaturesFactor = 8; + +// We won't pad a conv if doing so increases the total number of bytes in the +// lhs, rhs, or result by more than this amount. +// +// TODO(jlebar): This number was tuned experimentally. It represents a +// compromise on our current benchmarks; it speeds some up significantly, and +// doesn't slow any down. But we can observe by changing this value that +// there's additional room for speedups. Achieving those speedups without also +// slowing other things down will likely require a more sophisticated heuristic, +// possibly some form of auto-tuning. +static constexpr double kMaxBytesTouchedIncrease = 1.2; + +// Pads the given dimensions in the given shape up to a multiple of +// kDesiredNumFeaturesFactor. +static Shape PadShape(Shape s, ArraySlice dims) { + for (int64 dim : dims) { + int64 dim_to_pad_size = s.dimensions(dim); + int64 new_dim_to_pad_size = + RoundUpToNearest(dim_to_pad_size, kDesiredNumFeaturesFactor); + s.set_dimensions(dim, new_dim_to_pad_size); + } + return s; +} + +// Creates and returns an HLO that zero-pads one or more dimensions in the given +// instruction so that its shape is equal to the given shape. +// +// Padding is added to the end of each relevant dimension. +// +// If the instruction already has the given shape, simply returns it without an +// intervening pad. +static HloInstruction* PadInstruction(HloInstruction* instr, + const Shape& new_shape) { + HloComputation* comp = instr->parent(); + + const Shape& shape = instr->shape(); + auto* zero = comp->AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::Zero(shape.element_type()).CloneToUnique())); + + PaddingConfig pad_config = MakeNoPaddingConfig(ShapeUtil::Rank(shape)); + + bool added_padding = false; + for (int64 dim = 0; dim < ShapeUtil::Rank(shape); ++dim) { + if (shape.dimensions(dim) == new_shape.dimensions(dim)) { + continue; + } + CHECK_GT(new_shape.dimensions(dim), shape.dimensions(dim)); + pad_config.mutable_dimensions(dim)->set_edge_padding_high( + new_shape.dimensions(dim) - shape.dimensions(dim)); + added_padding = true; + } + + if (!added_padding) { + return instr; + } + return comp->AddInstruction( + HloInstruction::CreatePad(new_shape, instr, zero, pad_config)); +} + +// Pads the input/output feature dimensions of the given cudnn convolution +// custom-call to be multiples of kDesiredNumFeaturesFactor. +static StatusOr PadFeaturesDims(HloInstruction* conv) { + CHECK_EQ(0, conv->shape().tuple_shapes(1).dimensions(0)) + << "conv must use 0 scratch bytes, i.e. this pass must be run " + "before CudnnConvolutionAlgorithmPicker."; + + const auto& target = conv->custom_call_target(); + const auto& dnums = conv->convolution_dimension_numbers(); + auto* lhs = conv->mutable_operand(0); + auto* rhs = conv->mutable_operand(1); + const Shape& result_shape = conv->shape().tuple_shapes(0); + + Shape new_lhs_shape = [&] { + if (target == kCudnnConvForwardCallTarget || + target == kCudnnConvBackwardFilterCallTarget) { + // LHS is "input". + return PadShape(lhs->shape(), {dnums.input_feature_dimension()}); + } + CHECK_EQ(target, kCudnnConvBackwardInputCallTarget); + // LHS is "output". + return PadShape(lhs->shape(), {dnums.output_feature_dimension()}); + }(); + + Shape new_rhs_shape = [&] { + if (target == kCudnnConvForwardCallTarget || + target == kCudnnConvBackwardInputCallTarget) { + // RHS is "filter". + return PadShape(rhs->shape(), {dnums.kernel_input_feature_dimension(), + dnums.kernel_output_feature_dimension()}); + } + CHECK_EQ(target, kCudnnConvBackwardFilterCallTarget); + // RHS is "output". + return PadShape(rhs->shape(), {dnums.output_feature_dimension()}); + }(); + + if (ShapeUtil::Equal(lhs->shape(), new_lhs_shape) && + ShapeUtil::Equal(rhs->shape(), new_rhs_shape)) { + VLOG(3) << "No need to pad features of " << conv->ToString(); + return false; + } + + Shape new_result_shape = [&] { + if (target == kCudnnConvForwardCallTarget) { + // Result is "output". + return PadShape(result_shape, {dnums.output_feature_dimension()}); + } + if (target == kCudnnConvBackwardInputCallTarget) { + // Result is "input". + return PadShape(result_shape, {dnums.input_feature_dimension()}); + } + CHECK_EQ(target, kCudnnConvBackwardFilterCallTarget); + // Result is "filter". + return PadShape(result_shape, {dnums.kernel_input_feature_dimension(), + dnums.kernel_output_feature_dimension()}); + }(); + + // Check that padding wouldn't increase the total bytes read/written by this + // operation too much. + auto check_size_increase = [&](const Shape& old_shape, + const Shape& new_shape) { + int64 old_bytes = ShapeUtil::ByteSizeOf(old_shape); + int64 new_bytes = ShapeUtil::ByteSizeOf(new_shape); + if (new_bytes <= old_bytes * kMaxBytesTouchedIncrease) { + return true; + } + VLOG(3) << "Not padding convolution; doing so would change input / result " + "shape from " + << ShapeUtil::HumanString(old_shape) << " to " + << ShapeUtil::HumanString(new_shape) << ", a size increase of " + << new_bytes / static_cast(old_bytes) << "x > " + << kMaxBytesTouchedIncrease << "x: " << conv->ToString(); + return false; + }; + if (!check_size_increase(lhs->shape(), new_lhs_shape) || + !check_size_increase(rhs->shape(), new_rhs_shape) || + !check_size_increase(result_shape, new_result_shape)) { + return false; + } + + // OK, let's do the transformation! + + auto* new_lhs = PadInstruction(lhs, new_lhs_shape); + auto* new_rhs = PadInstruction(rhs, new_rhs_shape); + CHECK(new_lhs != lhs || new_rhs != rhs) + << "We should have had to pad either LHS or RHS."; + + auto add = [&](std::unique_ptr new_instr) { + return conv->parent()->AddInstruction(std::move(new_instr)); + }; + + Shape new_conv_shape = ShapeUtil::MakeTupleShape( + {new_result_shape, ShapeUtil::MakeShape(U8, {0})}); + auto* new_conv = + add(conv->CloneWithNewOperands(new_conv_shape, {new_lhs, new_rhs})); + + // Slice the new conv result if necessary, keeping in mind that new_conv has + // tuple shape (new_result_shape, u8[0]). + if (!ShapeUtil::Equal(result_shape, new_result_shape)) { + std::vector start_indices(result_shape.dimensions_size(), 0); + std::vector end_indices(result_shape.dimensions().begin(), + result_shape.dimensions().end()); + std::vector strides(result_shape.dimensions_size(), 1); + + auto* new_conv_result = add( + HloInstruction::CreateGetTupleElement(new_result_shape, new_conv, 0)); + auto* empty_temp_buffer = + add(HloInstruction::CreateConstant(LiteralUtil::CreateR1({}))); + auto* sliced_result = add(HloInstruction::CreateSlice( + result_shape, new_conv_result, start_indices, end_indices, strides)); + new_conv = + add(HloInstruction::CreateTuple({sliced_result, empty_temp_buffer})); + } + + VLOG(2) << "Padded features of " << conv->ToString() << ", replaced with " + << new_conv->ToString(); + TF_RETURN_IF_ERROR(conv->parent()->ReplaceInstruction(conv, new_conv)); + return true; +} + +static std::vector GetRelevantConvs(HloComputation* comp) { + std::vector convs; + for (HloInstruction* instr : comp->instructions()) { + if (IsCustomCallToDnnConvolution(*instr) && + instr->operand(0)->shape().element_type() == F16) { + convs.push_back(instr); + } + } + return convs; +} + +StatusOr PadForTensorCores::Run(HloModule* module) { + bool changed = false; + for (HloComputation* comp : module->MakeNonfusionComputations()) { + for (HloInstruction* conv : GetRelevantConvs(comp)) { + TF_ASSIGN_OR_RETURN(bool result, PadFeaturesDims(conv)); + changed |= result; + } + } + return changed; +} + +} // namespace gpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores.h b/tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores.h new file mode 100644 index 0000000000000000000000000000000000000000..192359f026bfb2f1d5436713e4a30725fa0ad6ba --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores.h @@ -0,0 +1,45 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_PAD_FOR_TENSOR_CORES_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_PAD_FOR_TENSOR_CORES_H_ + +#include "tensorflow/compiler/xla/service/hlo_pass_interface.h" + +namespace xla { +namespace gpu { + +// Ensures that f16 cudnn convolutions have input/output channel dimensions that +// are multiples of 8, inserting pads/slices as necessary. +// +// This is useful primarily for Volta and newer GPUs, where tensor cores can +// only be used if the channel dims are multiples of 8. It's probably the +// opposite of useful on other GPUs, so you should check what GPU you're +// targeting before running this pass. +// +// TODO(jlebar): Also pad dots. +class PadForTensorCores : public HloPassInterface { + public: + tensorflow::StringPiece name() const override { + return "pad for tensor cores"; + } + + StatusOr Run(HloModule* module) override; +}; + +} // namespace gpu +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_PAD_FOR_TENSOR_CORES_H_ diff --git a/tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores_test.cc b/tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..99e7580b826fc5cd6d98a037a5eb064552952e18 --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores_test.cc @@ -0,0 +1,164 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores.h" + +#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" +#include "tensorflow/compiler/xla/service/hlo_matchers.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" +#include "tensorflow/compiler/xla/util.h" + +namespace xla { +namespace gpu { +namespace { + +namespace op = xla::testing::opcode_matchers; +using ::testing::_; + +using PadForTensorCoresTest = HloVerifiedTestBase; + +TEST_F(PadForTensorCoresTest, PadF16ForwardConvInputChannels) { + ParseAndVerifyModule(R"( + HloModule TestModule + + ENTRY TestComputation { + input = f16[10,20,30,41] parameter(0) + filter = f16[2,2,41,40] parameter(1) + ROOT result = (f16[10,20,30,40], u8[0]) custom-call(input, filter), + window={size=2x2}, dim_labels=b01f_01io->b01f, + custom_call_target="__cudnn$convForward" + })"); + EXPECT_TRUE(PadForTensorCores().Run(&module()).ValueOrDie()); + auto* root = module().entry_computation()->root_instruction(); + + SCOPED_TRACE(module().ToString()); + EXPECT_THAT(root, op::CustomCall(kCudnnConvForwardCallTarget, + op::Pad(op::Parameter(0), _), + op::Pad(op::Parameter(1), _))); + EXPECT_TRUE(ShapeUtil::Equal(root->operand(0)->shape(), + ShapeUtil::MakeShape(F16, {10, 20, 30, 48}))); + EXPECT_TRUE(ShapeUtil::Equal(root->operand(1)->shape(), + ShapeUtil::MakeShape(F16, {2, 2, 48, 40}))); +} + +TEST_F(PadForTensorCoresTest, PadF16BackwardInputConvOutputChannels) { + ParseAndVerifyModule(R"( + HloModule TestModule + + ENTRY TestComputation { + output = f16[10,20,30,41] parameter(0) + filter = f16[2,2,40,41] parameter(1) + ROOT result = (f16[10,20,30,40], u8[0]) custom-call(output, filter), + window={size=2x2}, dim_labels=b01f_01io->b01f, + custom_call_target="__cudnn$convBackwardInput" + })"); + EXPECT_TRUE(PadForTensorCores().Run(&module()).ValueOrDie()); + auto* root = module().entry_computation()->root_instruction(); + EXPECT_THAT(root, op::CustomCall(kCudnnConvBackwardInputCallTarget, + op::Pad(op::Parameter(0), _), + op::Pad(op::Parameter(1), _))); + EXPECT_TRUE(ShapeUtil::Equal(root->operand(0)->shape(), + ShapeUtil::MakeShape(F16, {10, 20, 30, 48}))); + EXPECT_TRUE(ShapeUtil::Equal(root->operand(1)->shape(), + ShapeUtil::MakeShape(F16, {2, 2, 40, 48}))); +} + +TEST_F(PadForTensorCoresTest, PadF16ForwardConvOutputChannels) { + ParseAndVerifyModule(R"( + HloModule TestModule + + ENTRY TestComputation { + input = f16[10,20,30,40] parameter(0) + filter = f16[2,2,40,41] parameter(1) + ROOT result = (f16[10,20,30,41], u8[0]) custom-call(input, filter), + window={size=2x2}, dim_labels=b01f_01io->b01f, + custom_call_target="__cudnn$convForward" + })"); + EXPECT_TRUE(PadForTensorCores().Run(&module()).ValueOrDie()); + auto* root = module().entry_computation()->root_instruction(); + EXPECT_THAT(root, op::Tuple(op::Slice(op::GetTupleElement(op::CustomCall( + kCudnnConvForwardCallTarget, op::Parameter(0), + op::Pad(op::Parameter(1), _)))), + _)); +} + +TEST_F(PadForTensorCoresTest, PadF16BackwardInputConvInputChannels) { + ParseAndVerifyModule(R"( + HloModule TestModule + + ENTRY TestComputation { + output = f16[10,20,30,40] parameter(0) + filter = f16[2,2,41,40] parameter(1) + result = (f16[10,20,30,41], u8[0]) custom-call(output, filter), + window={size=2x2}, dim_labels=b01f_01io->b01f, + custom_call_target="__cudnn$convBackwardInput" + ROOT gte = f16[10,20,30,41] get-tuple-element(result), index=0 + })"); + EXPECT_TRUE(PadForTensorCores().Run(&module()).ValueOrDie()); + auto* root = module().entry_computation()->root_instruction(); + EXPECT_THAT(root, op::GetTupleElement(op::Tuple( + op::Slice(op::GetTupleElement(op::CustomCall( + kCudnnConvBackwardInputCallTarget, op::Parameter(0), + op::Pad(op::Parameter(1), _)))), + _))); +} + +TEST_F(PadForTensorCoresTest, PadF16BackwardFilterConvInputChannels) { + ParseAndVerifyModule(R"( + HloModule TestModule + + ENTRY TestComputation { + input = f16[10,20,30,41] parameter(0) + output = f16[10,20,30,40] parameter(1) + result = (f16[2,2,41,40], u8[0]) custom-call(input, output), + window={size=2x2}, dim_labels=b01f_01io->b01f, + custom_call_target="__cudnn$convBackwardFilter" + ROOT gte = f16[2,2,41,40] get-tuple-element(result), index=0 + })"); + EXPECT_TRUE(PadForTensorCores().Run(&module()).ValueOrDie()); + auto* root = module().entry_computation()->root_instruction(); + EXPECT_THAT(root, op::GetTupleElement(op::Tuple( + op::Slice(op::GetTupleElement(op::CustomCall( + kCudnnConvBackwardFilterCallTarget, + op::Pad(op::Parameter(0), _), op::Parameter(1)))), + _))); +} + +TEST_F(PadForTensorCoresTest, PadF16BackwardFilterConvOutputChannels) { + ParseAndVerifyModule(R"( + HloModule TestModule + + ENTRY TestComputation { + input = f16[10,20,30,40] parameter(0) + output = f16[10,20,30,41] parameter(1) + result = (f16[2,2,40,41], u8[0]) custom-call(input, output), + window={size=2x2}, dim_labels=b01f_01io->b01f, + custom_call_target="__cudnn$convBackwardFilter" + ROOT gte = f16[2,2,40,41] get-tuple-element(result), index=0 + })"); + EXPECT_TRUE(PadForTensorCores().Run(&module()).ValueOrDie()); + auto* root = module().entry_computation()->root_instruction(); + EXPECT_THAT(root, op::GetTupleElement(op::Tuple( + op::Slice(op::GetTupleElement(op::CustomCall( + kCudnnConvBackwardFilterCallTarget, + op::Parameter(0), op::Pad(op::Parameter(1), _)))), + _))); +} + +} // anonymous namespace +} // namespace gpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/pad_insertion.cc b/tensorflow/compiler/xla/service/gpu/pad_insertion.cc index b22040eee167e784bed58dbc0d0ad2ae042037f3..98cc21ccac57268257f1f9a3999a3d876ef074fc 100644 --- a/tensorflow/compiler/xla/service/gpu/pad_insertion.cc +++ b/tensorflow/compiler/xla/service/gpu/pad_insertion.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/pad_insertion.h" +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" @@ -69,7 +70,7 @@ HloInstruction* MaybePaddedAndSlicedInput( PrimitiveType element_type = input->shape().element_type(); HloInstruction* padding = computation->AddInstruction(HloInstruction::CreateConstant( - MakeUnique(LiteralUtil::Zero(element_type)))); + absl::make_unique(LiteralUtil::Zero(element_type)))); input = MakePadHlo(input, padding, padding_config).ValueOrDie(); } @@ -126,7 +127,7 @@ HloInstruction* MaybePaddedKernel(const Window& conv_window, PrimitiveType element_type = kernel->shape().element_type(); HloInstruction* padding = computation->AddInstruction(HloInstruction::CreateConstant( - MakeUnique(LiteralUtil::Zero(element_type)))); + absl::make_unique(LiteralUtil::Zero(element_type)))); return MakePadHlo(kernel, padding, padding_config).ValueOrDie(); } } // namespace @@ -236,7 +237,7 @@ bool PadInsertion::CanonicalizeBackwardFilterConvolution( HloComputation* computation = backward_conv->parent(); HloInstruction* output = backward_conv->mutable_operand(1); HloInstruction* padding = computation->AddInstruction( - HloInstruction::CreateConstant(MakeUnique( + HloInstruction::CreateConstant(absl::make_unique( LiteralUtil::Zero(input->shape().element_type())))); HloInstruction* padded_input = MakePadHlo(input, padding, input_padding_config).ValueOrDie(); diff --git a/tensorflow/compiler/xla/service/gpu/partition_assignment.cc b/tensorflow/compiler/xla/service/gpu/partition_assignment.cc index d3fd0544fb68809125e9b9f7a5e5b7eff8c6ef43..c927c5ee1666b6198d96750ff372ac83813a9df9 100644 --- a/tensorflow/compiler/xla/service/gpu/partition_assignment.cc +++ b/tensorflow/compiler/xla/service/gpu/partition_assignment.cc @@ -18,8 +18,8 @@ limitations under the License. #include #include +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/map_util.h" -#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/shape_util.h" diff --git a/tensorflow/compiler/xla/service/gpu/stream_assignment.cc b/tensorflow/compiler/xla/service/gpu/stream_assignment.cc index 0806dd51614f4d2da12f3fbbc9fb98df5273d5c8..5b6cf2c04d05378a363232e33a6df6432cd6848e 100644 --- a/tensorflow/compiler/xla/service/gpu/stream_assignment.cc +++ b/tensorflow/compiler/xla/service/gpu/stream_assignment.cc @@ -15,8 +15,8 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/stream_assignment.h" +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/map_util.h" -#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_reachability.h" @@ -119,7 +119,7 @@ int ComputeStreamToAssign( } // namespace std::unique_ptr AssignStreams(const HloModule& module) { - auto stream_assignment = MakeUnique(); + auto stream_assignment = absl::make_unique(); const HloComputation& computation = *module.entry_computation(); std::unique_ptr reachability = computation.ComputeReachability(); diff --git a/tensorflow/compiler/xla/service/gpu/stream_assignment_test.cc b/tensorflow/compiler/xla/service/gpu/stream_assignment_test.cc index 6f4bb0580e8dfc1dce1cca0a60cc3dd9ea600fb3..3f75d8b55959495017f1b08d61bd6e7b44bed27f 100644 --- a/tensorflow/compiler/xla/service/gpu/stream_assignment_test.cc +++ b/tensorflow/compiler/xla/service/gpu/stream_assignment_test.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/stream_assignment.h" +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" @@ -33,7 +34,7 @@ class StreamAssignmentTest : public HloTestBase { auto debug_options = GetDebugOptionsForTest(); debug_options.set_xla_gpu_disable_multi_streaming(false); config.set_debug_options(debug_options); - return MakeUnique("test_module", config); + return absl::make_unique("test_module", config); } // Pre-canned shapes. diff --git a/tensorflow/compiler/xla/service/gpu/stream_executor_util.cc b/tensorflow/compiler/xla/service/gpu/stream_executor_util.cc index a50ddf6ac63c7fa7ccace94bc7f40f438aedccf8..05b305ea4cdfdbaeb42544b626a6b9990bb42f57 100644 --- a/tensorflow/compiler/xla/service/gpu/stream_executor_util.cc +++ b/tensorflow/compiler/xla/service/gpu/stream_executor_util.cc @@ -20,10 +20,17 @@ limitations under the License. namespace xla { namespace gpu { -using stream_executor::dnn::DataLayout; -using stream_executor::dnn::DataLayoutString; -using stream_executor::dnn::FilterLayout; -using stream_executor::dnn::FilterLayoutString; +using se::dnn::DataLayout; +using se::dnn::DataLayoutString; +using se::dnn::FilterLayout; +using se::dnn::FilterLayoutString; + +bool IsVoltaOrLater(const se::StreamExecutor& stream_executor) { + int major, minor; + CHECK(stream_executor.GetDeviceDescription().cuda_compute_capability(&major, + &minor)); + return major >= 7; +} StatusOr> StreamExecutorConvLayoutsToXlaLayouts(const ConvolutionDimensionNumbers& dnums, diff --git a/tensorflow/compiler/xla/service/gpu/stream_executor_util.h b/tensorflow/compiler/xla/service/gpu/stream_executor_util.h index 39a6a38d001f502b2abb8de6efe2ce623b478c71..1fc46bafa10e7ba6c896f081d5c836bd400886c9 100644 --- a/tensorflow/compiler/xla/service/gpu/stream_executor_util.h +++ b/tensorflow/compiler/xla/service/gpu/stream_executor_util.h @@ -17,6 +17,7 @@ limitations under the License. #define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_STREAM_EXECUTOR_UTIL_H_ #include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" @@ -25,18 +26,20 @@ limitations under the License. namespace xla { namespace gpu { +// Returns true if the given StreamExecutor is for a Volta or newer nvidia GPU. +bool IsVoltaOrLater(const se::StreamExecutor& stream_exec); + // Returns (input, filter, output) XLA Layout protos given the StreamExecutor // layouts. StatusOr> StreamExecutorConvLayoutsToXlaLayouts(const ConvolutionDimensionNumbers& dnums, - stream_executor::dnn::DataLayout input, - stream_executor::dnn::FilterLayout filter, - stream_executor::dnn::DataLayout output); + se::dnn::DataLayout input, + se::dnn::FilterLayout filter, + se::dnn::DataLayout output); // Returns (input, filter, output) StreamExecutor layouts given the XLA layouts. -StatusOr> +StatusOr< + std::tuple> XlaConvLayoutsToStreamExecutorLayouts(const ConvolutionDimensionNumbers& dnums, const Layout& input, const Layout& filter, const Layout& output); diff --git a/tensorflow/compiler/xla/service/gpu/tests/BUILD b/tensorflow/compiler/xla/service/gpu/tests/BUILD index 686c3c16c97151d5ac06983f8709f1d367eb596c..db4a33dc564b62b5fe54b725ea453a6fcbfb3287 100644 --- a/tensorflow/compiler/xla/service/gpu/tests/BUILD +++ b/tensorflow/compiler/xla/service/gpu/tests/BUILD @@ -35,13 +35,13 @@ cc_library( "requires-gpu-sm35", ], deps = [ - "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/compiler/xla/service:gpu_plugin", "//tensorflow/compiler/xla/service/gpu:gpu_executable", "//tensorflow/compiler/xla/tests:filecheck", "//tensorflow/compiler/xla/tests:llvm_irgen_test_base", "//tensorflow/core:lib", + "@com_google_absl//absl/memory", ], ) @@ -60,6 +60,7 @@ tf_cc_test( "//tensorflow/compiler/xla/service:hlo", "//tensorflow/core:test", "//tensorflow/core:test_main", + "@com_google_absl//absl/memory", ], ) @@ -94,6 +95,7 @@ tf_cc_test( "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/core:test", "//tensorflow/core:test_main", + "@com_google_absl//absl/memory", ], ) @@ -111,8 +113,8 @@ tf_cc_test( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/client/lib:arithmetic", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/core:lib", @@ -150,6 +152,7 @@ tf_cc_test( "//tensorflow/compiler/xla/service:hlo", "//tensorflow/core:test", "//tensorflow/core:test_main", + "@com_google_absl//absl/memory", ], ) @@ -168,6 +171,7 @@ tf_cc_test( "//tensorflow/compiler/xla/service:hlo", "//tensorflow/core:test", "//tensorflow/core:test_main", + "@com_google_absl//absl/memory", ], ) diff --git a/tensorflow/compiler/xla/service/gpu/tests/gpu_codegen_test.cc b/tensorflow/compiler/xla/service/gpu/tests/gpu_codegen_test.cc index 4b8415fe9106137e588f345a3492f93e46aeb5b6..0e84ec7e621fcd1778725dc2743d7a70fb01c47a 100644 --- a/tensorflow/compiler/xla/service/gpu/tests/gpu_codegen_test.cc +++ b/tensorflow/compiler/xla/service/gpu/tests/gpu_codegen_test.cc @@ -14,8 +14,8 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/xla/service/gpu/tests/gpu_codegen_test.h" +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" -#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/gpu/gpu_executable.h" #include "tensorflow/compiler/xla/tests/filecheck.h" #include "tensorflow/core/platform/logging.h" @@ -32,7 +32,7 @@ std::unique_ptr GpuCodegenTest::CreateNewModuleWithFTZ(bool ftz) { debug_options.add_xla_disable_hlo_passes("constant_folding"); config.set_debug_options(debug_options); - return MakeUnique(TestName(), config); + return absl::make_unique(TestName(), config); } void GpuCodegenTest::CompileAndVerifyPtx(std::unique_ptr hlo_module, diff --git a/tensorflow/compiler/xla/service/gpu/tests/gpu_copy_test.cc b/tensorflow/compiler/xla/service/gpu/tests/gpu_copy_test.cc index ce69e058e64aab1f3c292b2ad7c7b529d4666b35..4550f36fdfc097632fed4956fcd3e42ef8a919c5 100644 --- a/tensorflow/compiler/xla/service/gpu/tests/gpu_copy_test.cc +++ b/tensorflow/compiler/xla/service/gpu/tests/gpu_copy_test.cc @@ -16,9 +16,9 @@ limitations under the License. #include #include +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/literal_util.h" -#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/gpu/tests/gpu_codegen_test.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" diff --git a/tensorflow/compiler/xla/service/gpu/tests/gpu_index_test.cc b/tensorflow/compiler/xla/service/gpu/tests/gpu_index_test.cc index e5958165eff21d82faf821213e50fe30a11059a4..a06576df7b874745236a8d9075355a01ec42e777 100644 --- a/tensorflow/compiler/xla/service/gpu/tests/gpu_index_test.cc +++ b/tensorflow/compiler/xla/service/gpu/tests/gpu_index_test.cc @@ -16,8 +16,8 @@ limitations under the License. #include #include +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/literal.h" -#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/gpu/tests/gpu_codegen_test.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" diff --git a/tensorflow/compiler/xla/service/gpu/tests/gpu_ldg_test.cc b/tensorflow/compiler/xla/service/gpu/tests/gpu_ldg_test.cc index 6c9ae7bada5e7545b558b6fcb872ece60850cbe9..6a9ecd9dae7c9ddde0b56d8615e4a39fb3df0af9 100644 --- a/tensorflow/compiler/xla/service/gpu/tests/gpu_ldg_test.cc +++ b/tensorflow/compiler/xla/service/gpu/tests/gpu_ldg_test.cc @@ -20,8 +20,8 @@ limitations under the License. #include #include +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/literal.h" -#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/gpu/tests/gpu_codegen_test.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" diff --git a/tensorflow/compiler/xla/service/gpu/tests/gpu_noalias_test.cc b/tensorflow/compiler/xla/service/gpu/tests/gpu_noalias_test.cc index c42e5704a4d2e611a203293e60a86ba4104bca46..15198865bda98f9718342d5a444a20305f923b48 100644 --- a/tensorflow/compiler/xla/service/gpu/tests/gpu_noalias_test.cc +++ b/tensorflow/compiler/xla/service/gpu/tests/gpu_noalias_test.cc @@ -16,8 +16,8 @@ limitations under the License. #include #include +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/literal.h" -#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/gpu/tests/gpu_codegen_test.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" diff --git a/tensorflow/compiler/xla/service/gpu/tests/infeed_test.cc b/tensorflow/compiler/xla/service/gpu/tests/infeed_test.cc index ba5cd2d84dfc0cd1515875e8510c18d89e4ec5f7..9072b30317d253fd6d50e9d98949cad4eaebfe7b 100644 --- a/tensorflow/compiler/xla/service/gpu/tests/infeed_test.cc +++ b/tensorflow/compiler/xla/service/gpu/tests/infeed_test.cc @@ -19,7 +19,7 @@ limitations under the License. #include "tensorflow/compiler/xla/client/global_data.h" #include "tensorflow/compiler/xla/client/lib/arithmetic.h" #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test_helpers.h" diff --git a/tensorflow/compiler/xla/service/gpu/thunk.h b/tensorflow/compiler/xla/service/gpu/thunk.h index 4df0bb005b623e5ac79a4dfcb7c5a8a7a400940c..e68bee035a029178844282995429eaa960cc4817 100644 --- a/tensorflow/compiler/xla/service/gpu/thunk.h +++ b/tensorflow/compiler/xla/service/gpu/thunk.h @@ -82,17 +82,9 @@ class Thunk { return Status::OK(); } - // Users of Thunk should call ShouldHaltAllActivityBeforeRunning(stream) - // before calling ExecuteOnStream(stream). If it returns true, it's the - // user's responsibility to wait for all activity on the GPU to finish before - // calling ExecuteOnStream. - // - // This value is not required to be constant for a given Thunk. For example, - // a Thunk that performs autotuning may return true for its first run and - // false thereafter. - virtual bool ShouldHaltAllActivityBeforeRunning(se::Stream* /*stream*/) { - return false; - } + // Returns true if this kernel will autotune for the stream device the next + // time it is run. + virtual bool WillAutotuneKernel(se::Stream* /*stream*/) { return false; } // Execute the kernel for the thunk on the given stream. This method must be // called after Initialize and can be called multiple times over Thunk's diff --git a/tensorflow/compiler/xla/service/gpu/tuple_thunk.cc b/tensorflow/compiler/xla/service/gpu/tuple_thunk.cc index a10e40451c1db01ce73db7b56a3a0599769fa49b..989b542ff4503600b2e3c751a23345959fab6fd6 100644 --- a/tensorflow/compiler/xla/service/gpu/tuple_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/tuple_thunk.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/tuple_thunk.h" +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h" #include "tensorflow/compiler/xla/util.h" @@ -24,24 +25,32 @@ namespace gpu { Status TupleThunk::ExecuteOnStream(const BufferAllocations& buffer_allocations, se::Stream* stream, HloExecutionProfiler* profiler) { - std::vector tuple_element_buffer_addresses; - for (BufferAllocation::Slice tuple_element_buffer : tuple_element_buffers_) { - tuple_element_buffer_addresses.push_back( - buffer_allocations.GetDeviceAddress(tuple_element_buffer).opaque()); + auto size = tuple_element_buffers_.size(); + auto tuple_element_buffer_addresses = absl::make_unique(size); + for (int i = 0; i != size; ++i) { + tuple_element_buffer_addresses[i] = + buffer_allocations.GetDeviceAddress(tuple_element_buffers_[i]).opaque(); } se::DeviceMemory dest_buffer_address( buffer_allocations.GetDeviceAddress(dest_buffer_)); - auto host_size = tuple_element_buffer_addresses.size() * sizeof(void*); + auto host_size = size * sizeof(void*); auto op_profiler = profiler->MakeScopedInstructionProfiler(hlo_instruction()); if (!stream ->ThenMemcpy(&dest_buffer_address, - tuple_element_buffer_addresses.data(), host_size) + tuple_element_buffer_addresses.get(), host_size) .ok()) { return InternalError( "Unable to launch MemcpyH2D from %p to %p with size %lu", - tuple_element_buffer_addresses.data(), dest_buffer_address.opaque(), - sizeof(void*) * tuple_element_buffer_addresses.size()); + tuple_element_buffer_addresses.get(), dest_buffer_address.opaque(), + host_size); + } + // Free the tuple address buffer when memcpy is done. + auto* buffers_raw = tuple_element_buffer_addresses.release(); + if (!stream->ThenDoHostCallback([buffers_raw] { delete[] buffers_raw; }) + .ok()) { + delete[] buffers_raw; + return InternalError("Unable to enqueue host callback!"); } return Status::OK(); } diff --git a/tensorflow/compiler/xla/service/gpu/while_thunk.cc b/tensorflow/compiler/xla/service/gpu/while_thunk.cc index 1315a4183a98d6ea9ed4c82d4c22e77c2109ec83..828fc2884bd7d58333d86c35a537f06467cf6e4a 100644 --- a/tensorflow/compiler/xla/service/gpu/while_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/while_thunk.cc @@ -15,7 +15,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/while_thunk.h" -#include "tensorflow/compiler/xla/ptr_util.h" +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/core/errors.h" @@ -34,9 +34,9 @@ WhileThunk::WhileThunk( // and body_thunk_sequence_ constructors because these SequentialThunks // are logically "part of" this WhileThunk, and shouldn't be profiled // separately from it. - condition_thunk_sequence_(MakeUnique( + condition_thunk_sequence_(absl::make_unique( std::move(*condition_thunk_sequence), nullptr)), - body_thunk_sequence_(MakeUnique( + body_thunk_sequence_(absl::make_unique( std::move(*body_thunk_sequence), nullptr)) {} Status WhileThunk::Initialize(const GpuExecutable& executable, @@ -57,6 +57,7 @@ Status WhileThunk::ExecuteOnStream(const BufferAllocations& buffer_allocations, while (true) { // Invoke thunk sequence for while 'condition' computation. profiler->StartHloComputation(); + VLOG(3) << "Executing condition computation"; TF_RETURN_IF_ERROR(condition_thunk_sequence_->ExecuteOnStream( buffer_allocations, stream, profiler)); profiler->FinishHloComputation(hlo_instruction()->while_condition()); @@ -64,6 +65,7 @@ Status WhileThunk::ExecuteOnStream(const BufferAllocations& buffer_allocations, // Copy the result of condition computation and break the loop if 'false'. bool condition_result; stream->ThenMemcpy(&condition_result, condition_result_data, sizeof(bool)); + VLOG(3) << "condition_result = " << condition_result; Status block_status = stream->BlockHostUntilDone(); if (!block_status.ok()) { return InternalError( @@ -78,6 +80,7 @@ Status WhileThunk::ExecuteOnStream(const BufferAllocations& buffer_allocations, // We measure the time of one execution of the while body computation. The // while body may be executed more than once, the last measurement "wins". profiler->StartHloComputation(); + VLOG(3) << "Executing body computation"; // Invoke thunk sequence for while 'body' computation, and pass on // 'profiler' to measure the timing of the thunks in 'body_thunk_sequence_'. TF_RETURN_IF_ERROR(body_thunk_sequence_->ExecuteOnStream(buffer_allocations, diff --git a/tensorflow/compiler/xla/service/gpu/while_transformer.cc b/tensorflow/compiler/xla/service/gpu/while_transformer.cc deleted file mode 100644 index c5321df6c466fcb3816fb2aedad65b7c3811cb37..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/xla/service/gpu/while_transformer.cc +++ /dev/null @@ -1,521 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/compiler/xla/service/gpu/while_transformer.h" - -#include -#include - -#include "tensorflow/compiler/xla/literal.h" -#include "tensorflow/compiler/xla/service/hlo_computation.h" -#include "tensorflow/compiler/xla/shape_util.h" -#include "tensorflow/compiler/xla/status_macros.h" -#include "tensorflow/compiler/xla/util.h" -#include "tensorflow/core/lib/core/errors.h" - -namespace xla { -namespace gpu { - -namespace { - -// TODO(b/33483676) Use an expression tree to specify computations to pattern -// match for while transformations. - -// ExprTree is a simple recursive data structure used to express computation -// patterns to match. -// -// Each ExprTree node is comprised of an HloOpcode, and a set of operands (each -// of type ExprTree). Operands can be added by specifying the index and -// HloOpcode of the operand. -// -// For example, the following computation: -// -// Parameter -// | -// Const GetTupleElement -// \ / -// Add (root) -// -// Can be matched with the following expression tree: -// -// ExprTree add(HloOpcode::kAdd, -// ExprTree(HloOpcode::kConstant), -// ExprTree(HloOpcode::kGetTupleElement, -// tuple_index, ExprTree(HloOpcode::kParameter))); -// -// Match the ExprTree root against an Hlo graph: -// -// ExprTree::TaggedInstructionMap tagged_instructions; -// TF_RETURN_IF_ERROR(add.Match(computation_->root_instruction(), -// &tagged_instructions)); -// -// Instructions that are "tagged" with a context-specific string will -// be returned in 'tagged_instructions' for further processing (i.e. parsing -// constants or recording the tuple_index). -// -class ExprTree { - public: - explicit ExprTree(HloOpcode opcode) : opcode_(opcode) {} - ExprTree(HloOpcode opcode, const string& tag) : opcode_(opcode), tag_(tag) {} - ExprTree(HloOpcode opcode, const ExprTree& operand0) : opcode_(opcode) { - SetOperand(0, operand0); - } - ExprTree(HloOpcode opcode, int64 index0, const ExprTree& operand0) - : opcode_(opcode) { - SetOperand(index0, operand0); - } - ExprTree(HloOpcode opcode, int64 index0, const ExprTree& operand0, - int64 index1, const ExprTree& operand1) - : opcode_(opcode) { - SetOperand(index0, operand0); - SetOperand(index1, operand1); - } - ExprTree(HloOpcode opcode, const string& tag, const ExprTree& operand0) - : opcode_(opcode), tag_(tag) { - SetOperand(0, operand0); - } - ExprTree(HloOpcode opcode, const ExprTree& operand0, const ExprTree& operand1) - : opcode_(opcode) { - SetOperand(0, operand0); - SetOperand(1, operand1); - } - - ExprTree(const ExprTree& to_copy) { - opcode_ = to_copy.opcode_; - tag_ = to_copy.tag_; - if (to_copy.fused_root_tree_ != nullptr) { - fused_root_tree_.reset(new ExprTree(*to_copy.fused_root_tree_)); - } - for (auto& pair : to_copy.operands_) { - CHECK(operands_.find(pair.first) == operands_.end()); - operands_.insert(std::make_pair( - pair.first, std::unique_ptr(new ExprTree(*pair.second)))); - } - } - - void SetFusedRoot(const ExprTree& fused_root) { - fused_root_tree_.reset(new ExprTree(fused_root)); - } - - typedef std::unordered_map - TaggedInstructionMap; - - // Matches 'instruction' HloOpcode against 'opcode_'. - // Recursively matches each operand in 'operands_'. - // Recursively matches fused instructions starting at 'fused_root_tree_' - // if 'opcode_ == kFusion'. - // Returns OK status, and instructions in 'tagged_instructions' for each - // matched ExprTree node with a non-empty 'tag_'. - // Returns error message on failure. - Status Match(const HloInstruction* instruction, - TaggedInstructionMap* tagged_instructions) const { - if (opcode_ != instruction->opcode()) { - return InvalidArgument("got opcode %s, want %s", - HloOpcodeString(instruction->opcode()).c_str(), - HloOpcodeString(opcode_).c_str()); - } - - VLOG(2) << "Matched " << HloOpcodeString(opcode_) << ": " << tag_; - if (!tag_.empty()) { - tagged_instructions->insert({tag_, instruction}); - } - - if (instruction->opcode() == HloOpcode::kFusion) { - CHECK(fused_root_tree_ != nullptr); - // Match fused instructions for this node starting a 'fused_root_tree'. - TF_RETURN_IF_ERROR(fused_root_tree_->Match( - instruction->fused_expression_root(), tagged_instructions)); - } - - // Match each operand in 'operands_'. - for (auto& pair : operands_) { - TF_RETURN_IF_ERROR(pair.second->Match(instruction->operand(pair.first), - tagged_instructions)); - } - return Status::OK(); - } - - private: - void SetOperand(int64 index, const ExprTree& operand) { - CHECK_EQ(0, operands_.count(index)); - operands_.insert(std::make_pair(index, MakeUnique(operand))); - } - - HloOpcode opcode_; - std::unordered_map> operands_; - std::unique_ptr fused_root_tree_; - string tag_; -}; - -// MatcherBase is a base class that provides common functionality for -// sub-classes which match specific target sub-computations (i.e. loop -// induction variable initialization, comparison and update). -class MatcherBase { - public: - MatcherBase() {} - virtual ~MatcherBase() {} - - // Attempts to match each ExprTree in 'expr_trees_'. - // Returns OK on the first successful match, error status otherwise. - virtual Status Run() { - Status status; - for (const ExprTree& expr_tree : expr_trees_) { - status = MatchExprTree(expr_tree); - if (status.ok()) { - return status; - } - } - return status; - } - - virtual Status MatchExprTree(const ExprTree& expr_tree) = 0; - - // Returns the constant value parsed form kConstant 'instruction'. - // Returns error status otherwise. - Status ParseConstInteger(const HloInstruction* instruction, - int64* const_value) const { - CHECK_EQ(HloOpcode::kConstant, instruction->opcode()); - PrimitiveType element_type = instruction->shape().element_type(); - if (element_type != S32 && element_type != S64) { - return InvalidArgument("Expected constant of integral type."); - } - const Literal& literal = instruction->literal(); - PrimitiveType type = literal.shape().element_type(); - if (type != S32 && type != S64) { - return InvalidArgument("Must use S32 or S64 integral types."); - } - if (type == S32) { - *const_value = static_cast(literal.GetFirstElement()); - } else if (type == S64) { - *const_value = literal.GetFirstElement(); - } - return Status::OK(); - } - - StatusOr GetTaggedInstruction( - const string& tag, - const ExprTree::TaggedInstructionMap& tagged_instructions) { - auto it = tagged_instructions.find(tag); - if (it == tagged_instructions.end()) { - return InvalidArgument("Cound not find instruction for tag: %s", - tag.c_str()); - } - return it->second; - } - - protected: - std::vector expr_trees_; - - private: - TF_DISALLOW_COPY_AND_ASSIGN(MatcherBase); -}; - -// WhileConditionComputationMatcher attempts to match a target computation -// pattern in the while condition sub-computation. -// If the target pattern is matched, two pieces of information are extracted -// from 'tagged' instructions returned by the matcher: -// -// *) 'tuple_index': -// *) The loop induction variable tuple_index from the GetTupleElement -// instruction of the matched computation. -// *) Used in subsequent matching passes of while init operand and body -// computations to select loop induction variable tuple element. -// -// *) 'loop_limit': -// *) The integral value from Constant root operand in matched computation. -// *) Used as the constant for the loop limit. -// -class WhileConditionComputationMatcher : public MatcherBase { - public: - explicit WhileConditionComputationMatcher(const HloComputation* computation) - : computation_(computation) { - expr_trees_.emplace_back(BuildCondExprTree()); - } - - int64 loop_limit() const { return loop_limit_; } - int64 tuple_index() const { return tuple_index_; } - - private: - // Builds expression tree for the following condition computation: - // - // Const Parameter - // \ / - // Fusion ------------> FusionParam FusionParam - // \ / - // GTE / - // \ / - // LessThan (fused root) - // - ExprTree BuildCondExprTree() { - // Build ExprTree for fused instructions. - ExprTree fused_root( - HloOpcode::kLt, - ExprTree(HloOpcode::kGetTupleElement, "gte", - ExprTree(HloOpcode::kParameter, "gte.fusion_param.param0")), - ExprTree(HloOpcode::kParameter)); - - // Build top-level computation. - ExprTree root(HloOpcode::kFusion, - ExprTree(HloOpcode::kConstant, "loop_limit"), - ExprTree(HloOpcode::kParameter, "param0")); - - root.SetFusedRoot(fused_root); - return root; - } - - Status MatchExprTree(const ExprTree& expr_tree) override { - VLOG(2) << "MATCHING while condition"; - ExprTree::TaggedInstructionMap tagged_instructions; - TF_RETURN_IF_ERROR(expr_tree.Match(computation_->root_instruction(), - &tagged_instructions)); - - // Get tagged GTE instruction and set 'tuple_index_'. - TF_ASSIGN_OR_RETURN(const HloInstruction* gte, - GetTaggedInstruction("gte", tagged_instructions)); - tuple_index_ = gte->tuple_index(); - - // Get tagged Constant instruction and parse 'loop_limit_'. - TF_ASSIGN_OR_RETURN( - const HloInstruction* const_hlo, - GetTaggedInstruction("loop_limit", tagged_instructions)); - TF_RETURN_IF_ERROR(ParseConstInteger(const_hlo, &loop_limit_)); - - // Get tagged "param0" instruction, and check that it matches - // 'computation_' parameter 0. - TF_ASSIGN_OR_RETURN(const HloInstruction* param0, - GetTaggedInstruction("param0", tagged_instructions)); - if (param0 != computation_->parameter_instruction(0)) { - return InvalidArgument("Unexpected Parameter0 instruction : %s", - param0->name().c_str()); - } - - // Get tagged 'gte.fusion_param.param0', find its associated fusion operand, - // and compare it to 'computation_' parameter0. - TF_ASSIGN_OR_RETURN( - const HloInstruction* gte_fusion_param0, - GetTaggedInstruction("gte.fusion_param.param0", tagged_instructions)); - CHECK_EQ(HloOpcode::kParameter, gte_fusion_param0->opcode()); - CHECK(gte_fusion_param0->IsFused()); - if (gte_fusion_param0->parent()->FusionInstruction()->operand( - gte_fusion_param0->parameter_number()) != - computation_->parameter_instruction(0)) { - return InvalidArgument("Could not match fusion param: %s", - gte_fusion_param0->name().c_str()); - } - - return Status::OK(); - } - - const HloComputation* computation_; - - int64 loop_limit_ = -1; - int64 tuple_index_ = -1; - - TF_DISALLOW_COPY_AND_ASSIGN(WhileConditionComputationMatcher); -}; - -// WhileInitOperandMatcher matches a target computation pattern of the -// while instructions 'init' operand, indexing the tuple at 'tuple_index'. -// On success, parses constant 'loop_start' which represents the loop induction -// variable start values, then returns OK. -// Returns error status otherwise. -class WhileInitOperandMatcher : public MatcherBase { - public: - WhileInitOperandMatcher(const HloInstruction* while_hlo, - const int64 tuple_index) - : while_hlo_(while_hlo), tuple_index_(tuple_index) { - expr_trees_.emplace_back(BuildInitExprTree()); - } - - int64 loop_start() const { return loop_start_; } - - private: - // Builds expression tree for the following while init operand subcomputation: - // - // Const - // | - // Copy - // | - // Tuple0 - // | - // While - // - ExprTree BuildInitExprTree() { - return ExprTree( - HloOpcode::kWhile, "while", - ExprTree(HloOpcode::kTuple, tuple_index_, - ExprTree(HloOpcode::kCopy, - ExprTree(HloOpcode::kConstant, "loop_start")))); - } - - Status MatchExprTree(const ExprTree& expr_tree) override { - VLOG(2) << "MATCHING while init"; - ExprTree::TaggedInstructionMap tagged_instructions; - TF_RETURN_IF_ERROR(expr_tree.Match(while_hlo_, &tagged_instructions)); - - // Get tagged while instruction check against 'while_hlo_'. - TF_ASSIGN_OR_RETURN(const HloInstruction* while_hlo, - GetTaggedInstruction("while", tagged_instructions)); - if (while_hlo != while_hlo_) { - return InvalidArgument("Expected While for instruction : %s", - while_hlo->name().c_str()); - } - - // Get tagged Constant instruction and parse 'loop_start_'. - TF_ASSIGN_OR_RETURN( - const HloInstruction* const_hlo, - GetTaggedInstruction("loop_start", tagged_instructions)); - TF_RETURN_IF_ERROR(ParseConstInteger(const_hlo, &loop_start_)); - - return Status::OK(); - } - - const HloInstruction* while_hlo_; - const int64 tuple_index_; - - int64 loop_start_ = -1; - - TF_DISALLOW_COPY_AND_ASSIGN(WhileInitOperandMatcher); -}; - -// WhileBodyComputationMatcher matches a target computation pattern for -// the loop induction variable update. Matching proceeds from the while body -// computation root[tuple_index] to param[tuple_index], where 'tuple_index' -// If the target pattern is matched, parses a constant which represents the -// loop induction variable increment value, then returns status OK. -// Returns error status otherwise. -class WhileBodyComputationMatcher : public MatcherBase { - public: - WhileBodyComputationMatcher(const HloComputation* computation, - const int64 tuple_index) - : computation_(computation), tuple_index_(tuple_index) { - expr_trees_.emplace_back(BuildBodyExprTree(0, 1)); - expr_trees_.emplace_back(BuildBodyExprTree(1, 0)); - } - - int64 loop_increment() const { return loop_increment_; } - - private: - // Builds expression tree for the following while body computation: - // - // - // FusionParam FusionParam - // \ / - // Const Param \ GTE1 - // \ / \ / - // Fusion -----------> Add - // | - // Copy - // | - // Tuple0 - // - ExprTree BuildBodyExprTree(const int64 const_index, const int64 gte_index) { - // Build ExprTree for fused instructions. - ExprTree gte1 = - ExprTree(HloOpcode::kGetTupleElement, "gte", - ExprTree(HloOpcode::kParameter, "gte.fusion_param.param0")); - ExprTree fused_root(HloOpcode::kAdd, const_index, - ExprTree(HloOpcode::kParameter), gte_index, gte1); - - // Build fusion instruction (and set fused root). - ExprTree fusion(HloOpcode::kFusion, 0, - ExprTree(HloOpcode::kConstant, "loop_increment"), 1, - ExprTree(HloOpcode::kParameter, "param0")); - fusion.SetFusedRoot(fused_root); - - // Build top-level computation. - ExprTree tuple0(HloOpcode::kTuple, tuple_index_, - ExprTree(HloOpcode::kCopy, fusion)); - return tuple0; - } - - Status MatchExprTree(const ExprTree& expr_tree) override { - VLOG(2) << "MATCHING while body"; - ExprTree::TaggedInstructionMap tagged_instructions; - TF_RETURN_IF_ERROR(expr_tree.Match(computation_->root_instruction(), - &tagged_instructions)); - - for (const auto& pair : tagged_instructions) { - const auto& tag = pair.first; - const auto& inst = pair.second; - - if (tag == "gte" && inst->tuple_index() != tuple_index_) { - // Check that the matched GTE instruction is at the 'tuple_index' we - // matched in the while condition computation. - return InvalidArgument("Unexpected tuple index instruction : %s", - inst->name().c_str()); - } else if (tag == "loop_increment") { - // ParseHloString the constant which represents the loop induction - // variable increment value. - TF_RETURN_IF_ERROR(ParseConstInteger(inst, &loop_increment_)); - } else if (tag == "param0" && - inst != computation_->parameter_instruction(0)) { - // Check that the matched parameter == parameter 0 from 'computation_'. - return InvalidArgument("Unexpected Parameter0 instruction : %s", - inst->name().c_str()); - } else if (tag == "gte.fusion_param.param0") { - // Fusion parameter: lookup and compare with associated fusion operand. - CHECK_EQ(HloOpcode::kParameter, inst->opcode()); - CHECK(inst->IsFused()); - if (inst->parent()->FusionInstruction()->operand( - inst->parameter_number()) != - computation_->parameter_instruction(0)) { - return InvalidArgument("Could not match fusion param: %s", - inst->name().c_str()); - } - } - } - return Status::OK(); - } - - const HloComputation* computation_; - const int64 tuple_index_; - - int64 loop_increment_ = -1; - - TF_DISALLOW_COPY_AND_ASSIGN(WhileBodyComputationMatcher); -}; - -} // namespace - -StatusOr> CanTransformWhileToFor( - const HloInstruction* while_hlo) { - if (while_hlo->opcode() != HloOpcode::kWhile) { - return InvalidArgument("Expected While instruction."); - } - - WhileConditionComputationMatcher cond_matcher(while_hlo->while_condition()); - TF_RETURN_IF_ERROR(cond_matcher.Run()); - - WhileInitOperandMatcher init_matcher(while_hlo, cond_matcher.tuple_index()); - TF_RETURN_IF_ERROR(init_matcher.Run()); - - WhileBodyComputationMatcher body_matcher(while_hlo->while_body(), - cond_matcher.tuple_index()); - TF_RETURN_IF_ERROR(body_matcher.Run()); - - // Check for valid For loop parameters. - if (init_matcher.loop_start() >= cond_matcher.loop_limit()) { - return InvalidArgument("Loop start must be less than loop limit."); - } - if (body_matcher.loop_increment() <= 0) { - return InvalidArgument("Loop increment must greater than zero."); - } - return std::make_tuple(init_matcher.loop_start(), cond_matcher.loop_limit(), - body_matcher.loop_increment()); -} - -} // namespace gpu -} // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/while_transformer.h b/tensorflow/compiler/xla/service/gpu/while_transformer.h deleted file mode 100644 index fe3a954e1828ee4a323872eea81f64c7e780ad24..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/xla/service/gpu/while_transformer.h +++ /dev/null @@ -1,43 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_WHILE_TRANSFORMER_H_ -#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_WHILE_TRANSFORMER_H_ - -#include "tensorflow/compiler/xla/service/hlo_instruction.h" -#include "tensorflow/compiler/xla/statusor.h" - -namespace xla { -namespace gpu { - -// Runs an analysis of the while loop instruction 'while_hlo' (and its -// associated sub-computations) to determine if it can be transformed into an -// equivalent "for" loop with the following "for" loop parameters: -// -// *) 'loop_start': loop induction variable starting value. -// *) 'loop_limit': loop induction variable limit value. -// *) 'loop_increment': loop induction variable per-iteration increment value. -// -// Returns an std::tuple = (loop_start, loop_limit, loop_increment) on success. -// The values in the returned tuple are values extracted from the 'while_hlo' -// operand (and its sub-computations) during analysis. -// Returns an error status on failure. -StatusOr> CanTransformWhileToFor( - const HloInstruction* while_hlo); - -} // namespace gpu -} // namespace xla - -#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_WHILE_TRANSFORMER_H_ diff --git a/tensorflow/compiler/xla/service/gpu/while_transformer_test.cc b/tensorflow/compiler/xla/service/gpu/while_transformer_test.cc index dbc8442ed2785a112b674632689256c01282156b..c5f3906356d821e059d2b1213c9083c4408a4d1c 100644 --- a/tensorflow/compiler/xla/service/gpu/while_transformer_test.cc +++ b/tensorflow/compiler/xla/service/gpu/while_transformer_test.cc @@ -13,11 +13,10 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/xla/service/gpu/while_transformer.h" - #include "tensorflow/compiler/xla/service/copy_insertion.h" #include "tensorflow/compiler/xla/service/gpu/instruction_fusion.h" #include "tensorflow/compiler/xla/service/hlo_verifier.h" +#include "tensorflow/compiler/xla/service/while_loop_analysis.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/test_helpers.h" @@ -110,12 +109,12 @@ class WhileTransformerTest : public HloTestBase { void RunFusionPasses() { // Run standard fusion passes. - EXPECT_TRUE(gpu::GpuInstructionFusion(/*may_duplicate=*/false) - .Run(module_.get()) - .ValueOrDie()); - EXPECT_TRUE(gpu::GpuInstructionFusion(/*may_duplicate=*/true) - .Run(module_.get()) - .ValueOrDie()); + TF_ASSERT_OK(gpu::GpuInstructionFusion(/*may_duplicate=*/false) + .Run(module_.get()) + .status()); + TF_ASSERT_OK(gpu::GpuInstructionFusion(/*may_duplicate=*/true) + .Run(module_.get()) + .status()); } void RunCopyInsertionPass() { @@ -141,10 +140,7 @@ class WhileTransformerTest : public HloTestBase { Shape condition_result_shape_; }; -// TODO(b/68830972): The while transformer is far too fragile. It patterns -// matches the exact expressions of opcodes. Re-enable when transformation is -// more general -TEST_F(WhileTransformerTest, DISABLED_InductionVariableAtTupleElement0) { +TEST_F(WhileTransformerTest, InductionVariableAtTupleElement0) { // Build computation with induction variable at tuple element 0. auto condition = module_->AddEmbeddedComputation(BuildConditionComputation(0, 10)); @@ -153,18 +149,13 @@ TEST_F(WhileTransformerTest, DISABLED_InductionVariableAtTupleElement0) { // Run HLO Optimization passes. RunFusionPasses(); RunCopyInsertionPass(); - // Run WhileTransformer. - auto result = gpu::CanTransformWhileToFor(while_hlo); - TF_ASSERT_OK(result.status()); - // Check results. - EXPECT_THAT(result.ConsumeValueOrDie(), - Eq(std::tuple(0, 10, 1))); + + auto result = ComputeWhileLoopTripCount(while_hlo); + ASSERT_TRUE(result); + EXPECT_EQ(10, *result); } -// TODO(b/68830972): The while transformer is far too fragile. It patterns -// matches the exact expressions of opcodes. Re-enable when transformation is -// more general -TEST_F(WhileTransformerTest, DISABLED_InductionVariableAtTupleElement1) { +TEST_F(WhileTransformerTest, InductionVariableAtTupleElement1) { // Build computation with induction variable at tuple element 1. auto condition = module_->AddEmbeddedComputation(BuildConditionComputation(1, 10)); @@ -173,19 +164,14 @@ TEST_F(WhileTransformerTest, DISABLED_InductionVariableAtTupleElement1) { // Run HLO Optimization passes. RunFusionPasses(); RunCopyInsertionPass(); - // Run WhileTransformer. - auto result = gpu::CanTransformWhileToFor(while_hlo); - TF_ASSERT_OK(result.status()); - // Check results. - EXPECT_THAT(result.ConsumeValueOrDie(), - Eq(std::tuple(0, 10, 1))); + + auto result = ComputeWhileLoopTripCount(while_hlo); + ASSERT_TRUE(result); + EXPECT_EQ(10, *result); } -// TODO(b/68830972): The while transformer is far too fragile. It patterns -// matches the exact expressions of opcodes. Re-enable when transformation is -// more general -TEST_F(WhileTransformerTest, DISABLED_InvalidLoopLimit) { - // Build computation with invalid loop limit. +TEST_F(WhileTransformerTest, ImpossibleLoopLimit) { + // Build computation with an impossible loop limit. auto condition = module_->AddEmbeddedComputation(BuildConditionComputation(0, 5)); auto body = module_->AddEmbeddedComputation(BuildBodyComputation(0, 1, 1)); @@ -193,17 +179,13 @@ TEST_F(WhileTransformerTest, DISABLED_InvalidLoopLimit) { // Run HLO Optimization passes. RunFusionPasses(); RunCopyInsertionPass(); - // Run WhileTransformer. - auto result = gpu::CanTransformWhileToFor(while_hlo); - ASSERT_FALSE(result.ok()); - EXPECT_THAT(result.status().error_message(), - HasSubstr("Loop start must be less than loop limit.")); + + auto result = ComputeWhileLoopTripCount(while_hlo); + ASSERT_TRUE(result); + EXPECT_EQ(0, *result); } -// TODO(b/68830972): The while transformer is far too fragile. It patterns -// matches the exact expressions of opcodes. Re-enable when transformation is -// more general -TEST_F(WhileTransformerTest, DISABLED_InvalidLoopIncrement) { +TEST_F(WhileTransformerTest, InvalidLoopIncrement) { // Build computation with invalid loop increment. auto condition = module_->AddEmbeddedComputation(BuildConditionComputation(0, 10)); @@ -212,11 +194,9 @@ TEST_F(WhileTransformerTest, DISABLED_InvalidLoopIncrement) { // Run HLO Optimization passes. RunFusionPasses(); RunCopyInsertionPass(); - // Run WhileTransformer. - auto result = gpu::CanTransformWhileToFor(while_hlo); - ASSERT_FALSE(result.ok()); - EXPECT_THAT(result.status().error_message(), - HasSubstr("Loop increment must greater than zero.")); + + auto result = ComputeWhileLoopTripCount(while_hlo); + ASSERT_FALSE(result); } } // namespace diff --git a/tensorflow/compiler/xla/service/graphviz_example.cc b/tensorflow/compiler/xla/service/graphviz_example.cc index aa89567ee86e59e197045c0b51eed3b9aa59fef7..31431f115f8ffd72df65638a2b00e63b3c433a7e 100644 --- a/tensorflow/compiler/xla/service/graphviz_example.cc +++ b/tensorflow/compiler/xla/service/graphviz_example.cc @@ -22,9 +22,9 @@ limitations under the License. #include #include +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/literal_util.h" -#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_graph_dumper.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" @@ -84,7 +84,7 @@ HloComputation* CallForwardingComputation(HloComputation* computation, // the module. std::unique_ptr MakeBigGraph() { HloModuleConfig config; - auto module = MakeUnique("BigGraph", config); + auto module = absl::make_unique("BigGraph", config); auto builder = HloComputation::Builder("TestBigGraphvizGraph"); diff --git a/tensorflow/compiler/xla/service/heap_simulator.cc b/tensorflow/compiler/xla/service/heap_simulator.cc index 4005fc0d114a3ec7a38dfb5edecdaeb1e8497ade..93a922b9046c9a8557a7cecb961dbcbe17fa1f9b 100644 --- a/tensorflow/compiler/xla/service/heap_simulator.cc +++ b/tensorflow/compiler/xla/service/heap_simulator.cc @@ -18,6 +18,7 @@ limitations under the License. #include #include +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/map_util.h" #include "tensorflow/compiler/xla/util.h" @@ -45,7 +46,7 @@ StatusOr HeapSimulator::MinimumMemoryForModule( // bound, by minimizing the liveness of sub-computations. TF_ASSIGN_OR_RETURN( HeapSimulator::Result result, - HeapSimulator::Run(MakeUnique(), *module, + HeapSimulator::Run(absl::make_unique(), *module, module_sequence, *points_to_analysis, size_function)); return result.heap_size; } @@ -60,9 +61,10 @@ StatusOr HeapSimulator::MinimumMemoryForComputation( memory_by_computation) { TF_ASSIGN_OR_RETURN( HeapSimulator::Result result, - HeapSimulator::Run(MakeUnique(), computation, - sequence, points_to_analysis, size_function, - HeapSimulator::Options(), memory_by_computation)); + HeapSimulator::Run(absl::make_unique(), + computation, sequence, points_to_analysis, + size_function, HeapSimulator::Options(), + memory_by_computation)); return result.heap_size; } @@ -344,7 +346,7 @@ HeapSimulator::HeapSimulator( const SequentialHloOrdering::HloModuleSequence* module_sequence, const tensorflow::gtl::FlatMap* memory_by_computation) - : no_fragmentation_stats_(MakeUnique()), + : no_fragmentation_stats_(absl::make_unique()), algorithm_(std::move(algorithm)), size_fn_(size_fn), options_(options), diff --git a/tensorflow/compiler/xla/service/heap_simulator_test.cc b/tensorflow/compiler/xla/service/heap_simulator_test.cc index b41dc66fe9f5e869a114be96b7cc01fc1a3d59da..5f85f145657b67634844c849447ef545a6dea468 100644 --- a/tensorflow/compiler/xla/service/heap_simulator_test.cc +++ b/tensorflow/compiler/xla/service/heap_simulator_test.cc @@ -19,6 +19,7 @@ limitations under the License. #include #include +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/service/buffer_value.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" @@ -137,7 +138,7 @@ class HeapSimulatorTracker { const string& name, std::unique_ptr computation, const std::vector& instruction_sequence) { HloModuleConfig config; - module_ = MakeUnique(name, config); + module_ = absl::make_unique(name, config); module_->AddEntryComputation(std::move(computation)); points_to_analysis_ = TuplePointsToAnalysis::Run(module_.get()).ConsumeValueOrDie(); @@ -146,8 +147,8 @@ class HeapSimulatorTracker { // the secondary sorting criteria of DecreasingSizeRunsHeap to sort calls by // buffer id, for determinism in the tests. auto zero_size = [](const BufferValue& buffer) { return 0; }; - auto algorithm = MakeUnique( - MakeUnique(&actual_calls_)); + auto algorithm = absl::make_unique( + absl::make_unique(&actual_calls_)); result_ = HeapSimulator::Run( std::move(algorithm), *module_->entry_computation(), instruction_sequence, *points_to_analysis_, zero_size) @@ -156,7 +157,7 @@ class HeapSimulatorTracker { explicit HeapSimulatorTracker(const string& name) { HloModuleConfig config; - module_ = MakeUnique(name, config); + module_ = absl::make_unique(name, config); } // Similar to the single entry computation constructor above, but runs the @@ -182,8 +183,8 @@ class HeapSimulatorTracker { auto size_fn = [&reverse_position](const BufferValue& buffer) { return reverse_position[buffer.instruction()]; }; - auto algorithm = MakeUnique( - MakeUnique(&actual_calls_)); + auto algorithm = absl::make_unique( + absl::make_unique(&actual_calls_)); result_ = HeapSimulator::Run(std::move(algorithm), *module_, module_sequence, *points_to_analysis_, size_fn) .ConsumeValueOrDie(); @@ -675,7 +676,8 @@ class HeapAlgorithmTestBase : public ::testing::Test { const BufferValue::Id id = buffers_.size(); auto const0 = builder_.AddInstruction( HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); - buffers_.emplace_back(MakeUnique(id, const0, ShapeIndex{})); + buffers_.emplace_back( + absl::make_unique(id, const0, ShapeIndex{})); return buffers_.back().get(); } @@ -724,7 +726,8 @@ class DecreasingSizeRunsHeapTest : public HeapAlgorithmTestBase {}; TEST_F(DecreasingSizeRunsHeapTest, Empty) { CallSequence call_sequence; - DecreasingSizeRunsHeap heap(MakeUnique(&call_sequence)); + DecreasingSizeRunsHeap heap( + absl::make_unique(&call_sequence)); heap.Finish(); EXPECT_EQ(call_sequence, CallSequence({ {kFinish, nullptr}, @@ -733,7 +736,8 @@ TEST_F(DecreasingSizeRunsHeapTest, Empty) { TEST_F(DecreasingSizeRunsHeapTest, Simple) { CallSequence call_sequence; - DecreasingSizeRunsHeap heap(MakeUnique(&call_sequence)); + DecreasingSizeRunsHeap heap( + absl::make_unique(&call_sequence)); heap.Alloc(buffer_a_, 10); heap.Alloc(buffer_b_, 20); heap.Alloc(buffer_c_, 30); @@ -760,7 +764,8 @@ TEST_F(DecreasingSizeRunsHeapTest, Simple) { TEST_F(DecreasingSizeRunsHeapTest, Mixed) { CallSequence call_sequence; - DecreasingSizeRunsHeap heap(MakeUnique(&call_sequence)); + DecreasingSizeRunsHeap heap( + absl::make_unique(&call_sequence)); heap.Alloc(buffer_a_, 10); heap.Alloc(buffer_b_, 20); heap.Free(buffer_b_, 20); diff --git a/tensorflow/compiler/xla/service/hlo.proto b/tensorflow/compiler/xla/service/hlo.proto index 63a8a813cddf304e60fa9b4bbf709eca2d7c2cae..fa218657fe51769321d75685703b44c29bd34291 100644 --- a/tensorflow/compiler/xla/service/hlo.proto +++ b/tensorflow/compiler/xla/service/hlo.proto @@ -34,6 +34,7 @@ import "tensorflow/compiler/xla/xla_data.proto"; option cc_enable_arenas = true; // Serialization of HloInstruction. +// Next ID: 51 message HloInstructionProto { reserved 10; reserved "parameter_name"; @@ -74,6 +75,11 @@ message HloInstructionProto { // Describes the dimension numbers used for a convolution. xla.ConvolutionDimensionNumbers convolution_dimension_numbers = 16; + // The number of feature groups. Used for a convolution. Must be a divisor of + // the input feature dimension and output feature dimension. If not specified, + // it will use a default value of 1. + int64 feature_group_count = 50; + // Describes the [begin, end) index range and stride for slices. message SliceDimensions { int64 start = 1; @@ -133,7 +139,7 @@ message HloInstructionProto { // Gather dimension numbers. xla.GatherDimensionNumbers gather_dimension_numbers = 33; - repeated int64 gather_window_bounds = 34; + repeated int64 gather_slice_sizes = 34; // Compute Host. string channel_name = 41; @@ -151,8 +157,11 @@ message HloInstructionProto { // Backend configuration for the instruction. Has backend-specific meaning. string backend_config = 43; - // Cross Replica Sum fields. + // Cross replica op fields. + // TODO(b/112107579): remove replica_group_ids field and always use + // replica_groups. repeated int64 replica_group_ids = 44; + repeated ReplicaGroup replica_groups = 49; int64 all_reduce_id = 45; string cross_replica_sum_barrier = 46; @@ -160,6 +169,8 @@ message HloInstructionProto { // present for Send and Recv instructions and their SendDone and RecvDone // partners. bool is_host_transfer = 47; + + xla.ScatterDimensionNumbers scatter_dimension_numbers = 48; } // Serialization of HloComputation. diff --git a/tensorflow/compiler/xla/service/hlo_alias_analysis.cc b/tensorflow/compiler/xla/service/hlo_alias_analysis.cc index e8a4b034b4396860bd5873f43003844ce92dea6c..0ca489846e7137a9ffa341e63c8a289ed4af2043 100644 --- a/tensorflow/compiler/xla/service/hlo_alias_analysis.cc +++ b/tensorflow/compiler/xla/service/hlo_alias_analysis.cc @@ -457,7 +457,7 @@ StatusOr> HloAliasAnalysis::Run( VLOG(2) << "HloAliasAnalysis::Run on module " << module->name(); XLA_VLOG_LINES(2, module->ToString()); - auto alias_analysis = WrapUnique(new HloAliasAnalysis(module)); + auto alias_analysis = absl::WrapUnique(new HloAliasAnalysis(module)); TF_ASSIGN_OR_RETURN(alias_analysis->dataflow_analysis_, HloDataflowAnalysis::Run(*module, /*ssa_form=*/true, /*bitcast_defines_value=*/false, diff --git a/tensorflow/compiler/xla/service/hlo_computation.cc b/tensorflow/compiler/xla/service/hlo_computation.cc index 441288da1a6859a3f393a298ee02eb4b435e42e0..bae78c94bdcb9b23d8337f62b5d0797c7028b7d0 100644 --- a/tensorflow/compiler/xla/service/hlo_computation.cc +++ b/tensorflow/compiler/xla/service/hlo_computation.cc @@ -23,9 +23,10 @@ limitations under the License. #include #include +#include "absl/algorithm/container.h" +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/map_util.h" -#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" @@ -56,8 +57,8 @@ std::unique_ptr HloComputation::Builder::Build( HloInstruction* root = root_instruction ? root_instruction : last_added_instruction_; CHECK_NE(nullptr, root); - return WrapUnique(new HloComputation(name_, parameter_count, &instructions_, - root, fusion_instruction_)); + return absl::WrapUnique(new HloComputation( + name_, parameter_count, &instructions_, root, fusion_instruction_)); } HloComputation::HloComputation( @@ -493,9 +494,9 @@ HloComputation::CreateFromProto( return to_proto_id[a.get()] < to_proto_id[b.get()]; }); - return WrapUnique(new HloComputation(proto.name(), parameter_count, - &instructions, root, - /*fusion_instruction=*/nullptr)); + return absl::WrapUnique(new HloComputation(proto.name(), parameter_count, + &instructions, root, + /*fusion_instruction=*/nullptr)); } void HloComputation::FuseInstructionsInto( @@ -674,7 +675,7 @@ Status HloComputation::ReplaceInstruction(HloInstruction* old_instruction, std::unique_ptr HloComputation::ComputeReachability() const { const auto& all = MakeInstructionPostOrder(); - auto result = MakeUnique(all); + auto result = absl::make_unique(all); std::vector inputs; for (const HloInstruction* hlo : all) { @@ -829,7 +830,7 @@ std::unique_ptr HloComputation::CloneWithReplacements( HloCloneContext* context, const string& suffix) { std::unique_ptr context_ptr; if (context == nullptr) { - context_ptr = MakeUnique(parent(), suffix); + context_ptr = absl::make_unique(parent(), suffix); context = context_ptr.get(); } @@ -901,9 +902,9 @@ void HloComputation::UniquifyName(NameUniquer* name_uniquer) { HloInstruction* HloComputation::GetInstructionWithName( tensorflow::StringPiece name) { auto instructions_in_computation = instructions(); - auto it = c_find_if(instructions_in_computation, [&](HloInstruction* instr) { - return instr->name() == name; - }); + auto it = absl::c_find_if( + instructions_in_computation, + [&](HloInstruction* instr) { return instr->name() == name; }); return it == instructions_in_computation.end() ? nullptr : *it; } diff --git a/tensorflow/compiler/xla/service/hlo_constant_folding.cc b/tensorflow/compiler/xla/service/hlo_constant_folding.cc index 7229031c0c7f8bd374cfb495c7d8c11e9ca8b95e..6dddda1ca8902e6047cde59aced3867bda5c4303 100644 --- a/tensorflow/compiler/xla/service/hlo_constant_folding.cc +++ b/tensorflow/compiler/xla/service/hlo_constant_folding.cc @@ -20,6 +20,7 @@ limitations under the License. #include #include +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" @@ -38,7 +39,7 @@ StatusOr HloConstantFolding::Run(HloModule* module) { // Limit the constant folding to 0 iterations to skip folding loops. This // retains the behavior from before while loop support in HloEvaluator and may // be revised. - auto evaluator = MakeUnique(/*max_loop_iterations=*/0); + auto evaluator = absl::make_unique(/*max_loop_iterations=*/0); XLA_VLOG_LINES(2, "HloConstantFolding::Run(), before:\n" + module->ToString()); diff --git a/tensorflow/compiler/xla/service/hlo_cost_analysis.cc b/tensorflow/compiler/xla/service/hlo_cost_analysis.cc index 1f672502f72f9c658b681383e858995f6e94d2c7..1bbb0ff08e26f626f4c3992a5f20ec4990f7db2d 100644 --- a/tensorflow/compiler/xla/service/hlo_cost_analysis.cc +++ b/tensorflow/compiler/xla/service/hlo_cost_analysis.cc @@ -49,9 +49,9 @@ Status HloCostAnalysis::Preprocess(const HloInstruction* hlo) { // The default number of bytes accessed for an instruction is the sum of the // sizes of the inputs and outputs. The default ShapeUtil::ByteSizeOf does not // handle opaque types. - float bytes_accessed = shape_size_(hlo->shape()); + float bytes_accessed = GetShapeSize(hlo->shape()); for (const HloInstruction* operand : hlo->operands()) { - bytes_accessed += shape_size_(operand->shape()); + bytes_accessed += GetShapeSize(operand->shape()); } current_properties_[kBytesAccessedKey] = bytes_accessed; @@ -121,6 +121,13 @@ Status HloCostAnalysis::HandleElementwiseOp( } } +int64 HloCostAnalysis::GetShapeSize(const Shape& shape) const { + if (!LayoutUtil::HasLayout(shape)) { + return 0; + } + return shape_size_(shape); +} + Status HloCostAnalysis::HandleElementwiseUnary(const HloInstruction* hlo) { return HandleElementwiseOp(hlo); } @@ -181,21 +188,21 @@ Status HloCostAnalysis::HandleReverse(const HloInstruction*) { } Status HloCostAnalysis::HandleSlice(const HloInstruction* slice) { - current_properties_[kBytesAccessedKey] = shape_size_(slice->shape()) * 2; + current_properties_[kBytesAccessedKey] = GetShapeSize(slice->shape()) * 2; return Status::OK(); } Status HloCostAnalysis::HandleDynamicSlice( const HloInstruction* dynamic_slice) { current_properties_[kBytesAccessedKey] = - shape_size_(dynamic_slice->shape()) * 2; + GetShapeSize(dynamic_slice->shape()) * 2; return Status::OK(); } Status HloCostAnalysis::HandleDynamicUpdateSlice( const HloInstruction* dynamic_update_slice) { current_properties_[kBytesAccessedKey] = - shape_size_(dynamic_update_slice->operand(1)->shape()) * 2; + GetShapeSize(dynamic_update_slice->operand(1)->shape()) * 2; return Status::OK(); } @@ -204,7 +211,7 @@ Status HloCostAnalysis::HandleTuple(const HloInstruction* tuple) { // through them). The memory touched is then only the size of the output // index table of the tuple. - current_properties_[kBytesAccessedKey] = shape_size_(tuple->shape()); + current_properties_[kBytesAccessedKey] = GetShapeSize(tuple->shape()); return Status::OK(); } @@ -526,12 +533,25 @@ Status HloCostAnalysis::HandleCrossReplicaSum(const HloInstruction* crs) { // TODO(b/33004697): Compute correct cost here, taking the actual number of // replicas into account. double flops = 0.0; - ShapeUtil::ForEachSubshape( - crs->shape(), [&, this](const Shape& subshape, const ShapeIndex&) { - if (ShapeUtil::IsArray(subshape)) { - flops += ShapeUtil::ElementsIn(subshape); - } - }); + ShapeUtil::ForEachSubshape(crs->shape(), + [&](const Shape& subshape, const ShapeIndex&) { + if (ShapeUtil::IsArray(subshape)) { + flops += ShapeUtil::ElementsIn(subshape); + } + }); + current_properties_[kFlopsKey] = flops; + return Status::OK(); +} + +Status HloCostAnalysis::HandleAllToAll(const HloInstruction* hlo) { + // TODO(b/110096724): Compute correct cost here. + double flops = 0.0; + ShapeUtil::ForEachSubshape(hlo->shape(), + [&](const Shape& subshape, const ShapeIndex&) { + if (ShapeUtil::IsArray(subshape)) { + flops += ShapeUtil::ElementsIn(subshape); + } + }); current_properties_[kFlopsKey] = flops; return Status::OK(); } @@ -546,15 +566,9 @@ Status HloCostAnalysis::HandleRng(const HloInstruction* random) { } Status HloCostAnalysis::HandleFusion(const HloInstruction* fusion) { - // Compute the properties of the fused expression and attribute them to the - // fusion node. Use a dummy shape_size to avoid any errors from trying to - // calculate the size of a shape that does not have a layout, since nodes - // inside fusion nodes do not necessarily have a layout assigned. - ShapeSizeFunction shape_size = [](const Shape& shape) { return 0; }; TF_ASSIGN_OR_RETURN( current_properties_, - ProcessSubcomputation(fusion->fused_instructions_computation(), - &shape_size)); + ProcessSubcomputation(fusion->fused_instructions_computation())); // Fusion nodes that produce a tuple also produce the entries in the tuple. // Ignore the memory accessed inside fused ops, since fusion is supposed to @@ -563,11 +577,11 @@ Status HloCostAnalysis::HandleFusion(const HloInstruction* fusion) { ShapeUtil::ForEachSubshape( fusion->shape(), [this](const Shape& subshape, const ShapeIndex& /*shape_index*/) { - current_properties_[kBytesAccessedKey] += shape_size_(subshape); + current_properties_[kBytesAccessedKey] += GetShapeSize(subshape); }); for (const HloInstruction* operand : fusion->operands()) { - current_properties_[kBytesAccessedKey] += shape_size_(operand->shape()); + current_properties_[kBytesAccessedKey] += GetShapeSize(operand->shape()); } return Status::OK(); @@ -648,6 +662,11 @@ Status HloCostAnalysis::HandleGather(const HloInstruction* gather) { return Status::OK(); } +Status HloCostAnalysis::HandleScatter(const HloInstruction* scatter) { + // TODO(b/32945756): Compute the properties of the sub-computation. + return Status::OK(); +} + Status HloCostAnalysis::FinishVisit(const HloInstruction*) { return Status::OK(); } @@ -685,11 +704,8 @@ float HloCostAnalysis::optimal_seconds(const HloInstruction& hlo) const { } StatusOr HloCostAnalysis::ProcessSubcomputation( - HloComputation* computation, const ShapeSizeFunction* shape_size) { - if (shape_size == nullptr) { - shape_size = &shape_size_; - } - HloCostAnalysis visitor(*shape_size, per_second_rates_); + HloComputation* computation) { + HloCostAnalysis visitor(shape_size_, per_second_rates_); TF_RETURN_IF_ERROR(computation->Accept(&visitor)); return visitor.properties(); } diff --git a/tensorflow/compiler/xla/service/hlo_cost_analysis.h b/tensorflow/compiler/xla/service/hlo_cost_analysis.h index 82d650dc7b2a7fdd7c156d5fadcabd40f5535161..193a04bea0831de2b3aca19b17a445ad73e02e49 100644 --- a/tensorflow/compiler/xla/service/hlo_cost_analysis.h +++ b/tensorflow/compiler/xla/service/hlo_cost_analysis.h @@ -71,6 +71,7 @@ class HloCostAnalysis : public ConstDfsHloVisitor { Status HandleConvolution(const HloInstruction* convolution) override; Status HandleFft(const HloInstruction* fft) override; Status HandleCrossReplicaSum(const HloInstruction* crs) override; + Status HandleAllToAll(const HloInstruction* hlo) override; Status HandleInfeed(const HloInstruction* infeed) override; Status HandleOutfeed(const HloInstruction* outfeed) override; Status HandleHostCompute(const HloInstruction* host_compute) override; @@ -104,6 +105,7 @@ class HloCostAnalysis : public ConstDfsHloVisitor { Status HandleWhile(const HloInstruction* xla_while) override; Status HandleConditional(const HloInstruction* conditional) override; Status HandleGather(const HloInstruction* gather) override; + Status HandleScatter(const HloInstruction* scatter) override; Status FinishVisit(const HloInstruction* root) override; Status Preprocess(const HloInstruction* hlo) override; @@ -149,11 +151,8 @@ class HloCostAnalysis : public ConstDfsHloVisitor { const Properties& per_second_rates); // Returns the properties computed from visiting the computation rooted at the - // given hlo. Uses shape_size_ to calculate shape sizes if shape_size is null, - // otherwise uses shape_size_. - StatusOr ProcessSubcomputation( - HloComputation* computation, - const ShapeSizeFunction* shape_size = nullptr); + // given hlo. + StatusOr ProcessSubcomputation(HloComputation* computation); // Utility function to handle all element-wise operations. Status HandleElementwiseOp(const HloInstruction* hlo_instruction); @@ -170,6 +169,10 @@ class HloCostAnalysis : public ConstDfsHloVisitor { static float GetPropertyForHlo(const HloInstruction& hlo, const string& key, const HloToProperties& hlo_to_properties); + // Decorates shape_size_ by returning 0 immediately if the shape does not have + // a layout. + int64 GetShapeSize(const Shape& shape) const; + // Function which computes the size of the top-level of a given shape (not // including nested elements, if any). If null then bytes_accessed methods // return an error. diff --git a/tensorflow/compiler/xla/service/hlo_cost_analysis_test.cc b/tensorflow/compiler/xla/service/hlo_cost_analysis_test.cc index b2241cd423d702b38d4c5dc013217ba42753c767..2c854eea18642eb7cb081b4fdfe3bc83627e41ae 100644 --- a/tensorflow/compiler/xla/service/hlo_cost_analysis_test.cc +++ b/tensorflow/compiler/xla/service/hlo_cost_analysis_test.cc @@ -22,7 +22,7 @@ limitations under the License. #include "tensorflow/compiler/xla/client/client_library.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/client/padding.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/local_service.h" diff --git a/tensorflow/compiler/xla/service/hlo_creation_utils.cc b/tensorflow/compiler/xla/service/hlo_creation_utils.cc index 90d2be118d94d52135820e5b8138fcb06389c684..c4e27dc558ecb2a3a0acfd036de73506ce7631fa 100644 --- a/tensorflow/compiler/xla/service/hlo_creation_utils.cc +++ b/tensorflow/compiler/xla/service/hlo_creation_utils.cc @@ -14,9 +14,10 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/xla/service/hlo_creation_utils.h" +#include "absl/algorithm/container.h" +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/literal_util.h" -#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/shape_inference.h" #include "tensorflow/compiler/xla/util.h" @@ -149,13 +150,13 @@ StatusOr MakeConcatHlo(ArraySlice operands, CHECK_GT(operands.size(), 0); HloComputation* computation = operands[0]->parent(); - CHECK(c_all_of(operands, [&](HloInstruction* instr) { + CHECK(absl::c_all_of(operands, [&](HloInstruction* instr) { return instr->parent() == computation; })); std::vector operand_shapes; - c_transform(operands, std::back_inserter(operand_shapes), - [](HloInstruction* instr) { return &instr->shape(); }); + absl::c_transform(operands, std::back_inserter(operand_shapes), + [](HloInstruction* instr) { return &instr->shape(); }); TF_ASSIGN_OR_RETURN(Shape concat_shape, ShapeInference::InferConcatOpShape( operand_shapes, dimension)); @@ -174,6 +175,29 @@ StatusOr MakeDotHlo(HloInstruction* lhs, HloInstruction* rhs, HloInstruction::CreateDot(dot_shape, lhs, rhs, dim_numbers)); } +StatusOr MakeMapHlo( + tensorflow::gtl::ArraySlice operands, + HloComputation* map_computation) { + CHECK(!operands.empty()) << "Map Hlo requires at least one operand."; + HloComputation* computation = operands.front()->parent(); + std::vector operand_shapes; + int64 max_operand_rank = 0; + for (const HloInstruction* operand : operands) { + CHECK_EQ(computation, operand->parent()); + operand_shapes.push_back(&operand->shape()); + max_operand_rank = + std::max(max_operand_rank, ShapeUtil::Rank(operand->shape())); + } + std::vector map_dims(max_operand_rank); + std::iota(map_dims.begin(), map_dims.end(), 0); + TF_ASSIGN_OR_RETURN( + Shape map_shape, + ShapeInference::InferMapShape( + operand_shapes, map_computation->ComputeProgramShape(), map_dims)); + return computation->AddInstruction( + HloInstruction::CreateMap(map_shape, operands, map_computation)); +} + StatusOr CollapseFirstNDims(HloInstruction* operand, int64 n) { CHECK_GT(n, 0); @@ -205,7 +229,7 @@ StatusOr PrependDegenerateDims(HloInstruction* operand, const Shape& operand_shape = operand->shape(); new_shape_dims.reserve(n + operand_shape.dimensions_size()); new_shape_dims.insert(new_shape_dims.begin(), n, 1); - c_copy(operand_shape.dimensions(), std::back_inserter(new_shape_dims)); + absl::c_copy(operand_shape.dimensions(), std::back_inserter(new_shape_dims)); return MakeReshapeHlo(new_shape_dims, operand); } @@ -217,7 +241,7 @@ StatusOr ExpandFirstDimIntoNDims( std::vector expanded_shape_dim_bounds; expanded_shape_dim_bounds.reserve(expanded_dims.size() + operand->shape().dimensions_size() - 1); - c_copy(expanded_dims, std::back_inserter(expanded_shape_dim_bounds)); + absl::c_copy(expanded_dims, std::back_inserter(expanded_shape_dim_bounds)); std::copy(operand->shape().dimensions().begin() + 1, operand->shape().dimensions().end(), std::back_inserter(expanded_shape_dim_bounds)); @@ -228,7 +252,7 @@ StatusOr ExpandFirstDimIntoNDims( StatusOr ElideDegenerateDims(HloInstruction* operand, ArraySlice dims_to_elide) { - CHECK(c_is_sorted(dims_to_elide)); + CHECK(absl::c_is_sorted(dims_to_elide)); const Shape& input_shape = operand->shape(); // First accumulate in reverse @@ -245,12 +269,44 @@ StatusOr ElideDegenerateDims(HloInstruction* operand, } } - c_reverse(new_shape_dim_bounds); + absl::c_reverse(new_shape_dim_bounds); Shape output_shape = ShapeUtil::MakeShape(input_shape.element_type(), new_shape_dim_bounds); return MakeReshapeHlo(output_shape, operand); } +StatusOr InsertDegenerateDims( + HloInstruction* operand, ArraySlice dims_to_insert) { + CHECK(absl::c_is_sorted(dims_to_insert)); + + const Shape& operand_shape = operand->shape(); + int64 output_shape_rank = + operand_shape.dimensions_size() + dims_to_insert.size(); + for (auto dim_to_insert : dims_to_insert) { + CHECK_LT(dim_to_insert, output_shape_rank); + } + + std::vector output_shape_dim_bounds; + output_shape_dim_bounds.reserve(output_shape_rank); + int64 operand_dims_idx = 0; + int64 dims_to_insert_idx = 0; + for (int64 i = 0; i < output_shape_rank; ++i) { + if (dims_to_insert_idx < dims_to_insert.size() && + i == dims_to_insert[dims_to_insert_idx]) { + output_shape_dim_bounds.push_back(1); + ++dims_to_insert_idx; + } else { + output_shape_dim_bounds.push_back( + operand_shape.dimensions(operand_dims_idx)); + ++operand_dims_idx; + } + } + + Shape output_shape = ShapeUtil::MakeShape(operand_shape.element_type(), + output_shape_dim_bounds); + return MakeReshapeHlo(output_shape, operand); +} + StatusOr PadVectorWithZeros(HloInstruction* operand, int64 zeros_to_prepend, int64 zeros_to_append) { @@ -263,7 +319,7 @@ StatusOr PadVectorWithZeros(HloInstruction* operand, *padding_config.add_dimensions() = padding_config_dim; HloInstruction* zero = computation->AddInstruction( - HloInstruction::CreateConstant(MakeUnique( + HloInstruction::CreateConstant(absl::make_unique( LiteralUtil::Zero(operand->shape().element_type())))); return MakePadHlo(operand, zero, padding_config); } @@ -273,7 +329,7 @@ StatusOr BroadcastZeros( ArraySlice broadcast_dimensions) { HloInstruction* zero = computation->AddInstruction(HloInstruction::CreateConstant( - MakeUnique(LiteralUtil::Zero(element_type)))); + absl::make_unique(LiteralUtil::Zero(element_type)))); return MakeBroadcastHlo(zero, /*broadcast_dimensions=*/{}, /*result_shape_bounds=*/broadcast_dimensions); } diff --git a/tensorflow/compiler/xla/service/hlo_creation_utils.h b/tensorflow/compiler/xla/service/hlo_creation_utils.h index 49b1402d689a74874e34423a1832a0b6aa15f469..5ff8946fb098b57ae563a8ade47e8323f807a369 100644 --- a/tensorflow/compiler/xla/service/hlo_creation_utils.h +++ b/tensorflow/compiler/xla/service/hlo_creation_utils.h @@ -102,6 +102,12 @@ StatusOr MakeConcatHlo( StatusOr MakeDotHlo(HloInstruction* lhs, HloInstruction* rhs, const DotDimensionNumbers& dim_numbers); +// Creates a Map HLO instruction and adds it to the computation containing the +// operands. All operands must be in the same computation. +StatusOr MakeMapHlo( + tensorflow::gtl::ArraySlice operands, + HloComputation* map_computation); + // ----------------------------------------------------------------------------- // Some other miscellaneous helpers to generate common HLO patterns. All of // these add all the instructions they generate into the computation containing @@ -144,6 +150,16 @@ StatusOr ExpandFirstDimIntoNDims( StatusOr ElideDegenerateDims( HloInstruction* operand, tensorflow::gtl::ArraySlice dims_to_elide); +// Inserts (via reshape) a set of degenerate dimensions (dimensions containing +// exactly one element), `dims_to_insert` into `operand`. The dimensions in +// `dims_to_insert` refer to the dimensions in the result, and hence should be +// less than the rank of the result. Also, `dims_to_insert` must be sorted. +// +// For example, if `operand` is of shape f32[12,21,8,34] and dims_to_insert is +// {0, 2}, then the result is `operand` reshaped to [1,12,1,21,8,34]. +StatusOr InsertDegenerateDims( + HloInstruction* operand, tensorflow::gtl::ArraySlice dims_to_insert); + // Pads `operand` (which must have rank 1) with `zeros_to_prepend` zeros in the // front and `zeros_to_append` zeros in the back. StatusOr PadVectorWithZeros(HloInstruction* operand, diff --git a/tensorflow/compiler/xla/service/hlo_creation_utils_test.cc b/tensorflow/compiler/xla/service/hlo_creation_utils_test.cc index 60d3e71757d5ce31e025c744e089ff56091d9a43..a8de285d16fdf6c5824f4076860b57b3fdc279a0 100644 --- a/tensorflow/compiler/xla/service/hlo_creation_utils_test.cc +++ b/tensorflow/compiler/xla/service/hlo_creation_utils_test.cc @@ -14,7 +14,7 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/xla/service/hlo_creation_utils.h" -#include "tensorflow/compiler/xla/ptr_util.h" +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/service/hlo_evaluator.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/shape_util.h" @@ -28,7 +28,7 @@ using tensorflow::gtl::ArraySlice; class HloCreationUtilsTest : public HloTestBase { protected: - static std::unique_ptr CreateModuleWithProgramShape( + std::unique_ptr CreateModuleWithProgramShape( PrimitiveType primitive_type, ArraySlice input_shape_dims, ArraySlice output_shape_dims, HloInstruction** param, HloComputation** entry_computation) { diff --git a/tensorflow/compiler/xla/service/hlo_cse_test.cc b/tensorflow/compiler/xla/service/hlo_cse_test.cc index 90fbaa37c5a70a78a9a818b4a8968f3406c671b1..406d712ec6783a310aabc6600b8b70e1a1ae30a9 100644 --- a/tensorflow/compiler/xla/service/hlo_cse_test.cc +++ b/tensorflow/compiler/xla/service/hlo_cse_test.cc @@ -20,9 +20,9 @@ limitations under the License. #include #include +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/literal.h" -#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_matchers.h" diff --git a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc index de1a32d8bd9217baabda4ab4b02bf28baebad531..9b150579298bd9f4fd73457a5d14ada62be29d66 100644 --- a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc +++ b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc @@ -19,8 +19,8 @@ limitations under the License. #include #include +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/map_util.h" -#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" @@ -886,7 +886,7 @@ StatusOr> HloDataflowAnalysis::Run( VLOG(1) << "HloDataflowAnalysis::Run on module " << module.name(); XLA_VLOG_LINES(2, module.ToString()); - auto dataflow_analysis = WrapUnique(new HloDataflowAnalysis( + auto dataflow_analysis = absl::WrapUnique(new HloDataflowAnalysis( module, ssa_form, bitcast_defines_value, fusion_can_share_buffer)); TF_RETURN_IF_ERROR(dataflow_analysis->InitializeInstructionValueSets()); @@ -1017,19 +1017,17 @@ bool HloDataflowAnalysis::CanShareOperandBufferWithUser( } if (user->opcode() == HloOpcode::kFusion) { + if (fusion_can_share_buffer_ != nullptr) { + return fusion_can_share_buffer_(user, operand); + } // Get the parameter associated with 'operand'; HloInstruction* fusion_param = user->fused_parameter(user->operand_index(operand)); const HloValue& value = GetValueDefinedAt(fusion_param, operand_index); - if (value.uses().size() != 1) { - if (MultiDynamicSliceUseShareSameIndices(value.uses())) { - return true; - } - return false; + if (MultiDynamicSliceUseShareSameIndices(value.uses())) { + return true; } - const HloUse& use = value.uses()[0]; - if (user->fusion_kind() == HloInstruction::FusionKind::kLoop || user->fusion_kind() == HloInstruction::FusionKind::kInput) { if (user->fused_expression_root()->opcode() == @@ -1039,13 +1037,17 @@ bool HloDataflowAnalysis::CanShareOperandBufferWithUser( // Returns true iff there is exactly one use of 'operand' at shape index // 'operand_index', and this singleton use is the fused root at operand // index 0. - return use.instruction == user->fused_expression_root() && - use.operand_number == 0; - } else { - return AreTransitiveUsesElementwiseOrTuple(fusion_param); + if (value.uses().size() == 1) { + const HloUse& use = value.uses()[0]; + return use.instruction == user->fused_expression_root() && + use.operand_number == 0; + } + return false; } - } else if (user->fusion_kind() == HloInstruction::FusionKind::kOutput && - user->fused_expression_root()->opcode() == HloOpcode::kAdd) { + return AreTransitiveUsesElementwiseOrTuple(fusion_param); + } + if (user->fusion_kind() == HloInstruction::FusionKind::kOutput && + user->fused_expression_root()->opcode() == HloOpcode::kAdd) { // Output fusion with kAdd fused root. // Check if one operand of kAdd fused root is kDot or kConvolution. @@ -1066,11 +1068,12 @@ bool HloDataflowAnalysis::CanShareOperandBufferWithUser( // Returns true iff there is exactly one use of 'operand' at shape index // 'operand_index', and this singleton use is the fused root (at operand // index 'other_add_operand_index'). - return use.instruction == user->fused_expression_root() && - use.operand_number == other_add_operand_index; - } else if (fusion_can_share_buffer_ != nullptr && - fusion_can_share_buffer_(user, operand)) { - return true; + if (value.uses().size() == 1) { + const HloUse& use = value.uses()[0]; + return use.instruction == user->fused_expression_root() && + use.operand_number == other_add_operand_index; + } + return false; } } @@ -1081,6 +1084,21 @@ bool HloDataflowAnalysis::CanShareOperandBufferWithUser( std::vector operand_indices = user->OperandIndices(operand); return operand_indices.size() == 1 && operand_indices[0] == 0; } + if (user->opcode() == HloOpcode::kSort) { + // Only valid if there are no other users. + if (operand->users().size() != 1) { + return false; + } + // If we only sort keys, the output of sort is not a tuple, so we can always + // share the buffer. + if (user->operand_count() == 1) { + return true; + } + CHECK(!user_index.empty()); + // Only share with the right tuple element buffer. + std::vector operand_indices = user->OperandIndices(operand); + return operand_indices.size() == 1 && user_index[0] == operand_indices[0]; + } if (user->opcode() == HloOpcode::kCall) { // Get all uses of value defined by 'operand' at 'operand_index'. const auto& uses = GetValueDefinedAt(operand, operand_index).uses(); diff --git a/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc b/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc index 37bc2d2c9d2a0d0624917337b36c5d5f625c0991..4755c4a0cf8d268b1c47e596a14605eb2c60b36c 100644 --- a/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc +++ b/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc @@ -2232,6 +2232,48 @@ TEST_F(CanShareOperandBufferWithUserTest, DynamicUpdateSliceCanShare) { dataflow_analysis_->CanShareOperandBufferWithUser(starts, {}, dus, {})); } +TEST_F(CanShareOperandBufferWithUserTest, SortCanShare) { + auto builder = HloComputation::Builder(TestName()); + + Shape keys_shape = ShapeUtil::MakeShape(F32, {8}); + auto keys = builder.AddInstruction( + HloInstruction::CreateParameter(0, keys_shape, "keys")); + auto sort = + builder.AddInstruction(HloInstruction::CreateSort(keys_shape, 0, keys)); + + BuildModuleAndRunAnalysis(builder.Build()); + + EXPECT_TRUE( + dataflow_analysis_->CanShareOperandBufferWithUser(keys, {}, sort, {})); +} + +TEST_F(CanShareOperandBufferWithUserTest, SortCanShareWithTupleUser) { + auto builder = HloComputation::Builder(TestName()); + + Shape keys_shape = ShapeUtil::MakeShape(F32, {8}); + Shape values_shape = ShapeUtil::MakeShape(F32, {8}); + auto keys = builder.AddInstruction( + HloInstruction::CreateParameter(0, keys_shape, "keys")); + auto values = builder.AddInstruction( + HloInstruction::CreateParameter(1, values_shape, "values")); + auto sort = builder.AddInstruction(HloInstruction::CreateSort( + ShapeUtil::MakeTupleShape({keys_shape, values_shape}), 0, keys, values)); + + BuildModuleAndRunAnalysis(builder.Build()); + + // The buffer for the keys can be shared with the first tuple entry. + EXPECT_TRUE( + dataflow_analysis_->CanShareOperandBufferWithUser(keys, {}, sort, {0})); + // The buffer for the values can be shared with the second tuple entry. + EXPECT_TRUE( + dataflow_analysis_->CanShareOperandBufferWithUser(values, {}, sort, {1})); + // Verify that the buffers are not shared with the "wrong" tuple entry. + EXPECT_FALSE( + dataflow_analysis_->CanShareOperandBufferWithUser(keys, {}, sort, {1})); + EXPECT_FALSE( + dataflow_analysis_->CanShareOperandBufferWithUser(values, {}, sort, {0})); +} + TEST_F(CanShareOperandBufferWithUserTest, FusedDotAdd) { auto builder = HloComputation::Builder(TestName()); Shape data_shape = ShapeUtil::MakeShape(F32, {2, 2}); @@ -2323,7 +2365,7 @@ TEST_F(CanShareOperandBufferWithUserTest, FusionCanShareBufferCustomized) { TEST_F(CanShareOperandBufferWithUserTest, WhileCanShare) { Shape data_shape = ShapeUtil::MakeShape(F32, {8}); - auto make_cond = [this, &data_shape]() { + auto make_cond = [&data_shape]() { auto builder = HloComputation::Builder(TestName() + ".Cond"); auto data = builder.AddInstruction( HloInstruction::CreateParameter(0, data_shape, "data")); @@ -2332,7 +2374,7 @@ TEST_F(CanShareOperandBufferWithUserTest, WhileCanShare) { return builder.Build(); }; - auto make_body = [this, &data_shape]() { + auto make_body = [&data_shape]() { auto builder = HloComputation::Builder(TestName() + ".Body"); auto data = builder.AddInstruction( HloInstruction::CreateParameter(0, data_shape, "data")); diff --git a/tensorflow/compiler/xla/service/hlo_dce_test.cc b/tensorflow/compiler/xla/service/hlo_dce_test.cc index 26e3736e01270dbc6ca67647e814843aba2d1e3d..3b5cde2996c4195ef458662cd21de85a832d8d55 100644 --- a/tensorflow/compiler/xla/service/hlo_dce_test.cc +++ b/tensorflow/compiler/xla/service/hlo_dce_test.cc @@ -17,9 +17,9 @@ limitations under the License. #include +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/literal_util.h" -#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module.h" diff --git a/tensorflow/compiler/xla/service/hlo_domain_map.cc b/tensorflow/compiler/xla/service/hlo_domain_map.cc index 9e096320db5048457435199627a1ef1fe1572177..edf0073f3091ef4da7ced3f13b56961a7db4b430 100644 --- a/tensorflow/compiler/xla/service/hlo_domain_map.cc +++ b/tensorflow/compiler/xla/service/hlo_domain_map.cc @@ -17,6 +17,7 @@ limitations under the License. #include +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/map_util.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/types.h" @@ -25,14 +26,14 @@ namespace xla { /* static */ StatusOr> HloDomainMap::Create( HloComputation* computation, string domain_kind) { - auto domain_map = WrapUnique(new HloDomainMap(std::move(domain_kind))); + auto domain_map = absl::WrapUnique(new HloDomainMap(std::move(domain_kind))); TF_RETURN_IF_ERROR(domain_map->Populate(computation)); return std::move(domain_map); } /* static */ StatusOr> HloDomainMap::Create( HloModule* module, string domain_kind) { - auto domain_map = WrapUnique(new HloDomainMap(std::move(domain_kind))); + auto domain_map = absl::WrapUnique(new HloDomainMap(std::move(domain_kind))); for (HloComputation* computation : module->computations()) { TF_RETURN_IF_ERROR(domain_map->Populate(computation)); } @@ -56,14 +57,14 @@ Status HloDomainMap::TryProcessEmptyDomain(HloInstruction* instruction) { // both sides. for (HloInstruction* operand : instruction->unique_operands()) { if (IsDomainInstruction(operand)) { - auto domain = MakeUnique(); + auto domain = absl::make_unique(); domain->enter_domains.insert(operand); domain->exit_domains.insert(instruction); TF_RETURN_IF_ERROR(InsertDomain(std::move(domain))); } } if (instruction == instruction->parent()->root_instruction()) { - auto domain = MakeUnique(); + auto domain = absl::make_unique(); domain->enter_domains.insert(instruction); TF_RETURN_IF_ERROR(InsertDomain(std::move(domain))); } @@ -143,7 +144,7 @@ Status HloDomainMap::ExpandDomain(HloInstruction* instruction, StatusOr> HloDomainMap::CreateDomain( HloInstruction* instruction) const { - auto domain = MakeUnique(); + auto domain = absl::make_unique(); TF_RETURN_IF_ERROR(ExpandDomain(instruction, domain.get())); domain->instructions = MakeNonDomainInstructions(domain->reach_set); return std::move(domain); diff --git a/tensorflow/compiler/xla/service/hlo_domain_test.cc b/tensorflow/compiler/xla/service/hlo_domain_test.cc index ffc18a0f886df86d87944d9c284a6faf8afe4c60..7d48be15cfdd2d89945a6ea28d8fee51838fbb16 100644 --- a/tensorflow/compiler/xla/service/hlo_domain_test.cc +++ b/tensorflow/compiler/xla/service/hlo_domain_test.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/service/hlo_domain_isolator.h" #include "tensorflow/compiler/xla/service/hlo_domain_metadata.h" @@ -80,7 +81,7 @@ class OpNameMetadata : public DomainMetadata { explicit OpNameMetadata(string opname) : opname_(std::move(opname)) {} std::unique_ptr Clone() const override { - return MakeUnique(opname_); + return absl::make_unique(opname_); } tensorflow::StringPiece Kind() const override { return KindName(); } @@ -110,9 +111,9 @@ std::unique_ptr OpNameDomainCreator(HloInstruction* instruction, return nullptr; } std::unique_ptr operand_side_metadata = - MakeUnique(operand->metadata().op_name()); + absl::make_unique(operand->metadata().op_name()); std::unique_ptr user_side_metadata = - MakeUnique(instruction->metadata().op_name()); + absl::make_unique(instruction->metadata().op_name()); return HloInstruction::CreateDomain(operand->shape(), operand, std::move(operand_side_metadata), std::move(user_side_metadata)); @@ -474,8 +475,8 @@ ENTRY entry { TEST_F(HloDomainTest, DumpParseNullSharding) { auto builder = HloComputation::Builder(TestName()); Shape shape = ShapeUtil::MakeShape(F32, {}); - auto sharding_md_0 = MakeUnique(nullptr); - auto sharding_md_1 = MakeUnique(nullptr); + auto sharding_md_0 = absl::make_unique(nullptr); + auto sharding_md_1 = absl::make_unique(nullptr); HloInstruction* param = builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "p")); HloInstruction* domain = builder.AddInstruction(HloInstruction::CreateDomain( @@ -490,5 +491,38 @@ TEST_F(HloDomainTest, DumpParseNullSharding) { ASSERT_TRUE(ParseModule(hlo_string).status().ok()); } +TEST_F(HloDomainTest, DomainTuple) { + const char* const hlo_string = R"( +HloModule Module + +ENTRY entry { + p0 = f32[4] parameter(0), sharding={maximal device=0} + cst = u32[] constant(0), sharding={maximal device=1} + tpl = (u32[], f32[4]) tuple(cst, p0), sharding={{maximal device=1}, {maximal device=0}} + ROOT gte = f32[4] get-tuple-element(tpl), index=1, sharding={maximal device=0} +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(HloModule * module, ParseModule(hlo_string)); + + HloDomainIsolator isolator(CreateShardingDomain); + TF_ASSERT_OK_AND_ASSIGN(bool isolator_changed, isolator.Run(module)); + EXPECT_TRUE(isolator_changed); + + // Clear sharding of tpl instruction, in order to test domain sharding + // application. + auto tpl = FindInstruction(module, "tpl"); + tpl->clear_sharding(); + + HloDomainRemover remover(ShardingMetadata::KindName(), + ShardingMetadata::NormalizeShardingDomain); + TF_ASSERT_OK_AND_ASSIGN(bool remover_changed, remover.Run(module)); + EXPECT_TRUE(remover_changed); + + EXPECT_EQ(HloSharding::Tuple(tpl->shape(), {HloSharding::AssignDevice(1), + HloSharding::AssignDevice(0)}), + tpl->sharding()); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_element_type_converter.cc b/tensorflow/compiler/xla/service/hlo_element_type_converter.cc index c804f4364f6d16d5b8112219ce884495200aa827..b9244b8e9e5f34e7ac5113c8eacb6f8243eea314 100644 --- a/tensorflow/compiler/xla/service/hlo_element_type_converter.cc +++ b/tensorflow/compiler/xla/service/hlo_element_type_converter.cc @@ -144,6 +144,7 @@ StatusOr HloElementTypeConverter::Run(HloModule* module) { opcode == HloOpcode::kCrossReplicaSum || opcode == HloOpcode::kFusion || opcode == HloOpcode::kMap || opcode == HloOpcode::kReduce || opcode == HloOpcode::kReduceWindow || + opcode == HloOpcode::kScatter || opcode == HloOpcode::kSelectAndScatter || opcode == HloOpcode::kConditional) { continue; diff --git a/tensorflow/compiler/xla/service/hlo_evaluator.cc b/tensorflow/compiler/xla/service/hlo_evaluator.cc index 51353eea6e72d5a131897f3c3ae312046051103e..35d9e799df6f0ba7c1c27909c36f7a3f3029d640 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator.cc +++ b/tensorflow/compiler/xla/service/hlo_evaluator.cc @@ -23,13 +23,14 @@ limitations under the License. #include #include +#include "absl/algorithm/container.h" +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/index_util.h" #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/map_util.h" #include "tensorflow/compiler/xla/primitive_util.h" -#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" @@ -95,7 +96,7 @@ StatusOr> Compare(const Shape& shape, HloOpcode opcode, << HloOpcodeString(opcode); } - auto result = MakeUnique(shape); + auto result = absl::make_unique(shape); TF_RETURN_IF_ERROR(result->Populate([&](ArraySlice multi_index) { return compare_op(lhs_literal.Get(multi_index), rhs_literal.Get(multi_index)); @@ -125,7 +126,7 @@ StatusOr> Compare( << HloOpcodeString(opcode); } - auto result = MakeUnique(shape); + auto result = absl::make_unique(shape); TF_RETURN_IF_ERROR(result->Populate([&](ArraySlice multi_index) { return compare_op(lhs_literal.Get(multi_index), rhs_literal.Get(multi_index)); @@ -138,44 +139,57 @@ StatusOr> Compare( HloEvaluator::HloEvaluator(int64 max_loop_iterations) : max_loop_iterations_(max_loop_iterations) { - typed_visitors_[PRED] = MakeUnique>(this); - typed_visitors_[U8] = MakeUnique>(this); - typed_visitors_[U16] = MakeUnique([](HloInstruction*) { - return Unimplemented( - "HloEvaluator::HloEvaluatorTypedVisitor: unhandled primitive type: " - "U16."); - }); - typed_visitors_[U32] = MakeUnique>(this); - typed_visitors_[U64] = MakeUnique>(this); - typed_visitors_[S8] = MakeUnique>(this); - typed_visitors_[S16] = MakeUnique([](HloInstruction*) { - return Unimplemented( - "HloEvaluator::HloEvaluatorTypedVisitor: unhandled primitive type: " - "S16."); - }); - typed_visitors_[S32] = MakeUnique>(this); - typed_visitors_[S64] = MakeUnique>(this); + typed_visitors_[PRED] = + absl::make_unique>(this); + typed_visitors_[U8] = + absl::make_unique>(this); + typed_visitors_[U16] = + absl::make_unique([](HloInstruction*) { + return Unimplemented( + "HloEvaluator::HloEvaluatorTypedVisitor: unhandled primitive type: " + "U16."); + }); + typed_visitors_[U32] = + absl::make_unique>(this); + typed_visitors_[U64] = + absl::make_unique>(this); + typed_visitors_[S8] = absl::make_unique>(this); + typed_visitors_[S16] = + absl::make_unique([](HloInstruction*) { + return Unimplemented( + "HloEvaluator::HloEvaluatorTypedVisitor: unhandled primitive type: " + "S16."); + }); + typed_visitors_[S32] = + absl::make_unique>(this); + typed_visitors_[S64] = + absl::make_unique>(this); typed_visitors_[F16] = - MakeUnique>(this); - typed_visitors_[F32] = MakeUnique>(this); - typed_visitors_[F64] = MakeUnique>(this); - typed_visitors_[C64] = MakeUnique>(this); + absl::make_unique>(this); + typed_visitors_[F32] = + absl::make_unique>(this); + typed_visitors_[F64] = + absl::make_unique>(this); + typed_visitors_[C64] = + absl::make_unique>(this); // Most of the evaluator computations we use don't support BF16 (e.g., // std::ceil, std::tanh). To make evaluator work with BF16, we set all // elementwise computations to be done in F32 and do BF16<->F32 conversion // around the input and the output of the computations. typed_visitors_[BF16] = - MakeUnique>(this); - - typed_visitors_[TUPLE] = MakeUnique([](HloInstruction*) { - return Unimplemented( - "HloEvaluatorTypedVisitor: unhandled primitive type: TUPLE."); - }); - typed_visitors_[OPAQUE] = MakeUnique([](HloInstruction*) { - return Unimplemented( - "HloEvaluatorTypedVisitor: unhandled primitive type: OPAQUE."); - }); + absl::make_unique>(this); + + typed_visitors_[TUPLE] = + absl::make_unique([](HloInstruction*) { + return Unimplemented( + "HloEvaluatorTypedVisitor: unhandled primitive type: TUPLE."); + }); + typed_visitors_[OPAQUE] = + absl::make_unique([](HloInstruction*) { + return Unimplemented( + "HloEvaluatorTypedVisitor: unhandled primitive type: OPAQUE."); + }); } template @@ -555,43 +569,41 @@ Status HloEvaluator::HandleTuple(HloInstruction* tuple) { return Status::OK(); } -// Returns an ShapeUtil::IndexIterationSpace that iterates over the output -// gather dimensions while keeping the rest of the output dimensions clamped to -// 0. -ShapeUtil::IndexIterationSpace IterationSpaceForOutputGatherIndices( +// Returns an ShapeUtil::IndexIterationSpace that iterates over the output batch +// dimensions while keeping the rest of the output dimensions clamped to 0. +ShapeUtil::IndexIterationSpace IterationSpaceForOutputBatchIndices( const Shape& output_shape, const GatherDimensionNumbers& dim_numbers) { int64 output_rank = output_shape.dimensions_size(); std::vector index_base(output_rank, 0); std::vector index_count; index_count.reserve(output_rank); for (int64 i = 0; i < output_rank; i++) { - bool is_output_gather_dim = - !c_binary_search(dim_numbers.output_window_dims(), i); - index_count.push_back(is_output_gather_dim ? output_shape.dimensions(i) - : 1); + bool is_output_batch_dim = + !absl::c_binary_search(dim_numbers.offset_dims(), i); + index_count.push_back(is_output_batch_dim ? output_shape.dimensions(i) : 1); } return {std::move(index_base), std::move(index_count), std::vector(output_rank, 1)}; } -// Return an ShapeUtil::IndexIterationSpace that iterates over the output window +// Return an ShapeUtil::IndexIterationSpace that iterates over the output slice // dimensions while keeping the rest of the output dimensions clamped to 0. -ShapeUtil::IndexIterationSpace IterationSpaceForOutputWindowIndices( - int64 output_rank, ArraySlice window_bounds, +ShapeUtil::IndexIterationSpace IterationSpaceForOutputOffsetIndices( + int64 output_rank, ArraySlice slice_sizes, const GatherDimensionNumbers& dim_numbers) { std::vector index_base(output_rank, 0); std::vector index_count(output_rank, 1); - int64 window_bounds_idx = 0; + int64 slice_sizes_idx = 0; for (int64 i = 0; i < output_rank; i++) { bool is_output_window_dim = - c_binary_search(dim_numbers.output_window_dims(), i); + absl::c_binary_search(dim_numbers.offset_dims(), i); if (is_output_window_dim) { - while (c_binary_search(dim_numbers.elided_window_dims(), - window_bounds_idx)) { - window_bounds_idx++; + while (absl::c_binary_search(dim_numbers.collapsed_slice_dims(), + slice_sizes_idx)) { + slice_sizes_idx++; } - index_count[i] = window_bounds[window_bounds_idx++]; + index_count[i] = slice_sizes[slice_sizes_idx++]; } } @@ -599,30 +611,30 @@ ShapeUtil::IndexIterationSpace IterationSpaceForOutputWindowIndices( std::vector(output_rank, 1)}; } -// This functor computes the contribution of gather_indices to an input index +// This functor computes the contribution of start_indices to an input index // corresponding to an output index. That is, given an output index I, it picks -// out the gather output indices in I and uses them to look up a gather index, -// G, from the gather indices tensor, and expands G into the input space -// according to gather_dims_to_operand_dims. -class OutputGatherIndexToInputIndex { +// out the batch indices in I and uses them to look up a starting index, G, from +// the start indices tensor, and expands G into the input space according to +// start_index_map. +class OutputBatchIndexToInputIndex { public: // The constructor does some setup work that is amortized across all // iterations. - explicit OutputGatherIndexToInputIndex( + explicit OutputBatchIndexToInputIndex( const GatherDimensionNumbers* dim_numbers, const Shape& input_shape, - const Shape& output_shape, const Literal* gather_indices) - : dim_numbers_(*dim_numbers), gather_indices_(*gather_indices) { + const Shape& output_shape, const Literal* start_indices) + : dim_numbers_(*dim_numbers), start_indices_(*start_indices) { for (int64 i = 0; i < output_shape.dimensions_size(); i++) { - output_dim_is_gather_dims_.push_back( - !c_binary_search(dim_numbers_.output_window_dims(), i)); + output_dim_is_batch_dims_.push_back( + !absl::c_binary_search(dim_numbers_.offset_dims(), i)); } for (int64 i = 0; i < input_shape.dimensions_size(); i++) { int64 index_of_input_dim_in_index_vector = - std::distance(dim_numbers_.gather_dims_to_operand_dims().begin(), - c_find(dim_numbers_.gather_dims_to_operand_dims(), i)); + std::distance(dim_numbers_.start_index_map().begin(), + absl::c_find(dim_numbers_.start_index_map(), i)); if (index_of_input_dim_in_index_vector == - dim_numbers_.gather_dims_to_operand_dims_size()) { + dim_numbers_.start_index_map_size()) { input_dim_value_to_index_vector_.push_back(-1); } else { input_dim_value_to_index_vector_.push_back( @@ -630,14 +642,14 @@ class OutputGatherIndexToInputIndex { } } - index_vector_index_.resize(gather_indices_.shape().dimensions_size()); + index_vector_index_.resize(start_indices_.shape().dimensions_size()); input_index_.resize(input_shape.dimensions_size()); int64 index_vector_size = - gather_indices_.shape().dimensions(dim_numbers_.index_vector_dim()); + start_indices_.shape().dimensions(dim_numbers_.index_vector_dim()); index_vector_.resize(index_vector_size); } - // Returns the contribution of gather_indices to the input index corresponding + // Returns the contribution of start_indices to the input index corresponding // to output_index. See gather_inner_loop_body. // // This is conceptually a stateless transformation from output_index to the @@ -659,7 +671,7 @@ class OutputGatherIndexToInputIndex { } private: - // Propagates the gather index dimensions from the output index into + // Propagates the batch dimensions from the output index into // index_vector_index_ by mutating index_vector_index_ in place. Does not // update the dim_numbers.index_vector_dim() dimension -- that's the dimension // we iterate over in FetchIndexVector. @@ -667,7 +679,7 @@ class OutputGatherIndexToInputIndex { ArraySlice output_index) { int64 index_vector_index_i = 0; for (int64 i = 0, e = output_index.size(); i < e; i++) { - if (!output_dim_is_gather_dims_[i]) { + if (!output_dim_is_batch_dims_[i]) { continue; } @@ -679,14 +691,14 @@ class OutputGatherIndexToInputIndex { } } - // Populates index_vector_ by iterating over gather_indices_ according to + // Populates index_vector_ by iterating over start_indices_ according to // index_vector_index_. Status FetchIndexVector() { int64 index_vector_dim = dim_numbers_.index_vector_dim(); for (int64 i = 0, e = index_vector_.size(); i < e; i++) { index_vector_index_[index_vector_dim] = i; - TF_ASSIGN_OR_RETURN(index_vector_[i], gather_indices_.GetIntegralAsS64( - index_vector_index_)); + TF_ASSIGN_OR_RETURN(index_vector_[i], + start_indices_.GetIntegralAsS64(index_vector_index_)); } return Status::OK(); } @@ -708,15 +720,15 @@ class OutputGatherIndexToInputIndex { // PropagateIndexVectorToInputIndex. std::vector input_dim_value_to_index_vector_; - // output_dim_is_gather_dims_[i] is true iff the output index i is a gather + // output_dim_is_batch_dims_[i] is true iff the output index i is a gather // dimension. - std::vector output_dim_is_gather_dims_; + std::vector output_dim_is_batch_dims_; - // The buffer into which we construct an index into gather_indices_ to fetch + // The buffer into which we construct an index into start_indices_ to fetch // the index vector. std::vector index_vector_index_; - // The index vector fetched from gather_indices_. + // The index vector fetched from start_indices_. std::vector index_vector_; // The result computed by this functor. operator() returns an ArraySlice into @@ -724,24 +736,23 @@ class OutputGatherIndexToInputIndex { std::vector input_index_; const GatherDimensionNumbers& dim_numbers_; - const Literal& gather_indices_; + const Literal& start_indices_; }; -// This functor computes the contribution of the window indices in an output +// This functor computes the contribution of the offset indices in an output // index to an input index. That is, given an output index I it picks out the -// output window indices in I and expands it into a window index into the input -// shape. -class OutputWindowIndexToInputIndex { +// output offset indices in I and expands it into an index into the input shape. +class OutputOffsetIndexToInputIndex { public: // The constructor does some setup work that is amortized across all // iterations. - explicit OutputWindowIndexToInputIndex( + explicit OutputOffsetIndexToInputIndex( const GatherDimensionNumbers& dim_numbers, const Shape& input_shape, const Shape& output_shape) { std::vector window_index_to_output_index; int64 output_index_count = 0; for (int64 i = 0; i < output_shape.dimensions_size(); i++) { - if (c_binary_search(dim_numbers.output_window_dims(), i)) { + if (absl::c_binary_search(dim_numbers.offset_dims(), i)) { window_index_to_output_index.push_back(output_index_count++); } else { output_index_count++; @@ -750,7 +761,7 @@ class OutputWindowIndexToInputIndex { int64 window_dim_count = 0; for (int64 i = 0; i < input_shape.dimensions_size(); i++) { - if (c_binary_search(dim_numbers.elided_window_dims(), i)) { + if (absl::c_binary_search(dim_numbers.collapsed_slice_dims(), i)) { input_dim_value_to_output_index_.push_back(-1); } else { input_dim_value_to_output_index_.push_back( @@ -808,20 +819,20 @@ class OutputWindowIndexToInputIndex { // Rehapes the gather indices input to have a trailing degenerate `1` dimension // if necessary. Hands over the ownership of the newly created literal (if -// there is one) to `reshaped_gather_indices`. +// there is one) to `reshaped_start_indices`. static StatusOr> ReshapedGatherIndices( - int64 index_vector_dim, const Literal& gather_indices, - std::unique_ptr* reshaped_gather_indices) { - if (gather_indices.shape().dimensions_size() != index_vector_dim) { - return std::cref(gather_indices); + int64 index_vector_dim, const Literal& start_indices, + std::unique_ptr* reshaped_start_indices) { + if (start_indices.shape().dimensions_size() != index_vector_dim) { + return std::cref(start_indices); } - std::vector new_shape(gather_indices.shape().dimensions().begin(), - gather_indices.shape().dimensions().end()); + std::vector new_shape(start_indices.shape().dimensions().begin(), + start_indices.shape().dimensions().end()); new_shape.push_back(1); - TF_ASSIGN_OR_RETURN(*reshaped_gather_indices, - gather_indices.Reshape(new_shape)); - return std::cref(**reshaped_gather_indices); + TF_ASSIGN_OR_RETURN(*reshaped_start_indices, + start_indices.Reshape(new_shape)); + return std::cref(**reshaped_start_indices); } Status HloEvaluator::HandleGather(HloInstruction* gather) { @@ -830,34 +841,33 @@ Status HloEvaluator::HandleGather(HloInstruction* gather) { const GatherDimensionNumbers& dim_numbers = gather->gather_dimension_numbers(); const Literal& operand = GetEvaluatedLiteralFor(gather->operand(0)); - std::unique_ptr reshaped_gather_indices; + std::unique_ptr reshaped_start_indices; TF_ASSIGN_OR_RETURN( - const Literal& gather_indices, + const Literal& start_indices, ReshapedGatherIndices(dim_numbers.index_vector_dim(), GetEvaluatedLiteralFor(gather->operand(1)), - &reshaped_gather_indices)); + &reshaped_start_indices)); // We iterate over the gather dimensions in the output shape in an outer loop // nest, and iterate over the window dimensions in the output shape in an // inner loop nest. - ShapeUtil::IndexIterationSpace gather_indices_iteration_space = - IterationSpaceForOutputGatherIndices(shape, dim_numbers); - ShapeUtil::IndexIterationSpace window_indices_iteration_space = - IterationSpaceForOutputWindowIndices( - shape.dimensions_size(), gather->gather_window_bounds(), dim_numbers); + ShapeUtil::IndexIterationSpace start_indices_iteration_space = + IterationSpaceForOutputBatchIndices(shape, dim_numbers); + ShapeUtil::IndexIterationSpace offset_indices_iteration_space = + IterationSpaceForOutputOffsetIndices( + shape.dimensions_size(), gather->gather_slice_sizes(), dim_numbers); // Scratch buffers that hold an index in the output shape and the // corresponding index in the input shape. std::vector input_index(operand.shape().dimensions_size()); std::vector output_index(gather->shape().dimensions_size()); - std::vector input_gather_index_clamped( - operand.shape().dimensions_size()); + std::vector input_index_clamped(operand.shape().dimensions_size()); - OutputGatherIndexToInputIndex output_gather_index_to_input_index( + OutputBatchIndexToInputIndex output_batch_index_to_input_index( &gather->gather_dimension_numbers(), /*input_shape=*/operand.shape(), - /*output_shape=*/shape, &gather_indices); - OutputWindowIndexToInputIndex output_window_index_to_input_index( + /*output_shape=*/shape, &start_indices); + OutputOffsetIndexToInputIndex output_offset_index_to_input_index( gather->gather_dimension_numbers(), /*input_shape=*/operand.shape(), /*output_shape=*/shape); @@ -869,29 +879,29 @@ Status HloEvaluator::HandleGather(HloInstruction* gather) { ArraySlice output_gather_index) -> StatusOr { TF_ASSIGN_OR_RETURN( ArraySlice input_window_index, - output_window_index_to_input_index(output_window_index)); + output_offset_index_to_input_index(output_window_index)); for (int i = 0, e = output_index.size(); i < e; i++) { output_index[i] = output_gather_index[i] + output_window_index[i]; DCHECK_LT(output_index[i], shape.dimensions(i)); } for (int i = 0, e = input_gather_index.size(); i < e; i++) { int64 output_dim = - output_window_index_to_input_index.input_dim_value_to_output_index(i); + output_offset_index_to_input_index.input_dim_value_to_output_index(i); // If 'output_dim' is -1, it means 'i' is an elided window dim. This means // we set the iteration index to 0, so for the purpose of the following // calculations we can consider the output dimension size to be 1. int64 output_dim_size = output_dim == -1 ? 1 : shape.dimensions(output_dim); // Clamp the gather index so that the gather region fits in the operand. - // input_gather_index_clamped[i] = clamp(input_gather_index[i], 0, + // input_index_clamped[i] = clamp(input_gather_index[i], 0, // operand_shape.dimensions(i) - // output_dim_size); - input_gather_index_clamped[i] = + input_index_clamped[i] = std::min(operand_shape.dimensions(i) - output_dim_size, std::max(0LL, input_gather_index[i])); } for (int i = 0, e = input_index.size(); i < e; i++) { - input_index[i] = input_gather_index_clamped[i] + input_window_index[i]; + input_index[i] = input_index_clamped[i] + input_window_index[i]; DCHECK_GE(input_index[i], 0); DCHECK_LT(input_index[i], operand_shape.dimensions(i)); } @@ -902,18 +912,17 @@ Status HloEvaluator::HandleGather(HloInstruction* gather) { auto gather_outer_loop_body = [&](ArraySlice output_gather_index) -> StatusOr { - TF_ASSIGN_OR_RETURN( - ArraySlice input_gather_index, - output_gather_index_to_input_index(output_gather_index)); + TF_ASSIGN_OR_RETURN(ArraySlice input_gather_index, + output_batch_index_to_input_index(output_gather_index)); TF_RETURN_IF_ERROR(ShapeUtil::ForEachIndexWithStatus( - shape, window_indices_iteration_space, + shape, offset_indices_iteration_space, std::bind(gather_inner_loop_body, std::placeholders::_1, input_gather_index, output_gather_index))); return true; }; TF_RETURN_IF_ERROR(ShapeUtil::ForEachIndexWithStatus( - shape, gather_indices_iteration_space, gather_outer_loop_body)); + shape, start_indices_iteration_space, gather_outer_loop_body)); evaluated_[gather] = std::move(result); return Status::OK(); } @@ -960,7 +969,7 @@ Status HloEvaluator::HandleGetTupleElement(HloInstruction* get_tuple_element) { const Literal& operand_tuple_literal = GetEvaluatedLiteralFor(operand); - evaluated_[get_tuple_element] = MakeUnique( + evaluated_[get_tuple_element] = absl::make_unique( ShapeUtil::GetTupleElementShape(operand->shape(), index)); return evaluated_[get_tuple_element]->CopyFrom(operand_tuple_literal, /*dest_shape_index=*/{}, @@ -1162,10 +1171,11 @@ StatusOr> EvaluateSortInternal( result_keys.push_back(key_value.first); result_values.push_back(key_value.second); } - auto result_keys_literal = MakeUnique(keys_literal.shape()); + auto result_keys_literal = absl::make_unique(keys_literal.shape()); result_keys_literal->PopulateR1( tensorflow::gtl::ArraySlice(result_keys)); - auto result_values_literal = MakeUnique(values_literal.shape()); + auto result_values_literal = + absl::make_unique(values_literal.shape()); result_values_literal->PopulateR1( tensorflow::gtl::ArraySlice(result_values)); return std::make_pair(std::move(result_keys_literal), @@ -1180,8 +1190,9 @@ StatusOr> EvaluateSortInternal( } else { // For R2 sort, the desired semantics are to sort each matrix row // independently. - auto keys_result_literal = MakeUnique(keys_literal.shape()); - auto values_result_literal = MakeUnique(values_literal.shape()); + auto keys_result_literal = absl::make_unique(keys_literal.shape()); + auto values_result_literal = + absl::make_unique(values_literal.shape()); int64 r1_length = keys_literal.shape().dimensions(1); for (int64 row = 0; row < keys_literal.shape().dimensions(0); ++row) { TF_ASSIGN_OR_RETURN(auto keys_r1_slice, diff --git a/tensorflow/compiler/xla/service/hlo_evaluator.h b/tensorflow/compiler/xla/service/hlo_evaluator.h index a4c37ef32827892194da070ee05ec6dc4f4c306f..7588916de5068416410daf1a71a0bbad56f3ef0b 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator.h +++ b/tensorflow/compiler/xla/service/hlo_evaluator.h @@ -18,7 +18,7 @@ limitations under the License. #include -#include "tensorflow/compiler/xla/ptr_util.h" +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" @@ -226,7 +226,7 @@ class HloEvaluator : public DfsHloVisitorWithDefault { ShapeUtil::HumanString(operand->shape()).c_str()); } - auto result = MakeUnique(shape); + auto result = absl::make_unique(shape); TF_RETURN_IF_ERROR(result->Populate( [&](tensorflow::gtl::ArraySlice multi_index) { return unary_op(operand_literal.Get(multi_index)); diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_test.cc b/tensorflow/compiler/xla/service/hlo_evaluator_test.cc index 5f575b24a1fb36c5384592028e0f1f6a8e9404b6..4b8e6260ac837fa88a64126aaf83998b060d7efc 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator_test.cc +++ b/tensorflow/compiler/xla/service/hlo_evaluator_test.cc @@ -21,7 +21,8 @@ limitations under the License. #include #include -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "absl/memory/memory.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/reference_util.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" @@ -52,7 +53,7 @@ class HloEvaluatorTest : public ::testing::WithParamInterface, public HloVerifiedTestBase { protected: HloEvaluatorTest() : use_bfloat16_(GetParam()) { - evaluator_ = MakeUnique(); + evaluator_ = absl::make_unique(); } std::unique_ptr Evaluate( @@ -523,7 +524,7 @@ TEST_P(HloEvaluatorTest, Pad4DFloatArrayWithInteriorPadding) { std::unique_ptr result = Evaluate(); - auto expected_array = MakeUnique>(8, 5, 1, 1); + auto expected_array = absl::make_unique>(8, 5, 1, 1); expected_array->Fill(kPadValue); (*expected_array)(1, 0, 0, 0) = 1.0f; (*expected_array)(1, 2, 0, 0) = 2.0f; @@ -547,7 +548,7 @@ TEST_P(HloEvaluatorTest, NegativePadding2D) { // { 9, 10, 11 }, // { 13, 14, 15 }, // } - auto input_array = MakeUnique>(4, 3); + auto input_array = absl::make_unique>(4, 3); input_array->FillUnique(1.0f); auto input = LiteralUtil::CreateR2FromArray2D(*input_array); HloInstruction* input_instruction = @@ -568,7 +569,7 @@ TEST_P(HloEvaluatorTest, NegativePadding2D) { std::unique_ptr result = Evaluate(); // f32[1,5] { 7.0, 2.718, 2.718, 2.718, 2.718 } - auto expected_array = MakeUnique>(1, 5); + auto expected_array = absl::make_unique>(1, 5); (*expected_array)(0, 0) = 7.0f; (*expected_array)(0, 1) = 2.718f; (*expected_array)(0, 2) = 2.718f; @@ -588,7 +589,7 @@ TEST_P(HloEvaluatorTest, NegativeAndInteriorPadding2D) { // { 9, 10, 11 }, // { 13, 14, 15 }, // } - auto input_array = MakeUnique>(4, 3); + auto input_array = absl::make_unique>(4, 3); input_array->FillUnique(1.0f); auto input = LiteralUtil::CreateR2FromArray2D(*input_array); HloInstruction* input_instruction = @@ -612,7 +613,7 @@ TEST_P(HloEvaluatorTest, NegativeAndInteriorPadding2D) { std::unique_ptr result = Evaluate(); - auto expected_array = MakeUnique>(0, 9); + auto expected_array = absl::make_unique>(0, 9); auto expected = LiteralUtil::CreateR2FromArray2D(*expected_array); EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); @@ -628,7 +629,7 @@ TEST_P(HloEvaluatorTest, DotRank2AndRank1) { // { 3 }, // { 4 }, // } - auto lhs_array = MakeUnique>(4, 1); + auto lhs_array = absl::make_unique>(4, 1); lhs_array->FillUnique(1.0f); auto lhs_literal = LiteralUtil::CreateR2FromArray2D(*lhs_array); HloInstruction* lhs_instruction = @@ -679,7 +680,7 @@ TEST_P(HloEvaluatorTest, DotRank1AndRank2) { // { 3, 4 }, // { 5, 6 }, // } - auto rhs_array = MakeUnique>(3, 2); + auto rhs_array = absl::make_unique>(3, 2); rhs_array->FillUnique(1.0f); auto rhs_literal = LiteralUtil::CreateR2FromArray2D(*rhs_array); HloInstruction* rhs_instruction = @@ -710,7 +711,7 @@ TEST_P(HloEvaluatorTest, DotRank2AndRank2) { // { 9, 10, 11 }, // { 13, 14, 15 }, // } - auto lhs_array = MakeUnique>(4, 3); + auto lhs_array = absl::make_unique>(4, 3); lhs_array->FillUnique(1.0f); auto lhs_literal = LiteralUtil::CreateR2FromArray2D(*lhs_array); HloInstruction* lhs_instruction = @@ -722,7 +723,7 @@ TEST_P(HloEvaluatorTest, DotRank2AndRank2) { // { 3, 4 }, // { 5, 6 }, // } - auto rhs_array = MakeUnique>(3, 2); + auto rhs_array = absl::make_unique>(3, 2); rhs_array->FillUnique(1.0f); auto rhs_literal = LiteralUtil::CreateR2FromArray2D(*rhs_array); HloInstruction* rhs_instruction = @@ -1297,7 +1298,7 @@ TEST_P(HloEvaluatorTest, ReduceAdd) { // { 1, 2, 3 }, // { 5, 6, 7 }, // } - auto arg_array = MakeUnique>(2, 3); + auto arg_array = absl::make_unique>(2, 3); arg_array->FillUnique(1.0f); auto arg_literal = LiteralUtil::CreateR2FromArray2D(*arg_array); @@ -1339,7 +1340,7 @@ TEST_P(HloEvaluatorTest, ReduceWindowMax) { // { 1, 2, 3 }, // { 5, 6, 7 }, // } - auto arg_array = MakeUnique>(2, 3); + auto arg_array = absl::make_unique>(2, 3); arg_array->FillUnique(1.0f); auto arg_literal = LiteralUtil::CreateR2FromArray2D(*arg_array); @@ -1390,7 +1391,7 @@ TEST_P(HloEvaluatorTest, ReduceWindowAdd) { // { 1, 2, 3 }, // { 5, 6, 7 }, // } - auto arg_array = MakeUnique>(2, 3); + auto arg_array = absl::make_unique>(2, 3); arg_array->FillUnique(1.0f); auto arg_literal = LiteralUtil::CreateR2FromArray2D(*arg_array); @@ -1511,7 +1512,7 @@ TEST_P(HloEvaluatorTest, StridedSlice) { // { 9, 10, 11, 12, 13 }, // { 17, 18, 19, 20, 21 }, // } - auto operand_array = MakeUnique>(3, 5); + auto operand_array = absl::make_unique>(3, 5); operand_array->FillUnique(1.0f); auto operand_literal = LiteralUtil::CreateR2FromArray2D(*operand_array); @@ -1544,7 +1545,7 @@ TEST_P(HloEvaluatorTest, DynamicSlice) { // { 1, 2, 3, 4 }, // { 5, 6, 7, 8 }, // } - auto operand_array = MakeUnique>(2, 4); + auto operand_array = absl::make_unique>(2, 4); operand_array->FillUnique(1.0f); auto operand_literal = LiteralUtil::CreateR2FromArray2D(*operand_array); @@ -1580,7 +1581,7 @@ TEST_P(HloEvaluatorTest, DynamicSliceModSlice) { // { 1, 2, 3, 4 }, // { 5, 6, 7, 8 }, // } - auto operand_array = MakeUnique>(2, 4); + auto operand_array = absl::make_unique>(2, 4); operand_array->FillUnique(1.0f); auto operand_literal = LiteralUtil::CreateR2FromArray2D(*operand_array); @@ -1614,7 +1615,7 @@ TEST_P(HloEvaluatorTest, DynamicSliceUpdate) { // { 1, 2, 3 }, // { 5, 6, 7 }, // } - auto operand_array = MakeUnique>(2, 3); + auto operand_array = absl::make_unique>(2, 3); operand_array->FillUnique(1.0); auto operand_literal = LiteralUtil::CreateR2FromArray2D(*operand_array); @@ -1651,7 +1652,7 @@ TEST_P(HloEvaluatorTest, SetAndGetTuples) { // { 1, 2, 3 }, // { 5, 6, 7 }, // } - auto operand_array = MakeUnique>(2, 3); + auto operand_array = absl::make_unique>(2, 3); operand_array->FillUnique(1.0); auto operand_literal2 = LiteralUtil::CreateR2FromArray2D(*operand_array); @@ -1687,7 +1688,7 @@ TEST_P(HloEvaluatorTest, SetAndGetNestedTuples) { // { 1, 2, 3 }, // { 5, 6, 7 }, // } - auto operand_array = MakeUnique>(2, 3); + auto operand_array = absl::make_unique>(2, 3); operand_array->FillUnique(1.0); HloInstruction* operand2 = b.AddInstruction(HloInstruction::CreateConstant( @@ -1826,21 +1827,20 @@ ENTRY main { operand = s32[3,3] parameter(0) indices = s32[2] parameter(1) ROOT gather = s32[2,3] gather(operand, indices), - output_window_dims={1}, - elided_window_dims={0}, - gather_dims_to_operand_dims={0}, + offset_dims={1}, + collapsed_slice_dims={0}, + start_index_map={0}, index_vector_dim=1, - window_bounds={1, 3} + slice_sizes={1, 3} } )"; ParseAndVerifyModule(hlo_text); std::unique_ptr operand = LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); - std::unique_ptr gather_indices = - LiteralUtil::CreateR1({0, 2}); + std::unique_ptr start_indices = LiteralUtil::CreateR1({0, 2}); EXPECT_TRUE(LiteralTestUtil::Equal( *LiteralUtil::CreateR2({{1, 2, 3}, {7, 8, 9}}), - *Evaluate({operand.get(), gather_indices.get()}))); + *Evaluate({operand.get(), start_indices.get()}))); } TEST_P(HloEvaluatorTest, EvaluateGather_TensorFlowGatherV2) { @@ -1851,21 +1851,20 @@ ENTRY main { operand = s32[3,3] parameter(0) indices = s32[2] parameter(1) ROOT gather = s32[3,2] gather(operand, indices), - output_window_dims={0}, - elided_window_dims={1}, - gather_dims_to_operand_dims={1}, + offset_dims={0}, + collapsed_slice_dims={1}, + start_index_map={1}, index_vector_dim=1, - window_bounds={3, 1} + slice_sizes={3, 1} } )"; ParseAndVerifyModule(hlo_text); std::unique_ptr operand = LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); - std::unique_ptr gather_indices = - LiteralUtil::CreateR1({0, 2}); + std::unique_ptr start_indices = LiteralUtil::CreateR1({0, 2}); EXPECT_TRUE(LiteralTestUtil::Equal( *LiteralUtil::CreateR2({{1, 3}, {4, 6}, {7, 9}}), - *Evaluate({operand.get(), gather_indices.get()}))); + *Evaluate({operand.get(), start_indices.get()}))); } TEST_P(HloEvaluatorTest, EvaluateGather_TensorFlowGatherMultipleBatchDims) { @@ -1876,22 +1875,22 @@ ENTRY main { operand = s32[3,3] parameter(0) indices = s32[2,2] parameter(1) ROOT gather = s32[2,3,2] gather(operand, indices), - output_window_dims={1}, - elided_window_dims={1}, - gather_dims_to_operand_dims={1}, + offset_dims={1}, + collapsed_slice_dims={1}, + start_index_map={1}, index_vector_dim=2, - window_bounds={3, 1} + slice_sizes={3, 1} } )"; ParseAndVerifyModule(hlo_text); std::unique_ptr operand = LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); - std::unique_ptr gather_indices = + std::unique_ptr start_indices = LiteralUtil::CreateR2({{0, 2}, {2, 1}}); EXPECT_TRUE(LiteralTestUtil::Equal( *LiteralUtil::CreateR3( {{{1, 3}, {4, 6}, {7, 9}}, {{3, 2}, {6, 5}, {9, 8}}}), - *Evaluate({operand.get(), gather_indices.get()}))); + *Evaluate({operand.get(), start_indices.get()}))); } TEST_P(HloEvaluatorTest, EvaluateGather_TensorFlowGatherNd) { @@ -1902,11 +1901,11 @@ ENTRY main { operand = s32[3,3,2] parameter(0) indices = s32[2,2] parameter(1) ROOT gather = s32[2,2] gather(operand, indices), - output_window_dims={1}, - elided_window_dims={0,1}, - gather_dims_to_operand_dims={0,1}, + offset_dims={1}, + collapsed_slice_dims={0,1}, + start_index_map={0,1}, index_vector_dim=1, - window_bounds={1,1,2} + slice_sizes={1,1,2} } )"; ParseAndVerifyModule(hlo_text); @@ -1914,11 +1913,11 @@ ENTRY main { LiteralUtil::CreateR3({{{-1, 1}, {-2, 2}, {-3, 3}}, // {{-4, 4}, {-5, 5}, {-6, 6}}, // {{-7, 7}, {-8, 8}, {-9, 9}}}); - std::unique_ptr gather_indices = + std::unique_ptr start_indices = LiteralUtil::CreateR2({{0, 0}, {1, 0}}); EXPECT_TRUE( LiteralTestUtil::Equal(*LiteralUtil::CreateR2({{-1, 1}, {-4, 4}}), - *Evaluate({operand.get(), gather_indices.get()}))); + *Evaluate({operand.get(), start_indices.get()}))); } TEST_P(HloEvaluatorTest, @@ -1930,11 +1929,11 @@ ENTRY main { operand = s32[3,3,2] parameter(0) indices = s32[2,2] parameter(1) ROOT gather = s32[2,2] gather(operand, indices), - output_window_dims={1}, - elided_window_dims={0,1}, - gather_dims_to_operand_dims={0,1}, + offset_dims={1}, + collapsed_slice_dims={0,1}, + start_index_map={0,1}, index_vector_dim=0, - window_bounds={1,1,2} + slice_sizes={1,1,2} } )"; ParseAndVerifyModule(hlo_text); @@ -1942,11 +1941,11 @@ ENTRY main { LiteralUtil::CreateR3({{{-1, 1}, {-2, 2}, {-3, 3}}, // {{-4, 4}, {-5, 5}, {-6, 6}}, // {{-7, 7}, {-8, 8}, {-9, 9}}}); - std::unique_ptr gather_indices = + std::unique_ptr start_indices = LiteralUtil::CreateR2({{0, 0}, {1, 0}}); EXPECT_TRUE( LiteralTestUtil::Equal(*LiteralUtil::CreateR2({{-2, 2}, {-1, 1}}), - *Evaluate({operand.get(), gather_indices.get()}))); + *Evaluate({operand.get(), start_indices.get()}))); } TEST_P(HloEvaluatorTest, EvaluateGather_DynamicSlice) { @@ -1957,21 +1956,20 @@ ENTRY main { operand = s32[3,3] parameter(0) indices = s32[2] parameter(1) ROOT gather = s32[1,1] gather(operand, indices), - output_window_dims={0,1}, - elided_window_dims={}, - gather_dims_to_operand_dims={0,1}, + offset_dims={0,1}, + collapsed_slice_dims={}, + start_index_map={0,1}, index_vector_dim=0, - window_bounds={1,1} + slice_sizes={1,1} } )"; ParseAndVerifyModule(hlo_text); std::unique_ptr operand = LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); - std::unique_ptr gather_indices = - LiteralUtil::CreateR1({1, 1}); + std::unique_ptr start_indices = LiteralUtil::CreateR1({1, 1}); EXPECT_TRUE( LiteralTestUtil::Equal(*LiteralUtil::CreateR2({{5}}), - *Evaluate({operand.get(), gather_indices.get()}))); + *Evaluate({operand.get(), start_indices.get()}))); } TEST_P(HloEvaluatorTest, EvaluateGather_BatchDynamicSlice) { @@ -1982,21 +1980,21 @@ ENTRY main { operand = s32[3,3] parameter(0) indices = s32[2,2] parameter(1) ROOT gather = s32[2,1,1] gather(operand, indices), - output_window_dims={1,2}, - elided_window_dims={}, - gather_dims_to_operand_dims={0,1}, + offset_dims={1,2}, + collapsed_slice_dims={}, + start_index_map={0,1}, index_vector_dim=0, - window_bounds={1,1} + slice_sizes={1,1} } )"; ParseAndVerifyModule(hlo_text); std::unique_ptr operand = LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); - std::unique_ptr gather_indices = + std::unique_ptr start_indices = LiteralUtil::CreateR2({{2, 1}, {1, 1}}); EXPECT_TRUE( LiteralTestUtil::Equal(*LiteralUtil::CreateR3({{{8}}, {{5}}}), - *Evaluate({operand.get(), gather_indices.get()}))); + *Evaluate({operand.get(), start_indices.get()}))); } TEST_P(HloEvaluatorTest, EvaluateGather_ZeroDimBounds) { @@ -2007,20 +2005,19 @@ ENTRY main { operand = s32[3,0] parameter(0) indices = s32[2] parameter(1) ROOT gather = s32[2,0] gather(operand, indices), - output_window_dims={1}, - elided_window_dims={0}, - gather_dims_to_operand_dims={0}, + offset_dims={1}, + collapsed_slice_dims={0}, + start_index_map={0}, index_vector_dim=1, - window_bounds={1, 0} + slice_sizes={1, 0} } )"; ParseAndVerifyModule(hlo_text); std::unique_ptr operand = LiteralUtil::CreateR2({{}, {}, {}}); - std::unique_ptr gather_indices = - LiteralUtil::CreateR1({0, 2}); + std::unique_ptr start_indices = LiteralUtil::CreateR1({0, 2}); EXPECT_TRUE( LiteralTestUtil::Equal(*LiteralUtil::CreateR2({{}, {}}), - *Evaluate({operand.get(), gather_indices.get()}))); + *Evaluate({operand.get(), start_indices.get()}))); } TEST_P(HloEvaluatorTest, EvaluateGather_NoOutputWindowDims) { @@ -2031,21 +2028,474 @@ ENTRY main { operand = s32[3] parameter(0) indices = s32[2,2,1] parameter(1) ROOT gather = s32[2,2] gather(operand, indices), - output_window_dims={}, - elided_window_dims={0}, - gather_dims_to_operand_dims={0}, + offset_dims={}, + collapsed_slice_dims={0}, + start_index_map={0}, index_vector_dim=2, - window_bounds={1} + slice_sizes={1} } )"; ParseAndVerifyModule(hlo_text); std::unique_ptr operand = LiteralUtil::CreateR1({0, 1, 2}); - std::unique_ptr gather_indices = + std::unique_ptr start_indices = LiteralUtil::CreateR3({{{0}, {1}}, {{2}, {1}}}); EXPECT_TRUE( LiteralTestUtil::Equal(*LiteralUtil::CreateR2({{0, 1}, {2, 1}}), - *Evaluate({operand.get(), gather_indices.get()}))); + *Evaluate({operand.get(), start_indices.get()}))); +} + +TEST_P(HloEvaluatorTest, EvaluateScatter_TensorFlowScatterV1_Update) { + const char* hlo_text = R"( +HloModule TensorFlowScatterV1 + +update_s32 (lhs: s32[], rhs: s32[]) -> s32[] { + lhs = s32[] parameter(0) + ROOT rhs = s32[] parameter(1) +} + +ENTRY main { + operand = s32[3,3] parameter(0) + indices = s32[2] parameter(1) + updates = s32[2,3] parameter(2) + ROOT scatter = s32[3,3] scatter(operand, indices, updates), + to_apply=update_s32, + update_window_dims={1}, + inserted_window_dims={0}, + scatter_dims_to_operand_dims={0}, + index_vector_dim=1 +} +)"; + ParseAndVerifyModule(hlo_text); + std::unique_ptr operand = + LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); + std::unique_ptr scatter_indices = + LiteralUtil::CreateR1({0, 2}); + std::unique_ptr updates = + LiteralUtil::CreateR2({{10, 20, 30}, {70, 80, 90}}); + EXPECT_TRUE(LiteralTestUtil::Equal( + *LiteralUtil::CreateR2({{10, 20, 30}, {4, 5, 6}, {70, 80, 90}}), + *Evaluate({operand.get(), scatter_indices.get(), updates.get()}))); +} + +TEST_P(HloEvaluatorTest, EvaluateScatter_TensorFlowScatterV2_Update) { + const char* hlo_text = R"( +HloModule TensorFlowScatterV2 + +update_s32 (lhs: s32[], rhs: s32[]) -> s32[] { + lhs = s32[] parameter(0) + ROOT rhs = s32[] parameter(1) +} + +ENTRY main { + operand = s32[3,3] parameter(0) + indices = s32[2] parameter(1) + updates = s32[3,2] parameter(2) + ROOT scatter = s32[3,3] scatter(operand, indices, updates), + to_apply=update_s32, + update_window_dims={0}, + inserted_window_dims={1}, + scatter_dims_to_operand_dims={1}, + index_vector_dim=1 +} +)"; + ParseAndVerifyModule(hlo_text); + std::unique_ptr operand = + LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); + std::unique_ptr scatter_indices = + LiteralUtil::CreateR1({0, 2}); + std::unique_ptr updates = + LiteralUtil::CreateR2({{10, 30}, {40, 60}, {70, 90}}); + EXPECT_TRUE(LiteralTestUtil::Equal( + *LiteralUtil::CreateR2({{10, 2, 30}, {40, 5, 60}, {70, 8, 90}}), + *Evaluate({operand.get(), scatter_indices.get(), updates.get()}))); +} + +TEST_P(HloEvaluatorTest, EvaluateScatter_TensorFlowScatter_Add) { + const char* hlo_text = R"( +HloModule TensorFlowScatter + +add_s32 (lhs: s32[], rhs: s32[]) -> s32[] { + lhs = s32[] parameter(0) + rhs = s32[] parameter(1) + ROOT add = s32[] add(s32[] lhs, s32[] rhs) +} + +ENTRY main { + operand = s32[3,3] parameter(0) + indices = s32[2] parameter(1) + updates = s32[2,3] parameter(2) + ROOT scatter = s32[3,3] scatter(operand, indices, updates), + to_apply=add_s32, + update_window_dims={1}, + inserted_window_dims={0}, + scatter_dims_to_operand_dims={0}, + index_vector_dim=1 +} +)"; + ParseAndVerifyModule(hlo_text); + std::unique_ptr operand = + LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); + std::unique_ptr scatter_indices = + LiteralUtil::CreateR1({0, 2}); + std::unique_ptr updates = + LiteralUtil::CreateR2({{10, 20, 30}, {70, 80, 90}}); + EXPECT_TRUE(LiteralTestUtil::Equal( + *LiteralUtil::CreateR2({{11, 22, 33}, {4, 5, 6}, {77, 88, 99}}), + *Evaluate({operand.get(), scatter_indices.get(), updates.get()}))); +} + +TEST_P(HloEvaluatorTest, EvaluateScatter_TensorFlowScatter_Mul) { + const char* hlo_text = R"( +HloModule TensorFlowScatter + +mul_s32 (lhs: s32[], rhs: s32[]) -> s32[] { + lhs = s32[] parameter(0) + rhs = s32[] parameter(1) + ROOT mul = s32[] multiply(s32[] lhs, s32[] rhs) +} + +ENTRY main { + operand = s32[3,3] parameter(0) + indices = s32[2] parameter(1) + updates = s32[2,3] parameter(2) + ROOT scatter = s32[3,3] scatter(operand, indices, updates), + to_apply=mul_s32, + update_window_dims={1}, + inserted_window_dims={0}, + scatter_dims_to_operand_dims={0}, + index_vector_dim=1 +} +)"; + ParseAndVerifyModule(hlo_text); + std::unique_ptr operand = + LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); + std::unique_ptr scatter_indices = + LiteralUtil::CreateR1({0, 2}); + std::unique_ptr updates = + LiteralUtil::CreateR2({{10, 20, 30}, {70, 80, 90}}); + EXPECT_TRUE(LiteralTestUtil::Equal( + *LiteralUtil::CreateR2({{10, 40, 90}, {4, 5, 6}, {490, 640, 810}}), + *Evaluate({operand.get(), scatter_indices.get(), updates.get()}))); +} + +TEST_P(HloEvaluatorTest, EvaluateScatter_TensorFlowScatter_F32) { + const char* hlo_text = R"( +HloModule TensorFlowScatter + +add_f32 (lhs: f32[], rhs: f32[]) -> f32[] { + lhs = f32[] parameter(0) + rhs = f32[] parameter(1) + ROOT add = f32[] add(f32[] lhs, f32[] rhs) +} + +ENTRY main { + operand = f32[3,3] parameter(0) + indices = s32[2] parameter(1) + updates = f32[2,3] parameter(2) + ROOT scatter = f32[3,3] scatter(operand, indices, updates), + to_apply=add_f32, + update_window_dims={1}, + inserted_window_dims={0}, + scatter_dims_to_operand_dims={0}, + index_vector_dim=1 +} +)"; + ParseAndVerifyModule(hlo_text); + std::unique_ptr operand = LiteralUtil::CreateR2( + {{1.1, 2.2, 3.3}, {4.4, 5.5, 6.6}, {7.7, 8.8, 9.9}}); + std::unique_ptr scatter_indices = + LiteralUtil::CreateR1({2, 1}); + std::unique_ptr updates = + LiteralUtil::CreateR2({{0.4, 1.1, 0.7}, {2.3, 3.1, 1.6}}); + EXPECT_TRUE(LiteralTestUtil::Near( + *LiteralUtil::CreateR2( + {{1.1, 2.2, 3.3}, {6.7, 8.6, 8.2}, {8.1, 9.9, 10.6}}), + *Evaluate({operand.get(), scatter_indices.get(), updates.get()}), + ErrorSpec{0.1, 0.01})); +} + +TEST_P(HloEvaluatorTest, EvaluateScatter_TensorFlowScatter_RepeatedIndices) { + const char* hlo_text = R"( +HloModule TensorFlowScatter + +add_s32 (lhs: s32[], rhs: s32[]) -> s32[] { + lhs = s32[] parameter(0) + rhs = s32[] parameter(1) + ROOT add = s32[] add(s32[] lhs, s32[] rhs) +} + +ENTRY main { + operand = s32[3,3] parameter(0) + indices = s32[2] parameter(1) + updates = s32[2,3] parameter(2) + ROOT scatter = s32[3,3] scatter(operand, indices, updates), + to_apply=add_s32, + update_window_dims={1}, + inserted_window_dims={0}, + scatter_dims_to_operand_dims={0}, + index_vector_dim=1 +} +)"; + ParseAndVerifyModule(hlo_text); + std::unique_ptr operand = + LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); + std::unique_ptr scatter_indices = + LiteralUtil::CreateR1({1, 1}); + std::unique_ptr updates = + LiteralUtil::CreateR2({{10, 20, 30}, {70, 80, 90}}); + EXPECT_TRUE(LiteralTestUtil::Equal( + *LiteralUtil::CreateR2({{1, 2, 3}, {84, 105, 126}, {7, 8, 9}}), + *Evaluate({operand.get(), scatter_indices.get(), updates.get()}))); +} + +TEST_P(HloEvaluatorTest, EvaluateScatter_TensorFlowScatter_MultipleBatchDims) { + const char* hlo_text = R"( +HloModule TensorFlowScatterMultipleBatchDims + +add_s32 (lhs: s32[], rhs: s32[]) -> s32[] { + lhs = s32[] parameter(0) + rhs = s32[] parameter(1) + ROOT add = s32[] add(s32[] lhs, s32[] rhs) +} + +ENTRY main { + operand = s32[3,3] parameter(0) + indices = s32[2,2] parameter(1) + updates = s32[2,3,2] parameter(2) + ROOT scatter = s32[3,3] scatter(operand, indices, updates), + to_apply=add_s32, + update_window_dims={1}, + inserted_window_dims={1}, + scatter_dims_to_operand_dims={1}, + index_vector_dim=2 +} +)"; + ParseAndVerifyModule(hlo_text); + std::unique_ptr operand = + LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); + std::unique_ptr scatter_indices = + LiteralUtil::CreateR2({{0, 2}, {2, 1}}); + std::unique_ptr updates = LiteralUtil::CreateR3( + {{{10, 30}, {40, 60}, {70, 90}}, {{5, 5}, {5, 5}, {5, 5}}}); + EXPECT_TRUE(LiteralTestUtil::Equal( + *LiteralUtil::CreateR2({{11, 7, 38}, {44, 10, 71}, {77, 13, 104}}), + *Evaluate({operand.get(), scatter_indices.get(), updates.get()}))); +} + +TEST_P(HloEvaluatorTest, EvaluateScatter_TensorFlowScatterNd) { + const char* hlo_text = R"( +HloModule TensorFlowScatterNd + +update_s32 (lhs: s32[], rhs: s32[]) -> s32[] { + lhs = s32[] parameter(0) + ROOT rhs = s32[] parameter(1) +} + +ENTRY main { + operand = s32[3,3,2] parameter(0) + indices = s32[2,2] parameter(1) + updates = s32[2,2] parameter(2) + ROOT scatter = s32[3,3,2] scatter(operand, indices, updates), + to_apply=update_s32, + update_window_dims={1}, + inserted_window_dims={0,1}, + scatter_dims_to_operand_dims={0,1}, + index_vector_dim=1 +} +)"; + ParseAndVerifyModule(hlo_text); + std::unique_ptr operand = + LiteralUtil::CreateR3({{{-1, 1}, {-2, 2}, {-3, 3}}, // + {{-4, 4}, {-5, 5}, {-6, 6}}, // + {{-7, 7}, {-8, 8}, {-9, 9}}}); + std::unique_ptr scatter_indices = + LiteralUtil::CreateR2({{0, 0}, {1, 0}}); + std::unique_ptr updates = + LiteralUtil::CreateR2({{-10, 10}, {-40, 40}}); + std::unique_ptr expected = + LiteralUtil::CreateR3({{{-10, 10}, {-2, 2}, {-3, 3}}, // + {{-40, 40}, {-5, 5}, {-6, 6}}, // + {{-7, 7}, {-8, 8}, {-9, 9}}}); + EXPECT_TRUE(LiteralTestUtil::Equal( + *expected, + *Evaluate({operand.get(), scatter_indices.get(), updates.get()}))); +} + +TEST_P(HloEvaluatorTest, + EvaluateScatter_TensorFlowScatterNd_NonDefaultIndexVectorDim) { + const char* hlo_text = R"( +HloModule TensorFlowScatterNdNonDefaultIndexVectorDim + +update_s32 (lhs: s32[], rhs: s32[]) -> s32[] { + lhs = s32[] parameter(0) + ROOT rhs = s32[] parameter(1) +} + +ENTRY main { + operand = s32[3,3,2] parameter(0) + indices = s32[2,2] parameter(1) + updates = s32[2,2] parameter(2) + ROOT scatter = s32[3,3,2] scatter(operand, indices, updates), + to_apply=update_s32, + update_window_dims={1}, + inserted_window_dims={0,1}, + scatter_dims_to_operand_dims={0,1}, + index_vector_dim=0 +} +)"; + ParseAndVerifyModule(hlo_text); + std::unique_ptr operand = + LiteralUtil::CreateR3({{{-1, 1}, {-2, 2}, {-3, 3}}, // + {{-4, 4}, {-5, 5}, {-6, 6}}, // + {{-7, 7}, {-8, 8}, {-9, 9}}}); + std::unique_ptr scatter_indices = + LiteralUtil::CreateR2({{0, 0}, {1, 0}}); + std::unique_ptr updates = + LiteralUtil::CreateR2({{-10, 10}, {-20, 20}}); + std::unique_ptr expected = + LiteralUtil::CreateR3({{{-20, 20}, {-10, 10}, {-3, 3}}, // + {{-4, 4}, {-5, 5}, {-6, 6}}, // + {{-7, 7}, {-8, 8}, {-9, 9}}}); + EXPECT_TRUE(LiteralTestUtil::Equal( + *expected, + *Evaluate({operand.get(), scatter_indices.get(), updates.get()}))); +} + +TEST_P(HloEvaluatorTest, EvaluateScatter_DynamicUpdateSlice) { + const char* hlo_text = R"( +HloModule DynamicUpdateSlice + +update_s32 (lhs: s32[], rhs: s32[]) -> s32[] { + lhs = s32[] parameter(0) + ROOT rhs = s32[] parameter(1) +} + +ENTRY main { + operand = s32[3,3] parameter(0) + indices = s32[2] parameter(1) + updates = s32[1,1] parameter(2) + ROOT scatter = s32[3,3] scatter(operand, indices, updates), + to_apply=update_s32, + update_window_dims={0,1}, + inserted_window_dims={}, + scatter_dims_to_operand_dims={0,1}, + index_vector_dim=0 +} +)"; + ParseAndVerifyModule(hlo_text); + std::unique_ptr operand = + LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); + std::unique_ptr scatter_indices = + LiteralUtil::CreateR1({1, 1}); + std::unique_ptr updates = LiteralUtil::CreateR2({{10}}); + std::unique_ptr expected = + LiteralUtil::CreateR2({{1, 2, 3}, {4, 10, 6}, {7, 8, 9}}); + EXPECT_TRUE(LiteralTestUtil::Equal( + *expected, + *Evaluate({operand.get(), scatter_indices.get(), updates.get()}))); +} + +TEST_P(HloEvaluatorTest, EvaluateScatter_BatchDynamicUpdateSlice) { + const char* hlo_text = R"( +HloModule BatchDynamicUpdateSlice + +update_s32 (lhs: s32[], rhs: s32[]) -> s32[] { + lhs = s32[] parameter(0) + ROOT rhs = s32[] parameter(1) +} + +ENTRY main { + operand = s32[3,3] parameter(0) + indices = s32[2,2] parameter(1) + updates = s32[2,1,1] parameter(2) + ROOT scatter = s32[3,3] scatter(operand, indices, updates), + to_apply=update_s32, + update_window_dims={1,2}, + inserted_window_dims={}, + scatter_dims_to_operand_dims={0,1}, + index_vector_dim=0 +} +)"; + ParseAndVerifyModule(hlo_text); + std::unique_ptr operand = + LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); + std::unique_ptr scatter_indices = + LiteralUtil::CreateR2({{2, 1}, {1, 1}}); + std::unique_ptr updates = + LiteralUtil::CreateR3({{{10}}, {{20}}}); + std::unique_ptr expected = + LiteralUtil::CreateR2({{1, 2, 3}, {4, 20, 6}, {7, 10, 9}}); + EXPECT_TRUE(LiteralTestUtil::Equal( + *expected, + *Evaluate({operand.get(), scatter_indices.get(), updates.get()}))); +} + +TEST_P(HloEvaluatorTest, EvaluateScatter_ZeroDimBounds) { + const char* hlo_text = R"( +HloModule TensorFlowScatter_ZeroDimBounds + +update_s32 (lhs: s32[], rhs: s32[]) -> s32[] { + lhs = s32[] parameter(0) + ROOT rhs = s32[] parameter(1) +} + +ENTRY main { + operand = s32[3,0] parameter(0) + indices = s32[2] parameter(1) + updates = s32[2,0] parameter(2) + ROOT scatter = s32[3,0] scatter(operand, indices, updates), + to_apply=update_s32, + update_window_dims={1}, + inserted_window_dims={0}, + scatter_dims_to_operand_dims={0}, + index_vector_dim=1 +} +)"; + ParseAndVerifyModule(hlo_text); + std::unique_ptr operand = LiteralUtil::CreateR2({{}, {}, {}}); + std::unique_ptr scatter_indices = + LiteralUtil::CreateR1({0, 2}); + std::unique_ptr updates = LiteralUtil::CreateR2({{}, {}}); + EXPECT_TRUE(LiteralTestUtil::Equal( + *operand, + *Evaluate({operand.get(), scatter_indices.get(), updates.get()}))); +} + +TEST_P(HloEvaluatorTest, EvaluateScatter_NoUpdateWindowDims) { + const string hlo_text = R"( +HloModule Scatter_NoUpdateWindowDims + +add_s32 (lhs: s32[], rhs: s32[]) -> s32[] { + lhs = s32[] parameter(0) + rhs = s32[] parameter(1) + ROOT add = s32[] add(s32[] lhs, s32[] rhs) +} + +ENTRY main { + operand = s32[3] parameter(0) + indices = s32[2,2,1] parameter(1) + updates = s32[2,2] parameter(2) + ROOT scatter = s32[3] scatter(operand, indices, updates), + to_apply=add_s32, + update_window_dims={}, + inserted_window_dims={0}, + scatter_dims_to_operand_dims={0}, + index_vector_dim=2 +} +)"; + ParseAndVerifyModule(hlo_text); + + std::unique_ptr operand = LiteralUtil::CreateR1({0, 1, 2}); + std::unique_ptr scatter_indices = + LiteralUtil::CreateR3({{{0}, {1}}, {{2}, {1}}}); + std::unique_ptr updates = + LiteralUtil::CreateR2({{10, 20}, {30, 40}}); + std::unique_ptr expected = + LiteralUtil::CreateR1({10, 61, 32}); + EXPECT_TRUE(LiteralTestUtil::Equal( + *expected, + *Evaluate({operand.get(), scatter_indices.get(), updates.get()}))); } // Verifies that HloEvaluator evaluates a HLO instruction that performs @@ -2064,6 +2514,31 @@ TEST_P(HloEvaluatorTest, DoesCompareBF16) { std::move(rhs)); } +TEST_P(HloEvaluatorTest, Bf16Reduction) { + const string hlo_text = R"( +HloModule Bf16Reduction + +add_bf16 (lhs: bf16[], rhs: bf16[]) -> bf16[] { + lhs = bf16[] parameter(0) + rhs = bf16[] parameter(1) + ROOT add = bf16[] add(bf16[] lhs, bf16[] rhs) +} + +ENTRY main { + arg0 = bf16[4]{0} parameter(0) + init = bf16[] constant(0) + ROOT %reduce = bf16[] reduce(arg0, init), dimensions={0}, to_apply=add_bf16 +} +)"; + ParseAndVerifyModule(hlo_text); + + std::unique_ptr arg = LiteralUtil::CreateR1( + {bfloat16(1.0f), bfloat16(3.0f), bfloat16(-2.0f), bfloat16(42.0f)}); + std::unique_ptr expected = + LiteralUtil::CreateR0(bfloat16(44.0f)); + EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *Evaluate({arg.get()}))); +} + INSTANTIATE_TEST_CASE_P(HloEvaluatorTest_Instantiation, HloEvaluatorTest, ::testing::ValuesIn(use_bf16_params)); diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h index d5b4be7e1284509a4494b0e804e5396c7cfcecc2..83d7b404f0bd51a401674c0235296a217e8dbef1 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h +++ b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h @@ -16,6 +16,8 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_EVALUATOR_TYPED_VISITOR_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_EVALUATOR_TYPED_VISITOR_H_ +#include "absl/algorithm/container.h" +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/service/hlo_evaluator.h" #include "tensorflow/compiler/xla/service/shape_inference.h" @@ -86,6 +88,29 @@ bool SafeLess(const NativeT& a, const NativeT& b) { // of this class. template class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { + private: + // Get the value in the given literal static_cast as a double. + template < + typename NativeT, + typename std::enable_if::value>::type* = nullptr> + double GetAsDouble(const Literal& literal, + tensorflow::gtl::ArraySlice input_index) { + return static_cast(literal.Get(input_index)); + } + + // Specialization for complex types. In this case it is not possible to + // static_cast value to a double so just CHECK fail. This method is not used + // at run-time, but must be available at compile-time to keep the compiler + // happy. + template < + typename NativeT, + typename std::enable_if::value>::type* = nullptr> + double GetAsDouble(const Literal& literal, + tensorflow::gtl::ArraySlice input_index) { + LOG(FATAL) << "Trying to get complex literal as double: " + << literal.ToString(); + } + public: explicit HloEvaluatorTypedVisitor(HloEvaluator* p) : parent_(p) {} @@ -873,7 +898,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { << ShapeUtil::HumanString(inferred_return_shape); const Literal& operand_literal = parent_->GetEvaluatedLiteralFor(operand); - auto result = MakeUnique(result_shape); + auto result = absl::make_unique(result_shape); TF_RETURN_IF_ERROR(result->Populate( [&](tensorflow::gtl::ArraySlice out_index) { @@ -1030,7 +1055,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { return static_cast(result_val); }; - auto result = MakeUnique(result_shape); + auto result = absl::make_unique(result_shape); TF_RETURN_IF_ERROR(result->PopulateParallel(func)); parent_->evaluated_[conv] = std::move(result); @@ -1104,7 +1129,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { } } - auto result = MakeUnique(dot->shape()); + auto result = absl::make_unique(dot->shape()); TF_RETURN_IF_ERROR(result->Populate( [&](tensorflow::gtl::ArraySlice result_index) { ElementwiseT result_val = static_cast(0); @@ -1153,7 +1178,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { // Create new HLO of padded shape with padding value. ReturnT scalar = parent_->GetEvaluatedLiteralFor(pad->operand(1)).Get({}); - auto result = MakeUnique(pad->shape()); + auto result = absl::make_unique(pad->shape()); TF_RETURN_IF_ERROR(result->Populate( [&scalar](tensorflow::gtl::ArraySlice multi_index) { return scalar; @@ -1318,7 +1343,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { auto operands = map->operands(); HloComputation* computation = map->to_apply(); - auto result = MakeUnique(map->shape()); + auto result = absl::make_unique(map->shape()); HloEvaluator embedded_evaluator(parent_->max_loop_iterations_); TF_RETURN_IF_ERROR(result->Populate( @@ -1432,7 +1457,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { [](const ReturnT& a, const ReturnT& b) { return SafeLess(a, b); }); - auto result_literal = MakeUnique(keys_literal.shape()); + auto result_literal = absl::make_unique(keys_literal.shape()); result_literal->PopulateR1( tensorflow::gtl::ArraySlice(result_data)); VLOG(3) << "HandleSort result_literal: " << result_literal->ToString(); @@ -1444,7 +1469,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { } else { // For R2 sort, the desired semantics are to sort each matrix row // independently. - auto result_literal = MakeUnique(keys_literal.shape()); + auto result_literal = absl::make_unique(keys_literal.shape()); int64 r1_length = keys->shape().dimensions(1); for (int64 row = 0; row < keys->shape().dimensions(0); ++row) { TF_ASSIGN_OR_RETURN(auto r1_slice, @@ -1473,6 +1498,10 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { } Status HandleReduce(HloInstruction* reduce) override { + // TODO(b/112040122): Support variadic reduce. + if (!ShapeUtil::IsArray(reduce->shape())) { + return Unimplemented("Variadic reduce is not supported in the Evaluator"); + } auto arg = reduce->operand(0); auto init_value = reduce->operand(1); tensorflow::gtl::ArraySlice dimensions(reduce->dimensions()); @@ -1481,8 +1510,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { ShapeUtil::Rank(arg->shape()) - dimensions.size()); TF_ASSIGN_OR_RETURN(auto inferred_return_shape, ShapeInference::InferReduceShape( - /*arg=*/arg->shape(), - /*init_value=*/init_value->shape(), + {&arg->shape(), &init_value->shape()}, /*dimensions_to_reduce=*/dimensions, /*to_apply=*/function->ComputeProgramShape())); TF_RET_CHECK(ShapeUtil::Compatible(reduce->shape(), inferred_return_shape)) @@ -1515,7 +1543,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { } HloEvaluator embedded_evaluator(parent_->max_loop_iterations_); - auto result = MakeUnique(reduce->shape()); + auto result = absl::make_unique(reduce->shape()); // For each resulting dimension, calculate and assign computed value. TF_RETURN_IF_ERROR(result->Populate( [&](tensorflow::gtl::ArraySlice multi_index) { @@ -1533,7 +1561,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { IsScalarAdd(function)) { double computed_result = 0; auto func = [&](tensorflow::gtl::ArraySlice input_index) { - computed_result += arg_literal.Get(input_index); + computed_result += GetAsDouble(arg_literal, input_index); return true; }; ShapeUtil::ForEachIndex(arg_literal.shape(), base, arg_dim_counts, @@ -1596,7 +1624,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { TF_RET_CHECK(ShapeUtil::IsScalar(init_literal.shape())); auto init_scalar = init_literal.Get({}); - auto result = MakeUnique(select_and_scatter->shape()); + auto result = absl::make_unique(select_and_scatter->shape()); // Initialize result array with the init value. TF_RETURN_IF_ERROR(result->Populate( @@ -1732,7 +1760,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { DimensionVector operand_index(ShapeUtil::Rank(operand_literal.shape())); HloEvaluator embedded_evaluator(parent_->max_loop_iterations_); - auto result = MakeUnique(reduce_window->shape()); + auto result = absl::make_unique(reduce_window->shape()); // For each resulting dimension, calculate and assign computed value. TF_RETURN_IF_ERROR(result->Populate( [&](tensorflow::gtl::ArraySlice output_index) { @@ -1772,6 +1800,388 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { return Status::OK(); } + // Reshapes the scatter indices input to have a trailing degenerate `1` + // dimension if necessary. Hands over the ownership of the newly created + // literal (if there is one) to `reshaped_indices`. + StatusOr> ReshapedScatterIndices( + int64 index_vector_dim, const Literal& indices, + std::unique_ptr* reshaped_indices) { + if (indices.shape().dimensions_size() != index_vector_dim) { + return std::cref(indices); + } + + std::vector new_shape(indices.shape().dimensions().begin(), + indices.shape().dimensions().end()); + new_shape.push_back(1); + TF_ASSIGN_OR_RETURN(*reshaped_indices, indices.Reshape(new_shape)); + return std::cref(**reshaped_indices); + } + + // Returns an ShapeUtil::IndexIterationSpace that iterates over the update + // scatter dimensions while keeping the rest of the update dimensions clamped + // to 0. + ShapeUtil::IndexIterationSpace IterationSpaceForUpdateScatterIndices( + const Shape& updates_shape, const ScatterDimensionNumbers& dim_numbers) { + int64 updates_rank = updates_shape.dimensions_size(); + std::vector index_base(updates_rank, 0); + std::vector index_count(updates_rank, 1); + for (int64 i = 0; i < updates_rank; i++) { + bool is_update_scatter_dim = + !absl::c_binary_search(dim_numbers.update_window_dims(), i); + if (is_update_scatter_dim) { + index_count[i] = updates_shape.dimensions(i); + } + } + return {std::move(index_base), std::move(index_count), + std::vector(updates_rank, 1)}; + } + + // Return an ShapeUtil::IndexIterationSpace that iterates over the update + // window dimensions while keeping the rest of the update dimensions clamped + // to 0. + ShapeUtil::IndexIterationSpace IterationSpaceForUpdateWindowIndices( + const Shape& updates_shape, const ScatterDimensionNumbers& dim_numbers) { + int64 updates_rank = updates_shape.dimensions_size(); + std::vector index_base(updates_rank, 0); + std::vector index_count(updates_rank, 1); + for (int64 i = 0; i < updates_rank; i++) { + bool is_update_window_dim = + absl::c_binary_search(dim_numbers.update_window_dims(), i); + if (is_update_window_dim) { + index_count[i] = updates_shape.dimensions(i); + } + } + return {std::move(index_base), std::move(index_count), + std::vector(updates_rank, 1)}; + } + + // This functor computes the contribution of scatter_indices to an input index + // corresponding to an update index. That is, given an update index I, it + // picks out the scatter indices in I and uses them to look up a scatter + // index, S, from the scatter indices tensor, and expands S into the input + // space according to scatter_dims_to_operand_dims. + // + // This is similar to the class HloEvaluator::OutputGatherIndexToInputIndex + // that does the corresponding function for Gather. + class UpdateScatterIndexToInputIndex { + public: + // The constructor does some setup work that is amortized across all + // iterations. + explicit UpdateScatterIndexToInputIndex( + const ScatterDimensionNumbers* dim_numbers, const Shape& input_shape, + const Shape& updates_shape, const Literal* scatter_indices) + : dim_numbers_(*dim_numbers), scatter_indices_(*scatter_indices) { + for (int64 i = 0; i < updates_shape.dimensions_size(); i++) { + update_dim_is_scatter_dims_.push_back( + !absl::c_binary_search(dim_numbers_.update_window_dims(), i)); + } + + for (int64 i = 0; i < input_shape.dimensions_size(); i++) { + int64 index_of_input_dim_in_index_vector = + FindIndex(dim_numbers_.scatter_dims_to_operand_dims(), i); + if (index_of_input_dim_in_index_vector == + dim_numbers_.scatter_dims_to_operand_dims_size()) { + input_dim_value_to_index_vector_.push_back(-1); + } else { + input_dim_value_to_index_vector_.push_back( + index_of_input_dim_in_index_vector); + } + } + + index_vector_index_.resize(scatter_indices_.shape().dimensions_size()); + input_index_.resize(input_shape.dimensions_size()); + int64 index_vector_size = + scatter_indices_.shape().dimensions(dim_numbers_.index_vector_dim()); + index_vector_.resize(index_vector_size); + } + + // Returns the contribution of scatter_indices to the input index + // corresponding to update_index. See scatter_inner_loop_body. + // + // This is conceptually a stateless transformation from update_index to the + // scatter input index, but: + // + // - Instead of allocating memory to represent the scatter input index on + // every invocation we reuse the same storage for the result + // (input_index_), mutating it in place. + // - Instead of allocating buffers for temporary values like + // index_vector_index_ and index_vector on every invocation, we reuse the + // same storage for all invocations. + // + // This returns an arrayslice into memory owned by the class. + StatusOr> operator()( + tensorflow::gtl::ArraySlice update_index) { + PropagateUpdateIndexScatterDimsToIndexVectorIndex(update_index); + TF_RETURN_IF_ERROR(FetchIndexVector()); + PropagateIndexVectorToInputIndex(); + return tensorflow::gtl::ArraySlice(input_index_); + } + + private: + // Propagates the scatter index dimensions from the update index into + // index_vector_index_ by mutating index_vector_index_ in place. Does not + // update the dim_numbers.index_vector_dim() dimension -- that's the + // dimension we iterate over in FetchIndexVector. + void PropagateUpdateIndexScatterDimsToIndexVectorIndex( + tensorflow::gtl::ArraySlice update_index) { + int64 index_vector_index_i = 0; + for (int64 i = 0, e = update_index.size(); i < e; i++) { + if (!update_dim_is_scatter_dims_[i]) { + continue; + } + + if (index_vector_index_i == dim_numbers_.index_vector_dim()) { + index_vector_index_i++; + } + + index_vector_index_[index_vector_index_i++] = update_index[i]; + } + } + + // Populates index_vector_ by iterating over scatter_indices_ according to + // index_vector_index_. + Status FetchIndexVector() { + int64 index_vector_dim = dim_numbers_.index_vector_dim(); + for (int64 i = 0, e = index_vector_.size(); i < e; i++) { + index_vector_index_[index_vector_dim] = i; + TF_ASSIGN_OR_RETURN(index_vector_[i], scatter_indices_.GetIntegralAsS64( + index_vector_index_)); + } + return Status::OK(); + } + + // Populates input_index_. + void PropagateIndexVectorToInputIndex() { + for (int64 i = 0, e = input_index_.size(); i < e; i++) { + if (input_dim_value_to_index_vector_[i] != -1) { + input_index_[i] = index_vector_[input_dim_value_to_index_vector_[i]]; + } + + // If input_dim_value_to_index_vector_[i] == -1 then input_index_[i] + // remains 0, as set by the constructor. + } + } + + // input_dim_value_to_index_vector_[i] tells us how to compute dimension i + // of the input index from the index vector. See + // PropagateIndexVectorToInputIndex. + std::vector input_dim_value_to_index_vector_; + + // update_dim_is_scatter_dims_[i] is true iff the update index i is a + // scatter dimension. + std::vector update_dim_is_scatter_dims_; + + // The buffer into which we construct an index into scatter_indices_ to + // fetch the index vector. + std::vector index_vector_index_; + + // The index vector fetched from scatter_indices_. + std::vector index_vector_; + + // The result computed by this functor. operator() returns an ArraySlice + // into this vector. + std::vector input_index_; + + const ScatterDimensionNumbers& dim_numbers_; + const Literal& scatter_indices_; + }; + + // This functor computes the contribution of the window indices in an update + // index to an input index. That is, given an update index I it picks out the + // update window indices in I and expands it into a window index into the + // input shape. + // + // This is similar to the class HloEvaluator::OutputWindowIndexToInputIndex + // that does the corresponding function for Gather. + class UpdateWindowIndexToInputIndex { + public: + // The constructor does some setup work that is amortized across all + // iterations. + explicit UpdateWindowIndexToInputIndex( + const ScatterDimensionNumbers& dim_numbers, const Shape& input_shape, + const Shape& updates_shape) { + std::vector window_index_to_update_index; + int64 update_index_count = 0; + for (int64 i = 0; i < updates_shape.dimensions_size(); i++) { + if (absl::c_binary_search(dim_numbers.update_window_dims(), i)) { + window_index_to_update_index.push_back(update_index_count++); + } else { + update_index_count++; + } + } + + int64 window_dim_count = 0; + for (int64 i = 0; i < input_shape.dimensions_size(); i++) { + if (absl::c_binary_search(dim_numbers.inserted_window_dims(), i)) { + input_dim_value_to_update_index_.push_back(-1); + } else { + input_dim_value_to_update_index_.push_back( + window_index_to_update_index[window_dim_count++]); + } + } + + input_index_.resize(input_shape.dimensions_size()); + } + + // Returns the contribution of the window indices to the input index + // corresponding to update_index. See scatter_inner_loop_body. + // + // This is conceptually a stateless transformation from update_index to the + // window input index, but instead of allocating memory to represent the + // scatter input index on every invocation we reuse the same storage for the + // result (input_index_), mutating it in place. + // + // This returns an arrayslice into memory owned by the class. + StatusOr> operator()( + tensorflow::gtl::ArraySlice update_index) { + PropagateUpdateIndexWindowDimsToInputIndex(update_index); + return tensorflow::gtl::ArraySlice(input_index_); + } + + // Returns for a given 'input_dim' the corresponding update dimension index, + // or -1 if 'input_dim' is an elided window dimension. + int64 input_dim_value_to_update_index(int64 input_dim) { + return input_dim_value_to_update_index_[input_dim]; + } + + private: + // Propagates window dimensions from the update index to input_index_ by + // mutating input_index_ in place. + void PropagateUpdateIndexWindowDimsToInputIndex( + tensorflow::gtl::ArraySlice update_index) { + for (int64 i = 0, e = input_index_.size(); i < e; i++) { + if (input_dim_value_to_update_index_[i] != -1) { + input_index_[i] = update_index[input_dim_value_to_update_index_[i]]; + } + + // If input_dim_value_to_index_vector_[i] == -1 then input_index_[i] + // remains 0, as set by the constructor. + } + } + + // input_dim_value_to_index_vector_[i] tells us how to compute dimension i + // of the input index from the update index. See + // PropagateUpdateIndexWindowDimsToInputIndex. + std::vector input_dim_value_to_update_index_; + + // The result computed by this functor. operator() returns an ArraySlice + // into this vector. + std::vector input_index_; + }; + + Status HandleScatter(HloInstruction* scatter) override { + const ScatterDimensionNumbers& dim_numbers = + scatter->scatter_dimension_numbers(); + const Literal& operand = + parent_->GetEvaluatedLiteralFor(scatter->operand(0)); + std::unique_ptr reshaped_scatter_indices; + TF_ASSIGN_OR_RETURN(const Literal& scatter_indices, + ReshapedScatterIndices(dim_numbers.index_vector_dim(), + parent_->GetEvaluatedLiteralFor( + scatter->operand(1)), + &reshaped_scatter_indices)); + const Literal& updates = + parent_->GetEvaluatedLiteralFor(scatter->operand(2)); + const Shape& updates_shape = updates.shape(); + const Shape& operand_shape = operand.shape(); + + ShapeUtil::IndexIterationSpace scatter_indices_iteration_space = + IterationSpaceForUpdateScatterIndices(updates_shape, dim_numbers); + ShapeUtil::IndexIterationSpace window_indices_iteration_space = + IterationSpaceForUpdateWindowIndices(updates_shape, dim_numbers); + + std::vector input_index(operand_shape.dimensions_size()); + std::vector update_index(updates_shape.dimensions_size()); + std::vector input_scatter_index_clamped( + operand_shape.dimensions_size()); + + UpdateScatterIndexToInputIndex update_scatter_index_to_input_index( + &scatter->scatter_dimension_numbers(), /*input_shape=*/operand_shape, + updates_shape, &scatter_indices); + UpdateWindowIndexToInputIndex update_window_index_to_input_index( + scatter->scatter_dimension_numbers(), /*input_shape=*/operand_shape, + updates_shape); + + // Initialize the result with the operand. This makes it easier to handle + // the updates even when the indices are repeated. + std::unique_ptr result = operand.CloneToUnique(); + HloEvaluator embedded_evaluator; + auto scatter_inner_loop_body = + [&](tensorflow::gtl::ArraySlice update_window_index, + tensorflow::gtl::ArraySlice input_scatter_index, + tensorflow::gtl::ArraySlice update_scatter_index) + -> StatusOr { + TF_ASSIGN_OR_RETURN( + tensorflow::gtl::ArraySlice input_window_index, + update_window_index_to_input_index(update_window_index)); + for (int i = 0, e = update_index.size(); i < e; i++) { + update_index[i] = update_scatter_index[i] + update_window_index[i]; + DCHECK_LT(update_index[i], updates_shape.dimensions(i)); + } + for (int i = 0, e = input_scatter_index.size(); i < e; i++) { + int64 update_dim = + update_window_index_to_input_index.input_dim_value_to_update_index( + i); + // If 'update_dim' is -1, it means 'i' is an elided window dim. This + // means we set the iteration index to 0, so for the purpose of the + // following calculations we can consider the update dimension size to + // be 1. + int64 update_dim_size = + update_dim == -1 ? 1 : updates_shape.dimensions(update_dim); + // Clamp the scatter index so that the scatter region fits in the + // operand. input_scatter_index_clamped[i] = + // clamp(input_scatter_index[i], 0, + // operand_shape.dimensions(i) - + // update_dim_size); + input_scatter_index_clamped[i] = + std::min(operand_shape.dimensions(i) - update_dim_size, + std::max(0LL, input_scatter_index[i])); + } + for (int i = 0, e = input_index.size(); i < e; i++) { + input_index[i] = input_scatter_index_clamped[i] + input_window_index[i]; + DCHECK_GE(input_index[i], 0); + DCHECK_LT(input_index[i], operand_shape.dimensions(i)); + } + + auto result_value_literal = + LiteralUtil::CreateR0(result->Get(input_index)); + auto update_value_literal = + LiteralUtil::CreateR0(updates.Get(update_index)); + std::unique_ptr updated_result = + embedded_evaluator + .Evaluate( + *scatter->to_apply(), + {result_value_literal.get(), update_value_literal.get()}) + .ConsumeValueOrDie(); + // Clear visit states so that the we can use the evaluate again on the + // same computation. + embedded_evaluator.ResetVisitStates(); + result->Set(input_index, updated_result->Get({})); + return true; + }; + + auto scatter_outer_loop_body = + [&](tensorflow::gtl::ArraySlice update_scatter_index) + -> StatusOr { + TF_ASSIGN_OR_RETURN( + tensorflow::gtl::ArraySlice input_scatter_index, + update_scatter_index_to_input_index(update_scatter_index)); + TF_RETURN_IF_ERROR(ShapeUtil::ForEachIndexWithStatus( + updates_shape, window_indices_iteration_space, + [&](tensorflow::gtl::ArraySlice update_window_index) { + return scatter_inner_loop_body( + update_window_index, input_scatter_index, update_scatter_index); + })); + return true; + }; + + TF_RETURN_IF_ERROR(ShapeUtil::ForEachIndexWithStatus( + updates_shape, scatter_indices_iteration_space, + scatter_outer_loop_body)); + parent_->evaluated_[scatter] = std::move(result); + return Status::OK(); + } + Status HandleSlice(HloInstruction* slice) override { auto operand = slice->operand(0); const Shape& shape = slice->shape(); @@ -2003,7 +2413,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { std::is_same::value || std::is_same::value>::type* = nullptr> Status HandleIota(HloInstruction* iota) { - auto result = MakeUnique(iota->shape()); + auto result = absl::make_unique(iota->shape()); auto data = result->data(); std::iota(data.begin(), data.end(), 0); parent_->evaluated_[iota] = std::move(result); @@ -2085,7 +2495,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { } std::vector operand_indices(start.size()); - auto result = MakeUnique(result_shape); + auto result = absl::make_unique(result_shape); TF_RETURN_IF_ERROR(result->Populate( [&](tensorflow::gtl::ArraySlice multi_index) { for (int64 i = 0; i < operand_indices.size(); ++i) { @@ -2171,7 +2581,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { const Literal& lhs_literal = parent_->GetEvaluatedLiteralFor(lhs); const Literal& rhs_literal = parent_->GetEvaluatedLiteralFor(rhs); - auto result = MakeUnique(shape); + auto result = absl::make_unique(shape); TF_RETURN_IF_ERROR(result->Populate( [&](tensorflow::gtl::ArraySlice multi_index) { @@ -2209,7 +2619,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { const Literal& rhs_literal = parent_->GetEvaluatedLiteralFor(rhs); const Literal& ehs_literal = parent_->GetEvaluatedLiteralFor(ehs); - auto result = MakeUnique(shape); + auto result = absl::make_unique(shape); TF_RETURN_IF_ERROR(result->Populate( [&](tensorflow::gtl::ArraySlice multi_index) { diff --git a/tensorflow/compiler/xla/service/hlo_execution_profile.cc b/tensorflow/compiler/xla/service/hlo_execution_profile.cc index c3ccbf0f0c75b569b49652807dea52faebdccc31..de3d7a167752f0de790585e50874dd6d2904bd37 100644 --- a/tensorflow/compiler/xla/service/hlo_execution_profile.cc +++ b/tensorflow/compiler/xla/service/hlo_execution_profile.cc @@ -19,6 +19,8 @@ limitations under the License. #include #include +#include "absl/algorithm/container.h" +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/human_readable_profile_builder.h" @@ -49,7 +51,7 @@ std::unique_ptr CreateHloProfilePrinterData( size_t profile_counters_size = hlo_profile_index_map.total_count(); std::unique_ptr profile_printer_data = - MakeUnique(); + absl::make_unique(); profile_printer_data->set_profile_counters_size(profile_counters_size); profile_printer_data->mutable_computation_infos()->Reserve( hlo_profile_index_map.computation_count()); @@ -67,11 +69,11 @@ std::unique_ptr CreateHloProfilePrinterData( // The profile indices were computed deterministically in // HloProfileIndexMap::HloProfileIndexMap. - c_sort(computation_and_profile_idx_list, - [](const std::pair& left, - const std::pair& right) { - return left.second < right.second; - }); + absl::c_sort(computation_and_profile_idx_list, + [](const std::pair& left, + const std::pair& right) { + return left.second < right.second; + }); for (const auto& pair : computation_and_profile_idx_list) { CHECK_LT(pair.second, profile_counters_size); diff --git a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc index fd5085bed234068a1bdf18977b38d92badc02a49..1efa6eb5bda7e1cb90874e0466aafd2c788a3fbf 100644 --- a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc +++ b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc @@ -844,7 +844,10 @@ string HloDotDumper::GetInstructionNodeInlinedOperands( *elem_count *= dim; } } - if (elem_count.has_value() && *elem_count <= 8) { + // Allow HloDotDumper to print HloInstruction reconstructed from HloProto + // collected from profiling tools. Those constants may not have a valid + // literal. + if (elem_count.has_value() && *elem_count <= 8 && constant->HasLiteral()) { return Printf("%s (%s)", constant->literal().ToString(), ShapeUtil::HumanString(constant->shape())); } @@ -1019,6 +1022,8 @@ ColorScheme HloDotDumper::GetInstructionColor(const HloInstruction* instr) { return kWhite; } return kGreen; + case HloOpcode::kScatter: + // Do not de-emphasize Scatter, since it involves significant work. case HloOpcode::kCopy: // Emphasize copy nodes, which are either physical transposes (and thus // significant), or copies of read-only buffers (and thus dead weight). @@ -1043,6 +1048,7 @@ ColorScheme HloDotDumper::GetInstructionColor(const HloInstruction* instr) { case HloOpcode::kMap: return kGray; case HloOpcode::kCrossReplicaSum: + case HloOpcode::kAllToAll: case HloOpcode::kInfeed: case HloOpcode::kOutfeed: case HloOpcode::kRecv: diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc index 8b9bdd2f46fe8a63b419b45ef2c2a2e025c60c8f..e3d6b2e753b75d64c738ccbcfa4afdc56ca247d3 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction.cc @@ -21,10 +21,11 @@ limitations under the License. #include #include +#include "absl/algorithm/container.h" +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/protobuf_util.h" -#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/dfs_hlo_visitor.h" #include "tensorflow/compiler/xla/service/hlo_casting_utils.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" @@ -224,7 +225,7 @@ StatusOr> HloInstruction::CreateFromProto( Literal::CreateFromProto(proto.literal())); instruction = CreateConstant(std::move(literal)); } else { - instruction = MakeUnique(proto.shape()); + instruction = absl::make_unique(proto.shape()); } break; } @@ -281,27 +282,14 @@ StatusOr> HloInstruction::CreateFromProto( case HloOpcode::kInfeed: { const Shape& data_shape = ShapeUtil::GetTupleElementShape(proto.shape(), 0); - if (proto.operand_ids_size() == 0) { - // TODO(b/80000000): Remove this when all uses of infeed are - // converted to take tokens. - instruction = CreateInfeed(data_shape, proto.infeed_config()); - } else { - CHECK_EQ(proto.operand_ids_size(), 1); - instruction = - CreateInfeed(data_shape, operands(0), proto.infeed_config()); - } + TF_RET_CHECK(proto.operand_ids_size() == 1); + instruction = + CreateInfeed(data_shape, operands(0), proto.infeed_config()); } break; case HloOpcode::kOutfeed: - if (proto.operand_ids_size() == 1) { - // TODO(b/80000000): Remove this when all uses of outfeed are - // converted to take tokens. - instruction = CreateOutfeed(proto.outfeed_shape(), operands(0), - proto.outfeed_config()); - } else { - CHECK_EQ(proto.operand_ids_size(), 2); - instruction = CreateOutfeed(proto.outfeed_shape(), operands(0), - operands(1), proto.outfeed_config()); - } + TF_RET_CHECK(proto.operand_ids_size() == 2); + instruction = CreateOutfeed(proto.outfeed_shape(), operands(0), + operands(1), proto.outfeed_config()); break; case HloOpcode::kCrossReplicaSum: { TF_RET_CHECK(proto.called_computation_ids_size() == 1) @@ -320,15 +308,25 @@ StatusOr> HloInstruction::CreateFromProto( /*all_reduce_id=*/all_reduce_id); break; } + case HloOpcode::kAllToAll: { + instruction = CreateAllToAll( + proto.shape(), all_operands(), + /*replica_groups=*/ + std::vector(proto.replica_groups().begin(), + proto.replica_groups().end()), + /*barrier=*/proto.cross_replica_sum_barrier()); + break; + } case HloOpcode::kConvolution: TF_RET_CHECK(proto.operand_ids_size() == 2) << "Convolution instruction should have 2 operands but sees " << proto.operand_ids_size(); TF_RET_CHECK(proto.has_window()); TF_RET_CHECK(proto.has_convolution_dimension_numbers()); - instruction = - CreateConvolve(proto.shape(), operands(0), operands(1), - proto.window(), proto.convolution_dimension_numbers()); + instruction = CreateConvolve( + proto.shape(), operands(0), operands(1), proto.window(), + proto.convolution_dimension_numbers(), + std::max(static_cast(proto.feature_group_count()), 1LL)); break; case HloOpcode::kReduceWindow: TF_RET_CHECK(proto.operand_ids_size() == 2) @@ -382,7 +380,7 @@ StatusOr> HloInstruction::CreateFromProto( << "DynamicSlice instruction should have 2 operands but sees " << proto.operand_ids_size(); std::vector slice_sizes(proto.dynamic_slice_sizes_size()); - c_copy(proto.dynamic_slice_sizes(), slice_sizes.begin()); + absl::c_copy(proto.dynamic_slice_sizes(), slice_sizes.begin()); instruction = CreateDynamicSlice(proto.shape(), operands(0), operands(1), slice_sizes); break; @@ -394,18 +392,35 @@ StatusOr> HloInstruction::CreateFromProto( TF_RET_CHECK(proto.has_gather_dimension_numbers()) << "Gather instruction should have GatherDimensionNumbers set."; std::unique_ptr gather_dimension_numbers = - MakeUnique(proto.gather_dimension_numbers()); - std::vector gather_window_bounds; - for (int64 bound : proto.gather_window_bounds()) { - gather_window_bounds.push_back(bound); + absl::make_unique( + proto.gather_dimension_numbers()); + std::vector gather_slice_sizes; + for (int64 bound : proto.gather_slice_sizes()) { + gather_slice_sizes.push_back(bound); } + instruction = CreateGather(proto.shape(), operands(0), operands(1), + *gather_dimension_numbers, gather_slice_sizes); + break; + } + case HloOpcode::kScatter: { + TF_RET_CHECK(proto.operand_ids_size() == 3) + << "Scatter instruction should have 3 operands but sees " + << proto.operand_ids_size(); + TF_RET_CHECK(proto.has_scatter_dimension_numbers()) + << "Scatter instruction should have ScatterDimensionNumbers set."; + TF_RET_CHECK(proto.called_computation_ids_size() == 1) + << "Scatter instruction should have 1 called computation but sees " + << proto.called_computation_ids_size(); + auto scatter_dimension_numbers = + absl::make_unique( + proto.scatter_dimension_numbers()); instruction = - CreateGather(proto.shape(), operands(0), operands(1), - *gather_dimension_numbers, gather_window_bounds); + CreateScatter(proto.shape(), operands(0), operands(1), operands(2), + computations(0), *scatter_dimension_numbers); break; } default: { - instruction = WrapUnique(new HloInstruction(opcode, proto.shape())); + instruction = absl::WrapUnique(new HloInstruction(opcode, proto.shape())); for (const int64 operand_id : proto.operand_ids()) { TF_RET_CHECK(ContainsKey(instruction_map, operand_id)) << "No instruction with id " << operand_id; @@ -436,7 +451,7 @@ StatusOr> HloInstruction::CreateFromProto( if (proto.has_dot_dimension_numbers()) { instruction->dot_dimension_numbers_ = - MakeUnique(proto.dot_dimension_numbers()); + absl::make_unique(proto.dot_dimension_numbers()); } if (proto.has_sharding()) { @@ -450,34 +465,36 @@ StatusOr> HloInstruction::CreateFromProto( /* static */ std::unique_ptr HloInstruction::CreateParameter( int64 parameter_number, const Shape& shape, const string& name) { - return MakeUnique(parameter_number, shape, name); + return absl::make_unique(parameter_number, shape, + name); } /* static */ std::unique_ptr HloInstruction::CreateTrace( const string& tag, HloInstruction* operand) { - return MakeUnique(tag, operand); + return absl::make_unique(tag, operand); } /* static */ std::unique_ptr HloInstruction::CreateConstant( std::unique_ptr literal) { - return MakeUnique(std::move(literal)); + return absl::make_unique(std::move(literal)); } /* static */ std::unique_ptr HloInstruction::CreateIota( const Shape& shape) { - return WrapUnique(new HloInstruction(HloOpcode::kIota, shape)); + return absl::WrapUnique(new HloInstruction(HloOpcode::kIota, shape)); } /* static */ std::unique_ptr HloInstruction::CreateGetTupleElement(const Shape& shape, HloInstruction* operand, int64 index) { - return MakeUnique(shape, operand, index); + return absl::make_unique(shape, operand, + index); } /* static */ std::unique_ptr HloInstruction::CreateRng( const Shape& shape, RandomDistribution distribution, tensorflow::gtl::ArraySlice parameters) { - return MakeUnique(shape, distribution, parameters); + return absl::make_unique(shape, distribution, parameters); } /* static */ std::unique_ptr HloInstruction::CreateNary( @@ -487,7 +504,7 @@ HloInstruction::CreateGetTupleElement(const Shape& shape, // It is impossible to copy an opaque shape, we don't know how big it is. CHECK(!ShapeUtil::IsOpaque(shape)); } - auto instruction = WrapUnique(new HloInstruction(opcode, shape)); + auto instruction = absl::WrapUnique(new HloInstruction(opcode, shape)); for (auto operand : operands) { instruction->AppendOperand(operand); } @@ -592,31 +609,33 @@ HloInstruction::CreateGetTupleElement(const Shape& shape, /* static */ std::unique_ptr HloInstruction::CreateMap( const Shape& shape, tensorflow::gtl::ArraySlice operands, HloComputation* map_computation) { - return MakeUnique(shape, operands, map_computation); + return absl::make_unique(shape, operands, map_computation); } /* static */ std::unique_ptr HloInstruction::CreateConvolve( const Shape& shape, HloInstruction* lhs, HloInstruction* rhs, - const Window& window, - const ConvolutionDimensionNumbers& dimension_numbers) { - return MakeUnique(shape, lhs, rhs, window, - dimension_numbers); + const Window& window, const ConvolutionDimensionNumbers& dimension_numbers, + int64 feature_group_count) { + return absl::make_unique( + shape, lhs, rhs, window, dimension_numbers, feature_group_count); } /* static */ std::unique_ptr HloInstruction::CreateFft( const Shape& shape, HloInstruction* operand, FftType fft_type, tensorflow::gtl::ArraySlice fft_length) { - return MakeUnique(shape, operand, fft_type, fft_length); + return absl::make_unique(shape, operand, fft_type, + fft_length); } /* static */ std::unique_ptr HloInstruction::CreateDot( const Shape& shape, HloInstruction* lhs, HloInstruction* rhs, const DotDimensionNumbers& dimension_numbers) { - auto instruction = WrapUnique(new HloInstruction(HloOpcode::kDot, shape)); + auto instruction = + absl::WrapUnique(new HloInstruction(HloOpcode::kDot, shape)); instruction->AppendOperand(lhs); instruction->AppendOperand(rhs); instruction->dot_dimension_numbers_ = - MakeUnique(dimension_numbers); + absl::make_unique(dimension_numbers); return instruction; } @@ -625,10 +644,12 @@ HloInstruction::CreateGetTupleElement(const Shape& shape, CHECK_EQ(ShapeUtil::Rank(lhs->shape()), 2); CHECK_EQ(ShapeUtil::Rank(rhs->shape()), 2); - auto instruction = WrapUnique(new HloInstruction(HloOpcode::kDot, shape)); + auto instruction = + absl::WrapUnique(new HloInstruction(HloOpcode::kDot, shape)); instruction->AppendOperand(lhs); instruction->AppendOperand(rhs); - instruction->dot_dimension_numbers_ = MakeUnique(); + instruction->dot_dimension_numbers_ = + absl::make_unique(); instruction->dot_dimension_numbers_->add_lhs_contracting_dimensions(1); instruction->dot_dimension_numbers_->add_rhs_contracting_dimensions(0); return instruction; @@ -639,7 +660,7 @@ HloInstruction::CreateReducePrecision(const Shape& shape, HloInstruction* operand, const int exponent_bits, const int mantissa_bits) { - return MakeUnique( + return absl::make_unique( shape, operand, exponent_bits, mantissa_bits); } @@ -650,41 +671,38 @@ HloInstruction::CreateCrossReplicaSum( tensorflow::gtl::ArraySlice replica_group_ids, tensorflow::StringPiece barrier, const tensorflow::gtl::optional& all_reduce_id) { - return MakeUnique( + return absl::make_unique( shape, operands, reduce_computation, replica_group_ids, barrier, all_reduce_id); } -/* static */ std::unique_ptr HloInstruction::CreateInfeed( - const Shape& infeed_shape, HloInstruction* token_operand, - const string& config) { - return MakeUnique(infeed_shape, token_operand, config); +/* static */ std::unique_ptr HloInstruction::CreateAllToAll( + const Shape& shape, tensorflow::gtl::ArraySlice operands, + const std::vector& replica_groups, + tensorflow::StringPiece barrier) { + return absl::make_unique(shape, operands, + replica_groups, barrier); } /* static */ std::unique_ptr HloInstruction::CreateInfeed( - const Shape& infeed_shape, const string& config) { - return MakeUnique(infeed_shape, config); + const Shape& infeed_shape, HloInstruction* token_operand, + const string& config) { + return absl::make_unique(infeed_shape, token_operand, + config); } /* static */ std::unique_ptr HloInstruction::CreateOutfeed( const Shape& outfeed_shape, HloInstruction* operand, HloInstruction* token_operand, tensorflow::StringPiece outfeed_config) { - return MakeUnique(outfeed_shape, operand, - token_operand, outfeed_config); -} - -/* static */ std::unique_ptr HloInstruction::CreateOutfeed( - const Shape& outfeed_shape, HloInstruction* operand, - tensorflow::StringPiece outfeed_config) { - return MakeUnique(outfeed_shape, operand, - outfeed_config); + return absl::make_unique( + outfeed_shape, operand, token_operand, outfeed_config); } /* static */ std::unique_ptr HloInstruction::CreateSend( HloInstruction* operand, HloInstruction* token, int64 channel_id, bool is_host_transfer) { - return MakeUnique(operand, token, channel_id, - is_host_transfer); + return absl::make_unique(operand, token, channel_id, + is_host_transfer); } /* static */ std::unique_ptr HloInstruction::CreateSendDone( @@ -692,14 +710,15 @@ HloInstruction::CreateCrossReplicaSum( auto send_operand = DynCast(operand); CHECK(send_operand != nullptr) << "SendDone must take the context operand from Send"; - return MakeUnique(send_operand, is_host_transfer); + return absl::make_unique(send_operand, + is_host_transfer); } /* static */ std::unique_ptr HloInstruction::CreateRecv( const Shape& shape, HloInstruction* token, int64 channel_id, bool is_host_transfer) { - return MakeUnique(shape, token, channel_id, - is_host_transfer); + return absl::make_unique(shape, token, channel_id, + is_host_transfer); } /* static */ std::unique_ptr HloInstruction::CreateRecvDone( @@ -707,19 +726,20 @@ HloInstruction::CreateCrossReplicaSum( auto recv_operand = DynCast(operand); CHECK(recv_operand != nullptr) << "RecvDone must take the context operand from Recv"; - return MakeUnique(recv_operand, is_host_transfer); + return absl::make_unique(recv_operand, + is_host_transfer); } /* static */ std::unique_ptr HloInstruction::CreateReverse( const Shape& shape, HloInstruction* operand, tensorflow::gtl::ArraySlice dimensions) { - return MakeUnique(shape, operand, dimensions); + return absl::make_unique(shape, operand, dimensions); } /* static */ std::unique_ptr HloInstruction::CreateAfterAll( tensorflow::gtl::ArraySlice operands) { CHECK(!operands.empty()); - auto instruction = WrapUnique( + auto instruction = absl::WrapUnique( new HloInstruction(HloOpcode::kAfterAll, ShapeUtil::MakeTokenShape())); for (auto operand : operands) { instruction->AppendOperand(operand); @@ -728,14 +748,15 @@ HloInstruction::CreateCrossReplicaSum( } /* static */ std::unique_ptr HloInstruction::CreateToken() { - return WrapUnique( + return absl::WrapUnique( new HloInstruction(HloOpcode::kAfterAll, ShapeUtil::MakeTokenShape())); } /* static */ std::unique_ptr HloInstruction::CreateWhile( const Shape& shape, HloComputation* condition, HloComputation* body, HloInstruction* init) { - auto instruction = WrapUnique(new HloInstruction(HloOpcode::kWhile, shape)); + auto instruction = + absl::WrapUnique(new HloInstruction(HloOpcode::kWhile, shape)); instruction->AppendOperand(init); // Body comes before condition computation in the vector. instruction->called_computations_.push_back(body); @@ -748,7 +769,7 @@ HloInstruction::CreateCrossReplicaSum( HloInstruction* true_computation_arg, HloComputation* true_computation, HloInstruction* false_computation_arg, HloComputation* false_computation) { auto instruction = - WrapUnique(new HloInstruction(HloOpcode::kConditional, shape)); + absl::WrapUnique(new HloInstruction(HloOpcode::kConditional, shape)); instruction->AppendOperand(pred); instruction->AppendOperand(true_computation_arg); instruction->AppendOperand(false_computation_arg); @@ -765,15 +786,15 @@ HloInstruction::CreateCrossReplicaSum( tensorflow::gtl::ArraySlice start_indices, tensorflow::gtl::ArraySlice limit_indices, tensorflow::gtl::ArraySlice strides) { - return MakeUnique(shape, operand, start_indices, - limit_indices, strides); + return absl::make_unique(shape, operand, start_indices, + limit_indices, strides); } /* static */ std::unique_ptr HloInstruction::CreateDynamicSlice( const Shape& shape, HloInstruction* operand, HloInstruction* start_indices, tensorflow::gtl::ArraySlice slice_sizes) { - return MakeUnique(shape, operand, start_indices, - slice_sizes); + return absl::make_unique( + shape, operand, start_indices, slice_sizes); } /* static */ std::unique_ptr @@ -781,8 +802,8 @@ HloInstruction::CreateDynamicUpdateSlice(const Shape& shape, HloInstruction* operand, HloInstruction* update, HloInstruction* start_indices) { - auto instruction = - WrapUnique(new HloInstruction(HloOpcode::kDynamicUpdateSlice, shape)); + auto instruction = absl::WrapUnique( + new HloInstruction(HloOpcode::kDynamicUpdateSlice, shape)); instruction->AppendOperand(operand); instruction->AppendOperand(update); instruction->AppendOperand(start_indices); @@ -792,12 +813,14 @@ HloInstruction::CreateDynamicUpdateSlice(const Shape& shape, /* static */ std::unique_ptr HloInstruction::CreateConcatenate( const Shape& shape, tensorflow::gtl::ArraySlice operands, int64 dimension) { - return MakeUnique(shape, operands, dimension); + return absl::make_unique(shape, operands, + dimension); } /* static */ std::unique_ptr HloInstruction::CreateConvert( const Shape& shape, HloInstruction* operand) { - auto instruction = WrapUnique(new HloInstruction(HloOpcode::kConvert, shape)); + auto instruction = + absl::WrapUnique(new HloInstruction(HloOpcode::kConvert, shape)); instruction->AppendOperand(operand); return instruction; } @@ -806,24 +829,38 @@ HloInstruction::CreateDynamicUpdateSlice(const Shape& shape, HloInstruction::CreateBitcastConvert(const Shape& shape, HloInstruction* operand) { auto instruction = - WrapUnique(new HloInstruction(HloOpcode::kBitcastConvert, shape)); + absl::WrapUnique(new HloInstruction(HloOpcode::kBitcastConvert, shape)); instruction->AppendOperand(operand); return instruction; } /* static */ std::unique_ptr HloInstruction::CreateReduce( - const Shape& shape, HloInstruction* arg, HloInstruction* init_value, + const Shape& shape, HloInstruction* operand, HloInstruction* init_value, tensorflow::gtl::ArraySlice dimensions_to_reduce, HloComputation* reduce_computation) { - return MakeUnique( - shape, arg, init_value, dimensions_to_reduce, reduce_computation); + auto instruction = absl::WrapUnique(new HloReduceInstruction( + shape, {operand, init_value}, dimensions_to_reduce, reduce_computation)); + return std::move(instruction); +} + +/* static */ std::unique_ptr HloInstruction::CreateReduce( + const Shape& shape, tensorflow::gtl::ArraySlice operands, + tensorflow::gtl::ArraySlice init_values, + tensorflow::gtl::ArraySlice dimensions_to_reduce, + HloComputation* reduce_computation) { + std::vector all_args; + all_args.reserve(operands.size() * 2); + all_args.insert(all_args.end(), operands.begin(), operands.end()); + all_args.insert(all_args.end(), init_values.begin(), init_values.end()); + return absl::make_unique( + shape, all_args, dimensions_to_reduce, reduce_computation); } /* static */ std::unique_ptr HloInstruction::CreateReduceWindow( const Shape& shape, HloInstruction* operand, HloInstruction* init_value, const Window& window, HloComputation* reduce_computation) { - return MakeUnique(shape, operand, init_value, - window, reduce_computation); + return absl::make_unique( + shape, operand, init_value, window, reduce_computation); } /* static */ std::unique_ptr @@ -832,7 +869,7 @@ HloInstruction::CreateBatchNormTraining(const Shape& shape, HloInstruction* scale, HloInstruction* offset, float epsilon, int64 feature_index) { - return MakeUnique( + return absl::make_unique( shape, operand, scale, offset, epsilon, feature_index); } @@ -841,7 +878,7 @@ HloInstruction::CreateBatchNormInference( const Shape& shape, HloInstruction* operand, HloInstruction* scale, HloInstruction* offset, HloInstruction* mean, HloInstruction* variance, float epsilon, int64 feature_index) { - return MakeUnique( + return absl::make_unique( shape, operand, scale, offset, mean, variance, epsilon, feature_index); } @@ -851,9 +888,9 @@ HloInstruction::CreateBatchNormGrad(const Shape& shape, HloInstruction* operand, HloInstruction* variance, HloInstruction* grad_output, float epsilon, int64 feature_index) { - return MakeUnique(shape, operand, scale, mean, - variance, grad_output, epsilon, - feature_index); + return absl::make_unique( + shape, operand, scale, mean, variance, grad_output, epsilon, + feature_index); } /* static */ std::unique_ptr @@ -861,15 +898,15 @@ HloInstruction::CreateSelectAndScatter( const Shape& shape, HloInstruction* operand, HloComputation* select, const Window& window, HloInstruction* source, HloInstruction* init_value, HloComputation* scatter) { - return MakeUnique( + return absl::make_unique( shape, operand, select, window, source, init_value, scatter); } /* static */ std::unique_ptr HloInstruction::CreateBroadcast( const Shape& shape, HloInstruction* operand, tensorflow::gtl::ArraySlice broadcast_dimensions) { - return MakeUnique(shape, operand, - broadcast_dimensions); + return absl::make_unique(shape, operand, + broadcast_dimensions); } /* static */ std::unique_ptr @@ -927,8 +964,8 @@ HloInstruction::CreateBroadcastSequence( /* static */ std::unique_ptr HloInstruction::CreatePad( const Shape& shape, HloInstruction* operand, HloInstruction* padding_value, const PaddingConfig& padding_config) { - return MakeUnique(shape, operand, padding_value, - padding_config); + return absl::make_unique(shape, operand, padding_value, + padding_config); } /* static */ std::unique_ptr HloInstruction::CreateReshape( @@ -937,7 +974,8 @@ HloInstruction::CreateBroadcastSequence( ShapeUtil::ElementsIn(operand->shape())) << "shape: " << ShapeUtil::HumanString(shape) << " operand: " << ShapeUtil::HumanString(operand->shape()); - auto instruction = WrapUnique(new HloInstruction(HloOpcode::kReshape, shape)); + auto instruction = + absl::WrapUnique(new HloInstruction(HloOpcode::kReshape, shape)); instruction->AppendOperand(operand); return instruction; } @@ -945,26 +983,27 @@ HloInstruction::CreateBroadcastSequence( /* static */ std::unique_ptr HloInstruction::CreateTranspose( const Shape& shape, HloInstruction* operand, tensorflow::gtl::ArraySlice dimensions) { - return MakeUnique(shape, operand, dimensions); + return absl::make_unique(shape, operand, dimensions); } /* static */ std::unique_ptr HloInstruction::CreateSort( const Shape& shape, int64 dimension, HloInstruction* keys, HloInstruction* values) { - return MakeUnique(shape, dimension, keys, values); + return absl::make_unique(shape, dimension, keys, values); } /* static */ std::unique_ptr HloInstruction::CreateFusion( const Shape& shape, FusionKind fusion_kind, HloInstruction* fused_root) { - return MakeUnique(shape, fusion_kind, fused_root); + return absl::make_unique(shape, fusion_kind, + fused_root); } /* static */ std::unique_ptr HloInstruction::CreateFusion( const Shape& shape, FusionKind fusion_kind, tensorflow::gtl::ArraySlice operands, HloComputation* fusion_computation) { - return MakeUnique(shape, fusion_kind, operands, - fusion_computation); + return absl::make_unique(shape, fusion_kind, operands, + fusion_computation); } void HloInstruction::set_single_sharding(const HloSharding& sharding) { @@ -1022,7 +1061,7 @@ bool HloInstruction::HasSideEffect() const { const Shape& shape, tensorflow::gtl::ArraySlice operands, HloComputation* computation) { std::unique_ptr instruction = - WrapUnique(new HloInstruction(HloOpcode::kCall, shape)); + absl::WrapUnique(new HloInstruction(HloOpcode::kCall, shape)); for (auto operand : operands) { instruction->AppendOperand(operand); } @@ -1033,15 +1072,15 @@ bool HloInstruction::HasSideEffect() const { /* static */ std::unique_ptr HloInstruction::CreateCustomCall( const Shape& shape, tensorflow::gtl::ArraySlice operands, tensorflow::StringPiece custom_call_target) { - return MakeUnique(shape, operands, - custom_call_target); + return absl::make_unique(shape, operands, + custom_call_target); } /* static */ std::unique_ptr HloInstruction::CreateHostCompute( const Shape& shape, tensorflow::gtl::ArraySlice operands, tensorflow::StringPiece channel_name, const int64 cost_estimate_ns) { - return MakeUnique(shape, operands, channel_name, - cost_estimate_ns); + return absl::make_unique( + shape, operands, channel_name, cost_estimate_ns); } /* static */ std::unique_ptr HloInstruction::CreateTuple( @@ -1055,18 +1094,29 @@ bool HloInstruction::HasSideEffect() const { } /* static */ std::unique_ptr HloInstruction::CreateGather( - const Shape& shape, HloInstruction* operand, HloInstruction* gather_indices, + const Shape& shape, HloInstruction* operand, HloInstruction* start_indices, const GatherDimensionNumbers& gather_dim_numbers, - tensorflow::gtl::ArraySlice window_bounds) { - return MakeUnique(shape, operand, gather_indices, - gather_dim_numbers, window_bounds); + tensorflow::gtl::ArraySlice slice_sizes) { + return absl::make_unique( + shape, operand, start_indices, gather_dim_numbers, slice_sizes); +} + +/* static */ std::unique_ptr HloInstruction::CreateScatter( + const Shape& shape, HloInstruction* operand, + HloInstruction* scatter_indices, HloInstruction* updates, + HloComputation* update_computation, + const ScatterDimensionNumbers& scatter_dim_numbers) { + return absl::make_unique( + shape, operand, scatter_indices, updates, update_computation, + scatter_dim_numbers); } /* static */ std::unique_ptr HloInstruction::CreateDomain( const Shape& shape, HloInstruction* operand, std::unique_ptr operand_side_metadata, std::unique_ptr user_side_metadata) { - auto instruction = WrapUnique(new HloInstruction(HloOpcode::kDomain, shape)); + auto instruction = + absl::WrapUnique(new HloInstruction(HloOpcode::kDomain, shape)); instruction->operand_side_metadata_ = std::move(operand_side_metadata); instruction->user_side_metadata_ = std::move(user_side_metadata); instruction->AppendOperand(operand); @@ -1113,6 +1163,7 @@ std::unique_ptr HloInstruction::CloneWithNewOperands( case HloOpcode::kGetTupleElement: case HloOpcode::kReducePrecision: case HloOpcode::kCrossReplicaSum: + case HloOpcode::kAllToAll: case HloOpcode::kInfeed: case HloOpcode::kOutfeed: case HloOpcode::kConvolution: @@ -1124,6 +1175,7 @@ std::unique_ptr HloInstruction::CloneWithNewOperands( case HloOpcode::kDynamicSlice: case HloOpcode::kSort: case HloOpcode::kGather: + case HloOpcode::kScatter: case HloOpcode::kIota: clone = CloneWithNewOperandsImpl(shape, new_operands, context); break; @@ -1579,6 +1631,7 @@ bool HloInstruction::IdenticalSlowPath( case HloOpcode::kInfeed: case HloOpcode::kOutfeed: case HloOpcode::kCrossReplicaSum: + case HloOpcode::kAllToAll: case HloOpcode::kConvolution: case HloOpcode::kCustomCall: case HloOpcode::kReduceWindow: @@ -1587,6 +1640,7 @@ bool HloInstruction::IdenticalSlowPath( case HloOpcode::kPad: case HloOpcode::kDynamicSlice: case HloOpcode::kGather: + case HloOpcode::kScatter: LOG(FATAL) << "Base class impl called for opcode with subclass: " << opcode(); } @@ -1693,6 +1747,7 @@ HloComputation* HloInstruction::to_apply() const { case HloOpcode::kReduceWindow: case HloOpcode::kReduce: case HloOpcode::kCrossReplicaSum: + case HloOpcode::kScatter: CHECK_EQ(called_computations_.size(), 1); return called_computations_[0]; default: @@ -1711,6 +1766,7 @@ void HloInstruction::set_to_apply(HloComputation* computation) { case HloOpcode::kReduceWindow: case HloOpcode::kReduce: case HloOpcode::kCrossReplicaSum: + case HloOpcode::kScatter: CHECK_EQ(called_computations_.size(), 1); called_computations_[0] = computation; break; @@ -1977,7 +2033,8 @@ std::vector HloInstruction::ExtraAttributesToString( } else if (opcode() == HloOpcode::kCall || opcode() == HloOpcode::kMap || opcode() == HloOpcode::kReduceWindow || opcode() == HloOpcode::kReduce || - opcode() == HloOpcode::kCrossReplicaSum) { + opcode() == HloOpcode::kCrossReplicaSum || + opcode() == HloOpcode::kScatter) { extra.push_back( StrCat("to_apply=", PrintName(to_apply()->name(), options))); } else if (!called_computations().empty()) { @@ -2013,6 +2070,7 @@ std::vector HloInstruction::ExtraAttributesToString( case HloOpcode::kReduceWindow: case HloOpcode::kReduce: case HloOpcode::kCrossReplicaSum: + case HloOpcode::kScatter: extra.push_back( StrCat("to_apply=\n", to_apply()->ToString(new_options))); break; @@ -2219,6 +2277,8 @@ Status HloInstruction::Visit(DfsHloVisitorBase* visitor) { return visitor->HandleFft(this); case HloOpcode::kCrossReplicaSum: return visitor->HandleCrossReplicaSum(this); + case HloOpcode::kAllToAll: + return visitor->HandleAllToAll(this); case HloOpcode::kTuple: return visitor->HandleTuple(this); case HloOpcode::kMap: @@ -2311,6 +2371,8 @@ Status HloInstruction::Visit(DfsHloVisitorBase* visitor) { return visitor->HandleSendDone(this); case HloOpcode::kGather: return visitor->HandleGather(this); + case HloOpcode::kScatter: + return visitor->HandleScatter(this); case HloOpcode::kDomain: return visitor->HandleDomain(this); case HloOpcode::kAfterAll: @@ -3091,12 +3153,23 @@ const std::vector& HloInstruction::replica_group_ids() const { return Cast(this)->replica_group_ids(); } +const std::vector& HloInstruction::replica_groups() const { + return Cast(this)->replica_groups(); +} + string HloInstruction::cross_replica_sum_barrier() const { - return Cast(this)->cross_replica_sum_barrier(); + if (opcode() == HloOpcode::kCrossReplicaSum) { + return Cast(this)->cross_replica_sum_barrier(); + } + return Cast(this)->cross_replica_sum_barrier(); } void HloInstruction::set_cross_replica_sum_barrier(const string& barrier) { - return Cast(this)->set_cross_replica_sum_barrier( + if (opcode() == HloOpcode::kCrossReplicaSum) { + return Cast(this)->set_cross_replica_sum_barrier( + barrier); + } + return Cast(this)->set_cross_replica_sum_barrier( barrier); } @@ -3126,6 +3199,10 @@ void HloInstruction::set_convolution_dimension_numbers( } } +int64 HloInstruction::feature_group_count() const { + return Cast(this)->feature_group_count(); +} + HloComputation* HloInstruction::select() const { return Cast(this)->select(); } @@ -3166,9 +3243,13 @@ const GatherDimensionNumbers& HloInstruction::gather_dimension_numbers() const { return Cast(this)->gather_dimension_numbers(); } -tensorflow::gtl::ArraySlice HloInstruction::gather_window_bounds() +tensorflow::gtl::ArraySlice HloInstruction::gather_slice_sizes() const { + return Cast(this)->gather_slice_sizes(); +} + +const ScatterDimensionNumbers& HloInstruction::scatter_dimension_numbers() const { - return Cast(this)->gather_window_bounds(); + return Cast(this)->scatter_dimension_numbers(); } } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_instruction.h b/tensorflow/compiler/xla/service/hlo_instruction.h index 30bff286c20033ec193fb29d8c4c935ce6475a27..30dbabfced09d0c090324d2b4a8db8e9ebcbf643 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.h +++ b/tensorflow/compiler/xla/service/hlo_instruction.h @@ -32,6 +32,7 @@ limitations under the License. #include #include +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/iterator_util.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/map_util.h" @@ -402,7 +403,8 @@ class HloInstruction { static std::unique_ptr CreateConvolve( const Shape& shape, HloInstruction* lhs, HloInstruction* rhs, const Window& window, - const ConvolutionDimensionNumbers& dimension_numbers); + const ConvolutionDimensionNumbers& dimension_numbers, + int64 feature_group_count = 1); // Creates an FFT op, of the type indicated by fft_type. static std::unique_ptr CreateFft( @@ -447,8 +449,27 @@ class HloInstruction { HloComputation* reduce_computation, tensorflow::gtl::ArraySlice replica_group_ids, tensorflow::StringPiece barrier, - const tensorflow::gtl::optional& all_reduce_id = - tensorflow::gtl::nullopt); + const tensorflow::gtl::optional& all_reduce_id); + + // This op handles the communication of an Alltoall operation. On each core, + // the operands are N ops in the same shape, where N is the number of cores + // participating the Alltoall. Then the N operands are scattered to N cores, + // e.g., the ith operand is sent to the ith core. Then each core gathers the + // received data into a tuple. + // + // - `replica_groups`: each ReplicaGroup contains a list of replica id. If + // empty, all replicas belong to one group in the order of 0 - (n-1). Alltoall + // will be applied within subgroups in the specified order. For example, + // replica groups = {{1,2,3},{4,5,0}} means, an Alltoall will be applied + // within replica 1, 2, 3, and in the gather phase, the received blocks will + // be concatenated in the order of 1, 2, 3; another Alltoall will be applied + // within replica 4, 5, 0, and the concatenation order is 4, 5, 0. + // + // TODO(b/110096724): This is NOT YET ready to use. + static std::unique_ptr CreateAllToAll( + const Shape& shape, tensorflow::gtl::ArraySlice operands, + const std::vector& replica_groups, + tensorflow::StringPiece barrier); // Creates a conversion instruction, where operand is the data to convert and // shape is the target shape for the conversion. @@ -467,11 +488,6 @@ class HloInstruction { static std::unique_ptr CreateInfeed( const Shape& infeed_shape, HloInstruction* token_operand, const string& config); - // Overload which does not require a token. - // TODO(b/80000000): Remove this overload when all uses of infeed are - // converted to take tokens. - static std::unique_ptr CreateInfeed(const Shape& infeed_shape, - const string& config); // Creates an outfeed instruction, which outputs data. outfeed_shape is the // shape of the data being outfed *not* the shape of the outfeed instruction @@ -479,12 +495,6 @@ class HloInstruction { static std::unique_ptr CreateOutfeed( const Shape& outfeed_shape, HloInstruction* operand, HloInstruction* token_operand, tensorflow::StringPiece outfeed_config); - // Overload which does not require a token. - // TODO(b/80000000): Remove this overload when all uses of outfeed are - // converted to take tokens. - static std::unique_ptr CreateOutfeed( - const Shape& outfeed_shape, HloInstruction* operand, - tensorflow::StringPiece outfeed_config); // Creates an asynchronous send instruction with the given channel id, which // initiates sending the operand data to a unique receive instruction in @@ -542,17 +552,34 @@ class HloInstruction { int64 dimension); // Creates a reduce instruction, where the computation (given by the handle) - // is applied successively to every element in operand. That is, if f is the - // function to apply (which either takes 2 [accumulator, value] or 3 - // [accumulator, index, value] arguments) and init is a reduction operator - // specified initial value (for example, 0 for addition), then this operation - // will compute: - // f(f(init, [index0], value0), [index1], value1), ...) + // is applied successively to every element in operand. For example, let f be + // the function to apply, which takes 2 arguments, an accumulator and the + // current value. Let init be an initial value (which is normally chosen to be + // the identity element for f, e.g. 0 if f is addition). + // Then the reduce HLO will compute: + // f(f(init, value0), value1), ...) static std::unique_ptr CreateReduce( const Shape& shape, HloInstruction* operand, HloInstruction* init_value, tensorflow::gtl::ArraySlice dimensions_to_reduce, HloComputation* reduce_computation); + // A more general, multiple-argument version of the above. + // The function to apply, f, now takes N arguments: + // [accumulator0, accumulator1, ..., accumulatorN, value0, value1, ..., + // init_valueN], and returns an N-tuple. The performed computation is (for + // commutative and associative f operators) equivalent to: + // + // f_1 = f(init0, ... initN, input0.value0, ..., inputN.value0) + // f_2 = f(f_1.tuple_element(0), ..., f_1.tuple_element(N), input0.value1, + // ..., inputN.value1) + // ... + // TODO(b/112040122): Add support to this in HLO passes and in backends. + static std::unique_ptr CreateReduce( + const Shape& shape, tensorflow::gtl::ArraySlice operands, + tensorflow::gtl::ArraySlice init_values, + tensorflow::gtl::ArraySlice dimensions_to_reduce, + HloComputation* reduce_computation); + // Creates a reduce-window instruction, where the computation (given // by the handle) is applied window-wise at each valid window // position in the operand. @@ -641,9 +668,15 @@ class HloInstruction { static std::unique_ptr CreateGather( const Shape& shape, HloInstruction* operand, - HloInstruction* gather_indices, + HloInstruction* start_indices, const GatherDimensionNumbers& gather_dim_numbers, - tensorflow::gtl::ArraySlice window_bounds); + tensorflow::gtl::ArraySlice slice_sizes); + + static std::unique_ptr CreateScatter( + const Shape& shape, HloInstruction* operand, + HloInstruction* scatter_indices, HloInstruction* updates, + HloComputation* update_computation, + const ScatterDimensionNumbers& scatter_dim_numbers); // Creates a kDomain instruction which delimits an HLO domain which have // the provided user and operand side metadata. @@ -1015,14 +1048,12 @@ class HloInstruction { if (sharding_ == nullptr) { return tensorflow::gtl::optional(); } - auto device = sharding_->UniqueDevice(); - return device.ok() ? device.ValueOrDie() - : tensorflow::gtl::optional(); + return sharding_->UniqueDevice(); } // Sets the sharding of this operator. Should only be called by HloModule or // HloComputation methods. void set_sharding(const HloSharding& sharding) { - sharding_ = MakeUnique(sharding); + sharding_ = absl::make_unique(sharding); } void set_single_sharding(const HloSharding& sharding); // Sets a sharding that assigns the current instruction to device. @@ -1394,6 +1425,9 @@ class HloInstruction { // Delegates to HloAllReduceInstruction::replica_group_ids. const std::vector& replica_group_ids() const; + // Delegates to HloAllToAllInstruction::replica_groups. + const std::vector& replica_groups() const; + // Delegates to HloAllReduceInstruction::cross_replica_sum_barrier. string cross_replica_sum_barrier() const; void set_cross_replica_sum_barrier(const string& barrier); @@ -1423,6 +1457,10 @@ class HloInstruction { void set_convolution_dimension_numbers( const ConvolutionDimensionNumbers& dnums); + // The number of feature groups. Must be a divisor of the input feature + // dimension and output feature dimension. + int64 feature_group_count() const; + // Delegates to HloSelectAndScatterInstruction::select. HloComputation* select() const; @@ -1452,8 +1490,11 @@ class HloInstruction { // Delegates to HloGatherInstruction::gather_dimension_numbers. const GatherDimensionNumbers& gather_dimension_numbers() const; - // Delegates to HloGatherInstruction::gather_window_bounds. - tensorflow::gtl::ArraySlice gather_window_bounds() const; + // Delegates to HloGatherInstruction::gather_slice_sizes. + tensorflow::gtl::ArraySlice gather_slice_sizes() const; + + // Delegates to HloScatterInstruction::scatter_dimension_numbers(). + const ScatterDimensionNumbers& scatter_dimension_numbers() const; // Old methods kept for smooth subclassing transition END. diff --git a/tensorflow/compiler/xla/service/hlo_instruction_test.cc b/tensorflow/compiler/xla/service/hlo_instruction_test.cc index b75a2bd34bc5d3b5b6100515748df787b9e7f08a..504b13043f86f152cc83b0b961bf2e8fa3ad2afb 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction_test.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction_test.cc @@ -1355,7 +1355,7 @@ TEST_F(HloInstructionTest, Stringification) { TEST_F(HloInstructionTest, StringifyGather_0) { Shape input_tensor_shape = ShapeUtil::MakeShape(F32, {50, 49, 48, 47, 46}); - Shape gather_indices_tensor_shape = + Shape start_indices_tensor_shape = ShapeUtil::MakeShape(S64, {10, 9, 8, 7, 5}); Shape gather_result_shape = ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 30, 29, 28, 27, 26}); @@ -1363,19 +1363,18 @@ TEST_F(HloInstructionTest, StringifyGather_0) { HloComputation::Builder builder("Gather"); HloInstruction* input = builder.AddInstruction( HloInstruction::CreateParameter(0, input_tensor_shape, "input_tensor")); - HloInstruction* gather_indices = + HloInstruction* start_indices = builder.AddInstruction(HloInstruction::CreateParameter( - 1, gather_indices_tensor_shape, "gather_indices")); - - HloInstruction* gather_instruction = - builder.AddInstruction(HloInstruction::CreateGather( - gather_result_shape, input, gather_indices, - HloGatherInstruction::MakeGatherDimNumbers( - /*output_window_dims=*/{4, 5, 6, 7, 8}, - /*elided_window_dims=*/{}, - /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, - /*index_vector_dim=*/4), - /*window_bounds=*/{30, 29, 28, 27, 26})); + 1, start_indices_tensor_shape, "start_indices")); + + HloInstruction* gather_instruction = builder.AddInstruction( + HloInstruction::CreateGather(gather_result_shape, input, start_indices, + HloGatherInstruction::MakeGatherDimNumbers( + /*offset_dims=*/{4, 5, 6, 7, 8}, + /*collapsed_slice_dims=*/{}, + /*start_index_map=*/{0, 1, 2, 3, 4}, + /*index_vector_dim=*/4), + /*slice_sizes=*/{30, 29, 28, 27, 26})); auto module = CreateNewModule(); module->AddEntryComputation(builder.Build()); @@ -1383,15 +1382,15 @@ TEST_F(HloInstructionTest, StringifyGather_0) { EXPECT_EQ(gather_instruction->ToString(), "%gather = f32[10,9,8,7,30,29,28,27,26]{8,7,6,5,4,3,2,1,0} " "gather(f32[50,49,48,47,46]{4,3,2,1,0} %input_tensor, " - "s64[10,9,8,7,5]{4,3,2,1,0} %gather_indices), " - "output_window_dims={4,5,6,7,8}, elided_window_dims={}, " - "gather_dims_to_operand_dims={0,1,2,3,4}, " - "index_vector_dim=4, window_bounds={30,29,28,27,26}"); + "s64[10,9,8,7,5]{4,3,2,1,0} %start_indices), " + "offset_dims={4,5,6,7,8}, collapsed_slice_dims={}, " + "start_index_map={0,1,2,3,4}, " + "index_vector_dim=4, slice_sizes={30,29,28,27,26}"); } TEST_F(HloInstructionTest, StringifyGather_1) { Shape input_tensor_shape = ShapeUtil::MakeShape(F32, {50, 49, 48, 47, 46}); - Shape gather_indices_tensor_shape = + Shape start_indices_tensor_shape = ShapeUtil::MakeShape(S64, {10, 9, 5, 7, 6}); Shape gather_result_shape = ShapeUtil::MakeShape(F32, {10, 9, 7, 6, 30, 29, 28, 27, 26}); @@ -1399,19 +1398,18 @@ TEST_F(HloInstructionTest, StringifyGather_1) { HloComputation::Builder builder("Gather"); HloInstruction* input = builder.AddInstruction( HloInstruction::CreateParameter(0, input_tensor_shape, "input_tensor")); - HloInstruction* gather_indices = + HloInstruction* start_indices = builder.AddInstruction(HloInstruction::CreateParameter( - 1, gather_indices_tensor_shape, "gather_indices")); - - HloInstruction* gather_instruction = - builder.AddInstruction(HloInstruction::CreateGather( - gather_result_shape, input, gather_indices, - HloGatherInstruction::MakeGatherDimNumbers( - /*output_window_dims=*/{4, 5, 6, 7, 8}, - /*elided_window_dims=*/{}, - /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, - /*index_vector_dim=*/2), - /*window_bounds=*/{30, 29, 28, 27, 26})); + 1, start_indices_tensor_shape, "start_indices")); + + HloInstruction* gather_instruction = builder.AddInstruction( + HloInstruction::CreateGather(gather_result_shape, input, start_indices, + HloGatherInstruction::MakeGatherDimNumbers( + /*offset_dims=*/{4, 5, 6, 7, 8}, + /*collapsed_slice_dims=*/{}, + /*start_index_map=*/{0, 1, 2, 3, 4}, + /*index_vector_dim=*/2), + /*slice_sizes=*/{30, 29, 28, 27, 26})); auto module = CreateNewModule(); module->AddEntryComputation(builder.Build()); @@ -1419,10 +1417,59 @@ TEST_F(HloInstructionTest, StringifyGather_1) { EXPECT_EQ(gather_instruction->ToString(), "%gather = f32[10,9,7,6,30,29,28,27,26]{8,7,6,5,4,3,2,1,0} " "gather(f32[50,49,48,47,46]{4,3,2,1,0} %input_tensor, " - "s64[10,9,5,7,6]{4,3,2,1,0} %gather_indices), " - "output_window_dims={4,5,6,7,8}, elided_window_dims={}, " - "gather_dims_to_operand_dims={0,1,2,3,4}, " - "index_vector_dim=2, window_bounds={30,29,28,27,26}"); + "s64[10,9,5,7,6]{4,3,2,1,0} %start_indices), " + "offset_dims={4,5,6,7,8}, collapsed_slice_dims={}, " + "start_index_map={0,1,2,3,4}, " + "index_vector_dim=2, slice_sizes={30,29,28,27,26}"); +} + +TEST_F(HloInstructionTest, StringifyScatter) { + Shape input_tensor_shape = ShapeUtil::MakeShape(F32, {50, 49, 48, 47, 46}); + Shape scatter_indices_tensor_shape = + ShapeUtil::MakeShape(S64, {10, 9, 5, 7, 6}); + Shape scatter_updates_shape = + ShapeUtil::MakeShape(F32, {10, 9, 7, 6, 30, 29, 28, 27, 26}); + + HloComputation::Builder builder("Scatter"); + HloInstruction* input = builder.AddInstruction( + HloInstruction::CreateParameter(0, input_tensor_shape, "input_tensor")); + HloInstruction* scatter_indices = + builder.AddInstruction(HloInstruction::CreateParameter( + 1, scatter_indices_tensor_shape, "scatter_indices")); + HloInstruction* scatter_updates = + builder.AddInstruction(HloInstruction::CreateParameter( + 2, scatter_updates_shape, "scatter_updates")); + + HloComputation::Builder update_builder("Scatter.update"); + update_builder.AddInstruction( + HloInstruction::CreateParameter(0, ShapeUtil::MakeShape(F32, {}), "p1")); + update_builder.AddInstruction( + HloInstruction::CreateParameter(1, ShapeUtil::MakeShape(F32, {}), "p2")); + + auto module = CreateNewModule(); + auto* update_computation = + module->AddEmbeddedComputation(update_builder.Build()); + + HloInstruction* scatter_instruction = + builder.AddInstruction(HloInstruction::CreateScatter( + input_tensor_shape, input, scatter_indices, scatter_updates, + update_computation, + HloScatterInstruction::MakeScatterDimNumbers( + /*update_window_dims=*/{4, 5, 6, 7, 8}, + /*inserted_window_dims=*/{}, + /*scatter_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, + /*index_vector_dim=*/2))); + module->AddEntryComputation(builder.Build()); + + EXPECT_EQ( + scatter_instruction->ToString(), + "%scatter = f32[50,49,48,47,46]{4,3,2,1,0} " + "scatter(f32[50,49,48,47,46]{4,3,2,1,0} %input_tensor, " + "s64[10,9,5,7,6]{4,3,2,1,0} %scatter_indices, " + "f32[10,9,7,6,30,29,28,27,26]{8,7,6,5,4,3,2,1,0} %scatter_updates), " + "update_window_dims={4,5,6,7,8}, inserted_window_dims={}, " + "scatter_dims_to_operand_dims={0,1,2,3,4}, index_vector_dim=2, " + "to_apply=%Scatter.update"); } TEST_F(HloInstructionTest, CanonnicalStringificationFusion) { diff --git a/tensorflow/compiler/xla/service/hlo_instructions.cc b/tensorflow/compiler/xla/service/hlo_instructions.cc index df26a2c744fbcac814727139e1cf7f23037dcc50..79a5e7481d76a4f87a609ce2258372ae0cae29be 100644 --- a/tensorflow/compiler/xla/service/hlo_instructions.cc +++ b/tensorflow/compiler/xla/service/hlo_instructions.cc @@ -17,6 +17,8 @@ limitations under the License. #include +#include "absl/algorithm/container.h" +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/service/hlo_casting_utils.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" @@ -89,7 +91,7 @@ HloBatchNormTrainingInstruction::CloneWithNewOperandsImpl( tensorflow::gtl::ArraySlice new_operands, HloCloneContext* context) const { CHECK_EQ(new_operands.size(), 3); - return MakeUnique( + return absl::make_unique( shape, new_operands[0], new_operands[1], new_operands[2], epsilon(), feature_index()); } @@ -111,7 +113,7 @@ HloBatchNormInferenceInstruction::CloneWithNewOperandsImpl( tensorflow::gtl::ArraySlice new_operands, HloCloneContext* context) const { CHECK_EQ(new_operands.size(), 5); - return MakeUnique( + return absl::make_unique( shape, new_operands[0], new_operands[1], new_operands[2], new_operands[3], new_operands[4], epsilon(), feature_index()); } @@ -133,7 +135,7 @@ HloBatchNormGradInstruction::CloneWithNewOperandsImpl( tensorflow::gtl::ArraySlice new_operands, HloCloneContext* context) const { CHECK_EQ(new_operands.size(), 5); - return MakeUnique( + return absl::make_unique( shape, new_operands[0], new_operands[1], new_operands[2], new_operands[3], new_operands[4], epsilon(), feature_index()); } @@ -175,8 +177,8 @@ std::unique_ptr HloFftInstruction::CloneWithNewOperandsImpl( tensorflow::gtl::ArraySlice new_operands, HloCloneContext* context) const { CHECK_EQ(new_operands.size(), 1); - return MakeUnique(shape, new_operands[0], fft_type_, - fft_length_); + return absl::make_unique(shape, new_operands[0], fft_type_, + fft_length_); } HloSendRecvInstruction::HloSendRecvInstruction(HloOpcode opcode, @@ -230,8 +232,8 @@ std::unique_ptr HloSendInstruction::CloneWithNewOperandsImpl( tensorflow::gtl::ArraySlice new_operands, HloCloneContext* context) const { CHECK_EQ(new_operands.size(), 2); - return MakeUnique(new_operands[0], new_operands[1], - channel_id(), is_host_transfer()); + return absl::make_unique( + new_operands[0], new_operands[1], channel_id(), is_host_transfer()); } HloSendDoneInstruction::HloSendDoneInstruction(HloSendInstruction* operand, @@ -248,7 +250,7 @@ HloSendDoneInstruction::CloneWithNewOperandsImpl( tensorflow::gtl::ArraySlice new_operands, HloCloneContext* context) const { CHECK_EQ(new_operands.size(), 1); - return MakeUnique( + return absl::make_unique( Cast(new_operands[0]), is_host_transfer()); } @@ -269,7 +271,7 @@ std::unique_ptr HloRecvInstruction::CloneWithNewOperandsImpl( tensorflow::gtl::ArraySlice new_operands, HloCloneContext* context) const { CHECK_EQ(new_operands.size(), 1); - return MakeUnique( + return absl::make_unique( ShapeUtil::GetTupleElementShape(shape, 0), new_operands[0], channel_id(), is_host_transfer()); } @@ -291,7 +293,7 @@ HloRecvDoneInstruction::CloneWithNewOperandsImpl( tensorflow::gtl::ArraySlice new_operands, HloCloneContext* context) const { CHECK_EQ(new_operands.size(), 1); - return MakeUnique( + return absl::make_unique( Cast(new_operands[0]), is_host_transfer()); } @@ -354,11 +356,72 @@ HloAllReduceInstruction::CloneWithNewOperandsImpl( const Shape& shape, tensorflow::gtl::ArraySlice new_operands, HloCloneContext* /*context*/) const { - return MakeUnique( + return absl::make_unique( shape, new_operands, to_apply(), replica_group_ids(), cross_replica_sum_barrier(), all_reduce_id()); } +HloAllToAllInstruction::HloAllToAllInstruction( + const Shape& shape, tensorflow::gtl::ArraySlice operands, + const std::vector& replica_groups, + tensorflow::StringPiece barrier) + : HloInstruction(HloOpcode::kAllToAll, shape), + replica_groups_(replica_groups), + cross_replica_sum_barrier_(barrier.begin(), barrier.end()) { + for (auto operand : operands) { + AppendOperand(operand); + } +} + +bool HloAllToAllInstruction::IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const { + const auto& casted_other = static_cast(other); + return ContainersEqual(replica_groups(), casted_other.replica_groups(), + [](const ReplicaGroup& a, const ReplicaGroup& b) { + return ContainersEqual(a.replica_ids(), + b.replica_ids()); + }) && + cross_replica_sum_barrier() == + casted_other.cross_replica_sum_barrier(); +} + +std::unique_ptr +HloAllToAllInstruction::CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* /*context*/) const { + return absl::make_unique( + shape, new_operands, replica_groups(), cross_replica_sum_barrier()); +} + +std::vector HloAllToAllInstruction::ExtraAttributesToStringImpl( + const HloPrintOptions& options) const { + std::vector result; + std::vector replica_group_str; + for (const ReplicaGroup& group : replica_groups()) { + replica_group_str.push_back( + StrCat("{", Join(group.replica_ids(), ","), "}")); + } + result.push_back( + StrCat("replica_groups={", Join(replica_group_str, ","), "}")); + + if (!cross_replica_sum_barrier().empty()) { + result.push_back(StrCat("barrier=\"", cross_replica_sum_barrier(), "\"")); + } + + return result; +} + +HloInstructionProto HloAllToAllInstruction::ToProto() const { + HloInstructionProto proto = HloInstruction::ToProto(); + *proto.mutable_replica_groups() = {replica_groups_.begin(), + replica_groups_.end()}; + proto.set_cross_replica_sum_barrier(cross_replica_sum_barrier_); + return proto; +} + HloReverseInstruction::HloReverseInstruction( const Shape& shape, HloInstruction* operand, tensorflow::gtl::ArraySlice dimensions) @@ -393,8 +456,8 @@ std::unique_ptr HloReverseInstruction::CloneWithNewOperandsImpl( tensorflow::gtl::ArraySlice new_operands, HloCloneContext* context) const { CHECK_EQ(new_operands.size(), 1); - return MakeUnique(shape, new_operands[0], - dimensions()); + return absl::make_unique(shape, new_operands[0], + dimensions()); } HloConcatenateInstruction::HloConcatenateInstruction( @@ -433,18 +496,19 @@ HloConcatenateInstruction::CloneWithNewOperandsImpl( const Shape& shape, tensorflow::gtl::ArraySlice new_operands, HloCloneContext* context) const { - return MakeUnique(shape, new_operands, - dimensions(0)); + return absl::make_unique(shape, new_operands, + dimensions(0)); } HloReduceInstruction::HloReduceInstruction( - const Shape& shape, HloInstruction* arg, HloInstruction* init_value, + const Shape& shape, tensorflow::gtl::ArraySlice args, tensorflow::gtl::ArraySlice dimensions_to_reduce, HloComputation* reduce_computation) : HloInstruction(HloOpcode::kReduce, shape), dimensions_(dimensions_to_reduce.begin(), dimensions_to_reduce.end()) { - AppendOperand(arg); - AppendOperand(init_value); + for (HloInstruction* arg : args) { + AppendOperand(arg); + } AppendComputation(reduce_computation); } @@ -477,8 +541,8 @@ std::unique_ptr HloReduceInstruction::CloneWithNewOperandsImpl( tensorflow::gtl::ArraySlice new_operands, HloCloneContext* context) const { CHECK_EQ(new_operands.size(), 2); - return MakeUnique( - shape, new_operands[0], new_operands[1], dimensions(), to_apply()); + return absl::make_unique(shape, new_operands, + dimensions(), to_apply()); } HloSortInstruction::HloSortInstruction(const Shape& shape, int64 dimension, @@ -518,7 +582,8 @@ std::unique_ptr HloSortInstruction::CloneWithNewOperandsImpl( HloCloneContext* context) const { HloInstruction* keys = new_operands[0]; HloInstruction* values = new_operands.size() == 2 ? new_operands[1] : nullptr; - return MakeUnique(shape, dimensions(0), keys, values); + return absl::make_unique(shape, dimensions(0), keys, + values); } HloTransposeInstruction::HloTransposeInstruction( @@ -571,8 +636,8 @@ HloTransposeInstruction::CloneWithNewOperandsImpl( tensorflow::gtl::ArraySlice new_operands, HloCloneContext* context) const { CHECK_EQ(new_operands.size(), 1); - return MakeUnique(shape, new_operands[0], - dimensions()); + return absl::make_unique(shape, new_operands[0], + dimensions()); } HloBroadcastInstruction::HloBroadcastInstruction( @@ -610,8 +675,8 @@ HloBroadcastInstruction::CloneWithNewOperandsImpl( tensorflow::gtl::ArraySlice new_operands, HloCloneContext* context) const { CHECK_EQ(new_operands.size(), 1); - return MakeUnique(shape, new_operands[0], - dimensions()); + return absl::make_unique(shape, new_operands[0], + dimensions()); } HloMapInstruction::HloMapInstruction( @@ -668,7 +733,7 @@ std::unique_ptr HloMapInstruction::CloneWithNewOperandsImpl( const Shape& shape, tensorflow::gtl::ArraySlice new_operands, HloCloneContext* context) const { - return MakeUnique(shape, new_operands, to_apply()); + return absl::make_unique(shape, new_operands, to_apply()); } HloSliceInstruction::HloSliceInstruction( @@ -730,8 +795,8 @@ std::unique_ptr HloSliceInstruction::CloneWithNewOperandsImpl( tensorflow::gtl::ArraySlice new_operands, HloCloneContext* context) const { CHECK_EQ(new_operands.size(), 1); - return MakeUnique(shape, new_operands[0], slice_starts_, - slice_limits_, slice_strides_); + return absl::make_unique( + shape, new_operands[0], slice_starts_, slice_limits_, slice_strides_); } HloConstantInstruction::HloConstantInstruction(std::unique_ptr literal) @@ -783,7 +848,7 @@ HloConstantInstruction::CloneWithNewOperandsImpl( const Shape& shape, tensorflow::gtl::ArraySlice new_operands, HloCloneContext* context) const { - return MakeUnique(literal_->CloneToUnique()); + return absl::make_unique(literal_->CloneToUnique()); } string HloConstantInstruction::OperandsToStringWithCanonicalNameMap( @@ -1277,8 +1342,8 @@ std::unique_ptr HloFusionInstruction::CloneWithNewOperandsImpl( new_fused_computation = module->AddEmbeddedComputation( fused_instructions_computation()->Clone("clone", context)); } - return MakeUnique(shape, fusion_kind(), new_operands, - new_fused_computation); + return absl::make_unique( + shape, fusion_kind(), new_operands, new_fused_computation); } Status HloFusionInstruction::DeduplicateFusionOperands() { @@ -1337,7 +1402,8 @@ std::unique_ptr HloRngInstruction::CloneWithNewOperandsImpl( const Shape& shape, tensorflow::gtl::ArraySlice new_operands, HloCloneContext* context) const { - return MakeUnique(shape, distribution_, new_operands); + return absl::make_unique(shape, distribution_, + new_operands); } HloParameterInstruction::HloParameterInstruction(int64 parameter_number, @@ -1373,7 +1439,8 @@ HloParameterInstruction::CloneWithNewOperandsImpl( const Shape& shape, tensorflow::gtl::ArraySlice new_operands, HloCloneContext* context) const { - return MakeUnique(parameter_number_, shape, name()); + return absl::make_unique(parameter_number_, shape, + name()); } HloGetTupleElementInstruction::HloGetTupleElementInstruction( @@ -1409,8 +1476,8 @@ HloGetTupleElementInstruction::CloneWithNewOperandsImpl( tensorflow::gtl::ArraySlice new_operands, HloCloneContext* context) const { CHECK_EQ(new_operands.size(), 1); - return MakeUnique(shape, new_operands[0], - tuple_index()); + return absl::make_unique( + shape, new_operands[0], tuple_index()); } HloReducePrecisionInstruction::HloReducePrecisionInstruction( @@ -1452,7 +1519,7 @@ HloReducePrecisionInstruction::CloneWithNewOperandsImpl( tensorflow::gtl::ArraySlice new_operands, HloCloneContext* context) const { CHECK_EQ(new_operands.size(), 1); - return MakeUnique( + return absl::make_unique( shape, new_operands[0], exponent_bits(), mantissa_bits()); } @@ -1466,13 +1533,6 @@ HloInfeedInstruction::HloInfeedInstruction(const Shape& infeed_shape, AppendOperand(token_operand); } -HloInfeedInstruction::HloInfeedInstruction(const Shape& infeed_shape, - const string& config) - : HloInstruction(HloOpcode::kInfeed, - ShapeUtil::MakeTupleShape( - {infeed_shape, ShapeUtil::MakeTokenShape()})), - infeed_config_(config) {} - HloInstructionProto HloInfeedInstruction::ToProto() const { HloInstructionProto proto = HloInstruction::ToProto(); proto.set_infeed_config(infeed_config_); @@ -1499,13 +1559,9 @@ std::unique_ptr HloInfeedInstruction::CloneWithNewOperandsImpl( const Shape& shape, tensorflow::gtl::ArraySlice new_operands, HloCloneContext* context) const { - if (new_operands.empty()) { - return MakeUnique(infeed_shape(), infeed_config()); - } else { - CHECK_EQ(new_operands.size(), 1); - return MakeUnique(infeed_shape(), new_operands[0], - infeed_config()); - } + CHECK_EQ(new_operands.size(), 1); + return absl::make_unique( + infeed_shape(), new_operands[0], infeed_config()); } HloOutfeedInstruction::HloOutfeedInstruction( @@ -1521,18 +1577,6 @@ HloOutfeedInstruction::HloOutfeedInstruction( AppendOperand(token_operand); } -HloOutfeedInstruction::HloOutfeedInstruction( - const Shape& outfeed_shape, HloInstruction* operand, - tensorflow::StringPiece outfeed_config) - : HloInstruction(HloOpcode::kOutfeed, ShapeUtil::MakeTokenShape()), - outfeed_shape_(outfeed_shape), - outfeed_config_(outfeed_config.begin(), outfeed_config.end()) { - CHECK(ShapeUtil::Compatible(operand->shape(), outfeed_shape)) - << "Outfeed shape " << outfeed_shape - << " must be compatible with operand shape " << operand->shape(); - AppendOperand(operand); -} - HloInstructionProto HloOutfeedInstruction::ToProto() const { HloInstructionProto proto = HloInstruction::ToProto(); proto.set_outfeed_config(outfeed_config()); @@ -1560,22 +1604,19 @@ std::unique_ptr HloOutfeedInstruction::CloneWithNewOperandsImpl( const Shape& shape, tensorflow::gtl::ArraySlice new_operands, HloCloneContext* context) const { - if (new_operands.size() == 1) { - return MakeUnique(outfeed_shape(), new_operands[0], - outfeed_config()); - } else { - CHECK_EQ(new_operands.size(), 2); - return MakeUnique(outfeed_shape(), new_operands[0], - new_operands[1], outfeed_config()); - } + CHECK_EQ(new_operands.size(), 2); + return absl::make_unique( + outfeed_shape(), new_operands[0], new_operands[1], outfeed_config()); } HloConvolutionInstruction::HloConvolutionInstruction( const Shape& shape, HloInstruction* lhs, HloInstruction* rhs, - const Window& window, const ConvolutionDimensionNumbers& dimension_numbers) + const Window& window, const ConvolutionDimensionNumbers& dimension_numbers, + int64 feature_group_count) : HloInstruction(HloOpcode::kConvolution, shape), window_(window), - convolution_dimension_numbers_(dimension_numbers) { + convolution_dimension_numbers_(dimension_numbers), + feature_group_count_(feature_group_count) { if (window_util::HasBaseDilation(window)) { SetAndSanitizeName(StrCat(name(), "-base-dilated")); } @@ -1613,6 +1654,7 @@ std::vector HloConvolutionInstruction::ExtraAttributesToStringImpl( } extra.push_back(StrCat("dim_labels=", ConvolutionDimensionNumbersToString( convolution_dimension_numbers_))); + extra.push_back(StrCat("feature_group_count=", feature_group_count_)); return extra; } @@ -1634,9 +1676,9 @@ HloConvolutionInstruction::CloneWithNewOperandsImpl( tensorflow::gtl::ArraySlice new_operands, HloCloneContext* context) const { CHECK_EQ(new_operands.size(), 2); - return MakeUnique(shape, new_operands[0], - new_operands[1], window(), - convolution_dimension_numbers_); + return absl::make_unique( + shape, new_operands[0], new_operands[1], window(), + convolution_dimension_numbers_, feature_group_count_); } HloReduceWindowInstruction::HloReduceWindowInstruction( @@ -1679,7 +1721,7 @@ HloReduceWindowInstruction::CloneWithNewOperandsImpl( tensorflow::gtl::ArraySlice new_operands, HloCloneContext* context) const { CHECK_EQ(new_operands.size(), 2); - return MakeUnique( + return absl::make_unique( shape, new_operands[0], new_operands[1], window(), to_apply()); } @@ -1728,7 +1770,7 @@ HloSelectAndScatterInstruction::CloneWithNewOperandsImpl( tensorflow::gtl::ArraySlice new_operands, HloCloneContext* context) const { CHECK_EQ(new_operands.size(), 3); - return MakeUnique( + return absl::make_unique( shape, new_operands[0], select(), window(), new_operands[1], new_operands[2], scatter()); } @@ -1803,8 +1845,8 @@ HloCustomCallInstruction::CloneWithNewOperandsImpl( const Shape& shape, tensorflow::gtl::ArraySlice new_operands, HloCloneContext* context) const { - auto cloned = MakeUnique(shape, new_operands, - custom_call_target()); + auto cloned = absl::make_unique( + shape, new_operands, custom_call_target()); if (window_ != nullptr) { cloned->set_window(*window_); } @@ -1845,7 +1887,7 @@ HloHostComputeInstruction::CloneWithNewOperandsImpl( const Shape& shape, tensorflow::gtl::ArraySlice new_operands, HloCloneContext* context) const { - return MakeUnique( + return absl::make_unique( shape, new_operands, channel_name_, cost_estimate_ns_); } @@ -1883,8 +1925,8 @@ std::unique_ptr HloPadInstruction::CloneWithNewOperandsImpl( tensorflow::gtl::ArraySlice new_operands, HloCloneContext* context) const { CHECK_EQ(new_operands.size(), 2); - return MakeUnique(shape, new_operands[0], new_operands[1], - padding_config_); + return absl::make_unique(shape, new_operands[0], + new_operands[1], padding_config_); } HloDynamicSliceInstruction::HloDynamicSliceInstruction( @@ -1923,56 +1965,55 @@ HloDynamicSliceInstruction::CloneWithNewOperandsImpl( tensorflow::gtl::ArraySlice new_operands, HloCloneContext* context) const { CHECK_EQ(new_operands.size(), 2); - return MakeUnique( + return absl::make_unique( shape, new_operands[0], new_operands[1], dynamic_slice_sizes_); } HloGatherInstruction::HloGatherInstruction( - const Shape& shape, HloInstruction* operand, HloInstruction* gather_indices, + const Shape& shape, HloInstruction* operand, HloInstruction* start_indices, const GatherDimensionNumbers& gather_dim_numbers, - tensorflow::gtl::ArraySlice window_bounds) + tensorflow::gtl::ArraySlice slice_sizes) : HloInstruction(HloOpcode::kGather, shape) { AppendOperand(operand); - AppendOperand(gather_indices); + AppendOperand(start_indices); gather_dimension_numbers_ = - MakeUnique(gather_dim_numbers); - c_copy(window_bounds, std::back_inserter(gather_window_bounds_)); + absl::make_unique(gather_dim_numbers); + absl::c_copy(slice_sizes, std::back_inserter(gather_slice_sizes_)); } string HloGatherInstruction::GatherDimensionNumbersToString() const { CHECK(gather_dimension_numbers_ != nullptr); - string output_window_dims = - StrCat("output_window_dims={", - Join(gather_dimension_numbers_->output_window_dims(), ","), "}"); - string elided_window_dims = - StrCat("elided_window_dims={", - Join(gather_dimension_numbers_->elided_window_dims(), ","), "}"); - string gather_dims_to_operand_dims = StrCat( - "gather_dims_to_operand_dims={", - Join(gather_dimension_numbers_->gather_dims_to_operand_dims(), ","), "}"); + string offset_dims = + StrCat("offset_dims={", + Join(gather_dimension_numbers_->offset_dims(), ","), "}"); + string collapsed_slice_dims = + StrCat("collapsed_slice_dims={", + Join(gather_dimension_numbers_->collapsed_slice_dims(), ","), "}"); + string start_index_map = + StrCat("start_index_map={", + Join(gather_dimension_numbers_->start_index_map(), ","), "}"); string index_vector_dim = StrCat( "index_vector_dim=", gather_dimension_numbers_->index_vector_dim()); return Join>( - {output_window_dims, elided_window_dims, gather_dims_to_operand_dims, - index_vector_dim}, + {offset_dims, collapsed_slice_dims, start_index_map, index_vector_dim}, ", "); } /* static */ GatherDimensionNumbers HloGatherInstruction::MakeGatherDimNumbers( - tensorflow::gtl::ArraySlice output_window_dims, - tensorflow::gtl::ArraySlice elided_window_dims, - tensorflow::gtl::ArraySlice gather_dims_to_operand_dims, + tensorflow::gtl::ArraySlice offset_dims, + tensorflow::gtl::ArraySlice collapsed_slice_dims, + tensorflow::gtl::ArraySlice start_index_map, int64 index_vector_dim) { GatherDimensionNumbers gather_dim_numbers; - for (int64 output_window_dim : output_window_dims) { - gather_dim_numbers.add_output_window_dims(output_window_dim); + for (int64 output_window_dim : offset_dims) { + gather_dim_numbers.add_offset_dims(output_window_dim); } - for (int64 elided_window_dim : elided_window_dims) { - gather_dim_numbers.add_elided_window_dims(elided_window_dim); + for (int64 elided_window_dim : collapsed_slice_dims) { + gather_dim_numbers.add_collapsed_slice_dims(elided_window_dim); } - for (int64 gather_dim_to_input_dim : gather_dims_to_operand_dims) { - gather_dim_numbers.add_gather_dims_to_operand_dims(gather_dim_to_input_dim); + for (int64 gather_dim_to_input_dim : start_index_map) { + gather_dim_numbers.add_start_index_map(gather_dim_to_input_dim); } gather_dim_numbers.set_index_vector_dim(index_vector_dim); @@ -1982,8 +2023,8 @@ string HloGatherInstruction::GatherDimensionNumbersToString() const { HloInstructionProto HloGatherInstruction::ToProto() const { HloInstructionProto proto = HloInstruction::ToProto(); *proto.mutable_gather_dimension_numbers() = gather_dimension_numbers(); - for (int64 bound : gather_window_bounds()) { - proto.add_gather_window_bounds(bound); + for (int64 bound : gather_slice_sizes()) { + proto.add_gather_slice_sizes(bound); } return proto; } @@ -1991,7 +2032,7 @@ HloInstructionProto HloGatherInstruction::ToProto() const { std::vector HloGatherInstruction::ExtraAttributesToStringImpl( const HloPrintOptions& options) const { return {GatherDimensionNumbersToString(), - StrCat("window_bounds={", Join(gather_window_bounds(), ","), "}")}; + StrCat("slice_sizes={", Join(gather_slice_sizes(), ","), "}")}; } bool HloGatherInstruction::IdenticalSlowPath( @@ -2002,7 +2043,7 @@ bool HloGatherInstruction::IdenticalSlowPath( return protobuf_util::ProtobufEquals( gather_dimension_numbers(), casted_other.gather_dimension_numbers()) && - gather_window_bounds() == casted_other.gather_window_bounds(); + gather_slice_sizes() == casted_other.gather_slice_sizes(); } std::unique_ptr HloGatherInstruction::CloneWithNewOperandsImpl( @@ -2010,9 +2051,96 @@ std::unique_ptr HloGatherInstruction::CloneWithNewOperandsImpl( tensorflow::gtl::ArraySlice new_operands, HloCloneContext* context) const { CHECK_EQ(new_operands.size(), 2); - return MakeUnique( + return absl::make_unique( shape, new_operands[0], new_operands[1], gather_dimension_numbers(), - gather_window_bounds()); + gather_slice_sizes()); +} + +HloScatterInstruction::HloScatterInstruction( + const Shape& shape, HloInstruction* operand, + HloInstruction* scatter_indices, HloInstruction* updates, + HloComputation* update_computation, + const ScatterDimensionNumbers& scatter_dim_numbers) + : HloInstruction(HloOpcode::kScatter, shape) { + AppendOperand(operand); + AppendOperand(scatter_indices); + AppendOperand(updates); + AppendComputation(update_computation); + scatter_dimension_numbers_ = + absl::make_unique(scatter_dim_numbers); +} + +string HloScatterInstruction::ScatterDimensionNumbersToString() const { + string update_window_dims = + StrCat("update_window_dims={", + Join(scatter_dimension_numbers().update_window_dims(), ","), "}"); + string inserted_window_dims = StrCat( + "inserted_window_dims={", + Join(scatter_dimension_numbers().inserted_window_dims(), ","), "}"); + string scatter_dims_to_operand_dims = StrCat( + "scatter_dims_to_operand_dims={", + Join(scatter_dimension_numbers().scatter_dims_to_operand_dims(), ","), + "}"); + string index_vector_dim = StrCat( + "index_vector_dim=", scatter_dimension_numbers().index_vector_dim()); + + return Join>( + {update_window_dims, inserted_window_dims, scatter_dims_to_operand_dims, + index_vector_dim}, + ", "); +} + +/* static */ ScatterDimensionNumbers +HloScatterInstruction::MakeScatterDimNumbers( + tensorflow::gtl::ArraySlice update_window_dims, + tensorflow::gtl::ArraySlice inserted_window_dims, + tensorflow::gtl::ArraySlice scatter_dims_to_operand_dims, + int64 index_vector_dim) { + ScatterDimensionNumbers scatter_dim_numbers; + for (int64 update_window_dim : update_window_dims) { + scatter_dim_numbers.add_update_window_dims(update_window_dim); + } + for (int64 inserted_window_dim : inserted_window_dims) { + scatter_dim_numbers.add_inserted_window_dims(inserted_window_dim); + } + for (int64 scatter_dim_to_operand_dim : scatter_dims_to_operand_dims) { + scatter_dim_numbers.add_scatter_dims_to_operand_dims( + scatter_dim_to_operand_dim); + } + scatter_dim_numbers.set_index_vector_dim(index_vector_dim); + return scatter_dim_numbers; +} + +HloInstructionProto HloScatterInstruction::ToProto() const { + HloInstructionProto proto = HloInstruction::ToProto(); + *proto.mutable_scatter_dimension_numbers() = scatter_dimension_numbers(); + return proto; +} + +std::vector HloScatterInstruction::ExtraAttributesToStringImpl( + const HloPrintOptions& options) const { + return {ScatterDimensionNumbersToString()}; +} + +bool HloScatterInstruction::IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const { + const auto& casted_other = static_cast(other); + return protobuf_util::ProtobufEquals( + scatter_dimension_numbers(), + casted_other.scatter_dimension_numbers()) && + eq_computations(to_apply(), casted_other.to_apply()); +} + +std::unique_ptr HloScatterInstruction::CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const { + CHECK_EQ(new_operands.size(), 3); + return absl::make_unique( + shape, new_operands[0], new_operands[1], new_operands[2], to_apply(), + scatter_dimension_numbers()); } } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_instructions.h b/tensorflow/compiler/xla/service/hlo_instructions.h index e4031f04d5c0062d73efb2c8f95b462b691407fa..19b69c2171146175a0021f1ab7ad7d39c0b6ad85 100644 --- a/tensorflow/compiler/xla/service/hlo_instructions.h +++ b/tensorflow/compiler/xla/service/hlo_instructions.h @@ -18,6 +18,7 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_INSTRUCTIONS_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_INSTRUCTIONS_H_ +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" namespace xla { @@ -224,8 +225,7 @@ class HloAllReduceInstruction : public HloInstruction { HloComputation* reduce_computation, tensorflow::gtl::ArraySlice replica_group_ids, tensorflow::StringPiece barrier, - const tensorflow::gtl::optional& all_reduce_id = - tensorflow::gtl::nullopt); + const tensorflow::gtl::optional& all_reduce_id); // Returns the group ids of each replica for CrossReplicaSum op. const std::vector& replica_group_ids() const { @@ -274,6 +274,47 @@ class HloAllReduceInstruction : public HloInstruction { tensorflow::gtl::optional all_reduce_id_; }; +class HloAllToAllInstruction : public HloInstruction { + public: + explicit HloAllToAllInstruction( + const Shape& shape, tensorflow::gtl::ArraySlice operand, + const std::vector& replica_groups, + tensorflow::StringPiece barrier); + + const std::vector& replica_groups() const { + return replica_groups_; + } + + // TODO(b/110096724): rename this. + void set_cross_replica_sum_barrier(string barrier) { + cross_replica_sum_barrier_ = barrier; + } + string cross_replica_sum_barrier() const { + return cross_replica_sum_barrier_; + } + + HloInstructionProto ToProto() const override; + + private: + std::vector ExtraAttributesToStringImpl( + const HloPrintOptions& options) const override; + bool IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const override; + + // Implementation for non-common logic of CloneWithNewOperands. + std::unique_ptr CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const override; + + std::vector replica_groups_; + + // The string representation of the barrier config. + string cross_replica_sum_barrier_; +}; + class HloReverseInstruction : public HloInstruction { public: explicit HloReverseInstruction(const Shape& shape, HloInstruction* operand, @@ -332,7 +373,7 @@ class HloConcatenateInstruction : public HloInstruction { class HloReduceInstruction : public HloInstruction { public: explicit HloReduceInstruction( - const Shape& shape, HloInstruction* arg, HloInstruction* init_value, + const Shape& shape, tensorflow::gtl::ArraySlice args, tensorflow::gtl::ArraySlice dimensions_to_reduce, HloComputation* reduce_computation); // Returns the dimension sizes or numbers associated with this instruction. @@ -341,6 +382,18 @@ class HloReduceInstruction : public HloInstruction { // Returns a serialized representation of this instruction. HloInstructionProto ToProto() const override; + // Returns the input tensors to be reduced. + tensorflow::gtl::ArraySlice inputs() const { + return tensorflow::gtl::ArraySlice(operands(), 0, + operand_count() / 2); + } + + // Returns the init values of the reduction. + tensorflow::gtl::ArraySlice init_values() const { + return tensorflow::gtl::ArraySlice( + operands(), operand_count() / 2, operand_count()); + } + private: std::vector ExtraAttributesToStringImpl( const HloPrintOptions& options) const override; @@ -535,6 +588,8 @@ class HloConstantInstruction : public HloInstruction { explicit HloConstantInstruction(const Shape& shape); // Returns the literal associated with this instruction. const Literal& literal() const { return *literal_; } + // Returns whether there is literal associated with this instruction. + bool HasLiteral() const { return literal_ != nullptr; } // Returns a serialized representation of this instruction. HloInstructionProto ToProto() const override; @@ -829,10 +884,6 @@ class HloInfeedInstruction : public HloInstruction { explicit HloInfeedInstruction(const Shape& infeed_shape, HloInstruction* token_operand, const string& config); - // TODO(b/80000000): Remove this constructor when all uses of infeed are - // converted to take tokens. - explicit HloInfeedInstruction(const Shape& infeed_shape, - const string& config); // Returns the infeed configuration string. The infeed configuration includes // any metadata needed for the backend compiler (e.g., infeed buffer address) // and is target-dependent. @@ -871,12 +922,6 @@ class HloOutfeedInstruction : public HloInstruction { HloInstruction* operand, HloInstruction* token_operand, tensorflow::StringPiece outfeed_config); - // TODO(b/80000000): Remove this constructor when all uses of outfeed are - // converted to take tokens. - explicit HloOutfeedInstruction(const Shape& outfeed_shape, - HloInstruction* operand, - tensorflow::StringPiece outfeed_config); - // Returns the shape for the Outfeed instruction. const Shape& outfeed_shape() const { TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(outfeed_shape_)); @@ -911,7 +956,8 @@ class HloConvolutionInstruction : public HloInstruction { explicit HloConvolutionInstruction( const Shape& shape, HloInstruction* lhs, HloInstruction* rhs, const Window& window, - const ConvolutionDimensionNumbers& dimension_numbers); + const ConvolutionDimensionNumbers& dimension_numbers, + int64 feature_group_count); const Window& window() const override { return window_; } void set_window(const Window& window) override { window_ = window; } const ConvolutionDimensionNumbers& convolution_dimension_numbers() const { @@ -921,6 +967,9 @@ class HloConvolutionInstruction : public HloInstruction { const ConvolutionDimensionNumbers& dnums) { convolution_dimension_numbers_ = dnums; } + // The number of feature groups. Must be a divisor of the input feature + // dimension and output feature dimension. + int64 feature_group_count() const { return feature_group_count_; } string ToCategory() const override; // Returns a serialized representation of this instruction. HloInstructionProto ToProto() const override; @@ -940,6 +989,9 @@ class HloConvolutionInstruction : public HloInstruction { Window window_; // Describes the dimension numbers used for a convolution. ConvolutionDimensionNumbers convolution_dimension_numbers_; + // The number of feature groups. Must be a divisor of the input feature + // dimension and output feature dimension. + int64 feature_group_count_; }; class HloReduceWindowInstruction : public HloInstruction { @@ -1029,7 +1081,7 @@ class HloCustomCallInstruction : public HloInstruction { } void set_window(const Window& window) override { - window_ = MakeUnique(window); + window_ = absl::make_unique(window); } const ConvolutionDimensionNumbers& convolution_dimension_numbers() const { @@ -1040,7 +1092,7 @@ class HloCustomCallInstruction : public HloInstruction { void set_convolution_dimension_numbers( const ConvolutionDimensionNumbers& dnums) { convolution_dimension_numbers_ = - MakeUnique(dnums); + absl::make_unique(dnums); } const string& custom_call_target() const { return custom_call_target_; } // Returns a serialized representation of this instruction. @@ -1161,15 +1213,15 @@ class HloGatherInstruction : public HloInstruction { public: explicit HloGatherInstruction( const Shape& shape, HloInstruction* operand, - HloInstruction* gather_indices, + HloInstruction* start_indices, const GatherDimensionNumbers& gather_dim_numbers, - tensorflow::gtl::ArraySlice window_bounds); + tensorflow::gtl::ArraySlice slice_sizes); const GatherDimensionNumbers& gather_dimension_numbers() const { CHECK(gather_dimension_numbers_ != nullptr); return *gather_dimension_numbers_; } - tensorflow::gtl::ArraySlice gather_window_bounds() const { - return gather_window_bounds_; + tensorflow::gtl::ArraySlice gather_slice_sizes() const { + return gather_slice_sizes_; } // Returns the dump string of the gather dimension numbers. string GatherDimensionNumbersToString() const; @@ -1178,9 +1230,9 @@ class HloGatherInstruction : public HloInstruction { // Creates an instance of GatherDimensionNumbers. static GatherDimensionNumbers MakeGatherDimNumbers( - tensorflow::gtl::ArraySlice output_window_dims, - tensorflow::gtl::ArraySlice elided_window_dims, - tensorflow::gtl::ArraySlice gather_dims_to_operand_dims, + tensorflow::gtl::ArraySlice offset_dims, + tensorflow::gtl::ArraySlice collapsed_slice_dims, + tensorflow::gtl::ArraySlice start_index_map, int64 index_vector_dim); private: @@ -1196,7 +1248,46 @@ class HloGatherInstruction : public HloInstruction { HloCloneContext* context) const override; std::unique_ptr gather_dimension_numbers_; - std::vector gather_window_bounds_; + std::vector gather_slice_sizes_; +}; + +class HloScatterInstruction : public HloInstruction { + public: + explicit HloScatterInstruction( + const Shape& shape, HloInstruction* operand, + HloInstruction* scatter_indices, HloInstruction* updates, + HloComputation* update_computation, + const ScatterDimensionNumbers& scatter_dim_numbers); + const ScatterDimensionNumbers& scatter_dimension_numbers() const { + CHECK(scatter_dimension_numbers_ != nullptr); + return *scatter_dimension_numbers_; + } + // Returns the dump string of the scatter dimension numbers. + string ScatterDimensionNumbersToString() const; + // Returns a serialized representation of this instruction. + HloInstructionProto ToProto() const override; + + // Creates an instance of ScatterDimensionNumbers. + static ScatterDimensionNumbers MakeScatterDimNumbers( + tensorflow::gtl::ArraySlice update_window_dims, + tensorflow::gtl::ArraySlice inserted_window_dims, + tensorflow::gtl::ArraySlice scatter_dims_to_operand_dims, + int64 index_vector_dim); + + private: + std::vector ExtraAttributesToStringImpl( + const HloPrintOptions& options) const override; + bool IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const override; + // Implementation for non-common logic of CloneWithNewOperands. + std::unique_ptr CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const override; + + std::unique_ptr scatter_dimension_numbers_; }; } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_lexer.cc b/tensorflow/compiler/xla/service/hlo_lexer.cc index f0d9fdbc8f86da0bb9d7f9235239df677c9506bc..8e0d38b6a63917582b8bfa10f205e1ed511efef3 100644 --- a/tensorflow/compiler/xla/service/hlo_lexer.cc +++ b/tensorflow/compiler/xla/service/hlo_lexer.cc @@ -143,8 +143,47 @@ TokKind HloLexer::LexToken() { return TokKind::kLparen; case ')': return TokKind::kRparen; - case '/': - return LexComment(); + case '/': { + if (PeekCurrentChar() == '*') { + // This is the start of a /*...*/ delimited comment. Save the current + // location in case the comment is unterminated so the error message + // will point to the beginning of the comment. + const char* comment_start = current_ptr_; + current_ptr_++; + // Advance until '*/' is found. + while (true) { + int current = GetNextChar(); + if (current == '*' && PeekCurrentChar() == '/') { + // End of comment. + current_ptr_++; + break; + } + if (current == kEOF) { + // Unterminated comment. + current_ptr_ = comment_start; + return TokKind::kError; + } + } + // Return no token for the comment. Keep lexing. + continue; + } else if (PeekCurrentChar() == '/') { + // This is the start of a '//' delimited comment. Throw away + // everything until end of line or file. The end-of-line character(s) + // are left unlexed in the buffer which is harmless because these are + // skipped later by the lexer. This approach enables support for + // different end-of-line encodings. + while (true) { + int current = PeekCurrentChar(); + if (current == kEOF || current == '\n' || current == '\r') { + break; + } + current_ptr_++; + } + continue; + } + // A lone '/' is an error. + return TokKind::kError; + } case '"': return LexString(); } @@ -299,9 +338,12 @@ TokKind HloLexer::LexNumberOrPattern() { static LazyRE2 int_pattern = {R"([-]?\d+)"}; if (RE2::Consume(&consumable, *int_pattern)) { current_ptr_ = consumable.begin(); - tensorflow::strings::safe_strto64( - StringPieceFromPointers(token_start_, current_ptr_), &int64_val_); - return TokKind::kInt; + auto slice = StringPieceFromPointers(token_start_, current_ptr_); + if (tensorflow::strings::safe_strto64(slice, &int64_val_)) { + return TokKind::kInt; + } + LOG(ERROR) << "Failed to parse int literal: " << slice; + return TokKind::kError; } static LazyRE2 neg_inf = {"-inf"}; @@ -354,16 +396,6 @@ tensorflow::StringPiece HloLexer::GetLine(LocTy loc) const { return StringPieceFromPointers(start, end); } -TokKind HloLexer::LexComment() { - auto consumable = RegexpStringPieceFromPointers(token_start_, buf_.end()); - static LazyRE2 comment_pattern = {R"(\/\*.*?\*\/)"}; - if (RE2::Consume(&consumable, *comment_pattern)) { - current_ptr_ = consumable.begin(); - return TokKind::kComment; - } - return TokKind::kError; -} - // Lexes quoted string with escaping characters. If matched, the quoted string // will be unescaped and stored to str_val_. TokKind HloLexer::LexString() { @@ -409,8 +441,6 @@ string TokKindToString(TokKind kind) { return "kRparen"; case TokKind::kArrow: return "kArrow"; - case TokKind::kComment: - return "kComment"; case TokKind::kw_HloModule: return "kw_HloModule"; case TokKind::kw_ENTRY: diff --git a/tensorflow/compiler/xla/service/hlo_lexer.h b/tensorflow/compiler/xla/service/hlo_lexer.h index ceb674f25e94ac3ac2e6a4a0687a93ffdcd065e0..003ac34ace5713446afa74eb3af96ae33087223e 100644 --- a/tensorflow/compiler/xla/service/hlo_lexer.h +++ b/tensorflow/compiler/xla/service/hlo_lexer.h @@ -105,7 +105,6 @@ class HloLexer { TokKind LexShape(); TokKind LexConstant(); TokKind LexNumberOrPattern(); - TokKind LexComment(); TokKind LexString(); const tensorflow::StringPiece buf_; diff --git a/tensorflow/compiler/xla/service/hlo_liveness_analysis.cc b/tensorflow/compiler/xla/service/hlo_liveness_analysis.cc index 43c41ece6efc4f9e8ca74f16e0f63d29abc4de4e..18f17b75aede734b4971a07347f31ba45db9dc96 100644 --- a/tensorflow/compiler/xla/service/hlo_liveness_analysis.cc +++ b/tensorflow/compiler/xla/service/hlo_liveness_analysis.cc @@ -17,8 +17,8 @@ limitations under the License. #include +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/map_util.h" -#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/call_graph.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" @@ -296,7 +296,7 @@ StatusOr> HloLivenessAnalysis::Run( VLOG(1) << "HloLivenessAnalysis::Run on module " << module.name(); XLA_VLOG_LINES(2, module.ToString()); - auto liveness_analysis = WrapUnique(new HloLivenessAnalysis(module)); + auto liveness_analysis = absl::WrapUnique(new HloLivenessAnalysis(module)); liveness_analysis->RunAnalysis(); diff --git a/tensorflow/compiler/xla/service/hlo_matchers.h b/tensorflow/compiler/xla/service/hlo_matchers.h index b57c940238f0672692e3b65827f43e2f5499502d..c577b4359aae6c66f29860a0e56c3487b07afc02 100644 --- a/tensorflow/compiler/xla/service/hlo_matchers.h +++ b/tensorflow/compiler/xla/service/hlo_matchers.h @@ -231,6 +231,7 @@ HLO_MATCHER(Tanh); HLO_MATCHER(Trace); HLO_MATCHER(Transpose); HLO_MATCHER(Tuple); +HLO_MATCHER(TupleSelect); HLO_MATCHER(While); // The special cases below let you check additional information about the diff --git a/tensorflow/compiler/xla/service/hlo_matchers_test.cc b/tensorflow/compiler/xla/service/hlo_matchers_test.cc index 7de59acc1efbc0150b95ebdd85a13ede48eec2f9..7961aece541faeb66875885b380158756c503250 100644 --- a/tensorflow/compiler/xla/service/hlo_matchers_test.cc +++ b/tensorflow/compiler/xla/service/hlo_matchers_test.cc @@ -157,9 +157,8 @@ TEST(HloMatchersTest, ShardingMatcher) { Array assignment({2}); assignment.SetValues({0, 1}); auto sharding = HloSharding::Tuple( - tuple_shape, - {HloSharding::Tile(ShapeUtil::MakeShape(F32, {5}), assignment), - HloSharding::AssignDevice(1), HloSharding::Replicate()}); + tuple_shape, {HloSharding::Tile(assignment), HloSharding::AssignDevice(1), + HloSharding::Replicate()}); p2->set_sharding(sharding); EXPECT_THAT(p0.get(), op::NoSharding()); @@ -172,8 +171,7 @@ TEST(HloMatchersTest, ShardingMatcher) { EXPECT_THAT( p2.get(), - op::Sharding( - "{{f32[5] devices=[2]0,1}, {maximal device=1}, {replicated}}")); + op::Sharding("{{devices=[2]0,1}, {maximal device=1}, {replicated}}")); EXPECT_THAT(Explain(p0.get(), op::Sharding(HloSharding::AssignDevice(1))), "%param.0 = f32[5]{0} parameter(0) has no sharding (expected: " diff --git a/tensorflow/compiler/xla/service/hlo_module.cc b/tensorflow/compiler/xla/service/hlo_module.cc index 55ff073d3faf34aa0f1b8f0886946837e7a49bcc..d60b76d63f8fb0b3b775e743beaec58316fa3740 100644 --- a/tensorflow/compiler/xla/service/hlo_module.cc +++ b/tensorflow/compiler/xla/service/hlo_module.cc @@ -22,8 +22,9 @@ limitations under the License. #include #include +#include "absl/algorithm/container.h" +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/map_util.h" -#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/core/lib/gtl/map_util.h" @@ -274,7 +275,7 @@ StatusOr> HloModule::CreateFromProto( } TF_RET_CHECK(entry != nullptr); - auto module = MakeUnique(proto.name(), module_config); + auto module = absl::make_unique(proto.name(), module_config); // Sort the computations in the proto id's order. std::sort(computations.begin(), computations.end(), @@ -507,7 +508,7 @@ std::vector HloModule::MakeNonfusionComputations() const { std::unique_ptr HloModule::Clone(const string& suffix) const { VLOG(1) << "Cloning module :" << name_ << " --> " << suffix << "\n"; - auto module = MakeUnique(name_ + "-" + suffix, config_); + auto module = absl::make_unique(name_ + "-" + suffix, config_); HloCloneContext context(module.get(), suffix); auto cloned_computation = entry_computation_->Clone(suffix, &context); @@ -538,9 +539,9 @@ uint64 HloModule::RandomNew64() const { HloComputation* HloModule::GetComputationWithName( tensorflow::StringPiece name) { auto computations_in_module = computations(); - auto it = c_find_if(computations_in_module, [&](HloComputation* computation) { - return computation->name() == name; - }); + auto it = absl::c_find_if( + computations_in_module, + [&](HloComputation* computation) { return computation->name() == name; }); return it == computations_in_module.end() ? nullptr : *it; } diff --git a/tensorflow/compiler/xla/service/hlo_module_config.cc b/tensorflow/compiler/xla/service/hlo_module_config.cc index 07a8c798dbee072db3b75d5e99ca0dcabb5fdf6b..f9708283eb4becd67a76ff30103001c81c2c703a 100644 --- a/tensorflow/compiler/xla/service/hlo_module_config.cc +++ b/tensorflow/compiler/xla/service/hlo_module_config.cc @@ -18,7 +18,7 @@ limitations under the License. #include #include -#include "tensorflow/compiler/xla/ptr_util.h" +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/shape_layout.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/core/lib/strings/str_util.h" diff --git a/tensorflow/compiler/xla/service/hlo_module_group_metadata.cc b/tensorflow/compiler/xla/service/hlo_module_group_metadata.cc index 10bf9ffd6c1960df5ca2a3555d120b0874407f15..3b512bf0f81afc27a6314955b22700baf23f9ef4 100644 --- a/tensorflow/compiler/xla/service/hlo_module_group_metadata.cc +++ b/tensorflow/compiler/xla/service/hlo_module_group_metadata.cc @@ -19,7 +19,7 @@ limitations under the License. #include #include -#include "tensorflow/compiler/xla/ptr_util.h" +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/service/hlo_casting_utils.h" #include "tensorflow/compiler/xla/service/hlo_instructions.h" #include "tensorflow/compiler/xla/shape_util.h" @@ -59,7 +59,7 @@ string HloModuleGroupMetadata::TrackedInstruction::ToString() const { /* static */ StatusOr> HloModuleGroupMetadata::Build(const std::vector& modules) { - auto metadata = MakeUnique(modules); + auto metadata = absl::make_unique(modules); TF_RETURN_IF_ERROR(metadata->Build()); return std::move(metadata); } @@ -383,7 +383,7 @@ Status HloModuleGroupMetadata::AddCompanion(HloInstruction* instruction1, if (!ContainsKey(companion_set_index_, instruction1) && !ContainsKey(companion_set_index_, instruction2)) { companion_sets_.push_back( - tensorflow::MakeUnique>()); + absl::make_unique>()); auto companion_set = companion_sets_.back().get(); companion_set->insert(instruction1); companion_set->insert(instruction2); diff --git a/tensorflow/compiler/xla/service/hlo_module_group_metadata.h b/tensorflow/compiler/xla/service/hlo_module_group_metadata.h index 84f2d3f5fbc1a6ff1df8ba3c0babd122e5701148..1b256cd00e6fc6c91c7b4a7de82eef438a75396f 100644 --- a/tensorflow/compiler/xla/service/hlo_module_group_metadata.h +++ b/tensorflow/compiler/xla/service/hlo_module_group_metadata.h @@ -166,7 +166,7 @@ class HloModuleGroupMetadata { // // Precondition: IsCompanionWhile(instruction) is true. const std::unordered_set& Companions( - HloInstruction* instruction) const { + const HloInstruction* instruction) const { CHECK_EQ(companion_set_index_.count(instruction), 1); return companion_set(companion_set_index_.at(instruction)); } @@ -243,7 +243,7 @@ class HloModuleGroupMetadata { companion_sets_; // Map from each companion while instruction to the index into companion_set_. - tensorflow::gtl::FlatMap companion_set_index_; + tensorflow::gtl::FlatMap companion_set_index_; // Map from computation to the instruction using it (a kWhile, kConditional). tensorflow::gtl::FlatMap diff --git a/tensorflow/compiler/xla/service/hlo_module_group_util.cc b/tensorflow/compiler/xla/service/hlo_module_group_util.cc index 9fd0ade153109c6c809c37aa08257f83a82c44d5..4f11ce322e619f2679716230e7161d474d83a503 100644 --- a/tensorflow/compiler/xla/service/hlo_module_group_util.cc +++ b/tensorflow/compiler/xla/service/hlo_module_group_util.cc @@ -22,13 +22,14 @@ limitations under the License. #include #include -#include "tensorflow/compiler/xla/ptr_util.h" +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/hlo_reachability.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/gtl/flatset.h" #include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/types.h" @@ -37,24 +38,38 @@ namespace xla { std::vector HloModuleGroupUtil::GlobalPredecessors( HloInstruction* instruction) { - std::vector predecessors; - - // Adds to the unique predecessors list and also add companion instructions - // if the given predecessor has those. + std::vector + predecessors; // Use a vector to avoid non-determinism. + tensorflow::gtl::FlatSet unique; + + // Adds to the unique predecessors list; if the predecessors is a companion + // instruction, also add companion instructions; if the predecessors is a + // cross-module all-reduce, also add the all-reduce instructions in the same + // group. auto add_unique_predecessor = [&](HloInstruction* predecessor) { - if (std::find(predecessors.begin(), predecessors.end(), predecessor) != - predecessors.end()) { + if (unique.find(predecessor) != unique.end()) { return; } - if (!metadata_.IsCompanionInstruction(predecessor)) { - predecessors.push_back(predecessor); + if (metadata_.IsCompanionInstruction(predecessor)) { + for (HloInstruction* instr : metadata_.Companions(predecessor)) { + if (unique.insert(instr).second) { + predecessors.push_back(instr); + } + } return; } - for (HloInstruction* companion : metadata_.Companions(predecessor)) { - predecessors.push_back(companion); + if (predecessor->IsCrossModuleAllReduce()) { + for (HloInstruction* instr : + metadata_.GetAllReduceGroup(*predecessor->all_reduce_id())) { + if (unique.insert(instr).second) { + predecessors.push_back(instr); + } + } + return; } + unique.insert(predecessor); + predecessors.push_back(predecessor); }; - // If the given instruction is a companion instruction, we need to find the // predecessors of all of its companion instructions. If the instruction is an // all-reduce, we need to find the predecessors of all the peer all-reduce @@ -98,22 +113,37 @@ std::vector HloModuleGroupUtil::GlobalPredecessors( std::vector HloModuleGroupUtil::GlobalSuccessors( HloInstruction* instruction) { - std::vector successors; - - // Adds to the unique successors list and also add companion instructions - // if the given successor has those. + std::vector + successors; // Use a vector to avoid non-determinism. + tensorflow::gtl::FlatSet unique; + + // Adds to the unique successors list; if the successor is a companion + // instruction, also add companion instructions; if the successor is a + // cross-module all-reduce, also add the all-reduce instructions in the same + // group. auto add_unique_successor = [&](HloInstruction* successor) { - if (std::find(successors.begin(), successors.end(), successor) != - successors.end()) { + if (unique.find(successor) != unique.end()) { return; } - if (!metadata_.IsCompanionInstruction(successor)) { - successors.push_back(successor); + if (metadata_.IsCompanionInstruction(successor)) { + for (HloInstruction* instr : metadata_.Companions(successor)) { + if (unique.insert(instr).second) { + successors.push_back(instr); + } + } return; } - for (HloInstruction* companion : metadata_.Companions(successor)) { - successors.push_back(companion); + if (successor->IsCrossModuleAllReduce()) { + for (HloInstruction* instr : + metadata_.GetAllReduceGroup(*successor->all_reduce_id())) { + if (unique.insert(instr).second) { + successors.push_back(instr); + } + } + return; } + unique.insert(successor); + successors.push_back(successor); }; // If the given instruction is a companion instruction, we need to find the @@ -302,7 +332,7 @@ HloModuleGroupUtil::ComputeReachability( TF_RETURN_IF_ERROR( VisitTopologicalOrder(&visit_states, visit_function, root)); } - auto reachability = MakeUnique(post_order); + auto reachability = absl::make_unique(post_order); for (HloInstruction* hlo : post_order) { reachability->FastSetReachabilityToUnion(GlobalPredecessors(hlo), hlo); } diff --git a/tensorflow/compiler/xla/service/hlo_module_test.cc b/tensorflow/compiler/xla/service/hlo_module_test.cc index 236f4500860a8673e61cbd2f861a8fc40c7861f7..209ad5e58c9360fafc3d63606e61a553de73be13 100644 --- a/tensorflow/compiler/xla/service/hlo_module_test.cc +++ b/tensorflow/compiler/xla/service/hlo_module_test.cc @@ -15,8 +15,8 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_module.h" +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/literal.h" -#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/shape_util.h" diff --git a/tensorflow/compiler/xla/service/hlo_opcode.h b/tensorflow/compiler/xla/service/hlo_opcode.h index 59e9a5a94aa4fc6270bde76c19dbd0d4506a563c..0e0d96ab09cd6f92ff2919bf8e9a9d920ddd884c 100644 --- a/tensorflow/compiler/xla/service/hlo_opcode.h +++ b/tensorflow/compiler/xla/service/hlo_opcode.h @@ -47,6 +47,7 @@ namespace xla { #define HLO_OPCODE_LIST(V) \ V(kAbs, "abs") \ V(kAdd, "add") \ + V(kAllToAll, "all-to-all") \ V(kAtan2, "atan2") \ V(kBatchNormGrad, "batch-norm-grad") \ V(kBatchNormInference, "batch-norm-inference") \ @@ -118,6 +119,7 @@ namespace xla { V(kReverse, "reverse") \ V(kRng, "rng") \ V(kRoundNearestAfz, "round-nearest-afz") \ + V(kScatter, "scatter") \ V(kSelect, "select") \ V(kSelectAndScatter, "select-and-scatter") \ V(kSend, "send") \ @@ -154,7 +156,7 @@ enum HloOpcodeProperty { // Returns a string representation of the opcode. string HloOpcodeString(HloOpcode opcode); -// Returns a string representation of the opcode. +// Retrieves the opcode enum by name if the opcode exists. StatusOr StringToHloOpcode(const string& opcode_name); inline std::ostream& operator<<(std::ostream& os, HloOpcode opcode) { diff --git a/tensorflow/compiler/xla/service/hlo_parser.cc b/tensorflow/compiler/xla/service/hlo_parser.cc index e8eaf54949d6e41ebffabe7963cf737ce5ad4567..3768da8a731efa5e8f4866baaf166386f52a96ee 100644 --- a/tensorflow/compiler/xla/service/hlo_parser.cc +++ b/tensorflow/compiler/xla/service/hlo_parser.cc @@ -15,6 +15,8 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_parser.h" +#include "absl/algorithm/container.h" +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/service/hlo_domain_metadata.h" @@ -125,6 +127,7 @@ class HloParser { kFloat, kString, kBracedInt64List, + kBracedInt64ListList, kHloComputation, kFftType, kWindow, @@ -205,6 +208,10 @@ class HloParser { bool ParseInt64List(const TokKind start, const TokKind end, const TokKind delim, std::vector* result); + // 'parse_and_add_item' is an lambda to parse an element in the list and add + // the parsed element to the result. It's supposed to capture the result. + bool ParseList(const TokKind start, const TokKind end, const TokKind delim, + const std::function& parse_and_add_item); bool ParseParamListToShape(Shape* shape, LocTy* shape_loc); bool ParseParamList(); @@ -299,7 +306,7 @@ bool HloParser::ParseHloModule() { return false; } - module_ = MakeUnique(name, config_); + module_ = absl::make_unique(name, config_); return ParseComputations(); } @@ -352,7 +359,7 @@ bool HloParser::ParseComputation(HloComputation** entry_computation) { if (!ParseName(&name)) { return false; } - auto builder = MakeUnique(name); + auto builder = absl::make_unique(name); LocTy shape_loc = nullptr; Shape shape; @@ -619,6 +626,29 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, } break; } + case HloOpcode::kAllToAll: { + optional>> tmp_groups; + optional barrier; + attrs["replica_groups"] = {/*required=*/false, + AttrTy::kBracedInt64ListList, &tmp_groups}; + attrs["barrier"] = {/*required=*/false, AttrTy::kString, &barrier}; + if (!ParseOperands(&operands) || !ParseAttributes(attrs)) { + return false; + } + std::vector replica_groups; + if (tmp_groups) { + absl::c_transform( + *tmp_groups, std::back_inserter(replica_groups), + [](const std::vector& ids) { + ReplicaGroup group; + *group.mutable_replica_ids() = {ids.begin(), ids.end()}; + return group; + }); + } + instruction = builder->AddInstruction(HloInstruction::CreateAllToAll( + shape, operands, replica_groups, barrier ? *barrier : "")); + break; + } case HloOpcode::kReshape: { if (!ParseOperands(&operands, /*expected_size=*/1) || !ParseAttributes(attrs)) { @@ -798,9 +828,12 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, case HloOpcode::kConvolution: { optional window; optional dnums; + optional feature_group_count; attrs["window"] = {/*required=*/false, AttrTy::kWindow, &window}; attrs["dim_labels"] = {/*required=*/true, AttrTy::kConvolutionDimensionNumbers, &dnums}; + attrs["feature_group_count"] = {/*required=*/false, AttrTy::kInt64, + &feature_group_count}; if (!ParseOperands(&operands, /*expected_size=*/2) || !ParseAttributes(attrs)) { return false; @@ -808,8 +841,12 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, if (!window) { window.emplace(); } + if (!feature_group_count) { + feature_group_count = 1; + } instruction = builder->AddInstruction(HloInstruction::CreateConvolve( - shape, /*lhs=*/operands[0], /*rhs=*/operands[1], *window, *dnums)); + shape, /*lhs=*/operands[0], /*rhs=*/operands[1], *window, *dnums, + feature_group_count.value())); break; } case HloOpcode::kFft: { @@ -865,18 +902,28 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, break; } case HloOpcode::kReduce: { + auto loc = lexer_.GetLoc(); + optional reduce_computation; attrs["to_apply"] = {/*required=*/true, AttrTy::kHloComputation, &reduce_computation}; optional> dimensions_to_reduce; attrs["dimensions"] = {/*required=*/true, AttrTy::kBracedInt64List, &dimensions_to_reduce}; - if (!ParseOperands(&operands, /*expected_size=*/2) || - !ParseAttributes(attrs)) { + if (!ParseOperands(&operands) || !ParseAttributes(attrs)) { return false; } + if (operands.size() % 2) { + return Error(loc, StrCat("expects an even number of operands, but has ", + operands.size(), " operands")); + } instruction = builder->AddInstruction(HloInstruction::CreateReduce( - shape, /*operand=*/operands[0], /*init_value=*/operands[1], + shape, /*operands=*/ + tensorflow::gtl::ArraySlice(operands, 0, + operands.size() / 2), + /*init_values=*/ + tensorflow::gtl::ArraySlice( + operands, operands.size() / 2, operands.size()), *dimensions_to_reduce, *reduce_computation)); break; } @@ -1036,7 +1083,8 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, case HloOpcode::kInfeed: { optional config; attrs["infeed_config"] = {/*required=*/false, AttrTy::kString, &config}; - if (!ParseOperands(&operands) || !ParseAttributes(attrs)) { + if (!ParseOperands(&operands, /*expected_size=*/1) || + !ParseAttributes(attrs)) { return false; } // We need to know the infeed data shape to construct the infeed @@ -1048,41 +1096,21 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, return Error(lexer_.GetLoc(), "infeed must have a non-empty tuple shape"); } - - if (operands.empty()) { - // TODO(b/80000000): Remove this when all uses of infeed are - // converted to take tokens. - instruction = builder->AddInstruction(HloInstruction::CreateInfeed( - ShapeUtil::GetTupleElementShape(shape, 0), config ? *config : "")); - } else if (operands.size() == 1) { - instruction = builder->AddInstruction(HloInstruction::CreateInfeed( - ShapeUtil::GetTupleElementShape(shape, 0), operands[0], - config ? *config : "")); - } else { - return Error(lexer_.GetLoc(), - "infeed must have exactly zero or one operands"); - } + instruction = builder->AddInstruction(HloInstruction::CreateInfeed( + ShapeUtil::GetTupleElementShape(shape, 0), operands[0], + config ? *config : "")); break; } case HloOpcode::kOutfeed: { optional config; attrs["outfeed_config"] = {/*required=*/false, AttrTy::kString, &config}; - if (!ParseOperands(&operands) || !ParseAttributes(attrs)) { + if (!ParseOperands(&operands, /*expected_size=*/2) || + !ParseAttributes(attrs)) { return false; } - if (operands.size() == 1) { - // TODO(b/80000000): Remove this when all uses of outfeed are - // converted to take tokens. - instruction = builder->AddInstruction(HloInstruction::CreateOutfeed( - operands[0]->shape(), operands[0], config ? *config : "")); - } else if (operands.size() == 2) { - instruction = builder->AddInstruction( - HloInstruction::CreateOutfeed(operands[0]->shape(), operands[0], - operands[1], config ? *config : "")); - } else { - return Error(lexer_.GetLoc(), - "outfeed must have exactly one or two operands"); - } + instruction = builder->AddInstruction( + HloInstruction::CreateOutfeed(operands[0]->shape(), operands[0], + operands[1], config ? *config : "")); break; } case HloOpcode::kRng: { @@ -1132,13 +1160,24 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, } case HloOpcode::kCustomCall: { optional custom_call_target; + optional window; + optional dnums; attrs["custom_call_target"] = {/*required=*/true, AttrTy::kString, &custom_call_target}; + attrs["window"] = {/*required=*/false, AttrTy::kWindow, &window}; + attrs["dim_labels"] = {/*required=*/false, + AttrTy::kConvolutionDimensionNumbers, &dnums}; if (!ParseOperands(&operands) || !ParseAttributes(attrs)) { return false; } instruction = builder->AddInstruction(HloInstruction::CreateCustomCall( shape, operands, *custom_call_target)); + if (window.has_value()) { + instruction->set_window(*window); + } + if (dnums.has_value()) { + instruction->set_convolution_dimension_numbers(*dnums); + } break; } case HloOpcode::kHostCompute: { @@ -1197,22 +1236,21 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, break; } case HloOpcode::kGather: { - optional> output_window_dims; - attrs["output_window_dims"] = { - /*required=*/true, AttrTy::kBracedInt64List, &output_window_dims}; - optional> elided_window_dims; - attrs["elided_window_dims"] = { - /*required=*/true, AttrTy::kBracedInt64List, &elided_window_dims}; - optional> gather_dims_to_operand_dims; - attrs["gather_dims_to_operand_dims"] = {/*required=*/true, - AttrTy::kBracedInt64List, - &gather_dims_to_operand_dims}; + optional> offset_dims; + attrs["offset_dims"] = {/*required=*/true, AttrTy::kBracedInt64List, + &offset_dims}; + optional> collapsed_slice_dims; + attrs["collapsed_slice_dims"] = { + /*required=*/true, AttrTy::kBracedInt64List, &collapsed_slice_dims}; + optional> start_index_map; + attrs["start_index_map"] = {/*required=*/true, AttrTy::kBracedInt64List, + &start_index_map}; optional index_vector_dim; attrs["index_vector_dim"] = {/*required=*/true, AttrTy::kInt64, &index_vector_dim}; - optional> window_bounds; - attrs["window_bounds"] = {/*required=*/true, AttrTy::kBracedInt64List, - &window_bounds}; + optional> slice_sizes; + attrs["slice_sizes"] = {/*required=*/true, AttrTy::kBracedInt64List, + &slice_sizes}; if (!ParseOperands(&operands, /*expected_size=*/2) || !ParseAttributes(attrs)) { @@ -1221,14 +1259,50 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, GatherDimensionNumbers dim_numbers = HloGatherInstruction::MakeGatherDimNumbers( - /*output_window_dims=*/*output_window_dims, - /*elided_window_dims=*/*elided_window_dims, - /*gather_dims_to_operand_dims=*/*gather_dims_to_operand_dims, + /*offset_dims=*/*offset_dims, + /*collapsed_slice_dims=*/*collapsed_slice_dims, + /*start_index_map=*/*start_index_map, /*index_vector_dim=*/*index_vector_dim); instruction = builder->AddInstruction(HloInstruction::CreateGather( - shape, /*operand=*/operands[0], /*gather_indices=*/operands[1], - dim_numbers, *window_bounds)); + shape, /*operand=*/operands[0], /*start_indices=*/operands[1], + dim_numbers, *slice_sizes)); + break; + } + case HloOpcode::kScatter: { + optional> update_window_dims; + attrs["update_window_dims"] = { + /*required=*/true, AttrTy::kBracedInt64List, &update_window_dims}; + optional> inserted_window_dims; + attrs["inserted_window_dims"] = { + /*required=*/true, AttrTy::kBracedInt64List, &inserted_window_dims}; + optional> scatter_dims_to_operand_dims; + attrs["scatter_dims_to_operand_dims"] = {/*required=*/true, + AttrTy::kBracedInt64List, + &scatter_dims_to_operand_dims}; + optional index_vector_dim; + attrs["index_vector_dim"] = {/*required=*/true, AttrTy::kInt64, + &index_vector_dim}; + + optional update_computation; + attrs["to_apply"] = {/*required=*/true, AttrTy::kHloComputation, + &update_computation}; + + if (!ParseOperands(&operands, /*expected_size=*/3) || + !ParseAttributes(attrs)) { + return false; + } + + ScatterDimensionNumbers dim_numbers = + HloScatterInstruction::MakeScatterDimNumbers( + /*update_window_dims=*/*update_window_dims, + /*inserted_window_dims=*/*inserted_window_dims, + /*scatter_dims_to_operand_dims=*/*scatter_dims_to_operand_dims, + /*index_vector_dim=*/*index_vector_dim); + + instruction = builder->AddInstruction(HloInstruction::CreateScatter( + shape, /*operand=*/operands[0], /*scatter_indices=*/operands[1], + /*updates=*/operands[2], *update_computation, dim_numbers)); break; } case HloOpcode::kDomain: { @@ -1326,7 +1400,6 @@ bool HloParser::ParseSingleSharding(OpSharding* sharding, bool replicated = false; std::vector devices; std::vector tile_assignment_dimensions; - Shape tile_shape; while (lexer_.GetKind() != TokKind::kRbrace) { switch (lexer_.GetKind()) { case TokKind::kw_maximal: @@ -1377,7 +1450,8 @@ bool HloParser::ParseSingleSharding(OpSharding* sharding, break; } case TokKind::kShape: - tile_shape = lexer_.GetShapeVal(); + // TODO(b/112302613): Left here for backward compatibility to ignore the + // removed tile shape data. lexer_.Lex(); break; case TokKind::kRbrace: @@ -1392,19 +1466,12 @@ bool HloParser::ParseSingleSharding(OpSharding* sharding, return Error(loc, "replicated shardings should not have any devices assigned"); } - if (!ShapeUtil::Equal(tile_shape, Shape())) { - return Error(loc, - "replicated shardings should not have any tile shape set"); - } sharding->set_type(OpSharding::Type::OpSharding_Type_REPLICATED); } else if (maximal) { if (devices.size() != 1) { return Error(loc, "maximal shardings should have exactly one device assigned"); } - if (!ShapeUtil::Equal(tile_shape, Shape())) { - return Error(loc, "maximal shardings should not have any tile shape set"); - } sharding->set_type(OpSharding::Type::OpSharding_Type_MAXIMAL); sharding->add_tile_assignment_devices(devices[0]); } else { @@ -1412,9 +1479,6 @@ bool HloParser::ParseSingleSharding(OpSharding* sharding, return Error( loc, "non-maximal shardings must have more than one device assigned"); } - if (ShapeUtil::Equal(tile_shape, Shape())) { - return Error(loc, "non-maximal shardings should have a tile shape set"); - } if (tile_assignment_dimensions.empty()) { return Error( loc, @@ -1422,7 +1486,6 @@ bool HloParser::ParseSingleSharding(OpSharding* sharding, "dimensions"); } sharding->set_type(OpSharding::Type::OpSharding_Type_OTHER); - *sharding->mutable_tile_shape() = tile_shape; for (tensorflow::int64 dim : tile_assignment_dimensions) { sharding->add_tile_assignment_dimensions(dim); } @@ -1449,14 +1512,14 @@ bool HloParser::ParseDomain(DomainData* domain) { return false; } if (*kind == ShardingMetadata::KindName()) { - auto entry_sharding_ptr = MakeUnique( + auto entry_sharding_ptr = absl::make_unique( HloSharding::FromProto(*entry_sharding).ValueOrDie()); - auto exit_sharding_ptr = MakeUnique( + auto exit_sharding_ptr = absl::make_unique( HloSharding::FromProto(*exit_sharding).ValueOrDie()); domain->entry_metadata = - MakeUnique(std::move(entry_sharding_ptr)); + absl::make_unique(std::move(entry_sharding_ptr)); domain->exit_metadata = - MakeUnique(std::move(exit_sharding_ptr)); + absl::make_unique(std::move(exit_sharding_ptr)); } else { return TokenError(StrCat("unsupported domain kind: ", *kind)); } @@ -1579,6 +1642,24 @@ bool HloParser::SetValueInLiteralHelper(ParsedElemT value, "value ", value, " is out of range for literal's primitive type ", PrimitiveType_Name(literal->shape().element_type()))); } + } else if (std::is_unsigned::value) { + CHECK((std::is_same::value || + std::is_same::value)) + << "Unimplemented checking for ParsedElemT"; + + ParsedElemT upper_bound; + if (sizeof(LiteralNativeT) >= sizeof(ParsedElemT)) { + upper_bound = std::numeric_limits::max(); + } else { + upper_bound = + static_cast(std::numeric_limits::max()); + } + if (value > upper_bound || value < 0) { + // Value is out of range for LiteralNativeT. + return TokenError(StrCat( + "value ", value, " is out of range for literal's primitive type ", + PrimitiveType_Name(literal->shape().element_type()))); + } } else if (value > static_cast( std::numeric_limits::max()) || value < static_cast( @@ -1733,7 +1814,6 @@ bool HloParser::ParseDenseLiteral(std::unique_ptr* literal, break; } case TokKind::kComma: - case TokKind::kComment: // Skip. lexer_.Lex(); break; @@ -1848,7 +1928,7 @@ bool HloParser::ParseSparseLiteralHelper(std::unique_ptr* literal, tensorflow::int64 rank = ShapeUtil::Rank(shape); - *literal = MakeUnique(shape); + *literal = absl::make_unique(shape); if (!ParseToken(TokKind::kLbrace, "expects '{' at the beginning of a sparse literal")) { @@ -2180,6 +2260,26 @@ bool HloParser::ParseAttributeHelper( ->emplace(result); return true; } + case AttrTy::kBracedInt64ListList: { + std::vector> result; + auto parse_and_add_item = [&]() { + std::vector item; + if (!ParseInt64List(TokKind::kLbrace, TokKind::kRbrace, + TokKind::kComma, &item)) { + return false; + } + result.push_back(item); + return true; + }; + if (!ParseList(TokKind::kLbrace, TokKind::kRbrace, TokKind::kComma, + parse_and_add_item)) { + return false; + } + static_cast>>*>( + attr_out_ptr) + ->emplace(result); + return true; + } case AttrTy::kSliceRanges: { SliceRanges result; if (!ParseSliceRanges(&result)) { @@ -2522,6 +2622,26 @@ bool HloParser::ParseInt64List(const TokKind start, const TokKind end, end, StrCat("expects an int64 list to end with ", TokKindToString(end))); } +bool HloParser::ParseList(const TokKind start, const TokKind end, + const TokKind delim, + const std::function& parse_and_add_item) { + if (!ParseToken(start, StrCat("expects a list starting with ", + TokKindToString(start)))) { + return false; + } + if (lexer_.GetKind() == end) { + // empty + } else { + do { + if (!parse_and_add_item()) { + return false; + } + } while (EatIfPresent(delim)); + } + return ParseToken( + end, StrCat("expects a list to end with ", TokKindToString(end))); +} + // param_list_to_shape ::= param_list '->' shape bool HloParser::ParseParamListToShape(Shape* shape, LocTy* shape_loc) { if (!ParseParamList() || !ParseToken(TokKind::kArrow, "expects '->'")) { diff --git a/tensorflow/compiler/xla/service/hlo_parser.h b/tensorflow/compiler/xla/service/hlo_parser.h index 3f3a51215e34bbdd667f1cb20d0ae968e0ce5efd..5f0f75c480ecb3fba0253ed07a30e43b08a56600 100644 --- a/tensorflow/compiler/xla/service/hlo_parser.h +++ b/tensorflow/compiler/xla/service/hlo_parser.h @@ -16,7 +16,7 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_PARSER_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_PARSER_H_ -#include "tensorflow/compiler/xla/ptr_util.h" +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_lexer.h" diff --git a/tensorflow/compiler/xla/service/hlo_parser_test.cc b/tensorflow/compiler/xla/service/hlo_parser_test.cc index 1f0572c576c5b22cb7827ff26197e816132ce62e..0d7919346b13f2dcd227c5afe8972c610b69f829 100644 --- a/tensorflow/compiler/xla/service/hlo_parser_test.cc +++ b/tensorflow/compiler/xla/service/hlo_parser_test.cc @@ -380,7 +380,7 @@ ENTRY %Convolve1D1Window_0.v3 (input: f32[1,2,1], filter: f32[1,1,1]) -> f32[1,2 %input = f32[1,2,1]{2,1,0} parameter(0) %copy = f32[1,2,1]{2,0,1} copy(f32[1,2,1]{2,1,0} %input) %filter = f32[1,1,1]{2,1,0} parameter(1) - ROOT %convolution = f32[1,2,1]{2,0,1} convolution(f32[1,2,1]{2,0,1} %copy, f32[1,1,1]{2,1,0} %filter), window={size=1}, dim_labels=b0f_0io->b0f + ROOT %convolution = f32[1,2,1]{2,0,1} convolution(f32[1,2,1]{2,0,1} %copy, f32[1,1,1]{2,1,0} %filter), window={size=1}, dim_labels=b0f_0io->b0f, feature_group_count=1 } )" @@ -393,7 +393,7 @@ R"(HloModule ConvolveR2_module ENTRY %ConvolveR2.v3 (input: f32[1,2], filter: f32[1,1]) -> f32[1,2] { %input = f32[1,2]{1,0} parameter(0) %filter = f32[1,1]{1,0} parameter(1) - ROOT %convolution = f32[1,2]{0,1} convolution(f32[1,2]{1,0} %input, f32[1,1]{1,0} %filter), dim_labels=bf_io->bf + ROOT %convolution = f32[1,2]{0,1} convolution(f32[1,2]{1,0} %input, f32[1,1]{1,0} %filter), dim_labels=bf_io->bf, feature_group_count=1 } )" @@ -406,7 +406,7 @@ R"(HloModule ConvolveBackward_module ENTRY %ConvolveBackward (input: f32[128,7,7,512], filter: f32[3,3,512,512]) -> f32[128,14,14,512] { %input = f32[128,7,7,512]{0,3,2,1} parameter(0) %filter = f32[3,3,512,512]{3,2,1,0} parameter(1) - ROOT %convolution-base-dilated = f32[128,14,14,512]{0,3,2,1} convolution(f32[128,7,7,512]{0,3,2,1} %input, f32[3,3,512,512]{3,2,1,0} %filter), window={size=3x3 pad=1_2x1_2 lhs_dilate=2x2 rhs_reversal=1x1}, dim_labels=b01f_01oi->b01f + ROOT %convolution-base-dilated = f32[128,14,14,512]{0,3,2,1} convolution(f32[128,7,7,512]{0,3,2,1} %input, f32[3,3,512,512]{3,2,1,0} %filter), window={size=3x3 pad=1_2x1_2 lhs_dilate=2x2 rhs_reversal=1x1}, dim_labels=b01f_01oi->b01f, feature_group_count=1 } )" @@ -752,10 +752,50 @@ ENTRY %sparse_f32_r1 () -> f32[9] { "gather", R"(HloModule StringifyGather -ENTRY %Gather (input_tensor: f32[50,49,48,47,46], gather_indices: s64[10,9,8,7,5]) -> f32[10,9,8,7,30,29,28,27,26] { +ENTRY %Gather (input_tensor: f32[50,49,48,47,46], start_indices: s64[10,9,8,7,5]) -> f32[10,9,8,7,30,29,28,27,26] { %input_tensor = f32[50,49,48,47,46]{4,3,2,1,0} parameter(0) - %gather_indices = s64[10,9,8,7,5]{4,3,2,1,0} parameter(1) - ROOT %gather = f32[10,9,8,7,30,29,28,27,26]{8,7,6,5,4,3,2,1,0} gather(f32[50,49,48,47,46]{4,3,2,1,0} %input_tensor, s64[10,9,8,7,5]{4,3,2,1,0} %gather_indices), output_window_dims={4,5,6,7,8}, elided_window_dims={}, gather_dims_to_operand_dims={0,1,2,3,4}, index_vector_dim=4, window_bounds={30,29,28,27,26} + %start_indices = s64[10,9,8,7,5]{4,3,2,1,0} parameter(1) + ROOT %gather = f32[10,9,8,7,30,29,28,27,26]{8,7,6,5,4,3,2,1,0} gather(f32[50,49,48,47,46]{4,3,2,1,0} %input_tensor, s64[10,9,8,7,5]{4,3,2,1,0} %start_indices), offset_dims={4,5,6,7,8}, collapsed_slice_dims={}, start_index_map={0,1,2,3,4}, index_vector_dim=4, slice_sizes={30,29,28,27,26} +} + +)" +}, +{ +"scatter", +R"(HloModule StringifyScatter + +%add_F32.v3 (lhs: f32[], rhs: f32[]) -> f32[] { + %lhs = f32[] parameter(0) + %rhs = f32[] parameter(1) + ROOT %add = f32[] add(f32[] %lhs, f32[] %rhs) +} + +ENTRY %Scatter (input_tensor: f32[50,49,48,47,46], scatter_indices: s64[10,9,8,7,5], updates: f32[10,9,8,7,30,29,28,27,26]) -> f32[50,49,48,47,46] { + %input_tensor = f32[50,49,48,47,46]{4,3,2,1,0} parameter(0) + %scatter_indices = s64[10,9,8,7,5]{4,3,2,1,0} parameter(1) + %updates = f32[10,9,8,7,30,29,28,27,26]{8,7,6,5,4,3,2,1,0} parameter(2) + ROOT %scatter = f32[50,49,48,47,46]{4,3,2,1,0} scatter(f32[50,49,48,47,46]{4,3,2,1,0} %input_tensor, s64[10,9,8,7,5]{4,3,2,1,0} %scatter_indices, f32[10,9,8,7,30,29,28,27,26]{8,7,6,5,4,3,2,1,0} %updates), update_window_dims={4,5,6,7,8}, inserted_window_dims={}, scatter_dims_to_operand_dims={0,1,2,3,4}, index_vector_dim=4, to_apply=%add_F32.v3 +} + +)" +}, +{ + "ConstantUnsignedNoUnderflow", + R"(HloModule ConstantUnsignedNoUnderflow_module + +ENTRY %ConstantUnsignedNoUnderflow () -> u64[] { + ROOT %constant = u64[] constant(1) +} + +)" +}, + +{ + "ConstantUnsignedNoOverflow", + R"(HloModule ConstantUnsignedNoOverflow_module + +ENTRY %ConstantUnsignedNoOverflow () -> u64[] { + ROOT %constant = u64[] constant(9223372036854775807) } )" @@ -803,6 +843,32 @@ ENTRY ReduceR3ToR2.v3 { ROOT reduce = f32[8,16]{1,0} reduce(input, constant), dimensions={2}, to_apply=add_F32.v3 } +)" +}, +// tuple reduce +{ +"TupleReduce", +R"(HloModule TupleReduce + +max_argmax { + value = f32[] parameter(2) + prev_max = f32[] parameter(0) + is_next_larger = pred[] greater-than-or-equal-to(value, prev_max) + max = f32[] select(is_next_larger, value, prev_max) + index = s32[] parameter(3) + prev_argmax = s32[] parameter(1) + argmax = s32[] select(is_next_larger, index, prev_argmax) + ROOT pair = (f32[], s32[]) tuple(max, argmax) +} + +ENTRY reduce_entry { + values = f32[1024]{0} parameter(0) + indices = f32[1024]{0} parameter(1) + init_value = f32[] constant(-inf) + init_index = s32[] constant(-1) + ROOT result = (f32[], s32[]) reduce(values, indices, init_value, init_index), dimensions={0}, to_apply=max_argmax +} + )" }, // infeed/outfeed @@ -964,8 +1030,8 @@ R"(HloModule gather ENTRY Gather { input_tensor = f32[50,49,48,47,46]{4,3,2,1,0} parameter(0) - gather_indices = s64[10,9,8,7,5]{4,3,2,1,0} parameter(1) - ROOT gather = f32[10,9,8,7,30,29,28,27,26]{8,7,6,5,4,3,2,1,0} gather(input_tensor, gather_indices), output_window_dims={4,5,6,7,8}, elided_window_dims={}, gather_dims_to_operand_dims={0,1,2,3,4}, index_vector_dim=4, window_bounds={30,29,28,27,26} + start_indices = s64[10,9,8,7,5]{4,3,2,1,0} parameter(1) + ROOT gather = f32[10,9,8,7,30,29,28,27,26]{8,7,6,5,4,3,2,1,0} gather(input_tensor, start_indices), offset_dims={4,5,6,7,8}, collapsed_slice_dims={}, start_index_map={0,1,2,3,4}, index_vector_dim=4, slice_sizes={30,29,28,27,26} } )" @@ -1004,6 +1070,30 @@ ENTRY CrossReplicaSumWithSubgroups { ROOT cross-replica-sum = f32[128,32]{0,1} cross-replica-sum(input), replica_group_ids={0,0,1,1}, barrier="abc", to_apply=add } +)" +}, +// all-to-all +{ +"AllToAll", +R"(HloModule AllToAll + +ENTRY AllToAll { + input = f32[128,32]{0,1} parameter(0) + ROOT a2a = f32[128,32]{0,1} all-to-all(input), replica_groups={} +} + +)" +}, +// all-to-all with subgroups +{ +"AllToAllWithSubgroups", +R"(HloModule AllToAllWithSubgroups + +ENTRY AllToAllWithSubgroups { + input = f32[128,32]{0,1} parameter(0) + ROOT a2a = f32[128,32]{0,1} all-to-all(input), replica_groups={{1,2},{3,0}}, barrier="abc" +} + )" }, // Iota @@ -1015,6 +1105,17 @@ ENTRY Iota { ROOT iota = f32[100]{0} iota() } +)" +}, +// custom-call with window and dim_labels +{ +"CustomCallWithWindowAndDimLabels", +R"(HloModule CustomCallWithWindowAndDimLabels + +ENTRY Computation { + ROOT r = f32[100]{0} custom-call(), window={size=2x2}, dim_labels=b01f_01io->b01f, custom_call_target="target" +} + )" } }); @@ -1213,6 +1314,40 @@ ENTRY %ConstantF16Overflow.v4 () -> f16[] { "is out of range for literal's primitive type F16"); } +TEST_F(HloParserTest, ConstantUnsignedUnderflow) { + const string original = R"( + HloModule ConstantUnsignedUnderflow_module + ENTRY %ConstantUnsignedUnderflow () -> u64[] { + ROOT %constant = u64[] constant(-1) + })"; + auto result = ParseHloString(original); + EXPECT_NE(Status::OK(), result.status()); + ExpectHasSubstr(result.status().error_message(), + "is out of range for literal's primitive type U64"); +} + +TEST_F(HloParserTest, ConstantUnsignedOverflow) { + const string original = R"( + HloModule ConstantUnsignedOverflow_module + ENTRY %ConstantUnsignedOverflow () -> u32[] { + ROOT %constant = u32[] constant(4294967296) + })"; + auto result = ParseHloString(original); + EXPECT_NE(Status::OK(), result.status()); + ExpectHasSubstr(result.status().error_message(), + "is out of range for literal's primitive type U32"); +} + +TEST_F(HloParserTest, ConstantUnsignedInt64Overflow) { + const string original = R"( + HloModule ConstantUnsignedOverflow_module + ENTRY %ConstantUnsignedOverflow () -> u64[] { + ROOT %constant = u64[] constant(9223372036854775808) + })"; + auto result = ParseHloString(original); + EXPECT_NE(Status::OK(), result.status()); +} + TEST_F(HloParserTest, ConstantWithExp) { const string original = R"(HloModule ConstantWithExp_module @@ -1235,7 +1370,7 @@ ENTRY %Convolve1D1Window_0.v3 (input: f32[1,2,1], filter: f32[1,1,1]) -> f32[1,2 %input = f32[1,2,1]{2,1,0} parameter(0) %copy = f32[1,2,1]{2,0,1} copy(f32[1,2,1]{2,1,0} %input) %filter = f32[1,1,1]{2,1,0} parameter(1) - ROOT %convolution = f32[1,2,1]{2,0,1} convolution(f32[1,2,1]{2,0,1} %copy, f32[1,1,1]{2,1,0} %filter), sharding={maximal device=1}, backend_config="foo", dim_labels=b0f_0io->b0f, window={pad=1_1 size=2} + ROOT %convolution = f32[1,2,1]{2,0,1} convolution(f32[1,2,1]{2,0,1} %copy, f32[1,1,1]{2,1,0} %filter), feature_group_count=1, sharding={maximal device=1}, backend_config="foo", dim_labels=b0f_0io->b0f, window={pad=1_1 size=2} } )"; @@ -1425,6 +1560,81 @@ ENTRY consts { "last"); } +TEST_F(HloParserTest, Comments) { + const string original = R"(/* module description. */ +HloModule comments: + +ENTRY /*comment*/ c1 { + /* blah */ + ROOT const1 = /*foo*/f32[1]{0} constant({12345 /*bar*/}) + /* comment */ +} + +/* something else */ + +)"; + auto module = ParseHloString(original); + TF_ASSERT_OK(module.status()); +} + +TEST_F(HloParserTest, MultilineComments) { + const string original = R"(HloModule multiline_comment: +ENTRY c1 { + /* + ROOT foo = f32[1]{0} constant({12345}) + */ + ROOT const1 = f32[1]{0} constant({12345}) +/* +a +b +c +d + +*/ +})"; + auto module = ParseHloString(original); + TF_ASSERT_OK(module.status()); +} + +TEST_F(HloParserTest, UnterminatedComment) { + const string original = R"(HloModule unterminated_comment: +ENTRY c1 { +/* unterminated + ROOT const1 = f32[1]{0} constant({12345}) +})"; + // Verify that the error message points to the beginning of the unterminated + // comment. + ExpectHasSubstr(ParseHloString(original).status().error_message(), + "/* unterminated\n^"); +} + +TEST_F(HloParserTest, SlashSlashComments) { + const string original = R"(HloModule slash_slash_comment: +// Garbage +ENTRY c1 { + // Foo bar + ROOT const1 = f32[1]{0} constant({12345}) // Something else +})"; + auto module = ParseHloString(original); + TF_ASSERT_OK(module.status()); +} + +TEST_F(HloParserTest, SlashSlashCommentMsDosEolFormat) { + const string original = + "HloModule slash_slash_comment:\r\n// Garbage\r\nENTRY c1 {\r\n// Foo " + "bar\r\nROOT const1 = f32[1]{0} constant({12345}) // Something else\r\n}"; + auto module = ParseHloString(original); + TF_ASSERT_OK(module.status()); +} + +TEST_F(HloParserTest, SlashSlashCommentMacEolFormat) { + const string original = + "HloModule slash_slash_comment:\r// Garbage\rENTRY c1 {\r// Foo " + "bar\rROOT const1 = f32[1]{0} constant({12345}) // Something else\r}"; + auto module = ParseHloString(original); + TF_ASSERT_OK(module.status()); +} + TEST_F(HloParserTest, MultipleEntries) { const string original = R"(HloModule multiple_entries: ENTRY c1 { diff --git a/tensorflow/compiler/xla/service/hlo_pass_fix.h b/tensorflow/compiler/xla/service/hlo_pass_fix.h index b3d0a07add39968c6310392ea01daeab8a7dd9af..791b1a97b0b82edf19ff1588fd8d5d996ac0fef4 100644 --- a/tensorflow/compiler/xla/service/hlo_pass_fix.h +++ b/tensorflow/compiler/xla/service/hlo_pass_fix.h @@ -16,6 +16,8 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_PASS_FIX_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_PASS_FIX_H_ +#include + #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/statusor.h" @@ -34,9 +36,19 @@ class HloPassFix : public Pass { StatusOr Run(HloModule* module) override { bool changed = false; bool changed_this_iteration = true; + int64 iteration_count = 0; + int64 limit = + std::max(static_cast(1000), module->instruction_count()); while (changed_this_iteration) { TF_ASSIGN_OR_RETURN(changed_this_iteration, Pass::Run(module)); changed |= changed_this_iteration; + ++iteration_count; + if (iteration_count == limit) { + LOG(ERROR) + << "Unexpectedly high number of iterations in HLO passes (" + << iteration_count + << ")\nIf compilation hangs here, please file a bug with XLA."; + } } return changed; } diff --git a/tensorflow/compiler/xla/service/hlo_pass_pipeline.h b/tensorflow/compiler/xla/service/hlo_pass_pipeline.h index a42d7e59fed2d838dfe3cb7f99e6b946edfdb0b4..3bb1342aa370c09dc5cd180e6b0abade4a62c91d 100644 --- a/tensorflow/compiler/xla/service/hlo_pass_pipeline.h +++ b/tensorflow/compiler/xla/service/hlo_pass_pipeline.h @@ -21,7 +21,7 @@ limitations under the License. #include #include -#include "tensorflow/compiler/xla/ptr_util.h" +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_pass_interface.h" #include "tensorflow/compiler/xla/statusor.h" diff --git a/tensorflow/compiler/xla/service/hlo_runner.cc b/tensorflow/compiler/xla/service/hlo_runner.cc index b2725e2918ce76248d9f2cdbb2a6e5a63226bf9a..8f3ae9c62127d8bd79f272f801d9aa9a3043ab6a 100644 --- a/tensorflow/compiler/xla/service/hlo_runner.cc +++ b/tensorflow/compiler/xla/service/hlo_runner.cc @@ -19,9 +19,9 @@ limitations under the License. #include #include +#include "absl/memory/memory.h" #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #include "tensorflow/compiler/xla/layout_util.h" -#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/service/transfer_manager.h" #include "tensorflow/compiler/xla/shape_util.h" @@ -233,7 +233,7 @@ StatusOr>> HloRunner::ExecuteReplicated( int64 device = device_assignment(i, 0); TF_ASSIGN_OR_RETURN(se::StreamExecutor * executor, backend().stream_executor(device)); - streams.push_back(MakeUnique(executor)); + streams.push_back(absl::make_unique(executor)); streams.back()->Init(); service_run_options.emplace_back(GetServiceRunOptionsForDevice( device, streams.back().get(), &device_assignment)); @@ -260,7 +260,7 @@ StatusOr>> HloRunner::ExecuteReplicated( num_threads += options.num_replicas; } if (num_threads > 0) { - pool = MakeUnique( + pool = absl::make_unique( tensorflow::Env::Default(), "infeed_outfeed", /*num_threads=*/num_threads); } @@ -291,7 +291,7 @@ StatusOr>> HloRunner::ExecuteReplicated( VLOG(1) << "Starting outfeed on device " << device; for (int64 step = 1; options.infeed_steps < 0 || step <= options.infeed_steps; ++step) { - auto literal = MakeUnique(); + auto literal = absl::make_unique(); TF_CHECK_OK(backend().transfer_manager()->TransferLiteralFromOutfeed( executor, options.outfeed_shape, literal.get())); if (options.outfeed_values != nullptr) { diff --git a/tensorflow/compiler/xla/service/hlo_scheduling_test.cc b/tensorflow/compiler/xla/service/hlo_scheduling_test.cc index cf9ceed5b2fb49eb91fea96d89c8e1efc2a3dad1..9ec983c2bc353955cb23d441d200ac8aa36951b1 100644 --- a/tensorflow/compiler/xla/service/hlo_scheduling_test.cc +++ b/tensorflow/compiler/xla/service/hlo_scheduling_test.cc @@ -282,7 +282,7 @@ TEST_F(HloSchedulingTest, TuplesAreAccountedCorrectly) { TF_ASSERT_OK_AND_ASSIGN( SequentialHloOrdering::HloModuleSequence sequence, ScheduleComputationsInModule(*module, - [&TUPLE_SIZE](const BufferValue& buffer) { + [](const BufferValue& buffer) { return ShapeUtil::ByteSizeOf( buffer.shape(), TUPLE_SIZE); }, diff --git a/tensorflow/compiler/xla/service/hlo_sharding.cc b/tensorflow/compiler/xla/service/hlo_sharding.cc index 393944c20faa0b09ebc8544543b62566c836739f..0cba9ebbcb03598ed6a6c2603941c8950260a143 100644 --- a/tensorflow/compiler/xla/service/hlo_sharding.cc +++ b/tensorflow/compiler/xla/service/hlo_sharding.cc @@ -31,12 +31,9 @@ HloSharding HloSharding::Tile1D(const Shape& input_shape, int64 num_tiles) { CHECK_EQ(1, ShapeUtil::Rank(input_shape)); CHECK_GT(num_tiles, 1); std::vector dimensions(1, num_tiles); - Shape tile_shape = input_shape; - auto& tile_dimension = (*tile_shape.mutable_dimensions())[0]; - tile_dimension = CeilOfRatio(static_cast(tile_dimension), num_tiles); Array assignment(dimensions); std::iota(assignment.begin(), assignment.end(), 0); - return HloSharding(tile_shape, assignment); + return HloSharding(assignment); } HloSharding HloSharding::Tuple(const ShapeTree& sub_shardings) { @@ -104,8 +101,7 @@ string HloSharding::ToString() const { return StrCat( "{maximal device=", static_cast(*tile_assignment_.begin()), "}"); } else { - return StrCat("{", ShapeUtil::HumanString(tile_shape_), " ", "devices=[", - Join(tile_assignment_.dimensions(), ","), "]", + return StrCat("{devices=[", Join(tile_assignment_.dimensions(), ","), "]", Join(tile_assignment_, ","), "}"); } } @@ -127,15 +123,15 @@ std::map HloSharding::UsedDevices(int64* count) const { if (IsTuple()) { for (auto& tuple_element_sharding : tuple_elements()) { auto unique_device = tuple_element_sharding.UniqueDevice(); - if (unique_device.ok()) { - device_map[unique_device.ValueOrDie()] += 1; + if (unique_device) { + device_map[*unique_device] += 1; } } element_count = tuple_elements().size(); } else { auto unique_device = UniqueDevice(); - if (unique_device.ok()) { - device_map[unique_device.ValueOrDie()] += 1; + if (unique_device) { + device_map[*unique_device] += 1; } } if (count != nullptr) { @@ -145,7 +141,6 @@ std::map HloSharding::UsedDevices(int64* count) const { } std::vector HloSharding::TileIndexForDevice(int64 device) const { - CHECK(!ShapeUtil::IsTuple(tile_shape_)); CHECK(!maximal_); CHECK(!IsTuple()); std::vector ret_index; @@ -165,32 +160,43 @@ int64 HloSharding::DeviceForTileIndex( if (maximal_) { return *tile_assignment_.begin(); } - CHECK_EQ(ShapeUtil::Rank(tile_shape_), tile_assignment_.dimensions().size()); return tile_assignment_(index); } -std::vector HloSharding::TileOffsetForDevice(int64 device) const { +std::vector HloSharding::TileOffsetForDevice(const Shape& shape, + int64 device) const { CHECK(!IsTuple()); - std::vector index = TileIndexForDevice(device); if (maximal_) { - // Index will always be all zeroes if we're maximal, and tile_shape_ is not - // valid. - return index; + return std::vector(shape.dimensions_size(), 0); } + + CHECK_EQ(shape.dimensions_size(), tile_assignment_.num_dimensions()); + std::vector index = TileIndexForDevice(device); for (int64 i = 0; i < index.size(); ++i) { - index[i] *= tile_shape_.dimensions(i); + const int64 shape_dim = shape.dimensions(i); + index[i] = std::min( + index[i] * CeilOfRatio(shape_dim, tile_assignment_.dim(i)), shape_dim); } return index; } -std::vector HloSharding::TileLimitForDevice(int64 device) const { +std::vector HloSharding::TileLimitForDevice(const Shape& shape, + int64 device) const { CHECK(!IsTuple()); - CHECK(!maximal_); // Maximal shardings do not have a valid tile shape. + if (maximal_) { + return std::vector(shape.dimensions().begin(), + shape.dimensions().end()); + } + + CHECK_EQ(shape.dimensions_size(), tile_assignment_.num_dimensions()); std::vector index = TileIndexForDevice(device); for (int64 i = 0; i < index.size(); ++i) { - index[i] = (index[i] + 1) * tile_shape_.dimensions(i); + const int64 shape_dim = shape.dimensions(i); + index[i] = std::min( + (index[i] + 1) * CeilOfRatio(shape_dim, tile_assignment_.dim(i)), + shape_dim); } return index; } @@ -238,40 +244,31 @@ StatusOr HloSharding::GetTupleSharding(const Shape& shape) const { return Tuple(ShapeTree(shape, *this)); } -StatusOr HloSharding::UniqueDevice() const { +tensorflow::gtl::optional HloSharding::UniqueDevice() const { if (IsTuple()) { if (tuple_elements_.empty()) { - return tensorflow::errors::InvalidArgument( - "UniqueDevice() called on empty tuple"); + return tensorflow::gtl::nullopt; } - std::vector> results; - std::transform(tuple_elements_.begin(), tuple_elements_.end(), - std::back_inserter(results), - [](const HloSharding& s) { return s.UniqueDevice(); }); - if (std::all_of(results.begin(), results.end(), - [&](const StatusOr& s) { - return s.ok() && results[0].ok() && - s.ValueOrDie() == results[0].ValueOrDie(); - })) { - return results[0]; - } else { - return tensorflow::errors::InvalidArgument( - "Tuple did not contain a unique device"); + tensorflow::gtl::optional unique_device; + for (auto& tuple_sharding : tuple_elements_) { + auto device = tuple_sharding.UniqueDevice(); + if (!device || (unique_device && *device != *unique_device)) { + return tensorflow::gtl::nullopt; + } + unique_device = device; } + return unique_device; } - if (!replicated_ && maximal_ && !IsTuple()) { + if (!replicated_ && maximal_) { return static_cast(*tile_assignment_.begin()); } - return tensorflow::errors::InvalidArgument( - "UniqueDevice() called on sharding that executes on multiple devices"); + return tensorflow::gtl::nullopt; } -bool HloSharding::HasUniqueDevice() const { - if (IsTuple()) { - return UniqueDevice().status().ok(); - } else { - return !IsReplicated() && IsTileMaximal(); - } +int64 HloSharding::GetUniqueDevice() const { + auto device = UniqueDevice(); + CHECK(device) << "Sharding does not have a unique device: " << *this; + return *device; } Status HloSharding::ValidateTuple(const Shape& shape, int64 num_devices) const { @@ -345,11 +342,12 @@ Status HloSharding::ValidateNonTuple(const Shape& shape, return Status::OK(); } - // The tile rank must be the same as the input rank. - if (ShapeUtil::Rank(shape) != ShapeUtil::Rank(tile_shape_)) { + // The tile assignment tensor must have the same rank as the input. + if (ShapeUtil::Rank(shape) != tile_assignment_.num_dimensions()) { return tensorflow::errors::InvalidArgument( - "Tile rank is different to the input rank. sharding=", ToString(), - ", input_shape=", ShapeUtil::HumanString(shape)); + "Number of tile assignment dimensions is different to the input rank. " + "sharding=", + ToString(), ", input_shape=", ShapeUtil::HumanString(shape)); } // The correct constructor have to be used to create tile maximal shardings. @@ -359,20 +357,6 @@ Status HloSharding::ValidateNonTuple(const Shape& shape, "sharding was intended, use HloSharding::Replicated(). If a device " "placement was intended, use HloSharding::AssignDevice()"); } - - // The tile assignment tensor must contain enough element to cover the full - // shape with tiles of the specified size. - for (int64 i = 0, e = tile_assignment_.dimensions().size(); i != e; ++i) { - int64 total_tile_size = tile_assignment_.dim(i) * tile_shape_.dimensions(i); - if (shape.dimensions(i) > total_tile_size) { - return tensorflow::errors::InvalidArgument( - StrCat("Tile assignment tensor has too few element to cover the full " - "shape. Dimension ", - i, ", shape ", shape.dimensions(i), ", total size ", - total_tile_size)); - } - } - return Status::OK(); } @@ -402,7 +386,7 @@ Status HloSharding::ValidateNonTuple(const Shape& shape, proto.tile_assignment_dimensions().end())); std::copy(proto.tile_assignment_devices().begin(), proto.tile_assignment_devices().end(), tile_assignment.begin()); - return HloSharding(proto.tile_shape(), tile_assignment); + return HloSharding(tile_assignment); } OpSharding HloSharding::ToProto() const { @@ -416,7 +400,6 @@ OpSharding HloSharding::ToProto() const { return result; } - *result.mutable_tile_shape() = tile_shape_; for (int64 dim : tile_assignment_.dimensions()) { result.add_tile_assignment_dimensions(dim); } @@ -433,30 +416,16 @@ OpSharding HloSharding::ToProto() const { return result; } -HloSharding HloSharding::TransformShardedTileShape( - const Shape& new_shape, - const std::function& transform) const { - CHECK(!IsTuple()); +Shape HloSharding::TileShape(const Shape& shape) const { if (IsTileMaximal()) { - return *this; + return shape; } - CHECK_EQ(ShapeUtil::Rank(new_shape), ShapeUtil::Rank(tile_shape())); - Shape new_tile_shape; - new_tile_shape.set_element_type(tile_shape().element_type()); - for (int64 i = 0; i < ShapeUtil::Rank(new_shape); ++i) { - int64 dim; - if (tile_assignment().dim(i) == 1) { - dim = new_shape.dimensions(i); - } else if (transform) { - dim = transform(i, tile_shape().dimensions(i)); - } else { - dim = tile_shape().dimensions(i); - } - new_tile_shape.add_dimensions(dim); + Shape result_shape = shape; + for (int64 i = 0; i < shape.dimensions_size(); ++i) { + (*result_shape.mutable_dimensions())[i] = + CeilOfRatio(shape.dimensions(i), tile_assignment_.dim(i)); } - TF_CHECK_OK( - LayoutUtil::CopyLayoutBetweenShapes(tile_shape_, &new_tile_shape)); - return HloSharding::Tile(new_tile_shape, tile_assignment()); + return result_shape; } HloSharding HloSharding::GetSubSharding(const Shape& shape, @@ -484,7 +453,7 @@ tensorflow::gtl::optional HloSharding::ExtractSingleSharding() } size_t HloSharding::Hash() const { - if (!tuple_) { + if (tuple_) { size_t h = 0; for (const auto& element : tuple_elements_) { h = tensorflow::Hash64Combine(h, element.Hash()); @@ -498,9 +467,6 @@ size_t HloSharding::Hash() const { for (uint32 v : tile_assignment_) { h = tensorflow::Hash64Combine(h, std::hash{}(v)); } - for (uint32 v : tile_shape_.dimensions()) { - h = tensorflow::Hash64Combine(h, std::hash{}(v)); - } return h; } diff --git a/tensorflow/compiler/xla/service/hlo_sharding.h b/tensorflow/compiler/xla/service/hlo_sharding.h index 6f672b0f28d2b85411d70f33da9a9f270aefc0d0..894783e5d1538fa4e8e91b65827121f32040af83 100644 --- a/tensorflow/compiler/xla/service/hlo_sharding.h +++ b/tensorflow/compiler/xla/service/hlo_sharding.h @@ -48,22 +48,10 @@ class HloSharding { // the input shape (one tile) assigned to a single device. static HloSharding AssignDevice(int64 device_id); - // Creates a new sharding which splits a shape into tiles each with shape - // `tile_shape`. Each tile is assigned to one device, which is specified by - // `tile_assignment`. Any tensor not a multiple of the tile size in any - // dimension is implicitly padded to the tile size. - // - // e.g. Tile({2, 2}, {0, 1}) on a tensor of shape {3, 2} would look like: - // 2 1 padding - // <------><-> - // +----+----+ - // | 0 | 1 | - // +----+----+ - // - // Split into two tiles, one of which is implicitly padded by one. - static HloSharding Tile(const Shape& tile_shape, - const Array& tile_assignment) { - return HloSharding(tile_shape, tile_assignment); + // Creates a new sharding which splits a shape into tiles amongst the devices + // specified by `tile_assignment`. + static HloSharding Tile(const Array& tile_assignment) { + return HloSharding(tile_assignment); } // Creates a new sharding which splits a one-dimensional input shape into @@ -146,24 +134,30 @@ class HloSharding { // REQUIRES: !IsTuple() int64 DeviceForTileIndex(tensorflow::gtl::ArraySlice index) const; - // Given a device ID, returns the offset within the input space of the + // Given a device ID, returns the offset within the specified shape of the // tile that should be executed on the given core. This returns the lower // extent of the tile in the input space. // REQUIRES: !IsTuple() - std::vector TileOffsetForDevice(int64 device) const; + std::vector TileOffsetForDevice(const Shape& shape, + int64 device) const; - // Given a device ID, returns the limit within the input space of the + // Given a device ID, returns the limit within the specified shape of the // tile that should be executed on the given core. This returns the upper // extent of the tile in the input space. // REQUIRES: !IsTuple() - std::vector TileLimitForDevice(int64 device) const; + std::vector TileLimitForDevice(const Shape& shape, int64 device) const; - // Returns the single device this op operates on. - // REQUIRES: !IsTuple&& !Replicated() && IsTileMaximal() - StatusOr UniqueDevice() const; + // Returns the single device this op operates on. If the sharding does not + // span a single device, the return value will be empty. + // In order for a sharding to span a single device, every leaf sharding must + // be maximal and not replicated, and the used device must match. + tensorflow::gtl::optional UniqueDevice() const; + + // Retrieves the unique device or fails with a CHECK. + int64 GetUniqueDevice() const; // Returns true if this op only uses a single device. - bool HasUniqueDevice() const; + bool HasUniqueDevice() const { return UniqueDevice().has_value(); } // Returns the ShapeTree containing the shardings for each element of this // tuple, if IsTuple, or a ShapeTree with a single element containing this @@ -192,7 +186,6 @@ class HloSharding { bool operator==(const HloSharding& other) const { return replicated_ == other.replicated_ && maximal_ == other.maximal_ && - ShapeUtil::Compatible(tile_shape_, other.tile_shape_) && tile_assignment_ == other.tile_assignment_ && tuple_elements_ == other.tuple_elements_; } @@ -206,9 +199,6 @@ class HloSharding { } }; - // Gets the tile shape. - // REQUIRES: !IsTileMaximal() && !IsTuple() - const Shape& tile_shape() const { return tile_shape_; } // Gets the tile assignment tensor. // REQUIRES: !IsReplicated() && !IsTuple() const Array& tile_assignment() const { return tile_assignment_; } @@ -220,25 +210,15 @@ class HloSharding { return tuple_elements_; } - // Return a new sharding that can apply to the given new shape. - // If this sharding is tile-maximal, the returned sharding will be the same as - // this sharding. If this sharding is not tile-maximal, the returned - // sharding's tile size will differ: - // - Non-sharded dimensions will be adapted to be the same as `new_shape`; - // tile_dimension(i) = new_shape.dimensions(i); - // - Sharded dimensions will be kept the same unless `transform` is supplied - // in which case tile_dimension(i) = transform(i, tile_dimension(i)); - // REQUIRES: !IsTuple(). - HloSharding TransformShardedTileShape( - const Shape& new_shape, - const std::function& transform = nullptr) const; + // Gets the tile shape. + // REQUIRES: !IsTuple() + Shape TileShape(const Shape& shape) const; private: HloSharding() : replicated_(true), maximal_(true), tuple_(false), - tile_shape_(), tile_assignment_({0}) {} // device_id values: // -2: magic number to mean unassigned device, used by spatial partitioning @@ -250,15 +230,13 @@ class HloSharding { : replicated_(false), maximal_(true), tuple_(false), - tile_shape_(), tile_assignment_({1}, device_id) {} - HloSharding(const Shape& tile_shape, const Array& tile_assignment) + explicit HloSharding(const Array& tile_assignment) : replicated_(false), maximal_(false), tuple_(false), - tile_shape_(tile_shape), tile_assignment_(tile_assignment) {} - HloSharding(const std::vector& tuple_shardings) + explicit HloSharding(const std::vector& tuple_shardings) : replicated_(false), maximal_(false), tuple_(true), @@ -281,7 +259,6 @@ class HloSharding { bool replicated_; bool maximal_; bool tuple_; - Shape tile_shape_; Array tile_assignment_; // Only non-empty when tuple_ is true, but because empty tuples are allowed // may also be empty even then. This is a flattened list of all the leaf diff --git a/tensorflow/compiler/xla/service/hlo_sharding_metadata.cc b/tensorflow/compiler/xla/service/hlo_sharding_metadata.cc index 94f5a3b273b2fd7e545472c42f3863f549dd3db1..4e19557f8295d38de639f06e8402e38316aa3fc5 100644 --- a/tensorflow/compiler/xla/service/hlo_sharding_metadata.cc +++ b/tensorflow/compiler/xla/service/hlo_sharding_metadata.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_sharding_metadata.h" +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/shape_tree.h" #include "tensorflow/compiler/xla/shape_util.h" @@ -121,9 +122,9 @@ std::unique_ptr CloneShardingForDomain( const HloSharding& sharding) { auto single_sharding = sharding.ExtractSingleSharding(); if (!single_sharding) { - return MakeUnique(sharding); + return absl::make_unique(sharding); } - return MakeUnique(*single_sharding); + return absl::make_unique(*single_sharding); } Status ApplyDomainSingleSharding(const DomainMetadata::Domain& domain, @@ -158,7 +159,6 @@ ShapeTree GetTupleSharding(HloInstruction* tuple) { const HloSharding* GetOperandSharding(const HloInstruction* operand, const DomainMetadata::Domain& domain, const HloSharding& sharding) { - DCHECK_EQ(domain.reach_set.count(const_cast(operand)), 1); // Here the user of operand is within the domain instruction set, and since it // is user of operand, we need to look into the enter_domains set. If this is // not a kDomain within the user domains set, then return the operand @@ -203,10 +203,17 @@ StatusOr ApplyDomainShardingPass(const DomainMetadata::Domain& domain, for (int64 i = 0; i < instruction->operand_count(); ++i) { const HloSharding* operand_sharding = GetOperandSharding(instruction->operand(i), domain, sharding); - if (operand_sharding != nullptr && - shape_tree.element({i}) != *operand_sharding) { - *shape_tree.mutable_element({i}) = *operand_sharding; - ++tuple_assigned; + if (operand_sharding != nullptr) { + HloSharding operand_subsharding = HloSharding::Replicate(); + if (operand_sharding == &sharding) { + operand_subsharding = + sharding.GetSubSharding(instruction->shape(), {i}); + operand_sharding = &operand_subsharding; + } + if (shape_tree.element({i}) != *operand_sharding) { + *shape_tree.mutable_element({i}) = *operand_sharding; + ++tuple_assigned; + } } } if (tuple_assigned > 0) { @@ -312,9 +319,9 @@ std::unique_ptr CreateDomain(HloInstruction* instruction, : "None"); std::unique_ptr operand_side_metadata = - MakeUnique(std::move(real_operand_sharding)); + absl::make_unique(std::move(real_operand_sharding)); std::unique_ptr user_side_metadata = - MakeUnique(std::move(real_instruction_sharding)); + absl::make_unique(std::move(real_instruction_sharding)); return HloInstruction::CreateDomain(operand->shape(), operand, std::move(operand_side_metadata), std::move(user_side_metadata)); @@ -351,9 +358,9 @@ StatusOr> ExtractOriginalCommonSharding( std::unique_ptr ShardingMetadata::Clone() const { std::unique_ptr sharding; if (sharding_ != nullptr) { - sharding = MakeUnique(*sharding_); + sharding = absl::make_unique(*sharding_); } - return MakeUnique(std::move(sharding)); + return absl::make_unique(std::move(sharding)); } bool ShardingMetadata::Matches(const DomainMetadata& other) const { diff --git a/tensorflow/compiler/xla/service/hlo_sharding_test.cc b/tensorflow/compiler/xla/service/hlo_sharding_test.cc index 7baa927d0e2b1abbbb2333633d16dd605ae8c8ef..45fc300fcaf5a301fe11768da77a7c0907919c39 100644 --- a/tensorflow/compiler/xla/service/hlo_sharding_test.cc +++ b/tensorflow/compiler/xla/service/hlo_sharding_test.cc @@ -39,7 +39,6 @@ Array MakeArray(tensorflow::gtl::ArraySlice dimensions, class HloShardingTest : public HloTestBase {}; TEST_F(HloShardingTest, Replicate) { - Shape tile_shape = ShapeUtil::MakeShape(U32, {4}); HloSharding sharding = HloSharding::Replicate(); EXPECT_TRUE(sharding.IsReplicated()); EXPECT_TRUE(sharding.IsTileMaximal()); @@ -51,7 +50,7 @@ TEST_F(HloShardingTest, Replicate) { EXPECT_IS_OK(sharding.Validate(ShapeUtil::MakeShape(U32, {4}), /*num_devices=*/2)); - EXPECT_IS_NOT_OK(sharding.UniqueDevice()); + EXPECT_FALSE(sharding.HasUniqueDevice()); } TEST_F(HloShardingTest, DevicePlacement) { @@ -60,7 +59,7 @@ TEST_F(HloShardingTest, DevicePlacement) { EXPECT_TRUE(sharding.IsTileMaximal()); EXPECT_FALSE(sharding.UsesDevice(0)); EXPECT_TRUE(sharding.UsesDevice(5)); - EXPECT_EQ(5, sharding.UniqueDevice().ValueOrDie()); + EXPECT_EQ(5, sharding.GetUniqueDevice()); HloSharding other = HloSharding::Replicate(); EXPECT_NE(other, sharding); @@ -79,37 +78,22 @@ TEST_F(HloShardingTest, DevicePlacement) { TEST_F(HloShardingTest, Tile) { { // Test should fail because of a duplicate tile assignment. - Shape tile_shape = ShapeUtil::MakeShape(U32, {2, 3}); - HloSharding sharding = - HloSharding::Tile(tile_shape, MakeArray({2, 2}, {0, 0, 2, 3})); + HloSharding sharding = HloSharding::Tile(MakeArray({2, 2}, {0, 0, 2, 3})); EXPECT_IS_NOT_OK(sharding.Validate(ShapeUtil::MakeShape(F32, {4, 6}), /*num_devices=*/4)); } { // Test should fail because of more devices used then `num_device`. - Shape tile_shape = ShapeUtil::MakeShape(U32, {2, 3}); - HloSharding sharding = - HloSharding::Tile(tile_shape, MakeArray({2, 2}, {0, 1, 2, 3})); + HloSharding sharding = HloSharding::Tile(MakeArray({2, 2}, {0, 1, 2, 3})); EXPECT_IS_NOT_OK(sharding.Validate(ShapeUtil::MakeShape(U32, {4, 6}), /*num_devices=*/2)); } - { - // Test should fail because the total tiled size in dimension 0 is 4 but we - // have 6 elements along that dimensions. - Shape tile_shape = ShapeUtil::MakeShape(U32, {2, 3}); - HloSharding sharding = - HloSharding::Tile(tile_shape, MakeArray({2, 2}, {0, 1, 2, 3})); - EXPECT_IS_NOT_OK(sharding.Validate(ShapeUtil::MakeShape(F32, {6, 3}), - /*num_devices=*/4)); - } - { // Test should pass. - Shape tile_shape = ShapeUtil::MakeShape(U32, {2, 3}); - HloSharding sharding = - HloSharding::Tile(tile_shape, MakeArray({2, 2}, {0, 3, 2, 1})); + Shape shape = ShapeUtil::MakeShape(U32, {4, 5}); + HloSharding sharding = HloSharding::Tile(MakeArray({2, 2}, {0, 3, 2, 1})); EXPECT_IS_OK(sharding.Validate(ShapeUtil::MakeShape(F32, {3, 5}), /*num_devices=*/5)); @@ -118,12 +102,16 @@ TEST_F(HloShardingTest, Tile) { EXPECT_EQ(2, sharding.DeviceForTileIndex({1, 0})); EXPECT_EQ(1, sharding.DeviceForTileIndex({1, 1})); - EXPECT_EQ(sharding.TileOffsetForDevice(0), (std::vector{0, 0})); - EXPECT_EQ(sharding.TileOffsetForDevice(3), (std::vector{0, 3})); - EXPECT_EQ(sharding.TileOffsetForDevice(2), (std::vector{2, 0})); - EXPECT_EQ(sharding.TileOffsetForDevice(1), (std::vector{2, 3})); + EXPECT_EQ(sharding.TileOffsetForDevice(shape, 0), + (std::vector{0, 0})); + EXPECT_EQ(sharding.TileOffsetForDevice(shape, 3), + (std::vector{0, 3})); + EXPECT_EQ(sharding.TileOffsetForDevice(shape, 2), + (std::vector{2, 0})); + EXPECT_EQ(sharding.TileOffsetForDevice(shape, 1), + (std::vector{2, 3})); - EXPECT_IS_NOT_OK(sharding.UniqueDevice()); + EXPECT_FALSE(sharding.HasUniqueDevice()); } } @@ -135,8 +123,7 @@ TEST_F(HloShardingTest, NestedTuple) { ShapeUtil::MakeShape(F32, {4, 6}), }); - HloSharding tiled_sharding = HloSharding::Tile( - ShapeUtil::MakeShape(F32, {4, 3}), Array({{0, 1}})); + HloSharding tiled_sharding = HloSharding::Tile(Array({{0, 1}})); OpSharding proto; proto.set_type(OpSharding::Type::OpSharding_Type_TUPLE); *proto.add_tuple_shardings() = HloSharding::Replicate().ToProto(); @@ -187,32 +174,11 @@ TEST_F(HloShardingTest, Hash) { } { - Shape tile_shape = ShapeUtil::MakeShape(U32, {2, 3}); - HloSharding sharding1 = - HloSharding::Tile(tile_shape, MakeArray({2, 2}, {0, 3, 2, 1})); - HloSharding sharding2 = HloSharding::Tile(ShapeUtil::MakeShape(U32, {2, 3}), - MakeArray({2, 2}, {0, 3, 2, 1})); - EXPECT_TRUE(hash_compare_equal(sharding1, sharding2)); - } - - { - Shape tile_shape = ShapeUtil::MakeShape(U32, {2, 3}); - HloSharding sharding1 = - HloSharding::Tile(tile_shape, MakeArray({2, 2}, {0, 3, 2, 1})); - HloSharding sharding2 = HloSharding::Tile(ShapeUtil::MakeShape(U32, {2, 3}), - MakeArray({2, 2}, {0, 3, 2, 1})); + HloSharding sharding1 = HloSharding::Tile(MakeArray({2, 2}, {0, 3, 2, 1})); + HloSharding sharding2 = HloSharding::Tile(MakeArray({2, 2}, {0, 3, 2, 1})); EXPECT_TRUE(hash_compare_equal(sharding1, sharding2)); } - { - Shape tile_shape = ShapeUtil::MakeShape(U32, {2, 3}); - HloSharding sharding1 = - HloSharding::Tile(tile_shape, MakeArray({2, 2}, {0, 3, 2, 1})); - HloSharding sharding2 = HloSharding::Tile(ShapeUtil::MakeShape(U32, {2, 3}), - MakeArray({2, 2}, {0, 3, 1, 2})); - EXPECT_FALSE(hash_compare_equal(sharding1, sharding2)); - } - HloSharding default_sharding = HloSharding::Replicate(); { ShapeTree shape_tree(ShapeUtil::MakeTupleShape({}), @@ -259,19 +225,6 @@ TEST_F(HloShardingTest, Hash) { } } -TEST_F(HloShardingTest, TransformShardedTileShapeTest) { - HloSharding sharding = - HloSharding::Tile(ShapeUtil::MakeShape(F32, {3, 5, 7, 11}), - Array4D({{{{0, 1}, {2, 3}}}})); - HloSharding result = sharding.TransformShardedTileShape( - ShapeUtil::MakeShape(F32, {13, 15, 17, 19}), - [](int dim, int value) { return dim * 111; }); - HloSharding expected = - HloSharding::Tile(ShapeUtil::MakeShape(F32, {13, 15, 222, 333}), - Array4D({{{{0, 1}, {2, 3}}}})); - EXPECT_EQ(result, expected); -} - TEST_F(HloShardingTest, ToStringReplicatedTest) { HloSharding sharding = HloSharding::Replicate(); EXPECT_EQ(sharding.ToString(), "{replicated}"); @@ -284,9 +237,8 @@ TEST_F(HloShardingTest, ToStringAssignDeviceTest) { TEST_F(HloShardingTest, ToStringTiledTest) { HloSharding sharding = - HloSharding::Tile(ShapeUtil::MakeShape(S32, {7, 11, 13}), - Array3D({{{2, 3}}, {{5, 7}}})); - EXPECT_EQ(sharding.ToString(), "{s32[7,11,13] devices=[2,1,2]2,3,5,7}"); + HloSharding::Tile(Array3D({{{2, 3}}, {{5, 7}}})); + EXPECT_EQ(sharding.ToString(), "{devices=[2,1,2]2,3,5,7}"); } TEST_F(HloShardingTest, ToStringTupleTest) { @@ -294,21 +246,18 @@ TEST_F(HloShardingTest, ToStringTupleTest) { ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(F32, {3, 5}), ShapeUtil::MakeShape(U32, {7, 25}), ShapeUtil::MakeShape(S32, {9, 11})}), - {HloSharding::Replicate(), - HloSharding::Tile(ShapeUtil::MakeShape(U32, {7, 13}), - Array2D({{3, 5}})), + {HloSharding::Replicate(), HloSharding::Tile(Array2D({{3, 5}})), HloSharding::AssignDevice(3)}); EXPECT_EQ(sharding.ToString(), - "{{replicated}, {u32[7,13] devices=[1,2]3,5}, {maximal device=3}}"); + "{{replicated}, {devices=[1,2]3,5}, {maximal device=3}}"); } TEST_F(HloShardingTest, OstreamTest) { HloSharding sharding = - HloSharding::Tile(ShapeUtil::MakeShape(F32, {3, 5, 7, 11}), - Array4D({{{{0, 1}, {2, 3}}}})); + HloSharding::Tile(Array4D({{{{0, 1}, {2, 3}}}})); std::ostringstream oss; oss << sharding; - EXPECT_EQ(oss.str(), "{f32[3,5,7,11] devices=[1,1,2,2]0,1,2,3}"); + EXPECT_EQ(oss.str(), "{devices=[1,1,2,2]0,1,2,3}"); } TEST_F(HloShardingTest, ParseHloString) { @@ -319,8 +268,7 @@ TEST_F(HloShardingTest, ParseHloString) { }; check(HloSharding::Replicate()); check(HloSharding::AssignDevice(2)); - check(HloSharding::Tile(ShapeUtil::MakeShape(F32, {3, 1, 3, 7}), - Array4D({{{{0}, {1}}}}))); + check(HloSharding::Tile(Array4D({{{{0}, {1}}}}))); // Empty tuple. One sharding is required for empty tuples, as we need to be // able to assign sharding to them, even though they have no leaves. check(HloSharding::Tuple(ShapeUtil::MakeTupleShape({}), @@ -332,8 +280,7 @@ TEST_F(HloShardingTest, ParseHloString) { ShapeUtil::MakeShape(F32, {3, 5, 7}), ShapeUtil::MakeShape(F32, {3, 7})}); check(HloSharding::Tuple( - tuple_shape, {HloSharding::Tile(ShapeUtil::MakeShape(F32, {3, 1, 3, 7}), - Array4D({{{{0}, {1}}}})), + tuple_shape, {HloSharding::Tile(Array4D({{{{0}, {1}}}})), HloSharding::Replicate(), HloSharding::AssignDevice(1)})); } { @@ -343,8 +290,7 @@ TEST_F(HloShardingTest, ParseHloString) { ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(F32, {3, 5, 7}), ShapeUtil::MakeShape(F32, {3, 7})})}); std::vector leaf_shardings = { - HloSharding::Tile(ShapeUtil::MakeShape(F32, {3, 1, 3, 7}), - Array4D({{{{0}, {1}}}})), + HloSharding::Tile(Array4D({{{{0}, {1}}}})), HloSharding::Replicate(), HloSharding::AssignDevice(1)}; ShapeTree sharding_tree(tuple_shape, HloSharding::Replicate()); // Assign leaf_shardings to sharding_tree leaves. diff --git a/tensorflow/compiler/xla/service/hlo_tfgraph_builder.cc b/tensorflow/compiler/xla/service/hlo_tfgraph_builder.cc index 48f676db85ab5e7711d9e9ac900306a9ea85ef10..b78bfa0cdf4db605576fa11e18ce6c654c6a0b6d 100644 --- a/tensorflow/compiler/xla/service/hlo_tfgraph_builder.cc +++ b/tensorflow/compiler/xla/service/hlo_tfgraph_builder.cc @@ -101,11 +101,11 @@ const string& HloTfGraphBuilder::GetNodeNameForInstruction( } }; string node_name; - if (debug_options_.xla_hlo_tfgraph_device_scopes() && - instruction->has_sharding() && - instruction->sharding().HasUniqueDevice()) { - node_name = StrCat( - "dev", instruction->sharding().UniqueDevice().ConsumeValueOrDie()); + if (debug_options_.xla_hlo_tfgraph_device_scopes()) { + auto device = instruction->sharding_unique_device(); + if (device) { + node_name = StrCat("dev", *device); + } } // If an instruction is fused, put it in the subgraph of the fusion; // otherwise, put it in the computation subgraph. @@ -215,10 +215,10 @@ Status HloTfGraphBuilder::AddInstruction(const HloInstruction* instruction) { NodeDef* node_def = graph_def_.add_node(); node_def->set_name(GetNodeNameForInstruction(instruction)); node_def->set_op(GetOpDefName(instruction)); - if (instruction->has_sharding() && - instruction->sharding().HasUniqueDevice()) { - TF_ASSIGN_OR_RETURN(int64 device, instruction->sharding().UniqueDevice()); - node_def->set_device(GetDeviceName(device)); + + auto device = instruction->sharding_unique_device(); + if (device) { + node_def->set_device(GetDeviceName(*device)); } SetNodeAttrs(instruction, node_def); if (instruction->opcode() == HloOpcode::kFusion) { diff --git a/tensorflow/compiler/xla/service/hlo_token.h b/tensorflow/compiler/xla/service/hlo_token.h index 533429608bc2e13626a3e746fbe465398e1f4bb4..4458c251dee4af365e39027dd4289925c8890efd 100644 --- a/tensorflow/compiler/xla/service/hlo_token.h +++ b/tensorflow/compiler/xla/service/hlo_token.h @@ -44,7 +44,6 @@ enum class TokKind { kRparen, // ( ) kArrow, // -> - kComment, // /*xxx*/ // Keywords kw_HloModule, diff --git a/tensorflow/compiler/xla/service/hlo_value.cc b/tensorflow/compiler/xla/service/hlo_value.cc index 4e3c9df3a036890ce25f5b14603d275263e8659b..14703aaf64bdbfee4e737331dd47d5def95e1d4b 100644 --- a/tensorflow/compiler/xla/service/hlo_value.cc +++ b/tensorflow/compiler/xla/service/hlo_value.cc @@ -18,8 +18,8 @@ limitations under the License. #include #include +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/map_util.h" -#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module.h" @@ -283,8 +283,7 @@ std::ostream& operator<<(std::ostream& out, string InstructionValueSet::ToString() const { string out = StrCat("InstructionValueSet(", ShapeUtil::HumanString(shape()), ")\n"); - ForEachElement([this, &out](const ShapeIndex& index, - const HloValueSet& value_set) { + ForEachElement([&out](const ShapeIndex& index, const HloValueSet& value_set) { StrAppend(&out, " ", index.ToString(), " : ", value_set.ToString(), "\n"); }); return out; diff --git a/tensorflow/compiler/xla/service/hlo_verifier.cc b/tensorflow/compiler/xla/service/hlo_verifier.cc index 25fa319faf13d8bef69381c869f08f4948fc3519..ac1a663633796860b38a3f9035cc1d3362060736 100644 --- a/tensorflow/compiler/xla/service/hlo_verifier.cc +++ b/tensorflow/compiler/xla/service/hlo_verifier.cc @@ -84,7 +84,8 @@ Status ShapeVerifier::HandleConvolution(HloInstruction* convolution) { const Shape expected, ShapeInference::InferConvolveShape( convolution->operand(0)->shape(), convolution->operand(1)->shape(), - convolution->window(), convolution->convolution_dimension_numbers())); + convolution->window(), convolution->convolution_dimension_numbers(), + convolution->feature_group_count())); return CheckShape(convolution, expected); } @@ -105,6 +106,15 @@ Status ShapeVerifier::HandleCrossReplicaSum(HloInstruction* crs) { ShapeInference::InferCrossReplicaSumShape(operand_shapes)); } +Status ShapeVerifier::HandleAllToAll(HloInstruction* hlo) { + std::vector operand_shapes; + for (const HloInstruction* operand : hlo->operands()) { + operand_shapes.push_back(&operand->shape()); + } + return CheckShape(hlo, + ShapeInference::InferAllToAllTupleShape(operand_shapes)); +} + Status ShapeVerifier::HandleReducePrecision(HloInstruction* reduce_precision) { return CheckShape(reduce_precision, ShapeInference::InferReducePrecisionShape( reduce_precision->operand(0)->shape(), @@ -147,11 +157,7 @@ Status CheckOperandAndParameter(const HloInstruction* instruction, Status ShapeVerifier::HandleInfeed(HloInstruction* instruction) { HloInfeedInstruction* infeed = Cast(instruction); - // Infeed has an optional single token operand. - // TODO(b/80000000): Update when token is not optional. - if (infeed->operand_count() == 1) { - TF_RETURN_IF_ERROR(CheckIsTokenOperand(instruction, 0)); - } + TF_RETURN_IF_ERROR(CheckIsTokenOperand(instruction, 0)); // The output of infeed is a tuple containing the data value and a token. return CheckShape(infeed, @@ -161,11 +167,7 @@ Status ShapeVerifier::HandleInfeed(HloInstruction* instruction) { Status ShapeVerifier::HandleOutfeed(HloInstruction* instruction) { HloOutfeedInstruction* outfeed = Cast(instruction); - // Outfeed has an optional token operand (operand 1). - // TODO(b/80000000): Update when token is not optional. - if (outfeed->operand_count() == 2) { - TF_RETURN_IF_ERROR(CheckIsTokenOperand(instruction, 1)); - } + TF_RETURN_IF_ERROR(CheckIsTokenOperand(instruction, 1)); // Outfeed has a separate shape field for the value which is outfed to the // host. The shape of the instruction itself is always a token. @@ -185,7 +187,67 @@ Status ShapeVerifier::HandleHostCompute(HloInstruction*) { return Status::OK(); } -Status ShapeVerifier::HandleRng(HloInstruction*) { return Status::OK(); } +bool ShapeVerifier::HasCompatibleElementTypes(const Shape& shape_0, + const Shape& shape_1, + const Shape& result_shape) { + return ShapeUtil::SameElementType(shape_0, shape_1) && + (ShapeUtil::SameElementType(shape_0, result_shape) || + (allow_mixed_precision_ && + ShapeUtil::SameElementTypeIgnoringFpPrecision(shape_0, + result_shape))); +} + +Status ShapeVerifier::HandleRng(HloInstruction* instruction) { + if (instruction->operand_count() != 2) { + return InternalError("Expected two operands for Rng instruction: %s", + instruction->ToString().c_str()); + } + + const Shape& shape_0 = instruction->operand(0)->shape(); + const Shape& shape_1 = instruction->operand(1)->shape(); + if (!ShapeUtil::IsScalar(shape_0) || !ShapeUtil::IsScalar(shape_1)) { + return InternalError( + "Expected scalar types for the two operands of Rng instruction: %s", + instruction->ToString().c_str()); + } + + if (!HasCompatibleElementTypes(shape_0, shape_1, instruction->shape())) { + return InternalError( + "Expected compatible element types for the result and the two operands" + " of Rng instruction: %s", + instruction->ToString().c_str()); + } + + PrimitiveType element_type = shape_0.element_type(); + switch (instruction->random_distribution()) { + case RNG_UNIFORM: + if (!primitive_util::IsFloatingPointType(element_type) && + !primitive_util::IsIntegralType(element_type) && + element_type != PRED) { + return InternalError( + "Element type not supported." + " Expected element to be of floating point type, integral type or" + " predicate type for RngUniform: %s", + instruction->ToString().c_str()); + } + break; + + case RNG_NORMAL: + if (!primitive_util::IsFloatingPointType(element_type)) { + return InternalError( + "Element type not supported." + " Expected element to be FloatingPointType for RngNormal: %s", + instruction->ToString().c_str()); + } + break; + default: + return InternalError( + "Invalid Rng distribution %s", + RandomDistribution_Name(instruction->random_distribution()).c_str()); + } + + return Status::OK(); +} Status ShapeVerifier::HandleReverse(HloInstruction* reverse) { return CheckShape( @@ -224,10 +286,13 @@ Status ShapeVerifier::HandleGetTupleElement(HloInstruction* get_tuple_element) { } Status ShapeVerifier::HandleReduce(HloInstruction* reduce) { + if (!ShapeUtil::IsArray(reduce->shape())) { + return InvalidArgument("Variadic reduce is not supported."); + } return CheckShape( reduce, ShapeInference::InferReduceShape( - reduce->operand(0)->shape(), reduce->operand(1)->shape(), + {&reduce->operand(0)->shape(), &reduce->operand(1)->shape()}, reduce->dimensions(), reduce->to_apply()->ComputeProgramShape())); } @@ -451,9 +516,9 @@ namespace { // inputs. Status CheckMixedPrecisionOperands(const HloInstruction* instruction) { switch (instruction->opcode()) { - // White list the following opcodes for mixed-precision check, because they - // involve data pass through or grouping via tuples, where the precisions - // of buffers can be different. + // White list the following opcodes for mixed-precision check, because + // they involve data pass through or grouping via tuples, where the + // precisions of buffers can be different. case HloOpcode::kCall: case HloOpcode::kConditional: case HloOpcode::kConstant: @@ -507,7 +572,16 @@ Status ShapeVerifier::HandleGather(HloInstruction* gather) { gather, ShapeInference::InferGatherShape( gather->operand(0)->shape(), gather->operand(1)->shape(), - gather->gather_dimension_numbers(), gather->gather_window_bounds())); + gather->gather_dimension_numbers(), gather->gather_slice_sizes())); +} + +Status ShapeVerifier::HandleScatter(HloInstruction* scatter) { + return CheckShape( + scatter, ShapeInference::InferScatterShape( + scatter->operand(0)->shape(), scatter->operand(1)->shape(), + scatter->operand(2)->shape(), + scatter->to_apply()->ComputeProgramShape(), + scatter->scatter_dimension_numbers())); } Status ShapeVerifier::HandleAfterAll(HloInstruction* token) { @@ -626,7 +700,8 @@ string ComputationsToString( // Verifies various invariants about the structure of the HLO: // -// (1) each instruction has a non-null parent() set to the HloComputation which +// (1) each instruction has a non-null parent() set to the HloComputation +// which // contains it. // // (2) each computation has a non-null parent() set to the HloModule which @@ -660,9 +735,9 @@ Status VerifyHloStructure(HloModule* module) { } // Check that operands are in the same computation separately from verifying - // parent() correctness so conditions like a null HloInstruction::parent() are - // identified and reported explicitly above rather than reporting a mismatched - // operand. + // parent() correctness so conditions like a null HloInstruction::parent() + // are identified and reported explicitly above rather than reporting a + // mismatched operand. for (const HloComputation* computation : module->computations()) { for (const HloInstruction* instruction : computation->instructions()) { for (int i = 0; i < instruction->operand_count(); ++i) { @@ -686,13 +761,14 @@ Status HloVerifier::CheckFusionInstruction(HloInstruction* fusion) const { HloComputation* fused_computation = fusion->fused_instructions_computation(); if (fusion != fused_computation->FusionInstruction()) { return InternalError( - "Instruction of fused computation does not match expected instruction " + "Instruction of fused computation does not match expected " + "instruction " "%s.", fusion->ToString().c_str()); } - // Fused root instruction and fused parameters must all be owned by the fusion - // computation. + // Fused root instruction and fused parameters must all be owned by the + // fusion computation. bool root_owned = false; const std::vector& fused_parameters = fusion->fused_parameters(); @@ -734,8 +810,8 @@ Status HloVerifier::CheckFusionInstruction(HloInstruction* fusion) const { fusion->ToString().c_str()); } - // All uses of fused instructions must be in the fusion computation, and every - // non-root instruction must have at least one use. + // All uses of fused instructions must be in the fusion computation, and + // every non-root instruction must have at least one use. for (auto* instruction : fusion->fused_instructions_computation()->instructions()) { if (instruction != fused_root) { @@ -779,7 +855,8 @@ Status HloVerifier::CheckFusionInstruction(HloInstruction* fusion) const { if (!ShapeUtil::Compatible(fused_param->shape(), fusion->operand(param_no)->shape())) { return InternalError( - "Shape mismatch between parameter number %lld and its operand in %s.", + "Shape mismatch between parameter number %lld and its operand in " + "%s.", param_no, fusion->ToString().c_str()); } } @@ -897,8 +974,9 @@ Status CheckSameChannel(const HloInstruction* instr1, return Status::OK(); } -// Checks if the given two instructions have the same is_host_transfer attribute -// value. Intsructions must be send/recv instructions or their 'done' variant. +// Checks if the given two instructions have the same is_host_transfer +// attribute value. Intsructions must be send/recv instructions or their +// 'done' variant. Status CheckSameIsHostTransfer(const HloInstruction* instr1, const HloInstruction* instr2) { const HloSendRecvInstruction* send_recv1 = @@ -909,7 +987,8 @@ Status CheckSameIsHostTransfer(const HloInstruction* instr1, TF_RET_CHECK(send_recv2 != nullptr); if (send_recv1->is_host_transfer() != send_recv2->is_host_transfer()) { return InternalError( - "Expected instructions to have the same is-host-transfer property: %s, " + "Expected instructions to have the same is-host-transfer property: " + "%s, " "%s ", instr1->ToString().c_str(), instr2->ToString().c_str()); } @@ -928,7 +1007,8 @@ Status VerifySendsAndRecvs(const HloModule& module) { host_channels.insert({sendrecv->channel_id(), sendrecv}); if (!it_inserted.second) { return FailedPrecondition( - "Channel %lld is used for multiple host send/recv instructions: %s " + "Channel %lld is used for multiple host send/recv instructions: " + "%s " "and " "%s", sendrecv->channel_id(), sendrecv->ToString().c_str(), diff --git a/tensorflow/compiler/xla/service/hlo_verifier.h b/tensorflow/compiler/xla/service/hlo_verifier.h index 79f7aa9f4ce66cc9b53d016f2e126033492c81e9..9e54b54b26ad97aea212ea5730073dea0d79e0f3 100644 --- a/tensorflow/compiler/xla/service/hlo_verifier.h +++ b/tensorflow/compiler/xla/service/hlo_verifier.h @@ -18,6 +18,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_pass_interface.h" +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/service/shape_inference.h" namespace xla { @@ -45,6 +46,7 @@ class ShapeVerifier : public DfsHloVisitor { Status HandleConvolution(HloInstruction* convolution) override; Status HandleFft(HloInstruction* fft) override; Status HandleCrossReplicaSum(HloInstruction* crs) override; + Status HandleAllToAll(HloInstruction* hlo) override; Status HandleReducePrecision(HloInstruction* reduce_precision) override; Status HandleInfeed(HloInstruction*) override; Status HandleOutfeed(HloInstruction*) override; @@ -83,6 +85,7 @@ class ShapeVerifier : public DfsHloVisitor { HloInstruction* batch_norm_inference) override; Status HandleBatchNormGrad(HloInstruction* batch_norm_grad) override; Status HandleGather(HloInstruction* gather) override; + Status HandleScatter(HloInstruction* scatter) override; Status HandleAfterAll(HloInstruction* token) override; Status FinishVisit(HloInstruction*) override { return Status::OK(); } @@ -104,6 +107,13 @@ class ShapeVerifier : public DfsHloVisitor { Status CheckVariadicShape(const HloInstruction* instruction); private: + // Return true if the shapes of the two operands have the same element type, + // and the result shape either has the same element type as the operand + // shapes or mixed precision is allowed and the result shape and the operand + // shapes have floating point element types. + bool HasCompatibleElementTypes(const Shape& shape_0, const Shape& shape_1, + const Shape& result_shape); + // Whether the inputs and output of an instruction can contain both F32s and // BF16s. Tuples that include both F32s and BF16s are allowed regardless of // this flag. @@ -119,11 +129,11 @@ class HloVerifier : public HloPassInterface { // Uses standard shape inference. explicit HloVerifier() : shape_verifier_factory_( - [] { return MakeUnique(false); }) {} + [] { return absl::make_unique(false); }) {} explicit HloVerifier(bool allow_mixed_precision) : shape_verifier_factory_([allow_mixed_precision] { - return MakeUnique(allow_mixed_precision); + return absl::make_unique(allow_mixed_precision); }) {} // Uses custom shape verification. diff --git a/tensorflow/compiler/xla/service/hlo_verifier_test.cc b/tensorflow/compiler/xla/service/hlo_verifier_test.cc index 04c6ba3eeb92bad2b5b69f7f56e73e1f7a8148aa..d764964f3c3dc58a54bd0307f8b625076c14f3e5 100644 --- a/tensorflow/compiler/xla/service/hlo_verifier_test.cc +++ b/tensorflow/compiler/xla/service/hlo_verifier_test.cc @@ -34,7 +34,17 @@ namespace { using ::testing::HasSubstr; -using HloVerifierTest = HloTestBase; +class HloVerifierTest : public HloTestBase { + public: + HloVerifierTest() + : HloTestBase(/*allow_mixed_precision_in_hlo_verifier=*/false) {} +}; + +class HloVerifierTestAllowMixedPrecision : public HloTestBase { + public: + HloVerifierTestAllowMixedPrecision() + : HloTestBase(/*allow_mixed_precision_in_hlo_verifier=*/true) {} +}; TEST_F(HloVerifierTest, NullInstructionParent) { HloComputation::Builder builder(TestName()); @@ -174,5 +184,96 @@ ENTRY entry { HasSubstr("shape does not match parameter")); } +TEST_F(HloVerifierTest, RngOpnd0NotScalar) { + const char* const hlo_string = R"( + HloModule Module + + ENTRY RngOpnd0NotScalar { + constant.0 = f32[] constant(0) + constant.1 = f16[2] constant({1, 3}) + ROOT rng.0 = f32[10]{0} rng(f32[] constant.0, f16[2] constant.1), + distribution=rng_uniform + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseHloString(hlo_string)); + + auto status = verifier().Run(module.get()).status(); + ASSERT_FALSE(status.ok()); + EXPECT_THAT(status.error_message(), HasSubstr("Expected scalar type")); +} + +TEST_F(HloVerifierTest, RngOperandElementTypesDoNotMatch) { + const char* const hlo_string = R"( + HloModule Module + + ENTRY RngOperandElementTypesNotMatch { + constant.0 = f32[] constant(0) + constant.1 = f16[] constant(1) + ROOT rng.0 = f32[10]{0} rng(f32[] constant.0, f16[] constant.1), + distribution=rng_normal + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseHloString(hlo_string)); + + auto status = verifier().Run(module.get()).status(); + ASSERT_FALSE(status.ok()); + EXPECT_THAT(status.error_message(), + HasSubstr("Expected compatible element types")); +} + +TEST_F(HloVerifierTest, RngMixedPrecisionNotAllowed) { + const char* const hlo_string = R"( + HloModule Module + + ENTRY RngResultElementTypeNotMatch { + constant.0 = f32[] constant(0) + constant.1 = f32[] constant(1) + ROOT rng.0 = f16[10]{0} rng(f32[] constant.0, f32[] constant.1), + distribution=rng_normal + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseHloString(hlo_string)); + + auto status = verifier().Run(module.get()).status(); + ASSERT_FALSE(status.ok()); + EXPECT_THAT(status.error_message(), + HasSubstr("Expected compatible element types")); +} + +TEST_F(HloVerifierTestAllowMixedPrecision, RngMixedPrecisionAllowed) { + const char* const hlo_string = R"( + HloModule Module + + ENTRY RngResultElementTypeNotMatch { + constant.0 = f32[] constant(0) + constant.1 = f32[] constant(1) + ROOT rng.0 = f16[10]{0} rng(f32[] constant.0, f32[] constant.1), + distribution=rng_normal + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseHloString(hlo_string)); + + auto status = verifier().Run(module.get()).status(); + ASSERT_TRUE(status.ok()); +} + +TEST_F(HloVerifierTest, RngElementTypeNotSupported) { + const char* const hlo_string = R"( + HloModule Module + + ENTRY RngElementTypeNotSupported { + constant.0 = s32[] constant(0) + constant.1 = s32[] constant(1) + ROOT rng.0 = s32[10]{0} rng(s32[] constant.0, s32[] constant.1), + distribution=rng_normal + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseHloString(hlo_string)); + + auto status = verifier().Run(module.get()).status(); + ASSERT_FALSE(status.ok()); + EXPECT_THAT(status.error_message(), HasSubstr("Element type not supported")); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/human_readable_profile_builder.cc b/tensorflow/compiler/xla/service/human_readable_profile_builder.cc index d7458c338e9f1df9fac90270845aae0b8f779ee2..bb5b40a8a87c5eab5a5b1599581a81bbd064511b 100644 --- a/tensorflow/compiler/xla/service/human_readable_profile_builder.cc +++ b/tensorflow/compiler/xla/service/human_readable_profile_builder.cc @@ -36,7 +36,8 @@ string HumanReadableProfileBuilder::ToString() const { computation_name_.c_str(), HumanReadableElapsedTime(CyclesToSeconds(total_cycles_)).c_str()); - auto print_op = [&](const OpInfo& op) { + int64 cumulative_cycles = 0; + auto print_op = [&](const OpInfo& op, bool is_total = false) { // Skip ops with 0 optimal seconds and 0 actual cycles. These are ops that // were expected to be free and are actually free -- things like (on most // backends) kParameter or kConstant HLOs. There's no need to clutter the @@ -59,27 +60,44 @@ string HumanReadableProfileBuilder::ToString() const { } } + double cumulative_cycles_percent = 0; double cycles_percent = 0; + if (!is_total) { + cumulative_cycles += op.cycles; + } if (total_cycles_ > 0) { cycles_percent = op.cycles / static_cast(total_cycles_) * 100; + cumulative_cycles_percent = + cumulative_cycles / static_cast(total_cycles_) * 100; + } + + string cycles_percent_str; + if (is_total) { + // Leaving off the two trailing decimal points of "100.%" lets us save two + // columns in the output. + cycles_percent_str = "100.% 100Σ"; + } else { + cycles_percent_str = + Printf("%5.2f%% %2.0fΣ", cycles_percent, cumulative_cycles_percent); } double nsecs = op.cycles / clock_rate_ghz_; - Appendf(&s, - "%15lld cycles (%6.2f%%) :: %12.1f usec %22s :: %18s " - ":: %18s :: %14s :: %16s :: %s\n", - op.cycles, cycles_percent, CyclesToMicroseconds(op.cycles), - op.optimal_seconds < 0 - ? "" - : Printf("(%12.1f optimal)", op.optimal_seconds * 1e6).c_str(), - op.flop_count <= 0 - ? "" - : HumanReadableNumFlops(op.flop_count, nsecs).c_str(), - op.transcendental_count <= 0 ? "" - : HumanReadableNumTranscendentalOps( - op.transcendental_count, nsecs) - .c_str(), - bytes_per_sec.c_str(), bytes_per_cycle.c_str(), op.name.c_str()); + Appendf( + &s, + "%15lld cycles (%s) :: %12.1f usec %22s :: %18s :: %18s :: %14s :: " + "%16s :: %s\n", + op.cycles, cycles_percent_str.c_str(), CyclesToMicroseconds(op.cycles), + op.optimal_seconds < 0 + ? "" + : Printf("(%12.1f optimal)", op.optimal_seconds * 1e6).c_str(), + op.flop_count <= 0 + ? "" + : HumanReadableNumFlops(op.flop_count, nsecs).c_str(), + op.transcendental_count <= 0 + ? "" + : HumanReadableNumTranscendentalOps(op.transcendental_count, nsecs) + .c_str(), + bytes_per_sec.c_str(), bytes_per_cycle.c_str(), op.name.c_str()); }; float optimal_seconds_sum = 0.0; @@ -98,7 +116,8 @@ string HumanReadableProfileBuilder::ToString() const { VLOG(1) << "Total floating point ops: " << total_flops; print_op({"[total]", "[total]", /*category=*/"", total_cycles_, total_flops, - total_transcendentals, total_bytes, optimal_seconds_sum}); + total_transcendentals, total_bytes, optimal_seconds_sum}, + /*is_total=*/true); // Sort ops in decreasing order of cycles, and print them. std::vector sorted_ops(op_infos_); diff --git a/tensorflow/compiler/xla/service/indexed_array_analysis.cc b/tensorflow/compiler/xla/service/indexed_array_analysis.cc index 8b2df3256776a7d77517daff1fe282b0dbde7045..39dff567d4f58924f54738a1fcbd1088f27d491d 100644 --- a/tensorflow/compiler/xla/service/indexed_array_analysis.cc +++ b/tensorflow/compiler/xla/service/indexed_array_analysis.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/xla/service/indexed_array_analysis.h" +#include "absl/algorithm/container.h" #include "tensorflow/compiler/xla/map_util.h" #include "tensorflow/compiler/xla/service/hlo_evaluator.h" #include "tensorflow/compiler/xla/util.h" @@ -153,7 +154,7 @@ StatusOr IndexedArrayAnalysis::ComputeArrayFor( TF_ASSIGN_OR_RETURN( computed_array, ComputeArrayForGather(instr->shape(), instr->gather_dimension_numbers(), - instr->gather_window_bounds(), + instr->gather_slice_sizes(), FindOrDie(cache_, instr->operand(0)), FindOrDie(cache_, instr->operand(1)))); } else if (instr->opcode() == HloOpcode::kReshape) { @@ -251,24 +252,23 @@ StatusOr IndexedArrayAnalysis::FoldGatherOfGather( StatusOr IndexedArrayAnalysis::ComputeArrayForGather( const Shape& shape, const GatherDimensionNumbers& dim_numbers, - tensorflow::gtl::ArraySlice window_bounds, Array* source, + tensorflow::gtl::ArraySlice slice_sizes, Array* source, Array* indices) { if (dim_numbers.index_vector_dim() != indices->shape().dimensions_size()) { VLOG(3) << "ComputeArrayForGather: indices are not scalar"; return nullptr; } - CHECK_EQ(dim_numbers.gather_dims_to_operand_dims_size(), 1); + CHECK_EQ(dim_numbers.start_index_map_size(), 1); - // We can also handle dim_numbers.elided_window_dims_size() == 0 here, should - // it become relevant. + // We can also handle dim_numbers.collapsed_slice_dims_size() == 0 here, + // should it become relevant. - if (dim_numbers.elided_window_dims_size() != 1 || - dim_numbers.elided_window_dims(0) != - dim_numbers.gather_dims_to_operand_dims(0)) { + if (dim_numbers.collapsed_slice_dims_size() != 1 || + dim_numbers.collapsed_slice_dims(0) != dim_numbers.start_index_map(0)) { VLOG(3) << "ComputeArrayForGather: gather operations must elide " - "gather_dims_to_operand_dims[0] and " - "gather_dims_to_operand_dims[0] only"; + "start_index_map[0] and " + "start_index_map[0] only"; return nullptr; } @@ -277,27 +277,27 @@ StatusOr IndexedArrayAnalysis::ComputeArrayForGather( // arrays from an array of size [7,4,6]. We check that condition down below: for (int64 i = 0, e = source->shape().dimensions_size(); i < e; i++) { - if (i != dim_numbers.elided_window_dims(0) && - source->shape().dimensions(i) != window_bounds[i]) { - VLOG(3) << "ComputeArrayForGather: window_bounds[" << i + if (i != dim_numbers.collapsed_slice_dims(0) && + source->shape().dimensions(i) != slice_sizes[i]) { + VLOG(3) << "ComputeArrayForGather: slice_sizes[" << i << "] != source->shape().dimensions(" << i << ") -- " - << source->shape().dimensions(i) << " vs. " << window_bounds[i] - << " with dim_numbers.elided_window_dims(0) = " - << dim_numbers.elided_window_dims(0); + << source->shape().dimensions(i) << " vs. " << slice_sizes[i] + << " with dim_numbers.collapsed_slice_dims(0) = " + << dim_numbers.collapsed_slice_dims(0); return nullptr; } } - int64 source_dim = dim_numbers.gather_dims_to_operand_dims(0); + int64 source_dim = dim_numbers.start_index_map(0); std::vector output_dims; for (int64 i = 0, e = shape.dimensions_size(); i < e; i++) { - if (!c_binary_search(dim_numbers.output_window_dims(), i)) { + if (!absl::c_binary_search(dim_numbers.offset_dims(), i)) { output_dims.push_back(i); } } if (auto* indexed = dynamic_cast(source)) { - if (c_linear_search(indexed->output_dims(), source_dim)) { + if (absl::c_linear_search(indexed->output_dims(), source_dim)) { return FoldGatherOfGather(indexed, indices, source_dim, output_dims, shape); } @@ -315,7 +315,7 @@ namespace { // [values.begin()+index, values.end()) is equal to `product`. If there is no // such index, return -1. All integers in `values` must be positive. int64 FindSuffixWithProduct(ArraySlice values, int64 product) { - DCHECK(c_all_of(values, [](int64 value) { return value > 0; })); + DCHECK(absl::c_all_of(values, [](int64 value) { return value > 0; })); int64 current_product = 1; int64 i; @@ -389,26 +389,26 @@ std::vector ComputeReshapePassthroughDimPairs( result_subarray_size *= result_shape[result_dim]; } - c_reverse(result); + absl::c_reverse(result); if (VLOG_IS_ON(3)) { std::vector result_strings; - c_transform(result, std::back_inserter(result_strings), - [](ReshapePassthroughDimPair value) { - return tensorflow::strings::StrCat(value.result_dim, "->", - value.operand_dim); - }); + absl::c_transform(result, std::back_inserter(result_strings), + [](ReshapePassthroughDimPair value) { + return tensorflow::strings::StrCat( + value.result_dim, "->", value.operand_dim); + }); VLOG(3) << "For a reshape from [" << Join(operand_shape, ",") << "] to [" << Join(result_shape, ",") << "] passthrough indices are [" << Join(result_strings, ",") << "] (legend: `result`->`operand`)"; } - DCHECK(c_is_sorted( + DCHECK(absl::c_is_sorted( result, [](ReshapePassthroughDimPair lhs, ReshapePassthroughDimPair rhs) { return lhs.result_dim < rhs.result_dim; })); - DCHECK(c_is_sorted( + DCHECK(absl::c_is_sorted( result, [](ReshapePassthroughDimPair lhs, ReshapePassthroughDimPair rhs) { return lhs.operand_dim < rhs.operand_dim; })); @@ -420,20 +420,20 @@ std::vector ComputeReshapePassthroughDimPairs( // `passthrough_dims`. bool IsReshapePassthroughOperandDim( ArraySlice passthrough_dims, int64 dim) { - return c_any_of(passthrough_dims, - [&](ReshapePassthroughDimPair passthrough_dim_pair) { - return passthrough_dim_pair.operand_dim == dim; - }); + return absl::c_any_of(passthrough_dims, + [&](ReshapePassthroughDimPair passthrough_dim_pair) { + return passthrough_dim_pair.operand_dim == dim; + }); } // Maps `operand_dim` which must be an passthrough operand dimension to its // corresponding passthrough result dimension based on `passthrough_dims`. int64 MapPassthroughOperandDimToResultDim( ArraySlice passthrough_dims, int64 operand_dim) { - auto it = c_find_if(passthrough_dims, - [&](ReshapePassthroughDimPair passthrough_dim_pair) { - return passthrough_dim_pair.operand_dim == operand_dim; - }); + auto it = absl::c_find_if( + passthrough_dims, [&](ReshapePassthroughDimPair passthrough_dim_pair) { + return passthrough_dim_pair.operand_dim == operand_dim; + }); CHECK(it != passthrough_dims.end()); return it->result_dim; } @@ -447,15 +447,15 @@ int64 FindSourcePositionForPassthroughResultDim(ArraySlice operand_shape, int64 indexed_source_subarray_size = std::accumulate(operand_shape.begin() + source_passthrough_dim + 1, - operand_shape.end(), 1, std::multiplies()); + operand_shape.end(), 1LL, std::multiplies()); return FindSuffixWithProduct(result_shape, indexed_source_subarray_size); } Shape StripDegenerateDimensions(const Shape& shape) { DimensionVector new_dims; - c_copy_if(shape.dimensions(), std::back_inserter(new_dims), - [](int64 dim) { return dim != 1; }); + absl::c_copy_if(shape.dimensions(), std::back_inserter(new_dims), + [](int64 dim) { return dim != 1; }); return ShapeUtil::MakeShape(shape.element_type(), new_dims); } }; // namespace @@ -553,8 +553,8 @@ StatusOr IndexedArrayAnalysis::ReshapeToAddDegenerateDims( }(); DimensionVector new_result_shape_dims; - c_copy(operand->shape().dimensions(), - std::back_inserter(new_result_shape_dims)); + absl::c_copy(operand->shape().dimensions(), + std::back_inserter(new_result_shape_dims)); for (int64 degenerate_dim : degenerate_dims) { InsertAt(&new_result_shape_dims, degenerate_dim, 1); } @@ -695,8 +695,8 @@ IndexedArrayAnalysis::FoldReshapeOfGatherNoDegenerateDims( operand_dim); }; - if (!c_all_of(scalar_indexed->output_dims(), - is_reshape_passthrough_operand_dim)) { + if (!absl::c_all_of(scalar_indexed->output_dims(), + is_reshape_passthrough_operand_dim)) { VLOG(3) << "Not all output dims are passthrough dims " << ToString(scalar_indexed); return nullptr; @@ -735,11 +735,11 @@ IndexedArrayAnalysis::FoldReshapeOfGatherNoDegenerateDims( // operand = s32[3,5,2] constant({...}) // indices = s32[7] parameter(0) // gather = s32[3,2,7] gather(operand, indices), - // output_window_dims={0,1}, - // elided_window_dims={1}, - // gather_dims_to_operand_dims={1}, + // offset_dims={0,1}, + // collapsed_slice_dims={1}, + // start_index_map={1}, // index_vector_dim=1, - // window_bounds={3,1,2} + // slice_sizes={3,1,2} // reshape = s32[6,7] reshape(gather) // // In this case the gather maps to: @@ -764,8 +764,8 @@ IndexedArrayAnalysis::FoldReshapeOfGatherNoDegenerateDims( &new_scalar_indexed_source_shape, source_dim_for_new_scalar_indexed_node, scalar_indexed_source_shape.dimensions(scalar_indexed->source_dim())); - CHECK_EQ(c_accumulate(new_scalar_indexed_source_shape, 1l, - std::multiplies()), + CHECK_EQ(absl::c_accumulate(new_scalar_indexed_source_shape, 1LL, + std::multiplies()), ShapeUtil::ElementsIn(scalar_indexed_source_shape)); CHECK(IsReshapePassthroughOperandDim( @@ -781,9 +781,9 @@ IndexedArrayAnalysis::FoldReshapeOfGatherNoDegenerateDims( }; std::vector output_dims_for_new_scalar_indexed_node; - c_transform(scalar_indexed->output_dims(), - std::back_inserter(output_dims_for_new_scalar_indexed_node), - map_passthrough_operand_dim_to_result_dim); + absl::c_transform(scalar_indexed->output_dims(), + std::back_inserter(output_dims_for_new_scalar_indexed_node), + map_passthrough_operand_dim_to_result_dim); TF_ASSIGN_OR_RETURN(const Literal* new_scalar_indexed_source_literal, TakeOwnership(scalar_indexed->literal().Reshape( @@ -874,11 +874,12 @@ IndexedArrayAnalysis::ComputeArrayForElementwiseBinaryOp(HloOpcode opcode, ArraySlice broadcast_dims = broadcast_instr->dimensions(); auto is_broadcasted_dim = [&](int64 output_dim) { - return c_find(broadcast_dims, output_dim) == broadcast_dims.end(); + return absl::c_find(broadcast_dims, output_dim) == broadcast_dims.end(); }; // All of the output dims must be "broadcasted" dims for the other operand. - if (!c_all_of(scalar_indexed_const->output_dims(), is_broadcasted_dim)) { + if (!absl::c_all_of(scalar_indexed_const->output_dims(), + is_broadcasted_dim)) { return nullptr; } diff --git a/tensorflow/compiler/xla/service/indexed_array_analysis.h b/tensorflow/compiler/xla/service/indexed_array_analysis.h index e923dc39f7f464a8d3c400294499a6f5efda3991..675eb31d2666b52e21394a06ff95e7dc7cd1987a 100644 --- a/tensorflow/compiler/xla/service/indexed_array_analysis.h +++ b/tensorflow/compiler/xla/service/indexed_array_analysis.h @@ -265,7 +265,7 @@ class IndexedArrayAnalysis { StatusOr ComputeArrayForGather( const Shape& shape, const GatherDimensionNumbers& dim_numbers, - tensorflow::gtl::ArraySlice window_bounds, Array* source, + tensorflow::gtl::ArraySlice slice_sizes, Array* source, Array* indices); StatusOr ComputeArrayForDotWithIndexedLhs( diff --git a/tensorflow/compiler/xla/service/indexed_array_analysis_test.cc b/tensorflow/compiler/xla/service/indexed_array_analysis_test.cc index 5f4b42799b1c26ea544f9d4447cc45b5ae9d5a48..97052edf7d783491888cad3f57621e4cd6b045bc 100644 --- a/tensorflow/compiler/xla/service/indexed_array_analysis_test.cc +++ b/tensorflow/compiler/xla/service/indexed_array_analysis_test.cc @@ -82,11 +82,11 @@ ENTRY main { operand = s32[3,3] parameter(0) indices = s32[5] parameter(1) ROOT gather = s32[5,3] gather(operand, indices), - output_window_dims={1}, - elided_window_dims={0}, - gather_dims_to_operand_dims={0}, + offset_dims={1}, + collapsed_slice_dims={0}, + start_index_map={0}, index_vector_dim=1, - window_bounds={1,3} + slice_sizes={1,3} } )"; @@ -102,11 +102,11 @@ ENTRY main { operand = s32[3,3] constant(s32[3,3]{{1,2,3},{1,2,3},{1,2,3}}) indices = s32[5] parameter(0) ROOT gather = s32[5,3] gather(operand, indices), - output_window_dims={1}, - elided_window_dims={0}, - gather_dims_to_operand_dims={0}, + offset_dims={1}, + collapsed_slice_dims={0}, + start_index_map={0}, index_vector_dim=1, - window_bounds={1,3} + slice_sizes={1,3} } )"; @@ -122,11 +122,11 @@ ENTRY main { operand = s32[3,3] constant(s32[3,3]{{1,2,3},{1,2,3},{1,2,3}}) indices = s32[5,2] parameter(0) ROOT gather = s32[5] gather(operand, indices), - output_window_dims={}, - elided_window_dims={0,1}, - gather_dims_to_operand_dims={0,1}, + offset_dims={}, + collapsed_slice_dims={0,1}, + start_index_map={0,1}, index_vector_dim=1, - window_bounds={1,1} + slice_sizes={1,1} } )"; @@ -141,11 +141,11 @@ ENTRY main { operand = s32[3,3,1] parameter(0) indices = s32[5] parameter(1) ROOT gather = s32[5,3] gather(operand, indices), - output_window_dims={1}, - elided_window_dims={0,2}, - gather_dims_to_operand_dims={0}, + offset_dims={1}, + collapsed_slice_dims={0,2}, + start_index_map={0}, index_vector_dim=1, - window_bounds={1,3,1} + slice_sizes={1,3,1} } )"; @@ -160,11 +160,11 @@ ENTRY main { operand = s32[3,3,1] parameter(0) indices = s32[5] parameter(1) ROOT gather = s32[5,2,3] gather(operand, indices), - output_window_dims={1,2}, - elided_window_dims={2}, - gather_dims_to_operand_dims={0}, + offset_dims={1,2}, + collapsed_slice_dims={2}, + start_index_map={0}, index_vector_dim=1, - window_bounds={2,3,1} + slice_sizes={2,3,1} } )"; @@ -179,11 +179,11 @@ ENTRY main { operand = s32[3,3] parameter(0) indices = s32[5] parameter(1) ROOT gather = s32[5,2] gather(operand, indices), - output_window_dims={1}, - elided_window_dims={0}, - gather_dims_to_operand_dims={0}, + offset_dims={1}, + collapsed_slice_dims={0}, + start_index_map={0}, index_vector_dim=1, - window_bounds={1,2} + slice_sizes={1,2} } )"; @@ -199,17 +199,17 @@ ENTRY main { indices_a = s32[5] parameter(0) indices_b = s32[2] parameter(1) gather_a = s32[5,3] gather(operand, indices_a), - output_window_dims={1}, - elided_window_dims={0}, - gather_dims_to_operand_dims={0}, + offset_dims={1}, + collapsed_slice_dims={0}, + start_index_map={0}, index_vector_dim=1, - window_bounds={1,3} + slice_sizes={1,3} ROOT gather_b = s32[2,3] gather(gather_a, indices_b), - output_window_dims={1}, - elided_window_dims={0}, - gather_dims_to_operand_dims={0}, + offset_dims={1}, + collapsed_slice_dims={0}, + start_index_map={0}, index_vector_dim=1, - window_bounds={1,3} + slice_sizes={1,3} } )"; @@ -228,17 +228,17 @@ ENTRY main { indices_a = s32[5,7] parameter(1) indices_b = s32[2] parameter(2) gather_a = s32[5,3,7] gather(operand, indices_a), - output_window_dims={1}, - elided_window_dims={1}, - gather_dims_to_operand_dims={1}, + offset_dims={1}, + collapsed_slice_dims={1}, + start_index_map={1}, index_vector_dim=2, - window_bounds={3,1} + slice_sizes={3,1} ROOT gather_b = s32[5,3,2] gather(gather_a, indices_b), - output_window_dims={0,1}, - elided_window_dims={2}, - gather_dims_to_operand_dims={2}, + offset_dims={0,1}, + collapsed_slice_dims={2}, + start_index_map={2}, index_vector_dim=1, - window_bounds={5,3,1} + slice_sizes={5,3,1} } )"; @@ -256,17 +256,17 @@ ENTRY main { indices_a = s32[2] parameter(1) indices_b = s32[5,7] parameter(2) gather_a = s32[2,6] gather(operand, indices_a), - output_window_dims={1}, - elided_window_dims={0}, - gather_dims_to_operand_dims={0}, + offset_dims={1}, + collapsed_slice_dims={0}, + start_index_map={0}, index_vector_dim=1, - window_bounds={1,6} + slice_sizes={1,6} ROOT gather_b = s32[5,6,7] gather(gather_a, indices_b), - output_window_dims={1}, - elided_window_dims={0}, - gather_dims_to_operand_dims={0}, + offset_dims={1}, + collapsed_slice_dims={0}, + start_index_map={0}, index_vector_dim=2, - window_bounds={1,6} + slice_sizes={1,6} } )"; @@ -284,17 +284,17 @@ ENTRY main { indices_a = s32[5,7] parameter(1) indices_b = s32[4,8] parameter(2) gather_a = s32[5,3,7] gather(operand, indices_a), - output_window_dims={1}, - elided_window_dims={1}, - gather_dims_to_operand_dims={1}, + offset_dims={1}, + collapsed_slice_dims={1}, + start_index_map={1}, index_vector_dim=2, - window_bounds={3,1} + slice_sizes={3,1} ROOT gather_b = s32[4,5,3,8] gather(gather_a, indices_b), - output_window_dims={1,2}, - elided_window_dims={2}, - gather_dims_to_operand_dims={2}, + offset_dims={1,2}, + collapsed_slice_dims={2}, + start_index_map={2}, index_vector_dim=2, - window_bounds={5,3,1} + slice_sizes={5,3,1} } )"; @@ -312,11 +312,11 @@ ENTRY main { operand = s32[3,4] constant(s32[3,4]{{1,2,3,4},{1,2,3,4},{1,2,3,4}}) indices = s32[5] parameter(0) gather = s32[5,4] gather(operand, indices), - output_window_dims={1}, - elided_window_dims={0}, - gather_dims_to_operand_dims={0}, + offset_dims={1}, + collapsed_slice_dims={0}, + start_index_map={0}, index_vector_dim=1, - window_bounds={1,4} + slice_sizes={1,4} ROOT reshape = s32[5,2,2] reshape(gather) } )"; @@ -333,11 +333,11 @@ ENTRY main { operand = s32[3,4] constant(s32[3,4]{{1,2,3,4},{1,2,3,4},{1,2,3,4}}) indices = s32[5,7] parameter(0) gather = s32[5,4,7] gather(operand, indices), - output_window_dims={1}, - elided_window_dims={0}, - gather_dims_to_operand_dims={0}, + offset_dims={1}, + collapsed_slice_dims={0}, + start_index_map={0}, index_vector_dim=2, - window_bounds={1,4} + slice_sizes={1,4} ROOT reshape = s32[5,2,2,7] reshape(gather) } )"; @@ -358,11 +358,11 @@ ENTRY main { {{1,2,3,4,5,6},{1,2,3,4,5,6}}}) indices = s32[5,7] parameter(0) gather = s32[5,2,6,7] gather(operand, indices), - output_window_dims={1,2}, - elided_window_dims={0}, - gather_dims_to_operand_dims={0}, + offset_dims={1,2}, + collapsed_slice_dims={0}, + start_index_map={0}, index_vector_dim=2, - window_bounds={1,2,6} + slice_sizes={1,2,6} ROOT reshape = s32[5,3,4,7] reshape(gather) } )"; @@ -381,11 +381,11 @@ ENTRY main { {1,2,3,4,5,6},{1,2,3,4,5,6}}) indices = s32[1] parameter(0) gather = s32[1,6] gather(operand, indices), - output_window_dims={1}, - elided_window_dims={0}, - gather_dims_to_operand_dims={0}, + offset_dims={1}, + collapsed_slice_dims={0}, + start_index_map={0}, index_vector_dim=1, - window_bounds={1,6} + slice_sizes={1,6} ROOT reshape = s32[1,1,6] reshape(gather) } )"; @@ -408,14 +408,14 @@ ENTRY main { operand = s32[2,3]{1,0} constant(s32[2,3] { { 1, 2, 3 }, { 1, 2, 3 } }) i.0 = s64[1,3]{1,0} parameter(0) - g.0 = s32[1,3,3]{2,1,0} gather(operand, i.0), output_window_dims={2}, - elided_window_dims={0}, gather_dims_to_operand_dims={0}, - index_vector_dim=2, window_bounds={1,3} + g.0 = s32[1,3,3]{2,1,0} gather(operand, i.0), offset_dims={2}, + collapsed_slice_dims={0}, start_index_map={0}, + index_vector_dim=2, slice_sizes={1,3} i.1 = s64[1] parameter(1) - g.1 = s32[1,1,3]{2,1,0} gather(g.0, i.1), output_window_dims={0,2}, - elided_window_dims={1}, gather_dims_to_operand_dims={1}, - index_vector_dim=1, window_bounds={1,1,3} + g.1 = s32[1,1,3]{2,1,0} gather(g.0, i.1), offset_dims={0,2}, + collapsed_slice_dims={1}, start_index_map={1}, + index_vector_dim=1, slice_sizes={1,1,3} ROOT reshape = s32[1,3]{1,0} reshape(g.1) } @@ -441,11 +441,11 @@ ENTRY main { operand = s32[1,6] constant(s32[1,6]{{1,2,3,4,5,6}}) indices = s32[1] parameter(0) gather = s32[1,6] gather(operand, indices), - output_window_dims={1}, - elided_window_dims={0}, - gather_dims_to_operand_dims={0}, + offset_dims={1}, + collapsed_slice_dims={0}, + start_index_map={0}, index_vector_dim=1, - window_bounds={1,6} + slice_sizes={1,6} ROOT reshape = s32[1,1,6] reshape(gather) } )"; @@ -469,11 +469,11 @@ ENTRY main { {1,2,3,4,5,6},{1,2,3,4,5,6}}}) indices = s32[1] parameter(0) gather = s32[1,1,6] gather(operand, indices), - output_window_dims={1,2}, - elided_window_dims={1}, - gather_dims_to_operand_dims={1}, + offset_dims={1,2}, + collapsed_slice_dims={1}, + start_index_map={1}, index_vector_dim=1, - window_bounds={1,1,6} + slice_sizes={1,1,6} ROOT reshape = s32[1,1,1,6] reshape(gather) } )"; @@ -500,11 +500,11 @@ ENTRY main { {1,2,3,4,5,6},{1,2,3,4,5,6}}) indices = s32[1,5] parameter(0) gather = s32[1,5,6] gather(operand, indices), - output_window_dims={2}, - elided_window_dims={0}, - gather_dims_to_operand_dims={0}, + offset_dims={2}, + collapsed_slice_dims={0}, + start_index_map={0}, index_vector_dim=2, - window_bounds={1,6} + slice_sizes={1,6} ROOT reshape = s32[1,1,5,6] reshape(gather) } )"; @@ -530,11 +530,11 @@ ENTRY main { operand = s32[3,4] constant(s32[3,4]{{1,2,3,4},{1,2,3,4},{1,2,3,4}}) indices = s32[5,6] parameter(0) gather = s32[5,4,6] gather(operand, indices), - output_window_dims={1}, - elided_window_dims={0}, - gather_dims_to_operand_dims={0}, + offset_dims={1}, + collapsed_slice_dims={0}, + start_index_map={0}, index_vector_dim=2, - window_bounds={1,4} + slice_sizes={1,4} ROOT reshape = s32[5,2,2,2,3] reshape(gather) } )"; @@ -562,11 +562,11 @@ ENTRY main { {{1,2},{3,4},{5,6},{7,8},{9,10}}}) indices = s32[7] parameter(0) gather = s32[3,2,7] gather(operand, indices), - output_window_dims={0,1}, - elided_window_dims={1}, - gather_dims_to_operand_dims={1}, + offset_dims={0,1}, + collapsed_slice_dims={1}, + start_index_map={1}, index_vector_dim=1, - window_bounds={3,1,2} + slice_sizes={3,1,2} ROOT reshape = s32[6,7] reshape(gather) } )"; @@ -594,11 +594,11 @@ ENTRY main { {{1},{2},{3},{4}}}) indices = s32[5,6] parameter(0) gather = s32[5,4,6,1] gather(operand, indices), - output_window_dims={1,3}, - elided_window_dims={0}, - gather_dims_to_operand_dims={0}, + offset_dims={1,3}, + collapsed_slice_dims={0}, + start_index_map={0}, index_vector_dim=2, - window_bounds={1,4,1} + slice_sizes={1,4,1} ROOT reshape = s32[5,2,2,2,3,1] reshape(gather) } )"; @@ -623,11 +623,11 @@ ENTRY main { operand = f32[3,4] constant(f32[3,4]{{1,2,3,4},{1,3,2,4},{4,3,2,1}}) indices = s32[5] parameter(0) gather = f32[5,4] gather(operand, indices), - output_window_dims={1}, - elided_window_dims={0}, - gather_dims_to_operand_dims={0}, + offset_dims={1}, + collapsed_slice_dims={0}, + start_index_map={0}, index_vector_dim=1, - window_bounds={1,4} + slice_sizes={1,4} ROOT tanh = f32[5,4] tanh(gather) } )"; @@ -650,11 +650,11 @@ ENTRY main { constant_broadcasted = s32[5,4] broadcast(constant), dimensions={} indices = s32[5] parameter(0) gather = s32[5,4] gather(gather_operand, indices), - output_window_dims={1}, - elided_window_dims={0}, - gather_dims_to_operand_dims={0}, + offset_dims={1}, + collapsed_slice_dims={0}, + start_index_map={0}, index_vector_dim=1, - window_bounds={1,4} + slice_sizes={1,4} ROOT add = s32[5,4] add(gather, constant_broadcasted) } )"; @@ -678,11 +678,11 @@ ENTRY main { constant_broadcasted = s32[5,4] broadcast(constant), dimensions={} indices = s32[5] parameter(0) gather = s32[5,4] gather(gather_operand, indices), - output_window_dims={1}, - elided_window_dims={0}, - gather_dims_to_operand_dims={0}, + offset_dims={1}, + collapsed_slice_dims={0}, + start_index_map={0}, index_vector_dim=1, - window_bounds={1,4} + slice_sizes={1,4} ROOT sub = s32[5,4] subtract(gather, constant_broadcasted) } )"; @@ -706,11 +706,11 @@ ENTRY main { constant_broadcasted = s32[5,4] broadcast(constant), dimensions={} indices = s32[5] parameter(0) gather = s32[5,4] gather(gather_operand, indices), - output_window_dims={1}, - elided_window_dims={0}, - gather_dims_to_operand_dims={0}, + offset_dims={1}, + collapsed_slice_dims={0}, + start_index_map={0}, index_vector_dim=1, - window_bounds={1,4} + slice_sizes={1,4} ROOT sub = s32[5,4] subtract(constant_broadcasted, gather) } )"; @@ -733,11 +733,11 @@ ENTRY main { constant_broadcasted = s32[5,4] broadcast(constant_vect), dimensions={1} indices = s32[5] parameter(0) gather = s32[5,4] gather(gather_operand, indices), - output_window_dims={1}, - elided_window_dims={0}, - gather_dims_to_operand_dims={0}, + offset_dims={1}, + collapsed_slice_dims={0}, + start_index_map={0}, index_vector_dim=1, - window_bounds={1,4} + slice_sizes={1,4} ROOT add = s32[5,4] add(gather, constant_broadcasted) } )"; @@ -760,11 +760,11 @@ ENTRY main { constant_broadcasted = s32[5,4] broadcast(constant_vect), dimensions={0} indices = s32[5] parameter(0) gather = s32[5,4] gather(gather_operand, indices), - output_window_dims={1}, - elided_window_dims={0}, - gather_dims_to_operand_dims={0}, + offset_dims={1}, + collapsed_slice_dims={0}, + start_index_map={0}, index_vector_dim=1, - window_bounds={1,4} + slice_sizes={1,4} ROOT add = s32[5,4] add(gather, constant_broadcasted) } )"; @@ -808,11 +808,11 @@ ENTRY main { dot_rhs_constant = s32[4,3] constant(s32[4,3]{{1,2,3},{4,5,6},{7,8,9},{10,11,12}}) indices = s32[5] parameter(0) dot_lhs = s32[5,4] gather(gather_operand, indices), - output_window_dims={1}, - elided_window_dims={0}, - gather_dims_to_operand_dims={0}, + offset_dims={1}, + collapsed_slice_dims={0}, + start_index_map={0}, index_vector_dim=1, - window_bounds={1,4} + slice_sizes={1,4} ROOT dot = s32[5,3] dot(dot_lhs, dot_rhs_constant), lhs_contracting_dims={1}, rhs_contracting_dims={0} } )"; @@ -835,11 +835,11 @@ ENTRY main { dot_rhs_constant = s32[3,3] constant(s32[3,3]{{1,2,3},{4,5,6},{7,8,9}}) indices = s32[5] parameter(0) dot_lhs = s32[3,5] gather(gather_operand, indices), - output_window_dims={0}, - elided_window_dims={1}, - gather_dims_to_operand_dims={1}, + offset_dims={0}, + collapsed_slice_dims={1}, + start_index_map={1}, index_vector_dim=1, - window_bounds={3,1} + slice_sizes={3,1} ROOT dot = s32[5,3] dot(dot_lhs, dot_rhs_constant), lhs_contracting_dims={0}, rhs_contracting_dims={0} } )"; @@ -863,11 +863,11 @@ ENTRY main { dot_lhs_constant = s32[4,3] constant(s32[4,3]{{1,2,3},{4,5,6},{7,8,9},{10,11,12}}) indices = s32[5] parameter(0) dot_rhs = s32[3,5] gather(gather_operand, indices), - output_window_dims={0}, - elided_window_dims={1}, - gather_dims_to_operand_dims={1}, + offset_dims={0}, + collapsed_slice_dims={1}, + start_index_map={1}, index_vector_dim=1, - window_bounds={3,1} + slice_sizes={3,1} ROOT dot = s32[4,5] dot(dot_lhs_constant, dot_rhs), lhs_contracting_dims={1}, rhs_contracting_dims={0} } )"; @@ -892,11 +892,11 @@ ENTRY main { dot_lhs_constant = s32[4,3] constant(s32[4,3]{{1,2,3},{4,5,6},{7,8,9},{10,11,12}}) indices = s32[5] parameter(0) dot_rhs = s32[5,3] gather(gather_operand, indices), - output_window_dims={1}, - elided_window_dims={0}, - gather_dims_to_operand_dims={0}, + offset_dims={1}, + collapsed_slice_dims={0}, + start_index_map={0}, index_vector_dim=1, - window_bounds={1,3} + slice_sizes={1,3} ROOT dot = s32[4,5] dot(dot_lhs_constant, dot_rhs), lhs_contracting_dims={1}, rhs_contracting_dims={1} } )"; @@ -921,11 +921,11 @@ ENTRY main { dot_lhs_constant = s32[2,2,3] constant(s32[2,2,3]{{{1,2,3},{4,5,6}},{{7,8,9},{10,11,12}}}) indices = s32[4] parameter(0) dot_rhs = s32[2,3,4] gather(gather_operand, indices), - output_window_dims={0,1}, - elided_window_dims={2}, - gather_dims_to_operand_dims={2}, + offset_dims={0,1}, + collapsed_slice_dims={2}, + start_index_map={2}, index_vector_dim=1, - window_bounds={2,3,1} + slice_sizes={2,3,1} ROOT dot = s32[2,2,4] dot(dot_lhs_constant, dot_rhs), lhs_contracting_dims={2}, rhs_contracting_dims={1}, lhs_batch_dims={0}, rhs_batch_dims={0} @@ -952,11 +952,11 @@ ENTRY main { dot_rhs_constant = s32[2,3] constant(s32[2,3]{{1,2,3},{4,5,6}}) indices = s32[2] parameter(0) dot_lhs = s32[3,2] gather(gather_operand, indices), - output_window_dims={0}, - elided_window_dims={1}, - gather_dims_to_operand_dims={1}, + offset_dims={0}, + collapsed_slice_dims={1}, + start_index_map={1}, index_vector_dim=1, - window_bounds={3,1} + slice_sizes={3,1} ROOT dot = s32[3,3] dot(dot_lhs, dot_rhs_constant), lhs_contracting_dims={1}, rhs_contracting_dims={0} } )"; diff --git a/tensorflow/compiler/xla/service/inliner_test.cc b/tensorflow/compiler/xla/service/inliner_test.cc index 32937b33b3737482f07d4c7607f7f1c5c183a56b..5695bc242057c037a1999e7d63f5b4f21b5f658a 100644 --- a/tensorflow/compiler/xla/service/inliner_test.cc +++ b/tensorflow/compiler/xla/service/inliner_test.cc @@ -18,8 +18,8 @@ limitations under the License. #include #include +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/literal.h" -#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_matchers.h" diff --git a/tensorflow/compiler/xla/service/instruction_fusion.cc b/tensorflow/compiler/xla/service/instruction_fusion.cc index af07370135ca2b2e53fcbcb53696e0aa12bf7a6f..2fd221480634004a0371f42aab1247fce33cde90 100644 --- a/tensorflow/compiler/xla/service/instruction_fusion.cc +++ b/tensorflow/compiler/xla/service/instruction_fusion.cc @@ -21,6 +21,7 @@ limitations under the License. #include #include +#include "absl/algorithm/container.h" #include "tensorflow/compiler/xla/map_util.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/core/lib/core/errors.h" @@ -120,6 +121,7 @@ bool IsAlwaysDuplicable(const HloInstruction& instruction) { case HloOpcode::kConditional: case HloOpcode::kConvolution: case HloOpcode::kCrossReplicaSum: + case HloOpcode::kAllToAll: case HloOpcode::kCustomCall: case HloOpcode::kDivide: case HloOpcode::kDomain: @@ -141,6 +143,7 @@ bool IsAlwaysDuplicable(const HloInstruction& instruction) { case HloOpcode::kReduceWindow: case HloOpcode::kRemainder: case HloOpcode::kRng: + case HloOpcode::kScatter: case HloOpcode::kSelectAndScatter: case HloOpcode::kSend: case HloOpcode::kSendDone: @@ -495,7 +498,7 @@ HloInstruction* InstructionFusion::FuseIntoMultiOutput( bool InstructionFusion::MultiOutputFusionCreatesCycle( HloInstruction* producer, HloInstruction* consumer) { - return c_any_of( + return absl::c_any_of( consumer->operands(), [&](const HloInstruction* consumer_operand) { // The fusion algorithm traverses the HLO graph in reverse post order. // Thus `cosumers` is visited before its operands (including diff --git a/tensorflow/compiler/xla/service/interpreter/BUILD b/tensorflow/compiler/xla/service/interpreter/BUILD index 8652599dc6d48ff8c2aaa703fead161f891a57d1..581f8d2e92b9d7c4350360282cbd9e69824841ca 100644 --- a/tensorflow/compiler/xla/service/interpreter/BUILD +++ b/tensorflow/compiler/xla/service/interpreter/BUILD @@ -12,12 +12,11 @@ cc_library( srcs = ["interpreter_transfer_manager.cc"], hdrs = ["interpreter_transfer_manager.h"], deps = [ - "//tensorflow/compiler/xla:util", - "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/service:generic_transfer_manager", "//tensorflow/compiler/xla/service:transfer_manager", "//tensorflow/compiler/xla/service/interpreter:platform_id", "//tensorflow/core:lib", + "@com_google_absl//absl/memory", ], alwayslink = True, # Contains per-platform transfer manager registration ) @@ -32,8 +31,6 @@ cc_library( "//tensorflow/compiler/xla:status", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", - "//tensorflow/compiler/xla:util", - "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/service:algebraic_simplifier", "//tensorflow/compiler/xla/service:compiler", "//tensorflow/compiler/xla/service:computation_placer", @@ -54,6 +51,7 @@ cc_library( "//tensorflow/compiler/xla/service:while_loop_simplifier", "//tensorflow/core:lib", "//tensorflow/stream_executor", + "@com_google_absl//absl/memory", ], alwayslink = True, # Contains compiler registration ) @@ -79,7 +77,6 @@ cc_library( "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:types", - "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/service:executable", "//tensorflow/compiler/xla/service:hlo", @@ -91,6 +88,7 @@ cc_library( "//tensorflow/compiler/xla/service:transfer_manager", "//tensorflow/core:lib", "//tensorflow/core:stream_executor_no_cuda", + "@com_google_absl//absl/memory", ], ) diff --git a/tensorflow/compiler/xla/service/interpreter/compiler.cc b/tensorflow/compiler/xla/service/interpreter/compiler.cc index 9f8f4bda875cdff5e20fa8ca8eeecaa1140e2b9c..bb69cb9c47ff2c7de8d13832c4b8e6216c62da73 100644 --- a/tensorflow/compiler/xla/service/interpreter/compiler.cc +++ b/tensorflow/compiler/xla/service/interpreter/compiler.cc @@ -18,7 +18,7 @@ limitations under the License. #include #include -#include "tensorflow/compiler/xla/ptr_util.h" +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/service/algebraic_simplifier.h" #include "tensorflow/compiler/xla/service/computation_placer.h" #include "tensorflow/compiler/xla/service/flatten_call_graph.h" @@ -69,8 +69,8 @@ StatusOr> InterpreterCompiler::RunBackend( // Create executable from only the Hlo module. std::unique_ptr executable = - xla::MakeUnique(std::move(hlo_module), - xla::MakeUnique()); + absl::make_unique( + std::move(hlo_module), absl::make_unique()); return std::move(executable); } @@ -103,11 +103,11 @@ HloCostAnalysis::ShapeSizeFunction InterpreterCompiler::ShapeSizeBytesFunction() static bool InitModule() { xla::Compiler::RegisterCompilerFactory( se::interpreter::kXlaInterpreterPlatformId, []() { - return xla::MakeUnique(); + return absl::make_unique(); }); xla::ComputationPlacer::RegisterComputationPlacer( se::interpreter::kXlaInterpreterPlatformId, - []() { return xla::MakeUnique(); }); + []() { return absl::make_unique(); }); return true; } diff --git a/tensorflow/compiler/xla/service/interpreter/executable.cc b/tensorflow/compiler/xla/service/interpreter/executable.cc index 8d40c08d555a232b7cf3b81cc0f9970804c2f896..2259dc1083e6d1ca64cc7d7b8d9c566a27183ac7 100644 --- a/tensorflow/compiler/xla/service/interpreter/executable.cc +++ b/tensorflow/compiler/xla/service/interpreter/executable.cc @@ -21,8 +21,8 @@ limitations under the License. #include #include +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/literal.h" -#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/interpreter/executor.h" diff --git a/tensorflow/compiler/xla/service/interpreter/executor.h b/tensorflow/compiler/xla/service/interpreter/executor.h index 9b109022fbfc698f7dadc678ef837da270a5e74a..db6b910b32f8ec234c4cf1c331a1aa3bb2f9389f 100644 --- a/tensorflow/compiler/xla/service/interpreter/executor.h +++ b/tensorflow/compiler/xla/service/interpreter/executor.h @@ -104,7 +104,7 @@ class XlaInterpreterExecutor : public internal::StreamExecutorInterface { } // No "synchronize all activity" implemented for this platform at the moment. - bool SynchronizeAllActivity() override { return false; } + bool SynchronizeAllActivity() override { return true; } bool SynchronousMemZero(DeviceMemoryBase *location, uint64 size) override { return false; } diff --git a/tensorflow/compiler/xla/service/interpreter/interpreter_transfer_manager.cc b/tensorflow/compiler/xla/service/interpreter/interpreter_transfer_manager.cc index d27cd7502f10a1f615fc5b0d610acafdf55e3e43..7955ee5cf37f3fa45b942d8ab05a60076857dc6c 100644 --- a/tensorflow/compiler/xla/service/interpreter/interpreter_transfer_manager.cc +++ b/tensorflow/compiler/xla/service/interpreter/interpreter_transfer_manager.cc @@ -17,7 +17,7 @@ limitations under the License. #include -#include "tensorflow/compiler/xla/ptr_util.h" +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/service/interpreter/platform_id.h" #include "tensorflow/compiler/xla/service/transfer_manager.h" @@ -31,7 +31,7 @@ InterpreterTransferManager::InterpreterTransferManager() static std::unique_ptr CreateInterpreterTransferManager() { - return xla::MakeUnique(); + return absl::make_unique(); } static bool InitModule() { diff --git a/tensorflow/compiler/xla/service/interpreter/platform.cc b/tensorflow/compiler/xla/service/interpreter/platform.cc index 42c2c28997d5f3b02f1fe4effca164c893e4071d..e57a9b3672391e11b130b1c16307a80a0a5b5e77 100644 --- a/tensorflow/compiler/xla/service/interpreter/platform.cc +++ b/tensorflow/compiler/xla/service/interpreter/platform.cc @@ -17,6 +17,7 @@ limitations under the License. #include +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/service/interpreter/executor.h" #include "tensorflow/stream_executor/device_options.h" #include "tensorflow/stream_executor/lib/initialize.h" @@ -70,8 +71,8 @@ port::StatusOr XlaInterpreterPlatform::GetExecutor( port::StatusOr> XlaInterpreterPlatform::GetUncachedExecutor( const StreamExecutorConfig& config) { - auto executor = MakeUnique( - this, MakeUnique(config.plugin_config)); + auto executor = absl::make_unique( + this, absl::make_unique(config.plugin_config)); auto init_status = executor->Init(config.ordinal, config.device_options); if (!init_status.ok()) { return port::Status{ diff --git a/tensorflow/compiler/xla/service/layout_assignment.cc b/tensorflow/compiler/xla/service/layout_assignment.cc index 9705687b004976fc5d35ddeb1c2a69c65ed50358..c75bffc63d71c8018ad71b035d4e9a0886c0f4a6 100644 --- a/tensorflow/compiler/xla/service/layout_assignment.cc +++ b/tensorflow/compiler/xla/service/layout_assignment.cc @@ -26,9 +26,9 @@ limitations under the License. #include #include +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/map_util.h" -#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/computation_layout.h" #include "tensorflow/compiler/xla/service/hlo_casting_utils.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" @@ -137,7 +137,7 @@ PointsToSet::BufferSet* LayoutConstraints::GetBufferSet( } auto& buffer_set = buffer_sets_cache_ - .emplace(instruction, MakeUnique()) + .emplace(instruction, absl::make_unique()) .first->second; const auto& points_to_set = points_to_analysis_.GetPointsToSet(instruction); points_to_set.ForEachElement( @@ -874,8 +874,8 @@ void LayoutAssignment::SetupCopiedInstruction(const HloInstruction& instruction, // HostCompute module. // Otherwise it is preferable to leave the new instruction without device, // and let the automatic device placer to choose the best location. - if (!sharding.HasUniqueDevice() || - HloSharding::IsReservedDevice(sharding.UniqueDevice().ValueOrDie())) { + auto device = sharding.UniqueDevice(); + if (!device || HloSharding::IsReservedDevice(*device)) { copy->set_sharding(sharding); } } @@ -1008,7 +1008,7 @@ std::unique_ptr LayoutAssignment::ChooseOperandLayoutFromOutputLayout( // // TODO(jingyue): Other operations, such as kSlice and kConcat, can benefit // from assigning the same layout to input and output. - return MakeUnique(output_layout); + return absl::make_unique(output_layout); } if (instruction->opcode() == HloOpcode::kReshape) { @@ -1031,13 +1031,13 @@ std::unique_ptr LayoutAssignment::ChooseOperandLayoutFromOutputLayout( *operand_shape.mutable_layout() = LayoutUtil::GetDefaultLayoutForShape(operand_shape); if (ShapeUtil::ReshapeIsBitcast(operand_shape, output_shape_with_layout)) { - return MakeUnique(operand_shape.layout()); + return absl::make_unique(operand_shape.layout()); } if (ShapeUtil::Rank(operand_shape) == ShapeUtil::Rank(output_shape)) { *operand_shape.mutable_layout() = output_layout; if (ShapeUtil::ReshapeIsBitcast(operand_shape, output_shape_with_layout)) { - return MakeUnique(output_layout); + return absl::make_unique(output_layout); } } auto aligned_operand_shape = @@ -1046,7 +1046,7 @@ std::unique_ptr LayoutAssignment::ChooseOperandLayoutFromOutputLayout( auto operand_layout = aligned_operand_shape.value().layout(); TF_CHECK_OK( LayoutUtil::ValidateLayoutForShape(operand_layout, operand_shape)); - return MakeUnique(operand_layout); + return absl::make_unique(operand_layout); } } @@ -1062,7 +1062,7 @@ std::unique_ptr LayoutAssignment::ChooseOperandLayoutFromOutputLayout( Layout operand_layout = LayoutUtil::MakeLayout(new_minor_to_major); TF_CHECK_OK( LayoutUtil::ValidateLayoutForShape(operand_layout, operand->shape())); - return MakeUnique(operand_layout); + return absl::make_unique(operand_layout); } return nullptr; @@ -1080,7 +1080,7 @@ std::unique_ptr LayoutAssignment::ChooseOutputLayoutFromOperandLayout( !ShapeUtil::IsScalar(operand->shape()) && ShapeUtil::Rank(operand->shape()) == ShapeUtil::Rank(user->shape())) { // Assign users the same layout as the operand. - return MakeUnique(operand_layout); + return absl::make_unique(operand_layout); } if (user->opcode() == HloOpcode::kReshape) { @@ -1103,13 +1103,13 @@ std::unique_ptr LayoutAssignment::ChooseOutputLayoutFromOperandLayout( *output_shape.mutable_layout() = LayoutUtil::GetDefaultLayoutForShape(output_shape); if (ShapeUtil::ReshapeIsBitcast(output_shape, operand_shape_with_layout)) { - return MakeUnique(output_shape.layout()); + return absl::make_unique(output_shape.layout()); } if (ShapeUtil::Rank(operand->shape()) == ShapeUtil::Rank(output_shape)) { *output_shape.mutable_layout() = operand_layout; if (ShapeUtil::ReshapeIsBitcast(output_shape, operand_shape_with_layout)) { - return MakeUnique(operand_layout); + return absl::make_unique(operand_layout); } } auto aligned_user_shape = @@ -1118,7 +1118,7 @@ std::unique_ptr LayoutAssignment::ChooseOutputLayoutFromOperandLayout( auto user_layout = aligned_user_shape.value().layout(); TF_CHECK_OK( LayoutUtil::ValidateLayoutForShape(user_layout, output_shape)); - return MakeUnique(user_layout); + return absl::make_unique(user_layout); } } @@ -1134,7 +1134,7 @@ std::unique_ptr LayoutAssignment::ChooseOutputLayoutFromOperandLayout( } Layout user_layout = LayoutUtil::MakeLayout(new_minor_to_major); TF_CHECK_OK(LayoutUtil::ValidateLayoutForShape(user_layout, user->shape())); - return MakeUnique(user_layout); + return absl::make_unique(user_layout); } return nullptr; @@ -1228,7 +1228,7 @@ Status LayoutAssignment::PropagateUseConstraintToDefs( const PointsToSet& points_to_set = constraints->points_to_analysis().GetPointsToSet(instruction); return points_to_set.ForEachElementWithStatus( - [this, &shape_layout, constraints]( + [&shape_layout, constraints]( const ShapeIndex& index, const PointsToSet::BufferList& buffers) -> Status { if (ShapeUtil::IsLeafIndex(shape_layout.shape(), index)) { @@ -1563,7 +1563,7 @@ Status LayoutAssignment::ClearComputationLayouts(HloComputation* computation) { // and the computation result. The latter two are specified in // computation_layout, so we only need to keep the existing layouts for // infeeds. Clearing the layouts here avoids hiding potential bugs in the - // layout assignment pass that may accidently use the existing layout. + // layout assignment pass that may accidentally use the existing layout. for (HloInstruction* instruction : computation->instructions()) { if (instruction->opcode() == HloOpcode::kBitcast) { // bitcasts are inherently layout sensitive and so a bitcast instruction diff --git a/tensorflow/compiler/xla/service/llvm_ir/BUILD b/tensorflow/compiler/xla/service/llvm_ir/BUILD index 309a186e589dd5eabe0686def8a759a99fea276e..ce2d6678a5e8b71c5e2dacd2bc052d7cdb4cd292 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/BUILD +++ b/tensorflow/compiler/xla/service/llvm_ir/BUILD @@ -88,6 +88,7 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:lib", + "@com_google_absl//absl/algorithm:container", "@llvm//:core", ], ) @@ -224,6 +225,15 @@ cc_library( ], ) +cc_library( + name = "buffer_assignment_util", + srcs = ["buffer_assignment_util.cc"], + hdrs = ["buffer_assignment_util.h"], + deps = [ + "//tensorflow/compiler/xla/service:buffer_assignment", + ], +) + cc_library( name = "math_ops", srcs = ["math_ops.cc"], diff --git a/tensorflow/compiler/xla/service/llvm_ir/alias_analysis_test.cc b/tensorflow/compiler/xla/service/llvm_ir/alias_analysis_test.cc index 2552ff4a6a06d18f34b4ba224b66d6d97ddd74d3..fe5ec1cc66d06e85ce70625ef7cf764a37b29166 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/alias_analysis_test.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/alias_analysis_test.cc @@ -56,12 +56,12 @@ ENTRY while3 { )"; CompileAndVerifyIr(hlo_string, R"( -; CHECK-LABEL: @body(i8* align 4 dereferenceable(4) %retval +; CHECK-LABEL: @body(i8* %retval ; CHECK: %[[add_result:.*]] = fadd fast float %[[fadd_lhs:.*]], %[[fadd_rhs:.*]] -; CHECK: store float %[[add_result]], float* %[[store_dest:.*]], !alias.scope ![[alias_scope_md_for_store:.*]] +; CHECK: store float %[[add_result]], float* %[[store_dest:.*]], !alias.scope ![[alias_scope_md_for_store:[0-9]+]] ; -; CHECK-LABEL: @condition(i8* align 1 dereferenceable(1) %fusion, i8* noalias %run_options, i8** noalias %params -; CHECK: %[[cond_state_buf_ptr:.*]] = getelementptr inbounds i8*, i8** %params, i64 0 +; CHECK-LABEL: @condition(i8* %retval, i8* noalias %run_options, i8** noalias %params +; CHECK: %[[cond_state_buf_ptr:.*]] = getelementptr inbounds i8*, i8** %temps, i64 0 ; CHECK: %[[cond_state_buf_untyped:.*]] = load i8*, i8** %[[cond_state_buf_ptr]] ; CHECK: %[[cond_state_buf_typed:.*]] = bitcast i8* %[[cond_state_buf_untyped]] to float* ; CHECK: load float, float* %[[cond_state_buf_typed]], !alias.scope ![[alias_scope_md_for_store]], !noalias ![[noalias_md_for_load:.*]] diff --git a/tensorflow/compiler/xla/service/llvm_ir/buffer_assignment_util.cc b/tensorflow/compiler/xla/service/llvm_ir/buffer_assignment_util.cc new file mode 100644 index 0000000000000000000000000000000000000000..4eb5d9fb4750927ca189e02f312b2d6be7fdd418 --- /dev/null +++ b/tensorflow/compiler/xla/service/llvm_ir/buffer_assignment_util.cc @@ -0,0 +1,59 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/llvm_ir/buffer_assignment_util.h" + +namespace xla { +namespace llvm_ir { +static const HloInstruction& InstrForConstantBufferAllocation( + const BufferAllocation& allocation) { + CHECK(allocation.is_constant()); + HloInstruction* const_instr = nullptr; + for (const auto& buffer_offset_pair : allocation.assigned_buffers()) { + const LogicalBuffer* buffer = buffer_offset_pair.first; + // BufferAssignment may have assigned non-constant instructions to this + // allocation too so we can't CHECK this condition. E.g. for + // + // while(init = constant, body = identity, cond = ...) + // + // the LogicalBuffer for the kWhile instruction will have the same + // BufferAllocation as the LogicalBuffer for the (init) constant. + if (buffer->instruction()->opcode() == HloOpcode::kConstant) { + CHECK_EQ(const_instr, nullptr) + << const_instr->ToString() << " " << buffer->ToString(); + const_instr = buffer->instruction(); + } + } + CHECK_NE(const_instr, nullptr); + return *const_instr; +} + +string ConstantBufferAllocationToGlobalName( + const BufferAllocation& allocation) { + string instr_name = InstrForConstantBufferAllocation(allocation).name(); + for (char& c : instr_name) { + if (c == '.') { + c = '_'; + } + } + return tensorflow::strings::StrCat("buffer_for_", instr_name); +} + +const Literal& LiteralForConstantAllocation( + const BufferAllocation& allocation) { + return InstrForConstantBufferAllocation(allocation).literal(); +} +} // namespace llvm_ir +} // namespace xla diff --git a/tensorflow/compiler/xla/service/llvm_ir/buffer_assignment_util.h b/tensorflow/compiler/xla/service/llvm_ir/buffer_assignment_util.h new file mode 100644 index 0000000000000000000000000000000000000000..bfb6eecb87f6a1b756b3a8da3377f608dd7f0be7 --- /dev/null +++ b/tensorflow/compiler/xla/service/llvm_ir/buffer_assignment_util.h @@ -0,0 +1,34 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_BUFFER_ASSIGNMENT_UTIL_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_BUFFER_ASSIGNMENT_UTIL_H_ + +#include "tensorflow/compiler/xla/service/buffer_assignment.h" + +namespace xla { +namespace llvm_ir { +// In XLA:GPU we map constant buffer allocations to globals in the generated +// LLVM IR. This function gives us the name of the global variable a constant +// buffer is mapped to. Not used on XLA:CPU. +string ConstantBufferAllocationToGlobalName(const BufferAllocation& allocation); + +// Returns the Literal corresponding to `allocation`, which must be a constant +// allocation. +const Literal& LiteralForConstantAllocation(const BufferAllocation& allocation); +} // namespace llvm_ir +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_BUFFER_ASSIGNMENT_UTIL_H_ diff --git a/tensorflow/compiler/xla/service/llvm_ir/ir_array.h b/tensorflow/compiler/xla/service/llvm_ir/ir_array.h index 28ca793e3eeaed86664bfa6aa859a38f2c4dc6f3..cbfd2e701235c9a5e65378eab4e1be469b1e9256 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/ir_array.h +++ b/tensorflow/compiler/xla/service/llvm_ir/ir_array.h @@ -19,6 +19,7 @@ limitations under the License. #include #include +#include "absl/algorithm/container.h" #include "llvm/IR/IRBuilder.h" #include "llvm/IR/Value.h" #include "tensorflow/compiler/xla/map_util.h" @@ -81,7 +82,7 @@ class IrArray { } } CHECK_NE(index_type_, nullptr); - CHECK(c_all_of(multidim, [&](llvm::Value* v) { + CHECK(absl::c_all_of(multidim, [&](llvm::Value* v) { return index_type_ == v->getType(); })); } diff --git a/tensorflow/compiler/xla/service/llvm_ir/sort_util.cc b/tensorflow/compiler/xla/service/llvm_ir/sort_util.cc index 6f261c32f4181a6c4107f7fbcf782feb4347e587..e546f5cc4ae305b40c1bdbcae090daadee11241b 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/sort_util.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/sort_util.cc @@ -30,6 +30,7 @@ limitations under the License. #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/core/stringpiece.h" +#include "tensorflow/core/lib/gtl/optional.h" #include "tensorflow/core/platform/types.h" namespace xla { @@ -38,19 +39,18 @@ namespace llvm_ir { namespace { // Adds the inner comparison loop where we compare elements pointed to by // 'keys_index' and 'compare_keys_index'. -void EmitCompareLoop(int64 dimension_to_sort, - const llvm_ir::IrArray::Index& keys_index, - const llvm_ir::IrArray::Index& compare_keys_index, - const llvm_ir::IrArray& keys_array, llvm::IRBuilder<>* b) { - // TODO(b/26783907): parallelize this loop. - +void EmitCompareLoop(int64 dimension_to_sort, const IrArray::Index& keys_index, + const IrArray::Index& compare_keys_index, + const IrArray& keys_array, + const tensorflow::gtl::optional& values_array, + llvm::IRBuilder<>* b) { // if (is_smaller_index && // compare_keys[dimension_to_sort] < dimension_to_sort_bound) llvm::Value* is_smaller_index = b->CreateICmpSLT( keys_index[dimension_to_sort], compare_keys_index[dimension_to_sort]); int64 dimension_to_sort_bound = keys_array.GetShape().dimensions(dimension_to_sort); - auto if_data = llvm_ir::EmitIfThenElse( + auto if_data = EmitIfThenElse( b->CreateAnd(is_smaller_index, b->CreateICmpSLT(compare_keys_index[dimension_to_sort], keys_index.GetConstantWithIndexType( @@ -63,30 +63,36 @@ void EmitCompareLoop(int64 dimension_to_sort, auto comparison = primitive_util::IsFloatingPointType(key_type) // TODO(b/26783907): Figure out how to handle NaNs. - ? b->CreateFCmp(llvm::FCmpInst::FCMP_ULT, key1, key2) + ? b->CreateFCmp(llvm::FCmpInst::FCMP_ULT, key2, key1) : b->CreateICmp(primitive_util::IsSignedIntegralType(key_type) ? llvm::ICmpInst::ICMP_SLT : llvm::ICmpInst::ICMP_ULT, - key1, key2); - auto min_key = b->CreateSelect(comparison, key1, key2); - auto max_key = b->CreateSelect(comparison, key2, key1); - keys_array.EmitWriteArrayElement(keys_index, min_key, b); - keys_array.EmitWriteArrayElement(compare_keys_index, max_key, b); + key2, key1); + // If key2 < key1 + auto if_smaller_data = + EmitIfThenElse(comparison, "is_smaller_than", b, /*emit_else=*/false); + SetToFirstInsertPoint(if_smaller_data.true_block, b); + // Swap key1 with key2. + keys_array.EmitWriteArrayElement(keys_index, key2, b); + keys_array.EmitWriteArrayElement(compare_keys_index, key1, b); + if (values_array.has_value()) { + // Also swap the values. + auto value1 = values_array.value().EmitReadArrayElement(keys_index, b); + auto value2 = + values_array.value().EmitReadArrayElement(compare_keys_index, b); + values_array.value().EmitWriteArrayElement(keys_index, value2, b); + values_array.value().EmitWriteArrayElement(compare_keys_index, value1, b); + } } } // namespace Status EmitSortInPlace(int64 dimension_to_sort, const IrArray& keys_array, + const tensorflow::gtl::optional& values_array, tensorflow::StringPiece name, llvm::Value* xor_mask, llvm::IRBuilder<>* b, const gpu::LaunchDimensions* launch_dimensions) { const Shape& keys_shape = keys_array.GetShape(); - // TODO(b/26783907): This case can probably be avoided with the Algebraic - // Simplifier. - if (ShapeUtil::IsScalar(keys_shape)) { - return Status::OK(); - } - // Create loop nests which loop through the operand dimensions. The sort // dimension is handled in the innermost loop which performs the sorting. ForLoopNest loop_nest(name, b); @@ -131,7 +137,7 @@ Status EmitSortInPlace(int64 dimension_to_sort, const IrArray& keys_array, compare_keys_index[dimension_to_sort] = b->CreateXor(compare_index[0], xor_mask); EmitCompareLoop(dimension_to_sort, keys_index, compare_keys_index, - keys_array, b); + keys_array, values_array, b); return Status::OK(); }; if (launch_dimensions != nullptr) { diff --git a/tensorflow/compiler/xla/service/llvm_ir/sort_util.h b/tensorflow/compiler/xla/service/llvm_ir/sort_util.h index e75f9b08fbba7c79b8354698ad17e79c154bd67e..8458744c6bc0e50a1c1cc8d3e66e29c7d4f74d73 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/sort_util.h +++ b/tensorflow/compiler/xla/service/llvm_ir/sort_util.h @@ -21,6 +21,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/core/stringpiece.h" +#include "tensorflow/core/lib/gtl/optional.h" #include "tensorflow/core/platform/types.h" namespace xla { @@ -30,6 +31,7 @@ namespace llvm_ir { // implements the inner loop of BitonicSort. If 'launch_dimensions' is nullptr, // the inner compare loop will not be parallelized. Status EmitSortInPlace(int64 dimension_to_sort, const IrArray& keys_array, + const tensorflow::gtl::optional& values_array, tensorflow::StringPiece name, llvm::Value* xor_mask, llvm::IRBuilder<>* b, const gpu::LaunchDimensions* launch_dimensions); diff --git a/tensorflow/compiler/xla/service/local_service.cc b/tensorflow/compiler/xla/service/local_service.cc index 5e02096ee501b23a7976a50f13bb7e7f3c5e2d34..597a788c5d7d5488d3193fe9af1d85884c41500e 100644 --- a/tensorflow/compiler/xla/service/local_service.cc +++ b/tensorflow/compiler/xla/service/local_service.cc @@ -19,10 +19,10 @@ limitations under the License. #include #include +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/client/executable_build_options.h" #include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/execution_options_util.h" -#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/backend.h" #include "tensorflow/compiler/xla/service/computation_layout.h" #include "tensorflow/compiler/xla/service/executable.h" diff --git a/tensorflow/compiler/xla/service/logical_buffer_analysis.cc b/tensorflow/compiler/xla/service/logical_buffer_analysis.cc index d631fb5ee42df6525681a5cd1fe1a8241824121d..eaa09591b72ee5202e0a9d1225d92eca92904adc 100644 --- a/tensorflow/compiler/xla/service/logical_buffer_analysis.cc +++ b/tensorflow/compiler/xla/service/logical_buffer_analysis.cc @@ -17,6 +17,7 @@ limitations under the License. #include +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/platform/logging.h" @@ -89,7 +90,7 @@ void LogicalBufferAnalysis::NewLogicalBuffer(HloInstruction* instruction, const ShapeIndex& index) { CHECK_EQ(logical_buffers_.size(), next_buffer_id_); logical_buffers_.emplace_back( - MakeUnique(instruction, index, next_buffer_id_)); + absl::make_unique(instruction, index, next_buffer_id_)); output_buffers_[std::make_pair(instruction, index)] = logical_buffers_.back().get(); diff --git a/tensorflow/compiler/xla/service/multi_output_fusion.h b/tensorflow/compiler/xla/service/multi_output_fusion.h index 0019cd725417d81900974b462c3b05075ce3e893..6aa639a954d3a359ff3b3de69b454fc6c0ec1792 100644 --- a/tensorflow/compiler/xla/service/multi_output_fusion.h +++ b/tensorflow/compiler/xla/service/multi_output_fusion.h @@ -104,17 +104,17 @@ class MultiOutputFusion : public HloPassInterface { // InstructionFusion instead. virtual bool DoProducerConsumerMultiOutputFusion(); - private: - // Update the internal data structures after instr1 and instr2 are fused into - // one fusion instruction. - void Update(HloInstruction* instr1, HloInstruction* instr2); - // Optimization fuel is a compiler debugging technique that makes an // optimization pass stop what it is doing after having made N changes to the // program, where N is the fuel. By varying N, this can be used to find the // first single change that makes a test fail. int64 fuel_; + private: + // Update the internal data structures after instr1 and instr2 are fused into + // one fusion instruction. + void Update(HloInstruction* instr1, HloInstruction* instr2); + // Computation for the pass. HloComputation* computation_; diff --git a/tensorflow/compiler/xla/service/pool.h b/tensorflow/compiler/xla/service/pool.h deleted file mode 100644 index 8e710ebb6dc17e0e204ba6ab3c6c159627cd9d3b..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/xla/service/pool.h +++ /dev/null @@ -1,84 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_COMPILER_XLA_POOL_H_ -#define TENSORFLOW_COMPILER_XLA_POOL_H_ - -#include -#include - -#include "tensorflow/compiler/xla/ptr_util.h" -#include "tensorflow/core/platform/mutex.h" - -namespace xla { - -// Pool of values, which are created as needed and destroyed when the `Pool` is -// destroyed -template -class Pool { - public: - struct Deleter { - void operator()(T* ptr) { pool->Deallocate(ptr); } - - Pool* pool; - }; - - // A pointer to a taken element of a `Pool` which returns it to the pool on - // destruction - using SmartPtr = std::unique_ptr; - - // Constructs a `Pool` with given factory function, which need not be - // thread-safe. - explicit Pool(std::function()> factory) - : factory_(factory) {} - - explicit Pool() : Pool([]() { return MakeUnique(); }) {} - - // Returns a pointer to a value in the pool, creating a new value if none is - // free. The returned smart pointer returns the element to the pool on - // destruction. - // - // This method is thread-safe. - SmartPtr Allocate() { - tensorflow::mutex_lock lock(mu_); - T* ptr; - if (!xs_.empty()) { - ptr = std::move(xs_.back()).release(); - xs_.pop_back(); - } else { - ptr = factory_().release(); - } - Deleter del = {this}; - return std::unique_ptr(ptr, del); - } - - private: - // Puts a pointer to a value back into the pool, leaving it free for future - // use. - // - // This method is thread-safe. - void Deallocate(T* ptr) { - tensorflow::mutex_lock lock(mu_); - xs_.push_back(std::unique_ptr(ptr)); - } - - const std::function()> factory_ GUARDED_BY(mu_); - std::vector> xs_ GUARDED_BY(mu_); - tensorflow::mutex mu_; -}; - -} // namespace xla - -#endif // TENSORFLOW_COMPILER_XLA_POOL_H_ diff --git a/tensorflow/compiler/xla/service/reshape_mover.cc b/tensorflow/compiler/xla/service/reshape_mover.cc index ca86c5d13e98a98c62d0c9e8e32e28fe99e0fa1f..4df746fca9f8320eed72911726f33bb01f06fed5 100644 --- a/tensorflow/compiler/xla/service/reshape_mover.cc +++ b/tensorflow/compiler/xla/service/reshape_mover.cc @@ -38,6 +38,8 @@ limitations under the License. #include "tensorflow/compiler/xla/service/reshape_mover.h" #include + +#include "absl/algorithm/container.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" @@ -374,7 +376,7 @@ StatusOr TryReshapeMoveOnCandidates( removed = false; for (auto operand : nontrivial_operands) { - if (c_any_of(operand->users(), [&](HloInstruction* user) { + if (absl::c_any_of(operand->users(), [&](HloInstruction* user) { return !reshape_candidates->count(user); })) { for (auto* user : operand->users()) { diff --git a/tensorflow/compiler/xla/service/reshape_mover_test.cc b/tensorflow/compiler/xla/service/reshape_mover_test.cc index ad3b662c20ac53b0a6d634b16b3b908f730f3d2d..7534a3f7e32aa84e5b47695b3eef23a8e749ee63 100644 --- a/tensorflow/compiler/xla/service/reshape_mover_test.cc +++ b/tensorflow/compiler/xla/service/reshape_mover_test.cc @@ -15,9 +15,9 @@ limitations under the License. #include "tensorflow/compiler/xla/service/reshape_mover.h" +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/literal.h" -#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_matchers.h" @@ -76,9 +76,13 @@ TEST_F(ReshapeMoverTest, ReshapesWithDifferentInputShapesNotMoved) { TEST_F(ReshapeMoverTest, 1ConstantAnd1ReshapesOnRngNotMoved) { HloComputation::Builder builder(TestName()); auto root_shape = ShapeUtil::MakeShape(F32, {8, 7}); - auto rng0 = builder.AddInstruction( - HloInstruction::CreateRng(ShapeUtil::MakeShape(F32, {1, 8, 1, 7, 1}), - RandomDistribution::RNG_UNIFORM, {})); + auto rng0 = builder.AddInstruction(HloInstruction::CreateRng( + ShapeUtil::MakeShape(F32, {1, 8, 1, 7, 1}), + RandomDistribution::RNG_UNIFORM, + {builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.0f))), + builder.AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR0(1.0f)))})); auto reshape0 = builder.AddInstruction(HloInstruction::CreateReshape(root_shape, rng0)); diff --git a/tensorflow/compiler/xla/service/scatter_expander.cc b/tensorflow/compiler/xla/service/scatter_expander.cc new file mode 100644 index 0000000000000000000000000000000000000000..338f0c09e9e7f59127023144ff30ac62aff55ee1 --- /dev/null +++ b/tensorflow/compiler/xla/service/scatter_expander.cc @@ -0,0 +1,351 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/scatter_expander.h" + +#include "absl/algorithm/container.h" +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_creation_utils.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/while_util.h" +#include "tensorflow/compiler/xla/statusor.h" + +namespace xla { + +using tensorflow::gtl::ArraySlice; + +// Transposes the given scatter_indices such that the index_vector_dim becomes +// the most-minor dimension. +static StatusOr TransposeIndexVectorDimToLast( + HloInstruction* scatter_indices, int64 index_vector_dim) { + const Shape& scatter_indices_shape = scatter_indices->shape(); + + if (scatter_indices_shape.dimensions_size() == index_vector_dim) { + return scatter_indices; + } + + if (index_vector_dim == (scatter_indices_shape.dimensions_size() - 1)) { + return scatter_indices; + } + + std::vector permutation; + permutation.reserve(scatter_indices_shape.dimensions_size()); + for (int64 i = 0, e = scatter_indices_shape.dimensions_size(); i < e; i++) { + if (i != index_vector_dim) { + permutation.push_back(i); + } + } + permutation.push_back(index_vector_dim); + return MakeTransposeHlo(scatter_indices, permutation); +} + +// Canonicalizes the scatter_indices tensor in order to keep them uniform while +// performing the scatter operation. +static StatusOr CanonicalizeScatterIndices( + HloInstruction* scatter_indices, int64 index_vector_dim) { + // Transpose the non-index-vector dimensions to the front. + TF_ASSIGN_OR_RETURN( + HloInstruction * transposed_scatter_indices, + TransposeIndexVectorDimToLast(scatter_indices, index_vector_dim)); + bool indices_are_scalar = + index_vector_dim == scatter_indices->shape().dimensions_size(); + + // The number of dimensions in scatter_indices that are index dimensions. + const int64 index_dims_in_scatter_indices = indices_are_scalar ? 0 : 1; + + // If there is only one index (i.e. scatter_indices has rank 1 and this + // scatter is really just a dynamic update slice) add a leading degenerate + // dimension for uniformity. Otherwise create a "collapsed" leading dimension + // that subsumes all of the non-index-vector dimensions. + const Shape& shape = transposed_scatter_indices->shape(); + if (shape.dimensions_size() == index_dims_in_scatter_indices) { + return PrependDegenerateDims(transposed_scatter_indices, 1); + } else { + // Collapse all but the dimensions (0 or 1) in scatter_indices containing + // the index vectors. + return CollapseFirstNDims( + transposed_scatter_indices, + shape.dimensions_size() - index_dims_in_scatter_indices); + } +} + +// Permutes the `updates` tensor such that all the scatter dims appear in the +// major dimensions and all the window dimensions appear in the minor +// dimensions. +static StatusOr PermuteScatterAndWindowDims( + HloInstruction* updates, ArraySlice update_window_dims) { + std::vector permutation; + const int64 updates_rank = ShapeUtil::Rank(updates->shape()); + permutation.reserve(updates_rank); + + for (int64 i = 0; i < updates_rank; ++i) { + bool is_scatter_dim = !absl::c_binary_search(update_window_dims, i); + if (is_scatter_dim) { + permutation.push_back(i); + } + } + for (auto window_dim : update_window_dims) { + permutation.push_back(window_dim); + } + + return MakeTransposeHlo(updates, permutation); +} + +// Expands or contracts the scatter indices in the updates tensor. +static StatusOr AdjustScatterDims( + const Shape& scatter_indices_shape, HloInstruction* updates, + int64 index_vector_dim) { + int64 num_scatter_dims = scatter_indices_shape.dimensions_size(); + if (index_vector_dim < scatter_indices_shape.dimensions_size()) { + --num_scatter_dims; + } + if (num_scatter_dims == 0) { + // If there are no scatter dims, this must be a dynamic-update-slice kind of + // scatter. In this case, we prepend a degenerate dimension to work + // uniformly in the while loop. + return PrependDegenerateDims(updates, 1); + } + return CollapseFirstNDims(updates, num_scatter_dims); +} + +// Expands an index vector from the scatter_indices tensor into a vector that +// can be used to dynamic-update-slice to perform the scatter update. +static StatusOr ExpandIndexVectorIntoOperandSpace( + HloInstruction* index_vector, const ScatterDimensionNumbers& dim_numbers, + int64 operand_rank) { + HloComputation* computation = index_vector->parent(); + const Shape& index_shape = index_vector->shape(); + HloInstruction* zero = + computation->AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateFromDimensions(index_shape.element_type(), {1}))); + + // We extract out individual components from the smaller index and concatenate + // them (interspersing zeros as needed) into the larger index. + std::vector expanded_index_components; + + for (int i = 0; i < operand_rank; i++) { + int64 index_vector_dim_index = + FindIndex(dim_numbers.scatter_dims_to_operand_dims(), i); + if (index_vector_dim_index != + dim_numbers.scatter_dims_to_operand_dims_size()) { + TF_ASSIGN_OR_RETURN( + HloInstruction * component_to_concat, + MakeSliceHlo(index_vector, /*start_indices=*/{index_vector_dim_index}, + /*limit_indices=*/{index_vector_dim_index + 1}, + /*strides=*/{1})); + expanded_index_components.push_back(component_to_concat); + } else { + expanded_index_components.push_back(zero); + } + } + + return MakeConcatHlo(expanded_index_components, /*dimension=*/0); +} + +// Body of the while loop that performs the scatter operation using other HLOs. +static StatusOr> ScatterLoopBody( + HloInstruction* scatter, HloInstruction* induction_var, + const std::vector& loop_state) { + const ScatterDimensionNumbers& dim_numbers = + scatter->scatter_dimension_numbers(); + CHECK_EQ(loop_state.size(), 3); + HloInstruction* operand = loop_state[0]; + HloInstruction* scatter_indices = loop_state[1]; + HloInstruction* updates = loop_state[2]; + + bool has_scalar_indices = scatter_indices->shape().dimensions_size() == 1; + CHECK_EQ(has_scalar_indices, + dim_numbers.index_vector_dim() == + scatter->operand(1)->shape().dimensions_size()); + + // Build a vector form of the induction variable of the while loop. + TF_ASSIGN_OR_RETURN( + HloInstruction * induction_var_as_vector, + MakeBroadcastHlo(induction_var, /*broadcast_dimensions=*/{}, + /*result_shape_bounds=*/{1})); + + // Pick the index to scatter from scatter_indices based on the induction_var + // and transform that to an index into the `operand` space. + HloInstruction* index_vector; + if (has_scalar_indices) { + TF_ASSIGN_OR_RETURN( + index_vector, + MakeDynamicSliceHlo(scatter_indices, induction_var_as_vector, {1})); + } else { + TF_ASSIGN_OR_RETURN( + HloInstruction * index_into_scatter_indices, + PadVectorWithZeros(induction_var_as_vector, + /*zeros_to_prepend=*/0, /*zeros_to_append=*/1)); + int index_vector_size = scatter_indices->shape().dimensions(1); + TF_ASSIGN_OR_RETURN( + HloInstruction * index_vector_2d, + MakeDynamicSliceHlo(scatter_indices, index_into_scatter_indices, + {1, index_vector_size})); + TF_ASSIGN_OR_RETURN(index_vector, + ElideDegenerateDims(index_vector_2d, {0})); + } + TF_ASSIGN_OR_RETURN( + HloInstruction * scatter_slice_start, + ExpandIndexVectorIntoOperandSpace(index_vector, dim_numbers, + operand->shape().dimensions_size())); + + // Extract the slice to be used to update from `updates` tensor for the + // induction_var corresponding to this iteration of the while loop. + TF_ASSIGN_OR_RETURN( + HloInstruction * index_into_updates, + PadVectorWithZeros( + induction_var_as_vector, /*zeros_to_prepend=*/0, + /*zeros_to_append=*/updates->shape().dimensions_size() - 1)); + std::vector update_slice_bounds(updates->shape().dimensions().begin(), + updates->shape().dimensions().end()); + update_slice_bounds[0] = 1; + TF_ASSIGN_OR_RETURN( + HloInstruction * update_slice, + MakeDynamicSliceHlo(updates, index_into_updates, update_slice_bounds)); + TF_ASSIGN_OR_RETURN(HloInstruction * update_slice_for_scatter, + ElideDegenerateDims(update_slice, {0})); + TF_ASSIGN_OR_RETURN( + HloInstruction * update_slice_with_dims_inserted, + InsertDegenerateDims(update_slice_for_scatter, + AsInt64Slice(dim_numbers.inserted_window_dims()))); + + // Extact the slice to update from `operand` tensor. + const Shape& update_slice_shape = update_slice_with_dims_inserted->shape(); + TF_ASSIGN_OR_RETURN( + HloInstruction * operand_slice_to_update, + MakeDynamicSliceHlo(operand, scatter_slice_start, + AsInt64Slice(update_slice_shape.dimensions()))); + + // Compute the new value for the slice to be updated in `operand` tensor by + // combining the existing value and the update value using the update + // computation. + TF_ASSIGN_OR_RETURN( + HloInstruction * updated_operand_slice, + MakeMapHlo({operand_slice_to_update, update_slice_with_dims_inserted}, + scatter->to_apply())); + + // Write the updated value of the slice into `operand` tensor. + TF_ASSIGN_OR_RETURN(HloInstruction * updated_operand, + MakeDynamicUpdateSliceHlo(operand, updated_operand_slice, + scatter_slice_start)); + + return StatusOr>{ + {updated_operand, scatter_indices, updates}}; +} + +// High Level Algorithm. +// +// 1. Canonicalize the scatter_indices tensor such that it has rank 2, where +// each row is an index into the operand. +// 2. Canonicalize the updates tensor such that is has rank `num_window_dims+1` +// and the scatter dim is the most-major dimension. +// 3. Iterate over the set of indices in the canonicalized scatter_indices +// tensor using a while loop, updating the operand for each such index. Each +// iteration of this while loop performs the following: +// a. Pick the index from scatter_indices for this iteration. +// b. Transfrom this index into an index into the operand space. +// c. Extract the slice to be used to update from the updates tensor. +// d. Extract the slice to update from the operand tensor. +// e. Compute the new value for the slice to update by combining the slices +// from c. and d. using the update_computation of scatter. +// f. Write the updated value of the slice into the operand tensor. + +StatusOr ScatterExpander::ExpandScatter( + HloInstruction* scatter) { + HloInstruction* operand = scatter->mutable_operand(0); + HloInstruction* scatter_indices = scatter->mutable_operand(1); + HloInstruction* updates = scatter->mutable_operand(2); + const ScatterDimensionNumbers& dim_numbers = + scatter->scatter_dimension_numbers(); + + // If the updates tensor is empty, there is no need to update the operand. We + // can return the operand as is. + if (ShapeUtil::IsZeroElementArray(updates->shape())) { + return operand; + } + + // Compute the trip count for the while loop to be used for scatter. This + // should be the number of indices we should scatter into the operand. + const Shape& scatter_indices_shape = scatter_indices->shape(); + int64 scatter_loop_trip_count = 1; + for (int64 i = 0, e = scatter_indices_shape.dimensions_size(); i < e; i++) { + if (i != dim_numbers.index_vector_dim()) { + scatter_loop_trip_count *= scatter_indices_shape.dimensions(i); + } + } + if (!IsInt32(scatter_loop_trip_count)) { + return Unimplemented( + "Scatter operations with more than 2147483647 scatter indices are not " + "supported. This error occurred for %s.", + scatter->ToString().c_str()); + } + + // Canonicalize the scatter_indices, after which the size of its most-major + // dimension must be same as the while loop trip count. + TF_ASSIGN_OR_RETURN(HloInstruction * canonical_scatter_indices, + CanonicalizeScatterIndices( + scatter_indices, dim_numbers.index_vector_dim())); + CHECK_EQ(scatter_loop_trip_count, + canonical_scatter_indices->shape().dimensions(0)); + + // Canonicalize the updates, after which the size of its most-major dimension + // must be same as the while loop trip count. + TF_ASSIGN_OR_RETURN( + HloInstruction * canonical_updates, + PermuteScatterAndWindowDims( + updates, AsInt64Slice(dim_numbers.update_window_dims()))); + TF_ASSIGN_OR_RETURN( + HloInstruction * adjusted_canonical_updates, + AdjustScatterDims(scatter_indices->shape(), canonical_updates, + dim_numbers.index_vector_dim())); + CHECK_EQ(scatter_loop_trip_count, + adjusted_canonical_updates->shape().dimensions(0)); + + // The while loop that implements the scatter operation. + StatusOr> scatter_loop_result_status = + WhileUtil::MakeCountedLoop( + scatter->parent(), scatter_loop_trip_count, + {operand, canonical_scatter_indices, adjusted_canonical_updates}, + [&](HloInstruction* induction_var, + const std::vector& loop_state) { + return ScatterLoopBody(scatter, induction_var, loop_state); + }); + TF_ASSIGN_OR_RETURN(std::vector scatter_loop_result, + scatter_loop_result_status); + return scatter_loop_result.front(); +} + +StatusOr ScatterExpander::Run(HloModule* module) { + std::vector scatter_instrs; + for (HloComputation* computation : module->MakeNonfusionComputations()) { + for (HloInstruction* instr : computation->instructions()) { + if (instr->opcode() == HloOpcode::kScatter) { + scatter_instrs.push_back(instr); + } + } + } + + for (auto instr : scatter_instrs) { + TF_ASSIGN_OR_RETURN(HloInstruction * expanded_root, ExpandScatter(instr)); + TF_RETURN_IF_ERROR( + instr->parent()->ReplaceInstruction(instr, expanded_root)); + } + + return !scatter_instrs.empty(); +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/scatter_expander.h b/tensorflow/compiler/xla/service/scatter_expander.h new file mode 100644 index 0000000000000000000000000000000000000000..8f735e877d270c10b494e1cd974904c4e2d960c9 --- /dev/null +++ b/tensorflow/compiler/xla/service/scatter_expander.h @@ -0,0 +1,34 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_SCATTER_EXPANDER_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_SCATTER_EXPANDER_H_ + +#include "tensorflow/compiler/xla/service/hlo_pass_interface.h" + +namespace xla { + +class ScatterExpander : public HloPassInterface { + public: + tensorflow::StringPiece name() const override { return "scatter_expander"; } + StatusOr Run(HloModule* module) override; + + private: + StatusOr ExpandScatter(HloInstruction* scatter); +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_SCATTER_EXPANDER_H_ diff --git a/tensorflow/compiler/xla/service/service.cc b/tensorflow/compiler/xla/service/service.cc index 636013cbb561f8506e173bd634e07b48a8dc570e..18d1b7732bb2f54eb4b1bf74e1eed1d96221913c 100644 --- a/tensorflow/compiler/xla/service/service.cc +++ b/tensorflow/compiler/xla/service/service.cc @@ -20,10 +20,10 @@ limitations under the License. #include #include +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/execution_options_util.h" #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" -#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/compiler.h" #include "tensorflow/compiler/xla/service/computation_layout.h" #include "tensorflow/compiler/xla/service/device_memory_allocator.h" @@ -37,6 +37,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_proto_util.h" #include "tensorflow/compiler/xla/service/platform_util.h" #include "tensorflow/compiler/xla/service/source_map_util.h" +#include "tensorflow/compiler/xla/service/stream_pool.h" #include "tensorflow/compiler/xla/service/transfer_manager.h" #include "tensorflow/compiler/xla/shape_layout.h" #include "tensorflow/compiler/xla/shape_util.h" @@ -52,10 +53,10 @@ limitations under the License. #include "tensorflow/core/platform/protobuf.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" #include "tensorflow/core/platform/types.h" +#include "tensorflow/core/util/ptr_util.h" using ::tensorflow::strings::Printf; using ::tensorflow::strings::StrCat; -using ::xla::source_map_util::InvalidParameterArgument; namespace xla { @@ -244,7 +245,7 @@ StatusOr> Service::CreateModuleConfig( const ProgramShape& program_shape, tensorflow::gtl::ArraySlice argument_shapes, const ExecutionOptions* execution_options) { - auto config = MakeUnique(program_shape); + auto config = absl::make_unique(program_shape); ComputationLayout* computation_layout = config->mutable_entry_computation_layout(); if (program_shape.parameters_size() != argument_shapes.size()) { @@ -325,7 +326,7 @@ StatusOr>> Service::BuildExecutables( if (directory_path.empty() && execution_directory_path.empty()) { continue; } - auto hlo_snapshot = MakeUnique(); + auto hlo_snapshot = absl::make_unique(); *hlo_snapshot->mutable_hlo()->mutable_hlo_module() = *module_protos[i]; if (!directory_path.empty()) { string filename = @@ -376,7 +377,7 @@ Service::ExecuteParallelAndRegisterResult( ExecutionProfile* profile) { // Streams where the computation are launched, so we can wait on the streams // to complete. - std::vector::SmartPtr> streams; + std::vector streams; std::vector> timers; // Global data handles for the computation results, one for each computation. @@ -403,12 +404,13 @@ Service::ExecuteParallelAndRegisterResult( CHECK_EQ(replicas.size(), arguments[i].size()); std::vector result_buffers; for (int64 replica = 0; replica < replicas.size(); ++replica) { - TF_ASSIGN_OR_RETURN(Pool::SmartPtr stream, + TF_ASSIGN_OR_RETURN(StreamPool::Ptr stream, backend->BorrowStream(replicas[replica])); streams.push_back(std::move(stream)); if (replica == 0 && profile != nullptr) { - timers.emplace_back(new se::Timer(streams.back()->parent())); + timers.push_back( + absl::make_unique(streams.back()->parent())); streams.back() ->InitTimer(timers.back().get()) .ThenStartTimer(timers.back().get()); @@ -440,7 +442,7 @@ Service::ExecuteParallelAndRegisterResult( streams.back()->ThenStopTimer(timers.back().get()); } - result_buffers.emplace_back(std::move(result)); + result_buffers.push_back(std::move(result)); } TF_ASSIGN_OR_RETURN(GlobalDataHandle handle, allocation_tracker_.RegisterReplicatedBuffers( @@ -515,13 +517,13 @@ StatusOr Service::ExecuteAndRegisterResult( arguments, Backend* backend, const string& result_tag, ExecutionProfile* profile) { // Set up streams. - std::vector::SmartPtr> streams; + std::vector streams; TF_ASSIGN_OR_RETURN(auto replicas, Replicas(*backend, SingleComputationDeviceHandle())); TF_RET_CHECK(!replicas.empty()); for (se::StreamExecutor* executor : replicas) { - TF_ASSIGN_OR_RETURN(Pool::SmartPtr stream, + TF_ASSIGN_OR_RETURN(StreamPool::Ptr stream, backend->BorrowStream(executor)); streams.push_back(std::move(stream)); } @@ -533,7 +535,7 @@ StatusOr Service::ExecuteAndRegisterResult( // Set up run options. std::vector run_options; - for (const Pool::SmartPtr& stream : streams) { + for (const StreamPool::Ptr& stream : streams) { ExecutableRunOptions options; options.set_stream(stream.get()); options.set_device_ordinal(stream->parent()->device_ordinal()); @@ -558,7 +560,7 @@ StatusOr Service::ExecuteAndRegisterResult( std::vector> replicated_arguments; for (const auto& arg : arguments) { - replicated_arguments.emplace_back(arg); + replicated_arguments.push_back(arg); } TF_ASSIGN_OR_RETURN(auto results, executable->ExecuteOnStreams( @@ -799,7 +801,7 @@ StatusOr> Service::BuildExecutable( module_proto.name().c_str()); // Dump computation proto state if flag is set. - auto hlo_snapshot = MakeUnique(); + auto hlo_snapshot = absl::make_unique(); const string& directory_path = module_config->debug_options().xla_dump_computations_to(); const string& execution_directory_path = @@ -953,7 +955,7 @@ namespace { // shape and DeviceMemoryBase values of the clone are identical to the original. std::unique_ptr CloneShapedBufferOnDevice( const ShapedBuffer& shaped_buffer, int device_ordinal) { - auto clone = MakeUnique( + auto clone = absl::make_unique( shaped_buffer.on_host_shape(), shaped_buffer.on_device_shape(), shaped_buffer.platform(), device_ordinal); clone->buffers() = shaped_buffer.buffers(); @@ -1052,11 +1054,12 @@ Status Service::TransferFromOutfeed(const TransferFromOutfeedRequest* arg, executor = replicas[arg->replica_id()]; } - Literal literal; + auto literal = Literal::CreateFromShape(arg->shape_with_layout()); + TF_RETURN_IF_ERROR( execute_backend_->transfer_manager()->TransferLiteralFromOutfeed( - executor, arg->shape_with_layout(), &literal)); - *result->mutable_literal() = literal.ToProto(); + executor, arg->shape_with_layout(), *literal)); + *result->mutable_literal() = literal->ToProto(); return Status::OK(); } diff --git a/tensorflow/compiler/xla/service/service_executable_run_options.h b/tensorflow/compiler/xla/service/service_executable_run_options.h index 7f3910cdb0366078b97fb5f6a2dc498b37570926..dbfed628bfcabffe66bef41a82e0e2430897d80d 100644 --- a/tensorflow/compiler/xla/service/service_executable_run_options.h +++ b/tensorflow/compiler/xla/service/service_executable_run_options.h @@ -17,7 +17,7 @@ limitations under the License. #define TENSORFLOW_COMPILER_XLA_SERVICE_SERVICE_EXECUTABLE_RUN_OPTIONS_H_ #include "tensorflow/compiler/xla/executable_run_options.h" -#include "tensorflow/compiler/xla/service/pool.h" +#include "tensorflow/compiler/xla/service/stream_pool.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/stream_executor/stream_executor.h" @@ -27,8 +27,7 @@ namespace xla { // data, now only a stream cache for GPU backend. class ServiceExecutableRunOptions { public: - using StreamBorrower = - std::function::SmartPtr>(int)>; + using StreamBorrower = std::function(int)>; ServiceExecutableRunOptions() : ServiceExecutableRunOptions(ExecutableRunOptions()) {} @@ -51,7 +50,7 @@ class ServiceExecutableRunOptions { // Borrows a stream and returns a smart pointer which returns the stream on // destruction. - StatusOr::SmartPtr> BorrowStream(int device_ordinal) const { + StatusOr BorrowStream(int device_ordinal) const { return borrow_stream_ ? borrow_stream_(device_ordinal) : Status(tensorflow::error::UNIMPLEMENTED, "No stream cache"); diff --git a/tensorflow/compiler/xla/service/shape_inference.cc b/tensorflow/compiler/xla/service/shape_inference.cc index 35df792b07022b2338fcecc25eb8a0718626e464..ec6aa6df55460fb9bb5d468dbc4fa69be34524b2 100644 --- a/tensorflow/compiler/xla/service/shape_inference.cc +++ b/tensorflow/compiler/xla/service/shape_inference.cc @@ -21,6 +21,7 @@ limitations under the License. #include #include +#include "absl/algorithm/container.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/types.h" @@ -58,66 +59,101 @@ Status ExpectArray(const Shape& shape, tensorflow::StringPiece op_type) { return Status::OK(); } -Status VerifyReducerShape(const ProgramShape& reducer_shape, - const Shape& init_value_shape, - const PrimitiveType& input_element_type) { - if (reducer_shape.parameters_size() != 2) { +Status VerifyReducerShape( + const ProgramShape& reducer_shape, + tensorflow::gtl::ArraySlice init_value_shapes, + tensorflow::gtl::ArraySlice input_element_types, + int64 inputs) { + if (reducer_shape.parameters_size() != inputs * 2) { return InvalidArgument( - "Reduction function must take 2 parameters, but " + "Reduction function must take %lld parameters, but " "takes %d parameter(s).", - reducer_shape.parameters_size()); + inputs * 2, reducer_shape.parameters_size()); } const Shape& accumulator_shape = reducer_shape.result(); - if (!ShapeUtil::IsArray(accumulator_shape) || - ShapeUtil::Rank(accumulator_shape) != 0) { - return InvalidArgument( - "Reduction function must produce a scalar but has shape: %s", - ShapeUtil::HumanString(accumulator_shape).c_str()); - } - - // Check that the accumulator can be passed in as the first argument. - // Note: comparing here and below with Compatible since we don't care about - // layout in scalars - see b/26668201 for a longer-term vision. - if (!ShapeUtil::Compatible(accumulator_shape, reducer_shape.parameters(0))) { + std::vector accumulator_subshapes; + if (ShapeUtil::IsArray(accumulator_shape)) { + if (inputs != 1) { + return InvalidArgument( + "Reduction function must produce a tuple with %lld elements, but " + "produces a scalar", + inputs); + } + accumulator_subshapes.push_back(&accumulator_shape); + } else if (ShapeUtil::IsTuple(accumulator_shape)) { + if (ShapeUtil::TupleElementCount(accumulator_shape) != inputs) { + return InvalidArgument( + "Reduction function must produce a tuple with %lld elements, but has " + "%lld elements", + inputs, ShapeUtil::TupleElementCount(accumulator_shape)); + } + for (const Shape& element_shape : accumulator_shape.tuple_shapes()) { + accumulator_subshapes.push_back(&element_shape); + } + } else { return InvalidArgument( - "Reduction function's first parameter shape differs from the " - "result shape: %s vs %s", - ShapeUtil::HumanString(reducer_shape.parameters(0)).c_str(), + "Reduction function must produce a scalar or tuple of scalars, but has " + "shape: %s", ShapeUtil::HumanString(accumulator_shape).c_str()); } - // Check that init_value's shape is suitable for reducer_shape. - if (!ShapeUtil::CompatibleIgnoringFpPrecision(accumulator_shape, - init_value_shape)) { - return InvalidArgument( - "Reduction function's accumulator shape differs from the " - "init_value shape: %s vs %s", - ShapeUtil::HumanString(accumulator_shape).c_str(), - ShapeUtil::HumanString(init_value_shape).c_str()); - } - - // Check that the inputs can be passed in as the second argument. - const Shape& input_element_shape = - ShapeUtil::MakeShape(input_element_type, {}); - if (!ShapeUtil::CompatibleIgnoringFpPrecision(input_element_shape, - reducer_shape.parameters(1))) { - return InvalidArgument( - "Reduction function's second parameter shape differs from the " - "input type element type: %s vs %s", - ShapeUtil::HumanString(reducer_shape.parameters(1)).c_str(), - ShapeUtil::HumanString(input_element_shape).c_str()); + for (const Shape* element_shape : accumulator_subshapes) { + if (ShapeUtil::Rank(*element_shape) != 0) { + return InvalidArgument( + "Reduction function must return a scalar or tuple of scalars but " + "returns shape: %s", + ShapeUtil::HumanString(accumulator_shape).c_str()); + } } - // Currently the accumulator and inputs must be the same type, - // though that restriction could be relaxed. - if (!ShapeUtil::CompatibleIgnoringFpPrecision(accumulator_shape, - reducer_shape.parameters(1))) { - return InvalidArgument( - "Reduction function's second parameter shape must " - "match the result shape, but got %s vs %s.", - ShapeUtil::HumanString(reducer_shape.parameters(1)).c_str(), - ShapeUtil::HumanString(accumulator_shape).c_str()); + for (int64 i = 0; i < inputs; ++i) { + // Check that the accumulator can be passed in as the first argument. + // Note: comparing here and below with Compatible since we don't care about + // layout in scalars - see b/26668201 for a longer-term vision. + if (!ShapeUtil::Compatible(*accumulator_subshapes[i], + reducer_shape.parameters(i))) { + return InvalidArgument( + "Reduction function's %lld-th parameter shape differs from the " + "result shape: %s vs %s", + i, ShapeUtil::HumanString(reducer_shape.parameters(i)).c_str(), + ShapeUtil::HumanString(*accumulator_subshapes[i]).c_str()); + } + // Check that init_value's shapes are suitable for reducer_shape. + if (!ShapeUtil::CompatibleIgnoringFpPrecision(*accumulator_subshapes[i], + *init_value_shapes[i])) { + return InvalidArgument( + "Reduction function's accumulator shape at index %lld differs from " + "the init_value shape: %s vs %s", + i, ShapeUtil::HumanString(*accumulator_subshapes[i]).c_str(), + ShapeUtil::HumanString(*init_value_shapes[i]).c_str()); + } + // Check that the inputs can be passed in as the non-accumulator arguments. + const Shape input_element_shape = + ShapeUtil::MakeShape(input_element_types[i], {}); + if (!ShapeUtil::CompatibleIgnoringFpPrecision( + input_element_shape, reducer_shape.parameters(inputs + i))) { + return InvalidArgument( + "Reduction function's %lld-th parameter shape differs from the " + "input type element type: %s vs %s", + inputs + i, + ShapeUtil::HumanString(reducer_shape.parameters(inputs + i)).c_str(), + ShapeUtil::HumanString(input_element_shape).c_str()); + } + // Check that the accumulator and inputs to the reducer function match. + // If the accumulator is scalar, it must have the same type as the inputs + // (up to fp precision). If it is a tuple, then the k-th element of the + // tuple must have the same type as the K-th input (again, up to fp + // precision.) + if (!ShapeUtil::CompatibleIgnoringFpPrecision( + *accumulator_subshapes[i], reducer_shape.parameters(inputs + i))) { + return InvalidArgument( + "Reduction function's %lld-th parameter shape must " + "match the result shape, but got %s vs %s.", + inputs + i, + ShapeUtil::HumanString(reducer_shape.parameters(inputs + i)).c_str(), + ShapeUtil::HumanString(*accumulator_subshapes[i]).c_str()); + } } return Status::OK(); @@ -1495,7 +1531,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, /* static */ StatusOr ShapeInference::InferConvolveShape( const Shape& lhs, const Shape& rhs, const Window& window, - const ConvolutionDimensionNumbers& dnums) { + const ConvolutionDimensionNumbers& dnums, int64 feature_group_count) { TF_RETURN_IF_ERROR(ExpectArray(lhs, "lhs of convolution")); TF_RETURN_IF_ERROR(ExpectArray(rhs, "rhs of convolution")); @@ -1605,12 +1641,13 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, const int64 kernel_output_features = rhs.dimensions(dnums.kernel_output_feature_dimension()); - if (input_features != kernel_input_features) { + if (input_features != kernel_input_features * feature_group_count) { return InvalidArgument( "Expected LHS feature dimension (value %lld) to match RHS " - "input feature dimension (value %lld); got (%s, %s)\n" + "input feature dimension * feature_group_count (value %lld); " + "got (%s, %s)\n" "Dimension numbers: {%s}.", - input_features, kernel_input_features, + input_features, kernel_input_features * feature_group_count, ShapeUtil::HumanString(lhs).c_str(), ShapeUtil::HumanString(rhs).c_str(), dnums.DebugString().c_str()); } @@ -1744,11 +1781,83 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, return ShapeUtil::MakeTupleShape(operand_shape_values); } +/* static */ StatusOr ShapeInference::InferAllToAllShape( + const Shape& shape, int64 split_dimension, int64 concat_dimension, + int64 split_count) { + TF_RET_CHECK(split_count > 0); + if (split_dimension >= ShapeUtil::Rank(shape) || split_dimension < 0) { + return InvalidArgument( + "AllToAll split_dimension %lld is out-of-bounds in shape %s.", + split_dimension, ShapeUtil::HumanString(shape).c_str()); + } + if (concat_dimension >= ShapeUtil::Rank(shape) || concat_dimension < 0) { + return InvalidArgument( + "AllToAll concat_dimension %lld is out-of-bounds in shape %s.", + concat_dimension, ShapeUtil::HumanString(shape).c_str()); + } + if (shape.dimensions(split_dimension) % split_count != 0) { + return InvalidArgument( + "AllToAll split dimension size %lld must be dividable by split_count " + "%lld.", + shape.dimensions(split_dimension), split_count); + } + std::vector new_dimensions(shape.dimensions().begin(), + shape.dimensions().end()); + new_dimensions[split_dimension] /= split_count; + new_dimensions[concat_dimension] *= split_count; + return ShapeUtil::MakeShape(shape.element_type(), new_dimensions); +} + +/* static */ StatusOr ShapeInference::InferAllToAllTupleShape( + tensorflow::gtl::ArraySlice operand_shapes) { + // An Alltoall HLO instruction receives N operands (with the same shape) and + // returns a tuple that contains N array shapes. + TF_RET_CHECK(!operand_shapes.empty()); + for (int i = 0; i < operand_shapes.size(); i++) { + if (!ShapeUtil::Equal(*operand_shapes[0], *operand_shapes[i])) { + return InvalidArgument( + "HLO all-to-all has operands with different shapes: the 0th " + "operand shape %s, but the %dth operand has shape %s.", + ShapeUtil::HumanString(*operand_shapes[0]).c_str(), i, + ShapeUtil::HumanString(*operand_shapes[i]).c_str()); + } + } + + return InferVariadicOpShape(HloOpcode::kTuple, operand_shapes); +} + /* static */ StatusOr ShapeInference::InferReduceShape( - const Shape& arg, const Shape& init_value, + tensorflow::gtl::ArraySlice arg_shapes, tensorflow::gtl::ArraySlice dimensions_to_reduce, const ProgramShape& to_apply) { - // Check that the dimension to reduce are in-bounds for the given shape. + if (arg_shapes.empty()) { + return InvalidArgument("Reduce must have at least 2 arguments, has 0"); + } + if (arg_shapes.size() % 2) { + return InvalidArgument( + "Reduce must have an even number of arguments, has %lu", + arg_shapes.size()); + } + int64 num_reduced_args = arg_shapes.size() / 2; + + tensorflow::gtl::ArraySlice reduced_args(arg_shapes, 0, + num_reduced_args); + // Check that all of the reduced tensors have the same dimensions. The element + // types may be different. + for (int64 i = 1; i < num_reduced_args; ++i) { + if (!ShapeUtil::SameDimensions(*reduced_args[0], *reduced_args[i])) { + return InvalidArgument( + "All reduced tensors must have the sime dimension. Tensor 0 has " + "shape %s, Tensor %lld has shape %s", + ShapeUtil::HumanString(*reduced_args[0]).c_str(), i, + ShapeUtil::HumanString(*reduced_args[i]).c_str()); + } + } + + // Check that the dimensions to reduce are in-bounds for the given shape. + // We've already verified all reduced tensors have the same dimensions, so it + // doesn't matter which one we choose. + const Shape& arg = *reduced_args[0]; for (int64 dimension : dimensions_to_reduce) { if (dimension >= ShapeUtil::Rank(arg) || dimension < 0) { return InvalidArgument( @@ -1756,8 +1865,15 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, ShapeUtil::HumanString(arg).c_str()); } } - TF_RETURN_IF_ERROR( - VerifyReducerShape(to_apply, init_value, arg.element_type())); + + tensorflow::gtl::ArraySlice init_values( + arg_shapes, num_reduced_args, arg_shapes.size()); + std::vector element_types; + for (const Shape* arg : reduced_args) { + element_types.push_back(arg->element_type()); + } + TF_RETURN_IF_ERROR(VerifyReducerShape(to_apply, init_values, element_types, + num_reduced_args)); std::set dimensions_to_reduce_set(dimensions_to_reduce.begin(), dimensions_to_reduce.end()); @@ -1768,15 +1884,26 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, } } - return ShapeUtil::MakeShape(to_apply.result().element_type(), new_dimensions); + if (ShapeUtil::IsScalar(to_apply.result())) { + return ShapeUtil::MakeShape(to_apply.result().element_type(), + new_dimensions); + } else { + std::vector result_subshapes; + for (const Shape& subshape : to_apply.result().tuple_shapes()) { + result_subshapes.push_back( + ShapeUtil::MakeShape(subshape.element_type(), new_dimensions)); + } + return ShapeUtil::MakeTupleShape(result_subshapes); + } } /* static */ StatusOr ShapeInference::InferReduceWindowShape( const Shape& operand_shape, const Shape& init_value_shape, const Window& window, const ProgramShape& to_apply_shape) { TF_RETURN_IF_ERROR(ExpectArray(operand_shape, "operand of reduce-window")); - TF_RETURN_IF_ERROR(VerifyReducerShape(to_apply_shape, init_value_shape, - operand_shape.element_type())); + TF_RETURN_IF_ERROR(VerifyReducerShape(to_apply_shape, {&init_value_shape}, + {operand_shape.element_type()}, + /*inputs=*/1)); return InferWindowOutputShape(operand_shape, window, init_value_shape.element_type(), /*allow_negative_padding=*/false); @@ -1821,8 +1948,9 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, } // Check if the scatter function has a proper shape as a reduction. - TF_RETURN_IF_ERROR(VerifyReducerShape(scatter_shape, init_value_shape, - source_shape.element_type())); + TF_RETURN_IF_ERROR(VerifyReducerShape(scatter_shape, {&init_value_shape}, + {source_shape.element_type()}, + /*inputs=*/1)); // Check if the result shape of window operation matches the source shape. TF_ASSIGN_OR_RETURN(const Shape& window_result_shape, @@ -2365,201 +2493,198 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, static Status ValidateGatherDimensionNumbers( const Shape& input_shape, - tensorflow::gtl::ArraySlice gather_indices_shape, + tensorflow::gtl::ArraySlice start_indices_shape, const GatherDimensionNumbers& dim_numbers) { - if (!c_is_sorted(dim_numbers.output_window_dims())) { + if (!absl::c_is_sorted(dim_numbers.offset_dims())) { return InvalidArgument( "Output window dimensions in gather op must be ascending; got: %s.", - Join(dim_numbers.output_window_dims(), ", ").c_str()); + Join(dim_numbers.offset_dims(), ", ").c_str()); } - if (c_adjacent_find(dim_numbers.output_window_dims()) != - dim_numbers.output_window_dims().end()) { + if (absl::c_adjacent_find(dim_numbers.offset_dims()) != + dim_numbers.offset_dims().end()) { return InvalidArgument( "Output window dimensions in gather op must not repeat; got: %s.", - Join(dim_numbers.output_window_dims(), ", ").c_str()); + Join(dim_numbers.offset_dims(), ", ").c_str()); } - const int64 output_window_dim_count = dim_numbers.output_window_dims_size(); + const int64 output_offset_dim_count = dim_numbers.offset_dims_size(); const int64 output_shape_rank = - output_window_dim_count + gather_indices_shape.size() - 1; + output_offset_dim_count + start_indices_shape.size() - 1; - for (int i = 0; i < dim_numbers.output_window_dims_size(); ++i) { - int64 window_index = dim_numbers.output_window_dims(i); - if (window_index < 0 || window_index >= output_shape_rank) { + for (int i = 0; i < dim_numbers.offset_dims_size(); ++i) { + int64 offset_dim = dim_numbers.offset_dims(i); + if (offset_dim < 0 || offset_dim >= output_shape_rank) { return InvalidArgument( - "Window index %d in gather op is out of bounds; got %lld, but should " + "Offset dimension %d in gather op is out of bounds; got %lld, but " + "should " "have been in [0,%lld).", - i, window_index, output_shape_rank); + i, offset_dim, output_shape_rank); } } - if (dim_numbers.gather_dims_to_operand_dims_size() != - gather_indices_shape[dim_numbers.index_vector_dim()]) { + if (dim_numbers.start_index_map_size() != + start_indices_shape[dim_numbers.index_vector_dim()]) { return InvalidArgument( - "Gather op has %d elements in gather_dims_to_operand_dims and the " - "bound of dimension index_vector_dim=%lld of gather_indices is " + "Gather op has %d elements in start_index_map and the " + "bound of dimension index_vector_dim=%lld of start_indices is " "%lld. These two numbers must be equal.", - dim_numbers.gather_dims_to_operand_dims_size(), - dim_numbers.index_vector_dim(), - gather_indices_shape[dim_numbers.index_vector_dim()]); + dim_numbers.start_index_map_size(), dim_numbers.index_vector_dim(), + start_indices_shape[dim_numbers.index_vector_dim()]); } - for (int i = 0; i < dim_numbers.gather_dims_to_operand_dims_size(); i++) { - int64 gather_dim_to_input_dim = dim_numbers.gather_dims_to_operand_dims(i); - if (gather_dim_to_input_dim < 0 || - gather_dim_to_input_dim >= input_shape.dimensions_size()) { + for (int i = 0; i < dim_numbers.start_index_map_size(); i++) { + int64 operand_dim_for_start_index_i = dim_numbers.start_index_map(i); + if (operand_dim_for_start_index_i < 0 || + operand_dim_for_start_index_i >= input_shape.dimensions_size()) { return InvalidArgument( - "Invalid gather_dims_to_operand_dims mapping; domain is [0, %d), " - "got: %d->%lld.", - input_shape.dimensions_size(), i, gather_dim_to_input_dim); + "Invalid start_index_map; domain is [0, %d), got: %d->%lld.", + input_shape.dimensions_size(), i, operand_dim_for_start_index_i); } } - std::vector sorted_gather_dims_to_operand_dims( - dim_numbers.gather_dims_to_operand_dims().begin(), - dim_numbers.gather_dims_to_operand_dims().end()); + std::vector sorted_start_index_map( + dim_numbers.start_index_map().begin(), + dim_numbers.start_index_map().end()); - c_sort(sorted_gather_dims_to_operand_dims); + absl::c_sort(sorted_start_index_map); - if (c_adjacent_find(sorted_gather_dims_to_operand_dims) != - sorted_gather_dims_to_operand_dims.end()) { + if (absl::c_adjacent_find(sorted_start_index_map) != + sorted_start_index_map.end()) { return InvalidArgument( - "Repeated dimensions are not allowed in gather_dims_to_operand_dims; " + "Repeated dimensions are not allowed in start_index_map; " "got: %s.", - Join(dim_numbers.gather_dims_to_operand_dims(), ", ").c_str()); + Join(dim_numbers.start_index_map(), ", ").c_str()); } - for (int64 elided_dim : dim_numbers.elided_window_dims()) { - if (elided_dim < 0 || elided_dim >= input_shape.dimensions_size()) { + for (int64 collapsed_dim : dim_numbers.collapsed_slice_dims()) { + if (collapsed_dim < 0 || collapsed_dim >= input_shape.dimensions_size()) { return InvalidArgument( - "Invalid elided_window_dims set in gather op; valid range is [0, " + "Invalid collapsed_slice_dims set in gather op; valid range is [0, " "%d), got: %lld.", - input_shape.dimensions_size(), elided_dim); + input_shape.dimensions_size(), collapsed_dim); } } - if (!c_is_sorted(dim_numbers.elided_window_dims())) { + if (!absl::c_is_sorted(dim_numbers.collapsed_slice_dims())) { return InvalidArgument( - "elided_window_dims in gather op must be sorted; got: %s", - Join(dim_numbers.elided_window_dims(), ", ").c_str()); + "collapsed_slice_dims in gather op must be sorted; got: %s", + Join(dim_numbers.collapsed_slice_dims(), ", ").c_str()); } - if (c_adjacent_find(dim_numbers.elided_window_dims()) != - dim_numbers.elided_window_dims().end()) { + if (absl::c_adjacent_find(dim_numbers.collapsed_slice_dims()) != + dim_numbers.collapsed_slice_dims().end()) { return InvalidArgument( - "Repeated dimensions not allowed in elided_window_dims in gather op; " + "Repeated dimensions not allowed in collapsed_slice_dims in gather op; " "got: %s.", - Join(dim_numbers.elided_window_dims(), ", ").c_str()); + Join(dim_numbers.collapsed_slice_dims(), ", ").c_str()); } return Status::OK(); } /*static*/ StatusOr ShapeInference::InferGatherShape( - const Shape& input_shape, const Shape& gather_indices_shape, + const Shape& input_shape, const Shape& start_indices_shape, const GatherDimensionNumbers& gather_dim_numbers, - tensorflow::gtl::ArraySlice window_bounds) { + tensorflow::gtl::ArraySlice slice_sizes) { TF_RETURN_IF_ERROR( ExpectArray(input_shape, "input tensor operand gather op")); TF_RETURN_IF_ERROR( - ExpectArray(gather_indices_shape, "gather indices operand of gather op")); + ExpectArray(start_indices_shape, "gather indices operand of gather op")); - if (!ShapeUtil::ElementIsIntegral(gather_indices_shape)) { + if (!ShapeUtil::ElementIsIntegral(start_indices_shape)) { return InvalidArgument( "Gather indices parameter must be an integral tensor; got %s.", - ShapeUtil::HumanString(gather_indices_shape).c_str()); + ShapeUtil::HumanString(start_indices_shape).c_str()); } // We implicitly reshape gather indices of shape P[A,B,C] to P[A,B,C,1] if // index_vector_dim is rank(P). The bounds of this expanded shape is - // stored in expanded_gather_indices_shape. + // stored in expanded_start_indices_shape. - if (gather_indices_shape.dimensions_size() < + if (start_indices_shape.dimensions_size() < gather_dim_numbers.index_vector_dim() || gather_dim_numbers.index_vector_dim() < 0) { return InvalidArgument( - "Gather index leaf dimension must be within [0, rank(gather_indices) + " - "1). rank(gather_indices) is %d and gather index leaf dimension is " + "Gather index leaf dimension must be within [0, rank(start_indices) + " + "1). rank(start_indices) is %d and gather index leaf dimension is " "%lld.", - gather_indices_shape.dimensions_size(), + start_indices_shape.dimensions_size(), gather_dim_numbers.index_vector_dim()); } - std::vector expanded_gather_indices_shape; - expanded_gather_indices_shape.reserve(gather_indices_shape.dimensions_size()); - c_copy(gather_indices_shape.dimensions(), - std::back_inserter(expanded_gather_indices_shape)); - if (expanded_gather_indices_shape.size() == + std::vector expanded_start_indices_shape; + expanded_start_indices_shape.reserve(start_indices_shape.dimensions_size()); + absl::c_copy(start_indices_shape.dimensions(), + std::back_inserter(expanded_start_indices_shape)); + if (expanded_start_indices_shape.size() == gather_dim_numbers.index_vector_dim()) { - expanded_gather_indices_shape.push_back(1); + expanded_start_indices_shape.push_back(1); } TF_RETURN_IF_ERROR(ValidateGatherDimensionNumbers( - input_shape, expanded_gather_indices_shape, gather_dim_numbers)); + input_shape, expanded_start_indices_shape, gather_dim_numbers)); - if (window_bounds.size() != input_shape.dimensions_size()) { + if (slice_sizes.size() != input_shape.dimensions_size()) { return InvalidArgument( - "Gather op must have one window bound for every input dimension; got: " - "len(window_bounds)=%lu, input_shape.rank=%d.", - window_bounds.size(), input_shape.dimensions_size()); + "Gather op must have one slice size for every input dimension; got: " + "len(slice_sizes)=%lu, input_shape.rank=%d.", + slice_sizes.size(), input_shape.dimensions_size()); } - if (window_bounds.size() != - gather_dim_numbers.output_window_dims_size() + - gather_dim_numbers.elided_window_dims_size()) { + if (slice_sizes.size() != + gather_dim_numbers.offset_dims_size() + + gather_dim_numbers.collapsed_slice_dims_size()) { return InvalidArgument( - "All components of the window index in a gather op must either be a " - "output window index or explicitly elided; got len(window_bounds)=%lu, " - "output_window_bounds=%s, elided_window_bounds=%s.", - window_bounds.size(), - Join(gather_dim_numbers.output_window_dims(), ",").c_str(), - Join(gather_dim_numbers.elided_window_dims(), ",").c_str()); + "All components of the offset index in a gather op must either be a " + "offset dimension or explicitly collapsed; got len(slice_sizes)=%lu, " + "output_slice_sizes=%s, collapsed_slice_dims=%s.", + slice_sizes.size(), Join(gather_dim_numbers.offset_dims(), ",").c_str(), + Join(gather_dim_numbers.collapsed_slice_dims(), ",").c_str()); } - for (int i = 0; i < window_bounds.size(); i++) { - int64 window_bound = window_bounds[i]; - int64 corresponding_input_bound = input_shape.dimensions(i); - if (window_bound < 0 || window_bound > corresponding_input_bound) { + for (int i = 0; i < slice_sizes.size(); i++) { + int64 slice_size = slice_sizes[i]; + int64 corresponding_input_size = input_shape.dimensions(i); + if (slice_size < 0 || slice_size > corresponding_input_size) { return InvalidArgument( - "Window bound at index %d in gather op is out of range, must be " - "within " - "[0, %lld), got %lld.", - i, corresponding_input_bound + 1, window_bound); + "Slice size at index %d in gather op is out of range, must be " + "within [0, %lld), got %lld.", + i, corresponding_input_size + 1, slice_size); } } - for (int i = 0; i < gather_dim_numbers.elided_window_dims_size(); i++) { - if (window_bounds[gather_dim_numbers.elided_window_dims(i)] != 1) { + for (int i = 0; i < gather_dim_numbers.collapsed_slice_dims_size(); i++) { + if (slice_sizes[gather_dim_numbers.collapsed_slice_dims(i)] != 1) { return InvalidArgument( - "Gather op can only elide window indices with bound 1, but bound is " + "Gather op can only collapse slice dims with bound 1, but bound is " "%lld for index %lld at position %d.", - window_bounds[gather_dim_numbers.elided_window_dims(i)], - gather_dim_numbers.elided_window_dims(i), i); + slice_sizes[gather_dim_numbers.collapsed_slice_dims(i)], + gather_dim_numbers.collapsed_slice_dims(i), i); } } - int64 result_rank = gather_dim_numbers.output_window_dims_size() + - (expanded_gather_indices_shape.size() - 1); - int64 window_dims_seen = 0; + int64 result_rank = gather_dim_numbers.offset_dims_size() + + (expanded_start_indices_shape.size() - 1); + int64 offset_dims_seen = 0; int64 gather_dims_seen = 0; std::vector output_dim_bounds; output_dim_bounds.reserve(result_rank); for (int64 i = 0; i < result_rank; i++) { int64 current_bound; bool is_window_index = - c_binary_search(gather_dim_numbers.output_window_dims(), i); + absl::c_binary_search(gather_dim_numbers.offset_dims(), i); if (is_window_index) { - while (c_binary_search(gather_dim_numbers.elided_window_dims(), - window_dims_seen)) { - window_dims_seen++; + while (absl::c_binary_search(gather_dim_numbers.collapsed_slice_dims(), + offset_dims_seen)) { + offset_dims_seen++; } - current_bound = window_bounds[window_dims_seen++]; + current_bound = slice_sizes[offset_dims_seen++]; } else { if (gather_dims_seen == gather_dim_numbers.index_vector_dim()) { gather_dims_seen++; } - current_bound = expanded_gather_indices_shape[gather_dims_seen++]; + current_bound = expanded_start_indices_shape[gather_dims_seen++]; } output_dim_bounds.push_back(current_bound); @@ -2568,4 +2693,194 @@ static Status ValidateGatherDimensionNumbers( return ShapeUtil::MakeShape(input_shape.element_type(), output_dim_bounds); } +namespace { + +Status ValidateScatterDimensionNumbers( + const Shape& operand_shape, + tensorflow::gtl::ArraySlice scatter_indices_shape, + const Shape& updates_shape, const ScatterDimensionNumbers& dim_numbers) { + // Validate update_window_dims in ScatterDimensionNumbers. + if (!absl::c_is_sorted(dim_numbers.update_window_dims())) { + return InvalidArgument( + "update_window_dims in scatter op must be sorted; got: %s.", + Join(dim_numbers.update_window_dims(), ", ").c_str()); + } + if (absl::c_adjacent_find(dim_numbers.update_window_dims()) != + dim_numbers.update_window_dims().end()) { + return InvalidArgument( + "update_window_dims in scatter op must not repeat; got: %s.", + Join(dim_numbers.update_window_dims(), ", ").c_str()); + } + const int64 updates_rank = ShapeUtil::Rank(updates_shape); + for (int64 window_dim : dim_numbers.update_window_dims()) { + if (window_dim < 0 || window_dim >= updates_rank) { + return InvalidArgument( + "Invalid update_window_dims set in scatter op; valid range is [0, " + "%lld). got: %lld.", + updates_rank, window_dim); + } + } + + // Validate inserted_window_dims in ScatterDimensionNumbers. + if (!absl::c_is_sorted(dim_numbers.inserted_window_dims())) { + return InvalidArgument( + "inserted_window_dims in scatter op must be sorted; got: %s.", + Join(dim_numbers.inserted_window_dims(), ", ").c_str()); + } + if (absl::c_adjacent_find(dim_numbers.inserted_window_dims()) != + dim_numbers.inserted_window_dims().end()) { + return InvalidArgument( + "inserted_window_dims in scatter op must not repeat; got: %s.", + Join(dim_numbers.inserted_window_dims(), ", ").c_str()); + } + for (int64 inserted_dim : dim_numbers.inserted_window_dims()) { + if (inserted_dim < 0 || inserted_dim >= operand_shape.dimensions_size()) { + return InvalidArgument( + "Invalid inserted_window_dims set in scatter op; valid range is [0, " + "%d), got: %lld.", + operand_shape.dimensions_size(), inserted_dim); + } + } + + // Validate scatter_dims_to_operand_dims in ScatterDimensionNumbers. + if (dim_numbers.scatter_dims_to_operand_dims_size() != + scatter_indices_shape[dim_numbers.index_vector_dim()]) { + return InvalidArgument( + "Scatter op has %d elements in scatter_dims_to_operand_dims and the " + "bound of dimension index_vector_dim=%lld of scatter_indices is %lld. " + "These two numbers must be equal.", + dim_numbers.scatter_dims_to_operand_dims_size(), + dim_numbers.index_vector_dim(), + scatter_indices_shape[dim_numbers.index_vector_dim()]); + } + for (int i = 0; i < dim_numbers.scatter_dims_to_operand_dims_size(); ++i) { + int64 scatter_dim_to_operand_dim = + dim_numbers.scatter_dims_to_operand_dims(i); + if (scatter_dim_to_operand_dim < 0 || + scatter_dim_to_operand_dim >= operand_shape.dimensions_size()) { + return InvalidArgument( + "Invalid scatter_dims_to_operand_dims mapping; domain is [0, %d), " + "got: %d->%lld.", + operand_shape.dimensions_size(), i, scatter_dim_to_operand_dim); + } + } + std::vector sorted_scatter_dims_to_operand_dims( + dim_numbers.scatter_dims_to_operand_dims().begin(), + dim_numbers.scatter_dims_to_operand_dims().end()); + absl::c_sort(sorted_scatter_dims_to_operand_dims); + if (absl::c_adjacent_find(sorted_scatter_dims_to_operand_dims) != + sorted_scatter_dims_to_operand_dims.end()) { + return InvalidArgument( + "Repeated dimensions not allowed in scatter_dims_to_operand_dims; " + "got: %s.", + Join(dim_numbers.scatter_dims_to_operand_dims(), ", ").c_str()); + } + + return Status::OK(); +} + +} // namespace + +/*static*/ StatusOr ShapeInference::InferScatterShape( + const Shape& operand_shape, const Shape& scatter_indices_shape, + const Shape& updates_shape, const ProgramShape& to_apply_shape, + const ScatterDimensionNumbers& scatter_dim_numbers) { + TF_RETURN_IF_ERROR( + ExpectArray(operand_shape, "operand tensor of scatter op")); + TF_RETURN_IF_ERROR( + ExpectArray(scatter_indices_shape, "scatter indices of scatter op")); + TF_RETURN_IF_ERROR(ExpectArray(updates_shape, "updates of scatter op")); + + if (!ShapeUtil::ElementIsIntegral(scatter_indices_shape)) { + return InvalidArgument( + "Scatter indices parameter must be an integral tensor; got %s.", + ShapeUtil::HumanString(scatter_indices_shape).c_str()); + } + + if (scatter_indices_shape.dimensions_size() < + scatter_dim_numbers.index_vector_dim() || + scatter_dim_numbers.index_vector_dim() < 0) { + return InvalidArgument( + "Scatter index leaf dimension must be within [0, rank(scatter_indices)" + " + 1). rank(scatter_indices) is %d and scatter index leaf dimension " + "is %lld.", + scatter_indices_shape.dimensions_size(), + scatter_dim_numbers.index_vector_dim()); + } + + // Check if the update computation has a proper shape as a reduction. + const Shape init_value_shape = + ShapeUtil::MakeShape(operand_shape.element_type(), {}); + TF_RETURN_IF_ERROR(VerifyReducerShape(to_apply_shape, {&init_value_shape}, + {updates_shape.element_type()}, + /*inputs=*/1)); + + std::vector expanded_scatter_indices_shape = + ArraySliceToVector(AsInt64Slice(scatter_indices_shape.dimensions())); + if (expanded_scatter_indices_shape.size() == + scatter_dim_numbers.index_vector_dim()) { + expanded_scatter_indices_shape.push_back(1); + } + + int64 expected_updates_rank = expanded_scatter_indices_shape.size() - 1 + + scatter_dim_numbers.update_window_dims_size(); + if (ShapeUtil::Rank(updates_shape) != expected_updates_rank) { + return InvalidArgument("Updates tensor must be of rank %lld; got %lld.", + expected_updates_rank, + ShapeUtil::Rank(updates_shape)); + } + + TF_RETURN_IF_ERROR(ValidateScatterDimensionNumbers( + operand_shape, expanded_scatter_indices_shape, updates_shape, + scatter_dim_numbers)); + + int64 inserted_dims_seen = 0; + std::vector max_update_slice_sizes; + for (int i = 0; i < operand_shape.dimensions_size(); ++i) { + if (inserted_dims_seen < scatter_dim_numbers.inserted_window_dims_size() && + scatter_dim_numbers.inserted_window_dims(inserted_dims_seen) == i) { + ++inserted_dims_seen; + } else { + max_update_slice_sizes.push_back(operand_shape.dimensions(i)); + } + } + for (int i = 0; i < scatter_dim_numbers.update_window_dims_size(); ++i) { + auto update_window_dim = scatter_dim_numbers.update_window_dims(i); + if (updates_shape.dimensions(update_window_dim) > + max_update_slice_sizes[i]) { + return InvalidArgument( + "Bounds of the window dimensions of updates must not exceed the " + "bounds of the corresponding dimensions of operand. For dimension " + "%lld, updates bound is %lld, operand bound is %lld.", + update_window_dim, updates_shape.dimensions(update_window_dim), + max_update_slice_sizes[i]); + } + } + + int64 scatter_dims_seen = 0; + for (int64 i = 0; i < ShapeUtil::Rank(updates_shape); ++i) { + bool is_update_window_dim = + absl::c_binary_search(scatter_dim_numbers.update_window_dims(), i); + if (is_update_window_dim) { + continue; + } + if (scatter_dims_seen == scatter_dim_numbers.index_vector_dim()) { + ++scatter_dims_seen; + } + if (updates_shape.dimensions(i) != + expanded_scatter_indices_shape[scatter_dims_seen]) { + return InvalidArgument( + "Bounds of the scatter dimensions of updates must be same as the " + "bounds of the corresponding dimensions of scatter indices. For " + "scatter dimension %lld, updates bound is %lld, scatter_indices " + "bound is %lld.", + i, updates_shape.dimensions(i), + expanded_scatter_indices_shape[scatter_dims_seen]); + } + ++scatter_dims_seen; + } + + return operand_shape; +} + } // namespace xla diff --git a/tensorflow/compiler/xla/service/shape_inference.h b/tensorflow/compiler/xla/service/shape_inference.h index 1a5684e3c306eef90fd1bfdf4565b0dcde2fbab6..4974ac9916abaea25f8d455b24f7c0904277f5f7 100644 --- a/tensorflow/compiler/xla/service/shape_inference.h +++ b/tensorflow/compiler/xla/service/shape_inference.h @@ -112,18 +112,30 @@ class ShapeInference { // filter (rhs) to lhs in the way specified by the fields on window. static StatusOr InferConvolveShape( const Shape& lhs, const Shape& rhs, const Window& window, - const ConvolutionDimensionNumbers& dimension_numbers); + const ConvolutionDimensionNumbers& dimension_numbers, + int64 feature_group_count = 1); // Infers the shape produced by the given FFT type on the given operand. static StatusOr InferFftShape( const Shape& in, FftType fft_type, tensorflow::gtl::ArraySlice fft_length); - // Infers the shape produced a cross replica sum with the given operand + // Infers the shape produced by a cross replica sum with the given operand // shapes. static StatusOr InferCrossReplicaSumShape( tensorflow::gtl::ArraySlice operand_shapes); + // Infers final shape of an Alltoall operation that is created by the xla + // builder. + static StatusOr InferAllToAllShape(const Shape& shape, + int64 split_dimension, + int64 concat_dimension, + int64 split_count); + + // Infers the shape of an HLO all-to-all instruction. + static StatusOr InferAllToAllTupleShape( + tensorflow::gtl::ArraySlice operand_shapes); + // Infers the shape produced by applying the given reduction computation // shape to the given input operand shape. // @@ -131,7 +143,7 @@ class ShapeInference { // index as the leading parameter, and the program shape should match // accordingly (or an error will result). static StatusOr InferReduceShape( - const Shape& arg, const Shape& init_value, + tensorflow::gtl::ArraySlice arg_shapes, tensorflow::gtl::ArraySlice dimensions_to_reduce, const ProgramShape& to_apply); @@ -264,9 +276,17 @@ class ShapeInference { // with the given input shape, gather indices shape and gather dimension // numbers. static StatusOr InferGatherShape( - const Shape& input_shape, const Shape& gather_indices_shape, + const Shape& input_shape, const Shape& start_indices_shape, const GatherDimensionNumbers& gather_dim_numbers, - tensorflow::gtl::ArraySlice window_bounds); + tensorflow::gtl::ArraySlice slice_sizes); + + // Helper that validates the given input shape, scatter indices shape, updates + // shape, and scatter dimension numbers that constitute a scatter operation, + // and returns the result shape of the scatter operation. + static StatusOr InferScatterShape( + const Shape& operand_shape, const Shape& scatter_indices_shape, + const Shape& updates_shape, const ProgramShape& to_apply_shape, + const ScatterDimensionNumbers& scatter_dim_numbers); private: // Helper that infers the shape produced by performing an element-wise binary diff --git a/tensorflow/compiler/xla/service/shape_inference_test.cc b/tensorflow/compiler/xla/service/shape_inference_test.cc index 6046d50c6d41a3956b996a3320848784ffd59068..4ed8fc6b8654fb87701a629c1ded397fe23e52cd 100644 --- a/tensorflow/compiler/xla/service/shape_inference_test.cc +++ b/tensorflow/compiler/xla/service/shape_inference_test.cc @@ -63,7 +63,7 @@ class ReduceShapeInferenceTest : public ShapeInferenceTest { tensorflow::gtl::ArraySlice dimensions_to_reduce) { ProgramShape to_apply = ShapeUtil::MakeProgramShape({f32_, f32_}, f32_); auto inferred_status = ShapeInference::InferReduceShape( - arg, f32_, dimensions_to_reduce, to_apply); + {&arg, &f32_}, dimensions_to_reduce, to_apply); EXPECT_IS_OK(inferred_status.status()); EXPECT_TRUE(ShapeUtil::Equal(expected_inferred_shape, inferred_status.ValueOrDie())); @@ -703,11 +703,99 @@ TEST_F(ReduceShapeInferenceTest, ReduceCubeAmongAllDimensions) { /*dimensions_to_reduce=*/{0, 1, 2}); } +TEST_F(ReduceShapeInferenceTest, ReduceMultiOutput) { + Shape f32_arg_shape = ShapeUtil::MakeShape(F32, {5, 3}); + Shape s32_arg_shape = ShapeUtil::MakeShape(S32, {5, 3}); + ProgramShape to_apply = ShapeUtil::MakeProgramShape( + {f32_, s32_, f32_, s32_}, ShapeUtil::MakeTupleShape({f32_, s32_})); + auto inferred_status = ShapeInference::InferReduceShape( + {&f32_arg_shape, &s32_arg_shape, &f32_, &s32_}, {0, 1}, to_apply); + EXPECT_IS_OK(inferred_status.status()); + EXPECT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeTupleShape({f32_, s32_}), + inferred_status.ValueOrDie())); +} + +TEST_F(ReduceShapeInferenceTest, ErrorMultiOutputBadReducerInput1) { + Shape f32_arg_shape = ShapeUtil::MakeShape(F32, {5, 3}); + Shape s32_arg_shape = ShapeUtil::MakeShape(S32, {5, 3}); + ProgramShape to_apply = + ShapeUtil::MakeProgramShape({f32_, s32_, f32_, s32_, f32_, s32_}, + ShapeUtil::MakeTupleShape({f32_, s32_})); + auto inferred_status = ShapeInference::InferReduceShape( + {&f32_arg_shape, &s32_arg_shape, &f32_, &s32_}, {0, 1}, to_apply); + EXPECT_FALSE(inferred_status.ok()); + EXPECT_THAT(inferred_status.status().error_message(), + HasSubstr("must take 4 parameters, but takes 6 parameter(s)")); +} + +TEST_F(ReduceShapeInferenceTest, ErrorMultiOutputBadReducerInput2) { + Shape f32_arg_shape = ShapeUtil::MakeShape(F32, {5, 3}); + Shape s32_arg_shape = ShapeUtil::MakeShape(S32, {5, 3}); + ProgramShape to_apply = ShapeUtil::MakeProgramShape( + {s32_, s32_, f32_, s32_}, ShapeUtil::MakeTupleShape({f32_, s32_})); + auto inferred_status = ShapeInference::InferReduceShape( + {&f32_arg_shape, &s32_arg_shape, &f32_, &s32_}, {0, 1}, to_apply); + EXPECT_FALSE(inferred_status.ok()); + EXPECT_THAT( + inferred_status.status().error_message(), + HasSubstr( + "parameter shape differs from the result shape: s32[] vs f32[]")); +} + +TEST_F(ReduceShapeInferenceTest, ErrorMultiOutputBadReducerInput3) { + ProgramShape to_apply = ShapeUtil::MakeProgramShape( + {s32_, s32_, f32_, s32_}, ShapeUtil::MakeTupleShape({f32_, s32_})); + auto inferred_status = ShapeInference::InferReduceShape({}, {0, 1}, to_apply); + EXPECT_FALSE(inferred_status.ok()); + EXPECT_THAT(inferred_status.status().error_message(), + HasSubstr("must have at least 2 arguments, has 0")); +} + +TEST_F(ReduceShapeInferenceTest, ErrorMultiOutputBadReducerOutput1) { + Shape f32_arg_shape = ShapeUtil::MakeShape(F32, {5, 3}); + Shape s32_arg_shape = ShapeUtil::MakeShape(S32, {5, 3}); + ProgramShape to_apply = + ShapeUtil::MakeProgramShape({f32_, s32_, f32_, s32_}, f32_); + auto inferred_status = ShapeInference::InferReduceShape( + {&f32_arg_shape, &s32_arg_shape, &f32_, &s32_}, {0, 1}, to_apply); + EXPECT_FALSE(inferred_status.ok()); + EXPECT_THAT( + inferred_status.status().error_message(), + HasSubstr("must produce a tuple with 2 elements, but produces a scalar")); +} + +TEST_F(ReduceShapeInferenceTest, ErrorMultiOutputBadReducerOutput2) { + Shape f32_arg_shape = ShapeUtil::MakeShape(F32, {5, 3}); + Shape s32_arg_shape = ShapeUtil::MakeShape(S32, {5, 3}); + ProgramShape to_apply = ShapeUtil::MakeProgramShape( + {f32_, s32_, f32_, s32_}, ShapeUtil::MakeTupleShape({f32_, s32_, s32_})); + auto inferred_status = ShapeInference::InferReduceShape( + {&f32_arg_shape, &s32_arg_shape, &f32_, &s32_}, {0, 1}, to_apply); + EXPECT_FALSE(inferred_status.ok()); + EXPECT_THAT( + inferred_status.status().error_message(), + HasSubstr("must produce a tuple with 2 elements, but has 3 elements")); +} + +TEST_F(ReduceShapeInferenceTest, ErrorMultiOutputBadReducerBoth) { + Shape f32_arg_shape = ShapeUtil::MakeShape(F32, {5, 3}); + Shape s32_arg_shape = ShapeUtil::MakeShape(S32, {5, 3}); + ProgramShape to_apply = ShapeUtil::MakeProgramShape( + {s32_, s32_, s32_, s32_}, ShapeUtil::MakeTupleShape({s32_, s32_})); + auto inferred_status = ShapeInference::InferReduceShape( + {&f32_arg_shape, &s32_arg_shape, &f32_, &s32_}, {0, 1}, to_apply); + EXPECT_FALSE(inferred_status.ok()); + EXPECT_THAT(inferred_status.status().error_message(), + HasSubstr("accumulator shape at index 0 differs from the " + "init_value shape: s32[] vs f32[]")); +} + TEST_F(ReduceShapeInferenceTest, ErrorOutOfBoundsDimension) { ProgramShape to_apply = ShapeUtil::MakeProgramShape({f32_, f32_}, f32_); + Shape arg_shape = ShapeUtil::MakeShape(F32, {5, 3}); auto inferred_status = ShapeInference::InferReduceShape( - ShapeUtil::MakeShape(F32, {5, 3}), f32_, /*dimensions_to_reduce=*/{3, 4}, - to_apply); + {&arg_shape, &f32_}, + /*dimensions_to_reduce=*/{3, 4}, to_apply); EXPECT_FALSE(inferred_status.ok()); EXPECT_THAT(inferred_status.status().error_message(), HasSubstr("out-of-bounds dimension")); @@ -715,8 +803,9 @@ TEST_F(ReduceShapeInferenceTest, ErrorOutOfBoundsDimension) { TEST_F(ReduceShapeInferenceTest, ErrorToApplyArity) { ProgramShape to_apply = ShapeUtil::MakeProgramShape({f32_, f32_, f32_}, f32_); + Shape arg_shape = ShapeUtil::MakeShape(F32, {5, 3}); auto inferred_status = - ShapeInference::InferReduceShape(ShapeUtil::MakeShape(F32, {5, 3}), f32_, + ShapeInference::InferReduceShape({&arg_shape, &f32_}, /*dimensions_to_reduce=*/{0}, to_apply); EXPECT_FALSE(inferred_status.ok()); EXPECT_THAT(inferred_status.status().error_message(), @@ -725,12 +814,13 @@ TEST_F(ReduceShapeInferenceTest, ErrorToApplyArity) { TEST_F(ReduceShapeInferenceTest, ErrorElementTypeVsApplyType) { ProgramShape to_apply = ShapeUtil::MakeProgramShape({f32_, f32_}, s32_); + Shape arg_shape = ShapeUtil::MakeShape(F32, {5, 3}); auto inferred_status = - ShapeInference::InferReduceShape(ShapeUtil::MakeShape(F32, {5, 3}), f32_, + ShapeInference::InferReduceShape({&arg_shape, &f32_}, /*dimensions_to_reduce=*/{0}, to_apply); EXPECT_FALSE(inferred_status.ok()); EXPECT_THAT(inferred_status.status().error_message(), - HasSubstr("first parameter shape differs")); + HasSubstr("0-th parameter shape differs")); } TEST_F(ShapeInferenceTest, InferSliceShapeRank2) { @@ -1536,7 +1626,7 @@ TEST_F(ShapeInferenceTest, BadSort) { << statusor.status(); } -class GatherShapeInferenceTest : public ShapeInferenceTest { +class ScatterGatherShapeInferenceTest : public ShapeInferenceTest { protected: const Shape s64_scalar_ = ShapeUtil::MakeShape(S64, {}); const Shape s64_vector_5_ = ShapeUtil::MakeShape(S64, {5}); @@ -1553,81 +1643,85 @@ class GatherShapeInferenceTest : public ShapeInferenceTest { ShapeUtil::MakeShape(F32, {50, 49, 48, 47, 46}); const Shape tuple_shape_ = ShapeUtil::MakeTupleShape( {s64_4d_tensor_10_9_8_7_1_, s64_4d_tensor_10_9_8_7_1_}); + const ProgramShape to_apply_ = + ShapeUtil::MakeProgramShape({f32_, f32_}, f32_); }; -TEST_F(GatherShapeInferenceTest, TensorFlowGather) { +// Shape inference tests for Gather. + +TEST_F(ScatterGatherShapeInferenceTest, TensorFlowGather) { TF_ASSERT_OK_AND_ASSIGN(Shape gather_shape, ShapeInference::InferGatherShape( matrix_64_48_, s64_vector_32_, HloGatherInstruction::MakeGatherDimNumbers( - /*output_window_dims=*/{0}, - /*elided_window_dims=*/{1}, - /*gather_dims_to_operand_dims=*/{1}, + /*offset_dims=*/{0}, + /*collapsed_slice_dims=*/{1}, + /*start_index_map=*/{1}, /*index_vector_dim=*/1), - /*window_bounds=*/{64, 1})); + /*slice_sizes=*/{64, 1})); EXPECT_TRUE( ShapeUtil::Equal(gather_shape, ShapeUtil::MakeShape(F32, {64, 32}))) << ShapeUtil::HumanString(gather_shape); } -TEST_F(GatherShapeInferenceTest, TensorFlowGatherV2) { +TEST_F(ScatterGatherShapeInferenceTest, TensorFlowGatherV2) { TF_ASSERT_OK_AND_ASSIGN(Shape gather_shape, ShapeInference::InferGatherShape( matrix_64_48_, s64_vector_32_, HloGatherInstruction::MakeGatherDimNumbers( - /*output_window_dims=*/{1}, - /*elided_window_dims=*/{0}, - /*gather_dims_to_operand_dims=*/{0}, + /*offset_dims=*/{1}, + /*collapsed_slice_dims=*/{0}, + /*start_index_map=*/{0}, /*index_vector_dim=*/1), - /*window_bounds=*/{1, 48})); + /*slice_sizes=*/{1, 48})); EXPECT_TRUE( ShapeUtil::Equal(gather_shape, ShapeUtil::MakeShape(F32, {32, 48}))) << ShapeUtil::HumanString(gather_shape); } -TEST_F(GatherShapeInferenceTest, TensorFlowGatherNd) { +TEST_F(ScatterGatherShapeInferenceTest, TensorFlowGatherNd) { TF_ASSERT_OK_AND_ASSIGN(Shape gather_shape, ShapeInference::InferGatherShape( matrix_64_48_, s64_4d_tensor_10_9_8_7_1_, HloGatherInstruction::MakeGatherDimNumbers( - /*output_window_dims=*/{4}, - /*elided_window_dims=*/{0}, - /*gather_dims_to_operand_dims=*/{0}, + /*offset_dims=*/{4}, + /*collapsed_slice_dims=*/{0}, + /*start_index_map=*/{0}, /*index_vector_dim=*/4), - /*window_bounds=*/{1, 48})); + /*slice_sizes=*/{1, 48})); EXPECT_TRUE(ShapeUtil::Equal(gather_shape, ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 48}))) << ShapeUtil::HumanString(gather_shape); } -TEST_F(GatherShapeInferenceTest, TensorFlowBatchDynamicSlice) { +TEST_F(ScatterGatherShapeInferenceTest, TensorFlowBatchDynamicSlice) { TF_ASSERT_OK_AND_ASSIGN( Shape gather_shape, ShapeInference::InferGatherShape( f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, HloGatherInstruction::MakeGatherDimNumbers( - /*output_window_dims=*/{4, 5, 6, 7, 8}, - /*elided_window_dims=*/{}, - /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, + /*offset_dims=*/{4, 5, 6, 7, 8}, + /*collapsed_slice_dims=*/{}, + /*start_index_map=*/{0, 1, 2, 3, 4}, /*index_vector_dim=*/4), - /*window_bounds=*/{30, 29, 28, 27, 26})); + /*slice_sizes=*/{30, 29, 28, 27, 26})); EXPECT_TRUE(ShapeUtil::Equal( gather_shape, ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 30, 29, 28, 27, 26}))) << ShapeUtil::HumanString(gather_shape); } -TEST_F(GatherShapeInferenceTest, NonDefaultGatherIndicesLeafDim_A) { +TEST_F(ScatterGatherShapeInferenceTest, NonDefaultGatherIndicesLeafDim_A) { TF_ASSERT_OK_AND_ASSIGN( Shape gather_shape, ShapeInference::InferGatherShape( f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_5_7_6_, HloGatherInstruction::MakeGatherDimNumbers( - /*output_window_dims=*/{4, 5, 6, 7, 8}, - /*elided_window_dims=*/{}, - /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, + /*offset_dims=*/{4, 5, 6, 7, 8}, + /*collapsed_slice_dims=*/{}, + /*start_index_map=*/{0, 1, 2, 3, 4}, /*index_vector_dim=*/2), - /*window_bounds=*/{30, 29, 28, 27, 26})); + /*slice_sizes=*/{30, 29, 28, 27, 26})); EXPECT_TRUE(ShapeUtil::Equal( gather_shape, @@ -1635,17 +1729,17 @@ TEST_F(GatherShapeInferenceTest, NonDefaultGatherIndicesLeafDim_A) { << ShapeUtil::HumanString(gather_shape); } -TEST_F(GatherShapeInferenceTest, NonDefaultGatherIndicesLeafDim_B) { +TEST_F(ScatterGatherShapeInferenceTest, NonDefaultGatherIndicesLeafDim_B) { TF_ASSERT_OK_AND_ASSIGN( Shape gather_shape, ShapeInference::InferGatherShape( f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_5_10_9_7_6_, HloGatherInstruction::MakeGatherDimNumbers( - /*output_window_dims=*/{4, 5, 6, 7, 8}, - /*elided_window_dims=*/{}, - /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, + /*offset_dims=*/{4, 5, 6, 7, 8}, + /*collapsed_slice_dims=*/{}, + /*start_index_map=*/{0, 1, 2, 3, 4}, /*index_vector_dim=*/0), - /*window_bounds=*/{30, 29, 28, 27, 26})); + /*slice_sizes=*/{30, 29, 28, 27, 26})); EXPECT_TRUE(ShapeUtil::Equal( gather_shape, @@ -1653,97 +1747,96 @@ TEST_F(GatherShapeInferenceTest, NonDefaultGatherIndicesLeafDim_B) { << ShapeUtil::HumanString(gather_shape); } -TEST_F(GatherShapeInferenceTest, NoOutputGatherDims) { +TEST_F(ScatterGatherShapeInferenceTest, NoOutputGatherDims) { // This is equivalent to a dynamic slice. - TF_ASSERT_OK_AND_ASSIGN( - Shape gather_shape, - ShapeInference::InferGatherShape( - f32_5d_tensor_50_49_48_47_46_, s64_vector_5_, - HloGatherInstruction::MakeGatherDimNumbers( - /*output_window_dims=*/{0, 1, 2, 3, 4}, - /*elided_window_dims=*/{}, - /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, - /*index_vector_dim=*/0), - /*window_bounds=*/{30, 29, 28, 27, 26})); + TF_ASSERT_OK_AND_ASSIGN(Shape gather_shape, + ShapeInference::InferGatherShape( + f32_5d_tensor_50_49_48_47_46_, s64_vector_5_, + HloGatherInstruction::MakeGatherDimNumbers( + /*offset_dims=*/{0, 1, 2, 3, 4}, + /*collapsed_slice_dims=*/{}, + /*start_index_map=*/{0, 1, 2, 3, 4}, + /*index_vector_dim=*/0), + /*slice_sizes=*/{30, 29, 28, 27, 26})); EXPECT_TRUE(ShapeUtil::Equal(gather_shape, ShapeUtil::MakeShape(F32, {30, 29, 28, 27, 26}))) << ShapeUtil::HumanString(gather_shape); } -TEST_F(GatherShapeInferenceTest, ScalarGatherIndices) { +TEST_F(ScatterGatherShapeInferenceTest, ScalarGatherIndices) { // The gather indices "tensor" is a scalar S here that's used to slice out // [S,0,0,0,0]..[S,30,29,28,27] into a [30,29,28,27] shaped result. TF_ASSERT_OK_AND_ASSIGN(Shape gather_shape, ShapeInference::InferGatherShape( f32_5d_tensor_50_49_48_47_46_, s64_scalar_, HloGatherInstruction::MakeGatherDimNumbers( - /*output_window_dims=*/{0, 1, 2, 3}, - /*elided_window_dims=*/{0}, - /*gather_dims_to_operand_dims=*/{0}, + /*offset_dims=*/{0, 1, 2, 3}, + /*collapsed_slice_dims=*/{0}, + /*start_index_map=*/{0}, /*index_vector_dim=*/0), - /*window_bounds=*/{1, 30, 29, 28, 27})); + /*slice_sizes=*/{1, 30, 29, 28, 27})); EXPECT_TRUE(ShapeUtil::Equal(gather_shape, ShapeUtil::MakeShape(F32, {30, 29, 28, 27}))) << ShapeUtil::HumanString(gather_shape); } -TEST_F(GatherShapeInferenceTest, TupleShapedTensorInput) { +TEST_F(ScatterGatherShapeInferenceTest, TupleShapedTensorInput) { StatusOr statusor = ShapeInference::InferGatherShape( tuple_shape_, s64_vector_32_, HloGatherInstruction::MakeGatherDimNumbers( - /*output_window_dims=*/{0}, - /*elided_window_dims=*/{1}, - /*gather_dims_to_operand_dims=*/{1}, + /*offset_dims=*/{0}, + /*collapsed_slice_dims=*/{1}, + /*start_index_map=*/{1}, /*index_vector_dim=*/1), - /*window_bounds=*/{64, 1}); + /*slice_sizes=*/{64, 1}); ASSERT_FALSE(statusor.ok()); EXPECT_THAT(statusor.status().error_message(), HasSubstr("Expected array argument for input")) << statusor.status(); } -TEST_F(GatherShapeInferenceTest, TupleShapedGatherIndicesInput) { +TEST_F(ScatterGatherShapeInferenceTest, TupleShapedGatherIndicesInput) { StatusOr statusor = ShapeInference::InferGatherShape( s64_vector_32_, tuple_shape_, HloGatherInstruction::MakeGatherDimNumbers( - /*output_window_dims=*/{0}, - /*elided_window_dims=*/{1}, - /*gather_dims_to_operand_dims=*/{1}, + /*offset_dims=*/{0}, + /*collapsed_slice_dims=*/{1}, + /*start_index_map=*/{1}, /*index_vector_dim=*/0), - /*window_bounds=*/{64, 1}); + /*slice_sizes=*/{64, 1}); ASSERT_FALSE(statusor.ok()); EXPECT_THAT(statusor.status().error_message(), HasSubstr("Expected array argument for gather indices")) << statusor.status(); } -TEST_F(GatherShapeInferenceTest, FloatingPointGatherIndicesInput) { +TEST_F(ScatterGatherShapeInferenceTest, FloatingPointGatherIndicesInput) { StatusOr statusor = ShapeInference::InferGatherShape( s64_vector_32_, vector_32_, HloGatherInstruction::MakeGatherDimNumbers( - /*output_window_dims=*/{0}, - /*elided_window_dims=*/{1}, - /*gather_dims_to_operand_dims=*/{1}, + /*offset_dims=*/{0}, + /*collapsed_slice_dims=*/{1}, + /*start_index_map=*/{1}, /*index_vector_dim=*/0), - /*window_bounds=*/{64, 1}); + /*slice_sizes=*/{64, 1}); ASSERT_FALSE(statusor.ok()); EXPECT_THAT(statusor.status().error_message(), HasSubstr("Gather indices parameter must be an integral tensor")) << statusor.status(); } -TEST_F(GatherShapeInferenceTest, +TEST_F(ScatterGatherShapeInferenceTest, InvalidGatherDimNumbers_NonAscendingWindowIndices) { StatusOr statusor = ShapeInference::InferGatherShape( f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, HloGatherInstruction::MakeGatherDimNumbers( - /*output_window_dims=*/{4, 5, 6, 8, 7}, - /*elided_window_dims=*/{}, - /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, + /*offset_dims=*/{4, 5, 6, 8, 7}, + /*collapsed_slice_dims=*/{}, + /*start_index_map=*/{0, 1, 2, 3, 4}, /*index_vector_dim=*/4), - /*window_bounds=*/{30, 29, 28, 27, 26}); + /*slice_sizes=*/{30, 29, 28, 27, 26}); ASSERT_FALSE(statusor.ok()); EXPECT_THAT( statusor.status().error_message(), @@ -1751,16 +1844,16 @@ TEST_F(GatherShapeInferenceTest, << statusor.status(); } -TEST_F(GatherShapeInferenceTest, +TEST_F(ScatterGatherShapeInferenceTest, InvalidGatherDimNumbers_RepeatedWindowIndices) { StatusOr statusor = ShapeInference::InferGatherShape( f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, HloGatherInstruction::MakeGatherDimNumbers( - /*output_window_dims=*/{4, 5, 6, 7, 7}, - /*elided_window_dims=*/{}, - /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, + /*offset_dims=*/{4, 5, 6, 7, 7}, + /*collapsed_slice_dims=*/{}, + /*start_index_map=*/{0, 1, 2, 3, 4}, /*index_vector_dim=*/4), - /*window_bounds=*/{30, 29, 28, 27, 26}); + /*slice_sizes=*/{30, 29, 28, 27, 26}); ASSERT_FALSE(statusor.ok()); EXPECT_THAT( statusor.status().error_message(), @@ -1768,227 +1861,792 @@ TEST_F(GatherShapeInferenceTest, << statusor.status(); } -TEST_F(GatherShapeInferenceTest, +TEST_F(ScatterGatherShapeInferenceTest, InvalidGatherDimNumbers_WindowIndexOutOfBounds) { StatusOr statusor = ShapeInference::InferGatherShape( f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, HloGatherInstruction::MakeGatherDimNumbers( - /*output_window_dims=*/{4, 5, 99, 100, 101}, - /*elided_window_dims=*/{}, - /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, + /*offset_dims=*/{4, 5, 99, 100, 101}, + /*collapsed_slice_dims=*/{}, + /*start_index_map=*/{0, 1, 2, 3, 4}, /*index_vector_dim=*/4), - /*window_bounds=*/{30, 29, 28, 27, 26}); + /*slice_sizes=*/{30, 29, 28, 27, 26}); ASSERT_FALSE(statusor.ok()); EXPECT_THAT(statusor.status().error_message(), - HasSubstr("Window index 2 in gather op is out of bounds")) + HasSubstr("Offset dimension 2 in gather op is out of bounds")) << statusor.status(); } -TEST_F(GatherShapeInferenceTest, +TEST_F(ScatterGatherShapeInferenceTest, InvalidGatherDimNumbers_WindowIndexBarelyOutOfBounds) { StatusOr statusor = ShapeInference::InferGatherShape( f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, HloGatherInstruction::MakeGatherDimNumbers( - /*output_window_dims=*/{4, 5, 6, 7, 9}, - /*elided_window_dims=*/{}, - /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, + /*offset_dims=*/{4, 5, 6, 7, 9}, + /*collapsed_slice_dims=*/{}, + /*start_index_map=*/{0, 1, 2, 3, 4}, /*index_vector_dim=*/4), - /*window_bounds=*/{30, 29, 28, 27, 26}); + /*slice_sizes=*/{30, 29, 28, 27, 26}); ASSERT_FALSE(statusor.ok()); EXPECT_THAT(statusor.status().error_message(), - HasSubstr("Window index 4 in gather op is out of bounds")) + HasSubstr("Offset dimension 4 in gather op is out of bounds")) << statusor.status(); } -TEST_F(GatherShapeInferenceTest, +TEST_F(ScatterGatherShapeInferenceTest, InvalidGatherDimNumbers_MismatchingElidedWindowDims) { StatusOr statusor = ShapeInference::InferGatherShape( f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, HloGatherInstruction::MakeGatherDimNumbers( - /*output_window_dims=*/{4, 5, 6, 7, 8}, - /*elided_window_dims=*/{4}, - /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, + /*offset_dims=*/{4, 5, 6, 7, 8}, + /*collapsed_slice_dims=*/{4}, + /*start_index_map=*/{0, 1, 2, 3, 4}, /*index_vector_dim=*/4), - /*window_bounds=*/{30, 29, 28, 27, 26}); + /*slice_sizes=*/{30, 29, 28, 27, 26}); ASSERT_FALSE(statusor.ok()); EXPECT_THAT( statusor.status().error_message(), - HasSubstr("All components of the window index in a gather op must either " - "be a output window index or explicitly elided")) + HasSubstr("All components of the offset index in a gather op must either " + "be a offset dimension or explicitly collapsed")) << statusor.status(); } -TEST_F(GatherShapeInferenceTest, +TEST_F(ScatterGatherShapeInferenceTest, InvalidGatherDimNumbers_OutOfBoundsWindowToInputMapping) { StatusOr statusor = ShapeInference::InferGatherShape( f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, HloGatherInstruction::MakeGatherDimNumbers( - /*output_window_dims=*/{4, 5, 6, 7, 8}, - /*elided_window_dims=*/{0, 1, 2, 3, 19}, - /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, + /*offset_dims=*/{4, 5, 6, 7, 8}, + /*collapsed_slice_dims=*/{0, 1, 2, 3, 19}, + /*start_index_map=*/{0, 1, 2, 3, 4}, /*index_vector_dim=*/4), - /*window_bounds=*/{30, 29, 28, 27, 26}); + /*slice_sizes=*/{30, 29, 28, 27, 26}); ASSERT_FALSE(statusor.ok()); EXPECT_THAT(statusor.status().error_message(), - HasSubstr("Invalid elided_window_dims set in gather op; valid " + HasSubstr("Invalid collapsed_slice_dims set in gather op; valid " "range is [0, 5), got: 19")) << statusor.status(); } -TEST_F(GatherShapeInferenceTest, +TEST_F(ScatterGatherShapeInferenceTest, InvalidGatherDimNumbers_RepeatedWindowToInputMapping) { StatusOr statusor = ShapeInference::InferGatherShape( f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, HloGatherInstruction::MakeGatherDimNumbers( - /*output_window_dims=*/{4, 5, 6, 7, 8}, - /*elided_window_dims=*/{0, 1, 2, 3, 3}, - /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, + /*offset_dims=*/{4, 5, 6, 7, 8}, + /*collapsed_slice_dims=*/{0, 1, 2, 3, 3}, + /*start_index_map=*/{0, 1, 2, 3, 4}, /*index_vector_dim=*/4), - /*window_bounds=*/{30, 29, 28, 27, 26}); + /*slice_sizes=*/{30, 29, 28, 27, 26}); ASSERT_FALSE(statusor.ok()); - EXPECT_THAT( - statusor.status().error_message(), - HasSubstr( - "Repeated dimensions not allowed in elided_window_dims in gather op")) + EXPECT_THAT(statusor.status().error_message(), + HasSubstr("Repeated dimensions not allowed in " + "collapsed_slice_dims in gather op")) << statusor.status(); } -TEST_F(GatherShapeInferenceTest, +TEST_F(ScatterGatherShapeInferenceTest, InvalidGatherDimNumbers_MismatchingGatherToInputMapping) { StatusOr statusor = ShapeInference::InferGatherShape( f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, HloGatherInstruction::MakeGatherDimNumbers( - /*output_window_dims=*/{4, 5, 6, 7, 8}, - /*elided_window_dims=*/{}, - /*gather_dims_to_operand_dims=*/{0, 1, 2, 3}, + /*offset_dims=*/{4, 5, 6, 7, 8}, + /*collapsed_slice_dims=*/{}, + /*start_index_map=*/{0, 1, 2, 3}, /*index_vector_dim=*/4), - /*window_bounds=*/{30, 29, 28, 27, 26}); + /*slice_sizes=*/{30, 29, 28, 27, 26}); ASSERT_FALSE(statusor.ok()); - EXPECT_THAT( - statusor.status().error_message(), - HasSubstr("Gather op has 4 elements in gather_dims_to_operand_dims and " - "the bound of dimension index_vector_dim=4 of " - "gather_indices is 5. These two numbers must be equal.")) + EXPECT_THAT(statusor.status().error_message(), + HasSubstr("Gather op has 4 elements in start_index_map and " + "the bound of dimension index_vector_dim=4 of " + "start_indices is 5. These two numbers must be equal.")) << statusor.status(); } -TEST_F(GatherShapeInferenceTest, +TEST_F(ScatterGatherShapeInferenceTest, InvalidGatherDimNumbers_OutOfBoundsGatherToInputMapping) { StatusOr statusor = ShapeInference::InferGatherShape( f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, HloGatherInstruction::MakeGatherDimNumbers( - /*output_window_dims=*/{4, 5, 6, 7, 8}, - /*elided_window_dims=*/{}, - /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 7}, + /*offset_dims=*/{4, 5, 6, 7, 8}, + /*collapsed_slice_dims=*/{}, + /*start_index_map=*/{0, 1, 2, 3, 7}, /*index_vector_dim=*/4), - /*window_bounds=*/{30, 29, 28, 27, 26}); + /*slice_sizes=*/{30, 29, 28, 27, 26}); ASSERT_FALSE(statusor.ok()); - EXPECT_THAT( - statusor.status().error_message(), - HasSubstr("Invalid gather_dims_to_operand_dims mapping; domain is " - "[0, 5), got: 4->7")) + EXPECT_THAT(statusor.status().error_message(), + HasSubstr("Invalid start_index_map; domain is [0, 5), got: 4->7")) << statusor.status(); } -TEST_F(GatherShapeInferenceTest, +TEST_F(ScatterGatherShapeInferenceTest, InvalidGatherDimNumbers_RepeatedGatherToInputMapping) { StatusOr statusor = ShapeInference::InferGatherShape( f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, HloGatherInstruction::MakeGatherDimNumbers( - /*output_window_dims=*/{4, 5, 6, 7, 8}, - /*elided_window_dims=*/{}, - /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 3}, + /*offset_dims=*/{4, 5, 6, 7, 8}, + /*collapsed_slice_dims=*/{}, + /*start_index_map=*/{0, 1, 2, 3, 3}, /*index_vector_dim=*/4), - /*window_bounds=*/{30, 29, 28, 27, 26}); + /*slice_sizes=*/{30, 29, 28, 27, 26}); ASSERT_FALSE(statusor.ok()); EXPECT_THAT( statusor.status().error_message(), - HasSubstr( - "Repeated dimensions are not allowed in gather_dims_to_operand_dims")) + HasSubstr("Repeated dimensions are not allowed in start_index_map")) << statusor.status(); } -TEST_F(GatherShapeInferenceTest, +TEST_F(ScatterGatherShapeInferenceTest, InvalidGatherDimNumbers_NonAscendingElidedWindowDims) { StatusOr statusor = ShapeInference::InferGatherShape( f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, HloGatherInstruction::MakeGatherDimNumbers( - /*output_window_dims=*/{4, 5, 6, 7, 8}, - /*elided_window_dims=*/{2, 1}, - /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, + /*offset_dims=*/{4, 5, 6, 7, 8}, + /*collapsed_slice_dims=*/{2, 1}, + /*start_index_map=*/{0, 1, 2, 3, 4}, /*index_vector_dim=*/4), - /*window_bounds=*/{1, 1, 28, 27, 26}); + /*slice_sizes=*/{1, 1, 28, 27, 26}); ASSERT_FALSE(statusor.ok()); EXPECT_THAT(statusor.status().error_message(), - HasSubstr("elided_window_dims in gather op must be sorted")) + HasSubstr("collapsed_slice_dims in gather op must be sorted")) << statusor.status(); } -TEST_F(GatherShapeInferenceTest, InvalidGatherDimNumbers_WindowBoundsTooLarge) { +TEST_F(ScatterGatherShapeInferenceTest, + InvalidGatherDimNumbers_WindowBoundsTooLarge) { StatusOr statusor = ShapeInference::InferGatherShape( f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, HloGatherInstruction::MakeGatherDimNumbers( - /*output_window_dims=*/{4, 5, 6, 7}, - /*elided_window_dims=*/{2}, - /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, + /*offset_dims=*/{4, 5, 6, 7}, + /*collapsed_slice_dims=*/{2}, + /*start_index_map=*/{0, 1, 2, 3, 4}, /*index_vector_dim=*/4), - /*window_bounds=*/{30, 29, 1, 300, 26}); + /*slice_sizes=*/{30, 29, 1, 300, 26}); ASSERT_FALSE(statusor.ok()); EXPECT_THAT(statusor.status().error_message(), - HasSubstr("Window bound at index 3 in gather op is out of range, " - "must be within [0, 48), got 300")) + HasSubstr("Slice size at index 3 in gather op is out of range, " + "must be within [0, 48), got 300.")) << statusor.status(); } -TEST_F(GatherShapeInferenceTest, +TEST_F(ScatterGatherShapeInferenceTest, InvalidGatherDimNumbers_MismatchingNumberOfWindowBounds) { StatusOr statusor = ShapeInference::InferGatherShape( f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, HloGatherInstruction::MakeGatherDimNumbers( - /*output_window_dims=*/{4, 5, 6, 7, 8}, - /*elided_window_dims=*/{}, - /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, + /*offset_dims=*/{4, 5, 6, 7, 8}, + /*collapsed_slice_dims=*/{}, + /*start_index_map=*/{0, 1, 2, 3, 4}, /*index_vector_dim=*/4), - /*window_bounds=*/{30, 29, 28, 26}); + /*slice_sizes=*/{30, 29, 28, 26}); ASSERT_FALSE(statusor.ok()); EXPECT_THAT( statusor.status().error_message(), - HasSubstr( - "Gather op must have one window bound for every input dimension")) + HasSubstr("Gather op must have one slice size for every input dimension")) << statusor.status(); } -TEST_F(GatherShapeInferenceTest, +TEST_F(ScatterGatherShapeInferenceTest, InvalidGatherDimNumbers_WindowBoundsNot1ForElidedDim) { StatusOr statusor = ShapeInference::InferGatherShape( f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, HloGatherInstruction::MakeGatherDimNumbers( - /*output_window_dims=*/{4, 5, 6, 7}, - /*elided_window_dims=*/{1}, - /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, + /*offset_dims=*/{4, 5, 6, 7}, + /*collapsed_slice_dims=*/{1}, + /*start_index_map=*/{0, 1, 2, 3, 4}, /*index_vector_dim=*/4), - /*window_bounds=*/{30, 29, 28, 26, 20}); + /*slice_sizes=*/{30, 29, 28, 26, 20}); ASSERT_FALSE(statusor.ok()); EXPECT_THAT(statusor.status().error_message(), - HasSubstr("Gather op can only elide window indices with bound 1, " - "but bound is 29 for index 1 at position 0")) + HasSubstr("Gather op can only collapse slice dims with bound 1, " + "but bound is 29 for index 1 at position 0.")) << statusor.status(); } -TEST_F(GatherShapeInferenceTest, OutOfBoundsGatherIndicesLeafDim) { +TEST_F(ScatterGatherShapeInferenceTest, OutOfBoundsGatherIndicesLeafDim) { StatusOr statusor = ShapeInference::InferGatherShape( f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_5_7_6_, HloGatherInstruction::MakeGatherDimNumbers( - /*output_window_dims=*/{4, 5, 6, 7, 8}, - /*elided_window_dims=*/{}, - /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, + /*offset_dims=*/{4, 5, 6, 7, 8}, + /*collapsed_slice_dims=*/{}, + /*start_index_map=*/{0, 1, 2, 3, 4}, /*index_vector_dim=*/32), - /*window_bounds=*/{30, 29, 28, 27, 26}); + /*slice_sizes=*/{30, 29, 28, 27, 26}); ASSERT_FALSE(statusor.ok()); EXPECT_THAT(statusor.status().error_message(), HasSubstr("Gather index leaf dimension must be within [0, " - "rank(gather_indices) + 1)")) + "rank(start_indices) + 1)")) + << statusor.status(); +} + +// Shape inference tests for Scatter. + +TEST_F(ScatterGatherShapeInferenceTest, TfScatterWithFullUpdates) { + TF_ASSERT_OK_AND_ASSIGN(Shape scatter_shape, + ShapeInference::InferScatterShape( + matrix_64_48_, s64_vector_32_, + ShapeUtil::MakeShape(F32, {64, 32}), to_apply_, + HloScatterInstruction::MakeScatterDimNumbers( + /*update_window_dims=*/{0}, + /*inserted_window_dims=*/{1}, + /*scatter_dims_to_operand_dims=*/{1}, + /*index_vector_dim=*/1))); + EXPECT_TRUE(ShapeUtil::Equal(scatter_shape, matrix_64_48_)) + << ShapeUtil::HumanString(scatter_shape); +} + +TEST_F(ScatterGatherShapeInferenceTest, TfScatterWithFullUpdatesV2) { + TF_ASSERT_OK_AND_ASSIGN(Shape scatter_shape, + ShapeInference::InferScatterShape( + matrix_64_48_, s64_vector_32_, + ShapeUtil::MakeShape(F32, {32, 48}), to_apply_, + HloScatterInstruction::MakeScatterDimNumbers( + /*update_window_dims=*/{1}, + /*inserted_window_dims=*/{0}, + /*scatter_dims_to_operand_dims=*/{0}, + /*index_vector_dim=*/1))); + EXPECT_TRUE(ShapeUtil::Equal(scatter_shape, matrix_64_48_)) + << ShapeUtil::HumanString(scatter_shape); +} + +TEST_F(ScatterGatherShapeInferenceTest, TfScatterWithPartialUpdates) { + TF_ASSERT_OK_AND_ASSIGN(Shape scatter_shape, + ShapeInference::InferScatterShape( + matrix_64_48_, s64_vector_32_, + ShapeUtil::MakeShape(F32, {10, 32}), to_apply_, + HloScatterInstruction::MakeScatterDimNumbers( + /*update_window_dims=*/{0}, + /*inserted_window_dims=*/{1}, + /*scatter_dims_to_operand_dims=*/{1}, + /*index_vector_dim=*/1))); + EXPECT_TRUE(ShapeUtil::Equal(scatter_shape, matrix_64_48_)) + << ShapeUtil::HumanString(scatter_shape); +} + +TEST_F(ScatterGatherShapeInferenceTest, TfScatterWithPartialUpdatesV2) { + TF_ASSERT_OK_AND_ASSIGN(Shape scatter_shape, + ShapeInference::InferScatterShape( + matrix_64_48_, s64_vector_32_, + ShapeUtil::MakeShape(F32, {32, 8}), to_apply_, + HloScatterInstruction::MakeScatterDimNumbers( + /*update_window_dims=*/{1}, + /*inserted_window_dims=*/{0}, + /*scatter_dims_to_operand_dims=*/{0}, + /*index_vector_dim=*/1))); + EXPECT_TRUE(ShapeUtil::Equal(scatter_shape, matrix_64_48_)) + << ShapeUtil::HumanString(scatter_shape); +} + +TEST_F(ScatterGatherShapeInferenceTest, TfScatterWithUpdatesBiggerThanInput) { + StatusOr statusor = ShapeInference::InferScatterShape( + matrix_64_48_, s64_vector_32_, ShapeUtil::MakeShape(F32, {65, 32}), + to_apply_, + HloScatterInstruction::MakeScatterDimNumbers( + /*update_window_dims=*/{0}, + /*inserted_window_dims=*/{1}, + /*scatter_dims_to_operand_dims=*/{1}, + /*index_vector_dim=*/1)); + ASSERT_FALSE(statusor.ok()); + EXPECT_THAT( + statusor.status().error_message(), + HasSubstr("Bounds of the window dimensions of updates must not exceed " + "the bounds of the corresponding dimensions of operand.")) + << statusor.status(); +} + +TEST_F(ScatterGatherShapeInferenceTest, TfScatterWithUpdatesBiggerThanInputV2) { + StatusOr statusor = ShapeInference::InferScatterShape( + matrix_64_48_, s64_vector_32_, ShapeUtil::MakeShape(F32, {32, 49}), + to_apply_, + HloScatterInstruction::MakeScatterDimNumbers( + /*update_window_dims=*/{1}, + /*inserted_window_dims=*/{0}, + /*scatter_dims_to_operand_dims=*/{1}, + /*index_vector_dim=*/1)); + ASSERT_FALSE(statusor.ok()); + EXPECT_THAT( + statusor.status().error_message(), + HasSubstr("Bounds of the window dimensions of updates must not exceed " + "the bounds of the corresponding dimensions of operand.")) + << statusor.status(); +} + +TEST_F(ScatterGatherShapeInferenceTest, + TfScatterWithUpdatesNotMatchingIndices) { + StatusOr statusor = ShapeInference::InferScatterShape( + matrix_64_48_, s64_vector_32_, ShapeUtil::MakeShape(F32, {64, 31}), + to_apply_, + HloScatterInstruction::MakeScatterDimNumbers( + /*update_window_dims=*/{0}, + /*inserted_window_dims=*/{1}, + /*scatter_dims_to_operand_dims=*/{1}, + /*index_vector_dim=*/1)); + ASSERT_FALSE(statusor.ok()); + EXPECT_THAT( + statusor.status().error_message(), + HasSubstr( + "Bounds of the scatter dimensions of updates must be same as the " + "bounds of the corresponding dimensions of scatter indices.")) + << statusor.status(); +} + +TEST_F(ScatterGatherShapeInferenceTest, + TfScatterWithUpdatesNotMatchingIndicesV2) { + StatusOr statusor = ShapeInference::InferScatterShape( + matrix_64_48_, s64_vector_32_, ShapeUtil::MakeShape(F32, {31, 48}), + to_apply_, + HloScatterInstruction::MakeScatterDimNumbers( + /*update_window_dims=*/{1}, + /*inserted_window_dims=*/{0}, + /*scatter_dims_to_operand_dims=*/{1}, + /*index_vector_dim=*/1)); + ASSERT_FALSE(statusor.ok()); + EXPECT_THAT( + statusor.status().error_message(), + HasSubstr( + "Bounds of the scatter dimensions of updates must be same as the " + "bounds of the corresponding dimensions of scatter indices.")) + << statusor.status(); +} + +TEST_F(ScatterGatherShapeInferenceTest, TfScatterNdWithFullUpdates) { + TF_ASSERT_OK_AND_ASSIGN( + Shape scatter_shape, + ShapeInference::InferScatterShape( + matrix_64_48_, s64_4d_tensor_10_9_8_7_1_, + ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 48}), to_apply_, + HloScatterInstruction::MakeScatterDimNumbers( + /*update_window_dims=*/{4}, + /*inserted_window_dims=*/{0}, + /*scatter_dims_to_operand_dims=*/{0}, + /*index_vector_dim=*/4))); + EXPECT_TRUE(ShapeUtil::Equal(scatter_shape, matrix_64_48_)) + << ShapeUtil::HumanString(scatter_shape); +} + +TEST_F(ScatterGatherShapeInferenceTest, TfScatterNdWithFullUpdatesV2) { + TF_ASSERT_OK_AND_ASSIGN( + Shape scatter_shape, + ShapeInference::InferScatterShape( + matrix_64_48_, s64_4d_tensor_10_9_8_7_1_, + ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 64}), to_apply_, + HloScatterInstruction::MakeScatterDimNumbers( + /*update_window_dims=*/{4}, + /*inserted_window_dims=*/{1}, + /*scatter_dims_to_operand_dims=*/{0}, + /*index_vector_dim=*/4))); + EXPECT_TRUE(ShapeUtil::Equal(scatter_shape, matrix_64_48_)) + << ShapeUtil::HumanString(scatter_shape); +} + +TEST_F(ScatterGatherShapeInferenceTest, TfScatterNdWithPartialUpdates) { + TF_ASSERT_OK_AND_ASSIGN( + Shape scatter_shape, + ShapeInference::InferScatterShape( + matrix_64_48_, s64_4d_tensor_10_9_8_7_1_, + ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 10}), to_apply_, + HloScatterInstruction::MakeScatterDimNumbers( + /*update_window_dims=*/{4}, + /*inserted_window_dims=*/{0}, + /*scatter_dims_to_operand_dims=*/{0}, + /*index_vector_dim=*/4))); + EXPECT_TRUE(ShapeUtil::Equal(scatter_shape, matrix_64_48_)) + << ShapeUtil::HumanString(scatter_shape); +} + +TEST_F(ScatterGatherShapeInferenceTest, TfScatterNdWithPartialUpdatesV2) { + TF_ASSERT_OK_AND_ASSIGN( + Shape scatter_shape, + ShapeInference::InferScatterShape( + matrix_64_48_, s64_4d_tensor_10_9_8_7_1_, + ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 12}), to_apply_, + HloScatterInstruction::MakeScatterDimNumbers( + /*update_window_dims=*/{4}, + /*inserted_window_dims=*/{1}, + /*scatter_dims_to_operand_dims=*/{0}, + /*index_vector_dim=*/4))); + EXPECT_TRUE(ShapeUtil::Equal(scatter_shape, matrix_64_48_)) + << ShapeUtil::HumanString(scatter_shape); +} + +TEST_F(ScatterGatherShapeInferenceTest, TfScatterNdWithUpdatesBiggerThanInput) { + StatusOr statusor = ShapeInference::InferScatterShape( + matrix_64_48_, s64_4d_tensor_10_9_8_7_1_, + ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 65}), to_apply_, + HloScatterInstruction::MakeScatterDimNumbers( + /*update_window_dims=*/{4}, + /*inserted_window_dims=*/{1}, + /*scatter_dims_to_operand_dims=*/{0}, + /*index_vector_dim=*/4)); + ASSERT_FALSE(statusor.ok()); + EXPECT_THAT( + statusor.status().error_message(), + HasSubstr("Bounds of the window dimensions of updates must not exceed " + "the bounds of the corresponding dimensions of operand.")) + << statusor.status(); +} + +TEST_F(ScatterGatherShapeInferenceTest, + TfScatterNdWithUpdatesNotMatchingIndices) { + StatusOr statusor = ShapeInference::InferScatterShape( + matrix_64_48_, s64_4d_tensor_10_9_8_7_1_, + ShapeUtil::MakeShape(F32, {9, 9, 8, 7, 64}), to_apply_, + HloScatterInstruction::MakeScatterDimNumbers( + /*update_window_dims=*/{4}, + /*inserted_window_dims=*/{1}, + /*scatter_dims_to_operand_dims=*/{0}, + /*index_vector_dim=*/4)); + ASSERT_FALSE(statusor.ok()); + EXPECT_THAT( + statusor.status().error_message(), + HasSubstr( + "Bounds of the scatter dimensions of updates must be same as the " + "bounds of the corresponding dimensions of scatter indices.")) + << statusor.status(); +} + +TEST_F(ScatterGatherShapeInferenceTest, TfBatchDynamicUpdateSlice) { + TF_ASSERT_OK_AND_ASSIGN( + Shape scatter_shape, + ShapeInference::InferScatterShape( + f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, + ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 30, 29, 28, 27, 26}), + to_apply_, + HloScatterInstruction::MakeScatterDimNumbers( + /*update_window_dims=*/{4, 5, 6, 7, 8}, + /*inserted_window_dims=*/{}, + /*scatter_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, + /*index_vector_dim=*/4))); + EXPECT_TRUE(ShapeUtil::Equal(scatter_shape, f32_5d_tensor_50_49_48_47_46_)) + << ShapeUtil::HumanString(scatter_shape); +} + +TEST_F(ScatterGatherShapeInferenceTest, NonDefaultScatterIndicesLeafDim) { + TF_ASSERT_OK_AND_ASSIGN( + Shape scatter_shape, + ShapeInference::InferScatterShape( + f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_5_7_6_, + ShapeUtil::MakeShape(F32, {10, 9, 7, 6, 30, 29, 28, 27, 26}), + to_apply_, + HloScatterInstruction::MakeScatterDimNumbers( + /*update_window_dims=*/{4, 5, 6, 7, 8}, + /*inserted_window_dims=*/{}, + /*scatter_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, + /*index_vector_dim=*/2))); + + EXPECT_TRUE(ShapeUtil::Equal(scatter_shape, f32_5d_tensor_50_49_48_47_46_)) + << ShapeUtil::HumanString(scatter_shape); +} + +TEST_F(ScatterGatherShapeInferenceTest, NonDefaultScatterIndicesLeafDimV2) { + TF_ASSERT_OK_AND_ASSIGN( + Shape scatter_shape, + ShapeInference::InferScatterShape( + f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_5_10_9_7_6_, + ShapeUtil::MakeShape(F32, {10, 9, 7, 6, 30, 29, 28, 27, 26}), + to_apply_, + HloScatterInstruction::MakeScatterDimNumbers( + /*update_window_dims=*/{4, 5, 6, 7, 8}, + /*inserted_window_dims=*/{}, + /*scatter_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, + /*index_vector_dim=*/0))); + + EXPECT_TRUE(ShapeUtil::Equal(scatter_shape, f32_5d_tensor_50_49_48_47_46_)) + << ShapeUtil::HumanString(scatter_shape); +} + +TEST_F(ScatterGatherShapeInferenceTest, NoUpdateScatterDims) { + // This is equivalent to a dynamic update slice. + TF_ASSERT_OK_AND_ASSIGN( + Shape scatter_shape, + ShapeInference::InferScatterShape( + f32_5d_tensor_50_49_48_47_46_, s64_vector_5_, + ShapeUtil::MakeShape(F32, {30, 29, 28, 27, 26}), to_apply_, + HloScatterInstruction::MakeScatterDimNumbers( + /*update_window_dims=*/{0, 1, 2, 3, 4}, + /*inserted_window_dims=*/{}, + /*scatter_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, + /*index_vector_dim=*/0))); + + EXPECT_TRUE(ShapeUtil::Equal(scatter_shape, f32_5d_tensor_50_49_48_47_46_)) + << ShapeUtil::HumanString(scatter_shape); +} + +TEST_F(ScatterGatherShapeInferenceTest, ScalarScatterIndices) { + // The scalar indices "tensor" is a scalar S here that's used to update a + // [30,29,28,27] shaped tensor within the operand at position S. + TF_ASSERT_OK_AND_ASSIGN( + Shape scatter_shape, + ShapeInference::InferScatterShape( + f32_5d_tensor_50_49_48_47_46_, s64_scalar_, + ShapeUtil::MakeShape(F32, {30, 29, 28, 27}), to_apply_, + HloScatterInstruction::MakeScatterDimNumbers( + /*update_window_dims=*/{0, 1, 2, 3}, + /*inserted_window_dims=*/{0}, + /*scatter_dims_to_operand_dims=*/{0}, + /*index_vector_dim=*/0))); + + EXPECT_TRUE(ShapeUtil::Equal(scatter_shape, f32_5d_tensor_50_49_48_47_46_)) + << ShapeUtil::HumanString(scatter_shape); +} + +TEST_F(ScatterGatherShapeInferenceTest, ScatterWithTupleShapedTensorInput) { + StatusOr statusor = ShapeInference::InferScatterShape( + tuple_shape_, s64_vector_32_, s64_vector_32_, to_apply_, + HloScatterInstruction::MakeScatterDimNumbers( + /*update_window_dims=*/{0}, + /*inserted_window_dims=*/{1}, + /*scatter_dims_to_operand_dims=*/{1}, + /*index_vector_dim=*/1)); + ASSERT_FALSE(statusor.ok()); + EXPECT_THAT(statusor.status().error_message(), + HasSubstr("Expected array argument for operand")) + << statusor.status(); +} + +TEST_F(ScatterGatherShapeInferenceTest, + ScatterWithTupleShapedScatterIndicesInput) { + StatusOr statusor = ShapeInference::InferScatterShape( + s64_vector_32_, tuple_shape_, s64_vector_32_, to_apply_, + HloScatterInstruction::MakeScatterDimNumbers( + /*update_window_dims=*/{0}, + /*inserted_window_dims=*/{1}, + /*scatter_dims_to_operand_dims=*/{1}, + /*index_vector_dim=*/0)); + ASSERT_FALSE(statusor.ok()); + EXPECT_THAT(statusor.status().error_message(), + HasSubstr("Expected array argument for scatter indices")) + << statusor.status(); +} + +TEST_F(ScatterGatherShapeInferenceTest, ScatterWithTupleShapedUpdatesInput) { + StatusOr statusor = ShapeInference::InferScatterShape( + s64_vector_32_, s64_vector_32_, tuple_shape_, to_apply_, + HloScatterInstruction::MakeScatterDimNumbers( + /*update_window_dims=*/{0}, + /*inserted_window_dims=*/{1}, + /*scatter_dims_to_operand_dims=*/{1}, + /*index_vector_dim=*/0)); + ASSERT_FALSE(statusor.ok()); + EXPECT_THAT(statusor.status().error_message(), + HasSubstr("Expected array argument for updates")) + << statusor.status(); +} + +TEST_F(ScatterGatherShapeInferenceTest, FloatingPointScatterIndicesInput) { + StatusOr statusor = ShapeInference::InferScatterShape( + s64_vector_32_, vector_32_, s64_vector_32_, to_apply_, + HloScatterInstruction::MakeScatterDimNumbers( + /*update_window_dims=*/{0}, + /*inserted_window_dims=*/{1}, + /*scatter_dims_to_operand_dims=*/{1}, + /*index_vector_dim=*/0)); + ASSERT_FALSE(statusor.ok()); + EXPECT_THAT(statusor.status().error_message(), + HasSubstr("Scatter indices parameter must be an integral tensor")) + << statusor.status(); +} + +TEST_F(ScatterGatherShapeInferenceTest, OutOfBoundsScatterIndicesLeafDim) { + StatusOr statusor = ShapeInference::InferScatterShape( + f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, + ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 30, 29, 28}), to_apply_, + HloScatterInstruction::MakeScatterDimNumbers( + /*update_window_dims=*/{4, 5, 6}, + /*inserted_window_dims=*/{1, 2}, + /*scatter_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, + /*index_vector_dim=*/10)); + ASSERT_FALSE(statusor.ok()); + EXPECT_THAT(statusor.status().error_message(), + HasSubstr("Scatter index leaf dimension must be within [0, " + "rank(scatter_indices) + 1)")) + << statusor.status(); +} + +TEST_F(ScatterGatherShapeInferenceTest, InvalidUpdates) { + StatusOr statusor = ShapeInference::InferScatterShape( + f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, + ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 30, 29, 28, 50}), to_apply_, + HloScatterInstruction::MakeScatterDimNumbers( + /*update_window_dims=*/{4, 5, 6}, + /*inserted_window_dims=*/{1, 2}, + /*scatter_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, + /*index_vector_dim=*/4)); + ASSERT_FALSE(statusor.ok()); + EXPECT_THAT(statusor.status().error_message(), + HasSubstr("Updates tensor must be of rank 7; got 8.")) + << statusor.status(); +} + +TEST_F(ScatterGatherShapeInferenceTest, InvalidUpdateComputation) { + const ProgramShape invalid_update_computation = + ShapeUtil::MakeProgramShape({f32_}, f32_); + StatusOr statusor = ShapeInference::InferScatterShape( + f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, + ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 30, 29, 28}), + invalid_update_computation, + HloScatterInstruction::MakeScatterDimNumbers( + /*update_window_dims=*/{4, 5, 6}, + /*inserted_window_dims=*/{1, 2}, + /*scatter_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, + /*index_vector_dim=*/4)); + ASSERT_FALSE(statusor.ok()); + EXPECT_THAT( + statusor.status().error_message(), + HasSubstr("Reduction function must take 2 parameters, but takes 1")) + << statusor.status(); +} + +TEST_F(ScatterGatherShapeInferenceTest, + InvalidScatterDimNumbers_NonAscendingUpdateWindowDims) { + StatusOr statusor = ShapeInference::InferScatterShape( + f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, + ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 30, 29, 28, 27, 26}), to_apply_, + HloScatterInstruction::MakeScatterDimNumbers( + /*update_window_dims=*/{4, 5, 6, 8, 7}, + /*inserted_window_dims=*/{}, + /*scatter_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, + /*index_vector_dim=*/4)); + ASSERT_FALSE(statusor.ok()); + EXPECT_THAT(statusor.status().error_message(), + HasSubstr("update_window_dims in scatter op must be sorted")) + << statusor.status(); +} + +TEST_F(ScatterGatherShapeInferenceTest, + InvalidScatterDimNumbers_RepeatedUpdateWindowDims) { + StatusOr statusor = ShapeInference::InferScatterShape( + f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, + ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 30, 29, 28, 27, 26}), to_apply_, + HloScatterInstruction::MakeScatterDimNumbers( + /*update_window_dims=*/{4, 5, 6, 7, 7}, + /*inserted_window_dims=*/{}, + /*scatter_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, + /*index_vector_dim=*/4)); + ASSERT_FALSE(statusor.ok()); + EXPECT_THAT(statusor.status().error_message(), + HasSubstr("update_window_dims in scatter op must not repeat")) + << statusor.status(); +} + +TEST_F(ScatterGatherShapeInferenceTest, + InvalidScatterDimNumbers_OutOfBoundsUpdateWindowDims) { + StatusOr statusor = ShapeInference::InferScatterShape( + f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, + ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 30, 29, 28, 27, 26}), to_apply_, + HloScatterInstruction::MakeScatterDimNumbers( + /*update_window_dims=*/{4, 5, 6, 7, 9}, + /*inserted_window_dims=*/{}, + /*scatter_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, + /*index_vector_dim=*/4)); + ASSERT_FALSE(statusor.ok()); + EXPECT_THAT(statusor.status().error_message(), + HasSubstr("Invalid update_window_dims set in scatter op; valid " + "range is [0, 9)")) + << statusor.status(); +} + +TEST_F(ScatterGatherShapeInferenceTest, + InvalidScatterDimNumbers_NonAscendingInsertedWindowDims) { + StatusOr statusor = ShapeInference::InferScatterShape( + f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, + ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 30, 29, 28}), to_apply_, + HloScatterInstruction::MakeScatterDimNumbers( + /*update_window_dims=*/{4, 5, 6}, + /*inserted_window_dims=*/{2, 1}, + /*scatter_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, + /*index_vector_dim=*/4)); + ASSERT_FALSE(statusor.ok()); + EXPECT_THAT(statusor.status().error_message(), + HasSubstr("inserted_window_dims in scatter op must be sorted")) + << statusor.status(); +} + +TEST_F(ScatterGatherShapeInferenceTest, + InvalidScatterDimNumbers_RepeatedInsertedWindowDims) { + StatusOr statusor = ShapeInference::InferScatterShape( + f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, + ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 30, 29, 28}), to_apply_, + HloScatterInstruction::MakeScatterDimNumbers( + /*update_window_dims=*/{4, 5, 6}, + /*inserted_window_dims=*/{1, 1}, + /*scatter_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, + /*index_vector_dim=*/4)); + ASSERT_FALSE(statusor.ok()); + EXPECT_THAT(statusor.status().error_message(), + HasSubstr("inserted_window_dims in scatter op must not repeat")) + << statusor.status(); +} + +TEST_F(ScatterGatherShapeInferenceTest, + InvalidScatterDimNumbers_OutOfBoundsInsertedWindowDims) { + StatusOr statusor = ShapeInference::InferScatterShape( + f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, + ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 30, 29, 28}), to_apply_, + HloScatterInstruction::MakeScatterDimNumbers( + /*update_window_dims=*/{4, 5, 6}, + /*inserted_window_dims=*/{1, 5}, + /*scatter_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, + /*index_vector_dim=*/4)); + ASSERT_FALSE(statusor.ok()); + EXPECT_THAT(statusor.status().error_message(), + HasSubstr("Invalid inserted_window_dims set in scatter op; valid " + "range is [0, 5)")) + << statusor.status(); +} + +TEST_F(ScatterGatherShapeInferenceTest, + InvalidScatterDimNumbers_MismatchingScatterDimsToOperandDims) { + StatusOr statusor = ShapeInference::InferScatterShape( + f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, + ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 30, 29, 28}), to_apply_, + HloScatterInstruction::MakeScatterDimNumbers( + /*update_window_dims=*/{4, 5, 6}, + /*inserted_window_dims=*/{1, 2}, + /*scatter_dims_to_operand_dims=*/{0, 1, 2, 3}, + /*index_vector_dim=*/4)); + ASSERT_FALSE(statusor.ok()); + EXPECT_THAT( + statusor.status().error_message(), + HasSubstr("Scatter op has 4 elements in scatter_dims_to_operand_dims and " + "the bound of dimension index_vector_dim=4 of scatter_indices " + "is 5. These two numbers must be equal")) + << statusor.status(); +} + +TEST_F(ScatterGatherShapeInferenceTest, + InvalidScatterDimNumbers_OutOfBoundsScatterDimsToOperandDims) { + StatusOr statusor = ShapeInference::InferScatterShape( + f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, + ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 30, 29, 28}), to_apply_, + HloScatterInstruction::MakeScatterDimNumbers( + /*update_window_dims=*/{4, 5, 6}, + /*inserted_window_dims=*/{1, 2}, + /*scatter_dims_to_operand_dims=*/{0, 1, 2, 3, 10}, + /*index_vector_dim=*/4)); + ASSERT_FALSE(statusor.ok()); + EXPECT_THAT(statusor.status().error_message(), + HasSubstr("Invalid scatter_dims_to_operand_dims mapping; domain " + "is [0, 5), got: 4->10")) + << statusor.status(); +} + +TEST_F(ScatterGatherShapeInferenceTest, + InvalidScatterDimNumbers_RepeatedValuesInScatterDimsToOperandDims) { + StatusOr statusor = ShapeInference::InferScatterShape( + f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, + ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 30, 29, 28}), to_apply_, + HloScatterInstruction::MakeScatterDimNumbers( + /*update_window_dims=*/{4, 5, 6}, + /*inserted_window_dims=*/{1, 2}, + /*scatter_dims_to_operand_dims=*/{0, 1, 2, 2, 3}, + /*index_vector_dim=*/4)); + ASSERT_FALSE(statusor.ok()); + EXPECT_THAT( + statusor.status().error_message(), + HasSubstr( + "Repeated dimensions not allowed in scatter_dims_to_operand_dims")) << statusor.status(); } diff --git a/tensorflow/compiler/xla/service/shaped_buffer.cc b/tensorflow/compiler/xla/service/shaped_buffer.cc index 7d7dcac10b65933d1c81b8aca77465932694bfdb..70714ffff06b4ba4c13aae22290eff049ed3385c 100644 --- a/tensorflow/compiler/xla/service/shaped_buffer.cc +++ b/tensorflow/compiler/xla/service/shaped_buffer.cc @@ -18,8 +18,8 @@ limitations under the License. #include #include +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/layout_util.h" -#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/types.h" diff --git a/tensorflow/compiler/xla/service/shaped_buffer_test.cc b/tensorflow/compiler/xla/service/shaped_buffer_test.cc index 0fc243667911651c788e3c1e5f1d39d86170f1ad..d69e6362e91e4696dab3c46d99a981c67b593a1c 100644 --- a/tensorflow/compiler/xla/service/shaped_buffer_test.cc +++ b/tensorflow/compiler/xla/service/shaped_buffer_test.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/shaped_buffer.h" +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/service/device_memory_allocator.h" #include "tensorflow/compiler/xla/service/platform_util.h" #include "tensorflow/compiler/xla/shape_util.h" @@ -34,7 +35,7 @@ TEST(ShapedBufferTest, ScopedShapeBufferAsShapedBufferB71629047) { xla::StreamExecutorMemoryAllocator allocator(platform, executors); const xla::Shape shape = xla::ShapeUtil::MakeShape(xla::F32, {}); const int kDeviceOrdinal = 0; - auto scoped_buffer = tensorflow::MakeUnique( + auto scoped_buffer = absl::make_unique( shape, shape, &allocator, kDeviceOrdinal); std::unique_ptr buffer = std::move(scoped_buffer); buffer = nullptr; diff --git a/tensorflow/compiler/xla/service/stream_pool.cc b/tensorflow/compiler/xla/service/stream_pool.cc new file mode 100644 index 0000000000000000000000000000000000000000..5d1cd1c4422a10e3b9e6ce6fac2c83594bb58b30 --- /dev/null +++ b/tensorflow/compiler/xla/service/stream_pool.cc @@ -0,0 +1,65 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/stream_pool.h" + +#include "absl/memory/memory.h" +#include "tensorflow/core/platform/logging.h" + +namespace xla { + +StreamPool::Ptr StreamPool::BorrowStream(se::StreamExecutor* executor) { + std::unique_ptr stream; + { + tensorflow::mutex_lock lock(mu_); + if (!streams_.empty()) { + // Re-use an existing stream from the pool. + stream = std::move(streams_.back()); + streams_.pop_back(); + VLOG(1) << stream->DebugStreamPointers() + << " StreamPool reusing existing stream"; + } + } + + if (!stream) { + // Create a new stream. + stream = absl::make_unique(executor); + stream->Init(); + VLOG(1) << stream->DebugStreamPointers() + << " StreamPool created new stream"; + } + + // Return the stream wrapped in Ptr, which has our special deleter semantics. + PtrDeleter deleter = {this}; + return Ptr(stream.release(), deleter); +} + +void StreamPool::ReturnStream(se::Stream* stream) { + if (stream->ok()) { + VLOG(1) << stream->DebugStreamPointers() + << " StreamPool returning ok stream"; + tensorflow::mutex_lock lock(mu_); + streams_.emplace_back(stream); + } else { + // If the stream has encountered any errors, all subsequent operations on it + // will fail. So just delete the stream, and rely on new streams to be + // created in the future. + VLOG(1) << stream->DebugStreamPointers() + << " StreamPool deleting !ok stream"; + delete stream; + } +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/stream_pool.h b/tensorflow/compiler/xla/service/stream_pool.h new file mode 100644 index 0000000000000000000000000000000000000000..7221d323a61593ac4b203a81b6046d81a5beaaf0 --- /dev/null +++ b/tensorflow/compiler/xla/service/stream_pool.h @@ -0,0 +1,64 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_STREAM_POOL_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_STREAM_POOL_H_ + +#include +#include + +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/core/platform/mutex.h" +#include "tensorflow/core/platform/stream_executor_no_cuda.h" + +namespace xla { + +// Pool of stream_executor::Streams, which are created as needed and +// destroyed when the pool is destroyed. +class StreamPool { + public: + struct PtrDeleter { + void operator()(se::Stream* stream) { pool->ReturnStream(stream); } + StreamPool* pool; + }; + + // Stream pointer type returned by BorrowStream, which returns the + // stream to the pool on destruction. + using Ptr = std::unique_ptr; + + StreamPool() {} + + // Returns a pointer to a stream in the pool, creating a new stream + // if none are available in the pool. The returned smart pointer + // returns the stream to the pool on destruction. + // + // This method is thread-safe. + Ptr BorrowStream(se::StreamExecutor* executor); + + private: + // Puts a pointer to a stream back into the pool, leaving it free + // for future use. Streams that have previously encountered errors + // are deleted, and not returned to the pool. + // + // This method is thread-safe. + void ReturnStream(se::Stream* stream); + + tensorflow::mutex mu_; + std::vector> streams_ GUARDED_BY(mu_); +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_STREAM_POOL_H_ diff --git a/tensorflow/compiler/xla/service/stream_pool_test.cc b/tensorflow/compiler/xla/service/stream_pool_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..aaf5c37b0d250f78cb57639255ac9b59e1b462f7 --- /dev/null +++ b/tensorflow/compiler/xla/service/stream_pool_test.cc @@ -0,0 +1,136 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/stream_pool.h" + +#include + +#include "tensorflow/compiler/xla/test_helpers.h" +#include "tensorflow/core/platform/stream_executor_no_cuda.h" + +namespace xla { +namespace { + +class StreamPoolTest : public ::testing::Test { + protected: + std::unique_ptr NewStreamExecutor() { + se::Platform* platform = + se::MultiPlatformManager::PlatformWithName("Host").ConsumeValueOrDie(); + se::StreamExecutorConfig config(/*ordinal=*/0); + return platform->GetUncachedExecutor(config).ConsumeValueOrDie(); + } +}; + +TEST_F(StreamPoolTest, EmptyPool) { StreamPool pool; } + +TEST_F(StreamPoolTest, OneStreamPool) { + std::unique_ptr executor = NewStreamExecutor(); + StreamPool pool; + + // Borrow and return a stream. + StreamPool::Ptr stream1 = pool.BorrowStream(executor.get()); + se::Stream* stream1_ptr = stream1.get(); + EXPECT_TRUE(stream1->ok()); + stream1 = nullptr; + + // Borrow and return another stream. + StreamPool::Ptr stream2 = pool.BorrowStream(executor.get()); + se::Stream* stream2_ptr = stream2.get(); + EXPECT_TRUE(stream2->ok()); + stream2 = nullptr; + + // The underlying streams should be the same, since stream1 was the + // only stream available in the pool when stream2 was borrowed. + EXPECT_EQ(stream1_ptr, stream2_ptr); +} + +TEST_F(StreamPoolTest, TwoStreamPool) { + std::unique_ptr executor = NewStreamExecutor(); + StreamPool pool; + + // Borrow two streams. + StreamPool::Ptr stream1 = pool.BorrowStream(executor.get()); + se::Stream* stream1_ptr = stream1.get(); + EXPECT_TRUE(stream1->ok()); + StreamPool::Ptr stream2 = pool.BorrowStream(executor.get()); + se::Stream* stream2_ptr = stream2.get(); + EXPECT_TRUE(stream2->ok()); + + // The underlying streams should be different, since we haven't + // returned either of them yet. + EXPECT_NE(stream1_ptr, stream2_ptr); + + // Return stream1 and borrow stream3. + stream1 = nullptr; + StreamPool::Ptr stream3 = pool.BorrowStream(executor.get()); + se::Stream* stream3_ptr = stream3.get(); + EXPECT_TRUE(stream3->ok()); + + // stream1 and stream3 should be the same. + EXPECT_EQ(stream1_ptr, stream3_ptr); + EXPECT_NE(stream2_ptr, stream3_ptr); + + // Return stream2, and borrow stream4. + stream2 = nullptr; + StreamPool::Ptr stream4 = pool.BorrowStream(executor.get()); + se::Stream* stream4_ptr = stream4.get(); + EXPECT_TRUE(stream4->ok()); + + // Stream2 and stream4 should be the same. + EXPECT_EQ(stream2_ptr, stream4_ptr); + EXPECT_NE(stream3_ptr, stream4_ptr); +} + +TEST_F(StreamPoolTest, BadStreamDiscarded) { + std::unique_ptr executor = NewStreamExecutor(); + StreamPool pool; + + // Borrow a stream. + StreamPool::Ptr stream1 = pool.BorrowStream(executor.get()); + EXPECT_TRUE(stream1->ok()); + + // Force an error on the stream; here we call a method that requires + // DNN support, which we know the Host platform doesn't support. + stream1->ThenDepthConcatenate({}, {}, nullptr); + EXPECT_FALSE(stream1->ok()); + + // Return stream1 and borrow stream2. + stream1 = nullptr; + StreamPool::Ptr stream2 = pool.BorrowStream(executor.get()); + se::Stream* stream2_ptr = stream2.get(); + EXPECT_TRUE(stream2->ok()); + + // The underlying streams should be different. They would have been + // the same, but since we forced an error on stream1, it cannot be + // put back into the pool. Sadly we can't just check: + // EXPECT_NE(stream1_ptr, stream2_ptr); + // + // The above should hold logically, but it may fail if the new + // stream instance allocated for stream2 happens to reside in the + // same memory address as stream1, which has been deleted. + // + // The check that stream2->ok() serves as a good-enough check. + + // Return stream2 and borrow stream3. The previous error on stream1 + // has no effect on these streams, and they are the same. + stream2 = nullptr; + StreamPool::Ptr stream3 = pool.BorrowStream(executor.get()); + se::Stream* stream3_ptr = stream3.get(); + EXPECT_TRUE(stream3->ok()); + EXPECT_EQ(stream2_ptr, stream3_ptr); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/service/transfer_manager.cc b/tensorflow/compiler/xla/service/transfer_manager.cc index 7232c658b3f0687ac93a83e46a200f88bf202084..e0f995fd0d7cbabe5d1abd6af3d0c0005a8c9d48 100644 --- a/tensorflow/compiler/xla/service/transfer_manager.cc +++ b/tensorflow/compiler/xla/service/transfer_manager.cc @@ -18,6 +18,7 @@ limitations under the License. #include #include +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/types.h" @@ -43,15 +44,39 @@ TransferManager::GetPlatformTransferManagers() { StatusOr> TransferManager::TransferLiteralFromDevice( se::Stream* stream, const ShapedBuffer& device_buffer) { StatusOr> ret; + se::Stream* substream = stream->GetOrCreateSubStream(); substream->ThenWaitFor(stream); auto cleanup = tensorflow::gtl::MakeCleanup( [&]() { stream->ReturnSubStream(substream); }); tensorflow::Notification n; - TransferLiteralFromDevice(substream, device_buffer, - [&](StatusOr> arg) { - ret = std::move(arg); + Status s; + Literal literal(device_buffer.on_host_shape()); + TransferLiteralFromDevice(substream, device_buffer, literal, + [&](Status status) { + s = status; + n.Notify(); + }); + n.WaitForNotification(); + if (!s.ok()) { + return s; + } + return absl::make_unique(std::move(literal)); +} + +Status TransferManager::TransferLiteralFromDevice( + se::Stream* stream, const ShapedBuffer& device_buffer, + const MutableBorrowingLiteral& literal) { + se::Stream* substream = stream->GetOrCreateSubStream(); + auto cleanup = tensorflow::gtl::MakeCleanup( + [&]() { stream->ReturnSubStream(substream); }); + + Status ret; + tensorflow::Notification n; + TransferLiteralFromDevice(substream, device_buffer, literal, + [&](Status status) { + ret = status; n.Notify(); }); n.WaitForNotification(); @@ -76,22 +101,27 @@ Status TransferManager::TransferLiteralToDevice( StatusOr> TransferManager::TransferArrayFromDevice( se::Stream* stream, const Shape& shape, const se::DeviceMemoryBase& source) { + StatusOr> ret; // Implement the synchronous version by waiting on the asynchronous version. // Use a substream so that if we are called from a HostCallback we don't // deadlock. - StatusOr> ret; se::Stream* substream = stream->GetOrCreateSubStream(); auto cleanup = tensorflow::gtl::MakeCleanup( [&]() { stream->ReturnSubStream(substream); }); tensorflow::Notification n; - TransferArrayFromDevice(substream, shape, source, - [&](StatusOr> arg) { - ret = std::move(arg); + Literal literal(shape); + Status s; + TransferArrayFromDevice(substream, shape, source, literal, + [&](Status status) { + s = status; n.Notify(); }); n.WaitForNotification(); - return ret; + if (!s.ok()) { + return s; + } + return absl::make_unique(std::move(literal)); } Status TransferManager::TransferArrayToDevice( @@ -130,7 +160,7 @@ Status TransferManager::TransferArrayToDeviceAsync( void TransferManager::TransferArrayFromDevice( se::Stream* stream, const Shape& shape, const se::DeviceMemoryBase& source, - std::function>)> done) { + const MutableBorrowingLiteral& literal, std::function done) { if (!ShapeUtil::Equal(HostShapeToDeviceShape(shape), shape)) { auto error = StrCat("Shape ", ShapeUtil::HumanString(shape), " has a differently shaped representation on-device: ", @@ -147,7 +177,8 @@ void TransferManager::TransferArrayFromDevice( stream->parent()->platform(), stream->parent()->device_ordinal()); shaped_buffer.set_buffer(source, /*index=*/{}); - return TransferLiteralFromDevice(stream, shaped_buffer, std::move(done)); + return TransferLiteralFromDevice(stream, shaped_buffer, literal, + std::move(done)); } /* static */ void TransferManager::RegisterTransferManager( diff --git a/tensorflow/compiler/xla/service/transfer_manager.h b/tensorflow/compiler/xla/service/transfer_manager.h index 82c599e482d85fc5bbe5a5a48c6c6b053186803b..475a2e5c141d66fa689fb402da1ee81fb4ab80f7 100644 --- a/tensorflow/compiler/xla/service/transfer_manager.h +++ b/tensorflow/compiler/xla/service/transfer_manager.h @@ -59,6 +59,9 @@ class TransferManager { // This function should be avoided in favor of the asynchronous version below. virtual StatusOr> TransferLiteralFromDevice( se::Stream* stream, const ShapedBuffer& device_buffer); + virtual Status TransferLiteralFromDevice( + se::Stream* stream, const ShapedBuffer& device_buffer, + const MutableBorrowingLiteral& literal); // Begins transferring a literal containing the data held in the given // ShapedBuffer using the provided executor. @@ -69,9 +72,10 @@ class TransferManager { // // device_buffer is copied by reference and must live at least until done() is // invoked. - virtual void TransferLiteralFromDevice( - se::Stream* stream, const ShapedBuffer& device_buffer, - std::function>)> done) = 0; + virtual void TransferLiteralFromDevice(se::Stream* stream, + const ShapedBuffer& device_buffer, + MutableBorrowingLiteral literal, + std::function done) = 0; // Transfers the given literal into the previously allocated device memory // represented by the given ShapedBuffer using the given executor. The shape @@ -101,10 +105,10 @@ class TransferManager { // transfer an array at a known address. Status TransferArrayToDevice(se::Stream* stream, const LiteralSlice& literal, const se::DeviceMemoryBase& dest); - void TransferArrayFromDevice( - se::Stream* stream, const Shape& shape, - const se::DeviceMemoryBase& source, - std::function>)> done); + void TransferArrayFromDevice(se::Stream* stream, const Shape& shape, + const se::DeviceMemoryBase& source, + const MutableBorrowingLiteral& literal, + std::function done); Status TransferArrayToDeviceAsync(se::Stream* stream, const LiteralSlice& literal, @@ -120,9 +124,9 @@ class TransferManager { // Transfers the given literal from the Outfeed interface of the device, // using the given executor. - virtual Status TransferLiteralFromOutfeed(se::StreamExecutor* executor, - const Shape& literal_shape, - Literal* literal) = 0; + virtual Status TransferLiteralFromOutfeed( + se::StreamExecutor* executor, const Shape& literal_shape, + MutableBorrowingLiteral literal) = 0; // Resets the devices associated with this transfer manager. virtual Status ResetDevices( diff --git a/tensorflow/compiler/xla/service/transpose_folding_test.cc b/tensorflow/compiler/xla/service/transpose_folding_test.cc index 7051a4cf51749d294478cf9a34d4700cb52ae312..58f767e913fbc0023e0c45a4f0e82ecefeeef2d6 100644 --- a/tensorflow/compiler/xla/service/transpose_folding_test.cc +++ b/tensorflow/compiler/xla/service/transpose_folding_test.cc @@ -19,7 +19,7 @@ limitations under the License. #include #include -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" diff --git a/tensorflow/compiler/xla/service/tuple_points_to_analysis.cc b/tensorflow/compiler/xla/service/tuple_points_to_analysis.cc index 990dfc410ccf6ab84af00f4a16dc783c11985844..0c2f2112af5cdebe998f0d723528076b3c73d260 100644 --- a/tensorflow/compiler/xla/service/tuple_points_to_analysis.cc +++ b/tensorflow/compiler/xla/service/tuple_points_to_analysis.cc @@ -19,6 +19,7 @@ limitations under the License. #include #include +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/map_util.h" #include "tensorflow/compiler/xla/service/hlo_dataflow_analysis.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" @@ -232,8 +233,7 @@ Status TuplePointsToAnalysis::HandleGetTupleElement( // Copy the points-to set (and tuple sources) at index {element_index} of the // operand to the points-to set for this GetTupleElement instruction. points_to_set.ForEachMutableElement( - [&, this](const ShapeIndex& target_index, - PointsToSet::BufferList* points_to) { + [&](const ShapeIndex& target_index, PointsToSet::BufferList* points_to) { // Construct an index into the operand by prepending element_index to // the index for the GetTupleElement instruction's points-to set. ShapeIndex src_index; @@ -308,7 +308,7 @@ Status TuplePointsToAnalysis::HandleRecvDone(HloInstruction* recv_done) { // Recursively copy the points to set of the operand tuple {0} to the output // element {0}. points_to_set.ForEachMutableElement( - [this, &points_to_set, &operand_points_to_set]( + [&points_to_set, &operand_points_to_set]( const ShapeIndex& index, PointsToSet::BufferList* buffers) { if (index.empty() || index[0] != 0) { return; @@ -442,7 +442,7 @@ PointsToSet& TuplePointsToAnalysis::CreateEmptyPointsToSet( PerInstruction* pi = PerInst(instruction); CHECK(pi->points_to_set == nullptr) << "instruction should not have been present in the map."; - auto set = MakeUnique(&instruction->shape()); + auto set = absl::make_unique(&instruction->shape()); pi->points_to_set = std::move(set); // Return *set using the iterator returned by emplace. return *pi->points_to_set; @@ -517,7 +517,7 @@ Status TuplePointsToAnalysis::GatherBuffersDefinedByInstruction( const HloInstruction* instruction, TuplePointsToAnalysis::BufferDefinitionVector* buffers) { GetPointsToSet(instruction) - .ForEachElement([this, buffers, instruction]( + .ForEachElement([buffers, instruction]( const ShapeIndex& index, const PointsToSet::BufferList& source_buffers) { // Add buffers which 'instruction' is the source of. @@ -547,7 +547,7 @@ PointsToSet& TuplePointsToAnalysis::CreateCopiedPointsToSet( PointsToSet& dst_points_to_set = CreateEmptyPointsToSet(instruction); const PointsToSet& src_points_to_set = GetPointsToSet(src); dst_points_to_set.ForEachMutableElement( - [this, &dst_points_to_set, &src_points_to_set]( + [&dst_points_to_set, &src_points_to_set]( const ShapeIndex& index, PointsToSet::BufferList* buffers) { *buffers = src_points_to_set.element(index); for (auto& tuple_source : src_points_to_set.tuple_sources(index)) { @@ -718,6 +718,7 @@ bool TuplePointsToAnalysis::HasUniqueFusedUseOfOperandAt( // root at operand 0 or 1. Or... // (4) The 'user' of 'operand' is DynamicUpdateSlice or While at operand index // 0. +// (5) The 'user' of 'operand' is Sort, and it is the only user. // // (2) and (3) can only be determined if points-to analysis is available. bool TuplePointsToAnalysis::CanShareOperandBufferWithUser( @@ -783,6 +784,21 @@ bool TuplePointsToAnalysis::CanShareOperandBufferWithUser( std::vector operand_indices = user->OperandIndices(operand); return operand_indices.size() == 1 && operand_indices[0] == 0; } + if (user->opcode() == HloOpcode::kSort) { + // Only valid if there are no other users. + if (operand->users().size() != 1) { + return false; + } + // If we only sort keys, the output of sort is not a tuple, so we can always + // share the buffer. + if (user->operand_count() == 1) { + return true; + } + CHECK(!user_index.empty()); + // Only share with the right tuple element buffer. + std::vector operand_indices = user->OperandIndices(operand); + return operand_indices.size() == 1 && user_index[0] == operand_indices[0]; + } if (user->opcode() == HloOpcode::kCall) { // TODO(b/62548313): Remove when buffer assignment is module scoped and // does not assign buffers to calls. diff --git a/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc b/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc index 0ac8df42714a1550d36560cbff901f6a8a4b3a8d..10d382e8abc92145c1804cbf18bbed714fa34571 100644 --- a/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc +++ b/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc @@ -1012,6 +1012,48 @@ TEST_F(CanShareOperandBufferWithUserTest, DynamicUpdateSliceCanShare) { points_to_analysis_->CanShareOperandBufferWithUser(starts, {}, dus, {})); } +TEST_F(CanShareOperandBufferWithUserTest, SortCanShare) { + auto builder = HloComputation::Builder(TestName()); + + Shape keys_shape = ShapeUtil::MakeShape(F32, {8}); + auto keys = builder.AddInstruction( + HloInstruction::CreateParameter(0, keys_shape, "keys")); + auto sort = + builder.AddInstruction(HloInstruction::CreateSort(keys_shape, 0, keys)); + + BuildModuleAndRunAnalysis(builder.Build()); + + EXPECT_TRUE( + points_to_analysis_->CanShareOperandBufferWithUser(keys, {}, sort, {})); +} + +TEST_F(CanShareOperandBufferWithUserTest, SortCanShareWithTupleUser) { + auto builder = HloComputation::Builder(TestName()); + + Shape keys_shape = ShapeUtil::MakeShape(F32, {8}); + Shape values_shape = ShapeUtil::MakeShape(F32, {8}); + auto keys = builder.AddInstruction( + HloInstruction::CreateParameter(0, keys_shape, "keys")); + auto values = builder.AddInstruction( + HloInstruction::CreateParameter(1, values_shape, "values")); + auto sort = builder.AddInstruction(HloInstruction::CreateSort( + ShapeUtil::MakeTupleShape({keys_shape, values_shape}), 0, keys, values)); + + BuildModuleAndRunAnalysis(builder.Build()); + + // The buffer for the keys can be shared with the first tuple entry. + EXPECT_TRUE( + points_to_analysis_->CanShareOperandBufferWithUser(keys, {}, sort, {0})); + // The buffer for the values can be shared with the second tuple entry. + EXPECT_TRUE(points_to_analysis_->CanShareOperandBufferWithUser(values, {}, + sort, {1})); + // Verify that the buffers are not shared with the "wrong" tuple entry. + EXPECT_FALSE( + points_to_analysis_->CanShareOperandBufferWithUser(keys, {}, sort, {1})); + EXPECT_FALSE(points_to_analysis_->CanShareOperandBufferWithUser(values, {}, + sort, {0})); +} + TEST_F(CanShareOperandBufferWithUserTest, FusedDotAdd) { auto builder = HloComputation::Builder(TestName()); Shape data_shape = ShapeUtil::MakeShape(F32, {2, 2}); @@ -1076,7 +1118,7 @@ TEST_F(CanShareOperandBufferWithUserTest, OutputFusionCantAliasOperandBuffer) { TEST_F(CanShareOperandBufferWithUserTest, WhileCanShare) { Shape data_shape = ShapeUtil::MakeShape(F32, {8}); - auto make_cond = [this, &data_shape]() { + auto make_cond = [&data_shape]() { auto builder = HloComputation::Builder(TestName() + ".Cond"); auto data = builder.AddInstruction( HloInstruction::CreateParameter(0, data_shape, "data")); @@ -1085,7 +1127,7 @@ TEST_F(CanShareOperandBufferWithUserTest, WhileCanShare) { return builder.Build(); }; - auto make_body = [this, &data_shape]() { + auto make_body = [&data_shape]() { auto builder = HloComputation::Builder(TestName() + ".Body"); auto data = builder.AddInstruction( HloInstruction::CreateParameter(0, data_shape, "data")); diff --git a/tensorflow/compiler/xla/service/while_loop_analysis.cc b/tensorflow/compiler/xla/service/while_loop_analysis.cc new file mode 100644 index 0000000000000000000000000000000000000000..af2cb6dc2a3f4a004351acc62796e0daf46719c2 --- /dev/null +++ b/tensorflow/compiler/xla/service/while_loop_analysis.cc @@ -0,0 +1,238 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/while_loop_analysis.h" +#include "tensorflow/compiler/xla/service/hlo_evaluator.h" + +namespace xla { + +using tensorflow::gtl::nullopt; +using tensorflow::gtl::optional; + +// Finds and returns the non-constant operand in instr. +// +// CHECK-fails if instr doesn't have exactly one unique non-constant operand. +static const HloInstruction* NonConstantOperand(const HloInstruction* instr) { + const HloInstruction* result = nullptr; + for (const HloInstruction* operand : instr->operands()) { + if (!operand->IsConstant()) { + if (result != nullptr) { + CHECK_EQ(result, operand); + } + result = operand; + } + } + CHECK_NE(result, nullptr); + return result; +} + +// If all of instr's operands are either constants or have the form +// get-tuple-element(gte_operand, N) +// for the same value N, returns N. Otherwise, returns nullopt. +static optional GetGTEOperandIndex(const HloInstruction* instr, + const HloInstruction* gte_operand) { + VLOG(2) << "GetGTEOperandIndex(" << instr->ToString() << ", " + << gte_operand->ToString() << ")"; + optional tuple_idx; + for (const HloInstruction* operand : instr->operands()) { + if (operand->IsConstant()) { + continue; + } + // Look through copies. + // TODO(b/68830972): We wouldn't need this if for loop matching on the GPU + // would run before copy insertion. + if (operand->opcode() == HloOpcode::kCopy) { + operand = operand->operand(0); + } + if (operand->opcode() != HloOpcode::kGetTupleElement) { + VLOG(2) << "instr uses something other than gte(gte_operand): " + << operand->ToString(); + return nullopt; + } + if (operand->operand(0) != gte_operand) { + VLOG(2) << "instr has gte whose operand is not gte_operand: " + << operand->ToString(); + return nullopt; + } + if (tuple_idx && tuple_idx != operand->tuple_index()) { + VLOG(2) << "instr has operands with conflicting gte indices, " + << *tuple_idx << " vs " << operand->tuple_index(); + return nullopt; + } + + tuple_idx = operand->tuple_index(); + } + return tuple_idx; +} + +// Tries to get the tuple index of the induction variable of a while loop. +// +// Checks that the loop condition and root both plumb the induction variable +// through the same tuple index, and that they both apply exactly one op to the +// induction variable before deciding whether to do another loop iteration (in +// the loop condition's case) or packing the induction variable into the result +// tuple (in the loop body's case). +// +// Specifically, checks that the loop condition has structure +// +// root = op(constants, get-tuple-elem(param0, N), constants) +// +// and the loop body has the structure +// +// inc = op(constants, get-tuple-elem(param0, N), constants) +// root = tuple(..., inc, ...) // inc is N'th operand of tuple(). +// +// If so, returns N. Otherwise, returns nullopt. +static optional GetLoopInductionVarTupleIdx( + const HloInstruction* while_op) { + CHECK_EQ(while_op->opcode(), HloOpcode::kWhile); + VLOG(2) << "Finding induction variable for loop " + << while_op->ToShortString(); + + // The while_cond computation should have the form + // + // while_cond_root = + // op(constants, get-tuple-elem(while_cond_param, N), constants). + // + // If it does, set indvar_tuple_idx to N. + auto* while_cond = while_op->while_condition(); + auto* while_cond_root = while_cond->root_instruction(); + auto* while_cond_param = while_cond->parameter_instruction(0); + optional indvar_tuple_idx = + GetGTEOperandIndex(while_cond_root, while_cond_param); + if (!indvar_tuple_idx) { + VLOG(2) << "Induction variable not found in loop condition: " + << while_cond->root_instruction()->ToString(); + return nullopt; + } + + // The while_body computation should have the form + // + // while_body_inc = + // op(constants, get-tuple-elem(while_body_param, N), constants) + // while_body_root = tuple(..., while_body_inc, ...) + // + // where while_body_inc is operand N of while_body_root. + auto* while_body = while_op->while_body(); + auto* while_body_root = while_body->root_instruction(); + if (while_body_root->opcode() != HloOpcode::kTuple) { + VLOG(2) << "While body's root is not a tuple instruction: " + << while_body_root->ToString(); + return nullopt; + } + + auto* while_body_inc = while_body_root->operand(*indvar_tuple_idx); + auto* while_body_param = while_body->parameter_instruction(0); + optional while_body_indvar_tuple_idx = + GetGTEOperandIndex(while_body_inc, while_body_param); + if (!while_body_indvar_tuple_idx) { + VLOG(2) + << "Induction variable not found in while body increment instruction: " + << while_body_inc->ToString(); + return nullopt; + } + if (while_body_indvar_tuple_idx != indvar_tuple_idx) { + VLOG(2) << "Tuple index of induction variable does not match between loop " + "condition (" + << *indvar_tuple_idx << ") and while body (" + << *while_body_indvar_tuple_idx << ")"; + return nullopt; + } + + // Finally, check that the while loop's initial value is a tuple with enough + // elements. + auto* while_init = while_op->operand(0); + if (while_init->opcode() != HloOpcode::kTuple) { + VLOG(2) << "While init expected to be a tuple: " << while_init->ToString(); + return nullopt; + } + + VLOG(2) << "Induction variable's tuple index: " << *indvar_tuple_idx; + return indvar_tuple_idx; +} + +optional ComputeWhileLoopTripCount(HloInstruction* while_op, + int64 max_value_returned) { + VLOG(2) << "Getting trip count for loop " << while_op->ToString(); + + // The loop's induction variable is found at + // + // get-tuple-elem(comp->parameter_instruction(0), *indvar_tuple_idx), + // + // where comp is while_op->while_body() or while_op->while_condition(). + optional indvar_tuple_idx = GetLoopInductionVarTupleIdx(while_op); + if (!indvar_tuple_idx) { + return nullopt; + } + + // Now that we know the index of the induction variable, we can we can try to + // compute how many times the loop executes. Start by computing the induction + // variable's initial value. + HloEvaluator evaluator(/*max_loop_iterations=*/0); + auto* while_init = while_op->mutable_operand(0); + auto* indvar_init = while_init->mutable_operand(*indvar_tuple_idx); + StatusOr> indvar_init_result = + evaluator.Evaluate(indvar_init); + if (!indvar_init_result.ok()) { + VLOG(2) << "Couldn't evaluate induction variable init: " + << indvar_init_result.status(); + return nullopt; + } + + auto* while_body = while_op->while_body(); + auto* while_body_indvar_update = + while_body->root_instruction()->operand(*indvar_tuple_idx); + auto* while_body_indvar = NonConstantOperand(while_body_indvar_update); + + // The initial value of the induction variable. + std::unique_ptr indvar_iter_val = + std::move(indvar_init_result).ValueOrDie(); + for (int64 trip_count = 0; trip_count != max_value_returned + 1; + ++trip_count) { + auto* while_cond = while_op->while_condition(); + auto* while_cond_root = while_cond->root_instruction(); + auto* while_cond_indvar = NonConstantOperand(while_cond_root); + StatusOr> result = + evaluator.EvaluateWithSubstitutions( + while_cond_root, {{while_cond_indvar, indvar_iter_val.get()}}); + if (!result.ok()) { + VLOG(2) << "Couldn't evaluate while cond: " << result.status(); + return nullopt; + } + if (result.ValueOrDie()->data() == + tensorflow::gtl::ArraySlice{false}) { + VLOG(2) << "Loop has static trip count of " << trip_count; + return trip_count; + } + + // Calculate the value of the induction variable after one iteration of the + // loop, and check whether the while condition is true with this new value. + StatusOr> indvar_next_result = + evaluator.EvaluateWithSubstitutions( + while_body_indvar_update, + {{while_body_indvar, indvar_iter_val.get()}}); + if (!indvar_next_result.ok()) { + VLOG(2) << "Couldn't evaluate induction variable update: " + << indvar_next_result.status(); + return nullopt; + } + indvar_iter_val = std::move(indvar_next_result).ValueOrDie(); + } + + VLOG(2) << "Loop has unknown trip count."; + return nullopt; +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/while_loop_analysis.h b/tensorflow/compiler/xla/service/while_loop_analysis.h new file mode 100644 index 0000000000000000000000000000000000000000..bf59813e8c405a8709446bf8457729348ceae4ec --- /dev/null +++ b/tensorflow/compiler/xla/service/while_loop_analysis.h @@ -0,0 +1,33 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_WHILE_LOOP_ANALYSIS_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_WHILE_LOOP_ANALYSIS_H_ + +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/core/lib/gtl/optional.h" + +namespace xla { + +// Returns the precise trip count of the loop if it's statically known, +// nullopt otherwise. max_value_returned limits the number of steps that are +// evaluated while trying to brute force a loop trip count, trip counts larger +// than max_value_returned result in nullopt. +tensorflow::gtl::optional ComputeWhileLoopTripCount( + HloInstruction *while_op, int64 max_value_returned = 128); + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_WHILE_LOOP_ANALYSIS_H_ diff --git a/tensorflow/compiler/xla/service/while_loop_constant_sinking.cc b/tensorflow/compiler/xla/service/while_loop_constant_sinking.cc index 62af45128ad2fb7bf886bef78ec3ab42529a181e..aab11806621746141f4302f39a780fcdbab99fc1 100644 --- a/tensorflow/compiler/xla/service/while_loop_constant_sinking.cc +++ b/tensorflow/compiler/xla/service/while_loop_constant_sinking.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/xla/service/while_loop_constant_sinking.h" +#include "absl/algorithm/container.h" #include "tensorflow/compiler/xla/service/while_util.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/gtl/flatmap.h" @@ -32,7 +33,7 @@ static Status ReplaceUsesWhileKeepingLoopInvariance( std::vector users; users.reserve(old_instr->user_count()); - c_copy(old_instr->users(), std::back_inserter(users)); + absl::c_copy(old_instr->users(), std::back_inserter(users)); for (auto* user : users) { for (int64 i = 0, e = user->operand_count(); i < e; i++) { @@ -108,10 +109,10 @@ StatusOr WhileLoopConstantSinking::Run(HloModule* module) { // // This will let us sink the constant into the outer while first and then // into the inner while in a single run of this pass. - c_copy_if(comp->instructions(), std::back_inserter(while_instrs), - [](const HloInstruction* instr) { - return instr->opcode() == HloOpcode::kWhile; - }); + absl::c_copy_if(comp->instructions(), std::back_inserter(while_instrs), + [](const HloInstruction* instr) { + return instr->opcode() == HloOpcode::kWhile; + }); } for (HloInstruction* while_instr : while_instrs) { diff --git a/tensorflow/compiler/xla/service/while_loop_constant_sinking_test.cc b/tensorflow/compiler/xla/service/while_loop_constant_sinking_test.cc index 266039d2ff8ef4befba0d1023ac1914737207d4f..0e7667de832c54f647d071e3c9563091d0f994aa 100644 --- a/tensorflow/compiler/xla/service/while_loop_constant_sinking_test.cc +++ b/tensorflow/compiler/xla/service/while_loop_constant_sinking_test.cc @@ -206,7 +206,8 @@ body { p_body.0 = f32[2] get-tuple-element((f32[2],f32[2]) p_body), index=0 p_body.1 = f32[2] get-tuple-element((f32[2],f32[2]) p_body), index=1 - outfeed = token[] outfeed(p_body.0) + token = token[] after-all() + outfeed = token[] outfeed(p_body.0, token) ROOT root = (f32[2],f32[2],f32[2]) tuple(p_body.0, p_body.1, p_body.1) } diff --git a/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.cc b/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.cc index 09ddcffb22c2184262adf87d570870ec000c0e6f..cb132d4f16aeb963b783e6e985aa038b90072f9d 100644 --- a/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.cc +++ b/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/xla/service/while_loop_invariant_code_motion.h" +#include "absl/algorithm/container.h" #include "tensorflow/compiler/xla/service/tuple_util.h" #include "tensorflow/compiler/xla/service/while_util.h" #include "tensorflow/compiler/xla/util.h" @@ -65,8 +66,8 @@ static void CreateLoopInvariantCopy( }; InlinedVector new_operands; - c_transform(old_instruction->operands(), std::back_inserter(new_operands), - get_new_operand); + absl::c_transform(old_instruction->operands(), + std::back_inserter(new_operands), get_new_operand); HloInstruction* new_instruction = parent_of_while->AddInstruction(old_instruction->CloneWithNewOperands( @@ -197,7 +198,7 @@ WhileLoopInvariantCodeMotion::TryHoistingInvariantInstructionsFromWhileBody( op->opcode() == HloOpcode::kConstant; }; - if (!c_all_of(instruction->operands(), is_invariant)) { + if (!absl::c_all_of(instruction->operands(), is_invariant)) { continue; } @@ -257,10 +258,10 @@ StatusOr WhileLoopInvariantCodeMotion::Run(HloModule* module) { bool changed = false; std::vector while_instrs; for (auto* comp : module->computations()) { - c_copy_if(comp->instructions(), std::back_inserter(while_instrs), - [](const HloInstruction* instr) { - return instr->opcode() == HloOpcode::kWhile; - }); + absl::c_copy_if(comp->instructions(), std::back_inserter(while_instrs), + [](const HloInstruction* instr) { + return instr->opcode() == HloOpcode::kWhile; + }); } for (HloInstruction* while_instr : while_instrs) { diff --git a/tensorflow/compiler/xla/service/while_loop_simplifier.cc b/tensorflow/compiler/xla/service/while_loop_simplifier.cc index ec05a74e286c89dd8db5ae07580e461938d7c087..dd8697e680c56165f87c365a721eda2de1ebc085 100644 --- a/tensorflow/compiler/xla/service/while_loop_simplifier.cc +++ b/tensorflow/compiler/xla/service/while_loop_simplifier.cc @@ -15,7 +15,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/while_loop_simplifier.h" #include "tensorflow/compiler/xla/service/call_inliner.h" -#include "tensorflow/compiler/xla/service/hlo_evaluator.h" +#include "tensorflow/compiler/xla/service/while_loop_analysis.h" #include "tensorflow/core/lib/gtl/flatmap.h" #include "tensorflow/core/lib/gtl/optional.h" #include "tensorflow/core/lib/strings/str_util.h" @@ -26,23 +26,6 @@ namespace xla { using tensorflow::gtl::nullopt; using tensorflow::gtl::optional; -// Finds and returns the non-constant operand in instr. -// -// CHECK-fails if instr doesn't have exactly one unique non-constant operand. -static const HloInstruction* NonConstantOperand(const HloInstruction* instr) { - const HloInstruction* result = nullptr; - for (const HloInstruction* operand : instr->operands()) { - if (!operand->IsConstant()) { - if (result != nullptr) { - CHECK_EQ(result, operand); - } - result = operand; - } - } - CHECK_NE(result, nullptr); - return result; -} - // Determines whether the given instruction is a send/recv node, or has a // subcomputation which contains a send/recv node. static bool IsOrContainsSendOrRecv(const HloInstruction* instr); @@ -72,211 +55,6 @@ static bool IsOrContainsSendOrRecv(const HloInstruction* instr) { return false; } -// If all of instr's operands are either constants or have the form -// get-tuple-element(gte_operand, N) -// for the same value N, returns N. Otherwise, returns nullopt. -static optional GetGTEOperandIndex(const HloInstruction* instr, - const HloInstruction* gte_operand) { - VLOG(2) << "GetGTEOperandIndex(" << instr->ToString() << ", " - << gte_operand->ToString() << ")"; - optional tuple_idx; - for (const HloInstruction* operand : instr->operands()) { - if (operand->IsConstant()) { - continue; - } - if (operand->opcode() != HloOpcode::kGetTupleElement) { - VLOG(2) << "instr uses something other than gte(gte_operand): " - << operand->ToString(); - return nullopt; - } - if (operand->operand(0) != gte_operand) { - VLOG(2) << "instr has gte whose operand is not gte_operand: " - << operand->ToString(); - return nullopt; - } - if (tuple_idx && tuple_idx != operand->tuple_index()) { - VLOG(2) << "instr has operands with conflicting gte indices, " - << *tuple_idx << " vs " << operand->tuple_index(); - return nullopt; - } - - tuple_idx = operand->tuple_index(); - } - return tuple_idx; -} - -// Tries to get the tuple index of the induction variable of a while loop. -// -// Checks that the loop condition and root both plumb the induction variable -// through the same tuple index, and that they both apply exactly one op to the -// induction variable before deciding whether to do another loop iteration (in -// the loop condition's case) or packing the induction variable into the result -// tuple (in the loop body's case). -// -// Specifically, checks that the loop condition has structure -// -// root = op(constants, get-tuple-elem(param0, N), constants) -// -// and the loop body has the structure -// -// inc = op(constants, get-tuple-elem(param0, N), constants) -// root = tuple(..., inc, ...) // inc is N'th operand of tuple(). -// -// If so, returns N. Otherwise, returns nullopt. -static optional GetLoopInductionVarTupleIdx( - const HloInstruction* while_op) { - CHECK_EQ(while_op->opcode(), HloOpcode::kWhile); - VLOG(2) << "Finding induction variable for loop " - << while_op->ToShortString(); - - // The while_cond computation should have the form - // - // while_cond_root = - // op(constants, get-tuple-elem(while_cond_param, N), constants). - // - // If it does, set indvar_tuple_idx to N. - auto* while_cond = while_op->while_condition(); - auto* while_cond_root = while_cond->root_instruction(); - auto* while_cond_param = while_cond->parameter_instruction(0); - optional indvar_tuple_idx = - GetGTEOperandIndex(while_cond_root, while_cond_param); - if (!indvar_tuple_idx) { - VLOG(2) << "Induction variable not found in loop condition: " - << while_cond->root_instruction()->ToString(); - return nullopt; - } - - // The while_body computation should have the form - // - // while_body_inc = - // op(constants, get-tuple-elem(while_body_param, N), constants) - // while_body_root = tuple(..., while_body_inc, ...) - // - // where while_body_inc is operand N of while_body_root. - auto* while_body = while_op->while_body(); - auto* while_body_root = while_body->root_instruction(); - if (while_body_root->opcode() != HloOpcode::kTuple) { - VLOG(2) << "While body's root is not a tuple instruction: " - << while_body_root->ToString(); - return nullopt; - } - - auto* while_body_inc = while_body_root->operand(*indvar_tuple_idx); - auto* while_body_param = while_body->parameter_instruction(0); - optional while_body_indvar_tuple_idx = - GetGTEOperandIndex(while_body_inc, while_body_param); - if (!while_body_indvar_tuple_idx) { - VLOG(2) - << "Induction variable not found in while body increment instruction: " - << while_body_inc->ToString(); - return nullopt; - } - if (while_body_indvar_tuple_idx != indvar_tuple_idx) { - VLOG(2) << "Tuple index of induction variable does not match between loop " - "condition (" - << *indvar_tuple_idx << ") and while body (" - << *while_body_indvar_tuple_idx << ")"; - return nullopt; - } - - // Finally, check that the while loop's initial value is a tuple with enough - // elements. - auto* while_init = while_op->operand(0); - if (while_init->opcode() != HloOpcode::kTuple) { - VLOG(2) << "While init expected to be a tuple: " << while_init->ToString(); - return nullopt; - } - - VLOG(2) << "Induction variable's tuple index: " << *indvar_tuple_idx; - return indvar_tuple_idx; -} - -// Tries to determine the number of times the given loop executes. Currently -// simply returns 0, 1, or "can't tell" (nullopt). -static optional GetLoopTripCount(HloInstruction* while_op) { - CHECK_EQ(while_op->opcode(), HloOpcode::kWhile); - VLOG(2) << "Getting trip count for loop " << while_op->ToString(); - - // The loop's induction variable is found at - // - // get-tuple-elem(comp->parameter_instruction(0), *indvar_tuple_idx), - // - // where comp is while_op->while_body() or while_op->while_condition(). - optional indvar_tuple_idx = GetLoopInductionVarTupleIdx(while_op); - if (!indvar_tuple_idx) { - return nullopt; - } - - VLOG(2) << "Induction variable is at index " << *indvar_tuple_idx - << " in input tuple."; - - // Now that we know the index of the induction variable, we can we can try to - // compute how many times the loop executes. Start by computing the induction - // variable's initial value. - HloEvaluator evaluator(/*max_loop_iterations=*/0); - auto* while_init = while_op->mutable_operand(0); - auto* indvar_init = while_init->mutable_operand(*indvar_tuple_idx); - StatusOr> indvar_init_result = - evaluator.Evaluate(indvar_init); - if (!indvar_init_result.ok()) { - VLOG(2) << "Couldn't evaluate induction variable init: " - << indvar_init_result.status(); - return nullopt; - } - - // Evaluates the while loop's condition, returning either "true" (continue - // looping), "false" (stop looping), or nullopt (can't evaluate). - auto evaluate_while_cond = [&](const Literal& indvar) -> optional { - auto* while_cond = while_op->while_condition(); - auto* while_cond_root = while_cond->root_instruction(); - auto* while_cond_indvar = NonConstantOperand(while_cond_root); - StatusOr> result = - evaluator.EvaluateWithSubstitutions(while_cond_root, - {{while_cond_indvar, &indvar}}); - if (!result.ok()) { - VLOG(2) << "Couldn't evaluate while cond: " << result.status(); - return nullopt; - } - return result.ValueOrDie()->data() == - tensorflow::gtl::ArraySlice{true}; - }; - - // The initial value of the induction variable. - const Literal& indvar_iter0_val = *indvar_init_result.ValueOrDie(); - - // Evaluate whether the while condition is true when seeded with - // indvar_iter0_val. - optional while_cond_iter0_val = evaluate_while_cond(indvar_iter0_val); - if (while_cond_iter0_val == false) { - VLOG(2) << "Loop has static trip count of 0."; - return 0; - } - - // Calculate the value of the induction variable after one iteration of the - // loop, and check whether the while condition is true with this new value. - auto* while_body = while_op->while_body(); - auto* while_body_indvar_update = - while_body->root_instruction()->operand(*indvar_tuple_idx); - auto* while_body_indvar = NonConstantOperand(while_body_indvar_update); - StatusOr> indvar_iter1_result = - evaluator.EvaluateWithSubstitutions( - while_body_indvar_update, {{while_body_indvar, &indvar_iter0_val}}); - if (!indvar_iter1_result.ok()) { - VLOG(2) << "Couldn't evaluate induction variable update: " - << indvar_iter1_result.status(); - return nullopt; - } - const Literal& indvar_iter1_val = *indvar_iter1_result.ValueOrDie(); - optional while_cond_iter1_val = evaluate_while_cond(indvar_iter1_val); - if (while_cond_iter1_val == false) { - VLOG(2) << "Determined that loop has static trip count of 1."; - return 1; - } - - VLOG(2) << "Loop has unknown trip count >= 1."; - return nullopt; -} - // Tries to remove elements in a while loop's tuple that aren't used within the // loop. // @@ -577,7 +355,9 @@ static StatusOr TryRemoveWhileLoop(HloInstruction* while_op) { } // Remove while loops with static trip count of 0. - optional trip_count = GetLoopTripCount(while_op); + optional trip_count = + ComputeWhileLoopTripCount(while_op, + /*max_value_returned=*/1); if (trip_count && *trip_count == 0) { // The loop never executes, so the value of the loop is the value of its // "init" operand. diff --git a/tensorflow/compiler/xla/service/while_util.cc b/tensorflow/compiler/xla/service/while_util.cc index 1ef17b9d7d2e769aadf39f8a70f78200b88e9d2c..52d9c3e5ae71cc7d06acddd4717c16d3fbe9e8be 100644 --- a/tensorflow/compiler/xla/service/while_util.cc +++ b/tensorflow/compiler/xla/service/while_util.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/xla/service/while_util.h" +#include "absl/algorithm/container.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_creation_utils.h" @@ -206,7 +207,7 @@ static StatusOr MakeInitTupleFromInitValues( HloInstruction* zero = computation->AddInstruction( HloInstruction::CreateConstant(LiteralUtil::CreateR0(0))); init_values_with_indvar.push_back(zero); - c_copy(init_values, std::back_inserter(init_values_with_indvar)); + absl::c_copy(init_values, std::back_inserter(init_values_with_indvar)); return computation->AddInstruction( HloInstruction::CreateTuple(init_values_with_indvar)); } @@ -215,8 +216,9 @@ static Shape MakeLoopStateShape(const WhileUtil::LoopStateTy& init_values) { std::vector loop_state_shape_components; loop_state_shape_components.reserve(init_values.size() + 1); loop_state_shape_components.push_back(ShapeUtil::MakeShape(S32, {})); - c_transform(init_values, std::back_inserter(loop_state_shape_components), - [](HloInstruction* instr) { return instr->shape(); }); + absl::c_transform(init_values, + std::back_inserter(loop_state_shape_components), + [](HloInstruction* instr) { return instr->shape(); }); return ShapeUtil::MakeTupleShape(loop_state_shape_components); } diff --git a/tensorflow/compiler/xla/service/while_util_test.cc b/tensorflow/compiler/xla/service/while_util_test.cc index 2ccb919acf9c4e7c59a1ebaf36f42a6781068b5e..5e6941933330fde29bc9c779aae4bb3c36914660 100644 --- a/tensorflow/compiler/xla/service/while_util_test.cc +++ b/tensorflow/compiler/xla/service/while_util_test.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/while_util.h" +#include "absl/algorithm/container.h" #include "tensorflow/compiler/xla/service/hlo_matchers.h" #include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/test.h" @@ -206,7 +207,7 @@ ENTRY main { auto is_while = [](const HloInstruction* instr) { return instr->opcode() == HloOpcode::kWhile; }; - EXPECT_EQ(c_count_if(main->instructions(), is_while), 1); + EXPECT_EQ(absl::c_count_if(main->instructions(), is_while), 1); } } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/shape_tree.h b/tensorflow/compiler/xla/shape_tree.h index c74dd648addd70633edc2ec10a60879a00942716..186c42ed13089954ada5504a60ed1a4f189f9e79 100644 --- a/tensorflow/compiler/xla/shape_tree.h +++ b/tensorflow/compiler/xla/shape_tree.h @@ -21,8 +21,8 @@ limitations under the License. #include #include +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/layout_util.h" -#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/xla_data.pb.h" diff --git a/tensorflow/compiler/xla/shape_tree_test.cc b/tensorflow/compiler/xla/shape_tree_test.cc index 4391078b6484f25ba81aefa2c1d1f69d7d2774f4..c8ff55e7845785d9292516b823fb591cc28cbfad 100644 --- a/tensorflow/compiler/xla/shape_tree_test.cc +++ b/tensorflow/compiler/xla/shape_tree_test.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/compiler/xla/shape_tree.h" +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/xla_data.pb.h" @@ -172,7 +173,7 @@ TEST_F(ShapeTreeTest, TupleShape) { // Write zero to all data elements. shape_tree.ForEachMutableElement( - [&sum](const ShapeIndex& /*index*/, int* data) { *data = 0; }); + [](const ShapeIndex& /*index*/, int* data) { *data = 0; }); EXPECT_EQ(0, shape_tree.element({})); EXPECT_EQ(0, shape_tree.element({0})); EXPECT_EQ(0, shape_tree.element({1})); @@ -242,7 +243,7 @@ TEST_F(ShapeTreeTest, InvalidIndexingNestedTuple) { TEST_F(ShapeTreeTest, ShapeTreeOfNonCopyableType) { ShapeTree> shape_tree{tuple_shape_}; EXPECT_EQ(shape_tree.element({2}).get(), nullptr); - *shape_tree.mutable_element({2}) = MakeUnique(42); + *shape_tree.mutable_element({2}) = absl::make_unique(42); EXPECT_EQ(*shape_tree.element({2}), 42); } diff --git a/tensorflow/compiler/xla/shape_util.cc b/tensorflow/compiler/xla/shape_util.cc index ec901af1e2057449452c4c65243593b016a26f61..b69c346f1e62b78d4dd0c509a4bede50ed6aff14 100644 --- a/tensorflow/compiler/xla/shape_util.cc +++ b/tensorflow/compiler/xla/shape_util.cc @@ -596,8 +596,7 @@ StatusOr ParseShapeStringInternal(tensorflow::StringPiece* s) { }; auto comma_list_to_int64s = - [&s, - string_to_int64](const string& input) -> StatusOr> { + [string_to_int64](const string& input) -> StatusOr> { std::vector results; for (const string& piece : tensorflow::str_util::Split(input, ',')) { TF_ASSIGN_OR_RETURN(int64 element, string_to_int64(piece)); @@ -792,7 +791,7 @@ StatusOr ParseShapeStringInternal(tensorflow::StringPiece* s) { if (LayoutUtil::IsSparseArray(shape)) { allocated_element_count = LayoutUtil::MaxSparseElements(shape.layout()); } else { - CHECK(LayoutUtil::IsDenseArray(shape)); + CHECK(LayoutUtil::IsDenseArray(shape)) << shape.ShortDebugString(); tensorflow::gtl::ArraySlice padded_dimensions = LayoutUtil::PaddedDimensions(shape); if (!padded_dimensions.empty()) { @@ -1015,12 +1014,13 @@ bool ShapeUtil::IsLeafIndex(const Shape& shape, const ShapeIndex& index) { } /* static */ int64 ShapeUtil::GetLeafCount(const Shape& shape) { + if (!IsTuple(shape)) { + return 1; + } int64 count = 0; - ForEachSubshape(shape, [&](const Shape&, const ShapeIndex& index) { - if (IsLeafIndex(shape, index)) { - ++count; - } - }); + for (const Shape& subshape : shape.tuple_shapes()) { + count += GetLeafCount(subshape); + } return count; } diff --git a/tensorflow/compiler/xla/tests/BUILD b/tensorflow/compiler/xla/tests/BUILD index 099431d949e9cf78b7bc89490b525b8dea5e7841..4d5c9efe9ba4b0209aee08e612a6545d447207d5 100644 --- a/tensorflow/compiler/xla/tests/BUILD +++ b/tensorflow/compiler/xla/tests/BUILD @@ -113,7 +113,6 @@ cc_library( "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:types", - "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/compiler/xla/service:backend", @@ -127,6 +126,8 @@ cc_library( "//tensorflow/core:lib", "//tensorflow/core:stream_executor_no_cuda", "//tensorflow/core:test", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/memory", ], ) @@ -144,6 +145,7 @@ cc_library( "//tensorflow/compiler/xla/service:hlo_verifier", "//tensorflow/core:lib", "//tensorflow/core:test", + "@com_google_absl//absl/memory", ], ) @@ -154,8 +156,8 @@ tf_cc_binary( "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla/client:client_library", + "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/client:xla_computation", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", "//tensorflow/compiler/xla/service/cpu:cpu_compiler", "//tensorflow/compiler/xla/service/llvm_ir:llvm_util", "//tensorflow/core:lib", @@ -187,13 +189,12 @@ cc_library( "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:test_helpers", - "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:client_library", "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/client:xla_computation", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", "//tensorflow/compiler/xla/service:interpreter_plugin", # reference backend "//tensorflow/compiler/xla/service:platform_util", "//tensorflow/compiler/xla/tests:literal_test_util", @@ -201,6 +202,7 @@ cc_library( "//tensorflow/core:lib", "//tensorflow/core:stream_executor_no_cuda", "//tensorflow/core:test", + "@com_google_absl//absl/memory", ], ) @@ -274,6 +276,7 @@ cc_library( "//tensorflow/core:lib", "//tensorflow/core:stream_executor_no_cuda", "//third_party/eigen3", + "@com_google_absl//absl/memory", ], ) @@ -290,8 +293,8 @@ xla_test( "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/client:xla_computation", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", @@ -314,8 +317,8 @@ xla_test( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/client:xla_computation", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:test", @@ -334,8 +337,8 @@ xla_test( "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/client:xla_computation", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:test", @@ -356,9 +359,9 @@ xla_test( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:client_library", "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/client/lib:arithmetic", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", "//tensorflow/compiler/xla/service:platform_util", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", @@ -376,14 +379,16 @@ xla_test( "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/client:xla_computation", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", "//tensorflow/compiler/xla/service:platform_util", + "//tensorflow/compiler/xla/service:stream_pool", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:test_utils", "//tensorflow/core:lib", "//tensorflow/core:regexp_internal", "//tensorflow/core:test", + "@com_google_absl//absl/algorithm:container", ], ) @@ -395,8 +400,8 @@ xla_test( ], deps = [ "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/client:xla_computation", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:xla_internal_test_main", @@ -419,9 +424,9 @@ xla_test( "//tensorflow/compiler/xla:xla_proto", "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/client/lib:arithmetic", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:test_utils", @@ -445,8 +450,8 @@ xla_test( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/client:xla_computation", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:xla_internal_test_main", @@ -464,9 +469,9 @@ xla_test( deps = [ "//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/client/lib:arithmetic", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:test", @@ -483,8 +488,8 @@ xla_test( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/client:xla_computation", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:xla_internal_test_main", @@ -501,8 +506,8 @@ xla_test( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/client:xla_computation", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:xla_internal_test_main", @@ -519,9 +524,9 @@ xla_test( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/client/lib:arithmetic", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:xla_internal_test_main", @@ -543,8 +548,8 @@ xla_test( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/client:xla_computation", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:xla_internal_test_main", @@ -562,8 +567,8 @@ xla_test( "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/client:xla_computation", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", @@ -586,8 +591,8 @@ xla_test( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/client:xla_computation", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", @@ -612,8 +617,8 @@ xla_test( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/client:xla_computation", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", @@ -638,7 +643,7 @@ xla_test( deps = [ ":client_library_test_base", ":literal_test_util", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", + "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", ], @@ -658,7 +663,7 @@ xla_test( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", + "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/service:reduce_precision_insertion", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", @@ -681,8 +686,8 @@ xla_test( "//tensorflow/compiler/xla:reference_util", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/client:xla_computation", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:test_utils", @@ -702,8 +707,22 @@ xla_test( "//tensorflow/compiler/xla:execution_options_util", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:test", - "//tensorflow/compiler/xla/client:xla_computation", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/service:hlo_parser", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + ], +) + +xla_test( + name = "scatter_test", + srcs = ["scatter_test.cc"], + deps = [ + ":client_library_test_base", + ":hlo_test_base", + "//tensorflow/compiler/xla:execution_options_util", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/tests:xla_internal_test_main", ], @@ -726,8 +745,8 @@ xla_test( "//tensorflow/compiler/xla:reference_util", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/client:xla_computation", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:test_utils", @@ -750,8 +769,8 @@ xla_test( "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/client:xla_computation", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", @@ -774,8 +793,8 @@ xla_test( "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/client:xla_computation", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:xla_internal_test_main", @@ -796,8 +815,9 @@ CONVOLUTION_TEST_DEPS = [ "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client:padding", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", + "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/tests:client_library_test_base", + "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", @@ -809,7 +829,7 @@ xla_test( timeout = "long", srcs = ["convolution_test.cc"], shard_count = 25, - deps = CONVOLUTION_TEST_DEPS, + deps = CONVOLUTION_TEST_DEPS + ["@com_google_absl//absl/memory"], ) xla_test( @@ -819,7 +839,7 @@ xla_test( backend_args = {"gpu": ["--xla_backend_extra_options=xla_gpu_experimental_conv_disable_layout_heuristic"]}, backends = ["gpu"], shard_count = 25, - deps = CONVOLUTION_TEST_DEPS, + deps = CONVOLUTION_TEST_DEPS + ["@com_google_absl//absl/memory"], ) xla_test( @@ -839,8 +859,8 @@ xla_test( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client:padding", + "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/client:xla_computation", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:xla_internal_test_main", @@ -863,13 +883,14 @@ xla_test( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client:padding", + "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/client:xla_computation", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", "//tensorflow/core:test", + "@com_google_absl//absl/memory", ], ) @@ -892,10 +913,10 @@ xla_test( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/client/lib:arithmetic", "//tensorflow/compiler/xla/client/lib:math", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:hlo_test_base", @@ -925,9 +946,9 @@ xla_test( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/client/lib:arithmetic", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:hlo_test_base", @@ -951,8 +972,8 @@ xla_test( "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:test_helpers", + "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/client:xla_computation", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", @@ -973,7 +994,7 @@ xla_test( "//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:reference_util", "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", + "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:xla_internal_test_main", @@ -992,8 +1013,8 @@ xla_test( "//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:array3d", "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/client:xla_computation", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:xla_internal_test_main", @@ -1014,7 +1035,7 @@ xla_test( "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla/client:client_library", "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", + "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/service:computation_placer", "//tensorflow/compiler/xla/service:device_memory_allocator", "//tensorflow/compiler/xla/service:local_service", @@ -1045,13 +1066,14 @@ xla_test( "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/client:xla_computation", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:test", + "@com_google_absl//absl/memory", ], ) @@ -1066,9 +1088,9 @@ xla_test( "//tensorflow/compiler/xla:array3d", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/client/lib:arithmetic", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:xla_internal_test_main", @@ -1097,9 +1119,9 @@ xla_test( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/client/lib:arithmetic", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:xla_internal_test_main", @@ -1124,15 +1146,16 @@ xla_test_library( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client:padding", + "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/client/lib:arithmetic", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", "//tensorflow/core:test", + "@com_google_absl//absl/memory", ], ) @@ -1140,6 +1163,7 @@ xla_test( name = "reduce_window_test", timeout = "long", srcs = [], + shard_count = 20, tags = [ "enable_for_xla_interpreter", "optonly", @@ -1165,9 +1189,9 @@ xla_test( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client:padding", + "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/client/lib:arithmetic", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:xla_internal_test_main", @@ -1188,13 +1212,14 @@ xla_test( "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", + "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", "//tensorflow/core:test", + "@com_google_absl//absl/memory", ], ) @@ -1242,8 +1267,8 @@ xla_test( "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/client:xla_computation", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:xla_internal_test_main", @@ -1261,7 +1286,7 @@ xla_test( "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", + "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service/cpu:custom_call_target_registry", "//tensorflow/compiler/xla/tests:client_library_test_base", @@ -1270,6 +1295,7 @@ xla_test( "//tensorflow/compiler/xla/tests:xla_internal_test_main", # fixdeps: keep "//tensorflow/core:lib", "//tensorflow/core:test", + "@com_google_absl//absl/memory", ], ) @@ -1284,8 +1310,8 @@ xla_test( "//tensorflow/compiler/xla:array4d", "//tensorflow/compiler/xla:reference_util", "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/client:xla_computation", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:xla_internal_test_main", @@ -1307,8 +1333,8 @@ xla_test( "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/client:xla_computation", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:xla_internal_test_main", @@ -1328,14 +1354,14 @@ xla_test( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/compiler/xla/client:xla_computation", + "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/client/lib:arithmetic", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", "//tensorflow/core:test", + "@com_google_absl//absl/memory", ], ) @@ -1347,8 +1373,8 @@ xla_test( ], deps = [ "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/client:xla_computation", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:xla_internal_test_main", @@ -1364,8 +1390,8 @@ xla_test( ], deps = [ "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/client:xla_computation", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:xla_internal_test_main", @@ -1389,14 +1415,15 @@ xla_test( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/client:xla_computation", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:test_utils", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", "//tensorflow/core:test", + "@com_google_absl//absl/memory", ], ) @@ -1410,8 +1437,8 @@ xla_test( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/client:xla_computation", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", @@ -1440,8 +1467,8 @@ xla_test( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/client:xla_computation", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:xla_internal_test_main", @@ -1460,7 +1487,7 @@ xla_test( "//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:array4d", "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", + "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:xla_internal_test_main", @@ -1483,9 +1510,9 @@ xla_test( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/client/lib:arithmetic", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:xla_internal_test_main", @@ -1509,8 +1536,8 @@ xla_test( "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/client:xla_computation", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:xla_internal_test_main", @@ -1526,17 +1553,16 @@ xla_test( ], deps = [ "//tensorflow/compiler/xla:shape_util", - "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/compiler/xla/client:xla_computation", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", + "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", "//tensorflow/core:stream_executor_no_cuda", "//tensorflow/core:test", + "@com_google_absl//absl/algorithm:container", ], ) @@ -1551,8 +1577,8 @@ xla_test( "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/client:xla_computation", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:hlo_test_base", @@ -1574,7 +1600,7 @@ xla_test( "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", + "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:xla_internal_test_main", @@ -1595,8 +1621,8 @@ xla_test( "//tensorflow/compiler/xla:xla_proto", "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/client:xla_computation", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:test_utils", @@ -1614,8 +1640,8 @@ xla_test( ], deps = [ "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/client:xla_computation", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:xla_internal_test_main", @@ -1637,8 +1663,8 @@ xla_test( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:client_library", "//tensorflow/compiler/xla/client:global_data", + "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/client:xla_computation", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:test_utils", "//tensorflow/compiler/xla/tests:xla_internal_test_main", @@ -1658,8 +1684,8 @@ xla_test( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/client:xla_computation", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:test_utils", @@ -1675,8 +1701,8 @@ xla_test( deps = [ ":client_library_test_base", "//tensorflow/compiler/xla/client:global_data", + "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/client:xla_computation", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:test", ], @@ -1689,8 +1715,8 @@ xla_test( deps = [ ":client_library_test_base", "//tensorflow/compiler/xla/client:global_data", + "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/client:xla_computation", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:test", ], @@ -1710,8 +1736,8 @@ xla_test( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/client:xla_computation", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", "//tensorflow/compiler/xla/service:hlo_proto", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", @@ -1737,6 +1763,7 @@ xla_test( "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:test", + "@com_google_absl//absl/memory", ], ) @@ -1758,6 +1785,7 @@ tf_cc_test( "//tensorflow/core:test", "//tensorflow/core:test_main", "//tensorflow/stream_executor", + "@com_google_absl//absl/memory", "@llvm//:core", ], ) @@ -1795,8 +1823,8 @@ xla_test( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:client_library", "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/client:xla_computation", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/service:hlo_runner", @@ -1809,6 +1837,7 @@ xla_test( "//tensorflow/core:lib", "//tensorflow/core:test", "//third_party/eigen3", + "@com_google_absl//absl/memory", ], ) @@ -1823,8 +1852,8 @@ xla_test( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:client_library", "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/client:xla_computation", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service:hlo_runner", "//tensorflow/compiler/xla/service:platform_util", @@ -1835,6 +1864,7 @@ xla_test( "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", "//tensorflow/core:test", + "@com_google_absl//absl/memory", ], ) @@ -1860,8 +1890,8 @@ xla_test( "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/client:xla_computation", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", "//tensorflow/compiler/xla/service:local_service", "//tensorflow/compiler/xla/service:shaped_buffer", "//tensorflow/compiler/xla/tests:literal_test_util", @@ -1888,8 +1918,8 @@ xla_test( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:client_library", "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/client:xla_computation", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", "//tensorflow/compiler/xla/service:device_memory_allocator", "//tensorflow/compiler/xla/service:local_service", "//tensorflow/compiler/xla/service:platform_util", @@ -1924,7 +1954,7 @@ tf_cc_test( ":local_client_test_base", "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", + "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/service:cpu_plugin", "//tensorflow/compiler/xla/service:local_service", "//tensorflow/core:test_main", @@ -1966,8 +1996,8 @@ xla_test( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/client:xla_computation", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:xla_internal_test_main", @@ -1980,7 +2010,7 @@ xla_test( name = "deep_graph_test", srcs = ["deep_graph_test.cc"], deps = [ - "//tensorflow/compiler/xla/client/xla_client:xla_builder", + "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", ], @@ -2013,6 +2043,7 @@ xla_test( "//tensorflow/compiler/xla/service:device_memory_allocator", "//tensorflow/compiler/xla/service:generic_transfer_manager", "//tensorflow/compiler/xla/service:shaped_buffer", + "//tensorflow/compiler/xla/service:stream_pool", "//tensorflow/core:lib", "//tensorflow/core:stream_executor_no_cuda", "//tensorflow/core:test", @@ -2061,14 +2092,17 @@ tf_cc_test( xla_test( name = "test_utils_test", srcs = ["test_utils_test.cc"], + # There is nothing backend specific in this test, so just pick an arbitrary backend. + backends = ["cpu"], deps = [ ":local_client_test_base", ":test_utils", "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/client:xla_computation", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/core:lib", "//tensorflow/core:test", ], ) @@ -2087,7 +2121,7 @@ xla_test( ":client_library_test_base", ":literal_test_util", ":xla_internal_test_main", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", + "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/core:lib", "//tensorflow/core:test", ], diff --git a/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc b/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc index 3ae96fa1bcb1057653a75db62def5556ae37f886..74f2e36f826cd82ce4015df857f3de67950beaeb 100644 --- a/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc +++ b/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc @@ -24,7 +24,7 @@ limitations under the License. #include "tensorflow/compiler/xla/array4d.h" #include "tensorflow/compiler/xla/client/global_data.h" #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/statusor.h" diff --git a/tensorflow/compiler/xla/tests/axpy_simple_test.cc b/tensorflow/compiler/xla/tests/axpy_simple_test.cc index 8d15b7841bc7298cd6865d8689cc496c0459e4b9..caeb0bf49a0dde9eeac02037b2ea04fd024d100c 100644 --- a/tensorflow/compiler/xla/tests/axpy_simple_test.cc +++ b/tensorflow/compiler/xla/tests/axpy_simple_test.cc @@ -16,7 +16,7 @@ limitations under the License. #include #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/compiler/xla/tests/test_macros.h" diff --git a/tensorflow/compiler/xla/tests/bad_rng_shape_validation_test.cc b/tensorflow/compiler/xla/tests/bad_rng_shape_validation_test.cc index 71dbe4f0b6df1a7278d90f4e82313e4bd4c4d793..af0b8522394a0c591e6c42ad12db8853ef66243c 100644 --- a/tensorflow/compiler/xla/tests/bad_rng_shape_validation_test.cc +++ b/tensorflow/compiler/xla/tests/bad_rng_shape_validation_test.cc @@ -19,7 +19,7 @@ limitations under the License. #include #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/test.h" diff --git a/tensorflow/compiler/xla/tests/batch_normalization_test.cc b/tensorflow/compiler/xla/tests/batch_normalization_test.cc index 033382708a2b4368dbc7c42d51d6c7f3cd854b1c..24b17b71007a1872462bed1f6b86ae1a5bb9922c 100644 --- a/tensorflow/compiler/xla/tests/batch_normalization_test.cc +++ b/tensorflow/compiler/xla/tests/batch_normalization_test.cc @@ -22,7 +22,7 @@ limitations under the License. #include "tensorflow/compiler/xla/client/lib/arithmetic.h" #include "tensorflow/compiler/xla/client/lib/math.h" #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/reference_util.h" @@ -733,7 +733,7 @@ XLA_TEST_P(BatchNormTestManySizes, RandomizedGradTests) { var4D, [epsilon](float a) { return a + epsilon; }); auto rsqrt_var_add_epsilon = *ReferenceUtil::MapArray4D( - var_add_epsilon, [epsilon](float a) { return 1 / std::sqrt(a); }); + var_add_epsilon, [](float a) { return 1 / std::sqrt(a); }); auto grad_output_times_var = *ReferenceUtil::MapArray4D(grad_output_array, var_add_epsilon, diff --git a/tensorflow/compiler/xla/tests/bfloat16_test.cc b/tensorflow/compiler/xla/tests/bfloat16_test.cc index 747c82b502c8ec9f8121641382d9fd3c9552b010..6c20f654fe3df6a28e9633cd832c11b487894bad 100644 --- a/tensorflow/compiler/xla/tests/bfloat16_test.cc +++ b/tensorflow/compiler/xla/tests/bfloat16_test.cc @@ -21,7 +21,7 @@ limitations under the License. #include "tensorflow/compiler/xla/array4d.h" #include "tensorflow/compiler/xla/client/lib/arithmetic.h" #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/reference_util.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" diff --git a/tensorflow/compiler/xla/tests/binop_scaling_test.cc b/tensorflow/compiler/xla/tests/binop_scaling_test.cc index 20cb989751ad69e2f3cf97c87c43293951f599ab..0d7a3aa46a9c12c19d954c11ae3a2cccbed886ef 100644 --- a/tensorflow/compiler/xla/tests/binop_scaling_test.cc +++ b/tensorflow/compiler/xla/tests/binop_scaling_test.cc @@ -16,7 +16,7 @@ limitations under the License. #include "tensorflow/compiler/xla/array2d.h" #include "tensorflow/compiler/xla/array4d.h" #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/reference_util.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" diff --git a/tensorflow/compiler/xla/tests/bitcast_convert_test.cc b/tensorflow/compiler/xla/tests/bitcast_convert_test.cc index d531e8fa82e47f7bcd278f10da2c205e44db0ac1..c6b5108fe9e5bcf843982676d822f1942359da71 100644 --- a/tensorflow/compiler/xla/tests/bitcast_convert_test.cc +++ b/tensorflow/compiler/xla/tests/bitcast_convert_test.cc @@ -19,7 +19,7 @@ limitations under the License. #include #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" diff --git a/tensorflow/compiler/xla/tests/broadcast_simple_test.cc b/tensorflow/compiler/xla/tests/broadcast_simple_test.cc index 50dd574624bb3874e682be5a272fb5bdefa4adc4..1d28e85b16596b0ec2717138fb2081878203e8b2 100644 --- a/tensorflow/compiler/xla/tests/broadcast_simple_test.cc +++ b/tensorflow/compiler/xla/tests/broadcast_simple_test.cc @@ -20,7 +20,7 @@ limitations under the License. #include "tensorflow/compiler/xla/array2d.h" #include "tensorflow/compiler/xla/array4d.h" #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/statusor.h" diff --git a/tensorflow/compiler/xla/tests/broadcast_test.cc b/tensorflow/compiler/xla/tests/broadcast_test.cc index c7b94b5bbaaa512ad36056f9e68a87cc706c24b1..74d4d2eb10c32b270a83aa04dd2e6025d7a56c26 100644 --- a/tensorflow/compiler/xla/tests/broadcast_test.cc +++ b/tensorflow/compiler/xla/tests/broadcast_test.cc @@ -16,8 +16,8 @@ limitations under the License. #include #include +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/literal.h" -#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module.h" diff --git a/tensorflow/compiler/xla/tests/call_test.cc b/tensorflow/compiler/xla/tests/call_test.cc index 05c1c361bb815a5246d9169559eaac2a5020d166..b1d18210eaafdfec0920c0cccaa0dfdbd6de5609 100644 --- a/tensorflow/compiler/xla/tests/call_test.cc +++ b/tensorflow/compiler/xla/tests/call_test.cc @@ -16,7 +16,7 @@ limitations under the License. #include #include -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/literal_util.h" diff --git a/tensorflow/compiler/xla/tests/check_execution_arity_test.cc b/tensorflow/compiler/xla/tests/check_execution_arity_test.cc index 0bc8facfe2cfcfab094f483137f6d8e241c6aaf9..a4eb57fc7b9abd460a7d158d0dc629eba88018cd 100644 --- a/tensorflow/compiler/xla/tests/check_execution_arity_test.cc +++ b/tensorflow/compiler/xla/tests/check_execution_arity_test.cc @@ -17,7 +17,7 @@ limitations under the License. #include "tensorflow/compiler/xla/client/global_data.h" #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/statusor.h" diff --git a/tensorflow/compiler/xla/tests/client_library_test_base.cc b/tensorflow/compiler/xla/tests/client_library_test_base.cc index 515c0201d1c08771a2346af3d9f7b5df6dc8701d..2cab3264a7ebe6ef515783a5df55ac5609cbe106 100644 --- a/tensorflow/compiler/xla/tests/client_library_test_base.cc +++ b/tensorflow/compiler/xla/tests/client_library_test_base.cc @@ -17,13 +17,12 @@ limitations under the License. #include +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/client/client_library.h" #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" -#include "tensorflow/compiler/xla/client/xla_computation.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/execution_options_util.h" #include "tensorflow/compiler/xla/literal_util.h" -#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/platform_util.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" @@ -547,7 +546,7 @@ XlaComputation ClientLibraryTestBase::CreateScalarReluSensitivity() { std::unique_ptr> ClientLibraryTestBase::CreatePatternedMatrix( int rows, int cols, float offset) { - auto array = MakeUnique>(rows, cols); + auto array = absl::make_unique>(rows, cols); for (int64 row = 0; row < rows; ++row) { for (int64 col = 0; col < cols; ++col) { (*array)(row, col) = col + (row * 1000.0f) + offset; @@ -562,7 +561,7 @@ ClientLibraryTestBase::CreatePatternedMatrixWithZeroPadding(int rows, int cols, int cols_padded) { CHECK_GE(rows_padded, rows); CHECK_GE(cols_padded, cols); - auto array = MakeUnique>(rows_padded, cols_padded, 0.0); + auto array = absl::make_unique>(rows_padded, cols_padded, 0.0); for (int64 row = 0; row < rows; ++row) { for (int64 col = 0; col < cols; ++col) { (*array)(row, col) = col + (row * 1000.0f); diff --git a/tensorflow/compiler/xla/tests/client_library_test_base.h b/tensorflow/compiler/xla/tests/client_library_test_base.h index edc1ba8a5724a1a544d4eb605bc7b3d2bf28fcd4..24d0325929b66659f6b02ee5fd26ed6558b276e1 100644 --- a/tensorflow/compiler/xla/tests/client_library_test_base.h +++ b/tensorflow/compiler/xla/tests/client_library_test_base.h @@ -21,16 +21,16 @@ limitations under the License. #include #include +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/array2d.h" #include "tensorflow/compiler/xla/array3d.h" #include "tensorflow/compiler/xla/array4d.h" #include "tensorflow/compiler/xla/client/client_library.h" #include "tensorflow/compiler/xla/client/global_data.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/literal_util.h" -#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/compiler/xla/tests/test_utils.h" @@ -74,8 +74,9 @@ class ClientLibraryTestBase : public ::testing::Test { string TestName() const; void SetFastMathDisabled(bool disabled) { - execution_options_.mutable_debug_options()->set_xla_enable_fast_math( - !disabled); + auto* opts = execution_options_.mutable_debug_options(); + opts->set_xla_cpu_enable_fast_math(!disabled); + opts->set_xla_gpu_enable_fast_math(!disabled); } void SetSeed(uint64 seed) { execution_options_.set_seed(seed); } @@ -612,7 +613,7 @@ template std::unique_ptr> ClientLibraryTestBase::CreatePseudorandomR2( const int rows, const int cols, NativeT min_value, NativeT max_value, uint32 seed) { - auto result = MakeUnique>(rows, cols); + auto result = absl::make_unique>(rows, cols); PseudorandomGenerator generator(min_value, max_value, seed); for (int y = 0; y < rows; ++y) { for (int x = 0; x < cols; ++x) { diff --git a/tensorflow/compiler/xla/tests/client_test.cc b/tensorflow/compiler/xla/tests/client_test.cc index f97008bee26cdf6a33d1b6007e351fbda518260f..c898dacf489db97223e2918414daf5de88bece64 100644 --- a/tensorflow/compiler/xla/tests/client_test.cc +++ b/tensorflow/compiler/xla/tests/client_test.cc @@ -18,7 +18,7 @@ limitations under the License. #include "tensorflow/compiler/xla/client/global_data.h" #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" diff --git a/tensorflow/compiler/xla/tests/compilation_cache_test.cc b/tensorflow/compiler/xla/tests/compilation_cache_test.cc index 2b407ed2639bd883bad8314118fc4fba4e8ce05f..7c52c9fbbb57f9291ea9f0966e2efa715819fb67 100644 --- a/tensorflow/compiler/xla/tests/compilation_cache_test.cc +++ b/tensorflow/compiler/xla/tests/compilation_cache_test.cc @@ -19,7 +19,7 @@ limitations under the License. #include "tensorflow/compiler/xla/client/global_data.h" #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/shape_util.h" diff --git a/tensorflow/compiler/xla/tests/compute_constant_test.cc b/tensorflow/compiler/xla/tests/compute_constant_test.cc index 672fb06de6cd5171641956a77a81099fd07f2ad0..5a06d061f0d83fff547502495ff8ab13fb421b70 100644 --- a/tensorflow/compiler/xla/tests/compute_constant_test.cc +++ b/tensorflow/compiler/xla/tests/compute_constant_test.cc @@ -19,7 +19,7 @@ limitations under the License. #include "tensorflow/compiler/xla/client/client_library.h" #include "tensorflow/compiler/xla/client/global_data.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/literal.h" diff --git a/tensorflow/compiler/xla/tests/concat_test.cc b/tensorflow/compiler/xla/tests/concat_test.cc index e63d2480b6c2fca8343af411b7722155bfbe8ea7..be017477d84eb9faf5aa79dcdf54d6b6aaf6fd8e 100644 --- a/tensorflow/compiler/xla/tests/concat_test.cc +++ b/tensorflow/compiler/xla/tests/concat_test.cc @@ -19,7 +19,7 @@ limitations under the License. #include "tensorflow/compiler/xla/array2d.h" #include "tensorflow/compiler/xla/array3d.h" #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/reference_util.h" #include "tensorflow/compiler/xla/statusor.h" diff --git a/tensorflow/compiler/xla/tests/conditional_test.cc b/tensorflow/compiler/xla/tests/conditional_test.cc index d9d42bf061d326686fb68ecb84f073e506f08b6a..b27c1044baf2c0002f166c53a81e4361c60d012a 100644 --- a/tensorflow/compiler/xla/tests/conditional_test.cc +++ b/tensorflow/compiler/xla/tests/conditional_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" diff --git a/tensorflow/compiler/xla/tests/constants_test.cc b/tensorflow/compiler/xla/tests/constants_test.cc index 71d72a9828c5445be2cb1f559cf31363507bcd8d..49375748319ad5fe40db507a034ec4b07adb7e84 100644 --- a/tensorflow/compiler/xla/tests/constants_test.cc +++ b/tensorflow/compiler/xla/tests/constants_test.cc @@ -22,7 +22,7 @@ limitations under the License. #include "tensorflow/compiler/xla/array3d.h" #include "tensorflow/compiler/xla/array4d.h" #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" diff --git a/tensorflow/compiler/xla/tests/convert_test.cc b/tensorflow/compiler/xla/tests/convert_test.cc index 0fb6853e3f408488e79e52050f8b00c1ca073fef..7a203d6873dbb5b69f96c50048c2c5ff3150c544 100644 --- a/tensorflow/compiler/xla/tests/convert_test.cc +++ b/tensorflow/compiler/xla/tests/convert_test.cc @@ -19,8 +19,9 @@ limitations under the License. #include #include +#include "absl/algorithm/container.h" #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" @@ -447,11 +448,11 @@ std::vector GetInterestingF16ConversionTestCases() { XLA_TEST_F(ConvertTest, ConvertR1F16ToR1F32) { std::vector test_cases = GetInterestingF16ConversionTestCases(); std::vector input; - c_transform(test_cases, std::back_inserter(input), - [](float f) { return Eigen::half(f); }); + absl::c_transform(test_cases, std::back_inserter(input), + [](float f) { return Eigen::half(f); }); std::vector expected_output; - c_transform(input, std::back_inserter(expected_output), - [](Eigen::half h) { return static_cast(h); }); + absl::c_transform(input, std::back_inserter(expected_output), + [](Eigen::half h) { return static_cast(h); }); TF_ASSERT_OK_AND_ASSIGN( std::unique_ptr dot_lhs_handle, @@ -470,8 +471,8 @@ XLA_TEST_F(ConvertTest, ConvertR1F16ToR1F32) { XLA_TEST_F(ConvertTest, ConvertR1F32ToR1F16) { std::vector input = GetInterestingF16ConversionTestCases(); std::vector expected_output; - c_transform(input, std::back_inserter(expected_output), - [](float f) { return Eigen::half(f); }); + absl::c_transform(input, std::back_inserter(expected_output), + [](float f) { return Eigen::half(f); }); TF_ASSERT_OK_AND_ASSIGN( std::unique_ptr dot_lhs_handle, diff --git a/tensorflow/compiler/xla/tests/convolution_dimension_numbers_test.cc b/tensorflow/compiler/xla/tests/convolution_dimension_numbers_test.cc index 944366410b14439aa33999185525f1029735e95b..38b6da4fa96b0f6b7ed2d56852eb3ab2872f3520 100644 --- a/tensorflow/compiler/xla/tests/convolution_dimension_numbers_test.cc +++ b/tensorflow/compiler/xla/tests/convolution_dimension_numbers_test.cc @@ -17,11 +17,11 @@ limitations under the License. #include #include +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/array4d.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/client/padding.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" -#include "tensorflow/compiler/xla/ptr_util.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/reference_util.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/test.h" @@ -88,9 +88,9 @@ TEST_F(ConvolutionDimensionNumbersTest, InvalidOutputDimensionNumbers) { XLA_TEST_F(ConvolutionDimensionNumbersTest, TwoConvsWithDifferentDimensionNumbers) { - auto input_array = MakeUnique>(2, 3, 5, 5); + auto input_array = absl::make_unique>(2, 3, 5, 5); input_array->FillWithMultiples(0.1); - auto weight_array = MakeUnique>(4, 3, 1, 1); + auto weight_array = absl::make_unique>(4, 3, 1, 1); weight_array->FillWithMultiples(0.2); auto weight_data = client_ diff --git a/tensorflow/compiler/xla/tests/convolution_test.cc b/tensorflow/compiler/xla/tests/convolution_test.cc index a8b8f74ca9603a71acefc0be2141d7b9caf2b73b..40658c3b775de0a38df4d6a629cab29b1fc83f2b 100644 --- a/tensorflow/compiler/xla/tests/convolution_test.cc +++ b/tensorflow/compiler/xla/tests/convolution_test.cc @@ -18,19 +18,20 @@ limitations under the License. #include +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/array2d.h" #include "tensorflow/compiler/xla/array4d.h" #include "tensorflow/compiler/xla/client/global_data.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/client/padding.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/literal.h" -#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/reference_util.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/compiler/xla/tests/test_macros.h" #include "tensorflow/compiler/xla/xla_data.pb.h" @@ -70,16 +71,16 @@ class ForwardPassConvolution_3x3x256_256_OutputZ_Iota : public ConvolutionTest { const int kKernelSizeY = 2; const int kOutputActivationSizeZ = 256; const int kMiniBatchSize = 4; - auto alhs = - MakeUnique>(kMiniBatchSize, kInputActivationSizeZ, - kInputActivationSizeY, kInputActivationSizeX); + auto alhs = absl::make_unique>( + kMiniBatchSize, kInputActivationSizeZ, kInputActivationSizeY, + kInputActivationSizeX); alhs->FillWithMultiples(static_cast(1.0f)); ASSERT_EQ(3, alhs->width()); ASSERT_EQ(3, alhs->height()); - auto arhs = - MakeUnique>(kOutputActivationSizeZ, kInputActivationSizeZ, - kKernelSizeY, kKernelSizeX); + auto arhs = absl::make_unique>(kOutputActivationSizeZ, + kInputActivationSizeZ, + kKernelSizeY, kKernelSizeX); Array2D rhs_raster({ {1.0f, 0.0f}, // row 0 {0.0f, 0.0f}, // row 1 @@ -465,7 +466,7 @@ void iota_int_init_value(std::vector& values, int init_value) { } template -class Convolve2D_1x3x3x5_3x3x5x5_Valid : public ConvolutionTest { +class Convolve2D_1x3x3x5_3x3x5x3_Valid : public ConvolutionTest { public: void RunTest() { XlaBuilder builder(TestName()); @@ -520,8 +521,139 @@ class Convolve2D_1x3x3x5_3x3x5x5_Valid : public ConvolutionTest { } }; -TYPED_TEST_CASE(Convolve2D_1x3x3x5_3x3x5x5_Valid, TestTypes); -TYPED_TEST(Convolve2D_1x3x3x5_3x3x5x5_Valid, Types) { this->RunTest(); } +TYPED_TEST_CASE(Convolve2D_1x3x3x5_3x3x5x3_Valid, TestTypes); +TYPED_TEST(Convolve2D_1x3x3x5_3x3x5x3_Valid, Types) { this->RunTest(); } + +template +class Convolve2D_1x3x3x5_3x3x1x15_Depthwise_Valid : public ConvolutionTest { + public: + void RunTest() { + XlaBuilder builder(TestName()); + std::vector input_dims = {1, 3, 3, 5}; + std::vector filter_dims = {3, 3, 1, 15}; + Shape input_shape = ShapeUtil::MakeShapeWithType(input_dims); + Shape filter_shape = ShapeUtil::MakeShapeWithType(filter_dims); + { + auto input = Parameter(&builder, 0, input_shape, "input"); + auto filter = Parameter(&builder, 1, filter_shape, "filter"); + + // Tensorflow dimension numbers for 2D convolution. + ConvolutionDimensionNumbers dnums; + dnums.set_input_batch_dimension(0); + dnums.set_output_batch_dimension(0); + dnums.add_input_spatial_dimensions(1); + dnums.add_output_spatial_dimensions(1); + dnums.add_input_spatial_dimensions(2); + dnums.add_output_spatial_dimensions(2); + dnums.set_input_feature_dimension(3); + dnums.set_output_feature_dimension(3); + dnums.add_kernel_spatial_dimensions(0); + dnums.add_kernel_spatial_dimensions(1); + dnums.set_kernel_input_feature_dimension(2); + dnums.set_kernel_output_feature_dimension(3); + + ConvWithGeneralDimensions(input, filter, {1, 1}, Padding::kValid, dnums, + /*feature_group_count=*/5); + } + + std::vector input_elems(ShapeUtil::ElementsIn(input_shape)); + iota_int_init_value(input_elems, 1); + auto input_r1 = LiteralUtil::CreateR1(input_elems); + auto input_r4 = input_r1->Reshape(input_dims).ConsumeValueOrDie(); + + std::vector filter_elems(ShapeUtil::ElementsIn(filter_shape)); + iota_int_init_value(filter_elems, 1); + auto filter_r1 = LiteralUtil::CreateR1(filter_elems); + auto filter_r4 = filter_r1->Reshape(filter_dims).ConsumeValueOrDie(); + + auto expected_r1 = LiteralUtil::CreateR1( + {static_cast(16029), static_cast(16218), static_cast(16407), + static_cast(17172), static_cast(17370), static_cast(17568), + static_cast(18369), static_cast(18576), static_cast(18783), + static_cast(19620), static_cast(19836), static_cast(20052), + static_cast(20925), static_cast(21150), static_cast(21375)}); + auto expected_r4 = expected_r1->Reshape({1, 1, 1, 15}).ConsumeValueOrDie(); + + auto input_literal = + client_->TransferToServer(*input_r4).ConsumeValueOrDie(); + auto filter_literal = + client_->TransferToServer(*filter_r4).ConsumeValueOrDie(); + + ComputeAndCompareLiteral(&builder, *expected_r4, + {input_literal.get(), filter_literal.get()}, + error_spec_); + } +}; + +TYPED_TEST_CASE(Convolve2D_1x3x3x5_3x3x1x15_Depthwise_Valid, TestTypes); +TYPED_TEST(Convolve2D_1x3x3x5_3x3x1x15_Depthwise_Valid, Types) { + this->RunTest(); +} + +template +class Convolve2D_1x2x2x6_2x2x1x12_Grouped_Valid : public ConvolutionTest { + public: + void RunTest() { + XlaBuilder builder(TestName()); + std::vector input_dims = {1, 2, 2, 6}; + std::vector filter_dims = {2, 2, 2, 12}; + Shape input_shape = ShapeUtil::MakeShapeWithType(input_dims); + Shape filter_shape = ShapeUtil::MakeShapeWithType(filter_dims); + { + auto input = Parameter(&builder, 0, input_shape, "input"); + auto filter = Parameter(&builder, 1, filter_shape, "filter"); + + // Tensorflow dimension numbers for 2D convolution. + ConvolutionDimensionNumbers dnums; + dnums.set_input_batch_dimension(0); + dnums.set_output_batch_dimension(0); + dnums.add_input_spatial_dimensions(1); + dnums.add_output_spatial_dimensions(1); + dnums.add_input_spatial_dimensions(2); + dnums.add_output_spatial_dimensions(2); + dnums.set_input_feature_dimension(3); + dnums.set_output_feature_dimension(3); + dnums.add_kernel_spatial_dimensions(0); + dnums.add_kernel_spatial_dimensions(1); + dnums.set_kernel_input_feature_dimension(2); + dnums.set_kernel_output_feature_dimension(3); + + ConvWithGeneralDimensions(input, filter, {1, 1}, Padding::kValid, dnums, + /*feature_group_count=*/3); + } + + std::vector input_elems(ShapeUtil::ElementsIn(input_shape)); + iota_int_init_value(input_elems, 1); + auto input_r1 = LiteralUtil::CreateR1(input_elems); + auto input_r4 = input_r1->Reshape(input_dims).ConsumeValueOrDie(); + + std::vector filter_elems(ShapeUtil::ElementsIn(filter_shape)); + iota_int_init_value(filter_elems, 1); + auto filter_r1 = LiteralUtil::CreateR1(filter_elems); + auto filter_r4 = filter_r1->Reshape(filter_dims).ConsumeValueOrDie(); + + auto expected_r1 = LiteralUtil::CreateR1( + {static_cast(5076), static_cast(5160), static_cast(5244), + static_cast(5328), static_cast(6164), static_cast(6264), + static_cast(6364), static_cast(6464), static_cast(7380), + static_cast(7496), static_cast(7612), static_cast(7728)}); + auto expected_r4 = expected_r1->Reshape({1, 1, 1, 12}).ConsumeValueOrDie(); + + auto input_literal = + client_->TransferToServer(*input_r4).ConsumeValueOrDie(); + auto filter_literal = + client_->TransferToServer(*filter_r4).ConsumeValueOrDie(); + + ComputeAndCompareLiteral(&builder, *expected_r4, + {input_literal.get(), filter_literal.get()}, + error_spec_); + } +}; + +TYPED_TEST_CASE(Convolve2D_1x2x2x6_2x2x1x12_Grouped_Valid, TestTypes); +TYPED_TEST(Convolve2D_1x2x2x6_2x2x1x12_Grouped_Valid, Types) { + this->RunTest(); +} // Test fixture to run convolution tests with and without convolution // canonicalization enabled. @@ -765,5 +897,44 @@ XLA_TEST_F(ConvolutionTest, NoCudnnAlgorithmPicker) { std::move(*LiteralUtil::CreateFromArray(filter_data))}); } +class ConvolutionHloTest : public HloTestBase {}; + +XLA_TEST_F(ConvolutionHloTest, DISABLED_ON_CPU(ConvolveF64Forward)) { + constexpr char kHlo[] = R"( +HloModule TestModule + +ENTRY Test { + %arg0 = f64[3,56,56,16] parameter(0) + %arg1 = f64[3,3,3,64] parameter(1) + ROOT %conv = f64[54,54,16,64] convolution(%arg0, %arg1), window={size=3x3}, dim_labels=f01b_i01o->01bf +})"; + EXPECT_TRUE(RunAndCompare(kHlo, ErrorSpec{0.001})); +} + +XLA_TEST_F(ConvolutionHloTest, DISABLED_ON_CPU(ConvolveF64BackwardFilter)) { + constexpr char kHlo[] = R"( +HloModule TestModule + +ENTRY Test { + %arg0 = f64[2,5,8,1] parameter(0) + %arg1 = f64[2,5,8,2] parameter(1) + ROOT %conv = f64[4,4,1,2] convolution(%arg0, %arg1), window={size=5x8 pad=1_2x1_2}, dim_labels=f01b_i01o->01bf +})"; + EXPECT_TRUE(RunAndCompare(kHlo, ErrorSpec{0.001})); +} + +XLA_TEST_F(ConvolutionHloTest, DISABLED_ON_CPU(ConvolveF64BackwardInput)) { + constexpr char kHlo[] = R"( +HloModule TestModule + +ENTRY Test { + %output = f64[4,5,16,16] parameter(0) + %kernel = f64[5,3,7,7] parameter(1) + %reverse = f64[5,3,7,7] reverse(f64[5,3,7,7] %kernel), dimensions={2,3} + ROOT %convolution = f64[4,3,16,16] convolution(%output, %reverse), window={size=7x7 pad=3_3x3_3}, dim_labels=bf01_io01->bf01 +})"; + EXPECT_TRUE(RunAndCompare(kHlo, ErrorSpec{0.001})); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/tests/convolution_variants_test.cc b/tensorflow/compiler/xla/tests/convolution_variants_test.cc index 8792e7781b17465d94ae8ac8375a4523f368d720..6784c16715da72d337edf70fa51db42c59404136 100644 --- a/tensorflow/compiler/xla/tests/convolution_variants_test.cc +++ b/tensorflow/compiler/xla/tests/convolution_variants_test.cc @@ -27,7 +27,7 @@ limitations under the License. #include "tensorflow/compiler/xla/array4d.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/client/padding.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/reference_util.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" diff --git a/tensorflow/compiler/xla/tests/copy_test.cc b/tensorflow/compiler/xla/tests/copy_test.cc index 1dc6ff0f4f51b51002cfb868a51457c08a259a80..50a9ebc1e9915d5e8ad8d02276987784fe30b8fc 100644 --- a/tensorflow/compiler/xla/tests/copy_test.cc +++ b/tensorflow/compiler/xla/tests/copy_test.cc @@ -16,10 +16,10 @@ limitations under the License. #include #include +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/array2d.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/literal.h" -#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module.h" diff --git a/tensorflow/compiler/xla/tests/custom_call_test.cc b/tensorflow/compiler/xla/tests/custom_call_test.cc index 90f3d1b874f4da09104dc066c6642db1d2e77997..6f7fc0e6e52a69387a4c491871b6fcd97ac638b6 100644 --- a/tensorflow/compiler/xla/tests/custom_call_test.cc +++ b/tensorflow/compiler/xla/tests/custom_call_test.cc @@ -16,9 +16,9 @@ limitations under the License. #include #include -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "absl/memory/memory.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/literal_util.h" -#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/cpu/custom_call_target_registry.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" diff --git a/tensorflow/compiler/xla/tests/deallocation_test.cc b/tensorflow/compiler/xla/tests/deallocation_test.cc index 062b8cb8c408028d3dfcc7ad6c7821b25985890d..5f234f36a8543ad408fb3430b27844beb16a54b5 100644 --- a/tensorflow/compiler/xla/tests/deallocation_test.cc +++ b/tensorflow/compiler/xla/tests/deallocation_test.cc @@ -17,7 +17,7 @@ limitations under the License. #include "tensorflow/compiler/xla/client/global_data.h" #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/test.h" diff --git a/tensorflow/compiler/xla/tests/deconstruct_tuple_test.cc b/tensorflow/compiler/xla/tests/deconstruct_tuple_test.cc index 6795130cd10933d745171acc7c44fed90a6cb87d..2db6503afab748d7b778e26b2f9350ac64c7778b 100644 --- a/tensorflow/compiler/xla/tests/deconstruct_tuple_test.cc +++ b/tensorflow/compiler/xla/tests/deconstruct_tuple_test.cc @@ -18,7 +18,7 @@ limitations under the License. #include "tensorflow/compiler/xla/client/global_data.h" #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/shape_util.h" diff --git a/tensorflow/compiler/xla/tests/deep_graph_test.cc b/tensorflow/compiler/xla/tests/deep_graph_test.cc index 810947ab01b69b10b6ae60c551bd7aba10a6313d..3f3e8ab712fea14be9e4a7015effdf8ce518309b 100644 --- a/tensorflow/compiler/xla/tests/deep_graph_test.cc +++ b/tensorflow/compiler/xla/tests/deep_graph_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" namespace xla { diff --git a/tensorflow/compiler/xla/tests/dot_operation_test.cc b/tensorflow/compiler/xla/tests/dot_operation_test.cc index d86fd7cc2d4da10ed726ca11a6d9f86287a5d11e..0e9e92ed996fbb34826d19b670c7c4920a1aad13 100644 --- a/tensorflow/compiler/xla/tests/dot_operation_test.cc +++ b/tensorflow/compiler/xla/tests/dot_operation_test.cc @@ -19,7 +19,7 @@ limitations under the License. #include "tensorflow/compiler/xla/array2d.h" #include "tensorflow/compiler/xla/array3d.h" #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/primitive_util.h" #include "tensorflow/compiler/xla/reference_util.h" #include "tensorflow/compiler/xla/shape_util.h" @@ -111,7 +111,7 @@ XLA_TYPED_TEST(DotOperationTest_F16F32F64, TrivialMatrixVectorDot) { this->error_spec_); } -XLA_TYPED_TEST(DotOperationTest_F16F32F64, OneElementVectorDot) { +XLA_TYPED_TEST(DotOperationTest_F16F32F64CF64, OneElementVectorDot) { using T = TypeParam; XlaBuilder builder(this->TestName()); auto lhs = ConstantR1(&builder, {static_cast(2.0f)}); @@ -137,7 +137,7 @@ std::vector MinorToMajorForIsRowMajor(bool row_major) { return {row_major ? 1 : 0, row_major ? 0 : 1}; } -XLA_TYPED_TEST(DotOperationTest_F16F32F64, Dot_0x2_2x0) { +XLA_TYPED_TEST(DotOperationTest_F16F32F64CF64, Dot_0x2_2x0) { using T = TypeParam; XlaBuilder builder(this->TestName()); auto lhs = ConstantR2FromArray2D(&builder, Array2D(0, 2)); @@ -148,7 +148,7 @@ XLA_TYPED_TEST(DotOperationTest_F16F32F64, Dot_0x2_2x0) { this->error_spec_); } -XLA_TYPED_TEST(DotOperationTest_F16F32F64, Dot_0x2_2x3) { +XLA_TYPED_TEST(DotOperationTest_F16F32F64CF64, Dot_0x2_2x3) { using T = TypeParam; XlaBuilder builder(this->TestName()); auto lhs = ConstantR2FromArray2D(&builder, Array2D(0, 2)); @@ -160,7 +160,7 @@ XLA_TYPED_TEST(DotOperationTest_F16F32F64, Dot_0x2_2x3) { this->error_spec_); } -XLA_TYPED_TEST(DotOperationTest_F16F32F64, Dot_3x2_2x0) { +XLA_TYPED_TEST(DotOperationTest_F16F32F64CF64, Dot_3x2_2x0) { using T = TypeParam; XlaBuilder builder(this->TestName()); auto lhs = ConstantR2FromArray2D( @@ -172,7 +172,7 @@ XLA_TYPED_TEST(DotOperationTest_F16F32F64, Dot_3x2_2x0) { this->error_spec_); } -XLA_TYPED_TEST(DotOperationTest_F16F32F64, Dot_2x0_0x2) { +XLA_TYPED_TEST(DotOperationTest_F16F32F64CF64, Dot_2x0_0x2) { using T = TypeParam; XlaBuilder builder(this->TestName()); auto lhs = ConstantR2FromArray2D(&builder, Array2D(2, 0)); @@ -183,7 +183,7 @@ XLA_TYPED_TEST(DotOperationTest_F16F32F64, Dot_2x0_0x2) { &builder, Array2D(2, 2, static_cast(0.0f)), {}, this->error_spec_); } -XLA_TYPED_TEST(DotOperationTest_F16F32F64, FusedDot) { +XLA_TYPED_TEST(DotOperationTest_F16F32F64CF64, FusedDot) { using T = TypeParam; XlaBuilder builder(this->TestName()); auto param0 = @@ -533,7 +533,7 @@ XLA_TEST_F(DotOperationTest, MatrixVectorC64) { &builder, expected, {lhs_handle.get(), rhs_handle.get()}, error_spec_); } -XLA_TYPED_TEST(DotOperationTest_F16F32F64, ConcurrentMatMult) { +XLA_TYPED_TEST(DotOperationTest_F16F32F64CF64, ConcurrentMatMult) { using T = TypeParam; XlaBuilder builder(this->TestName()); @@ -612,7 +612,7 @@ XLA_TYPED_TEST(DotOperationTestForBatchMatMul, Types) { {x_data.get(), y_data.get()}, this->error_spec_); } -XLA_TYPED_TEST(DotOperationTest_F16F32F64, GeneralMatMul) { +XLA_TYPED_TEST(DotOperationTest_F16F32F64CF64, GeneralMatMul) { using T = TypeParam; XlaBuilder builder(this->TestName()); @@ -648,7 +648,49 @@ XLA_TYPED_TEST(DotOperationTest_F16F32F64, GeneralMatMul) { {x_data.get(), y_data.get()}, this->error_spec_); } -XLA_TYPED_TEST(DotOperationTest_F16F32F64, TransposeFolding) { +XLA_TYPED_TEST(DotOperationTest_F16F32F64CF64, GeneralMatMulMultipleBatch) { + using T = TypeParam; + + XlaBuilder builder(this->TestName()); + auto x = Parameter(&builder, 0, ShapeUtil::MakeShapeWithType({2, 2, 2, 2}), + "x"); + auto y = Parameter(&builder, 1, ShapeUtil::MakeShapeWithType({2, 2, 2, 2}), + "y"); + + DotDimensionNumbers dnums; + dnums.add_lhs_contracting_dimensions(3); + dnums.add_rhs_contracting_dimensions(2); + dnums.add_lhs_batch_dimensions(0); + dnums.add_lhs_batch_dimensions(1); + dnums.add_rhs_batch_dimensions(0); + dnums.add_rhs_batch_dimensions(1); + + DotGeneral(x, y, dnums); + + auto x_data = + this->client_ + ->TransferToServer(*LiteralUtil::CreateR4FromArray4D( + {{{{1.0f, 2.0f}, {3.0f, 4.0f}}, {{5.0f, 6.0f}, {7.0f, 8.0f}}}, + {{{9.0f, 10.0f}, {11.0f, 12.0f}}, + {{13.0f, 14.0f}, {15.0f, 16.0f}}}})) + .ConsumeValueOrDie(); + + auto y_data = + this->client_ + ->TransferToServer(*LiteralUtil::CreateR4FromArray4D( + {{{{1.0f, 0.0f}, {0.0f, 1.0f}}, {{1.0f, 0.0f}, {0.0f, 1.0f}}}, + {{{0.0f, 1.0f}, {1.0f, 0.0f}}, {{0.0f, 1.0f}, {1.0f, 0.0f}}}})) + .ConsumeValueOrDie(); + + this->template ComputeAndCompareR4( + &builder, + /*expected=*/ + {{{{1.0f, 2.0f}, {3.0f, 4.0f}}, {{5.0f, 6.0f}, {7.0f, 8.0f}}}, + {{{10.0f, 9.0f}, {12.0f, 11.0f}}, {{14.0f, 13.0f}, {16.0f, 15.0f}}}}, + {x_data.get(), y_data.get()}, this->error_spec_); +} + +XLA_TYPED_TEST(DotOperationTest_F16F32F64CF64, TransposeFolding) { using T = TypeParam; for (bool transpose_lhs : {false, true}) { for (bool transpose_rhs : {false, true}) { @@ -708,7 +750,7 @@ XLA_TYPED_TEST(DotOperationTest_F16F32F64, TransposeFolding) { } } -XLA_TYPED_TEST(DotOperationTest_F16F32F64, +XLA_TYPED_TEST(DotOperationTest_F16F32F64CF64, DotOfConcatOptimizationWithConstLHS) { using T = TypeParam; auto prim_type = primitive_util::NativeToPrimitiveType(); @@ -754,7 +796,7 @@ XLA_TYPED_TEST(DotOperationTest_F16F32F64, this->error_spec_); } -XLA_TYPED_TEST(DotOperationTest_F16F32F64, +XLA_TYPED_TEST(DotOperationTest_F16F32F64CF64, DotOfConcatOptimizationWithConstRHS) { using T = TypeParam; std::unique_ptr> constant_rhs_array( diff --git a/tensorflow/compiler/xla/tests/dynamic_ops_test.cc b/tensorflow/compiler/xla/tests/dynamic_ops_test.cc index 88ac96d6b0f9206ef1ed0e4135495d7903ebf3f4..7f6f203a1ba48e0053f799c58bbbeae87aef1f7f 100644 --- a/tensorflow/compiler/xla/tests/dynamic_ops_test.cc +++ b/tensorflow/compiler/xla/tests/dynamic_ops_test.cc @@ -19,7 +19,7 @@ limitations under the License. #include "tensorflow/compiler/xla/array2d.h" #include "tensorflow/compiler/xla/client/client_library.h" #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/reference_util.h" #include "tensorflow/compiler/xla/service/device_memory_allocator.h" #include "tensorflow/compiler/xla/service/local_service.h" diff --git a/tensorflow/compiler/xla/tests/execution_profile_test.cc b/tensorflow/compiler/xla/tests/execution_profile_test.cc index e2c145b795c3efab0e220834c5d2f962e27a6333..5116e60ca63ef5f94b25b15e6616086fb9e44bbb 100644 --- a/tensorflow/compiler/xla/tests/execution_profile_test.cc +++ b/tensorflow/compiler/xla/tests/execution_profile_test.cc @@ -14,7 +14,7 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/xla/client/global_data.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" #include "tensorflow/compiler/xla/tests/test_macros.h" diff --git a/tensorflow/compiler/xla/tests/exhaustive_f32_elementwise_op_test.cc b/tensorflow/compiler/xla/tests/exhaustive_f32_elementwise_op_test.cc index 86bfaea4ef43ad382e497fd281ec5439f001b56f..bf1de02ba9dbd97db9ee31484402fe9b92385219 100644 --- a/tensorflow/compiler/xla/tests/exhaustive_f32_elementwise_op_test.cc +++ b/tensorflow/compiler/xla/tests/exhaustive_f32_elementwise_op_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/compiler/xla/tests/test_macros.h" diff --git a/tensorflow/compiler/xla/tests/floor_ceil_test.cc b/tensorflow/compiler/xla/tests/floor_ceil_test.cc index 30dc639f117b9871238f0bf1628502cf8bef2e0c..39cc6c5927f1d416e31f689487efc10c20371abe 100644 --- a/tensorflow/compiler/xla/tests/floor_ceil_test.cc +++ b/tensorflow/compiler/xla/tests/floor_ceil_test.cc @@ -17,7 +17,7 @@ limitations under the License. #include #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/compiler/xla/tests/test_macros.h" diff --git a/tensorflow/compiler/xla/tests/fmax_test.cc b/tensorflow/compiler/xla/tests/fmax_test.cc index 0254ae1baaa864b38c3b217a5c2026d34b7f7d12..c5bbbe778df15d63a2586bd6291a7a33fc82aa52 100644 --- a/tensorflow/compiler/xla/tests/fmax_test.cc +++ b/tensorflow/compiler/xla/tests/fmax_test.cc @@ -16,7 +16,7 @@ limitations under the License. #include #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/core/platform/test.h" diff --git a/tensorflow/compiler/xla/tests/fusion_test.cc b/tensorflow/compiler/xla/tests/fusion_test.cc index 607bcdd51ee8ff678cd84622a81f45d5525bb683..341124170a5f6768720032394c42205f9185920a 100644 --- a/tensorflow/compiler/xla/tests/fusion_test.cc +++ b/tensorflow/compiler/xla/tests/fusion_test.cc @@ -22,13 +22,13 @@ limitations under the License. #define EIGEN_USE_THREADS +#include "absl/memory/memory.h" #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #include "tensorflow/compiler/xla/array2d.h" #include "tensorflow/compiler/xla/client/client_library.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/primitive_util.h" -#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module.h" diff --git a/tensorflow/compiler/xla/tests/gather_operation_test.cc b/tensorflow/compiler/xla/tests/gather_operation_test.cc index 2008d69237caf2e00c21645388ae6b648fdab2cd..f866ed6519e0e0da87806e26abfa771583261d19 100644 --- a/tensorflow/compiler/xla/tests/gather_operation_test.cc +++ b/tensorflow/compiler/xla/tests/gather_operation_test.cc @@ -13,8 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" -#include "tensorflow/compiler/xla/client/xla_computation.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/execution_options_util.h" #include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/status_macros.h" @@ -31,8 +30,8 @@ using tensorflow::gtl::nullopt; class GatherOperationTest : public HloTestBase { protected: void RunTest(const string& hlo_text, Literal* operand, - Literal* gather_indices) { - RunTest(hlo_text, {operand, gather_indices}); + Literal* start_indices) { + RunTest(hlo_text, {operand, start_indices}); } void RunTest(const string& hlo_text, @@ -53,18 +52,17 @@ ENTRY main { operand = s32[3,3] parameter(0) indices = s32[2] parameter(1) ROOT gather = s32[2,3] gather(operand, indices), - output_window_dims={1}, - elided_window_dims={0}, - gather_dims_to_operand_dims={0}, + offset_dims={1}, + collapsed_slice_dims={0}, + start_index_map={0}, index_vector_dim=1, - window_bounds={1, 3} + slice_sizes={1, 3} } )"; std::unique_ptr operand = LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); - std::unique_ptr gather_indices = - LiteralUtil::CreateR1({0, 2}); - RunTest(hlo_text, operand.get(), gather_indices.get()); + std::unique_ptr start_indices = LiteralUtil::CreateR1({0, 2}); + RunTest(hlo_text, operand.get(), start_indices.get()); } XLA_TEST_F(GatherOperationTest, TensorFlowGatherV2) { @@ -75,18 +73,17 @@ ENTRY main { operand = s32[3,3] parameter(0) indices = s32[2] parameter(1) ROOT gather = s32[3,2] gather(operand, indices), - output_window_dims={0}, - elided_window_dims={1}, - gather_dims_to_operand_dims={1}, + offset_dims={0}, + collapsed_slice_dims={1}, + start_index_map={1}, index_vector_dim=1, - window_bounds={3, 1} + slice_sizes={3, 1} } )"; std::unique_ptr operand = LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); - std::unique_ptr gather_indices = - LiteralUtil::CreateR1({0, 2}); - RunTest(hlo_text, operand.get(), gather_indices.get()); + std::unique_ptr start_indices = LiteralUtil::CreateR1({0, 2}); + RunTest(hlo_text, operand.get(), start_indices.get()); } XLA_TEST_F(GatherOperationTest, TensorFlowGatherMultipleBatchDims) { @@ -97,18 +94,18 @@ ENTRY main { operand = s32[3,3] parameter(0) indices = s32[2,2] parameter(1) ROOT gather = s32[2,3,2] gather(operand, indices), - output_window_dims={1}, - elided_window_dims={1}, - gather_dims_to_operand_dims={1}, + offset_dims={1}, + collapsed_slice_dims={1}, + start_index_map={1}, index_vector_dim=2, - window_bounds={3, 1} + slice_sizes={3, 1} } )"; std::unique_ptr operand = LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); - std::unique_ptr gather_indices = + std::unique_ptr start_indices = LiteralUtil::CreateR2({{0, 2}, {2, 1}}); - RunTest(hlo_text, operand.get(), gather_indices.get()); + RunTest(hlo_text, operand.get(), start_indices.get()); } XLA_TEST_F(GatherOperationTest, TensorFlowGatherNdMultipleBatchDims_0) { @@ -119,18 +116,18 @@ ENTRY main { operand = s32[3,3] parameter(0) indices = s32[2,2,2] parameter(1) ROOT gather = s32[2,2] gather(operand, indices), - output_window_dims={}, - elided_window_dims={0,1}, - gather_dims_to_operand_dims={0,1}, + offset_dims={}, + collapsed_slice_dims={0,1}, + start_index_map={0,1}, index_vector_dim=2, - window_bounds={1, 1} + slice_sizes={1, 1} } )"; std::unique_ptr operand = LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); - std::unique_ptr gather_indices = + std::unique_ptr start_indices = LiteralUtil::CreateR3({{{0, 2}, {2, 1}}, {{1, 2}, {2, 0}}}); - RunTest(hlo_text, operand.get(), gather_indices.get()); + RunTest(hlo_text, operand.get(), start_indices.get()); } XLA_TEST_F(GatherOperationTest, TensorFlowGatherNdMultipleBatchDims_1) { @@ -141,18 +138,18 @@ ENTRY main { operand = s32[3,3] parameter(0) indices = s32[2,2,2] parameter(1) ROOT gather = s32[2,1,1,2] gather(operand, indices), - output_window_dims={1,2}, - elided_window_dims={}, - gather_dims_to_operand_dims={0,1}, + offset_dims={1,2}, + collapsed_slice_dims={}, + start_index_map={0,1}, index_vector_dim=2, - window_bounds={1, 1} + slice_sizes={1, 1} } )"; std::unique_ptr operand = LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); - std::unique_ptr gather_indices = + std::unique_ptr start_indices = LiteralUtil::CreateR3({{{0, 2}, {2, 1}}, {{1, 2}, {2, 0}}}); - RunTest(hlo_text, operand.get(), gather_indices.get()); + RunTest(hlo_text, operand.get(), start_indices.get()); } XLA_TEST_F(GatherOperationTest, TensorFlowGatherNd) { @@ -163,20 +160,20 @@ ENTRY main { operand = s32[3,3,2] parameter(0) indices = s32[2,2] parameter(1) ROOT gather = s32[2,2] gather(operand, indices), - output_window_dims={1}, - elided_window_dims={0,1}, - gather_dims_to_operand_dims={0,1}, + offset_dims={1}, + collapsed_slice_dims={0,1}, + start_index_map={0,1}, index_vector_dim=1, - window_bounds={1,1,2} + slice_sizes={1,1,2} } )"; std::unique_ptr operand = LiteralUtil::CreateR3({{{-1, 1}, {-2, 2}, {-3, 3}}, // {{-4, 4}, {-5, 5}, {-6, 6}}, // {{-7, 7}, {-8, 8}, {-9, 9}}}); - std::unique_ptr gather_indices = + std::unique_ptr start_indices = LiteralUtil::CreateR2({{0, 0}, {1, 0}}); - RunTest(hlo_text, operand.get(), gather_indices.get()); + RunTest(hlo_text, operand.get(), start_indices.get()); } XLA_TEST_F(GatherOperationTest, TensorFlowGatherNdNonDefaultIndexVectorDim) { @@ -187,20 +184,20 @@ ENTRY main { operand = s32[3,3,2] parameter(0) indices = s32[2,2] parameter(1) ROOT gather = s32[2,2] gather(operand, indices), - output_window_dims={1}, - elided_window_dims={0,1}, - gather_dims_to_operand_dims={0,1}, + offset_dims={1}, + collapsed_slice_dims={0,1}, + start_index_map={0,1}, index_vector_dim=0, - window_bounds={1,1,2} + slice_sizes={1,1,2} } )"; std::unique_ptr operand = LiteralUtil::CreateR3({{{-1, 1}, {-2, 2}, {-3, 3}}, // {{-4, 4}, {-5, 5}, {-6, 6}}, // {{-7, 7}, {-8, 8}, {-9, 9}}}); - std::unique_ptr gather_indices = + std::unique_ptr start_indices = LiteralUtil::CreateR2({{0, 0}, {1, 0}}); - RunTest(hlo_text, operand.get(), gather_indices.get()); + RunTest(hlo_text, operand.get(), start_indices.get()); } XLA_TEST_F(GatherOperationTest, DynamicSlice) { @@ -211,18 +208,17 @@ ENTRY main { operand = s32[3,3] parameter(0) indices = s32[2] parameter(1) ROOT gather = s32[1,1] gather(operand, indices), - output_window_dims={0,1}, - elided_window_dims={}, - gather_dims_to_operand_dims={0,1}, + offset_dims={0,1}, + collapsed_slice_dims={}, + start_index_map={0,1}, index_vector_dim=0, - window_bounds={1,1} + slice_sizes={1,1} } )"; std::unique_ptr operand = LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); - std::unique_ptr gather_indices = - LiteralUtil::CreateR1({1, 1}); - RunTest(hlo_text, operand.get(), gather_indices.get()); + std::unique_ptr start_indices = LiteralUtil::CreateR1({1, 1}); + RunTest(hlo_text, operand.get(), start_indices.get()); } XLA_TEST_F(GatherOperationTest, BatchDynamicSlice) { @@ -233,18 +229,18 @@ ENTRY main { operand = s32[3,3] parameter(0) indices = s32[2,2] parameter(1) ROOT gather = s32[2,1,1] gather(operand, indices), - output_window_dims={1,2}, - elided_window_dims={}, - gather_dims_to_operand_dims={0,1}, + offset_dims={1,2}, + collapsed_slice_dims={}, + start_index_map={0,1}, index_vector_dim=0, - window_bounds={1,1} + slice_sizes={1,1} } )"; std::unique_ptr operand = LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); - std::unique_ptr gather_indices = + std::unique_ptr start_indices = LiteralUtil::CreateR2({{2, 1}, {1, 1}}); - RunTest(hlo_text, operand.get(), gather_indices.get()); + RunTest(hlo_text, operand.get(), start_indices.get()); } XLA_TEST_F(GatherOperationTest, ZeroDimBounds) { @@ -255,17 +251,16 @@ ENTRY main { operand = s32[3,0] parameter(0) indices = s32[2] parameter(1) ROOT gather = s32[2,0] gather(operand, indices), - output_window_dims={1}, - elided_window_dims={0}, - gather_dims_to_operand_dims={0}, + offset_dims={1}, + collapsed_slice_dims={0}, + start_index_map={0}, index_vector_dim=1, - window_bounds={1, 0} + slice_sizes={1, 0} } )"; std::unique_ptr operand = LiteralUtil::CreateR2({{}, {}, {}}); - std::unique_ptr gather_indices = - LiteralUtil::CreateR1({0, 2}); - RunTest(hlo_text, operand.get(), gather_indices.get()); + std::unique_ptr start_indices = LiteralUtil::CreateR1({0, 2}); + RunTest(hlo_text, operand.get(), start_indices.get()); } XLA_TEST_F(GatherOperationTest, OutOfBoundsIndex) { @@ -279,19 +274,19 @@ ENTRY main { operand = s32[3,3]{1,0} parameter(0) indices = s32[6,2]{1,0} parameter(1) gather = s32[6,1,1]{2,1,0} gather(operand, indices), - output_window_dims={1,2}, - elided_window_dims={}, - gather_dims_to_operand_dims={0,1}, + offset_dims={1,2}, + collapsed_slice_dims={}, + start_index_map={0,1}, index_vector_dim=1, - window_bounds={1,1} + slice_sizes={1,1} ROOT result = s32[6]{0} reshape(gather) } )"; std::unique_ptr operand = LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); - std::unique_ptr gather_indices = LiteralUtil::CreateR2( + std::unique_ptr start_indices = LiteralUtil::CreateR2( {{2, 7}, {2, 1}, {1, 1}, {5, 1}, {2147483647, 1}, {1, 2}}); - RunTest(hlo_text, operand.get(), gather_indices.get()); + RunTest(hlo_text, operand.get(), start_indices.get()); } XLA_TEST_F(GatherOperationTest, OutOfBoundsUnsignedIndex) { @@ -305,19 +300,19 @@ ENTRY main { operand = s32[3,3]{1,0} parameter(0) indices = u32[6,2]{1,0} parameter(1) gather = s32[6,1,1]{2,1,0} gather(operand, indices), - output_window_dims={1,2}, - elided_window_dims={}, - gather_dims_to_operand_dims={0,1}, + offset_dims={1,2}, + collapsed_slice_dims={}, + start_index_map={0,1}, index_vector_dim=1, - window_bounds={1,1} + slice_sizes={1,1} ROOT result = s32[6]{0} reshape(gather) } )"; std::unique_ptr operand = LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); - std::unique_ptr gather_indices = LiteralUtil::CreateR2( + std::unique_ptr start_indices = LiteralUtil::CreateR2( {{2, 7}, {2, 1}, {1, 1}, {5, 1}, {2147483648u, 1}, {1, 2}}); - RunTest(hlo_text, operand.get(), gather_indices.get()); + RunTest(hlo_text, operand.get(), start_indices.get()); } XLA_TEST_F(GatherOperationTest, NegativeIndex) { @@ -331,19 +326,19 @@ ENTRY main { operand = s32[3,3]{1,0} parameter(0) indices = s32[6,2]{1,0} parameter(1) gather = s32[6,1,1]{2,1,0} gather(operand, indices), - output_window_dims={1,2}, - elided_window_dims={}, - gather_dims_to_operand_dims={0,1}, + offset_dims={1,2}, + collapsed_slice_dims={}, + start_index_map={0,1}, index_vector_dim=1, - window_bounds={1,1} + slice_sizes={1,1} ROOT result = s32[6]{0} reshape(gather) } )"; std::unique_ptr operand = LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); - std::unique_ptr gather_indices = LiteralUtil::CreateR2( + std::unique_ptr start_indices = LiteralUtil::CreateR2( {{2, -1}, {2, 1}, {1, 1}, {-500, 1}, {-2147483648, 1}, {1, 2}}); - RunTest(hlo_text, operand.get(), gather_indices.get()); + RunTest(hlo_text, operand.get(), start_indices.get()); } XLA_TEST_F(GatherOperationTest, NegativeIndexIntoUnsignedOperand) { @@ -357,19 +352,19 @@ ENTRY main { operand = u32[3,3]{1,0} parameter(0) indices = s32[6,2]{1,0} parameter(1) gather = u32[6,1,1]{2,1,0} gather(operand, indices), - output_window_dims={1,2}, - elided_window_dims={}, - gather_dims_to_operand_dims={0,1}, + offset_dims={1,2}, + collapsed_slice_dims={}, + start_index_map={0,1}, index_vector_dim=1, - window_bounds={1,1} + slice_sizes={1,1} ROOT result = u32[6]{0} reshape(gather) } )"; std::unique_ptr operand = LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); - std::unique_ptr gather_indices = LiteralUtil::CreateR2( + std::unique_ptr start_indices = LiteralUtil::CreateR2( {{2, -1}, {2, 1}, {1, 1}, {-500, 1}, {-2147483648, 1}, {1, 2}}); - RunTest(hlo_text, operand.get(), gather_indices.get()); + RunTest(hlo_text, operand.get(), start_indices.get()); } XLA_TEST_F(GatherOperationTest, OneScalarIndex) { @@ -380,17 +375,17 @@ ENTRY main { operand = s32[2,3,2]{2,1,0} parameter(0) index = s32[] parameter(1) ROOT gather = s32[1,3,2]{2,1,0} gather(operand, index), - output_window_dims={0,1,2}, - elided_window_dims={}, - gather_dims_to_operand_dims={0}, + offset_dims={0,1,2}, + collapsed_slice_dims={}, + start_index_map={0}, index_vector_dim=0, - window_bounds={1,3,2} + slice_sizes={1,3,2} } )"; std::unique_ptr operand = LiteralUtil::CreateR3( {{{1, 2}, {3, 4}, {5, 6}}, {{7, 8}, {9, 10}, {11, 12}}}); - std::unique_ptr gather_indices = LiteralUtil::CreateR0(1); - RunTest(hlo_text, operand.get(), gather_indices.get()); + std::unique_ptr start_indices = LiteralUtil::CreateR0(1); + RunTest(hlo_text, operand.get(), start_indices.get()); } XLA_TEST_F(GatherOperationTest, ScalarResult) { @@ -401,16 +396,16 @@ ENTRY main { operand = s32[4]{0} parameter(0) index = s32[] parameter(1) ROOT gather = s32[] gather(operand, index), - output_window_dims={}, - elided_window_dims={0}, - gather_dims_to_operand_dims={0}, + offset_dims={}, + collapsed_slice_dims={0}, + start_index_map={0}, index_vector_dim=0, - window_bounds={1} + slice_sizes={1} } )"; std::unique_ptr operand = LiteralUtil::CreateR1({1, 2, 3, 4}); - std::unique_ptr gather_indices = LiteralUtil::CreateR0(1); - RunTest(hlo_text, operand.get(), gather_indices.get()); + std::unique_ptr start_indices = LiteralUtil::CreateR0(1); + RunTest(hlo_text, operand.get(), start_indices.get()); } XLA_TEST_F(GatherOperationTest, ZeroSizedResult) { @@ -421,17 +416,17 @@ ENTRY main { operand = s32[3,3] parameter(0) indices = s32[0] parameter(1) ROOT gather = s32[0,3] gather(operand, indices), - output_window_dims={1}, - elided_window_dims={0}, - gather_dims_to_operand_dims={0}, + offset_dims={1}, + collapsed_slice_dims={0}, + start_index_map={0}, index_vector_dim=1, - window_bounds={1, 3} + slice_sizes={1, 3} } )"; std::unique_ptr operand = LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); - std::unique_ptr gather_indices = LiteralUtil::CreateR1({}); - RunTest(hlo_text, operand.get(), gather_indices.get()); + std::unique_ptr start_indices = LiteralUtil::CreateR1({}); + RunTest(hlo_text, operand.get(), start_indices.get()); } XLA_TEST_F(GatherOperationTest, FusedTensorFlowGatherV2) { @@ -442,11 +437,11 @@ ENTRY main { operand = s32[3,3] parameter(0) indices = s32[2] parameter(1) gather = s32[3,2] gather(operand, indices), - output_window_dims={0}, - elided_window_dims={1}, - gather_dims_to_operand_dims={1}, + offset_dims={0}, + collapsed_slice_dims={1}, + start_index_map={1}, index_vector_dim=1, - window_bounds={3, 1} + slice_sizes={3, 1} one = s32[] constant(1) one_broadcasted = s32[3,2] broadcast(one), dimensions={} ROOT result = s32[3,2]{1,0} add(gather, one_broadcasted) @@ -454,9 +449,8 @@ ENTRY main { )"; std::unique_ptr operand = LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); - std::unique_ptr gather_indices = - LiteralUtil::CreateR1({0, 2}); - RunTest(hlo_text, operand.get(), gather_indices.get()); + std::unique_ptr start_indices = LiteralUtil::CreateR1({0, 2}); + RunTest(hlo_text, operand.get(), start_indices.get()); } XLA_TEST_F(GatherOperationTest, FusedTensorFlowGatherMultipleBatchDims) { @@ -467,11 +461,11 @@ ENTRY main { operand = s32[3,3] parameter(0) indices = s32[2,2] parameter(1) gather = s32[2,3,2] gather(operand, indices), - output_window_dims={1}, - elided_window_dims={1}, - gather_dims_to_operand_dims={1}, + offset_dims={1}, + collapsed_slice_dims={1}, + start_index_map={1}, index_vector_dim=2, - window_bounds={3, 1} + slice_sizes={3, 1} one = s32[] constant(1) one_broadcasted = s32[2,3,2] broadcast(one), dimensions={} ROOT result = s32[2,3,2]{2,1,0} add(gather, one_broadcasted) @@ -479,9 +473,9 @@ ENTRY main { )"; std::unique_ptr operand = LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); - std::unique_ptr gather_indices = + std::unique_ptr start_indices = LiteralUtil::CreateR2({{0, 2}, {2, 1}}); - RunTest(hlo_text, operand.get(), gather_indices.get()); + RunTest(hlo_text, operand.get(), start_indices.get()); } XLA_TEST_F(GatherOperationTest, FusedTensorFlowGatherNdMultipleBatchDims) { @@ -492,11 +486,11 @@ ENTRY main { operand = s32[3,3] parameter(0) indices = s32[2,2,2] parameter(1) gather = s32[2,2] gather(operand, indices), - output_window_dims={}, - elided_window_dims={0,1}, - gather_dims_to_operand_dims={0,1}, + offset_dims={}, + collapsed_slice_dims={0,1}, + start_index_map={0,1}, index_vector_dim=2, - window_bounds={1, 1} + slice_sizes={1, 1} one = s32[] constant(1) one_broadcasted = s32[2,2] broadcast(one), dimensions={} ROOT result = s32[2,2]{1,0} add(gather, one_broadcasted) @@ -504,9 +498,9 @@ ENTRY main { )"; std::unique_ptr operand = LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); - std::unique_ptr gather_indices = + std::unique_ptr start_indices = LiteralUtil::CreateR3({{{0, 2}, {2, 1}}, {{1, 2}, {2, 0}}}); - RunTest(hlo_text, operand.get(), gather_indices.get()); + RunTest(hlo_text, operand.get(), start_indices.get()); } XLA_TEST_F(GatherOperationTest, FusedTensorFlowGatherNd) { @@ -517,11 +511,11 @@ ENTRY main { operand = s32[3,3,2] parameter(0) indices = s32[2,2] parameter(1) gather = s32[2,2] gather(operand, indices), - output_window_dims={1}, - elided_window_dims={0,1}, - gather_dims_to_operand_dims={0,1}, + offset_dims={1}, + collapsed_slice_dims={0,1}, + start_index_map={0,1}, index_vector_dim=1, - window_bounds={1,1,2} + slice_sizes={1,1,2} one = s32[] constant(1) one_broadcasted = s32[2,2] broadcast(one), dimensions={} ROOT result = s32[2,2]{1,0} add(gather, one_broadcasted) @@ -531,9 +525,9 @@ ENTRY main { LiteralUtil::CreateR3({{{-1, 1}, {-2, 2}, {-3, 3}}, // {{-4, 4}, {-5, 5}, {-6, 6}}, // {{-7, 7}, {-8, 8}, {-9, 9}}}); - std::unique_ptr gather_indices = + std::unique_ptr start_indices = LiteralUtil::CreateR2({{0, 0}, {1, 0}}); - RunTest(hlo_text, operand.get(), gather_indices.get()); + RunTest(hlo_text, operand.get(), start_indices.get()); } XLA_TEST_F(GatherOperationTest, @@ -545,11 +539,11 @@ ENTRY main { operand = s32[3,3,2] parameter(0) indices = s32[2,2] parameter(1) gather = s32[2,2] gather(operand, indices), - output_window_dims={1}, - elided_window_dims={0,1}, - gather_dims_to_operand_dims={0,1}, + offset_dims={1}, + collapsed_slice_dims={0,1}, + start_index_map={0,1}, index_vector_dim=0, - window_bounds={1,1,2} + slice_sizes={1,1,2} one = s32[] constant(1) one_broadcasted = s32[2,2] broadcast(one), dimensions={} ROOT result = s32[2,2]{1,0} add(gather, one_broadcasted) @@ -559,9 +553,9 @@ ENTRY main { LiteralUtil::CreateR3({{{-1, 1}, {-2, 2}, {-3, 3}}, // {{-4, 4}, {-5, 5}, {-6, 6}}, // {{-7, 7}, {-8, 8}, {-9, 9}}}); - std::unique_ptr gather_indices = + std::unique_ptr start_indices = LiteralUtil::CreateR2({{0, 0}, {1, 0}}); - RunTest(hlo_text, operand.get(), gather_indices.get()); + RunTest(hlo_text, operand.get(), start_indices.get()); } XLA_TEST_F(GatherOperationTest, FusedDynamicSlice) { @@ -572,11 +566,11 @@ ENTRY main { operand = s32[3,3] parameter(0) indices = s32[2] parameter(1) gather = s32[1,1] gather(operand, indices), - output_window_dims={0,1}, - elided_window_dims={}, - gather_dims_to_operand_dims={0,1}, + offset_dims={0,1}, + collapsed_slice_dims={}, + start_index_map={0,1}, index_vector_dim=0, - window_bounds={1,1} + slice_sizes={1,1} one = s32[] constant(1) one_broadcasted = s32[1,1] broadcast(one), dimensions={} ROOT result = s32[1,1]{1,0} add(gather, one_broadcasted) @@ -584,9 +578,8 @@ ENTRY main { )"; std::unique_ptr operand = LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); - std::unique_ptr gather_indices = - LiteralUtil::CreateR1({1, 1}); - RunTest(hlo_text, operand.get(), gather_indices.get()); + std::unique_ptr start_indices = LiteralUtil::CreateR1({1, 1}); + RunTest(hlo_text, operand.get(), start_indices.get()); } XLA_TEST_F(GatherOperationTest, FusedBatchDynamicSlice) { @@ -597,11 +590,11 @@ ENTRY main { operand = s32[3,3] parameter(0) indices = s32[2,2] parameter(1) gather = s32[2,1,1] gather(operand, indices), - output_window_dims={1,2}, - elided_window_dims={}, - gather_dims_to_operand_dims={0,1}, + offset_dims={1,2}, + collapsed_slice_dims={}, + start_index_map={0,1}, index_vector_dim=0, - window_bounds={1,1} + slice_sizes={1,1} one = s32[] constant(1) one_broadcasted = s32[2,1,1] broadcast(one), dimensions={} ROOT result = s32[2,1,1]{2,1,0} add(gather, one_broadcasted) @@ -609,9 +602,9 @@ ENTRY main { )"; std::unique_ptr operand = LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); - std::unique_ptr gather_indices = + std::unique_ptr start_indices = LiteralUtil::CreateR2({{2, 1}, {1, 1}}); - RunTest(hlo_text, operand.get(), gather_indices.get()); + RunTest(hlo_text, operand.get(), start_indices.get()); } class GatherClientLibraryTest : public ClientLibraryTestBase {}; @@ -623,11 +616,11 @@ XLA_TEST_F(GatherClientLibraryTest, DISABLED_ON_GPU(Basic)) { // operand = s32[3,3] parameter(0) // indices = s32[2] parameter(1) // ROOT gather = s32[2,3] gather(operand, indices), - // output_window_dims={1}, - // elided_window_dims={0}, - // gather_dims_to_operand_dims={0}, + // offset_dims={1}, + // collapsed_slice_dims={0}, + // start_index_map={0}, // index_vector_dim=1, - // window_bounds={1, 3} + // slice_sizes={1, 3} // } XlaBuilder builder("gather_basic"); @@ -638,9 +631,9 @@ XLA_TEST_F(GatherClientLibraryTest, DISABLED_ON_GPU(Basic)) { auto operand = Parameter(&builder, 0, operand_shape, "operand"); auto indices = Parameter(&builder, 1, indices_shape, "indices"); GatherDimensionNumbers dim_numbers; - dim_numbers.add_output_window_dims(1); - dim_numbers.add_elided_window_dims(0); - dim_numbers.add_gather_dims_to_operand_dims(0); + dim_numbers.add_offset_dims(1); + dim_numbers.add_collapsed_slice_dims(0); + dim_numbers.add_start_index_map(0); dim_numbers.set_index_vector_dim(1); Gather(operand, indices, dim_numbers, {1, 3}); diff --git a/tensorflow/compiler/xla/tests/half_test.cc b/tensorflow/compiler/xla/tests/half_test.cc index 249a4b2493fdc28adf349eb3578c404f347dc892..51450314b611b49c643fb6fd5b0c0d2e7205a2d2 100644 --- a/tensorflow/compiler/xla/tests/half_test.cc +++ b/tensorflow/compiler/xla/tests/half_test.cc @@ -16,7 +16,7 @@ limitations under the License. #include #include -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/test.h" diff --git a/tensorflow/compiler/xla/tests/hlo_metadata_test.cc b/tensorflow/compiler/xla/tests/hlo_metadata_test.cc index 4d82442f7e3630c115eff1f17544e2b892c5e7eb..5511190caf95544e2ac48d91c0a138db06a2544c 100644 --- a/tensorflow/compiler/xla/tests/hlo_metadata_test.cc +++ b/tensorflow/compiler/xla/tests/hlo_metadata_test.cc @@ -14,7 +14,7 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/service/local_service.h" #include "tensorflow/compiler/xla/test_helpers.h" #include "tensorflow/compiler/xla/tests/local_client_test_base.h" diff --git a/tensorflow/compiler/xla/tests/hlo_test_base.cc b/tensorflow/compiler/xla/tests/hlo_test_base.cc index b662e837168c8b16daea0181786be19fa0237a8c..2167d4240e4341e686c135c05d440eea6d9a8ce9 100644 --- a/tensorflow/compiler/xla/tests/hlo_test_base.cc +++ b/tensorflow/compiler/xla/tests/hlo_test_base.cc @@ -20,12 +20,15 @@ limitations under the License. #include #include +#include "absl/algorithm/container.h" +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" -#include "tensorflow/compiler/xla/ptr_util.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/service/platform_util.h" #include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/compiler/xla/tests/test_utils.h" #include "tensorflow/compiler/xla/types.h" @@ -83,21 +86,38 @@ ProgramShape GetProgramShapeWithLayout(const HloModule& module) { } // namespace -HloTestBase::HloTestBase() - : HloTestBase(GetTestPlatform(), GetReferencePlatform()) {} +HloTestBase::HloTestBase(bool allow_mixed_precision_in_hlo_verifier) + : HloTestBase(GetTestPlatform(), GetReferencePlatform(), + allow_mixed_precision_in_hlo_verifier) {} HloTestBase::HloTestBase(se::Platform* test_platform, - se::Platform* reference_platform) + se::Platform* reference_platform, + bool allow_mixed_precision_in_hlo_verifier) : test_runner_(test_platform), reference_runner_(reference_platform) { - hlo_verifier_ = MakeUnique(/*allow_mixed_precision=*/true); + hlo_verifier_ = + absl::make_unique(allow_mixed_precision_in_hlo_verifier); } -/* static */ std::unique_ptr HloTestBase::CreateNewModule(const string& name) { - return MakeUnique(name, GetModuleConfigForTest()); + return absl::make_unique(name, GetModuleConfigForTest()); +} + +/* static */ +StatusOr HloTestBase::RunHloPass(HloPassInterface* hlo_pass, + HloModule* module) { + const string module_str_before_run = module->ToProto().ShortDebugString(); + const auto status_or = hlo_pass->Run(module); + if (status_or.status().ok()) { + const string module_str_after_run = module->ToProto().ShortDebugString(); + if (!status_or.ValueOrDie()) { + // Check that the proto remains same. + EXPECT_EQ(module_str_after_run, module_str_before_run); + } + } + return status_or; } -/*static*/ DebugOptions HloTestBase::GetDebugOptionsForTest() { +DebugOptions HloTestBase::GetDebugOptionsForTest() { auto debug_options = legacy_flags::GetDebugOptionsFromFlags(); // TODO(b/38354253): Change tests to use Parameters instead of Constants. debug_options.add_xla_disable_hlo_passes("constant_folding"); @@ -196,7 +216,7 @@ StatusOr<::testing::AssertionResult> HloTestBase::RunAndCompareInternal( MakeFakeArguments(module.get()).ConsumeValueOrDie(); std::vector fake_argument_ptrs; - c_transform( + absl::c_transform( fake_arguments, std::back_inserter(fake_argument_ptrs), [](const std::unique_ptr& literal) { return literal.get(); }); @@ -210,7 +230,7 @@ StatusOr<::testing::AssertionResult> HloTestBase::RunAndCompareInternal( const auto& fake_arguments = MakeFakeArguments(module.get()).ConsumeValueOrDie(); std::vector fake_argument_ptrs; - c_transform( + absl::c_transform( fake_arguments, std::back_inserter(fake_argument_ptrs), [](const std::unique_ptr& literal) { return literal.get(); }); @@ -233,6 +253,29 @@ StatusOr<::testing::AssertionResult> HloTestBase::RunAndCompareInternal( reference_preprocessor); } +::testing::AssertionResult HloTestBase::Run(const StringPiece hlo_string) { + auto module_or_status = + HloRunner::CreateModuleFromString(hlo_string, GetDebugOptionsForTest()); + if (!module_or_status.ok()) { + return ::testing::AssertionFailure() + << "Error while parsing HLO text format: " + << module_or_status.status().ToString(); + } + const auto& fake_arguments = + MakeFakeArguments(module_or_status.ValueOrDie().get()) + .ConsumeValueOrDie(); + std::vector fake_argument_ptrs; + absl::c_transform( + fake_arguments, std::back_inserter(fake_argument_ptrs), + [](const std::unique_ptr& literal) { return literal.get(); }); + return test_runner_ + .Execute(std::move(module_or_status.ValueOrDie()), + fake_argument_ptrs, /*run_hlo_passes=*/true) + .ok() + ? ::testing::AssertionSuccess() + : ::testing::AssertionFailure(); +} + ::testing::AssertionResult HloTestBase::RunAndCompareFromFile( const string& filename, const tensorflow::gtl::optional& error, const std::function& reference_preprocessor) { @@ -277,8 +320,8 @@ StatusOr<::testing::AssertionResult> HloTestBase::RunAndCompareInternal( HloComputation* HloTestBase::FindComputation(HloModule* module, tensorflow::StringPiece name) { auto computations = module->computations(); - auto it = c_find_if(computations, - [&](HloComputation* c) { return c->name() == name; }); + auto it = absl::c_find_if( + computations, [&](HloComputation* c) { return c->name() == name; }); if (it == computations.end()) { return nullptr; } @@ -289,8 +332,8 @@ HloInstruction* HloTestBase::FindInstruction(HloModule* module, tensorflow::StringPiece name) { for (const HloComputation* c : module->computations()) { auto instructions = c->instructions(); - auto it = c_find_if(instructions, - [&](HloInstruction* i) { return i->name() == name; }); + auto it = absl::c_find_if( + instructions, [&](HloInstruction* i) { return i->name() == name; }); if (it != instructions.end()) { return *it; } diff --git a/tensorflow/compiler/xla/tests/hlo_test_base.h b/tensorflow/compiler/xla/tests/hlo_test_base.h index 66719b1460063a61541535ff7507468ae0ca1ada..5c7304b4de9c82260ce8589da2aa6a91806f39ab 100644 --- a/tensorflow/compiler/xla/tests/hlo_test_base.h +++ b/tensorflow/compiler/xla/tests/hlo_test_base.h @@ -72,30 +72,39 @@ class HloTestBase : public ::testing::Test { // options from command-line flags. If you want a fresh HloModule object and // then add HloComputations to it, it's recommended to use this method in your // tests. - static std::unique_ptr CreateNewModule( - const string& name = TestName()); + std::unique_ptr CreateNewModule(const string& name = TestName()); + + // Runs the hlo_pass with the provided module and returns the result. This + // function also verifies that the module remains unchanged when hlo_pass + // returns false as the StatusOr value. + static StatusOr RunHloPass(HloPassInterface* hlo_pass, + HloModule* module); protected: // This uses the interpreter backend as the reference backend and // automatically finds another supported backend as the test backend. If the // interpreter is the only supported backend, it will be both the test backend // and the reference backend. - HloTestBase(); + HloTestBase(bool allow_mixed_precision_in_hlo_verifier = true); // If your test doesn't use interpreter as the reference backend, you can use // this constructor. Note that your test target is responsible for linking in // both needed backends. - HloTestBase(se::Platform* test_platform, se::Platform* reference_platform); + HloTestBase(se::Platform* test_platform, se::Platform* reference_platform, + bool allow_mixed_precision_in_hlo_verifier = true); ~HloTestBase() override {} // Populates debug options from command-line flags and adjusts the options for // testing. It is recommended to use this when you need to pass in // DebugOptions, e.g. when creating a module from a string or a file. - static DebugOptions GetDebugOptionsForTest(); + // + // This function is virtual so tests can specify an alternative set of debug + // options (e.g. disabling additional passes). + virtual DebugOptions GetDebugOptionsForTest(); // Gets an HloModuleConfig with options appropriate for tests. - static HloModuleConfig GetModuleConfigForTest() { + HloModuleConfig GetModuleConfigForTest() { HloModuleConfig config; config.set_debug_options(GetDebugOptionsForTest()); return config; @@ -166,6 +175,8 @@ class HloTestBase : public ::testing::Test { const tensorflow::gtl::optional& error, const std::function& reference_preprocessor = nullptr) TF_MUST_USE_RESULT; + ::testing::AssertionResult Run(const tensorflow::StringPiece hlo_string) + TF_MUST_USE_RESULT; ::testing::AssertionResult RunAndCompareFromFile( const string& filename, const tensorflow::gtl::optional& error, const std::function& reference_preprocessor = nullptr) diff --git a/tensorflow/compiler/xla/tests/hlo_verified_test_base.cc b/tensorflow/compiler/xla/tests/hlo_verified_test_base.cc index ad1f5b9eed8b5b140100c1fa35dc7d698e3db48b..a509ee32078551c850232d0f36380e25321e00a0 100644 --- a/tensorflow/compiler/xla/tests/hlo_verified_test_base.cc +++ b/tensorflow/compiler/xla/tests/hlo_verified_test_base.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/service/hlo_verifier.h" #include "tensorflow/compiler/xla/shape_util.h" @@ -25,7 +26,7 @@ limitations under the License. namespace xla { HloVerifiedTestBase::HloVerifiedTestBase() - : shape_verifier_(MakeUnique()) {} + : shape_verifier_(absl::make_unique()) {} HloVerifiedTestBase::~HloVerifiedTestBase() { // We can't call the ASSERT or EXPECT test macros in destructors, so we diff --git a/tensorflow/compiler/xla/tests/iota_test.cc b/tensorflow/compiler/xla/tests/iota_test.cc index f950aa1e8fe745075234a5ebff52d92be7378a5d..17ac95ae0198d98490b25f7f2edd32d1e0495803 100644 --- a/tensorflow/compiler/xla/tests/iota_test.cc +++ b/tensorflow/compiler/xla/tests/iota_test.cc @@ -17,6 +17,7 @@ limitations under the License. #include #include "tensorflow/compiler/xla/tests/client_library_test_base.h" +#include "tensorflow/compiler/xla/tests/test_macros.h" #include "tensorflow/core/lib/core/errors.h" namespace xla { @@ -34,7 +35,7 @@ class IotaTest : public ClientLibraryTestBase { } }; -TEST_F(IotaTest, SimpleR1) { +XLA_TEST_F(IotaTest, SimpleR1) { for (int num_elements = 1; num_elements < 10000001; num_elements *= 10) { { XlaBuilder builder(TestName() + "_f32"); diff --git a/tensorflow/compiler/xla/tests/llvm_compiler_test.cc b/tensorflow/compiler/xla/tests/llvm_compiler_test.cc index e719da54d45d3e6eb3f3e14d3fa3076db2081e04..8d658695576035cdc34a213847460dd80de5f67e 100644 --- a/tensorflow/compiler/xla/tests/llvm_compiler_test.cc +++ b/tensorflow/compiler/xla/tests/llvm_compiler_test.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/xla/service/llvm_compiler.h" +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/service/backend.h" #include "tensorflow/compiler/xla/service/cpu/cpu_compiler.h" @@ -125,7 +126,7 @@ class LLVMCompilerTest : public ::testing::Test { static std::unique_ptr CreateNewModule() { HloModuleConfig config; config.set_debug_options(legacy_flags::GetDebugOptionsFromFlags()); - return MakeUnique(TestName(), config); + return absl::make_unique(TestName(), config); } }; diff --git a/tensorflow/compiler/xla/tests/llvm_irgen_test_base.cc b/tensorflow/compiler/xla/tests/llvm_irgen_test_base.cc index 6fc11150978931f980349799372872f9fb68f292..0487d314094edcab61a92de32f14113dd19673fa 100644 --- a/tensorflow/compiler/xla/tests/llvm_irgen_test_base.cc +++ b/tensorflow/compiler/xla/tests/llvm_irgen_test_base.cc @@ -51,8 +51,9 @@ void LlvmIrGenTestBase::CompileAndVerifyIr( std::unique_ptr hlo_module, const string& pattern, bool match_optimized_ir) { SetIrHook(match_optimized_ir); - TF_ASSERT_OK(CompileToExecutable(std::move(hlo_module)).status()); + Status status = CompileToExecutable(std::move(hlo_module)).status(); ResetIrHook(); + TF_ASSERT_OK(status); StatusOr filecheck_result = RunFileCheck(ir_, pattern); TF_ASSERT_OK(filecheck_result.status()); @@ -73,9 +74,10 @@ void LlvmIrGenTestBase::CompileAheadOfTimeAndVerifyIr( std::unique_ptr hlo_module, const AotCompilationOptions& options, const string& pattern, bool match_optimized_ir) { SetIrHook(match_optimized_ir); - TF_ASSERT_OK( - CompileToAotCompilationResult(std::move(hlo_module), options).status()); + Status status = + CompileToAotCompilationResult(std::move(hlo_module), options).status(); ResetIrHook(); + TF_ASSERT_OK(status); StatusOr filecheck_result = RunFileCheck(ir_, pattern); ASSERT_TRUE(filecheck_result.ok()); diff --git a/tensorflow/compiler/xla/tests/local_client_allocation_test.cc b/tensorflow/compiler/xla/tests/local_client_allocation_test.cc index 0df50150aee69749beea79ff522fb6f820d1945d..e2cd5bcc5a95f692dcf4a43d717252bfe876aa81 100644 --- a/tensorflow/compiler/xla/tests/local_client_allocation_test.cc +++ b/tensorflow/compiler/xla/tests/local_client_allocation_test.cc @@ -16,7 +16,7 @@ limitations under the License. #include #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/service/local_service.h" #include "tensorflow/compiler/xla/service/shaped_buffer.h" diff --git a/tensorflow/compiler/xla/tests/local_client_aot_test.cc b/tensorflow/compiler/xla/tests/local_client_aot_test.cc index 47cab796041e9669affaebd7866d0d80100730f1..115448c908ac9e7f0b01772ce348d23bf4d838ed 100644 --- a/tensorflow/compiler/xla/tests/local_client_aot_test.cc +++ b/tensorflow/compiler/xla/tests/local_client_aot_test.cc @@ -42,13 +42,12 @@ extern "C" void SumStructElements(float* out, void** parameters) { TEST_F(LocalClientAotTest, Constant) { xla::ExecutableRunOptions run_options; OpaqueData opaque_data{100, 20, 3}; - void* parameters[] = {&opaque_data}; float out = 0; - void* temporary_buffers[] = {nullptr, &out}; - SumAndDouble(&out, &run_options, parameters, temporary_buffers); + void* temporary_buffers[] = {&opaque_data, &out}; + SumAndDouble(&out, &run_options, nullptr, temporary_buffers); EXPECT_EQ(out, 246.0f); opaque_data = {1, 2, 3}; - SumAndDouble(&out, &run_options, parameters, temporary_buffers); + SumAndDouble(&out, &run_options, nullptr, temporary_buffers); EXPECT_EQ(out, 12.0f); } diff --git a/tensorflow/compiler/xla/tests/local_client_aot_test_helper.cc b/tensorflow/compiler/xla/tests/local_client_aot_test_helper.cc index 0b44090702c793cacbc363a0701c35a12150975e..60eb21aafd23a8d724d1f08d5c87098b7c3dcd6b 100644 --- a/tensorflow/compiler/xla/tests/local_client_aot_test_helper.cc +++ b/tensorflow/compiler/xla/tests/local_client_aot_test_helper.cc @@ -21,7 +21,7 @@ limitations under the License. #include "llvm/ADT/Triple.h" #include "tensorflow/compiler/xla/client/client_library.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/service/cpu/cpu_compiler.h" #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" @@ -92,9 +92,10 @@ int main(int argc, char** argv) { // It's lame to hard-code the buffer assignments, but we need // local_client_aot_test.cc to be able to easily invoke the function. CHECK_EQ(result->result_buffer_index(), 1); - CHECK_EQ(result->buffer_sizes().size(), 2); - CHECK_EQ(result->buffer_sizes()[0], -1); // param buffer - CHECK_EQ(result->buffer_sizes()[1], sizeof(float)); // result buffer + CHECK_EQ(result->buffer_infos().size(), 3); + CHECK(result->buffer_infos()[0].is_entry_parameter()); // param buffer + CHECK_EQ(result->buffer_infos()[1].size(), sizeof(float)); // result buffer + CHECK(result->buffer_infos()[2].is_constant()); // const buffer if (triple.isOSBinFormatELF()) { // Check the ELF magic. CHECK_EQ(result->object_file_data()[0], 0x7F); diff --git a/tensorflow/compiler/xla/tests/local_client_execute_test.cc b/tensorflow/compiler/xla/tests/local_client_execute_test.cc index 5c3498c84cb68e4b1c4a7814284418f1ebbc0e98..1a823cf189b310c62c735419936544ea99fcfbaf 100644 --- a/tensorflow/compiler/xla/tests/local_client_execute_test.cc +++ b/tensorflow/compiler/xla/tests/local_client_execute_test.cc @@ -19,7 +19,7 @@ limitations under the License. #include "tensorflow/compiler/xla/client/client_library.h" #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/service/device_memory_allocator.h" diff --git a/tensorflow/compiler/xla/tests/local_client_test_base.cc b/tensorflow/compiler/xla/tests/local_client_test_base.cc index eaddf756dbc913dd9668cd22228fbd18c2c33309..948b60061e2f47c73c7c7a2d6cbc65baf1b4411c 100644 --- a/tensorflow/compiler/xla/tests/local_client_test_base.cc +++ b/tensorflow/compiler/xla/tests/local_client_test_base.cc @@ -18,11 +18,11 @@ limitations under the License. #include +#include "absl/memory/memory.h" #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/map_util.h" -#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/test_helpers.h" diff --git a/tensorflow/compiler/xla/tests/log_test.cc b/tensorflow/compiler/xla/tests/log_test.cc index cdf70ee4185be2ecd9dcb2d21fbd98c2ab6cc0ad..2d622242e657ce032a17f7b26c94227d343e2a38 100644 --- a/tensorflow/compiler/xla/tests/log_test.cc +++ b/tensorflow/compiler/xla/tests/log_test.cc @@ -17,7 +17,7 @@ limitations under the License. #include #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/compiler/xla/tests/test_macros.h" diff --git a/tensorflow/compiler/xla/tests/map_test.cc b/tensorflow/compiler/xla/tests/map_test.cc index 34bcaef513e352d75553ad370ac99f309de13475..0732e195d44d738b264361e43d38259c26a4116e 100644 --- a/tensorflow/compiler/xla/tests/map_test.cc +++ b/tensorflow/compiler/xla/tests/map_test.cc @@ -19,7 +19,7 @@ limitations under the License. #include "tensorflow/compiler/xla/client/global_data.h" #include "tensorflow/compiler/xla/client/lib/arithmetic.h" #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/shape_util.h" diff --git a/tensorflow/compiler/xla/tests/matrix_ops_simple_test.cc b/tensorflow/compiler/xla/tests/matrix_ops_simple_test.cc index 4fca90af770f075d112534d45e9e3af87dec3d14..b6035a21a6709120c4b950382a6d248435f970c8 100644 --- a/tensorflow/compiler/xla/tests/matrix_ops_simple_test.cc +++ b/tensorflow/compiler/xla/tests/matrix_ops_simple_test.cc @@ -17,12 +17,12 @@ limitations under the License. #include #include +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/array2d.h" #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/literal.h" -#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/reference_util.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/statusor.h" @@ -133,7 +133,7 @@ class TestLinspaceMaxParametric float from = -128.0, to = 256.0; std::unique_ptr> alhs = MakeLinspaceArray2D(from, to, rows, cols); - auto arhs = MakeUnique>(rows, cols, static_cast(1.0f)); + auto arhs = absl::make_unique>(rows, cols, static_cast(1.0f)); XlaBuilder builder( tensorflow::strings::Printf("max_%lldx%lld_linspace", rows, cols)); diff --git a/tensorflow/compiler/xla/tests/multidimensional_slice_test.cc b/tensorflow/compiler/xla/tests/multidimensional_slice_test.cc index e576f000ef23e761d6fa818457eec2144d4bcb00..955dbef6dcd28421fb351c6ee064ac53eda1fd08 100644 --- a/tensorflow/compiler/xla/tests/multidimensional_slice_test.cc +++ b/tensorflow/compiler/xla/tests/multidimensional_slice_test.cc @@ -20,7 +20,7 @@ limitations under the License. #include "tensorflow/compiler/xla/array2d.h" #include "tensorflow/compiler/xla/array3d.h" #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/compiler/xla/tests/test_macros.h" diff --git a/tensorflow/compiler/xla/tests/multioutput_fusion_test.cc b/tensorflow/compiler/xla/tests/multioutput_fusion_test.cc index eb06b115daa96bccd73de30bb7fa30733a6fd947..cadf1c5523afdd61e4252185a123defdd8aa2c27 100644 --- a/tensorflow/compiler/xla/tests/multioutput_fusion_test.cc +++ b/tensorflow/compiler/xla/tests/multioutput_fusion_test.cc @@ -19,10 +19,10 @@ limitations under the License. #include #include +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/primitive_util.h" -#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module.h" diff --git a/tensorflow/compiler/xla/tests/outfeed_in_nested_computation_test.cc b/tensorflow/compiler/xla/tests/outfeed_in_nested_computation_test.cc index cea7006526f0c56ade3cedead489ea12c0ab3922..0a0426adcbc1b5b89be0841fa2c4204e2b65abf4 100644 --- a/tensorflow/compiler/xla/tests/outfeed_in_nested_computation_test.cc +++ b/tensorflow/compiler/xla/tests/outfeed_in_nested_computation_test.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/xla/tests/local_client_test_base.h" +#include "tensorflow/compiler/xla/tests/test_macros.h" #include "tensorflow/core/lib/core/status_test_util.h" namespace xla { @@ -22,9 +23,9 @@ namespace { // Tests that ensure outfeed instructions that are contained in nested // computations in non-root positions are executed. -class LocalClientExecuteTest : public LocalClientTestBase {}; +class OutfeedInNestedComputationTest : public LocalClientTestBase {}; -TEST_F(LocalClientExecuteTest, OutfeedInWhile) { +XLA_TEST_F(OutfeedInNestedComputationTest, OutfeedInWhile) { XlaBuilder b(TestName()); Shape state_tuple_array_shape = ShapeUtil::MakeShape(xla::S32, {10, 5}); @@ -117,7 +118,7 @@ TEST_F(LocalClientExecuteTest, OutfeedInWhile) { EXPECT_EQ(comp_result->Get({}), 0); } -TEST_F(LocalClientExecuteTest, OutfeedInConditional) { +XLA_TEST_F(OutfeedInNestedComputationTest, OutfeedInConditional) { XlaBuilder b(TestName()); Shape condition_shape = ShapeUtil::MakeShape(xla::PRED, {}); diff --git a/tensorflow/compiler/xla/tests/pad_test.cc b/tensorflow/compiler/xla/tests/pad_test.cc index d8c17202f20f9318c4cbef707b82a644f4802160..cbeddffacfa4a0fc560e8b9f9a8d7bd23ff32e55 100644 --- a/tensorflow/compiler/xla/tests/pad_test.cc +++ b/tensorflow/compiler/xla/tests/pad_test.cc @@ -16,13 +16,12 @@ limitations under the License. #include #include +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/array2d.h" #include "tensorflow/compiler/xla/array4d.h" #include "tensorflow/compiler/xla/client/lib/arithmetic.h" #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" -#include "tensorflow/compiler/xla/client/xla_computation.h" -#include "tensorflow/compiler/xla/ptr_util.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/reference_util.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" @@ -141,7 +140,7 @@ XLA_TEST_P(PadTestFloat, Pad4D_2x0x3x2_FloatArray) { TEST_P(PadTestFloat, Pad4DFloat_1x1x3x2_Array) { XlaBuilder b(TestName()); - auto input = MakeUnique>(1, 1, 3, 2); + auto input = absl::make_unique>(1, 1, 3, 2); Array2D input_xy({ {1.0f, 2.0f}, // row 0 {3.0f, 4.0f}, // row 1 @@ -152,7 +151,7 @@ TEST_P(PadTestFloat, Pad4DFloat_1x1x3x2_Array) { Pad(AddParam(*input, &b), AddParam(*LiteralUtil::CreateR0(1.5), &b), r4_padding_on_dim0_dim1_); - auto expected = MakeUnique>(2, 3, 3, 2); + auto expected = absl::make_unique>(2, 3, 3, 2); expected->Fill(1.5); (*expected)(1, 0, 0, 0) = 1.0f; (*expected)(1, 0, 0, 1) = 2.0f; @@ -172,7 +171,7 @@ TEST_P(PadTestFloat, Pad4DFloatArrayWithInteriorPadding) { AddParam(*LiteralUtil::CreateR0(pad_value), &b), r4_padding_on_dim0_dim1_); - auto expected = MakeUnique>(8, 5, 1, 1); + auto expected = absl::make_unique>(8, 5, 1, 1); expected->Fill(pad_value); (*expected)(1, 0, 0, 0) = 1.0f; (*expected)(1, 2, 0, 0) = 2.0f; @@ -270,7 +269,7 @@ XLA_TEST_P(PadTestFloat, Pad4DFloatArrayMinorFirstNonTrivialMinorDimensions) { XLA_TEST_F(PadTest, Pad4DU8Array) { XlaBuilder b(TestName()); - auto input = MakeUnique>(1, 1, 3, 2); + auto input = absl::make_unique>(1, 1, 3, 2); Array2D input_xy({ {1, 2}, // row 0 {3, 4}, // row 1 @@ -281,7 +280,7 @@ XLA_TEST_F(PadTest, Pad4DU8Array) { Pad(AddParam(*input, &b), ConstantR0(&b, 35), r4_padding_on_dim0_dim1_); - auto expected = MakeUnique>(2, 3, 3, 2); + auto expected = absl::make_unique>(2, 3, 3, 2); expected->Fill(35); (*expected)(1, 0, 0, 0) = 1; (*expected)(1, 0, 0, 1) = 2; @@ -302,13 +301,13 @@ XLA_TEST_F(PadTest, Pad4DPredArray) { Pad(input, ConstantR0(&b, false), r4_padding_on_dim0_dim1_); // For the same reason, use Select to convert boolean values to int32. - auto zeros = MakeUnique>(2, 3, 3, 2); - auto ones = MakeUnique>(2, 3, 3, 2); + auto zeros = absl::make_unique>(2, 3, 3, 2); + auto ones = absl::make_unique>(2, 3, 3, 2); zeros->Fill(0); ones->Fill(1); Select(padded, AddParam(*ones, &b), AddParam(*zeros, &b)); - auto expected = MakeUnique>(2, 3, 3, 2); + auto expected = absl::make_unique>(2, 3, 3, 2); expected->Fill(0); (*expected)(1, 0, 0, 0) = 1; (*expected)(1, 0, 0, 1) = 1; @@ -322,7 +321,7 @@ XLA_TEST_F(PadTest, Pad4DPredArray) { XLA_TEST_P(PadTestFloat, Large2DPad) { XlaBuilder b(TestName()); - auto ones = MakeUnique>(4, 4); + auto ones = absl::make_unique>(4, 4); ones->Fill(1.0f); auto input = AddParam(*ones, &b); PaddingConfig padding_config = MakeNoPaddingConfig(2); @@ -343,7 +342,7 @@ XLA_TEST_P(PadTestFloat, AllTypes2DPad) { constexpr int64 in_rows = 35; constexpr int64 in_cols = 35; - auto operand = MakeUnique>(in_rows, in_cols); + auto operand = absl::make_unique>(in_rows, in_cols); operand->FillUnique(0.0f); auto input = AddParam(*operand, &b); @@ -369,7 +368,7 @@ XLA_TEST_P(PadTestFloat, High2DPad) { constexpr int64 low_padding = 0; int64 high_padding[2] = {5, 7}; constexpr int64 interior_padding = 0; - auto operand = MakeUnique>(in_rows, in_cols); + auto operand = absl::make_unique>(in_rows, in_cols); operand->FillUnique(1.0f); auto input = AddParam(*operand, &b); PaddingConfig padding_config = MakeNoPaddingConfig(2); @@ -396,7 +395,7 @@ XLA_TEST_P(PadTestFloat, NegativePadding2D) { int64 low_padding[2] = {-1, -2}; int64 high_padding[2] = {-3, 4}; constexpr int64 interior_padding = 0; - auto operand = MakeUnique>(in_rows, in_cols); + auto operand = absl::make_unique>(in_rows, in_cols); operand->FillUnique(1.0f); auto input = AddParam(*operand, &b); PaddingConfig padding_config = MakeNoPaddingConfig(2); @@ -424,7 +423,7 @@ XLA_TEST_P(PadTestFloat, NegativeAndInteriorPadding2D) { int64 low_padding[2] = {4, -1}; int64 high_padding[2] = {-2, -4}; int64 interior_padding[2] = {1, 2}; - auto operand = MakeUnique>(in_rows, in_cols); + auto operand = absl::make_unique>(in_rows, in_cols); operand->FillUnique(1.0f); auto input = AddParam(*operand, &b); PaddingConfig padding_config = MakeNoPaddingConfig(2); @@ -447,7 +446,7 @@ XLA_TEST_P(PadTestFloat, NegativeAndInteriorPadding2D) { // Regression test for b/31827337. XLA_TEST_P(PadTestFloat, ReducePad) { XlaBuilder b(TestName()); - auto ones = MakeUnique>(2, 2, 2, 2); + auto ones = absl::make_unique>(2, 2, 2, 2); ones->Fill(1.0); auto input = AddParam(*ones, &b); diff --git a/tensorflow/compiler/xla/tests/params_test.cc b/tensorflow/compiler/xla/tests/params_test.cc index bf3b5f2b6592b4ac47dfa7a0af4b8e5978e31ec8..f6c762e7a4bee91a26c4c2e033c3717fef6d91d0 100644 --- a/tensorflow/compiler/xla/tests/params_test.cc +++ b/tensorflow/compiler/xla/tests/params_test.cc @@ -21,7 +21,7 @@ limitations under the License. #include "tensorflow/compiler/xla/array2d.h" #include "tensorflow/compiler/xla/client/global_data.h" #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/literal.h" diff --git a/tensorflow/compiler/xla/tests/pred_test.cc b/tensorflow/compiler/xla/tests/pred_test.cc index 5c351b2d113709105244de4aafa49d7cc535ced1..2fc7f816b56db6f57ca835d1847476b6d622ce5e 100644 --- a/tensorflow/compiler/xla/tests/pred_test.cc +++ b/tensorflow/compiler/xla/tests/pred_test.cc @@ -19,7 +19,7 @@ limitations under the License. #include "tensorflow/compiler/xla/array2d.h" #include "tensorflow/compiler/xla/client/lib/arithmetic.h" #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/platform/test.h" diff --git a/tensorflow/compiler/xla/tests/prng_test.cc b/tensorflow/compiler/xla/tests/prng_test.cc index 3f98099be60ee4694a75f3200f130e49ed39fe67..326e13b3867f2f804e882e00e35850d0189ad8d7 100644 --- a/tensorflow/compiler/xla/tests/prng_test.cc +++ b/tensorflow/compiler/xla/tests/prng_test.cc @@ -17,7 +17,7 @@ limitations under the License. #include #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/primitive_util.h" #include "tensorflow/compiler/xla/shape_util.h" @@ -182,7 +182,7 @@ XLA_TEST_F(PrngTest, Uniformity256) { XLA_TEST_F(PrngTest, MapUsingRng) { // Build a x -> (x + U[0,1)) computation. - auto build_sum_rng = [this](XlaBuilder& builder) { + auto build_sum_rng = [](XlaBuilder& builder) { auto b = builder.CreateSubBuilder("sum_with_rng"); auto x = Parameter(b.get(), 0, ShapeUtil::MakeShape(F32, {}), "input"); Add(x, diff --git a/tensorflow/compiler/xla/tests/query_inferred_shape_test.cc b/tensorflow/compiler/xla/tests/query_inferred_shape_test.cc index 526a38e8d1dbed9cdd4a31bfbec49bc5c6bb174b..fab2a65de109c670a6854c0fc1118162acf3d312 100644 --- a/tensorflow/compiler/xla/tests/query_inferred_shape_test.cc +++ b/tensorflow/compiler/xla/tests/query_inferred_shape_test.cc @@ -16,7 +16,7 @@ limitations under the License. #include #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/test_helpers.h" diff --git a/tensorflow/compiler/xla/tests/reduce_precision_test.cc b/tensorflow/compiler/xla/tests/reduce_precision_test.cc index 04c7f316463441d1bd458393b29ea5eb2acb9c9b..531648fe3eb8e3941c5e3c012847ee68c616590f 100644 --- a/tensorflow/compiler/xla/tests/reduce_precision_test.cc +++ b/tensorflow/compiler/xla/tests/reduce_precision_test.cc @@ -22,7 +22,7 @@ limitations under the License. #include "tensorflow/compiler/xla/array2d.h" #include "tensorflow/compiler/xla/client/global_data.h" #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/service/reduce_precision_insertion.h" diff --git a/tensorflow/compiler/xla/tests/reduce_test.cc b/tensorflow/compiler/xla/tests/reduce_test.cc index 638b0825a158cb1db458daea40d91ef88d1abe5d..2065271a7f686c52c88df80b0efe8f2e1542d198 100644 --- a/tensorflow/compiler/xla/tests/reduce_test.cc +++ b/tensorflow/compiler/xla/tests/reduce_test.cc @@ -37,7 +37,7 @@ limitations under the License. #include "tensorflow/compiler/xla/client/global_data.h" #include "tensorflow/compiler/xla/client/lib/arithmetic.h" #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/literal_util.h" diff --git a/tensorflow/compiler/xla/tests/reduce_window_test.cc b/tensorflow/compiler/xla/tests/reduce_window_test.cc index 161b74a5c8f4e93f8944fcfc9e92b38afe5951f6..09acadb2c27e68024f55c017197fba79f7c6cab8 100644 --- a/tensorflow/compiler/xla/tests/reduce_window_test.cc +++ b/tensorflow/compiler/xla/tests/reduce_window_test.cc @@ -18,13 +18,14 @@ limitations under the License. #include #include +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/array2d.h" #include "tensorflow/compiler/xla/array3d.h" #include "tensorflow/compiler/xla/array4d.h" #include "tensorflow/compiler/xla/client/lib/arithmetic.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/client/padding.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/reference_util.h" #include "tensorflow/compiler/xla/shape_util.h" @@ -357,7 +358,7 @@ XLA_TEST_P(ReduceWindowTest, R6AddMultipleStrides) { std::vector input_dims(6, 8); auto shape = ShapeUtil::MakeShape(F32, input_dims); - auto arg_literal = MakeUnique(shape); + auto arg_literal = absl::make_unique(shape); arg_literal->PopulateWithValue(1.0f); const auto input = CreateConstantFromLiteral(*arg_literal, &builder_); @@ -368,7 +369,7 @@ XLA_TEST_P(ReduceWindowTest, R6AddMultipleStrides) { std::vector output_dims = {6, 8, 6, 6, 8, 8}; Shape result_shape = ShapeUtil::MakeShapeWithLayout(F32, output_dims, output_layout); - auto expected = MakeUnique(result_shape); + auto expected = absl::make_unique(result_shape); expected->PopulateWithValue(27.0f); ComputeAndCompareLiteral(&builder_, *expected, {}, DefaultErrorSpec()); } @@ -1261,6 +1262,12 @@ struct R1ReduceWindowTestData { /*pad_low=*/{5}, /*pad_high=*/{0}, /*reducer=*/Reducer::kAdd}, + + {/*base_bounds=*/{4096}, /*window_bounds=*/{4096}, + /*strides=*/{1}, + /*pad_low=*/{4095}, + /*pad_high=*/{0}, + /*reducer=*/Reducer::kMax}, }; string R1ReduceWindowTestDataToString( @@ -1341,7 +1348,7 @@ INSTANTIATE_TEST_CASE_P( // results on the interpreter backend. class ReduceWindowTextTest : public HloTestBase {}; -TEST_F(ReduceWindowTextTest, R2General256x384) { +XLA_TEST_F(ReduceWindowTextTest, R2General256x384) { const string hlo_string = R"( HloModule R2Window mul { @@ -1358,7 +1365,7 @@ ENTRY R2Window { EXPECT_TRUE(RunAndCompare(hlo_string, ErrorSpec{0.001})); } -TEST_F(ReduceWindowTextTest, R2General256x384Layout01) { +XLA_TEST_F(ReduceWindowTextTest, R2General256x384Layout01) { const string hlo_string = R"( HloModule R2Window mul { @@ -1375,7 +1382,7 @@ ROOT reduce-window = f32[256,384]{0,1} reduce-window(operand, constant), window= EXPECT_TRUE(RunAndCompare(hlo_string, ErrorSpec{0.001})); } -TEST_F(ReduceWindowTextTest, R2General2x5) { +XLA_TEST_F(ReduceWindowTextTest, R2General2x5) { const string hlo_string = R"( HloModule R2Window mul { @@ -1392,7 +1399,7 @@ ENTRY R2Window { EXPECT_TRUE(RunAndCompare(hlo_string, ErrorSpec{0.001})); } -TEST_F(ReduceWindowTextTest, R2EffectiveScalar) { +XLA_TEST_F(ReduceWindowTextTest, R2EffectiveScalar) { const string hlo_string = R"( HloModule R2Window mul { @@ -1410,7 +1417,7 @@ ENTRY R2Window { EXPECT_TRUE(RunAndCompare(hlo_string, ErrorSpec{0.001})); } -TEST_F(ReduceWindowTextTest, R3EffectiveScalar) { +XLA_TEST_F(ReduceWindowTextTest, R3EffectiveScalar) { const string hlo_string = R"( HloModule R3Window mul { @@ -1428,7 +1435,7 @@ ENTRY R3Window { EXPECT_TRUE(RunAndCompare(hlo_string, ErrorSpec{0.001})); } -TEST_F(HloTestBase, ReduceWindowIdentity) { +XLA_TEST_F(HloTestBase, ReduceWindowIdentity) { const string hlo_string = R"( HloModule ReduceWindowIdentity identity.pad_to_reduce_window { @@ -1445,7 +1452,7 @@ ENTRY reduce-window-identity { EXPECT_TRUE(RunAndCompare(hlo_string, tensorflow::gtl::nullopt)); } -TEST_F(HloTestBase, ReduceWindowS32) { +XLA_TEST_F(HloTestBase, ReduceWindowS32) { const string hlo_string = R"( HloModule reduce-window @@ -1464,5 +1471,24 @@ ENTRY %reduce-window (parameter.0: s32[81,8], parameter.1: s32[]) -> s32[82,8] { EXPECT_TRUE(RunAndCompare(hlo_string, tensorflow::gtl::nullopt)); } +XLA_TEST_F(HloTestBase, ReduceWindowF16) { + const string hlo_string = R"( +HloModule reduce-window + +%identity.pad_to_reduce_window (param0: f16[], param1: f16[]) -> f16[] { + %param0 = f16[] parameter(0) + ROOT %param1 = f16[] parameter(1) +} + +ENTRY %reduce-window (parameter.0: f16[81,8], parameter.1: f16[]) -> f16[82,8] { + %parameter.0 = f16[81,8]{1,0} parameter(0) + %parameter.1 = f16[] parameter(1) + ROOT %reduce-window = f16[82,8]{1,0} reduce-window(f16[81,8]{1,0} %parameter.0, f16[] %parameter.1), window={size=1x1 pad=0_1x0_0}, to_apply=%identity.pad_to_reduce_window +} + +)"; + EXPECT_TRUE(RunAndCompare(hlo_string, tensorflow::gtl::nullopt)); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/tests/replay_test.cc b/tensorflow/compiler/xla/tests/replay_test.cc index f026ad6c42d4123c0b6fb930b7a5566ff76d3557..d8914513819415368a628eab1f482f9644dd46b1 100644 --- a/tensorflow/compiler/xla/tests/replay_test.cc +++ b/tensorflow/compiler/xla/tests/replay_test.cc @@ -17,7 +17,7 @@ limitations under the License. #include "tensorflow/compiler/xla/client/global_data.h" #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/protobuf_util.h" diff --git a/tensorflow/compiler/xla/tests/reshape_motion_test.cc b/tensorflow/compiler/xla/tests/reshape_motion_test.cc index 7c0389cfa3251a6b62f83a78e986d870177d4d91..368f5583c9ce3773e57b858ff7606f679346529a 100644 --- a/tensorflow/compiler/xla/tests/reshape_motion_test.cc +++ b/tensorflow/compiler/xla/tests/reshape_motion_test.cc @@ -22,7 +22,7 @@ limitations under the License. #include "tensorflow/compiler/xla/array4d.h" #include "tensorflow/compiler/xla/client/global_data.h" #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/reference_util.h" diff --git a/tensorflow/compiler/xla/tests/reshape_test.cc b/tensorflow/compiler/xla/tests/reshape_test.cc index a6e985293a7824cb5b4bbd7883e961b4f362e241..382d1b1ae741285dcd1f7761edb82a5c333887af 100644 --- a/tensorflow/compiler/xla/tests/reshape_test.cc +++ b/tensorflow/compiler/xla/tests/reshape_test.cc @@ -22,7 +22,7 @@ limitations under the License. #include "tensorflow/compiler/xla/array4d.h" #include "tensorflow/compiler/xla/client/global_data.h" #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/literal_util.h" diff --git a/tensorflow/compiler/xla/tests/reverse_test.cc b/tensorflow/compiler/xla/tests/reverse_test.cc index 23f0d26d93bf979970d112993c0a945fb4fe7d53..41e49b4003236d55d85592315652a0ddefd5c485 100644 --- a/tensorflow/compiler/xla/tests/reverse_test.cc +++ b/tensorflow/compiler/xla/tests/reverse_test.cc @@ -18,7 +18,7 @@ limitations under the License. #include "tensorflow/compiler/xla/array2d.h" #include "tensorflow/compiler/xla/array4d.h" #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/compiler/xla/tests/test_macros.h" diff --git a/tensorflow/compiler/xla/tests/scalar_computations_test.cc b/tensorflow/compiler/xla/tests/scalar_computations_test.cc index 5a3bcaf0865883537e296ea8725f693730bf1776..e42c71eb284deb2e50d6ea4b47fa707e4bc14ffc 100644 --- a/tensorflow/compiler/xla/tests/scalar_computations_test.cc +++ b/tensorflow/compiler/xla/tests/scalar_computations_test.cc @@ -19,7 +19,7 @@ limitations under the License. #include "tensorflow/compiler/xla/client/global_data.h" #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/literal_util.h" diff --git a/tensorflow/compiler/xla/tests/scatter_test.cc b/tensorflow/compiler/xla/tests/scatter_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..922d70b7526f228b0559161167eeae8214d14476 --- /dev/null +++ b/tensorflow/compiler/xla/tests/scatter_test.cc @@ -0,0 +1,615 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/hlo_parser.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/tests/client_library_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/tests/test_macros.h" + +namespace xla { +namespace { + +using tensorflow::gtl::nullopt; + +class ScatterTest : public HloTestBase { + protected: + void RunTest(const string& hlo_text, Literal* operand, + Literal* scatter_indices, Literal* updates) { + RunTest(hlo_text, {operand, scatter_indices, updates}); + } + + void RunTest(const string& hlo_text, + tensorflow::gtl::ArraySlice args) { + HloModuleConfig config; + config.set_debug_options(GetDebugOptionsForTest()); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseHloString(hlo_text, config)); + EXPECT_TRUE(RunAndCompare(std::move(module), args, nullopt)); + } +}; + +XLA_TEST_F(ScatterTest, TensorFlowScatterV1_Update) { + const string hlo_text = R"( +HloModule TensorFlowScatterV1 + +update_s32 (lhs: s32[], rhs: s32[]) -> s32[] { + lhs = s32[] parameter(0) + ROOT rhs = s32[] parameter(1) +} + +ENTRY main { + operand = s32[3,3] parameter(0) + indices = s32[2] parameter(1) + updates = s32[2,3] parameter(2) + ROOT scatter = s32[3,3] scatter(operand, indices, updates), + to_apply=update_s32, + update_window_dims={1}, + inserted_window_dims={0}, + scatter_dims_to_operand_dims={0}, + index_vector_dim=1 +} +)"; + std::unique_ptr operand = + LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); + std::unique_ptr scatter_indices = + LiteralUtil::CreateR1({0, 2}); + std::unique_ptr updates = + LiteralUtil::CreateR2({{10, 20, 30}, {70, 80, 90}}); + RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get()); +} + +XLA_TEST_F(ScatterTest, TensorFlowScatterV2_Update) { + const char* hlo_text = R"( +HloModule TensorFlowScatterV2 + +update_s32 (lhs: s32[], rhs: s32[]) -> s32[] { + lhs = s32[] parameter(0) + ROOT rhs = s32[] parameter(1) +} + +ENTRY main { + operand = s32[3,3] parameter(0) + indices = s32[2] parameter(1) + updates = s32[3,2] parameter(2) + ROOT scatter = s32[3,3] scatter(operand, indices, updates), + to_apply=update_s32, + update_window_dims={0}, + inserted_window_dims={1}, + scatter_dims_to_operand_dims={1}, + index_vector_dim=1 +} +)"; + std::unique_ptr operand = + LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); + std::unique_ptr scatter_indices = + LiteralUtil::CreateR1({0, 2}); + std::unique_ptr updates = + LiteralUtil::CreateR2({{10, 30}, {40, 60}, {70, 90}}); + RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get()); +} + +XLA_TEST_F(ScatterTest, TensorFlowScatter_Add) { + const string hlo_text = R"( +HloModule TensorFlowScatter_Add + +add_s32 (lhs: s32[], rhs: s32[]) -> s32[] { + lhs = s32[] parameter(0) + rhs = s32[] parameter(1) + ROOT add = s32[] add(s32[] lhs, s32[] rhs) +} + +ENTRY main { + operand = s32[3,3] parameter(0) + indices = s32[2] parameter(1) + updates = s32[2,3] parameter(2) + ROOT scatter = s32[3,3] scatter(operand, indices, updates), + to_apply=add_s32, + update_window_dims={1}, + inserted_window_dims={0}, + scatter_dims_to_operand_dims={0}, + index_vector_dim=1 +} +)"; + std::unique_ptr operand = + LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); + std::unique_ptr scatter_indices = + LiteralUtil::CreateR1({0, 2}); + std::unique_ptr updates = + LiteralUtil::CreateR2({{10, 20, 30}, {70, 80, 90}}); + RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get()); +} + +XLA_TEST_F(ScatterTest, TensorFlowScatter_Mul) { + const string hlo_text = R"( +HloModule TensorFlowScatter_Mul + +mul_s32 (lhs: s32[], rhs: s32[]) -> s32[] { + lhs = s32[] parameter(0) + rhs = s32[] parameter(1) + ROOT mul = s32[] multiply(s32[] lhs, s32[] rhs) +} + +ENTRY main { + operand = s32[3,3] parameter(0) + indices = s32[2] parameter(1) + updates = s32[2,3] parameter(2) + ROOT scatter = s32[3,3] scatter(operand, indices, updates), + to_apply=mul_s32, + update_window_dims={1}, + inserted_window_dims={0}, + scatter_dims_to_operand_dims={0}, + index_vector_dim=1 +} +)"; + std::unique_ptr operand = + LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); + std::unique_ptr scatter_indices = + LiteralUtil::CreateR1({0, 2}); + std::unique_ptr updates = + LiteralUtil::CreateR2({{10, 20, 30}, {70, 80, 90}}); + RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get()); +} + +XLA_TEST_F(ScatterTest, TensorFlowScatter_F32) { + const string hlo_text = R"( +HloModule TensorFlowScatter_F32 + +add_f32 (lhs: f32[], rhs: f32[]) -> f32[] { + lhs = f32[] parameter(0) + rhs = f32[] parameter(1) + ROOT add = f32[] add(f32[] lhs, f32[] rhs) +} + +ENTRY main { + operand = f32[3,3] parameter(0) + indices = s32[2] parameter(1) + updates = f32[2,3] parameter(2) + ROOT scatter = f32[3,3] scatter(operand, indices, updates), + to_apply=add_f32, + update_window_dims={1}, + inserted_window_dims={0}, + scatter_dims_to_operand_dims={0}, + index_vector_dim=1 +} +)"; + std::unique_ptr operand = LiteralUtil::CreateR2( + {{1.1, 2.2, 3.3}, {4.4, 5.5, 6.6}, {7.7, 8.8, 9.9}}); + std::unique_ptr scatter_indices = + LiteralUtil::CreateR1({2, 1}); + std::unique_ptr updates = + LiteralUtil::CreateR2({{0.4, 1.1, 0.7}, {2.3, 3.1, 1.6}}); + RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get()); +} + +XLA_TEST_F(ScatterTest, TensorFlowScatter_RepeatedIndices) { + const char* hlo_text = R"( +HloModule TensorFlowScatter + +add_s32 (lhs: s32[], rhs: s32[]) -> s32[] { + lhs = s32[] parameter(0) + rhs = s32[] parameter(1) + ROOT add = s32[] add(s32[] lhs, s32[] rhs) +} + +ENTRY main { + operand = s32[3,3] parameter(0) + indices = s32[2] parameter(1) + updates = s32[2,3] parameter(2) + ROOT scatter = s32[3,3] scatter(operand, indices, updates), + to_apply=add_s32, + update_window_dims={1}, + inserted_window_dims={0}, + scatter_dims_to_operand_dims={0}, + index_vector_dim=1 +} +)"; + std::unique_ptr operand = + LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); + std::unique_ptr scatter_indices = + LiteralUtil::CreateR1({1, 1}); + std::unique_ptr updates = + LiteralUtil::CreateR2({{10, 20, 30}, {70, 80, 90}}); + RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get()); +} + +XLA_TEST_F(ScatterTest, TensorFlowScatter_MultipleBatchDims) { + const char* hlo_text = R"( +HloModule TensorFlowScatterMultipleBatchDims + +add_s32 (lhs: s32[], rhs: s32[]) -> s32[] { + lhs = s32[] parameter(0) + rhs = s32[] parameter(1) + ROOT add = s32[] add(s32[] lhs, s32[] rhs) +} + +ENTRY main { + operand = s32[3,3] parameter(0) + indices = s32[2,2] parameter(1) + updates = s32[2,3,2] parameter(2) + ROOT scatter = s32[3,3] scatter(operand, indices, updates), + to_apply=add_s32, + update_window_dims={1}, + inserted_window_dims={1}, + scatter_dims_to_operand_dims={1}, + index_vector_dim=2 +} +)"; + std::unique_ptr operand = + LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); + std::unique_ptr scatter_indices = + LiteralUtil::CreateR2({{0, 2}, {2, 1}}); + std::unique_ptr updates = LiteralUtil::CreateR3( + {{{10, 30}, {40, 60}, {70, 90}}, {{5, 5}, {5, 5}, {5, 5}}}); + RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get()); +} + +XLA_TEST_F(ScatterTest, TensorFlowScatterNd) { + const char* hlo_text = R"( +HloModule TensorFlowScatterNd + +update_s32 (lhs: s32[], rhs: s32[]) -> s32[] { + lhs = s32[] parameter(0) + ROOT rhs = s32[] parameter(1) +} + +ENTRY main { + operand = s32[3,3,2] parameter(0) + indices = s32[2,2] parameter(1) + updates = s32[2,2] parameter(2) + ROOT scatter = s32[3,3,2] scatter(operand, indices, updates), + to_apply=update_s32, + update_window_dims={1}, + inserted_window_dims={0,1}, + scatter_dims_to_operand_dims={0,1}, + index_vector_dim=1 +} +)"; + std::unique_ptr operand = + LiteralUtil::CreateR3({{{-1, 1}, {-2, 2}, {-3, 3}}, // + {{-4, 4}, {-5, 5}, {-6, 6}}, // + {{-7, 7}, {-8, 8}, {-9, 9}}}); + std::unique_ptr scatter_indices = + LiteralUtil::CreateR2({{0, 0}, {1, 0}}); + std::unique_ptr updates = + LiteralUtil::CreateR2({{-10, 10}, {-40, 40}}); + RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get()); +} + +XLA_TEST_F(ScatterTest, TensorFlowScatterNd_NonDefaultIndexVectorDim) { + const char* hlo_text = R"( +HloModule TensorFlowScatterNdNonDefaultIndexVectorDim + +update_s32 (lhs: s32[], rhs: s32[]) -> s32[] { + lhs = s32[] parameter(0) + ROOT rhs = s32[] parameter(1) +} + +ENTRY main { + operand = s32[3,3,2] parameter(0) + indices = s32[2,2] parameter(1) + updates = s32[2,2] parameter(2) + ROOT scatter = s32[3,3,2] scatter(operand, indices, updates), + to_apply=update_s32, + update_window_dims={1}, + inserted_window_dims={0,1}, + scatter_dims_to_operand_dims={0,1}, + index_vector_dim=0 +} +)"; + std::unique_ptr operand = + LiteralUtil::CreateR3({{{-1, 1}, {-2, 2}, {-3, 3}}, // + {{-4, 4}, {-5, 5}, {-6, 6}}, // + {{-7, 7}, {-8, 8}, {-9, 9}}}); + std::unique_ptr scatter_indices = + LiteralUtil::CreateR2({{0, 0}, {1, 0}}); + std::unique_ptr updates = + LiteralUtil::CreateR2({{-10, 10}, {-20, 20}}); + RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get()); +} + +XLA_TEST_F(ScatterTest, DynamicUpdateSlice) { + const char* hlo_text = R"( +HloModule DynamicUpdateSlice + +update_s32 (lhs: s32[], rhs: s32[]) -> s32[] { + lhs = s32[] parameter(0) + ROOT rhs = s32[] parameter(1) +} + +ENTRY main { + operand = s32[3,3] parameter(0) + indices = s32[2] parameter(1) + updates = s32[1,1] parameter(2) + ROOT scatter = s32[3,3] scatter(operand, indices, updates), + to_apply=update_s32, + update_window_dims={0,1}, + inserted_window_dims={}, + scatter_dims_to_operand_dims={0,1}, + index_vector_dim=0 +} +)"; + std::unique_ptr operand = + LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); + std::unique_ptr scatter_indices = + LiteralUtil::CreateR1({1, 1}); + std::unique_ptr updates = LiteralUtil::CreateR2({{10}}); + RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get()); +} + +XLA_TEST_F(ScatterTest, BatchDynamicUpdateSlice) { + const char* hlo_text = R"( +HloModule BatchDynamicUpdateSlice + +update_s32 (lhs: s32[], rhs: s32[]) -> s32[] { + lhs = s32[] parameter(0) + ROOT rhs = s32[] parameter(1) +} + +ENTRY main { + operand = s32[3,3] parameter(0) + indices = s32[2,2] parameter(1) + updates = s32[2,1,1] parameter(2) + ROOT scatter = s32[3,3] scatter(operand, indices, updates), + to_apply=update_s32, + update_window_dims={1,2}, + inserted_window_dims={}, + scatter_dims_to_operand_dims={0,1}, + index_vector_dim=0 +} +)"; + std::unique_ptr operand = + LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); + std::unique_ptr scatter_indices = + LiteralUtil::CreateR2({{2, 1}, {1, 1}}); + std::unique_ptr updates = + LiteralUtil::CreateR3({{{10}}, {{20}}}); + RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get()); +} + +XLA_TEST_F(ScatterTest, ZeroDimBounds) { + const char* hlo_text = R"( +HloModule TensorFlowScatter_ZeroDimBounds + +update_s32 (lhs: s32[], rhs: s32[]) -> s32[] { + lhs = s32[] parameter(0) + ROOT rhs = s32[] parameter(1) +} + +ENTRY main { + operand = s32[3,0] parameter(0) + indices = s32[2] parameter(1) + updates = s32[2,0] parameter(2) + ROOT scatter = s32[3,0] scatter(operand, indices, updates), + to_apply=update_s32, + update_window_dims={1}, + inserted_window_dims={0}, + scatter_dims_to_operand_dims={0}, + index_vector_dim=1 +} +)"; + std::unique_ptr operand = LiteralUtil::CreateR2({{}, {}, {}}); + std::unique_ptr scatter_indices = + LiteralUtil::CreateR1({0, 2}); + std::unique_ptr updates = LiteralUtil::CreateR2({{}, {}}); + RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get()); +} + +XLA_TEST_F(ScatterTest, NoUpdateWindowDims) { + const string hlo_text = R"( +HloModule Scatter_NoUpdateWindowDims + +add_s32 (lhs: s32[], rhs: s32[]) -> s32[] { + lhs = s32[] parameter(0) + rhs = s32[] parameter(1) + ROOT add = s32[] add(s32[] lhs, s32[] rhs) +} + +ENTRY main { + operand = s32[3] parameter(0) + indices = s32[2,2,1] parameter(1) + updates = s32[2,2] parameter(2) + ROOT scatter = s32[3] scatter(operand, indices, updates), + to_apply=add_s32, + update_window_dims={}, + inserted_window_dims={0}, + scatter_dims_to_operand_dims={0}, + index_vector_dim=2 +} +)"; + std::unique_ptr operand = LiteralUtil::CreateR1({0, 1, 2}); + std::unique_ptr scatter_indices = + LiteralUtil::CreateR3({{{0}, {1}}, {{2}, {1}}}); + std::unique_ptr updates = + LiteralUtil::CreateR2({{10, 20}, {30, 40}}); + RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get()); +} + +XLA_TEST_F(ScatterTest, OutOfBoundsIndex) { + const string hlo_text = R"( +HloModule BatchDynamicSlice + +update_s32 (lhs: s32[], rhs: s32[]) -> s32[] { + lhs = s32[] parameter(0) + ROOT rhs = s32[] parameter(1) +} + +ENTRY main { + operand = s32[3,3]{1,0} parameter(0) + indices = s32[6,2]{1,0} parameter(1) + updates = s32[6,1,1]{2,1,0} parameter(2) + ROOT scatter = s32[3,3]{1,0} scatter(operand, indices, updates), + to_apply=update_s32, + update_window_dims={1,2}, + inserted_window_dims={}, + scatter_dims_to_operand_dims={0,1}, + index_vector_dim=1 +} +)"; + std::unique_ptr operand = + LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); + std::unique_ptr scatter_indices = LiteralUtil::CreateR2( + {{2, 7}, {2, 1}, {1, 1}, {5, 1}, {2147483647, 1}, {1, 2}}); + std::unique_ptr updates = LiteralUtil::CreateR3( + {{{10}}, {{20}}, {{30}}, {{40}}, {{50}}, {{60}}}); + RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get()); +} + +XLA_TEST_F(ScatterTest, OutOfBoundsUnsignedIndex) { + const string hlo_text = R"( +HloModule BatchDynamicSlice + +update_s32 (lhs: s32[], rhs: s32[]) -> s32[] { + lhs = s32[] parameter(0) + ROOT rhs = s32[] parameter(1) +} + +ENTRY main { + operand = s32[3,3]{1,0} parameter(0) + indices = u32[6,2]{1,0} parameter(1) + updates = s32[6,1,1]{2,1,0} parameter(2) + ROOT scatter = s32[3,3]{1,0} scatter(operand, indices, updates), + to_apply=update_s32, + update_window_dims={1,2}, + inserted_window_dims={}, + scatter_dims_to_operand_dims={0,1}, + index_vector_dim=1 +} +)"; + std::unique_ptr operand = + LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); + std::unique_ptr scatter_indices = LiteralUtil::CreateR2( + {{2, 7}, {2, 1}, {1, 1}, {5, 1}, {2147483648u, 1}, {1, 2}}); + std::unique_ptr updates = LiteralUtil::CreateR3( + {{{10}}, {{20}}, {{30}}, {{40}}, {{50}}, {{60}}}); + RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get()); +} + +XLA_TEST_F(ScatterTest, NegativeIndex) { + const string hlo_text = R"( +HloModule BatchDynamicSlice + +update_s32 (lhs: s32[], rhs: s32[]) -> s32[] { + lhs = s32[] parameter(0) + ROOT rhs = s32[] parameter(1) +} + +ENTRY main { + operand = s32[3,3]{1,0} parameter(0) + indices = s32[6,2]{1,0} parameter(1) + updates = s32[6,1,1]{2,1,0} parameter(2) + ROOT scatter = s32[3,3]{1,0} scatter(operand, indices, updates), + to_apply=update_s32, + update_window_dims={1,2}, + inserted_window_dims={}, + scatter_dims_to_operand_dims={0,1}, + index_vector_dim=1 +} +)"; + std::unique_ptr operand = + LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); + std::unique_ptr scatter_indices = LiteralUtil::CreateR2( + {{2, 7}, {2, 1}, {1, 1}, {-500, 1}, {-2147483648, 1}, {1, 2}}); + std::unique_ptr updates = LiteralUtil::CreateR3( + {{{10}}, {{20}}, {{30}}, {{40}}, {{50}}, {{60}}}); + RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get()); +} + +XLA_TEST_F(ScatterTest, OneScalarIndex) { + const char* hlo_text = R"( +HloModule OneScalarIndex + +update_s32 (lhs: s32[], rhs: s32[]) -> s32[] { + lhs = s32[] parameter(0) + ROOT rhs = s32[] parameter(1) +} + +ENTRY main { + operand = s32[2,3,2]{2,1,0} parameter(0) + index = s32[] parameter(1) + updates = s32[1,3,2]{2,1,0} parameter(2) + ROOT scatter = s32[2,3,2]{2,1,0} scatter(operand, index, updates), + to_apply=update_s32, + update_window_dims={0,1,2}, + inserted_window_dims={}, + scatter_dims_to_operand_dims={0}, + index_vector_dim=0 +} +)"; + std::unique_ptr operand = LiteralUtil::CreateR3( + {{{1, 2}, {3, 4}, {5, 6}}, {{7, 8}, {9, 10}, {11, 12}}}); + std::unique_ptr scatter_indices = LiteralUtil::CreateR0(1); + std::unique_ptr updates = + LiteralUtil::CreateR3({{{10, 20}, {30, 40}, {50, 60}}}); + RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get()); +} + +XLA_TEST_F(ScatterTest, ScalarUpdate) { + const char* hlo_text = R"( +HloModule ScalarUpdate + +update_s32 (lhs: s32[], rhs: s32[]) -> s32[] { + lhs = s32[] parameter(0) + ROOT rhs = s32[] parameter(1) +} + +ENTRY main { + operand = s32[4]{0} parameter(0) + index = s32[] parameter(1) + updates = s32[] parameter(2) + ROOT scatter = s32[4]{0} scatter(operand, index, updates), + to_apply=update_s32, + update_window_dims={}, + inserted_window_dims={0}, + scatter_dims_to_operand_dims={0}, + index_vector_dim=0 +} +)"; + std::unique_ptr operand = LiteralUtil::CreateR1({1, 2, 3, 4}); + std::unique_ptr scatter_indices = LiteralUtil::CreateR0(1); + std::unique_ptr updates = LiteralUtil::CreateR0(25); + RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get()); +} + +XLA_TEST_F(ScatterTest, EmptyIndices) { + const string hlo_text = R"( +HloModule EmptyIndices + +update_s32 (lhs: s32[], rhs: s32[]) -> s32[] { + lhs = s32[] parameter(0) + ROOT rhs = s32[] parameter(1) +} + +ENTRY main { + operand = s32[3] parameter(0) + indices = s32[0] parameter(1) + updates = s32[0] parameter(2) + ROOT scatter = s32[3] scatter(operand, indices, updates), + to_apply=update_s32, + update_window_dims={}, + inserted_window_dims={0}, + scatter_dims_to_operand_dims={0}, + index_vector_dim=1 +} +)"; + std::unique_ptr operand = LiteralUtil::CreateR1({1, 2, 3}); + std::unique_ptr scatter_indices = LiteralUtil::CreateR1({}); + std::unique_ptr updates = LiteralUtil::CreateR1({}); + RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get()); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/tests/select_and_scatter_test.cc b/tensorflow/compiler/xla/tests/select_and_scatter_test.cc index ceb795219ae8f31f0a38865e5a84fc975f7aa2d7..e3d4f98dd7432d1dce7e697586e8b17105dc82e7 100644 --- a/tensorflow/compiler/xla/tests/select_and_scatter_test.cc +++ b/tensorflow/compiler/xla/tests/select_and_scatter_test.cc @@ -22,7 +22,7 @@ limitations under the License. #include "tensorflow/compiler/xla/client/lib/arithmetic.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/client/padding.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/literal.h" diff --git a/tensorflow/compiler/xla/tests/select_test.cc b/tensorflow/compiler/xla/tests/select_test.cc index 59409ab26e1c19a8271318c18e19caa7b8ddc3b7..1c01402798658877889527a5dd02d5c74787ff99 100644 --- a/tensorflow/compiler/xla/tests/select_test.cc +++ b/tensorflow/compiler/xla/tests/select_test.cc @@ -18,7 +18,7 @@ limitations under the License. #include "tensorflow/compiler/xla/client/global_data.h" #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/compiler/xla/tests/test_macros.h" diff --git a/tensorflow/compiler/xla/tests/slice_test.cc b/tensorflow/compiler/xla/tests/slice_test.cc index a593faca0035b64670f294f81fd5b6d95f35cd88..b8ad6668f80a3002eff3cc458997966ee67c8d4b 100644 --- a/tensorflow/compiler/xla/tests/slice_test.cc +++ b/tensorflow/compiler/xla/tests/slice_test.cc @@ -20,7 +20,7 @@ limitations under the License. #include "tensorflow/compiler/xla/array2d.h" #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/reference_util.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" diff --git a/tensorflow/compiler/xla/tests/test_utils.cc b/tensorflow/compiler/xla/tests/test_utils.cc index 2647937013222ccfdae98b0c1d141f461020b5c9..2f1d97b25d5c3e5116256a6303859bbcdb45218e 100644 --- a/tensorflow/compiler/xla/tests/test_utils.cc +++ b/tensorflow/compiler/xla/tests/test_utils.cc @@ -13,12 +13,15 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/xla/tests/test_utils.h" +#include + +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/primitive_util.h" #include "tensorflow/compiler/xla/service/hlo_dataflow_analysis.h" #include "tensorflow/compiler/xla/service/hlo_verifier.h" #include "tensorflow/compiler/xla/service/transfer_manager.h" +#include "tensorflow/compiler/xla/tests/test_utils.h" namespace xla { @@ -26,89 +29,101 @@ namespace { template void PopulateWithRandomFloatingPointDataImpl(Literal* literal, - std::minstd_rand0* engine) { + std::minstd_rand0* engine, + bool no_duplicates) { CHECK(engine != nullptr); CHECK_EQ(literal->shape().element_type(), primitive_util::NativeToPrimitiveType()); - // Create uniform numbers between 1 and 1.125 to avoid creating denormal - // numbers. - std::uniform_real_distribution generator(1.0f, 1.125f); - const bool should_index_bias = ShapeUtil::ElementsIn(literal->shape()) > 1000; - TF_CHECK_OK(literal->Populate( - [&](tensorflow::gtl::ArraySlice indices) { - // Generate a random uniform number from -0.0625 and 0.0625 and bias it - // with a position dependent number with mean 0.037109375. These number - // should allow for long chains of accumulation without being too close - // to zero or too large to accumulate all numbers accurately. Only do - // this for large literals where the number of elements is much greater - // than 47 otherwise only negative values are produced. - // - // The value is positionally biased using a product of the indices. Add - // one to each index value to avoid collapsing to zero if any of the - // indices are zero. - int64 index_product = 1; - for (int64 i : indices) { - index_product *= (1 + i); - } - const int64 negative_bias = should_index_bias ? 47 : 0; - FloatT index_bias = - static_cast(index_product % 113 - negative_bias) / - static_cast(256.0f); - return static_cast(generator(*engine) - 1.0625f) + index_bias; - })); + if (no_duplicates) { + // Duplicates may be generated if the number of elements in the literal + // exceeds the number of positive values supported by the type. + FloatT next_value = std::numeric_limits::min(); + for (FloatT& value : literal->data()) { + value = next_value; + next_value = + std::nextafter(next_value, std::numeric_limits::max()); + } + std::shuffle(literal->data().begin(), literal->data().end(), + *engine); + } else { + std::uniform_real_distribution generator(-0.1f, 0.2f); + for (FloatT& value : literal->data()) { + value = static_cast(generator(*engine)); + } + } } template void PopulateWithRandomFloatingPointData(Literal* literal, - std::minstd_rand0* engine) { + std::minstd_rand0* engine, + bool no_duplicates) { CHECK(engine != nullptr); - PopulateWithRandomFloatingPointDataImpl(literal, engine); + PopulateWithRandomFloatingPointDataImpl(literal, engine, + no_duplicates); } template <> void PopulateWithRandomFloatingPointData(Literal* literal, - std::minstd_rand0* engine) { + std::minstd_rand0* engine, + bool no_duplicates) { + // no_duplicates is ignored for half types. Unique values can only be + // generated for arrays with fewer than ~2**16 elements and no_duplicates is + // best-effort anyway. CHECK(engine != nullptr); - PopulateWithRandomFloatingPointDataImpl(literal, engine); + std::uniform_real_distribution generator(-0.1f, 0.2f); + for (half& value : literal->data()) { + value = static_cast(generator(*engine)); + } } -// The standard library does not have a case for bfloat16, unsurprisingly, so we -// handle that one specially. template <> void PopulateWithRandomFloatingPointData(Literal* literal, - std::minstd_rand0* engine) { + std::minstd_rand0* engine, + bool no_duplicates) { + // no_duplicates is ignored for bfloat types. Unique values can only be + // generated for arrays with fewer than ~2**16 elements and no_duplicates is + // best-effort anyway. CHECK(engine != nullptr); - CHECK_EQ(literal->shape().element_type(), BF16); - std::uniform_real_distribution generator(-0.9f, 1.0f); - TF_CHECK_OK(literal->Populate( - [&](tensorflow::gtl::ArraySlice /*indices*/) { - return static_cast(generator(*engine)); - })); + std::uniform_real_distribution generator(-0.1f, 0.2f); + for (bfloat16& value : literal->data()) { + value = static_cast(generator(*engine)); + } } template -void PopulateWithRandomIntegralData(Literal* literal, - std::minstd_rand0* engine) { +void PopulateWithRandomIntegralData(Literal* literal, std::minstd_rand0* engine, + bool no_duplicates) { CHECK(engine != nullptr); CHECK_EQ(literal->shape().element_type(), primitive_util::NativeToPrimitiveType()); - std::uniform_int_distribution generator( - std::numeric_limits::lowest(), std::numeric_limits::max()); - TF_CHECK_OK(literal->Populate( - [&](tensorflow::gtl::ArraySlice /*indices*/) { - return generator(*engine); - })); + if (no_duplicates && ShapeUtil::ElementsIn(literal->shape()) < + std::numeric_limits::max()) { + std::iota(literal->data().begin(), literal->data().end(), 0); + std::shuffle(literal->data().begin(), literal->data().end(), + *engine); + } else { + std::uniform_int_distribution generator( + std::numeric_limits::lowest(), std::numeric_limits::max()); + for (IntT& value : literal->data()) { + value = generator(*engine); + } + } } // Similar to MakeFakeLiteral but takes a random number generator engine to -// enable reusing the engine across randomly generated literals. +// enable reusing the engine across randomly generated literals. 'no_duplicates' +// indicates that there should be no duplicate values in each generated +// array. This is uniqueness is best-effort only. Some types (half and bfloat16) +// are not supported and uniqueness cannot be guaranteed if the number of +// elements exceeds the number of different values supported by the type. StatusOr> MakeFakeLiteralInternal( - const Shape& shape, std::minstd_rand0* engine) { + const Shape& shape, std::minstd_rand0* engine, bool no_duplicates) { if (ShapeUtil::IsTuple(shape)) { std::vector> elements; for (const Shape& element_shape : shape.tuple_shapes()) { - TF_ASSIGN_OR_RETURN(std::unique_ptr element, - MakeFakeLiteralInternal(element_shape, engine)); + TF_ASSIGN_OR_RETURN( + std::unique_ptr element, + MakeFakeLiteralInternal(element_shape, engine, no_duplicates)); elements.push_back(std::move(element)); } return LiteralUtil::MakeTupleOwned(std::move(elements)); @@ -116,43 +131,55 @@ StatusOr> MakeFakeLiteralInternal( if (engine == nullptr) { return Literal::CreateFromShape(shape); } - auto literal = MakeUnique(shape); + auto literal = absl::make_unique(shape); switch (shape.element_type()) { case BF16: - PopulateWithRandomFloatingPointData(literal.get(), engine); + PopulateWithRandomFloatingPointData(literal.get(), engine, + no_duplicates); break; case F16: - PopulateWithRandomFloatingPointData(literal.get(), engine); + PopulateWithRandomFloatingPointData(literal.get(), engine, + no_duplicates); break; case F32: - PopulateWithRandomFloatingPointData(literal.get(), engine); + PopulateWithRandomFloatingPointData(literal.get(), engine, + no_duplicates); break; case F64: - PopulateWithRandomFloatingPointData(literal.get(), engine); + PopulateWithRandomFloatingPointData(literal.get(), engine, + no_duplicates); break; case S8: - PopulateWithRandomIntegralData(literal.get(), engine); + PopulateWithRandomIntegralData(literal.get(), engine, + no_duplicates); break; case U8: - PopulateWithRandomIntegralData(literal.get(), engine); + PopulateWithRandomIntegralData(literal.get(), engine, + no_duplicates); break; case S16: - PopulateWithRandomIntegralData(literal.get(), engine); + PopulateWithRandomIntegralData(literal.get(), engine, + no_duplicates); break; case U16: - PopulateWithRandomIntegralData(literal.get(), engine); + PopulateWithRandomIntegralData(literal.get(), engine, + no_duplicates); break; case S32: - PopulateWithRandomIntegralData(literal.get(), engine); + PopulateWithRandomIntegralData(literal.get(), engine, + no_duplicates); break; case U32: - PopulateWithRandomIntegralData(literal.get(), engine); + PopulateWithRandomIntegralData(literal.get(), engine, + no_duplicates); break; case S64: - PopulateWithRandomIntegralData(literal.get(), engine); + PopulateWithRandomIntegralData(literal.get(), engine, + no_duplicates); break; case U64: - PopulateWithRandomIntegralData(literal.get(), engine); + PopulateWithRandomIntegralData(literal.get(), engine, + no_duplicates); break; case PRED: { std::uniform_int_distribution generator(0, 1); @@ -208,16 +235,12 @@ bool NeedsInitValue(const HloUse& use) { // Generate random values that are constrained to the input_shape minus the // output_shape so as not to produce wrapping slices, for instance. -std::unique_ptr MakeRandomNonwrappingSliceIndex( - const Shape& input_shape, const Shape& slice_shape, - std::minstd_rand0* engine) { - const int64 rank = ShapeUtil::Rank(input_shape); - std::vector start_indices(rank); +std::unique_ptr MakeRandomIndex( + tensorflow::gtl::ArraySlice index_space, std::minstd_rand0* engine) { + std::vector start_indices(index_space.size()); if (engine != nullptr) { - for (int i = 0; i < rank; ++i) { - const int32 upper_bound = ShapeUtil::GetDimension(input_shape, i) - - ShapeUtil::GetDimension(slice_shape, i); - std::uniform_int_distribution generator(0, upper_bound); + for (int i = 0; i < index_space.size(); ++i) { + std::uniform_int_distribution generator(0, index_space[i]); start_indices[i] = generator(*engine); } } @@ -254,6 +277,11 @@ std::vector FindConstrainedUses( auto converted_uses = FindConstrainedUses(dataflow, *instruction); constrained_uses.insert(constrained_uses.end(), converted_uses.begin(), converted_uses.end()); + } else if (opcode == HloOpcode::kSort && + instruction->operand_count() == 2 && op_num == 0) { + // Operand 0 of sort is the array of keys used for key/value + // (two-operand) kSort instructions. + constrained_uses.push_back(instruction); } } } @@ -267,56 +295,66 @@ std::vector FindConstrainedUses( StatusOr> CreateLiteralForConstrainedUses( const tensorflow::gtl::ArraySlice constrained_uses, const HloInstruction& param, std::minstd_rand0* engine) { - HloInstruction* needs_index = nullptr; - HloInstruction* needs_constant = nullptr; + std::vector index_space; + bool no_duplicates = false; + bool needs_constant = false; ConstantType constant_type = ConstantType::kUnknown; for (HloInstruction* use : constrained_uses) { switch (use->opcode()) { case HloOpcode::kDynamicSlice: - case HloOpcode::kDynamicUpdateSlice: - if (needs_index != nullptr) { - auto needs_index_shape = needs_index->shape(); - auto use_shape = use->shape(); - if (needs_index->opcode() == HloOpcode::kDynamicSlice) { - needs_index_shape = needs_index->operand(0)->shape(); - } - if (use->opcode() == HloOpcode::kDynamicSlice) { - use_shape = use->operand(0)->shape(); + case HloOpcode::kDynamicUpdateSlice: { + const Shape& indexed_shape = use->operand(0)->shape(); + const Shape& slice_shape = use->opcode() == HloOpcode::kDynamicSlice + ? use->shape() + : use->operand(1)->shape(); + const int64 rank = ShapeUtil::Rank(indexed_shape); + if (!index_space.empty()) { + TF_RET_CHECK(rank == index_space.size()); + for (int64 i = 0; i < rank; ++i) { + index_space[i] = std::min( + index_space[i], ShapeUtil::GetDimension(indexed_shape, i) - + ShapeUtil::GetDimension(slice_shape, i)); } - if (!ShapeUtil::Equal(needs_index_shape, use_shape)) { - return Unimplemented( - "Conflicting operand generation slice index constraints\n"); + } else { + index_space.resize(rank); + for (int64 i = 0; i < rank; ++i) { + index_space[i] = ShapeUtil::GetDimension(indexed_shape, i) - + ShapeUtil::GetDimension(slice_shape, i); } } - needs_index = use; break; + } case HloOpcode::kReduce: case HloOpcode::kReduceWindow: - needs_constant = use; + needs_constant = true; constant_type = GetInitValue(*use->to_apply()); break; case HloOpcode::kSelectAndScatter: - needs_constant = use; + needs_constant = true; constant_type = GetInitValue(*use->scatter()); break; + case HloOpcode::kSort: + no_duplicates = true; + break; + default: return Unimplemented( "Constrained operand generation not implemented for %s.", use->ToString().c_str()); } } - if (needs_index != nullptr && needs_constant != nullptr) { - return Unimplemented( - "Conflicting operand generation constraints.\nNeeds index: %s\nNeeds " - "constant: %s\n", - needs_index->ToString().c_str(), needs_constant->ToString().c_str()); + int constraint_count = 0; + constraint_count += no_duplicates ? 1 : 0; + constraint_count += !index_space.empty() ? 1 : 0; + constraint_count += needs_constant ? 1 : 0; + if (constraint_count > 1) { + return Unimplemented("Conflicting operand generation constraints."); } - if (needs_index != nullptr) { - return MakeRandomNonwrappingSliceIndex(needs_index->operand(0)->shape(), - needs_index->shape(), engine); - } else if (needs_constant != nullptr) { + if (!index_space.empty()) { + return MakeRandomIndex(index_space, engine); + } else if (needs_constant) { switch (constant_type) { case ConstantType::kZero: return LiteralUtil::Zero(param.shape().element_type()).CloneToUnique(); @@ -325,10 +363,11 @@ StatusOr> CreateLiteralForConstrainedUses( case ConstantType::kUnknown: // We want the identity element for the computation, but we don't really // know what it is - so any value we generate will be just as wrong. - return MakeFakeLiteralInternal(param.shape(), engine); + return MakeFakeLiteralInternal(param.shape(), engine, + /*no_duplicates=*/false); } } else { - return MakeFakeLiteralInternal(param.shape(), engine); + return MakeFakeLiteralInternal(param.shape(), engine, no_duplicates); } } @@ -345,19 +384,26 @@ StatusOr> MakeConstrainedArgument( StatusOr> MakeFakeLiteral(const Shape& shape, bool pseudo_random) { - auto engine = pseudo_random ? MakeUnique() : nullptr; - return MakeFakeLiteralInternal(shape, engine.get()); + auto engine = + pseudo_random ? absl::make_unique() : nullptr; + return MakeFakeLiteralInternal(shape, engine.get(), /*no_duplicates=*/false); } StatusOr>> MakeFakeArguments( HloModule* const module, bool pseudo_random) { + auto engine = + pseudo_random ? absl::make_unique() : nullptr; + return MakeFakeArguments(module, engine.get()); +} + +StatusOr>> MakeFakeArguments( + HloModule* const module, std::minstd_rand0* engine) { TF_ASSIGN_OR_RETURN(auto dataflow, HloDataflowAnalysis::Run(*module)); const auto params = module->entry_computation()->parameter_instructions(); - auto engine = pseudo_random ? MakeUnique() : nullptr; std::vector> arguments(params.size()); for (int i = 0; i < params.size(); ++i) { - TF_ASSIGN_OR_RETURN(arguments[i], MakeConstrainedArgument( - *dataflow, *params[i], engine.get())); + arguments[i] = + MakeConstrainedArgument(*dataflow, *params[i], engine).ValueOrDie(); } return std::move(arguments); } diff --git a/tensorflow/compiler/xla/tests/test_utils.h b/tensorflow/compiler/xla/tests/test_utils.h index e59f215a9a3ace80d7a23e1bbc40970c7a63ea0d..1aca1d8ef7e714c7ebb4d522f0d2dd28992fd16b 100644 --- a/tensorflow/compiler/xla/tests/test_utils.h +++ b/tensorflow/compiler/xla/tests/test_utils.h @@ -20,9 +20,9 @@ limitations under the License. #include #include +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/literal.h" -#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/gtl/array_slice.h" @@ -63,8 +63,17 @@ StatusOr> MakeFakeLiteral(const Shape& shape, // Generates a vector of arguments containing fake data. The number, shape and // layout of the arguments is appropriate for given HLO module. // -// Will handle special cases such as making sure that indices used for dynamic -// slices are bounded, reduces that call adds use 0 as an init value, etc. +// A best-effort attempt is made to generate the data in a way which produce +// stable computation results across platforms. Specifically: +// +// (1) Init values of reductions should be the identity of the reduction +// computation. +// +// (2) Indices of dynamic slices and update slices should be in bounds. +// +// (3) Keys of key/value sorts should contain no duplicates. +// +// These constraints are best-effort only. // // If pseudo_random is true, the generated numbers will be generated // deterministically in a pseudo random way unless the values are constrated to @@ -78,6 +87,12 @@ StatusOr> MakeFakeLiteral(const Shape& shape, StatusOr>> MakeFakeArguments( HloModule* const module, bool pseudo_random = true); +// Overload which accepts a random number generator. This enables generation of +// different random values with sequential calls to MakeFakeArguments by reusing +// the same generator. +StatusOr>> MakeFakeArguments( + HloModule* const module, std::minstd_rand0* engine); + // Check that a given module satisfies various constraints before trying to // execute it. Status VerifyHloModule(HloModule* const module, diff --git a/tensorflow/compiler/xla/tests/test_utils_test.cc b/tensorflow/compiler/xla/tests/test_utils_test.cc index 8f424ae81f592bfd8accd8decb8fc363f7561c73..322c8ef090cf867f65cada5cb1dbae188f83bad6 100644 --- a/tensorflow/compiler/xla/tests/test_utils_test.cc +++ b/tensorflow/compiler/xla/tests/test_utils_test.cc @@ -15,11 +15,12 @@ limitations under the License. #include "tensorflow/compiler/xla/tests/test_utils.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/tests/local_client_test_base.h" #include "tensorflow/compiler/xla/tests/test_macros.h" +#include "tensorflow/core/lib/core/casts.h" #include "tensorflow/core/lib/core/status_test_util.h" namespace xla { @@ -72,5 +73,106 @@ XLA_TEST_F(TestUtilsTest, Token) { TF_ASSERT_OK(MakeFakeArguments(module.get()).status()); } +XLA_TEST_F(TestUtilsTest, MultipleIndexSpacesForDynamicSlices) { + auto module = ParseHloString( + R"(HloModule index_space_module + + ENTRY IndexSpace { + index_param = s32[3]{0} parameter(0) + array_param.1 = f32[123,4,789]{0,1,2} parameter(1) + array_param.2 = f32[3,3000,5]{0,1,2} parameter(2) + dynamic-slice.1 = f32[1,2,3] dynamic-slice(array_param.1, index_param), dynamic_slice_sizes={1,2,3} + ROOT dynamic-slice.2 = f32[3,2,2] dynamic-slice(array_param.2, index_param), dynamic_slice_sizes={3,2,2} + })") + .ValueOrDie(); + TF_ASSERT_OK_AND_ASSIGN(std::vector> args, + MakeFakeArguments(module.get())); + ASSERT_EQ(args.size(), 3); + const Literal& index_arg = *args[0]; + + EXPECT_EQ(index_arg.Get({0}), 0); + + EXPECT_GE(index_arg.Get({1}), 0); + EXPECT_LE(index_arg.Get({1}), 2); + + EXPECT_GE(index_arg.Get({2}), 0); + EXPECT_LE(index_arg.Get({2}), 3); +} + +XLA_TEST_F(TestUtilsTest, MultipleIndexSpacesForDynamicUpdateSlices) { + auto module = ParseHloString( + R"(HloModule index_space_module + + ENTRY IndexSpace { + index_param = s32[3]{0} parameter(0) + array_param.1 = f32[123,4,789]{0,1,2} parameter(1) + array_param.2 = f32[3,3000,5]{0,1,2} parameter(2) + update_param.1 = f32[1,2,3]{0,1,2} parameter(3) + update_param.2 = f32[3,2,2]{0,1,2} parameter(4) + + dynamic-update-slice.1 = f32[123,4,789] dynamic-update-slice(array_param.1, update_param.1, index_param) + ROOT dynamic-update-slice.2 = f32[3,3000,5] dynamic-update-slice(array_param.2, update_param.2, index_param) + })") + .ValueOrDie(); + TF_ASSERT_OK_AND_ASSIGN(std::vector> args, + MakeFakeArguments(module.get())); + ASSERT_EQ(args.size(), 5); + const Literal& index_arg = *args[0]; + + EXPECT_EQ(index_arg.Get({0}), 0); + + EXPECT_GE(index_arg.Get({1}), 0); + EXPECT_LE(index_arg.Get({1}), 2); + + EXPECT_GE(index_arg.Get({2}), 0); + EXPECT_LE(index_arg.Get({2}), 3); +} + +XLA_TEST_F(TestUtilsTest, NoDuplicatesFloats) { + // Inputs which are sort keys in key/value sorts should have no duplicates. + auto module = ParseHloString(R"( +HloModule sort.148.1589 + +ENTRY %sort.148.1589 (parameter.0: f32[1048576], parameter.1: s32[1048576]) -> (f32[1048576], s32[1048576]) { + %parameter.0 = f32[1048576]{0} parameter(0) + %parameter.1 = s32[1048576]{0} parameter(1) + ROOT %sort.148.1589 = (f32[1048576]{0}, s32[1048576]{0}) sort(f32[1048576]{0} %parameter.0, s32[1048576]{0} %parameter.1), dimensions={0} +} +)") + .ValueOrDie(); + TF_ASSERT_OK_AND_ASSIGN(std::vector> args, + MakeFakeArguments(module.get())); + ASSERT_EQ(args.size(), 2); + const Literal& key_arg = *args[0]; + + tensorflow::gtl::FlatSet key_set; + for (const float& value : key_arg.data()) { + EXPECT_TRUE(key_set.insert(tensorflow::bit_cast(value)).second); + } +} + +XLA_TEST_F(TestUtilsTest, NoDuplicatesInt32) { + // Inputs which are sort keys in key/value sorts should have no duplicates. + auto module = ParseHloString(R"( +HloModule sort.148.1589 + +ENTRY %sort.148.1589 (parameter.0: s32[1048576], parameter.1: s32[1048576]) -> (s32[1048576], s32[1048576]) { + %parameter.0 = s32[1048576]{0} parameter(0) + %parameter.1 = s32[1048576]{0} parameter(1) + ROOT %sort.148.1589 = (s32[1048576]{0}, s32[1048576]{0}) sort(s32[1048576]{0} %parameter.0, s32[1048576]{0} %parameter.1), dimensions={0} +} +)") + .ValueOrDie(); + TF_ASSERT_OK_AND_ASSIGN(std::vector> args, + MakeFakeArguments(module.get())); + ASSERT_EQ(args.size(), 2); + const Literal& key_arg = *args[0]; + + tensorflow::gtl::FlatSet key_set; + for (const int32& value : key_arg.data()) { + EXPECT_TRUE(key_set.insert(tensorflow::bit_cast(value)).second); + } +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/tests/transfer_manager_test.cc b/tensorflow/compiler/xla/tests/transfer_manager_test.cc index 0f86b7f20f9bd7597ece713626ee0e9c23509e05..125513ddfd16cb4e742e7d589e22b721307621ee 100644 --- a/tensorflow/compiler/xla/tests/transfer_manager_test.cc +++ b/tensorflow/compiler/xla/tests/transfer_manager_test.cc @@ -22,6 +22,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/device_memory_allocator.h" #include "tensorflow/compiler/xla/service/generic_transfer_manager.h" #include "tensorflow/compiler/xla/service/shaped_buffer.h" +#include "tensorflow/compiler/xla/service/stream_pool.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" @@ -60,7 +61,7 @@ class TransferManagerTest : public LocalClientTestBase { } protected: - Backend::StreamPtr stream_ptr_; + StreamPool::Ptr stream_ptr_; se::Stream* stream_; private: diff --git a/tensorflow/compiler/xla/tests/transpose_test.cc b/tensorflow/compiler/xla/tests/transpose_test.cc index 6ebb4324f8d20ed9f8886d92b0513441685ed19b..fbe9d1b64aa0c06d65b547c45cfa981800d40ff3 100644 --- a/tensorflow/compiler/xla/tests/transpose_test.cc +++ b/tensorflow/compiler/xla/tests/transpose_test.cc @@ -17,7 +17,7 @@ limitations under the License. #include "tensorflow/compiler/xla/array2d.h" #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/reference_util.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" diff --git a/tensorflow/compiler/xla/tests/tuple_test.cc b/tensorflow/compiler/xla/tests/tuple_test.cc index ad46eaa1c30b90daf46128764accf7b22faec8c8..c101cd2d20131199801f755c96b629ccb65744db 100644 --- a/tensorflow/compiler/xla/tests/tuple_test.cc +++ b/tensorflow/compiler/xla/tests/tuple_test.cc @@ -16,9 +16,10 @@ limitations under the License. #include #include +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/array2d.h" #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/shape_util.h" @@ -504,7 +505,7 @@ XLA_TEST_F(TupleTest, ComplexTuples) { LiteralUtil::CreateR2({{{111, 222}, {331, 442}}, {{1011, 2022}, {3031, 4042}}, {{10011, 20022}, {30031, 40042}}}); - auto prod = MakeUnique(sum->shape()); + auto prod = absl::make_unique(sum->shape()); ASSERT_TRUE(prod->Populate( [&sum](tensorflow::gtl::ArraySlice indexes) { return sum->Get(indexes) * @@ -586,9 +587,9 @@ XLA_TEST_F(TupleHloTest, })); auto expected = LiteralUtil::MakeTupleOwned(LiteralUtil::CreateR1({2, 3})); - auto literal = MakeUnique(); + auto literal = Literal::CreateFromShape(expected->shape()); TF_EXPECT_OK(backend().transfer_manager()->TransferLiteralFromOutfeed( - backend().default_stream_executor(), expected->shape(), literal.get())); + backend().default_stream_executor(), expected->shape(), *literal)); EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *literal)); } diff --git a/tensorflow/compiler/xla/tests/unary_op_test.cc b/tensorflow/compiler/xla/tests/unary_op_test.cc index a90a6fb0a5b5bb5119eee93c9c6a1377e3461b46..20ae68ab74026936c43e5f525eb796eb402a19cb 100644 --- a/tensorflow/compiler/xla/tests/unary_op_test.cc +++ b/tensorflow/compiler/xla/tests/unary_op_test.cc @@ -18,7 +18,7 @@ limitations under the License. #include "tensorflow/compiler/xla/client/global_data.h" #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/compiler/xla/tests/test_macros.h" diff --git a/tensorflow/compiler/xla/tests/vector_ops_reduce_test.cc b/tensorflow/compiler/xla/tests/vector_ops_reduce_test.cc index ea3aba6df1d3fbd492a23b280309322b8524c0bf..ef1b1445bbe555da00db4446d59439b752735a80 100644 --- a/tensorflow/compiler/xla/tests/vector_ops_reduce_test.cc +++ b/tensorflow/compiler/xla/tests/vector_ops_reduce_test.cc @@ -21,7 +21,7 @@ limitations under the License. #include "tensorflow/compiler/xla/array3d.h" #include "tensorflow/compiler/xla/client/lib/arithmetic.h" #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/compiler/xla/tests/test_macros.h" diff --git a/tensorflow/compiler/xla/tests/vector_ops_simple_test.cc b/tensorflow/compiler/xla/tests/vector_ops_simple_test.cc index cacbe83b867e7310d11b641c8e1d7f0a8f7bff4f..3848ec1684cdc9186e14ac0b60315b7520d127f3 100644 --- a/tensorflow/compiler/xla/tests/vector_ops_simple_test.cc +++ b/tensorflow/compiler/xla/tests/vector_ops_simple_test.cc @@ -21,7 +21,7 @@ limitations under the License. #include "tensorflow/compiler/xla/client/global_data.h" #include "tensorflow/compiler/xla/client/lib/arithmetic.h" #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/statusor.h" diff --git a/tensorflow/compiler/xla/tests/while_test.cc b/tensorflow/compiler/xla/tests/while_test.cc index 0a3977800263821e9c5d4e4c73832468e28f02c9..1bdf1867b9330b715b0ba4aca71d56307883c775 100644 --- a/tensorflow/compiler/xla/tests/while_test.cc +++ b/tensorflow/compiler/xla/tests/while_test.cc @@ -20,7 +20,7 @@ limitations under the License. #include "tensorflow/compiler/xla/client/client_library.h" #include "tensorflow/compiler/xla/client/lib/arithmetic.h" #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/service/platform_util.h" @@ -1236,6 +1236,35 @@ TEST_F(WhileTest, WhileWithLoopInvariantOperation) { {param_value.get()}, ErrorSpec(4e-5)); } +TEST_F(WhileTest, DISABLED_ON_INTERPRETER(WhileInfeedCondition)) { + auto while_shape = ShapeUtil::MakeShape(S32, {}); + + XlaComputation condition; + { + XlaBuilder builder("condition"); + Parameter(&builder, 0, while_shape, "state"); + Infeed(&builder, ShapeUtil::MakeShape(PRED, {})); + TF_ASSERT_OK_AND_ASSIGN(condition, builder.Build()); + } + + XlaComputation body; + { + XlaBuilder builder("body"); + auto indvar = Parameter(&builder, 0, while_shape, "state"); + Add(indvar, ConstantR0(&builder, 1)); + TF_ASSERT_OK_AND_ASSIGN(body, builder.Build()); + } + + XlaBuilder builder(TestName()); + While(condition, body, ConstantR0(&builder, 0)); + + TF_ASSERT_OK(client_->TransferToInfeed(*LiteralUtil::CreateR0(true))); + TF_ASSERT_OK(client_->TransferToInfeed(*LiteralUtil::CreateR0(true))); + TF_ASSERT_OK(client_->TransferToInfeed(*LiteralUtil::CreateR0(false))); + + ComputeAndCompareR0(&builder, 2, {}); +} + void BM_WhileLoop(int num_iters) { // Benchmark a simple kernel to measure while loop overheads. tensorflow::testing::StopTiming(); diff --git a/tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc b/tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc index 7a75e5102c2dacf9bbaadba5671bfe68895b1484..e12e095ecdef1d79d29e619f1cf88e91a577e0fd 100644 --- a/tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc +++ b/tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc @@ -16,12 +16,14 @@ limitations under the License. #include #include +#include "absl/algorithm/container.h" #include "tensorflow/compiler/xla/array2d.h" #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/map_util.h" #include "tensorflow/compiler/xla/service/platform_util.h" +#include "tensorflow/compiler/xla/service/stream_pool.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" #include "tensorflow/compiler/xla/tests/test_macros.h" @@ -83,8 +85,8 @@ Status ParseOneProfileOutputLine( tensorflow::gtl::ArraySlice opcodes_to_ignore = {}) { string separator = "[^:]*:: +"; - string match_percentage = "\\d+\\.\\d\\d%"; - string match_cycles = "(\\d+) cycles +\\( *(" + match_percentage + ")\\)"; + string match_percentage = R"(\d+\.\d*% +\d+Σ)"; + string match_cycles = R"((\d+) cycles +\( *()" + match_percentage + R"()\))"; string match_usecs = "([0-9.]+) usec"; string match_flops = "([^ ]*)"; string match_trops = "([^ ]*)"; @@ -115,7 +117,7 @@ Status ParseOneProfileOutputLine( ", Regexp: ", regexp_pattern); } - if (!c_linear_search(opcodes_to_ignore, parsed_line.opcode)) { + if (!absl::c_linear_search(opcodes_to_ignore, parsed_line.opcode)) { InsertOrDie(parsed_results, parsed_line.opcode, parsed_line); } @@ -133,7 +135,7 @@ void ExecuteAndFetchProfile(string* profile_output, LocalClient* client, DeviceMemoryAllocator* allocator = backend->memory_allocator(); auto* transfer_manager = backend->transfer_manager(); TF_ASSERT_OK_AND_ASSIGN( - Backend::StreamPtr stream_ptr, + StreamPool::Ptr stream_ptr, backend->BorrowStream(backend->default_device_ordinal())); TF_ASSERT_OK_AND_ASSIGN( @@ -224,7 +226,7 @@ XLA_TEST_F(HloProfileTest, ProfileSingleComputation) { MaybeFind(parsed_profile_lines, "tanh")); EXPECT_GT(total_profile.cycles, 0); - EXPECT_EQ(total_profile.cycles_percentage, "100.00%"); + EXPECT_EQ(total_profile.cycles_percentage, "100.% 100Σ"); EXPECT_TRUE(HasFlops(total_profile)); EXPECT_TRUE(HasTrops(total_profile)); @@ -293,7 +295,7 @@ XLA_TEST_F(HloProfileTest, ProfileWhileComputation) { tensorflow::str_util::Split(profile_output, '\n'); auto while_body_profile_start = - c_find_if(profile_output_lines, [](tensorflow::StringPiece s) { + absl::c_find_if(profile_output_lines, [](tensorflow::StringPiece s) { return tensorflow::str_util::StartsWith(s, "Execution profile for body"); }); @@ -332,7 +334,7 @@ XLA_TEST_F(HloProfileTest, ProfileWhileComputation) { EXPECT_GT(total_while_body_profile.cycles, 0); EXPECT_EQ(total_while_body_profile.opcode, "[total]"); - EXPECT_EQ(total_while_body_profile.cycles_percentage, "100.00%"); + EXPECT_EQ(total_while_body_profile.cycles_percentage, "100.% 100Σ"); EXPECT_GT(total_while_body_profile.cycles, multiply_profile.cycles); EXPECT_NE(multiply_profile.cycles_percentage, "0.00%"); diff --git a/tensorflow/compiler/xla/text_literal_reader.cc b/tensorflow/compiler/xla/text_literal_reader.cc index 897123d7606db60abc1105b03beb3f23ab249579..7de2c39b3892dc40d09adfed1c39e4aca449039d 100644 --- a/tensorflow/compiler/xla/text_literal_reader.cc +++ b/tensorflow/compiler/xla/text_literal_reader.cc @@ -20,8 +20,8 @@ limitations under the License. #include #include +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/literal.h" -#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/types.h" @@ -102,7 +102,7 @@ StatusOr> TextLiteralReader::ReadAllLines() { ShapeUtil::HumanString(shape).c_str()); } - auto result = MakeUnique(shape); + auto result = absl::make_unique(shape); const float fill = std::numeric_limits::quiet_NaN(); result->PopulateWithValue(fill); std::vector pieces; diff --git a/tensorflow/compiler/xla/tools/BUILD b/tensorflow/compiler/xla/tools/BUILD index d7cabbe876c662fc71237a0fb62141c93e69d14b..40d28a57bfddd3403cad8252df985b746362631f 100644 --- a/tensorflow/compiler/xla/tools/BUILD +++ b/tensorflow/compiler/xla/tools/BUILD @@ -87,6 +87,7 @@ cc_library( "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/client/lib:testing", + "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/service:hlo_proto", "//tensorflow/compiler/xla/service/gpu:infeed_manager", diff --git a/tensorflow/compiler/xla/tools/replay_computation.cc b/tensorflow/compiler/xla/tools/replay_computation.cc index 3bb2f3c0007bbe92aed6a995790284c89719be91..b4774233e588dc407bfb88defca9bf55e08eea09 100644 --- a/tensorflow/compiler/xla/tools/replay_computation.cc +++ b/tensorflow/compiler/xla/tools/replay_computation.cc @@ -30,6 +30,9 @@ limitations under the License. // The output format is: // // file_path: computation_name :: type:literal_str +// +// Note: If you pass multiple modules, they will be compiled in parallel but run +// in series. #include #include @@ -44,6 +47,7 @@ limitations under the License. #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/execution_options_util.h" +#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/service/gpu/infeed_manager.h" #include "tensorflow/compiler/xla/service/hlo.pb.h" @@ -75,6 +79,18 @@ struct Options { int num_runs = 1; }; +std::unique_ptr CompileExecutable(const HloSnapshot& module, + LocalClient* client) { + XlaComputation computation(module.hlo().hlo_module()); + std::vector argument_layouts; + for (const auto& param : computation.proto().program_shape().parameters()) { + argument_layouts.push_back(¶m); + } + return client + ->Compile(computation, argument_layouts, ExecutableBuildOptions()) + .ValueOrDie(); +} + // Invokes the given computation passing arbitrary data for every (unbound) // parameter if use_fake_data, Otherwise use recorded data if available. // @@ -85,6 +101,7 @@ struct Options { // If neither generate_fake_infeed is true nor a fake_infeed_shape is provided, // no infeed is performed. StatusOr ReplayComputation(const HloSnapshot& module, + LocalExecutable* executable, LocalClient* client, const Options& opts) { XlaComputation computation(module.hlo().hlo_module()); @@ -167,34 +184,34 @@ StatusOr ReplayComputation(const HloSnapshot& module, }); } - std::vector argument_layouts; - for (const auto& param : computation.proto().program_shape().parameters()) { - argument_layouts.push_back(¶m); - } - std::unique_ptr executable = - client->Compile(computation, argument_layouts, ExecutableBuildOptions()) - .ValueOrDie(); - - // Do not attmept to run the executable, if num_runs is less than 1. + // Do not attempt to run the executable if num_runs is less than 1. if (opts.num_runs < 1) { return Cancelled("Cancelled after compilation since --num_runs < 1."); } // Run the computation num_runs times, and return the result from the last // execution. + const bool xla_hlo_profile = + legacy_flags::GetDebugOptionsFromFlags().xla_hlo_profile(); StreamExecutorMemoryAllocator allocator( client->platform(), {client->platform()->ExecutorForDevice(0).ValueOrDie()}); tensorflow::gtl::optional result; for (int i = 0; i < opts.num_runs; ++i) { + // If xla_hlo_profile is enabled, print a noisy message before the last run, + // making it easier to separate this profile from the others in the logspam. + if (xla_hlo_profile && i == opts.num_runs - 1) { + LOG(INFO) << "\n\n***** Final run below ******"; + } ExecutionProfile profile; ExecutableRunOptions run_options; run_options.set_execution_profile(&profile); run_options.set_allocator(&allocator); TF_ASSIGN_OR_RETURN(result, executable->Run(argument_ptrs, run_options)); - LOG(INFO) << "Execution took " - << static_cast(profile.compute_time_ns()) / 1e9 << "s"; + LOG(INFO) << "Done executing in " + << static_cast(profile.compute_time_ns()) / 1e9 + << "s: " << module.hlo().hlo_module().name(); } TF_ASSIGN_OR_RETURN(std::unique_ptr result_literal, @@ -206,9 +223,13 @@ StatusOr ParseInputFile(const string& filename, const Options& opts) { tensorflow::Env* env = tensorflow::Env::Default(); HloSnapshot snapshot; - if (tensorflow::ReadBinaryProto(env, filename, &snapshot).ok()) { + auto s = tensorflow::ReadBinaryProto(env, filename, &snapshot); + if (s.ok()) { return snapshot; } + if (s.code() == tensorflow::error::NOT_FOUND) { + return s; + } CHECK(opts.use_fake_data) << "Without --use_fake_data, you must pass an HloSnapshot -- HloProto " "and textual HLO don't carry real data."; @@ -235,15 +256,42 @@ StatusOr ParseInputFile(const string& filename, int RealMain(tensorflow::gtl::ArraySlice args, const Options& opts) { LocalClient* client = ClientLibrary::LocalClientOrDie(); int exit_status = EXIT_SUCCESS; + + std::vector snapshots; for (char* arg : args) { StatusOr maybe_snapshot = ParseInputFile(arg, opts); - if (!maybe_snapshot.ok()) { - continue; + if (maybe_snapshot.ok()) { + snapshots.push_back(std::move(maybe_snapshot).ValueOrDie()); + } else { + LOG(ERROR) << "Can't handle file " << arg << ": " + << maybe_snapshot.status(); } - HloSnapshot snapshot = std::move(maybe_snapshot).ValueOrDie(); - StatusOr result_status = ReplayComputation(snapshot, client, opts); + } + + // Compile all the modules in parallel. + LOG(INFO) << "Compiling " << snapshots.size() << " modules in parallel."; + std::vector> executables; + { + // ThreadPool CHECK-fails if we give it 0 threads. + tensorflow::thread::ThreadPool thread_pool( + tensorflow::Env::Default(), tensorflow::ThreadOptions(), + "compile_modules", std::max(size_t{1}, snapshots.size()), + /*low_latency_hint=*/false); + executables.resize(snapshots.size()); + for (int64 i = 0; i < snapshots.size(); ++i) { + thread_pool.Schedule([&snapshots, &executables, client, i] { + executables[i] = CompileExecutable(snapshots[i], client); + }); + } + } + LOG(INFO) << "Done compiling; now running the modules."; + + for (int64 i = 0; i < executables.size(); ++i) { + LocalExecutable* executable = executables[i].get(); + StatusOr result_status = + ReplayComputation(snapshots[i], executable, client, opts); if (!result_status.ok()) { - fprintf(stderr, "%s: error: %s\n", arg, + fprintf(stderr, "%s: error: %s\n", args[i], result_status.status().ToString().c_str()); exit_status = EXIT_FAILURE; continue; @@ -251,10 +299,11 @@ int RealMain(tensorflow::gtl::ArraySlice args, const Options& opts) { if (opts.print_result) { Literal result = std::move(result_status).ValueOrDie(); - fprintf(stdout, "%s: %s :: %s:%s\n", arg, - snapshot.hlo().hlo_module().name().c_str(), + fprintf(stdout, "%s: %s :: %s:%s\n", args[i], + executable->executable()->module().name().c_str(), ShapeUtil::HumanString(result.shape()).c_str(), result.ToString().c_str()); + auto& snapshot = snapshots[i]; if (snapshot.has_result()) { std::unique_ptr literal = Literal::CreateFromProto(snapshot.result()).ConsumeValueOrDie(); diff --git a/tensorflow/compiler/xla/util.h b/tensorflow/compiler/xla/util.h index 5ae099a4622bb7116c7a17f93060b699ead6e3a6..cc07346ee50c320bd57b23d1c0f2b7396873f178 100644 --- a/tensorflow/compiler/xla/util.h +++ b/tensorflow/compiler/xla/util.h @@ -24,6 +24,7 @@ limitations under the License. #include #include +#include "absl/algorithm/container.h" #include "tensorflow/compiler/xla/status.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/types.h" @@ -434,122 +435,15 @@ std::vector> CommonFactors( // Removes illegal characters from filenames. string SanitizeFileName(string file_name); -template -bool c_all_of(const Container& container, Predicate&& predicate) { - return std::all_of(std::begin(container), std::end(container), - std::forward(predicate)); -} - -template -bool c_any_of(const Container& container, Predicate&& predicate) { - return std::any_of(std::begin(container), std::end(container), - std::forward(predicate)); -} - -template -OutputIterator c_transform(const InputContainer& input_container, - OutputIterator output_iterator, - UnaryOperation&& unary_op) { - return std::transform(std::begin(input_container), std::end(input_container), - output_iterator, - std::forward(unary_op)); -} - -template -OutputIterator c_copy_if(const InputContainer& input_container, - OutputIterator output_iterator, - UnaryPredicate&& predicate) { - return std::copy_if(std::begin(input_container), std::end(input_container), - output_iterator, std::forward(predicate)); -} - -template -OutputIterator c_copy(const InputContainer& input_container, - OutputIterator output_iterator) { - return std::copy(std::begin(input_container), std::end(input_container), - output_iterator); -} - -template -void c_sort(InputContainer& input_container) { - std::sort(std::begin(input_container), std::end(input_container)); -} - -template -void c_sort(InputContainer& input_container, Comparator&& comparator) { - std::sort(std::begin(input_container), std::end(input_container), - std::forward(comparator)); -} - -template -bool c_binary_search(const Sequence& sequence, T&& value) { - return std::binary_search(std::begin(sequence), std::end(sequence), - std::forward(value)); -} - -template -bool c_is_sorted(const C& c) { - return std::is_sorted(std::begin(c), std::end(c)); -} - -template -bool c_is_sorted(const C& c, Compare&& comp) { - return std::is_sorted(std::begin(c), std::end(c), - std::forward(comp)); -} - -template -auto c_adjacent_find(C& c) -> decltype(std::begin(c)) { - return std::adjacent_find(std::begin(c), std::end(c)); -} - -template -auto c_find_if(C& c, Pred&& pred) -> decltype(std::begin(c)) { - return std::find_if(std::begin(c), std::end(c), std::forward(pred)); -} - -template -auto c_find(C& c, Value&& value) -> decltype(std::begin(c)) { - return std::find(std::begin(c), std::end(c), std::forward(value)); -} - -template -void c_reverse(Sequence& sequence) { - std::reverse(std::begin(sequence), std::end(sequence)); -} - -template -typename std::decay::type c_accumulate(const Sequence& sequence, T&& init, - BinaryOp&& binary_op) { - return std::accumulate(std::begin(sequence), std::end(sequence), - std::forward(init), - std::forward(binary_op)); -} - -template -typename std::iterator_traits< - decltype(std::begin(std::declval()))>::difference_type -c_count_if(const C& c, Pred&& pred) { - return std::count_if(std::begin(c), std::end(c), std::forward(pred)); -} - -// Determines whether `value` is present in `c`. -template -bool c_linear_search(const C& c, T&& value) { - auto last = std::end(c); - return std::find(std::begin(c), last, std::forward(value)) != last; -} - template int64 FindIndex(const C& c, Value&& value) { - auto it = c_find(c, std::forward(value)); + auto it = absl::c_find(c, std::forward(value)); return std::distance(c.begin(), it); } template bool ArrayContains(tensorflow::gtl::ArraySlice c, const T& value) { - return c_find(c, value) != c.end(); + return absl::c_find(c, value) != c.end(); } template @@ -584,8 +478,8 @@ bool IsInt32(T x) { template Status EraseElementFromVector(std::vector* container, const T& value) { - // c_find returns a const_iterator which does not seem to work on gcc 4.8.4, - // and this breaks the ubuntu/xla_gpu build bot. + // absl::c_find returns a const_iterator which does not seem to work on + // gcc 4.8.4, and this breaks the ubuntu/xla_gpu build bot. auto it = std::find(container->begin(), container->end(), value); TF_RET_CHECK(it != container->end()); container->erase(it); diff --git a/tensorflow/compiler/xla/xla.proto b/tensorflow/compiler/xla/xla.proto index 10c0adc6707f01fcee87303a6e2ec5c570601309..b53f89d63b1edb5fb01ae9e6e71385797ca0f904 100644 --- a/tensorflow/compiler/xla/xla.proto +++ b/tensorflow/compiler/xla/xla.proto @@ -104,15 +104,6 @@ message DebugOptions { // interpretation of this value is left to the backends. int32 xla_backend_optimization_level = 31; - // When true, "unsafe" mathematical optimizations are enabled. These - // transformations include but are not limited to: - // - // - Reducing the precision of operations (e.g. using an approximate sin - // function, or transforming x/y into x * (1/y)). - // - Assuming that operations never produce or consume NaN or +/- Inf. - // - Assuming that +0 and -0 are indistinguishable. - bool xla_enable_fast_math = 32; - // Embed the compiler IR as a string in the executable. bool xla_embed_ir_in_executable = 33; @@ -194,8 +185,23 @@ message DebugOptions { // Maximum kernel unroll factor for the GPU backend. int32 xla_gpu_max_kernel_unroll_factor = 98; - // Extra options to pass to the compilation backend; specific interpretation - // of these values is left to the backend. + // When true, "unsafe" mathematical optimizations are enabled. These + // transformations include but are not limited to: + // + // - Reducing the precision of operations (e.g. using an approximate sin + // function, or transforming x/y into x * (1/y)). + // - Assuming that operations never produce or consume NaN or +/- Inf. + // - Assuming that +0 and -0 are indistinguishable. + bool xla_cpu_enable_fast_math = 99; + bool xla_gpu_enable_fast_math = 100; + + // Crashes the program when any kind of verification fails, instead of just + // logging the failures. One example is cross checking of convolution results + // among different algorithms. + bool xla_gpu_crash_on_verification_failures = 101; + + // Extra options to pass to the compilation backend (e.g. LLVM); specific + // interpretation of these values is left to the backend. map xla_backend_extra_options = 500; } diff --git a/tensorflow/compiler/xla/xla_data.proto b/tensorflow/compiler/xla/xla_data.proto index 0b300dc7b2d03cc8e1564f78412cc610cff518cd..27aa94c2cbc7f1aa3dd877e3b5d0e6d1b5380a1e 100644 --- a/tensorflow/compiler/xla/xla_data.proto +++ b/tensorflow/compiler/xla/xla_data.proto @@ -424,29 +424,43 @@ message GatherDimensionNumbers { // "Window indices" is a term for a set of indices that index into the // interior of a dynamic-slice from the input tensor, the starting indices for // which were computed from output_gather_dims (see the operation semantic for - // how this is defined) and the gather_indices tensor. + // how this is defined) and the start_indices tensor. // // The window indices for a specific output index Out is computed as: // // i = 0 // for (k : [0, input_tensor_shape.rank)) // window_indices[k] = - // if k in elided_window_dims + // if k in collapsed_slice_dims // then 0 - // else Out[output_window_dims[i++]] - repeated int64 output_window_dims = 1; - repeated int64 elided_window_dims = 2; + // else Out[offset_dims[i++]] + repeated int64 offset_dims = 1; + repeated int64 collapsed_slice_dims = 2; - // This is interpreted as a map from i to gather_dims_to_operand_dims[i]. It - // transforms the gather index looked up from the gather_indices tensor into + // This is interpreted as a map from i to start_index_map[i]. It + // transforms the gather index looked up from the start_indices tensor into // the starting index in the input space. - repeated int64 gather_dims_to_operand_dims = 3; + repeated int64 start_index_map = 3; - // The dimension in the gather_indices input that contains the starting + // The dimension in the start_indices input that contains the starting // indices. int64 index_vector_dim = 4; } +// Describes the dimension numbers for a scatter operation. +// +// All the fields are similar to the corresponding fields in +// GatherDimensionNumbers. Differences are noted below. +message ScatterDimensionNumbers { + // The set of dimensions in the updates shape that are window dimensions. + repeated int64 update_window_dims = 1; + // The set of window dimensions that must be inserted into the updates shape. + repeated int64 inserted_window_dims = 2; + + repeated int64 scatter_dims_to_operand_dims = 3; + int64 index_vector_dim = 4; +} + message ConvolutionDimensionNumbers { // The number of the dimension that represents batch in the input. int64 input_batch_dimension = 7; @@ -547,3 +561,11 @@ message OpSharding { // to. repeated OpSharding tuple_shardings = 5; } + +// Describes the replica groups in a cross replica op (e.g., all-reduce and +// all-to-all). +message ReplicaGroup { + // The ids of the replicas that belongs to the same group. The ordering of the + // ids matters in some op (e.g., all-to-all). + repeated int64 replica_ids = 1; +} diff --git a/tensorflow/contrib/BUILD b/tensorflow/contrib/BUILD index 6a4e252b44881c679350e121b1793e3b797f0785..f7e3c8d8fb235bda622b637f0082eedd680185f5 100644 --- a/tensorflow/contrib/BUILD +++ b/tensorflow/contrib/BUILD @@ -46,6 +46,7 @@ py_library( "//tensorflow/contrib/gan", "//tensorflow/contrib/graph_editor:graph_editor_py", "//tensorflow/contrib/grid_rnn:grid_rnn_py", + "//tensorflow/contrib/hadoop", "//tensorflow/contrib/hooks", "//tensorflow/contrib/image:distort_image_py", "//tensorflow/contrib/image:image_py", @@ -63,6 +64,7 @@ py_library( "//tensorflow/contrib/linalg:linalg_py", "//tensorflow/contrib/linear_optimizer:sdca_estimator_py", "//tensorflow/contrib/linear_optimizer:sdca_ops_py", + "//tensorflow/contrib/lite/python:lite", "//tensorflow/contrib/lookup:lookup_py", "//tensorflow/contrib/losses:losses_py", "//tensorflow/contrib/losses:metric_learning_py", @@ -107,7 +109,6 @@ py_library( "//tensorflow/contrib/tfprof", "//tensorflow/contrib/timeseries", "//tensorflow/contrib/tpu", - "//tensorflow/contrib/tpu:tpu_py", "//tensorflow/contrib/training:training_py", "//tensorflow/contrib/util:util_py", "//tensorflow/python:util", @@ -135,7 +136,6 @@ py_library( # This is an issue with the tensorrt static library and will be fixed by # the next tensorrt release, so fix the order here after that. "//tensorflow/contrib/tensorrt:init_py", # doesn't compile on windows - "//tensorflow/contrib/lite/python:lite", # unix dependency, need to fix code ]), ) @@ -147,6 +147,7 @@ cc_library( "//tensorflow/contrib/coder:all_kernels", "//tensorflow/contrib/data/kernels:dataset_kernels", "//tensorflow/contrib/factorization/kernels:all_kernels", + "//tensorflow/contrib/hadoop:dataset_kernels", "//tensorflow/contrib/input_pipeline:input_pipeline_ops_kernels", "//tensorflow/contrib/layers:sparse_feature_cross_op_kernel", "//tensorflow/contrib/nearest_neighbor:nearest_neighbor_ops_kernels", @@ -182,6 +183,7 @@ cc_library( "//tensorflow/contrib/data:dataset_ops_op_lib", "//tensorflow/contrib/factorization:all_ops", "//tensorflow/contrib/framework:all_ops", + "//tensorflow/contrib/hadoop:dataset_ops_op_lib", "//tensorflow/contrib/input_pipeline:input_pipeline_ops_op_lib", "//tensorflow/contrib/layers:sparse_feature_cross_op_op_lib", "//tensorflow/contrib/nccl:nccl_ops_op_lib", diff --git a/tensorflow/contrib/__init__.py b/tensorflow/contrib/__init__.py index ded05da71877566781a5fb6d0c21e1c8d43de9ed..45a7680160251b37fbfb923eb23a5d68ccb2c5fb 100644 --- a/tensorflow/contrib/__init__.py +++ b/tensorflow/contrib/__init__.py @@ -22,6 +22,7 @@ from __future__ import print_function import os # Add projects here, they will show up under tf.contrib. +from tensorflow.contrib import autograph from tensorflow.contrib import batching from tensorflow.contrib import bayesflow from tensorflow.contrib import checkpoint @@ -93,8 +94,7 @@ from tensorflow.contrib import tpu from tensorflow.contrib import training from tensorflow.contrib import util from tensorflow.contrib.eager.python import tfe as eager -if os.name != "nt": - from tensorflow.contrib.lite.python import lite +from tensorflow.contrib.lite.python import lite from tensorflow.contrib.optimizer_v2 import optimizer_v2_symbols as optimizer_v2 from tensorflow.contrib.receptive_field import receptive_field_api as receptive_field from tensorflow.contrib.recurrent.python import recurrent_api as recurrent diff --git a/tensorflow/contrib/all_reduce/python/all_reduce.py b/tensorflow/contrib/all_reduce/python/all_reduce.py index 159d985db5c48f8fe1a26350255f8d8f68482473..3b539734a236804026826a8117d9c668c0dd089a 100644 --- a/tensorflow/contrib/all_reduce/python/all_reduce.py +++ b/tensorflow/contrib/all_reduce/python/all_reduce.py @@ -32,10 +32,10 @@ def _flatten_tensors(tensors): """Check tensors for isomorphism and flatten. Args: - tensors: list of T @{tf.Tensor} which must all have the same shape. + tensors: list of T `tf.Tensor` which must all have the same shape. Returns: - tensors: a list of T @{tf.Tensor} which are flattened (1D) views of tensors + tensors: a list of T `tf.Tensor` which are flattened (1D) views of tensors shape: the original shape of each element of input tensors Raises: @@ -61,12 +61,12 @@ def _reshape_tensors(tensors, shape): """Reshape tensors flattened by _flatten_tensors. Args: - tensors: list of T @{tf.Tensor} of identical length 1D tensors. + tensors: list of T `tf.Tensor` of identical length 1D tensors. shape: list of integers describing the desired shape. Product of the elements must equal the length of each tensor. Returns: - list of T @{tf.Tensor} which are the reshaped inputs. + list of T `tf.Tensor` which are the reshaped inputs. """ reshaped = [] for t in tensors: @@ -79,12 +79,12 @@ def _padded_split(tensor, pieces): """Like split for 1D tensors but pads-out case where len % pieces != 0. Args: - tensor: T @{tf.Tensor} that must be 1D. + tensor: T `tf.Tensor` that must be 1D. pieces: a positive integer specifying the number of pieces into which tensor should be split. Returns: - list of T @{tf.Tensor} of length pieces, which hold the values of + list of T `tf.Tensor` of length pieces, which hold the values of thin input tensor, in order. The final tensor may be zero-padded on the end to make its size equal to those of all of the other tensors. @@ -132,11 +132,11 @@ def _strip_padding(tensors, pad_len): """Strip the suffix padding added by _padded_split. Args: - tensors: list of T @{tf.Tensor} of identical length 1D tensors. + tensors: list of T `tf.Tensor` of identical length 1D tensors. pad_len: number of elements to be stripped from the end of each tensor. Returns: - list of T @{tf.Tensor} which are the stripped inputs. + list of T `tf.Tensor` which are the stripped inputs. Raises: ValueError: tensors must be a non-empty list of 1D tensors, and @@ -161,12 +161,12 @@ def _ragged_split(tensor, pieces): """Like split for 1D tensors but allows case where len % pieces != 0. Args: - tensor: T @{tf.Tensor} that must be 1D. + tensor: T `tf.Tensor` that must be 1D. pieces: a positive integer specifying the number of pieces into which tensor should be split. Returns: - list of T @{tf.Tensor} of length pieces, which hold the values of + list of T `tf.Tensor` of length pieces, which hold the values of the input tensor, in order. The final tensor may be shorter than the others, which will all be of equal length. @@ -256,7 +256,7 @@ def build_ring_all_reduce(input_tensors, num_workers, num_subchunks, """Construct a subgraph performing a ring-style all-reduce of input_tensors. Args: - input_tensors: a list of T @{tf.Tensor} objects, which must all + input_tensors: a list of T `tf.Tensor` objects, which must all have the same shape and type. num_workers: number of worker tasks spanned by input_tensors. num_subchunks: number of subchunks each device should process in one tick. @@ -272,7 +272,7 @@ def build_ring_all_reduce(input_tensors, num_workers, num_subchunks, size. Returns: - a list of T @{tf.Tensor} identical sum-reductions of input_tensors. + a list of T `tf.Tensor` identical sum-reductions of input_tensors. """ if len(input_tensors) < 2: raise ValueError("input_tensors must be length 2 or longer") @@ -299,7 +299,7 @@ def _build_ring_gather(input_tensors, devices, num_subchunks, """Construct a subgraph for the first (reduction) pass of ring all-reduce. Args: - input_tensors: a list of T @{tf.Tensor} 1D input tensors of same + input_tensors: a list of T `tf.Tensor` 1D input tensors of same shape and type. devices: array of device name strings num_subchunks: number of subchunks each device should process in one tick. @@ -311,7 +311,7 @@ def _build_ring_gather(input_tensors, devices, num_subchunks, ValueError: tensors must all be one dimensional. Returns: - list of list of T @{tf.Tensor} of (partially) reduced values where + list of list of T `tf.Tensor` of (partially) reduced values where exactly num_subchunks chunks at each device are fully reduced. """ num_devices = len(input_tensors) @@ -360,11 +360,11 @@ def _apply_unary_to_chunks(f, chunks_by_dev): """Apply a unary op to each tensor in chunks_by_dev, on same device. Args: - f: a unary function over T @{tf.Tensor}. - chunks_by_dev: list of lists of T @{tf.Tensor}. + f: a unary function over T `tf.Tensor`. + chunks_by_dev: list of lists of T `tf.Tensor`. Returns: - new list of lists of T @{tf.Tensor} with the same structure as + new list of lists of T `tf.Tensor` with the same structure as chunks_by_dev containing the derived tensors. """ output = [] @@ -381,14 +381,14 @@ def _build_ring_scatter(pred_by_s_d, rank_by_s_d, Args: pred_by_s_d: as produced by _ring_permutations rank_by_s_d: as produced by _ring_permutations - chunks_by_dev: list of list of T @{tf.Tensor} indexed by ints + chunks_by_dev: list of list of T `tf.Tensor` indexed by ints (device, chunk) Raises: ValueError: chunks_by_dev is not well-formed Returns: - list of T @{tf.Tensor} which are the fully reduced tensors, one + list of T `tf.Tensor` which are the fully reduced tensors, one at each device corresponding to the outer dimension of chunks_by_dev. """ num_devices = len(chunks_by_dev) @@ -448,12 +448,12 @@ def build_recursive_hd_all_reduce(input_tensors, red_op, un_op=None): the future with edge-case specific logic. Args: - input_tensors: list of T @{tf.Tensor} to be elementwise reduced. + input_tensors: list of T `tf.Tensor` to be elementwise reduced. red_op: a binary elementwise reduction Op. un_op: an optional unary elementwise Op to apply to reduced values. Returns: - list of T @{tf.Tensor} which are the fully reduced tensors, one + list of T `tf.Tensor` which are the fully reduced tensors, one at each device of input_tensors. Raises: @@ -475,13 +475,13 @@ def _build_recursive_hd_gather(input_tensors, devices, red_op): """Construct the gather phase of recursive halving-doubling all-reduce. Args: - input_tensors: list of T @{tf.Tensor} to be elementwise reduced. + input_tensors: list of T `tf.Tensor` to be elementwise reduced. devices: a list of strings naming the devices hosting input_tensors, which will also be used to host the (partial) reduction values. red_op: a binary elementwise reduction Op. Returns: - list of T @{tf.Tensor} which are the fully reduced tensor shards. + list of T `tf.Tensor` which are the fully reduced tensor shards. Raises: ValueError: num_devices not a power of 2, or tensor len not divisible @@ -516,12 +516,12 @@ def _build_recursive_hd_scatter(input_tensors, devices): """Construct the scatter phase of recursive halving-doublng all-reduce. Args: - input_tensors: list of T @{tf.Tensor} that are fully-reduced shards. + input_tensors: list of T `tf.Tensor` that are fully-reduced shards. devices: a list of strings naming the devices on which the reconstituted full tensors should be placed. Returns: - list of T @{tf.Tensor} which are the fully reduced tensors. + list of T `tf.Tensor` which are the fully reduced tensors. """ num_devices = len(devices) num_hops = int(math.log(num_devices, 2)) @@ -571,7 +571,7 @@ def build_shuffle_all_reduce(input_tensors, gather_devices, red_op, un_op=None): un_op: optional elementwise unary Op to be applied to fully-reduced values. Returns: - list of T @{tf.Tensor} which are the fully reduced tensors. + list of T `tf.Tensor` which are the fully reduced tensors. """ input_tensors, shape = _flatten_tensors(input_tensors) dst_devices = [t.device for t in input_tensors] @@ -594,7 +594,7 @@ def _build_shuffle_gather(input_tensors, gather_devices, red_op, un_op=None): un_op: optional elementwise unary Op to be applied to fully-reduced values. Returns: - list of T @{tf.Tensor} which are the fully reduced shards. + list of T `tf.Tensor` which are the fully reduced shards. Raises: ValueError: inputs not well-formed. @@ -629,7 +629,7 @@ def _build_shuffle_scatter(reduced_shards, dst_devices): should be reconstituted. Returns: - list of T @{tf.Tensor} scattered tensors. + list of T `tf.Tensor` scattered tensors. """ num_devices = len(dst_devices) out_tensors = [] @@ -644,7 +644,7 @@ def _split_by_task(devices, values): Args: devices: list of device name strings - values: list of T @{tf.tensor} of same length as devices. + values: list of T `tf.tensor` of same length as devices. Returns: (per_task_devices, per_task_values) where both values are @@ -680,14 +680,14 @@ def build_nccl_all_reduce(input_tensors, red_op, un_op=None): """Build a subgraph that does one full all-reduce, using NCCL. Args: - input_tensors: list of T @{tf.Tensor} of same-shape and type values to + input_tensors: list of T `tf.Tensor` of same-shape and type values to be reduced. red_op: binary elementwise reduction operator. Must be one of {tf.add} un_op: optional unary elementwise Op to apply to fully-reduce values. Returns: - list of T @{tf.Tensor} of reduced values. + list of T `tf.Tensor` of reduced values. Raises: ValueError: red_op not supported. @@ -709,14 +709,14 @@ def _build_nccl_hybrid(input_tensors, red_op, upper_level_f): """Construct a subgraph for NCCL hybrid all-reduce. Args: - input_tensors: list of T @{tf.Tensor} of same-shape and type values to + input_tensors: list of T `tf.Tensor` of same-shape and type values to be reduced. red_op: binary elementwise reduction operator. upper_level_f: function for reducing one value per worker, across workers. Returns: - list of T @{tf.Tensor} of reduced values. + list of T `tf.Tensor` of reduced values. Raises: ValueError: inputs not well-formed. @@ -797,7 +797,7 @@ def _build_shuffle_hybrid(input_tensors, gather_devices, red_op, upper_level_f): """Construct a subgraph for Shuffle hybrid all-reduce. Args: - input_tensors: list of T @{tf.Tensor} of same-shape and type values to + input_tensors: list of T `tf.Tensor` of same-shape and type values to be reduced. gather_devices: list of device names on which to host gather shards. red_op: binary elementwise reduction operator. @@ -805,7 +805,7 @@ def _build_shuffle_hybrid(input_tensors, gather_devices, red_op, upper_level_f): workers. Returns: - list of T @{tf.Tensor} of reduced values. + list of T `tf.Tensor` of reduced values. Raises: ValueError: inputs not well-formed. diff --git a/tensorflow/contrib/autograph/converters/BUILD b/tensorflow/contrib/autograph/converters/BUILD index 7cbba7168383f3d0cdc80fda9908cb7d70836bb4..2d2ab7040a8bb76f9538f201f75a2e4dcba0f511 100644 --- a/tensorflow/contrib/autograph/converters/BUILD +++ b/tensorflow/contrib/autograph/converters/BUILD @@ -204,6 +204,7 @@ py_test( name = "side_effect_guards_test", srcs = ["side_effect_guards_test.py"], srcs_version = "PY2AND3", + tags = ["notsan"], deps = [ ":converters", "//tensorflow/contrib/autograph/core:test_lib", diff --git a/tensorflow/contrib/autograph/converters/break_statements.py b/tensorflow/contrib/autograph/converters/break_statements.py index 2a60750bdae273ca349c305b033313fa61f41872..180779670d91abd7d395bda0b63f592967c5015b 100644 --- a/tensorflow/contrib/autograph/converters/break_statements.py +++ b/tensorflow/contrib/autograph/converters/break_statements.py @@ -42,7 +42,7 @@ class BreakTransformer(converter.Base): var_name = self.state[_Break].control_var_name # TODO(mdan): This will fail when expanded inside a top-level else block. template = """ - var_name = True + var_name = tf.constant(True) continue """ return templates.replace(template, var_name=var_name) @@ -85,7 +85,7 @@ class BreakTransformer(converter.Base): guarded_orelse = self._guard_if_present(node.orelse, break_var) template = """ - var_name = False + var_name = tf.constant(False) while test and not var_name: body else: @@ -122,7 +122,7 @@ class BreakTransformer(converter.Base): # the control variable is marked as used. # TODO(mdan): Use a marker instead, e.g. ag__.condition_loop_on(var_name) template = """ - var_name = False + var_name = tf.constant(False) for target in iter_: (var_name,) body diff --git a/tensorflow/contrib/autograph/converters/break_statements_test.py b/tensorflow/contrib/autograph/converters/break_statements_test.py index c26ca2946ce40e30248d1d835bbe6517911540c0..fcae7d68c0f90817e001b45fa86ca6be08456027 100644 --- a/tensorflow/contrib/autograph/converters/break_statements_test.py +++ b/tensorflow/contrib/autograph/converters/break_statements_test.py @@ -20,13 +20,16 @@ from __future__ import print_function from tensorflow.contrib.autograph.converters import break_statements from tensorflow.contrib.autograph.core import converter_testing +from tensorflow.python.eager import context as tfe_ctx +from tensorflow.python.framework import constant_op from tensorflow.python.platform import test class BreakCanonicalizationTest(converter_testing.TestCase): def assertTransformedEquivalent(self, test_fn, *inputs): - with self.converted(test_fn, break_statements, {}) as result: + with self.converted(test_fn, break_statements, {}, + constant_op.constant) as result: self.assertEqual(test_fn(*inputs), result.test_fn(*inputs)) def test_while_loop(self): @@ -40,9 +43,10 @@ class BreakCanonicalizationTest(converter_testing.TestCase): v.append(x) return v - self.assertTransformedEquivalent(test_fn, 0) - self.assertTransformedEquivalent(test_fn, 1) - self.assertTransformedEquivalent(test_fn, 4) + with tfe_ctx.eager_mode(): + self.assertTransformedEquivalent(test_fn, 0) + self.assertTransformedEquivalent(test_fn, 1) + self.assertTransformedEquivalent(test_fn, 4) def test_for_loop(self): @@ -55,7 +59,8 @@ class BreakCanonicalizationTest(converter_testing.TestCase): v.append(x) return v - with self.converted(test_fn, break_statements, {}) as result: + with self.converted(test_fn, break_statements, {}, + constant_op.constant) as result: # The break is incompletely canonicalized. The loop will not interrupt, # but the section following the break will be skipped. self.assertEqual([3], result.test_fn([5, 4])) @@ -77,9 +82,10 @@ class BreakCanonicalizationTest(converter_testing.TestCase): v.append(x) return v, u, w - self.assertTransformedEquivalent(test_fn, 0) - self.assertTransformedEquivalent(test_fn, 3) - self.assertTransformedEquivalent(test_fn, 11) + with tfe_ctx.eager_mode(): + self.assertTransformedEquivalent(test_fn, 0) + self.assertTransformedEquivalent(test_fn, 3) + self.assertTransformedEquivalent(test_fn, 11) def test_nested_loops(self): @@ -99,10 +105,11 @@ class BreakCanonicalizationTest(converter_testing.TestCase): v.append(x) return v, u - self.assertTransformedEquivalent(test_fn, 0) - self.assertTransformedEquivalent(test_fn, 2) - self.assertTransformedEquivalent(test_fn, 3) - self.assertTransformedEquivalent(test_fn, 5) + with tfe_ctx.eager_mode(): + self.assertTransformedEquivalent(test_fn, 0) + self.assertTransformedEquivalent(test_fn, 2) + self.assertTransformedEquivalent(test_fn, 3) + self.assertTransformedEquivalent(test_fn, 5) def test_loop_orelse(self): @@ -120,9 +127,10 @@ class BreakCanonicalizationTest(converter_testing.TestCase): v.append(x) return v, u - self.assertTransformedEquivalent(test_fn, 0) - self.assertTransformedEquivalent(test_fn, 2) - self.assertTransformedEquivalent(test_fn, 3) + with tfe_ctx.eager_mode(): + self.assertTransformedEquivalent(test_fn, 0) + self.assertTransformedEquivalent(test_fn, 2) + self.assertTransformedEquivalent(test_fn, 3) if __name__ == '__main__': diff --git a/tensorflow/contrib/autograph/converters/call_trees.py b/tensorflow/contrib/autograph/converters/call_trees.py index a36b3d77a9233daed864c616306b2ad27f582a38..2d1bed3367fa0b283200b775c5953da80c855367 100644 --- a/tensorflow/contrib/autograph/converters/call_trees.py +++ b/tensorflow/contrib/autograph/converters/call_trees.py @@ -238,7 +238,7 @@ class CallTreeTransformer(converter.Base): # Before we could convert all the time though, we'd need a reasonable # caching mechanism. template = """ - ag__.converted_call(func, True, False, {}, args) + ag__.converted_call(func, True, False, False, {}, args) """ call_expr = templates.replace(template, func=node.func, args=node.args) new_call = call_expr[0].value diff --git a/tensorflow/contrib/autograph/converters/continue_statements.py b/tensorflow/contrib/autograph/converters/continue_statements.py index 958bde0a58764e705c35ab73ce879b2c11ce7cdc..0476e97c15e33dcfc09b3555cf8dc7ff3fd7ce19 100644 --- a/tensorflow/contrib/autograph/converters/continue_statements.py +++ b/tensorflow/contrib/autograph/converters/continue_statements.py @@ -37,7 +37,7 @@ class ContinueCanonicalizationTransformer(converter.Base): def visit_Continue(self, node): self.set_local(CONTINUE_USED, True) template = """ - var_name = True + var_name = tf.constant(True) """ return templates.replace( template, var_name=self.get_local(CONTROL_VAR_NAME)) @@ -92,7 +92,7 @@ class ContinueCanonicalizationTransformer(converter.Base): if self.get_local(CONTINUE_USED, False): template = """ - var_name = False + var_name = tf.constant(False) """ control_var_init = templates.replace(template, var_name=continue_var) nodes = control_var_init + nodes diff --git a/tensorflow/contrib/autograph/converters/continue_statements_test.py b/tensorflow/contrib/autograph/converters/continue_statements_test.py index 3a7c7d1486de81482a191a321547ec1e67bf8618..37c15211b4fe266e57879249fe7e060ded44dc1f 100644 --- a/tensorflow/contrib/autograph/converters/continue_statements_test.py +++ b/tensorflow/contrib/autograph/converters/continue_statements_test.py @@ -20,13 +20,16 @@ from __future__ import print_function from tensorflow.contrib.autograph.converters import continue_statements from tensorflow.contrib.autograph.core import converter_testing +from tensorflow.python.eager import context as tfe_ctx +from tensorflow.python.framework import constant_op from tensorflow.python.platform import test class ContinueCanonicalizationTest(converter_testing.TestCase): def assertTransformedEquivalent(self, test_fn, *inputs): - with self.converted(test_fn, continue_statements, {}) as result: + with self.converted(test_fn, continue_statements, {}, + constant_op.constant) as result: self.assertEqual(test_fn(*inputs), result.test_fn(*inputs)) def test_basic(self): @@ -40,10 +43,11 @@ class ContinueCanonicalizationTest(converter_testing.TestCase): v.append(x) return v - self.assertTransformedEquivalent(test_fn, 0) - self.assertTransformedEquivalent(test_fn, 1) - self.assertTransformedEquivalent(test_fn, 3) - self.assertTransformedEquivalent(test_fn, 4) + with tfe_ctx.eager_mode(): + self.assertTransformedEquivalent(test_fn, 0) + self.assertTransformedEquivalent(test_fn, 1) + self.assertTransformedEquivalent(test_fn, 3) + self.assertTransformedEquivalent(test_fn, 4) def test_for_loop(self): @@ -56,10 +60,11 @@ class ContinueCanonicalizationTest(converter_testing.TestCase): v.append(x) return v - self.assertTransformedEquivalent(test_fn, []) - self.assertTransformedEquivalent(test_fn, [1]) - self.assertTransformedEquivalent(test_fn, [2]) - self.assertTransformedEquivalent(test_fn, [1, 2, 3]) + with tfe_ctx.eager_mode(): + self.assertTransformedEquivalent(test_fn, []) + self.assertTransformedEquivalent(test_fn, [1]) + self.assertTransformedEquivalent(test_fn, [2]) + self.assertTransformedEquivalent(test_fn, [1, 2, 3]) def test_nested(self): @@ -78,10 +83,11 @@ class ContinueCanonicalizationTest(converter_testing.TestCase): v.append(x) return v, u, w - self.assertTransformedEquivalent(test_fn, 0) - self.assertTransformedEquivalent(test_fn, 1) - self.assertTransformedEquivalent(test_fn, 3) - self.assertTransformedEquivalent(test_fn, 4) + with tfe_ctx.eager_mode(): + self.assertTransformedEquivalent(test_fn, 0) + self.assertTransformedEquivalent(test_fn, 1) + self.assertTransformedEquivalent(test_fn, 3) + self.assertTransformedEquivalent(test_fn, 4) if __name__ == '__main__': diff --git a/tensorflow/contrib/autograph/converters/directives.py b/tensorflow/contrib/autograph/converters/directives.py index ccdf79d47be65dd777a7ae3a226246a62e274430..77f625bac792621c45799d1a220f99eb4b99f7af 100644 --- a/tensorflow/contrib/autograph/converters/directives.py +++ b/tensorflow/contrib/autograph/converters/directives.py @@ -42,10 +42,30 @@ def _map_args(call_node, function): Returns: Dict[Text, ast.AST], mapping each of the function's argument names to the respective AST node. + Raises: + ValueError: if the default arguments are not correctly set """ args = call_node.args kwds = {kwd.arg: kwd.value for kwd in call_node.keywords} - return tf_inspect.getcallargs(function, *args, **kwds) + call_args = tf_inspect.getcallargs(function, *args, **kwds) + + # Keyword arguments not specified in kwds will be mapped to their defaults, + # which are Python values. Since we don't currently have a way to transform + # those into AST references, we simply remove them. By convention, directives + # use UNSPECIFIED as default value for for optional arguments. No other + # defaults should be present. + unexpected_defaults = [] + for k in call_args: + if (k not in kwds + and call_args[k] not in args + and call_args[k] is not directives.UNSPECIFIED): + unexpected_defaults.append(k) + if unexpected_defaults: + raise ValueError('Unexpected keyword argument values, %s, for function %s' + % (zip(unexpected_defaults, + [call_args[k] for k in unexpected_defaults]), + function)) + return {k: v for k, v in call_args.items() if v is not directives.UNSPECIFIED} class DirectivesTransformer(converter.Base): diff --git a/tensorflow/contrib/autograph/converters/directives_test.py b/tensorflow/contrib/autograph/converters/directives_test.py index a573ba5850609f65ea60432470485c523cd3da3b..a2d083b891314d2f8f3fa61b46edc347ca8e24eb 100644 --- a/tensorflow/contrib/autograph/converters/directives_test.py +++ b/tensorflow/contrib/autograph/converters/directives_test.py @@ -23,6 +23,7 @@ from tensorflow.contrib.autograph.core import converter_testing from tensorflow.contrib.autograph.core.converter import AgAnno from tensorflow.contrib.autograph.lang import directives from tensorflow.contrib.autograph.pyct import anno +from tensorflow.contrib.autograph.pyct import parser from tensorflow.python.platform import test @@ -71,7 +72,23 @@ class DirectivesTest(converter_testing.TestCase): d = d[directives.set_loop_options] self.assertEqual(d['parallel_iterations'].n, 10) self.assertEqual(d['back_prop'].id, 'a') - self.assertEqual(d['swap_memory'], directives.UNSPECIFIED) + self.assertNotIn('swap_memory', d) + + def test_invalid_default(self): + + def invalid_directive(valid_arg, invalid_default=object()): + del valid_arg + del invalid_default + return + + def call_invalid_directive(): + invalid_directive(1) + + node, _ = parser.parse_entity(call_invalid_directive) + # Find the call to the invalid directive + node = node.body[0].body[0].value + with self.assertRaisesRegexp(ValueError, 'Unexpected keyword.*'): + directives_converter._map_args(node, invalid_directive) if __name__ == '__main__': diff --git a/tensorflow/contrib/autograph/converters/error_handlers.py b/tensorflow/contrib/autograph/converters/error_handlers.py index 3f2366215268cffe1aa2c55a174dbdba6127d701..193682139438c1d0133b17165d7f7fb84e2eaaac 100644 --- a/tensorflow/contrib/autograph/converters/error_handlers.py +++ b/tensorflow/contrib/autograph/converters/error_handlers.py @@ -37,7 +37,8 @@ class ErrorRewritingTransformer(converter.Base): def visit_FunctionDef(self, node): node = self.generic_visit(node) - if anno.hasanno(node, anno.Basic.ORIGIN): + if (anno.hasanno(node, anno.Basic.ORIGIN) and + len(self.enclosing_entities) <= 1): template = """ try: body diff --git a/tensorflow/contrib/autograph/converters/error_handlers_test.py b/tensorflow/contrib/autograph/converters/error_handlers_test.py index cd74e5f18f76d0c531f487bc0c736b421c9c3fb4..5d61b220afa0fcf9a9e619bbd78f83a5076c473a 100644 --- a/tensorflow/contrib/autograph/converters/error_handlers_test.py +++ b/tensorflow/contrib/autograph/converters/error_handlers_test.py @@ -34,8 +34,10 @@ class ErrorHandlersTest(converter_testing.TestCase): raise ValueError() node, ctx = self.prepare(test_fn, {}) - anno.setanno(node, anno.Basic.ORIGIN, - origin_info.OriginInfo(None, None, None)) + anno.setanno( + node, anno.Basic.ORIGIN, + origin_info.OriginInfo(None, 'test_function_name', 'test_code', + 'test_comment')) node = error_handlers.transform(node, ctx) with self.compiled(node, {}) as result: with self.assertRaises(errors.GraphConstructionError): diff --git a/tensorflow/contrib/autograph/core/converter.py b/tensorflow/contrib/autograph/core/converter.py index a93e4a806469db63e7d767563e64dadfe71f50ee..83a80c1f52123c325782a67c651e892163af83b3 100644 --- a/tensorflow/contrib/autograph/core/converter.py +++ b/tensorflow/contrib/autograph/core/converter.py @@ -233,7 +233,7 @@ class Base(transformer.Base): arg_values = [] for def_ in defs: if (directive not in def_.directives or - arg not in arg not in def_.directives[directive]): + arg not in def_.directives[directive]): continue arg_value = def_.directives[directive][arg] for prev_value in arg_values: diff --git a/tensorflow/contrib/autograph/core/errors.py b/tensorflow/contrib/autograph/core/errors.py index c219b372c13f2870ebde2d35c50dcc1fb270490c..5a57d57e7d4c6461f05030b72cc9bfe1b33210db 100644 --- a/tensorflow/contrib/autograph/core/errors.py +++ b/tensorflow/contrib/autograph/core/errors.py @@ -33,8 +33,6 @@ import traceback from tensorflow.contrib.autograph.pyct import origin_info from tensorflow.python.framework import errors_impl -from tensorflow.python.util import tf_inspect - # TODO(mdan): Add a superclass common to all errors. @@ -68,47 +66,29 @@ class TfRuntimeError(Exception): return message + ''.join(traceback.format_list(self.custom_traceback)) -def _rewrite_tb(source_map, tb, filter_function_name=None): +def _rewrite_tb(source_map, tb): """Rewrites code references in a traceback. Args: source_map: Dict[origin_info.LineLocation, origin_info.OriginInfo], mapping locations to their origin tb: List[Tuple[Text, Text, Text, Text]], consistent with - traceback.extract_tb - filter_function_name: Optional[Text], allows restricting restricts the - frames to rewrite to a particular function name + traceback.extract_tb. Returns: List[Tuple[Text, Text, Text, Text]], the rewritten traceback """ new_tb = [] for frame in tb: - filename, lineno, function_name, _ = frame + filename, lineno, _, _ = frame loc = origin_info.LineLocation(filename, lineno) origin = source_map.get(loc) - # TODO(mdan): We shouldn't need the function name at all. - # filename + lineno should be sufficient, even if there are multiple source - # maps. if origin is not None: - if filter_function_name == function_name or filter_function_name is None: - new_tb.append(origin.as_frame()) - else: - new_tb.append(frame) + new_tb.append(origin.as_frame()) else: new_tb.append(frame) return new_tb -# TODO(znado): Make more robust to name changes in the rewriting logic. -def _remove_rewrite_frames(tb): - """Remove stack frames containing the error rewriting logic.""" - cleaned_tb = [] - for f in tb: - if 'ag__.rewrite_graph_construction_error' not in f[3]: - cleaned_tb.append(f) - return cleaned_tb - - # TODO(mdan): rename to raise_* def rewrite_graph_construction_error(source_map): """Rewrites errors raised by non-AG APIs inside AG generated code. @@ -132,20 +112,17 @@ def rewrite_graph_construction_error(source_map): _, original_error, e_traceback = error_info assert original_error is not None try: - _, _, _, func_name, _, _ = tf_inspect.stack()[1] + current_traceback = _cut_traceback_loops(source_map, + traceback.extract_tb(e_traceback)) if isinstance(original_error, GraphConstructionError): # TODO(mdan): This is incomplete. # The error might have bubbled through a non-converted function. - cleaned_traceback = traceback.extract_tb(e_traceback) previous_traceback = original_error.custom_traceback - cleaned_traceback = [cleaned_traceback[0]] + previous_traceback + cleaned_traceback = [current_traceback[0]] + previous_traceback else: - cleaned_traceback = traceback.extract_tb(e_traceback) + cleaned_traceback = current_traceback - # Remove the frame corresponding to this function call. - cleaned_traceback = cleaned_traceback[1:] - - cleaned_traceback = _rewrite_tb(source_map, cleaned_traceback, func_name) + cleaned_traceback = _rewrite_tb(source_map, cleaned_traceback) if isinstance(original_error, GraphConstructionError): original_error.custom_traceback = cleaned_traceback @@ -163,6 +140,60 @@ def rewrite_graph_construction_error(source_map): del e_traceback +def _cut_traceback_loops(source_map, original_traceback): + """Check for cases where we leave a user method and re-enter it. + + This is done by looking at the function names when the filenames are from any + files the user code is in. If we find a case where we return to a user method + after leaving it then we cut out the frames in between because we assume this + means these in between frames are from internal AutoGraph code that shouldn't + be included. + + An example of this is: + + File "file1.py", line 57, in my_func + ... + File "control_flow_ops.py", line 231, in cond + ... + File "control_flow_ops.py", line 1039, in inner_cond + ... + File "file1.py", line 68, in my_func + ... + + Where we would remove the control_flow_ops.py frames because we re-enter + my_func in file1.py. + + The source map keys are (file_path, line_number) so get the set of all user + file_paths. + + Args: + source_map: Dict[origin_info.LineLocation, origin_info.OriginInfo], mapping + locations to their origin + original_traceback: List[Tuple[Text, Text, Text, Text]], consistent with + traceback.extract_tb. + + Returns: + List[Tuple[Text, Text, Text, Text]], the traceback with any loops removed. + """ + all_user_files = set(loc.filename for loc in source_map) + cleaned_traceback = [] + last_user_frame_index = None + last_user_user_file_path = None + # TODO(mdan): Simplify this logic. + for fi, frame in enumerate(original_traceback): + frame_file_path, lineno, _, _ = frame + src_map_key = origin_info.LineLocation(frame_file_path, lineno) + if frame_file_path in all_user_files: + if src_map_key in source_map: + if (last_user_frame_index is not None and + last_user_user_file_path == frame_file_path): + cleaned_traceback = cleaned_traceback[:last_user_frame_index] + last_user_frame_index = fi + last_user_user_file_path = frame_file_path + cleaned_traceback.append(frame) + return cleaned_traceback + + # TODO(mdan): This should be consistent with rewrite_graph_construction_error # Both should either raise or return. def rewrite_tf_runtime_error(error, source_map): @@ -175,56 +206,9 @@ def rewrite_tf_runtime_error(error, source_map): Returns: TfRuntimeError, the rewritten underlying error. """ - # Check for cases where we leave a user method and re-enter it in the - # traceback. This is done by looking at the function names when the - # filenames are from any files the user code is in. If we find a case where - # we return to a user method after leaving it then we cut out the frames in - # between because we assume this means these in between frames are from - # internal AutoGraph code that shouldn't be included. - # - # An example of this is: - # - # File "file1.py", line 57, in my_func - # ... - # File "control_flow_ops.py", line 231, in cond - # ... - # File "control_flow_ops.py", line 1039, in inner_cond - # ... - # File "file1.py", line 68, in my_func - # ... - # - # Where we would remove the control_flow_ops.py frames because we re-enter - # my_func in file1.py. - # - # The source map keys are (file_path, line_number) so get the set of all user - # file_paths. try: - all_user_files = set(loc.filename for loc in source_map) - cleaned_traceback = [] - last_user_frame_index = None - last_user_user_file_path = None - last_user_user_fn_name = None - # TODO(mdan): Simplify this logic. - for fi, frame in enumerate(error.op.traceback): - frame_file_path, lineno, _, _ = frame - lineno -= 1 # Frame line numbers are 1-based. - src_map_key = origin_info.LineLocation(frame_file_path, lineno) - if frame_file_path in all_user_files: - if src_map_key in source_map: - original_fn_name = source_map[src_map_key].function_name - if (last_user_frame_index is not None and - last_user_user_file_path == frame_file_path): - if last_user_user_fn_name == original_fn_name: - cleaned_traceback = cleaned_traceback[:last_user_frame_index] - else: - cleaned_traceback = cleaned_traceback[:last_user_frame_index + 1] - last_user_user_fn_name = original_fn_name - else: - last_user_user_fn_name = None - last_user_frame_index = fi - last_user_user_file_path = frame_file_path - cleaned_traceback.append(frame) - + cleaned_traceback = _cut_traceback_loops(source_map, error.op.traceback) + # cleaned_traceback = error.op.traceback cleaned_traceback = _rewrite_tb(source_map, cleaned_traceback) op_name = error.op.name diff --git a/tensorflow/contrib/autograph/core/errors_test.py b/tensorflow/contrib/autograph/core/errors_test.py index c0e2c74e47ddfb8ee812d6d839b06784e7a01dba..404c1f5456f9654724d068e3007fe9ced15cbf07 100644 --- a/tensorflow/contrib/autograph/core/errors_test.py +++ b/tensorflow/contrib/autograph/core/errors_test.py @@ -43,7 +43,8 @@ class RuntimeErrorsTest(test.TestCase): filename = tf_inspect.getsourcefile(function) lineno += line_offset loc = origin_info.LineLocation(filename, lineno) - origin = origin_info.OriginInfo(loc, 'test_function_name', 'test_code') + origin = origin_info.OriginInfo(loc, 'test_function_name', 'test_code', + 'test_comment') return loc, origin def test_improved_errors_basic(self): diff --git a/tensorflow/contrib/autograph/docs/pyfunc_dtypes.md b/tensorflow/contrib/autograph/docs/pyfunc_dtypes.md new file mode 100644 index 0000000000000000000000000000000000000000..bcbb920cc53de4b89dc67128c9c2c2312f030f0a --- /dev/null +++ b/tensorflow/contrib/autograph/docs/pyfunc_dtypes.md @@ -0,0 +1,33 @@ +# Specifying return data type for `py_func` calls + +The `py_func` op requires specifying a +[data type](https://www.tensorflow.org/guide/tensors#data_types). + +When wrapping a function with `py_func`, for instance using +`@autograph.do_not_convert(run_mode=autograph.RunMode.PY_FUNC)`, you have two +options to specify the returned data type: + + * explicitly, with a specified `tf.DType` value + * by matching the data type of an input argument, which is then assumed to be + a `Tensor` + +Examples: + +Specify an explicit data type: + +``` + def foo(a): + return a + 1 + + autograph.util.wrap_py_func(f, return_dtypes=[tf.float32]) +``` + +Match the data type of the first argument: + +``` + def foo(a): + return a + 1 + + autograph.util.wrap_py_func( + f, return_dtypes=[autograph.utils.py_func.MatchDType(0)]) +``` diff --git a/tensorflow/contrib/autograph/examples/integration_tests/BUILD b/tensorflow/contrib/autograph/examples/integration_tests/BUILD index d20c17b63b923458952dbfdb1e07e808cf6a36ff..6c281485b4a3c4d09292a4d7af16330cdc44edd4 100644 --- a/tensorflow/contrib/autograph/examples/integration_tests/BUILD +++ b/tensorflow/contrib/autograph/examples/integration_tests/BUILD @@ -16,6 +16,19 @@ filegroup( visibility = ["//tensorflow:__subpackages__"], ) +py_test( + name = "errors_test", + srcs = [ + "errors_test.py", + ], + srcs_version = "PY2AND3", + tags = ["no_windows"], + visibility = ["//visibility:public"], + deps = [ + "//tensorflow:tensorflow_py", + ], +) + py_test( name = "keras_test", srcs = [ diff --git a/tensorflow/contrib/autograph/examples/integration_tests/errors_test.py b/tensorflow/contrib/autograph/examples/integration_tests/errors_test.py new file mode 100644 index 0000000000000000000000000000000000000000..f4b9159942bcf8837b97dfac000d8fb34d15a314 --- /dev/null +++ b/tensorflow/contrib/autograph/examples/integration_tests/errors_test.py @@ -0,0 +1,162 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Error traceback rewriting integration tests.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import tensorflow as tf + +from tensorflow.contrib import autograph as ag +from tensorflow.python.util import tf_inspect + + +class ErrorsTest(tf.test.TestCase): + + def test_graph_construction_error_rewriting_call_tree(self): + + def innermost(x): + if x > 0: + return tf.random_normal((2, 3), mean=0.0, dtype=tf.int32) + return tf.zeros((2, 3)) + + def inner_caller(): + return innermost(1.0) + + def caller(): + return inner_caller() + + with self.assertRaises(ag.GraphConstructionError) as error: + graph = ag.to_graph(caller) + graph() + expected = error.exception + custom_traceback = expected.custom_traceback + found_correct_filename = False + num_innermost_names = 0 + num_inner_caller_names = 0 + num_caller_names = 0 + ag_output_filename = tf_inspect.getsourcefile(graph) + for frame in custom_traceback: + filename, _, fn_name, _ = frame + self.assertFalse('control_flow_ops.py' in filename) + self.assertFalse(ag_output_filename in filename) + found_correct_filename |= __file__ in filename + self.assertNotEqual('tf__test_fn', fn_name) + num_innermost_names += int('innermost' == fn_name) + self.assertNotEqual('tf__inner_caller', fn_name) + num_inner_caller_names += int('inner_caller' == fn_name) + self.assertNotEqual('tf__caller', fn_name) + num_caller_names += int('caller' == fn_name) + self.assertTrue(found_correct_filename) + self.assertEqual(num_innermost_names, 1) + self.assertEqual(num_inner_caller_names, 1) + self.assertEqual(num_caller_names, 1) + + def test_graph_construction_error_rewriting_class(self): + + class TestClass(object): + + def test_fn(self): + return tf.random_normal((2, 3), mean=0.0, dtype=tf.int32) + + def inner_caller(self): + return self.test_fn() + + def caller(self): + return self.inner_caller() + + # Note we expect a TypeError here because the traceback will not be + # rewritten for classes. + with self.assertRaises(TypeError): + graph = ag.to_graph(TestClass) + graph().caller() + + def test_runtime_error_rewriting(self): + + def g(x, s): + while tf.reduce_sum(x) > s: + x //= 0 + return x + + def test_fn(x): + return g(x, 10) + + compiled_fn = ag.to_graph(test_fn) + + with self.assertRaises(ag.TfRuntimeError) as error: + with self.test_session() as sess: + x = compiled_fn(tf.constant([4, 8])) + with ag.improved_errors(compiled_fn): + sess.run(x) + expected = error.exception + custom_traceback = expected.custom_traceback + found_correct_filename = False + num_test_fn_frames = 0 + num_g_frames = 0 + ag_output_filename = tf_inspect.getsourcefile(compiled_fn) + for frame in custom_traceback: + filename, _, fn_name, source_code = frame + self.assertFalse(ag_output_filename in filename) + self.assertFalse('control_flow_ops.py' in filename) + self.assertFalse('ag__.' in fn_name) + self.assertFalse('tf__g' in fn_name) + self.assertFalse('tf__test_fn' in fn_name) + found_correct_filename |= __file__ in filename + num_test_fn_frames += int('test_fn' == fn_name and + 'return g(x, 10)' in source_code) + # This makes sure that the code is correctly rewritten from "x_1 //= 0" to + # "x //= 0". + num_g_frames += int('g' == fn_name and 'x //= 0' in source_code) + self.assertTrue(found_correct_filename) + self.assertEqual(num_test_fn_frames, 1) + self.assertEqual(num_g_frames, 1) + + def test_runtime_error_rewriting_nested(self): + + def test_fn(x): + + def g(y): + return y**2 // 0 + + s = 0 + for xi in x: + s += g(xi) + return s + + compiled_fn = ag.to_graph(test_fn) + + # TODO(b/111408261): Nested functions currently do not rewrite correctly, + # when they do we should change this test to check for the same traceback + # properties as the other tests. This should throw a runtime error with a + # frame with "g" as the function name but because we don't yet add + # try/except blocks to inner functions the name is "tf__g". + with self.assertRaises(ag.TfRuntimeError) as error: + with self.test_session() as sess: + x = compiled_fn(tf.constant([4, 8])) + with ag.improved_errors(compiled_fn): + sess.run(x) + expected = error.exception + custom_traceback = expected.custom_traceback + num_tf_g_frames = 0 + for frame in custom_traceback: + _, _, fn_name, _ = frame + self.assertNotEqual('g', fn_name) + num_tf_g_frames += int('tf__g' == fn_name) + self.assertEqual(num_tf_g_frames, 1) + + +if __name__ == '__main__': + tf.test.main() diff --git a/tensorflow/contrib/autograph/examples/integration_tests/keras_test.py b/tensorflow/contrib/autograph/examples/integration_tests/keras_test.py index 73125eb452fc3f3f94a8323d677341345931c4ea..7e7ef5a3e2bbf6a15936eb181c9c4112f8b820e6 100644 --- a/tensorflow/contrib/autograph/examples/integration_tests/keras_test.py +++ b/tensorflow/contrib/autograph/examples/integration_tests/keras_test.py @@ -44,6 +44,33 @@ class ModelWithStaticConditional(object): return x +class BasicBlock(tf.keras.Model): + + def __init__(self): + super(BasicBlock, self).__init__() + self.conv1 = tf.keras.layers.Conv2D(8, 3) + self.pool = tf.keras.layers.GlobalAveragePooling2D() + self.dense = tf.keras.layers.Dense(3) + + def call(self, x): + x = self.conv1(x) + x = self.pool(x) + x = self.dense(x) + return x + + +class CompoundModel(tf.keras.Model): + + def __init__(self): + super(CompoundModel, self).__init__() + self.block = BasicBlock() + + @autograph.convert(recursive=True) + def call(self, x): + x = self.block(x) # pylint: disable=not-callable + return x + + class KerasTest(tf.test.TestCase): def test_basic(self): @@ -57,6 +84,20 @@ class KerasTest(tf.test.TestCase): model = ModelWithStaticConditional(True) self.assertEqual(model.call(), 25) + def test_recursive_true(self): + with self.assertRaisesRegexp(NotImplementedError, + 'Object conversion is not yet supported.'): + with tf.Graph().as_default(): + model = CompoundModel() + model.build(tf.TensorShape((None, 10, 10, 1))) + init = tf.global_variables_initializer() + + with tf.Session() as sess: + sess.run(init) + sample_input = tf.random_uniform((1, 10, 10, 1)) + output = model(sample_input) # pylint: disable=not-callable + self.assertEqual(sess.run(output).shape, (1, 3)) + if __name__ == '__main__': tf.test.main() diff --git a/tensorflow/contrib/autograph/examples/notebooks/dev_summit_2018_demo.ipynb b/tensorflow/contrib/autograph/examples/notebooks/dev_summit_2018_demo.ipynb index a3109fa5db2b895817545f5ff611c6979375f85b..7e9cc54d4cafa64e4cd3b48f9376b1b2b4d3575e 100644 --- a/tensorflow/contrib/autograph/examples/notebooks/dev_summit_2018_demo.ipynb +++ b/tensorflow/contrib/autograph/examples/notebooks/dev_summit_2018_demo.ipynb @@ -392,7 +392,7 @@ "output_type": "stream", "text": [ "Got error message: assertion failed: [Do not pass zero!]\n", - "\t [[Node: f/Assert/Assert = Assert[T=[DT_STRING], summarize=3, _device=\"/job:localhost/replica:0/task:0/device:CPU:0\"](f/NotEqual, f/Assert/Assert/data_0)]]\n" + "\t [[{{node f/Assert/Assert}} = Assert[T=[DT_STRING], summarize=3, _device=\"/job:localhost/replica:0/task:0/device:CPU:0\"](f/NotEqual, f/Assert/Assert/data_0)]]\n" ] } ], diff --git a/tensorflow/contrib/autograph/impl/api.py b/tensorflow/contrib/autograph/impl/api.py index ee71f4f9acf8b7f00f849dcdbcdc020fed04c278..276a3871801da2c66fbfffc38ac1ea39704b5de1 100644 --- a/tensorflow/contrib/autograph/impl/api.py +++ b/tensorflow/contrib/autograph/impl/api.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Public API.""" +"""This module contains the user-facing API for AutoGraph.""" from __future__ import absolute_import from __future__ import division @@ -42,33 +42,30 @@ from tensorflow.python.util import tf_inspect # (currently we require (module + class name, type)) -def convert(recursive=False, verbose=False, arg_types=None): - """Decorator that compiles a function to graph mode. +# TODO(mdan): This should behave like to_graph (e.g. convert statically). +def convert(recursive=False, verbose=False): + """Decorator that compiles a function to use TensorFlow ops. - The decorator is dynamic - invoking compilation whenever the decorated - function is called. This means the parameter values are known at compilation. + The decorator is dynamic - it recompiles the target whenever the decorated + function is called. This means the parameter values are known at conversion. + It also means that repeated calls with different types of parameters will be + correctly processed. Args: - recursive: Whether to recursively convert any functions that the decorator - function may call. - verbose: Whether to output the compiled code in the logs. - arg_types: See to_graph. + recursive: bool, whether to recursively convert any functions or classes + that the converted function may use. + verbose: bool, whether to output the compiled code in the logs. Returns: - A decorator that compiles the given function to graph mode. - - Raises: - ValueError: If any of the arguments are illegal. + Callable, a decorator that converts the given function into an equivalent + function that uses TensorFlow ops. """ - if arg_types is None: - arg_types = {} - def decorator(f): """Decorator implementation.""" @wraps(f) def wrapper(*args, **kwargs): - return converted_call(f, recursive, verbose, arg_types, *args, **kwargs) + return converted_call(f, recursive, verbose, True, {}, *args, **kwargs) wrapper = tf_decorator.make_decorator(f, wrapper) @@ -81,22 +78,34 @@ def convert(recursive=False, verbose=False, arg_types=None): class RunMode(Enum): + """Specifies the way a converted function or method should be executed in TF. + + The enum values have the following semantics: + + * GRAPH: Call this function directly, as-is. This is suitable for functions + that were already designed for TF graphs and contain ops. + * PY_FUNC: Wrap this function into a py_func op. This is suitable for code + that will only run correctly in Python, for example code that renders + to the display, reads keyboard input, etc. + """ GRAPH = 1 PY_FUNC = 2 def do_not_convert(run_as=RunMode.GRAPH, return_dtypes=None): - """Decorator that suppresses compilation of a function. + """Decorator that suppresses the conversion of a function. + + See also: docs/pyfunc_dtypes.md Args: - run_as: RunMode value. Whether to run the function as-is, or wrap it into - a py_func. - return_dtypes: See autograph.utils.py_func.wrap_py_func. Setting to None or - empty list or tuple will create a dummy return value that can be used - to set control dependencies. + run_as: RunMode, specifies how to use the function in TensorFlow. + return_dtypes: Optional[Iterable[ + Union[tf.DType, utils.py_func.MatchDType]]], the return data types of + the converted function, if run_as is RunMode.PY_FUNC. Ignored otherwise. + May be set to None if the function has no return values. Returns: - A decorator that wraps the original function. + Callable, a decorator that wraps the original function. """ def decorator(f): @@ -129,12 +138,13 @@ def do_not_convert(run_as=RunMode.GRAPH, return_dtypes=None): return decorator -def converted_call(f, recursive, verbose, arg_types, *args, **kwargs): - """Compiles a function call inline.""" +# TODO(mdan): Move to a private, undocumented module. +def converted_call(f, recursive, verbose, force_conversion, arg_types, *args, + **kwargs): + """Compiles a function call inline. For internal use only.""" # TODO(mdan): This needs cleanup. # In particular, we may want to avoid renaming functions altogether. - - if conversion.is_whitelisted_for_graph(f): + if not force_conversion and conversion.is_whitelisted_for_graph(f): return f(*args, **kwargs) unknown_arg_value = object() # Sentinel for arguments of unknown value @@ -201,39 +211,41 @@ def converted_call(f, recursive, verbose, arg_types, *args, **kwargs): return converted_f(*effective_args, **kwargs) +# TODO(mdan): Rename: to_ops? +# TODO(mdan): Looki into overloading as function and decorator, like tfe.defun. +# TODO(mdan): Remove partial_types. def to_graph(e, recursive=True, verbose=False, arg_values=None, arg_types=None, partial_types=None): - """Compile a Python entity into equivalent TensorFlow code. + """Converts a Python entity into equivalent code that uses TensorFlow ops. - Currently supported entities: + Supported Python entities include: * functions * classes - Classes are handled by converting all their methods into a new class. + Classes are converted by converting all their methods into a new class. Args: - e: A Python entity. - recursive: Whether to recursively convert any functions that the decorator - function may call. - verbose: Whether to output the compiled code in the logs. - arg_values: A dict containing value hints for symbols like function - parameters. - arg_types: A dict containing type hints for symbols like function - parameters. - partial_types: A set of types (e.g. classes) that will not be converted - entirely. Calls to member functions for these types will be renamed - independently. + e: Union[Callable, Type], the Python entity to convert. + recursive: bool, whether to recursively convert any functions that the + converted function may call. + verbose: bool, whether to output the compiled code in the logs. + arg_values: Optional[Dict[Text, Any]], value hints for symbols including + function arguments. + arg_types: Optional[Dict[Text, Type]], type hints for symbols including + function arguments. + partial_types: Set[Type], reserved for internal use. Returns: - A function with a signature identical to `o`, but which when executed it - creates TF a graph that has the same functionality as the original entity. + Union[Callable, Type], the converted entity, which is the same kind as e + (that is, a function is e is a function, a class if e is a class, etc.) but + its code has been converted to use TF ops. + Raises: - ValueError: If the converted function defines or refers to symbol names that - are reserved for AutoGraph. + ValueError: If the entity could not be converted. """ program_ctx = converter.ProgramContext( recursive=recursive, @@ -258,25 +270,27 @@ def to_graph(e, # Avoid overwriting entities that have been transformed. if key not in compiled_module.__dict__: compiled_module.__dict__[key] = val - compiled_fn = getattr(compiled_module, name) + compiled = getattr(compiled_module, name) # Need this so the source_mapping attribute is available for the context # manager to access for runtime errors. # # Note that compiler.ast_to_object attaches the source map 'ag_source_map__' # symbol to the compiled module. + # TODO(mdan): Record this statically in the generated code. + # TODO(mdan): Rename this attribute to 'autograph_info__' source_map_attribute_name = 'ag_source_map' - if getattr(compiled_fn, source_map_attribute_name, None) is not None: + if getattr(compiled, source_map_attribute_name, None) is not None: raise ValueError('cannot convert %s because is has an attribute ' '"%s", which is reserved for AutoGraph.' % - (compiled_fn, source_map_attribute_name)) - setattr(compiled_fn, source_map_attribute_name, + (compiled, source_map_attribute_name)) + setattr(compiled, source_map_attribute_name, compiled_module.__dict__['ag_source_map__']) if verbose: logging.info('Compiled output of %s:\n\n%s\n', e, compiled_src) - return compiled_fn + return compiled def to_code(e, @@ -285,20 +299,23 @@ def to_code(e, arg_types=None, partial_types=None, indentation=' '): - """Return the equivalent of an entity in TensorFlow code. + """Returns the equivalent code that uses TensorFlow ops. - See `to_graph` for more details. + Also see: `to_graph`, `convert` Args: - e: A Python entity. - recursive: See to_graph. - arg_values: See to_graph. - arg_types: See to_graph. - partial_types: See to_graph. - indentation: String, when to use for each level of indentation. + e: Union[Callable, Type], the Python entity to convert. + recursive: bool, whether to recursively convert any functions that the + converted function may call. + arg_values: Optional[Dict[Text, Any]], value hints for symbols including + function arguments. + arg_types: Optional[Dict[Text, Type]], type hints for symbols including + function arguments. + partial_types: Set[Type], reserved for internal use. + indentation: Text, when to use for each level of indentation. Returns: - String. + Text, the converted code. """ program_ctx = converter.ProgramContext( recursive=recursive, diff --git a/tensorflow/contrib/autograph/impl/api_test.py b/tensorflow/contrib/autograph/impl/api_test.py index 4de7df657204db2f625098d15e475f942eb352b8..803fde9089b1c004d9bfc0dfefd3d6b422752f0a 100644 --- a/tensorflow/contrib/autograph/impl/api_test.py +++ b/tensorflow/contrib/autograph/impl/api_test.py @@ -183,8 +183,8 @@ class ApiTest(test.TestCase): @api.convert(recursive=True) def test_method(self, x, s, a): while tf.reduce_sum(x) > s: - x //= api.converted_call(self.called_member, False, False, {}, self, - a) + x //= api.converted_call(self.called_member, False, False, False, {}, + self, a) return x tc = TestClass() @@ -195,7 +195,7 @@ class ApiTest(test.TestCase): self.assertListEqual([0, 1], sess.run(x).tolist()) def test_converted_call_builtin(self): - x = api.converted_call(range, False, False, {}, 3) + x = api.converted_call(range, False, False, False, {}, 3) self.assertEqual((0, 1, 2), tuple(x)) def test_converted_call_function(self): @@ -206,7 +206,7 @@ class ApiTest(test.TestCase): return x with self.test_session() as sess: - x = api.converted_call(test_fn, False, False, {}, + x = api.converted_call(test_fn, False, False, False, {}, constant_op.constant(-1)) self.assertEqual(1, sess.run(x)) @@ -224,7 +224,7 @@ class ApiTest(test.TestCase): with self.test_session() as sess: tc = TestClass(constant_op.constant(-1)) - x = api.converted_call(tc.test_method, False, False, {}, tc) + x = api.converted_call(tc.test_method, False, False, False, {}, tc) self.assertEqual(1, sess.run(x)) def test_converted_call_method_by_class(self): @@ -241,7 +241,7 @@ class ApiTest(test.TestCase): with self.test_session() as sess: tc = TestClass(constant_op.constant(-1)) - x = api.converted_call(TestClass.test_method, False, False, {}, tc) + x = api.converted_call(TestClass.test_method, False, False, False, {}, tc) self.assertEqual(1, sess.run(x)) def test_converted_call_callable_object(self): @@ -258,7 +258,7 @@ class ApiTest(test.TestCase): with self.test_session() as sess: tc = TestClass(constant_op.constant(-1)) - x = api.converted_call(tc, False, False, {}) + x = api.converted_call(tc, False, False, False, {}) self.assertEqual(1, sess.run(x)) def test_converted_call_constructor(self): @@ -274,12 +274,27 @@ class ApiTest(test.TestCase): return self.x with self.test_session() as sess: - tc = api.converted_call(TestClass, False, False, {}, + tc = api.converted_call(TestClass, False, False, False, {}, constant_op.constant(-1)) # tc is now a converted object. x = tc.test_method() self.assertEqual(1, sess.run(x)) + def test_converted_call_already_converted(self): + + def f(x): + return x == 0 + + with self.test_session() as sess: + x = api.converted_call(f, False, False, False, {}, + constant_op.constant(0)) + self.assertTrue(sess.run(x)) + + converted_f = api.to_graph(f) + x = api.converted_call(converted_f, False, False, False, {}, + constant_op.constant(0)) + self.assertTrue(sess.run(x)) + def test_to_graph_basic(self): def test_fn(x, s): diff --git a/tensorflow/contrib/autograph/impl/conversion.py b/tensorflow/contrib/autograph/impl/conversion.py index 57ec739a8006414377de451723538d27e286b2bd..fc8a976d3f3ecdc9c6339995dd0dfc776824b90d 100644 --- a/tensorflow/contrib/autograph/impl/conversion.py +++ b/tensorflow/contrib/autograph/impl/conversion.py @@ -48,6 +48,7 @@ from tensorflow.contrib.autograph.pyct import inspect_utils from tensorflow.contrib.autograph.pyct import origin_info from tensorflow.contrib.autograph.pyct import parser from tensorflow.contrib.autograph.pyct import qual_names +from tensorflow.contrib.autograph.pyct import templates from tensorflow.contrib.autograph.pyct import transformer from tensorflow.python.util import tf_inspect @@ -70,6 +71,8 @@ def is_whitelisted_for_graph(o): for prefix, in config.DEFAULT_UNCOMPILED_MODULES: if m.__name__.startswith(prefix): return True + if hasattr(o, 'autograph_info__'): + return True return False @@ -115,12 +118,32 @@ def entity_to_graph(o, program_ctx, arg_values, arg_types): node, name, ns = function_to_graph(o, program_ctx, arg_values, arg_types) elif tf_inspect.ismethod(o): node, name, ns = function_to_graph(o, program_ctx, arg_values, arg_types) + # TODO(mdan,yashkatariya): Remove when object conversion is implemented. + elif hasattr(o, '__class__'): + raise NotImplementedError( + 'Object conversion is not yet supported. If you are ' + 'trying to convert code that uses an existing object, ' + 'try including the creation of that object in the ' + 'conversion. For example, instead of converting the method ' + 'of a class, try converting the entire class instead. ' + 'See https://github.com/tensorflow/tensorflow/blob/master/tensorflow/' + 'contrib/autograph/README.md#using-the-functional-api ' + 'for more information.') else: raise ValueError( 'Entity "%s" has unsupported type "%s". Only functions and classes are ' 'supported for now.' % (o, type(o))) + # TODO(mdan): This is temporary. it should be created using a converter. + # TODO(mdan): The attribute should be added with a helper, not directly. + # The helper can ensure there are no collisions. + template = ''' + entity.autograph_info__ = {} + ''' + node.extend(templates.replace(template, entity=name)) + program_ctx.add_to_cache(o, node) + if program_ctx.recursive: while True: candidate = None @@ -169,7 +192,7 @@ def class_to_graph(c, program_ctx): class_name = namer.compiled_class_name(c.__name__, c) # TODO(mdan): This needs to be explained more thoroughly. - # Process any base classes: if the sueprclass if of a whitelisted type, an + # Process any base classes: if the superclass if of a whitelisted type, an # absolute import line is generated. Otherwise, it is marked for conversion # (as a side effect of the call to namer.compiled_class_name() followed by # program_ctx.update_name_map(namer)). @@ -268,18 +291,18 @@ def function_to_graph(f, context = converter.EntityContext(namer, entity_info, program_ctx) node = node_to_graph(node, context, rewrite_errors=rewrite_errors) - # TODO(mdan): This somewhat duplicates the call rename logic in call_treest.py + # TODO(mdan): This somewhat duplicates the call rename logic in call_trees.py new_name, did_rename = namer.compiled_function_name(f.__name__, f, owner_type) if not did_rename: new_name = f.__name__ if node.name != f.__name__: raise NotImplementedError('Strange corner case. Send us offending code!') - node.name = new_name + program_ctx.update_name_map(namer) # TODO(mdan): Use this at compilation. - return (node,), new_name, namespace + return [node], new_name, namespace def node_to_graph(node, context, rewrite_errors=True): diff --git a/tensorflow/contrib/autograph/impl/conversion_test.py b/tensorflow/contrib/autograph/impl/conversion_test.py index bfc51365a3031176ed5151e5478b368a9d26aac6..86432573a719ea3f2b163746996dbf3301785a91 100644 --- a/tensorflow/contrib/autograph/impl/conversion_test.py +++ b/tensorflow/contrib/autograph/impl/conversion_test.py @@ -50,7 +50,7 @@ class ConversionTest(test.TestCase): self.assertTrue(conversion.is_whitelisted_for_graph(constant_op.constant)) def test_entity_to_graph_unsupported_types(self): - with self.assertRaises(ValueError): + with self.assertRaises(NotImplementedError): program_ctx = self._simple_program_ctx() conversion.entity_to_graph('dummy', program_ctx, None, None) @@ -61,7 +61,7 @@ class ConversionTest(test.TestCase): program_ctx = self._simple_program_ctx() nodes, name, ns = conversion.entity_to_graph(f, program_ctx, None, None) - fn_node, = nodes + fn_node, _ = nodes self.assertIsInstance(fn_node, gast.FunctionDef) self.assertEqual('tf__f', name) self.assertIs(ns['b'], b) @@ -115,10 +115,12 @@ class ConversionTest(test.TestCase): self.assertTrue(TestBase in program_ctx.dependency_cache) self.assertTrue(TestSubclass in program_ctx.dependency_cache) + # The returned nodes will include: + # , , self.assertEqual('TfTestBase', - program_ctx.dependency_cache[TestBase][-1].name) + program_ctx.dependency_cache[TestBase][-2].name) self.assertEqual('TfTestSubclass', - program_ctx.dependency_cache[TestSubclass][-1].name) + program_ctx.dependency_cache[TestSubclass][-2].name) def test_entity_to_graph_class_hierarchy_whitelisted(self): @@ -138,8 +140,10 @@ class ConversionTest(test.TestCase): self.assertFalse(training.Model in program_ctx.dependency_cache) self.assertEqual( 'Model', program_ctx.dependency_cache[TestSubclass][0].names[0].name) + # The returned nodes will include: + # , , self.assertEqual('TfTestSubclass', - program_ctx.dependency_cache[TestSubclass][-1].name) + program_ctx.dependency_cache[TestSubclass][-2].name) def test_entity_to_graph_lambda(self): f = lambda a: a diff --git a/tensorflow/contrib/autograph/operators/control_flow.py b/tensorflow/contrib/autograph/operators/control_flow.py index 988df70157170ed0a9ece33976e871e6f7693bbc..9909e521644a7a901653dc09853222167828c75c 100644 --- a/tensorflow/contrib/autograph/operators/control_flow.py +++ b/tensorflow/contrib/autograph/operators/control_flow.py @@ -141,7 +141,7 @@ def _dataset_for_stmt(ds, extra_test, body, init_state): while_body, init_state=(epoch_number, iterate) + init_state, extra_deps=()) - # Dropping the epoch number and iterate because they are not not syntactically + # Dropping the epoch number and iterate because they are not syntactically # visible. results = results[2:] @@ -212,12 +212,12 @@ def if_stmt(cond, body, orelse): Tuple containing the statement outputs. """ if tensor_util.is_tensor(cond): - return _tf_if_stmt(cond, body, orelse) + return tf_if_stmt(cond, body, orelse) else: return _py_if_stmt(cond, body, orelse) -def _tf_if_stmt(cond, body, orelse): +def tf_if_stmt(cond, body, orelse): """Overload of if_stmt that stages a TF cond.""" return control_flow_ops.cond(cond, body, orelse) diff --git a/tensorflow/contrib/autograph/pyct/common_transformers/BUILD b/tensorflow/contrib/autograph/pyct/common_transformers/BUILD index ca1441cf6f8bb034c95b37fcdd9e8158d1db2e39..a0938b3e5f0e52532f63fea6fb4c3e478fc51d93 100644 --- a/tensorflow/contrib/autograph/pyct/common_transformers/BUILD +++ b/tensorflow/contrib/autograph/pyct/common_transformers/BUILD @@ -24,6 +24,7 @@ py_library( deps = [ "//tensorflow/contrib/autograph/pyct", "@gast_archive//:gast", + "@six_archive//:six", ], ) diff --git a/tensorflow/contrib/autograph/pyct/common_transformers/anf.py b/tensorflow/contrib/autograph/pyct/common_transformers/anf.py index cc039986c219db1febfe610a5078e26eeb2d5a83..e42f679cfe31f919e10f7baf409247014b3cf386 100644 --- a/tensorflow/contrib/autograph/pyct/common_transformers/anf.py +++ b/tensorflow/contrib/autograph/pyct/common_transformers/anf.py @@ -12,12 +12,24 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Conversion to A-normal form.""" +"""Conversion to A-normal form. + +The general idea of A-normal form is that every intermediate value is +explicitly named with a variable. For more, see +https://en.wikipedia.org/wiki/A-normal_form. + +The specific converters used here are based on Python AST semantics as +documented at https://greentreesnakes.readthedocs.io/en/latest/. +""" from __future__ import absolute_import from __future__ import division from __future__ import print_function +import gast +import six + +from tensorflow.contrib.autograph.pyct import templates from tensorflow.contrib.autograph.pyct import transformer @@ -32,26 +44,375 @@ class DummyGensym(object): # * the symbols generated so far self._idx = 0 - def new_name(self, stem): + def new_name(self, stem='tmp'): self._idx += 1 return stem + '_' + str(1000 + self._idx) class AnfTransformer(transformer.Base): - """Performs the actual conversion.""" + """Performs the conversion to A-normal form (ANF).""" - # TODO(mdan): Link to a reference. - # TODO(mdan): Implement. + # The algorithm is a postorder recursive tree walk. Any given node A may, in + # general, require creation of a series B of Assign statements, which compute + # and explicitly name the intermediate values needed to compute the value of + # A. If A was already a statement, it can be replaced with the sequence B + + # [A]. If A was an expression, B needs to be propagated up the tree until a + # statement is encountered. Since the `ast.NodeTransformer` framework makes + # no provision for subtraversals returning side information, this class + # accumulates the sequence B in an instance variable. - def __init__(self, entity_info): - """Creates a transformer. + # The only other subtlety is that some Python statements (like `if`) have both + # expression fields (`test`) and statement list fields (`body` and `orelse`). + # Any additional assignments needed to name all the intermediate values in the + # `test` can be prepended to the `if` node, but assignments produced by + # processing the `body` and the `orelse` need to be kept together with them, + # and not accidentally lifted out of the `if`. + + def __init__(self, entity_info, gensym_source=None): + """Creates an ANF transformer. Args: entity_info: transformer.EntityInfo + gensym_source: An optional object with the same interface as `DummyGensym` + for generating unique names """ super(AnfTransformer, self).__init__(entity_info) - self._gensym = DummyGensym(entity_info) + if gensym_source is None: + self._gensym = DummyGensym(entity_info) + else: + self._gensym = gensym_source(entity_info) + self._pending_statements = [] + + def _consume_pending_statements(self): + ans = self._pending_statements + self._pending_statements = [] + return ans + + def _add_pending_statement(self, stmt): + self._pending_statements.append(stmt) + + _trivial_nodes = ( + # Non-nodes that show up as AST fields + bool, six.string_types, + # Leaf nodes that are already in A-normal form + gast.expr_context, gast.Name, gast.Num, gast.Str, gast.Bytes, + gast.NameConstant, gast.Ellipsis, + # Binary operators + gast.Add, gast.Sub, gast.Mult, gast.Div, gast.Mod, gast.Pow, gast.LShift, + gast.RShift, gast.BitOr, gast.BitXor, gast.BitAnd, gast.FloorDiv, + # Unary operators + gast.Invert, gast.Not, gast.UAdd, gast.USub, + # Comparison operators + gast.Eq, gast.NotEq, gast.Lt, gast.LtE, gast.Gt, gast.GtE, + gast.Is, gast.IsNot, gast.In, gast.NotIn, + ) + + def _is_node_trivial(self, node): + if node is None: + return True + elif isinstance(node, self._trivial_nodes): + return True + elif isinstance(node, gast.keyword): + return self._is_node_trivial(node.value) + elif isinstance(node, (gast.Starred, gast.withitem, gast.slice)): + return self._are_children_trivial(node) + return False + + def _are_children_trivial(self, node): + for field in node._fields: + if not field.startswith('__'): + if not self._is_node_trivial(getattr(node, field)): + return False + return True + + def _ensure_node_is_trivial(self, node): + if node is None: + return node + elif isinstance(node, self._trivial_nodes): + return node + elif isinstance(node, list): + # If something's field was actually a list, e.g., variadic arguments. + return [self._ensure_node_is_trivial(n) for n in node] + elif isinstance(node, gast.keyword): + node.value = self._ensure_node_is_trivial(node.value) + return node + elif isinstance(node, (gast.Starred, gast.withitem, gast.slice)): + return self._ensure_fields_trivial(node) + elif isinstance(node, gast.expr): + temp_name = self._gensym.new_name() + temp_assign = templates.replace( + 'temp_name = expr', temp_name=temp_name, expr=node)[0] + self._add_pending_statement(temp_assign) + answer = templates.replace('temp_name', temp_name=temp_name)[0] + return answer + else: + raise ValueError('Do not know how to treat {}'.format(node)) + + def _ensure_fields_trivial(self, node): + for field in node._fields: + if field.startswith('__'): + continue + setattr(node, field, self._ensure_node_is_trivial(getattr(node, field))) + return node + + def _visit_strict_statement(self, node, trivialize_children=True): + assert not self._pending_statements + node = self.generic_visit(node) + if trivialize_children: + self._ensure_fields_trivial(node) + results = self._consume_pending_statements() + results.append(node) + return results + + def _visit_strict_expression(self, node): + node = self.generic_visit(node) + self._ensure_fields_trivial(node) + return node + + # Note on code order: These are listed in the same order as the grammar + # elements on https://github.com/serge-sans-paille/gast + + # FunctionDef, AsyncFunctionDef, and ClassDef should be correct by default. + + def visit_Return(self, node): + return self._visit_strict_statement(node) + + def visit_Delete(self, node): + return self._visit_strict_statement(node, trivialize_children=False) + + def visit_Assign(self, node): + return self._visit_strict_statement(node, trivialize_children=False) + + def visit_AugAssign(self, node): + return self._visit_strict_statement(node, trivialize_children=False) + + def visit_Print(self, node): + return self._visit_strict_statement(node) + + def visit_For(self, node): + assert not self._pending_statements + # It's important to visit node.iter first, because any statements created + # thereby need to live outside the body. + self.visit(node.iter) + node.iter = self._ensure_node_is_trivial(node.iter) + iter_stmts = self._consume_pending_statements() + # This generic_visit will revisit node.iter, but that is both correct and + # cheap because by this point node.iter is trivial. + node = self.generic_visit(node) + assert not self._pending_statements + iter_stmts.append(node) + return iter_stmts + + def visit_AsyncFor(self, node): + if not self._are_children_trivial(node): + msg = ('Nontrivial AsyncFor nodes not supported yet ' + '(need to think through the semantics).') + raise ValueError(msg) + return self.generic_visit(node) + + def visit_While(self, node): + if not self._is_node_trivial(node.test): + msg = ('While with nontrivial test not supported yet ' + '(need to avoid precomputing the test).') + raise ValueError(msg) + return self.generic_visit(node) + + def visit_If(self, node): + assert not self._pending_statements + # It's important to visit node.test first, because any statements created + # thereby need to live outside the body. + self.visit(node.test) + node.test = self._ensure_node_is_trivial(node.test) + condition_stmts = self._consume_pending_statements() + # This generic_visit will revisit node.test, but that is both correct and + # cheap because by this point node.test is trivial. + node = self.generic_visit(node) + assert not self._pending_statements + condition_stmts.append(node) + return condition_stmts + + def visit_With(self, node): + assert not self._pending_statements + # It's important to visit node.items first, because any statements created + # thereby need to live outside the body. + for item in node.items: + self.visit(item) + node.items = [self._ensure_node_is_trivial(n) for n in node.items] + contexts_stmts = self._consume_pending_statements() + # This generic_visit will revisit node.items, but that is both correct and + # cheap because by this point node.items is trivial. + node = self.generic_visit(node) + assert not self._pending_statements + contexts_stmts.append(node) + return contexts_stmts + + def visit_AsyncWith(self, node): + if not self._are_children_trivial(node): + msg = ('Nontrivial AsyncWith nodes not supported yet ' + '(need to think through the semantics).') + raise ValueError(msg) + return self.generic_visit(node) + + def visit_Raise(self, node): + return self._visit_strict_statement(node) + + # Try should be correct by default. + + def visit_Assert(self, node): + if not self._are_children_trivial(node): + msg = ('Nontrivial Assert nodes not supported yet ' + '(need to avoid computing the test when assertions are off, and ' + 'avoid computing the irritant when the assertion does not fire).') + raise ValueError(msg) + return self.generic_visit(node) + + # Import and ImportFrom should be correct by default. + + def visit_Exec(self, node): + return self._visit_strict_statement(node) + + # Global and Nonlocal should be correct by default. + + def visit_Expr(self, node): + return self._visit_strict_statement(node, trivialize_children=False) + + # Pass, Break, and Continue should be correct by default. + + def visit_BoolOp(self, node): + if not self._are_children_trivial(node): + msg = ('Nontrivial BoolOp nodes not supported yet ' + '(need to preserve short-circuiting semantics).') + raise ValueError(msg) + return self.generic_visit(node) + + def visit_BinOp(self, node): + return self._visit_strict_expression(node) + + def visit_UnaryOp(self, node): + return self._visit_strict_expression(node) + + def visit_Lambda(self, node): + if not self._are_children_trivial(node): + msg = ('Nontrivial Lambda nodes not supported ' + '(cannot insert statements into lambda bodies).') + raise ValueError(msg) + return self.generic_visit(node) + + def visit_IfExp(self, node): + if not self._are_children_trivial(node): + msg = ('Nontrivial IfExp nodes not supported yet ' + '(need to convert to If statement, to evaluate branches lazily ' + 'and insert statements into them).') + raise ValueError(msg) + return self.generic_visit(node) + + def visit_Dict(self, node): + return self._visit_strict_expression(node) + + def visit_Set(self, node): + return self._visit_strict_expression(node) + + def visit_ListComp(self, node): + msg = ('ListComp nodes not supported ' + '(need to convert to a form that tolerates ' + 'assignment statements in clause bodies).') + raise ValueError(msg) + + def visit_SetComp(self, node): + msg = ('SetComp nodes not supported ' + '(need to convert to a form that tolerates ' + 'assignment statements in clause bodies).') + raise ValueError(msg) + + def visit_DictComp(self, node): + msg = ('DictComp nodes not supported ' + '(need to convert to a form that tolerates ' + 'assignment statements in clause bodies).') + raise ValueError(msg) + + def visit_GeneratorExp(self, node): + msg = ('GeneratorExp nodes not supported ' + '(need to convert to a form that tolerates ' + 'assignment statements in clause bodies).') + raise ValueError(msg) + + def visit_Await(self, node): + if not self._are_children_trivial(node): + msg = ('Nontrivial Await nodes not supported yet ' + '(need to think through the semantics).') + raise ValueError(msg) + return self.generic_visit(node) + + def visit_Yield(self, node): + return self._visit_strict_expression(node) + + def visit_YieldFrom(self, node): + if not self._are_children_trivial(node): + msg = ('Nontrivial YieldFrom nodes not supported yet ' + '(need to unit-test them in Python 2).') + raise ValueError(msg) + return self.generic_visit(node) + + def visit_Compare(self, node): + if len(node.ops) > 1: + msg = ('Multi-ary compare nodes not supported yet ' + '(need to preserve short-circuiting semantics).') + raise ValueError(msg) + return self._visit_strict_expression(node) + + def visit_Call(self, node): + return self._visit_strict_expression(node) + + def visit_Repr(self, node): + if not self._are_children_trivial(node): + msg = ('Nontrivial Repr nodes not supported yet ' + '(need to research their syntax and semantics).') + raise ValueError(msg) + return self.generic_visit(node) + + def visit_FormattedValue(self, node): + if not self._are_children_trivial(node): + msg = ('Nontrivial FormattedValue nodes not supported yet ' + '(need to unit-test them in Python 2).') + raise ValueError(msg) + return self.generic_visit(node) + + def visit_JoinedStr(self, node): + if not self._are_children_trivial(node): + msg = ('Nontrivial JoinedStr nodes not supported yet ' + '(need to unit-test them in Python 2).') + raise ValueError(msg) + return self.generic_visit(node) + + def visit_Attribute(self, node): + return self._visit_strict_expression(node) + + def visit_Subscript(self, node): + return self._visit_strict_expression(node) + + # Starred and Name are correct by default, because the right thing to do is to + # just recur. + + def visit_List(self, node): + return self._visit_strict_expression(node) + + def visit_Tuple(self, node): + return self._visit_strict_expression(node) + + +def transform(node, entity_info, gensym_source=None): + """Converts the given node to A-normal form (ANF). + + The general idea of A-normal form: https://en.wikipedia.org/wiki/A-normal_form + The specific converters used here are based on Python AST semantics as + documented at https://greentreesnakes.readthedocs.io/en/latest/. -def transform(node, entity_info): - return AnfTransformer(entity_info).visit(node) + Args: + node: The node to transform. + entity_info: transformer.EntityInfo. TODO(mdan): What information does this + argument provide? + gensym_source: An optional object with the same interface as `DummyGensym` + for generating unique names. + """ + return AnfTransformer(entity_info, gensym_source=gensym_source).visit(node) diff --git a/tensorflow/contrib/autograph/pyct/common_transformers/anf_test.py b/tensorflow/contrib/autograph/pyct/common_transformers/anf_test.py index aefbc69d8cfb0d9b75e3da10fddee71b2c5e309a..951974820c784974cb5bb2320adbb2b07f9332df 100644 --- a/tensorflow/contrib/autograph/pyct/common_transformers/anf_test.py +++ b/tensorflow/contrib/autograph/pyct/common_transformers/anf_test.py @@ -18,6 +18,8 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import textwrap + from tensorflow.contrib.autograph.pyct import compiler from tensorflow.contrib.autograph.pyct import parser from tensorflow.contrib.autograph.pyct import transformer @@ -25,6 +27,22 @@ from tensorflow.contrib.autograph.pyct.common_transformers import anf from tensorflow.python.platform import test +class DummyGensym(object): + """A dumb gensym that suffixes a stem by sequential numbers from 1000.""" + + def __init__(self, entity_info): + del entity_info + # A proper implementation needs to account for: + # * entity_info.namespace + # * all the symbols defined in the AST + # * the symbols generated so far + self._idx = 0 + + def new_name(self, stem='tmp'): + self._idx += 1 + return stem + '_' + str(1000 + self._idx) + + class AnfTransformerTest(test.TestCase): def _simple_source_info(self): @@ -37,17 +55,349 @@ class AnfTransformerTest(test.TestCase): owner_type=None) def test_basic(self): - def test_function(): a = 0 return a - node, _ = parser.parse_entity(test_function) node = anf.transform(node.body[0], self._simple_source_info()) result, _ = compiler.ast_to_object(node) - self.assertEqual(test_function(), result.test_function()) + def assert_same_ast(self, expected_node, node, msg=None): + expected_source = compiler.ast_to_source(expected_node, indentation=' ') + expected_str = textwrap.dedent(expected_source).strip() + got_source = compiler.ast_to_source(node, indentation=' ') + got_str = textwrap.dedent(got_source).strip() + self.assertEqual(expected_str, got_str, msg=msg) + + def assert_body_anfs_as_expected(self, expected_fn, test_fn): + # Testing the code bodies only. Wrapping them in functions so the + # syntax highlights nicely, but Python doesn't try to execute the + # statements. + exp_node, _ = parser.parse_entity(expected_fn) + node, _ = parser.parse_entity(test_fn) + node = anf.transform( + node, self._simple_source_info(), gensym_source=DummyGensym) + exp_name = exp_node.body[0].name + # Ignoring the function names in the result because they can't be + # the same (because both functions have to exist in the same scope + # at the same time). + node.body[0].name = exp_name + self.assert_same_ast(exp_node, node) + # Check that ANF is idempotent + node_repeated = anf.transform( + node, self._simple_source_info(), gensym_source=DummyGensym) + self.assert_same_ast(node_repeated, node) + + def test_binop_basic(self): + + def test_function(x, y, z): + a = x + y + z + return a + + def expected_result(x, y, z): + tmp_1001 = x + y + a = tmp_1001 + z + return a + + self.assert_body_anfs_as_expected(expected_result, test_function) + + def test_if_basic(self): + + def test_function(a, b, c, e, f, g): + if a + b + c: + d = e + f + g + return d + + def expected_result(a, b, c, e, f, g): + tmp_1001 = a + b + tmp_1002 = tmp_1001 + c + if tmp_1002: + tmp_1003 = e + f + d = tmp_1003 + g + return d + + self.assert_body_anfs_as_expected(expected_result, test_function) + + def test_nested_binop_and_return(self): + + def test_function(b, c, d, e): + return (2 * b + c) + (d + e) + + def expected_result(b, c, d, e): + tmp_1001 = 2 * b + tmp_1002 = tmp_1001 + c + tmp_1003 = d + e + tmp_1004 = tmp_1002 + tmp_1003 + return tmp_1004 + + self.assert_body_anfs_as_expected(expected_result, test_function) + + def test_function_call_and_expr(self): + + def test_function(call_something, a, b, y, z, c, d, e, f, g, h, i): + call_something(a + b, y * z, kwarg=c + d, *(e + f), **(g + h + i)) + + def expected_result(call_something, a, b, y, z, c, d, e, f, g, h, i): + tmp_1001 = g + h + tmp_1002 = a + b + tmp_1003 = y * z + tmp_1004 = e + f + tmp_1005 = c + d + tmp_1006 = tmp_1001 + i + call_something(tmp_1002, tmp_1003, kwarg=tmp_1005, *tmp_1004, **tmp_1006) + + self.assert_body_anfs_as_expected(expected_result, test_function) + + def test_with_and_print(self): + + def test_function(a, b, c): + with a + b + c as d: + print(2 * d + 1) + + def expected_result(a, b, c): + tmp_1001 = a + b + tmp_1002 = tmp_1001 + c + with tmp_1002 as d: + tmp_1003 = 2 * d + tmp_1004 = tmp_1003 + 1 + print(tmp_1004) + + self.assert_body_anfs_as_expected(expected_result, test_function) + + def test_local_definition_and_binary_compare(self): + + def test_function(): + def foo(a, b): + return 2 * a < b + return foo + + def expected_result(): + def foo(a, b): + tmp_1001 = 2 * a + tmp_1002 = tmp_1001 < b + return tmp_1002 + return foo + + self.assert_body_anfs_as_expected(expected_result, test_function) + + def test_list_literal(self): + + def test_function(a, b, c, d, e, f): + return [a + b, c + d, e + f] + + def expected_result(a, b, c, d, e, f): + tmp_1001 = a + b + tmp_1002 = c + d + tmp_1003 = e + f + tmp_1004 = [tmp_1001, tmp_1002, tmp_1003] + return tmp_1004 + + self.assert_body_anfs_as_expected(expected_result, test_function) + + def test_tuple_literal_and_unary(self): + + def test_function(a, b, c, d, e, f): + return (a + b, -(c + d), e + f) + + def expected_result(a, b, c, d, e, f): + tmp_1001 = c + d + tmp_1002 = a + b + tmp_1003 = -tmp_1001 + tmp_1004 = e + f + tmp_1005 = (tmp_1002, tmp_1003, tmp_1004) + return tmp_1005 + + self.assert_body_anfs_as_expected(expected_result, test_function) + + def test_set_literal(self): + + def test_function(a, b, c, d, e, f): + return set(a + b, c + d, e + f) + + def expected_result(a, b, c, d, e, f): + tmp_1001 = a + b + tmp_1002 = c + d + tmp_1003 = e + f + tmp_1004 = set(tmp_1001, tmp_1002, tmp_1003) + return tmp_1004 + + self.assert_body_anfs_as_expected(expected_result, test_function) + + def test_dict_literal_and_repr(self): + + def test_function(foo, bar, baz): + return repr({foo + bar + baz: 7 | 8}) + + def expected_result(foo, bar, baz): + tmp_1001 = foo + bar + tmp_1002 = tmp_1001 + baz + tmp_1003 = 7 | 8 + tmp_1004 = {tmp_1002: tmp_1003} + tmp_1005 = repr(tmp_1004) + return tmp_1005 + + self.assert_body_anfs_as_expected(expected_result, test_function) + + def test_field_read_and_write(self): + + def test_function(a, d): + a.b.c = d.e.f + 3 + + def expected_result(a, d): + tmp_1001 = a.b + tmp_1002 = d.e + tmp_1003 = tmp_1002.f + tmp_1001.c = tmp_1003 + 3 + + self.assert_body_anfs_as_expected(expected_result, test_function) + + def test_subscript_read_and_write(self): + + def test_function(a, b, c, d, e, f): + a[b][c] = d[e][f] + 3 + + def expected_result(a, b, c, d, e, f): + tmp_1001 = a[b] + tmp_1002 = d[e] + tmp_1003 = tmp_1002[f] + tmp_1001[c] = tmp_1003 + 3 + + self.assert_body_anfs_as_expected(expected_result, test_function) + + def test_augassign_and_delete(self): + + def test_function(a, x, y, z): + a += x + y + z + del a + del z[y][x] + + def expected_result(a, x, y, z): + tmp_1001 = x + y + a += tmp_1001 + z + del a + tmp_1002 = z[y] + del tmp_1002[x] + + self.assert_body_anfs_as_expected(expected_result, test_function) + + def test_raise_yield_and_raise(self): + + def test_function(a, c, some_computed, exception): + yield a ** c + raise some_computed('complicated' + exception) + + def expected_result(a, c, some_computed, exception): + tmp_1001 = a ** c + yield tmp_1001 + tmp_1002 = 'complicated' + exception + tmp_1003 = some_computed(tmp_1002) + raise tmp_1003 + + self.assert_body_anfs_as_expected(expected_result, test_function) + + def test_with_and_if_with_expressions(self): + + def test_function(foo, bar, function, quux, quozzle, w, x, y, z): + with foo + bar: + function(x + y) + if quux + quozzle: + function(z / w) + + def expected_result(foo, bar, function, quux, quozzle, w, x, y, z): + tmp_1001 = foo + bar + with tmp_1001: + tmp_1002 = x + y + function(tmp_1002) + tmp_1003 = quux + quozzle + if tmp_1003: + tmp_1004 = z / w + function(tmp_1004) + + self.assert_body_anfs_as_expected(expected_result, test_function) + + def test_exec(self): + + def test_function(): + # The point is to test A-normal form conversion of exec + # pylint: disable=exec-used + exec('computed' + 5 + 'stuff', globals(), locals()) + + def expected_result(): + # pylint: disable=exec-used + tmp_1001 = 'computed' + 5 + tmp_1002 = tmp_1001 + 'stuff' + tmp_1003 = globals() + tmp_1004 = locals() + exec(tmp_1002, tmp_1003, tmp_1004) + + self.assert_body_anfs_as_expected(expected_result, test_function) + + def test_simple_while_and_assert(self): + + def test_function(foo, quux): + while foo: + assert quux + foo = foo + 1 * 3 + + def expected_result(foo, quux): + while foo: + assert quux + tmp_1001 = 1 * 3 + foo = foo + tmp_1001 + + self.assert_body_anfs_as_expected(expected_result, test_function) + + def test_for(self): + + def test_function(compute, something, complicated, foo): + for foo in compute(something + complicated): + bar = foo + 1 * 3 + return bar + + def expected_result(compute, something, complicated, foo): + tmp_1001 = something + complicated + tmp_1002 = compute(tmp_1001) + for foo in tmp_1002: + tmp_1003 = 1 * 3 + bar = foo + tmp_1003 + return bar + + self.assert_body_anfs_as_expected(expected_result, test_function) + + # This test collects several examples where the definition of A-normal form + # implemented by this transformer is questionable. Mostly it's here to spell + # out what the definition is in these cases. + def test_controversial(self): + + def test_function(b, c, d, f): + a = c + d + a.b = c + d + a[b] = c + d + a += c + d + a, b = c + a, b = c, d + a = f(c) + a = f(c + d) + a[b + d] = f.e(c + d) + + def expected_result(b, c, d, f): + a = c + d + a.b = c + d # Should be a.b = tmp? (Definitely not tmp = c + d) + a[b] = c + d # Should be a[b] = tmp? (Definitely not tmp = c + d) + a += c + d # Should be a += tmp? (Definitely not tmp = c + d) + a, b = c # Should be a = c[0], b = c[1]? Or not? + a, b = c, d # Should be a = c, b = d? Or not? + a = f(c) + tmp_1001 = c + d + a = f(tmp_1001) + tmp_1002 = b + d + tmp_1003 = f.e + tmp_1004 = c + d + a[tmp_1002] = tmp_1003(tmp_1004) # Or should be a[tmp1] = tmp2? + + self.assert_body_anfs_as_expected(expected_result, test_function) + if __name__ == '__main__': test.main() diff --git a/tensorflow/contrib/autograph/pyct/origin_info.py b/tensorflow/contrib/autograph/pyct/origin_info.py index 1aad2f47dfa66cc1bfd3a7a66d31b03e2aa0d09e..b60651a30e342dabe40cbcef1486826e16c2e2c7 100644 --- a/tensorflow/contrib/autograph/pyct/origin_info.py +++ b/tensorflow/contrib/autograph/pyct/origin_info.py @@ -18,8 +18,10 @@ from __future__ import division from __future__ import print_function import collections +import tokenize import gast +import six from tensorflow.contrib.autograph.pyct import anno from tensorflow.contrib.autograph.pyct import ast_util @@ -56,13 +58,14 @@ class Location( class OriginInfo( collections.namedtuple( 'OriginInfo', - ('loc', 'function_name', 'source_code_line'))): + ('loc', 'function_name', 'source_code_line', 'comment'))): """Container for information about the source code before conversion. Attributes: loc: Location function_name: Optional[Text] source_code_line: Text + comment: Optional[Text] """ def as_frame(self): @@ -152,6 +155,15 @@ def resolve(nodes, source, function=None): function_lineno = None function_filepath = None + # TODO(mdan): Pull this to a separate utility. + code_reader = six.StringIO(source) + comment_map = {} + for token in tokenize.generate_tokens(code_reader.readline): + tok_type, tok_string, loc, _, _ = token + srow, _ = loc + if tok_type == tokenize.COMMENT: + comment_map[srow] = tok_string.strip()[1:].strip() + source_lines = source.split('\n') for node in nodes: for n in gast.walk(node): @@ -169,5 +181,6 @@ def resolve(nodes, source, function=None): function_name = None location = Location(function_filepath, source_lineno, n.col_offset) - origin = OriginInfo(location, function_name, source_code_line) + origin = OriginInfo(location, function_name, + source_code_line, comment_map.get(source_lineno)) anno.setanno(n, anno.Basic.ORIGIN, origin) diff --git a/tensorflow/contrib/autograph/pyct/origin_info_test.py b/tensorflow/contrib/autograph/pyct/origin_info_test.py index 6d7d8b1622a2ddb1a1d0eaeec50bdfaf38f05182..eeaa13007ea0ae331293c216a76352956c0ee9ec 100644 --- a/tensorflow/contrib/autograph/pyct/origin_info_test.py +++ b/tensorflow/contrib/autograph/pyct/origin_info_test.py @@ -85,16 +85,19 @@ class OriginInfoTest(test.TestCase): self.assertEqual(origin.loc.lineno, 1) self.assertEqual(origin.loc.col_offset, 0) self.assertEqual(origin.source_code_line, 'def test_fn(x):') + self.assertIsNone(origin.comment) origin = anno.getanno(fn_node.body[0], anno.Basic.ORIGIN) self.assertEqual(origin.loc.lineno, 2) self.assertEqual(origin.loc.col_offset, 2) self.assertEqual(origin.source_code_line, ' """Docstring."""') + self.assertIsNone(origin.comment) origin = anno.getanno(fn_node.body[1], anno.Basic.ORIGIN) self.assertEqual(origin.loc.lineno, 3) self.assertEqual(origin.loc.col_offset, 2) self.assertEqual(origin.source_code_line, ' return x # comment') + self.assertEqual(origin.comment, 'comment') if __name__ == '__main__': diff --git a/tensorflow/contrib/autograph/pyct/static_analysis/reaching_definitions.py b/tensorflow/contrib/autograph/pyct/static_analysis/reaching_definitions.py index 9a84f1231cb71745f778285f30ada151a7c1accd..7f2b379d3de236020f1ec2b8a4972cc67b10b060 100644 --- a/tensorflow/contrib/autograph/pyct/static_analysis/reaching_definitions.py +++ b/tensorflow/contrib/autograph/pyct/static_analysis/reaching_definitions.py @@ -39,7 +39,7 @@ from tensorflow.contrib.autograph.pyct.static_analysis import annos class Definition(object): """Definition objects describe a unique definition of a variable. - Subclasses of this may be used by passing an appropriate factory fuction to + Subclasses of this may be used by passing an appropriate factory function to resolve. Attributes: diff --git a/tensorflow/contrib/autograph/pyct/testing/BUILD b/tensorflow/contrib/autograph/pyct/testing/BUILD new file mode 100644 index 0000000000000000000000000000000000000000..9ef1ac9663eac8febffd697d7164425716b65d9d --- /dev/null +++ b/tensorflow/contrib/autograph/pyct/testing/BUILD @@ -0,0 +1,46 @@ +licenses(["notice"]) # Apache 2.0 + +load("//tensorflow:tensorflow.bzl", "py_test") + +filegroup( + name = "all_files", + srcs = glob( + ["**/*"], + exclude = [ + "**/METADATA", + "**/OWNERS", + ], + ), + visibility = ["//tensorflow:__subpackages__"], +) + +py_library( + name = "testing", + srcs = [ + "codegen.py", + ], + srcs_version = "PY2AND3", + visibility = ["//visibility:public"], + deps = [ + "//tensorflow/contrib/autograph/pyct", + "//tensorflow/contrib/autograph/utils", + "@gast_archive//:gast", + ], +) + +py_test( + name = "codegen_test", + size = "large", + srcs = ["codegen_test.py"], + srcs_version = "PY2AND3", + tags = [ + "no_windows", + "nomsan", + ], + deps = [ + ":testing", + "//tensorflow/contrib/autograph/pyct", + "//tensorflow/python:client_testlib", + "@gast_archive//:gast", + ], +) diff --git a/tensorflow/contrib/autograph/pyct/testing/codegen.py b/tensorflow/contrib/autograph/pyct/testing/codegen.py new file mode 100644 index 0000000000000000000000000000000000000000..279e7c09dc6449184e2029ad65fc3f71d94db8b4 --- /dev/null +++ b/tensorflow/contrib/autograph/pyct/testing/codegen.py @@ -0,0 +1,234 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Random code generation for testing/fuzzing.""" +# pylint: disable=invalid-name +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import random +import string + +import gast +import numpy as np + +from tensorflow.contrib.autograph.pyct import templates + + +class NodeSampler(object): + sample_map = None + + def sample(self): + nodes, magnitudes = zip(*self.sample_map.items()) + return np.random.choice( + nodes, p=np.array(magnitudes, dtype='float32') / np.sum(magnitudes)) + + +class StatementSampler(NodeSampler): + sample_map = dict(( + (gast.Assign, 10), + (gast.Print, 1), + (gast.If, 2), + (gast.While, 2), + (gast.For, 0), + )) + + +class ExpressionSampler(NodeSampler): + sample_map = dict(( + (gast.UnaryOp, 1), + (gast.BinOp, 8), + (gast.Name, 1), + (gast.Call, 0), + )) + + +class CompareSampler(NodeSampler): + sample_map = dict(( + (gast.Eq, 1), + (gast.NotEq, 1), + (gast.Lt, 1), + (gast.LtE, 1), + (gast.Gt, 1), + (gast.GtE, 1), + (gast.Is, 1), + (gast.IsNot, 1), + )) + + +class BinaryOpSampler(NodeSampler): + sample_map = dict(( + (gast.Add, 1), + (gast.Sub, 1), + (gast.Mult, 1), + (gast.Div, 1), + (gast.FloorDiv, 1), + (gast.Mod, 1), + (gast.Pow, 1), + )) + + +class UnaryOpSampler(NodeSampler): + sample_map = dict(((gast.USub, 1), (gast.UAdd, 0))) + + +class NameSampler(NodeSampler): + sample_map = dict(( + ('new', 1), + ('existing', 1), + )) + + +N_CONTROLFLOW_STATEMENTS = 10 +N_FUNCTIONDEF_STATEMENTS = 10 + + +class CodeGenerator(object): + """Generate random syntactically-valid Python ASTs.""" + + def __init__(self, max_depth=3, depth=0): + self.max_depth = max_depth + self.depth = depth + + def generate_statement(self): + """Generate a statement node, dispatching to the correct class method.""" + desired_node = StatementSampler().sample() + self.depth += 1 + + # Enforce some constraints on generating statements. + # E.g., if statements need at least 3 readable variables. + # If we fail to satisfy our constraints, draw another sample. + if desired_node in (gast.While, gast.For, gast.If): + if self.depth > self.max_depth: + return self.generate_statement() + + # Go get the generator method and run it + method = 'generate_' + desired_node.__name__ + visitor = getattr(self, method) + node = visitor() + self.depth -= 1 + return node + + def sample_node_list(self, low, high, generator): + """Generate a list of statements of random length. + + Args: + low: Fewest number of statements to generate. + high: Highest number of statements to generate. + generator: Function to call to generate nodes. + + Returns: + A list of statements. + """ + statements = [] + for _ in range(np.random.randint(low, high)): + statements.append(generator()) + return statements + + def generate_Name(self, ctx=gast.Load()): + variable_name = '_' + ''.join( + random.choice(string.ascii_lowercase) for _ in range(4)) + return gast.Name(variable_name, ctx=ctx, annotation=None) + + def generate_BinOp(self): + # TODO(alexbw): convert to generate_expression when we get to limit + # expression depth. + op = BinaryOpSampler().sample()() + return gast.BinOp(self.generate_Name(), op, self.generate_Name()) + + def generate_Compare(self): + op = CompareSampler().sample()() + return gast.Compare(self.generate_Name(), [op], [self.generate_Name()]) + + def generate_UnaryOp(self): + operand = self.generate_Name() + op = UnaryOpSampler().sample()() + return gast.UnaryOp(op, operand) + + def generate_expression(self): + desired_node = ExpressionSampler().sample() + # Go get the generator method and run it + method = 'generate_' + desired_node.__name__ + generator = getattr(self, method) + return generator() + + def generate_Assign(self): + """Generate an Assign node.""" + # Generate left-hand side + target_node = self.generate_Name(gast.Store()) + # Generate right-hand side + value_node = self.generate_expression() + # Put it all together + node = gast.Assign(targets=[target_node], value=value_node) + return node + + def generate_If(self): + """Generate an If node.""" + test = self.generate_Compare() + + # Generate true branch statements + body = self.sample_node_list( + low=1, + high=N_CONTROLFLOW_STATEMENTS // 2, + generator=self.generate_statement) + + # Generate false branch statements + orelse = self.sample_node_list( + low=1, + high=N_CONTROLFLOW_STATEMENTS // 2, + generator=self.generate_statement) + + node = gast.If(test, body, orelse) + return node + + def generate_While(self): + """Generate a While node.""" + + test = self.generate_Compare() + body = self.sample_node_list( + low=1, high=N_CONTROLFLOW_STATEMENTS, generator=self.generate_statement) + orelse = [] # not generating else statements + + node = gast.While(test, body, orelse) + return node + + def generate_Call(self): + raise NotImplementedError + + def generate_Return(self): + return gast.Return(self.generate_expression()) + + def generate_Print(self): + return templates.replace('print(x)', x=self.generate_expression())[0] + + def generate_FunctionDef(self): + """Generate a FunctionDef node.""" + + # Generate the arguments, register them as available + arg_vars = self.sample_node_list( + low=2, high=10, generator=lambda: self.generate_Name(gast.Param())) + args = gast.arguments(arg_vars, None, [], [], None, []) + + # Generate the function body + body = self.sample_node_list( + low=1, high=N_FUNCTIONDEF_STATEMENTS, generator=self.generate_statement) + body.append(self.generate_Return()) + fn_name = self.generate_Name().id + node = gast.FunctionDef(fn_name, args, body, (), None) + return node + + +def generate_random_functiondef(): + return CodeGenerator().generate_FunctionDef() diff --git a/tensorflow/contrib/autograph/pyct/testing/codegen_test.py b/tensorflow/contrib/autograph/pyct/testing/codegen_test.py new file mode 100644 index 0000000000000000000000000000000000000000..255c3b2a2edc65ab978d8c32682fafd8ce00f5ac --- /dev/null +++ b/tensorflow/contrib/autograph/pyct/testing/codegen_test.py @@ -0,0 +1,40 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for type_info module.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.contrib.autograph.pyct import compiler +from tensorflow.contrib.autograph.pyct.testing import codegen +from tensorflow.python.platform import test + + +class CodeGenTest(test.TestCase): + + def test_codegen_gens(self): + np.random.seed(0) + for _ in range(1000): + node = codegen.generate_random_functiondef() + fn = compiler.ast_to_object(node) + self.assertIsNotNone( + fn, 'Generated invalid AST that could not convert to source.') + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/autograph/utils/builtins.py b/tensorflow/contrib/autograph/utils/builtins.py index 71079cfdc04feaf26ab07b7dba193f745555433f..4dd440ef197b7e24b901bc9e30794b0182378a32 100644 --- a/tensorflow/contrib/autograph/utils/builtins.py +++ b/tensorflow/contrib/autograph/utils/builtins.py @@ -27,6 +27,7 @@ from tensorflow.contrib.autograph.utils import type_check from tensorflow.python.framework import dtypes from tensorflow.python.framework import tensor_util from tensorflow.python.ops import array_ops +from tensorflow.python.ops import list_ops from tensorflow.python.ops import logging_ops from tensorflow.python.ops import math_ops @@ -43,6 +44,8 @@ def dynamic_builtin(f, *args, **kwargs): return dynamic_int(*args, **kwargs) if f is float: return dynamic_float(*args, **kwargs) + if f is abs: + return dynamic_abs(*args, **kwargs) raise NotImplementedError( 'The "%s" builtin is not yet supported.' % f.__name__) @@ -50,7 +53,9 @@ def dynamic_builtin(f, *args, **kwargs): def dynamic_len(list_or_tensor): """Implementation of len using dynamic dispatch.""" - if tensor_util.is_tensor(list_or_tensor): + if _is_tensor_list(list_or_tensor): + return list_ops.tensor_list_length(list_or_tensor) + elif tensor_util.is_tensor(list_or_tensor): shape = list_or_tensor.shape if not shape.ndims: raise ValueError( @@ -59,6 +64,11 @@ def dynamic_len(list_or_tensor): return len(list_or_tensor) +def _is_tensor_list(list_or_tensor): + return (tensor_util.is_tensor(list_or_tensor) + and list_or_tensor.dtype == dtypes.variant) + + def dynamic_int(num_or_tensor, **kwargs): """Implementation of int() using dynamic dispatch.""" if tensor_util.is_tensor(num_or_tensor): @@ -73,6 +83,13 @@ def dynamic_float(num_or_tensor, **kwargs): return float(num_or_tensor) +def dynamic_abs(num_or_tensor, **kwargs): + if tensor_util.is_tensor(num_or_tensor): + return math_ops.abs(num_or_tensor, **kwargs) + else: + return abs(num_or_tensor, **kwargs) + + def dynamic_range(start_or_stop, stop=None, step=None): """Implementation of range using dynamic dispatch.""" if type_check.is_tensor(start_or_stop, stop, step): diff --git a/tensorflow/contrib/autograph/utils/builtins_test.py b/tensorflow/contrib/autograph/utils/builtins_test.py index b4821f36fcab8c201956e366d394bababb9f02b6..b1cd5253bc3ffb1e67d89ef79cf56eaeb65fae07 100644 --- a/tensorflow/contrib/autograph/utils/builtins_test.py +++ b/tensorflow/contrib/autograph/utils/builtins_test.py @@ -44,6 +44,23 @@ class BuiltinsTest(test.TestCase): with self.test_session() as sess: self.assertEqual(3, sess.run(builtins.dynamic_builtin(len, a))) + def test_dynamic_abs_tf_scalar(self): + a = constant_op.constant(-1) + + with self.test_session() as sess: + self.assertEqual(1, sess.run(builtins.dynamic_builtin(abs, a))) + + def test_dynamic_abs_tf_array(self): + a = constant_op.constant([-1, 2, -3]) + + with self.test_session() as sess: + self.assertListEqual([1, 2, 3], + list(sess.run(builtins.dynamic_builtin(abs, a)))) + + def test_dynamic_abs_py_scalar(self): + a = -1 + self.assertEqual(1, builtins.dynamic_builtin(abs, a)) + def test_dynamic_len_tf_matrix(self): a = constant_op.constant([[1, 2], [3, 4]]) diff --git a/tensorflow/contrib/bayesflow/python/ops/monte_carlo_impl.py b/tensorflow/contrib/bayesflow/python/ops/monte_carlo_impl.py index 68ead2f7609ca987180fe8973cf902f1e56b8388..9afe3df585fed6dc7feed1c364a4dac72041257d 100644 --- a/tensorflow/contrib/bayesflow/python/ops/monte_carlo_impl.py +++ b/tensorflow/contrib/bayesflow/python/ops/monte_carlo_impl.py @@ -14,8 +14,6 @@ # ============================================================================== """Monte Carlo integration and helpers. -See the @{$python/contrib.bayesflow.monte_carlo} guide. - @@expectation @@expectation_importance_sampler @@expectation_importance_sampler_logspace diff --git a/tensorflow/contrib/bigtable/README.md b/tensorflow/contrib/bigtable/README.md index d7c71a20ed4ba6a55dc0356ab5a3d096ed042e59..b9abfa8295f9013cd8e92f87466a73952ccceb10 100644 --- a/tensorflow/contrib/bigtable/README.md +++ b/tensorflow/contrib/bigtable/README.md @@ -1,4 +1,4 @@ -# Bigtable # +# Google Cloud Bigtable [Cloud Bigtable](https://cloud.google.com/bigtable/) is a high performance storage system that can store and serve training data. This contrib @@ -13,7 +13,7 @@ Bigtable at high speed, in particular to feed modern accelerators. For general-purpose Cloud Bigtable APIs, see the [official Cloud Bigtable client library documentation][clientdoc]. -[clientdoc]: https://cloud.google.com/bigtable/docs/reference/libraries +[clientdoc]: https://cloud.google.com/bigtable/docs/reference/libraries ## Sample Use @@ -324,7 +324,7 @@ If you encounter a log line that includes the following: "filename":"/usr/share/grpc/roots.pem" ``` -you likely need to copy the [gRPC roots.pem file][grpcPem] to +you likely need to copy the [gRPC `roots.pem` file][grpcPem] to `/usr/share/grpc/roots.pem` on your local machine. [grpcPem]: https://github.com/grpc/grpc/blob/master/etc/roots.pem @@ -338,7 +338,10 @@ are available. - **Compute Engine**: When running on Compute Engine, the client will often use the service account from the virtual machine's metadata service. Be sure to authorize your Compute Engine VM to have access to the Cloud Bigtable service - when creating your VM. + when creating your VM, or [update the VM's scopes][update-vm-scopes] on a + running VM if you run into this issue. - **Cloud TPU**: Your Cloud TPUs run with the designated Cloud TPU service account dedicated to your GCP project. Ensure the service account has been authorized via the Cloud Console to access your Cloud Bigtable instances. + +[update-vm-scopes]: https://cloud.google.com/compute/docs/access/create-enable-service-accounts-for-instances#changeserviceaccountandscopes diff --git a/tensorflow/contrib/bigtable/kernels/bigtable_kernels.cc b/tensorflow/contrib/bigtable/kernels/bigtable_kernels.cc index a6755a3496f3e1720f1c8c67f75521f2380a9845..a25a641cdb4608dee6d6c1bd18697860cc1f5613 100644 --- a/tensorflow/contrib/bigtable/kernels/bigtable_kernels.cc +++ b/tensorflow/contrib/bigtable/kernels/bigtable_kernels.cc @@ -84,6 +84,8 @@ class BigtableClientOp : public OpKernel { channel_args.SetMaxReceiveMessageSize( max_receive_message_size_); channel_args.SetUserAgentPrefix("tensorflow"); + channel_args.SetInt(GRPC_ARG_KEEPALIVE_PERMIT_WITHOUT_CALLS, 0); + channel_args.SetInt(GRPC_ARG_KEEPALIVE_TIMEOUT_MS, 60 * 1000); client_options.set_channel_arguments(channel_args); std::shared_ptr client = google::cloud::bigtable::CreateDefaultDataClient( @@ -216,11 +218,11 @@ class ToBigtableOp : public AsyncOpKernel { OP_REQUIRES_OK_ASYNC( ctx, GetDatasetFromVariantTensor(ctx->input(1), &dataset), done); - IteratorContext iter_ctx = dataset::MakeIteratorContext(ctx); std::unique_ptr iterator; OP_REQUIRES_OK_ASYNC( ctx, - dataset->MakeIterator(&iter_ctx, "ToBigtableOpIterator", &iterator), + dataset->MakeIterator(IteratorContext(ctx), "ToBigtableOpIterator", + &iterator), done); int64 timestamp_int; @@ -243,9 +245,10 @@ class ToBigtableOp : public AsyncOpKernel { ::google::cloud::bigtable::BulkMutation mutation; // TODO(saeta): Make # of mutations configurable. for (uint64 i = 0; i < 100 && !end_of_sequence; ++i) { - OP_REQUIRES_OK_ASYNC( - ctx, iterator->GetNext(&iter_ctx, &components, &end_of_sequence), - done); + OP_REQUIRES_OK_ASYNC(ctx, + iterator->GetNext(IteratorContext(ctx), + &components, &end_of_sequence), + done); if (!end_of_sequence) { OP_REQUIRES_OK_ASYNC( ctx, diff --git a/tensorflow/contrib/bigtable/kernels/bigtable_lookup_dataset_op.cc b/tensorflow/contrib/bigtable/kernels/bigtable_lookup_dataset_op.cc index 9e49fa35db4b2cd2c8991100a28a5b9c55f01ffe..bd32672aa99d7bf70c44a264f488482c4f213a0b 100644 --- a/tensorflow/contrib/bigtable/kernels/bigtable_lookup_dataset_op.cc +++ b/tensorflow/contrib/bigtable/kernels/bigtable_lookup_dataset_op.cc @@ -53,7 +53,7 @@ class BigtableLookupDatasetOp : public UnaryDatasetOpKernel { } private: - class Dataset : public GraphDatasetBase { + class Dataset : public DatasetBase { public: explicit Dataset(OpKernelContext* ctx, const DatasetBase* input, BigtableTableResource* table, @@ -61,7 +61,7 @@ class BigtableLookupDatasetOp : public UnaryDatasetOpKernel { std::vector columns, const DataTypeVector& output_types, std::vector output_shapes) - : GraphDatasetBase(ctx), + : DatasetBase(DatasetContext(ctx)), input_(input), table_(table), column_families_(std::move(column_families)), @@ -80,8 +80,8 @@ class BigtableLookupDatasetOp : public UnaryDatasetOpKernel { std::unique_ptr MakeIteratorInternal( const string& prefix) const override { - return std::unique_ptr(new Iterator( - {this, strings::StrCat(prefix, "::BigtableLookupDataset")})); + return std::unique_ptr( + new Iterator({this, strings::StrCat(prefix, "::BigtableLookup")})); } const DataTypeVector& output_dtypes() const override { @@ -96,6 +96,14 @@ class BigtableLookupDatasetOp : public UnaryDatasetOpKernel { return "BigtableLookupDatasetOp::Dataset"; } + protected: + Status AsGraphDefInternal(SerializationContext* ctx, + DatasetGraphDefBuilder* b, + Node** output) const override { + return errors::Unimplemented("%s does not support serialization", + DebugString()); + } + private: static ::google::cloud::bigtable::Filter MakeFilter( const std::vector& column_families, diff --git a/tensorflow/contrib/bigtable/kernels/bigtable_prefix_key_dataset_op.cc b/tensorflow/contrib/bigtable/kernels/bigtable_prefix_key_dataset_op.cc index e960719614a1c7c6c4af53ea924aef214a09b24d..a803fdcb49604ef4e596b64d62c7278c69764c15 100644 --- a/tensorflow/contrib/bigtable/kernels/bigtable_prefix_key_dataset_op.cc +++ b/tensorflow/contrib/bigtable/kernels/bigtable_prefix_key_dataset_op.cc @@ -35,11 +35,13 @@ class BigtablePrefixKeyDatasetOp : public DatasetOpKernel { } private: - class Dataset : public GraphDatasetBase { + class Dataset : public DatasetBase { public: explicit Dataset(OpKernelContext* ctx, BigtableTableResource* table, string prefix) - : GraphDatasetBase(ctx), table_(table), prefix_(std::move(prefix)) { + : DatasetBase(DatasetContext(ctx)), + table_(table), + prefix_(std::move(prefix)) { table_->Ref(); } @@ -47,8 +49,8 @@ class BigtablePrefixKeyDatasetOp : public DatasetOpKernel { std::unique_ptr MakeIteratorInternal( const string& prefix) const override { - return std::unique_ptr(new Iterator( - {this, strings::StrCat(prefix, "::BigtablePrefixKeyDataset")})); + return std::unique_ptr( + new Iterator({this, strings::StrCat(prefix, "::BigtablePrefixKey")})); } const DataTypeVector& output_dtypes() const override { @@ -68,6 +70,14 @@ class BigtablePrefixKeyDatasetOp : public DatasetOpKernel { BigtableTableResource* table() const { return table_; } + protected: + Status AsGraphDefInternal(SerializationContext* ctx, + DatasetGraphDefBuilder* b, + Node** output) const override { + return errors::Unimplemented("%s does not support serialization", + DebugString()); + } + private: class Iterator : public BigtableReaderDatasetIterator { public: diff --git a/tensorflow/contrib/bigtable/kernels/bigtable_range_key_dataset_op.cc b/tensorflow/contrib/bigtable/kernels/bigtable_range_key_dataset_op.cc index 96d3565d9b90e72f9e25e69e91f1931c982714cd..5cd0371c79f7eded9303b81dd388df8d306dff80 100644 --- a/tensorflow/contrib/bigtable/kernels/bigtable_range_key_dataset_op.cc +++ b/tensorflow/contrib/bigtable/kernels/bigtable_range_key_dataset_op.cc @@ -39,11 +39,11 @@ class BigtableRangeKeyDatasetOp : public DatasetOpKernel { } private: - class Dataset : public GraphDatasetBase { + class Dataset : public DatasetBase { public: explicit Dataset(OpKernelContext* ctx, BigtableTableResource* table, string start_key, string end_key) - : GraphDatasetBase(ctx), + : DatasetBase(DatasetContext(ctx)), table_(table), start_key_(std::move(start_key)), end_key_(std::move(end_key)) { @@ -54,8 +54,8 @@ class BigtableRangeKeyDatasetOp : public DatasetOpKernel { std::unique_ptr MakeIteratorInternal( const string& prefix) const override { - return std::unique_ptr(new Iterator( - {this, strings::StrCat(prefix, "::BigtableRangeKeyDataset")})); + return std::unique_ptr( + new Iterator({this, strings::StrCat(prefix, "::BigtableRangeKey")})); } const DataTypeVector& output_dtypes() const override { @@ -75,6 +75,14 @@ class BigtableRangeKeyDatasetOp : public DatasetOpKernel { BigtableTableResource* table() const { return table_; } + protected: + Status AsGraphDefInternal(SerializationContext* ctx, + DatasetGraphDefBuilder* b, + Node** output) const override { + return errors::Unimplemented("%s does not support serialization", + DebugString()); + } + private: class Iterator : public BigtableReaderDatasetIterator { public: diff --git a/tensorflow/contrib/bigtable/kernels/bigtable_sample_key_pairs_dataset_op.cc b/tensorflow/contrib/bigtable/kernels/bigtable_sample_key_pairs_dataset_op.cc index a1a63a975afd62325e01586542006058fa2c83bc..6928d9423c84f7504fea3ac1abd929357da034a5 100644 --- a/tensorflow/contrib/bigtable/kernels/bigtable_sample_key_pairs_dataset_op.cc +++ b/tensorflow/contrib/bigtable/kernels/bigtable_sample_key_pairs_dataset_op.cc @@ -52,11 +52,11 @@ class BigtableSampleKeyPairsDatasetOp : public DatasetOpKernel { } private: - class Dataset : public GraphDatasetBase { + class Dataset : public DatasetBase { public: explicit Dataset(OpKernelContext* ctx, BigtableTableResource* table, string prefix, string start_key, string end_key) - : GraphDatasetBase(ctx), + : DatasetBase(DatasetContext(ctx)), table_(table), key_range_(MakeMultiModeKeyRange( std::move(prefix), std::move(start_key), std::move(end_key))) { @@ -68,7 +68,7 @@ class BigtableSampleKeyPairsDatasetOp : public DatasetOpKernel { std::unique_ptr MakeIteratorInternal( const string& prefix) const override { return std::unique_ptr(new Iterator( - {this, strings::StrCat(prefix, "::BigtableSampleKeyPairsDataset")})); + {this, strings::StrCat(prefix, "::BigtableSampleKeyPairs")})); } const DataTypeVector& output_dtypes() const override { @@ -87,6 +87,14 @@ class BigtableSampleKeyPairsDatasetOp : public DatasetOpKernel { return "BigtableSampleKeyPairsDatasetOp::Dataset"; } + protected: + Status AsGraphDefInternal(SerializationContext* ctx, + DatasetGraphDefBuilder* b, + Node** output) const override { + return errors::Unimplemented("%s does not support serialization", + DebugString()); + } + private: static MultiModeKeyRange MakeMultiModeKeyRange(string prefix, string start_key, diff --git a/tensorflow/contrib/bigtable/kernels/bigtable_sample_keys_dataset_op.cc b/tensorflow/contrib/bigtable/kernels/bigtable_sample_keys_dataset_op.cc index a5a47cfe2dcf7c4034e0d5bc7d9a73ef9c1dc94e..a759fb5063900199325304ccf83c52f3bdd7d702 100644 --- a/tensorflow/contrib/bigtable/kernels/bigtable_sample_keys_dataset_op.cc +++ b/tensorflow/contrib/bigtable/kernels/bigtable_sample_keys_dataset_op.cc @@ -31,10 +31,10 @@ class BigtableSampleKeysDatasetOp : public DatasetOpKernel { } private: - class Dataset : public GraphDatasetBase { + class Dataset : public DatasetBase { public: explicit Dataset(OpKernelContext* ctx, BigtableTableResource* table) - : GraphDatasetBase(ctx), table_(table) { + : DatasetBase(DatasetContext(ctx)), table_(table) { table_->Ref(); } @@ -43,7 +43,7 @@ class BigtableSampleKeysDatasetOp : public DatasetOpKernel { std::unique_ptr MakeIteratorInternal( const string& prefix) const override { return std::unique_ptr(new Iterator( - {this, strings::StrCat(prefix, "::BigtableSampleKeysDataset")})); + {this, strings::StrCat(prefix, "::BigtableSampleKeys")})); } const DataTypeVector& output_dtypes() const override { @@ -63,6 +63,14 @@ class BigtableSampleKeysDatasetOp : public DatasetOpKernel { BigtableTableResource* table() const { return table_; } + protected: + Status AsGraphDefInternal(SerializationContext* ctx, + DatasetGraphDefBuilder* b, + Node** output) const override { + return errors::Unimplemented("%s does not support serialization", + DebugString()); + } + private: class Iterator : public DatasetIterator { public: diff --git a/tensorflow/contrib/bigtable/kernels/bigtable_scan_dataset_op.cc b/tensorflow/contrib/bigtable/kernels/bigtable_scan_dataset_op.cc index 13cb8681679ec1541b74a20474665f770790201f..78a920b077680980a209ad8c30c09409a6f4ebf5 100644 --- a/tensorflow/contrib/bigtable/kernels/bigtable_scan_dataset_op.cc +++ b/tensorflow/contrib/bigtable/kernels/bigtable_scan_dataset_op.cc @@ -84,7 +84,7 @@ class BigtableScanDatasetOp : public DatasetOpKernel { } private: - class Dataset : public GraphDatasetBase { + class Dataset : public DatasetBase { public: explicit Dataset(OpKernelContext* ctx, BigtableTableResource* table, string prefix, string start_key, string end_key, @@ -92,7 +92,7 @@ class BigtableScanDatasetOp : public DatasetOpKernel { std::vector columns, float probability, const DataTypeVector& output_types, std::vector output_shapes) - : GraphDatasetBase(ctx), + : DatasetBase(DatasetContext(ctx)), table_(table), prefix_(std::move(prefix)), start_key_(std::move(start_key)), @@ -111,8 +111,8 @@ class BigtableScanDatasetOp : public DatasetOpKernel { std::unique_ptr MakeIteratorInternal( const string& prefix) const override { - return std::unique_ptr(new Iterator( - {this, strings::StrCat(prefix, "::BigtableScanDataset")})); + return std::unique_ptr( + new Iterator({this, strings::StrCat(prefix, "::BigtableScan")})); } const DataTypeVector& output_dtypes() const override { @@ -129,6 +129,14 @@ class BigtableScanDatasetOp : public DatasetOpKernel { BigtableTableResource* table() const { return table_; } + protected: + Status AsGraphDefInternal(SerializationContext* ctx, + DatasetGraphDefBuilder* b, + Node** output) const override { + return errors::Unimplemented("%s does not support serialization", + DebugString()); + } + private: class Iterator : public BigtableReaderDatasetIterator { public: diff --git a/tensorflow/contrib/bigtable/python/ops/bigtable_api.py b/tensorflow/contrib/bigtable/python/ops/bigtable_api.py index fd30aa8bbb962257c1ef5ac07e047fffca88c4bc..3e1b6228673fbdcb5a228a11532d29e6b2c817dc 100644 --- a/tensorflow/contrib/bigtable/python/ops/bigtable_api.py +++ b/tensorflow/contrib/bigtable/python/ops/bigtable_api.py @@ -12,15 +12,15 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""The Python API for TensorFlow's Bigtable integration. +"""The Python API for TensorFlow's Cloud Bigtable integration. TensorFlow has support for reading from and writing to Cloud Bigtable. To use -the Bigtable TensorFlow integration, first create a BigtableClient (which -configures your connection to Cloud Bigtable), and then open a Table. The Table -object then allows you to create numerous @{tf.data.Dataset}s to read data, or -write a @{tf.data.Dataset} object to the underlying Bigtable Table. +TensorFlow + Cloud Bigtable integration, first create a BigtableClient to +configure your connection to Cloud Bigtable, and then create a BigtableTable +object to allow you to create numerous `tf.data.Dataset`s to read data, or +write a `tf.data.Dataset` object to the underlying Cloud Bigtable table. -For background on Google Cloud Bigtable, see: https://cloud.google.com/bigtable. +For background on Cloud Bigtable, see: https://cloud.google.com/bigtable . """ from __future__ import absolute_import @@ -48,7 +48,7 @@ class BigtableClient(object): """BigtableClient is the entrypoint for interacting with Cloud Bigtable in TF. BigtableClient encapsulates a connection to Cloud Bigtable, and exposes the - `table` method to open a Bigtable Table. + `table` method to open a Bigtable table. """ def __init__(self, @@ -94,7 +94,7 @@ class BigtableClient(object): project_id, instance_id, connection_pool_size, max_receive_message_size) def table(self, name, snapshot=None): - """Opens a table and returns a `BigtableTable` object. + """Opens a table and returns a `tf.contrib.bigtable.BigtableTable` object. Args: name: A `tf.string` `tf.Tensor` name of the table to open. @@ -102,8 +102,8 @@ class BigtableClient(object): request the creation of a snapshot. (Note: currently unimplemented.) Returns: - A `BigtableTable` python object representing the operations available on - the table. + A `tf.contrib.bigtable.BigtableTable` Python object representing the + operations available on the table. """ # TODO(saeta): Implement snapshot functionality. table = gen_bigtable_ops.bigtable_table(self._resource, name) @@ -133,7 +133,8 @@ class BigtableTable(object): """Retrieves the values of columns for a dataset of keys. Example usage: - ``` + + ```python table = bigtable_client.table("my_table") key_dataset = table.get_keys_prefix("imagenet") images = key_dataset.apply(table.lookup_columns(("cf1", "image"), @@ -144,7 +145,8 @@ class BigtableTable(object): Alternatively, you can use keyword arguments to specify the columns to capture. Example (same as above, rewritten): - ``` + + ```python table = bigtable_client.table("my_table") key_dataset = table.get_keys_prefix("imagenet") images = key_dataset.apply(table.lookup_columns( @@ -152,15 +154,17 @@ class BigtableTable(object): training_data = images.map(parse_and_crop, num_parallel_calls=64).batch(128) ``` - Note: certain kwargs keys are reserved, and thus some column families cannot - be identified using the kwargs syntax. Instead, please use the args syntax. - This list includes: + Note: certain `kwargs` keys are reserved, and thus, some column families + cannot be identified using the `kwargs` syntax. Instead, please use the + `args` syntax. This list includes: + - 'name' - This list can change at any time. + + Note: this list can change at any time. Args: *args: A list of tuples containing (column family, column name) pairs. - **kwargs: Column families and + **kwargs: Column families (keys) and column qualifiers (values). Returns: A function that can be passed to `tf.data.Dataset.apply` to retrieve the @@ -199,7 +203,7 @@ class BigtableTable(object): be retrieved. If end is None, all subsequent row keys will be retrieved. Returns: - A @{tf.data.Dataset} containing `tf.string` Tensors corresponding to all + A `tf.data.Dataset` containing `tf.string` Tensors corresponding to all of the row keys between `start` and `end`. """ # TODO(saeta): Make inclusive / exclusive configurable? @@ -215,7 +219,7 @@ class BigtableTable(object): retrieved. Returns: - A @{tf.data.Dataset}. containing `tf.string` Tensors corresponding to all + A `tf.data.Dataset`. containing `tf.string` Tensors corresponding to all of the row keys matching that prefix. """ return _BigtablePrefixKeyDataset(self, prefix) @@ -224,11 +228,11 @@ class BigtableTable(object): """Retrieves a sampling of row keys from the Bigtable table. This dataset is most often used in conjunction with - @{tf.contrib.data.parallel_interleave} to construct a set of ranges for + `tf.contrib.data.parallel_interleave` to construct a set of ranges for scanning in parallel. Returns: - A @{tf.data.Dataset} returning string row keys. + A `tf.data.Dataset` returning string row keys. """ return _BigtableSampleKeysDataset(self) @@ -268,7 +272,7 @@ class BigtableTable(object): that are treated as the column qualifier (column name). Returns: - A @{tf.data.Dataset} returning the row keys and the cell contents. + A `tf.data.Dataset` returning the row keys and the cell contents. Raises: ValueError: If the configured probability is unexpected. @@ -313,7 +317,7 @@ class BigtableTable(object): that are treated as the column qualifier (column name). Returns: - A @{tf.data.Dataset} returning the row keys and the cell contents. + A `tf.data.Dataset` returning the row keys and the cell contents. Raises: ValueError: If the configured probability is unexpected. @@ -331,7 +335,7 @@ class BigtableTable(object): """Retrieves row (including values) from the Bigtable service at high speed. Rows with row-key prefixed by `prefix` will be retrieved. This method is - similar to `scan_prefix`, but by constrast performs multiple sub-scans in + similar to `scan_prefix`, but by contrast performs multiple sub-scans in parallel in order to achieve higher performance. Note: The dataset produced by this method is not deterministic! @@ -369,7 +373,7 @@ class BigtableTable(object): that are treated as the column qualifier (column name). Returns: - A @{tf.data.Dataset} returning the row keys and the cell contents. + A `tf.data.Dataset` returning the row keys and the cell contents. Raises: ValueError: If the configured probability is unexpected. @@ -390,7 +394,7 @@ class BigtableTable(object): """Retrieves rows (including values) from the Bigtable service. Rows with row-keys between `start` and `end` will be retrieved. This method - is similar to `scan_range`, but by constrast performs multiple sub-scans in + is similar to `scan_range`, but by contrast performs multiple sub-scans in parallel in order to achieve higher performance. Note: The dataset produced by this method is not deterministic! @@ -431,7 +435,7 @@ class BigtableTable(object): that are treated as the column qualifier (column name). Returns: - A @{tf.data.Dataset} returning the row keys and the cell contents. + A `tf.data.Dataset` returning the row keys and the cell contents. Raises: ValueError: If the configured probability is unexpected. @@ -446,12 +450,12 @@ class BigtableTable(object): """Writes a dataset to the table. Args: - dataset: A @{tf.data.Dataset} to be written to this table. It must produce + dataset: A `tf.data.Dataset` to be written to this table. It must produce a list of number-of-columns+1 elements, all of which must be strings. The first value will be used as the row key, and subsequent values will be used as cell values for the corresponding columns from the corresponding column_families and columns entries. - column_families: A @{tf.Tensor} of `tf.string`s corresponding to the + column_families: A `tf.Tensor` of `tf.string`s corresponding to the column names to store the dataset's elements into. columns: A `tf.Tensor` of `tf.string`s corresponding to the column names to store the dataset's elements into. @@ -459,7 +463,7 @@ class BigtableTable(object): Leave as None to use server-provided timestamps. Returns: - A @{tf.Operation} that can be run to perform the write. + A `tf.Operation` that can be run to perform the write. Raises: ValueError: If there are unexpected or incompatible types, or if the @@ -498,7 +502,7 @@ class BigtableTable(object): normalized_columns: The column families and column qualifiers to retrieve. Returns: - A @{tf.data.Dataset} representing the result of the parallel scan. + A `tf.data.Dataset` representing the result of the parallel scan. """ if num_parallel_scans is None: num_parallel_scans = 50 @@ -712,7 +716,7 @@ class _BigtableScanDataset(dataset_ops.Dataset): class _BigtableSampleKeyPairsDataset(dataset_ops.Dataset): - """_BigtableKeyRangeDataset returns key pairs from the Bigtable. + """_BigtableSampleKeyPairsDataset returns key pairs from a Bigtable table. """ def __init__(self, table, prefix, start, end): diff --git a/tensorflow/contrib/boosted_trees/estimator_batch/BUILD b/tensorflow/contrib/boosted_trees/estimator_batch/BUILD index ef0e80cd0997bc0e95cd0d150e87db144a2dde44..5fcb19a47aac492d49b0d8e99af5699bae2ad9f0 100644 --- a/tensorflow/contrib/boosted_trees/estimator_batch/BUILD +++ b/tensorflow/contrib/boosted_trees/estimator_batch/BUILD @@ -147,6 +147,7 @@ py_library( deps = [ ":distillation_loss", ":estimator_utils", + ":model", ":trainer_hooks", "//tensorflow/contrib/boosted_trees:gbdt_batch", "//tensorflow/contrib/boosted_trees:model_ops_py", @@ -190,7 +191,7 @@ py_test( py_test( name = "estimator_test", - size = "medium", + size = "large", srcs = ["estimator_test.py"], srcs_version = "PY2AND3", tags = [ diff --git a/tensorflow/contrib/boosted_trees/estimator_batch/dnn_tree_combined_estimator.py b/tensorflow/contrib/boosted_trees/estimator_batch/dnn_tree_combined_estimator.py index 7eb429b636a5193a124dd9b0c020dae6cac910cb..194a5c8754cb0ab2db299e3fb5c998c0f27f8435 100644 --- a/tensorflow/contrib/boosted_trees/estimator_batch/dnn_tree_combined_estimator.py +++ b/tensorflow/contrib/boosted_trees/estimator_batch/dnn_tree_combined_estimator.py @@ -26,6 +26,7 @@ from __future__ import print_function import six from tensorflow.contrib import layers +from tensorflow.contrib.boosted_trees.estimator_batch import model from tensorflow.contrib.boosted_trees.estimator_batch import distillation_loss from tensorflow.contrib.boosted_trees.estimator_batch import estimator_utils from tensorflow.contrib.boosted_trees.estimator_batch import trainer_hooks @@ -34,6 +35,7 @@ from tensorflow.contrib.boosted_trees.python.training.functions import gbdt_batc from tensorflow.contrib.layers.python.layers import optimizers from tensorflow.contrib.learn.python.learn.estimators import estimator from tensorflow.contrib.learn.python.learn.estimators import head as head_lib +from tensorflow.python.estimator import estimator as core_estimator from tensorflow.contrib.learn.python.learn.estimators import model_fn from tensorflow.python.feature_column import feature_column as feature_column_lib from tensorflow.python.framework import ops @@ -62,27 +64,30 @@ def _add_hidden_layer_summary(value, tag): summary.histogram("%s_activation" % tag, value) -def _dnn_tree_combined_model_fn(features, - labels, - mode, - head, - dnn_hidden_units, - dnn_feature_columns, - tree_learner_config, - num_trees, - tree_examples_per_layer, - config=None, - dnn_optimizer="Adagrad", - dnn_activation_fn=nn.relu, - dnn_dropout=None, - dnn_input_layer_partitioner=None, - dnn_input_layer_to_tree=True, - dnn_steps_to_train=10000, - predict_with_tree_only=False, - tree_feature_columns=None, - tree_center_bias=False, - dnn_to_tree_distillation_param=None, - use_core_versions=False): +def _dnn_tree_combined_model_fn( + features, + labels, + mode, + head, + dnn_hidden_units, + dnn_feature_columns, + tree_learner_config, + num_trees, + tree_examples_per_layer, + config=None, + dnn_optimizer="Adagrad", + dnn_activation_fn=nn.relu, + dnn_dropout=None, + dnn_input_layer_partitioner=None, + dnn_input_layer_to_tree=True, + dnn_steps_to_train=10000, + predict_with_tree_only=False, + tree_feature_columns=None, + tree_center_bias=False, + dnn_to_tree_distillation_param=None, + use_core_versions=False, + output_type=model.ModelBuilderOutputType.MODEL_FN_OPS, + override_global_step_value=None): """DNN and GBDT combined model_fn. Args: @@ -131,6 +136,12 @@ def _dnn_tree_combined_model_fn(features, will be set to True. use_core_versions: Whether feature columns and loss are from the core (as opposed to contrib) version of tensorflow. + output_type: Whether to return ModelFnOps (old interface) or EstimatorSpec + (new interface). + override_global_step_value: If after the training is done, global step + value must be reset to this value. This is particularly useful for hyper + parameter tuning, which can't recognize early stopping due to the number + of trees. If None, no override of global step will happen. Returns: A `ModelFnOps` object. @@ -156,6 +167,10 @@ def _dnn_tree_combined_model_fn(features, partitioned_variables.min_max_variable_partitioner( max_partitions=config.num_ps_replicas, min_slice_size=64 << 20)) + if (output_type == model.ModelBuilderOutputType.ESTIMATOR_SPEC and + not use_core_versions): + raise ValueError("You must use core versions with Estimator Spec") + with variable_scope.variable_scope( dnn_parent_scope, values=tuple(six.itervalues(features)), @@ -235,7 +250,8 @@ def _dnn_tree_combined_model_fn(features, learner_config=tree_learner_config, feature_columns=tree_feature_columns, logits_dimension=head.logits_dimension, - features=tree_features) + features=tree_features, + use_core_columns=use_core_versions) with ops.name_scope("gbdt"): predictions_dict = gbdt_model.predict(mode) @@ -284,63 +300,98 @@ def _dnn_tree_combined_model_fn(features, del loss return control_flow_ops.no_op() - if use_core_versions: - model_fn_ops = head.create_estimator_spec( - features=features, - mode=mode, - labels=labels, - train_op_fn=_no_train_op_fn, - logits=tree_train_logits) - dnn_train_op = head.create_estimator_spec( - features=features, - mode=mode, - labels=labels, - train_op_fn=_dnn_train_op_fn, - logits=dnn_logits) - dnn_train_op = estimator_utils.estimator_spec_to_model_fn_ops( - dnn_train_op).train_op + if tree_center_bias: + num_trees += 1 + finalized_trees, attempted_trees = gbdt_model.get_number_of_trees_tensor() - tree_train_op = head.create_estimator_spec( - features=tree_features, - mode=mode, - labels=labels, - train_op_fn=_tree_train_op_fn, - logits=tree_train_logits) - tree_train_op = estimator_utils.estimator_spec_to_model_fn_ops( - tree_train_op).train_op + if output_type == model.ModelBuilderOutputType.MODEL_FN_OPS: + if use_core_versions: + model_fn_ops = head.create_estimator_spec( + features=features, + mode=mode, + labels=labels, + train_op_fn=_no_train_op_fn, + logits=tree_train_logits) + dnn_train_op = head.create_estimator_spec( + features=features, + mode=mode, + labels=labels, + train_op_fn=_dnn_train_op_fn, + logits=dnn_logits) + dnn_train_op = estimator_utils.estimator_spec_to_model_fn_ops( + dnn_train_op).train_op - model_fn_ops = estimator_utils.estimator_spec_to_model_fn_ops(model_fn_ops) - else: - model_fn_ops = head.create_model_fn_ops( + tree_train_op = head.create_estimator_spec( + features=tree_features, + mode=mode, + labels=labels, + train_op_fn=_tree_train_op_fn, + logits=tree_train_logits) + tree_train_op = estimator_utils.estimator_spec_to_model_fn_ops( + tree_train_op).train_op + + model_fn_ops = estimator_utils.estimator_spec_to_model_fn_ops( + model_fn_ops) + else: + model_fn_ops = head.create_model_fn_ops( + features=features, + mode=mode, + labels=labels, + train_op_fn=_no_train_op_fn, + logits=tree_train_logits) + dnn_train_op = head.create_model_fn_ops( + features=features, + mode=mode, + labels=labels, + train_op_fn=_dnn_train_op_fn, + logits=dnn_logits).train_op + tree_train_op = head.create_model_fn_ops( + features=tree_features, + mode=mode, + labels=labels, + train_op_fn=_tree_train_op_fn, + logits=tree_train_logits).train_op + + # Add the hooks + model_fn_ops.training_hooks.extend([ + trainer_hooks.SwitchTrainOp(dnn_train_op, dnn_steps_to_train, + tree_train_op), + trainer_hooks.StopAfterNTrees(num_trees, attempted_trees, + finalized_trees, + override_global_step_value) + ]) + return model_fn_ops + + elif output_type == model.ModelBuilderOutputType.ESTIMATOR_SPEC: + fusion_spec = head.create_estimator_spec( features=features, mode=mode, labels=labels, train_op_fn=_no_train_op_fn, logits=tree_train_logits) - dnn_train_op = head.create_model_fn_ops( + dnn_spec = head.create_estimator_spec( features=features, mode=mode, labels=labels, train_op_fn=_dnn_train_op_fn, - logits=dnn_logits).train_op - tree_train_op = head.create_model_fn_ops( + logits=dnn_logits) + tree_spec = head.create_estimator_spec( features=tree_features, mode=mode, labels=labels, train_op_fn=_tree_train_op_fn, - logits=tree_train_logits).train_op - - if tree_center_bias: - num_trees += 1 - finalized_trees, attempted_trees = gbdt_model.get_number_of_trees_tensor() - - model_fn_ops.training_hooks.extend([ - trainer_hooks.SwitchTrainOp(dnn_train_op, dnn_steps_to_train, - tree_train_op), - trainer_hooks.StopAfterNTrees(num_trees, attempted_trees, finalized_trees) - ]) + logits=tree_train_logits) - return model_fn_ops + training_hooks = [ + trainer_hooks.SwitchTrainOp(dnn_spec.train_op, dnn_steps_to_train, + tree_spec.train_op), + trainer_hooks.StopAfterNTrees(num_trees, attempted_trees, + finalized_trees, + override_global_step_value) + ] + fusion_spec = fusion_spec._replace(training_hooks=training_hooks + + list(fusion_spec.training_hooks)) + return fusion_spec class DNNBoostedTreeCombinedClassifier(estimator.Estimator): @@ -369,7 +420,8 @@ class DNNBoostedTreeCombinedClassifier(estimator.Estimator): tree_feature_columns=None, tree_center_bias=False, dnn_to_tree_distillation_param=None, - use_core_versions=False): + use_core_versions=False, + override_global_step_value=None): """Initializes a DNNBoostedTreeCombinedClassifier instance. Args: @@ -425,6 +477,10 @@ class DNNBoostedTreeCombinedClassifier(estimator.Estimator): will be set to True. use_core_versions: Whether feature columns and loss are from the core (as opposed to contrib) version of tensorflow. + override_global_step_value: If after the training is done, global step + value must be reset to this value. This is particularly useful for hyper + parameter tuning, which can't recognize early stopping due to the number + of trees. If None, no override of global step will happen. """ head = head_lib.multi_class_head( n_classes=n_classes, @@ -455,7 +511,8 @@ class DNNBoostedTreeCombinedClassifier(estimator.Estimator): tree_feature_columns=tree_feature_columns, tree_center_bias=tree_center_bias, dnn_to_tree_distillation_param=dnn_to_tree_distillation_param, - use_core_versions=use_core_versions) + use_core_versions=use_core_versions, + override_global_step_value=override_global_step_value) super(DNNBoostedTreeCombinedClassifier, self).__init__( model_fn=_model_fn, @@ -489,7 +546,8 @@ class DNNBoostedTreeCombinedRegressor(estimator.Estimator): tree_feature_columns=None, tree_center_bias=False, dnn_to_tree_distillation_param=None, - use_core_versions=False): + use_core_versions=False, + override_global_step_value=None): """Initializes a DNNBoostedTreeCombinedRegressor instance. Args: @@ -545,6 +603,10 @@ class DNNBoostedTreeCombinedRegressor(estimator.Estimator): will be set to True. use_core_versions: Whether feature columns and loss are from the core (as opposed to contrib) version of tensorflow. + override_global_step_value: If after the training is done, global step + value must be reset to this value. This is particularly useful for hyper + parameter tuning, which can't recognize early stopping due to the number + of trees. If None, no override of global step will happen. """ head = head_lib.regression_head( label_name=label_name, @@ -580,7 +642,8 @@ class DNNBoostedTreeCombinedRegressor(estimator.Estimator): tree_feature_columns=tree_feature_columns, tree_center_bias=tree_center_bias, dnn_to_tree_distillation_param=dnn_to_tree_distillation_param, - use_core_versions=use_core_versions) + use_core_versions=use_core_versions, + override_global_step_value=override_global_step_value) super(DNNBoostedTreeCombinedRegressor, self).__init__( model_fn=_model_fn, @@ -615,7 +678,8 @@ class DNNBoostedTreeCombinedEstimator(estimator.Estimator): tree_feature_columns=None, tree_center_bias=False, dnn_to_tree_distillation_param=None, - use_core_versions=False): + use_core_versions=False, + override_global_step_value=None): """Initializes a DNNBoostedTreeCombinedEstimator instance. Args: @@ -666,6 +730,10 @@ class DNNBoostedTreeCombinedEstimator(estimator.Estimator): will be set to True. use_core_versions: Whether feature columns and loss are from the core (as opposed to contrib) version of tensorflow. + override_global_step_value: If after the training is done, global step + value must be reset to this value. This is particularly useful for hyper + parameter tuning, which can't recognize early stopping due to the number + of trees. If None, no override of global step will happen. """ def _model_fn(features, labels, mode, config): @@ -690,10 +758,109 @@ class DNNBoostedTreeCombinedEstimator(estimator.Estimator): tree_feature_columns=tree_feature_columns, tree_center_bias=tree_center_bias, dnn_to_tree_distillation_param=dnn_to_tree_distillation_param, - use_core_versions=use_core_versions) + use_core_versions=use_core_versions, + override_global_step_value=override_global_step_value) super(DNNBoostedTreeCombinedEstimator, self).__init__( model_fn=_model_fn, model_dir=model_dir, config=config, feature_engineering_fn=feature_engineering_fn) + + +class CoreDNNBoostedTreeCombinedEstimator(core_estimator.Estimator): + """Initializes a core version of DNNBoostedTreeCombinedEstimator. + + Args: + dnn_hidden_units: List of hidden units per layer for DNN. + dnn_feature_columns: An iterable containing all the feature columns + used by the model's DNN. + tree_learner_config: A config for the tree learner. + num_trees: Number of trees to grow model to after training DNN. + tree_examples_per_layer: Number of examples to accumulate before + growing the tree a layer. This value has a big impact on model + quality and should be set equal to the number of examples in + training dataset if possible. It can also be a function that computes + the number of examples based on the depth of the layer that's + being built. + head: `Head` instance. + model_dir: Directory for model exports. + config: `RunConfig` of the estimator. + dnn_optimizer: string, `Optimizer` object, or callable that defines the + optimizer to use for training the DNN. If `None`, will use the Adagrad + optimizer with default learning rate. + dnn_activation_fn: Activation function applied to each layer of the DNN. + If `None`, will use `tf.nn.relu`. + dnn_dropout: When not `None`, the probability to drop out a given + unit in the DNN. + dnn_input_layer_partitioner: Partitioner for input layer of the DNN. + Defaults to `min_max_variable_partitioner` with `min_slice_size` + 64 << 20. + dnn_input_layer_to_tree: Whether to provide the DNN's input layer + as a feature to the tree. + dnn_steps_to_train: Number of steps to train dnn for before switching + to gbdt. + predict_with_tree_only: Whether to use only the tree model output as the + final prediction. + tree_feature_columns: An iterable containing all the feature columns + used by the model's boosted trees. If dnn_input_layer_to_tree is + set to True, these features are in addition to dnn_feature_columns. + tree_center_bias: Whether a separate tree should be created for + first fitting the bias. + dnn_to_tree_distillation_param: A Tuple of (float, loss_fn), where the + float defines the weight of the distillation loss, and the loss_fn, for + computing distillation loss, takes dnn_logits, tree_logits and weight + tensor. If the entire tuple is None, no distillation will be applied. If + only the loss_fn is None, we will take the sigmoid/softmax cross entropy + loss be default. When distillation is applied, `predict_with_tree_only` + will be set to True. + """ + + def __init__(self, + dnn_hidden_units, + dnn_feature_columns, + tree_learner_config, + num_trees, + tree_examples_per_layer, + head, + model_dir=None, + config=None, + dnn_optimizer="Adagrad", + dnn_activation_fn=nn.relu, + dnn_dropout=None, + dnn_input_layer_partitioner=None, + dnn_input_layer_to_tree=True, + dnn_steps_to_train=10000, + predict_with_tree_only=False, + tree_feature_columns=None, + tree_center_bias=False, + dnn_to_tree_distillation_param=None): + + def _model_fn(features, labels, mode, config): + return _dnn_tree_combined_model_fn( + features=features, + labels=labels, + mode=mode, + head=head, + dnn_hidden_units=dnn_hidden_units, + dnn_feature_columns=dnn_feature_columns, + tree_learner_config=tree_learner_config, + num_trees=num_trees, + tree_examples_per_layer=tree_examples_per_layer, + config=config, + dnn_optimizer=dnn_optimizer, + dnn_activation_fn=dnn_activation_fn, + dnn_dropout=dnn_dropout, + dnn_input_layer_partitioner=dnn_input_layer_partitioner, + dnn_input_layer_to_tree=dnn_input_layer_to_tree, + dnn_steps_to_train=dnn_steps_to_train, + predict_with_tree_only=predict_with_tree_only, + tree_feature_columns=tree_feature_columns, + tree_center_bias=tree_center_bias, + dnn_to_tree_distillation_param=dnn_to_tree_distillation_param, + output_type=model.ModelBuilderOutputType.ESTIMATOR_SPEC, + use_core_versions=True, + override_global_step_value=None) + + super(CoreDNNBoostedTreeCombinedEstimator, self).__init__( + model_fn=_model_fn, model_dir=model_dir, config=config) diff --git a/tensorflow/contrib/boosted_trees/estimator_batch/dnn_tree_combined_estimator_test.py b/tensorflow/contrib/boosted_trees/estimator_batch/dnn_tree_combined_estimator_test.py index 9b7acfa664b0398216b5a7fb904960d8363929d6..839eedd3a87ccaa1faecd1966fe5907d682cac02 100644 --- a/tensorflow/contrib/boosted_trees/estimator_batch/dnn_tree_combined_estimator_test.py +++ b/tensorflow/contrib/boosted_trees/estimator_batch/dnn_tree_combined_estimator_test.py @@ -28,10 +28,11 @@ from tensorflow.python.estimator.canned import head as head_lib from tensorflow.python.feature_column import feature_column_lib as core_feature_column from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops from tensorflow.python.framework import test_util from tensorflow.python.ops.losses import losses from tensorflow.python.platform import googletest - +from tensorflow.python.training import checkpoint_utils def _train_input_fn(): features = { @@ -156,5 +157,72 @@ class DNNBoostedTreeCombinedTest(test_util.TensorFlowTestCase): classifier.evaluate(input_fn=_eval_input_fn, steps=1) +class CoreDNNBoostedTreeCombinedTest(test_util.TensorFlowTestCase): + + def _assert_checkpoint(self, model_dir, global_step): + reader = checkpoint_utils.load_checkpoint(model_dir) + self.assertEqual(global_step, reader.get_tensor(ops.GraphKeys.GLOBAL_STEP)) + + def testTrainEvaluateInferDoesNotThrowErrorWithNoDnnInput(self): + head_fn = head_lib._binary_logistic_head_with_sigmoid_cross_entropy_loss( + loss_reduction=losses.Reduction.SUM_OVER_NONZERO_WEIGHTS) + + learner_config = learner_pb2.LearnerConfig() + learner_config.num_classes = 2 + learner_config.constraints.max_tree_depth = 3 + model_dir = tempfile.mkdtemp() + config = run_config.RunConfig() + + est = estimator.CoreDNNBoostedTreeCombinedEstimator( + head=head_fn, + dnn_hidden_units=[1], + dnn_feature_columns=[core_feature_column.numeric_column("x")], + tree_learner_config=learner_config, + num_trees=1, + tree_examples_per_layer=3, + model_dir=model_dir, + config=config, + dnn_steps_to_train=10, + dnn_input_layer_to_tree=False, + tree_feature_columns=[core_feature_column.numeric_column("x")]) + + # Train for a few steps. + est.train(input_fn=_train_input_fn, steps=1000) + # 10 steps for dnn, 3 for 1 tree of depth 3 + 1 after the tree finished + self._assert_checkpoint(est.model_dir, global_step=14) + res = est.evaluate(input_fn=_eval_input_fn, steps=1) + self.assertLess(0.5, res["auc"]) + est.predict(input_fn=_eval_input_fn) + + def testTrainEvaluateInferDoesNotThrowErrorWithDnnInput(self): + head_fn = head_lib._binary_logistic_head_with_sigmoid_cross_entropy_loss( + loss_reduction=losses.Reduction.SUM_OVER_NONZERO_WEIGHTS) + + learner_config = learner_pb2.LearnerConfig() + learner_config.num_classes = 2 + learner_config.constraints.max_tree_depth = 3 + model_dir = tempfile.mkdtemp() + config = run_config.RunConfig() + + est = estimator.CoreDNNBoostedTreeCombinedEstimator( + head=head_fn, + dnn_hidden_units=[1], + dnn_feature_columns=[core_feature_column.numeric_column("x")], + tree_learner_config=learner_config, + num_trees=1, + tree_examples_per_layer=3, + model_dir=model_dir, + config=config, + dnn_steps_to_train=10, + dnn_input_layer_to_tree=True, + tree_feature_columns=[]) + + # Train for a few steps. + est.train(input_fn=_train_input_fn, steps=1000) + res = est.evaluate(input_fn=_eval_input_fn, steps=1) + self.assertLess(0.5, res["auc"]) + est.predict(input_fn=_eval_input_fn) + + if __name__ == "__main__": googletest.main() diff --git a/tensorflow/contrib/boosted_trees/estimator_batch/estimator.py b/tensorflow/contrib/boosted_trees/estimator_batch/estimator.py index 38fa8c38345f5006628b3b944d0c89d2df54f998..870ce2442bb5e98db7615c43054c9c827b8c88f0 100644 --- a/tensorflow/contrib/boosted_trees/estimator_batch/estimator.py +++ b/tensorflow/contrib/boosted_trees/estimator_batch/estimator.py @@ -22,8 +22,16 @@ from tensorflow.contrib.boosted_trees.estimator_batch import model from tensorflow.contrib.boosted_trees.python.utils import losses from tensorflow.contrib.learn.python.learn.estimators import estimator from tensorflow.contrib.learn.python.learn.estimators import head as head_lib +from tensorflow.python.estimator.canned import head as core_head_lib from tensorflow.python.estimator import estimator as core_estimator from tensorflow.python.ops import math_ops +from tensorflow.python.ops.losses import losses as core_losses + + +# ================== Old estimator interface=================================== +# The estimators below were designed for old feature columns and old estimator +# interface. They can be used with new feature columns and losses by setting +# use_core_libs = True. class GradientBoostedDecisionTreeClassifier(estimator.Estimator): @@ -43,7 +51,8 @@ class GradientBoostedDecisionTreeClassifier(estimator.Estimator): logits_modifier_function=None, center_bias=True, use_core_libs=False, - output_leaf_index=False): + output_leaf_index=False, + override_global_step_value=None): """Initializes a GradientBoostedDecisionTreeClassifier estimator instance. Args: @@ -77,6 +86,14 @@ class GradientBoostedDecisionTreeClassifier(estimator.Estimator): for result_dict in result_iter: # access leaf index list by result_dict["leaf_index"] # which contains one leaf index per tree + override_global_step_value: If after the training is done, global step + value must be reset to this value. This should be used to reset global + step to a number > number of steps used to train the current ensemble. + For example, the usual way is to train a number of trees and set a very + large number of training steps. When the training is done (number of + trees were trained), this parameter can be used to set the global step + to a large value, making it look like that number of training steps ran. + If None, no override of global step will happen. Raises: ValueError: If learner_config is not valid. @@ -117,6 +134,7 @@ class GradientBoostedDecisionTreeClassifier(estimator.Estimator): 'logits_modifier_function': logits_modifier_function, 'use_core_libs': use_core_libs, 'output_leaf_index': output_leaf_index, + 'override_global_step_value': override_global_step_value }, model_dir=model_dir, config=config, @@ -140,7 +158,8 @@ class GradientBoostedDecisionTreeRegressor(estimator.Estimator): logits_modifier_function=None, center_bias=True, use_core_libs=False, - output_leaf_index=False): + output_leaf_index=False, + override_global_step_value=None): """Initializes a GradientBoostedDecisionTreeRegressor estimator instance. Args: @@ -174,6 +193,14 @@ class GradientBoostedDecisionTreeRegressor(estimator.Estimator): for example_prediction_result in result_dict: # access leaf index list by example_prediction_result["leaf_index"] # which contains one leaf index per tree + override_global_step_value: If after the training is done, global step + value must be reset to this value. This should be used to reset global + step to a number > number of steps used to train the current ensemble. + For example, the usual way is to train a number of trees and set a very + large number of training steps. When the training is done (number of + trees were trained), this parameter can be used to set the global step + to a large value, making it look like that number of training steps ran. + If None, no override of global step will happen. """ head = head_lib.regression_head( label_name=label_name, @@ -197,6 +224,7 @@ class GradientBoostedDecisionTreeRegressor(estimator.Estimator): 'center_bias': center_bias, 'use_core_libs': use_core_libs, 'output_leaf_index': False, + 'override_global_step_value': override_global_step_value }, model_dir=model_dir, config=config, @@ -222,7 +250,8 @@ class GradientBoostedDecisionTreeEstimator(estimator.Estimator): logits_modifier_function=None, center_bias=True, use_core_libs=False, - output_leaf_index=False): + output_leaf_index=False, + override_global_step_value=None): """Initializes a GradientBoostedDecisionTreeEstimator estimator instance. Args: @@ -252,6 +281,14 @@ class GradientBoostedDecisionTreeEstimator(estimator.Estimator): for example_prediction_result in result_dict: # access leaf index list by example_prediction_result["leaf_index"] # which contains one leaf index per tree + override_global_step_value: If after the training is done, global step + value must be reset to this value. This should be used to reset global + step to a number > number of steps used to train the current ensemble. + For example, the usual way is to train a number of trees and set a very + large number of training steps. When the training is done (number of + trees were trained), this parameter can be used to set the global step + to a large value, making it look like that number of training steps ran. + If None, no override of global step will happen. """ super(GradientBoostedDecisionTreeEstimator, self).__init__( model_fn=model.model_builder, @@ -266,6 +303,7 @@ class GradientBoostedDecisionTreeEstimator(estimator.Estimator): 'center_bias': center_bias, 'use_core_libs': use_core_libs, 'output_leaf_index': False, + 'override_global_step_value': override_global_step_value }, model_dir=model_dir, config=config, @@ -275,24 +313,23 @@ class GradientBoostedDecisionTreeEstimator(estimator.Estimator): class GradientBoostedDecisionTreeRanker(estimator.Estimator): """A ranking estimator using gradient boosted decision trees.""" - def __init__( - self, - learner_config, - examples_per_layer, - head, - ranking_model_pair_keys, - num_trees=None, - feature_columns=None, - weight_column_name=None, - model_dir=None, - config=None, - label_keys=None, - feature_engineering_fn=None, - logits_modifier_function=None, - center_bias=False, - use_core_libs=False, - output_leaf_index=False, - ): + def __init__(self, + learner_config, + examples_per_layer, + head, + ranking_model_pair_keys, + num_trees=None, + feature_columns=None, + weight_column_name=None, + model_dir=None, + config=None, + label_keys=None, + feature_engineering_fn=None, + logits_modifier_function=None, + center_bias=False, + use_core_libs=False, + output_leaf_index=False, + override_global_step_value=None): """Initializes a GradientBoostedDecisionTreeRanker instance. This is an estimator that can be trained off the pairwise data and can be @@ -332,7 +369,14 @@ class GradientBoostedDecisionTreeRanker(estimator.Estimator): for result_dict in result_iter: # access leaf index list by result_dict["leaf_index"] # which contains one leaf index per tree - + override_global_step_value: If after the training is done, global step + value must be reset to this value. This should be used to reset global + step to a number > number of steps used to train the current ensemble. + For example, the usual way is to train a number of trees and set a very + large number of training steps. When the training is done (number of + trees were trained), this parameter can be used to set the global step + to a large value, making it look like that number of training steps ran. + If None, no override of global step will happen. Raises: ValueError: If learner_config is not valid. """ @@ -351,14 +395,41 @@ class GradientBoostedDecisionTreeRanker(estimator.Estimator): 'use_core_libs': use_core_libs, 'output_leaf_index': output_leaf_index, 'ranking_model_pair_keys': ranking_model_pair_keys, + 'override_global_step_value': override_global_step_value }, model_dir=model_dir, config=config, feature_engineering_fn=feature_engineering_fn) +# ================== New Estimator interface=================================== +# The estimators below use new core Estimator interface and must be used with +# new feature columns and heads. + +# For multiclass classification, use the following head since it uses loss +# that is twice differentiable. +def core_multiclass_head(n_classes): + """Core head for multiclass problems.""" + + def loss_fn(labels, logits): + result = losses.per_example_maxent_loss( + labels=labels, logits=logits, weights=None, num_classes=n_classes) + return result[0] + + # pylint:disable=protected-access + head_fn = core_head_lib._multi_class_head_with_softmax_cross_entropy_loss( + n_classes=n_classes, + loss_fn=loss_fn, + loss_reduction=core_losses.Reduction.SUM_OVER_NONZERO_WEIGHTS) + # pylint:enable=protected-access + + return head_fn + class CoreGradientBoostedDecisionTreeEstimator(core_estimator.Estimator): - """An estimator using gradient boosted decision trees.""" + """An estimator using gradient boosted decision trees. + + Useful for training with user specified `Head`. + """ def __init__(self, learner_config, @@ -374,6 +445,36 @@ class CoreGradientBoostedDecisionTreeEstimator(core_estimator.Estimator): logits_modifier_function=None, center_bias=True, output_leaf_index=False): + """Initializes a core version of GradientBoostedDecisionTreeEstimator. + + Args: + learner_config: A config for the learner. + examples_per_layer: Number of examples to accumulate before growing a + layer. It can also be a function that computes the number of examples + based on the depth of the layer that's being built. + head: `Head` instance. + num_trees: An int, number of trees to build. + feature_columns: A list of feature columns. + weight_column_name: Name of the column for weights, or None if not + weighted. + model_dir: Directory for model exports, etc. + config: `RunConfig` object to configure the runtime settings. + label_keys: Optional list of strings with size `[n_classes]` defining the + label vocabulary. Only supported for `n_classes` > 2. + feature_engineering_fn: Feature engineering function. Takes features and + labels which are the output of `input_fn` and returns features and + labels which will be fed into the model. + logits_modifier_function: A modifier function for the logits. + center_bias: Whether a separate tree should be created for first fitting + the bias. + output_leaf_index: whether to output leaf indices along with predictions + during inference. The leaf node indexes are available in predictions + dict by the key 'leaf_index'. For example, + result_dict = classifier.predict(...) + for example_prediction_result in result_dict: + # access leaf index list by example_prediction_result["leaf_index"] + # which contains one leaf index per tree + """ def _model_fn(features, labels, mode, config): return model.model_builder( @@ -392,8 +493,92 @@ class CoreGradientBoostedDecisionTreeEstimator(core_estimator.Estimator): 'logits_modifier_function': logits_modifier_function, 'use_core_libs': True, 'output_leaf_index': output_leaf_index, + 'override_global_step_value': None }, output_type=model.ModelBuilderOutputType.ESTIMATOR_SPEC) super(CoreGradientBoostedDecisionTreeEstimator, self).__init__( model_fn=_model_fn, model_dir=model_dir, config=config) + + +class CoreGradientBoostedDecisionTreeRanker(core_estimator.Estimator): + """A ranking estimator using gradient boosted decision trees.""" + + def __init__(self, + learner_config, + examples_per_layer, + head, + ranking_model_pair_keys, + num_trees=None, + feature_columns=None, + weight_column_name=None, + model_dir=None, + config=None, + label_keys=None, + logits_modifier_function=None, + center_bias=False, + output_leaf_index=False): + """Initializes a GradientBoostedDecisionTreeRanker instance. + + This is an estimator that can be trained off the pairwise data and can be + used for inference on non-paired data. This is essentially LambdaMart. + Args: + learner_config: A config for the learner. + examples_per_layer: Number of examples to accumulate before growing a + layer. It can also be a function that computes the number of examples + based on the depth of the layer that's being built. + head: `Head` instance. + ranking_model_pair_keys: Keys to distinguish between features + for left and right part of the training pairs for ranking. For example, + for an Example with features "a.f1" and "b.f1", the keys would be + ("a", "b"). + num_trees: An int, number of trees to build. + feature_columns: A list of feature columns. + weight_column_name: Name of the column for weights, or None if not + weighted. + model_dir: Directory for model exports, etc. + config: `RunConfig` object to configure the runtime settings. + label_keys: Optional list of strings with size `[n_classes]` defining the + label vocabulary. Only supported for `n_classes` > 2. + logits_modifier_function: A modifier function for the logits. + center_bias: Whether a separate tree should be created for first fitting + the bias. + output_leaf_index: whether to output leaf indices along with predictions + during inference. The leaf node indexes are available in predictions + dict by the key 'leaf_index'. It is a Tensor of rank 2 and its shape is + [batch_size, num_trees]. + For example, + result_iter = classifier.predict(...) + for result_dict in result_iter: + # access leaf index list by result_dict["leaf_index"] + # which contains one leaf index per tree + + Raises: + ValueError: If learner_config is not valid. + """ + + def _model_fn(features, labels, mode, config): + return model.ranking_model_builder( + features=features, + labels=labels, + mode=mode, + config=config, + params={ + 'head': head, + 'n_classes': 2, + 'feature_columns': feature_columns, + 'learner_config': learner_config, + 'num_trees': num_trees, + 'weight_column_name': weight_column_name, + 'examples_per_layer': examples_per_layer, + 'center_bias': center_bias, + 'logits_modifier_function': logits_modifier_function, + 'use_core_libs': True, + 'output_leaf_index': output_leaf_index, + 'ranking_model_pair_keys': ranking_model_pair_keys, + 'override_global_step_value': None + }, + output_type=model.ModelBuilderOutputType.ESTIMATOR_SPEC) + + super(CoreGradientBoostedDecisionTreeRanker, self).__init__( + model_fn=_model_fn, model_dir=model_dir, config=config) diff --git a/tensorflow/contrib/boosted_trees/estimator_batch/estimator_test.py b/tensorflow/contrib/boosted_trees/estimator_batch/estimator_test.py index f787d3cdb81febded62e472549ec98250a0393ff..c155128c0e4ccf928349ee6453baff4384222096 100644 --- a/tensorflow/contrib/boosted_trees/estimator_batch/estimator_test.py +++ b/tensorflow/contrib/boosted_trees/estimator_batch/estimator_test.py @@ -16,7 +16,10 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function + import tempfile +import numpy as np + from tensorflow.contrib.boosted_trees.estimator_batch import estimator from tensorflow.contrib.boosted_trees.proto import learner_pb2 from tensorflow.contrib.layers.python.layers import feature_column as contrib_feature_column @@ -25,10 +28,13 @@ from tensorflow.python.estimator.canned import head as head_lib from tensorflow.python.feature_column import feature_column_lib as core_feature_column from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.framework import sparse_tensor from tensorflow.python.framework import test_util from tensorflow.python.ops.losses import losses from tensorflow.python.platform import gfile from tensorflow.python.platform import googletest +from tensorflow.python.training import checkpoint_utils def _train_input_fn(): @@ -37,6 +43,15 @@ def _train_input_fn(): return features, label +def _multiclass_train_input_fn(): + features = { + "x": constant_op.constant([[2.], [1.], [1.], [5.], [3.5], [4.6], [3.5]]) + } + label = constant_op.constant( + [[1], [0], [0], [2], [2], [0], [1]], dtype=dtypes.int32) + return features, label + + def _ranking_train_input_fn(): features = { "a.f1": constant_op.constant([[3.], [0.3], [1.]]), @@ -68,6 +83,10 @@ class BoostedTreeEstimatorTest(test_util.TensorFlowTestCase): self._export_dir_base = tempfile.mkdtemp() + "export/" gfile.MkDir(self._export_dir_base) + def _assert_checkpoint(self, model_dir, global_step): + reader = checkpoint_utils.load_checkpoint(model_dir) + self.assertEqual(global_step, reader.get_tensor(ops.GraphKeys.GLOBAL_STEP)) + def testFitAndEvaluateDontThrowException(self): learner_config = learner_pb2.LearnerConfig() learner_config.num_classes = 2 @@ -202,8 +221,128 @@ class BoostedTreeEstimatorTest(test_util.TensorFlowTestCase): model.evaluate(input_fn=_ranking_train_input_fn, steps=1) model.predict(input_fn=_infer_ranking_train_input_fn) + def testDoesNotOverrideGlobalSteps(self): + learner_config = learner_pb2.LearnerConfig() + learner_config.num_classes = 2 + learner_config.constraints.max_tree_depth = 2 + model_dir = tempfile.mkdtemp() + config = run_config.RunConfig() + + classifier = estimator.GradientBoostedDecisionTreeClassifier( + learner_config=learner_config, + num_trees=1, + examples_per_layer=3, + model_dir=model_dir, + config=config, + feature_columns=[contrib_feature_column.real_valued_column("x")], + output_leaf_index=False) + + classifier.fit(input_fn=_train_input_fn, steps=15) + # When no override of global steps, 5 steps were used. + self._assert_checkpoint(classifier.model_dir, global_step=5) + + def testOverridesGlobalSteps(self): + learner_config = learner_pb2.LearnerConfig() + learner_config.num_classes = 2 + learner_config.constraints.max_tree_depth = 2 + model_dir = tempfile.mkdtemp() + config = run_config.RunConfig() + + classifier = estimator.GradientBoostedDecisionTreeClassifier( + learner_config=learner_config, + num_trees=1, + examples_per_layer=3, + model_dir=model_dir, + config=config, + feature_columns=[contrib_feature_column.real_valued_column("x")], + output_leaf_index=False, + override_global_step_value=10000000) + + classifier.fit(input_fn=_train_input_fn, steps=15) + self._assert_checkpoint(classifier.model_dir, global_step=10000000) + + def testFitAndEvaluateMultiClassTreePerClassDontThrowException(self): + learner_config = learner_pb2.LearnerConfig() + learner_config.num_classes = 3 + learner_config.constraints.max_tree_depth = 1 + learner_config.multi_class_strategy = ( + learner_pb2.LearnerConfig.TREE_PER_CLASS) + + model_dir = tempfile.mkdtemp() + config = run_config.RunConfig() + + classifier = estimator.GradientBoostedDecisionTreeClassifier( + learner_config=learner_config, + n_classes=learner_config.num_classes, + num_trees=1, + examples_per_layer=7, + model_dir=model_dir, + config=config, + feature_columns=[contrib_feature_column.real_valued_column("x")]) + + classifier.fit(input_fn=_multiclass_train_input_fn, steps=100) + classifier.evaluate(input_fn=_eval_input_fn, steps=1) + classifier.export(self._export_dir_base) + result_iter = classifier.predict(input_fn=_eval_input_fn) + for prediction_dict in result_iter: + self.assertTrue("classes" in prediction_dict) + + def testFitAndEvaluateMultiClassDiagonalDontThrowException(self): + learner_config = learner_pb2.LearnerConfig() + learner_config.num_classes = 3 + learner_config.constraints.max_tree_depth = 1 + learner_config.multi_class_strategy = ( + learner_pb2.LearnerConfig.DIAGONAL_HESSIAN) + + model_dir = tempfile.mkdtemp() + config = run_config.RunConfig() + + classifier = estimator.GradientBoostedDecisionTreeClassifier( + learner_config=learner_config, + n_classes=learner_config.num_classes, + num_trees=1, + examples_per_layer=7, + model_dir=model_dir, + config=config, + center_bias=False, + feature_columns=[contrib_feature_column.real_valued_column("x")]) + + classifier.fit(input_fn=_multiclass_train_input_fn, steps=100) + classifier.evaluate(input_fn=_eval_input_fn, steps=1) + classifier.export(self._export_dir_base) + result_iter = classifier.predict(input_fn=_eval_input_fn) + for prediction_dict in result_iter: + self.assertTrue("classes" in prediction_dict) + + def testFitAndEvaluateMultiClassFullDontThrowException(self): + learner_config = learner_pb2.LearnerConfig() + learner_config.num_classes = 3 + learner_config.constraints.max_tree_depth = 1 + learner_config.multi_class_strategy = ( + learner_pb2.LearnerConfig.FULL_HESSIAN) + + model_dir = tempfile.mkdtemp() + config = run_config.RunConfig() + + classifier = estimator.GradientBoostedDecisionTreeClassifier( + learner_config=learner_config, + n_classes=learner_config.num_classes, + num_trees=1, + examples_per_layer=7, + model_dir=model_dir, + config=config, + center_bias=False, + feature_columns=[contrib_feature_column.real_valued_column("x")]) + + classifier.fit(input_fn=_multiclass_train_input_fn, steps=100) + classifier.evaluate(input_fn=_eval_input_fn, steps=1) + classifier.export(self._export_dir_base) + result_iter = classifier.predict(input_fn=_eval_input_fn) + for prediction_dict in result_iter: + self.assertTrue("classes" in prediction_dict) + -class CoreGradientBoostedDecisionTreeEstimator(test_util.TensorFlowTestCase): +class CoreGradientBoostedDecisionTreeEstimators(test_util.TensorFlowTestCase): def testTrainEvaluateInferDoesNotThrowError(self): head_fn = head_lib._binary_logistic_head_with_sigmoid_cross_entropy_loss( @@ -229,6 +368,172 @@ class CoreGradientBoostedDecisionTreeEstimator(test_util.TensorFlowTestCase): est.evaluate(input_fn=_eval_input_fn, steps=1) est.predict(input_fn=_eval_input_fn) + def testRankingDontThrowExceptionForForEstimator(self): + learner_config = learner_pb2.LearnerConfig() + learner_config.num_classes = 2 + learner_config.constraints.max_tree_depth = 1 + model_dir = tempfile.mkdtemp() + config = run_config.RunConfig() + + head_fn = head_lib._binary_logistic_head_with_sigmoid_cross_entropy_loss( + loss_reduction=losses.Reduction.SUM_OVER_NONZERO_WEIGHTS) + + est = estimator.CoreGradientBoostedDecisionTreeRanker( + head=head_fn, + learner_config=learner_config, + num_trees=1, + examples_per_layer=3, + model_dir=model_dir, + config=config, + feature_columns=[ + core_feature_column.numeric_column("f1"), + core_feature_column.numeric_column("f2") + ], + ranking_model_pair_keys=("a", "b")) + + # Train for a few steps. + est.train(input_fn=_ranking_train_input_fn, steps=1000) + est.evaluate(input_fn=_ranking_train_input_fn, steps=1) + est.predict(input_fn=_infer_ranking_train_input_fn) + + def testFitAndEvaluateMultiClassTreePerClasssDontThrowException(self): + n_classes = 3 + learner_config = learner_pb2.LearnerConfig() + learner_config.num_classes = n_classes + learner_config.constraints.max_tree_depth = 1 + learner_config.multi_class_strategy = ( + learner_pb2.LearnerConfig.TREE_PER_CLASS) + + head_fn = estimator.core_multiclass_head(n_classes=n_classes) + + model_dir = tempfile.mkdtemp() + config = run_config.RunConfig() + + classifier = estimator.CoreGradientBoostedDecisionTreeEstimator( + learner_config=learner_config, + head=head_fn, + num_trees=1, + center_bias=False, + examples_per_layer=7, + model_dir=model_dir, + config=config, + feature_columns=[core_feature_column.numeric_column("x")]) + + classifier.train(input_fn=_multiclass_train_input_fn, steps=100) + classifier.evaluate(input_fn=_multiclass_train_input_fn, steps=1) + classifier.predict(input_fn=_eval_input_fn) + + def testFitAndEvaluateMultiClassDiagonalDontThrowException(self): + n_classes = 3 + learner_config = learner_pb2.LearnerConfig() + learner_config.num_classes = n_classes + learner_config.constraints.max_tree_depth = 1 + learner_config.multi_class_strategy = ( + learner_pb2.LearnerConfig.DIAGONAL_HESSIAN) + + head_fn = estimator.core_multiclass_head(n_classes=n_classes) + + model_dir = tempfile.mkdtemp() + config = run_config.RunConfig() + + classifier = estimator.CoreGradientBoostedDecisionTreeEstimator( + learner_config=learner_config, + head=head_fn, + num_trees=1, + center_bias=False, + examples_per_layer=7, + model_dir=model_dir, + config=config, + feature_columns=[core_feature_column.numeric_column("x")]) + + classifier.train(input_fn=_multiclass_train_input_fn, steps=100) + classifier.evaluate(input_fn=_multiclass_train_input_fn, steps=1) + classifier.predict(input_fn=_eval_input_fn) + + def testFitAndEvaluateMultiClassFullDontThrowException(self): + n_classes = 3 + learner_config = learner_pb2.LearnerConfig() + learner_config.num_classes = n_classes + learner_config.constraints.max_tree_depth = 1 + learner_config.multi_class_strategy = ( + learner_pb2.LearnerConfig.FULL_HESSIAN) + + head_fn = estimator.core_multiclass_head(n_classes=n_classes) + + model_dir = tempfile.mkdtemp() + config = run_config.RunConfig() + + classifier = estimator.CoreGradientBoostedDecisionTreeEstimator( + learner_config=learner_config, + head=head_fn, + num_trees=1, + center_bias=False, + examples_per_layer=7, + model_dir=model_dir, + config=config, + feature_columns=[core_feature_column.numeric_column("x")]) + + classifier.train(input_fn=_multiclass_train_input_fn, steps=100) + classifier.evaluate(input_fn=_multiclass_train_input_fn, steps=1) + classifier.predict(input_fn=_eval_input_fn) + + def testWeightedCategoricalColumn(self): + head_fn = head_lib._binary_logistic_head_with_sigmoid_cross_entropy_loss( + loss_reduction=losses.Reduction.SUM_OVER_NONZERO_WEIGHTS) + + learner_config = learner_pb2.LearnerConfig() + learner_config.num_classes = 2 + learner_config.constraints.max_tree_depth = 1 + model_dir = tempfile.mkdtemp() + config = run_config.RunConfig() + + feature_columns = [ + core_feature_column.weighted_categorical_column( + categorical_column=core_feature_column. + categorical_column_with_vocabulary_list( + key="word", vocabulary_list=["the", "cat", "dog"]), + weight_feature_key="weight") + ] + + labels = np.array([[1], [1], [0], [0.]], dtype=np.float32) + + def _make_input_fn(): + + def _input_fn(): + features_dict = {} + # Sparse tensor representing + # example 0: "cat","the" + # examaple 1: "dog" + # example 2: - + # example 3: "the" + # Weights for the words are 5 - cat, 6- dog and 1 -the. + features_dict["word"] = sparse_tensor.SparseTensor( + indices=[[0, 0], [0, 1], [1, 0], [3, 0]], + values=constant_op.constant( + ["the", "cat", "dog", "the"], dtype=dtypes.string), + dense_shape=[4, 3]) + features_dict["weight"] = sparse_tensor.SparseTensor( + indices=[[0, 0], [0, 1], [1, 0], [3, 0]], + values=[1., 5., 6., 1.], + dense_shape=[4, 3]) + return features_dict, labels + + return _input_fn + + est = estimator.CoreGradientBoostedDecisionTreeEstimator( + head=head_fn, + learner_config=learner_config, + num_trees=1, + examples_per_layer=3, + model_dir=model_dir, + config=config, + feature_columns=feature_columns) + + input_fn = _make_input_fn() + est.train(input_fn=input_fn, steps=100) + est.evaluate(input_fn=input_fn, steps=1) + est.predict(input_fn=input_fn) + if __name__ == "__main__": googletest.main() diff --git a/tensorflow/contrib/boosted_trees/estimator_batch/model.py b/tensorflow/contrib/boosted_trees/estimator_batch/model.py index 2fbe72951a559808ee2ee12a2efa07b1d857883a..04b46c3483fa25286078b88c2776b76e4f3c0bcf 100644 --- a/tensorflow/contrib/boosted_trees/estimator_batch/model.py +++ b/tensorflow/contrib/boosted_trees/estimator_batch/model.py @@ -58,7 +58,13 @@ def model_builder(features, * weight_column_name: The name of weight column. * center_bias: Whether a separate tree should be created for first fitting the bias. + * override_global_step_value: If after the training is done, global step + value must be reset to this value. This is particularly useful for hyper + parameter tuning, which can't recognize early stopping due to the number + of trees. If None, no override of global step will happen. config: `RunConfig` of the estimator. + output_type: Whether to return ModelFnOps (old interface) or EstimatorSpec + (new interface). Returns: A `ModelFnOps` object. @@ -74,6 +80,7 @@ def model_builder(features, use_core_libs = params["use_core_libs"] logits_modifier_function = params["logits_modifier_function"] output_leaf_index = params["output_leaf_index"] + override_global_step_value = params.get("override_global_step_value", None) if features is None: raise ValueError("At least one feature must be specified.") @@ -126,14 +133,16 @@ def model_builder(features, create_estimator_spec_op = getattr(head, "create_estimator_spec", None) + training_hooks = [] if num_trees: if center_bias: num_trees += 1 + finalized_trees, attempted_trees = gbdt_model.get_number_of_trees_tensor() - training_hooks = [ + training_hooks.append( trainer_hooks.StopAfterNTrees(num_trees, attempted_trees, - finalized_trees) - ] + finalized_trees, + override_global_step_value)) if output_type == ModelBuilderOutputType.MODEL_FN_OPS: if use_core_libs and callable(create_estimator_spec_op): @@ -175,7 +184,12 @@ def model_builder(features, return model_fn_ops -def ranking_model_builder(features, labels, mode, params, config): +def ranking_model_builder(features, + labels, + mode, + params, + config, + output_type=ModelBuilderOutputType.MODEL_FN_OPS): """Multi-machine batch gradient descent tree model for ranking. Args: @@ -198,7 +212,14 @@ def ranking_model_builder(features, labels, mode, params, config): for left and right part of the training pairs for ranking. For example, for an Example with features "a.f1" and "b.f1", the keys would be ("a", "b"). + * override_global_step_value: If after the training is done, global step + value must be reset to this value. This is particularly useful for hyper + parameter tuning, which can't recognize early stopping due to the number + of trees. If None, no override of global step will happen. config: `RunConfig` of the estimator. + output_type: Whether to return ModelFnOps (old interface) or EstimatorSpec + (new interface). + Returns: A `ModelFnOps` object. @@ -215,6 +236,7 @@ def ranking_model_builder(features, labels, mode, params, config): logits_modifier_function = params["logits_modifier_function"] output_leaf_index = params["output_leaf_index"] ranking_model_pair_keys = params["ranking_model_pair_keys"] + override_global_step_value = params.get("override_global_step_value", None) if features is None: raise ValueError("At least one feature must be specified.") @@ -326,31 +348,55 @@ def ranking_model_builder(features, labels, mode, params, config): return update_op create_estimator_spec_op = getattr(head, "create_estimator_spec", None) - if use_core_libs and callable(create_estimator_spec_op): - model_fn_ops = head.create_estimator_spec( - features=features, - mode=mode, - labels=labels, - train_op_fn=_train_op_fn, - logits=logits) - model_fn_ops = estimator_utils.estimator_spec_to_model_fn_ops(model_fn_ops) - else: - model_fn_ops = head.create_model_fn_ops( - features=features, - mode=mode, - labels=labels, - train_op_fn=_train_op_fn, - logits=logits) - if output_leaf_index and gbdt_batch.LEAF_INDEX in predictions_dict: - model_fn_ops.predictions[gbdt_batch.LEAF_INDEX] = predictions_dict[ - gbdt_batch.LEAF_INDEX] + training_hooks = [] if num_trees: if center_bias: num_trees += 1 + finalized_trees, attempted_trees = ( gbdt_model_main.get_number_of_trees_tensor()) - model_fn_ops.training_hooks.append( + training_hooks.append( trainer_hooks.StopAfterNTrees(num_trees, attempted_trees, - finalized_trees)) + finalized_trees, + override_global_step_value)) + + if output_type == ModelBuilderOutputType.MODEL_FN_OPS: + if use_core_libs and callable(create_estimator_spec_op): + model_fn_ops = head.create_estimator_spec( + features=features, + mode=mode, + labels=labels, + train_op_fn=_train_op_fn, + logits=logits) + model_fn_ops = estimator_utils.estimator_spec_to_model_fn_ops( + model_fn_ops) + else: + model_fn_ops = head.create_model_fn_ops( + features=features, + mode=mode, + labels=labels, + train_op_fn=_train_op_fn, + logits=logits) + + if output_leaf_index and gbdt_batch.LEAF_INDEX in predictions_dict: + model_fn_ops.predictions[gbdt_batch.LEAF_INDEX] = predictions_dict[ + gbdt_batch.LEAF_INDEX] + + model_fn_ops.training_hooks.extend(training_hooks) + return model_fn_ops + + elif output_type == ModelBuilderOutputType.ESTIMATOR_SPEC: + assert callable(create_estimator_spec_op) + estimator_spec = head.create_estimator_spec( + features=features, + mode=mode, + labels=labels, + train_op_fn=_train_op_fn, + logits=logits) + + estimator_spec = estimator_spec._replace( + training_hooks=training_hooks + list(estimator_spec.training_hooks)) + return estimator_spec + return model_fn_ops diff --git a/tensorflow/contrib/boosted_trees/estimator_batch/trainer_hooks.py b/tensorflow/contrib/boosted_trees/estimator_batch/trainer_hooks.py index 2e4151cac40f770e2bece70d752122eb7f34dd40..f137ada35524bf2467314f4a284ea35a82f06825 100644 --- a/tensorflow/contrib/boosted_trees/estimator_batch/trainer_hooks.py +++ b/tensorflow/contrib/boosted_trees/estimator_batch/trainer_hooks.py @@ -25,6 +25,7 @@ from tensorflow.contrib.learn.python.learn.session_run_hook import SessionRunArg from tensorflow.core.framework.summary_pb2 import Summary from tensorflow.python.framework import ops from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import state_ops from tensorflow.python.platform import tf_logging as logging from tensorflow.python.training import training_util from tensorflow.python.training.summary_io import SummaryWriterCache @@ -150,12 +151,23 @@ class FeedFnHook(session_run_hook.SessionRunHook): class StopAfterNTrees(session_run_hook.SessionRunHook): """Stop training after building N full trees.""" - def __init__(self, n, num_attempted_trees_tensor, num_finalized_trees_tensor): + def __init__(self, n, num_attempted_trees_tensor, num_finalized_trees_tensor, + override_global_step_value=None): self._num_trees = n # num_attempted_trees_tensor and num_finalized_trees_tensor are both # tensors. self._num_attempted_trees_tensor = num_attempted_trees_tensor self._num_finalized_trees_tensor = num_finalized_trees_tensor + self._override_global_step_value = override_global_step_value + + def begin(self): + self._global_step_tensor = training_util.get_global_step() + if self._global_step_tensor is None: + raise RuntimeError("Global step should be created.") + + if self._override_global_step_value is not None: + self._override_global_step_op = state_ops.assign( + self._global_step_tensor, self._override_global_step_value) def before_run(self, run_context): del run_context # unused by StopTrainingAfterNTrees. @@ -175,6 +187,9 @@ class StopAfterNTrees(session_run_hook.SessionRunHook): num_attempted_trees > 2 * self._num_trees): logging.info("Requesting stop since we have reached %d trees.", num_finalized_trees) + if self._override_global_step_value is not None: + logging.info("Overriding global steps value.") + run_context.session.run(self._override_global_step_op) run_context.request_stop() diff --git a/tensorflow/contrib/boosted_trees/kernels/quantile_ops.cc b/tensorflow/contrib/boosted_trees/kernels/quantile_ops.cc index 5b4be2f25838d5405a8148ea20cb0f759cd3a8fb..1375fddf2bea1a8f856c35d756c38a8beb14a53f 100644 --- a/tensorflow/contrib/boosted_trees/kernels/quantile_ops.cc +++ b/tensorflow/contrib/boosted_trees/kernels/quantile_ops.cc @@ -125,6 +125,8 @@ void QuantizeFeatures( auto flat_values = values_tensor.flat(); for (int64 instance = 0; instance < num_values; ++instance) { const float value = flat_values(instance); + CHECK(!buckets_vector.empty()) + << "Got empty buckets for feature " << feature_index; auto bucket_iter = std::lower_bound(buckets_vector.begin(), buckets_vector.end(), value); if (bucket_iter == buckets_vector.end()) { diff --git a/tensorflow/contrib/boosted_trees/kernels/split_handler_ops.cc b/tensorflow/contrib/boosted_trees/kernels/split_handler_ops.cc index 401bec84a20a0fefcddbfa1039a117e65f853633..d9e7a0f4660470a0c79ad7a832db233481161770 100644 --- a/tensorflow/contrib/boosted_trees/kernels/split_handler_ops.cc +++ b/tensorflow/contrib/boosted_trees/kernels/split_handler_ops.cc @@ -34,7 +34,9 @@ namespace tensorflow { +using boosted_trees::learner::LearnerConfig; using boosted_trees::learner::LearnerConfig_MultiClassStrategy; +using boosted_trees::learner::ObliviousSplitInfo; using boosted_trees::learner::SplitInfo; using boosted_trees::learner::stochastic::GradientStats; using boosted_trees::learner::stochastic::NodeStats; @@ -158,6 +160,11 @@ class BuildDenseInequalitySplitsOp : public OpKernel { const Tensor* hessians_t; OP_REQUIRES_OK(context, context->input("hessians", &hessians_t)); + const Tensor* weak_learner_type_t; + OP_REQUIRES_OK(context, + context->input("weak_learner_type", &weak_learner_type_t)); + const int32 weak_learner_type = weak_learner_type_t->scalar()(); + // Find the number of unique partitions before we allocate the output. std::vector partition_boundaries; partition_boundaries.push_back(0); @@ -188,20 +195,59 @@ class BuildDenseInequalitySplitsOp : public OpKernel { tensorflow::TTypes::Vec output_partition_ids = output_partition_ids_t->vec(); - Tensor* gains_t = nullptr; - OP_REQUIRES_OK( - context, context->allocate_output("gains", TensorShape({num_elements}), - &gains_t)); + // For a normal tree, we output a split per partition. For an oblivious + // tree, we output one split for all partitions of the layer + int32 size_output = num_elements; + if (weak_learner_type == LearnerConfig::OBLIVIOUS_DECISION_TREE && + num_elements > 0) { + size_output = 1; + } + Tensor* gains_t = nullptr; + OP_REQUIRES_OK(context, context->allocate_output( + "gains", TensorShape({size_output}), &gains_t)); tensorflow::TTypes::Vec gains = gains_t->vec(); Tensor* output_splits_t = nullptr; - OP_REQUIRES_OK(context, context->allocate_output( - "split_infos", TensorShape({num_elements}), - &output_splits_t)); + OP_REQUIRES_OK(context, context->allocate_output("split_infos", + TensorShape({size_output}), + &output_splits_t)); tensorflow::TTypes::Vec output_splits = output_splits_t->vec(); + + if (num_elements == 0) { + return; + } SplitBuilderState state(context); + switch (weak_learner_type) { + case LearnerConfig::NORMAL_DECISION_TREE: { + ComputeNormalDecisionTree( + &state, normalizer_ratio, num_elements, partition_boundaries, + bucket_boundaries, partition_ids, bucket_ids, gradients_t, + hessians_t, &output_partition_ids, &gains, &output_splits); + break; + } + case LearnerConfig::OBLIVIOUS_DECISION_TREE: { + ComputeObliviousDecisionTree( + &state, normalizer_ratio, num_elements, partition_boundaries, + bucket_boundaries, partition_ids, bucket_ids, gradients_t, + hessians_t, &output_partition_ids, &gains, &output_splits); + break; + } + } + } + + private: + void ComputeNormalDecisionTree( + SplitBuilderState* state, const float normalizer_ratio, + const int num_elements, const std::vector& partition_boundaries, + const tensorflow::TTypes::ConstVec& bucket_boundaries, + const tensorflow::TTypes::ConstVec& partition_ids, + const tensorflow::TTypes::ConstMatrix& bucket_ids, + const Tensor* gradients_t, const Tensor* hessians_t, + tensorflow::TTypes::Vec* output_partition_ids, + tensorflow::TTypes::Vec* gains, + tensorflow::TTypes::Vec* output_splits) { for (int root_idx = 0; root_idx < num_elements; ++root_idx) { float best_gain = std::numeric_limits::lowest(); int start_index = partition_boundaries[root_idx]; @@ -213,7 +259,7 @@ class BuildDenseInequalitySplitsOp : public OpKernel { GradientStats(*gradients_t, *hessians_t, bucket_idx); } root_gradient_stats *= normalizer_ratio; - NodeStats root_stats = state.ComputeNodeStats(root_gradient_stats); + NodeStats root_stats = state->ComputeNodeStats(root_gradient_stats); int32 best_bucket_idx = 0; NodeStats best_right_node_stats(0); NodeStats best_left_node_stats(0); @@ -223,10 +269,10 @@ class BuildDenseInequalitySplitsOp : public OpKernel { GradientStats g(*gradients_t, *hessians_t, bucket_idx); g *= normalizer_ratio; left_gradient_stats += g; - NodeStats left_stats = state.ComputeNodeStats(left_gradient_stats); + NodeStats left_stats = state->ComputeNodeStats(left_gradient_stats); GradientStats right_gradient_stats = root_gradient_stats - left_gradient_stats; - NodeStats right_stats = state.ComputeNodeStats(right_gradient_stats); + NodeStats right_stats = state->ComputeNodeStats(right_gradient_stats); if (left_stats.gain + right_stats.gain > best_gain) { best_gain = left_stats.gain + right_stats.gain; best_left_node_stats = left_stats; @@ -237,20 +283,124 @@ class BuildDenseInequalitySplitsOp : public OpKernel { SplitInfo split_info; auto* dense_split = split_info.mutable_split_node()->mutable_dense_float_binary_split(); - dense_split->set_feature_column(state.feature_column_group_id()); + dense_split->set_feature_column(state->feature_column_group_id()); dense_split->set_threshold( bucket_boundaries(bucket_ids(best_bucket_idx, 0))); auto* left_child = split_info.mutable_left_child(); auto* right_child = split_info.mutable_right_child(); - state.FillLeaf(best_left_node_stats, left_child); - state.FillLeaf(best_right_node_stats, right_child); - split_info.SerializeToString(&output_splits(root_idx)); - gains(root_idx) = - best_gain - root_stats.gain - state.tree_complexity_regularization(); - output_partition_ids(root_idx) = partition_ids(start_index); + state->FillLeaf(best_left_node_stats, left_child); + state->FillLeaf(best_right_node_stats, right_child); + split_info.SerializeToString(&(*output_splits)(root_idx)); + (*gains)(root_idx) = + best_gain - root_stats.gain - state->tree_complexity_regularization(); + (*output_partition_ids)(root_idx) = partition_ids(start_index); + } + } + void ComputeObliviousDecisionTree( + SplitBuilderState* state, const float normalizer_ratio, + const int num_elements, const std::vector& partition_boundaries, + const tensorflow::TTypes::ConstVec& bucket_boundaries, + const tensorflow::TTypes::ConstVec& partition_ids, + const tensorflow::TTypes::ConstMatrix& bucket_ids, + const Tensor* gradients_t, const Tensor* hessians_t, + tensorflow::TTypes::Vec* output_partition_ids, + tensorflow::TTypes::Vec* gains, + tensorflow::TTypes::Vec* output_splits) { + // Holds the root stats per each node to be split. + std::vector current_layer_stats; + current_layer_stats.reserve(num_elements); + for (int root_idx = 0; root_idx < num_elements; root_idx++) { + const int start_index = partition_boundaries[root_idx]; + const int end_index = partition_boundaries[root_idx + 1]; + GradientStats root_gradient_stats; + for (int64 bucket_idx = start_index; bucket_idx < end_index; + ++bucket_idx) { + root_gradient_stats += + GradientStats(*gradients_t, *hessians_t, bucket_idx); + } + root_gradient_stats *= normalizer_ratio; + current_layer_stats.push_back(root_gradient_stats); + } + + float best_gain = std::numeric_limits::lowest(); + int64 best_bucket_idx = 0; + std::vector best_right_node_stats(num_elements, NodeStats(0)); + std::vector best_left_node_stats(num_elements, NodeStats(0)); + std::vector current_left_node_stats(num_elements, NodeStats(0)); + std::vector current_right_node_stats(num_elements, NodeStats(0)); + int64 current_bucket_id = 0; + int64 last_bucket_id = -1; + // Indexes offsets for each of the partitions that can be used to access + // gradients of a partition for a current bucket we consider. + std::vector current_layer_offsets(num_elements, 0); + std::vector left_gradient_stats(num_elements); + // The idea is to try every bucket id in increasing order. In each iteration + // we calculate the gain of the layer using the current bucket id as split + // value, and we also obtain the following bucket id to try. + while (current_bucket_id > last_bucket_id) { + last_bucket_id = current_bucket_id; + int64 next_bucket_id = -1; + for (int root_idx = 0; root_idx < num_elements; root_idx++) { + int idx = + current_layer_offsets[root_idx] + partition_boundaries[root_idx]; + const int end_index = partition_boundaries[root_idx + 1]; + if (idx < end_index && bucket_ids(idx, 0) == current_bucket_id) { + GradientStats g(*gradients_t, *hessians_t, idx); + g *= normalizer_ratio; + left_gradient_stats[root_idx] += g; + current_layer_offsets[root_idx]++; + idx++; + } + if (idx < end_index && + (bucket_ids(idx, 0) < next_bucket_id || next_bucket_id == -1)) { + next_bucket_id = bucket_ids(idx, 0); + } + } + float gain_of_split = 0.0; + for (int root_idx = 0; root_idx < num_elements; root_idx++) { + GradientStats right_gradient_stats = + current_layer_stats[root_idx] - left_gradient_stats[root_idx]; + NodeStats left_stat = + state->ComputeNodeStats(left_gradient_stats[root_idx]); + NodeStats right_stat = state->ComputeNodeStats(right_gradient_stats); + gain_of_split += left_stat.gain + right_stat.gain; + current_left_node_stats[root_idx] = left_stat; + current_right_node_stats[root_idx] = right_stat; + } + if (gain_of_split > best_gain) { + best_gain = gain_of_split; + best_left_node_stats = current_left_node_stats; + best_right_node_stats = current_right_node_stats; + } + current_bucket_id = next_bucket_id; + } + + for (int root_idx = 0; root_idx < num_elements; root_idx++) { + best_gain -= state->ComputeNodeStats(current_layer_stats[root_idx]).gain; + } + best_gain -= num_elements * state->tree_complexity_regularization(); + + ObliviousSplitInfo oblivious_split_info; + auto* oblivious_dense_split = oblivious_split_info.mutable_split_node() + ->mutable_dense_float_binary_split(); + oblivious_dense_split->set_feature_column(state->feature_column_group_id()); + oblivious_dense_split->set_threshold( + bucket_boundaries(bucket_ids(best_bucket_idx, 0))); + (*gains)(0) = best_gain; + + for (int root_idx = 0; root_idx < num_elements; root_idx++) { + auto* left_children = oblivious_split_info.add_children_leaves(); + auto* right_children = oblivious_split_info.add_children_leaves(); + + state->FillLeaf(best_left_node_stats[root_idx], left_children); + state->FillLeaf(best_right_node_stats[root_idx], right_children); + + const int start_index = partition_boundaries[root_idx]; + (*output_partition_ids)(root_idx) = partition_ids(start_index); } + oblivious_split_info.SerializeToString(&(*output_splits)(0)); } }; REGISTER_KERNEL_BUILDER(Name("BuildDenseInequalitySplits").Device(DEVICE_CPU), diff --git a/tensorflow/contrib/boosted_trees/kernels/training_ops.cc b/tensorflow/contrib/boosted_trees/kernels/training_ops.cc index 1bfeed306641111718984b2097512e5ec3fa8630..6d9a6ee5a0d05465459393c4339558f1ca38d417 100644 --- a/tensorflow/contrib/boosted_trees/kernels/training_ops.cc +++ b/tensorflow/contrib/boosted_trees/kernels/training_ops.cc @@ -372,12 +372,18 @@ class GrowTreeEnsembleOp : public OpKernel { return; } + // Get the max tree depth. + const Tensor* max_tree_depth_t; + OP_REQUIRES_OK(context, + context->input("max_tree_depth", &max_tree_depth_t)); + const int32 max_tree_depth = max_tree_depth_t->scalar()(); + // Update and retrieve the growable tree. // If the tree is fully built and dropout was applied, it also adjusts the // weights of dropped and the last tree. boosted_trees::trees::DecisionTreeConfig* const tree_config = UpdateAndRetrieveGrowableTree(ensemble_resource, learning_rate, - dropout_seed); + dropout_seed, max_tree_depth); // Split tree nodes. for (auto& split_entry : best_splits) { @@ -494,7 +500,8 @@ class GrowTreeEnsembleOp : public OpKernel { boosted_trees::trees::DecisionTreeConfig* UpdateAndRetrieveGrowableTree( boosted_trees::models::DecisionTreeEnsembleResource* const ensemble_resource, - const float learning_rate, const uint64 dropout_seed) { + const float learning_rate, const uint64 dropout_seed, + const int32 max_tree_depth) { const auto num_trees = ensemble_resource->num_trees(); if (num_trees <= 0 || ensemble_resource->LastTreeMetadata()->is_finalized()) { @@ -506,8 +513,7 @@ class GrowTreeEnsembleOp : public OpKernel { tree_config->add_nodes()->mutable_leaf(); boosted_trees::trees::DecisionTreeMetadata* const tree_metadata = ensemble_resource->LastTreeMetadata(); - tree_metadata->set_is_finalized( - learner_config_.constraints().max_tree_depth() <= 1); + tree_metadata->set_is_finalized(max_tree_depth <= 1); tree_metadata->set_num_tree_weight_updates(1); } else { // The growable tree is by definition the last tree in the ensemble. @@ -518,8 +524,7 @@ class GrowTreeEnsembleOp : public OpKernel { << num_trees - 1 << " of ensemble of " << num_trees << " trees."; // Update growable tree metadata. tree_metadata->set_num_layers_grown(new_num_layers); - tree_metadata->set_is_finalized( - new_num_layers >= learner_config_.constraints().max_tree_depth()); + tree_metadata->set_is_finalized(new_num_layers >= max_tree_depth); } UpdateTreeWeightsIfDropout(ensemble_resource, dropout_seed); return ensemble_resource->LastTree(); diff --git a/tensorflow/contrib/boosted_trees/lib/learner/batch/base_split_handler.py b/tensorflow/contrib/boosted_trees/lib/learner/batch/base_split_handler.py index 1b7f59ea4218355a13f1df7264352bd68503bd19..5d4819b0f1cb598cfbe146f569aecd7883186339 100644 --- a/tensorflow/contrib/boosted_trees/lib/learner/batch/base_split_handler.py +++ b/tensorflow/contrib/boosted_trees/lib/learner/batch/base_split_handler.py @@ -131,6 +131,10 @@ class BaseSplitHandler(object): }, stamp_token, None) return control_flow_ops.group(update_1, *update_2[self]) + @abc.abstractmethod + def reset(self, stamp_token, next_stamp_token): + """Resets the state maintained by the handler.""" + @abc.abstractmethod def make_splits(self, stamp_token, next_stamp_token, class_id): """Create the best split using the accumulated stats and flush the state. diff --git a/tensorflow/contrib/boosted_trees/lib/learner/batch/categorical_split_handler.py b/tensorflow/contrib/boosted_trees/lib/learner/batch/categorical_split_handler.py index bf686237ff696dadad9713d26bf784d7442b80d0..efe29216c2a7d8aa985da54cdbb839b9e6f69078 100644 --- a/tensorflow/contrib/boosted_trees/lib/learner/batch/categorical_split_handler.py +++ b/tensorflow/contrib/boosted_trees/lib/learner/batch/categorical_split_handler.py @@ -202,3 +202,7 @@ class EqualitySplitHandler(base_split_handler.BaseSplitHandler): # always return ready. are_splits_ready = constant_op.constant(True) return (are_splits_ready, partition_ids, gains, split_infos) + + def reset(self, stamp_token, next_stamp_token): + reset = self._stats_accumulator.flush(stamp_token, next_stamp_token) + return reset diff --git a/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler.py b/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler.py index df0bec1fe363e07bbff6b059e86076239bd605e9..f45010ec26ed25127ca78b97f4d6fd7ebd6467ae 100644 --- a/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler.py +++ b/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler.py @@ -64,6 +64,7 @@ from __future__ import print_function import re from tensorflow.contrib.boosted_trees.lib.learner.batch import base_split_handler +from tensorflow.contrib.boosted_trees.proto import learner_pb2 from tensorflow.contrib.boosted_trees.python.ops import gen_quantile_ops from tensorflow.contrib.boosted_trees.python.ops import gen_stats_accumulator_ops from tensorflow.contrib.boosted_trees.python.ops import quantile_ops @@ -79,6 +80,7 @@ from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops + _BIAS_FEATURE_ID = -1 # Pattern to remove all non alpha numeric from a string. _PATTERN = re.compile(r"[\W_]+") @@ -147,6 +149,11 @@ class InequalitySplitHandler(base_split_handler.BaseSplitHandler): num_quantiles=num_quantiles, name="QuantileAccumulator/{}".format(self._name)) + def reset(self, stamp_token, next_stamp_token): + reset_1 = self._stats_accumulator.flush(stamp_token, next_stamp_token) + reset_2 = self._quantile_accumulator.flush(stamp_token, next_stamp_token) + return control_flow_ops.group([reset_1, reset_2]) + class DenseSplitHandler(InequalitySplitHandler): """Computes stats and finds the best inequality splits on dense columns.""" @@ -165,6 +172,7 @@ class DenseSplitHandler(InequalitySplitHandler): multiclass_strategy, init_stamp_token=0, loss_uses_sum_reduction=False, + weak_learner_type=learner_pb2.LearnerConfig.NORMAL_DECISION_TREE, name=None): """Initialize the internal state for this split handler. @@ -186,6 +194,7 @@ class DenseSplitHandler(InequalitySplitHandler): stamped objects. loss_uses_sum_reduction: A scalar boolean tensor that specifies whether SUM or MEAN reduction was used for the loss. + weak_learner_type: Specifies the type of weak learner to use. name: An optional handler name. """ super(DenseSplitHandler, self).__init__( @@ -203,6 +212,7 @@ class DenseSplitHandler(InequalitySplitHandler): multiclass_strategy=multiclass_strategy, loss_uses_sum_reduction=loss_uses_sum_reduction) self._dense_float_column = dense_float_column + self._weak_learner_type = weak_learner_type # Register dense_make_stats_update function as an Op to the graph. g = ops.get_default_graph() dense_make_stats_update.add_to_graph(g) @@ -263,15 +273,17 @@ class DenseSplitHandler(InequalitySplitHandler): next_stamp_token, self._multiclass_strategy, class_id, self._feature_column_group_id, self._l1_regularization, self._l2_regularization, self._tree_complexity_regularization, - self._min_node_weight, self._loss_uses_sum_reduction)) + self._min_node_weight, self._loss_uses_sum_reduction, + self._weak_learner_type)) return are_splits_ready, partition_ids, gains, split_infos -def _make_dense_split( - quantile_accumulator_handle, stats_accumulator_handle, stamp_token, - next_stamp_token, multiclass_strategy, class_id, feature_column_id, - l1_regularization, l2_regularization, tree_complexity_regularization, - min_node_weight, is_multi_dimentional, loss_uses_sum_reduction): +def _make_dense_split(quantile_accumulator_handle, stats_accumulator_handle, + stamp_token, next_stamp_token, multiclass_strategy, + class_id, feature_column_id, l1_regularization, + l2_regularization, tree_complexity_regularization, + min_node_weight, is_multi_dimentional, + loss_uses_sum_reduction, weak_learner_type): """Function that builds splits for a dense feature column.""" # Get the bucket boundaries are_splits_ready, buckets = ( @@ -320,7 +332,8 @@ def _make_dense_split( l2_regularization=l2_regularization, tree_complexity_regularization=tree_complexity_regularization, min_node_weight=min_node_weight, - multiclass_strategy=multiclass_strategy)) + multiclass_strategy=multiclass_strategy, + weak_learner_type=weak_learner_type)) return are_splits_ready, partition_ids, gains, split_infos @@ -500,7 +513,40 @@ def _make_sparse_split( return are_splits_ready, partition_ids, gains, split_infos -def _specialize_make_split(func, is_multi_dimentional): +def _specialize_make_split_dense(func, is_multi_dimentional): + """Builds a specialized version of the function.""" + + @function.Defun( + dtypes.resource, + dtypes.resource, + dtypes.int64, + dtypes.int64, + dtypes.int32, + dtypes.int32, + dtypes.int32, + dtypes.float32, + dtypes.float32, + dtypes.float32, + dtypes.float32, + dtypes.bool, + dtypes.int32, + noinline=True) + def f(quantile_accumulator_handle, stats_accumulator_handle, stamp_token, + next_stamp_token, multiclass_strategy, class_id, feature_column_id, + l1_regularization, l2_regularization, tree_complexity_regularization, + min_node_weight, loss_uses_sum_reduction, weak_learner_type): + """Function that builds splits for a sparse feature column.""" + return func(quantile_accumulator_handle, stats_accumulator_handle, + stamp_token, next_stamp_token, multiclass_strategy, class_id, + feature_column_id, l1_regularization, l2_regularization, + tree_complexity_regularization, min_node_weight, + is_multi_dimentional, loss_uses_sum_reduction, + weak_learner_type) + + return f + + +def _specialize_make_split_sparse(func, is_multi_dimentional): """Builds a specialized version of the function.""" @function.Defun( @@ -530,15 +576,17 @@ def _specialize_make_split(func, is_multi_dimentional): return f -make_dense_split_scalar = _specialize_make_split(_make_dense_split, - is_multi_dimentional=False) -make_dense_split_tensor = _specialize_make_split(_make_dense_split, - is_multi_dimentional=True) -make_sparse_split_scalar = _specialize_make_split(_make_sparse_split, - is_multi_dimentional=False) -make_sparse_split_tensor = _specialize_make_split(_make_sparse_split, - is_multi_dimentional=True) +make_dense_split_scalar = _specialize_make_split_dense( + _make_dense_split, is_multi_dimentional=False) + +make_dense_split_tensor = _specialize_make_split_dense( + _make_dense_split, is_multi_dimentional=True) + +make_sparse_split_scalar = _specialize_make_split_sparse( + _make_sparse_split, is_multi_dimentional=False) +make_sparse_split_tensor = _specialize_make_split_sparse( + _make_sparse_split, is_multi_dimentional=True) @function.Defun( @@ -579,8 +627,10 @@ def dense_make_stats_update(is_active, are_buckets_ready, float_column, example_partition_ids, feature_ids, gradients, hessians = ( control_flow_ops.cond( - math_ops.logical_and(are_buckets_ready, is_active[0]), - ready_inputs_fn, not_ready_inputs_fn)) + math_ops.logical_and( + math_ops.logical_and(are_buckets_ready, + array_ops.size(quantile_buckets) > 0), + is_active[0]), ready_inputs_fn, not_ready_inputs_fn)) return (quantile_values, quantile_weights, example_partition_ids, feature_ids, gradients, hessians) @@ -674,8 +724,10 @@ def sparse_make_stats_update( lambda: handler_not_active)) example_partition_ids, feature_ids, gradients, hessians = ( - control_flow_ops.cond(are_buckets_ready, quantiles_ready, - quantiles_not_ready)) + control_flow_ops.cond( + math_ops.logical_and(are_buckets_ready, + array_ops.size(quantile_buckets) > 0), + quantiles_ready, quantiles_not_ready)) return (quantile_indices, quantile_values, quantile_shape, quantile_weights, example_partition_ids, feature_ids, gradients, hessians) diff --git a/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler_test.py b/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler_test.py index d59732cf92eb85e88732ac5a17dccf475ae5342f..6572f2f414b5d6741f43ec9f79ac7f6ab0f22deb 100644 --- a/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler_test.py +++ b/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler_test.py @@ -182,6 +182,133 @@ class DenseSplitHandlerTest(test_util.TensorFlowTestCase): self.assertAllClose(0.52, split_node.threshold, 0.00001) + def testObliviousFeatureSplitGeneration(self): + with self.test_session() as sess: + # The data looks like the following: + # Example | Gradients | Partition | Dense Quantile | + # i0 | (0.2, 0.12) | 0 | 2 | + # i1 | (-0.5, 0.07) | 0 | 2 | + # i2 | (1.2, 0.2) | 0 | 0 | + # i3 | (4.0, 0.13) | 1 | 1 | + dense_column = array_ops.constant([0.62, 0.62, 0.3, 0.52]) + gradients = array_ops.constant([0.2, -0.5, 1.2, 4.0]) + hessians = array_ops.constant([0.12, 0.07, 0.2, 0.13]) + partition_ids = array_ops.constant([0, 0, 0, 1], dtype=dtypes.int32) + class_id = -1 + + gradient_shape = tensor_shape.scalar() + hessian_shape = tensor_shape.scalar() + split_handler = ordinal_split_handler.DenseSplitHandler( + l1_regularization=0.1, + l2_regularization=1., + tree_complexity_regularization=0., + min_node_weight=0., + epsilon=0.001, + num_quantiles=10, + feature_column_group_id=0, + dense_float_column=dense_column, + init_stamp_token=0, + gradient_shape=gradient_shape, + hessian_shape=hessian_shape, + multiclass_strategy=learner_pb2.LearnerConfig.TREE_PER_CLASS, + weak_learner_type=learner_pb2.LearnerConfig.OBLIVIOUS_DECISION_TREE) + resources.initialize_resources(resources.shared_resources()).run() + + empty_gradients, empty_hessians = get_empty_tensors( + gradient_shape, hessian_shape) + example_weights = array_ops.ones([4, 1], dtypes.float32) + + update_1 = split_handler.update_stats_sync( + 0, + partition_ids, + gradients, + hessians, + empty_gradients, + empty_hessians, + example_weights, + is_active=array_ops.constant([True, True])) + with ops.control_dependencies([update_1]): + are_splits_ready = split_handler.make_splits( + np.int64(0), np.int64(1), class_id)[0] + + with ops.control_dependencies([are_splits_ready]): + update_2 = split_handler.update_stats_sync( + 1, + partition_ids, + gradients, + hessians, + empty_gradients, + empty_hessians, + example_weights, + is_active=array_ops.constant([True, True])) + with ops.control_dependencies([update_2]): + are_splits_ready2, partitions, gains, splits = ( + split_handler.make_splits(np.int64(1), np.int64(2), class_id)) + are_splits_ready, are_splits_ready2, partitions, gains, splits = ( + sess.run([ + are_splits_ready, are_splits_ready2, partitions, gains, splits + ])) + + # During the first iteration, inequality split handlers are not going to + # have any splits. Make sure that we return not_ready in that case. + self.assertFalse(are_splits_ready) + self.assertTrue(are_splits_ready2) + + self.assertAllEqual([0, 1], partitions) + + oblivious_split_info = split_info_pb2.ObliviousSplitInfo() + oblivious_split_info.ParseFromString(splits[0]) + split_node = oblivious_split_info.split_node.dense_float_binary_split + + self.assertAllClose(0.3, split_node.threshold, 0.00001) + self.assertEqual(0, split_node.feature_column) + + # Check the split on partition 0. + # -(1.2 - 0.1) / (0.2 + 1) + expected_left_weight_0 = -0.9166666666666666 + + # expected_left_weight_0 * -(1.2 - 0.1) + expected_left_gain_0 = 1.008333333333333 + + # (-0.5 + 0.2 + 0.1) / (0.19 + 1) + expected_right_weight_0 = 0.1680672 + + # expected_right_weight_0 * -(-0.5 + 0.2 + 0.1)) + expected_right_gain_0 = 0.033613445378151252 + + # (0.2 + -0.5 + 1.2 - 0.1) ** 2 / (0.12 + 0.07 + 0.2 + 1) + expected_bias_gain_0 = 0.46043165467625896 + + left_child = oblivious_split_info.children_leaves[0].vector + right_child = oblivious_split_info.children_leaves[1].vector + + self.assertAllClose([expected_left_weight_0], left_child.value, 0.00001) + + self.assertAllClose([expected_right_weight_0], right_child.value, 0.00001) + + # Check the split on partition 1. + expected_left_weight_1 = 0 + expected_left_gain_1 = 0 + # -(4 - 0.1) / (0.13 + 1) + expected_right_weight_1 = -3.4513274336283186 + # expected_right_weight_1 * -(4 - 0.1) + expected_right_gain_1 = 13.460176991150442 + # (-4 + 0.1) ** 2 / (0.13 + 1) + expected_bias_gain_1 = 13.460176991150442 + + left_child = oblivious_split_info.children_leaves[2].vector + right_child = oblivious_split_info.children_leaves[3].vector + + self.assertAllClose([expected_left_weight_1], left_child.value, 0.00001) + + self.assertAllClose([expected_right_weight_1], right_child.value, 0.00001) + + # The layer gain is the sum of the gains of each partition + layer_gain = ( + expected_left_gain_0 + expected_right_gain_0 - expected_bias_gain_0) + ( + expected_left_gain_1 + expected_right_gain_1 - expected_bias_gain_1) + self.assertAllClose(layer_gain, gains[0], 0.00001) + def testGenerateFeatureSplitCandidatesLossUsesSumReduction(self): with self.test_session() as sess: # The data looks like the following: @@ -1072,8 +1199,8 @@ class SparseSplitHandlerTest(test_util.TensorFlowTestCase): def testGenerateFeatureSplitCandidatesMulticlassFullHessian(self): with self.test_session() as sess: # Batch is 4, 2 classes - gradients = array_ops.constant( - [[0.2, 1.4], [-0.5, 0.1], [1.2, 3], [4.0, -3]]) + gradients = array_ops.constant([[0.2, 1.4], [-0.5, 0.1], [1.2, 3], + [4.0, -3]]) # 2x2 matrix for each instance hessian_0 = [[0.12, 0.02], [0.3, 0.11]] hessian_1 = [[0.07, -0.2], [-0.5, 0.2]] @@ -1167,8 +1294,8 @@ class SparseSplitHandlerTest(test_util.TensorFlowTestCase): def testGenerateFeatureSplitCandidatesMulticlassDiagonalHessian(self): with self.test_session() as sess: # Batch is 4, 2 classes - gradients = array_ops.constant( - [[0.2, 1.4], [-0.5, 0.1], [1.2, 3], [4.0, -3]]) + gradients = array_ops.constant([[0.2, 1.4], [-0.5, 0.1], [1.2, 3], + [4.0, -3]]) # Each hessian is a diagonal from a full hessian matrix. hessian_0 = [0.12, 0.11] hessian_1 = [0.07, 0.2] @@ -1406,6 +1533,100 @@ class SparseSplitHandlerTest(test_util.TensorFlowTestCase): self.assertEqual(len(gains), 0) self.assertEqual(len(splits), 0) + def testEmptyBuckets(self): + """Test that reproduces the case when quantile buckets were empty.""" + with self.test_session() as sess: + sparse_column = array_ops.sparse_placeholder(dtypes.float32) + + # We have two batches - at first, a sparse feature is empty. + empty_indices = array_ops.constant([], dtype=dtypes.int64, shape=[0, 2]) + empty_values = array_ops.constant([], dtype=dtypes.float32) + empty_sparse_column = sparse_tensor.SparseTensor(empty_indices, + empty_values, [4, 2]) + empty_sparse_column = empty_sparse_column.eval(session=sess) + + # For the second batch, the sparse feature is not empty. + non_empty_indices = array_ops.constant( + [[0, 0], [2, 1], [3, 2]], dtype=dtypes.int64, shape=[3, 2]) + non_empty_values = array_ops.constant( + [0.52, 0.3, 0.52], dtype=dtypes.float32) + non_empty_sparse_column = sparse_tensor.SparseTensor( + non_empty_indices, non_empty_values, [4, 2]) + non_empty_sparse_column = non_empty_sparse_column.eval(session=sess) + + gradient_shape = tensor_shape.scalar() + hessian_shape = tensor_shape.scalar() + class_id = -1 + + split_handler = ordinal_split_handler.SparseSplitHandler( + l1_regularization=0.0, + l2_regularization=2.0, + tree_complexity_regularization=0.0, + min_node_weight=0.0, + epsilon=0.01, + num_quantiles=2, + feature_column_group_id=0, + sparse_float_column=sparse_column, + init_stamp_token=0, + gradient_shape=gradient_shape, + hessian_shape=hessian_shape, + multiclass_strategy=learner_pb2.LearnerConfig.TREE_PER_CLASS) + resources.initialize_resources(resources.shared_resources()).run() + gradients = array_ops.constant([0.2, -0.5, 1.2, 4.0]) + hessians = array_ops.constant([0.12, 0.07, 0.2, 0.13]) + partition_ids = array_ops.constant([0, 0, 0, 1], dtype=dtypes.int32) + + empty_gradients, empty_hessians = get_empty_tensors( + gradient_shape, hessian_shape) + example_weights = array_ops.ones([4, 1], dtypes.float32) + + update_1 = split_handler.update_stats_sync( + 0, + partition_ids, + gradients, + hessians, + empty_gradients, + empty_hessians, + example_weights, + is_active=array_ops.constant([True, True])) + with ops.control_dependencies([update_1]): + are_splits_ready = split_handler.make_splits( + np.int64(0), np.int64(1), class_id)[0] + + # First, calculate quantiles and try to update on an empty data for a + # feature. + are_splits_ready = ( + sess.run( + are_splits_ready, + feed_dict={sparse_column: empty_sparse_column})) + self.assertFalse(are_splits_ready) + + update_2 = split_handler.update_stats_sync( + 1, + partition_ids, + gradients, + hessians, + empty_gradients, + empty_hessians, + example_weights, + is_active=array_ops.constant([True, True])) + with ops.control_dependencies([update_2]): + are_splits_ready2, partitions, gains, splits = ( + split_handler.make_splits(np.int64(1), np.int64(2), class_id)) + + # Now the feature in the second batch is not empty, but buckets + # calculated on the first batch are empty. + are_splits_ready2, partitions, gains, splits = ( + sess.run( + [are_splits_ready2, partitions, gains, splits], + feed_dict={sparse_column: non_empty_sparse_column})) + self.assertFalse(are_splits_ready) + self.assertTrue(are_splits_ready2) + # Since the buckets were empty, we can't calculate the splits. + self.assertEqual(len(partitions), 0) + self.assertEqual(len(gains), 0) + self.assertEqual(len(splits), 0) + def testDegenerativeCase(self): with self.test_session() as sess: # One data example only, one leaf and thus one quantile bucket.The same diff --git a/tensorflow/contrib/boosted_trees/ops/split_handler_ops.cc b/tensorflow/contrib/boosted_trees/ops/split_handler_ops.cc index ca5c7f3d8c78a543c63fbfa9f7eb7c3d348f11b8..9b68a9de96ec8f6c7679410ca8a468978f2149e6 100644 --- a/tensorflow/contrib/boosted_trees/ops/split_handler_ops.cc +++ b/tensorflow/contrib/boosted_trees/ops/split_handler_ops.cc @@ -36,6 +36,7 @@ REGISTER_OP("BuildDenseInequalitySplits") .Input("tree_complexity_regularization: float") .Input("min_node_weight: float") .Input("multiclass_strategy: int32") + .Input("weak_learner_type: int32") .Output("output_partition_ids: int32") .Output("gains: float32") .Output("split_infos: string") @@ -84,6 +85,8 @@ min_node_weight: A scalar, minimum sum of example hessian needed in a child. be considered. multiclass_strategy: A scalar, specifying the multiclass handling strategy. See LearnerConfig.MultiClassStrategy for valid values. +weak_learner_type: A scalar, specifying the weak learner type to use. + See LearnerConfig.WeakLearnerType for valid values. output_partition_ids: A rank 1 tensor, the partition IDs that we created splits for. gains: A rank 1 tensor, for the computed gain for the created splits. diff --git a/tensorflow/contrib/boosted_trees/ops/training_ops.cc b/tensorflow/contrib/boosted_trees/ops/training_ops.cc index f63c199ad6146c23c22437ffe2287a77ee91ca44..22ac9edb72ea91ecef6fd1dff9f399b3c9020083 100644 --- a/tensorflow/contrib/boosted_trees/ops/training_ops.cc +++ b/tensorflow/contrib/boosted_trees/ops/training_ops.cc @@ -56,6 +56,7 @@ REGISTER_OP("GrowTreeEnsemble") .Input("next_stamp_token: int64") .Input("learning_rate: float") .Input("dropout_seed: int64") + .Input("max_tree_depth: int32") .Input("partition_ids: num_handlers * int32") .Input("gains: num_handlers * float") .Input("splits: num_handlers * string") @@ -67,6 +68,8 @@ REGISTER_OP("GrowTreeEnsemble") TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused_input)); // Dropout seed. TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused_input)); + // Maximum tree depth. + TF_RETURN_IF_ERROR(c->WithRank(c->input(5), 0, &unused_input)); return Status::OK(); }) .Doc(R"doc( diff --git a/tensorflow/contrib/boosted_trees/proto/learner.proto b/tensorflow/contrib/boosted_trees/proto/learner.proto index d84ba7438e7f03685d5bafca52ff8283f0fce898..c49cb48cdea6d8c85588f4c3c2bda6faf7e125db 100644 --- a/tensorflow/contrib/boosted_trees/proto/learner.proto +++ b/tensorflow/contrib/boosted_trees/proto/learner.proto @@ -108,6 +108,11 @@ message LearnerConfig { DIAGONAL_HESSIAN = 3; } + enum WeakLearnerType { + NORMAL_DECISION_TREE = 0; + OBLIVIOUS_DECISION_TREE = 1; + } + // Number of classes. uint32 num_classes = 1; @@ -141,4 +146,7 @@ message LearnerConfig { // If you want to average the ensembles (for regularization), provide the // config below. AveragingConfig averaging_config = 11; + + // By default we use NORMAL_DECISION_TREE as weak learner. + WeakLearnerType weak_learner_type = 12; } diff --git a/tensorflow/contrib/boosted_trees/proto/split_info.proto b/tensorflow/contrib/boosted_trees/proto/split_info.proto index a300c24c8ec507dea0af662b2361d408a2085237..850340f5c2096ca674616254de45d96b84200a64 100644 --- a/tensorflow/contrib/boosted_trees/proto/split_info.proto +++ b/tensorflow/contrib/boosted_trees/proto/split_info.proto @@ -17,3 +17,10 @@ message SplitInfo { // Right Leaf node. tensorflow.boosted_trees.trees.Leaf right_child = 3; } + +message ObliviousSplitInfo { + // The split node with the feature_column and threshold defined. + tensorflow.boosted_trees.trees.TreeNode split_node = 1; + // The new leaves of the tree. + repeated tensorflow.boosted_trees.trees.Leaf children_leaves = 2; +} diff --git a/tensorflow/contrib/boosted_trees/python/kernel_tests/split_handler_ops_test.py b/tensorflow/contrib/boosted_trees/python/kernel_tests/split_handler_ops_test.py index 5cd37ec67ec3bdefb6ea19049a7a12249162d45a..2589504762787deaf598777650b8372320824c22 100644 --- a/tensorflow/contrib/boosted_trees/python/kernel_tests/split_handler_ops_test.py +++ b/tensorflow/contrib/boosted_trees/python/kernel_tests/split_handler_ops_test.py @@ -59,7 +59,8 @@ class SplitHandlerOpsTest(test_util.TensorFlowTestCase): min_node_weight=0, class_id=-1, feature_column_group_id=0, - multiclass_strategy=learner_pb2.LearnerConfig.TREE_PER_CLASS)) + multiclass_strategy=learner_pb2.LearnerConfig.TREE_PER_CLASS, + weak_learner_type=learner_pb2.LearnerConfig.NORMAL_DECISION_TREE)) partitions, gains, splits = sess.run([partitions, gains, splits]) self.assertAllEqual([0, 1], partitions) @@ -132,7 +133,8 @@ class SplitHandlerOpsTest(test_util.TensorFlowTestCase): min_node_weight=0, class_id=-1, feature_column_group_id=0, - multiclass_strategy=learner_pb2.LearnerConfig.FULL_HESSIAN)) + multiclass_strategy=learner_pb2.LearnerConfig.FULL_HESSIAN, + weak_learner_type=learner_pb2.LearnerConfig.NORMAL_DECISION_TREE)) partitions, gains, splits = sess.run([partitions, gains, splits]) self.assertAllEqual([0, 1], partitions) @@ -171,7 +173,8 @@ class SplitHandlerOpsTest(test_util.TensorFlowTestCase): min_node_weight=0, class_id=-1, feature_column_group_id=0, - multiclass_strategy=learner_pb2.LearnerConfig.TREE_PER_CLASS)) + multiclass_strategy=learner_pb2.LearnerConfig.TREE_PER_CLASS, + weak_learner_type=learner_pb2.LearnerConfig.NORMAL_DECISION_TREE)) partitions, gains, splits = sess.run([partitions, gains, splits]) # .assertEmpty doesn't exist on ubuntu-contrib self.assertEqual(0, len(partitions)) diff --git a/tensorflow/contrib/boosted_trees/python/kernel_tests/training_ops_test.py b/tensorflow/contrib/boosted_trees/python/kernel_tests/training_ops_test.py index 3e524efbeac74ff754d63cae92b3e194411cb2de..e39e1de8d1954c7f4dcab87d7727a64affa13c8c 100644 --- a/tensorflow/contrib/boosted_trees/python/kernel_tests/training_ops_test.py +++ b/tensorflow/contrib/boosted_trees/python/kernel_tests/training_ops_test.py @@ -296,7 +296,7 @@ class GrowTreeEnsembleOpTest(test_util.TensorFlowTestCase): pruning_mode=learner_pb2.LearnerConfig.PRE_PRUNE, growing_mode=learner_pb2.LearnerConfig.WHOLE_TREE, # Dropout does not change anything here, tree is not finalized. - dropout_probability=0.5).SerializeToString() + dropout_probability=0.5) # Prepare handler inputs. # Note that handlers 1 & 3 have the same gain but different splits. @@ -321,9 +321,10 @@ class GrowTreeEnsembleOpTest(test_util.TensorFlowTestCase): ], gains=[handler1_gains, handler2_gains, handler3_gains], splits=[handler1_split, handler2_split, handler3_split], - learner_config=learner_config, + learner_config=learner_config.SerializeToString(), dropout_seed=123, - center_bias=True) + center_bias=True, + max_tree_depth=learner_config.constraints.max_tree_depth) session.run(grow_op) # Expect the simpler split from handler 1 to be chosen. @@ -443,7 +444,7 @@ class GrowTreeEnsembleOpTest(test_util.TensorFlowTestCase): pruning_mode=learner_pb2.LearnerConfig.PRE_PRUNE, growing_mode=learner_pb2.LearnerConfig.WHOLE_TREE, # Dropout does not change anything here - tree is not finalized. - dropout_probability=0.5).SerializeToString() + dropout_probability=0.5) # Prepare handler inputs. # Handler 1 only has a candidate for partition 1, handler 2 has candidates @@ -472,9 +473,10 @@ class GrowTreeEnsembleOpTest(test_util.TensorFlowTestCase): ], gains=[handler1_gains, handler2_gains, handler3_gains], splits=[handler1_split, handler2_split, handler3_split], - learner_config=learner_config, + learner_config=learner_config.SerializeToString(), dropout_seed=123, - center_bias=True) + center_bias=True, + max_tree_depth=learner_config.constraints.max_tree_depth) session.run(grow_op) # Expect the split for partition 1 to be chosen from handler 1 and @@ -632,8 +634,7 @@ class GrowTreeEnsembleOpTest(test_util.TensorFlowTestCase): max_depth=1, min_node_weight=0, pruning_mode=learner_pb2.LearnerConfig.PRE_PRUNE, - growing_mode=learner_pb2.LearnerConfig.WHOLE_TREE).SerializeToString( - ) + growing_mode=learner_pb2.LearnerConfig.WHOLE_TREE) # Prepare handler inputs. handler1_partitions = np.array([0], dtype=np.int32) @@ -657,9 +658,10 @@ class GrowTreeEnsembleOpTest(test_util.TensorFlowTestCase): ], gains=[handler1_gains, handler2_gains, handler3_gains], splits=[handler1_split, handler2_split, handler3_split], - learner_config=learner_config, + learner_config=learner_config.SerializeToString(), dropout_seed=123, - center_bias=True) + center_bias=True, + max_tree_depth=learner_config.constraints.max_tree_depth) session.run(grow_op) # Expect a new tree to be added with the split from handler 1. @@ -773,8 +775,7 @@ class GrowTreeEnsembleOpTest(test_util.TensorFlowTestCase): max_depth=1, min_node_weight=0, pruning_mode=learner_pb2.LearnerConfig.PRE_PRUNE, - growing_mode=learner_pb2.LearnerConfig.WHOLE_TREE).SerializeToString( - ) + growing_mode=learner_pb2.LearnerConfig.WHOLE_TREE) # Prepare handler inputs. # All handlers have negative gain. @@ -794,9 +795,10 @@ class GrowTreeEnsembleOpTest(test_util.TensorFlowTestCase): partition_ids=[handler1_partitions, handler2_partitions], gains=[handler1_gains, handler2_gains], splits=[handler1_split, handler2_split], - learner_config=learner_config, + learner_config=learner_config.SerializeToString(), dropout_seed=123, - center_bias=True) + center_bias=True, + max_tree_depth=learner_config.constraints.max_tree_depth) session.run(grow_op) # Expect the ensemble to be empty. @@ -839,8 +841,7 @@ class GrowTreeEnsembleOpTest(test_util.TensorFlowTestCase): max_depth=1, min_node_weight=0, pruning_mode=learner_pb2.LearnerConfig.POST_PRUNE, - growing_mode=learner_pb2.LearnerConfig.WHOLE_TREE).SerializeToString( - ) + growing_mode=learner_pb2.LearnerConfig.WHOLE_TREE) # Prepare handler inputs. # Note that handlers 1 & 3 have the same gain but different splits. @@ -865,9 +866,10 @@ class GrowTreeEnsembleOpTest(test_util.TensorFlowTestCase): ], gains=[handler1_gains, handler2_gains, handler3_gains], splits=[handler1_split, handler2_split, handler3_split], - learner_config=learner_config, + learner_config=learner_config.SerializeToString(), dropout_seed=123, - center_bias=True) + center_bias=True, + max_tree_depth=learner_config.constraints.max_tree_depth) session.run(grow_op) # Expect the simpler split from handler 1 to be chosen. @@ -946,8 +948,7 @@ class GrowTreeEnsembleOpTest(test_util.TensorFlowTestCase): max_depth=2, min_node_weight=0, pruning_mode=learner_pb2.LearnerConfig.POST_PRUNE, - growing_mode=learner_pb2.LearnerConfig.WHOLE_TREE).SerializeToString( - ) + growing_mode=learner_pb2.LearnerConfig.WHOLE_TREE) # Prepare handler inputs. # All handlers have negative gain. @@ -967,9 +968,10 @@ class GrowTreeEnsembleOpTest(test_util.TensorFlowTestCase): partition_ids=[handler1_partitions, handler2_partitions], gains=[handler1_gains, handler2_gains], splits=[handler1_split, handler2_split], - learner_config=learner_config, + learner_config=learner_config.SerializeToString(), dropout_seed=123, - center_bias=True) + center_bias=True, + max_tree_depth=learner_config.constraints.max_tree_depth) session.run(grow_op) # Expect the split from handler 2 to be chosen despite the negative gain. @@ -1048,9 +1050,10 @@ class GrowTreeEnsembleOpTest(test_util.TensorFlowTestCase): partition_ids=[handler1_partitions], gains=[handler1_gains], splits=[handler1_split], - learner_config=learner_config, + learner_config=learner_config.SerializeToString(), dropout_seed=123, - center_bias=True) + center_bias=True, + max_tree_depth=learner_config.constraints.max_tree_depth) session.run(grow_op) # Expect the ensemble to be empty as post-pruning will prune @@ -1094,8 +1097,7 @@ class GrowTreeEnsembleOpTest(test_util.TensorFlowTestCase): max_depth=2, min_node_weight=0, pruning_mode=learner_pb2.LearnerConfig.POST_PRUNE, - growing_mode=learner_pb2.LearnerConfig.WHOLE_TREE).SerializeToString( - ) + growing_mode=learner_pb2.LearnerConfig.WHOLE_TREE) # Prepare handler inputs. # Second handler has positive gain. @@ -1115,9 +1117,10 @@ class GrowTreeEnsembleOpTest(test_util.TensorFlowTestCase): partition_ids=[handler1_partitions, handler2_partitions], gains=[handler1_gains, handler2_gains], splits=[handler1_split, handler2_split], - learner_config=learner_config, + learner_config=learner_config.SerializeToString(), dropout_seed=123, - center_bias=True) + center_bias=True, + max_tree_depth=learner_config.constraints.max_tree_depth) session.run(grow_op) # Expect the split from handler 2 to be chosen despite the negative gain. @@ -1194,9 +1197,10 @@ class GrowTreeEnsembleOpTest(test_util.TensorFlowTestCase): partition_ids=[handler1_partitions], gains=[handler1_gains], splits=[handler1_split], - learner_config=learner_config, + learner_config=learner_config.SerializeToString(), dropout_seed=123, - center_bias=True) + center_bias=True, + max_tree_depth=learner_config.constraints.max_tree_depth) session.run(grow_op) # Expect the negative gain split of partition 1 to be pruned and the @@ -1335,7 +1339,7 @@ class GrowTreeEnsembleOpTest(test_util.TensorFlowTestCase): pruning_mode=learner_pb2.LearnerConfig.PRE_PRUNE, growing_mode=learner_pb2.LearnerConfig.LAYER_BY_LAYER, # Dropout will have no effect, since the tree will not be fully grown. - dropout_probability=1.0).SerializeToString() + dropout_probability=1.0) # Prepare handler inputs. # Handler 1 only has a candidate for partition 1, handler 2 has candidates @@ -1364,9 +1368,10 @@ class GrowTreeEnsembleOpTest(test_util.TensorFlowTestCase): ], gains=[handler1_gains, handler2_gains, handler3_gains], splits=[handler1_split, handler2_split, handler3_split], - learner_config=learner_config, + learner_config=learner_config.SerializeToString(), dropout_seed=123, - center_bias=True) + center_bias=True, + max_tree_depth=learner_config.constraints.max_tree_depth) session.run(grow_op) # Expect the split for partition 1 to be chosen from handler 1 and @@ -1543,7 +1548,7 @@ class GrowTreeEnsembleOpTest(test_util.TensorFlowTestCase): min_node_weight=0, pruning_mode=learner_pb2.LearnerConfig.PRE_PRUNE, growing_mode=learner_pb2.LearnerConfig.WHOLE_TREE, - dropout_probability=1.0).SerializeToString() + dropout_probability=1.0) # Prepare handler inputs. handler1_partitions = np.array([0], dtype=np.int32) @@ -1567,9 +1572,10 @@ class GrowTreeEnsembleOpTest(test_util.TensorFlowTestCase): ], gains=[handler1_gains, handler2_gains, handler3_gains], splits=[handler1_split, handler2_split, handler3_split], - learner_config=learner_config, + learner_config=learner_config.SerializeToString(), dropout_seed=123, - center_bias=True) + center_bias=True, + max_tree_depth=learner_config.constraints.max_tree_depth) session.run(grow_op) # Expect a new tree to be added with the split from handler 1. @@ -1669,7 +1675,6 @@ class GrowTreeEnsembleOpTest(test_util.TensorFlowTestCase): growing_mode=learner_pb2.LearnerConfig.WHOLE_TREE) learner_config.constraints.max_number_of_unique_feature_columns = 3 - learner_config = learner_config.SerializeToString() # Prepare handler inputs. handler1_partitions = np.array([0], dtype=np.int32) handler1_gains = np.array([7.62], dtype=np.float32) @@ -1692,9 +1697,10 @@ class GrowTreeEnsembleOpTest(test_util.TensorFlowTestCase): ], gains=[handler1_gains, handler2_gains, handler3_gains], splits=[handler1_split, handler2_split, handler3_split], - learner_config=learner_config, + learner_config=learner_config.SerializeToString(), dropout_seed=123, - center_bias=True) + center_bias=True, + max_tree_depth=learner_config.constraints.max_tree_depth) session.run(grow_op) _, serialized = session.run( diff --git a/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py b/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py index e08b230f468ea2179aa6a8abf5225e9549ea2ee4..2f75d8aa99c54ce1127b3c907702a7220be16155 100644 --- a/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py +++ b/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py @@ -51,6 +51,7 @@ from tensorflow.python.platform import tf_logging as logging from tensorflow.python.summary import summary from tensorflow.python.training import device_setter + # Key names for prediction dict. ENSEMBLE_STAMP = "ensemble_stamp" PREDICTIONS = "predictions" @@ -217,6 +218,21 @@ def extract_features(features, feature_columns, use_core_columns): sparse_int_shapes = [] for key in sorted(features.keys()): tensor = features[key] + # TODO(nponomareva): consider iterating over feature columns instead. + if isinstance(tensor, tuple): + # Weighted categorical feature. + categorical_tensor = tensor[0] + weight_tensor = tensor[1] + + shape = categorical_tensor.dense_shape + indices = array_ops.concat([ + array_ops.slice(categorical_tensor.indices, [0, 0], [-1, 1]), + array_ops.expand_dims( + math_ops.to_int64(categorical_tensor.values), -1) + ], 1) + tensor = sparse_tensor.SparseTensor( + indices=indices, values=weight_tensor.values, dense_shape=shape) + if isinstance(tensor, sparse_tensor.SparseTensor): if tensor.values.dtype == dtypes.float32: sparse_float_names.append(key) @@ -353,6 +369,9 @@ class GradientBoostedDecisionTreeModel(object): self._gradient_shape = tensor_shape.scalar() self._hessian_shape = tensor_shape.scalar() else: + if center_bias: + raise ValueError("Center bias should be False for multiclass.") + self._gradient_shape = tensor_shape.TensorShape([logits_dimension]) if (learner_config.multi_class_strategy == learner_pb2.LearnerConfig.FULL_HESSIAN): @@ -380,6 +399,8 @@ class GradientBoostedDecisionTreeModel(object): self._learner_config = learner_config self._feature_columns = feature_columns self._learner_config_serialized = learner_config.SerializeToString() + self._max_tree_depth = variables.Variable( + initial_value=self._learner_config.constraints.max_tree_depth) self._attempted_trees = variables.Variable( initial_value=array_ops.zeros([], dtypes.int64), trainable=False, @@ -666,6 +687,8 @@ class GradientBoostedDecisionTreeModel(object): self._learner_config.constraints.min_node_weight, dtypes.float32) loss_uses_sum_reduction = self._loss_reduction == losses.Reduction.SUM loss_uses_sum_reduction = constant_op.constant(loss_uses_sum_reduction) + weak_learner_type = constant_op.constant( + self._learner_config.weak_learner_type) epsilon = 0.01 num_quantiles = 100 strategy_tensor = constant_op.constant(strategy) @@ -690,6 +713,7 @@ class GradientBoostedDecisionTreeModel(object): multiclass_strategy=strategy_tensor, init_stamp_token=init_stamp_token, loss_uses_sum_reduction=loss_uses_sum_reduction, + weak_learner_type=weak_learner_type, )) fc_name_idx += 1 @@ -893,7 +917,7 @@ class GradientBoostedDecisionTreeModel(object): reset_ops = [] for handler in handlers: - reset_ops.append(handler.make_splits(stamp_token, next_stamp_token, 0)) + reset_ops.append(handler.reset(stamp_token, next_stamp_token)) if self._center_bias: reset_ops.append( bias_stats_accumulator.flush(stamp_token, next_stamp_token)) @@ -1051,7 +1075,8 @@ class GradientBoostedDecisionTreeModel(object): splits=split_info_list, learner_config=self._learner_config_serialized, dropout_seed=dropout_seed, - center_bias=self._center_bias) + center_bias=self._center_bias, + max_tree_depth=self._max_tree_depth) def _grow_ensemble_not_ready_fn(): # Don't grow the ensemble, just update the stamp. @@ -1065,7 +1090,8 @@ class GradientBoostedDecisionTreeModel(object): splits=[], learner_config=self._learner_config_serialized, dropout_seed=dropout_seed, - center_bias=self._center_bias) + center_bias=self._center_bias, + max_tree_depth=self._max_tree_depth) def _grow_ensemble_fn(): # Conditionally grow an ensemble depending on whether the splits @@ -1105,6 +1131,9 @@ class GradientBoostedDecisionTreeModel(object): def get_number_of_trees_tensor(self): return self._finalized_trees, self._attempted_trees + def get_max_tree_depth(self): + return self._max_tree_depth + def train(self, loss, predictions_dict, labels): """Updates the accumalator stats and grows the ensemble. diff --git a/tensorflow/contrib/checkpoint/__init__.py b/tensorflow/contrib/checkpoint/__init__.py index 2fbaa31d5e19b58c335cd0a894e1db9af2c34d08..150d734db6cdd8023ab6d91a49872f657bcdbdea 100644 --- a/tensorflow/contrib/checkpoint/__init__.py +++ b/tensorflow/contrib/checkpoint/__init__.py @@ -31,6 +31,12 @@ Checkpointable data structures: @@List @@Mapping @@UniqueNameTracker + +Checkpoint management: +@@CheckpointManager + +Saving and restoring Python state: +@@NumpyState """ from __future__ import absolute_import @@ -38,9 +44,11 @@ from __future__ import division from __future__ import print_function from tensorflow.contrib.checkpoint.python.containers import UniqueNameTracker +from tensorflow.contrib.checkpoint.python.python_state import NumpyState from tensorflow.contrib.checkpoint.python.split_dependency import split_dependency from tensorflow.contrib.checkpoint.python.visualize import dot_graph_from_checkpoint from tensorflow.core.protobuf.checkpointable_object_graph_pb2 import CheckpointableObjectGraph +from tensorflow.python.training.checkpoint_management import CheckpointManager from tensorflow.python.training.checkpointable.base import CheckpointableBase from tensorflow.python.training.checkpointable.data_structures import List from tensorflow.python.training.checkpointable.data_structures import Mapping diff --git a/tensorflow/contrib/checkpoint/python/BUILD b/tensorflow/contrib/checkpoint/python/BUILD index 7b200a29bf60087d6da1010b0be05c04faec80cd..ada41687261ab63286933d01da4e286173042e0c 100644 --- a/tensorflow/contrib/checkpoint/python/BUILD +++ b/tensorflow/contrib/checkpoint/python/BUILD @@ -9,6 +9,7 @@ py_library( srcs_version = "PY2AND3", deps = [ ":containers", + ":python_state", ":split_dependency", ":visualize", "//tensorflow/python/training/checkpointable:data_structures", @@ -40,6 +41,33 @@ py_test( ], ) +py_library( + name = "python_state", + srcs = ["python_state.py"], + srcs_version = "PY2AND3", + visibility = ["//tensorflow:internal"], + deps = [ + "//tensorflow/python/training/checkpointable:base", + "//third_party/py/numpy", + "@six_archive//:six", + ], +) + +py_test( + name = "python_state_test", + srcs = ["python_state_test.py"], + deps = [ + ":python_state", + "//tensorflow/python:framework_ops", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:session", + "//tensorflow/python:variables", + "//tensorflow/python/eager:test", + "//tensorflow/python/training/checkpointable:util", + "//third_party/py/numpy", + ], +) + py_library( name = "split_dependency", srcs = ["split_dependency.py"], diff --git a/tensorflow/contrib/checkpoint/python/python_state.py b/tensorflow/contrib/checkpoint/python/python_state.py new file mode 100644 index 0000000000000000000000000000000000000000..9b11035b6d277851ea0a0071062bf5cf6b6b2185 --- /dev/null +++ b/tensorflow/contrib/checkpoint/python/python_state.py @@ -0,0 +1,166 @@ +"""Utilities for including Python state in TensorFlow checkpoints.""" +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import functools + +import numpy + +from tensorflow.python.training.checkpointable import base + +# pylint: disable=g-import-not-at-top +try: + # In Python 2.x, use the faster string buffering option. + from cStringIO import StringIO as BytesIO +except ImportError: + from io import BytesIO +# pylint: enable=g-import-not-at-top + + +class NumpyState(base.CheckpointableBase): + """A checkpointable object whose NumPy array attributes are saved/restored. + + Example usage: + + ```python + arrays = tf.contrib.checkpoint.NumpyState() + checkpoint = tf.train.Checkpoint(numpy_arrays=arrays) + arrays.x = numpy.zeros([3, 4]) + save_path = checkpoint.save("/tmp/ckpt") + arrays.x[1, 1] = 4. + checkpoint.restore(save_path) + assert (arrays.x == numpy.zeros([3, 4])).all() + + second_checkpoint = tf.train.Checkpoint( + numpy_arrays=tf.contrib.checkpoint.NumpyState()) + # Attributes of NumpyState objects are created automatically by restore() + second_checkpoint.restore(save_path) + assert (second_checkpoint.numpy_arrays.x == numpy.zeros([3, 4])).all() + ``` + + Note that `NumpyState` objects re-create the attributes of the previously + saved object on `restore()`. This is in contrast to TensorFlow variables, for + which a `Variable` object must be created and assigned to an attribute. + + This snippet works both when graph building and when executing eagerly. On + save, the NumPy array(s) are fed as strings to be saved in the checkpoint (via + a placeholder when graph building, or as a string constant when executing + eagerly). When restoring they skip the TensorFlow graph entirely, and so no + restore ops need be run. This means that restoration always happens eagerly, + rather than waiting for `checkpoint.restore(...).run_restore_ops()` like + TensorFlow variables when graph building. + """ + + def _lookup_dependency(self, name): + """Create placeholder NumPy arrays for to-be-restored attributes. + + Typically `_lookup_dependency` is used to check by name whether a dependency + exists. We cheat slightly by creating a checkpointable object for `name` if + we don't already have one, giving us attribute re-creation behavior when + loading a checkpoint. + + Args: + name: The name of the dependency being checked. + Returns: + An existing dependency if one exists, or a new `_NumpyWrapper` placeholder + dependency (which will generally be restored immediately). + """ + value = super(NumpyState, self)._lookup_dependency(name) + if value is None: + value = _NumpyWrapper(numpy.array([])) + new_reference = base.CheckpointableReference(name=name, ref=value) + self._unconditional_checkpoint_dependencies.append(new_reference) + self._unconditional_dependency_names[name] = value + super(NumpyState, self).__setattr__(name, value) + return value + + def __getattribute__(self, name): + """Un-wrap `_NumpyWrapper` objects when accessing attributes.""" + value = super(NumpyState, self).__getattribute__(name) + if isinstance(value, _NumpyWrapper): + return value.array + return value + + def __setattr__(self, name, value): + """Automatically wrap NumPy arrays assigned to attributes.""" + # TODO(allenl): Consider supporting lists/tuples, either ad-hoc or by making + # ndarrays checkpointable natively and using standard checkpointable list + # tracking. + if isinstance(value, numpy.ndarray): + try: + existing = super(NumpyState, self).__getattribute__(name) + existing.array = value + return + except AttributeError: + value = _NumpyWrapper(value) + self._track_checkpointable(value, name=name, overwrite=True) + elif (name not in ("_setattr_tracking", "_update_uid") + and getattr(self, "_setattr_tracking", True)): + # Mixing restore()-created attributes with user-added checkpointable + # objects is tricky, since we can't use the `_lookup_dependency` trick to + # re-create attributes (we might accidentally steal the restoration for + # another checkpointable object). For now `NumpyState` objects must be + # leaf nodes. Theoretically we could add some extra arguments to + # `_lookup_dependency` to figure out whether we should create a NumPy + # array for the attribute or not. + raise NotImplementedError( + ("Assigned %s to the %s property of %s, which is not a NumPy array. " + "Currently mixing NumPy arrays and other checkpointable objects is " + "not supported. File a feature request if this limitation bothers " + "you.") + % (value, name, self)) + super(NumpyState, self).__setattr__(name, value) + + +class _NumpyWrapper(base.CheckpointableBase): + """Wraps a NumPy array for storage in an object-based checkpoint.""" + + def __init__(self, array): + """Specify a NumPy array to wrap. + + Args: + array: The NumPy array to save and restore (may be overwritten). + """ + self.array = array + + def _serialize(self): + """Callback for `PythonStringStateSaveable` to serialize the array.""" + string_file = BytesIO() + try: + numpy.save(string_file, self.array, allow_pickle=False) + serialized = string_file.getvalue() + finally: + string_file.close() + return serialized + + def _deserialize(self, string_value): + """Callback for `PythonStringStateSaveable` to deserialize the array.""" + string_file = BytesIO(string_value) + try: + self.array = numpy.load(string_file, allow_pickle=False) + finally: + string_file.close() + + def _gather_saveables_for_checkpoint(self): + """Specify callbacks for saving and restoring `array`.""" + return { + "array": functools.partial( + base.PythonStringStateSaveable, + state_callback=self._serialize, + restore_callback=self._deserialize) + } diff --git a/tensorflow/contrib/checkpoint/python/python_state_test.py b/tensorflow/contrib/checkpoint/python/python_state_test.py new file mode 100644 index 0000000000000000000000000000000000000000..0439a4755e36fc3be6e065d18d3e835feda8aab3 --- /dev/null +++ b/tensorflow/contrib/checkpoint/python/python_state_test.py @@ -0,0 +1,101 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os + +import numpy + +from tensorflow.contrib.checkpoint.python import python_state +from tensorflow.python.client import session +from tensorflow.python.eager import test +from tensorflow.python.framework import ops +from tensorflow.python.framework import test_util +from tensorflow.python.ops import variables +from tensorflow.python.training.checkpointable import util + + +class NumpyStateTests(test.TestCase): + + @test_util.run_in_graph_and_eager_modes + def testSaveRestoreNumpyState(self): + directory = self.get_temp_dir() + prefix = os.path.join(directory, "ckpt") + save_state = python_state.NumpyState() + saver = util.Checkpoint(numpy=save_state) + save_state.a = numpy.ones([2, 2]) + save_state.b = numpy.ones([2, 2]) + save_state.b = numpy.zeros([2, 2]) + self.assertAllEqual(numpy.ones([2, 2]), save_state.a) + self.assertAllEqual(numpy.zeros([2, 2]), save_state.b) + first_save_path = saver.save(prefix) + save_state.a[1, 1] = 2. + second_save_path = saver.save(prefix) + + load_state = python_state.NumpyState() + loader = util.Checkpoint(numpy=load_state) + loader.restore(first_save_path).initialize_or_restore() + self.assertAllEqual(numpy.ones([2, 2]), load_state.a) + self.assertAllEqual(numpy.zeros([2, 2]), load_state.b) + load_state.a[0, 0] = 42. + self.assertAllEqual([[42., 1.], [1., 1.]], load_state.a) + loader.restore(first_save_path).run_restore_ops() + self.assertAllEqual(numpy.ones([2, 2]), load_state.a) + loader.restore(second_save_path).run_restore_ops() + self.assertAllEqual([[1., 1.], [1., 2.]], load_state.a) + self.assertAllEqual(numpy.zeros([2, 2]), load_state.b) + + def testNoGraphPollution(self): + graph = ops.Graph() + with graph.as_default(), session.Session(): + directory = self.get_temp_dir() + prefix = os.path.join(directory, "ckpt") + save_state = python_state.NumpyState() + saver = util.Checkpoint(numpy=save_state) + save_state.a = numpy.ones([2, 2]) + save_path = saver.save(prefix) + saver.restore(save_path) + graph.finalize() + saver.save(prefix) + save_state.a = numpy.zeros([2, 2]) + saver.save(prefix) + saver.restore(save_path) + + @test_util.run_in_graph_and_eager_modes + def testNoMixedNumpyStateTF(self): + save_state = python_state.NumpyState() + save_state.a = numpy.ones([2, 2]) + with self.assertRaises(NotImplementedError): + save_state.v = variables.Variable(1.) + + @test_util.run_in_graph_and_eager_modes + def testDocstringExample(self): + arrays = python_state.NumpyState() + checkpoint = util.Checkpoint(numpy_arrays=arrays) + arrays.x = numpy.zeros([3, 4]) + save_path = checkpoint.save(os.path.join(self.get_temp_dir(), "ckpt")) + arrays.x[1, 1] = 4. + checkpoint.restore(save_path) + self.assertAllEqual(numpy.zeros([3, 4]), arrays.x) + + second_checkpoint = util.Checkpoint(numpy_arrays=python_state.NumpyState()) + second_checkpoint.restore(save_path) + self.assertAllEqual(numpy.zeros([3, 4]), second_checkpoint.numpy_arrays.x) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/cloud/kernels/bigquery_table_accessor.cc b/tensorflow/contrib/cloud/kernels/bigquery_table_accessor.cc index 1bfd27305d569668a0bd67d876e59eec082296b3..58fadffce32f9a8fec047d1e99f9f4eb5a710d91 100644 --- a/tensorflow/contrib/cloud/kernels/bigquery_table_accessor.cc +++ b/tensorflow/contrib/cloud/kernels/bigquery_table_accessor.cc @@ -85,7 +85,7 @@ Status BigQueryTableAccessor::New( int64 timestamp_millis, int64 row_buffer_size, const string& end_point, const std::vector& columns, const BigQueryTablePartition& partition, std::unique_ptr auth_provider, - std::unique_ptr http_request_factory, + std::shared_ptr http_request_factory, std::unique_ptr* accessor) { if (timestamp_millis <= 0) { return errors::InvalidArgument( @@ -94,29 +94,19 @@ Status BigQueryTableAccessor::New( const string& big_query_end_point = end_point.empty() ? kBigQueryEndPoint : end_point; if (auth_provider == nullptr && http_request_factory == nullptr) { - accessor->reset(new BigQueryTableAccessor( - project_id, dataset_id, table_id, timestamp_millis, row_buffer_size, - big_query_end_point, columns, partition)); - } else { - accessor->reset(new BigQueryTableAccessor( - project_id, dataset_id, table_id, timestamp_millis, row_buffer_size, - big_query_end_point, columns, partition, std::move(auth_provider), - std::move(http_request_factory))); + http_request_factory = std::make_shared(); + auto compute_engine_metadata_client = + std::make_shared(http_request_factory); + auth_provider = std::unique_ptr( + new GoogleAuthProvider(compute_engine_metadata_client)); } - return (*accessor)->ReadSchema(); -} -BigQueryTableAccessor::BigQueryTableAccessor( - const string& project_id, const string& dataset_id, const string& table_id, - int64 timestamp_millis, int64 row_buffer_size, const string& end_point, - const std::vector& columns, const BigQueryTablePartition& partition) - : BigQueryTableAccessor( - project_id, dataset_id, table_id, timestamp_millis, row_buffer_size, - end_point, columns, partition, - std::unique_ptr(new GoogleAuthProvider()), - std::unique_ptr( - new CurlHttpRequest::Factory())) { - row_buffer_.resize(row_buffer_size); + accessor->reset(new BigQueryTableAccessor( + project_id, dataset_id, table_id, timestamp_millis, row_buffer_size, + big_query_end_point, columns, partition, std::move(auth_provider), + std::move(http_request_factory))); + + return (*accessor)->ReadSchema(); } BigQueryTableAccessor::BigQueryTableAccessor( @@ -124,7 +114,7 @@ BigQueryTableAccessor::BigQueryTableAccessor( int64 timestamp_millis, int64 row_buffer_size, const string& end_point, const std::vector& columns, const BigQueryTablePartition& partition, std::unique_ptr auth_provider, - std::unique_ptr http_request_factory) + std::shared_ptr http_request_factory) : project_id_(project_id), dataset_id_(dataset_id), table_id_(table_id), diff --git a/tensorflow/contrib/cloud/kernels/bigquery_table_accessor.h b/tensorflow/contrib/cloud/kernels/bigquery_table_accessor.h index b349063715c903c982cfe2fb116b6525e35ff63b..1af43a3e1070d466bb50019f12b22a060c1e6ab1 100644 --- a/tensorflow/contrib/cloud/kernels/bigquery_table_accessor.h +++ b/tensorflow/contrib/cloud/kernels/bigquery_table_accessor.h @@ -109,24 +109,17 @@ class BigQueryTableAccessor { const std::vector& columns, const BigQueryTablePartition& partition, std::unique_ptr auth_provider, - std::unique_ptr http_request_factory, + std::shared_ptr http_request_factory, std::unique_ptr* accessor); /// \brief Constructs an object for a given table and partition. - BigQueryTableAccessor(const string& project_id, const string& dataset_id, - const string& table_id, int64 timestamp_millis, - int64 row_buffer_size, const string& end_point, - const std::vector& columns, - const BigQueryTablePartition& partition); - - /// Used for unit testing. BigQueryTableAccessor( const string& project_id, const string& dataset_id, const string& table_id, int64 timestamp_millis, int64 row_buffer_size, const string& end_point, const std::vector& columns, const BigQueryTablePartition& partition, std::unique_ptr auth_provider, - std::unique_ptr http_request_factory); + std::shared_ptr http_request_factory); /// \brief Parses column values for a given row. Status ParseColumnValues(const Json::Value& value, @@ -199,7 +192,7 @@ class BigQueryTableAccessor { SchemaNode schema_root_; std::unique_ptr auth_provider_; - std::unique_ptr http_request_factory_; + std::shared_ptr http_request_factory_; TF_DISALLOW_COPY_AND_ASSIGN(BigQueryTableAccessor); }; diff --git a/tensorflow/contrib/cloud/python/ops/gcs_config_ops.py b/tensorflow/contrib/cloud/python/ops/gcs_config_ops.py index 95e7e744d34391a511cdba7702aad369b8d9d9c0..cb45e42734256d140276fafdb39c0a44199a4e9d 100644 --- a/tensorflow/contrib/cloud/python/ops/gcs_config_ops.py +++ b/tensorflow/contrib/cloud/python/ops/gcs_config_ops.py @@ -19,6 +19,7 @@ from __future__ import division from __future__ import print_function import json +import os from tensorflow.contrib.cloud.python.ops import gen_gcs_config_ops from tensorflow.python.framework import dtypes @@ -188,6 +189,8 @@ def configure_colab_session(session): session: A `tf.Session` session. """ # Read from the application default credentials (adc). - with open('/content/datalab/adc.json') as f: + adc_filename = os.environ.get( + 'GOOGLE_APPLICATION_CREDENTIALS', '/content/adc.json') + with open(adc_filename) as f: data = json.load(f) configure_gcs(session, credentials=data) diff --git a/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py b/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py index 8f521ffee4d31e090c13bac98290656d6e1d330e..1ab150d74ac00c5f9acf3c9399880708b2f62b1e 100644 --- a/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py +++ b/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py @@ -148,6 +148,9 @@ class TPUClusterResolver(ClusterResolver): else: tpu = self._envVarFallback() + if tpu is None: + raise ValueError('Please provide a TPU Name to connect to.') + self._tpu = compat.as_bytes(tpu) # self._tpu is always bytes self._job_name = job_name self._credentials = credentials @@ -259,11 +262,11 @@ class TPUClusterResolver(ClusterResolver): if 'state' in response and response['state'] != 'READY': raise RuntimeError('TPU "%s" is not yet ready; state: "%s"' % - (self._tpu, response['state'])) + (compat.as_text(self._tpu), response['state'])) if 'health' in response and response['health'] != 'HEALTHY': - raise RuntimeError('TPU "%s" is unhealthy: "%s"' % (self._tpu, - response['health'])) + raise RuntimeError('TPU "%s" is unhealthy: "%s"' % + (compat.as_text(self._tpu), response['health'])) if 'networkEndpoints' in response: worker_list = [ diff --git a/tensorflow/contrib/cmake/CMakeLists.txt b/tensorflow/contrib/cmake/CMakeLists.txt index 6c93487e0d6d5caf576f1af80c36e4df895e6afa..f6c928e2be62e7292c6feaa3bb26fd463320158b 100644 --- a/tensorflow/contrib/cmake/CMakeLists.txt +++ b/tensorflow/contrib/cmake/CMakeLists.txt @@ -471,7 +471,6 @@ if (tensorflow_ENABLE_GPU) ${CUDA_TOOLKIT_TARGET_DIR}/include/cuComplex.h ${CUDA_TOOLKIT_TARGET_DIR}/include/cublas_v2.h ${CUDA_TOOLKIT_TARGET_DIR}/include/cusolverDn.h - ${CUDA_TOOLKIT_TARGET_DIR}/include/cuda_fp16.h ${CUDA_TOOLKIT_TARGET_DIR}/include/device_functions.h ${CUDA_TOOLKIT_TARGET_DIR}/include/cufft.h ${CUDA_TOOLKIT_TARGET_DIR}/include/curand.h diff --git a/tensorflow/contrib/cmake/external/eigen.cmake b/tensorflow/contrib/cmake/external/eigen.cmake index 45a0096085cc2a6332c82e1ea284812acdd45152..33bb31148d2e5b7ca177d7c30b7781e8f620c3cb 100644 --- a/tensorflow/contrib/cmake/external/eigen.cmake +++ b/tensorflow/contrib/cmake/external/eigen.cmake @@ -19,6 +19,12 @@ # build_file = "eigen.BUILD", #) +option(eigen_PATCH_FILE "Patch file to apply to eigen" OFF) +set(eigen_PATCH_COMMAND "") +if(eigen_PATCH_FILE) + set(eigen_PATCH_COMMAND PATCH_COMMAND patch -p0 -i "${eigen_PATCH_FILE}") +endif(eigen_PATCH_FILE) + include (ExternalProject) # We parse the current Eigen version and archive hash from the bazel configuration @@ -45,6 +51,7 @@ ExternalProject_Add(eigen URL ${eigen_URL} DOWNLOAD_DIR "${DOWNLOAD_LOCATION}" INSTALL_DIR "${eigen_INSTALL}" + ${eigen_PATCH_COMMAND} CMAKE_CACHE_ARGS -DCMAKE_BUILD_TYPE:STRING=Release -DCMAKE_VERBOSE_MAKEFILE:BOOL=OFF diff --git a/tensorflow/contrib/cmake/external/highwayhash.cmake b/tensorflow/contrib/cmake/external/highwayhash.cmake index a6e8a38d8c2ee3deb5453c264e0c5eb23248301f..7d260b85f21e7e56e153daf550c81155e4b68777 100644 --- a/tensorflow/contrib/cmake/external/highwayhash.cmake +++ b/tensorflow/contrib/cmake/external/highwayhash.cmake @@ -20,14 +20,6 @@ set(highwayhash_TAG be5edafc2e1a455768e260ccd68ae7317b6690ee) set(highwayhash_BUILD ${CMAKE_CURRENT_BINARY_DIR}/highwayhash/src/highwayhash) set(highwayhash_INSTALL ${CMAKE_CURRENT_BINARY_DIR}/highwayhash/install) -# put highwayhash includes in the directory where they are expected -add_custom_target(highwayhash_create_destination_dir - COMMAND ${CMAKE_COMMAND} -E make_directory ${highwayhash_INCLUDE_DIR}/highwayhash - DEPENDS highwayhash) - -add_custom_target(highwayhash_copy_headers_to_destination - DEPENDS highwayhash_create_destination_dir) - if(WIN32) set(highwayhash_HEADERS "${highwayhash_BUILD}/highwayhash/*.h") set(highwayhash_STATIC_LIBRARIES ${highwayhash_INSTALL}/lib/highwayhash.lib) @@ -36,6 +28,20 @@ else() set(highwayhash_STATIC_LIBRARIES ${highwayhash_INSTALL}/lib/libhighwayhash.a) endif() +set(highwayhash_HEADERS + "${highwayhash_INSTALL}/include/code_annotation.h" + "${highwayhash_INSTALL}/include/highway_tree_hash.h" + "${highwayhash_INSTALL}/include/scalar_highway_tree_hash.h" + "${highwayhash_INSTALL}/include/scalar_sip_tree_hash.h" + "${highwayhash_INSTALL}/include/sip_hash.h" + "${highwayhash_INSTALL}/include/sip_tree_hash.h" + "${highwayhash_INSTALL}/include/sse41_highway_tree_hash.h" + "${highwayhash_INSTALL}/include/state_helpers.h" + "${highwayhash_INSTALL}/include/types.h" + "${highwayhash_INSTALL}/include/vec.h" + "${highwayhash_INSTALL}/include/vec2.h" +) + ExternalProject_Add(highwayhash PREFIX highwayhash GIT_REPOSITORY ${highwayhash_URL} @@ -50,5 +56,15 @@ ExternalProject_Add(highwayhash -DCMAKE_VERBOSE_MAKEFILE:BOOL=OFF -DCMAKE_INSTALL_PREFIX:STRING=${highwayhash_INSTALL}) -add_custom_command(TARGET highwayhash_copy_headers_to_destination PRE_BUILD - COMMAND ${CMAKE_COMMAND} -E copy_directory ${highwayhash_INSTALL}/include/ ${highwayhash_INCLUDE_DIR}/highwayhash) +# put highwayhash includes in the directory where they are expected +add_custom_target(highwayhash_create_destination_dir + COMMAND ${CMAKE_COMMAND} -E make_directory ${highwayhash_INCLUDE_DIR}/highwayhash + DEPENDS highwayhash) + +add_custom_target(highwayhash_copy_headers_to_destination + DEPENDS highwayhash_create_destination_dir) + +foreach(header_file ${highwayhash_HEADERS}) + add_custom_command(TARGET highwayhash_copy_headers_to_destination PRE_BUILD + COMMAND ${CMAKE_COMMAND} -E copy_if_different ${header_file} ${highwayhash_INCLUDE_DIR}/highwayhash/) +endforeach() diff --git a/tensorflow/contrib/cmake/external/nsync.cmake b/tensorflow/contrib/cmake/external/nsync.cmake index eba3bcfc79efe87d0a45c979c5accfa1b6511ed0..479609458c64f7c7bd7b3ce6b23aceaa3db17f21 100644 --- a/tensorflow/contrib/cmake/external/nsync.cmake +++ b/tensorflow/contrib/cmake/external/nsync.cmake @@ -16,24 +16,16 @@ include (ExternalProject) set(nsync_INCLUDE_DIR ${CMAKE_CURRENT_BINARY_DIR}/external/nsync/public) set(nsync_URL https://github.com/google/nsync) -set(nsync_TAG 1.20.0) +set(nsync_TAG 1.20.1) set(nsync_BUILD ${CMAKE_CURRENT_BINARY_DIR}/nsync/src/nsync) set(nsync_INSTALL ${CMAKE_CURRENT_BINARY_DIR}/nsync/install) -# put nsync includes in the directory where they are expected -add_custom_target(nsync_create_destination_dir - COMMAND ${CMAKE_COMMAND} -E make_directory ${nsync_INCLUDE_DIR} - DEPENDS nsync) - -add_custom_target(nsync_copy_headers_to_destination - DEPENDS nsync_create_destination_dir) - if(WIN32) set(nsync_HEADERS "${nsync_BUILD}/public/*.h") - set(nsync_STATIC_LIBRARIES ${nsync_INSTALL}/lib/nsync.lib) + set(nsync_STATIC_LIBRARIES ${nsync_INSTALL}/lib/nsync_cpp.lib) else() set(nsync_HEADERS "${nsync_BUILD}/public/*.h") - set(nsync_STATIC_LIBRARIES ${nsync_INSTALL}/lib/libnsync.a) + set(nsync_STATIC_LIBRARIES ${nsync_INSTALL}/lib/libnsync_cpp.a) endif() ExternalProject_Add(nsync @@ -43,13 +35,41 @@ ExternalProject_Add(nsync DOWNLOAD_DIR "${DOWNLOAD_LOCATION}" BUILD_IN_SOURCE 1 BUILD_BYPRODUCTS ${nsync_STATIC_LIBRARIES} - PATCH_COMMAND ${CMAKE_COMMAND} -E copy_if_different ${CMAKE_CURRENT_SOURCE_DIR}/patches/nsync/CMakeLists.txt ${nsync_BUILD} INSTALL_DIR ${nsync_INSTALL} CMAKE_CACHE_ARGS -DCMAKE_BUILD_TYPE:STRING=Release -DCMAKE_VERBOSE_MAKEFILE:BOOL=OFF -DCMAKE_INSTALL_PREFIX:STRING=${nsync_INSTALL} - -DNSYNC_LANGUAGE:STRING=c++11) + -DCMAKE_INSTALL_LIBDIR:STRING=lib + -DNSYNC_LANGUAGE:STRING=c++11) + +set(nsync_HEADERS + "${nsync_INSTALL}/include/nsync.h" + "${nsync_INSTALL}/include/nsync_atomic.h" + "${nsync_INSTALL}/include/nsync_counter.h" + "${nsync_INSTALL}/include/nsync_cpp.h" + "${nsync_INSTALL}/include/nsync_cv.h" + "${nsync_INSTALL}/include/nsync_debug.h" + "${nsync_INSTALL}/include/nsync_mu.h" + "${nsync_INSTALL}/include/nsync_mu_wait.h" + "${nsync_INSTALL}/include/nsync_note.h" + "${nsync_INSTALL}/include/nsync_once.h" + "${nsync_INSTALL}/include/nsync_time.h" + "${nsync_INSTALL}/include/nsync_time_internal.h" + "${nsync_INSTALL}/include/nsync_waiter.h" +) + +# put nsync includes in the directory where they are expected +add_custom_target(nsync_create_destination_dir + COMMAND ${CMAKE_COMMAND} -E make_directory ${nsync_INCLUDE_DIR} + DEPENDS nsync) + +add_custom_target(nsync_copy_headers_to_destination + DEPENDS nsync_create_destination_dir) + +foreach(header_file ${nsync_HEADERS}) + add_custom_command(TARGET nsync_copy_headers_to_destination PRE_BUILD + COMMAND ${CMAKE_COMMAND} -E copy_if_different ${header_file} ${nsync_INCLUDE_DIR}/) +endforeach() + -add_custom_command(TARGET nsync_copy_headers_to_destination PRE_BUILD - COMMAND ${CMAKE_COMMAND} -E copy_directory ${nsync_INSTALL}/include/ ${nsync_INCLUDE_DIR}/) diff --git a/tensorflow/contrib/cmake/patches/nsync/CMakeLists.txt b/tensorflow/contrib/cmake/patches/nsync/CMakeLists.txt deleted file mode 100644 index 6f059c7225dd0938b758e8f9c28ec36fcff6db4c..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/cmake/patches/nsync/CMakeLists.txt +++ /dev/null @@ -1,325 +0,0 @@ -cmake_minimum_required (VERSION 2.8.12) - -# nsync provides portable synchronization primitives, such as mutexes and -# condition variables. -project (nsync) - -# Set variable NSYNC_LANGUAGE to "c++11" to build with C++11 -# rather than C. - -# Some builds need position-independent code. -set (CMAKE_POSITION_INDEPENDENT_CODE ON) - -# ----------------------------------------------------------------- -# Platform dependencies - -# Many platforms use these posix related sources; even Win32. -set (NSYNC_POSIX_SRC - "platform/posix/src/nsync_panic.c" - "platform/posix/src/per_thread_waiter.c" - "platform/posix/src/time_rep.c" - "platform/posix/src/yield.c" -) - -if (WIN32) - # Suppress warnings to reduce build log size. - add_definitions(/wd4267 /wd4244 /wd4800 /wd4503 /wd4554 /wd4996 /wd4348 /wd4018) - add_definitions(/wd4099 /wd4146 /wd4267 /wd4305 /wd4307) - add_definitions(/wd4715 /wd4722 /wd4723 /wd4838 /wd4309 /wd4334) - add_definitions(/wd4003 /wd4244 /wd4267 /wd4503 /wd4506 /wd4800 /wd4996) - add_definitions(/wd8029) -endif() - -# Many of the string matches below use a literal "X" suffix on both sides. -# This is because some versions of cmake treat (for example) "MSVC" (in quotes) -# as a reference to the variable MSVC, thus the expression -# "${CMAKE_C_COMPILER_ID}" STREQUAL "MSVC" -# is false when ${CMAKE_C_COMPILER_ID} has the value "MSVC"! See -# https://cmake.org/cmake/help/v3.1/policy/CMP0054.html - -# Pick the include directory for the operating system. -if ("${NSYNC_LANGUAGE}X" STREQUAL "c++11X") - include_directories ("${PROJECT_SOURCE_DIR}/platform/c++11") - add_definitions ("-DNSYNC_USE_CPP11_TIMEPOINT -DNSYNC_ATOMIC_CPP11") - set (NSYNC_OS_CPP_SRC - "platform/c++11/src/per_thread_waiter.cc" - "platform/c++11/src/yield.cc" - "platform/c++11/src/time_rep_timespec.cc" - "platform/c++11/src/nsync_panic.cc" - ) - if ("${CMAKE_SYSTEM_NAME}X" STREQUAL "WindowsX") - include_directories ("${PROJECT_SOURCE_DIR}/platform/win32") - add_compile_options ("/TP") - set (NSYNC_OS_SRC - "platform/c++11/src/nsync_semaphore_mutex.cc" - "platform/win32/src/clock_gettime.c" - "platform/win32/src/pthread_key_win32.cc" - ${NSYNC_OS_CPP_SRC} - ) - set (NSYNC_TEST_OS_SRC - "platform/win32/src/start_thread.c" - ) - elseif ("${CMAKE_SYSTEM_NAME}X" STREQUAL "DarwinX") - include_directories ("${PROJECT_SOURCE_DIR}/platform/macos") - include_directories ("${PROJECT_SOURCE_DIR}/platform/posix") - # Some versions of MacOS, such as Sierra, require _DARWIN_C_SOURCE - # when including certin C++ standard header files, such as . - add_definitions ("-D_DARWIN_C_SOURCE") - add_compile_options ("-std=c++11") - set (NSYNC_OS_SRC - ${NSYNC_OS_CPP_SRC} - "platform/c++11/src/nsync_semaphore_mutex.cc" - "platform/posix/src/clock_gettime.c" - "platform/posix/src/nsync_semaphore_mutex.c" - ) - set (NSYNC_TEST_OS_SRC - "platform/posix/src/start_thread.c" - ) - elseif ("${CMAKE_SYSTEM_NAME}X" STREQUAL "LinuxX") - include_directories (BEFORE "${PROJECT_SOURCE_DIR}/platform/c++11.futex") - include_directories ("${PROJECT_SOURCE_DIR}/platform/posix") - add_compile_options ("-std=c++11") - set (NSYNC_OS_SRC - "platform/linux/src/nsync_semaphore_futex.c" - ${NSYNC_OS_CPP_SRC} - ) - set (NSYNC_TEST_OS_SRC - "platform/posix/src/start_thread.c" - ) - elseif ("${CMAKE_SYSTEM_NAME}X" STREQUAL "NetBSDX") - include_directories ("${PROJECT_SOURCE_DIR}/platform/posix") - add_compile_options ("-std=c++11") - set (NSYNC_OS_SRC - "platform/c++11/src/nsync_semaphore_mutex.cc" - ${NSYNC_OS_CPP_SRC} - ) - set (NSYNC_TEST_OS_SRC - "platform/posix/src/start_thread.c" - ) - elseif ("${CMAKE_SYSTEM_NAME}X" STREQUAL "FreeBSDX") - include_directories ("${PROJECT_SOURCE_DIR}/platform/posix") - add_compile_options ("-std=c++11") - set (NSYNC_OS_SRC - "platform/c++11/src/nsync_semaphore_mutex.cc" - ${NSYNC_OS_CPP_SRC} - ) - set (NSYNC_TEST_OS_SRC - "platform/posix/src/start_thread.c" - ) - elseif ("${CMAKE_SYSTEM_NAME}X" STREQUAL "OpenBSDX") - include_directories ("${PROJECT_SOURCE_DIR}/platform/posix") - add_compile_options ("-std=c++11") - set (NSYNC_OS_SRC - "platform/c++11/src/nsync_semaphore_mutex.cc" - ${NSYNC_OS_CPP_SRC} - ) - set (NSYNC_TEST_OS_SRC - "platform/posix/src/start_thread.c" - ) - endif () -endif () - -# Pick the include directory for the compiler. -if ("${CMAKE_C_COMPILER_ID}X" STREQUAL "GNUX") - include_directories ("${PROJECT_SOURCE_DIR}/platform/gcc") - set (THREADS_HAVE_PTHREAD_ARG ON) -elseif ("${CMAKE_C_COMPILER_ID}X" STREQUAL "ClangX") - include_directories ("${PROJECT_SOURCE_DIR}/platform/clang") - set (THREADS_HAVE_PTHREAD_ARG ON) -elseif ("${CMAKE_C_COMPILER_ID}X" STREQUAL "MSVCX") - include_directories ("${PROJECT_SOURCE_DIR}/platform/msvc") -else () - message (WARNING "CMAKE_C_COMPILER_ID (${CMAKE_C_COMPILER_ID}) matched NOTHING") -endif () - -if (NOT "${NSYNC_LANGUAGE}X" STREQUAL "c++11X") - if ("${CMAKE_SYSTEM_NAME}X" STREQUAL "WindowsX") - include_directories ("${PROJECT_SOURCE_DIR}/platform/win32") - set (NSYNC_OS_SRC - ${NSYNC_POSIX_SRC} - "platform/win32/src/clock_gettime.c" - "platform/win32/src/init_callback_win32.c" - "platform/win32/src/nanosleep.c" - "platform/win32/src/nsync_semaphore_win32.c" - "platform/win32/src/pthread_cond_timedwait_win32.c" - "platform/win32/src/pthread_key_win32.cc" - ) - set (NSYNC_TEST_OS_SRC - "platform/win32/src/start_thread.c" - ) - elseif ("${CMAKE_SYSTEM_NAME}X" STREQUAL "DarwinX") - include_directories ("${PROJECT_SOURCE_DIR}/platform/macos") - set (NSYNC_POSIX ON) - set (NSYNC_OS_EXTRA_SRC - "platform/posix/src/clock_gettime.c" - "platform/posix/src/nsync_semaphore_mutex.c" - ) - include_directories ("${PROJECT_SOURCE_DIR}/platform/posix") - elseif ("${CMAKE_SYSTEM_NAME}X" STREQUAL "LinuxX") - include_directories ("${PROJECT_SOURCE_DIR}/platform/linux") - set (NSYNC_POSIX ON) - set (NSYNC_OS_EXTRA_SRC - "platform/linux/src/nsync_semaphore_futex.c" - ) - elseif ("${CMAKE_SYSTEM_NAME}X" STREQUAL "NetBSDX") - include_directories ("${PROJECT_SOURCE_DIR}/platform/netbsd") - set (NSYNC_POSIX ON) - set (NSYNC_OS_EXTRA_SRC - "platform/posix/src/nsync_semaphore_mutex.c" - ) - elseif ("${CMAKE_SYSTEM_NAME}X" STREQUAL "FreeBSDX") - include_directories ("${PROJECT_SOURCE_DIR}/platform/freebsd") - set (NSYNC_POSIX ON) - set (NSYNC_OS_EXTRA_SRC - "platform/posix/src/nsync_semaphore_mutex.c" - ) - elseif ("${CMAKE_SYSTEM_NAME}X" STREQUAL "OpenBSDX") - include_directories ("${PROJECT_SOURCE_DIR}/platform/openbsd") - set (NSYNC_POSIX ON) - set (NSYNC_OS_EXTRA_SRC - "platform/posix/src/nsync_semaphore_mutex.c" - ) - endif () -endif () - -if (NSYNC_POSIX) - include_directories ("${PROJECT_SOURCE_DIR}/platform/posix") - set (NSYNC_OS_SRC - ${NSYNC_POSIX_SRC} - ${NSYNC_OS_EXTRA_SRC} - ) - set (NSYNC_TEST_OS_SRC - "platform/posix/src/start_thread.c" - ) -endif () - -# Pick the include directory for the architecture. -if (("${CMAKE_SYSTEM_PROCESSOR}X" STREQUAL "x86_64X") OR - ("${CMAKE_SYSTEM_PROCESSOR}X" STREQUAL "amd64X") OR - ("${CMAKE_SYSTEM_PROCESSOR}X" STREQUAL "AMD64X")) - include_directories ("${PROJECT_SOURCE_DIR}/platform/x86_64") -elseif (("${CMAKE_SYSTEM_PROCESSOR}X" STREQUAL "x86_32X") OR - ("${CMAKE_SYSTEM_PROCESSOR}X" STREQUAL "i386X") OR - ("${CMAKE_SYSTEM_PROCESSOR}X" STREQUAL "i686X")) - include_directories ("${PROJECT_SOURCE_DIR}/platform/x86_32") -elseif (("${CMAKE_SYSTEM_PROCESSOR}X" STREQUAL "armv6lX") OR - ("${CMAKE_SYSTEM_PROCESSOR}X" STREQUAL "armv7lX") OR - ("${CMAKE_SYSTEM_PROCESSOR}X" STREQUAL "armX")) - include_directories ("${PROJECT_SOURCE_DIR}/platform/arm") -elseif (("${CMAKE_SYSTEM_PROCESSOR}X" STREQUAL "aarch64X") OR - ("${CMAKE_SYSTEM_PROCESSOR}X" STREQUAL "arm64X")) - include_directories ("${PROJECT_SOURCE_DIR}/platform/aarch64") -elseif (("${CMAKE_SYSTEM_PROCESSOR}X" STREQUAL "ppcX") OR - ("${CMAKE_SYSTEM_PROCESSOR}X" STREQUAL "ppc32X")) - include_directories ("${PROJECT_SOURCE_DIR}/platform/ppc32") -elseif (("${CMAKE_SYSTEM_PROCESSOR}X" STREQUAL "ppc64X")) - include_directories ("${PROJECT_SOURCE_DIR}/platform/ppc64") -endif () - -# Windows uses some include files from the posix directory also. -if ("${CMAKE_SYSTEM_NAME}X" STREQUAL "WindowsX") - include_directories ("${PROJECT_SOURCE_DIR}/platform/posix") -endif () - -# ----------------------------------------------------------------- - -include_directories ("${PROJECT_SOURCE_DIR}/public") -include_directories ("${PROJECT_SOURCE_DIR}/internal") - -set (NSYNC_SRC - "internal/common.c" - "internal/counter.c" - "internal/cv.c" - "internal/debug.c" - "internal/dll.c" - "internal/mu.c" - "internal/mu_wait.c" - "internal/note.c" - "internal/once.c" - "internal/sem_wait.c" - "internal/time_internal.c" - "internal/wait.c" - ${NSYNC_OS_SRC} -) -add_library (nsync ${NSYNC_SRC}) - -set (NSYNC_TEST_SRC - "testing/array.c" - "testing/atm_log.c" - "testing/closure.c" - "testing/smprintf.c" - "testing/testing.c" - "testing/time_extra.c" - ${NSYNC_TEST_OS_SRC} -) -add_library (nsync_test ${NSYNC_TEST_SRC}) - -set (NSYNC_TESTS - "counter_test" - "cv_mu_timeout_stress_test" - "cv_test" - "cv_wait_example_test" - "dll_test" - "mu_starvation_test" - "mu_test" - "mu_wait_example_test" - "mu_wait_test" - "note_test" - "once_test" - "pingpong_test" - "wait_test" -) - -if ("${NSYNC_LANGUAGE}X" STREQUAL "c++11X") - foreach (s IN ITEMS ${NSYNC_SRC} ${NSYNC_TEST_SRC}) - SET_SOURCE_FILES_PROPERTIES ("${s}" PROPERTIES LANGUAGE CXX) - endforeach (s) - foreach (t IN ITEMS ${NSYNC_TESTS}) - SET_SOURCE_FILES_PROPERTIES ("testing/${t}.c" PROPERTIES LANGUAGE CXX) - endforeach (t) -endif () - -enable_testing () -foreach (t IN ITEMS ${NSYNC_TESTS}) - add_executable (${t} "testing/${t}.c") -endforeach (t) - -find_package (Threads REQUIRED) -set (THREADS_PREFER_PTHREAD_FLAG ON) -foreach (t IN ITEMS "nsync" "nsync_test" ${NSYNC_TESTS}) - if (THREADS_HAVE_PTHREAD_ARG) - target_compile_options (${t} PUBLIC "-pthread") - endif () - if (CMAKE_THREAD_LIBS_INIT) - target_link_libraries (${t} "${CMAKE_THREAD_LIBS_INIT}") - endif () -endforeach (t) - -foreach (t IN ITEMS ${NSYNC_TESTS}) - target_link_libraries (${t} nsync_test nsync) - add_test (NAME ${t} COMMAND ${t}) -endforeach (t) - -install (TARGETS nsync - LIBRARY DESTINATION lib COMPONENT RuntimeLibraries - ARCHIVE DESTINATION lib COMPONENT Development) - -set (NSYNC_INCLUDES - "public/nsync.h" - "public/nsync_atomic.h" - "public/nsync_counter.h" - "public/nsync_cpp.h" - "public/nsync_cv.h" - "public/nsync_debug.h" - "public/nsync_mu.h" - "public/nsync_mu_wait.h" - "public/nsync_note.h" - "public/nsync_once.h" - "public/nsync_time.h" - "public/nsync_time_internal.h" - "public/nsync_waiter.h" -) - -foreach (NSYNC_INCLUDE ${NSYNC_INCLUDES}) - install (FILES ${NSYNC_INCLUDE} DESTINATION include COMPONENT Development) -endforeach () diff --git a/tensorflow/contrib/cmake/python_modules.txt b/tensorflow/contrib/cmake/python_modules.txt index 75e00f32675df1b7e523bc7e8bb44fa584b79347..07934ef3247c6e05323de8ccea70e0264561441f 100644 --- a/tensorflow/contrib/cmake/python_modules.txt +++ b/tensorflow/contrib/cmake/python_modules.txt @@ -4,6 +4,8 @@ tensorflow tensorflow/core tensorflow/core/example tensorflow/core/framework +tensorflow/core/kernels +tensorflow/core/kernels/boosted_trees tensorflow/core/lib tensorflow/core/lib/core tensorflow/core/profiler @@ -115,7 +117,6 @@ tensorflow/contrib/coder tensorflow/contrib/coder/kernels tensorflow/contrib/coder/ops tensorflow/contrib/coder/python -tensorflow/contrib/coder/python/layers tensorflow/contrib/coder/python/ops tensorflow/contrib/compiler tensorflow/contrib/constrained_optimization @@ -187,6 +188,8 @@ tensorflow/contrib/graph_editor/examples tensorflow/contrib/grid_rnn tensorflow/contrib/grid_rnn/python tensorflow/contrib/grid_rnn/python/ops +tensorflow/contrib/hadoop/python +tensorflow/contrib/hadoop/python/ops tensorflow/contrib/hooks tensorflow/contrib/hooks/python tensorflow/contrib/image diff --git a/tensorflow/contrib/cmake/tf_python.cmake b/tensorflow/contrib/cmake/tf_python.cmake index 32b185f07b6ba836ffb47e85beff6fb2481fdc3e..6d86daf5f174a3238ab92e5bba6085c904766766 100755 --- a/tensorflow/contrib/cmake/tf_python.cmake +++ b/tensorflow/contrib/cmake/tf_python.cmake @@ -198,7 +198,7 @@ function(add_python_module MODULE_NAME) # so we currently add explicit commands to include those files # later on in this script. if (NOT "${script}" MATCHES "_test\.py$") - add_custom_command(TARGET tf_python_copy_scripts_to_destination PRE_BUILD + add_custom_command(TARGET tf_python_copy_scripts_to_destination PRE_BUILD COMMAND ${CMAKE_COMMAND} -E copy ${tensorflow_source_dir}/${script} ${CMAKE_CURRENT_BINARY_DIR}/tf_python/${script}) endif() endforeach() @@ -297,7 +297,7 @@ function(GENERATE_PYTHON_OP_LIB tf_python_op_lib_name) ) target_link_libraries(${tf_python_op_lib_name}_gen_python PRIVATE tf_protos_cc - tf_python_protos_cc + tf_python_protos_cc ${tensorflow_EXTERNAL_LIBRARIES} ) @@ -549,15 +549,15 @@ if(WIN32) ${NUMPY_INCLUDE_DIR} ) #target_link_libraries(pywrap_tensorflow_internal_static - # tf_protos_cc - # tf_python_protos_cc + # tf_protos_cc + # tf_python_protos_cc #) add_dependencies(pywrap_tensorflow_internal_static tf_protos_cc tf_python_protos_cc) set(pywrap_tensorflow_internal_static_dependencies $ $ $ - ${nsync_STATIC_LIBRARIES} + ${nsync_STATIC_LIBRARIES} ) if(${CMAKE_GENERATOR} MATCHES "Visual Studio.*") @@ -737,7 +737,7 @@ endif() ######################################################## # Parse tensorflow/python/tools/api/generator/BUILD to get list of generated files. -FILE(READ ${tensorflow_source_dir}/tensorflow/python/tools/api/generator/api_gen.bzl api_generator_BUILD_text) +FILE(READ ${tensorflow_source_dir}/tensorflow/python/tools/api/generator/api_init_files.bzl api_generator_BUILD_text) STRING(REGEX MATCH "# BEGIN GENERATED FILES.*# END GENERATED FILES" api_init_files_text ${api_generator_BUILD_text}) string(REPLACE "# BEGIN GENERATED FILES" "" api_init_files_text ${api_init_files_text}) string(REPLACE "# END GENERATED FILES" "" api_init_files_text ${api_init_files_text}) @@ -763,57 +763,40 @@ file(WRITE "${api_init_list_file}" "${api_init_files}") # recongnize paths. As CUDA isn't built with MKL, the MKL built directory is the only path to this command to work around that issue. # To not override the CUDA and system path in other circumstances, `if-else` branch used here to handle this problem, # and should be removed if the path issue can be resolved. +# UPDATE: Below block appears to handle multiple items in PATH correctly, but risks command line limits if PATH is large. +# If you have issues, try `set(PY_RUNTIME_ENV "PATH=${mkl_BIN_DIRS}")` instead. ### -if (tensorflow_ENABLE_MKL_SUPPORT) +set(PY_RUNTIME_ENV "") +if(tensorflow_ENABLE_MKL_SUPPORT) # add mkl dist dlls to system path for python - # TODO: In current cmake version, PY_RUNTIME_ENV behaves strange with multiple paths, - # so we have to specify only one path in it to work around the issue. We need this if/else - # to protect overwriting CUDA environments - set(PY_RUNTIME_ENV ${mkl_BIN_DIRS}) - add_custom_command( - OUTPUT ${api_init_files} - DEPENDS tf_python_ops tf_python_copy_scripts_to_destination pywrap_tensorflow_internal tf_python_touchup_modules tf_extension_ops - - # tensorflow/__init__.py depends on files generated in this step. So, remove it while - # this step is running since the files aren't there yet. - COMMAND ${CMAKE_COMMAND} -E remove -f ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/__init__.py - - # Run create_python_api.py to generate API init files. - COMMAND ${CMAKE_COMMAND} -E env PYTHONPATH=${CMAKE_CURRENT_BINARY_DIR}/tf_python PATH=${PY_RUNTIME_ENV} ${PYTHON_EXECUTABLE} - "${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/python/tools/api/generator/create_python_api.py" - "--root_init_template=${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/api_template.__init__.py" - "--apidir=${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow" - "--package=tensorflow.python" - "--apiname=tensorflow" - "${api_init_list_file}" - - COMMENT "Generating __init__.py files for Python API." - WORKING_DIRECTORY "${CMAKE_CURRENT_BINARY_DIR}/tf_python" - VERBATIM - ) -else (tensorflow_ENABLE_MKL_SUPPORT) - add_custom_command( - OUTPUT ${api_init_files} - DEPENDS tf_python_ops tf_python_copy_scripts_to_destination pywrap_tensorflow_internal tf_python_touchup_modules tf_extension_ops - - # tensorflow/__init__.py depends on files generated in this step. So, remove it while - # this step is running since the files aren't there yet. - COMMAND ${CMAKE_COMMAND} -E remove -f ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/__init__.py - - # Run create_python_api.py to generate API init files. - COMMAND ${CMAKE_COMMAND} -E env PYTHONPATH=${CMAKE_CURRENT_BINARY_DIR}/tf_python ${PYTHON_EXECUTABLE} - "${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/python/tools/api/generator/create_python_api.py" - "--root_init_template=${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/api_template.__init__.py" - "--apidir=${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow" - "--package=tensorflow.python" - "--apiname=tensorflow" - "${api_init_list_file}" - - COMMENT "Generating __init__.py files for Python API." - WORKING_DIRECTORY "${CMAKE_CURRENT_BINARY_DIR}/tf_python" - ) -endif (tensorflow_ENABLE_MKL_SUPPORT) + file(TO_CMAKE_PATH "$ENV{PATH}" PY_RUNTIME_ENV) + set(PY_RUNTIME_ENV ${mkl_BIN_DIRS} ${PY_RUNTIME_ENV}) + file(TO_NATIVE_PATH "${PY_RUNTIME_ENV}" PY_RUNTIME_ENV) + set(PY_RUNTIME_ENV "PATH=${PY_RUNTIME_ENV}") +endif(tensorflow_ENABLE_MKL_SUPPORT) + +add_custom_command( + OUTPUT ${api_init_files} + DEPENDS tf_python_ops tf_python_copy_scripts_to_destination pywrap_tensorflow_internal tf_python_touchup_modules tf_extension_ops + + # tensorflow/__init__.py depends on files generated in this step. So, remove it while + # this step is running since the files aren't there yet. + COMMAND ${CMAKE_COMMAND} -E remove -f ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/__init__.py + + # Run create_python_api.py to generate API init files. + COMMAND ${CMAKE_COMMAND} -E env PYTHONPATH=${CMAKE_CURRENT_BINARY_DIR}/tf_python "${PY_RUNTIME_ENV}" ${PYTHON_EXECUTABLE} + "${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/python/tools/api/generator/create_python_api.py" + "--root_init_template=${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/api_template.__init__.py" + "--apidir=${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow" + "--package=tensorflow.python" + "--apiname=tensorflow" + "${api_init_list_file}" + + COMMENT "Generating __init__.py files for Python API." + WORKING_DIRECTORY "${CMAKE_CURRENT_BINARY_DIR}/tf_python" + VERBATIM +) add_custom_target(tf_python_api SOURCES ${api_init_files}) add_dependencies(tf_python_api tf_python_ops) @@ -848,12 +831,12 @@ add_custom_command( DEPENDS tf_python_ops tf_python_copy_scripts_to_destination pywrap_tensorflow_internal tf_python_touchup_modules tf_extension_ops # Run create_python_api.py to generate API init files. - COMMAND ${CMAKE_COMMAND} -E env PYTHONPATH=${CMAKE_CURRENT_BINARY_DIR}/tf_python ${PYTHON_EXECUTABLE} + COMMAND ${CMAKE_COMMAND} -E env PYTHONPATH=${CMAKE_CURRENT_BINARY_DIR}/tf_python "${PY_RUNTIME_ENV}" ${PYTHON_EXECUTABLE} "${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/python/tools/api/generator/create_python_api.py" "--apidir=${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/python/estimator/api" "--package=tensorflow.python.estimator" "--apiname=estimator" - "--output_package=tensorflow.python.estimator.api" + "--output_package=tensorflow.python.estimator.api" "${estimator_api_init_list_file}" COMMENT "Generating __init__.py files for Python API." diff --git a/tensorflow/contrib/cmake/tf_tests.cmake b/tensorflow/contrib/cmake/tf_tests.cmake index b2330c4e340d531f70234de812ab6f6b2e5c1160..2c878c17167c662d10a8c7dabf41687efdbf65d8 100644 --- a/tensorflow/contrib/cmake/tf_tests.cmake +++ b/tensorflow/contrib/cmake/tf_tests.cmake @@ -122,6 +122,17 @@ function(AddPythonTests) endforeach() endfunction(AddPythonTests) +# +# ensure that every element is an existing file +# +function(CheckExists TYPE SOURCES) + foreach(source ${SOURCES}) + if(NOT EXISTS ${source}) + message(SEND_ERROR "${TYPE} not found: ${source}") + endif() + endforeach(source) +endfunction(CheckExists) + if (tensorflow_BUILD_PYTHON_TESTS) # # python tests. This assumes that the tensorflow wheel is @@ -145,7 +156,6 @@ if (tensorflow_BUILD_PYTHON_TESTS) "${tensorflow_source_dir}/tensorflow/python/debug/wrappers/*_test.py" "${tensorflow_source_dir}/tensorflow/contrib/estimator/python/estimator/*_test.py" "${tensorflow_source_dir}/tensorflow/python/kernel_tests/*.py" - "${tensorflow_source_dir}/tensorflow/python/meta_graph_transform/*_test.py" "${tensorflow_source_dir}/tensorflow/python/ops/quantized_conv_ops_test.py" "${tensorflow_source_dir}/tensorflow/python/ops/quantized_ops_test.py" "${tensorflow_source_dir}/tensorflow/python/platform/build_info_test.py" @@ -198,7 +208,6 @@ if (tensorflow_BUILD_PYTHON_TESTS) "${tensorflow_source_dir}/tensorflow/python/saved_model/saved_model_test.py" "${tensorflow_source_dir}/tensorflow/contrib/image/python/kernel_tests/sparse_image_warp_test.py" # requires scipy - "${tensorflow_source_dir}/tensorflow/contrib/keras/python/keras/preprocessing/*_test.py" "${tensorflow_source_dir}/tensorflow/contrib/tfprof/python/tools/tfprof/pprof_profiler_test.py" "${tensorflow_source_dir}/tensorflow/contrib/image/python/kernel_tests/interpolate_spline_test.py" # Takes very long to run without sharding (defined in bazel build file). @@ -256,10 +265,9 @@ if (tensorflow_BUILD_PYTHON_TESTS) # Flaky because of local cluster creation. "${tensorflow_source_dir}/tensorflow/python/training/sync_replicas_optimizer_test.py" "${tensorflow_source_dir}/tensorflow/python/debug/lib/session_debug_grpc_test.py" - "${tensorflow_source_dir}tensorflow/python/training/localhost_cluster_performance_test.py" + "${tensorflow_source_dir}/tensorflow/python/training/localhost_cluster_performance_test.py" "${tensorflow_source_dir}/tensorflow/python/data/kernel_tests/iterator_ops_cluster_test.py" "${tensorflow_source_dir}/tensorflow/python/kernel_tests/functional_ops_test.py" - "${tensorflow_source_dir}/tensorflow/contrib/data/python/kernel_tests/iterator_ops_cluster_test.py" # Type error in testRemoteIteratorUsingRemoteCallOpDirectSessionGPUCPU. "${tensorflow_source_dir}/tensorflow/python/data/kernel_tests/iterator_ops_test.py" "${tensorflow_source_dir}/tensorflow/python/kernel_tests/self_adjoint_eig_op_test.py" @@ -329,6 +337,7 @@ if (tensorflow_BUILD_PYTHON_TESTS) "${tensorflow_source_dir}/tensorflow/python/keras/_impl/keras/utils/io_utils_test.py" # b/72894325 ) endif() + CheckExists(${tf_test_src_py_exclude}) list(REMOVE_ITEM tf_test_src_py ${tf_test_src_py_exclude}) AddPythonTests( @@ -480,6 +489,7 @@ if (tensorflow_BUILD_CC_TESTS) "${tensorflow_source_dir}/tensorflow/cc/saved_model/*_test.cc" ) + CheckExists(${tf_test_src_simple_exclude}) list(REMOVE_ITEM tf_test_src_simple ${tf_test_src_simple_exclude} ${tf_cc_saved_model_test_srcs} @@ -494,6 +504,7 @@ if (tensorflow_BUILD_CC_TESTS) ${tf_core_profiler_test_srcs} ) + CheckExists(${tf_src_testlib}) set(tf_test_lib tf_test_lib) add_library(${tf_test_lib} STATIC ${tf_src_testlib}) diff --git a/tensorflow/contrib/coder/BUILD b/tensorflow/contrib/coder/BUILD index a2c6e413039ee3b5af3cb53d1af3325037536d36..855c824ead2f7de4c37db2d2a3648a9ee00fb9e9 100644 --- a/tensorflow/contrib/coder/BUILD +++ b/tensorflow/contrib/coder/BUILD @@ -1,5 +1,5 @@ # Description: -# Contains tools related to data compression. +# Contains ops related to data compression. package(default_visibility = [ "//learning/brain:__subpackages__", @@ -168,7 +168,6 @@ py_library( srcs_version = "PY2AND3", deps = [ ":coder_ops_py", - ":entropybottleneck_py", ], ) @@ -205,44 +204,3 @@ tf_py_test( ], main = "python/ops/coder_ops_test.py", ) - -py_library( - name = "entropybottleneck_py", - srcs = [ - "python/layers/entropybottleneck.py", - ], - srcs_version = "PY2AND3", - deps = [ - ":coder_ops_py", - "//tensorflow/python:array_ops", - "//tensorflow/python:constant_op", - "//tensorflow/python:dtypes", - "//tensorflow/python:functional_ops", - "//tensorflow/python:init_ops", - "//tensorflow/python:math_ops", - "//tensorflow/python:nn", - "//tensorflow/python:ops", - "//tensorflow/python:random_ops", - "//tensorflow/python:state_ops", - "//tensorflow/python:summary_ops", - "//tensorflow/python:tensor_shape", - "//tensorflow/python:variable_scope", - "//tensorflow/python/eager:context", - "//tensorflow/python/keras:engine", - "//third_party/py/numpy", - ], -) - -tf_py_test( - name = "entropybottleneck_py_test", - srcs = [ - "python/layers/entropybottleneck_test.py", - ], - additional_deps = [ - ":entropybottleneck_py", - "//tensorflow/python:client_testlib", - "//tensorflow/python:variables", - "//tensorflow/python:training", - ], - main = "python/layers/entropybottleneck_test.py", -) diff --git a/tensorflow/contrib/coder/README.md b/tensorflow/contrib/coder/README.md deleted file mode 100644 index c6c379c458893551b765327c0c1cbfff7f24f9c3..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/coder/README.md +++ /dev/null @@ -1,73 +0,0 @@ -# Entropy coder - -This module contains range encoder and range decoder which can encode integer -data into string with cumulative distribution functions (CDF). - -## Data and CDF values - -The data to be encoded should be non-negative integers in half-open interval -`[0, m)`. Then a CDF is represented as an integral vector of length `m + 1` -where `CDF(i) = f(Pr(X < i) * 2^precision)` for i = 0,1,...,m, and `precision` -is an attribute in range `0 < precision <= 16`. The function `f` maps real -values into integers, e.g., round or floor. It is important that to encode a -number `i`, `CDF(i + 1) - CDF(i)` cannot be zero. - -Note that we used `Pr(X < i)` not `Pr(X <= i)`, and therefore CDF(0) = 0 always. - -## RangeEncode: data shapes and CDF shapes - -For each data element, its CDF has to be provided. Therefore if the shape of CDF -should be `data.shape + (m + 1,)` in NumPy-like notation. For example, if `data` -is a 2-D tensor of shape (10, 10) and its elements are in `[0, 64)`, then the -CDF tensor should have shape (10, 10, 65). - -This may make CDF tensor too large, and in many applications all data elements -may have the same probability distribution. To handle this, `RangeEncode` -supports limited broadcasting CDF into data. Broadcasting is limited in the -following sense: - -- All CDF axes but the last one is broadcasted into data but not the other way - around, -- The number of CDF axes does not extend, i.e., `CDF.ndim == data.ndim + 1`. - -In the previous example where data has shape (10, 10), the following are -acceptable CDF shapes: - -- (10, 10, 65) -- (1, 10, 65) -- (10, 1, 65) -- (1, 1, 65) - -## RangeDecode - -`RangeEncode` encodes neither data shape nor termination character. Therefore -the decoder should know how many characters are encoded into the string, and -`RangeDecode` takes the encoded data shape as the second argument. The same -shape restrictions as `RangeEncode` inputs apply here. - -## Example - -```python -data = tf.random_uniform((128, 128), 0, 10, dtype=tf.int32) - -histogram = tf.bincount(data, minlength=10, maxlength=10) -cdf = tf.cumsum(histogram, exclusive=False) -# CDF should have length m + 1. -cdf = tf.pad(cdf, [[1, 0]]) -# CDF axis count must be one more than data. -cdf = tf.reshape(cdf, [1, 1, -1]) - -# Note that data has 2^14 elements, and therefore the sum of CDF is 2^14. -data = tf.cast(data, tf.int16) -encoded = coder.range_encode(data, cdf, precision=14) -decoded = coder.range_decode(encoded, tf.shape(data), cdf, precision=14) - -# data and decoded should be the same. -sess = tf.Session() -x, y = sess.run((data, decoded)) -assert np.all(x == y) -``` - -## Authors -Sung Jin Hwang (github: [ssjhv](https://github.com/ssjhv)) and Nick Johnston -(github: [nmjohn](https://github.com/nmjohn)) diff --git a/tensorflow/contrib/coder/__init__.py b/tensorflow/contrib/coder/__init__.py index 99b8ac7595ec632b2918e6b7ca22c06dd7f0a8b3..8897312046c63c42d85e7fba5b62d2ed908dd6e9 100644 --- a/tensorflow/contrib/coder/__init__.py +++ b/tensorflow/contrib/coder/__init__.py @@ -12,14 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Data compression tools.""" +"""Data compression ops.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function # pylint: disable=wildcard-import -from tensorflow.contrib.coder.python.layers.entropybottleneck import * from tensorflow.contrib.coder.python.ops.coder_ops import * # pylint: enable=wildcard-import diff --git a/tensorflow/contrib/coder/python/layers/entropybottleneck.py b/tensorflow/contrib/coder/python/layers/entropybottleneck.py deleted file mode 100644 index 0c997bd4fdfa4233117c9fec2c4397301b1c8cb9..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/coder/python/layers/entropybottleneck.py +++ /dev/null @@ -1,697 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2018 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Entropy bottleneck layer.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import numpy as np - -from tensorflow.contrib.coder.python.ops import coder_ops - -from tensorflow.python.eager import context -from tensorflow.python.framework import constant_op -from tensorflow.python.framework import dtypes -from tensorflow.python.framework import ops -from tensorflow.python.framework import tensor_shape -from tensorflow.python.keras.engine import base_layer -from tensorflow.python.ops import array_ops -from tensorflow.python.ops import functional_ops -from tensorflow.python.ops import init_ops -from tensorflow.python.ops import math_ops -from tensorflow.python.ops import nn -from tensorflow.python.ops import random_ops -from tensorflow.python.ops import state_ops -from tensorflow.python.ops import variable_scope -from tensorflow.python.summary import summary - - -class EntropyBottleneck(base_layer.Layer): - """Entropy bottleneck layer. - - This layer can be used to model the entropy (the amount of information - conveyed) of the tensor passing through it. During training, this can be used - to impose a (soft) entropy constraint on its activations, limiting the amount - of information flowing through the layer. Note that this is distinct from - other types of bottlenecks, which reduce the dimensionality of the space, for - example. Dimensionality reduction does not limit the amount of information, - and does not enable efficient data compression per se. - - After training, this layer can be used to compress any input tensor to a - string, which may be written to a file, and to decompress a file which it - previously generated back to a reconstructed tensor (possibly on a different - machine having access to the same model checkpoint). The entropies estimated - during training or evaluation are approximately equal to the average length of - the strings in bits. - - The layer implements a flexible probability density model to estimate entropy, - which is described in the appendix of the paper (please cite the paper if you - use this code for scientific work): - - "Variational image compression with a scale hyperprior" - - Johannes Ballé, David Minnen, Saurabh Singh, Sung Jin Hwang, Nick Johnston - - https://arxiv.org/abs/1802.01436 - - The layer assumes that the input tensor is at least 2D, with a batch dimension - at the beginning and a channel dimension as specified by `data_format`. The - layer trains an independent probability density model for each channel, but - assumes that across all other dimensions, the inputs are i.i.d. (independent - and identically distributed). Because the entropy (and hence, average - codelength) is a function of the densities, this assumption may have a direct - effect on the compression performance. - - Because data compression always involves discretization, the outputs of the - layer are generally only approximations of its inputs. During training, - discretization is modeled using additive uniform noise to ensure - differentiability. The entropies computed during training are differential - entropies. During evaluation, the data is actually quantized, and the - entropies are discrete (Shannon entropies). To make sure the approximated - tensor values are good enough for practical purposes, the training phase must - be used to balance the quality of the approximation with the entropy, by - adding an entropy term to the training loss, as in the following example. - - Here, we use the entropy bottleneck to compress the latent representation of - an autoencoder. The data vectors `x` in this case are 4D tensors in - `'channels_last'` format (for example, 16x16 pixel grayscale images). - - The layer always produces exactly one auxiliary loss and one update op which - are only significant for compression and decompression. To use the compression - feature, the auxiliary loss must be minimized during or after training. After - that, the update op must be executed at least once. Here, we simply attach - them to the main training step. - - Training: - ``` - # Build autoencoder. - x = tf.placeholder(tf.float32, shape=[None, 16, 16, 1]) - y = forward_transform(x) - entropy_bottleneck = EntropyBottleneck() - y_, likelihoods = entropy_bottleneck(y, training=True) - x_ = backward_transform(y_) - - # Information content (= predicted codelength) in bits of each batch element - # (note that taking the natural logarithm and dividing by `log(2)` is - # equivalent to taking base-2 logarithms): - bits = tf.reduce_sum(tf.log(likelihoods), axis=(1, 2, 3)) / -np.log(2) - - # Squared difference of each batch element: - squared_error = tf.reduce_sum(tf.squared_difference(x, x_), axis=(1, 2, 3)) - - # The loss is a weighted sum of mean squared error and entropy (average - # information content), where the weight controls the trade-off between - # approximation error and entropy. - main_loss = 0.5 * tf.reduce_mean(squared_error) + tf.reduce_mean(bits) - - # Minimize loss and auxiliary loss, and execute update op. - main_optimizer = tf.train.AdamOptimizer(learning_rate=1e-4) - main_step = optimizer.minimize(main_loss) - # 1e-2 is a good starting point for the learning rate of the auxiliary loss, - # assuming Adam is used. - aux_optimizer = tf.train.AdamOptimizer(learning_rate=1e-2) - aux_step = optimizer.minimize(entropy_bottleneck.losses[0]) - step = tf.group(main_step, aux_step, entropy_bottleneck.updates[0]) - ``` - - Evaluation: - ``` - # Build autoencoder. - x = tf.placeholder(tf.float32, shape=[None, 16, 16, 1]) - y = forward_transform(x) - y_, likelihoods = EntropyBottleneck()(y, training=False) - x_ = backward_transform(y_) - - # Information content (= predicted codelength) in bits of each batch element: - bits = tf.reduce_sum(tf.log(likelihoods), axis=(1, 2, 3)) / -np.log(2) - - # Squared difference of each batch element: - squared_error = tf.reduce_sum(tf.squared_difference(x, x_), axis=(1, 2, 3)) - - # The loss is a weighted sum of mean squared error and entropy (average - # information content), where the weight controls the trade-off between - # approximation error and entropy. - loss = 0.5 * tf.reduce_mean(squared_error) + tf.reduce_mean(bits) - ``` - - To be able to compress the bottleneck tensor and decompress it in a different - session, or on a different machine, you need three items: - - The compressed representations stored as strings. - - The shape of the bottleneck for these string representations as a `Tensor`, - as well as the number of channels of the bottleneck at graph construction - time. - - The checkpoint of the trained model that was used for compression. Note: - It is crucial that the auxiliary loss produced by this layer is minimized - during or after training, and that the update op is run after training and - minimization of the auxiliary loss, but *before* the checkpoint is saved. - - Compression: - ``` - x = tf.placeholder(tf.float32, shape=[None, 16, 16, 1]) - y = forward_transform(x) - strings = EntropyBottleneck().compress(y) - shape = tf.shape(y)[1:] - ``` - - Decompression: - ``` - strings = tf.placeholder(tf.string, shape=[None]) - shape = tf.placeholder(tf.int32, shape=[3]) - entropy_bottleneck = EntropyBottleneck(dtype=tf.float32) - y_ = entropy_bottleneck.decompress(strings, shape, channels=5) - x_ = backward_transform(y_) - ``` - Here, we assumed that the tensor produced by the forward transform has 5 - channels. - - The above four use cases can also be implemented within the same session (i.e. - on the same `EntropyBottleneck` instance), for testing purposes, etc., by - calling the object more than once. - - Arguments: - init_scale: Float. A scaling factor determining the initial width of the - probability densities. This should be chosen big enough so that the - range of values of the layer inputs roughly falls within the interval - [`-init_scale`, `init_scale`] at the beginning of training. - filters: An iterable of ints, giving the number of filters at each layer of - the density model. Generally, the more filters and layers, the more - expressive is the density model in terms of modeling more complicated - distributions of the layer inputs. For details, refer to the paper - referenced above. The default is `[3, 3, 3]`, which should be sufficient - for most practical purposes. - tail_mass: Float, between 0 and 1. The bottleneck layer automatically - determines the range of input values that should be represented based on - their frequency of occurrence. Values occurring in the tails of the - distributions will be clipped to that range during compression. - `tail_mass` determines the amount of probability mass in the tails which - is cut off in the worst case. For example, the default value of `1e-9` - means that at most 1 in a billion input samples will be clipped to the - range. - optimize_integer_offset: Boolean. Typically, the input values of this layer - are floats, which means that quantization during evaluation can be - performed with an arbitrary offset. By default, the layer determines that - offset automatically. In special situations, such as when it is known that - the layer will receive only full integer values during evaluation, it can - be desirable to set this argument to `False` instead, in order to always - quantize to full integer values. - likelihood_bound: Float. If positive, the returned likelihood values are - ensured to be greater than or equal to this value. This prevents very - large gradients with a typical entropy loss (defaults to 1e-9). - range_coder_precision: Integer, between 1 and 16. The precision of the range - coder used for compression and decompression. This trades off computation - speed with compression efficiency, where 16 is the slowest but most - efficient setting. Choosing lower values may increase the average - codelength slightly compared to the estimated entropies. - data_format: Either `'channels_first'` or `'channels_last'` (default). - trainable: Boolean. Whether the layer should be trained. - name: String. The name of the layer. - dtype: Default dtype of the layer's parameters (default of `None` means use - the type of the first input). - - Read-only properties: - init_scale: See above. - filters: See above. - tail_mass: See above. - optimize_integer_offset: See above. - likelihood_bound: See above. - range_coder_precision: See above. - data_format: See above. - name: String. See above. - dtype: See above. - trainable_variables: List of trainable variables. - non_trainable_variables: List of non-trainable variables. - variables: List of all variables of this layer, trainable and non-trainable. - updates: List of update ops of this layer. Always contains exactly one - update op, which must be run once after the last training step, before - `compress` or `decompress` is used. - losses: List of losses added by this layer. Always contains exactly one - auxiliary loss, which must be added to the training loss. - - Mutable properties: - trainable: Boolean. Whether the layer should be trained. - input_spec: Optional `InputSpec` object specifying the constraints on inputs - that can be accepted by the layer. - """ - - def __init__(self, init_scale=10, filters=(3, 3, 3), tail_mass=1e-9, - optimize_integer_offset=True, likelihood_bound=1e-9, - range_coder_precision=16, data_format="channels_last", **kwargs): - super(EntropyBottleneck, self).__init__(**kwargs) - self._init_scale = float(init_scale) - self._filters = tuple(int(f) for f in filters) - self._tail_mass = float(tail_mass) - if not 0 < self.tail_mass < 1: - raise ValueError( - "`tail_mass` must be between 0 and 1, got {}.".format(self.tail_mass)) - self._optimize_integer_offset = bool(optimize_integer_offset) - self._likelihood_bound = float(likelihood_bound) - self._range_coder_precision = int(range_coder_precision) - self._data_format = data_format - self._channel_axis(2) # trigger ValueError early - self.input_spec = base_layer.InputSpec(min_ndim=2) - - @property - def init_scale(self): - return self._init_scale - - @property - def filters(self): - return self._filters - - @property - def tail_mass(self): - return self._tail_mass - - @property - def optimize_integer_offset(self): - return self._optimize_integer_offset - - @property - def likelihood_bound(self): - return self._likelihood_bound - - @property - def range_coder_precision(self): - return self._range_coder_precision - - @property - def data_format(self): - return self._data_format - - def _channel_axis(self, ndim): - try: - return {"channels_first": 1, "channels_last": ndim - 1}[self.data_format] - except KeyError: - raise ValueError("Unsupported `data_format` for {} layer: {}.".format( - self.__class__.__name__, self.data_format)) - - def _logits_cumulative(self, inputs, stop_gradient): - """Evaluate logits of the cumulative densities. - - Args: - inputs: The values at which to evaluate the cumulative densities, expected - to be a `Tensor` of shape `(channels, 1, batch)`. - stop_gradient: Boolean. Whether to add `array_ops.stop_gradient` calls so - that the gradient of the output with respect to the density model - parameters is disconnected (the gradient with respect to `inputs` is - left untouched). - - Returns: - A `Tensor` of the same shape as `inputs`, containing the logits of the - cumulative densities evaluated at the given inputs. - """ - logits = inputs - - for i in range(len(self.filters) + 1): - matrix = self._matrices[i] - if stop_gradient: - matrix = array_ops.stop_gradient(matrix) - logits = math_ops.matmul(matrix, logits) - - bias = self._biases[i] - if stop_gradient: - bias = array_ops.stop_gradient(bias) - logits += bias - - if i < len(self._factors): - factor = self._factors[i] - if stop_gradient: - factor = array_ops.stop_gradient(factor) - logits += factor * math_ops.tanh(logits) - - return logits - - def build(self, input_shape): - """Builds the layer. - - Creates the variables for the network modeling the densities, creates the - auxiliary loss estimating the median and tail quantiles of the densities, - and then uses that to create the probability mass functions and the update - op that produces the discrete cumulative density functions used by the range - coder. - - Args: - input_shape: Shape of the input tensor, used to get the number of - channels. - - Raises: - ValueError: if `input_shape` doesn't specify the length of the channel - dimension. - """ - input_shape = tensor_shape.TensorShape(input_shape) - channel_axis = self._channel_axis(input_shape.ndims) - channels = input_shape[channel_axis].value - if channels is None: - raise ValueError("The channel dimension of the inputs must be defined.") - self.input_spec = base_layer.InputSpec( - ndim=input_shape.ndims, axes={channel_axis: channels}) - filters = (1,) + self.filters + (1,) - scale = self.init_scale ** (1 / (len(self.filters) + 1)) - - # Create variables. - self._matrices = [] - self._biases = [] - self._factors = [] - for i in range(len(self.filters) + 1): - init = np.log(np.expm1(1 / scale / filters[i + 1])) - matrix = self.add_variable( - "matrix_{}".format(i), dtype=self.dtype, - shape=(channels, filters[i + 1], filters[i]), - initializer=init_ops.Constant(init)) - matrix = nn.softplus(matrix) - self._matrices.append(matrix) - - bias = self.add_variable( - "bias_{}".format(i), dtype=self.dtype, - shape=(channels, filters[i + 1], 1), - initializer=init_ops.RandomUniform(-.5, .5)) - self._biases.append(bias) - - if i < len(self.filters): - factor = self.add_variable( - "factor_{}".format(i), dtype=self.dtype, - shape=(channels, filters[i + 1], 1), - initializer=init_ops.Zeros()) - factor = math_ops.tanh(factor) - self._factors.append(factor) - - # To figure out what range of the densities to sample, we need to compute - # the quantiles given by `tail_mass / 2` and `1 - tail_mass / 2`. Since we - # can't take inverses of the cumulative directly, we make it an optimization - # problem: - # `quantiles = argmin(|logit(cumulative) - target|)` - # where `target` is `logit(tail_mass / 2)` or `logit(1 - tail_mass / 2)`. - # Taking the logit (inverse of sigmoid) of the cumulative makes the - # representation of the right target more numerically stable. - - # Numerically stable way of computing logits of `tail_mass / 2` - # and `1 - tail_mass / 2`. - target = np.log(2 / self.tail_mass - 1) - # Compute lower and upper tail quantile as well as median. - target = constant_op.constant([-target, 0, target], dtype=self.dtype) - - def quantiles_initializer(shape, dtype=None, partition_info=None): - del partition_info # unused - assert tuple(shape[1:]) == (1, 3) - init = constant_op.constant( - [[[-self.init_scale, 0, self.init_scale]]], dtype=dtype) - return array_ops.tile(init, (shape[0], 1, 1)) - - quantiles = self.add_variable( - "quantiles", shape=(channels, 1, 3), dtype=self.dtype, - initializer=quantiles_initializer) - logits = self._logits_cumulative(quantiles, stop_gradient=True) - loss = math_ops.reduce_sum(abs(logits - target)) - self.add_loss(loss, inputs=None) - - # Save medians for `call`, `compress`, and `decompress`. - self._medians = quantiles[:, :, 1:2] - if not self.optimize_integer_offset: - self._medians = math_ops.round(self._medians) - - # Largest distance observed between lower tail quantile and median, - # or between median and upper tail quantile. - minima = math_ops.reduce_max(self._medians - quantiles[:, :, 0:1]) - maxima = math_ops.reduce_max(quantiles[:, :, 2:3] - self._medians) - minmax = math_ops.maximum(minima, maxima) - minmax = math_ops.ceil(minmax) - minmax = math_ops.maximum(minmax, 1) - - # Sample the density up to `minmax` around the median. - samples = math_ops.range(-minmax, minmax + 1, dtype=self.dtype) - samples += self._medians - - half = constant_op.constant(.5, dtype=self.dtype) - # We strip the sigmoid from the end here, so we can use the special rule - # below to only compute differences in the left tail of the sigmoid. - # This increases numerical stability (see explanation in `call`). - lower = self._logits_cumulative(samples - half, stop_gradient=True) - upper = self._logits_cumulative(samples + half, stop_gradient=True) - # Flip signs if we can move more towards the left tail of the sigmoid. - sign = -math_ops.sign(math_ops.add_n([lower, upper])) - pmf = abs(math_ops.sigmoid(sign * upper) - math_ops.sigmoid(sign * lower)) - # Add tail masses to first and last bin of pmf, as we clip values for - # compression, meaning that out-of-range values get mapped to these bins. - pmf = array_ops.concat([ - math_ops.add_n([pmf[:, 0, :1], math_ops.sigmoid(lower[:, 0, :1])]), - pmf[:, 0, 1:-1], - math_ops.add_n([pmf[:, 0, -1:], math_ops.sigmoid(-upper[:, 0, -1:])]), - ], axis=-1) - self._pmf = pmf - - cdf = coder_ops.pmf_to_quantized_cdf( - pmf, precision=self.range_coder_precision) - def cdf_getter(*args, **kwargs): - del args, kwargs # ignored - return variable_scope.get_variable( - "quantized_cdf", dtype=dtypes.int32, initializer=cdf, - trainable=False, validate_shape=False, collections=()) - # Need to provide a fake shape here since add_variable insists on it. - self._quantized_cdf = self.add_variable( - "quantized_cdf", shape=(channels, 1), dtype=dtypes.int32, - getter=cdf_getter, trainable=False) - - update_op = state_ops.assign( - self._quantized_cdf, cdf, validate_shape=False) - self.add_update(update_op, inputs=None) - - super(EntropyBottleneck, self).build(input_shape) - - def call(self, inputs, training): - """Pass a tensor through the bottleneck. - - Args: - inputs: The tensor to be passed through the bottleneck. - training: Boolean. If `True`, returns a differentiable approximation of - the inputs, and their likelihoods under the modeled probability - densities. If `False`, returns the quantized inputs and their - likelihoods under the corresponding probability mass function. These - quantities can't be used for training, as they are not differentiable, - but represent actual compression more closely. - - Returns: - values: `Tensor` with the same shape as `inputs` containing the perturbed - or quantized input values. - likelihood: `Tensor` with the same shape as `inputs` containing the - likelihood of `values` under the modeled probability distributions. - - Raises: - ValueError: if `inputs` has different `dtype` or number of channels than - a previous set of inputs the model was invoked with earlier. - """ - inputs = ops.convert_to_tensor(inputs) - ndim = self.input_spec.ndim - channel_axis = self._channel_axis(ndim) - half = constant_op.constant(.5, dtype=self.dtype) - - # Convert to (channels, 1, batch) format by commuting channels to front - # and then collapsing. - order = list(range(ndim)) - order.pop(channel_axis) - order.insert(0, channel_axis) - values = array_ops.transpose(inputs, order) - shape = array_ops.shape(values) - values = array_ops.reshape(values, (shape[0], 1, -1)) - - # Add noise or quantize. - if training: - noise = random_ops.random_uniform(array_ops.shape(values), -half, half) - values = math_ops.add_n([values, noise]) - elif self.optimize_integer_offset: - values = math_ops.round(values - self._medians) + self._medians - else: - values = math_ops.round(values) - - # Evaluate densities. - # We can use the special rule below to only compute differences in the left - # tail of the sigmoid. This increases numerical stability: sigmoid(x) is 1 - # for large x, 0 for small x. Subtracting two numbers close to 0 can be done - # with much higher precision than subtracting two numbers close to 1. - lower = self._logits_cumulative(values - half, stop_gradient=False) - upper = self._logits_cumulative(values + half, stop_gradient=False) - # Flip signs if we can move more towards the left tail of the sigmoid. - sign = -math_ops.sign(math_ops.add_n([lower, upper])) - sign = array_ops.stop_gradient(sign) - likelihood = abs( - math_ops.sigmoid(sign * upper) - math_ops.sigmoid(sign * lower)) - if self.likelihood_bound > 0: - likelihood_bound = constant_op.constant( - self.likelihood_bound, dtype=self.dtype) - # TODO(jballe): Override gradients. - likelihood = math_ops.maximum(likelihood, likelihood_bound) - - # Convert back to input tensor shape. - order = list(range(1, ndim)) - order.insert(channel_axis, 0) - values = array_ops.reshape(values, shape) - values = array_ops.transpose(values, order) - likelihood = array_ops.reshape(likelihood, shape) - likelihood = array_ops.transpose(likelihood, order) - - if not context.executing_eagerly(): - values_shape, likelihood_shape = self.compute_output_shape(inputs.shape) - values.set_shape(values_shape) - likelihood.set_shape(likelihood_shape) - - return values, likelihood - - def compress(self, inputs): - """Compress inputs and store their binary representations into strings. - - Args: - inputs: `Tensor` with values to be compressed. - - Returns: - String `Tensor` vector containing the compressed representation of each - batch element of `inputs`. - """ - with ops.name_scope(self._name_scope()): - inputs = ops.convert_to_tensor(inputs) - if not self.built: - # Check input assumptions set before layer building, e.g. input rank. - self._assert_input_compatibility(inputs) - if self.dtype is None: - self._dtype = inputs.dtype.base_dtype.name - self.build(inputs.shape) - - # Check input assumptions set after layer building, e.g. input shape. - if not context.executing_eagerly(): - self._assert_input_compatibility(inputs) - - ndim = self.input_spec.ndim - channel_axis = self._channel_axis(ndim) - # Tuple of slices for expanding dimensions of tensors below. - slices = ndim * [None] + [slice(None)] - slices[channel_axis] = slice(None) - slices = tuple(slices) - - # Expand dimensions of CDF to input dimensions, keeping the channels along - # the right dimension. - cdf = self._quantized_cdf[slices[1:]] - num_levels = array_ops.shape(cdf)[-1] - 1 - - # Bring inputs to the right range by centering the range on the medians. - half = constant_op.constant(.5, dtype=self.dtype) - medians = array_ops.squeeze(self._medians, [1, 2]) - offsets = (math_ops.cast(num_levels // 2, self.dtype) + half) - medians - # Expand offsets to input dimensions and add to inputs. - values = inputs + offsets[slices[:-1]] - - # Clip to range and cast to integers. Because we have added .5 above, and - # all values are positive, the cast effectively implements rounding. - values = math_ops.maximum(values, half) - values = math_ops.minimum( - values, math_ops.cast(num_levels, self.dtype) - half) - values = math_ops.cast(values, dtypes.int16) - - def loop_body(tensor): - return coder_ops.range_encode( - tensor, cdf, precision=self.range_coder_precision) - strings = functional_ops.map_fn( - loop_body, values, dtype=dtypes.string, back_prop=False) - - if not context.executing_eagerly(): - strings.set_shape(inputs.shape[:1]) - - return strings - - def decompress(self, strings, shape, channels=None): - """Decompress values from their compressed string representations. - - Args: - strings: A string `Tensor` vector containing the compressed data. - shape: A `Tensor` vector of int32 type. Contains the shape of the tensor - to be decompressed, excluding the batch dimension. - channels: Integer. Specifies the number of channels statically. Needs only - be set if the layer hasn't been built yet (i.e., this is the first input - it receives). - - Returns: - The decompressed `Tensor`. Its shape will be equal to `shape` prepended - with the batch dimension from `strings`. - - Raises: - ValueError: If the length of `shape` isn't available at graph construction - time. - """ - with ops.name_scope(self._name_scope()): - strings = ops.convert_to_tensor(strings) - shape = ops.convert_to_tensor(shape) - if self.built: - ndim = self.input_spec.ndim - channel_axis = self._channel_axis(ndim) - if channels is None: - channels = self.input_spec.axes[channel_axis] - else: - if not (shape.shape.is_fully_defined() and shape.shape.ndims == 1): - raise ValueError("`shape` must be a vector with known length.") - ndim = shape.shape[0].value + 1 - channel_axis = self._channel_axis(ndim) - input_shape = ndim * [None] - input_shape[channel_axis] = channels - self.build(input_shape) - - # Tuple of slices for expanding dimensions of tensors below. - slices = ndim * [None] + [slice(None)] - slices[channel_axis] = slice(None) - slices = tuple(slices) - - # Expand dimensions of CDF to input dimensions, keeping the channels along - # the right dimension. - cdf = self._quantized_cdf[slices[1:]] - num_levels = array_ops.shape(cdf)[-1] - 1 - - def loop_body(string): - return coder_ops.range_decode( - string, shape, cdf, precision=self.range_coder_precision) - outputs = functional_ops.map_fn( - loop_body, strings, dtype=dtypes.int16, back_prop=False) - outputs = math_ops.cast(outputs, self.dtype) - - medians = array_ops.squeeze(self._medians, [1, 2]) - offsets = math_ops.cast(num_levels // 2, self.dtype) - medians - outputs -= offsets[slices[:-1]] - - if not context.executing_eagerly(): - outputs_shape = ndim * [None] - outputs_shape[0] = strings.shape[0] - outputs_shape[channel_axis] = channels - outputs.set_shape(outputs_shape) - - return outputs - - def visualize(self): - """Multi-channel visualization of densities as images. - - Creates and returns an image summary visualizing the current probabilty - density estimates. The image contains one row for each channel. Within each - row, the pixel intensities are proportional to probability values, and each - row is centered on the median of the corresponding distribution. - - Returns: - The created image summary. - """ - with ops.name_scope(self._name_scope()): - image = self._pmf - image *= 255 / math_ops.reduce_max(image, axis=1, keepdims=True) - image = math_ops.cast(image + .5, dtypes.uint8) - image = image[None, :, :, None] - return summary.image("pmf", image, max_outputs=1) - - def compute_output_shape(self, input_shape): - input_shape = tensor_shape.TensorShape(input_shape) - return input_shape, input_shape diff --git a/tensorflow/contrib/coder/python/layers/entropybottleneck_test.py b/tensorflow/contrib/coder/python/layers/entropybottleneck_test.py deleted file mode 100644 index 798b0234ebcce7df108a0da65d1305502ce0253a..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/coder/python/layers/entropybottleneck_test.py +++ /dev/null @@ -1,315 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2018 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Tests of EntropyBottleneck class.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import numpy as np - -from tensorflow.contrib.coder.python.layers import entropybottleneck - -from tensorflow.python.framework import dtypes -from tensorflow.python.ops import array_ops -from tensorflow.python.ops import math_ops -from tensorflow.python.ops import variables -from tensorflow.python.platform import test -from tensorflow.python.training import gradient_descent - - -class EntropyBottleneckTest(test.TestCase): - - def test_noise(self): - # Tests that the noise added is uniform noise between -0.5 and 0.5. - inputs = array_ops.placeholder(dtypes.float32, (None, 1)) - layer = entropybottleneck.EntropyBottleneck() - noisy, _ = layer(inputs, training=True) - with self.test_session() as sess: - sess.run(variables.global_variables_initializer()) - values = np.linspace(-50, 50, 100)[:, None] - noisy, = sess.run([noisy], {inputs: values}) - self.assertFalse(np.allclose(values, noisy, rtol=0, atol=.49)) - self.assertAllClose(values, noisy, rtol=0, atol=.5) - - def test_quantization(self): - # Tests that inputs are quantized to full integer values, even after - # quantiles have been updated. - inputs = array_ops.placeholder(dtypes.float32, (None, 1)) - layer = entropybottleneck.EntropyBottleneck(optimize_integer_offset=False) - quantized, _ = layer(inputs, training=False) - opt = gradient_descent.GradientDescentOptimizer(learning_rate=1) - self.assertTrue(len(layer.losses) == 1) - step = opt.minimize(layer.losses[0]) - with self.test_session() as sess: - sess.run(variables.global_variables_initializer()) - sess.run(step) - values = np.linspace(-50, 50, 100)[:, None] - quantized, = sess.run([quantized], {inputs: values}) - self.assertAllClose(np.around(values), quantized, rtol=0, atol=1e-6) - - def test_quantization_optimized_offset(self): - # Tests that inputs are not quantized to full integer values after quantiles - # have been updated. However, the difference between input and output should - # be between -0.5 and 0.5, and the offset must be consistent. - inputs = array_ops.placeholder(dtypes.float32, (None, 1)) - layer = entropybottleneck.EntropyBottleneck(optimize_integer_offset=True) - quantized, _ = layer(inputs, training=False) - opt = gradient_descent.GradientDescentOptimizer(learning_rate=1) - self.assertTrue(len(layer.losses) == 1) - step = opt.minimize(layer.losses[0]) - with self.test_session() as sess: - sess.run(variables.global_variables_initializer()) - sess.run(step) - values = np.linspace(-50, 50, 100)[:, None] - quantized, = sess.run([quantized], {inputs: values}) - self.assertAllClose(values, quantized, rtol=0, atol=.5) - diff = np.ravel(np.around(values) - quantized) % 1 - self.assertAllClose(diff, np.full_like(diff, diff[0]), rtol=0, atol=5e-6) - self.assertNotEqual(diff[0], 0) - - def test_codec(self): - # Tests that inputs are compressed and decompressed correctly, and quantized - # to full integer values, even after quantiles have been updated. - inputs = array_ops.placeholder(dtypes.float32, (1, None, 1)) - layer = entropybottleneck.EntropyBottleneck( - data_format="channels_last", init_scale=60, - optimize_integer_offset=False) - bitstrings = layer.compress(inputs) - decoded = layer.decompress(bitstrings, array_ops.shape(inputs)[1:]) - opt = gradient_descent.GradientDescentOptimizer(learning_rate=1) - self.assertTrue(len(layer.losses) == 1) - step = opt.minimize(layer.losses[0]) - with self.test_session() as sess: - sess.run(variables.global_variables_initializer()) - sess.run(step) - self.assertTrue(len(layer.updates) == 1) - sess.run(layer.updates[0]) - values = np.linspace(-50, 50, 100)[None, :, None] - decoded, = sess.run([decoded], {inputs: values}) - self.assertAllClose(np.around(values), decoded, rtol=0, atol=1e-6) - - def test_codec_optimized_offset(self): - # Tests that inputs are compressed and decompressed correctly, and not - # quantized to full integer values after quantiles have been updated. - # However, the difference between input and output should be between -0.5 - # and 0.5, and the offset must be consistent. - inputs = array_ops.placeholder(dtypes.float32, (1, None, 1)) - layer = entropybottleneck.EntropyBottleneck( - data_format="channels_last", init_scale=60, - optimize_integer_offset=True) - bitstrings = layer.compress(inputs) - decoded = layer.decompress(bitstrings, array_ops.shape(inputs)[1:]) - opt = gradient_descent.GradientDescentOptimizer(learning_rate=1) - self.assertTrue(len(layer.losses) == 1) - step = opt.minimize(layer.losses[0]) - with self.test_session() as sess: - sess.run(variables.global_variables_initializer()) - sess.run(step) - self.assertTrue(len(layer.updates) == 1) - sess.run(layer.updates[0]) - values = np.linspace(-50, 50, 100)[None, :, None] - decoded, = sess.run([decoded], {inputs: values}) - self.assertAllClose(values, decoded, rtol=0, atol=.5) - diff = np.ravel(np.around(values) - decoded) % 1 - self.assertAllClose(diff, np.full_like(diff, diff[0]), rtol=0, atol=5e-6) - self.assertNotEqual(diff[0], 0) - - def test_codec_clipping(self): - # Tests that inputs are compressed and decompressed correctly, and clipped - # to the expected range. - inputs = array_ops.placeholder(dtypes.float32, (1, None, 1)) - layer = entropybottleneck.EntropyBottleneck( - data_format="channels_last", init_scale=40) - bitstrings = layer.compress(inputs) - decoded = layer.decompress(bitstrings, array_ops.shape(inputs)[1:]) - with self.test_session() as sess: - sess.run(variables.global_variables_initializer()) - self.assertTrue(len(layer.updates) == 1) - sess.run(layer.updates[0]) - values = np.linspace(-50, 50, 100)[None, :, None] - decoded, = sess.run([decoded], {inputs: values}) - expected = np.clip(np.around(values), -40, 40) - self.assertAllClose(expected, decoded, rtol=0, atol=1e-6) - - def test_channels_last(self): - # Test the layer with more than one channel and multiple input dimensions, - # with the channels in the last dimension. - inputs = array_ops.placeholder(dtypes.float32, (None, None, None, 2)) - layer = entropybottleneck.EntropyBottleneck( - data_format="channels_last", init_scale=50) - noisy, _ = layer(inputs, training=True) - quantized, _ = layer(inputs, training=False) - bitstrings = layer.compress(inputs) - decoded = layer.decompress(bitstrings, array_ops.shape(inputs)[1:]) - with self.test_session() as sess: - sess.run(variables.global_variables_initializer()) - self.assertTrue(len(layer.updates) == 1) - sess.run(layer.updates[0]) - values = 5 * np.random.normal(size=(7, 5, 3, 2)) - noisy, quantized, decoded = sess.run( - [noisy, quantized, decoded], {inputs: values}) - self.assertAllClose(values, noisy, rtol=0, atol=.5) - self.assertAllClose(values, quantized, rtol=0, atol=.5) - self.assertAllClose(values, decoded, rtol=0, atol=.5) - - def test_channels_first(self): - # Test the layer with more than one channel and multiple input dimensions, - # with the channel dimension right after the batch dimension. - inputs = array_ops.placeholder(dtypes.float32, (None, 3, None, None)) - layer = entropybottleneck.EntropyBottleneck( - data_format="channels_first", init_scale=50) - noisy, _ = layer(inputs, training=True) - quantized, _ = layer(inputs, training=False) - bitstrings = layer.compress(inputs) - decoded = layer.decompress(bitstrings, array_ops.shape(inputs)[1:]) - with self.test_session() as sess: - sess.run(variables.global_variables_initializer()) - self.assertTrue(len(layer.updates) == 1) - sess.run(layer.updates[0]) - values = 5 * np.random.normal(size=(2, 3, 5, 7)) - noisy, quantized, decoded = sess.run( - [noisy, quantized, decoded], {inputs: values}) - self.assertAllClose(values, noisy, rtol=0, atol=.5) - self.assertAllClose(values, quantized, rtol=0, atol=.5) - self.assertAllClose(values, decoded, rtol=0, atol=.5) - - def test_compress(self): - # Test compression and decompression, and produce test data for - # `test_decompress`. If you set the constant at the end to `True`, this test - # will fail and the log will contain the new test data. - inputs = array_ops.placeholder(dtypes.float32, (2, 3, 10)) - layer = entropybottleneck.EntropyBottleneck( - data_format="channels_first", filters=(), init_scale=2) - bitstrings = layer.compress(inputs) - decoded = layer.decompress(bitstrings, array_ops.shape(inputs)[1:]) - with self.test_session() as sess: - sess.run(variables.global_variables_initializer()) - self.assertTrue(len(layer.updates) == 1) - sess.run(layer.updates[0]) - values = 5 * np.random.uniform(size=(2, 3, 10)) - 2.5 - bitstrings, quantized_cdf, decoded = sess.run( - [bitstrings, layer._quantized_cdf, decoded], {inputs: values}) - self.assertAllClose(values, decoded, rtol=0, atol=.5) - # Set this constant to `True` to log new test data for `test_decompress`. - if False: # pylint:disable=using-constant-test - assert False, (bitstrings, quantized_cdf, decoded) - - # Data generated by `test_compress`. - # pylint:disable=g-inconsistent-quotes,bad-whitespace - bitstrings = np.array([ - b'\x1e\xbag}\xc2\xdaN\x8b\xbd.', - b'\x8dF\xf0%\x1cv\xccllW' - ], dtype=object) - - quantized_cdf = np.array([ - [ 0, 15636, 22324, 30145, 38278, 65536], - [ 0, 19482, 26927, 35052, 42904, 65535], - [ 0, 21093, 28769, 36919, 44578, 65536] - ], dtype=np.int32) - - expected = np.array([ - [[-2., 1., 0., -2., -1., -2., -2., -2., 2., -1.], - [ 1., 2., 1., 0., -2., -2., 1., 2., 0., 1.], - [ 2., 0., -2., 2., 0., -1., -2., 0., 2., 0.]], - [[ 1., 2., 0., -1., 1., 2., 1., 1., 2., -2.], - [ 2., -1., -1., 0., -1., 2., 0., 2., -2., 2.], - [ 2., -2., -2., -1., -2., 1., -2., 0., 0., 0.]] - ], dtype=np.float32) - # pylint:enable=g-inconsistent-quotes,bad-whitespace - - def test_decompress(self): - # Test that decompression of values compressed with a previous version - # works, i.e. that the file format doesn't change across revisions. - bitstrings = array_ops.placeholder(dtypes.string) - input_shape = array_ops.placeholder(dtypes.int32) - quantized_cdf = array_ops.placeholder(dtypes.int32) - layer = entropybottleneck.EntropyBottleneck( - data_format="channels_first", filters=(), dtype=dtypes.float32) - layer.build(self.expected.shape) - layer._quantized_cdf = quantized_cdf - decoded = layer.decompress(bitstrings, input_shape[1:]) - with self.test_session() as sess: - sess.run(variables.global_variables_initializer()) - decoded, = sess.run([decoded], { - bitstrings: self.bitstrings, input_shape: self.expected.shape, - quantized_cdf: self.quantized_cdf}) - self.assertAllClose(self.expected, decoded, rtol=0, atol=1e-6) - - def test_build_decompress(self): - # Test that layer can be built when `decompress` is the first call to it. - bitstrings = array_ops.placeholder(dtypes.string) - input_shape = array_ops.placeholder(dtypes.int32, shape=[3]) - layer = entropybottleneck.EntropyBottleneck(dtype=dtypes.float32) - layer.decompress(bitstrings, input_shape[1:], channels=5) - self.assertTrue(layer.built) - - def test_pmf_normalization(self): - # Test that probability mass functions are normalized correctly. - layer = entropybottleneck.EntropyBottleneck(dtype=dtypes.float32) - layer.build((None, 10)) - with self.test_session() as sess: - sess.run(variables.global_variables_initializer()) - pmf, = sess.run([layer._pmf]) - self.assertAllClose(np.ones(10), np.sum(pmf, axis=-1), rtol=0, atol=1e-6) - - def test_visualize(self): - # Test that summary op can be constructed. - layer = entropybottleneck.EntropyBottleneck(dtype=dtypes.float32) - layer.build((None, 10)) - summary = layer.visualize() - with self.test_session() as sess: - sess.run(variables.global_variables_initializer()) - sess.run([summary]) - - def test_normalization(self): - # Test that densities are normalized correctly. - inputs = array_ops.placeholder(dtypes.float32, (None, 1)) - layer = entropybottleneck.EntropyBottleneck(filters=(2,)) - _, likelihood = layer(inputs, training=True) - with self.test_session() as sess: - sess.run(variables.global_variables_initializer()) - x = np.repeat(np.arange(-200, 201), 1000)[:, None] - likelihood, = sess.run([likelihood], {inputs: x}) - self.assertEqual(x.shape, likelihood.shape) - integral = np.sum(likelihood) * .001 - self.assertAllClose(1, integral, rtol=0, atol=1e-4) - - def test_entropy_estimates(self): - # Test that entropy estimates match actual range coding. - inputs = array_ops.placeholder(dtypes.float32, (1, None, 1)) - layer = entropybottleneck.EntropyBottleneck( - filters=(2, 3), data_format="channels_last") - _, likelihood = layer(inputs, training=True) - diff_entropy = math_ops.reduce_sum(math_ops.log(likelihood)) / -np.log(2) - _, likelihood = layer(inputs, training=False) - disc_entropy = math_ops.reduce_sum(math_ops.log(likelihood)) / -np.log(2) - bitstrings = layer.compress(inputs) - with self.test_session() as sess: - sess.run(variables.global_variables_initializer()) - self.assertTrue(len(layer.updates) == 1) - sess.run(layer.updates[0]) - diff_entropy, disc_entropy, bitstrings = sess.run( - [diff_entropy, disc_entropy, bitstrings], - {inputs: np.random.normal(size=(1, 10000, 1))}) - codelength = 8 * sum(len(bitstring) for bitstring in bitstrings) - self.assertAllClose(diff_entropy, disc_entropy, rtol=5e-3, atol=0) - self.assertAllClose(disc_entropy, codelength, rtol=5e-3, atol=0) - self.assertGreater(codelength, disc_entropy) - - -if __name__ == "__main__": - test.main() diff --git a/tensorflow/contrib/constrained_optimization/python/candidates.py b/tensorflow/contrib/constrained_optimization/python/candidates.py index ac86a6741be1f244476f917d0e151166db65524b..66d7ebed74d8d4b9493af3a0badafa8f9e95bd9f 100644 --- a/tensorflow/contrib/constrained_optimization/python/candidates.py +++ b/tensorflow/contrib/constrained_optimization/python/candidates.py @@ -204,7 +204,7 @@ def find_best_candidate_distribution(objective_vector, assert best_pp is not None # Throughout this loop, a maximum_violation of "lower" is not achievable, - # but a maximum_violation of "upper" is achiveable. + # but a maximum_violation of "upper" is achievable. while True: middle = 0.5 * (lower + upper) if (middle - lower <= epsilon) or (upper - middle <= epsilon): diff --git a/tensorflow/contrib/constrained_optimization/python/constrained_minimization_problem.py b/tensorflow/contrib/constrained_optimization/python/constrained_minimization_problem.py index 70813fb217956b167b80a7e1d555c8ba79088fdb..41258edd90866ae9f644a02c42dfe2dc589da998 100644 --- a/tensorflow/contrib/constrained_optimization/python/constrained_minimization_problem.py +++ b/tensorflow/contrib/constrained_optimization/python/constrained_minimization_problem.py @@ -72,7 +72,8 @@ class ConstrainedMinimizationProblem(object): else: proxy_constraints_shape = self.proxy_constraints.get_shape() - if (constraints_shape is None or proxy_constraints_shape is None or + if (constraints_shape.ndims is None or + proxy_constraints_shape.ndims is None or any([ii is None for ii in constraints_shape.as_list()]) or any([ii is None for ii in proxy_constraints_shape.as_list()])): raise ValueError( @@ -121,3 +122,19 @@ class ConstrainedMinimizationProblem(object): A tensor of proxy constraint functions. """ return None + + # This is a property, instead of an abstract property, since it doesn't need + # to be overridden: if pre_train_ops returns None, then there are no ops to + # run before train_op. + @property + def pre_train_ops(self): + """Returns a list of `Operation`s to run before the train_op. + + When a `ConstrainedOptimizer` creates a train_op (in `minimize` + `minimize_unconstrained`, or `minimize_constrained`), it will include these + ops before the main training step. + + Returns: + A list of `Operation`s. + """ + return None diff --git a/tensorflow/contrib/constrained_optimization/python/constrained_optimizer.py b/tensorflow/contrib/constrained_optimization/python/constrained_optimizer.py index 805554536610a5e2cc650ff0b47185f4fbd6fac5..0b79bdf7c05c5195b169797ca76b619032fc3a61 100644 --- a/tensorflow/contrib/constrained_optimization/python/constrained_optimizer.py +++ b/tensorflow/contrib/constrained_optimization/python/constrained_optimizer.py @@ -55,20 +55,21 @@ class ConstrainedOptimizer(object): """Returns the `tf.train.Optimizer` used for optimization.""" return self._optimizer - def minimize_unconstrained(self, - minimization_problem, - global_step=None, - var_list=None, - gate_gradients=train_optimizer.Optimizer.GATE_OP, - aggregation_method=None, - colocate_gradients_with_ops=False, - name=None, - grad_loss=None): - """Returns an `Op` for minimizing the unconstrained problem. + @abc.abstractmethod + def _minimize_constrained(self, + minimization_problem, + global_step=None, + var_list=None, + gate_gradients=train_optimizer.Optimizer.GATE_OP, + aggregation_method=None, + colocate_gradients_with_ops=False, + name=None, + grad_loss=None): + """Version of `minimize_constrained` to be overridden by subclasses. - Unlike `minimize_constrained`, this function ignores the `constraints` (and - `proxy_constraints`) portion of the minimization problem entirely, and only - minimizes `objective`. + Implementations of this method should ignore the `pre_train_ops` property of + the `minimization_problem`. The public `minimize_constrained` method will + take care of executing these before the returned train_op. Args: minimization_problem: ConstrainedMinimizationProblem, the problem to @@ -83,19 +84,10 @@ class ConstrainedOptimizer(object): grad_loss: as in `tf.train.Optimizer`'s `minimize` method. Returns: - TensorFlow Op. + `Operation`, the train_op. """ - return self.optimizer.minimize( - minimization_problem.objective, - global_step=global_step, - var_list=var_list, - gate_gradients=gate_gradients, - aggregation_method=aggregation_method, - colocate_gradients_with_ops=colocate_gradients_with_ops, - name=name, - grad_loss=grad_loss) + pass - @abc.abstractmethod def minimize_constrained(self, minimization_problem, global_step=None, @@ -105,7 +97,7 @@ class ConstrainedOptimizer(object): colocate_gradients_with_ops=False, name=None, grad_loss=None): - """Returns an `Op` for minimizing the constrained problem. + """Returns an `Operation` for minimizing the constrained problem. Unlike `minimize_unconstrained`, this function attempts to find a solution that minimizes the `objective` portion of the minimization problem while @@ -124,9 +116,83 @@ class ConstrainedOptimizer(object): grad_loss: as in `tf.train.Optimizer`'s `minimize` method. Returns: - TensorFlow Op. + `Operation`, the train_op. """ - pass + + def train_op_callback(): + return self._minimize_constrained( + minimization_problem, + global_step=global_step, + var_list=var_list, + gate_gradients=gate_gradients, + aggregation_method=aggregation_method, + colocate_gradients_with_ops=colocate_gradients_with_ops, + name=name, + grad_loss=grad_loss) + + # If we have pre_train_ops, use tf.control_dependencies() to ensure that + # they execute before the train_op. + pre_train_ops = minimization_problem.pre_train_ops + if pre_train_ops: + with ops.control_dependencies(pre_train_ops): + train_op = train_op_callback() + else: + train_op = train_op_callback() + + return train_op + + def minimize_unconstrained(self, + minimization_problem, + global_step=None, + var_list=None, + gate_gradients=train_optimizer.Optimizer.GATE_OP, + aggregation_method=None, + colocate_gradients_with_ops=False, + name=None, + grad_loss=None): + """Returns an `Operation` for minimizing the unconstrained problem. + + Unlike `minimize_constrained`, this function ignores the `constraints` (and + `proxy_constraints`) portion of the minimization problem entirely, and only + minimizes `objective`. + + Args: + minimization_problem: ConstrainedMinimizationProblem, the problem to + optimize. + global_step: as in `tf.train.Optimizer`'s `minimize` method. + var_list: as in `tf.train.Optimizer`'s `minimize` method. + gate_gradients: as in `tf.train.Optimizer`'s `minimize` method. + aggregation_method: as in `tf.train.Optimizer`'s `minimize` method. + colocate_gradients_with_ops: as in `tf.train.Optimizer`'s `minimize` + method. + name: as in `tf.train.Optimizer`'s `minimize` method. + grad_loss: as in `tf.train.Optimizer`'s `minimize` method. + + Returns: + `Operation`, the train_op. + """ + + def train_op_callback(): + return self.optimizer.minimize( + minimization_problem.objective, + global_step=global_step, + var_list=var_list, + gate_gradients=gate_gradients, + aggregation_method=aggregation_method, + colocate_gradients_with_ops=colocate_gradients_with_ops, + name=name, + grad_loss=grad_loss) + + # If we have pre_train_ops, use tf.control_dependencies() to ensure that + # they execute before the train_op. + pre_train_ops = minimization_problem.pre_train_ops + if pre_train_ops: + with ops.control_dependencies(pre_train_ops): + train_op = train_op_callback() + else: + train_op = train_op_callback() + + return train_op def minimize(self, minimization_problem, @@ -138,7 +204,7 @@ class ConstrainedOptimizer(object): colocate_gradients_with_ops=False, name=None, grad_loss=None): - """Returns an `Op` for minimizing the constrained problem. + """Returns an `Operation` for minimizing the constrained problem. This method combines the functionality of `minimize_unconstrained` and `minimize_constrained`. If global_step < unconstrained_steps, it will @@ -164,14 +230,14 @@ class ConstrainedOptimizer(object): grad_loss: as in `tf.train.Optimizer`'s `minimize` method. Returns: - TensorFlow Op. + `Operation`, the train_op. Raises: ValueError: If unconstrained_steps is provided, but global_step is not. """ def unconstrained_fn(): - """Returns an `Op` for minimizing the unconstrained problem.""" + """Returns an `Operation` for minimizing the unconstrained problem.""" return self.minimize_unconstrained( minimization_problem=minimization_problem, global_step=global_step, @@ -183,7 +249,7 @@ class ConstrainedOptimizer(object): grad_loss=grad_loss) def constrained_fn(): - """Returns an `Op` for minimizing the constrained problem.""" + """Returns an `Operation` for minimizing the constrained problem.""" return self.minimize_constrained( minimization_problem=minimization_problem, global_step=global_step, diff --git a/tensorflow/contrib/constrained_optimization/python/external_regret_optimizer.py b/tensorflow/contrib/constrained_optimization/python/external_regret_optimizer.py index 01c6e4f08afb93e37aa124f31ca7faa10b07d4d6..d1af15f7e423c5135071ea73f6b7a0709d140600 100644 --- a/tensorflow/contrib/constrained_optimization/python/external_regret_optimizer.py +++ b/tensorflow/contrib/constrained_optimization/python/external_regret_optimizer.py @@ -70,11 +70,13 @@ def _project_multipliers_wrt_euclidean_norm(multipliers, radius): region w.r.t. the Euclidean norm. Raises: - ValueError: if the `multipliers` tensor does not have a fully-known shape, - or is not one-dimensional. + ValueError: if the `multipliers` tensor is not floating-point, does not have + a fully-known shape, or is not one-dimensional. """ + if not multipliers.dtype.is_floating: + raise ValueError("multipliers must have a floating-point dtype") multipliers_shape = multipliers.get_shape() - if multipliers_shape is None: + if multipliers_shape.ndims is None: raise ValueError("multipliers must have known shape") if multipliers_shape.ndims != 1: raise ValueError( @@ -101,12 +103,12 @@ def _project_multipliers_wrt_euclidean_norm(multipliers, radius): (radius - standard_ops.reduce_sum(multipliers)) / standard_ops.maximum( 1.0, standard_ops.reduce_sum(inactive))) multipliers += scale * inactive - new_inactive = standard_ops.to_float(multipliers > 0) + new_inactive = standard_ops.cast(multipliers > 0, multipliers.dtype) multipliers *= new_inactive return (iteration, multipliers, new_inactive, inactive) iteration = standard_ops.constant(0) - inactive = standard_ops.ones_like(multipliers) + inactive = standard_ops.ones_like(multipliers, dtype=multipliers.dtype) # We actually want a do-while loop, so we explicitly call while_loop_body() # once before tf.while_loop(). @@ -189,16 +191,16 @@ class _ExternalRegretOptimizer(constrained_optimizer.ConstrainedOptimizer): def _projection_op(self, state, name=None): pass - def minimize_constrained(self, - minimization_problem, - global_step=None, - var_list=None, - gate_gradients=train_optimizer.Optimizer.GATE_OP, - aggregation_method=None, - colocate_gradients_with_ops=False, - name=None, - grad_loss=None): - """Returns an `Op` for minimizing the constrained problem. + def _minimize_constrained(self, + minimization_problem, + global_step=None, + var_list=None, + gate_gradients=train_optimizer.Optimizer.GATE_OP, + aggregation_method=None, + colocate_gradients_with_ops=False, + name=None, + grad_loss=None): + """Returns an `Operation` for minimizing the constrained problem. The `optimizer` constructor parameter will be used to update the model parameters, while the Lagrange multipliers will be updated using @@ -216,8 +218,11 @@ class _ExternalRegretOptimizer(constrained_optimizer.ConstrainedOptimizer): name: as in `tf.train.Optimizer`'s `minimize` method. grad_loss: as in `tf.train.Optimizer`'s `minimize` method. + Raises: + ValueError: If the minimization_problem tensors have different dtypes. + Returns: - TensorFlow Op. + `Operation`, the train_op. """ objective = minimization_problem.objective @@ -225,6 +230,14 @@ class _ExternalRegretOptimizer(constrained_optimizer.ConstrainedOptimizer): proxy_constraints = minimization_problem.proxy_constraints if proxy_constraints is None: proxy_constraints = constraints + + # Make sure that the objective, constraints and proxy constraints all have + # the same dtype. + if (objective.dtype.base_dtype != constraints.dtype.base_dtype or + objective.dtype.base_dtype != proxy_constraints.dtype.base_dtype): + raise ValueError("objective, constraints and proxy_constraints must " + "have the same dtype") + # Flatten both constraints tensors to 1d. num_constraints = minimization_problem.num_constraints constraints = standard_ops.reshape(constraints, shape=(num_constraints,)) @@ -241,8 +254,10 @@ class _ExternalRegretOptimizer(constrained_optimizer.ConstrainedOptimizer): multipliers = self._lagrange_multipliers(state) loss = ( - objective + standard_ops.tensordot(multipliers, proxy_constraints, 1)) - multipliers_gradient = constraints + objective + standard_ops.tensordot( + standard_ops.cast(multipliers, proxy_constraints.dtype), + proxy_constraints, 1)) + multipliers_gradient = standard_ops.cast(constraints, multipliers.dtype) update_ops = [] if self.constraint_optimizer is None: @@ -356,6 +371,8 @@ class AdditiveExternalRegretOptimizer(_ExternalRegretOptimizer): # For an AdditiveExternalRegretOptimizer, the internal state is simply a # tensor of Lagrange multipliers with shape (m,), where m is the number of # constraints. + # + # FUTURE WORK: make the dtype a parameter. return standard_ops.zeros((num_constraints,), dtype=dtypes.float32) def _lagrange_multipliers(self, state): diff --git a/tensorflow/contrib/constrained_optimization/python/swap_regret_optimizer.py b/tensorflow/contrib/constrained_optimization/python/swap_regret_optimizer.py index 3791dae8d7f6b03bc1115bca97811dfc4775c45b..2c673d9347141b3a12eb9ec76065d22f1769ac12 100644 --- a/tensorflow/contrib/constrained_optimization/python/swap_regret_optimizer.py +++ b/tensorflow/contrib/constrained_optimization/python/swap_regret_optimizer.py @@ -79,9 +79,11 @@ def _maximal_eigenvector_power_method(matrix, The maximal right-eigenvector of `matrix`. Raises: - ValueError: If the epsilon or maximum_iterations parameters violate their - bounds. + ValueError: If the `matrix` tensor is not floating-point, or if the + `epsilon` or `maximum_iterations` parameters violate their bounds. """ + if not matrix.dtype.is_floating: + raise ValueError("multipliers must have a floating-point dtype") if epsilon <= 0.0: raise ValueError("epsilon must be strictly positive") if maximum_iterations <= 0: @@ -139,18 +141,20 @@ def _project_stochastic_matrix_wrt_euclidean_norm(matrix): (i.e. the Frobenius norm). Raises: - ValueError: if the `matrix` tensor does not have a fully-known shape, or is - not two-dimensional and square. + ValueError: if the `matrix` tensor is not floating-point, does not have a + fully-known shape, or is not two-dimensional and square. """ + if not matrix.dtype.is_floating: + raise ValueError("multipliers must have a floating-point dtype") matrix_shape = matrix.get_shape() - if matrix_shape is None: + if matrix_shape.ndims is None: raise ValueError("matrix must have known shape") if matrix_shape.ndims != 2: raise ValueError( "matrix must be two dimensional (instead is %d-dimensional)" % matrix_shape.ndims) if matrix_shape[0] != matrix_shape[1]: - raise ValueError("matrix must be be square (instead has shape (%d,%d))" % + raise ValueError("matrix must be square (instead has shape (%d,%d))" % (matrix_shape[0], matrix_shape[1])) dimension = matrix_shape[0].value if dimension is None: @@ -172,12 +176,12 @@ def _project_stochastic_matrix_wrt_euclidean_norm(matrix): matrix, axis=0, keepdims=True)) / standard_ops.maximum( 1.0, standard_ops.reduce_sum(inactive, axis=0, keepdims=True)) matrix += scale * inactive - new_inactive = standard_ops.to_float(matrix > 0) + new_inactive = standard_ops.cast(matrix > 0, matrix.dtype) matrix *= new_inactive return (iteration, matrix, new_inactive, inactive) iteration = standard_ops.constant(0) - inactive = standard_ops.ones_like(matrix) + inactive = standard_ops.ones_like(matrix, dtype=matrix.dtype) # We actually want a do-while loop, so we explicitly call while_loop_body() # once before tf.while_loop(). @@ -218,7 +222,7 @@ class _SwapRegretOptimizer(constrained_optimizer.ConstrainedOptimizer): """Base class representing a `_SwapRegretOptimizer`. This class contains most of the logic for performing constrained optimization, - minimizing external regret for the constraints player. What it *doesn't* do is + minimizing swap regret for the constraints player. What it *doesn't* do is keep track of the internal state (the stochastic matrix). Instead, the state is accessed via the _initial_state(), _stochastic_matrix(), _constraint_grad_and_var() and _projection_op() methods. @@ -291,16 +295,16 @@ class _SwapRegretOptimizer(constrained_optimizer.ConstrainedOptimizer): def _projection_op(self, state, name=None): pass - def minimize_constrained(self, - minimization_problem, - global_step=None, - var_list=None, - gate_gradients=train_optimizer.Optimizer.GATE_OP, - aggregation_method=None, - colocate_gradients_with_ops=False, - name=None, - grad_loss=None): - """Returns an `Op` for minimizing the constrained problem. + def _minimize_constrained(self, + minimization_problem, + global_step=None, + var_list=None, + gate_gradients=train_optimizer.Optimizer.GATE_OP, + aggregation_method=None, + colocate_gradients_with_ops=False, + name=None, + grad_loss=None): + """Returns an `Operation` for minimizing the constrained problem. The `optimizer` constructor parameter will be used to update the model parameters, while the constraint/objective weight matrix (the analogue of @@ -320,8 +324,11 @@ class _SwapRegretOptimizer(constrained_optimizer.ConstrainedOptimizer): name: as in `tf.train.Optimizer`'s `minimize` method. grad_loss: as in `tf.train.Optimizer`'s `minimize` method. + Raises: + ValueError: If the minimization_problem tensors have different dtypes. + Returns: - TensorFlow Op. + `Operation`, the train_op. """ objective = minimization_problem.objective @@ -329,6 +336,14 @@ class _SwapRegretOptimizer(constrained_optimizer.ConstrainedOptimizer): proxy_constraints = minimization_problem.proxy_constraints if proxy_constraints is None: proxy_constraints = constraints + + # Make sure that the objective, constraints and proxy constraints all have + # the same dtype. + if (objective.dtype.base_dtype != constraints.dtype.base_dtype or + objective.dtype.base_dtype != proxy_constraints.dtype.base_dtype): + raise ValueError("objective, constraints and proxy_constraints must " + "have the same dtype") + # Flatten both constraints tensors to 1d. num_constraints = minimization_problem.num_constraints constraints = standard_ops.reshape(constraints, shape=(num_constraints,)) @@ -344,15 +359,18 @@ class _SwapRegretOptimizer(constrained_optimizer.ConstrainedOptimizer): name="swap_regret_optimizer_state") zero_and_constraints = standard_ops.concat( - (standard_ops.zeros((1,)), constraints), axis=0) + (standard_ops.zeros((1,), dtype=constraints.dtype), constraints), + axis=0) objective_and_proxy_constraints = standard_ops.concat( (standard_ops.expand_dims(objective, 0), proxy_constraints), axis=0) distribution = self._distribution(state) - loss = standard_ops.tensordot(distribution, objective_and_proxy_constraints, - 1) + loss = standard_ops.tensordot( + standard_ops.cast(distribution, objective_and_proxy_constraints.dtype), + objective_and_proxy_constraints, 1) matrix_gradient = standard_ops.matmul( - standard_ops.expand_dims(zero_and_constraints, 1), + standard_ops.expand_dims( + standard_ops.cast(zero_and_constraints, distribution.dtype), 1), standard_ops.expand_dims(distribution, 0)) update_ops = [] @@ -555,6 +573,7 @@ class MultiplicativeSwapRegretOptimizer(_SwapRegretOptimizer): log_initial_one = math.log(1.0 - (self._initial_multiplier_radius * (dimension - 1) / (dimension))) log_initial_zero = math.log(self._initial_multiplier_radius / dimension) + # FUTURE WORK: make the dtype a parameter. return standard_ops.concat( (standard_ops.constant( log_initial_one, dtype=dtypes.float32, shape=(1, dimension)), diff --git a/tensorflow/contrib/copy_graph/python/util/copy_elements.py b/tensorflow/contrib/copy_graph/python/util/copy_elements.py index 5931c8a27996534cca80797e8b840559c124297c..6c9ab6aeb87fd39b22ab4f28d69b432b15899a13 100644 --- a/tensorflow/contrib/copy_graph/python/util/copy_elements.py +++ b/tensorflow/contrib/copy_graph/python/util/copy_elements.py @@ -219,8 +219,10 @@ def copy_op_to_graph(org_instance, to_graph, variables, scope=''): op_def) #Use Graph's hidden methods to add the op to_graph._record_op_seen_by_control_dependencies(new_op) - for device_function in reversed(to_graph._device_function_stack): + # pylint: disable=protected-access + for device_function in to_graph._device_functions_outer_to_inner: new_op._set_device(device_function(new_op)) + # pylint: enable=protected-access return new_op diff --git a/tensorflow/contrib/crf/__init__.py b/tensorflow/contrib/crf/__init__.py index 615e62b16f1906dafa22a12cc7275a2335e8df88..fe5e34d258fbc1508a0a85655f29c2c9bc8fa8b1 100644 --- a/tensorflow/contrib/crf/__init__.py +++ b/tensorflow/contrib/crf/__init__.py @@ -14,7 +14,7 @@ # ============================================================================== """Linear-chain CRF layer. -See the @{$python/contrib.crf} guide. +See the [CRF](https://tensorflow.org/api_guides/python/contrib.crf) guide. @@crf_binary_score @@crf_decode diff --git a/tensorflow/contrib/crf/python/kernel_tests/crf_test.py b/tensorflow/contrib/crf/python/kernel_tests/crf_test.py index f56a973f6f80b81697e9f58578e60a2efb90154e..8cfe14205927bf7763cf36fa31012ab10fce995c 100644 --- a/tensorflow/contrib/crf/python/kernel_tests/crf_test.py +++ b/tensorflow/contrib/crf/python/kernel_tests/crf_test.py @@ -158,7 +158,7 @@ class CrfTest(test.TestCase): # Test both the length-1 and regular cases. sequence_lengths_list = [ np.array(3, dtype=np.int32), - np.array(1, dtype=np.int32) + np.array(1, dtype=np.int64) ] inputs_list = [ np.array([[4, 5, -3], [3, -1, 3], [-1, 2, 1], [0, 0, 0]], @@ -291,7 +291,7 @@ class CrfTest(test.TestCase): # Test both the length-1 and regular cases. sequence_lengths_list = [ np.array(3, dtype=np.int32), - np.array(1, dtype=np.int32) + np.array(1, dtype=np.int64) ] inputs_list = [ np.array([[4, 5, -3], [3, -1, 3], [-1, 2, 1], [0, 0, 0]], diff --git a/tensorflow/contrib/crf/python/ops/crf.py b/tensorflow/contrib/crf/python/ops/crf.py index 8a7ff61bc8391efe453ee37019c23bd6ccbdf066..2a91dcb63a80016e62d10d1310ca57e3e54434c5 100644 --- a/tensorflow/contrib/crf/python/ops/crf.py +++ b/tensorflow/contrib/crf/python/ops/crf.py @@ -548,7 +548,9 @@ def crf_decode(potentials, transition_params, sequence_length): initial_state = array_ops.squeeze(initial_state, axis=[1]) # [B, O] inputs = array_ops.slice(potentials, [0, 1, 0], [-1, -1, -1]) # [B, T-1, O] # Sequence length is not allowed to be less than zero. - sequence_length_less_one = math_ops.maximum(0, sequence_length - 1) + sequence_length_less_one = math_ops.maximum( + constant_op.constant(0, dtype=sequence_length.dtype), + sequence_length - 1) backpointers, last_score = rnn.dynamic_rnn( # [B, T - 1, O], [B, O] crf_fwd_cell, inputs=inputs, diff --git a/tensorflow/contrib/cudnn_rnn/python/layers/cudnn_rnn.py b/tensorflow/contrib/cudnn_rnn/python/layers/cudnn_rnn.py index d58198faf353aab68430d2fa153a18de359112de..e26d56c8579e110d61c73c6154b82f47f0093687 100644 --- a/tensorflow/contrib/cudnn_rnn/python/layers/cudnn_rnn.py +++ b/tensorflow/contrib/cudnn_rnn/python/layers/cudnn_rnn.py @@ -56,7 +56,7 @@ class _CudnnRNN(base_layer.Layer): Cudnn RNNs have two major differences from other platform-independent RNNs tf provides: * Cudnn LSTM and GRU are mathematically different from their tf counterparts. - (e.g. @{tf.contrib.rnn.LSTMBlockCell} and @{tf.nn.rnn_cell.GRUCell}. + (e.g. `tf.contrib.rnn.LSTMBlockCell` and `tf.nn.rnn_cell.GRUCell`. * Cudnn-trained checkpoints are not directly compatible with tf RNNs: * They use a single opaque parameter buffer for the entire (possibly) multi-layer multi-directional RNN; Whereas tf RNN weights are per-cell and @@ -182,7 +182,7 @@ class _CudnnRNN(base_layer.Layer): dropout: dropout rate, a number between [0, 1]. Dropout is applied between each layer (no dropout is applied for a model with a single layer). When set to 0, dropout is disabled. - seed: the op seed used for initializing dropout. See @{tf.set_random_seed} + seed: the op seed used for initializing dropout. See `tf.set_random_seed` for behavior. dtype: tf.float16, tf.float32 or tf.float64 kernel_initializer: starting value to initialize the weight. diff --git a/tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py b/tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py index 748d7cd011f32fdebd781176b560b9b7498f327e..2c92f31788378c2a9f01183bc04b035668b59b59 100644 --- a/tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py +++ b/tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py @@ -61,8 +61,8 @@ _WEIGHTS_VARIABLE_NAME = rnn_cell_impl._WEIGHTS_VARIABLE_NAME class CudnnCompatibleLSTMCell(lstm_ops.LSTMBlockCell): """Cudnn Compatible LSTMCell. - A simple wrapper around @{tf.contrib.rnn.LSTMBlockCell} to use along with - @{tf.contrib.cudnn_rnn.CudnnLSTM}. The latter's params can be used by + A simple wrapper around `tf.contrib.rnn.LSTMBlockCell` to use along with + `tf.contrib.cudnn_rnn.CudnnLSTM`. The latter's params can be used by this cell seamlessly. """ @@ -76,8 +76,8 @@ class CudnnCompatibleLSTMCell(lstm_ops.LSTMBlockCell): class CudnnCompatibleGRUCell(rnn_cell_impl.GRUCell): """Cudnn Compatible GRUCell. - A GRU impl akin to @{tf.nn.rnn_cell.GRUCell} to use along with - @{tf.contrib.cudnn_rnn.CudnnGRU}. The latter's params can be used by + A GRU impl akin to `tf.nn.rnn_cell.GRUCell` to use along with + `tf.contrib.cudnn_rnn.CudnnGRU`. The latter's params can be used by it seamlessly. It differs from platform-independent GRUs in how the new memory gate is @@ -97,7 +97,7 @@ class CudnnCompatibleGRUCell(rnn_cell_impl.GRUCell): $$h_t = (1 - u_t) .* h'_t + u_t .* h_t-1$$ ``` - Other GRU (see @{tf.nn.rnn_cell.GRUCell} and @{tf.contrib.rnn.GRUBlockCell}): + Other GRU (see `tf.nn.rnn_cell.GRUCell` and `tf.contrib.rnn.GRUBlockCell`): ```python # new memory gate \\(h'_t = tanh(x_t * W_h + (r_t .* h_t-1) * R_h + b_{Wh})\\) @@ -891,7 +891,7 @@ def _cudnn_rnn(inputs, direction: the direction model that the model operates. Could be either 'unidirectional' or 'bidirectional' dropout: whether to enable dropout. With it is 0, dropout is disabled. - seed: the op seed used for initializing dropout. See @{tf.set_random_seed} + seed: the op seed used for initializing dropout. See `tf.set_random_seed` for behavior. name: name of the operation. Returns: @@ -957,7 +957,7 @@ def cudnn_lstm(inputs, direction: the direction model that the model operates. Could be either 'unidirectional' or 'bidirectional' dropout: whether to enable dropout. With it is 0, dropout is disabled. - seed: the op seed used for initializing dropout. See @{tf.set_random_seed} + seed: the op seed used for initializing dropout. See `tf.set_random_seed` for behavior. name: name of the operation. Returns: @@ -998,7 +998,7 @@ def _cudnn_rnn_no_input_c(inputs, direction: the direction model that the model operates. Could be either 'unidirectional' or 'bidirectional' dropout: whether to enable dropout. With it is 0, dropout is disabled. - seed: the op seed used for initializing dropout. See @{tf.set_random_seed} + seed: the op seed used for initializing dropout. See `tf.set_random_seed` for behavior. name: name of the operation. Returns: @@ -1040,7 +1040,7 @@ def cudnn_gru(inputs, direction: the direction model that the model operates. Could be either 'unidirectional' or 'bidirectional' dropout: whether to enable dropout. With it is 0, dropout is disabled. - seed: the op seed used for initializing dropout. See @{tf.set_random_seed} + seed: the op seed used for initializing dropout. See `tf.set_random_seed` for behavior. name: name of the operation. Returns: @@ -1079,7 +1079,7 @@ def cudnn_rnn_relu(inputs, direction: the direction model that the model operates. Could be either 'unidirectional' or 'bidirectional' dropout: whether to enable dropout. With it is 0, dropout is disabled. - seed: the op seed used for initializing dropout. See @{tf.set_random_seed} + seed: the op seed used for initializing dropout. See `tf.set_random_seed` for behavior. name: name of the operation. Returns: @@ -1119,7 +1119,7 @@ def cudnn_rnn_tanh(inputs, direction: the direction model that the model operates. Could be either 'unidirectional' or 'bidirectional' dropout: whether to enable dropout. With it is 0, dropout is disabled. - seed: the op seed used for initializing dropout. See @{tf.set_random_seed} + seed: the op seed used for initializing dropout. See `tf.set_random_seed` for behavior. name: name of the operation. Returns: @@ -1161,7 +1161,7 @@ def cudnn_rnn_opaque_params_to_canonical(rnn_mode, direction: the direction model that the model operates. Could be either 'unidirectional' or 'bidirectional' dropout: whether to enable dropout. With it is 0, dropout is disabled. - seed: the op seed used for initializing dropout. See @{tf.set_random_seed} + seed: the op seed used for initializing dropout. See `tf.set_random_seed` for behavior. name: name of the operation. Returns: @@ -1224,7 +1224,7 @@ def cudnn_rnn_canonical_to_opaque_params(rnn_mode, direction: the direction model that the model operates. Could be either 'unidirectional' or 'bidirectional' dropout: whether to enable dropout. With it is 0, dropout is disabled. - seed: the op seed used for initializing dropout. See @{tf.set_random_seed} + seed: the op seed used for initializing dropout. See `tf.set_random_seed` for behavior. name: name of the operation. Returns: @@ -1282,7 +1282,7 @@ def cudnn_rnn_opaque_params_size(rnn_mode, 'unidirectional' or 'bidirectional' dtype: one of tf.float32 or tf.float64. dropout: whether to enable dropout. With it is 0, dropout is disabled. - seed: the op seed used for initializing dropout. See @{tf.set_random_seed} + seed: the op seed used for initializing dropout. See `tf.set_random_seed` for behavior. name: name of the operation. Returns: @@ -1349,7 +1349,7 @@ class _CudnnRNN(object): 'unidirectional' or 'bidirectional' dtype: dtype of params, tf.float32 or tf.float64. dropout: whether to enable dropout. With it is 0, dropout is disabled. - seed: the op seed used for initializing dropout. See @{tf.set_random_seed} + seed: the op seed used for initializing dropout. See `tf.set_random_seed` for behavior. Raises: ValueError: if direction is invalid. diff --git a/tensorflow/contrib/data/__init__.py b/tensorflow/contrib/data/__init__.py index 7878e46e88b2ea8b0012768342c218baeda80eaa..5821d51bca491b1e5c5388c0c82088ca0eb8fed3 100644 --- a/tensorflow/contrib/data/__init__.py +++ b/tensorflow/contrib/data/__init__.py @@ -15,12 +15,12 @@ """Experimental API for building input pipelines. This module contains experimental `Dataset` sources and transformations that can -be used in conjunction with the @{tf.data.Dataset} API. Note that the +be used in conjunction with the `tf.data.Dataset` API. Note that the `tf.contrib.data` API is not subject to the same backwards compatibility guarantees as `tf.data`, but we will provide deprecation advice in advance of removing existing functionality. -See @{$guide/datasets$Importing Data} for an overview. +See [Importing Data](https://tensorflow.org/guide/datasets) for an overview. @@Counter @@CheckpointInputPipelineHook diff --git a/tensorflow/contrib/data/kernels/BUILD b/tensorflow/contrib/data/kernels/BUILD index 566cbb246a104d1e6cfc284d220ca8386b8897e1..2e249f5c14ab111ae412ff3288acc25de8d7aa11 100644 --- a/tensorflow/contrib/data/kernels/BUILD +++ b/tensorflow/contrib/data/kernels/BUILD @@ -37,6 +37,7 @@ cc_library( "//third_party/eigen3", "@protobuf_archive//:protobuf_headers", ], + alwayslink = 1, ) cc_library( @@ -58,6 +59,7 @@ cc_library( "//third_party/eigen3", "@protobuf_archive//:protobuf_headers", ], + alwayslink = 1, ) cc_library( @@ -68,6 +70,7 @@ cc_library( "//third_party/eigen3", "@protobuf_archive//:protobuf_headers", ], + alwayslink = 1, ) cc_library( @@ -78,6 +81,7 @@ cc_library( "//third_party/eigen3", "@protobuf_archive//:protobuf_headers", ], + alwayslink = 1, ) cc_library( diff --git a/tensorflow/contrib/data/kernels/assert_next_dataset_op.cc b/tensorflow/contrib/data/kernels/assert_next_dataset_op.cc index 95b8e1f7fd487119d77a5f708de42b014c55f79d..e36c9c0634235022362b59a6699b4d550d6d0eee 100644 --- a/tensorflow/contrib/data/kernels/assert_next_dataset_op.cc +++ b/tensorflow/contrib/data/kernels/assert_next_dataset_op.cc @@ -42,13 +42,13 @@ class AssertNextDatasetOp : public UnaryDatasetOpKernel { } private: - class Dataset : public GraphDatasetBase { + class Dataset : public DatasetBase { public: Dataset(OpKernelContext* ctx, const DatasetBase* input, const std::vector& transformations, const DataTypeVector& output_types, const std::vector& output_shapes) - : GraphDatasetBase(ctx), + : DatasetBase(DatasetContext(ctx)), input_(input), transformations_(transformations), output_types_(output_types), @@ -76,10 +76,11 @@ class AssertNextDatasetOp : public UnaryDatasetOpKernel { } protected: - Status AsGraphDefInternal(OpKernelContext* ctx, DatasetGraphDefBuilder* b, + Status AsGraphDefInternal(SerializationContext* ctx, + DatasetGraphDefBuilder* b, Node** output) const override { Node* input_graph_node = nullptr; - TF_RETURN_IF_ERROR(b->AddParentDataset(ctx, input_, &input_graph_node)); + TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node)); Node* transformations_node = nullptr; TF_RETURN_IF_ERROR(b->AddVector(transformations_, &transformations_node)); TF_RETURN_IF_ERROR(b->AddDataset( @@ -121,13 +122,13 @@ class AssertNextDatasetOp : public UnaryDatasetOpKernel { protected: Status SaveInternal(IteratorStateWriter* writer) override { - TF_RETURN_IF_ERROR(SaveParent(writer, input_impl_)); + TF_RETURN_IF_ERROR(SaveInput(writer, input_impl_)); return Status::OK(); } Status RestoreInternal(IteratorContext* ctx, IteratorStateReader* reader) override { - TF_RETURN_IF_ERROR(RestoreParent(ctx, reader, input_impl_)); + TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_)); return Status::OK(); } diff --git a/tensorflow/contrib/data/kernels/csv_dataset_op.cc b/tensorflow/contrib/data/kernels/csv_dataset_op.cc index f7e3ed886c6655cdc07e08bbe2fbe82e671a6802..d242cfdf4911ee43051b8aa2f7b960916b40374a 100644 --- a/tensorflow/contrib/data/kernels/csv_dataset_op.cc +++ b/tensorflow/contrib/data/kernels/csv_dataset_op.cc @@ -131,7 +131,7 @@ class CSVDatasetOp : public DatasetOpKernel { } private: - class Dataset : public GraphDatasetBase { + class Dataset : public DatasetBase { public: Dataset(OpKernelContext* ctx, std::vector filenames, bool header, string compression_type, io::ZlibCompressionOptions options, @@ -139,7 +139,7 @@ class CSVDatasetOp : public DatasetOpKernel { const std::vector& output_shapes, std::vector record_defaults, std::vector select_cols, bool use_quote_delim, char delim, string na_value) - : GraphDatasetBase(ctx), + : DatasetBase(DatasetContext(ctx)), filenames_(std::move(filenames)), header_(header), out_type_(output_types), @@ -168,7 +168,8 @@ class CSVDatasetOp : public DatasetOpKernel { string DebugString() const override { return "CSVDatasetOp::Dataset"; } protected: - Status AsGraphDefInternal(DatasetGraphDefBuilder* b, + Status AsGraphDefInternal(SerializationContext* ctx, + DatasetGraphDefBuilder* b, Node** output) const override { Node* filenames = nullptr; Node* compression_type = nullptr; diff --git a/tensorflow/contrib/data/kernels/directed_interleave_dataset_op.cc b/tensorflow/contrib/data/kernels/directed_interleave_dataset_op.cc index 6a12ca06f4d6cc2096aaf8191a01a899881b43db..ccf7ec1f842f5a1ad9b304c904f046ad49ed1757 100644 --- a/tensorflow/contrib/data/kernels/directed_interleave_dataset_op.cc +++ b/tensorflow/contrib/data/kernels/directed_interleave_dataset_op.cc @@ -63,11 +63,11 @@ class DirectedInterleaveDatasetOp : public DatasetOpKernel { } private: - class Dataset : public GraphDatasetBase { + class Dataset : public DatasetBase { public: Dataset(OpKernelContext* ctx, const DatasetBase* selector_input, std::vector data_inputs) - : GraphDatasetBase(ctx), + : DatasetBase(DatasetContext(ctx)), selector_input_(selector_input), data_inputs_(std::move(data_inputs)) { selector_input_->Ref(); @@ -110,15 +110,16 @@ class DirectedInterleaveDatasetOp : public DatasetOpKernel { } protected: - Status AsGraphDefInternal(OpKernelContext* ctx, DatasetGraphDefBuilder* b, + Status AsGraphDefInternal(SerializationContext* ctx, + DatasetGraphDefBuilder* b, Node** output) const override { Node* selector_input_node; TF_RETURN_IF_ERROR( - b->AddParentDataset(ctx, selector_input_, &selector_input_node)); + b->AddInputDataset(ctx, selector_input_, &selector_input_node)); std::vector data_input_nodes(data_inputs_.size()); for (size_t i = 0; i < data_inputs_.size(); ++i) { TF_RETURN_IF_ERROR( - b->AddParentDataset(ctx, data_inputs_[i], &data_input_nodes[i])); + b->AddInputDataset(ctx, data_inputs_[i], &data_input_nodes[i])); } TF_RETURN_IF_ERROR(b->AddDataset(this, {{0, selector_input_node}}, {{1, data_input_nodes}}, {}, output)); @@ -204,7 +205,7 @@ class DirectedInterleaveDatasetOp : public DatasetOpKernel { Status SaveInternal(IteratorStateWriter* writer) override { mutex_lock l(mu_); if (selector_input_impl_) { - TF_RETURN_IF_ERROR(SaveParent(writer, selector_input_impl_)); + TF_RETURN_IF_ERROR(SaveInput(writer, selector_input_impl_)); } else { TF_RETURN_IF_ERROR( writer->WriteScalar(full_name("selector_input_impl_empty"), "")); @@ -212,7 +213,7 @@ class DirectedInterleaveDatasetOp : public DatasetOpKernel { for (size_t i = 0; i < data_input_impls_.size(); ++i) { const auto& data_input_impl = data_input_impls_[i]; if (data_input_impl) { - TF_RETURN_IF_ERROR(SaveParent(writer, data_input_impl)); + TF_RETURN_IF_ERROR(SaveInput(writer, data_input_impl)); } else { TF_RETURN_IF_ERROR(writer->WriteScalar( full_name(strings::StrCat("data_input_impl_empty[", i, "]")), @@ -226,15 +227,14 @@ class DirectedInterleaveDatasetOp : public DatasetOpKernel { IteratorStateReader* reader) override { mutex_lock l(mu_); if (!reader->Contains(full_name("selector_input_impl_empty"))) { - TF_RETURN_IF_ERROR(RestoreParent(ctx, reader, selector_input_impl_)); + TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, selector_input_impl_)); } else { selector_input_impl_.reset(); } for (size_t i = 0; i < data_input_impls_.size(); ++i) { if (!reader->Contains(full_name( strings::StrCat("data_input_impl_empty[", i, "]")))) { - TF_RETURN_IF_ERROR( - RestoreParent(ctx, reader, data_input_impls_[i])); + TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, data_input_impls_[i])); } else { data_input_impls_[i].reset(); } diff --git a/tensorflow/contrib/data/kernels/ignore_errors_dataset_op.cc b/tensorflow/contrib/data/kernels/ignore_errors_dataset_op.cc index bbec50681c6f5decec5a3b5fbf09cc3011a21199..db24e608463224f05159b57eb721718afd7cbb20 100644 --- a/tensorflow/contrib/data/kernels/ignore_errors_dataset_op.cc +++ b/tensorflow/contrib/data/kernels/ignore_errors_dataset_op.cc @@ -35,10 +35,10 @@ class IgnoreErrorsDatasetOp : public UnaryDatasetOpKernel { } private: - class Dataset : public GraphDatasetBase { + class Dataset : public DatasetBase { public: explicit Dataset(OpKernelContext* ctx, const DatasetBase* input) - : GraphDatasetBase(ctx), input_(input) { + : DatasetBase(DatasetContext(ctx)), input_(input) { input_->Ref(); } @@ -62,10 +62,11 @@ class IgnoreErrorsDatasetOp : public UnaryDatasetOpKernel { } protected: - Status AsGraphDefInternal(OpKernelContext* ctx, DatasetGraphDefBuilder* b, + Status AsGraphDefInternal(SerializationContext* ctx, + DatasetGraphDefBuilder* b, Node** output) const override { Node* input_graph_node = nullptr; - TF_RETURN_IF_ERROR(b->AddParentDataset(ctx, input_, &input_graph_node)); + TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node)); TF_RETURN_IF_ERROR(b->AddDataset(this, {input_graph_node}, output)); return Status::OK(); } @@ -106,7 +107,7 @@ class IgnoreErrorsDatasetOp : public UnaryDatasetOpKernel { Status SaveInternal(IteratorStateWriter* writer) override { mutex_lock l(mu_); if (input_impl_) - TF_RETURN_IF_ERROR(SaveParent(writer, input_impl_)); + TF_RETURN_IF_ERROR(SaveInput(writer, input_impl_)); else TF_RETURN_IF_ERROR( writer->WriteScalar(full_name("input_impls_empty"), "")); @@ -119,7 +120,7 @@ class IgnoreErrorsDatasetOp : public UnaryDatasetOpKernel { if (reader->Contains(full_name("input_impls_empty"))) input_impl_.reset(); else - TF_RETURN_IF_ERROR(RestoreParent(ctx, reader, input_impl_)); + TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_)); return Status::OK(); } diff --git a/tensorflow/contrib/data/kernels/prefetching_kernels.cc b/tensorflow/contrib/data/kernels/prefetching_kernels.cc index 32f03ca68364e40c6fd6769f05d0566f50119240..74df1e42a8fbca9b6a65aa4800424d27aa90de24 100644 --- a/tensorflow/contrib/data/kernels/prefetching_kernels.cc +++ b/tensorflow/contrib/data/kernels/prefetching_kernels.cc @@ -526,6 +526,15 @@ string SanitizeThreadSuffix(string suffix) { return clean; } +struct HostBufferElement { + Status status; + bool end_of_sequence; + std::vector value; +}; + +using MultiDeviceIteratorCallback = + std::function; + class MultiDeviceIterator : public ResourceBase { public: MultiDeviceIterator(const DataTypeVector& output_types, @@ -539,83 +548,45 @@ class MultiDeviceIterator : public ResourceBase { devices_(devices), flib_def_(std::move(flib_def)), pflr_(std::move(pflr)), - lib_(lib) { - buffer_.resize(devices_.size()); - } + lib_(lib) {} string DebugString() override { - return strings::StrCat("MultiDeviceIterator"); + return strings::StrCat("MultiDeviceIterator for ", devices_.size(), + " devices"); } - Status Init(std::unique_ptr iterator, int64* incarnation_id) { - mutex_lock l(mu_); + Status Init(std::unique_ptr iterator, int64 max_buffer_size, + int64* incarnation_id) { if (iterator) { TF_RETURN_IF_ERROR( VerifyTypesMatch(output_types_, iterator->output_dtypes())); TF_RETURN_IF_ERROR( VerifyShapesCompatible(output_shapes_, iterator->output_shapes())); } - host_iterator_.reset(iterator.release()); - incarnation_id_++; + + mutex_lock l(mu_); + if (multi_device_buffer_) { + multi_device_buffer_->Reset(); + } + + ++incarnation_id_; *incarnation_id = incarnation_id_; - max_buffer_size_ = 0; - num_elements_ = 0; - buffer_.clear(); - buffer_.resize(devices_.size()); + + multi_device_buffer_.reset( + new MultiDeviceBuffer(devices_.size(), max_buffer_size, incarnation_id_, + std::move(iterator))); return Status::OK(); } - Status GetNextFromShard(IteratorContext* ctx, int shard_num, - int64 incarnation_id, - std::vector* out_tensors, - bool* end_of_sequence) { - // TODO(rohanj): This might potentially strand elements in other shards. - // Opportunity to do smarter locking semantics. - mutex_lock l(mu_); - // Make sure we're in the right incarnation. - if (incarnation_id != incarnation_id_) { - return errors::InvalidArgument( - "Current incarnation: ", incarnation_id_, - "; Supplied incarnation: ", incarnation_id); - } - // Then look it up in the buffer. - if (!buffer_[shard_num].empty()) { - const HostBufferElement& elem = buffer_[shard_num].front(); - *out_tensors = elem.value; - *end_of_sequence = elem.end_of_sequence; - Status s = elem.status; - buffer_[shard_num].pop_front(); - return s; - } - std::shared_ptr captured_iterator(host_iterator_); - if (captured_iterator) { - if (lib_ != nullptr) { - ctx->set_lib(lib_); - } - while (true) { - HostBufferElement elem; - elem.status = - captured_iterator->GetNext(ctx, &elem.value, &elem.end_of_sequence); - int buffer_index = num_elements_ % devices_.size(); - num_elements_++; - if (buffer_index == shard_num) { - out_tensors->swap(elem.value); - *end_of_sequence = elem.end_of_sequence; - return elem.status; - } else { - buffer_[buffer_index].push_back(std::move(elem)); - // TODO(rohanj): Put an upper bound to buffer size. - if (buffer_[buffer_index].size() > max_buffer_size_) { - max_buffer_size_ = buffer_[buffer_index].size(); - VLOG(1) << "MultiDeviceIterator: Max buffer size increased to: " - << max_buffer_size_; - } - } - } - } else { - return errors::FailedPrecondition("Iterator not initialized"); + void GetNextFromShard(IteratorContext* ctx, int shard_num, + int64 incarnation_id, + MultiDeviceIteratorCallback callback) { + if (lib_ != nullptr) { + ctx->set_lib(lib_); } - return Status::OK(); + tf_shared_lock l(mu_); + multi_device_buffer_->GetNextFromShard(ctx, shard_num, incarnation_id, + std::move(callback)); } const DataTypeVector& output_types() const { return output_types_; } @@ -630,25 +601,218 @@ class MultiDeviceIterator : public ResourceBase { } private: - struct HostBufferElement { - Status status; - bool end_of_sequence; - std::vector value; + // A private class that uses a background thread to keep a per device buffer + // full. + class MultiDeviceBuffer { + public: + MultiDeviceBuffer(size_t size, int64 max_buffer_size, int64 incarnation_id, + std::unique_ptr host_iterator) + : buffer_(size), + size_(size), + max_buffer_size_(max_buffer_size), + incarnation_id_(incarnation_id), + host_iterator_(std::move(host_iterator)) {} + + ~MultiDeviceBuffer() { Reset(); } + + void Reset() LOCKS_EXCLUDED(mu_) { + { + mutex_lock l(mu_); + if (background_thread_finished_) { + return; + } + + cancelled_ = true; + // Wake up the background thread. + for (int i = 0; i < size_; ++i) { + buffer_[i].cond_var.notify_all(); + } + + // Make sure background thread has finished first. + while (!background_thread_finished_) { + shutdown_cond_var_.wait(l); + } + } + RunPendingCallbacks(); + } + + void GetNextFromShard(IteratorContext* ctx, int shard_num, + int64 incarnation_id, + MultiDeviceIteratorCallback callback) { + HostBufferElement elem; + if (incarnation_id_ != incarnation_id) { + elem.status = errors::InvalidArgument("Invalid incarnation id"); + callback(elem); + return; + } + + bool produced_output = false; + { + mutex_lock l(mu_); + if (cancelled_) { + elem.status = errors::Cancelled("Cancelled Multidevice iterator"); + callback(elem); + return; + } + + EnsureBackgroundThreadStarted(ctx); + + if (!buffer_[shard_num].data.empty()) { + produced_output = true; + std::swap(elem, buffer_[shard_num].data.front()); + buffer_[shard_num].data.pop_front(); + // Wake up background thread if it is blocked on this element. + if (buffer_[shard_num].data.size() == max_buffer_size_ - 1) { + buffer_[shard_num].cond_var.notify_all(); + } + } else { + if (background_thread_finished_) { + produced_output = true; + elem.end_of_sequence = true; + } else { + buffer_[shard_num].callbacks.push_back(std::move(callback)); + callback = nullptr; + } + } + } + + if (produced_output) { + callback(elem); + } + } + + private: + void EnsureBackgroundThreadStarted(IteratorContext* ctx) + EXCLUSIVE_LOCKS_REQUIRED(mu_) { + if (!background_thread_) { + background_thread_.reset(ctx->env()->StartThread( + {}, "multi_device_iterator_background_thread", + std::bind(&MultiDeviceIterator::MultiDeviceBuffer::BackgroundThread, + this, new IteratorContext(*ctx)))); + } + } + + void RunPendingCallbacks() LOCKS_EXCLUDED(mu_) { + // Run all remaining callbacks. + std::vector cancellation_callbacks; + std::vector cancellation_elements; + { + mutex_lock l(mu_); + + for (int i = 0; i < size_; ++i) { + while (!buffer_[i].callbacks.empty()) { + if (buffer_[i].data.empty()) { + HostBufferElement elem; + elem.status = + errors::Cancelled("Cancelled and buffer not filled."); + cancellation_elements.push_back(std::move(elem)); + } else { + cancellation_elements.push_back( + std::move(buffer_[i].data.front())); + buffer_[i].data.pop_front(); + } + cancellation_callbacks.push_back( + std::move(buffer_[i].callbacks.front())); + buffer_[i].callbacks.pop_front(); + } + } + } + for (int i = 0; i < cancellation_callbacks.size(); ++i) { + cancellation_callbacks[i](cancellation_elements[i]); + } + } + + void BackgroundThread(IteratorContext* ctx) { + std::unique_ptr cleanup(ctx); + int shard_to_fetch = 0; + while (true) { + HostBufferElement elem; + MultiDeviceIteratorCallback callback = nullptr; + bool end_of_iterator = false; + + { + mutex_lock l(mu_); + while (!cancelled_ && + buffer_[shard_to_fetch].data.size() >= max_buffer_size_) { + buffer_[shard_to_fetch].cond_var.wait(l); + } + + if (cancelled_) { + background_thread_finished_ = true; + shutdown_cond_var_.notify_all(); + return; + } + } + + elem.status = + host_iterator_->GetNext(ctx, &elem.value, &elem.end_of_sequence); + + if (elem.status.ok() && elem.end_of_sequence) { + end_of_iterator = true; + } + + { + mutex_lock l(mu_); + // Try to find a callback, else just push stuff into buffer. + if (!buffer_[shard_to_fetch].callbacks.empty()) { + callback = buffer_[shard_to_fetch].callbacks.front(); + buffer_[shard_to_fetch].callbacks.pop_front(); + } else { + buffer_[shard_to_fetch].data.push_back(std::move(elem)); + elem = HostBufferElement(); + } + } + + if (callback) { + (*ctx->runner())(std::bind(std::move(callback), std::move(elem))); + } + + // Finish off the thread if we reach the end of the iterator. Runs + // pending callbacks. + if (end_of_iterator) { + { + mutex_lock l(mu_); + background_thread_finished_ = true; + shutdown_cond_var_.notify_all(); + } + RunPendingCallbacks(); + return; + } + shard_to_fetch = (shard_to_fetch + 1) % size_; + } + } + + struct HostBuffer { + condition_variable cond_var; + std::deque data; + std::deque callbacks; + }; + + mutex mu_; + std::unique_ptr background_thread_ GUARDED_BY(mu_); + bool background_thread_finished_ GUARDED_BY(mu_) = false; + bool cancelled_ GUARDED_BY(mu_) = false; + condition_variable shutdown_cond_var_ GUARDED_BY(mu_); + + std::vector buffer_; + + const size_t size_; + const int64 max_buffer_size_; + const int64 incarnation_id_; + const std::unique_ptr host_iterator_; }; mutex mu_; const DataTypeVector output_types_; const std::vector output_shapes_; const std::vector devices_; - int64 num_elements_ GUARDED_BY(mu_) = 0; - int64 max_buffer_size_ GUARDED_BY(mu_) = 0; - int64 incarnation_id_ GUARDED_BY(mu_) = 0; - std::vector> buffer_ GUARDED_BY(mu_); - std::unique_ptr flib_def_; - std::unique_ptr pflr_; - FunctionLibraryRuntime* lib_ = nullptr; // not owned. - std::shared_ptr host_iterator_; + const std::unique_ptr flib_def_; + const std::unique_ptr pflr_; + FunctionLibraryRuntime* const lib_ = nullptr; // not owned. std::shared_ptr lib_def_ GUARDED_BY(mu_); + + int64 incarnation_id_ GUARDED_BY(mu_) = 0; + std::unique_ptr multi_device_buffer_ GUARDED_BY(mu_); }; // Just creates a MultiDeviceIterator and returns it. @@ -754,6 +918,10 @@ class MultiDeviceIteratorInitOp : public OpKernel { : OpKernel(ctx) {} void Compute(OpKernelContext* ctx) override { + const Tensor* tensor_max_buffer_size; + OP_REQUIRES_OK(ctx, ctx->input("max_buffer_size", &tensor_max_buffer_size)); + int64 max_buffer_size = tensor_max_buffer_size->scalar()(); + DatasetBase* dataset; OP_REQUIRES_OK(ctx, GetDatasetFromVariantTensor(ctx->input(0), &dataset)); MultiDeviceIterator* resource; @@ -761,12 +929,12 @@ class MultiDeviceIteratorInitOp : public OpKernel { LookupResource(ctx, HandleFromInput(ctx, 1), &resource)); core::ScopedUnref unref(resource); - IteratorContext iter_ctx = dataset::MakeIteratorContext(ctx); std::unique_ptr iterator; - OP_REQUIRES_OK(ctx, - dataset->MakeIterator(&iter_ctx, "Iterator", &iterator)); + OP_REQUIRES_OK(ctx, dataset->MakeIterator(IteratorContext(ctx), "Iterator", + &iterator)); int64 incarnation_id; - OP_REQUIRES_OK(ctx, resource->Init(std::move(iterator), &incarnation_id)); + OP_REQUIRES_OK(ctx, resource->Init(std::move(iterator), max_buffer_size, + &incarnation_id)); Tensor tensor_incarnation_id(DT_INT64, TensorShape({})); tensor_incarnation_id.scalar()() = incarnation_id; OP_REQUIRES_OK(ctx, @@ -804,9 +972,6 @@ class MultiDeviceIteratorGetNextFromShardOp : public AsyncOpKernel { ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &iterator), done); thread_pool_->Schedule(std::bind( [ctx, iterator, shard_num, incarnation_id](DoneCallback done) { - std::vector components; - bool end_of_sequence = false; - IteratorContext::Params params; params.env = ctx->env(); params.runner = *(ctx->runner()); @@ -817,22 +982,26 @@ class MultiDeviceIteratorGetNextFromShardOp : public AsyncOpKernel { }; IteratorContext iter_ctx(std::move(params)); - Status s = - iterator->GetNextFromShard(&iter_ctx, shard_num, incarnation_id, - &components, &end_of_sequence); - iterator->Unref(); + MultiDeviceIteratorCallback callback = std::bind( + [ctx](const HostBufferElement& elem, DoneCallback done) { + // iterator->Unref(); + Status s = elem.status; + if (!s.ok()) { + ctx->SetStatus(s); + } else if (elem.end_of_sequence) { + ctx->SetStatus(errors::OutOfRange("End of sequence")); + } else { + for (int i = 0; i < elem.value.size(); ++i) { + ctx->set_output(i, elem.value[i]); + } + } + done(); + }, + std::placeholders::_1, std::move(done)); - if (!s.ok()) { - ctx->SetStatus(s); - } else if (end_of_sequence) { - ctx->SetStatus(errors::OutOfRange("End of sequence")); - } else { - for (int i = 0; i < components.size(); ++i) { - // TODO(mrry): Check that the shapes match the shape attrs. - ctx->set_output(i, components[i]); - } - } - done(); + iterator->GetNextFromShard(&iter_ctx, shard_num, incarnation_id, + callback); + iterator->Unref(); }, std::move(done))); } diff --git a/tensorflow/contrib/data/kernels/threadpool_dataset_op.cc b/tensorflow/contrib/data/kernels/threadpool_dataset_op.cc index 141706f393b076d9f55898ca4bdbe7438f7c3625..ab584504a05369105d080df73750974af9fc70bb 100644 --- a/tensorflow/contrib/data/kernels/threadpool_dataset_op.cc +++ b/tensorflow/contrib/data/kernels/threadpool_dataset_op.cc @@ -130,11 +130,13 @@ class ThreadPoolDatasetOp : public UnaryDatasetOpKernel { } private: - class Dataset : public GraphDatasetBase { + class Dataset : public DatasetBase { public: Dataset(OpKernelContext* ctx, const DatasetBase* input, ThreadPoolResource* threadpool) - : GraphDatasetBase(ctx), input_(input), threadpool_(threadpool) { + : DatasetBase(DatasetContext(ctx)), + input_(input), + threadpool_(threadpool) { input_->Ref(); threadpool_->Ref(); } @@ -162,11 +164,11 @@ class ThreadPoolDatasetOp : public UnaryDatasetOpKernel { } protected: - Status AsGraphDefInternal(OpKernelContext* ctx, DatasetGraphDefBuilder* b, + Status AsGraphDefInternal(SerializationContext* ctx, + DatasetGraphDefBuilder* b, Node** output) const override { - return errors::Unimplemented( - "Cannot currently serialize the thread pool for a " - "ThreadPoolDataset."); + return errors::Unimplemented("%s does not support serialization", + DebugString()); } private: diff --git a/tensorflow/contrib/data/kernels/unique_dataset_op.cc b/tensorflow/contrib/data/kernels/unique_dataset_op.cc index 67c237799c10a2724f18bb0df99e4bf8f5cd2b8a..6fbf5d2ebb598132a7e8433608e67436a172b615 100644 --- a/tensorflow/contrib/data/kernels/unique_dataset_op.cc +++ b/tensorflow/contrib/data/kernels/unique_dataset_op.cc @@ -47,10 +47,10 @@ class UniqueDatasetOp : public UnaryDatasetOpKernel { } private: - class Dataset : public GraphDatasetBase { + class Dataset : public DatasetBase { public: Dataset(OpKernelContext* ctx, const DatasetBase* input) - : GraphDatasetBase(ctx), input_(input) { + : DatasetBase(DatasetContext(ctx)), input_(input) { input_->Ref(); } @@ -75,10 +75,11 @@ class UniqueDatasetOp : public UnaryDatasetOpKernel { } protected: - Status AsGraphDefInternal(OpKernelContext* ctx, DatasetGraphDefBuilder* b, + Status AsGraphDefInternal(SerializationContext* ctx, + DatasetGraphDefBuilder* b, Node** output) const override { Node* input_graph_node = nullptr; - TF_RETURN_IF_ERROR(b->AddParentDataset(ctx, input_, &input_graph_node)); + TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node)); TF_RETURN_IF_ERROR(b->AddDataset(this, {input_graph_node}, output)); return Status::OK(); } @@ -116,7 +117,7 @@ class UniqueDatasetOp : public UnaryDatasetOpKernel { Status SaveInternal(IteratorStateWriter* writer) override { mutex_lock l(mu_); if (input_impl_) { - TF_RETURN_IF_ERROR(SaveParent(writer, input_impl_)); + TF_RETURN_IF_ERROR(SaveInput(writer, input_impl_)); } else { TF_RETURN_IF_ERROR( writer->WriteScalar(full_name("input_impl_empty"), "")); @@ -135,7 +136,7 @@ class UniqueDatasetOp : public UnaryDatasetOpKernel { IteratorStateReader* reader) override { mutex_lock l(mu_); if (!reader->Contains(full_name("input_impl_empty"))) { - TF_RETURN_IF_ERROR(RestoreParent(ctx, reader, input_impl_)); + TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_)); } else { input_impl_.reset(); } diff --git a/tensorflow/contrib/data/ops/dataset_ops.cc b/tensorflow/contrib/data/ops/dataset_ops.cc index 66a7c7fdcd5e0ab77596177c209470e17f63bc10..cc5e250ea15bf89be2db9aba14e3b29b72512a73 100644 --- a/tensorflow/contrib/data/ops/dataset_ops.cc +++ b/tensorflow/contrib/data/ops/dataset_ops.cc @@ -168,9 +168,11 @@ output_shapes: The list of shapes being produced. REGISTER_OP("MultiDeviceIteratorInit") .Input("dataset: variant") .Input("multi_device_iterator: resource") + .Input("max_buffer_size: int64") .Output("incarnation_id: int64") .Doc(R"doc( Initializes the multi device iterator with the given dataset. +max_buffer_size: The maximum size of the host side per device buffer to keep. incarnation_id: An int64 indicating which incarnation of the MultiDeviceIterator is running. dataset: Dataset to be iterated upon. diff --git a/tensorflow/contrib/data/python/kernel_tests/BUILD b/tensorflow/contrib/data/python/kernel_tests/BUILD index 2de1a79d28c16706e3c237d62935212ce387c776..2b75aa2ca54509b42f431db2dd39261cf025588a 100644 --- a/tensorflow/contrib/data/python/kernel_tests/BUILD +++ b/tensorflow/contrib/data/python/kernel_tests/BUILD @@ -175,7 +175,7 @@ py_test( "//tensorflow/python:variables", "//tensorflow/python/data/ops:dataset_ops", "//tensorflow/python/estimator", - "//tensorflow/python/estimator:model_fn", + "//tensorflow/python/estimator:estimator_py", ], ) @@ -198,21 +198,46 @@ py_test( "//tensorflow/python:errors", "//tensorflow/python:framework_ops", "//tensorflow/python:io_ops", + "//tensorflow/python:math_ops", "//tensorflow/python:util", "//tensorflow/python/data/ops:dataset_ops", "//third_party/py/numpy", ], ) +py_test( + name = "map_defun_op_test", + size = "small", + srcs = ["map_defun_op_test.py"], + srcs_version = "PY2AND3", + tags = ["no_pip"], + deps = [ + "//tensorflow/contrib/data/python/ops:map_defun", + "//tensorflow/python:array_ops", + "//tensorflow/python:check_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:constant_op", + "//tensorflow/python:dtypes", + "//tensorflow/python:framework_ops", + "//tensorflow/python:function", + "//tensorflow/python:math_ops", + ], +) + py_test( name = "optimize_dataset_op_test", size = "small", srcs = ["optimize_dataset_op_test.py"], srcs_version = "PY2AND3", deps = [ + ":stats_dataset_test_base", "//tensorflow/contrib/data/python/ops:optimization", + "//tensorflow/contrib/data/python/ops:stats_ops", "//tensorflow/python:client_testlib", + "//tensorflow/python:constant_op", + "//tensorflow/python:dtypes", "//tensorflow/python:errors", + "//tensorflow/python:math_ops", "//tensorflow/python/data/ops:dataset_ops", "@absl_py//absl/testing:parameterized", ], @@ -239,7 +264,7 @@ cuda_py_test( tags = [ "manual", "no_oss", - "no_windows_gpu" + + "no_windows_gpu", "notap", ], ) @@ -431,8 +456,8 @@ py_test( tags = ["no_pip"], deps = [ ":reader_dataset_ops_test_base", + ":stats_dataset_test_base", "//tensorflow/contrib/data/python/ops:stats_ops", - "//tensorflow/core:protos_all_py", "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", "//tensorflow/python:errors", @@ -442,6 +467,16 @@ py_test( ], ) +py_library( + name = "stats_dataset_test_base", + srcs = ["stats_dataset_test_base.py"], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow/core:protos_all_py", + "//tensorflow/python:client_testlib", + ], +) + py_test( name = "threadpool_dataset_ops_test", size = "small", diff --git a/tensorflow/contrib/data/python/kernel_tests/interleave_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/interleave_dataset_op_test.py index 44c3325a3db84bb844b7f860a7c925982f1e3d6a..7a3215f6ccfa807e8930ac8561587e474da61195 100644 --- a/tensorflow/contrib/data/python/kernel_tests/interleave_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/interleave_dataset_op_test.py @@ -777,6 +777,34 @@ class ParallelInterleaveDatasetTest(test.TestCase): with self.assertRaises(errors.OutOfRangeError): sess.run(self.next_element) + def testShutdownRace(self): + dataset = dataset_ops.Dataset.range(20) + map_fn = lambda x: dataset_ops.Dataset.range(20 * x, 20 * (x + 1)) + dataset = dataset.apply( + interleave_ops.parallel_interleave( + map_fn, + cycle_length=3, + sloppy=False, + buffer_output_elements=1, + prefetch_input_elements=0)) + dataset = dataset.batch(32) + iterator = dataset.make_initializable_iterator() + next_element = iterator.get_next() + + results = [] + with self.test_session() as sess: + for _ in range(2): + elements = [] + sess.run(iterator.initializer) + try: + while True: + elements.extend(sess.run(next_element)) + except errors.OutOfRangeError: + pass + results.append(elements) + + self.assertAllEqual(results[0], results[1]) + if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/data/python/kernel_tests/iterator_ops_test.py b/tensorflow/contrib/data/python/kernel_tests/iterator_ops_test.py index 30a993b1f7056b9726f524b2279131339c80c5eb..77148aceec7fa90f927a9c009671c2939460877b 100644 --- a/tensorflow/contrib/data/python/kernel_tests/iterator_ops_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/iterator_ops_test.py @@ -28,6 +28,7 @@ from tensorflow.python.framework import ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import variables from tensorflow.python.platform import test +from tensorflow.python.training import checkpoint_management from tensorflow.python.training import saver as saver_lib from tensorflow.python.training import training_util @@ -55,7 +56,7 @@ class CheckpointInputPipelineHookTest(test.TestCase): def _read_vars(self, model_dir): """Returns (global_step, latest_feature).""" with ops.Graph().as_default() as g: - ckpt_path = saver_lib.latest_checkpoint(model_dir) + ckpt_path = checkpoint_management.latest_checkpoint(model_dir) meta_filename = ckpt_path + '.meta' saver_lib.import_meta_graph(meta_filename) saver = saver_lib.Saver() diff --git a/tensorflow/contrib/data/python/kernel_tests/map_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/map_dataset_op_test.py index 48adc98e9a4caee1651d5c7bca9dd813f11dfb01..009e21a34c8df86af6abbb7599dbcfa23ddf90a7 100644 --- a/tensorflow/contrib/data/python/kernel_tests/map_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/map_dataset_op_test.py @@ -80,6 +80,7 @@ class MapDatasetTest(test.TestCase): sess.run(get_next) def testReadFileIgnoreError(self): + def write_string_to_file(value, filename): with open(filename, "w") as f: f.write(value) @@ -308,5 +309,50 @@ class MapDatasetBenchmark(test.Benchmark): opt_mark, chain_length)) +class MapAndFilterBenchmark(test.Benchmark): + + # This benchmark compares the performance of pipeline with multiple chained + # map + filter with and without map fusion. + def benchmarkMapAndFilter(self): + chain_lengths = [0, 1, 2, 5, 10, 20, 50] + for chain_length in chain_lengths: + self._benchmarkMapAndFilter(chain_length, False) + self._benchmarkMapAndFilter(chain_length, True) + + def _benchmarkMapAndFilter(self, chain_length, optimize_dataset): + with ops.Graph().as_default(): + dataset = dataset_ops.Dataset.from_tensors(0).repeat(None) + for _ in range(chain_length): + dataset = dataset.map(lambda x: x + 5).filter( + lambda x: math_ops.greater_equal(x - 5, 0)) + if optimize_dataset: + dataset = dataset.apply( + optimization.optimize(["map_and_filter_fusion"])) + + iterator = dataset.make_one_shot_iterator() + next_element = iterator.get_next() + + with session.Session() as sess: + for _ in range(10): + sess.run(next_element.op) + deltas = [] + for _ in range(100): + start = time.time() + for _ in range(100): + sess.run(next_element.op) + end = time.time() + deltas.append(end - start) + + median_wall_time = np.median(deltas) / 100 + opt_mark = "opt" if optimize_dataset else "no-opt" + print("Map and filter dataset {} chain length: {} Median wall time: {}". + format(opt_mark, chain_length, median_wall_time)) + self.report_benchmark( + iters=1000, + wall_time=median_wall_time, + name="benchmark_map_and_filter_dataset_chain_latency_{}_{}".format( + opt_mark, chain_length)) + + if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/data/python/kernel_tests/map_defun_op_test.py b/tensorflow/contrib/data/python/kernel_tests/map_defun_op_test.py new file mode 100644 index 0000000000000000000000000000000000000000..a711325daed12f45e4e533f18ee81adc7dec93be --- /dev/null +++ b/tensorflow/contrib/data/python/kernel_tests/map_defun_op_test.py @@ -0,0 +1,126 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for MapDefunOp.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.data.python.ops import map_defun +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import errors +from tensorflow.python.framework import function +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import check_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.platform import test + + +class MapDefunTest(test.TestCase): + + def testMapDefun_Simple(self): + + @function.Defun(dtypes.int32) + def simple_fn(x): + return x * 2 + 3 + + with self.test_session(): + nums = [[1, 2], [3, 4], [5, 6]] + elems = constant_op.constant(nums, dtype=dtypes.int32, name="data") + r = map_defun.map_defun(simple_fn, [elems], [dtypes.int32], [(2,)])[0] + expected = elems * 2 + 3 + self.assertAllEqual(self.evaluate(r), self.evaluate(expected)) + + def testMapDefun_MismatchedTypes(self): + + @function.Defun(dtypes.int32) + def fn(x): + return math_ops.cast(x, dtypes.float64) + + with self.test_session(): + nums = [1, 2, 3, 4, 5, 6] + elems = constant_op.constant(nums, dtype=dtypes.int32, name="data") + r = map_defun.map_defun(fn, [elems], [dtypes.int32], [()])[0] + with self.assertRaises(errors.InvalidArgumentError): + self.evaluate(r) + + def testMapDefun_MultipleOutputs(self): + + @function.Defun(dtypes.int32) + def fn(x): + return (x, math_ops.cast(x * 2 + 3, dtypes.float64)) + + with self.test_session(): + nums = [[1, 2], [3, 4], [5, 6]] + elems = constant_op.constant(nums, dtype=dtypes.int32, name="data") + r = map_defun.map_defun(fn, [elems], [dtypes.int32, dtypes.float64], + [(2,), (2,)]) + expected = [elems, elems * 2 + 3] + self.assertAllEqual(self.evaluate(r), self.evaluate(expected)) + + def testMapDefun_ShapeInference(self): + + @function.Defun(dtypes.int32) + def fn(x): + return x + + nums = [[1, 2], [3, 4], [5, 6]] + elems = constant_op.constant(nums, dtype=dtypes.int32, name="data") + result = map_defun.map_defun(fn, [elems], [dtypes.int32], [(2,)])[0] + self.assertEqual(result.get_shape(), (3, 2)) + + def testMapDefun_PartialShapeInference(self): + + @function.Defun(dtypes.int32) + def fn(x): + return x + + elems = array_ops.placeholder(dtypes.int64, (None, 2)) + result = map_defun.map_defun(fn, [elems], [dtypes.int32], [(2,)]) + self.assertEqual(result[0].get_shape().as_list(), [None, 2]) + + def testMapDefun_RaisesErrorOnRuntimeShapeMismatch(self): + + @function.Defun(dtypes.int32, dtypes.int32) + def fn(x, y): + return x, y + + elems1 = array_ops.placeholder(dtypes.int32) + elems2 = array_ops.placeholder(dtypes.int32) + result = map_defun.map_defun(fn, [elems1, elems2], + [dtypes.int32, dtypes.int32], [(), ()]) + with self.test_session() as sess: + with self.assertRaisesWithPredicateMatch( + errors.InvalidArgumentError, + "All inputs must have the same dimension 0."): + sess.run(result, feed_dict={elems1: [1, 2, 3, 4, 5], elems2: [1, 2, 3]}) + + def testMapDefun_RaisesDefunError(self): + + @function.Defun(dtypes.int32) + def fn(x): + with ops.control_dependencies([check_ops.assert_equal(x, 0)]): + return array_ops.identity(x) + + elems = constant_op.constant([0, 0, 0, 37, 0]) + result = map_defun.map_defun(fn, [elems], [dtypes.int32], [()]) + with self.test_session(): + with self.assertRaises(errors.InvalidArgumentError): + self.evaluate(result) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/data/python/kernel_tests/optimize_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/optimize_dataset_op_test.py index d8156dc9c7bf187d7399aede44c41c8c50670248..ae147b4fa79c5fc8e63e1860f45036709ecc9777 100644 --- a/tensorflow/contrib/data/python/kernel_tests/optimize_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/optimize_dataset_op_test.py @@ -19,9 +19,14 @@ from __future__ import print_function from absl.testing import parameterized +from tensorflow.contrib.data.python.kernel_tests import stats_dataset_test_base from tensorflow.contrib.data.python.ops import optimization +from tensorflow.contrib.data.python.ops import stats_ops from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors +from tensorflow.python.ops import math_ops from tensorflow.python.platform import test @@ -46,8 +51,7 @@ class OptimizeDatasetTest(test.TestCase, parameterized.TestCase): with self.assertRaisesRegexp( errors.InvalidArgumentError, "Asserted Whoops transformation at offset 0 but encountered " - "Map transformation instead." - ): + "Map transformation instead."): sess.run(get_next) def testAssertSuffixShort(self): @@ -123,19 +127,30 @@ class OptimizeDatasetTest(test.TestCase, parameterized.TestCase): functions = [identity, increment, increment_and_square] tests = [] - - for fun1 in functions: - for fun2 in functions: - tests.append(([fun1, fun2],)) - for fun3 in functions: - tests.append(([fun1, fun2, fun3],)) + for i, fun1 in enumerate(functions): + for j, fun2 in enumerate(functions): + tests.append(( + "test_{}_{}".format(i, j), + [fun1, fun2], + )) + for k, fun3 in enumerate(functions): + tests.append(( + "test_{}_{}_{}".format(i, j, k), + [fun1, fun2, fun3], + )) swap = lambda x, n: (n, x) - tests.append(([lambda x: (x, 42), swap],)) - tests.append(([lambda x: (x, 42), swap, swap],)) + tests.append(( + "swap1", + [lambda x: (x, 42), swap], + )) + tests.append(( + "swap2", + [lambda x: (x, 42), swap, swap], + )) return tuple(tests) - @parameterized.parameters(*map_functions.__func__()) + @parameterized.named_parameters(*map_functions.__func__()) def testMapFusion(self, functions): dataset = dataset_ops.Dataset.range(5).apply( optimization.assert_next(["Map", "Prefetch"])) @@ -159,6 +174,108 @@ class OptimizeDatasetTest(test.TestCase, parameterized.TestCase): with self.assertRaises(errors.OutOfRangeError): sess.run(get_next) + @staticmethod + def map_and_filter_functions(): + identity = lambda x: x + increment = lambda x: x + 1 + minus_five = lambda x: x - 5 + + def increment_and_square(x): + y = x + 1 + return y * y + + take_all = lambda x: constant_op.constant(True) + is_zero = lambda x: math_ops.equal(x, 0) + is_odd = lambda x: math_ops.equal(x % 2, 0) + greater = lambda x: math_ops.greater(x + 5, 0) + + functions = [identity, increment, minus_five, increment_and_square] + filters = [take_all, is_zero, is_odd, greater] + tests = [] + + for x, fun in enumerate(functions): + for y, predicate in enumerate(filters): + tests.append(("mixed_{}_{}".format(x, y), fun, predicate)) + + # Multi output + tests.append(("multiOne", lambda x: (x, x), + lambda x, y: constant_op.constant(True))) + tests.append( + ("multiTwo", lambda x: (x, 2), + lambda x, y: math_ops.equal(x * math_ops.cast(y, dtypes.int64), 0))) + return tuple(tests) + + @parameterized.named_parameters(*map_and_filter_functions.__func__()) + def testMapFilterFusion(self, function, predicate): + dataset = dataset_ops.Dataset.range(10).apply( + optimization.assert_next( + ["Map", + "FilterByLastComponent"])).map(function).filter(predicate).apply( + optimization.optimize(["map_and_filter_fusion"])) + self._testMapAndFilter(dataset, function, predicate) + + def _testMapAndFilter(self, dataset, function, predicate): + iterator = dataset.make_one_shot_iterator() + get_next = iterator.get_next() + with self.test_session() as sess: + for x in range(10): + r = function(x) + if isinstance(r, tuple): + b = predicate(*r) # Pass tuple as multiple arguments. + else: + b = predicate(r) + if sess.run(b): + result = sess.run(get_next) + self.assertAllEqual(r, result) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + + def testAdditionalInputs(self): + a = constant_op.constant(3, dtype=dtypes.int64) + b = constant_op.constant(4, dtype=dtypes.int64) + some_tensor = math_ops.mul(a, b) + function = lambda x: x * x + + def predicate(y): + return math_ops.less(math_ops.cast(y, dtypes.int64), some_tensor) + + # We are currently not supporting functions with additional inputs. + dataset = dataset_ops.Dataset.range(10).apply( + optimization.assert_next( + ["Map", "Filter"])).map(function).filter(predicate).apply( + optimization.optimize(["map_and_filter_fusion"])) + + self._testMapAndFilter(dataset, function, predicate) + + +class OptimizeStatsDatasetTest(stats_dataset_test_base.StatsDatasetTestBase): + + def testLatencyStatsOptimization(self): + + stats_aggregator = stats_ops.StatsAggregator() + dataset = dataset_ops.Dataset.from_tensors(1).apply( + optimization.assert_next( + ["LatencyStats", "Map", "LatencyStats", "Prefetch", + "LatencyStats"])).map(lambda x: x * x).prefetch(1).apply( + optimization.optimize(["latency_all_edges"])).apply( + stats_ops.set_stats_aggregator(stats_aggregator)) + iterator = dataset.make_initializable_iterator() + get_next = iterator.get_next() + summary_t = stats_aggregator.get_summary() + + with self.test_session() as sess: + sess.run(iterator.initializer) + self.assertEqual(1 * 1, sess.run(get_next)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + summary_str = sess.run(summary_t) + self._assertSummaryHasCount(summary_str, + "record_latency_TensorDataset/_1", 1) + self._assertSummaryHasCount(summary_str, "record_latency_MapDataset/_4", + 1) + self._assertSummaryHasCount(summary_str, + "record_latency_PrefetchDataset/_6", 1) + if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/data/python/kernel_tests/prefetching_ops_test.py b/tensorflow/contrib/data/python/kernel_tests/prefetching_ops_test.py index 2da6131e8e60ca53723da7f66a7ee52151640129..361fe0dd39bb3f855c3b0b11281a9909fd601232 100644 --- a/tensorflow/contrib/data/python/kernel_tests/prefetching_ops_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/prefetching_ops_test.py @@ -907,6 +907,42 @@ class CopyToDeviceTest(test.TestCase): with self.assertRaises(errors.OutOfRangeError): sess.run(next_element) + def testIteratorGetNextAsOptionalOnGPU(self): + if not test_util.is_gpu_available(): + self.skipTest("No GPU available") + + host_dataset = dataset_ops.Dataset.range(3) + device_dataset = host_dataset.apply( + prefetching_ops.copy_to_device("/gpu:0")) + with ops.device("/gpu:0"): + iterator = device_dataset.make_initializable_iterator() + next_elem = iterator_ops.get_next_as_optional(iterator) + elem_has_value_t = next_elem.has_value() + elem_value_t = next_elem.get_value() + + with self.test_session() as sess: + # Before initializing the iterator, evaluating the optional fails with + # a FailedPreconditionError. + with self.assertRaises(errors.FailedPreconditionError): + sess.run(elem_has_value_t) + with self.assertRaises(errors.FailedPreconditionError): + sess.run(elem_value_t) + + # For each element of the dataset, assert that the optional evaluates to + # the expected value. + sess.run(iterator.initializer) + for i in range(3): + elem_has_value, elem_value = sess.run([elem_has_value_t, elem_value_t]) + self.assertTrue(elem_has_value) + self.assertEqual(i, elem_value) + + # After exhausting the iterator, `next_elem.has_value()` will evaluate to + # false, and attempting to get the value will fail. + for _ in range(2): + self.assertFalse(sess.run(elem_has_value_t)) + with self.assertRaises(errors.InvalidArgumentError): + sess.run(elem_value_t) + class MultiDeviceIteratorTest(test.TestCase): @@ -985,7 +1021,7 @@ class MultiDeviceIteratorTest(test.TestCase): def testUneven(self): dataset = dataset_ops.Dataset.range(10) multi_device_iterator = prefetching_ops.MultiDeviceIterator( - dataset, ["/cpu:1", "/cpu:2"]) + dataset, ["/cpu:1", "/cpu:2"], max_buffer_size=4) elem_on_1, elem_on_2 = multi_device_iterator.get_next() config = config_pb2.ConfigProto(device_count={"CPU": 3}) @@ -1043,7 +1079,7 @@ class MultiDeviceIteratorTest(test.TestCase): with compat.forward_compatibility_horizon(2018, 8, 4): dataset = dataset_ops.Dataset.range(10) multi_device_iterator = prefetching_ops.MultiDeviceIterator( - dataset, ["/cpu:1", "/gpu:0"]) + dataset, ["/cpu:1", "/gpu:0"], max_buffer_size=4) elem_on_1, elem_on_2 = multi_device_iterator.get_next() config = config_pb2.ConfigProto(device_count={"CPU": 2, "GPU": 1}) diff --git a/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test.py b/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test.py index 851a33dfc849a2d935887def44734aace5dcaf7f..15b342d30f85a05b3827998565ba5f84021ac885 100644 --- a/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test.py @@ -173,15 +173,23 @@ class ReadBatchFeaturesTest( for num_epochs in [1, 10]: with ops.Graph().as_default(): # Basic test: read from file 0. - self.outputs = self.make_batch_feature( + outputs = self.make_batch_feature( filenames=self.test_filenames[0], num_epochs=num_epochs, batch_size=batch_size, drop_final_batch=True).make_one_shot_iterator().get_next() - for _, tensor in self.outputs.items(): + for _, tensor in outputs.items(): if isinstance(tensor, ops.Tensor): # Guard against SparseTensor. self.assertEqual(tensor.shape[0], batch_size) + def testIndefiniteRepeatShapeInference(self): + dataset = self.make_batch_feature( + filenames=self.test_filenames[0], num_epochs=None, batch_size=32) + for shape, clazz in zip(nest.flatten(dataset.output_shapes), + nest.flatten(dataset.output_classes)): + if issubclass(clazz, ops.Tensor): + self.assertEqual(32, shape[0]) + class MakeCsvDatasetTest(test.TestCase): @@ -795,6 +803,16 @@ class MakeCsvDatasetTest(test.TestCase): all_equal = all_equal and np.array_equal(batch1[i], batch2[i]) self.assertFalse(all_equal) + def testIndefiniteRepeatShapeInference(self): + column_names = ["col%d" % i for i in range(5)] + inputs = [[",".join(x for x in column_names), "0,1,2,3,4", "5,6,7,8,9"], [ + ",".join(x for x in column_names), "10,11,12,13,14", "15,16,17,18,19" + ]] + filenames = self._setup_files(inputs) + dataset = self._make_csv_dataset(filenames, batch_size=32, num_epochs=None) + for shape in nest.flatten(dataset.output_shapes): + self.assertEqual(32, shape[0]) + class MakeTFRecordDatasetTest( reader_dataset_ops_test_base.TFRecordDatasetTestBase): @@ -1002,5 +1020,12 @@ class MakeTFRecordDatasetTest( self._shuffle_test(batch_size, num_epochs, num_parallel_reads, seed=21345) + def testIndefiniteRepeatShapeInference(self): + dataset = readers.make_tf_record_dataset( + file_pattern=self.test_filenames, num_epochs=None, batch_size=32) + for shape in nest.flatten(dataset.output_shapes): + self.assertEqual(32, shape[0]) + + if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/BUILD b/tensorflow/contrib/data/python/kernel_tests/serialization/BUILD index 3c3f23f9a984c702abfdacf11bef0e5d4066782f..7b9ea191a4524891d1b589e1e228e29241fda7f8 100644 --- a/tensorflow/contrib/data/python/kernel_tests/serialization/BUILD +++ b/tensorflow/contrib/data/python/kernel_tests/serialization/BUILD @@ -56,6 +56,7 @@ py_test( "//tensorflow/python:client_testlib", "//tensorflow/python:errors", "//tensorflow/python/data/ops:dataset_ops", + "@absl_py//absl/testing:parameterized", ], ) diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/cache_dataset_serialization_test.py b/tensorflow/contrib/data/python/kernel_tests/serialization/cache_dataset_serialization_test.py index a0a1100893c7384b0e2bd9fcfdaa8d3698b95d28..1b6059ccbcc81937696e1b0ebb269f213adbb976 100644 --- a/tensorflow/contrib/data/python/kernel_tests/serialization/cache_dataset_serialization_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/serialization/cache_dataset_serialization_test.py @@ -19,6 +19,8 @@ from __future__ import print_function import os +from absl.testing import parameterized + from tensorflow.contrib.data.python.kernel_tests.serialization import dataset_serialization_test_base from tensorflow.python.data.ops import dataset_ops from tensorflow.python.framework import errors @@ -26,7 +28,8 @@ from tensorflow.python.platform import test class CacheDatasetSerializationTest( - dataset_serialization_test_base.DatasetSerializationTestBase): + dataset_serialization_test_base.DatasetSerializationTestBase, + parameterized.TestCase): def setUp(self): self.range_size = 10 @@ -34,88 +37,123 @@ class CacheDatasetSerializationTest( self.num_outputs = self.range_size * self.num_repeats self.cache_file_prefix = 'test' - def ds_fn(self): - return dataset_ops.Dataset.range(self.range_size).cache( - os.path.join(self.get_temp_dir(), - self.cache_file_prefix)).repeat(self.num_repeats) + def make_dataset_fn(self, is_memory): + if is_memory: + filename = '' + else: + filename = os.path.join(self.get_temp_dir(), self.cache_file_prefix) + + def ds_fn(): + return dataset_ops.Dataset.range(self.range_size).cache(filename).repeat( + self.num_repeats) + + return ds_fn def expected_outputs(self): return list(range(self.range_size)) * self.num_repeats - def testCheckpointBeforeOneEpoch(self): + @parameterized.named_parameters( + ('Memory', True), + ('File', False), + ) + def testCheckpointBeforeOneEpoch(self, is_memory): + ds_fn = self.make_dataset_fn(is_memory) + # Generate 5 entries from iterator and save checkpoint. - outputs = self.gen_outputs(self.ds_fn, [], 5, verify_exhausted=False) + outputs = self.gen_outputs(ds_fn, [], 5, verify_exhausted=False) self.assertSequenceEqual(outputs, range(5)) # Restore from checkpoint and produce the rest of the elements from the # iterator. outputs.extend( self.gen_outputs( - self.ds_fn, [], + ds_fn, [], self.num_outputs - 5, ckpt_saved=True, verify_exhausted=False)) self.assertSequenceEqual(outputs, self.expected_outputs()) - def testCheckpointBeforeOneEpochThenRunFewSteps(self): - # Generate 8 entries from iterator but save checkpoint after producing - # 5. + @parameterized.named_parameters( + ('Memory', True), + ('File', False), + ) + def testCheckpointBeforeOneEpochThenRunFewSteps(self, is_memory): + ds_fn = self.make_dataset_fn(is_memory) + + # Generate 8 entries from iterator but save checkpoint after producing 5. outputs = self.gen_outputs( - self.ds_fn, [5], - 8, - verify_exhausted=False, - save_checkpoint_at_end=False) + ds_fn, [5], 8, verify_exhausted=False, save_checkpoint_at_end=False) self.assertSequenceEqual(outputs, range(8)) - # Restoring from checkpoint and running GetNext should return a - # `AlreadExistsError` now because the lockfile already exists. - with self.assertRaises(errors.AlreadyExistsError): - self.gen_outputs( - self.ds_fn, [], - self.num_outputs - 5, - ckpt_saved=True, - verify_exhausted=False) + if is_memory: + outputs = outputs[:5] + outputs.extend( + self.gen_outputs( + ds_fn, [], + self.num_outputs - 5, + ckpt_saved=True, + verify_exhausted=False)) + self.assertSequenceEqual(outputs, self.expected_outputs()) + else: + # Restoring from checkpoint and running GetNext should return + # `AlreadExistsError` now because the lockfile already exists. + with self.assertRaises(errors.AlreadyExistsError): + self.gen_outputs( + ds_fn, [], + self.num_outputs - 5, + ckpt_saved=True, + verify_exhausted=False) + + @parameterized.named_parameters( + ('Memory', True), + ('File', False), + ) + def testCheckpointAfterOneEpoch(self, is_memory): + ds_fn = self.make_dataset_fn(is_memory) - def testCheckpointAfterOneEpoch(self): # Generate 15 entries from iterator and save checkpoint. - outputs = self.gen_outputs(self.ds_fn, [], 15, verify_exhausted=False) + outputs = self.gen_outputs(ds_fn, [], 15, verify_exhausted=False) self.assertSequenceEqual(outputs, list(range(10)) + list(range(5))) # Restore from checkpoint and produce the rest of the elements from the # iterator. outputs.extend( self.gen_outputs( - self.ds_fn, [], + ds_fn, [], self.num_outputs - 15, ckpt_saved=True, verify_exhausted=False)) self.assertSequenceEqual(outputs, self.expected_outputs()) - def testCheckpointAfterOneEpochThenRunFewSteps(self): - # Generate 18 entries from iterator but save checkpoint after producing - # 15. + @parameterized.named_parameters( + ('Memory', True), + ('File', False), + ) + def testCheckpointAfterOneEpochThenRunFewSteps(self, is_memory): + ds_fn = self.make_dataset_fn(is_memory) + + # Generate 18 entries from iterator but save checkpoint after producing 15. outputs = self.gen_outputs( - self.ds_fn, [15], - 18, - verify_exhausted=False, - save_checkpoint_at_end=False) + ds_fn, [15], 18, verify_exhausted=False, save_checkpoint_at_end=False) self.assertSequenceEqual(outputs, list(range(10)) + list(range(8))) outputs = list(range(10)) + list(range(5)) + self.gen_outputs( - self.ds_fn, [], + ds_fn, [], self.num_outputs - 15, ckpt_saved=True, verify_exhausted=False) self.assertSequenceEqual(outputs, list(range(10)) * 3) - def testCheckpointBeforeOneEpochButRunCompleteEpoch(self): - # Generate 13 entries from iterator but save checkpoint after producing - # 5. + @parameterized.named_parameters( + ('Memory', True), + ('File', False), + ) + def testCheckpointBeforeOneEpochButRunCompleteEpoch(self, is_memory): + ds_fn = self.make_dataset_fn(is_memory) + + # Generate 13 entries from iterator but save checkpoint after producing 5. outputs = self.gen_outputs( - self.ds_fn, [5], - 13, - verify_exhausted=False, - save_checkpoint_at_end=False) + ds_fn, [5], 13, verify_exhausted=False, save_checkpoint_at_end=False) self.assertSequenceEqual(outputs, list(range(10)) + list(range(3))) # Since we ran for more than one epoch, the cache was completely written. @@ -124,65 +162,90 @@ class CacheDatasetSerializationTest( # been completely written. outputs = list(range(5)) + self.gen_outputs( - self.ds_fn, [], + ds_fn, [], self.num_outputs - 5, ckpt_saved=True, verify_exhausted=False) self.assertSequenceEqual(outputs, list(range(10)) * 3) - def testCheckpointUnusedWriterIterator(self): + @parameterized.named_parameters( + ('Memory', True), + ('File', False), + ) + def testCheckpointUnusedWriterIterator(self, is_memory): + ds_fn = self.make_dataset_fn(is_memory) + # Checkpoint before get_next is called even once. - outputs = self.gen_outputs(self.ds_fn, [], 0, verify_exhausted=False) + outputs = self.gen_outputs(ds_fn, [], 0, verify_exhausted=False) self.assertSequenceEqual(outputs, []) outputs = self.gen_outputs( - self.ds_fn, [], - self.num_outputs, - ckpt_saved=True, - verify_exhausted=False) + ds_fn, [], self.num_outputs, ckpt_saved=True, verify_exhausted=False) self.assertSequenceEqual(outputs, list(range(10)) * 3) - def testCheckpointUnusedMidwayWriterIterator(self): + @parameterized.named_parameters( + ('Memory', True), + ('File', False), + ) + def testCheckpointUnusedMidwayWriterIterator(self, is_memory): + ds_fn = self.make_dataset_fn(is_memory) + # Produce 5 elements and checkpoint. - outputs = self.gen_outputs(self.ds_fn, [], 5, verify_exhausted=False) + outputs = self.gen_outputs(ds_fn, [], 5, verify_exhausted=False) self.assertSequenceEqual(outputs, range(5)) # Restore from checkpoint, then produce no elements and checkpoint. outputs.extend( - self.gen_outputs( - self.ds_fn, [], 0, ckpt_saved=True, verify_exhausted=False)) + self.gen_outputs(ds_fn, [], 0, ckpt_saved=True, verify_exhausted=False)) self.assertSequenceEqual(outputs, range(5)) # Restore from checkpoint and produce rest of the elements. outputs.extend( self.gen_outputs( - self.ds_fn, [], + ds_fn, [], self.num_outputs - 5, ckpt_saved=True, verify_exhausted=False)) self.assertSequenceEqual(outputs, list(range(10)) * 3) - def testUnusedCheckpointError(self): + @parameterized.named_parameters( + ('Memory', True), + ('File', False), + ) + def testUnusedCheckpointError(self, is_memory): + ds_fn = self.make_dataset_fn(is_memory) + # Produce 5 elements and save ckpt. - outputs = self.gen_outputs(self.ds_fn, [], 5, verify_exhausted=False) + outputs = self.gen_outputs(ds_fn, [], 5, verify_exhausted=False) self.assertSequenceEqual(outputs, range(5)) - # Since the complete cache has not been written, a new iterator which does - # not restore the checkpoint will throw an error since there is a partial - # cache shard. - with self.assertRaises(errors.AlreadyExistsError): + if is_memory: outputs = self.gen_outputs( - self.ds_fn, [], self.num_outputs, verify_exhausted=False) + ds_fn, [], self.num_outputs, verify_exhausted=False) + self.assertSequenceEqual(outputs, self.expected_outputs()) + else: + # Since the complete cache has not been written, a new iterator which does + # not restore the checkpoint will throw an error since there is a partial + # cache shard. + with self.assertRaises(errors.AlreadyExistsError): + outputs = self.gen_outputs( + ds_fn, [], self.num_outputs, verify_exhausted=False) + + @parameterized.named_parameters( + ('Memory', True), + ('File', False), + ) + def testIgnoreCheckpointIfCacheWritten(self, is_memory): + ds_fn = self.make_dataset_fn(is_memory) - def testIgnoreCheckpointIfCacheWritten(self): # Produce 15 elements and save ckpt. This will write the complete cache. - outputs = self.gen_outputs(self.ds_fn, [], 15, verify_exhausted=False) + outputs = self.gen_outputs(ds_fn, [], 15, verify_exhausted=False) self.assertSequenceEqual(outputs, list(range(10)) + list(range(5))) # Build the iterator again but do not restore from ckpt. Since the cache # has already been written we should be able to use it. outputs = self.gen_outputs( - self.ds_fn, [], self.num_outputs, verify_exhausted=False) + ds_fn, [], self.num_outputs, verify_exhausted=False) self.assertSequenceEqual(outputs, list(range(10)) * 3) diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/dataset_serialization_test_base.py b/tensorflow/contrib/data/python/kernel_tests/serialization/dataset_serialization_test_base.py index 393f08850b1865180a8b94e9209b2445b54c8b69..3ed4dfb7295ca77c78ce5318bf31e16a354e16a8 100644 --- a/tensorflow/contrib/data/python/kernel_tests/serialization/dataset_serialization_test_base.py +++ b/tensorflow/contrib/data/python/kernel_tests/serialization/dataset_serialization_test_base.py @@ -32,6 +32,7 @@ from tensorflow.python.ops import lookup_ops from tensorflow.python.ops import variables from tensorflow.python.platform import gfile from tensorflow.python.platform import test +from tensorflow.python.training import checkpoint_management from tensorflow.python.training import saver as saver_lib from tensorflow.python.util import nest @@ -655,7 +656,7 @@ class DatasetSerializationTestBase(test.TestCase): return os.path.join(self.get_temp_dir(), "iterator") def _latest_ckpt(self): - return saver_lib.latest_checkpoint(self.get_temp_dir()) + return checkpoint_management.latest_checkpoint(self.get_temp_dir()) def _save(self, sess, saver): saver.save(sess, self._ckpt_path()) diff --git a/tensorflow/contrib/data/python/kernel_tests/stats_dataset_ops_test.py b/tensorflow/contrib/data/python/kernel_tests/stats_dataset_ops_test.py index b4945685c1d1062bf416b73f1541f351adf45604..a41d21f8c14ed6bec7626599a5aa7f365765ce8b 100644 --- a/tensorflow/contrib/data/python/kernel_tests/stats_dataset_ops_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/stats_dataset_ops_test.py @@ -20,8 +20,8 @@ from __future__ import print_function import numpy as np from tensorflow.contrib.data.python.kernel_tests import reader_dataset_ops_test_base +from tensorflow.contrib.data.python.kernel_tests import stats_dataset_test_base from tensorflow.contrib.data.python.ops import stats_ops -from tensorflow.core.framework import summary_pb2 from tensorflow.python.data.ops import dataset_ops from tensorflow.python.framework import errors from tensorflow.python.framework import ops @@ -29,28 +29,7 @@ from tensorflow.python.ops import array_ops from tensorflow.python.platform import test -class StatsDatasetTestBase(test.TestCase): - - def _assertSummaryHasCount(self, summary_str, tag, expected_value): - summary_proto = summary_pb2.Summary() - summary_proto.ParseFromString(summary_str) - for value in summary_proto.value: - if tag == value.tag: - self.assertEqual(expected_value, value.histo.num) - return - self.fail("Expected tag %r not found in summary %r" % (tag, summary_proto)) - - def _assertSummaryHasSum(self, summary_str, tag, expected_value): - summary_proto = summary_pb2.Summary() - summary_proto.ParseFromString(summary_str) - for value in summary_proto.value: - if tag == value.tag: - self.assertEqual(expected_value, value.histo.sum) - return - self.fail("Expected tag %r not found in summary %r" % (tag, summary_proto)) - - -class StatsDatasetTest(StatsDatasetTestBase): +class StatsDatasetTest(stats_dataset_test_base.StatsDatasetTestBase): def testBytesProduced(self): stats_aggregator = stats_ops.StatsAggregator() @@ -197,7 +176,7 @@ class StatsDatasetTest(StatsDatasetTestBase): class FeatureStatsDatasetTest( - StatsDatasetTestBase, + stats_dataset_test_base.StatsDatasetTestBase, reader_dataset_ops_test_base.ReadBatchFeaturesTestBase): def testFeaturesStats(self): diff --git a/tensorflow/contrib/data/python/kernel_tests/stats_dataset_test_base.py b/tensorflow/contrib/data/python/kernel_tests/stats_dataset_test_base.py new file mode 100644 index 0000000000000000000000000000000000000000..9a13acf8f0ac6690cad8847873768562da795496 --- /dev/null +++ b/tensorflow/contrib/data/python/kernel_tests/stats_dataset_test_base.py @@ -0,0 +1,44 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Base class for testing the input pipeline statistics gathering ops.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + + +from tensorflow.core.framework import summary_pb2 +from tensorflow.python.platform import test + + +class StatsDatasetTestBase(test.TestCase): + """Base class for testing statistics gathered in `StatsAggregator`.""" + + def _assertSummaryHasCount(self, summary_str, tag, expected_value): + summary_proto = summary_pb2.Summary() + summary_proto.ParseFromString(summary_str) + for value in summary_proto.value: + if tag == value.tag: + self.assertEqual(expected_value, value.histo.num) + return + self.fail("Expected tag %r not found in summary %r" % (tag, summary_proto)) + + def _assertSummaryHasSum(self, summary_str, tag, expected_value): + summary_proto = summary_pb2.Summary() + summary_proto.ParseFromString(summary_str) + for value in summary_proto.value: + if tag == value.tag: + self.assertEqual(expected_value, value.histo.sum) + return + self.fail("Expected tag %r not found in summary %r" % (tag, summary_proto)) diff --git a/tensorflow/contrib/data/python/ops/BUILD b/tensorflow/contrib/data/python/ops/BUILD index 1ad021ea037add48afee5bdfda9eea18485eca5d..ad9378dfb9d938c826f994da9bbb89101cfbd872 100644 --- a/tensorflow/contrib/data/python/ops/BUILD +++ b/tensorflow/contrib/data/python/ops/BUILD @@ -210,6 +210,17 @@ py_library( ], ) +py_library( + name = "map_defun", + srcs = ["map_defun.py"], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow/python:dataset_ops_gen", + "//tensorflow/python:framework_ops", + "//tensorflow/python:tensor_shape", + ], +) + py_library( name = "resampling", srcs = ["resampling.py"], @@ -370,6 +381,7 @@ py_library( ":get_single_element", ":grouping", ":interleave_ops", + ":map_defun", ":optimization", ":prefetching_ops", ":readers", diff --git a/tensorflow/contrib/data/python/ops/batching.py b/tensorflow/contrib/data/python/ops/batching.py index a4914f4cde71925af477636c91d98b54ce0cce0e..9f059942a65177186132164531237f838ecd63a2 100644 --- a/tensorflow/contrib/data/python/ops/batching.py +++ b/tensorflow/contrib/data/python/ops/batching.py @@ -31,7 +31,6 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor from tensorflow.python.framework import tensor_shape -from tensorflow.python.framework import tensor_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import check_ops from tensorflow.python.ops import control_flow_ops @@ -186,7 +185,7 @@ def dense_to_sparse_batch(batch_size, row_shape): Returns: A `Dataset` transformation function, which can be passed to - @{tf.data.Dataset.apply}. + `tf.data.Dataset.apply`. """ def _apply_fn(dataset): @@ -402,7 +401,7 @@ def unbatch(): Returns: A `Dataset` transformation function, which can be passed to - @{tf.data.Dataset.apply}. + `tf.data.Dataset.apply`. """ def _apply_fn(dataset): @@ -439,54 +438,12 @@ def unbatch(): return _apply_fn -def _filter_irregular_batches(batch_size): - """Transformation that filters out batches that are not of size batch_size.""" - - def _apply_fn(dataset): - """Function from `Dataset` to `Dataset` that applies the transformation.""" - tensor_batch_size = ops.convert_to_tensor( - batch_size, dtype=dtypes.int64, name="batch_size") - - flattened = _RestructuredDataset( - dataset, - tuple(nest.flatten(dataset.output_types)), - output_classes=tuple(nest.flatten(dataset.output_classes))) - - def _predicate(*xs): - """Return `True` if this element is a full batch.""" - # Extract the dynamic batch size from the first component of the flattened - # batched element. - first_component = xs[0] - first_component_batch_size = array_ops.shape( - first_component, out_type=dtypes.int64)[0] - - return math_ops.equal(first_component_batch_size, tensor_batch_size) - - filtered = flattened.filter(_predicate) - - maybe_constant_batch_size = tensor_util.constant_value(tensor_batch_size) - - def _set_first_dimension(shape): - return shape.merge_with( - tensor_shape.vector(maybe_constant_batch_size).concatenate(shape[1:])) - - known_shapes = nest.map_structure(_set_first_dimension, - dataset.output_shapes) - return _RestructuredDataset( - filtered, - dataset.output_types, - known_shapes, - output_classes=dataset.output_classes) - - return _apply_fn - - @deprecation.deprecated( None, "Use `tf.data.Dataset.batch(..., drop_remainder=True)`.") def batch_and_drop_remainder(batch_size): """A batching transformation that omits the final small batch (if present). - Like @{tf.data.Dataset.batch}, this transformation combines + Like `tf.data.Dataset.batch`, this transformation combines consecutive elements of this dataset into batches. However, if the batch size does not evenly divide the input dataset size, this transformation will drop the final smaller element. @@ -510,15 +467,12 @@ def batch_and_drop_remainder(batch_size): Returns: A `Dataset` transformation function, which can be passed to - @{tf.data.Dataset.apply} + `tf.data.Dataset.apply` """ def _apply_fn(dataset): """Function from `Dataset` to `Dataset` that applies the transformation.""" - # TODO(jsimsa): Switch to using `batch(..., drop_remainder=True)` any time - # after 6/30/2018. - batched = dataset.batch(batch_size) - return _filter_irregular_batches(batch_size)(batched) + return dataset.batch(batch_size, drop_remainder=True) return _apply_fn @@ -530,34 +484,32 @@ def padded_batch_and_drop_remainder(batch_size, padding_values=None): """A batching and padding transformation that omits the final small batch. - Like @{tf.data.Dataset.padded_batch}, this transformation combines + Like `tf.data.Dataset.padded_batch`, this transformation combines consecutive elements of this dataset into batches. However, if the batch size does not evenly divide the input dataset size, this transformation will drop the final smaller element. - See `@{tf.contrib.data.batch_and_drop_remainder}` for more details. + See `tf.contrib.data.batch_and_drop_remainder` for more details. Args: batch_size: A `tf.int64` scalar `tf.Tensor`, representing the number of consecutive elements of this dataset to combine in a single batch. padded_shapes: A nested structure of `tf.TensorShape` or `tf.int64` vector tensor-like objects. See - @{tf.data.Dataset.padded_batch} for details. + `tf.data.Dataset.padded_batch` for details. padding_values: (Optional.) A nested structure of scalar-shaped - `tf.Tensor`. See @{tf.data.Dataset.padded_batch} for details. + `tf.Tensor`. See `tf.data.Dataset.padded_batch` for details. Returns: A `Dataset` transformation function, which can be passed to - @{tf.data.Dataset.apply} + `tf.data.Dataset.apply` """ def _apply_fn(dataset): """Function from `Dataset` to `Dataset` that applies the transformation.""" - # TODO(jsimsa): Switch to using `padded_batch(..., drop_remainder=True)` - # any time after 6/30/2018. - batched = dataset.padded_batch( - batch_size, padded_shapes=padded_shapes, padding_values=padding_values) - return _filter_irregular_batches(batch_size)(batched) + return dataset.padded_batch( + batch_size, padded_shapes=padded_shapes, padding_values=padding_values, + drop_remainder=True) return _apply_fn @@ -709,7 +661,7 @@ def assert_element_shape(expected_shapes): Returns: A `Dataset` transformation function, which can be passed to - @{tf.data.Dataset.apply} + `tf.data.Dataset.apply` """ def _check_shape(*elements): @@ -808,7 +760,7 @@ def map_and_batch(map_func, Returns: A `Dataset` transformation function, which can be passed to - @{tf.data.Dataset.apply}. + `tf.data.Dataset.apply`. Raises: ValueError: If both `num_parallel_batches` and `num_parallel_calls` are diff --git a/tensorflow/contrib/data/python/ops/enumerate_ops.py b/tensorflow/contrib/data/python/ops/enumerate_ops.py index ac2b386b81532b801139baa00fd5edd4ecd6ef0a..490281e0d2da7a454a2f63f95753c7c436b87a76 100644 --- a/tensorflow/contrib/data/python/ops/enumerate_ops.py +++ b/tensorflow/contrib/data/python/ops/enumerate_ops.py @@ -47,7 +47,7 @@ def enumerate_dataset(start=0): Returns: A `Dataset` transformation function, which can be passed to - @{tf.data.Dataset.apply}. + `tf.data.Dataset.apply`. """ def _apply_fn(dataset): diff --git a/tensorflow/contrib/data/python/ops/error_ops.py b/tensorflow/contrib/data/python/ops/error_ops.py index d46d96c461ad4cc0ac25a8ddc285cec23d09c682..b4a7521e0875089c39ac7aa8b7b49e44feb2b4ad 100644 --- a/tensorflow/contrib/data/python/ops/error_ops.py +++ b/tensorflow/contrib/data/python/ops/error_ops.py @@ -42,7 +42,7 @@ def ignore_errors(): Returns: A `Dataset` transformation function, which can be passed to - @{tf.data.Dataset.apply}. + `tf.data.Dataset.apply`. """ def _apply_fn(dataset): diff --git a/tensorflow/contrib/data/python/ops/get_single_element.py b/tensorflow/contrib/data/python/ops/get_single_element.py index ef9284456eb35099db804e0680abfacd6384d503..a6713b017afa315edec9389d0a6c1c7135e6aeb9 100644 --- a/tensorflow/contrib/data/python/ops/get_single_element.py +++ b/tensorflow/contrib/data/python/ops/get_single_element.py @@ -29,8 +29,8 @@ from tensorflow.python.ops import gen_dataset_ops def get_single_element(dataset): """Returns the single element in `dataset` as a nested structure of tensors. - This function enables you to use a @{tf.data.Dataset} in a stateless - "tensor-in tensor-out" expression, without creating a @{tf.data.Iterator}. + This function enables you to use a `tf.data.Dataset` in a stateless + "tensor-in tensor-out" expression, without creating a `tf.data.Iterator`. This can be useful when your preprocessing transformations are expressed as a `Dataset`, and you want to use the transformation at serving time. For example: @@ -50,10 +50,10 @@ def get_single_element(dataset): ``` Args: - dataset: A @{tf.data.Dataset} object containing a single element. + dataset: A `tf.data.Dataset` object containing a single element. Returns: - A nested structure of @{tf.Tensor} objects, corresponding to the single + A nested structure of `tf.Tensor` objects, corresponding to the single element of `dataset`. Raises: @@ -77,11 +77,11 @@ def reduce_dataset(dataset, reducer): """Returns the result of reducing the `dataset` using `reducer`. Args: - dataset: A @{tf.data.Dataset} object. - reducer: A @{tf.contrib.data.Reducer} object representing the reduce logic. + dataset: A `tf.data.Dataset` object. + reducer: A `tf.contrib.data.Reducer` object representing the reduce logic. Returns: - A nested structure of @{tf.Tensor} objects, corresponding to the result + A nested structure of `tf.Tensor` objects, corresponding to the result of reducing `dataset` using `reducer`. Raises: diff --git a/tensorflow/contrib/data/python/ops/grouping.py b/tensorflow/contrib/data/python/ops/grouping.py index bd8d398c58cc1825616c1ab5337cf6668c66697e..6edc1d79902c571b34b6a0a108c4d62cb6097ccb 100644 --- a/tensorflow/contrib/data/python/ops/grouping.py +++ b/tensorflow/contrib/data/python/ops/grouping.py @@ -50,7 +50,7 @@ def group_by_reducer(key_func, reducer): Returns: A `Dataset` transformation function, which can be passed to - @{tf.data.Dataset.apply}. + `tf.data.Dataset.apply`. """ def _apply_fn(dataset): @@ -92,7 +92,7 @@ def group_by_window(key_func, Returns: A `Dataset` transformation function, which can be passed to - @{tf.data.Dataset.apply}. + `tf.data.Dataset.apply`. Raises: ValueError: if neither or both of {`window_size`, `window_size_func`} are @@ -142,11 +142,11 @@ def bucket_by_sequence_length(element_length_func, bucket_batch_sizes: `list`, batch size per bucket. Length should be `len(bucket_boundaries) + 1`. padded_shapes: Nested structure of `tf.TensorShape` to pass to - @{tf.data.Dataset.padded_batch}. If not provided, will use + `tf.data.Dataset.padded_batch`. If not provided, will use `dataset.output_shapes`, which will result in variable length dimensions being padded out to the maximum length in each batch. padding_values: Values to pad with, passed to - @{tf.data.Dataset.padded_batch}. Defaults to padding with 0. + `tf.data.Dataset.padded_batch`. Defaults to padding with 0. pad_to_bucket_boundary: bool, if `False`, will pad dimensions with unknown size to maximum length in batch. If `True`, will pad dimensions with unknown size to bucket boundary minus 1 (i.e., the maximum length in each @@ -155,7 +155,7 @@ def bucket_by_sequence_length(element_length_func, Returns: A `Dataset` transformation function, which can be passed to - @{tf.data.Dataset.apply}. + `tf.data.Dataset.apply`. Raises: ValueError: if `len(bucket_batch_sizes) != len(bucket_boundaries) + 1`. diff --git a/tensorflow/contrib/data/python/ops/interleave_ops.py b/tensorflow/contrib/data/python/ops/interleave_ops.py index bcc959594a6b311a3c60bb4696ac97be5c448756..5a1a35199abecc3890d5733ddf678af8d4098f33 100644 --- a/tensorflow/contrib/data/python/ops/interleave_ops.py +++ b/tensorflow/contrib/data/python/ops/interleave_ops.py @@ -42,7 +42,7 @@ def parallel_interleave(map_func, `parallel_interleave()` maps `map_func` across its input to produce nested datasets, and outputs their elements interleaved. Unlike - @{tf.data.Dataset.interleave}, it gets elements from `cycle_length` nested + `tf.data.Dataset.interleave`, it gets elements from `cycle_length` nested datasets in parallel, which increases the throughput, especially in the presence of stragglers. Furthermore, the `sloppy` argument can be used to improve performance, by relaxing the requirement that the outputs are produced @@ -79,7 +79,7 @@ def parallel_interleave(map_func, Returns: A `Dataset` transformation function, which can be passed to - @{tf.data.Dataset.apply}. + `tf.data.Dataset.apply`. """ def _apply_fn(dataset): return readers.ParallelInterleaveDataset( @@ -138,7 +138,7 @@ def sloppy_interleave(map_func, cycle_length, block_length=1): Returns: A `Dataset` transformation function, which can be passed to - @{tf.data.Dataset.apply}. + `tf.data.Dataset.apply`. """ def _apply_fn(dataset): return readers.ParallelInterleaveDataset( @@ -196,15 +196,15 @@ def sample_from_datasets(datasets, weights=None, seed=None): """Samples elements at random from the datasets in `datasets`. Args: - datasets: A list of @{tf.data.Dataset} objects with compatible structure. + datasets: A list of `tf.data.Dataset` objects with compatible structure. weights: (Optional.) A list of `len(datasets)` floating-point values where `weights[i]` represents the probability with which an element should be - sampled from `datasets[i]`, or a @{tf.data.Dataset} object where each + sampled from `datasets[i]`, or a `tf.data.Dataset` object where each element is such a list. Defaults to a uniform distribution across `datasets`. seed: (Optional.) A `tf.int64` scalar `tf.Tensor`, representing the random seed that will be used to create the distribution. See - @{tf.set_random_seed} for behavior. + `tf.set_random_seed` for behavior. Returns: A dataset that interleaves elements from `datasets` at random, according to @@ -262,8 +262,8 @@ def choose_from_datasets(datasets, choice_dataset): ``` Args: - datasets: A list of @{tf.data.Dataset} objects with compatible structure. - choice_dataset: A @{tf.data.Dataset} of scalar `tf.int64` tensors between + datasets: A list of `tf.data.Dataset` objects with compatible structure. + choice_dataset: A `tf.data.Dataset` of scalar `tf.int64` tensors between `0` and `len(datasets) - 1`. Returns: diff --git a/tensorflow/contrib/data/python/ops/iterator_ops.py b/tensorflow/contrib/data/python/ops/iterator_ops.py index 0d71be66018eeebe60de9deff24ceb6854d209d9..18515e21edfe0449514ab4f21683a600eaf48910 100644 --- a/tensorflow/contrib/data/python/ops/iterator_ops.py +++ b/tensorflow/contrib/data/python/ops/iterator_ops.py @@ -20,6 +20,7 @@ from tensorflow.python.data.ops import iterator_ops from tensorflow.python.framework import ops from tensorflow.python.ops import gen_dataset_ops from tensorflow.python.training import basic_session_run_hooks +from tensorflow.python.training import checkpoint_management from tensorflow.python.training import saver as saver_lib from tensorflow.python.training import session_run_hook @@ -117,7 +118,7 @@ class CheckpointInputPipelineHook(session_run_hook.SessionRunHook): pipeline. For saving the input pipeline checkpoint alongside the model weights use - @{tf.contrib.data.make_saveable_from_iterator} directly to create a + `tf.contrib.data.make_saveable_from_iterator` directly to create a `SaveableObject` and add to the `SAVEABLE_OBJECTS` collection. Note, however, that you will need to be careful not to restore the training iterator during eval. You can do that by not adding the iterator to the SAVEABLE_OBJECTS @@ -206,7 +207,7 @@ class CheckpointInputPipelineHook(session_run_hook.SessionRunHook): # Check if there is an existing checkpoint. If so, restore from it. # pylint: disable=protected-access - latest_checkpoint_path = saver_lib.latest_checkpoint( + latest_checkpoint_path = checkpoint_management.latest_checkpoint( self._checkpoint_saver_hook._checkpoint_dir, latest_filename=self._latest_filename) if latest_checkpoint_path: diff --git a/tensorflow/contrib/data/python/ops/map_defun.py b/tensorflow/contrib/data/python/ops/map_defun.py new file mode 100644 index 0000000000000000000000000000000000000000..54d5cd6da068fa5471b7beafcc66d76b5972e7d5 --- /dev/null +++ b/tensorflow/contrib/data/python/ops/map_defun.py @@ -0,0 +1,58 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Experimental API for optimizing `tf.data` pipelines.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_shape +from tensorflow.python.ops import gen_dataset_ops + + +def map_defun(fn, elems, output_dtypes, output_shapes): + """Map a function on the list of tensors unpacked from `elems` on dimension 0. + + Args: + fn: A function (`function.Defun`) that takes a list of tensors and returns + another list of tensors. The output list has the same types as + output_dtypes. The elements of the output list have the same dimension 0 + as `elems`, and the remaining dimensions correspond to those of + `fn_output_shapes`. + elems: A list of tensors. + output_dtypes: A list of dtypes corresponding to the output types of the + function. + output_shapes: A list of `TensorShape`s corresponding to the output + shapes from each invocation of the function on slices of inputs. + + Raises: + ValueError: if any of the inputs are malformed. + + Returns: + A list of `Tensor` objects with the same types as `output_dtypes`. + """ + if not isinstance(elems, list): + raise ValueError("`elems` must be a list of tensors.") + if not isinstance(output_dtypes, list): + raise ValueError("`output_dtypes` must be a list of tensors.") + if not isinstance(output_shapes, list): + raise ValueError("`output_shapes` must be a list of tensors.") + + elems = [ops.convert_to_tensor(e) for e in elems] + output_shapes = [tensor_shape.TensorShape(s) for s in output_shapes] + if not all(s.is_fully_defined() for s in output_shapes): + raise ValueError("All fn output shapes must be fully defined.") + return gen_dataset_ops.map_defun(elems, output_dtypes, output_shapes, fn) diff --git a/tensorflow/contrib/data/python/ops/optimization.py b/tensorflow/contrib/data/python/ops/optimization.py index 018c5115e1d5599e48bf99ccf832c7962794fc40..fa1b851ad74bcf2cff69d42bce3eaa38822cd663 100644 --- a/tensorflow/contrib/data/python/ops/optimization.py +++ b/tensorflow/contrib/data/python/ops/optimization.py @@ -36,7 +36,7 @@ def assert_next(transformations): Returns: A `Dataset` transformation function, which can be passed to - @{tf.data.Dataset.apply}. + `tf.data.Dataset.apply`. """ def _apply_fn(dataset): @@ -56,7 +56,7 @@ def optimize(optimizations=None): Returns: A `Dataset` transformation function, which can be passed to - @{tf.data.Dataset.apply}. + `tf.data.Dataset.apply`. """ def _apply_fn(dataset): diff --git a/tensorflow/contrib/data/python/ops/prefetching_ops.py b/tensorflow/contrib/data/python/ops/prefetching_ops.py index 0edd7c9fe974784f199c272a649b302e72d8c218..5222011d045efd9a64b4e89b248303cffbcb0b37 100644 --- a/tensorflow/contrib/data/python/ops/prefetching_ops.py +++ b/tensorflow/contrib/data/python/ops/prefetching_ops.py @@ -92,7 +92,7 @@ def function_buffering_resource_reset(function_buffer_resource, name=None): # pylint: disable=protected-access class _PrefetchToDeviceIterator(object): - """A replacement for @{tf.data.Iterator} that prefetches to another device. + """A replacement for `tf.data.Iterator` that prefetches to another device. Args: input_dataset: The input dataset @@ -158,7 +158,7 @@ class _PrefetchToDeviceIterator(object): self._input_dataset) def get_next(self, name=None): - """See @{tf.data.Iterator.get_next}.""" + """See `tf.data.Iterator.get_next`.""" self._get_next_call_count += 1 if self._get_next_call_count > iterator_ops.GET_NEXT_CALL_WARNING_THRESHOLD: warnings.warn(iterator_ops.GET_NEXT_CALL_WARNING_MESSAGE) @@ -199,7 +199,7 @@ class _PrefetchToDeviceIterator(object): class _PrefetchToDeviceEagerIterator(iterator_ops.EagerIterator): - """A replacement for @{tf.data.Iterator} that prefetches to another device. + """A replacement for `tf.data.Iterator` that prefetches to another device. Args: input_dataset: The input dataset @@ -334,7 +334,7 @@ class _PrefetchToDeviceDataset(dataset_ops.Dataset): def prefetch_to_device(device, buffer_size=None): """A transformation that prefetches dataset values to the given `device`. - NOTE: Although the transformation creates a @{tf.data.Dataset}, the + NOTE: Although the transformation creates a `tf.data.Dataset`, the transformation must be the final `Dataset` in the input pipeline. Args: @@ -344,7 +344,7 @@ def prefetch_to_device(device, buffer_size=None): Returns: A `Dataset` transformation function, which can be passed to - @{tf.data.Dataset.apply}. + `tf.data.Dataset.apply`. """ def _apply_fn(dataset): return _PrefetchToDeviceDataset(dataset, device, buffer_size) @@ -361,7 +361,7 @@ def copy_to_device(target_device, source_device="/cpu:0"): Returns: A `Dataset` transformation function, which can be passed to - @{tf.data.Dataset.apply}. + `tf.data.Dataset.apply`. """ def _apply_fn(dataset): @@ -631,8 +631,19 @@ class MultiDeviceIterator(object): def __init__(self, dataset, devices, + max_buffer_size=1, prefetch_buffer_size=1, source_device="/cpu:0"): + """Constructs a MultiDeviceIterator. + + Args: + dataset: The input dataset to be iterated over. + devices: The list of devices to fetch data to. + max_buffer_size: Maximum size of the host side per device buffer to keep. + prefetch_buffer_size: if > 1, then we setup a buffer on each device + to prefetch into. + source_device: The host device to place the `dataset` on. + """ self._dataset = dataset self._devices = devices self._source_device = source_device @@ -659,7 +670,8 @@ class MultiDeviceIterator(object): # iterators and the multi-device iterator. self._incarnation_id = gen_dataset_ops.multi_device_iterator_init( self._dataset._as_variant_tensor(), # pylint: disable=protected-access - self._multi_device_iterator_resource) + self._multi_device_iterator_resource, + max_buffer_size=max_buffer_size) # TODO(rohanj): Explore the possibility of the MultiDeviceIterator to # initialize the device side of the pipeline. This would allow the @@ -673,7 +685,8 @@ class MultiDeviceIterator(object): i, self._multi_device_iterator_resource, self._incarnation_id, self._source_device_tensor, device, self._dataset.output_shapes, self._dataset.output_types, self._dataset.output_classes) - ds = ds.prefetch(prefetch_buffer_size) + if prefetch_buffer_size > 0: + ds = ds.prefetch(prefetch_buffer_size) with ops.device(device): self._device_iterators.append(ds.make_initializable_iterator()) i += 1 diff --git a/tensorflow/contrib/data/python/ops/readers.py b/tensorflow/contrib/data/python/ops/readers.py index f018dd02e6ae9de69c7364677e1756d1e11bf484..3882d4bfdbe899c2ce92f829cb331b32d3d50398 100644 --- a/tensorflow/contrib/data/python/ops/readers.py +++ b/tensorflow/contrib/data/python/ops/readers.py @@ -234,7 +234,7 @@ def make_tf_record_dataset( Args: file_pattern: List of files or patterns of TFRecord file paths. - See @{tf.gfile.Glob} for pattern rules. + See `tf.gfile.Glob` for pattern rules. batch_size: An int representing the number of records to combine in a single batch. parser_fn: (Optional.) A function accepting string input to parse @@ -286,11 +286,14 @@ def make_tf_record_dataset( dataset = _maybe_shuffle_and_repeat( dataset, num_epochs, shuffle, shuffle_buffer_size, shuffle_seed) + # NOTE(mrry): We set `drop_final_batch=True` when `num_epochs is None` to + # improve the shape inference, because it makes the batch dimension static. + # It is safe to do this because in that case we are repeating the input + # indefinitely, and all batches will be full-sized. + drop_final_batch = drop_final_batch or num_epochs is None + if parser_fn is None: - if drop_final_batch: - dataset = dataset.apply(batching.batch_and_drop_remainder(batch_size)) - else: - dataset = dataset.batch(batch_size) + dataset = dataset.batch(batch_size, drop_remainder=drop_final_batch) else: # TODO(josh11b): if num_parallel_parser_calls is None, use some function # of num cores instead of map_and_batch's default behavior of one batch. @@ -337,7 +340,7 @@ def make_csv_dataset( Args: file_pattern: List of files or patterns of file paths containing CSV - records. See @{tf.gfile.Glob} for pattern rules. + records. See `tf.gfile.Glob` for pattern rules. batch_size: An int representing the number of records to combine in a single batch. column_names: An optional list of strings that corresponds to the CSV @@ -493,8 +496,13 @@ def make_csv_dataset( dataset, num_epochs, shuffle, shuffle_buffer_size, shuffle_seed) # Apply batch before map for perf, because map has high overhead relative - # to the size of the computation in each map - dataset = dataset.batch(batch_size=batch_size) + # to the size of the computation in each map. + # NOTE(mrry): We set `drop_remainder=True` when `num_epochs is None` to + # improve the shape inference, because it makes the batch dimension static. + # It is safe to do this because in that case we are repeating the input + # indefinitely, and all batches will be full-sized. + dataset = dataset.batch(batch_size=batch_size, + drop_remainder=num_epochs is None) dataset = dataset.map(map_fn, num_parallel_calls=num_parallel_parser_calls) dataset = dataset.prefetch(prefetch_buffer_size) @@ -772,10 +780,12 @@ def make_batched_features_dataset(file_pattern, dataset = dataset.apply(stats_ops.feature_stats("record_stats")) - if drop_final_batch: - dataset = dataset.apply(batching.batch_and_drop_remainder(batch_size)) - else: - dataset = dataset.batch(batch_size) + # NOTE(mrry): We set `drop_remainder=True` when `num_epochs is None` to + # improve the shape inference, because it makes the batch dimension static. + # It is safe to do this because in that case we are repeating the input + # indefinitely, and all batches will be full-sized. + dataset = dataset.batch( + batch_size, drop_remainder=drop_final_batch or num_epochs is None) # Parse `Example` tensors to a dictionary of `Feature` tensors. dataset = dataset.map( diff --git a/tensorflow/contrib/data/python/ops/resampling.py b/tensorflow/contrib/data/python/ops/resampling.py index 182a5c6ff36fcda8c9e2c522cce07bed0c2daec9..75642f143e19c3d77e675384362c4dab94e10932 100644 --- a/tensorflow/contrib/data/python/ops/resampling.py +++ b/tensorflow/contrib/data/python/ops/resampling.py @@ -50,7 +50,7 @@ def rejection_resample(class_func, target_dist, initial_dist=None, seed=None): Returns: A `Dataset` transformation function, which can be passed to - @{tf.data.Dataset.apply}. + `tf.data.Dataset.apply`. """ def _apply_fn(dataset): """Function from `Dataset` to `Dataset` that applies the transformation.""" diff --git a/tensorflow/contrib/data/python/ops/scan_ops.py b/tensorflow/contrib/data/python/ops/scan_ops.py index ea9dcfe68fa2630d915323fa295031af7d48cdfb..6b002b4a533669dd0f5e82a00aa29224a83a7e57 100644 --- a/tensorflow/contrib/data/python/ops/scan_ops.py +++ b/tensorflow/contrib/data/python/ops/scan_ops.py @@ -151,7 +151,7 @@ class _ScanDataset(dataset_ops.Dataset): def scan(initial_state, scan_func): """A transformation that scans a function across an input dataset. - This transformation is a stateful relative of @{tf.data.Dataset.map}. + This transformation is a stateful relative of `tf.data.Dataset.map`. In addition to mapping `scan_func` across the elements of the input dataset, `scan()` accumulates one or more state tensors, whose initial values are `initial_state`. @@ -166,7 +166,7 @@ def scan(initial_state, scan_func): Returns: A `Dataset` transformation function, which can be passed to - @{tf.data.Dataset.apply}. + `tf.data.Dataset.apply`. """ def _apply_fn(dataset): return _ScanDataset(dataset, initial_state, scan_func) diff --git a/tensorflow/contrib/data/python/ops/shuffle_ops.py b/tensorflow/contrib/data/python/ops/shuffle_ops.py index d7f8a73fe3d67bb83e44e962832ce34c116aef66..4356721704046199e8ef2938bde6d7d8bce68cc1 100644 --- a/tensorflow/contrib/data/python/ops/shuffle_ops.py +++ b/tensorflow/contrib/data/python/ops/shuffle_ops.py @@ -92,11 +92,11 @@ def shuffle_and_repeat(buffer_size, count=None, seed=None): indefinitely. seed: (Optional.) A `tf.int64` scalar `tf.Tensor`, representing the random seed that will be used to create the distribution. See - @{tf.set_random_seed} for behavior. + `tf.set_random_seed` for behavior. Returns: A `Dataset` transformation function, which can be passed to - @{tf.data.Dataset.apply}. + `tf.data.Dataset.apply`. """ def _apply_fn(dataset): # pylint: disable=missing-docstring diff --git a/tensorflow/contrib/data/python/ops/sliding.py b/tensorflow/contrib/data/python/ops/sliding.py index e9dd74530ac64cd414d53eab5294eaa95c919131..8025dcdd16b0180aeb951a31de21e22b8e8c31c7 100644 --- a/tensorflow/contrib/data/python/ops/sliding.py +++ b/tensorflow/contrib/data/python/ops/sliding.py @@ -109,7 +109,7 @@ def sliding_window_batch(window_size, Returns: A `Dataset` transformation function, which can be passed to - @{tf.data.Dataset.apply}. + `tf.data.Dataset.apply`. Raises: ValueError: if invalid arguments are provided. diff --git a/tensorflow/contrib/data/python/ops/stats_ops.py b/tensorflow/contrib/data/python/ops/stats_ops.py index 97931f75bd37d9e45864fe477c6e1620b5e4f193..3b4e98140234af0bf2128ac32f95dbdbf183cb54 100644 --- a/tensorflow/contrib/data/python/ops/stats_ops.py +++ b/tensorflow/contrib/data/python/ops/stats_ops.py @@ -29,7 +29,7 @@ class StatsAggregator(object): """A stateful resource that aggregates statistics from one or more iterators. To record statistics, use one of the custom transformation functions defined - in this module when defining your @{tf.data.Dataset}. All statistics will be + in this module when defining your `tf.data.Dataset`. All statistics will be aggregated by the `StatsAggregator` that is associated with a particular iterator (see below). For example, to record the total number of bytes produced by iterating over a dataset: @@ -39,7 +39,7 @@ class StatsAggregator(object): dataset = dataset.apply(stats_ops.bytes_produced_stats("total_bytes")) ``` - To associate a `StatsAggregator` with a @{tf.data.Iterator} object, use + To associate a `StatsAggregator` with a `tf.data.Iterator` object, use the following pattern: ```python @@ -55,7 +55,7 @@ class StatsAggregator(object): To get a protocol buffer summary of the currently aggregated statistics, use the `StatsAggregator.get_summary()` tensor. The easiest way to do this - is to add the returned tensor to the @{tf.GraphKeys.SUMMARIES} collection, + is to add the returned tensor to the `tf.GraphKeys.SUMMARIES` collection, so that the summaries will be included with any existing summaries. ```python @@ -74,13 +74,13 @@ class StatsAggregator(object): self._resource = gen_dataset_ops.stats_aggregator_handle() def get_summary(self): - """Returns a string @{tf.Tensor} that summarizes the aggregated statistics. + """Returns a string `tf.Tensor` that summarizes the aggregated statistics. - The returned tensor will contain a serialized @{tf.summary.Summary} protocol + The returned tensor will contain a serialized `tf.summary.Summary` protocol buffer, which can be used with the standard TensorBoard logging facilities. Returns: - A scalar string @{tf.Tensor} that summarizes the aggregated statistics. + A scalar string `tf.Tensor` that summarizes the aggregated statistics. """ return gen_dataset_ops.stats_aggregator_summary(self._resource) @@ -122,7 +122,7 @@ def set_stats_aggregator(stats_aggregator): Returns: A `Dataset` transformation function, which can be passed to - @{tf.data.Dataset.apply}. + `tf.data.Dataset.apply`. """ def _apply_fn(dataset): @@ -145,7 +145,7 @@ def bytes_produced_stats(tag): Returns: A `Dataset` transformation function, which can be passed to - @{tf.data.Dataset.apply}. + `tf.data.Dataset.apply`. """ def _apply_fn(dataset): @@ -169,7 +169,7 @@ def latency_stats(tag): Returns: A `Dataset` transformation function, which can be passed to - @{tf.data.Dataset.apply}. + `tf.data.Dataset.apply`. """ def _apply_fn(dataset): @@ -192,7 +192,7 @@ def feature_stats(tag): Returns: A `Dataset` transformation function, which can be passed to - @{tf.data.Dataset.apply}. + `tf.data.Dataset.apply`. """ def _apply_fn(dataset): diff --git a/tensorflow/contrib/data/python/ops/threadpool.py b/tensorflow/contrib/data/python/ops/threadpool.py index 9af1e784ffb4f6d71da25f09d60343b649c5079b..dc67accdcfbc2692cbe0c961521897a316f40647 100644 --- a/tensorflow/contrib/data/python/ops/threadpool.py +++ b/tensorflow/contrib/data/python/ops/threadpool.py @@ -100,6 +100,6 @@ def override_threadpool(dataset, thread_pool): Returns: A dataset containing the same values as `dataset`, but which uses `thread_pool` to compute any of its parallel operations (such as - @{tf.data.Dataset.map}). + `tf.data.Dataset.map`). """ return _ThreadPoolDataset(dataset, thread_pool) diff --git a/tensorflow/contrib/data/python/ops/unique.py b/tensorflow/contrib/data/python/ops/unique.py index e0ce0a4ef15f6b9181bce92fb4d73bf1fab2e66c..e0d606311c4f2f678970113c1faa578dbf44b2ba 100644 --- a/tensorflow/contrib/data/python/ops/unique.py +++ b/tensorflow/contrib/data/python/ops/unique.py @@ -38,7 +38,7 @@ def unique(): Returns: A `Dataset` transformation function, which can be passed to - @{tf.data.Dataset.apply}. + `tf.data.Dataset.apply`. """ def _apply_fn(dataset): diff --git a/tensorflow/contrib/data/python/ops/writers.py b/tensorflow/contrib/data/python/ops/writers.py index f53bd3f7383950d6cfdb35e12811fb1daf24b320..c455fdcba673853079ff0d162c4799e72bc8e627 100644 --- a/tensorflow/contrib/data/python/ops/writers.py +++ b/tensorflow/contrib/data/python/ops/writers.py @@ -38,13 +38,13 @@ class TFRecordWriter(object): argument_dtype=dtypes.string) def write(self, dataset): - """Returns a @{tf.Operation} to write a dataset to a file. + """Returns a `tf.Operation` to write a dataset to a file. Args: - dataset: a @{tf.data.Dataset} whose elements are to be written to a file + dataset: a `tf.data.Dataset` whose elements are to be written to a file Returns: - A @{tf.Operation} that, when run, writes contents of `dataset` to a file. + A `tf.Operation` that, when run, writes contents of `dataset` to a file. """ if not isinstance(dataset, dataset_ops.Dataset): raise TypeError("`dataset` must be a `tf.data.Dataset` object.") diff --git a/tensorflow/contrib/distribute/BUILD b/tensorflow/contrib/distribute/BUILD index 1126f76f5854932bcb6a9550c100768069bbd1cc..c16f1d6035d9fb4c5ffe29a713edfeaff299affc 100644 --- a/tensorflow/contrib/distribute/BUILD +++ b/tensorflow/contrib/distribute/BUILD @@ -25,10 +25,12 @@ py_library( srcs = ["__init__.py"], visibility = ["//tensorflow:internal"], deps = [ + "//tensorflow/contrib/distribute/python:collective_all_reduce_strategy", "//tensorflow/contrib/distribute/python:cross_tower_ops", "//tensorflow/contrib/distribute/python:mirrored_strategy", "//tensorflow/contrib/distribute/python:monitor", "//tensorflow/contrib/distribute/python:one_device_strategy", + "//tensorflow/contrib/distribute/python:parameter_server_strategy", "//tensorflow/contrib/distribute/python:step_fn", "//tensorflow/contrib/distribute/python:tpu_strategy", "//tensorflow/python:training", diff --git a/tensorflow/contrib/distribute/__init__.py b/tensorflow/contrib/distribute/__init__.py index 2e2c3be853cc5503c86121c142394d49e5037405..588a4f2898b2b7d818898990e4ce7bd343a32bfe 100644 --- a/tensorflow/contrib/distribute/__init__.py +++ b/tensorflow/contrib/distribute/__init__.py @@ -19,24 +19,29 @@ from __future__ import division from __future__ import print_function # pylint: disable=unused-import,wildcard-import +from tensorflow.contrib.distribute.python.collective_all_reduce_strategy import CollectiveAllReduceStrategy from tensorflow.contrib.distribute.python.cross_tower_ops import * from tensorflow.contrib.distribute.python.mirrored_strategy import MirroredStrategy from tensorflow.contrib.distribute.python.monitor import Monitor from tensorflow.contrib.distribute.python.one_device_strategy import OneDeviceStrategy +from tensorflow.contrib.distribute.python.parameter_server_strategy import ParameterServerStrategy from tensorflow.contrib.distribute.python.step_fn import * from tensorflow.contrib.distribute.python.tpu_strategy import TPUStrategy from tensorflow.python.training.distribute import * +from tensorflow.python.training.distribution_strategy_context import * from tensorflow.python.util.all_util import remove_undocumented _allowed_symbols = [ 'AllReduceCrossTowerOps', + 'CollectiveAllReduceStrategy', 'CrossTowerOps', 'DistributionStrategy', 'MirroredStrategy', 'Monitor', 'OneDeviceStrategy', + 'ParameterServerStrategy', 'ReductionToOneDeviceCrossTowerOps', 'Step', 'StandardInputStep', @@ -49,6 +54,7 @@ _allowed_symbols = [ 'get_tower_context', 'has_distribution_strategy', 'require_tower_context', + 'UpdateContext', ] remove_undocumented(__name__, _allowed_symbols) diff --git a/tensorflow/contrib/distribute/python/BUILD b/tensorflow/contrib/distribute/python/BUILD index f5d7e24ae2e3aa76efc50f4da93411f66edea651..59efd17746d98ba4fd736e4e3b7772f52c2f5bd7 100644 --- a/tensorflow/contrib/distribute/python/BUILD +++ b/tensorflow/contrib/distribute/python/BUILD @@ -57,7 +57,7 @@ cuda_py_test( "//tensorflow/python/eager:context", "//tensorflow/python:device_util", "//tensorflow/python/eager:test", - "//tensorflow/python/estimator:model_fn", + "//tensorflow/python/estimator:estimator_py", ], tags = [ "no_pip", @@ -72,31 +72,39 @@ py_library( ":cross_tower_ops", ":shared_variable_creator", ":values", + "//tensorflow/core:protos_all_py", "//tensorflow/python:array_ops", + "//tensorflow/python:constant_op", + "//tensorflow/python:control_flow_ops", "//tensorflow/python:device", "//tensorflow/python:device_util", "//tensorflow/python:distribute", "//tensorflow/python:framework_ops", - "//tensorflow/python:math_ops", "//tensorflow/python:pywrap_tensorflow", "//tensorflow/python:training", + "//tensorflow/python:util", "//tensorflow/python:variable_scope", + "//tensorflow/python:variables", "//tensorflow/python/eager:context", "//tensorflow/python/eager:tape", - "@six_archive//:six", ], ) py_library( - name = "multi_worker_strategy", - srcs = ["multi_worker_strategy.py"], + name = "parameter_server_strategy", + srcs = ["parameter_server_strategy.py"], visibility = ["//tensorflow:internal"], deps = [ + ":cross_tower_ops", ":mirrored_strategy", ":values", "//tensorflow/core:protos_all_py", + "//tensorflow/python:array_ops", + "//tensorflow/python:framework_ops", + "//tensorflow/python:resource_variable_ops", "//tensorflow/python:training", "//tensorflow/python:util", + "//tensorflow/python/distribute:multi_worker_util", ], ) @@ -116,6 +124,24 @@ py_library( ], ) +py_library( + name = "collective_all_reduce_strategy", + srcs = ["collective_all_reduce_strategy.py"], + visibility = ["//tensorflow:internal"], + deps = [ + ":cross_tower_ops", + ":cross_tower_utils", + ":mirrored_strategy", + ":values", + "//tensorflow/core:protos_all_py", + "//tensorflow/python:array_ops", + "//tensorflow/python:collective_ops", + "//tensorflow/python:framework_ops", + "//tensorflow/python:training", + "//tensorflow/python/eager:context", + ], +) + py_library( name = "strategy_test_lib", testonly = 1, @@ -149,9 +175,9 @@ py_library( ], deps = [ ":mirrored_strategy", - ":multi_worker_strategy", ":one_device_strategy", ":tpu_strategy", + "//tensorflow/contrib/cluster_resolver:cluster_resolver_pip", "//tensorflow/contrib/optimizer_v2:training", "//tensorflow/python:distribute", "//tensorflow/python:framework_ops", @@ -183,9 +209,13 @@ py_test( ], deps = [ ":mirrored_strategy", + ":multi_worker_test_base", ":strategy_test_lib", + "//tensorflow/python:constant_op", "//tensorflow/python:distribute", + "//tensorflow/python:framework_ops", "//tensorflow/python:framework_test_lib", + "//tensorflow/python:training", "//tensorflow/python:variable_scope", "//tensorflow/python/eager:context", "//tensorflow/python/eager:test", @@ -207,6 +237,35 @@ py_test( ], ) +py_test( + name = "parameter_server_strategy_test", + srcs = ["parameter_server_strategy_test.py"], + srcs_version = "PY2AND3", + tags = [ + "no_pip", + ], + deps = [ + ":combinations", + ":multi_worker_test_base", + ":parameter_server_strategy", + "//tensorflow/core:protos_all_py", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:constant_op", + "//tensorflow/python:control_flow_ops", + "//tensorflow/python:framework_ops", + "//tensorflow/python:gradients", + "//tensorflow/python:layers", + "//tensorflow/python:session", + "//tensorflow/python:training", + "//tensorflow/python:variable_scope", + "//tensorflow/python:variables", + "//tensorflow/python/eager:context", + "//tensorflow/python/estimator:estimator_py", + "@absl_py//absl/testing:parameterized", + ], +) + cuda_py_test( name = "mirrored_strategy_multigpu_test", srcs = ["mirrored_strategy_multigpu_test.py"], @@ -247,11 +306,11 @@ py_library( ], deps = [ "//tensorflow/core:protos_all_py", + "//tensorflow/python:client_testlib", "//tensorflow/python:distributed_framework_test_lib", - "//tensorflow/python:platform", "//tensorflow/python:session", - "//tensorflow/python:training", - "//tensorflow/python/eager:test", + "//tensorflow/python/estimator:estimator_py", + "//third_party/py/numpy", ], ) @@ -272,8 +331,7 @@ py_library( deps = [ ":one_device_strategy", ":values", - "//tensorflow/contrib/tpu", - "//tensorflow/contrib/tpu:tpu_py", + "//tensorflow/contrib/tpu:tpu_lib", "//tensorflow/python:constant_op", "//tensorflow/python:control_flow_ops", "//tensorflow/python:framework_ops", @@ -281,6 +339,37 @@ py_library( ], ) +py_test( + name = "collective_all_reduce_strategy_test", + srcs = ["collective_all_reduce_strategy_test.py"], + srcs_version = "PY2AND3", + tags = [ + "no_pip", + ], + deps = [ + ":collective_all_reduce_strategy", + ":combinations", + ":cross_tower_utils", + ":multi_worker_test_base", + ":strategy_test_lib", + "//tensorflow/core:protos_all_py", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:constant_op", + "//tensorflow/python:dtypes", + "//tensorflow/python:framework_ops", + "//tensorflow/python:gradients", + "//tensorflow/python:init_ops", + "//tensorflow/python:layers", + "//tensorflow/python:variable_scope", + "//tensorflow/python:variables", + "//tensorflow/python/eager:context", + "//tensorflow/python/estimator:estimator_py", + "//third_party/py/numpy", + "@absl_py//absl/testing:parameterized", + ], +) + py_library( name = "minimize_loss_test_lib", testonly = 1, @@ -345,11 +434,7 @@ cuda_py_test( "//tensorflow/contrib/optimizer_v2:training", "//tensorflow/python/data/ops:dataset_ops", "//tensorflow/python/eager:test", - "//tensorflow/python/estimator:dnn_linear_combined", - "//tensorflow/python/estimator:export_export", - "//tensorflow/python/estimator:numpy_io", - "//tensorflow/python/estimator:prediction_keys", - "//tensorflow/python/estimator:run_config", + "//tensorflow/python/estimator:estimator_py", "//tensorflow/python/feature_column", "//tensorflow/python:framework_ops", "//tensorflow/python:platform", @@ -375,17 +460,27 @@ py_library( ], ) -cuda_py_test( - name = "step_fn_test", +py_library( + name = "step_fn_test_lib", + testonly = 1, srcs = ["step_fn_test.py"], - additional_deps = [ - ":single_loss_example", + deps = [ ":combinations", - "@absl_py//absl/testing:parameterized", - "//third_party/py/numpy", + ":single_loss_example", + "//tensorflow/contrib/tpu:tpu_lib", "//tensorflow/python:variables", "//tensorflow/python/eager:context", "//tensorflow/python/eager:test", + "//third_party/py/numpy", + "@absl_py//absl/testing:parameterized", + ], +) + +cuda_py_test( + name = "step_fn_test", + srcs = ["step_fn_test.py"], + additional_deps = [ + ":step_fn_test_lib", ], tags = [ "multi_and_single_gpu", @@ -451,8 +546,11 @@ py_library( "//tensorflow/contrib/all_reduce:all_reduce_py", "//tensorflow/contrib/nccl:nccl_py", "//tensorflow/python:array_ops", + "//tensorflow/python:collective_ops", + "//tensorflow/python:device", "//tensorflow/python:dtypes", "//tensorflow/python:framework_ops", + "//tensorflow/python:gradients", "//tensorflow/python:math_ops", ], ) @@ -487,7 +585,9 @@ py_library( "//tensorflow/python:framework_ops", "//tensorflow/python:math_ops", "//tensorflow/python:platform", + "//tensorflow/python:resource_variable_ops", "//tensorflow/python:training", + "//tensorflow/python:variable_scope", "//tensorflow/python/eager:context", "@six_archive//:six", ], @@ -495,6 +595,7 @@ py_library( cuda_py_test( name = "cross_tower_ops_test", + size = "large", srcs = ["cross_tower_ops_test.py"], additional_deps = [ ":combinations", @@ -509,7 +610,6 @@ cuda_py_test( "//tensorflow/python/eager:context", "//tensorflow/python/eager:test", ], - shard_count = 15, tags = [ "multi_and_single_gpu", "no_pip", @@ -581,8 +681,7 @@ cuda_py_test( "//tensorflow/contrib/distribute/python:mirrored_strategy", "//tensorflow/python:client_testlib", "//tensorflow/python:training", - "//tensorflow/python/estimator:keras", - "//tensorflow/python/estimator:run_config", + "//tensorflow/python/estimator:estimator_py", "//tensorflow/python/keras", ], tags = [ diff --git a/tensorflow/contrib/distribute/python/checkpoint_utils_test.py b/tensorflow/contrib/distribute/python/checkpoint_utils_test.py index fe3df9cbb95308251581005fdb858cccd5d19a1d..bcb977f64073b1d15ef5c872eb0d6b09d5307b54 100644 --- a/tensorflow/contrib/distribute/python/checkpoint_utils_test.py +++ b/tensorflow/contrib/distribute/python/checkpoint_utils_test.py @@ -49,17 +49,23 @@ class CheckpointUtilsWithDistributionStrategyTest( def testInitFromCheckpoint(self, distribution, in_tower_mode): checkpoint_dir = self.get_temp_dir() with self.test_session() as session: - v1_value, _, _, _ = checkpoint_utils_test._create_checkpoints( + v1_value, v2_value, _, _ = checkpoint_utils_test._create_checkpoints( session, checkpoint_dir) def init_and_verify(g): v1 = variable_scope.get_variable("new_var1", [1, 10]) + v2 = variable_scope.get_variable( + "new_var2", [10, 10], + synchronization=variable_scope.VariableSynchronization.ON_READ, + aggregation=variable_scope.VariableAggregation.MEAN) checkpoint_utils.init_from_checkpoint(checkpoint_dir, { "var1": "new_var1", + "var2": "new_var2" }) with self.test_session(graph=g) as session: session.run(variables.global_variables_initializer()) self.assertAllEqual(v1_value, self.evaluate(v1)) + self.assertAllEqual(v2_value, self.evaluate(v2)) with ops.Graph().as_default() as g, distribution.scope(): if in_tower_mode: diff --git a/tensorflow/contrib/distribute/python/collective_all_reduce_strategy.py b/tensorflow/contrib/distribute/python/collective_all_reduce_strategy.py new file mode 100644 index 0000000000000000000000000000000000000000..9afcaecf78844b011a9dbc30bb95fa3bfeda8470 --- /dev/null +++ b/tensorflow/contrib/distribute/python/collective_all_reduce_strategy.py @@ -0,0 +1,205 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Class CollectiveAllReduceStrategy implementing DistributionStrategy.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import json +import os + +from tensorflow.contrib.distribute.python import cross_tower_ops as cross_tower_ops_lib +from tensorflow.contrib.distribute.python import cross_tower_utils +from tensorflow.contrib.distribute.python import mirrored_strategy +from tensorflow.contrib.distribute.python import values +from tensorflow.core.protobuf import cluster_pb2 +from tensorflow.python.eager import context +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import collective_ops +from tensorflow.python.training import server_lib + + +# TODO(yuefengz): move this function to a common util file. +def _normalize_cluster_spec(cluster_spec): + if isinstance(cluster_spec, (dict, cluster_pb2.ClusterDef)): + return server_lib.ClusterSpec(cluster_spec) + elif not isinstance(cluster_spec, server_lib.ClusterSpec): + raise ValueError( + "`cluster_spec' should be dict or a `tf.train.ClusterSpec` or a " + "`tf.train.ClusterDef` object") + return cluster_spec + + +# TODO(yuefengz): shard the dataset. +# TODO(yuefengz): support in-graph replication. +# TODO(yuefengz): it only works with a cluster without a chief node, maybe +# support chief node? +class CollectiveAllReduceStrategy(mirrored_strategy.MirroredStrategy): + """Distribution strategy that uses collective ops for all-reduce. + + It is similar to the MirroredStrategy but it uses collective ops for + reduction. It currently only works for between-graph replication and its + reduction will reduce across all workers. + """ + + def __init__(self, + num_gpus_per_worker=0, + cluster_spec=None, + task_type="worker", + task_id=0): + """Initializes the object. + + Args: + num_gpus_per_worker: number of local GPUs or GPUs per worker. + cluster_spec: a dict, ClusterDef or ClusterSpec object specifying the + cluster configurations. + task_type: the current task type, such as "worker". + task_id: the current task id. + + Raises: + ValueError: if `task_type` is not in the `cluster_spec`. + """ + self._num_gpus_per_worker = num_gpus_per_worker + self._initialize(cluster_spec, task_type, task_id) + + def _initialize(self, cluster_spec, task_type, task_id): + if task_type not in ["chief", "worker"]: + raise ValueError( + "Unrecognized task_type: %r, valid task types are: \"chief\", " + "\"worker\"." % task_type) + if cluster_spec: + self._cluster_spec = _normalize_cluster_spec(cluster_spec) + worker_device = "/job:%s/task:%d" % (task_type, task_id) + num_workers = len(self._cluster_spec.as_dict().get(task_type, [])) + if "chief" in self._cluster_spec.as_dict(): + num_workers += 1 + if not num_workers: + raise ValueError("`task_type` shoud be in `cluster_spec`.") + + # TODO(yuefengz): create a utility to infer chief. + if "chief" in self._cluster_spec.as_dict() and task_type == "chief": + assert task_id == 0 + self._is_chief = True + else: + assert task_type == "worker" + self._is_chief = task_id == 0 + else: + self._cluster_spec = None + self._is_chief = True + worker_device = "" + num_workers = 1 + self._num_workers = num_workers + + if self._num_gpus_per_worker: + local_devices = [ + "%s/device:GPU:%d" % (worker_device, i) + for i in range(self._num_gpus_per_worker) + ] + else: + local_devices = [worker_device] + + self._collective_keys = cross_tower_utils.CollectiveKeys() + super(CollectiveAllReduceStrategy, self).__init__( + devices=local_devices, + cross_tower_ops=cross_tower_ops_lib.CollectiveAllReduce( + num_workers=num_workers, + num_gpus_per_worker=self._num_gpus_per_worker, + collective_keys=self._collective_keys)) + + # Add a default device so that ops without specified devices will not end up + # on other workers. + if cluster_spec: + self._default_device = "/job:%s/replica:0/task:%d" % (task_type, task_id) + + def _create_variable(self, next_creator, *args, **kwargs): + colocate_with = kwargs.pop("colocate_with", None) + devices = self._get_devices_from(colocate_with) + group_size = len(devices) * self._num_workers + group_key = self._collective_keys.get_group_key(self._devices) + + def _real_mirrored_creator(devices, *args, **kwargs): + """Creates one MirroredVariable on the current worker.""" + index = {} + collective_instance_key = self._collective_keys.get_instance_key( + key_id=kwargs["name"]) + if "initial_value" not in kwargs: + raise ValueError("Initial value must be specified.") + initial_value = kwargs["initial_value"] + if callable(initial_value): + initial_value_fn = initial_value + else: + initial_value_fn = lambda: initial_value + + for i, d in enumerate(devices): + with ops.device(d): + if i > 0: + # Give replicas meaningful distinct names: + var0name = index[devices[0]].name.split(":")[0] + # We append a / to variable names created on towers with id > 0 to + # ensure that we ignore the name scope and instead use the given + # name as the absolute name of the variable. + kwargs["name"] = "%s/replica_%d/" % (var0name, i) + + # The initial value fn makes sure variables all initialized to + # same values. The first device of the chief worker will send their + # variable values to other devices and other workers. + def _overridden_initial_value_fn(device=d, index=i): # pylint: disable=g-missing-docstring + with ops.device(device): + initial_value = initial_value_fn() + assert not callable(initial_value) + initial_value = ops.convert_to_tensor(initial_value) + + if self._is_chief and index == 0: + bcast_send = collective_ops.broadcast_send( + initial_value, initial_value.shape, initial_value.dtype, + group_size, group_key, collective_instance_key) + with ops.control_dependencies([bcast_send]): + return array_ops.identity(initial_value) + else: + return collective_ops.broadcast_recv( + initial_value.shape, initial_value.dtype, group_size, + group_key, collective_instance_key) + + kwargs["initial_value"] = _overridden_initial_value_fn + + with context.context().device_policy(context.DEVICE_PLACEMENT_SILENT): + v = next_creator(*args, **kwargs) + + assert not isinstance(v, values.DistributedVariable) + index[d] = v + return index + + # pylint: disable=protected-access + return mirrored_strategy._create_mirrored_variable( + devices, _real_mirrored_creator, *args, **kwargs) + + def configure(self, session_config=None): + # Use TF_CONFIG to get the cluster spec and the current job. + if not self._cluster_spec: + tf_config = json.loads(os.environ.get("TF_CONFIG", "{}")) + cluster_spec = _normalize_cluster_spec(tf_config.get("cluster", {})) + + task_env = tf_config.get("task", {}) + if task_env: + task_type = task_env.get("type", "worker") + task_id = int(task_env.get("index", "0")) + else: + task_type = "worker" + task_id = 0 + + if cluster_spec: + self._initialize(cluster_spec, task_type, task_id) diff --git a/tensorflow/contrib/distribute/python/collective_all_reduce_strategy_test.py b/tensorflow/contrib/distribute/python/collective_all_reduce_strategy_test.py new file mode 100644 index 0000000000000000000000000000000000000000..b5e54e3b7d7156e87731e6f79aa66262d127232c --- /dev/null +++ b/tensorflow/contrib/distribute/python/collective_all_reduce_strategy_test.py @@ -0,0 +1,217 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for CollectiveAllReduceStrategy.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from absl.testing import parameterized +import numpy as np + +from tensorflow.contrib.distribute.python import collective_all_reduce_strategy +from tensorflow.contrib.distribute.python import combinations +from tensorflow.contrib.distribute.python import cross_tower_utils +from tensorflow.contrib.distribute.python import multi_worker_test_base +from tensorflow.contrib.distribute.python import strategy_test_lib +from tensorflow.core.protobuf import config_pb2 +from tensorflow.python.eager import context +from tensorflow.python.estimator import run_config +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.layers import core +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import gradients +from tensorflow.python.ops import init_ops +from tensorflow.python.ops import variable_scope +from tensorflow.python.ops import variables +from tensorflow.python.platform import test + + +class DistributedCollectiveAllReduceStrategyTest( + multi_worker_test_base.MultiWorkerTestBase, parameterized.TestCase): + + collective_key_base = 0 + + @classmethod + def setUpClass(cls): + """Create a local cluster with 2 workers.""" + cls._workers, cls._ps = multi_worker_test_base.create_in_process_cluster( + num_workers=3, num_ps=0) + cls._cluster_spec = { + run_config.TaskType.WORKER: [ + 'fake_worker_0', 'fake_worker_1', 'fake_worker_2' + ] + } + + def setUp(self): + self._run_options = config_pb2.RunOptions() + self._run_options.experimental.collective_graph_key = 6 + + self._sess_config = config_pb2.ConfigProto() + self._sess_config.experimental.collective_group_leader = ( + '/job:worker/replica:0/task:0') + + # We use a different key_base for each test so that collective keys won't be + # reused. + # TODO(yuefengz, tucker): enable it to reuse collective keys in different + # tests. + DistributedCollectiveAllReduceStrategyTest.collective_key_base += 100000 + super(DistributedCollectiveAllReduceStrategyTest, self).setUp() + + def _get_test_object(self, task_type, task_id, num_gpus=0): + distribution = collective_all_reduce_strategy.CollectiveAllReduceStrategy( + num_gpus_per_worker=num_gpus, + cluster_spec=self._cluster_spec, + task_type=task_type, + task_id=task_id) + collective_keys = cross_tower_utils.CollectiveKeys( + group_key_start=10 * num_gpus + + DistributedCollectiveAllReduceStrategyTest.collective_key_base, + instance_key_start=num_gpus * 100 + + DistributedCollectiveAllReduceStrategyTest.collective_key_base, + instance_key_with_id_start=num_gpus * 10000 + + DistributedCollectiveAllReduceStrategyTest.collective_key_base) + distribution._collective_keys = collective_keys + distribution._cross_tower_ops._collective_keys = collective_keys + return distribution, self._workers[task_id].target + + def _test_minimize_loss_graph(self, task_type, task_id, num_gpus): + d, master_target = self._get_test_object(task_type, task_id, num_gpus) + with ops.Graph().as_default(), \ + self.test_session(config=self._sess_config, + target=master_target) as sess, \ + d.scope(): + l = core.Dense(1, use_bias=False, name='gpu_%d' % d._num_gpus_per_worker) + + def loss_fn(x): + y = array_ops.reshape(l(x), []) - constant_op.constant(1.) + return y * y + + # TODO(yuefengz, apassos): eager.backprop.implicit_grad is not safe for + # multiple graphs (b/111216820). + def grad_fn(x): + loss = loss_fn(x) + var_list = ( + variables.trainable_variables() + ops.get_collection( + ops.GraphKeys.TRAINABLE_RESOURCE_VARIABLES)) + grads = gradients.gradients(loss, var_list) + ret = list(zip(grads, var_list)) + return ret + + def update(v, g): + return v.assign_sub(0.05 * g, use_locking=True) + + one = d.broadcast(constant_op.constant([[1.]])) + + def step(): + """Perform one optimization step.""" + # Run forward & backward to get gradients, variables list. + g_v = d.call_for_each_tower(grad_fn, one) + # Update the variables using the gradients and the update() function. + before_list = [] + after_list = [] + for g, v in g_v: + fetched = d.read_var(v) + before_list.append(fetched) + with ops.control_dependencies([fetched]): + # TODO(yuefengz): support non-Mirrored variable as destinations. + g = d.reduce( + variable_scope.VariableAggregation.SUM, g, destinations=v) + with ops.control_dependencies(d.unwrap(d.update(v, update, g))): + after_list.append(d.read_var(v)) + return before_list, after_list + + before_out, after_out = step() + + if context.num_gpus() < d._num_gpus_per_worker: + return True + + sess.run( + variables.global_variables_initializer(), options=self._run_options) + + for i in range(10): + b, a = sess.run((before_out, after_out), options=self._run_options) + if i == 0: + before, = b + after, = a + + error_before = abs(before - 1) + error_after = abs(after - 1) + # Error should go down + self.assertLess(error_after, error_before) + return error_after < error_before + + @combinations.generate( + combinations.combine(mode=['graph'], num_gpus=[0, 1, 2])) + def testMinimizeLossGraph(self, num_gpus): + self._run_between_graph_clients(self._test_minimize_loss_graph, + self._cluster_spec, num_gpus) + + def _test_variable_initialization(self, task_type, task_id, num_gpus): + distribution, master_target = self._get_test_object(task_type, task_id, + num_gpus) + with ops.Graph().as_default(), \ + self.test_session(config=self._sess_config, + target=master_target) as sess, \ + distribution.scope(): + + def model_fn(): + x = variable_scope.get_variable( + 'x', + shape=(2, 3), + initializer=init_ops.random_uniform_initializer( + 1.0, 10.0, dtype=dtypes.float32)) + return array_ops.identity(x) + + x = distribution.call_for_each_tower(model_fn) + reduced_x = distribution.unwrap( + distribution.reduce( + variable_scope.VariableAggregation.MEAN, x, + destinations='/cpu:0'))[0] + + sess.run( + variables.global_variables_initializer(), options=self._run_options) + x_value, reduced_x_value = sess.run( + [x, reduced_x], options=self._run_options) + self.assertTrue(np.array_equal(x_value, reduced_x_value)) + return np.array_equal(x_value, reduced_x_value) + + @combinations.generate( + combinations.combine(mode=['graph'], num_gpus=[0, 1, 2])) + def testVariableInitialization(self, num_gpus): + if context.num_gpus() < num_gpus: + return + self._run_between_graph_clients( + self._test_variable_initialization, + self._cluster_spec, + num_gpus=num_gpus) + + +class LocalCollectiveAllReduceStrategy(strategy_test_lib.DistributionTestBase, + parameterized.TestCase): + + def testMinimizeLossGraph(self, num_gpus=2): + # Collective ops doesn't support strategy with one device. + if context.num_gpus() < num_gpus: + return + distribution = collective_all_reduce_strategy.CollectiveAllReduceStrategy( + num_gpus_per_worker=num_gpus) + self._test_minimize_loss_graph(distribution) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/distribute/python/combinations.py b/tensorflow/contrib/distribute/python/combinations.py index 9a8ea4aa48b8cf4c5906f18d8bddacc224e0b644..2fbadfe0f5ad9ef0a4255f51abe4aad5a0646efe 100644 --- a/tensorflow/contrib/distribute/python/combinations.py +++ b/tensorflow/contrib/distribute/python/combinations.py @@ -46,8 +46,8 @@ import unittest from absl.testing import parameterized import six +from tensorflow.contrib.cluster_resolver import TPUClusterResolver from tensorflow.contrib.distribute.python import mirrored_strategy as mirrored_lib -from tensorflow.contrib.distribute.python import multi_worker_strategy from tensorflow.contrib.distribute.python import one_device_strategy as one_device_lib from tensorflow.contrib.distribute.python import tpu_strategy as tpu_lib from tensorflow.contrib.optimizer_v2 import adam as adam_v2 @@ -55,7 +55,7 @@ from tensorflow.contrib.optimizer_v2 import gradient_descent as gradient_descent from tensorflow.python.eager import context from tensorflow.python.framework import ops from tensorflow.python.training import adam -from tensorflow.python.training import distribute as distribute_lib +from tensorflow.python.training import distribution_strategy_context from tensorflow.python.training import gradient_descent from tensorflow.python.util import tf_inspect @@ -144,7 +144,7 @@ def _augment_with_special_arguments(test_method): """A wrapped test method that treats some arguments in a special way.""" mode = kwargs.pop("mode", "graph") - distribution = kwargs.pop("distribution", None) + distribution = kwargs.get("distribution", None) required_tpu = kwargs.pop("required_tpu", False) required_gpus = kwargs.pop("required_gpus", None) @@ -153,7 +153,6 @@ def _augment_with_special_arguments(test_method): "Do not use `required_gpus` and `distribution` together.") assert required_tpu is False, ( "Do not use `required_tpu` and `distribution` together.") - kwargs["distribution"] = distribution.strategy required_gpus = distribution.required_gpus required_tpu = distribution.required_tpu @@ -189,9 +188,13 @@ def _augment_with_special_arguments(test_method): if mode == "eager": with ops.Graph().as_default(), context.eager_mode(): + if distribution: + kwargs_to_pass["distribution"] = distribution.strategy test_method(**kwargs_to_pass) elif mode == "graph": with ops.Graph().as_default(), context.graph_mode(): + if distribution: + kwargs_to_pass["distribution"] = distribution.strategy test_method(**kwargs_to_pass) else: raise ValueError( @@ -316,12 +319,15 @@ class NamedDistribution(object): # pylint: disable=g-long-lambda default_strategy = NamedDistribution( "Default", - lambda: distribute_lib._default_distribution_strategy, # pylint: disable=protected-access + distribution_strategy_context._get_default_distribution_strategy, # pylint: disable=protected-access required_gpus=None) one_device_strategy = NamedDistribution( "OneDeviceCPU", lambda: one_device_lib.OneDeviceStrategy("/cpu:0"), required_gpus=None) -tpu_strategy = NamedDistribution("TPU", tpu_lib.TPUStrategy, required_tpu=True) +tpu_strategy = NamedDistribution( + "TPU", lambda: tpu_lib.TPUStrategy( + TPUClusterResolver(""), steps_per_run=5), + required_tpu=True) # Note that we disable prefetching for testing since prefetching makes # the input non-deterministic. mirrored_strategy_with_gpu_and_cpu = NamedDistribution( @@ -337,42 +343,44 @@ mirrored_strategy_with_two_gpus = NamedDistribution( multi_worker_strategy_with_cpu = NamedDistribution( "MultiWorkerCPU", - lambda: multi_worker_strategy.MultiWorkerMirroredStrategy( - cluster={ + lambda: mirrored_lib.MirroredStrategy( + cluster_spec={ "worker": [ "/job:worker/replica:0/task:0", "/job:worker/replica:0/task:1" ] }, - num_gpus_per_worker=0), 0) + num_gpus=0), 0) multi_worker_strategy_with_one_gpu = NamedDistribution( "MultiWorker1GPU", - lambda: multi_worker_strategy.MultiWorkerMirroredStrategy( - cluster={ + lambda: mirrored_lib.MirroredStrategy( + cluster_spec={ "worker": [ "/job:worker/replica:0/task:0", "/job:worker/replica:0/task:1" ] }, - num_gpus_per_worker=1), 1) + num_gpus=1), 1) multi_worker_strategy_with_two_gpus = NamedDistribution( "MultiWorker2GPUs", - lambda: multi_worker_strategy.MultiWorkerMirroredStrategy( - cluster={ + lambda: mirrored_lib.MirroredStrategy( + cluster_spec={ "worker": [ "/job:worker/replica:0/task:0", "/job:worker/replica:0/task:1" ] }, - num_gpus_per_worker=2), 2) + num_gpus=2), 2) adam_optimizer_v1_fn = NamedObject( "AdamV1", lambda: adam.AdamOptimizer(0.2, epsilon=1)) gradient_descent_optimizer_v1_fn = NamedObject( "GradientDescentV1", lambda: gradient_descent.GradientDescentOptimizer(0.2)) +optimizers_v1 = [adam_optimizer_v1_fn, gradient_descent_optimizer_v1_fn] adam_optimizer_v2_fn = NamedObject( "AdamV2", lambda: adam_v2.AdamOptimizer(0.2, epsilon=1)) gradient_descent_optimizer_v2_fn = NamedObject( "GradientDescentV2", lambda: gradient_descent_v2.GradientDescentOptimizer(0.2)) +optimizers_v2 = [adam_optimizer_v2_fn, gradient_descent_optimizer_v2_fn] graph_and_eager_modes = ["graph", "eager"] @@ -384,7 +392,7 @@ def distributions_and_v1_optimizers(): one_device_strategy, mirrored_strategy_with_gpu_and_cpu, mirrored_strategy_with_two_gpus ], - optimizer_fn=[adam_optimizer_v1_fn, gradient_descent_optimizer_v1_fn]) + optimizer_fn=optimizers_v1) def distributions_and_v2_optimizers(): @@ -394,4 +402,4 @@ def distributions_and_v2_optimizers(): one_device_strategy, mirrored_strategy_with_gpu_and_cpu, mirrored_strategy_with_two_gpus ], - optimizer_fn=[adam_optimizer_v2_fn, gradient_descent_optimizer_v2_fn]) + optimizer_fn=optimizers_v2) diff --git a/tensorflow/contrib/distribute/python/cross_tower_ops.py b/tensorflow/contrib/distribute/python/cross_tower_ops.py index b0baf0dad1d55eafac5338d1eb43465927e428a1..163559587da3b8b6f175e295602f767c08468a28 100644 --- a/tensorflow/contrib/distribute/python/cross_tower_ops.py +++ b/tensorflow/contrib/distribute/python/cross_tower_ops.py @@ -28,18 +28,37 @@ from tensorflow.python.eager import context from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops +from tensorflow.python.ops import resource_variable_ops from tensorflow.python.ops import variable_scope as vs from tensorflow.python.platform import tf_logging as logging from tensorflow.python.training import device_util +def check_destinations(destinations): + """Checks whether `destinations` is not None and not empty. + + Args: + destinations: a DistributedValues, Variable, string or a list of strings. + + Returns: + Boolean indicating whether `destinations` is not None and not empty. + """ + # Calling bool() on a ResourceVariable is not allowed. + if isinstance(destinations, resource_variable_ops.ResourceVariable): + return bool(destinations.device) + return bool(destinations) + + def validate_destinations(destinations): - if not isinstance(destinations, - (value_lib.DistributedValues, six.string_types, list)): + if not isinstance( + destinations, + (value_lib.DistributedValues, resource_variable_ops.ResourceVariable, + value_lib.AggregatingVariable, six.string_types, list)): raise ValueError("destinations must be one of a `DistributedValues` object," - " a device string, a list of device strings or None") + " a tf.Variable object, a device string, a list of device " + "strings or None") - if not destinations: + if not check_destinations(destinations): raise ValueError("destinations can not be empty") @@ -59,6 +78,9 @@ def _validate_value_destination_pairs(value_destination_pairs): def get_devices_from(destinations): if isinstance(destinations, value_lib.DistributedValues): return list(destinations.devices) + elif isinstance(destinations, (resource_variable_ops.ResourceVariable, + value_lib.AggregatingVariable)): + return [destinations.device] elif isinstance(destinations, six.string_types): return [device_util.resolve(destinations)] else: @@ -136,7 +158,7 @@ class CrossTowerOps(object): Args: aggregation: Indicates how a variable will be aggregated. Accepted values - are @{tf.VariableAggregation.SUM}, @{tf.VariableAggregation.MEAN}. + are `tf.VariableAggregation.SUM`, `tf.VariableAggregation.MEAN`. per_device_value: a PerDevice object. destinations: the reduction destinations. @@ -160,7 +182,7 @@ class CrossTowerOps(object): Args: aggregation: Indicates how a variable will be aggregated. Accepted values - are @{tf.VariableAggregation.SUM}, @{tf.VariableAggregation.MEAN}. + are `tf.VariableAggregation.SUM`, `tf.VariableAggregation.MEAN`. value_destination_pairs: a list or a tuple of tuples of PerDevice objects and destinations. If a destination is None, then the destinations are set to match the devices of the input PerDevice object. @@ -225,7 +247,10 @@ class ReductionToOneDeviceCrossTowerOps(CrossTowerOps): super(ReductionToOneDeviceCrossTowerOps, self).__init__() def _reduce(self, aggregation, per_device_value, destinations): - devices = get_devices_from(destinations or per_device_value) + if check_destinations(destinations): + devices = get_devices_from(destinations) + else: + devices = get_devices_from(per_device_value) reduce_to_device = self.reduce_to_device or devices[0] reduced = _simple_reduce(per_device_value, reduce_to_device, self.accumulation_fn, aggregation) @@ -243,9 +268,9 @@ def _group_value_by_device(per_device_values): This grouping is needed to call the all-reduce library because it expects a list of the following form: - [(grad0_gpu0, v0_gpu0), (grad1_gpu0, v1_gpu0), (grad2_gpu0, v2_gpu0) ... - (grad0_gpu1, v0_gpu1), (grad1_gpu1, v1_gpu1), (grad2_gpu1, v2_gpu1) ... - (grad0_gpu2, v0_gpu2), (grad1_gpu0, v1_gpu2), (grad2_gpu0, v2_gpu2) ... + [[(grad0_gpu0, v0_gpu0), (grad1_gpu0, v1_gpu0), (grad2_gpu0, v2_gpu0) ...], + [(grad0_gpu1, v0_gpu1), (grad1_gpu1, v1_gpu1), (grad2_gpu1, v2_gpu1) ...], + [(grad0_gpu2, v0_gpu2), (grad1_gpu0, v1_gpu2), (grad2_gpu0, v2_gpu2) ...], ... ] @@ -266,7 +291,10 @@ def _group_value_by_device(per_device_values): return grouped -def _ungroup_and_make_mirrored(grouped_reduced, destinations, aggregation): +def _ungroup_and_make_mirrored(grouped_reduced, + destinations, + aggregation, + num_between_graph_workers=1): """Ungroup results from all-reduce and make Mirrored objects. Each all-reduce result will be divided by the number of destinations before @@ -278,7 +306,9 @@ def _ungroup_and_make_mirrored(grouped_reduced, destinations, aggregation): cross_tower_utils.aggregate_gradients_using*. destinations: a list of device strings for returned Mirrored objects. aggregation: Indicates how a variable will be aggregated. Accepted values - are @{tf.VariableAggregation.SUM}, @{tf.VariableAggregation.MEAN}. + are `tf.VariableAggregation.SUM`, `tf.VariableAggregation.MEAN`. + num_between_graph_workers: number of workers in the between-graph + replication. Returns: a list of Mirrored objects. @@ -287,7 +317,8 @@ def _ungroup_and_make_mirrored(grouped_reduced, destinations, aggregation): for d, per_device_reduced in enumerate(grouped_reduced): for i, (v, _) in enumerate(per_device_reduced): if aggregation == vs.VariableAggregation.MEAN: - index[i][destinations[d]] = v / len(destinations) + index[i][destinations[d]] = v / ( + len(destinations) * num_between_graph_workers) else: index[i][destinations[d]] = v return [value_lib.Mirrored(v) for v in index] @@ -508,7 +539,10 @@ class AllReduceCrossTowerOps(CrossTowerOps): logging.WARN, "Efficient allreduce is not supported for IndexedSlices.", 10) - devices = get_devices_from(destinations or per_device_value) + if check_destinations(destinations): + devices = get_devices_from(destinations) + else: + devices = get_devices_from(per_device_value) reduce_to_device = devices[0] reduced = _simple_reduce(per_device_value, reduce_to_device, math_ops.add_n, aggregation) @@ -534,12 +568,12 @@ class AllReduceCrossTowerOps(CrossTowerOps): def _batch_all_reduce(self, aggregation, per_device_values): """All reduce algorithm in a batch.""" - logging.info( - "batch_all_reduce invoked for batches size = %d with " + logging.log_first_n( + logging.INFO, "batch_all_reduce invoked for batches size = %d with " "algorithm = %s, num_packs = %d, agg_small_grads_max_bytes = %d and " - "agg_small_grads_max_group = %d", len(per_device_values), - self._all_reduce_alg, self._num_packs, self._agg_small_grads_max_bytes, - self._agg_small_grads_max_group) + "agg_small_grads_max_group = %d" % + (len(per_device_values), self._all_reduce_alg, self._num_packs, + self._agg_small_grads_max_bytes, self._agg_small_grads_max_group), 10) destinations = per_device_values[0].devices grouped = _group_value_by_device(per_device_values) @@ -644,12 +678,13 @@ class MultiWorkerAllReduce(AllReduceCrossTowerOps): def _batch_all_reduce(self, aggregation, per_device_values): """All reduce algorithm in a batch.""" - logging.info( + logging.log_first_n( + logging.INFO, "distributed batch_all_reduce invoked for batches size = %d with " "allreduce_spec = %r, num_packs = %d, agg_small_grads_max_bytes = %d " - "and agg_small_grads_max_group = %d", len(per_device_values), - self._all_reduce_spec, self._num_packs, self._agg_small_grads_max_bytes, - self._agg_small_grads_max_group) + "and agg_small_grads_max_group = %d" % + (len(per_device_values), self._all_reduce_spec, self._num_packs, + self._agg_small_grads_max_bytes, self._agg_small_grads_max_group), 10) destinations = sorted(per_device_values[0].devices) device_grads = _group_value_by_device(per_device_values) @@ -692,6 +727,104 @@ class MultiWorkerAllReduce(AllReduceCrossTowerOps): aggregation) +# TODO(yuefengz): support in-graph collective all-reduce. +class CollectiveAllReduce(CrossTowerOps): + """All-reduce cross tower ops using collective ops. + + In the between-graph replicated training, it will still do all-reduces across + all workers and then put results on the right destinations. + """ + + def __init__(self, + num_workers=1, + num_gpus_per_worker=0, + all_reduce_merge_scope=1, + collective_keys=None): + """Initializes the object. + + Args: + num_workers: number of workers in the between-graph replicated training. + num_gpus_per_worker: number of GPUs per worker. + all_reduce_merge_scope: size of groups into which to partition consecutive + gradients grouped under a common 'allreduce' name scope. This is useful + for some optimization of collective ops. + collective_keys: an optional CollectiveKey object. + """ + self._num_workers = num_workers + self._num_gpus_per_worker = num_gpus_per_worker + self._all_reduce_merge_scope = all_reduce_merge_scope + self._collective_keys = collective_keys or cross_tower_utils.CollectiveKeys( + ) + super(CollectiveAllReduce, self).__init__() + + # TODO(yuefengz, tucker): is indexed slices supported by collective ops? + def _reduce(self, aggregation, per_device_value, destinations): + all_reduced = self._batch_all_reduce(aggregation, [per_device_value])[0] + if destinations is None or _devices_match(per_device_value, destinations): + return all_reduced + else: + index = {} + for d in get_devices_from(destinations): + # pylint: disable=protected-access + if d in all_reduced._index: + index[d] = all_reduced._index[d] + else: + with ops.control_dependencies(list( + all_reduced._index.values())), ops.device(d): + index[d] = array_ops.identity(list(all_reduced._index.values())[0]) + + return value_lib.Mirrored(index) + + def _batch_reduce(self, aggregation, value_destination_pairs): + return [ + self._reduce(aggregation, t, destinations=v) + for t, v in value_destination_pairs + ] + + def _batch_all_reduce(self, aggregation, per_device_values): + """All-reduce across all workers in a batch.""" + if context.executing_eagerly(): + raise ValueError("Eager mode with collective ops is not supported yet.") + + logging.log_first_n( + logging.INFO, "Collective All-reduce invoked with batches size = %d, " + "num_workers = %d" % (len(per_device_values), self._num_workers), 10) + + grouped_by_tower = _group_value_by_device(per_device_values) + + grouped_by_var = list(zip(*grouped_by_tower)) + # grouped_by_var is grouped by variables and takes the following format: + # [((grad0_gpu0, v0_gpu0), (grad0_gpu1, v0_gpu1), (grad0_gpu2, v0_gpu2) ..), + # ((grad1_gpu0, v1_gpu0), (grad1_gpu1, v1_gpu1), (grad1_gpu0, v1_gpu2) ..), + # ((grad2_gpu0, v2_gpu0), (grad2_gpu1, v2_gpu1), (grad2_gpu0, v2_gpu2) ..), + # ... + # ] + chunked_gv = [ + grouped_by_var[x:x + self._all_reduce_merge_scope] + for x in range(0, len(grouped_by_var), self._all_reduce_merge_scope) + ] + + reduced_gv_list = [] + for chunk in chunked_gv: + with ops.name_scope("allreduce"): + for grad_and_vars in chunk: + scaled_grads = [g for g, _ in grad_and_vars] + collective_reduced = cross_tower_utils.build_collective_reduce( + scaled_grads, self._num_workers, self._collective_keys, "Add", + "Id") + result = [] + for (_, v), g in zip(grad_and_vars, collective_reduced): + result.append([g, v]) + reduced_gv_list.append(result) + + new_tower_grads = [list(x) for x in zip(*reduced_gv_list)] + return _ungroup_and_make_mirrored( + new_tower_grads, + per_device_values[0].devices, + aggregation, + num_between_graph_workers=self._num_workers) + + _dgx1_links = [[1, 2, 3, 4], [0, 2, 3, 5], [0, 1, 3, 6], [0, 1, 2, 7], [0, 5, 6, 7], [1, 4, 6, 7], [2, 4, 5, 7], [3, 4, 5, 6]] diff --git a/tensorflow/contrib/distribute/python/cross_tower_ops_test.py b/tensorflow/contrib/distribute/python/cross_tower_ops_test.py index 6a780ff60ffcd59d416278bfde6d005d7ad37a68..3508c9d5997070ef1350d4f08f98bf2d9c8b6837 100644 --- a/tensorflow/contrib/distribute/python/cross_tower_ops_test.py +++ b/tensorflow/contrib/distribute/python/cross_tower_ops_test.py @@ -21,13 +21,17 @@ from __future__ import print_function import itertools from absl.testing import parameterized +import numpy as np from tensorflow.contrib.distribute.python import combinations from tensorflow.contrib.distribute.python import cross_tower_ops as cross_tower_ops_lib +from tensorflow.contrib.distribute.python import cross_tower_utils from tensorflow.contrib.distribute.python import multi_worker_test_base from tensorflow.contrib.distribute.python import values as value_lib +from tensorflow.core.protobuf import config_pb2 from tensorflow.python.eager import context from tensorflow.python.eager import test +from tensorflow.python.estimator import run_config from tensorflow.python.framework import constant_op from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops @@ -376,5 +380,172 @@ class MultiWorkerCrossTowerOpsTest(multi_worker_test_base.MultiWorkerTestBase, self._testReductionAndBroadcast(cross_tower_ops, distribution) +class MultiWorkerCollectiveAllReduceTest( + multi_worker_test_base.MultiWorkerTestBase, parameterized.TestCase): + + collective_key_base = 100000 + + @classmethod + def setUpClass(cls): + """Create a local cluster with 2 workers.""" + cls._workers, cls._ps = multi_worker_test_base.create_in_process_cluster( + num_workers=3, num_ps=0) + cls._cluster_spec = { + run_config.TaskType.WORKER: [ + "fake_worker_0", "fake_worker_1", "fake_worker_2" + ] + } + + def setUp(self): + super(MultiWorkerCollectiveAllReduceTest, self).setUp() + # Reusing keys are not supported well. So we have to give a different + # collective key base for different tests. + MultiWorkerCollectiveAllReduceTest.collective_key_base += 100000 + + def _get_test_objects(self, task_type, task_id, num_gpus=0, local_mode=False): + collective_keys = cross_tower_utils.CollectiveKeys( + group_key_start=10 * num_gpus + + MultiWorkerCollectiveAllReduceTest.collective_key_base, + instance_key_start=num_gpus * 100 + + MultiWorkerCollectiveAllReduceTest.collective_key_base, + instance_key_with_id_start=num_gpus * 10000 + + MultiWorkerCollectiveAllReduceTest.collective_key_base) + if local_mode: + collective_all_reduce_ops = cross_tower_ops_lib.CollectiveAllReduce( + 1, num_gpus, collective_keys=collective_keys) + if num_gpus: + devices = ["/device:GPU:%d" % i for i in range(num_gpus)] + else: + devices = ["/device:CPU:0"] + return collective_all_reduce_ops, devices, "" + else: + collective_all_reduce_ops = cross_tower_ops_lib.CollectiveAllReduce( + 3, num_gpus, collective_keys=collective_keys) + if num_gpus: + devices = [ + "/job:%s/task:%d/device:GPU:%d" % (task_type, task_id, i) + for i in range(num_gpus) + ] + else: + devices = ["/job:%s/task:%d" % (task_type, task_id)] + return collective_all_reduce_ops, devices, self._workers[task_id].target + + def _assert_values_equal(self, left, right, sess): + if isinstance(left, list): + for l, r in zip(left, right): + self._assert_values_equal(l, r, sess) + else: + self.assertEqual(type(left), type(right)) + self.assertEqual(set(left.devices), set(right.devices)) + + run_options = config_pb2.RunOptions() + run_options.experimental.collective_graph_key = 6 + + left_values = np.array( + sess.run(list(left._index.values()), options=run_options)).flatten() + right_values = np.array(list(right._index.values())).flatten() + self.assertEqual(len(left_values), len(right_values)) + for l, r in zip(left_values, right_values): + self.assertEqual(l, r) + + def _test_reduction(self, task_type, task_id, num_gpus, local_mode=False): + collective_all_reduce, devices, master_target = self._get_test_objects( + task_type, task_id, num_gpus, local_mode=local_mode) + if local_mode: + num_workers = 1 + worker_device = None + else: + num_workers = len(self._workers) + worker_device = "/job:%s/task:%d" % (task_type, task_id) + with ops.Graph().as_default(), \ + ops.device(worker_device), \ + self.test_session(target=master_target) as sess: + # Collective ops doesn't support scalar tensors, so we have to construct + # 1-d tensors. + values = [constant_op.constant([float(d)]) for d in range(len(devices))] + per_device = _make_per_device(values, devices) + mean = np.array([(len(devices) - 1.) / 2.]) + + values_2 = [constant_op.constant([d + 1.0]) for d in range(len(devices))] + per_device_2 = _make_per_device(values_2, devices) + mean_2 = np.array([mean[0] + 1.]) + + destination_mirrored = _fake_mirrored(1., devices) + destination_different = _fake_mirrored(1., _cpu_device) + destination_str = _cpu_device + destination_list = devices + + all_destinations = [ + destination_different, None, destination_mirrored, destination_str, + destination_list + ] + + # test reduce() + for destinations in all_destinations: + self._assert_values_equal( + collective_all_reduce.reduce( + vs.VariableAggregation.MEAN, + per_device, + destinations=destinations), + _fake_mirrored(mean, destinations or per_device), sess) + self._assert_values_equal( + collective_all_reduce.reduce( + vs.VariableAggregation.MEAN, + per_device_2, + destinations=destinations), + _fake_mirrored(mean_2, destinations or per_device), sess) + self._assert_values_equal( + collective_all_reduce.reduce( + vs.VariableAggregation.SUM, + per_device, + destinations=destinations), + _fake_mirrored(mean * len(devices) * num_workers, destinations or + per_device), sess) + self._assert_values_equal( + collective_all_reduce.reduce( + vs.VariableAggregation.SUM, + per_device_2, + destinations=destinations), + _fake_mirrored(mean_2 * len(devices) * num_workers, destinations or + per_device), sess) + + # test batch_reduce() + for d1, d2 in itertools.product(all_destinations, all_destinations): + self._assert_values_equal( + collective_all_reduce.batch_reduce(vs.VariableAggregation.MEAN, + [(per_device, d1), + (per_device_2, d2)]), + [ + _fake_mirrored(mean, d1 or per_device), + _fake_mirrored(mean_2, d2 or per_device_2) + ], sess) + self._assert_values_equal( + collective_all_reduce.batch_reduce(vs.VariableAggregation.SUM, + [(per_device, d1), + (per_device_2, d2)]), + [ + _fake_mirrored(mean * len(devices) * num_workers, d1 or + per_device), + _fake_mirrored(mean_2 * len(devices) * num_workers, d2 or + per_device_2) + ], sess) + + return True + + @combinations.generate( + combinations.combine(mode=["graph"], num_gpus=[0, 1, 2])) + def testReductionDistributed(self, num_gpus): + if context.num_gpus() < num_gpus: + return + self._run_between_graph_clients(self._test_reduction, self._cluster_spec, + num_gpus) + + # Collective ops doesn't support strategy with one device. + def testReductionLocal(self, num_gpus=2): + if context.num_gpus() < num_gpus: + return + self._test_reduction(None, None, num_gpus, local_mode=True) + + if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/distribute/python/cross_tower_utils.py b/tensorflow/contrib/distribute/python/cross_tower_utils.py index 2bb088e704c584598b863b1b836166af2a5bb12c..24cb08fb48f832572da5ae2113e6c224557c6a81 100644 --- a/tensorflow/contrib/distribute/python/cross_tower_utils.py +++ b/tensorflow/contrib/distribute/python/cross_tower_utils.py @@ -19,13 +19,16 @@ from __future__ import division from __future__ import print_function import collections as pycoll +import threading from tensorflow.contrib import nccl from tensorflow.contrib.all_reduce.python import all_reduce from tensorflow.contrib.distribute.python import values as value_lib +from tensorflow.python.framework import device as pydev from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops +from tensorflow.python.ops import collective_ops from tensorflow.python.ops import gradients_impl from tensorflow.python.ops import math_ops @@ -218,6 +221,146 @@ def split_grads_by_size(threshold_size, device_grads): return small_grads, large_grads +# threading.Lock() cannot be pickled and therefore cannot be a field of +# CollectiveKeys. +_lock = threading.Lock() + + +# TODO(yuefengz): use random key starts to avoid reusing keys? +class CollectiveKeys(object): + """Class that manages collective keys. + + We need to manage three different keys for collective: + + *Group key*: an integer key to identify the set of cooperative devices. + Collective ops work under the same set of devices must using the same group + key. + + *Instance key*: an integer key to identify the set of same counterpart of + tensors on different devices in a device group that need to be all-reduced. + + "Graph key": an integer key that is unique key graph. This is used to support + multiple graphs per client session. It must be non-zero and set in the + `config` argument of each call to `session.run`. + """ + + def __init__(self, + group_key_start=1, + instance_key_start=100, + instance_key_with_id_start=10000): + """Initializes the object. + + Args: + group_key_start: the starting integer of group key. + instance_key_start: the starting integer of instance key. + instance_key_with_id_start: the starting integer of instance key that is + recorded with an id. + """ + self._group_key = group_key_start + self._group_key_table = dict() + + # For instance keys with ids + self._instance_key_id_to_key_table = dict() + self._instance_key_with_id_counter = instance_key_with_id_start + + # For instance keys without ids + self._instance_key_start = instance_key_start + + self._thread_local = threading.local() + + def _get_thread_local_object(self): + # We make instance key without key ids thread local so that it will work + # with MirroredStrategy and distribute coordinator. + if not hasattr(self._thread_local, 'instance_key'): + self._thread_local.instance_key = self._instance_key_start + return self._thread_local + + def get_group_key(self, devices): + """Returns a group key for the set of devices. + + Args: + devices: list of strings naming devices in a collective group. + + Returns: + int key uniquely identifying the set of device names. + """ + parsed = [pydev.DeviceSpec.from_string(d) for d in devices] + # In the between-graph replicated training, different workers need to get + # the same device key. So we remove the task_type and task_id from the + # devices. + # TODO(yuefengz): in the in-graph replicated training, we need to include + # task_type and task_id. + names = sorted(['%s:%d' % (d.device_type, d.device_index) for d in parsed]) + key_id = ','.join(names) + with _lock: + if key_id not in self._group_key_table: + new_key = self._group_key + self._group_key += 1 + self._group_key_table[key_id] = new_key + return self._group_key_table[key_id] + + def get_instance_key(self, key_id=None): + """Returns a new instance key for use in defining a collective op. + + Args: + key_id: optional string. If set, key will be recorded and the same key + will be returned when the same key_id is provided. If not, an increasing + instance key will be returned. + """ + if key_id: + with _lock: + if key_id not in self._instance_key_id_to_key_table: + self._instance_key_with_id_counter += 1 + self._instance_key_id_to_key_table[key_id] = ( + self._instance_key_with_id_counter) + return self._instance_key_id_to_key_table[key_id] + else: + v = self._get_thread_local_object().instance_key + self._get_thread_local_object().instance_key += 1 + return v + + +def build_collective_reduce(input_tensors, + num_workers, + collective_keys, + reduction_op='Add', + unary_op='Id'): + """Build a subgraph that does one full all-reduce, using the collective Op. + + Args: + input_tensors: tensors within a single worker graph that are to be reduced + together; must be one per device. + num_workers: total number of workers with identical independent graphs that + will be doing this same reduction. The reduction will actually include + the corresponding tensors at all these workers. + collective_keys: a CollectiveKeys object. + reduction_op: string naming the reduction op. + unary_op: string naming the unary final op. + + Returns: + An array of final tensors, one per device, computed by the full reduction. + + Raises: + ValueError: There must be at least two tensors over all the workers. + """ + group_size = len(input_tensors) * num_workers + if group_size < 2: + raise ValueError('num_workers * len(input_tensors) must be 2 or greater') + devices = [t.device for t in input_tensors] + num_devices = len(devices) + group_key = collective_keys.get_group_key(devices) + instance_key = collective_keys.get_instance_key() + out_tensors = [] + subdiv_offsets = [0] # TODO(tucker): maybe support non-default subdiv spec + for d in range(num_devices): + with ops.device(devices[d]): + reduce_op = collective_ops.all_reduce( + input_tensors[d], group_size, group_key, instance_key, reduction_op, + unary_op, subdiv_offsets) + out_tensors.append(reduce_op) + return out_tensors + + def sum_grad_and_var_all_reduce(grad_and_vars, num_workers, alg, @@ -253,10 +396,10 @@ def sum_grad_and_var_all_reduce(grad_and_vars, else: raise ValueError('unsupported all_reduce alg: ', alg) - result = [] - for (_, v), g in zip(grad_and_vars, summed_grads): - result.append([g, v]) - return result + result = [] + for (_, v), g in zip(grad_and_vars, summed_grads): + result.append([g, v]) + return result def sum_gradients_all_reduce(dev_prefixes, tower_grads, num_workers, alg, diff --git a/tensorflow/contrib/distribute/python/estimator_integration_test.py b/tensorflow/contrib/distribute/python/estimator_integration_test.py index 34410a6470185ac2821bc6a59de9230ff478aeb6..cc626c33bf8e282736f8e6e0c151e5a3d3f3244b 100644 --- a/tensorflow/contrib/distribute/python/estimator_integration_test.py +++ b/tensorflow/contrib/distribute/python/estimator_integration_test.py @@ -29,6 +29,7 @@ from tensorflow.contrib.optimizer_v2 import adagrad from tensorflow.python.data.ops import dataset_ops from tensorflow.python.eager import test from tensorflow.python.estimator import run_config +from tensorflow.python.estimator import training from tensorflow.python.estimator.canned import dnn_linear_combined from tensorflow.python.estimator.canned import prediction_keys from tensorflow.python.estimator.export import export @@ -63,8 +64,9 @@ class DNNLinearCombinedClassifierIntegrationTest(test.TestCase, combinations.one_device_strategy, combinations.mirrored_strategy_with_gpu_and_cpu, combinations.mirrored_strategy_with_two_gpus - ])) - def test_complete_flow_with_mode(self, distribution): + ], + use_train_and_evaluate=[True, False])) + def test_complete_flow_with_mode(self, distribution, use_train_and_evaluate): label_dimension = 2 input_dimension = label_dimension batch_size = 10 @@ -75,8 +77,11 @@ class DNNLinearCombinedClassifierIntegrationTest(test.TestCase, y=data, batch_size=batch_size // len(distribution.worker_devices), shuffle=True) - eval_input_fn = numpy_io.numpy_input_fn( - x={'x': data}, y=data, batch_size=batch_size, shuffle=False) + eval_input_fn = self.dataset_input_fn( + x={'x': data}, + y=data, + batch_size=batch_size // len(distribution.worker_devices), + shuffle=False) predict_input_fn = numpy_io.numpy_input_fn( x={'x': data}, batch_size=batch_size, shuffle=False) @@ -96,12 +101,19 @@ class DNNLinearCombinedClassifierIntegrationTest(test.TestCase, # TODO(isaprykin): Work around the colocate_with error. dnn_optimizer=adagrad.AdagradOptimizer(0.001), linear_optimizer=adagrad.AdagradOptimizer(0.001), - config=run_config.RunConfig(train_distribute=distribution)) + config=run_config.RunConfig( + train_distribute=distribution, eval_distribute=distribution)) num_steps = 10 - estimator.train(train_input_fn, steps=num_steps) + if use_train_and_evaluate: + scores, _ = training.train_and_evaluate( + estimator, + training.TrainSpec(train_input_fn, max_steps=num_steps), + training.EvalSpec(eval_input_fn)) + else: + estimator.train(train_input_fn, steps=num_steps) + scores = estimator.evaluate(eval_input_fn) - scores = estimator.evaluate(eval_input_fn) self.assertEqual(num_steps, scores[ops.GraphKeys.GLOBAL_STEP]) self.assertIn('loss', six.iterkeys(scores)) diff --git a/tensorflow/contrib/distribute/python/examples/simple_estimator_example.py b/tensorflow/contrib/distribute/python/examples/simple_estimator_example.py index 00c25c7a2482a559c8b94ff3be86c4961dfb439f..44a69ed23a4e00ab81d5b51ae0c14550bd493f14 100644 --- a/tensorflow/contrib/distribute/python/examples/simple_estimator_example.py +++ b/tensorflow/contrib/distribute/python/examples/simple_estimator_example.py @@ -59,7 +59,8 @@ def build_model_fn_optimizer(): def main(_): distribution = tf.contrib.distribute.MirroredStrategy( ["/device:GPU:0", "/device:GPU:1"]) - config = tf.estimator.RunConfig(train_distribute=distribution) + config = tf.estimator.RunConfig(train_distribute=distribution, + eval_distribute=distribution) def input_fn(): features = tf.data.Dataset.from_tensors([[1.]]).repeat(10) @@ -70,7 +71,7 @@ def main(_): model_fn=build_model_fn_optimizer(), config=config) estimator.train(input_fn=input_fn, steps=10) - eval_result = estimator.evaluate(input_fn=input_fn) + eval_result = estimator.evaluate(input_fn=input_fn, steps=10) print("Eval result: {}".format(eval_result)) def predict_input_fn(): diff --git a/tensorflow/contrib/distribute/python/examples/simple_tfkeras_example.py b/tensorflow/contrib/distribute/python/examples/simple_tfkeras_example.py index 2b05884b9b93470ef9a764cbedbc91bd3912c611..518ec9c4232465c3ecd0e4161f707dac499430c7 100644 --- a/tensorflow/contrib/distribute/python/examples/simple_tfkeras_example.py +++ b/tensorflow/contrib/distribute/python/examples/simple_tfkeras_example.py @@ -57,7 +57,8 @@ def main(args): # tf.Estimator that utilizes the DistributionStrategy. strategy = tf.contrib.distribute.MirroredStrategy( ['/device:GPU:0', '/device:GPU:1']) - config = tf.estimator.RunConfig(train_distribute=strategy) + config = tf.estimator.RunConfig( + train_distribute=strategy, eval_distribute=strategy) keras_estimator = tf.keras.estimator.model_to_estimator( keras_model=model, config=config, model_dir=model_dir) diff --git a/tensorflow/contrib/distribute/python/keras_test.py b/tensorflow/contrib/distribute/python/keras_test.py index 75ecd90dcffa7a786b78238ef453c4c8e4346afa..a262d7666e7be2c28857b7b38ad0ccbd1b053463 100644 --- a/tensorflow/contrib/distribute/python/keras_test.py +++ b/tensorflow/contrib/distribute/python/keras_test.py @@ -12,33 +12,40 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Tests for Keras Sequential and Functional models.""" +"""Tests for tf.keras models using DistributionStrategy.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function import os - import numpy as np from tensorflow.contrib.distribute.python import mirrored_strategy +from tensorflow.contrib.distribute.python import values from tensorflow.python import keras from tensorflow.python.data.ops import dataset_ops from tensorflow.python.estimator import keras as keras_lib from tensorflow.python.estimator import run_config as run_config_lib +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes from tensorflow.python.framework import test_util from tensorflow.python.keras import testing_utils +from tensorflow.python.keras.engine import distributed_training_utils from tensorflow.python.platform import gfile from tensorflow.python.platform import test from tensorflow.python.summary.writer import writer_cache +from tensorflow.python.training import gradient_descent from tensorflow.python.training import rmsprop + _RANDOM_SEED = 1337 _TRAIN_SIZE = 200 _INPUT_SIZE = (10,) _NUM_CLASS = 2 +# TODO(anjalisridhar): Add a decorator that will allow us to run these tests as +# part of the tf.keras unit tests suite. def simple_sequential_model(): model = keras.models.Sequential() model.add(keras.layers.Dense(16, activation='relu', input_shape=_INPUT_SIZE)) @@ -84,7 +91,7 @@ def get_ds_test_input_fn(): return dataset -class TestKerasDistributionStrategy(test_util.TensorFlowTestCase): +class TestEstimatorDistributionStrategy(test_util.TensorFlowTestCase): def setUp(self): self._base_dir = os.path.join(self.get_temp_dir(), @@ -107,7 +114,8 @@ class TestKerasDistributionStrategy(test_util.TensorFlowTestCase): optimizer=rmsprop.RMSPropOptimizer(learning_rate=0.01)) config = run_config_lib.RunConfig(tf_random_seed=_RANDOM_SEED, model_dir=self._base_dir, - train_distribute=dist) + train_distribute=dist, + eval_distribute=dist) with self.test_session(): est_keras = keras_lib.model_to_estimator( keras_model=keras_model, config=config) @@ -144,5 +152,457 @@ class TestKerasDistributionStrategy(test_util.TensorFlowTestCase): writer_cache.FileWriterCache.clear() gfile.DeleteRecursively(self._config.model_dir) + def test_keras_optimizer_with_distribution_strategy(self): + dist = mirrored_strategy.MirroredStrategy( + devices=['/device:GPU:0', '/device:GPU:1']) + keras_model = simple_sequential_model() + keras_model.compile( + loss='categorical_crossentropy', + optimizer=keras.optimizers.rmsprop(lr=0.01)) + + config = run_config_lib.RunConfig(tf_random_seed=_RANDOM_SEED, + model_dir=self._base_dir, + train_distribute=dist) + with self.test_session(): + est_keras = keras_lib.model_to_estimator(keras_model=keras_model, + config=config) + with self.assertRaisesRegexp(ValueError, + 'Only TensorFlow native optimizers are ' + 'supported with DistributionStrategy.'): + est_keras.train(input_fn=get_ds_train_input_fn, steps=_TRAIN_SIZE / 16) + + writer_cache.FileWriterCache.clear() + gfile.DeleteRecursively(self._config.model_dir) + + +class TestWithDistributionStrategy(test.TestCase): + + def test_validating_dataset_input_tensors_with_shape_mismatch(self): + with self.test_session(): + strategy = mirrored_strategy.MirroredStrategy(['/device:GPU:0', + '/device:CPU:0']) + a = constant_op.constant([1, 2], shape=(1, 2)) + b = constant_op.constant([[1, 2], [1, 2]], shape=(2, 2)) + x = values.DistributedValues({'/device:CPU:0': a, '/device:GPU:0': b}) + y = values.DistributedValues({'/device:CPU:0': a, '/device:GPU:0': a}) + with strategy.scope(): + # Removed device and input tensor shape details from the error message + # since the order of the device and the corresponding input tensor shape + # is not deterministic over different runs. + with self.assertRaisesRegexp(ValueError, + 'Input tensor shapes do not match for ' + 'distributed tensor inputs ' + 'DistributedValues:.+'): + distributed_training_utils.validate_distributed_dataset_inputs( + strategy, x, y) + + def test_validating_dataset_input_tensors_with_dtype_mismatch(self): + with self.test_session(): + strategy = mirrored_strategy.MirroredStrategy(['/device:GPU:0', + '/device:CPU:0']) + a = constant_op.constant([1, 2], shape=(1, 2), dtype=dtypes.int32) + b = constant_op.constant([1, 2], shape=(1, 2), dtype=dtypes.float64) + x = values.DistributedValues({'/device:CPU:0': a, '/device:GPU:0': b}) + y = values.DistributedValues({'/device:CPU:0': a, '/device:GPU:0': a}) + with strategy.scope(): + # Removed device and input tensor dtype details from the error message + # since the order of the device and the corresponding input tensor dtype + # is not deterministic over different runs. + with self.assertRaisesRegexp(ValueError, + 'Input tensor dtypes do not match for ' + 'distributed tensor inputs ' + 'DistributedValues:.+'): + distributed_training_utils.validate_distributed_dataset_inputs( + strategy, x, y) + + def test_calling_model_on_same_dataset(self): + with self.test_session(): + x = keras.layers.Input(shape=(3,), name='input') + y = keras.layers.Dense(4, name='dense')(x) + model = keras.Model(x, y) + + optimizer = gradient_descent.GradientDescentOptimizer(0.001) + loss = 'mse' + metrics = ['mae'] + strategy = mirrored_strategy.MirroredStrategy(['/device:GPU:1', + '/device:GPU:0']) + model.compile(optimizer, loss, metrics=metrics, distribute=strategy) + + inputs = np.zeros((10, 3), dtype=np.float32) + targets = np.zeros((10, 4), dtype=np.float32) + dataset = dataset_ops.Dataset.from_tensor_slices((inputs, targets)) + dataset = dataset.repeat(100) + dataset = dataset.batch(10) + + # Call fit with validation data + model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=0, + validation_data=dataset, validation_steps=2) + model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=0, + validation_data=dataset, validation_steps=2) + model.predict(dataset, steps=2) + + def test_fit_with_tuple_and_dict_dataset_inputs(self): + with self.test_session(): + a = keras.layers.Input(shape=(3,), name='input_a') + b = keras.layers.Input(shape=(3,), name='input_b') + + dense = keras.layers.Dense(4, name='dense') + c = dense(a) + d = dense(b) + e = keras.layers.Dropout(0.5, name='dropout')(c) + + model = keras.models.Model([a, b], [d, e]) + + optimizer = gradient_descent.GradientDescentOptimizer(learning_rate=0.001) + loss = 'mse' + metrics = ['mae'] + strategy = mirrored_strategy.MirroredStrategy(['/device:GPU:0', + '/device:CPU:0']) + model.compile(optimizer, loss, metrics=metrics, distribute=strategy) + + input_a_np = np.random.random((10, 3)) + input_b_np = np.random.random((10, 3)) + output_d_np = np.random.random((10, 4)) + output_e_np = np.random.random((10, 4)) + + # Test with tuples + dataset_tuple = dataset_ops.Dataset.from_tensor_slices(( + (input_a_np, input_b_np), (output_d_np, output_e_np))) + dataset_tuple = dataset_tuple.repeat(100) + dataset_tuple = dataset_tuple.batch(10) + + model.fit(dataset_tuple, epochs=1, steps_per_epoch=2, verbose=1) + + # Test with dict + dataset_dict = dataset_ops.Dataset.from_tensor_slices(( + {'input_a': input_a_np, 'input_b': input_b_np}, + (output_d_np, output_e_np))) + dataset_dict = dataset_dict.repeat(100) + dataset_dict = dataset_dict.batch(10) + + model.fit(dataset_dict, epochs=1, steps_per_epoch=2, verbose=1) + + def test_fit_eval_and_predict_methods_on_dataset(self): + with self.test_session(): + x = keras.layers.Input(shape=(3,), name='input') + y = keras.layers.Dense(4, name='dense')(x) + model = keras.Model(x, y) + + optimizer = gradient_descent.GradientDescentOptimizer(0.001) + loss = 'mse' + metrics = ['mae'] + strategy = mirrored_strategy.MirroredStrategy(['/device:GPU:0', + '/device:CPU:0']) + + model.compile(optimizer, loss, metrics=metrics, distribute=strategy) + + inputs = np.zeros((10, 3), dtype=np.float32) + targets = np.zeros((10, 4), dtype=np.float32) + dataset = dataset_ops.Dataset.from_tensor_slices((inputs, targets)) + dataset = dataset.repeat(100) + dataset = dataset.batch(10) + + model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=1) + model.evaluate(dataset, steps=2, verbose=1) + model.predict(dataset, steps=2) + # Test with validation data + model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=0, + validation_data=dataset, validation_steps=2) + + def test_raise_error_for_stateful_metrics(self): + + class ExampleStatefulMetric(keras.layers.Layer): + + def __init__(self, name='true_positives', **kwargs): + super(ExampleStatefulMetric, self).__init__(name=name, **kwargs) + self.stateful = True + + def __call__(self, y_true, y_pred): + return y_pred - y_true + + with self.test_session(): + x = keras.layers.Input(shape=(3,), name='input') + y = keras.layers.Dense(4, name='dense')(x) + model = keras.Model(x, y) + + optimizer = gradient_descent.GradientDescentOptimizer(0.001) + loss = 'mse' + metrics = ['mae', ExampleStatefulMetric()] + strategy = mirrored_strategy.MirroredStrategy(['/device:GPU:1', + '/device:GPU:0']) + with self.assertRaisesRegexp( + NotImplementedError, 'Stateful metrics are not supported with ' + 'DistributionStrategy.'): + model.compile(optimizer, loss, metrics=metrics, distribute=strategy) + + def test_unsupported_features(self): + with self.test_session(): + x = keras.layers.Input(shape=(3,), name='input') + y = keras.layers.Dense(4, name='dense')(x) + model = keras.Model(x, y) + + optimizer = gradient_descent.GradientDescentOptimizer(0.001) + loss = 'mse' + metrics = ['mae'] + strategy = mirrored_strategy.MirroredStrategy(['/device:GPU:1', + '/device:GPU:0']) + + model.compile(optimizer, loss, metrics=metrics, distribute=strategy) + + inputs = np.zeros((10, 3), dtype=np.float32) + targets = np.zeros((10, 4), dtype=np.float32) + dataset = dataset_ops.Dataset.from_tensor_slices((inputs, targets)) + dataset = dataset.repeat(100) + dataset = dataset.batch(10) + + # Test with validation split + with self.assertRaisesRegexp( + ValueError, '`validation_split` argument is not ' + 'supported when input `x` is a dataset or a ' + 'dataset iterator.+'): + model.fit(dataset, + epochs=1, steps_per_epoch=2, verbose=0, + validation_split=0.5, validation_steps=2) + + # Test with sample weight. + sample_weight = np.random.random((10,)) + with self.assertRaisesRegexp( + NotImplementedError, '`sample_weight` is currently not supported ' + 'when using DistributionStrategy.'): + model.fit( + dataset, + epochs=1, + steps_per_epoch=2, + verbose=0, + sample_weight=sample_weight) + + # Test with not specifying the `steps` argument. + with self.assertRaisesRegexp( + ValueError, 'you should specify the `steps_per_epoch` argument'): + model.fit(dataset, epochs=1, verbose=0) + with self.assertRaisesRegexp(ValueError, + 'you should specify the `steps` argument'): + model.evaluate(dataset, verbose=0) + + with self.assertRaisesRegexp(ValueError, + 'you should specify the `steps` argument'): + model.predict(dataset, verbose=0) + + def test_calling_with_unsupported_predefined_callbacks(self): + with self.test_session(): + x = keras.layers.Input(shape=(3,), name='input') + y = keras.layers.Dense(4, name='dense')(x) + model = keras.Model(x, y) + + optimizer = gradient_descent.GradientDescentOptimizer(0.001) + loss = 'mse' + metrics = ['mae'] + strategy = mirrored_strategy.MirroredStrategy(['/device:GPU:1', + '/device:GPU:0']) + model.compile(optimizer, loss, metrics=metrics, distribute=strategy) + + inputs = np.zeros((10, 3), dtype=np.float32) + targets = np.zeros((10, 4), dtype=np.float32) + dataset = dataset_ops.Dataset.from_tensor_slices((inputs, targets)) + dataset = dataset.repeat(100) + dataset = dataset.batch(10) + + def schedule(_): + return 0.001 + with self.assertRaisesRegexp(ValueError, + 'LearningRateScheduler callback is not ' + 'supported with DistributionStrategy.'): + model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=0, + callbacks=[keras.callbacks.LearningRateScheduler(schedule)]) + + with self.assertRaisesRegexp(ValueError, + 'ReduceLROnPlateau callback is not ' + 'supported with DistributionStrategy.'): + model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=0, + callbacks=[keras.callbacks.ReduceLROnPlateau()]) + with self.assertRaisesRegexp(ValueError, + 'histogram_freq in the TensorBoard callback ' + 'is not supported when using ' + 'DistributionStrategy.'): + model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=0, + callbacks=[keras.callbacks.TensorBoard(histogram_freq=10)]) + + def test_dataset_input_shape_validation(self): + with self.test_session(): + x = keras.layers.Input(shape=(3,), name='input') + y = keras.layers.Dense(4, name='dense')(x) + model = keras.Model(x, y) + + optimizer = rmsprop.RMSPropOptimizer(learning_rate=0.001) + loss = 'mse' + strategy = mirrored_strategy.MirroredStrategy(['/device:GPU:1', + '/device:GPU:0']) + + model.compile(optimizer, loss, distribute=strategy) + + # User forgets to batch the dataset + inputs = np.zeros((10, 3), dtype=np.float32) + targets = np.zeros((10, 4), dtype=np.float32) + dataset = dataset_ops.Dataset.from_tensor_slices((inputs, targets)) + dataset = dataset.repeat(100) + + with self.assertRaisesRegexp(ValueError, + 'expected input to have 2 dimensions'): + model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=0) + + # Wrong input shape + inputs = np.zeros((10, 5), dtype=np.float32) + targets = np.zeros((10, 4), dtype=np.float32) + dataset = dataset_ops.Dataset.from_tensor_slices((inputs, targets)) + dataset = dataset.repeat(100) + dataset = dataset.batch(10) + + with self.assertRaisesRegexp(ValueError, + 'expected input to have shape'): + model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=0) + + def test_learning_phase_value(self): + # TODO(anjalisridhar): Modify this test to use Lambdas since we can compare + # meaningful values. Currently we don't pass the learning phase if the + # Lambda layer uses the learning phase. + with self.test_session(): + x = keras.layers.Input(shape=(16,), name='input') + y = keras.layers.Dense(16)(x) + z = keras.layers.Dropout(0.9999)(y) + model = keras.Model(x, z) + + optimizer = gradient_descent.GradientDescentOptimizer(0.005) + loss = 'mse' + metrics = ['acc'] + strategy = mirrored_strategy.MirroredStrategy(['/device:GPU:0', + '/device:CPU:0']) + + model.compile(optimizer, loss, metrics=metrics, distribute=strategy) + + inputs = np.random.rand(10, 16) + targets = np.ones((10, 16), dtype=np.float32) + dataset = dataset_ops.Dataset.from_tensor_slices((inputs, targets)) + dataset = dataset.repeat(100) + dataset = dataset.batch(8) + + hist = model.fit(dataset, epochs=5, steps_per_epoch=20, verbose=1) + self.assertEqual(hist.history['acc'][0], 1) + + evaluate_output = model.evaluate(dataset, steps=20) + self.assertEqual(evaluate_output[1], 0) + + predict_output = model.predict(dataset, steps=1) + self.assertNotEqual(np.mean(predict_output), 0) + + +class LossMaskingWithDistributionStrategyTest(test.TestCase): + + def test_masking(self): + with self.test_session(): + np.random.seed(1337) + x = np.array([[[1], [1]], [[0], [0]]]) + model = keras.models.Sequential() + model.add(keras.layers.Masking(mask_value=0, input_shape=(2, 1))) + model.add( + keras.layers.TimeDistributed( + keras.layers.Dense(1, kernel_initializer='one'))) + strategy = mirrored_strategy.MirroredStrategy(['/device:GPU:1', + '/device:GPU:0']) + + model.compile(loss='mse', + optimizer=gradient_descent.GradientDescentOptimizer(0.01), + distribute=strategy) + y = np.array([[[1], [1]], [[1], [1]]]) + dataset = dataset_ops.Dataset.from_tensor_slices((x, y)) + dataset = dataset.repeat(100) + dataset = dataset.batch(10) + hist = model.fit(x=dataset, epochs=1, steps_per_epoch=2) + self.assertEqual(hist.history['loss'][0], 0) + + +class NormalizationLayerWithDistributionStrategyTest(test.TestCase): + + def test_batchnorm_correctness(self): + with self.test_session(): + model = keras.models.Sequential() + norm = keras.layers.BatchNormalization(input_shape=(10,), momentum=0.8) + model.add(norm) + strategy = mirrored_strategy.MirroredStrategy(['/device:CPU:0', + '/device:GPU:0']) + model.compile(loss='mse', + optimizer=gradient_descent.GradientDescentOptimizer(0.01), + distribute=strategy) + + # centered on 5.0, variance 10.0 + x = np.random.normal(loc=5.0, scale=10.0, size=(1000, 10)) + dataset = dataset_ops.Dataset.from_tensor_slices((x, x)) + dataset = dataset.repeat(100) + dataset = dataset.batch(32) + + model.fit(dataset, epochs=4, verbose=0, steps_per_epoch=10) + out = model.predict(dataset, steps=2) + out -= keras.backend.eval(norm.beta) + out /= keras.backend.eval(norm.gamma) + np.testing.assert_allclose(out.mean(), 0.0, atol=1e-1) + np.testing.assert_allclose(out.std(), 1.0, atol=1e-1) + + +class CorrectnessWithDistributionStrategyTest(test.TestCase): + + def test_correctness(self): + with self.test_session(): + keras.backend.set_image_data_format('channels_last') + num_samples = 10000 + x_train = np.random.rand(num_samples, 1) + y_train = 3 * x_train + x_train = x_train.astype('float32') + y_train = y_train.astype('float32') + + model = keras.Sequential() + model.add(keras.layers.Dense(1, input_shape=(1,))) + + # With DistributionStrategy + dataset_with = dataset_ops.Dataset.from_tensor_slices((x_train, y_train)) + dataset_with = dataset_with.batch(32) + strategy = mirrored_strategy.MirroredStrategy(devices=['/device:CPU:0', + '/device:GPU:0'], + prefetch_on_device=False) + + model.compile(loss=keras.losses.mean_squared_error, + optimizer=gradient_descent.GradientDescentOptimizer(0.5), + distribute=strategy) + model.fit(x=dataset_with, epochs=1, steps_per_epoch=310) + wts_with_ds = model.get_weights() + + x_predict = [[1], [2], [3], [4]] + predict_dataset_with = dataset_ops.Dataset.from_tensor_slices((x_predict, + x_predict)) + predict_dataset_with = predict_dataset_with.batch(2) + predict_with_ds = model.predict(predict_dataset_with, steps=1) + predict_with_ds = np.reshape(predict_with_ds, (4, 1)) + + # Without DistributionStrategy + dataset_without = dataset_ops.Dataset.from_tensor_slices((x_train, + y_train)) + dataset_without = dataset_without.batch(64) + + model.compile(loss=keras.losses.mean_squared_error, + optimizer=gradient_descent.GradientDescentOptimizer(0.5)) + model.fit(x=dataset_without, epochs=1, steps_per_epoch=310) + wts_without_ds = model.get_weights() + + x_predict = [[1], [2], [3], [4]] + predict_dataset_without = dataset_ops.Dataset.from_tensor_slices(( + x_predict, x_predict)) + predict_dataset_without = predict_dataset_without.batch(4) + predict_without_ds = model.predict(predict_dataset_without, steps=1) + + # Verify that the weights are the same within some limits of tolerance. + np.testing.assert_allclose(wts_with_ds[0], wts_without_ds[0], rtol=1e-3) + # Verify that the predicted outputs are the same within some limits of + # tolerance. + np.testing.assert_allclose(predict_with_ds, predict_without_ds, rtol=1e-3) + + if __name__ == '__main__': test.main() diff --git a/tensorflow/contrib/distribute/python/metrics_v1_test.py b/tensorflow/contrib/distribute/python/metrics_v1_test.py index 6c6bf143098c1bba64d47efce1bfface7682683d..8163494c8ed2c5c2164df2e731d09ebb794414cd 100644 --- a/tensorflow/contrib/distribute/python/metrics_v1_test.py +++ b/tensorflow/contrib/distribute/python/metrics_v1_test.py @@ -19,7 +19,6 @@ from __future__ import print_function from absl.testing import parameterized -from tensorflow.contrib.data.python.ops import batching from tensorflow.contrib.distribute.python import combinations from tensorflow.python.data.ops import dataset_ops from tensorflow.python.eager import test @@ -69,6 +68,8 @@ def _regression_dataset_fn(): "predictions": [1., .75, .25, 0.]}).repeat() +# TODO(priyag): Add TPU Strategy to this once metrics aggregate correctly using +# TowerLocalVariables on TPUs. Submit http://cl/208914352. def all_combinations(): return combinations.combine( distribution=[combinations.default_strategy, @@ -183,7 +184,7 @@ class MetricsV1Test(test.TestCase, parameterized.TestCase): def _dataset_fn(): dataset = dataset_ops.Dataset.range(1000).map(math_ops.to_float) # Want to produce a fixed, known shape, so drop remainder when batching. - dataset = dataset.apply(batching.batch_and_drop_remainder(4)) + dataset = dataset.batch(4, drop_remainder=True) return dataset def _expected_fn(num_batches): diff --git a/tensorflow/contrib/distribute/python/minimize_loss_test.py b/tensorflow/contrib/distribute/python/minimize_loss_test.py index aeeb9553e6044a0a928936597400e582e0329b95..516ede7ade7d8c9d09198993f919f15377b1c565 100644 --- a/tensorflow/contrib/distribute/python/minimize_loss_test.py +++ b/tensorflow/contrib/distribute/python/minimize_loss_test.py @@ -25,11 +25,13 @@ from tensorflow.contrib.distribute.python import combinations from tensorflow.contrib.distribute.python import mirrored_strategy from tensorflow.contrib.distribute.python.single_loss_example import batchnorm_example from tensorflow.contrib.distribute.python.single_loss_example import minimize_loss_example -from tensorflow.contrib.tpu.python.tpu import tpu from tensorflow.python.data.ops import dataset_ops from tensorflow.python.eager import context from tensorflow.python.eager import test +from tensorflow.python.framework import constant_op from tensorflow.python.framework import ops +from tensorflow.python.layers import core +from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import variable_scope @@ -43,32 +45,60 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase): combinations.times( combinations.distributions_and_v1_optimizers(), combinations.combine(mode=["graph"], use_callable_loss=[True, False]) - + combinations.combine(mode=["eager"], use_callable_loss=[True]), - combinations.combine(is_tpu=[False])) + combinations.combine( - distribution=[combinations.tpu_strategy], - optimizer_fn=[ - combinations.adam_optimizer_v1_fn, - # TODO(isaprykin): Make Adam v2 work with while_loops - # and TPUs. - ], - mode=["graph"], - use_callable_loss=[False], - is_tpu=[True])) - def testTrainNetwork(self, distribution, optimizer_fn, use_callable_loss, - is_tpu): - # TODO(priyag): Remove this once the step TPU Strategy is stable. - if is_tpu: - self.skipTest("TPU tests are WIP.") + + combinations.combine(mode=["eager"], use_callable_loss=[True])) + + combinations.combine( + distribution=[combinations.tpu_strategy], + optimizer_fn=combinations.optimizers_v1, + mode=["graph"], + use_callable_loss=[True, False])) + def testTrainNetwork(self, distribution, optimizer_fn, use_callable_loss): + with distribution.scope(): + model_fn, dataset_fn, layer = minimize_loss_example( + optimizer_fn, use_bias=True, use_callable_loss=use_callable_loss) + + def step_fn(ctx, *inputs): + del ctx # Unused + return distribution.group( + distribution.call_for_each_tower( + model_fn, *inputs, run_concurrently=layer.built)) + + iterator = distribution.distribute_dataset( + dataset_fn).make_one_shot_iterator() + + def run_step(): + return distribution.run_steps_on_dataset( + step_fn, iterator, iterations=2).run_op + + self.evaluate(distribution.initialize()) + if not context.executing_eagerly(): + with self.test_session() as sess: + run_step = sess.make_callable(run_step()) + self.evaluate(variables_lib.global_variables_initializer()) + + weights, biases = [], [] + for _ in range(5): + run_step() + weights.append(self.evaluate(layer.kernel)) + biases.append(self.evaluate(layer.bias)) + + self.evaluate(distribution.finalize()) + + error = abs(numpy.add(numpy.squeeze(weights), numpy.squeeze(biases)) - 1) + is_not_increasing = all(y <= x for x, y in zip(error, error[1:])) + self.assertTrue(is_not_increasing) + + @combinations.generate( + combinations.times( + combinations.distributions_and_v1_optimizers(), + combinations.combine(mode=["graph"], use_callable_loss=[True, False]) + + combinations.combine(mode=["eager"], use_callable_loss=[True]))) + def testTrainNetworkByCallForEachTower(self, distribution, optimizer_fn, + use_callable_loss): with distribution.scope(): model_fn, dataset_fn, layer = minimize_loss_example( optimizer_fn, use_bias=True, use_callable_loss=use_callable_loss) - # TODO(isaprykin): Eliminate `is_tpu`. Probably add a - # `DistributionStrategy.create_monitor` so that each DistributionStrategy - # could influence its training loop. That method would return an instance - # of Monitor. TPUMonitor would execute tpu.initialize_system() and - # tpu.shutdown_system(). iterator = distribution.distribute_dataset( dataset_fn).make_one_shot_iterator() @@ -79,8 +109,6 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase): if not context.executing_eagerly(): with self.test_session() as sess: - if is_tpu: - sess.run(tpu.initialize_system()) run_step = sess.make_callable(run_step()) self.evaluate(variables_lib.global_variables_initializer()) @@ -91,10 +119,6 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase): weights.append(self.evaluate(layer.kernel)) biases.append(self.evaluate(layer.bias)) - if is_tpu: - with self.test_session() as sess: - sess.run(tpu.shutdown_system()) - error = abs(numpy.add(numpy.squeeze(weights), numpy.squeeze(biases)) - 1) is_not_increasing = all(y <= x for x, y in zip(error, error[1:])) self.assertTrue(is_not_increasing) @@ -103,22 +127,12 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase): combinations.times( combinations.distributions_and_v1_optimizers() + combinations.distributions_and_v2_optimizers(), - combinations.combine(mode=["graph", "eager"], is_tpu=[False])) + + combinations.combine(mode=["graph", "eager"])) + combinations.combine( distribution=[combinations.tpu_strategy], - optimizer_fn=[ - combinations.adam_optimizer_v1_fn, - combinations.gradient_descent_optimizer_v1_fn, - combinations.gradient_descent_optimizer_v2_fn, - ], - mode=["graph"], - is_tpu=[True])) - - def testOptimizerInsideModelFn(self, distribution, optimizer_fn, is_tpu): - # TODO(priyag): Remove this once the step TPU Strategy is stable. - if is_tpu: - self.skipTest("TPU tests are WIP.") - + optimizer_fn=combinations.optimizers_v1+combinations.optimizers_v2, + mode=["graph"])) + def testOptimizerInsideModelFn(self, distribution, optimizer_fn): created_variables = [] trainable_variables = [] @@ -139,26 +153,28 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase): use_callable_loss=True, create_optimizer_inside_model_fn=True) + def step_fn(ctx, *inputs): + del ctx # Unused + return distribution.group( + distribution.call_for_each_tower( + model_fn, *inputs, run_concurrently=layer.built)) + iterator = distribution.distribute_dataset( dataset_fn).make_one_shot_iterator() def run_step(): - return distribution.group( - distribution.call_for_each_tower( - model_fn, iterator.get_next(), run_concurrently=layer.built)) + return distribution.run_steps_on_dataset( + step_fn, iterator, iterations=1).run_op + self.evaluate(distribution.initialize()) if not context.executing_eagerly(): with self.test_session() as sess: - if is_tpu: - sess.run(tpu.initialize_system()) run_step = sess.make_callable(run_step()) - self.evaluate(variables_lib.global_variables_initializer()) + self.evaluate(variables_lib.global_variables_initializer()) run_step() - if is_tpu: - with self.test_session() as sess: - sess.run(tpu.shutdown_system()) + self.evaluate(distribution.finalize()) def get_expected_variables(optimizer_fn, num_parameter_devices): variables_map = { @@ -189,27 +205,17 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase): combinations.distributions_and_v1_optimizers(), combinations.combine( mode=["graph", "eager"], - is_tpu=[False], # TODO(isaprykin): Allow False here. Currently subsequent # towers will re-execute UPDATE_OPS of previous towers. update_ops_in_cross_tower_mode=[True])) + combinations.combine( distribution=[combinations.tpu_strategy], - optimizer_fn=[ - combinations.gradient_descent_optimizer_v1_fn, - combinations.gradient_descent_optimizer_v2_fn - ], + optimizer_fn=combinations.optimizers_v1, mode=["graph"], - is_tpu=[True], update_ops_in_cross_tower_mode=[False]))) def testTrainNetworkWithBatchNorm(self, distribution, optimizer_fn, momentum, - renorm, is_tpu, - update_ops_in_cross_tower_mode): + renorm, update_ops_in_cross_tower_mode): """Verifies that moving mean updates are reduced across towers.""" - # TODO(priyag): Remove this once the step TPU Strategy is stable. - if is_tpu: - self.skipTest("TPU tests are WIP.") - with distribution.scope(): num_towers = len(distribution.worker_devices) model_fn, dataset_fn, batchnorm = batchnorm_example( @@ -224,24 +230,28 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase): # this test relies on specific input being on each device. if isinstance(distribution, mirrored_strategy.MirroredStrategy): self.assertFalse(distribution._prefetch_on_device) - iterator = distribution.distribute_dataset( - dataset_fn).make_one_shot_iterator() - def run_step(): + def step_fn(ctx, *inputs): + del ctx # Unused fetches = distribution.unwrap( distribution.call_for_each_tower( - model_fn, iterator.get_next(), - run_concurrently=batchnorm.built)) + model_fn, *inputs, run_concurrently=batchnorm.built)) if update_ops_in_cross_tower_mode: fetches += ops.get_collection(ops.GraphKeys.UPDATE_OPS) return control_flow_ops.group(fetches) + iterator = distribution.distribute_dataset( + dataset_fn).make_one_shot_iterator() + + def run_step(): + return distribution.run_steps_on_dataset( + step_fn, iterator, iterations=1).run_op + + self.evaluate(distribution.initialize()) if not context.executing_eagerly(): with self.test_session() as sess: - if is_tpu: - sess.run(tpu.initialize_system()) run_step = sess.make_callable(run_step()) - self.evaluate(variables_lib.global_variables_initializer()) + self.evaluate(variables_lib.global_variables_initializer()) expected_moving_means = [0.] * 8 @@ -263,9 +273,7 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase): expected_moving_mean - averaged_batch_mean(i)) * (1.0 - momentum)) self.assertNear(expected_moving_means[i], moving_means[i], 0.0001) - if is_tpu: - with self.test_session() as sess: - sess.run(tpu.shutdown_system()) + self.evaluate(distribution.finalize()) @combinations.generate( combinations.times( @@ -285,22 +293,16 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase): combinations.one_device_strategy, combinations.mirrored_strategy_with_gpu_and_cpu, combinations.mirrored_strategy_with_two_gpus - ], - is_tpu=[False]), + ]), combinations.combine( mode=["graph"], use_callable_loss=[True, False]) + combinations.combine(mode=["eager"], use_callable_loss=[True])) + combinations.combine( distribution=[combinations.tpu_strategy], - is_tpu=[True], mode=["graph"], use_callable_loss=[True, False]))) def testMeanVsSum(self, distribution, optimizer_fn, loss_reduction, - use_callable_loss, is_tpu): - # TODO(priyag): Remove this once the step TPU Strategy is stable. - if is_tpu: - self.skipTest("TPU tests are WIP.") - + use_callable_loss): with distribution.scope(): all_vars = [] @@ -326,20 +328,24 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase): labels = dataset_ops.Dataset.from_tensors([[6.], [21.]]) return dataset_ops.Dataset.zip((features, labels)).repeat() + def step_fn(ctx, x, y): + del ctx # Unused + return distribution.group( + distribution.call_for_each_tower( + model_fn, x, y, run_concurrently=False)) + iterator = distribution.distribute_dataset( dataset_fn).make_one_shot_iterator() def run_step(): - return distribution.group( - distribution.call_for_each_tower( - model_fn, *iterator.get_next(), run_concurrently=False)) + return distribution.run_steps_on_dataset( + step_fn, iterator, iterations=1).run_op + self.evaluate(distribution.initialize()) if not context.executing_eagerly(): with self.test_session() as sess: - if is_tpu: - sess.run(tpu.initialize_system()) run_step = sess.make_callable(run_step()) - self.evaluate(variables_lib.global_variables_initializer()) + self.evaluate(variables_lib.global_variables_initializer()) run_step() @@ -369,10 +375,132 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase): # One of the mean loss reductions. self.assertNear(weight, 2 + 10.6, 0.0001) - if is_tpu: + self.evaluate(distribution.finalize()) + + @combinations.generate( + combinations.times( + combinations.distributions_and_v1_optimizers(), + combinations.combine(mode=["graph", "eager"]), + combinations.combine(is_tpu=[False])) + + combinations.combine( + distribution=[combinations.tpu_strategy], + optimizer_fn=combinations.optimizers_v1, + mode=["graph"], + is_tpu=[True])) + def testRunStepsWithOutputContext(self, distribution, optimizer_fn, is_tpu): + with distribution.scope(): + def dataset_fn(): + dataset = dataset_ops.Dataset.from_tensors([[1.]]).repeat() + # TODO(priyag): batch with drop_remainder=True causes shapes to be + # fully defined for TPU. Remove this when XLA supports dynamic shapes. + return dataset.batch(batch_size=1, drop_remainder=True) + + optimizer = optimizer_fn() + layer = core.Dense(1, use_bias=True) + + key1 = "foo" + value1 = "bar" + + def model_fn(output_context, x): + """A very simple model written by the user.""" + def loss_fn(): + y = array_ops.reshape(layer(x), []) - constant_op.constant(1.) + return y * y + + train_op = optimizer.minimize(loss_fn) + loss = loss_fn() + output_context.set_last_step_output( + name="tower_loss_agg", + output=loss, + aggregation=variables_lib.VariableAggregation.MEAN) + output_context.set_non_tensor_output(key1, value1) + return (train_op, loss) + + def step_fn(output_context, *inputs): + (train_op, loss) = distribution.call_for_each_tower( + model_fn, output_context, *inputs, run_concurrently=False) + output_context.set_last_step_output( + name="cross_tower_loss_agg", + output=loss, + aggregation=variables_lib.VariableAggregation.MEAN) + output_context.set_last_step_output( + name="cross_tower_loss_noagg", + output=loss) + return distribution.group(train_op) + + iterator = distribution.distribute_dataset( + dataset_fn).make_one_shot_iterator() + + def run_step(): + initial_loss = lambda: constant_op.constant(1e7) + # Initial values corresponding to aggregated losses are just single + # tensors. But for non aggregated losses, we need to have initial + # values that are of the same structure as non reduced losses. In + # MirroredStrategy, this will be a list of losses, in TPUStrategy + # it will be single tensor. Using `broadcast` followed by `unwrap` + # gives us the desired initial value structure. + initial_loop_values = { + "tower_loss_agg": initial_loss(), + "cross_tower_loss_agg": initial_loss(), + "cross_tower_loss_noagg": + distribution.unwrap(distribution.broadcast(initial_loss())) + } + ctx = distribution.run_steps_on_dataset( + step_fn, iterator, iterations=2, + initial_loop_values=initial_loop_values) + + self.assertEqual({key1: [value1]}, ctx.non_tensor_outputs) + self._verify_loss_output( + initial_loss(), + loss_output=ctx.last_step_outputs["tower_loss_agg"], + aggregated=True, distribution=distribution) + self._verify_loss_output( + initial_loss(), + loss_output=ctx.last_step_outputs["cross_tower_loss_agg"], + aggregated=True, distribution=distribution) + self._verify_loss_output( + initial_loss(), + loss_output=ctx.last_step_outputs["cross_tower_loss_noagg"], + aggregated=False, distribution=distribution) + return (ctx.run_op, ctx.last_step_outputs["tower_loss_agg"]) + + self.evaluate(distribution.initialize()) + if not context.executing_eagerly(): with self.test_session() as sess: - sess.run(tpu.shutdown_system()) + run_step = sess.make_callable(run_step()) + self.evaluate(variables_lib.global_variables_initializer()) + + weights, biases, losses = [], [], [] + for _ in range(5): + _, loss = run_step() + losses.append(loss) + weights.append(self.evaluate(layer.kernel)) + biases.append(self.evaluate(layer.bias)) + self.evaluate(distribution.finalize()) + + loss_is_not_increasing = all(y <= x for x, y in zip(losses, losses[1:])) + self.assertTrue(loss_is_not_increasing) + + error = abs( + numpy.add(numpy.squeeze(weights), numpy.squeeze(biases)) - 1) + error_is_not_increasing = all(y <= x for x, y in zip(error, error[1:])) + self.assertTrue(error_is_not_increasing) + + def _verify_loss_output(self, initial_loss, loss_output, aggregated, + distribution): + if not aggregated: + self.assertEqual(distribution.num_towers, + len(distribution.unwrap(loss_output))) + loss_output = distribution.reduce( + aggregation=variables_lib.VariableAggregation.MEAN, + value=loss_output, destinations="/device:CPU:0") + + unwrapped_output = distribution.unwrap(loss_output) + self.assertEqual(1, len(unwrapped_output)) + loss_tensor = unwrapped_output[0] + self.assertEqual(initial_loss.dtype, loss_tensor.dtype) + self.assertEqual(initial_loss.shape, loss_tensor.shape) if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/distribute/python/mirrored_strategy.py b/tensorflow/contrib/distribute/python/mirrored_strategy.py index dcbc6b0878b89cbb5b9779de315429e6f9478d15..6981449a4cc9d15ebc3a0edd145fa5766e9b6503 100644 --- a/tensorflow/contrib/distribute/python/mirrored_strategy.py +++ b/tensorflow/contrib/distribute/python/mirrored_strategy.py @@ -19,22 +19,28 @@ from __future__ import division from __future__ import print_function import contextlib +from functools import partial import threading -import six from tensorflow.contrib.distribute.python import cross_tower_ops as cross_tower_ops_lib from tensorflow.contrib.distribute.python import shared_variable_creator from tensorflow.contrib.distribute.python import values +from tensorflow.core.protobuf import cluster_pb2 from tensorflow.python import pywrap_tensorflow from tensorflow.python.eager import context from tensorflow.python.eager import tape +from tensorflow.python.framework import constant_op from tensorflow.python.framework import device as tf_device from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops +from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import variable_scope +from tensorflow.python.ops import variables as variables_lib from tensorflow.python.training import coordinator from tensorflow.python.training import device_util from tensorflow.python.training import distribute as distribute_lib +from tensorflow.python.training import server_lib +from tensorflow.python.util import nest # TODO(josh11b): Replace asserts in this file with if ...: raise ... @@ -60,25 +66,340 @@ class _RequestedStop(Exception): pass -class MirroredStrategy(distribute_lib.DistributionStrategy): - """Mirrors vars to distribute across multiple devices on a single machine. +# Make _call_for_each_tower and _reduce_non_distributed_value not members of +# MirroredStrategy so that they are generally not allowed to use anything +# specific to MirroredStrategy and thus can be shared with other distribution +# strategies. + + +# TODO(yuefengz): maybe create a common class for those who need to call this +# _call_for_each_tower. +def _call_for_each_tower(distribution, fn, *args, **kwargs): + """Run `fn` in separate threads, once per tower/worker device. + + Args: + distribution: the DistributionStrategy object. + fn: function to run (will be run once per device, each in its own thread). + *args: positional arguments for `fn` + **kwargs: keyword arguments for `fn`. + `"run_concurrently"`: Boolean indicating whether executions of `fn` + can be run concurrently (under eager execution only), defaults to + `True`. - This strategy uses one tower per device and sync replication. + Returns: + Merged return value of `fn` across all towers. + + Raises: + RuntimeError: If fn() calls get_tower_context().merge_call() a different + number of times from the available devices. + """ + run_concurrently = kwargs.pop("run_concurrently", True) + if not context.executing_eagerly(): + # Lots of TF library code isn't thread-safe in graph mode, and + # there is little to be gained by turning on multithreading when + # constructing a graph. + run_concurrently = False + # Needed for per-thread device, etc. contexts in graph mode. + ops.get_default_graph().switch_to_thread_local() + elif run_concurrently is None: + run_concurrently = True + + coord = coordinator.Coordinator(clean_stop_exception_types=(_RequestedStop,)) + + shared_variable_store = {} + + # TODO(isaprykin): Create these threads once instead of during every run() + # call. + threads = [] + for index, d in enumerate(distribution.worker_devices): + variable_creator_fn = shared_variable_creator.make_fn( + shared_variable_store, index) + t = MirroredStrategy._MirroredTowerThread( # pylint: disable=protected-access + distribution, coord, d, variable_creator_fn, fn, + *values.select_device(d, args), **values.select_device(d, kwargs)) + threads.append(t) + + for t in threads: + t.start() + + # When `fn` starts `should_run` event is set on _MirroredTowerThread + # (`MTT`) threads. The execution waits until + # `MTT.has_paused` is set, which indicates that either `fn` is + # complete or a `get_tower_context().merge_call()` is called. If `fn` is + # complete, then `MTT.done` is set to True. Otherwise, arguments + # of `get_tower_context().merge_call` from all paused threads are grouped + # and the `merge_fn` is performed. Results of the + # `get_tower_context().merge_call` are then set to `MTT.merge_result`. + # Each such `get_tower_context().merge_call` call returns the + # `MTT.merge_result` for that thread when `MTT.should_run` event + # is reset again. Execution of `fn` resumes. + + try: + with coord.stop_on_exception(): + all_done = False + while not all_done and not coord.should_stop(): + done = [] + if run_concurrently: + for t in threads: + t.should_run.set() + for t in threads: + t.has_paused.wait() + t.has_paused.clear() + if coord.should_stop(): + return None + done.append(t.done) + else: + for t in threads: + t.should_run.set() + t.has_paused.wait() + t.has_paused.clear() + if coord.should_stop(): + return None + done.append(t.done) + if coord.should_stop(): + return None + all_done = all(done) + if not all_done: + if any(done): + raise RuntimeError("Some towers made a different number of " + "tower_context().merge_call() calls.") + # get_tower_context().merge_call() case + merge_args = values.regroup({t.device: t.merge_args for t in threads}) + merge_kwargs = values.regroup( + {t.device: t.merge_kwargs for t in threads}) + # We capture the name_scope of the MTT when we call merge_fn + # to ensure that if we have opened a name scope in the MTT, + # it will be respected when executing the merge function. We only + # capture the name_scope from the first MTT and assume it is + # the same for all other MTTs. + mtt_captured_name_scope = threads[0].captured_name_scope + with ops.name_scope(mtt_captured_name_scope): + merge_result = threads[0].merge_fn(distribution, *merge_args, + **merge_kwargs) + for t in threads: + t.merge_result = values.select_device(t.device, merge_result) + finally: + for t in threads: + t.should_run.set() + coord.join(threads) + + return values.regroup({t.device: t.main_result for t in threads}) + + +def _reduce_non_distributed_value(distribution, aggregation, value, + destinations): + """Reduce a non-DistributedValue `value` to `destinations`.""" + if isinstance(value, values.DistributedValues): + raise ValueError("You are passing a `DistributedValue` to " + "`_reduce_non_distributed_value`, which is not allowed.") + + # If the same value is present on all towers then the PerDevice value will + # be a single value. We also handle the case when `value` is a single value + # and equal to 0. + if value == 0: + return 0 + # If the aggregation type is MEAN, then this essentially means that the same + # value should be on all destinations. + if aggregation == variable_scope.VariableAggregation.MEAN: + return distribution.broadcast(value, destinations) + + cross_tower_ops_lib.validate_destinations(destinations) + # We do not support an aggregation type of SUM if the value is the same across + # all towers. We call this as part of assign functions for MirroredVariables + # and summing up identical values across towers is not clearly defined. + if (len(distribution.worker_devices) != 1 or + not cross_tower_ops_lib.check_destinations(destinations)): + raise ValueError("A non-DistributedValues value cannot be reduced with the " + "given aggregation.") + # TODO(anjalisridhar): Moves these methods to a device utility file? + devices = cross_tower_ops_lib.get_devices_from(destinations) + if len(devices) == 1: + with ops.device(devices[0]): + return array_ops.identity(value) + else: + value_updates = {} + for d in devices: + with ops.device(d): + value_updates[d] = array_ops.identity(value) + return values.Mirrored(value_updates) + + +def _create_mirrored_variable(devices, real_mirrored_creator, *args, **kwargs): # pylint: disable=g-missing-docstring + # Figure out what collections this variable should be added to. + # We'll add the MirroredVariable to those collections instead. + collections = kwargs.pop("collections", None) + if collections is None: + collections = [ops.GraphKeys.GLOBAL_VARIABLES] + kwargs["collections"] = [] + + # Get synchronization value + synchronization = kwargs.get("synchronization", + variable_scope.VariableSynchronization.ON_WRITE) + if synchronization == variable_scope.VariableSynchronization.NONE: + raise ValueError("`NONE` variable synchronization mode is not " + "supported with `Mirrored` distribution strategy. Please" + " change the `synchronization` for variable: " + + kwargs["name"]) + elif synchronization == variable_scope.VariableSynchronization.ON_READ: + # Variables that are to be synced on read are tower local. + is_tower_local = True + kwargs["trainable"] = False + elif (synchronization == variable_scope.VariableSynchronization.ON_WRITE or + synchronization == variable_scope.VariableSynchronization.AUTO): + # `AUTO` synchronization for `MirroredStrategy` is `ON_WRITE`. + is_tower_local = False + else: + raise ValueError("Invalid variable synchronization mode: " + + synchronization + " for variable: " + kwargs["name"]) + + # Get aggregation value + aggregation = kwargs.pop("aggregation", + variable_scope.VariableAggregation.NONE) + if aggregation not in [ + variable_scope.VariableAggregation.NONE, + variable_scope.VariableAggregation.SUM, + variable_scope.VariableAggregation.MEAN + ]: + raise ValueError("Invalid variable aggregation mode: " + aggregation + + " for variable: " + kwargs["name"]) + + # Ignore user-specified caching device, not needed for mirrored variables. + kwargs.pop("caching_device", None) + + # TODO(josh11b,apassos): It would be better if variable initialization + # was never recorded on the tape instead of having to do this manually + # here. + with tape.stop_recording(): + index = real_mirrored_creator(devices, *args, **kwargs) + + if is_tower_local: + result = values.TowerLocalVariable(index, index[devices[0]], aggregation) + else: + result = values.MirroredVariable(index, index[devices[0]], aggregation) + + if not context.executing_eagerly(): + g = ops.get_default_graph() + # If "trainable" is True, next_creator() will add the member variables + # to the TRAINABLE_VARIABLES collection, so we manually remove + # them and replace with the MirroredVariable. We can't set + # "trainable" to False for next_creator() since that causes functions + # like implicit_gradients to skip those variables. + if kwargs.get("trainable", True): + collections.append(ops.GraphKeys.TRAINABLE_VARIABLES) + l = g.get_collection_ref(ops.GraphKeys.TRAINABLE_VARIABLES) + for v in index.values(): + l.remove(v) + g.add_to_collections(collections, result) + return result + + +class MirroredStrategy(distribute_lib.DistributionStrategy): + """Mirrors vars to distribute across multiple devices and machines. + + This strategy uses one tower per device and sync replication for its multi-GPU + version. + + When `cluster_spec` is given, it turns into the mulit-worker version that + works on multiple workers with in-graph replication. + + There are several important concepts for distributed TensorFlow, e.g. + `client`, `job`, 'task', `cluster`, `in-graph replication` and + 'synchronous training' and they have already been defined in the + [TensorFlow's documentation](https://www.tensorflow.org/deploy/distributed). + The distribution strategy inherits these concepts as well and in addition to + that we also clarify several more concepts: + * **In-graph replication**: the `client` creates a single `tf.Graph` that + specifies tasks for devices on all workers. The `client` then creates a + client session which will talk to the `master` service of a `worker`. Then + the `master` will partition the graph and distribute the work to all + participating workers. + * **Worker**: A `worker` is a TensorFlow `task` that usually maps to one + physical machine. We will have multiple `worker`s with different `task` + index. They all do similar things except for one worker checkpointing model + variables, writing summaries, etc. in addition to its ordinary work. + + The multi-worker version of this class maps one tower to one device on a + worker. It mirrors all model variables on all towers. For example, if you have + two `worker`s and each `worker` has 4 GPUs, it will create 8 copies of the + model variables on these 8 GPUs. Then like in MirroredStrategy, each tower + performs their computation with their own copy of variables unless in + cross-tower model where variable or tensor reduction happens. + + Args: + devices: a list of device strings. + num_gpus: number of GPUs. For local training, either specify `devices` or + `num_gpus`. In distributed training, this must be specified as number of + GPUs on each worker. + cluster_spec: if this is set, it turns into the multi-worker version and + `devices` must not be set but `num_gpus` must be set. + cross_tower_ops: optional, a descedant of `CrossTowerOps`. If this is not + set, the `configure` method will try to find the best one. + prefetch_on_device: optional boolean to specify whether to prefetch input + data to devices. """ def __init__(self, devices=None, num_gpus=None, + cluster_spec=None, cross_tower_ops=None, prefetch_on_device=None): super(MirroredStrategy, self).__init__() - # Convert `num_gpus` into `devices`, shouldn't specify both. - if devices is None: + + if cluster_spec: + if devices is not None: + raise ValueError("Specifying devices when `cluster_spec` is also given " + "is not supported in MirroredStrategy.") + + # TODO(yuefengz): use the utility method to normalize cluster_spec. + if isinstance(cluster_spec, (dict, cluster_pb2.ClusterDef)): + cluster_spec = server_lib.ClusterSpec(cluster_spec) + elif not isinstance(cluster_spec, server_lib.ClusterSpec): + raise ValueError( + "`cluster_spec' should be dict or a `tf.train.ClusterSpec` or a " + "`tf.train.ClusterDef` object") + self._cluster_spec = cluster_spec + + self._workers = [] + for job in sorted(cluster_spec.jobs): + for task in range(cluster_spec.num_tasks(job)): + self._workers.append("/job:%s/task:%d" % (job, task)) + if num_gpus is None: - num_gpus = context.num_gpus() - devices = ["/device:GPU:%d" % d for d in range(num_gpus)] - elif num_gpus is not None: - raise ValueError("Must only specify one of `devices` and `num_gpus`.") + raise ValueError("`num_gpus` is required if `cluster_spec` is given.") + self._num_gpus = num_gpus + if num_gpus > 0: + self._worker_device_map = { + worker: [ + device_util.canonicalize(worker + "/device:GPU:%d" % gpu) + for gpu in range(num_gpus) + ] for worker in self._workers + } + else: + self._worker_device_map = { + worker: [device_util.canonicalize(worker, "/device:CPU:0")] + for worker in self._workers + } + devices = nest.flatten(self._worker_device_map) + + # Setting `_default_device` will add a device scope in the + # distribution.scope. We set the default device to the first worker. When + # users specify device under distribution.scope by + # with tf.device("/cpu:0"): + # ... + # their ops will end up on the cpu device of its first worker, e.g. + # "/job:worker/task:0/device:CPU:0". Note this is not used in tower mode. + self._default_device = self._workers[0] + else: + self._cluster_spec = None + # Convert `num_gpus` into `devices`, shouldn't specify both. + if devices is None: + if num_gpus is None: + num_gpus = context.num_gpus() + devices = ["/device:GPU:%d" % d for d in range(num_gpus)] + elif num_gpus is not None: + raise ValueError("Must only specify one of `devices` and `num_gpus`.") + # TODO(yuefengz): consider setting the default device. assert devices, "Must specify at least one device." assert len(set(devices)) == len(devices), ( @@ -87,61 +408,16 @@ class MirroredStrategy(distribute_lib.DistributionStrategy): self._devices = [device_util.resolve(d) for d in devices] self._canonical_device_set = set(self._devices) self._device_index = values.PerDevice( - dict((d, i) for i, d in enumerate(devices))) + {d: i for i, d in enumerate(devices)}) self._cross_tower_ops = cross_tower_ops self._prefetch_on_device = prefetch_on_device - # TODO(yuefengz): consider setting the default device. def _create_variable(self, next_creator, *args, **kwargs): """Create a mirrored variable. See `DistributionStrategy.scope`.""" - # Figure out what collections this variable should be added to. - # We'll add the MirroredVariable to those collections instead. - collections = kwargs.pop("collections", None) - if collections is None: - collections = [ops.GraphKeys.GLOBAL_VARIABLES] - kwargs["collections"] = [] - colocate_with = kwargs.pop("colocate_with", None) devices = self._get_devices_from(colocate_with) - # Get synchronization value - synchronization = kwargs.get( - "synchronization", variable_scope.VariableSynchronization.ON_WRITE) - if synchronization == variable_scope.VariableSynchronization.NONE: - raise ValueError("`NONE` variable synchronization mode is not " - "supported with `Mirrored` distribution strategy. Please" - " change the `synchronization` for variable: " + - kwargs["name"]) - elif synchronization == variable_scope.VariableSynchronization.ON_READ: - # Variables that are to be synced on read are tower local. - is_tower_local = True - kwargs["trainable"] = False - elif (synchronization == variable_scope.VariableSynchronization.ON_WRITE or - synchronization == variable_scope.VariableSynchronization.AUTO): - # `AUTO` synchronization for `MirroredStrategy` is `ON_WRITE`. - is_tower_local = False - else: - raise ValueError("Invalid variable synchronization mode: " + - synchronization + " for variable: " + kwargs["name"]) - - # Get aggregation value - aggregation = kwargs.pop("aggregation", - variable_scope.VariableAggregation.NONE) - if aggregation not in [ - variable_scope.VariableAggregation.NONE, - variable_scope.VariableAggregation.SUM, - variable_scope.VariableAggregation.MEAN - ]: - raise ValueError("Invalid variable aggregation mode: " + aggregation + - " for variable: " + kwargs["name"]) - - # Ignore user-specified caching device, not needed for mirrored variables. - kwargs.pop("caching_device", None) - - # TODO(josh11b,apassos): It would be better if variable initialization - # was never recorded on the tape instead of having to do this manually - # here. - with tape.stop_recording(): + def _real_mirrored_creator(devices, *args, **kwargs): # pylint: disable=g-missing-docstring index = {} for i, d in enumerate(devices): with ops.device(d): @@ -165,32 +441,80 @@ class MirroredStrategy(distribute_lib.DistributionStrategy): v = next_creator(*args, **kwargs) assert not isinstance(v, values.DistributedVariable) index[d] = v + return index - if is_tower_local: - result = values.TowerLocalVariable(index, index[devices[0]], - aggregation) - else: - result = values.MirroredVariable(index, index[devices[0]], aggregation) - - if not context.executing_eagerly(): - g = ops.get_default_graph() - # If "trainable" is True, next_creator() will add the member variables - # to the TRAINABLE_VARIABLES collection, so we manually remove - # them and replace with the MirroredVariable. We can't set - # "trainable" to False for next_creator() since that causes functions - # like implicit_gradients to skip those variables. - if kwargs.get("trainable", True): - collections.append(ops.GraphKeys.TRAINABLE_VARIABLES) - l = g.get_collection_ref(ops.GraphKeys.TRAINABLE_VARIABLES) - for v in index.values(): - l.remove(v) - g.add_to_collections(collections, result) - return result + return _create_mirrored_variable(devices, _real_mirrored_creator, *args, + **kwargs) def distribute_dataset(self, dataset_fn): - return values.PerDeviceDataset( - self._call_dataset_fn(dataset_fn), self._devices, - self._prefetch_on_device) + if self._cluster_spec: + return values.MultiWorkerDataset( + partial(self._call_dataset_fn, dataset_fn), self._worker_device_map, + self._prefetch_on_device) + else: + return values.PerDeviceDataset( + self._call_dataset_fn(dataset_fn), self._devices, + self._prefetch_on_device) + + # TODO(priyag): Deal with OutOfRange errors once b/111349762 is fixed. + def _run_steps_on_dataset(self, fn, iterator, iterations, + initial_loop_values=None): + if initial_loop_values is None: + initial_loop_values = {} + initial_loop_values = nest.flatten(initial_loop_values) + + ctx = values.MultiStepContext() + def body(i, *args): + """A wrapper around `fn` to create the while loop body.""" + del args + fn_inputs = iterator.get_next() + if not isinstance(fn_inputs, tuple): + fn_inputs = (fn_inputs,) + fn_result = fn(ctx, *fn_inputs) + for (name, output) in ctx.last_step_outputs.items(): + # Convert all outputs to tensors, potentially from `DistributedValues`. + ctx.last_step_outputs[name] = self.unwrap(output) + flat_last_step_outputs = nest.flatten(ctx.last_step_outputs) + with ops.control_dependencies([fn_result]): + return [i + 1] + flat_last_step_outputs + + # We capture the control_flow_context at this point, before we run `fn` + # inside a while_loop. This is useful in cases where we might need to exit + # these contexts and get back to the outer context to do some things, for + # e.g. create an op which should be evaluated only once at the end of the + # loop on the host. One such usage is in creating metrics' value op. + self._outer_control_flow_context = ( + ops.get_default_graph()._get_control_flow_context()) # pylint: disable=protected-access + + cond = lambda i, *args: i < iterations + i = constant_op.constant(0) + loop_result = control_flow_ops.while_loop( + cond, body, [i] + initial_loop_values, name="", + parallel_iterations=1, back_prop=False, swap_memory=False, + return_same_structure=True) + del self._outer_control_flow_context + + ctx.run_op = control_flow_ops.group(loop_result) + + # Convert the last_step_outputs from a list to the original dict structure + # of last_step_outputs. + last_step_tensor_outputs = loop_result[1:] + last_step_tensor_outputs_dict = nest.pack_sequence_as( + ctx.last_step_outputs, last_step_tensor_outputs) + + for (name, aggregation) in ctx._last_step_outputs_aggregations.items(): # pylint: disable=protected-access + output = last_step_tensor_outputs_dict[name] + # For outputs that have already been aggregated, wrap them in a Mirrored + # container, else in a PerDevice container. + if aggregation is variables_lib.VariableAggregation.NONE: + last_step_tensor_outputs_dict[name] = values.regroup( + {d: t for d, t in zip(self._devices, output)}, values.PerDevice) + else: + assert len(output) == 1 + last_step_tensor_outputs_dict[name] = output[0] + + ctx._set_last_step_outputs(last_step_tensor_outputs_dict) # pylint: disable=protected-access + return ctx def _broadcast(self, tensor, destinations): # TODO(josh11b): In eager mode, use one thread per device, or async mode. @@ -198,116 +522,7 @@ class MirroredStrategy(distribute_lib.DistributionStrategy): self._devices) def _call_for_each_tower(self, fn, *args, **kwargs): - """Run `fn` in separate threads, once per tower/worker device. - - Args: - fn: function to run (will be run once per device, each in its own thread). - *args: positional arguments for `fn` - **kwargs: keyword arguments for `fn`. - `"run_concurrently"`: Boolean indicating whether executions of `fn` - can be run concurrently (under eager execution only), defaults to - `True`. - - Returns: - Merged return value of `fn` across all towers. - - Raises: - RuntimeError: If fn() calls get_tower_context().merge_call() a different - number of times for when called for different devices. - """ - run_concurrently = kwargs.pop("run_concurrently", True) - if not context.executing_eagerly(): - # Lots of TF library code isn't thread-safe in graph mode, and - # there is little to be gained by turning on multithreading when - # constructing a graph. - run_concurrently = False - # Needed for per-thread device, etc. contexts in graph mode. - ops.get_default_graph().switch_to_thread_local() - elif run_concurrently is None: - run_concurrently = True - - coord = coordinator.Coordinator( - clean_stop_exception_types=(_RequestedStop,)) - - shared_variable_store = {} - - # TODO(isaprykin): Create these threads once instead of during every run() - # call. - threads = [] - for index, d in enumerate(self._devices): - variable_creator_fn = shared_variable_creator.make_fn( - shared_variable_store, index) - t = MirroredStrategy._MirroredTowerThread( - self, coord, d, variable_creator_fn, fn, - *values.select_device(d, args), **values.select_device(d, kwargs)) - threads.append(t) - - for t in threads: - t.start() - - # When `fn` starts `should_run` event is set on _MirroredTowerThread - # (`MTT`) threads. The execution waits until - # `MTT.has_paused` is set, which indicates that either `fn` is - # complete or a `get_tower_context().merge_call()` is called. If `fn` is - # complete, then `MTT.done` is set to True. Otherwise, arguments - # of `get_tower_context().merge_call` from all paused threads are grouped - # and the `merge_fn` is performed. Results of the - # `get_tower_context().merge_call` are then set to `MTT.merge_result`. - # Each such `get_tower_context().merge_call` call returns the - # `MTT.merge_result` for that thread when `MTT.should_run` event - # is reset again. Execution of `fn` resumes. - - try: - with coord.stop_on_exception(): - all_done = False - while not all_done and not coord.should_stop(): - done = [] - if run_concurrently: - for t in threads: - t.should_run.set() - for t in threads: - t.has_paused.wait() - t.has_paused.clear() - if coord.should_stop(): - return None - done.append(t.done) - else: - for t in threads: - t.should_run.set() - t.has_paused.wait() - t.has_paused.clear() - if coord.should_stop(): - return None - done.append(t.done) - if coord.should_stop(): - return None - all_done = all(done) - if not all_done: - if any(done): - raise RuntimeError("Some towers made a different number of " - "tower_context().merge_call() calls.") - # get_tower_context().merge_call() case - merge_args = values.regroup( - {t.device: t.merge_args for t in threads}) - merge_kwargs = values.regroup( - {t.device: t.merge_kwargs for t in threads}) - # We capture the name_scope of the MTT when we call merge_fn - # to ensure that if we have opened a name scope in the MTT, - # it will be respected when executing the merge function. We only - # capture the name_scope from the first MTT and assume it is - # the same for all other MTTs. - mtt_captured_name_scope = threads[0].captured_name_scope - with ops.name_scope(mtt_captured_name_scope): - merge_result = threads[0].merge_fn( - self, *merge_args, **merge_kwargs) - for t in threads: - t.merge_result = values.select_device(t.device, merge_result) - finally: - for t in threads: - t.should_run.set() - coord.join(threads) - - return values.regroup({t.device: t.main_result for t in threads}) + return _call_for_each_tower(self, fn, *args, **kwargs) def map(self, map_over, fn, *args, **kwargs): # TODO(josh11b): In eager mode, use one thread per device. @@ -324,10 +539,19 @@ class MirroredStrategy(distribute_lib.DistributionStrategy): # in addition to PerDevice data. return values.PerDevice({k: values.MapOutput(v) for k, v in index.items()}) - def configure(self, session_config=None): + def configure(self, + session_config=None, + cluster_spec=None, + task_type=None, + task_id=None): + del cluster_spec, task_type, task_id if self._cross_tower_ops is None: - self._cross_tower_ops = cross_tower_ops_lib.choose_the_best( - self._devices, session_config=session_config) + if self._cluster_spec: + self._cross_tower_ops = cross_tower_ops_lib.MultiWorkerAllReduce( + self._workers, self._num_gpus) + else: + self._cross_tower_ops = cross_tower_ops_lib.choose_the_best( + self._devices, session_config=session_config) def _get_cross_tower_ops(self): if self._cross_tower_ops is None: @@ -337,29 +561,12 @@ class MirroredStrategy(distribute_lib.DistributionStrategy): def _reduce(self, aggregation, value, destinations): assert not isinstance(value, values.Mirrored) - if not isinstance(value, values.PerDevice): - if value == 0: - return 0 - if aggregation == variable_scope.VariableAggregation.MEAN: - return self._broadcast(value, destinations) - - cross_tower_ops_lib.validate_destinations(destinations) - if len(self._devices) == 1: - if destinations: - # TODO(anjalisridhar): Moves these methods to a device utility file? - devices = cross_tower_ops_lib.get_devices_from(destinations) - if len(devices) == 1: - with ops.device(devices[0]): - return array_ops.identity(value) - else: - value_updates = {} - for d in devices: - with ops.device(d): - value_updates[d] = array_ops.identity(value) - return values.Mirrored(value_updates) - raise ValueError("A non PerDevice value cannot be reduced with the given " - "aggregation.") - + if not isinstance(value, values.DistributedValues): + # This function handles reducing values that are not PerDevice or Mirrored + # values. For example, the same value could be present on all towers in + # which case `value` would be a single value or value could be 0. + return _reduce_non_distributed_value(self, aggregation, value, + destinations) return self._get_cross_tower_ops().reduce( aggregation, value, destinations=destinations) @@ -406,6 +613,9 @@ class MirroredStrategy(distribute_lib.DistributionStrategy): return [val.get(device=d) for d in sorted(val.devices)] return [val] + def value_container(self, val): + return values.value_container(val) + @property def is_single_tower(self): return len(self._devices) == 1 @@ -433,15 +643,8 @@ class MirroredStrategy(distribute_lib.DistributionStrategy): def _get_devices_from(self, colocate_with=None): if colocate_with is None: return self._devices - elif isinstance(colocate_with, values.DistributedValues): - # pylint: disable=protected-access - return list(colocate_with._index.keys()) - elif isinstance(colocate_with, six.string_types): - return [device_util.resolve(colocate_with)] - elif isinstance(colocate_with, list): - return [device_util.resolve(d) for d in colocate_with] else: - return colocate_with + return cross_tower_ops_lib.get_devices_from(colocate_with) class _MirroredTowerThread(threading.Thread): """A thread that runs() a function on a device.""" diff --git a/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py b/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py index 9807ce43515a9f1000f62c279f9dcf16491e4fba..9a4cc0a8975c39cf82e474d660968afc17991db0 100644 --- a/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py +++ b/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py @@ -25,7 +25,9 @@ from tensorflow.contrib.distribute.python import strategy_test_lib from tensorflow.contrib.distribute.python import values from tensorflow.core.protobuf import config_pb2 from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.eager import backprop from tensorflow.python.eager import context +from tensorflow.python.eager import function from tensorflow.python.eager import test from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes @@ -37,7 +39,8 @@ from tensorflow.python.ops import rnn from tensorflow.python.ops import rnn_cell_impl from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variables -from tensorflow.python.training import distribute as distribute_lib +from tensorflow.python.training import device_util +from tensorflow.python.training import distribution_strategy_context GPU_TEST = "test_gpu" in sys.argv[0] @@ -161,7 +164,7 @@ class MirroredStrategyVariableCreationTest(test.TestCase): # This variable should be created only once across the threads because of # special variable_creator functions used by `dist.call_for_each_tower`. v = variable_scope.variable(1.0, name="foo") - distribute_lib.get_tower_context().merge_call(lambda _: _) + distribution_strategy_context.get_tower_context().merge_call(lambda _: _) return v dist = mirrored_strategy.MirroredStrategy( @@ -178,7 +181,7 @@ class MirroredStrategyVariableCreationTest(test.TestCase): def model_fn(): v = variable_scope.variable(1.0) - distribute_lib.get_tower_context().merge_call(lambda _: _) + distribution_strategy_context.get_tower_context().merge_call(lambda _: _) return v dist = mirrored_strategy.MirroredStrategy( @@ -198,7 +201,7 @@ class MirroredStrategyVariableCreationTest(test.TestCase): vs = [] for i in range(5): vs.append(variable_scope.variable(1.0, name="foo" + str(i))) - distribute_lib.get_tower_context().merge_call(lambda _: _) + distribution_strategy_context.get_tower_context().merge_call(lambda _: _) return vs dist = mirrored_strategy.MirroredStrategy( @@ -220,7 +223,7 @@ class MirroredStrategyVariableCreationTest(test.TestCase): vs.append(variable_scope.variable(1.0, name="foo_1/bar")) vs.append(variable_scope.variable(1.0, name="foo_1/bar_1")) vs.append(variable_scope.variable(1.0, name="foo/bar_1")) - distribute_lib.get_tower_context().merge_call(lambda _: _) + distribution_strategy_context.get_tower_context().merge_call(lambda _: _) return vs dist = mirrored_strategy.MirroredStrategy( @@ -242,7 +245,7 @@ class MirroredStrategyVariableCreationTest(test.TestCase): def model_fn(device_id): v = variable_scope.variable(1.0, name="foo_" + str(device_id)) - distribute_lib.get_tower_context().merge_call(lambda _: _) + distribution_strategy_context.get_tower_context().merge_call(lambda _: _) return v dist = mirrored_strategy.MirroredStrategy( @@ -265,7 +268,8 @@ class MirroredStrategyVariableCreationTest(test.TestCase): layer2 = core.Dense(1) layer2(features) # This will pause the current thread, and execute the other thread. - distribute_lib.get_tower_context().merge_call(lambda _: _) + distribution_strategy_context.get_tower_context().merge_call( + lambda _: _) layer3 = core.Dense(1) layer3(features) return [(layer1.kernel, layer1.bias), @@ -297,7 +301,8 @@ class MirroredStrategyVariableCreationTest(test.TestCase): with variable_scope.variable_scope("common"): v1 = variable_scope.variable(1.0, name="var1") # This will pause the current thread, and execute the other thread. - distribute_lib.get_tower_context().merge_call(lambda _: _) + distribution_strategy_context.get_tower_context().merge_call( + lambda _: _) v2 = variable_scope.variable( 1.0, name="var2", @@ -340,7 +345,8 @@ class MirroredStrategyVariableCreationTest(test.TestCase): with variable_scope.variable_scope("common"): v1 = variable_scope.get_variable("var1", [1]) # This will pause the current thread, and execute the other thread. - distribute_lib.get_tower_context().merge_call(lambda _: _) + distribution_strategy_context.get_tower_context().merge_call( + lambda _: _) v2 = variable_scope.get_variable( "var2", [1], synchronization=variable_scope.VariableSynchronization.ON_READ, @@ -450,7 +456,7 @@ class MirroredStrategyVariableCreationTest(test.TestCase): def model_fn(): v = variable_scope.variable(1.0, name="foo") - distribute_lib.get_tower_context().merge_call(lambda _: _) + distribution_strategy_context.get_tower_context().merge_call(lambda _: _) return v dist = mirrored_strategy.MirroredStrategy( @@ -467,7 +473,7 @@ class MirroredStrategyVariableCreationTest(test.TestCase): def model_fn(name): v = variable_scope.variable(1.0, name=name) - distribute_lib.get_tower_context().merge_call(lambda _: _) + distribution_strategy_context.get_tower_context().merge_call(lambda _: _) return v dist = mirrored_strategy.MirroredStrategy( @@ -567,7 +573,8 @@ class MirroredStrategyVariableCreationTest(test.TestCase): def model_fn(): with ops.name_scope("foo"): a = constant_op.constant(1.0, name="a") - distribute_lib.get_tower_context().merge_call(lambda _: _) + distribution_strategy_context.get_tower_context().merge_call( + lambda _: _) b = constant_op.constant(1.0, name="b") return a, b @@ -588,7 +595,8 @@ class MirroredStrategyVariableCreationTest(test.TestCase): def model_fn(): with ops.name_scope(None, "foo"): a = constant_op.constant(1.0, name="a") - distribute_lib.get_tower_context().merge_call(lambda _: _) + distribution_strategy_context.get_tower_context().merge_call( + lambda _: _) b = constant_op.constant(2.0, name="b") return a, b @@ -616,7 +624,8 @@ class MirroredStrategyVariableCreationTest(test.TestCase): def model_fn(): b = variable_scope.variable(1.0, name="b") with ops.name_scope("foo"): - c = distribute_lib.get_tower_context().merge_call(in_cross_tower) + c = distribution_strategy_context.get_tower_context().merge_call( + in_cross_tower) return b, c dist = mirrored_strategy.MirroredStrategy( @@ -648,7 +657,8 @@ class MirroredStrategyVariableCreationTest(test.TestCase): def model_fn(): b = variable_scope.get_variable("b", [1]) with ops.name_scope("foo"): - c = distribute_lib.get_tower_context().merge_call(in_cross_tower) + c = distribution_strategy_context.get_tower_context().merge_call( + in_cross_tower) return b, c dist = mirrored_strategy.MirroredStrategy( @@ -792,8 +802,8 @@ class MirroredVariableUpdateTest(test.TestCase): return mirrored_var.assign(5.0) with self.assertRaisesRegexp( - ValueError, "A non PerDevice value cannot be reduced with the given " - "aggregation."): + ValueError, "A non-DistributedValues value cannot be reduced with " + "the given aggregation."): self.evaluate(dist.unwrap(dist.call_for_each_tower(model_fn))) @test_util.run_in_graph_and_eager_modes(config=config) @@ -830,14 +840,38 @@ class MirroredVariableUpdateTest(test.TestCase): self.assertEquals(1.0, self.evaluate(mirrored_var)) def model_fn(): - value = math_ops.cast(distribute_lib.get_tower_context().tower_id, - mirrored_var.dtype) + value = math_ops.cast( + distribution_strategy_context.get_tower_context().tower_id, + mirrored_var.dtype) return mirrored_var.assign(value) self.evaluate(dist.unwrap(dist.call_for_each_tower( model_fn, run_concurrently=False))) self.assertEquals(0.5, self.evaluate(mirrored_var)) + @test_util.run_in_graph_and_eager_modes(config=config) + def testAssignMirroredVarTowerContextWithSingleValue(self): + self._skip_eager_if_gpus_less_than(1) + def var_fn(): + return variable_scope.variable( + 1.0, name="foo", aggregation=variable_scope.VariableAggregation.MEAN) + + dist = mirrored_strategy.MirroredStrategy( + ["/device:GPU:0", "/device:CPU:0"]) + + with dist.scope(): + mirrored_var = dist.call_for_each_tower(var_fn, run_concurrently=False) + self.assertIsInstance(mirrored_var, values.MirroredVariable) + self.evaluate(variables.global_variables_initializer()) + self.assertEquals(1.0, self.evaluate(mirrored_var)) + + def model_fn(): + return mirrored_var.assign(5.0) + + self.evaluate(dist.unwrap(dist.call_for_each_tower( + model_fn, run_concurrently=False))) + self.assertEquals(5.0, self.evaluate(mirrored_var)) + @test_util.run_in_graph_and_eager_modes(config=config) def testAssignAddMirroredVarCrossTowerContext(self): self._skip_eager_if_gpus_less_than(1) @@ -872,14 +906,38 @@ class MirroredVariableUpdateTest(test.TestCase): self.assertEquals(1.0, self.evaluate(mirrored_var)) def model_fn(): - value = math_ops.cast(distribute_lib.get_tower_context().tower_id, - mirrored_var.dtype) + value = math_ops.cast( + distribution_strategy_context.get_tower_context().tower_id, + mirrored_var.dtype) return mirrored_var.assign_add(value) self.evaluate(dist.unwrap(dist.call_for_each_tower( model_fn, run_concurrently=False))) self.assertEquals(1.5, self.evaluate(mirrored_var)) + @test_util.run_in_graph_and_eager_modes(config=config) + def testAssignAddMirroredVarTowerContextWithSingleValue(self): + self._skip_eager_if_gpus_less_than(1) + def var_fn(): + return variable_scope.variable( + 1.0, name="foo", aggregation=variable_scope.VariableAggregation.MEAN) + + dist = mirrored_strategy.MirroredStrategy( + ["/device:GPU:0", "/device:CPU:0"]) + + with dist.scope(): + mirrored_var = dist.call_for_each_tower(var_fn, run_concurrently=False) + self.assertIsInstance(mirrored_var, values.MirroredVariable) + self.evaluate(variables.global_variables_initializer()) + self.assertEquals(1.0, self.evaluate(mirrored_var)) + + def model_fn(): + return mirrored_var.assign_add(5.0) + + self.evaluate(dist.unwrap(dist.call_for_each_tower( + model_fn, run_concurrently=False))) + self.assertEquals(6.0, self.evaluate(mirrored_var)) + @test_util.run_in_graph_and_eager_modes(config=config) def testAssignSubMirroredVarCrossTowerContext(self): self._skip_eager_if_gpus_less_than(1) @@ -914,14 +972,38 @@ class MirroredVariableUpdateTest(test.TestCase): self.assertEquals(5.0, self.evaluate(mirrored_var)) def model_fn(): - value = math_ops.cast(distribute_lib.get_tower_context().tower_id, - mirrored_var.dtype) + value = math_ops.cast( + distribution_strategy_context.get_tower_context().tower_id, + mirrored_var.dtype) return mirrored_var.assign_sub(value) self.evaluate(dist.unwrap(dist.call_for_each_tower( model_fn, run_concurrently=False))) self.assertEquals(4.5, self.evaluate(mirrored_var)) + @test_util.run_in_graph_and_eager_modes(config=config) + def testAssignSubMirroredVarTowerContextWithSingleValue(self): + self._skip_eager_if_gpus_less_than(1) + def var_fn(): + return variable_scope.variable( + 5.0, name="foo", aggregation=variable_scope.VariableAggregation.MEAN) + + dist = mirrored_strategy.MirroredStrategy( + ["/device:GPU:0", "/device:CPU:0"]) + + with dist.scope(): + mirrored_var = dist.call_for_each_tower(var_fn, run_concurrently=False) + self.assertIsInstance(mirrored_var, values.MirroredVariable) + self.evaluate(variables.global_variables_initializer()) + self.assertEquals(5.0, self.evaluate(mirrored_var)) + + def model_fn(): + return mirrored_var.assign_sub(1.0) + + self.evaluate(dist.unwrap(dist.call_for_each_tower( + model_fn, run_concurrently=False))) + self.assertEquals(4.0, self.evaluate(mirrored_var)) + class MirroredAndTowerLocalVariableInitializerTest(test.TestCase): config = config_pb2.ConfigProto() @@ -974,7 +1056,7 @@ class TowerLocalVariableAssignTest(test.TestCase): def _skip_eager_if_gpus_less_than(self, num_gpus): if context.num_gpus() < num_gpus and context.executing_eagerly(): - self.skipTest("Enough GPUs not available for this test in eager mode.") + self.skipTest("Not enough GPUs available for this test in eager mode.") @test_util.run_in_graph_and_eager_modes(config=config) def testAssignTowerLocalVarSumAggregation(self): @@ -1036,5 +1118,131 @@ class TowerLocalVariableAssignTest(test.TestCase): self.assertEqual(6.0, self.evaluate(dist.read_var(tower_local_var))) +class MockModel(object): + + def __init__(self, two_variables=False): + self.variables = [] + self.variables.append(variable_scope.variable(1.25, name="dummy_var1")) + if two_variables: + self.variables.append(variable_scope.variable(2.0, name="dummy_var2")) + + def __call__(self, factor=2): + x = factor * self.variables[0] + if len(self.variables) > 1: + x += self.variables[1] + return x + + +class MirroredStrategyDefunTest(test.TestCase): + + def _skip_eager_if_gpus_less_than(self, num_gpus): + if context.num_gpus() < num_gpus and context.executing_eagerly(): + self.skipTest("Not enough GPUs available for this test in eager mode.") + + def _call_and_check(self, model_fn, inputs, expected_result, defuns, + two_variables=False): + cpu_dev = device_util.canonicalize("CPU:0") + gpu_dev = device_util.canonicalize("GPU:0") + devices = [cpu_dev, gpu_dev] + dist = mirrored_strategy.MirroredStrategy(devices) + + with dist.scope(): + mock_model = MockModel(two_variables) + self.evaluate(variables.global_variables_initializer()) + + result = dist.call_for_each_tower(model_fn, mock_model, *inputs, + run_concurrently=False) + for device in devices: + device_result = values.select_device(device, result) + device_expected_result = values.select_device(device, expected_result) + self.assertAllClose(device_expected_result, + self.evaluate(device_result)) + + for defun in defuns: + self.assertEqual(set(mock_model.variables), set(defun.variables)) + + @test_util.run_in_graph_and_eager_modes() + def testVariableInDefun(self): + self._skip_eager_if_gpus_less_than(1) + + @function.defun + def times_two(mock_model): + return mock_model() + + def model_fn(mock_model): + return times_two(mock_model) + + self._call_and_check(model_fn, [], 2.5, [times_two]) + + @test_util.run_in_graph_and_eager_modes() + def testVariableInNestedDefun(self): + self._skip_eager_if_gpus_less_than(1) + + @function.defun + def times_two(mock_model): + return mock_model() + + @function.defun + def two_x_plus_one(mock_model): + return times_two(mock_model) + 1 + + def model_fn(mock_model): + return two_x_plus_one(mock_model) + + self._call_and_check(model_fn, [], 3.5, [times_two, two_x_plus_one]) + + @test_util.run_in_graph_and_eager_modes() + def testTwoVariablesInNestedDefun(self): + self._skip_eager_if_gpus_less_than(1) + + @function.defun + def fn1(mock_model): + return mock_model() + + @function.defun + def fn2(mock_model): + return fn1(mock_model) + 1 + + def model_fn(mock_model): + return fn2(mock_model) + + self._call_and_check(model_fn, [], 5.5, [fn1, fn2], two_variables=True) + + @test_util.run_in_graph_and_eager_modes() + def testGradientTapeOverNestedDefuns(self): + self._skip_eager_if_gpus_less_than(1) + + @function.defun + def fn1(mock_model): + return mock_model() + + @function.defun + def fn2(mock_model): + return fn1(mock_model) + 1 + + def model_fn(mock_model): + with backprop.GradientTape(persistent=True) as gtape: + result = fn2(mock_model) + grads = gtape.gradient(result, + [v.get() for v in mock_model.variables]) + return grads + + self._call_and_check(model_fn, [], [2.0, 1.0], [fn1, fn2], + two_variables=True) + + @test_util.run_in_graph_and_eager_modes() + def testPassPerDevice(self): + self._skip_eager_if_gpus_less_than(1) + + @function.defun + def fn1(mock_model, factor): + return mock_model(factor) + + factors = values.PerDevice({"CPU:0": 5.0, "GPU:0": 3.0}) + expected_result = values.PerDevice({"CPU:0": 5.0 * 1.25, + "GPU:0": 3.0 * 1.25}) + self._call_and_check(fn1, [factors], expected_result, [fn1]) + + if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/distribute/python/mirrored_strategy_test.py b/tensorflow/contrib/distribute/python/mirrored_strategy_test.py index a066adf1246ecd9ab8bd6a85be1f1e9be2c35b17..55d59adc078ad546e4fe0a3acb88741e8666b562 100644 --- a/tensorflow/contrib/distribute/python/mirrored_strategy_test.py +++ b/tensorflow/contrib/distribute/python/mirrored_strategy_test.py @@ -19,12 +19,16 @@ from __future__ import division from __future__ import print_function from tensorflow.contrib.distribute.python import mirrored_strategy +from tensorflow.contrib.distribute.python import multi_worker_test_base from tensorflow.contrib.distribute.python import strategy_test_lib from tensorflow.python.eager import context from tensorflow.python.eager import test +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import ops from tensorflow.python.framework import test_util from tensorflow.python.ops import variable_scope -from tensorflow.python.training import distribute as distribute_lib +from tensorflow.python.training import distribution_strategy_context +from tensorflow.python.training import server_lib class MirroredOneCPUDistributionTest(strategy_test_lib.DistributionTestBase): @@ -68,7 +72,8 @@ class VariableCreatorStackTest(test.TestCase): v = variable_scope.variable(1.0) # This will pause the current thread, and execute the other thread. - distribute_lib.get_tower_context().merge_call(lambda _: _) + distribution_strategy_context.get_tower_context().merge_call( + lambda _: _) return v def main_thread_creator(next_creator, *args, **kwargs): @@ -85,5 +90,33 @@ class VariableCreatorStackTest(test.TestCase): self.assertEquals(expected, result) +class MultiWorkerMirroredStrategyTest( + multi_worker_test_base.MultiWorkerTestBase, + strategy_test_lib.DistributionTestBase): + + def _get_distribution_strategy(self): + return mirrored_strategy.MirroredStrategy( + cluster_spec=server_lib.ClusterSpec({ + 'worker': ['/job:worker/task:0', '/job:worker/task:1'] + }), + num_gpus=context.num_gpus()) + + def testMinimizeLossGraph(self): + self._test_minimize_loss_graph(self._get_distribution_strategy()) + + def testDeviceScope(self): + """Test the device scope of multi-worker MirroredStrategy.""" + with context.graph_mode(): + strategy = mirrored_strategy.MirroredStrategy( + cluster_spec={'worker': ['/job:worker/task:0', '/job:worker/task:1']}, + num_gpus=context.num_gpus()) + with strategy.scope(): + a = constant_op.constant(1.) + with ops.device('/cpu:0'): + b = constant_op.constant(1.) + self.assertEqual(a.device, '/job:worker/task:0') + self.assertEqual(b.device, '/job:worker/task:0/device:CPU:0') + + if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/distribute/python/multi_worker_strategy.py b/tensorflow/contrib/distribute/python/multi_worker_strategy.py deleted file mode 100644 index cbfe5df61d1ee6fa1eb9275b715b0721d678a46f..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/distribute/python/multi_worker_strategy.py +++ /dev/null @@ -1,141 +0,0 @@ -# Copyright 2018 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Classes implementing a mirrored DistributionStrategy for multiple workers.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -from functools import partial - -from tensorflow.contrib.distribute.python import values -from tensorflow.contrib.distribute.python.mirrored_strategy import MirroredStrategy -from tensorflow.core.protobuf import cluster_pb2 -from tensorflow.python.training import device_util -from tensorflow.python.training import server_lib -from tensorflow.python.util import nest - - -# TODO(yuefengz): support between-graph replication. -# TODO(yuefengz): merge this class into its base class. -# TODO(yuefengz): in some cases, we probably want to use configure method to -# configure this class. -# TODO(yuefengz): MirroredStrategy.worker_devices may be confusing after the -# class is introduced. -class MultiWorkerMirroredStrategy(MirroredStrategy): - """Mirrored strategy that works on multiple workers with in-graph replication. - - There are several important concepts for distributed TensorFlow, e.g. - `client`, `job`, 'task', `cluster`, `in-graph replication` and - 'synchronous training' and they have already been defined in the - [TensorFlow's documentation](https://www.tensorflow.org/deploy/distributed). - The distribution strategy inherits these concepts as well and in addition to - that we also clarify several more concepts: - * **In-graph replication**: the `client` creates a single `tf.Graph` that - specifies tasks for devices on all workers. The `client` then creates a - client session which will talk to the `master` service of a `worker`. Then - the `master` will partition the graph and distribute the work to all - participating workers. - * **Worker**: A `worker` is a TensorFlow `task` that usually maps to one - physical machine. We will have multiple `worker`s with different `task` - index. They all do similar things except for one worker checkpointing model - variables, writing summaries, etc. in addition to its ordinary work. - - This class maps one tower to one device on a worker. It mirrors all model - variables on all towers. For example, if you have two `worker`s and each - `worker` has 4 GPUs, it will create 8 copies of the model variables on these 8 - GPUs. Then like in MirroredStrategy, each tower performs their computation - with their own copy of variables unless in cross-tower model where variable or - tensor reduction happens. - """ - - def __init__(self, - num_gpus_per_worker=1, - worker_job_name=None, - num_workers=None, - cluster=None, - cross_tower_ops=None, - prefetch_on_device=None): - """Initialize the strategy object. - - Args: - num_gpus_per_worker: number of GPUs per work. If it is zero, the local - CPU will be used. - worker_job_name: the job name for `worker`, typically just 'worker'. - num_workers: the number of workers. If it is 0, it regenerates to - single-worker MirroredStrategy. - cluster: a `tf.train.ClusterSpec` object or a dict that can be used to - construct a `tf.train.ClusterSpec` object or a `tf.train.ClusterDef` - proto buffer. It is an alternative way to initialize this object. - cross_tower_ops: the cross tower ops to use. If None, a default one will - be used. If configure method is called, a best one for the configuration - will be chosen. - prefetch_on_device: a boolean to specify whether to prefetech input to - each worker's devices. - - Raises: - ValueError: if got an unexpected `cluster`. - """ - if cluster is None: - self._workers = [ - '/job:%s/task:%d' % (worker_job_name, task_index) - for task_index in range(num_workers) - ] - else: - if isinstance(cluster, (dict, cluster_pb2.ClusterDef)): - cluster_spec = server_lib.ClusterSpec(cluster) - elif isinstance(cluster, server_lib.ClusterSpec): - cluster_spec = cluster - else: - raise ValueError( - "`cluster_spec' should be dict or a `tf.train.ClusterSpec` or a " - '`tf.train.ClusterDef` object') - - self._workers = [] - for job in sorted(cluster_spec.jobs): - for task in range(cluster_spec.num_tasks(job)): - self._workers.append('/job:%s/task:%d' % (job, task)) - - self._num_gpus_per_worker = num_gpus_per_worker - if num_gpus_per_worker > 0: - self._worker_device_map = { - worker: [ - device_util.canonicalize(worker + '/device:GPU:%d' % gpu) - for gpu in range(num_gpus_per_worker) - ] for worker in self._workers - } - else: - self._worker_device_map = { - worker: [device_util.canonicalize(worker, '/device:CPU:0')] - for worker in self._workers - } - self._devices = nest.flatten(self._worker_device_map) - - super(MultiWorkerMirroredStrategy, self).__init__( - devices=self._devices, prefetch_on_device=prefetch_on_device) - - # Setting `_default_device` will add a device scope in the - # distribution.scope. We set the default device to the first worker. When - # users specify device under distribution.scope by - # with tf.device("/cpu:0"): - # ... - # their ops will end up on the cpu device of its first worker, e.g. - # "/job:worker/task:0/device:CPU:0". Note this is not used in tower mode. - self._default_device = self._workers[0] - - def distribute_dataset(self, dataset_fn): - return values.MultiWorkerDataset( - partial(self._call_dataset_fn, dataset_fn), self._worker_device_map, - self._prefetch_on_device) diff --git a/tensorflow/contrib/distribute/python/multi_worker_strategy_test.py b/tensorflow/contrib/distribute/python/multi_worker_strategy_test.py deleted file mode 100644 index 09c859b32a3150b95fbfcfa5b62b5eca426ddf18..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/distribute/python/multi_worker_strategy_test.py +++ /dev/null @@ -1,62 +0,0 @@ -# Copyright 2018 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Tests for MultiWorkerMirroredStrategy.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -from tensorflow.contrib.distribute.python import multi_worker_strategy -from tensorflow.contrib.distribute.python import multi_worker_test_base -from tensorflow.contrib.distribute.python import strategy_test_lib -from tensorflow.python.eager import context -from tensorflow.python.eager import test -from tensorflow.python.framework import constant_op -from tensorflow.python.framework import ops -from tensorflow.python.training import server_lib - - -class MultiWorkerStrategyTest(multi_worker_test_base.MultiWorkerTestBase, - strategy_test_lib.DistributionTestBase): - - def _get_distribution_strategy(self): - return multi_worker_strategy.MultiWorkerMirroredStrategy( - cluster=server_lib.ClusterSpec({ - 'worker': ['/job:worker/task:0', '/job:worker/task:1'] - }), - num_gpus_per_worker=context.num_gpus()) - - def testMinimizeLossGraph(self): - self._test_minimize_loss_graph(self._get_distribution_strategy()) - - -class DeviceScopeTest(test.TestCase): - """Test the device scope of MultiWorkerMirroredStrategy.""" - - def testDeviceScope(self): - with context.graph_mode(): - strategy = multi_worker_strategy.MultiWorkerMirroredStrategy( - cluster={'worker': ['/job:worker/task:0', '/job:worker/task:1']}, - num_gpus_per_worker=context.num_gpus()) - with strategy.scope(): - a = constant_op.constant(1.) - with ops.device('/cpu:0'): - b = constant_op.constant(1.) - self.assertEqual(a.device, '/job:worker/task:0') - self.assertEqual(b.device, '/job:worker/task:0/device:CPU:0') - - -if __name__ == '__main__': - test.main() diff --git a/tensorflow/contrib/distribute/python/multi_worker_test_base.py b/tensorflow/contrib/distribute/python/multi_worker_test_base.py index f659be5f42594b275af06435cb0c228e5d594ac9..249de01f0880b02d603687db99692088480f7136 100644 --- a/tensorflow/contrib/distribute/python/multi_worker_test_base.py +++ b/tensorflow/contrib/distribute/python/multi_worker_test_base.py @@ -20,35 +20,68 @@ from __future__ import print_function import contextlib import copy +import threading +import numpy as np from tensorflow.core.protobuf import config_pb2 from tensorflow.core.protobuf import rewriter_config_pb2 from tensorflow.python.client import session -from tensorflow.python.eager import test +from tensorflow.python.estimator import run_config +from tensorflow.python.platform import test from tensorflow.python.framework import test_util +def create_in_process_cluster(num_workers, num_ps): + """Create an in-process cluster that consists of only standard server.""" + # Leave some memory for cuda runtime. + gpu_mem_frac = 0.7 / num_workers + worker_config = config_pb2.ConfigProto() + worker_config.gpu_options.per_process_gpu_memory_fraction = gpu_mem_frac + + # Enable collective ops which has no impact on non-collective ops. + # TODO(yuefengz, tucker): removing this after we move the initialization of + # collective mgr to the session level. + worker_config.experimental.collective_group_leader = ( + '/job:worker/replica:0/task:0') + + ps_config = config_pb2.ConfigProto() + ps_config.device_count['GPU'] = 0 + + # Create in-process servers. Once an in-process tensorflow server is created, + # there is no way to terminate it. So we create one cluster per test process. + # We could've started the server in another process, we could then kill that + # process to terminate the server. The reasons why we don't want multiple + # processes are + # 1) it is more difficult to manage these processes; + # 2) there is something global in CUDA such that if we initialize CUDA in the + # parent process, the child process cannot initialize it again and thus cannot + # use GPUs (https://stackoverflow.com/questions/22950047). + return test_util.create_local_cluster( + num_workers, + num_ps=num_ps, + worker_config=worker_config, + ps_config=ps_config, + protocol='grpc') + + class MultiWorkerTestBase(test.TestCase): """Base class for testing multi node strategy and dataset.""" @classmethod def setUpClass(cls): """Create a local cluster with 2 workers.""" - num_workers = 2 - # Leave some memory for cuda runtime. - gpu_mem_frac = 0.7 / num_workers - default_config = config_pb2.ConfigProto() - default_config.gpu_options.per_process_gpu_memory_fraction = gpu_mem_frac - - # The local cluster takes some portion of the local GPUs and there is no way - # for the cluster to terminate unless using multiple processes. Therefore, - # we have to only create only one cluster throughout a test process. - workers, _ = test_util.create_local_cluster( - num_workers, num_ps=0, worker_config=default_config) - cls._master_target = workers[0].target + cls._workers, cls._ps = create_in_process_cluster(num_workers=2, num_ps=0) + + def setUp(self): + # We only cache the session in one test because another test may have a + # different session config or master target. + self._thread_local = threading.local() + self._thread_local.cached_session = None + self._result = 0 + self._lock = threading.Lock() @contextlib.contextmanager - def test_session(self, graph=None, config=None): + def test_session(self, graph=None, config=None, target=None): """Create a test session with master target set to the testing cluster. This overrides the base class' method, removes arguments that are not needed @@ -59,6 +92,7 @@ class MultiWorkerTestBase(test.TestCase): graph: Optional graph to use during the returned session. config: An optional config_pb2.ConfigProto to use to configure the session. + target: the target of session to connect to. Yields: A Session object that should be used as a context manager to surround @@ -78,13 +112,46 @@ class MultiWorkerTestBase(test.TestCase): rewriter_config_pb2.RewriterConfig.OFF) if graph is None: - if self._cached_session is None: # pylint: disable=access-member-before-definition - self._cached_session = session.Session( - graph=None, config=config, target=self._master_target) - sess = self._cached_session + if getattr(self._thread_local, 'cached_session', None) is None: + self._thread_local.cached_session = session.Session( + graph=None, config=config, target=target or self._workers[0].target) + sess = self._thread_local.cached_session with sess.graph.as_default(), sess.as_default(): yield sess else: with session.Session( - graph=graph, config=config, target=self._master_target) as sess: + graph=graph, config=config, target=target or + self._workers[0].target) as sess: yield sess + + def _run_client(self, client_fn, task_type, task_id, num_gpus, *args, + **kwargs): + result = client_fn(task_type, task_id, num_gpus, *args, **kwargs) + if np.all(result): + with self._lock: + self._result += 1 + + def _run_between_graph_clients(self, client_fn, cluster_spec, num_gpus, *args, + **kwargs): + """Runs several clients for between-graph replication. + + Args: + client_fn: a function that needs to accept `task_type`, `task_id`, + `num_gpus` and returns True if it succeeds. + cluster_spec: a dict specifying jobs in a cluster. + num_gpus: number of GPUs per worker. + *args: will be passed to `client_fn`. + **kwargs: will be passed to `client_fn`. + """ + threads = [] + for task_type in [run_config.TaskType.CHIEF, run_config.TaskType.WORKER]: + for task_id in range(len(cluster_spec.get(task_type, []))): + t = threading.Thread( + target=self._run_client, + args=(client_fn, task_type, task_id, num_gpus) + args, + kwargs=kwargs) + t.start() + threads.append(t) + for t in threads: + t.join() + self.assertEqual(self._result, len(threads)) diff --git a/tensorflow/contrib/distribute/python/one_device_strategy.py b/tensorflow/contrib/distribute/python/one_device_strategy.py index dbd3514aec7d40d9a04dba4bcbc5c14be639aa33..68561b5bbf06374cb391e2837ff7bc989ac3a2bd 100644 --- a/tensorflow/contrib/distribute/python/one_device_strategy.py +++ b/tensorflow/contrib/distribute/python/one_device_strategy.py @@ -21,11 +21,14 @@ from __future__ import print_function import six from tensorflow.contrib.distribute.python import values +from tensorflow.python.framework import constant_op from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops +from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import variable_scope as vs from tensorflow.python.training import distribute as distribute_lib +from tensorflow.python.util import nest # TODO(josh11b): Replace asserts in this file with if ...: raise ... @@ -66,6 +69,53 @@ class OneDeviceStrategy(distribute_lib.DistributionStrategy): def _broadcast(self, tensor, destinations): return tensor + # TODO(priyag): Deal with OutOfRange errors once b/111349762 is fixed. + def _run_steps_on_dataset(self, fn, iterator, iterations, + initial_loop_values=None): + if initial_loop_values is None: + initial_loop_values = {} + initial_loop_values = nest.flatten(initial_loop_values) + + ctx = values.MultiStepContext() + def body(i, *args): + """A wrapper around `fn` to create the while loop body.""" + del args + fn_inputs = iterator.get_next() + if not isinstance(fn_inputs, tuple): + fn_inputs = (fn_inputs,) + fn_result = fn(ctx, *fn_inputs) + flat_last_step_outputs = nest.flatten(ctx.last_step_outputs) + with ops.control_dependencies([fn_result]): + return [i + 1] + flat_last_step_outputs + + # We capture the control_flow_context at this point, before we run `fn` + # inside a while_loop. This is useful in cases where we might need to exit + # these contexts and get back to the outer context to do some things, for + # e.g. create an op which should be evaluated only once at the end of the + # loop on the host. One such usage is in creating metrics' value op. + self._outer_control_flow_context = ( + ops.get_default_graph()._get_control_flow_context()) # pylint: disable=protected-access + + # TODO(priyag): Use max_iterations instead of an explicit counter. + cond = lambda i, *args: i < iterations + i = constant_op.constant(0) + loop_result = control_flow_ops.while_loop( + cond, body, [i] + initial_loop_values, name="", + parallel_iterations=1, back_prop=False, swap_memory=False, + return_same_structure=True) + del self._outer_control_flow_context + + ctx.run_op = control_flow_ops.group(loop_result) + + # Convert the last_step_outputs from a list to the original dict structure + # of last_step_outputs. + last_step_tensor_outputs = loop_result[1:] + last_step_tensor_outputs_dict = nest.pack_sequence_as( + ctx.last_step_outputs, last_step_tensor_outputs) + + ctx._set_last_step_outputs(last_step_tensor_outputs_dict) # pylint: disable=protected-access + return ctx + def _call_for_each_tower(self, fn, *args, **kwargs): # We don't run `fn` in multiple threads in OneDeviceStrategy. kwargs.pop("run_concurrently", None) @@ -105,6 +155,9 @@ class OneDeviceStrategy(distribute_lib.DistributionStrategy): def _unwrap(self, value): return [value] + def value_container(self, value): + return value + @property def is_single_tower(self): return True diff --git a/tensorflow/contrib/distribute/python/parameter_server_strategy.py b/tensorflow/contrib/distribute/python/parameter_server_strategy.py new file mode 100644 index 0000000000000000000000000000000000000000..96b6519bc4d0a280746632fef57c54a9b1e82fe8 --- /dev/null +++ b/tensorflow/contrib/distribute/python/parameter_server_strategy.py @@ -0,0 +1,389 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Classes implementing a multi-worker ps DistributionStrategy.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.distribute.python import cross_tower_ops as cross_tower_ops_lib +from tensorflow.contrib.distribute.python import mirrored_strategy +from tensorflow.contrib.distribute.python import values +from tensorflow.python.distribute import multi_worker_util +from tensorflow.python.framework import device as tf_device +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import resource_variable_ops +from tensorflow.python.ops import variable_scope as vs +from tensorflow.python.training import device_setter +from tensorflow.python.training import device_util +from tensorflow.python.training import distribute as distribute_lib +from tensorflow.python.util import nest + +_LOCAL_CPU = "/device:CPU:0" +_LOCAL_GPU_0 = "/device:GPU:0" + + +# TODO(yuefengz): maybe cache variables on local CPU. +# TODO(yuefengz): we may want to set session options to disallow communication +# between workers. +class ParameterServerStrategy(distribute_lib.DistributionStrategy): + """A parameter server DistributionStrategy. + + This strategy class works for both local training and between-graph replicated + training for multiple workers. If `cluster_spec` is specified, either passed + in to __init__() method or parsed from the + ["TF_CONFIG" environment + variable](https://www.tensorflow.org/api_docs/python/tf/estimator/RunConfig), + variables and updates to those variables are assigned to parameter servers and + other operations are assigned to workers. If `cluster_spec` is not set, it + becomes local training where variables are assigned to local CPU or the only + GPU. When each worker has more than one GPU, operations will be replicated on + these GPUs. In both cases, operations are replicated but variables are not and + these workers share a common view for which paramater server a variable is + assigned to. + + This class assumes between-graph replication will be used and works on a graph + for a particular worker. Note that each graph and worker is independent. + This means that while each worker will synchronously compute a single gradient + update across all GPUs, updates between workers proceed asynchronously. + Operations that occur only on the first tower (such as incrementing the global + step), will occur on the first tower *of every worker*. + + It is expected to call `call_for_each_tower(fn, *args, **kwargs)` for any + operations which potentially can be replicated across towers (i.e. multiple + GPUs) even if there is only CPU or one GPU. When defining the `fn`, extra + caution needs to be taken: + + 1) Always use `tf.get_variable` instead of `tf.Variable` which is not able + to refer to the same variable on different towers. + + 2) It is generally not recommended to open a device scope under the strategy's + scope. A device scope (i.e. calling `tf.device`) will be merged with or + override the device for operations but will not change the device for + variables. + + 3) It is also not recommended to open a colocation scope (i.e. calling + `tf.colocate_with`) under the strategy's scope. For colocating variables, + use `distribution.colocate_vars_with` instead. Colocation of ops will possibly + create conflicts of device assignment. + """ + + def __init__(self, + num_gpus_per_worker=0, + cluster_spec=None, + task_type=None, + task_id=None): + """Initializes this strategy. + + Args: + num_gpus_per_worker: number of local GPUs or GPUs per worker. + cluster_spec: a dict, ClusterDef or ClusterSpec object specifying the + cluster configurations. + task_type: the current task type. + task_id: the current task id. + """ + super(ParameterServerStrategy, self).__init__() + self._num_gpus_per_worker = num_gpus_per_worker + if cluster_spec: + cluster_spec = multi_worker_util.normalize_cluster_spec(cluster_spec) + self._cluster_spec = cluster_spec + + # We typically don't need to do all-reduce in this strategy. + self._cross_tower_ops = ( + cross_tower_ops_lib.ReductionToOneDeviceCrossTowerOps( + reduce_to_device=_LOCAL_CPU)) + + self._initialize_devices(num_gpus_per_worker, cluster_spec, task_type, + task_id) + + def _initialize_devices(self, num_gpus_per_worker, cluster_spec, task_type, + task_id): + """Initialize internal devices. + + It creates variable devices and compute devices. Variables and operations + will be assigned to them respectively. We have one compute device per tower. + The variable device is a device function or device string. The default + variable device assigns variables to parameter servers in a round-robin + fashion. + + Args: + num_gpus_per_worker: number of local GPUs or GPUs per worker. + cluster_spec: a dict, ClusterDef or ClusterSpec object specifying the + cluster configurations. + task_type: the current task type. + task_id: the current task id. + + Raises: + ValueError: if the cluster_spec doesn't have ps jobs. + """ + self._task_type = task_type or "worker" + self._task_id = task_id or 0 + self._worker_device = "/job:%s/task:%d" % (self._task_type, self._task_id) + + # TODO(yuefengz): maybe clearer to split it into two classes, one for + # the distribuetd case and one for the local case, once we have the factory + # class/method. + + # Define compute devices which is a list of device strings and one for each + # tower. When there are GPUs, replicate operations on these GPUs. Otherwise, + # place operations on CPU. + if cluster_spec is None: + # Local mode. + if num_gpus_per_worker > 0: + self._compute_devices = list( + map("/device:GPU:{}".format, range(num_gpus_per_worker))) + else: + self._compute_devices = [_LOCAL_CPU] + else: + # Distributed mode. + if num_gpus_per_worker > 0: + self._compute_devices = [ + "%s/device:GPU:%d" % (self._worker_device, i) + for i in range(num_gpus_per_worker) + ] + else: + self._compute_devices = [self._worker_device] + + self._compute_devices = list( + map(device_util.resolve, self._compute_devices)) + self._canonical_compute_device_set = set(self._compute_devices) + + # Define variable device which is a device string in the local case and a + # device function in the distributed case. It is used to open a device scope + # where varibles are defined. + # The `_parameter_devices` is needed for the `parameter_devices` property + # and is a list of all variable devices. + if cluster_spec is None: + # Local mode. If there is only one GPU, put everything on that GPU. + # Otherwise, place variables on CPU. + if num_gpus_per_worker == 1: + assert len(list(self._compute_devices)) == 1 + self._variable_device = _LOCAL_GPU_0 + self._parameter_devices = [_LOCAL_GPU_0] + else: + self._variable_device = _LOCAL_CPU + self._parameter_devices = [_LOCAL_CPU] + else: + # Distributed mode. Place variables on ps jobs in a round-robin fashion. + # Note that devices returned from `replica_device_setter` are not + # canonical and therefore we don't canonicalize all variable devices to + # make them consistent. + # TODO(yuefengz): support passing a strategy object to control variable + # assignment. + # TODO(yuefengz): merge the logic of replica_device_setter into this + # class. + num_ps_replicas = len(cluster_spec.as_dict().get("ps", [])) + if num_ps_replicas == 0: + raise ValueError("The cluster spec needs to have `ps` jobs.") + self._variable_device = device_setter.replica_device_setter( + ps_tasks=num_ps_replicas, + worker_device=self._worker_device, + merge_devices=True, + cluster=cluster_spec) + + # Parameter devices are all tasks of the "ps" job. + self._parameter_devices = map("/job:ps/task:{}".format, + range(num_ps_replicas)) + + # Define the default device in cross-tower mode. In the distributed case, we + # set the default device to the corresponding worker to prevent these ops + # from being placed on other workers. + if cluster_spec is None: + self._default_device = None + else: + self._default_device = self._worker_device + + self._is_chief = cluster_spec is None or multi_worker_util.is_chief( + cluster_spec, task_type, task_id) + + def distribute_dataset(self, dataset_fn): + """Distributes the dataset to each local GPU.""" + return values.PerDeviceDataset( + self._call_dataset_fn(dataset_fn), self._compute_devices, True) + + def _broadcast(self, tensor, destinations): + if not cross_tower_ops_lib.check_destinations(destinations): + destinations = self._compute_devices + return self._cross_tower_ops.broadcast(tensor, destinations) + + # TODO(yuefengz): not all ops in device_setter.STANDARD_PS_OPS will go through + # this creator, such as "MutableHashTable". + def _create_variable(self, next_creator, *args, **kwargs): + if self.num_towers > 1: + aggregation = kwargs.pop("aggregation", vs.VariableAggregation.NONE) + if aggregation not in ( + vs.VariableAggregation.NONE, + vs.VariableAggregation.SUM, + vs.VariableAggregation.MEAN + ): + raise ValueError("Invalid variable aggregation mode: " + aggregation + + " for variable: " + kwargs["name"]) + + def var_creator(*args, **kwargs): + v = next_creator(*args, **kwargs) + return values.AggregatingVariable(v, aggregation) + else: + var_creator = next_creator + + if "colocate_with" in kwargs: + with ops.device(None): + with ops.colocate_with(kwargs["colocate_with"]): + return var_creator(*args, **kwargs) + + with ops.colocate_with(None, ignore_existing=True): + with ops.device(self._variable_device): + return var_creator(*args, **kwargs) + + def _call_for_each_tower(self, fn, *args, **kwargs): + # pylint: disable=protected-access + return mirrored_strategy._call_for_each_tower(self, fn, *args, **kwargs) + + def _verify_destinations_not_different_worker(self, destinations): + if destinations is None: + return + for d in cross_tower_ops_lib.get_devices_from(destinations): + d_spec = tf_device.DeviceSpec.from_string(d) + if d_spec.job == self._task_type and d_spec.task != self._task_id: + raise ValueError( + "Cannot reduce to another worker: %r, current worker is %r" % + (d, self._worker_device)) + + def _reduce(self, aggregation, value, destinations): + self._verify_destinations_not_different_worker(destinations) + if not isinstance(value, values.DistributedValues): + # pylint: disable=protected-access + return mirrored_strategy._reduce_non_distributed_value( + self, aggregation, value, destinations) + return self._cross_tower_ops.reduce( + aggregation, value, destinations=destinations) + + def _batch_reduce(self, aggregation, value_destination_pairs): + for _, destinations in value_destination_pairs: + self._verify_destinations_not_different_worker(destinations) + return self._cross_tower_ops.batch_reduce(aggregation, + value_destination_pairs) + + def _select_single_value(self, structured): + """Select any single values in `structured`.""" + + def _select_fn(x): # pylint: disable=g-missing-docstring + if isinstance(x, values.Mirrored): + if len(x.devices) == 1: + return list(x._index.values())[0] # pylint: disable=protected-access + else: + raise ValueError( + "You cannot update variable with a Mirrored object with multiple " + "components %r when using ParameterServerStrategy. You must " + "specify a single value or a Mirrored with a single value." % x) + elif isinstance(x, values.PerDevice): + raise ValueError( + "You cannot update variable with a PerDevice object %r when using " + "ParameterServerStrategy. You must specify a single value or a " + "Mirrored with a single value" % x) + else: + return x + + return nest.map_structure(_select_fn, structured) + + def _update(self, var, fn, *args, **kwargs): + if isinstance(var, values.AggregatingVariable): + var = var.get() + if not isinstance(var, resource_variable_ops.ResourceVariable): + raise ValueError( + "You can not update `var` %r. It must be a Variable." % var) + with ops.colocate_with(var), distribute_lib.UpdateContext(var.device): + return fn(var, *self._select_single_value(args), + **self._select_single_value(kwargs)) + + # TODO(yuefengz): does it need to call _select_single_value? + def _update_non_slot(self, colocate_with, fn, *args, **kwargs): + with ops.device( + colocate_with.device), distribute_lib.UpdateContext(colocate_with): + return fn(*args, **kwargs) + + def _unwrap(self, val): + if isinstance(val, values.DistributedValues): + # Return in a deterministic order. + if set(val.devices) == self._canonical_compute_device_set: + return [val.get(device=d) for d in self._compute_devices] + return [val.get(device=d) for d in sorted(val.devices)] + return [val] + + def value_container(self, val): + return values.value_container(val) + + def read_var(self, var): + # No need to distinguish between normal variables and tower-local variables. + return array_ops.identity(var) + + def configure(self, + session_config=None, + cluster_spec=None, + task_type=None, + task_id=None): + """Configures the strategy class. + + The strategy object will be re-initialized if `cluster_spec` is given but + was not passed in the constructor. + + Args: + session_config: not used currently. + cluster_spec: a dict, ClusterDef or ClusterSpec object specifying the + cluster configurations. + task_type: the current task type. + task_id: the current task id. + """ + del session_config + + # Set the devices if cluster_spec is defined in TF_CONFIG but not passed in + # the constructor. + if not self._cluster_spec and cluster_spec: + self._cluster_spec = multi_worker_util.normalize_cluster_spec( + cluster_spec) + self._initialize_devices(self._num_gpus_per_worker, self._cluster_spec, + task_type, task_id) + + @property + def num_towers(self): + return len(self._compute_devices) + + @property + def worker_devices(self): + # Make a copy to prevent users from accidentally mutating our copy. + return list(self._compute_devices) + + @property + def parameter_devices(self): + return list(self._parameter_devices) + + def non_slot_devices(self, var_list): + return min(var_list, key=lambda x: x.name) + + @property + def between_graph(self): + return True + + @property + def should_init(self): + return self._is_chief + + @property + def should_checkpoint(self): + return self._is_chief + + @property + def should_save_summary(self): + return self._is_chief diff --git a/tensorflow/contrib/distribute/python/parameter_server_strategy_test.py b/tensorflow/contrib/distribute/python/parameter_server_strategy_test.py new file mode 100644 index 0000000000000000000000000000000000000000..adfe3e8b020521d9c2c409da7c6d79e0ba060330 --- /dev/null +++ b/tensorflow/contrib/distribute/python/parameter_server_strategy_test.py @@ -0,0 +1,448 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for ParameterServerStrategy.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import threading +from absl.testing import parameterized + +from tensorflow.contrib.distribute.python import combinations +from tensorflow.contrib.distribute.python import multi_worker_test_base +from tensorflow.contrib.distribute.python import parameter_server_strategy +from tensorflow.python.eager import context +from tensorflow.python.estimator import run_config +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import ops +from tensorflow.python.layers import core +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import gradients +from tensorflow.python.ops import variable_scope +from tensorflow.python.ops import variables +from tensorflow.python.platform import test +from tensorflow.python.training import device_util +from tensorflow.python.training import distribution_strategy_context + + +class ParameterServerStrategyTest(multi_worker_test_base.MultiWorkerTestBase, + parameterized.TestCase): + + @classmethod + def setUpClass(cls): + cls._workers, cls._ps = multi_worker_test_base.create_in_process_cluster( + num_workers=3, num_ps=2) + cls._cluster_spec = { + run_config.TaskType.WORKER: [ + 'fake_worker_0', 'fake_worker_1', 'fake_worker_2' + ], + run_config.TaskType.PS: ['fake_ps_0', 'fake_ps_1'] + } + + def setUp(self): + self._result = 0 + self._lock = threading.Lock() + self._init_condition = threading.Condition() + self._init_reached = 0 + self._finish_condition = threading.Condition() + self._finish_reached = 0 + super(ParameterServerStrategyTest, self).setUp() + + def _get_test_objects(self, task_type, task_id, num_gpus): + distribution = parameter_server_strategy.ParameterServerStrategy( + num_gpus_per_worker=num_gpus) + if not task_type: + return distribution, '' + + distribution.configure( + cluster_spec=self._cluster_spec, task_type=task_type, task_id=task_id) + return distribution, self._workers[task_id].target + + def _test_device_assignment_distributed(self, task_type, task_id, num_gpus): + worker_device = '/job:%s/replica:0/task:%d' % (task_type, task_id) + d, _ = self._get_test_objects(task_type, task_id, num_gpus) + with ops.Graph().as_default(), \ + self.test_session(target=self._workers[0].target) as sess, \ + d.scope(): + + # Define a variable outside the call_for_each_tower scope. This is not + # recommended. + n = variable_scope.get_variable('n', initializer=10.0) + self.assertEqual(n.device, '/job:ps/task:0') + + def model_fn(): + if num_gpus == 0: + last_part_device = 'device:CPU:0' + else: + last_part_device = ( + 'device:GPU:%d' % + distribution_strategy_context.get_tower_context().tower_id) + + a = constant_op.constant(1.0) + b = constant_op.constant(2.0) + c = a + b + self.assertEqual(a.device, worker_device + '/' + last_part_device) + self.assertEqual(b.device, worker_device + '/' + last_part_device) + self.assertEqual(c.device, worker_device + '/' + last_part_device) + + # The device scope is ignored for variables but not for normal ops. + with ops.device('/job:worker/task:0'): + x = variable_scope.get_variable( + 'x', initializer=10.0, + aggregation=variable_scope.VariableAggregation.SUM) + x_add = x.assign_add(c) + e = a + c + # The variable x is on the task 1 since the device_function has been + # called once before the model_fn. + self.assertEqual(x.device, '/job:ps/task:1') + self.assertEqual(x_add.device, x.device) + self.assertEqual(e.device, + '/job:worker/replica:0/task:0/%s' % last_part_device) + + # The colocate_vars_with can override the distribution's device. + with d.colocate_vars_with(x): + y = variable_scope.get_variable( + 'y', initializer=20.0, + aggregation=variable_scope.VariableAggregation.SUM) + # We add an identity here to avoid complaints about summing + # non-distributed values. + y_add = y.assign_add(array_ops.identity(x_add)) + self.assertEqual(y.device, '/job:ps/task:1') + self.assertEqual(y_add.device, y.device) + self.assertEqual(y.device, x.device) + + z = variable_scope.get_variable( + 'z', initializer=10.0, + aggregation=variable_scope.VariableAggregation.SUM) + self.assertEqual(z.device, '/job:ps/task:0') + self.assertNotEqual(z.device, x.device) + + with ops.control_dependencies([y_add]): + # We add an identity here to avoid complaints about summing + # non-distributed values. + z_add = z.assign_add(array_ops.identity(y)) + with ops.control_dependencies([z_add]): + f = z + c + self.assertEqual(f.device, worker_device + '/' + last_part_device) + + # The device scope would merge with the default worker device. + with ops.device('/CPU:1'): + g = e + 1.0 + self.assertEqual(g.device, worker_device + '/device:CPU:1') + + # Ths ops.colocate_with will be ignored when defining a variale but not + # for a normal tensor. + with ops.colocate_with(x): + u = variable_scope.get_variable('u', initializer=30.0) + v = variable_scope.get_variable('v', initializer=30.0) + h = f + 1.0 + self.assertIn('/job:ps/', u.device) + self.assertIn('/job:ps/', v.device) + # u and v are on different parameter servers. + self.assertTrue(u.device != x.device or v.device != x.device) + self.assertTrue(u.device == x.device or v.device == x.device) + # Here h is not on one worker. Note h.device is canonical while x.device + # is not but. + self.assertIn('/job:ps/', h.device) + return y_add, z_add, f + + y, z, f = d.call_for_each_tower(model_fn) + self.assertNotEqual(y, None) + self.assertNotEqual(z, None) + self.assertNotEqual(f, None) + + if context.num_gpus() >= 1 and num_gpus <= 1: + variables.global_variables_initializer().run() + y_val, z_val, f_val = sess.run([y, z, f]) + self.assertEqual(y_val, 33.0) + self.assertEqual(z_val, 43.0) + self.assertEqual(f_val, 46.0) + + @combinations.generate( + combinations.combine(mode=['graph'], num_gpus=[0, 1, 2])) + def testDeviceAssignmentDistributed(self, num_gpus): + self._test_device_assignment_distributed('worker', 1, num_gpus) + + def _test_device_assignment_local(self, + d, + compute_device='CPU', + variable_device='CPU', + num_gpus=0): + with ops.Graph().as_default(), \ + self.test_session(target=self._workers[0].target) as sess, \ + d.scope(): + + def model_fn(): + if 'CPU' in compute_device: + tower_compute_device = '/device:CPU:0' + else: + tower_compute_device = ( + '/device:GPU:%d' % + distribution_strategy_context.get_tower_context().tower_id) + tower_compute_device = device_util.canonicalize(tower_compute_device) + + if 'CPU' in variable_device: + tower_variable_device = '/device:CPU:0' + else: + tower_variable_device = ( + '/device:GPU:%d' % + distribution_strategy_context.get_tower_context().tower_id) + tower_variable_device = device_util.canonicalize(tower_variable_device) + + a = constant_op.constant(1.0) + b = constant_op.constant(2.0) + c = a + b + self.assertEqual(a.device, tower_compute_device) + self.assertEqual(b.device, tower_compute_device) + self.assertEqual(c.device, tower_compute_device) + + # The device scope is ignored for variables but not for normal ops. + with ops.device('/device:GPU:2'): + x = variable_scope.get_variable( + 'x', initializer=10.0, + aggregation=variable_scope.VariableAggregation.SUM) + x_add = x.assign_add(c) + e = a + c + self.assertEqual( + device_util.canonicalize(x.device), tower_variable_device) + self.assertEqual(x_add.device, x.device) + self.assertEqual(e.device, device_util.canonicalize('/device:GPU:2')) + + # The colocate_vars_with can override the distribution's device. + with d.colocate_vars_with(x): + y = variable_scope.get_variable( + 'y', initializer=20.0, + aggregation=variable_scope.VariableAggregation.SUM) + # We add an identity here to avoid complaints about summing + # non-distributed values. + y_add = y.assign_add(array_ops.identity(x_add)) + self.assertEqual( + device_util.canonicalize(y.device), tower_variable_device) + self.assertEqual(y_add.device, y.device) + self.assertEqual(y.device, x.device) + + z = variable_scope.get_variable( + 'z', initializer=10.0, + aggregation=variable_scope.VariableAggregation.SUM) + self.assertEqual( + device_util.canonicalize(z.device), tower_variable_device) + + with ops.control_dependencies([y_add]): + # We add an identity here to avoid complaints about summing + # non-distributed values. + z_add = z.assign_add(array_ops.identity(y)) + with ops.control_dependencies([z_add]): + f = z + c + self.assertEqual(f.device, tower_compute_device) + + # The device scope would merge with the default worker device. + with ops.device('/CPU:1'): + g = e + 1.0 + self.assertEqual(g.device, device_util.canonicalize('/device:CPU:1')) + + # Ths ops.colocate_with will be ignored when defining a variale but not + # for a normal tensor. + with ops.colocate_with(x): + u = variable_scope.get_variable('u', initializer=30.0) + h = f + 1.0 + self.assertEqual( + device_util.canonicalize(u.device), tower_variable_device) + self.assertEqual(device_util.canonicalize(x.device), h.device) + return y_add, z_add, f + + y, z, f = d.call_for_each_tower(model_fn) + self.assertNotEqual(y, None) + self.assertNotEqual(z, None) + self.assertNotEqual(f, None) + + if context.num_gpus() >= 1 and num_gpus <= 1: + variables.global_variables_initializer().run() + y_val, z_val, f_val = sess.run([y, z, f]) + self.assertEqual(y_val, 33.0) + self.assertEqual(z_val, 43.0) + self.assertEqual(f_val, 46.0) + + def testDeviceAssignmentLocalCPU(self): + distribution = parameter_server_strategy.ParameterServerStrategy( + num_gpus_per_worker=0) + self._test_device_assignment_local( + distribution, compute_device='CPU', variable_device='CPU', num_gpus=0) + + def testDeviceAssignmentLocalOneGPU(self): + distribution = parameter_server_strategy.ParameterServerStrategy( + num_gpus_per_worker=1) + self._test_device_assignment_local( + distribution, compute_device='GPU', variable_device='GPU', num_gpus=1) + + def testDeviceAssignmentLocalTwoGPUs(self): + distribution = parameter_server_strategy.ParameterServerStrategy( + num_gpus_per_worker=2) + self._test_device_assignment_local( + distribution, compute_device='GPU', variable_device='CPU', num_gpus=2) + + def _test_simple_increment(self, task_type, task_id, num_gpus): + d, master_target = self._get_test_objects(task_type, task_id, num_gpus) + if hasattr(d, '_cluster_spec') and d._cluster_spec: + num_workers = len(d._cluster_spec.as_dict().get('worker', + ['dummy_worker'])) + else: + num_workers = 1 + with ops.Graph().as_default(), \ + self.test_session(target=master_target) as sess, \ + d.scope(): + + def model_fn(): + x = variable_scope.get_variable( + 'x', initializer=10.0, + aggregation=variable_scope.VariableAggregation.SUM) + y = variable_scope.get_variable( + 'y', initializer=20.0, + aggregation=variable_scope.VariableAggregation.SUM) + + # We explicitly make a constant tensor here to avoid complaints about + # summing non-distributed values. + one = constant_op.constant(1.0) + x_add = x.assign_add(one, use_locking=True) + y_add = y.assign_add(one, use_locking=True) + + train_op = control_flow_ops.group([x_add, y_add]) + return x, y, train_op + + x, y, train_op = d.call_for_each_tower(model_fn) + train_op = d.group(d.unwrap(train_op)) + + if context.num_gpus() < d._num_gpus_per_worker: + return True + + if task_id == 0: + variables.global_variables_initializer().run() + + # Workers waiting for chief worker's initializing variables. + self._init_condition.acquire() + self._init_reached += 1 + while self._init_reached != num_workers: + self._init_condition.wait() + self._init_condition.notify_all() + self._init_condition.release() + + sess.run(train_op) + + # Wait for other workers to finish training. + self._finish_condition.acquire() + self._finish_reached += 1 + while self._finish_reached != num_workers: + self._finish_condition.wait() + self._finish_condition.notify_all() + self._finish_condition.release() + + x_val, y_val = sess.run([x, y]) + self.assertEqual(x_val, 10.0 + 1.0 * num_workers * d.num_towers) + self.assertEqual(y_val, 20.0 + 1.0 * num_workers * d.num_towers) + return (x_val == 10.0 + 1.0 * num_workers * d.num_towers and + y_val == 20.0 + 1.0 * num_workers * d.num_towers) + + def _test_minimize_loss_graph(self, task_type, task_id, num_gpus): + d, master_target = self._get_test_objects(task_type, task_id, num_gpus) + with ops.Graph().as_default(), \ + self.test_session(target=master_target) as sess, \ + d.scope(): + l = core.Dense(1, use_bias=False) + + def loss_fn(x): + y = array_ops.reshape(l(x), []) - constant_op.constant(1.) + return y * y + + # TODO(yuefengz, apassos): eager.backprop.implicit_grad is not safe for + # multiple graphs (b/111216820). + def grad_fn(x): + loss = loss_fn(x) + var_list = ( + variables.trainable_variables() + ops.get_collection( + ops.GraphKeys.TRAINABLE_RESOURCE_VARIABLES)) + grads = gradients.gradients(loss, var_list) + ret = list(zip(grads, var_list)) + return ret + + def update(v, g): + return v.assign_sub(0.05 * g, use_locking=True) + + one = d.broadcast(constant_op.constant([[1.]])) + + def step(): + """Perform one optimization step.""" + # Run forward & backward to get gradients, variables list. + g_v = d.call_for_each_tower(grad_fn, one) + # Update the variables using the gradients and the update() function. + before_list = [] + after_list = [] + for g, v in g_v: + fetched = d.read_var(v) + before_list.append(fetched) + with ops.control_dependencies([fetched]): + # TODO(yuefengz): support non-Mirrored variable as destinations. + g = d.reduce( + variable_scope.VariableAggregation.SUM, g, destinations=v) + with ops.control_dependencies(d.unwrap(d.update(v, update, g))): + after_list.append(d.read_var(v)) + return before_list, after_list + + before_out, after_out = step() + + if context.num_gpus() < d._num_gpus_per_worker: + return True + + if task_id == 0: + variables.global_variables_initializer().run() + + # Workers waiting for chief worker's initializing variables. + self._init_condition.acquire() + self._init_reached += 1 + while self._init_reached != 3: + self._init_condition.wait() + self._init_condition.notify_all() + self._init_condition.release() + + for i in range(10): + b, a = sess.run((before_out, after_out)) + if i == 0: + before, = b + after, = a + + error_before = abs(before - 1) + error_after = abs(after - 1) + # Error should go down + self.assertLess(error_after, error_before) + return error_after < error_before + + def testSimpleBetweenGraph(self): + self._run_between_graph_clients(self._test_simple_increment, + self._cluster_spec, 0) + + @combinations.generate( + combinations.combine(mode=['graph'], num_gpus=[0, 1, 2])) + def testLocalSimpleIncrement(self, num_gpus): + self._test_simple_increment(None, 0, num_gpus) + + @combinations.generate( + combinations.combine(mode=['graph'], num_gpus=[0, 1, 2])) + def testMinimizeLossGraph(self, num_gpus): + self._run_between_graph_clients(self._test_minimize_loss_graph, + self._cluster_spec, num_gpus) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/distribute/python/prefetching_ops_v2.py b/tensorflow/contrib/distribute/python/prefetching_ops_v2.py index 24cdc627a35f4455cb92484566dc13fa1bbaf2cc..1ff60c076226299a89060a295c1cc0c50817b861 100644 --- a/tensorflow/contrib/distribute/python/prefetching_ops_v2.py +++ b/tensorflow/contrib/distribute/python/prefetching_ops_v2.py @@ -35,7 +35,7 @@ from tensorflow.python.util import nest # pylint: disable=protected-access class _PrefetchToDeviceIterator(object): - """A replacement for @{tf.data.Iterator} that prefetches to another device. + """A replacement for `tf.data.Iterator` that prefetches to another device. Args: input_dataset: The input dataset. @@ -108,7 +108,7 @@ class _PrefetchToDeviceIterator(object): self._input_dataset) def get_next(self, name=None): - """See @{tf.data.Iterator.get_next}.""" + """See `tf.data.Iterator.get_next`.""" self._get_next_call_count += 1 if self._get_next_call_count > iterator_ops.GET_NEXT_CALL_WARNING_THRESHOLD: warnings.warn(iterator_ops.GET_NEXT_CALL_WARNING_MESSAGE) @@ -209,7 +209,7 @@ class _PrefetchToDeviceDataset(dataset_ops.Dataset): def prefetch_to_devices(devices, buffer_size=None): """A transformation that prefetches dataset values to the given `devices`. - NOTE: Although the transformation creates a @{tf.data.Dataset}, the + NOTE: Although the transformation creates a `tf.data.Dataset`, the transformation must be the final `Dataset` in the input pipeline. Args: @@ -220,7 +220,7 @@ def prefetch_to_devices(devices, buffer_size=None): Returns: A `Dataset` transformation function, which can be passed to - @{tf.data.Dataset.apply}. + `tf.data.Dataset.apply`. """ def _apply_fn(dataset): return _PrefetchToDeviceDataset(dataset, devices, buffer_size) diff --git a/tensorflow/contrib/distribute/python/single_loss_example.py b/tensorflow/contrib/distribute/python/single_loss_example.py index d1fdb3279cf2a7cba6e2282d58eedccf38bd38a3..5aa19cf6a9f8411120ed929cecaf93dda6c9edf2 100644 --- a/tensorflow/contrib/distribute/python/single_loss_example.py +++ b/tensorflow/contrib/distribute/python/single_loss_example.py @@ -29,7 +29,8 @@ from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops -def single_loss_example(optimizer_fn, distribution, use_bias=False): +def single_loss_example(optimizer_fn, distribution, use_bias=False, + iterations_per_step=1): """Build a very simple network to use in tests and examples.""" def dataset_fn(): @@ -38,12 +39,13 @@ def single_loss_example(optimizer_fn, distribution, use_bias=False): optimizer = optimizer_fn() layer = core.Dense(1, use_bias=use_bias) - def loss_fn(x): + def loss_fn(ctx, x): + del ctx y = array_ops.reshape(layer(x), []) - constant_op.constant(1.) return y * y - single_loss_step = step_fn.StandardSingleLossStep(dataset_fn, loss_fn, - optimizer, distribution) + single_loss_step = step_fn.StandardSingleLossStep( + dataset_fn, loss_fn, optimizer, distribution, iterations_per_step) # Layer is returned for inspecting the kernels in tests. return single_loss_step, layer diff --git a/tensorflow/contrib/distribute/python/step_fn.py b/tensorflow/contrib/distribute/python/step_fn.py index d1910622b38c748fc5a814f9e83c2294850d5d12..1b5a4f64e5bb1ffabfe1b87c150f713c755bb682 100644 --- a/tensorflow/contrib/distribute/python/step_fn.py +++ b/tensorflow/contrib/distribute/python/step_fn.py @@ -34,15 +34,9 @@ class Step(object): def __call__(self): """Perform one step of this training algorithm.""" - return self.step(self.inputs()) - - def inputs(self): - """For the generating the input to be passed to `step()`.""" raise NotImplementedError("must be implemented in descendants") - def step(self, inputs): - """Perform the main computation of this training algorithm.""" - raise NotImplementedError("must be implemented in descendants") + # TODO(priyag): Add an method to access initialization and finalize ops. class StandardInputStep(Step): @@ -54,12 +48,9 @@ class StandardInputStep(Step): """ def __init__(self, dataset_fn, distribution): - Step.__init__(self, distribution) - self._distributed_input = distribution.distribute_dataset( - dataset_fn).make_one_shot_iterator() - - def inputs(self): - return self._distributed_input.get_next() + super(StandardInputStep, self).__init__(distribution) + self._distributed_input = distribution.distribute_dataset(dataset_fn) + self._iterator = self._distributed_input.make_one_shot_iterator() class StandardSingleLossStep(StandardInputStep): @@ -69,8 +60,8 @@ class StandardSingleLossStep(StandardInputStep): ```python ... - step = step_fn.StandardSingleLossStep(dataset, loss_fn, optimizer) - step.initialize(distribution) + step = step_fn.StandardSingleLossStep( + dataset, loss_fn, optimizer, distribution) # Run a single training step on a given DistributionStrategy: step(distribution) @@ -80,27 +71,43 @@ class StandardSingleLossStep(StandardInputStep): Args: dataset_fn: a function that returns a tf.data Dataset that produces the input for the model. - loss_fn: a function that returns loss. + loss_fn: a function that takes a context and inputs as arguments. It returns + the loss for those inputs. `context` is an instance of + `values.MultiStepContext` that will be passed when `loss_fn` is run. + `context` can be used to specify the outputs to be returned from + `loss_fn`, among other things. optimizer: an optimizer that implements an update rule. distribution: a `DistributionStrategy` object. """ - def __init__(self, dataset_fn, loss_fn, optimizer, distribution): - StandardInputStep.__init__(self, dataset_fn, distribution) + def __init__(self, dataset_fn, loss_fn, optimizer, distribution, + iterations_per_step=1): + super(StandardSingleLossStep, self).__init__(dataset_fn, distribution) self._loss_fn = loss_fn self._optimizer = optimizer self._is_run_concurrently = False + self._iterations_per_step = iterations_per_step - def step(self, inputs): + def __call__(self): with self._distribution.scope(): - gradients_fn = backprop.implicit_grad(self._loss_fn) - gradients_fn = optimizer_lib.get_filtered_grad_fn(gradients_fn) - - grads_and_vars = self.distribution.call_for_each_tower( - gradients_fn, inputs, run_concurrently=self._is_run_concurrently) - # If threads use layers, then we need to run the first step sequentially, - # so that layers.build() is not executed in parallel. Otherwise, multiple - # sets of mirrored variables are going to be created. - self._is_run_concurrently = True - return self._optimizer._distributed_apply( # pylint: disable=protected-access - self.distribution, grads_and_vars) + def step_fn(ctx, *inputs): + """Function to run one iteration with one input.""" + gradients_fn = backprop.implicit_grad(self._loss_fn) + gradients_fn = optimizer_lib.get_filtered_grad_fn(gradients_fn) + + grads_and_vars = self.distribution.call_for_each_tower( + gradients_fn, + ctx, *inputs, + run_concurrently=self._is_run_concurrently) + # If threads use layers, then we need to run the first step + # sequentially, so that layers.build() is not executed in parallel. + # Otherwise, multiple sets of mirrored variables are going to be + # created. + self._is_run_concurrently = True + return self._optimizer._distributed_apply( # pylint: disable=protected-access + self.distribution, grads_and_vars) + + # TODO(priyag): Return the outputs, context, etc as well. + ctx = self.distribution.run_steps_on_dataset( + step_fn, self._iterator, self._iterations_per_step) + return ctx.run_op diff --git a/tensorflow/contrib/distribute/python/step_fn_test.py b/tensorflow/contrib/distribute/python/step_fn_test.py index 2ee94d8f70868c07ca217dd4d433585458efa8d8..8605ab1f7daeb81e778577ad3c4a18b39c57d743 100644 --- a/tensorflow/contrib/distribute/python/step_fn_test.py +++ b/tensorflow/contrib/distribute/python/step_fn_test.py @@ -33,12 +33,19 @@ class SingleLossStepTest(test.TestCase, parameterized.TestCase): @combinations.generate( combinations.times( combinations.distributions_and_v1_optimizers(), - combinations.combine(mode=combinations.graph_and_eager_modes))) - def testTrainNetwork(self, distribution, optimizer_fn): + combinations.combine(mode=combinations.graph_and_eager_modes), + combinations.combine(is_tpu=[False])) + + combinations.combine( + distribution=[combinations.tpu_strategy], + optimizer_fn=combinations.optimizers_v1, + mode=["graph"], + is_tpu=[True])) + def testTrainNetwork(self, distribution, optimizer_fn, is_tpu): with distribution.scope(): single_loss_step, layer = single_loss_example( - optimizer_fn, distribution, use_bias=True) + optimizer_fn, distribution, use_bias=True, iterations_per_step=2) + self.evaluate(distribution.initialize()) if context.executing_eagerly(): run_step = single_loss_step else: @@ -47,12 +54,14 @@ class SingleLossStepTest(test.TestCase, parameterized.TestCase): self.evaluate(variables.global_variables_initializer()) weights, biases = [], [] - for _ in range(10): + for _ in range(5): run_step() weights.append(self.evaluate(layer.kernel)) biases.append(self.evaluate(layer.bias)) + self.evaluate(distribution.finalize()) + error = abs(numpy.add(numpy.squeeze(weights), numpy.squeeze(biases)) - 1) is_not_increasing = all(y <= x for x, y in zip(error, error[1:])) self.assertTrue(is_not_increasing) diff --git a/tensorflow/contrib/distribute/python/strategy_test_lib.py b/tensorflow/contrib/distribute/python/strategy_test_lib.py index baed0ebaae8a3f41c55f309d28203b363336dd16..371b97ba96a826194a6469ba63e485fc67639585 100644 --- a/tensorflow/contrib/distribute/python/strategy_test_lib.py +++ b/tensorflow/contrib/distribute/python/strategy_test_lib.py @@ -28,7 +28,7 @@ from tensorflow.python.layers import core from tensorflow.python.ops import array_ops from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variables -from tensorflow.python.training import distribute as distribute_lib +from tensorflow.python.training import distribution_strategy_context from tensorflow.python.training import optimizer @@ -45,7 +45,8 @@ def _raise_exception_fn(_=None): # Must be the argument to a distribution.call_for_each_tower() call, calls a # get_tower_context().merge_call() that raises an exception. def _merge_raises_fn(): - distribute_lib.get_tower_context().merge_call(_raise_exception_fn) + distribution_strategy_context.get_tower_context().merge_call( + _raise_exception_fn) # Must be the argument to a get_tower_context().merge_call() call, calls @@ -58,7 +59,7 @@ def _call_raises_fn(dist): # calls a get_tower_context().merge_call() that calls a # call_for_each_tower() that raises an exception. def _merge_call_raises_fn(): - distribute_lib.get_tower_context().merge_call(_call_raises_fn) + distribution_strategy_context.get_tower_context().merge_call(_call_raises_fn) # Must be the argument to a get_tower_context().merge_call() call, calls @@ -72,7 +73,8 @@ def _call_merge_raises_fn(dist): # get_tower_context().merge_call() that calls a call_for_each_tower() that # calls a get_tower_context().merge_call() that raises an exception. def _merge_call_merge_raises_fn(): - distribute_lib.get_tower_context().merge_call(_call_merge_raises_fn) + distribution_strategy_context.get_tower_context().merge_call( + _call_merge_raises_fn) class DistributionTestBase(test.TestCase): @@ -208,7 +210,7 @@ class DistributionTestBase(test.TestCase): expected_devices = [False] * len(d.worker_devices) def mark_devices_fn(): - tower_id = distribute_lib.get_tower_context().tower_id + tower_id = distribution_strategy_context.get_tower_context().tower_id self.assertLess(tower_id, len(d.worker_devices)) self.assertFalse(expected_devices[tower_id]) expected_devices[tower_id] = True diff --git a/tensorflow/contrib/distribute/python/tpu_strategy.py b/tensorflow/contrib/distribute/python/tpu_strategy.py index bc53898539d76320e331784f9a717be9491365e1..a4860030769fab92ec946c5a436240e7c88af1bf 100644 --- a/tensorflow/contrib/distribute/python/tpu_strategy.py +++ b/tensorflow/contrib/distribute/python/tpu_strategy.py @@ -21,40 +21,79 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib import tpu +from tensorflow.contrib.distribute.python import cross_tower_ops as cross_tower_ops_lib from tensorflow.contrib.distribute.python import one_device_strategy from tensorflow.contrib.distribute.python import values from tensorflow.contrib.tpu.python.ops import tpu_ops +from tensorflow.contrib.tpu.python.tpu import tpu +from tensorflow.contrib.tpu.python.tpu import tpu_system_metadata as tpu_system_metadata_lib +from tensorflow.contrib.tpu.python.tpu import training_loop +from tensorflow.python.eager import context from tensorflow.python.framework import constant_op from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import math_ops from tensorflow.python.ops import variable_scope as vs +from tensorflow.python.ops import variables as variables_lib +from tensorflow.python.training import device_util from tensorflow.python.util import nest +def get_tpu_system_metadata(tpu_cluster_resolver): + """Retrieves TPU system metadata given a TPUClusterResolver.""" + master = tpu_cluster_resolver.master() + + # pylint: disable=protected-access + cluster_spec = tpu_cluster_resolver.cluster_spec() + cluster_def = cluster_spec.as_cluster_def() if cluster_spec else None + tpu_system_metadata = ( + tpu_system_metadata_lib._query_tpu_system_metadata( + master, + cluster_def=cluster_def, + query_topology=False)) + + return tpu_system_metadata + + class TPUStrategy(one_device_strategy.OneDeviceStrategy): """Experimental TPU distribution strategy implementation.""" - def __init__(self, num_cores_per_host=2): + def __init__(self, tpu_cluster_resolver, steps_per_run): + """Initializes the TPUStrategy object. + + Args: + tpu_cluster_resolver: A tf.contrib.cluster_resolver.TPUClusterResolver, + which provides information about the TPU cluster. + steps_per_run: Number of steps to run on device before returning to the + host. Note that this can have side-effects on performance, hooks, + metrics, summaries etc. + This parameter is only used when Distribution Strategy is used with + estimator or keras. + """ # TODO(isaprykin): Generalize the defaults. They are currently tailored for # the unit test. - super(TPUStrategy, self).__init__('/cpu:0') - # TODO(isaprykin): Auto-detect number of cores and hosts. - self._num_cores_per_host = num_cores_per_host + super(TPUStrategy, self).__init__('/device:CPU:0') + + self._tpu_cluster_resolver = tpu_cluster_resolver + self._tpu_metadata = get_tpu_system_metadata(self._tpu_cluster_resolver) + # TODO(priyag): This should not be hardcoded here. - self._host = '/task:0/device:CPU:0' + self._host = '/device:CPU:0' + # TODO(sourabhbajaj): Remove this once performance of running one step + # at a time is comparable to multiple steps. + self.steps_per_run = steps_per_run def distribute_dataset(self, dataset_fn): # TODO(priyag): Perhaps distribute across cores here. return self._call_dataset_fn(dataset_fn) - # TODO(priyag): Deal with OutOfRange errors. + # TODO(priyag): Deal with OutOfRange errors once b/111349762 is fixed. # TODO(sourabhbajaj): Remove the initial_loop_values parameter when we have # a mechanism to infer the outputs of `fn`. Pending b/110550782. def _run_steps_on_dataset(self, fn, iterator, iterations, initial_loop_values=None): - # Enqueue ops + shapes = nest.flatten(iterator.output_shapes) if any([not s.is_fully_defined() for s in shapes]): raise ValueError( @@ -68,7 +107,7 @@ class TPUStrategy(one_device_strategy.OneDeviceStrategy): control_deps = [] sharded_inputs = [] with ops.device(self._host): - for _ in range(self._num_cores_per_host): + for _ in range(self.num_towers): # Use control dependencies to ensure a deterministic ordering. with ops.control_dependencies(control_deps): inputs = nest.flatten(iterator.get_next()) @@ -93,58 +132,130 @@ class TPUStrategy(one_device_strategy.OneDeviceStrategy): [constant_op.constant(0)], parallel_iterations=1) - # Dequeue ops def dequeue_fn(): - dequeued = tpu.infeed_dequeue_tuple(dtypes=types, shapes=shapes) + dequeued = tpu_ops.infeed_dequeue_tuple(dtypes=types, shapes=shapes) return nest.pack_sequence_as(iterator.output_shapes, dequeued) # Wrap `fn` for repeat. if initial_loop_values is None: - initial_loop_values = [] - ctx = values.MultiStepContext(initial_loop_values) + initial_loop_values = {} + initial_loop_values = nest.flatten(initial_loop_values) + ctx = values.MultiStepContext() def run_fn(*args, **kwargs): del args, kwargs - fn_result = fn(ctx, dequeue_fn()) - if ctx.last_step_outputs is None: - ctx.last_step_outputs = [] - with ops.control_dependencies([fn_result]): - return array_ops.identity(ctx.last_step_outputs) + fn_inputs = dequeue_fn() + if not isinstance(fn_inputs, tuple): + fn_inputs = (fn_inputs,) + fn_result = fn(ctx, *fn_inputs) + flat_last_step_outputs = nest.flatten(ctx.last_step_outputs) + if flat_last_step_outputs: + with ops.control_dependencies([fn_result]): + return [array_ops.identity(f) for f in flat_last_step_outputs] + else: + return fn_result - # Repeat # TODO(sourabhbajaj): The input to while loop should be based on the output # type of the step_fn def iterate_on_tpu(): - return tpu.repeat(iterations, run_fn, [initial_loop_values]) + return training_loop.repeat(iterations, run_fn, initial_loop_values) - # Re-write and distribute computation. - # TODO(sourabhbajaj): Convert the output to PerDevice variable and - # implement support for that in reduce. - last_step_tensor_outputs = tpu.batch_parallel( - iterate_on_tpu, [], num_shards=self._num_cores_per_host) + # We capture the control_flow_context at this point, before we run `fn` + # inside a while_loop and TPU replicate context. This is useful in cases + # where we might need to exit these contexts and get back to the outer + # context to do some things, for e.g. create an op which should be + # evaluated only once at the end of the loop on the host. One such usage + # is in creating metrics' value op. + self._outer_control_flow_context = ( + ops.get_default_graph()._get_control_flow_context()) # pylint: disable=protected-access - # Take index [0] of last_step_tensor_outputs as we wrapped - # initial_loop_values in a list in the `repeat` call. - return (control_flow_ops.group(last_step_tensor_outputs, enqueue_ops), - last_step_tensor_outputs[0], ctx) + replicate_inputs = [[]] * self.num_towers + replicate_outputs = tpu.replicate(iterate_on_tpu, replicate_inputs) + del self._outer_control_flow_context + ctx.run_op = control_flow_ops.group(replicate_outputs, enqueue_ops) + + # Filter out any ops from the outputs, typically this would be the case + # when there were no tensor outputs. + last_step_tensor_outputs = [x for x in replicate_outputs + if not isinstance(x, ops.Operation)] + + # Outputs are currently of the structure (grouped by device) + # [[output0_device0, output1_device0, output2_device0], + # [output0_device1, output1_device1, output2_device1]] + # Convert this to the following structure instead: (grouped by output) + # [[output0_device0, output0_device1], + # [output1_device0, output1_device1], + # [output2_device0, output2_device1]] + last_step_tensor_outputs = [list(x) for x in zip(*last_step_tensor_outputs)] + + # Convert replicate_outputs to the original dict structure of + # last_step_outputs. + last_step_tensor_outputs_dict = nest.pack_sequence_as( + ctx.last_step_outputs, last_step_tensor_outputs) + + for (name, aggregation) in ctx._last_step_outputs_aggregations.items(): # pylint: disable=protected-access + output = last_step_tensor_outputs_dict[name] + # For outputs that have already been aggregated, take the first value + # from the list as each value should be the same. Else return the full + # list of values. + if aggregation is not variables_lib.VariableAggregation.NONE: + # TODO(priyag): Should this return the element or a list with 1 element + last_step_tensor_outputs_dict[name] = output[0] + ctx._set_last_step_outputs(last_step_tensor_outputs_dict) # pylint: disable=protected-access + + return ctx def _call_for_each_tower(self, fn, *args, **kwargs): kwargs.pop('run_concurrently', None) with one_device_strategy._OneDeviceTowerContext(self): # pylint: disable=protected-access return fn(*args, **kwargs) - def get_initialization_ops(self): - return [tpu.initialize_system()] + def initialize(self): + if context.executing_eagerly(): + # TODO(priyag): Add appopriate call here when eager is supported for TPUs. + raise NotImplementedError('Eager mode not supported in TPUStrategy.') + else: + return [tpu.initialize_system()] - def get_finalize_ops(self): - return [tpu.shutdown_system()] + def finalize(self): + if context.executing_eagerly(): + # TODO(priyag): Add appopriate call here when eager is supported for TPUs. + raise NotImplementedError('Eager mode not supported in TPUStrategy.') + else: + return [tpu.shutdown_system()] def _reduce(self, aggregation, value, destinations): - del destinations # TPU is graph mode only. Rely on implicit Send/Recv. + graph = ops.get_default_graph() + cf_context = graph._get_control_flow_context() # pylint: disable=protected-access + # If we're inside the ReplicateContext, reduction should be done using + # CrossReplicaSum while outside we can directly use an add_n op. + while cf_context: + if isinstance(cf_context, tpu.TPUReplicateContext): + if aggregation == vs.VariableAggregation.MEAN: + # TODO(jhseu): Revisit once we support model-parallelism. + value *= (1. / self.num_towers) + return tpu_ops.cross_replica_sum(value) + cf_context = cf_context.outer_context + + # Validate that the destination is same as the host device + # Note we don't do this when in replicate context as the reduction is + # performed on the TPU device itself. + devices = cross_tower_ops_lib.get_devices_from(destinations) + if len(devices) == 1: + assert device_util.canonicalize(devices[0]) == device_util.canonicalize( + self._host) + else: + raise ValueError('Multiple devices are not supported for TPUStrategy') + + output = math_ops.add_n(value) if aggregation == vs.VariableAggregation.MEAN: - # TODO(jhseu): Revisit once we support model-parallelism. - value *= (1. / self._num_cores_per_host) - return tpu_ops.cross_replica_sum(value) + return output * (1. / len(value)) + return output + + def _unwrap(self, value): + if isinstance(value, list): + return value + return [value] @property def num_towers(self): - return self._num_cores_per_host + return self._tpu_metadata.num_of_cores_per_host diff --git a/tensorflow/contrib/distribute/python/values.py b/tensorflow/contrib/distribute/python/values.py index 47dcf679c2a6280d4be523b7fb04f0d2ba5855e8..a58bb3a8492a372d29089db0943e2e993ba47ad3 100644 --- a/tensorflow/contrib/distribute/python/values.py +++ b/tensorflow/contrib/distribute/python/values.py @@ -35,8 +35,10 @@ from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import variable_scope as vs +from tensorflow.python.ops import variables as variables_lib from tensorflow.python.training import device_util from tensorflow.python.training import distribute as distribute_lib +from tensorflow.python.training import distribution_strategy_context from tensorflow.python.training import saver from tensorflow.python.training.checkpointable import base as checkpointable from tensorflow.python.util import nest @@ -55,7 +57,7 @@ class DistributedValues(object): def get(self, device=None): """Returns the value for the current device or raises a ValueError.""" if device is None: - tower_context = distribute_lib.get_tower_context() + tower_context = distribution_strategy_context.get_tower_context() if tower_context: device = tower_context.device else: @@ -210,6 +212,11 @@ class DistributedVariable(DistributedDelegate): # without this it will use `__getattr__` which will delegate to a component # variable. self._keras_initialized = False + # Typically, a `DistributedVariable`'s initializer is composed of the + # initializers of the components variables. However, in some cases, such as + # when restoring from a checkpoint, we may set the _initializer_op + # property on the entire `DistributedVariable`. + self._initializer_op = None super(DistributedVariable, self).__init__(index) def is_initialized(self, name=None): @@ -239,9 +246,14 @@ class DistributedVariable(DistributedDelegate): @property def initializer(self): - # return grouped ops of all the var initializations of component values of - # the mirrored variable - return control_flow_ops.group([v.initializer for v in self._index.values()]) + if self._initializer_op: + init_op = self._initializer_op + else: + # return grouped ops of all the var initializations of component values of + # the mirrored variable + init_op = control_flow_ops.group( + [v.initializer for v in self._index.values()]) + return init_op @property def graph(self): @@ -278,12 +290,16 @@ class DistributedVariable(DistributedDelegate): # We want cross-tower code that does some var.op.X calls # to work (even if the current device isn't in self.devices), but # other uses of var.op in a cross-tower context to fail. - if distribute_lib.get_cross_tower_context(): + if distribution_strategy_context.get_cross_tower_context(): return DistributedVarOp(self._primary_var.op.name, self._primary_var.op.graph, self._primary_var.op.type) return self.get().op + def read_value(self): + return distribution_strategy_context.get_distribution_strategy().read_var( + self) + def _should_act_as_resource_variable(self): """Pass resource_variable_ops.is_resource_variable check.""" pass @@ -292,26 +308,6 @@ class DistributedVariable(DistributedDelegate): ops.register_dense_tensor_like_type(DistributedVariable) -def _get_update_device(): - """Validate we are in update/update_non_slot() and return current device. - - This is used in MirroredVariable.assign* members, to make sure they - are only called via an update method, to make sure all components of the - variable are being updated in a consistent way. - - Returns: - A string device. - - Raises: - RuntimeError: If not in distribution.update()/.update_non_slot(). - """ - device = distribute_lib.get_update_device() - if device is None: - raise RuntimeError( - "Use DistributionStrategy.update() to modify a MirroredVariable.") - return device - - class _MirroredSaveable(saver.BaseSaverBuilder.ResourceVariableSaveable): """Class for defining how to restore a MirroredVariable.""" @@ -348,16 +344,17 @@ class MirroredVariable(DistributedVariable, Mirrored, # update several non-slot variables in one call. def _assign_func(self, *args, **kwargs): f = kwargs.pop("f") - if distribute_lib.get_cross_tower_context(): + if distribution_strategy_context.get_cross_tower_context(): update_device = distribute_lib.get_update_device() - # We are calling update on the mirrored variable in cross tower context. if update_device is not None: - # We are calling an assign function on the mirrored variable in cross - # tower context. + # We are calling an assign function on the mirrored variable in an + # update context. v = self.get(device=update_device) return f(v, *args, **kwargs) - return distribute_lib.get_distribution_strategy().update( + # We are calling assign on the mirrored variable in cross tower context, + # use update to update the variable. + return distribution_strategy_context.get_distribution_strategy().update( self, f, *args, **kwargs) else: _assert_tower_context() @@ -378,8 +375,8 @@ class MirroredVariable(DistributedVariable, Mirrored, aggregation=self._aggregation, value=value, destinations=self), *other_args, **other_kwargs) - return distribute_lib.get_tower_context().merge_call(merge_fn, *args, - **kwargs) + return distribution_strategy_context.get_tower_context().merge_call( + merge_fn, *args, **kwargs) def assign_sub(self, *args, **kwargs): assign_sub_fn = lambda var, *a, **kw: var.assign_sub(*a, **kw) @@ -405,7 +402,7 @@ class MirroredVariable(DistributedVariable, Mirrored, def _as_graph_element(self): # pylint: disable=protected-access - if distribute_lib.get_cross_tower_context(): + if distribution_strategy_context.get_cross_tower_context(): return self._primary_var._as_graph_element() return self.get()._as_graph_element() @@ -445,7 +442,7 @@ class _TowerLocalSaveable(saver.BaseSaverBuilder.SaveableObject): # We use a callable so that we don't have to evaluate this expression # in the case where we are trying to restore instead of save. def tensor(): - return distribute_lib.get_distribution_strategy().read_var( + return distribution_strategy_context.get_distribution_strategy().read_var( tower_local_variable) spec = saver.BaseSaverBuilder.SaveSpec( tensor=tensor, @@ -461,7 +458,7 @@ class _TowerLocalSaveable(saver.BaseSaverBuilder.SaveableObject): def _assert_tower_context(): - if not distribute_lib.get_tower_context(): + if not distribution_strategy_context.get_tower_context(): raise RuntimeError( "Tower-local variables may only be assigned in a tower context.") @@ -484,7 +481,7 @@ class TowerLocalVariable(DistributedVariable, PerDevice, return self.get().assign_add(*args, **kwargs) def assign(self, *args, **kwargs): - if distribute_lib.get_cross_tower_context(): + if distribution_strategy_context.get_cross_tower_context(): # To preserve the sum across save and restore, we have to divide the # total across all devices when restoring a variable that was summed # when saving. @@ -512,7 +509,7 @@ class TowerLocalVariable(DistributedVariable, PerDevice, def _as_graph_element(self): # pylint: disable=protected-access - if distribute_lib.get_cross_tower_context(): + if distribution_strategy_context.get_cross_tower_context(): return self._get_cross_tower() return self.get()._as_graph_element() @@ -921,64 +918,276 @@ class MultiStepContext(object): This context object is useful when running multiple steps at a time using the `run_steps_on_dataset` API. For e.g. it allows the user's step function to - specify which outputs to emit at what frequency. Currently it only supports - capturing output from the last step, but will soon be augmented to support - other use cases such as output each N steps. + specify which outputs to emit at what frequency. Currently it supports + capturing output from the last step, as well as capturing non tensor outputs. + In the future it will be augmented to support other use cases such as output + each N steps. """ - def __init__(self, initial_loop_values=None): + def __init__(self): """Initializes an output context. - Args: - initial_loop_values: Initial values passed to the run steps - while loop. The only purpose is to verify the shapes and types - when the actual output is set. This will be removed once we - automatically infer the output shapes and types (and do not need to - check for user error in specifying them manually). Returns: A context object. """ - self._last_step_outputs = None - self._non_tensor_outputs = None - self._initial_loop_values = initial_loop_values + self._last_step_outputs = {} + self._last_step_outputs_aggregations = {} + self._non_tensor_outputs = {} @property def last_step_outputs(self): - """Return the last step's outputs.""" + """A dictionary consisting of outputs to be captured on last step. + + Keys in the dictionary are names of tensors to be captured, as specified + when `set_last_step_output` is called. + Values in the dictionary are the tensors themselves. If + `set_last_step_output` was called with an `aggregation` for this output, + then the value is the aggregated value. + + Returns: + A dictionary with last step outputs. + """ return self._last_step_outputs - @last_step_outputs.setter - def last_step_outputs(self, outputs): - """Set the last step's outputs.""" - self._verify_structure_shapes_types(outputs, self._initial_loop_values) + def _set_last_step_outputs(self, outputs): + """Replace the entire dictionary of last step outputs.""" + if not isinstance(outputs, dict): + raise ValueError("Need a dictionary to set last_step_outputs.") self._last_step_outputs = outputs + def set_last_step_output(self, name, output, + aggregation=variables_lib.VariableAggregation.NONE): + """Set `output` with `name` to be outputted from the last step. + + Args: + name: String, name to identify the output. Doesn't need to match tensor + name. + output: The tensors that should be outputted with `name`. See below for + actual types supported. + aggregation: Aggregation method to use to aggregate outputs from multiple + towers. Required if `set_last_step_output` is called in a tower context. + Optional in cross_tower_context. + When present, the outputs from all the towers are aggregated using the + current distribution strategy's `reduce` method. Hence, the type of + `output` must be what's supported by the corresponding `reduce` method. + For e.g. if using MirroredStrategy and aggregation is set, output + must be a `PerDevice` value. + The aggregation method is also recorded in a dictionary + `_last_step_outputs_aggregations` for later interpreting of the + outputs as already reduced or not. + + """ + if distribution_strategy_context.get_cross_tower_context(): + self._last_step_outputs_aggregations[name] = aggregation + if aggregation is variables_lib.VariableAggregation.NONE: + self._last_step_outputs[name] = output + else: + distribution = distribution_strategy_context.get_distribution_strategy() + self._last_step_outputs[name] = distribution.reduce( + aggregation, output, destinations="/device:CPU:0") + else: + assert aggregation is not variables_lib.VariableAggregation.NONE + def merge_fn(distribution, value): + self._last_step_outputs[name] = distribution.reduce( + aggregation, value, destinations="/device:CPU:0") + # Setting this inside the `merge_fn` because all towers share the same + # context object, so it's more robust to set it only once (even if all + # the towers are trying to set the same value). + self._last_step_outputs_aggregations[name] = aggregation + + distribution_strategy_context.get_tower_context().merge_call( + merge_fn, output) + @property def non_tensor_outputs(self): - """Return the non tensor outputs.""" + """A dictionary consisting of any non tensor outputs to be captured.""" return self._non_tensor_outputs - @non_tensor_outputs.setter - def non_tensor_outputs(self, outputs): - """Set any non tensor outputs.""" - self._non_tensor_outputs = outputs - - def _verify_structure_shapes_types(self, left, right): - """Verify that the structure, shapes and types of left are same as right.""" - nest.assert_same_structure(left, right) - flat_left = nest.flatten(left) - flat_right = nest.flatten(right) - assert len(flat_left) == len(flat_right), ( - "Length of left {} and right {} should be same.". - format(len(flat_left), len(flat_right))) - - for o, i in zip(flat_left, flat_right): - # TODO(priyag): Add checks for other types like IndexedSlices. - if isinstance(o, ops.Tensor): - assert isinstance(i, ops.Tensor) - assert o.shape == i.shape, ( - "Shape {} of left {} doesn't match shape {} of right {}.". - format(o.shape, o, i.shape, i)) - assert o.dtype == i.dtype, ( - "Dtype {} of left {} doesn't match dtype {} of right {}.". - format(o.dtype, o, i.dtype, i)) + def set_non_tensor_output(self, name, output): + """Set `output` with `name` to be captured as a non tensor output.""" + if distribution_strategy_context.get_cross_tower_context(): + self._non_tensor_outputs[name] = output + else: + def merge_fn(distribution, value): + # NOTE(priyag): For non tensor outputs, we simply return all the values + # in a list as aggregation doesn't make sense on non tensors. + self._non_tensor_outputs[name] = distribution.unwrap(value) + distribution_strategy_context.get_tower_context().merge_call( + merge_fn, output) + + +def value_container(val): + """Returns the container that this per-device `value` belongs to. + + Args: + val: A value returned by `call_for_each_tower()` or a variable + created in `scope()`. + + Returns: + A container that `value` belongs to. + If value does not belong to any container (including the case of + container having been destroyed), returns the value itself. + """ + # pylint: disable=protected-access + if (hasattr(val, "_distributed_container") and + # DistributedVariable has _distributed_container defined + # but we don't want to return it. + not isinstance(val, DistributedVariable)): + container = val._distributed_container() + # pylint: disable=protected-access + if container is not None: + return container + return val + + +# TODO(josh11b): Descend from Variable. +class AggregatingVariable(checkpointable.CheckpointableBase): + """A wrapper around a variable that aggregates updates across towers.""" + + def __init__(self, v, aggregation): + self._v = v + # TODO(josh11b): Set v._distributed_container? + # v._distributed_container = weakref.ref(self) # pylint: disable=protected-access + self._aggregation = aggregation + + def get(self): + return self._v + + def __getattr__(self, name): + return getattr(self._v, name) + + def _assign_func(self, *args, **kwargs): + f = kwargs.pop("f") + if distribution_strategy_context.get_cross_tower_context(): + update_device = distribute_lib.get_update_device() + if update_device is not None: + # We are calling an assign function in an update context. + return f(self._v, *args, **kwargs) + + # We are calling an assign function in cross tower context, wrap it in an + # update call. + return distribution_strategy_context.get_distribution_strategy().update( + self, f, *args, **kwargs) + else: + assert distribution_strategy_context.get_tower_context() + # We are calling an assign function in tower context. + # We reduce the value we want to assign/add/sub. More details about how we + # handle the different use cases can be found in the _reduce method. + # We call the function with the reduced value. + if self._aggregation == vs.VariableAggregation.NONE: + raise ValueError("You must specify an aggregation method to update a " + "a variable in Tower Context.") + + def merge_fn(strategy, value, *other_args, **other_kwargs): + return strategy.update( + self, f, + strategy.reduce( + aggregation=self._aggregation, value=value, destinations=self), + *other_args, **other_kwargs) + + return distribution_strategy_context.get_tower_context().merge_call( + merge_fn, *args, **kwargs) + + def assign_sub(self, *args, **kwargs): + assign_sub_fn = lambda var, *a, **kw: var.assign_sub(*a, **kw) + return self._assign_func(f=assign_sub_fn, *args, **kwargs) + + def assign_add(self, *args, **kwargs): + assign_add_fn = lambda var, *a, **kw: var.assign_add(*a, **kw) + return self._assign_func(f=assign_add_fn, *args, **kwargs) + + def assign(self, *args, **kwargs): + assign_fn = lambda var, *a, **kw: var.assign(*a, **kw) + return self._assign_func(f=assign_fn, *args, **kwargs) + + @property + def aggregation(self): + return self._aggregation + + @property + def name(self): + return self._v.name + + @property + def dtype(self): + return self._v.dtype + + # TODO(josh11b): Test saving & restoring. + def _gather_saveables_for_checkpoint(self): + return {checkpointable.VARIABLE_VALUE_KEY: self._v} + + # pylint: disable=multiple-statements + def __add__(self, o): return self._v + o + def __radd__(self, o): return o + self._v + def __sub__(self, o): return self._v - o + def __rsub__(self, o): return o - self._v + def __mul__(self, o): return self._v * o + def __rmul__(self, o): return o * self._v + def __truediv__(self, o): return self._v / o + def __rtruediv__(self, o): return o / self._v + def __floordiv__(self, o): return self._v // o + def __rfloordiv__(self, o): return o // self._v + def __mod__(self, o): return self._v % o + def __rmod__(self, o): return o % self._v + def __lt__(self, o): return self._v < o + def __le__(self, o): return self._v <= o + def __gt__(self, o): return self._v > o + def __ge__(self, o): return self._v >= o + def __and__(self, o): return self._v & o + def __rand__(self, o): return o & self._v + def __or__(self, o): return self._v | o + def __ror__(self, o): return o | self._v + def __xor__(self, o): return self._v ^ o + def __rxor__(self, o): return o ^ self._v + def __getitem__(self, o): return self._v[o] + def __pow__(self, o, modulo=None): return pow(self._v, o, modulo) + def __rpow__(self, o): return pow(o, self._v) + def __invert__(self): return ~self._v + def __neg__(self): return -self._v + def __abs__(self): return abs(self._v) + + def __div__(self, o): + try: + return self._v.__div__(o) + except AttributeError: + # See https://docs.python.org/3/library/constants.html#NotImplemented + return NotImplemented + + def __rdiv__(self, o): + try: + return self._v.__rdiv__(o) + except AttributeError: + # See https://docs.python.org/3/library/constants.html#NotImplemented + return NotImplemented + + def __matmul__(self, o): + try: + return self._v.__matmul__(o) + except AttributeError: + # See https://docs.python.org/3/library/constants.html#NotImplemented + return NotImplemented + + def __rmatmul__(self, o): + try: + return self._v.__rmatmul__(o) + except AttributeError: + # See https://docs.python.org/3/library/constants.html#NotImplemented + return NotImplemented + + def __str__(self): + return str(self._v) + + def __repr__(self): + return repr(self._v) + + +# Register a conversion function which reads the value of the variable, +# allowing instances of the class to be used as tensors. +def _tensor_conversion_aggregate(var, dtype=None, name=None, as_ref=False): + return ops.internal_convert_to_tensor( + var.get(), dtype=dtype, name=name, as_ref=as_ref) + + +ops.register_tensor_conversion_function( + AggregatingVariable, _tensor_conversion_aggregate) +ops.register_dense_tensor_like_type(AggregatingVariable) diff --git a/tensorflow/contrib/distributions/BUILD b/tensorflow/contrib/distributions/BUILD index ad00d1734dd14ed846522a33d888a5387cb25cc6..a8d0d493abcd7de540799f6b94c3cdb9ce9dafae 100644 --- a/tensorflow/contrib/distributions/BUILD +++ b/tensorflow/contrib/distributions/BUILD @@ -124,7 +124,7 @@ cuda_py_test( cuda_py_test( name = "conditional_distribution_test", - size = "small", + size = "medium", srcs = [ "python/kernel_tests/conditional_distribution_test.py", "python/kernel_tests/distribution_test.py", diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/matrix_inverse_tril_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/matrix_inverse_tril_test.py index 85d604e34ac25cf94b601470b7f166d9d414a8e3..49a9afe3f6debe048369c52328fb5534946ab9e5 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/matrix_inverse_tril_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/matrix_inverse_tril_test.py @@ -29,6 +29,17 @@ from tensorflow.python.platform import test class MatrixInverseTriLBijectorTest(test.TestCase): """Tests the correctness of the Y = inv(tril) transformation.""" + #The inverse of 0 is undefined, as the numbers above the main + #diagonal must be zero, we zero out these numbers after running inverse. + #See: https://github.com/numpy/numpy/issues/11445 + def _inv(self, x): + y = np.linalg.inv(x) + #triu_indices only works on 2d arrays + #need to iterate over all the 2d arrays in a x-dimensional array. + for idx in np.ndindex(y.shape[0:-2]): + y[idx][np.triu_indices(y[idx].shape[-1], 1)] = 0 + return y + @test_util.run_in_graph_and_eager_modes def testComputesCorrectValues(self): inv = bijectors.MatrixInverseTriL(validate_args=True) @@ -98,7 +109,7 @@ class MatrixInverseTriLBijectorTest(test.TestCase): [2., 3.]]], [[[4., 0.], [5., -6.]]]], dtype=np.float32) - x_inv_ = np.linalg.inv(x_) + x_inv_ = self._inv(x_) expected_fldj_ = -4. * np.sum( np.log(np.abs(np.diagonal(x_, axis1=-2, axis2=-1))), axis=-1) diff --git a/tensorflow/contrib/distributions/python/kernel_tests/deterministic_test.py b/tensorflow/contrib/distributions/python/kernel_tests/deterministic_test.py index 90910f3839b1a4e882debf396b90955a42762794..200310bc414b6703d0683ce9f81b0aa5441f677d 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/deterministic_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/deterministic_test.py @@ -173,6 +173,13 @@ class DeterministicTest(test.TestCase): self.assertAllClose( np.zeros(sample_shape_ + (2,)).astype(np.float32), sample_) + def testEntropy(self): + loc = np.array([-0.1, -3.2, 7.]) + deterministic = deterministic_lib.Deterministic(loc=loc) + with self.test_session() as sess: + entropy_ = sess.run(deterministic.entropy()) + self.assertAllEqual(np.zeros(3), entropy_) + class VectorDeterministicTest(test.TestCase): @@ -290,6 +297,13 @@ class VectorDeterministicTest(test.TestCase): self.assertAllClose( np.zeros(sample_shape_ + (2, 1)).astype(np.float32), sample_) + def testEntropy(self): + loc = np.array([[8.3, 1.2, 3.3], [-0.1, -3.2, 7.]]) + deterministic = deterministic_lib.VectorDeterministic(loc=loc) + with self.test_session() as sess: + entropy_ = sess.run(deterministic.entropy()) + self.assertAllEqual(np.zeros(2), entropy_) + if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/distributions/python/ops/deterministic.py b/tensorflow/contrib/distributions/python/ops/deterministic.py index ad853ee293f86565c1af601214522f53d936b70a..affc64a14f6fe9ae6e08ceff2298bc99ee7caa43 100644 --- a/tensorflow/contrib/distributions/python/ops/deterministic.py +++ b/tensorflow/contrib/distributions/python/ops/deterministic.py @@ -152,6 +152,9 @@ class _BaseDeterministic(distribution.Distribution): """Relative tolerance for comparing points to `self.loc`.""" return self._rtol + def _entropy(self): + return array_ops.zeros(self.batch_shape_tensor(), dtype=self.dtype) + def _mean(self): return array_ops.identity(self.loc) diff --git a/tensorflow/contrib/distributions/python/ops/quantized_distribution.py b/tensorflow/contrib/distributions/python/ops/quantized_distribution.py index ef3bdfa75fcaa8df17db1238ceadadf788601356..18a0f754e6e618f240db109f593a80dec57e200b 100644 --- a/tensorflow/contrib/distributions/python/ops/quantized_distribution.py +++ b/tensorflow/contrib/distributions/python/ops/quantized_distribution.py @@ -326,6 +326,21 @@ class QuantizedDistribution(distributions.Distribution): graph_parents=graph_parents, name=name) + @property + def distribution(self): + """Base distribution, p(x).""" + return self._dist + + @property + def low(self): + """Lowest value that quantization returns.""" + return self._low + + @property + def high(self): + """Highest value that quantization returns.""" + return self._high + def _batch_shape_tensor(self): return self.distribution.batch_shape_tensor() @@ -569,8 +584,3 @@ class QuantizedDistribution(distributions.Distribution): dependencies = [distribution_util.assert_integer_form( value, message="value has non-integer components.")] return control_flow_ops.with_dependencies(dependencies, value) - - @property - def distribution(self): - """Base distribution, p(x).""" - return self._dist diff --git a/tensorflow/contrib/distributions/python/ops/sample_stats.py b/tensorflow/contrib/distributions/python/ops/sample_stats.py index f5aaa5cf34abde3ea4d25de1ecf3adaef3f2a770..aa680a92be64cf0f099acd335369f2a1610c5953 100644 --- a/tensorflow/contrib/distributions/python/ops/sample_stats.py +++ b/tensorflow/contrib/distributions/python/ops/sample_stats.py @@ -134,7 +134,7 @@ def auto_correlation( x_len = util.prefer_static_shape(x_rotated)[-1] # TODO(langmore) Investigate whether this zero padding helps or hurts. At - # the moment is is necessary so that all FFT implementations work. + # the moment is necessary so that all FFT implementations work. # Zero pad to the next power of 2 greater than 2 * x_len, which equals # 2**(ceil(Log_2(2 * x_len))). Note: Log_2(X) = Log_e(X) / Log_e(2). x_len_float64 = math_ops.cast(x_len, np.float64) @@ -198,7 +198,7 @@ def auto_correlation( # Recall R[m] is a sum of N / 2 - m nonzero terms x[n] Conj(x[n - m]). The # other terms were zeros arising only due to zero padding. # `denominator = (N / 2 - m)` (defined below) is the proper term to - # divide by by to make this an unbiased estimate of the expectation + # divide by to make this an unbiased estimate of the expectation # E[X[n] Conj(X[n - m])]. x_len = math_ops.cast(x_len, dtype.real_dtype) max_lags = math_ops.cast(max_lags, dtype.real_dtype) diff --git a/tensorflow/contrib/eager/python/BUILD b/tensorflow/contrib/eager/python/BUILD index 0cc764d2208c5b061b7b836bdf57a035f52c6fcf..fa3f1bb7ad187993379afeedf3790c789b4538aa 100644 --- a/tensorflow/contrib/eager/python/BUILD +++ b/tensorflow/contrib/eager/python/BUILD @@ -104,7 +104,6 @@ cuda_py_test( "//tensorflow/python:array_ops", "//tensorflow/python:client", "//tensorflow/python:client_testlib", - "//tensorflow/python/eager:graph_callable", "//tensorflow/python/eager:test", "//tensorflow/python:variables", ], @@ -199,7 +198,7 @@ py_library( "//tensorflow/python:training", "//tensorflow/python:variable_scope", "//tensorflow/python/eager:context", - "//tensorflow/python/estimator:util", + "//tensorflow/python/estimator:estimator_py", ], ) @@ -223,3 +222,17 @@ py_test( "//tensorflow/python/eager:test", ], ) + +py_test( + name = "remote_test", + srcs = ["remote_test.py"], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow/contrib/eager/python:tfe", + "//tensorflow/python:array_ops", + "//tensorflow/python:client", + "//tensorflow/python:framework", + "//tensorflow/python:math_ops", + "//tensorflow/python/eager:function", + ], +) diff --git a/tensorflow/contrib/eager/python/datasets.py b/tensorflow/contrib/eager/python/datasets.py index e31dbbe80f9634e8e45ec91bf395eab82942c8ce..135095a97980da8988b976948fb18492526e390c 100644 --- a/tensorflow/contrib/eager/python/datasets.py +++ b/tensorflow/contrib/eager/python/datasets.py @@ -22,16 +22,13 @@ from tensorflow.contrib.data.python.ops import prefetching_ops from tensorflow.python.data.ops import iterator_ops from tensorflow.python.eager import context from tensorflow.python.framework import ops -from tensorflow.python.ops import gen_dataset_ops -from tensorflow.python.training.checkpointable import base as checkpointable -from tensorflow.python.training.saver import BaseSaverBuilder -class Iterator(iterator_ops.EagerIterator, checkpointable.CheckpointableBase): +class Iterator(iterator_ops.EagerIterator): """An iterator producing tf.Tensor objects from a tf.data.Dataset. NOTE: Unlike the iterator created by the - @{tf.data.Dataset.make_one_shot_iterator} method, this class enables + `tf.data.Dataset.make_one_shot_iterator` method, this class enables additional experimental functionality, such as prefetching to the GPU. """ @@ -82,30 +79,3 @@ class Iterator(iterator_ops.EagerIterator, checkpointable.CheckpointableBase): # TODO(b/77291417): Fix with context.execution_mode(context.SYNC): return super(Iterator, self)._next_internal() - - # TODO(shivaniagrawal): Expose checkpointable stateful objects from dataset - # attributes(potential). - - class _Saveable(BaseSaverBuilder.SaveableObject): - """SaveableObject for saving/restoring iterator state.""" - - def __init__(self, iterator_resource, name): - serialized_iterator = gen_dataset_ops.serialize_iterator( - iterator_resource) - specs = [ - BaseSaverBuilder.SaveSpec(serialized_iterator, "", name + "_STATE") - ] - # pylint: disable=protected-access - super(Iterator._Saveable, self).__init__(iterator_resource, specs, name) - - def restore(self, restored_tensors, restored_shapes): - with ops.colocate_with(self.op): - return gen_dataset_ops.deserialize_iterator(self.op, - restored_tensors[0]) - - def _gather_saveables_for_checkpoint(self): - - def _saveable_factory(name): - return self._Saveable(self._resource, name) - - return {"ITERATOR": _saveable_factory} diff --git a/tensorflow/contrib/eager/python/datasets_test.py b/tensorflow/contrib/eager/python/datasets_test.py index acc605247faffcf7ba83891dacdab13fc8c8574a..a753d77580758af9de8410de4a08f7ea278c4c79 100644 --- a/tensorflow/contrib/eager/python/datasets_test.py +++ b/tensorflow/contrib/eager/python/datasets_test.py @@ -37,6 +37,7 @@ from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor from tensorflow.python.ops import math_ops from tensorflow.python.ops import script_ops +from tensorflow.python.training import checkpoint_management from tensorflow.python.training.checkpointable import util as checkpointable_utils @@ -306,6 +307,19 @@ class IteratorTest(test.TestCase): checkpoint.restore(save_path) self.assertEqual(2, iterator.get_next().numpy()) + def testRestoreInReconstructedIterator(self): + checkpoint_directory = self.get_temp_dir() + checkpoint_prefix = os.path.join(checkpoint_directory, 'ckpt') + dataset = Dataset.range(10) + for i in range(5): + iterator = datasets.Iterator(dataset) + checkpoint = checkpointable_utils.Checkpoint(iterator=iterator) + checkpoint.restore(checkpoint_management.latest_checkpoint( + checkpoint_directory)) + for j in range(2): + self.assertEqual(i * 2 + j, iterator.get_next().numpy()) + checkpoint.save(file_prefix=checkpoint_prefix) + class DatasetConstructorBenchmark(test.Benchmark): diff --git a/tensorflow/contrib/eager/python/examples/BUILD b/tensorflow/contrib/eager/python/examples/BUILD index 12155a459c29c353c57679c407e7dda25047a35c..6f02c90368d966b8cf8d0dee09f9d2a5013c90c1 100644 --- a/tensorflow/contrib/eager/python/examples/BUILD +++ b/tensorflow/contrib/eager/python/examples/BUILD @@ -15,8 +15,6 @@ py_library( "//tensorflow/contrib/eager/python/examples/revnet:config", "//tensorflow/contrib/eager/python/examples/rnn_colorbot", "//tensorflow/contrib/eager/python/examples/rnn_ptb", - "//tensorflow/contrib/eager/python/examples/sagan", - "//tensorflow/contrib/eager/python/examples/sagan:config", "//tensorflow/contrib/eager/python/examples/spinn:data", ], ) diff --git a/tensorflow/contrib/eager/python/examples/densenet/densenet_graph_test.py b/tensorflow/contrib/eager/python/examples/densenet/densenet_graph_test.py index bd0057fb1a0175a805a0f7a1e4dcaa2bdc3c435a..4b3cb624bc947a1d1956eff6accb6d4da3bf3b87 100644 --- a/tensorflow/contrib/eager/python/examples/densenet/densenet_graph_test.py +++ b/tensorflow/contrib/eager/python/examples/densenet/densenet_graph_test.py @@ -128,8 +128,10 @@ class DensenetBenchmark(tf.test.Benchmark): weight_decay=1e-4, dropout_rate=0, pool_initial=True, include_top=True) logits = model(images, training=True) - loss = tf.losses.softmax_cross_entropy( + cross_ent = tf.losses.softmax_cross_entropy( logits=logits, onehot_labels=labels) + regularization = tf.add_n(model.losses) + loss = cross_ent + regularization optimizer = tf.train.GradientDescentOptimizer(learning_rate=1.0) train_op = optimizer.minimize(loss) diff --git a/tensorflow/contrib/eager/python/examples/densenet/densenet_test.py b/tensorflow/contrib/eager/python/examples/densenet/densenet_test.py index 4f19711fb87d6b5558302fd69104aca7e2cf403e..e5058bfd9480e25b3cf040f0d96bf21242a147b8 100644 --- a/tensorflow/contrib/eager/python/examples/densenet/densenet_test.py +++ b/tensorflow/contrib/eager/python/examples/densenet/densenet_test.py @@ -98,12 +98,52 @@ class DensenetTest(tf.test.TestCase): output_shape = model(rand_input).shape self.assertEqual(output_shape, (batch_size, output_classes)) + def test_regularization(self): + if tf.test.is_gpu_available(): + rand_input = tf.random_uniform((10, 3, 32, 32)) + data_format = 'channels_first' + else: + rand_input = tf.random_uniform((10, 32, 32, 3)) + data_format = 'channels_last' + weight_decay = 1e-4 + + conv = tf.keras.layers.Conv2D( + 3, (3, 3), + padding='same', + use_bias=False, + data_format=data_format, + kernel_regularizer=tf.keras.regularizers.l2(weight_decay)) + optimizer = tf.train.GradientDescentOptimizer(0.1) + conv(rand_input) # Initialize the variables in the layer + + def compute_true_l2(vs, wd): + return tf.reduce_sum(tf.square(vs)) * wd + + true_l2 = compute_true_l2(conv.variables, weight_decay) + keras_l2 = tf.add_n(conv.losses) + self.assertAllClose(true_l2, keras_l2) + + with tf.GradientTape() as tape_true, tf.GradientTape() as tape_keras: + loss = tf.reduce_sum(conv(rand_input)) + loss_with_true_l2 = loss + compute_true_l2(conv.variables, weight_decay) + loss_with_keras_l2 = loss + tf.add_n(conv.losses) + + true_grads = tape_true.gradient(loss_with_true_l2, conv.variables) + keras_grads = tape_keras.gradient(loss_with_keras_l2, conv.variables) + self.assertAllClose(true_grads, keras_grads) + + optimizer.apply_gradients(zip(keras_grads, conv.variables)) + keras_l2_after_update = tf.add_n(conv.losses) + self.assertNotAllClose(keras_l2, keras_l2_after_update) + def compute_gradients(model, images, labels): with tf.GradientTape() as tape: logits = model(images, training=True) - loss = tf.losses.softmax_cross_entropy( + cross_ent = tf.losses.softmax_cross_entropy( logits=logits, onehot_labels=labels) + regularization = tf.add_n(model.losses) + loss = cross_ent + regularization tf.contrib.summary.scalar(name='loss', tensor=loss) return tape.gradient(loss, model.variables) @@ -178,7 +218,7 @@ class DensenetBenchmark(tf.test.Benchmark): tf.constant(1.).cpu() def _benchmark_eager_apply(self, label, device_and_format, defun=False, - execution_mode=None, compiled=False): + execution_mode=None): with tfe.execution_mode(execution_mode): device, data_format = device_and_format model = densenet.DenseNet(self.depth, self.growth_rate, self.num_blocks, @@ -188,7 +228,7 @@ class DensenetBenchmark(tf.test.Benchmark): weight_decay=1e-4, dropout_rate=0, pool_initial=True, include_top=True) if defun: - model.call = tfe.defun(model.call, compiled=compiled) + model.call = tfe.defun(model.call) batch_size = 64 num_burn = 5 num_iters = 30 @@ -224,8 +264,7 @@ class DensenetBenchmark(tf.test.Benchmark): make_iterator, device_and_format, defun=False, - execution_mode=None, - compiled=False): + execution_mode=None): with tfe.execution_mode(execution_mode): device, data_format = device_and_format for batch_size in self._train_batch_sizes(): @@ -239,8 +278,8 @@ class DensenetBenchmark(tf.test.Benchmark): optimizer = tf.train.GradientDescentOptimizer(0.1) apply_grads = apply_gradients if defun: - model.call = tfe.defun(model.call, compiled=compiled) - apply_grads = tfe.defun(apply_gradients, compiled=compiled) + model.call = tfe.defun(model.call) + apply_grads = tfe.defun(apply_gradients) num_burn = 3 num_iters = 10 diff --git a/tensorflow/contrib/eager/python/examples/generative_examples/cvae.ipynb b/tensorflow/contrib/eager/python/examples/generative_examples/cvae.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..ca27a85a229d41a85fa26ecdc982da478fe9e202 --- /dev/null +++ b/tensorflow/contrib/eager/python/examples/generative_examples/cvae.ipynb @@ -0,0 +1,649 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "0TD5ZrvEMbhZ" + }, + "source": [ + "##### Copyright 2018 The TensorFlow Authors.\n", + "\n", + "Licensed under the Apache License, Version 2.0 (the \"License\").\n", + "\n", + "# Convolutional VAE: An example with tf.keras and eager\n", + "\n", + "\u003ctable class=\"tfo-notebook-buttons\" align=\"left\"\u003e\u003ctd\u003e\n", + "\u003ca target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/tensorflow/blob/master/tensorflow/contrib/eager/python/examples/generative_examples/cvae.ipynb\"\u003e\n", + " \u003cimg src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" /\u003eRun in Google Colab\u003c/a\u003e \n", + "\u003c/td\u003e\u003ctd\u003e\n", + "\u003ca target=\"_blank\" href=\"https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/eager/python/examples/generative_examples/cvae.ipynb\"\u003e\u003cimg width=32px src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" /\u003eView source on GitHub\u003c/a\u003e\u003c/td\u003e\u003c/table\u003e" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "ITZuApL56Mny" + }, + "source": [ + "![evolution of output during training](https://tensorflow.org/images/autoencoders/cvae.gif)\n", + "\n", + "This notebook demonstrates how to generate images of handwritten digits using [tf.keras](https://www.tensorflow.org/programmers_guide/keras) and [eager execution](https://www.tensorflow.org/programmers_guide/eager) by training a Variational Autoencoder. (VAE, [[1]](https://arxiv.org/abs/1312.6114), [[2]](https://arxiv.org/abs/1401.4082)).\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "colab_type": "code", + "id": "P-JuIu2N_SQf" + }, + "outputs": [], + "source": [ + "# to generate gifs\n", + "!pip install imageio" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "e1_Y75QXJS6h" + }, + "source": [ + "## Import TensorFlow and enable Eager execution" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "colab_type": "code", + "id": "YfIk2es3hJEd" + }, + "outputs": [], + "source": [ + "from __future__ import absolute_import, division, print_function\n", + "\n", + "# Import TensorFlow \u003e= 1.9 and enable eager execution\n", + "import tensorflow as tf\n", + "tfe = tf.contrib.eager\n", + "tf.enable_eager_execution()\n", + "\n", + "import os\n", + "import time\n", + "import numpy as np\n", + "import glob\n", + "import matplotlib.pyplot as plt\n", + "import PIL\n", + "import imageio\n", + "from IPython import display" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "iYn4MdZnKCey" + }, + "source": [ + "## Load the MNIST dataset\n", + "Each MNIST image is originally a vector of 784 integers, each of which is between 0-255 and represents the intensity of a pixel. We model each pixel with a Bernoulli distribution in our model, and we statically binarize the dataset." + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "colab_type": "code", + "id": "a4fYMGxGhrna" + }, + "outputs": [], + "source": [ + "(train_images, _), (test_images, _) = tf.keras.datasets.mnist.load_data()" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "colab_type": "code", + "id": "NFC2ghIdiZYE" + }, + "outputs": [], + "source": [ + "train_images = train_images.reshape(train_images.shape[0], 28, 28, 1).astype('float32')\n", + "test_images = test_images.reshape(test_images.shape[0], 28, 28, 1).astype('float32')\n", + "\n", + "# Normalizing the images to the range of [0., 1.]\n", + "train_images /= 255.\n", + "test_images /= 255.\n", + "\n", + "# Binarization\n", + "train_images[train_images \u003e= .5] = 1.\n", + "train_images[train_images \u003c .5] = 0.\n", + "test_images[test_images \u003e= .5] = 1.\n", + "test_images[test_images \u003c .5] = 0." + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "colab_type": "code", + "id": "S4PIDhoDLbsZ" + }, + "outputs": [], + "source": [ + "TRAIN_BUF = 60000\n", + "BATCH_SIZE = 100\n", + "\n", + "TEST_BUF = 10000" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "PIGN6ouoQxt3" + }, + "source": [ + "## Use *tf.data* to create batches and shuffle the dataset" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "colab_type": "code", + "id": "-yKCCQOoJ7cn" + }, + "outputs": [], + "source": [ + "train_dataset = tf.data.Dataset.from_tensor_slices(train_images).shuffle(TRAIN_BUF).batch(BATCH_SIZE)\n", + "test_dataset = tf.data.Dataset.from_tensor_slices(test_images).shuffle(TEST_BUF).batch(BATCH_SIZE)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "THY-sZMiQ4UV" + }, + "source": [ + "## Wire up the generative and inference network with *tf.keras.Sequential*\n", + "\n", + "In our VAE example, we use two small ConvNets for the generative and inference network. Since these neural nets are small, we use `tf.keras.Sequential` to simplify our code. Let $x$ and $z$ denote the observation and latent variable respectively in the following descriptions. \n", + "\n", + "### Generative Network\n", + "This defines the generative model which takes a latent encoding as input, and outputs the parameters for a conditional distribution of the observation, i.e. $p(x|z)$. Additionally, we use a unit Gaussian prior $p(z)$ for the latent variable.\n", + "\n", + "### Inference Network\n", + "This defines an approximate posterior distribution $q(z|x)$, which takes as input an observation and outputs a set of parameters for the conditional distribution of the latent representation. In this example, we simply model this distribution as a diagonal Gaussian. In this case, the inference network outputs the mean and log-variance parameters of a factorized Gaussian (log-variance instead of the variance directly is for numerical stability).\n", + "\n", + "### Reparameterization Trick\n", + "During optimization, we can sample from $q(z|x)$ by first sampling from a unit Gaussian, and then multiplying by the standard deviation and adding the mean. This ensures the gradients could pass through the sample to the inference network parameters.\n", + "\n", + "### Network architecture\n", + "For the inference network, we use two convolutional layers followed by a fully-connected layer. In the generative network, we mirror this architecture by using a fully-connected layer followed by three convolution transpose layers (a.k.a. deconvolutional layers in some contexts). Note, it's common practice to avoid using batch normalization when training VAEs, since the additional stochasticity due to using mini-batches may aggravate instability on top of the stochasticity from sampling." + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "colab_type": "code", + "id": "VGLbvBEmjK0a" + }, + "outputs": [], + "source": [ + "class CVAE(tf.keras.Model):\n", + " def __init__(self, latent_dim):\n", + " super(CVAE, self).__init__()\n", + " self.latent_dim = latent_dim\n", + " self.inference_net = tf.keras.Sequential(\n", + " [\n", + " tf.keras.layers.InputLayer(input_shape=(28, 28, 1)),\n", + " tf.keras.layers.Conv2D(\n", + " filters=32, kernel_size=3, strides=(2, 2), activation=tf.nn.relu),\n", + " tf.keras.layers.Conv2D(\n", + " filters=64, kernel_size=3, strides=(2, 2), activation=tf.nn.relu),\n", + " tf.keras.layers.Flatten(),\n", + " # No activation\n", + " tf.keras.layers.Dense(latent_dim + latent_dim),\n", + " ]\n", + " )\n", + "\n", + " self.generative_net = tf.keras.Sequential(\n", + " [\n", + " tf.keras.layers.InputLayer(input_shape=(latent_dim,)),\n", + " tf.keras.layers.Dense(units=7*7*32, activation=tf.nn.relu),\n", + " tf.keras.layers.Reshape(target_shape=(7, 7, 32)),\n", + " tf.keras.layers.Conv2DTranspose(\n", + " filters=64,\n", + " kernel_size=3,\n", + " strides=(2, 2),\n", + " padding=\"SAME\",\n", + " activation=tf.nn.relu),\n", + " tf.keras.layers.Conv2DTranspose(\n", + " filters=32,\n", + " kernel_size=3,\n", + " strides=(2, 2),\n", + " padding=\"SAME\",\n", + " activation=tf.nn.relu),\n", + " # No activation\n", + " tf.keras.layers.Conv2DTranspose(\n", + " filters=1, kernel_size=3, strides=(1, 1), padding=\"SAME\"),\n", + " ]\n", + " )\n", + "\n", + " def sample(self, eps=None):\n", + " if eps is None:\n", + " eps = tf.random_normal(shape=(100, self.latent_dim))\n", + " return self.decode(eps, apply_sigmoid=True)\n", + "\n", + " def encode(self, x):\n", + " mean, logvar = tf.split(self.inference_net(x), num_or_size_splits=2, axis=1)\n", + " return mean, logvar\n", + "\n", + " def reparameterize(self, mean, logvar):\n", + " eps = tf.random_normal(shape=mean.shape)\n", + " return eps * tf.exp(logvar * .5) + mean\n", + "\n", + " def decode(self, z, apply_sigmoid=False):\n", + " logits = self.generative_net(z)\n", + " if apply_sigmoid:\n", + " probs = tf.sigmoid(logits)\n", + " return probs\n", + "\n", + " return logits" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "0FMYgY_mPfTi" + }, + "source": [ + "## Define the loss function and the optimizer\n", + "\n", + "VAEs train by maximizing the evidence lower bound (ELBO) on the marginal log-likelihood:\n", + "\n", + "$$\\log p(x) \\ge \\text{ELBO} = \\mathbb{E}_{q(z|x)}\\left[\\log \\frac{p(x, z)}{q(z|x)}\\right].$$\n", + "\n", + "In practice, we optimize the single sample Monte Carlo estimate of this expectation:\n", + "\n", + "$$\\log p(x| z) + \\log p(z) - \\log q(z|x),$$\n", + "where $z$ is sampled from $q(z|x)$.\n", + "\n", + "**Note**: we could also analytically compute the KL term, but here we incorporate all three terms in the Monte Carlo estimator for simplicity." + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "colab_type": "code", + "id": "iWCn_PVdEJZ7" + }, + "outputs": [], + "source": [ + "def log_normal_pdf(sample, mean, logvar, raxis=1):\n", + " log2pi = tf.log(2. * np.pi)\n", + " return tf.reduce_sum(\n", + " -.5 * ((sample - mean) ** 2. * tf.exp(-logvar) + logvar + log2pi),\n", + " axis=raxis)\n", + "\n", + "def compute_loss(model, x):\n", + " mean, logvar = model.encode(x)\n", + " z = model.reparameterize(mean, logvar)\n", + " x_logit = model.decode(z)\n", + "\n", + " cross_ent = tf.nn.sigmoid_cross_entropy_with_logits(logits=x_logit, labels=x)\n", + " logpx_z = -tf.reduce_sum(cross_ent, axis=[1, 2, 3])\n", + " logpz = log_normal_pdf(z, 0., 0.)\n", + " logqz_x = log_normal_pdf(z, mean, logvar)\n", + " return -tf.reduce_mean(logpx_z + logpz - logqz_x)\n", + "\n", + "def compute_gradients(model, x):\n", + " with tf.GradientTape() as tape:\n", + " loss = compute_loss(model, x)\n", + " return tape.gradient(loss, model.trainable_variables), loss\n", + "\n", + "optimizer = tf.train.AdamOptimizer(1e-4)\n", + "def apply_gradients(optimizer, gradients, variables, global_step=None):\n", + " optimizer.apply_gradients(zip(gradients, variables), global_step=global_step)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "Rw1fkAczTQYh" + }, + "source": [ + "## Training\n", + "\n", + "* We start by iterating over the dataset\n", + "* During each iteration, we pass the image to the encoder to obtain a set of mean and log-variance parameters of the approximate posterior $q(z|x)$\n", + "* We then apply the *reparameterization trick* to sample from $q(z|x)$\n", + "* Finally, we pass the reparameterized samples to the decoder to obtain the logits of the generative distribution $p(x|z)$\n", + "* **Note:** Since we use the dataset loaded by keras with 60k datapoints in the training set and 10k datapoints in the test set, our resulting ELBO on the test set is slightly higher than reported results in the literature which uses dynamic binarization of Larochelle's MNIST.\n", + "\n", + "## Generate Images\n", + "\n", + "* After training, it is time to generate some images\n", + "* We start by sampling a set of latent vectors from the unit Gaussian prior distribution $p(z)$\n", + "* The generator will then convert the latent sample $z$ to logits of the observation, giving a distribution $p(x|z)$\n", + "* Here we plot the probabilities of Bernoulli distributions\n" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "colab_type": "code", + "id": "NS2GWywBbAWo" + }, + "outputs": [], + "source": [ + "epochs = 100\n", + "latent_dim = 50\n", + "num_examples_to_generate = 16\n", + "\n", + "# keeping the random vector constant for generation (prediction) so\n", + "# it will be easier to see the improvement.\n", + "random_vector_for_generation = tf.random_normal(\n", + " shape=[num_examples_to_generate, latent_dim])\n", + "model = CVAE(latent_dim)" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "colab_type": "code", + "id": "RmdVsmvhPxyy" + }, + "outputs": [], + "source": [ + "def generate_and_save_images(model, epoch, test_input):\n", + " predictions = model.sample(test_input)\n", + " fig = plt.figure(figsize=(4,4))\n", + "\n", + " for i in range(predictions.shape[0]):\n", + " plt.subplot(4, 4, i+1)\n", + " plt.imshow(predictions[i, :, :, 0], cmap='gray')\n", + " plt.axis('off')\n", + "\n", + " # tight_layout minimizes the overlap between 2 sub-plots\n", + " plt.savefig('image_at_epoch_{:04d}.png'.format(epoch))\n", + " plt.show()" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "colab_type": "code", + "id": "2M7LmLtGEMQJ" + }, + "outputs": [], + "source": [ + "generate_and_save_images(model, 0, random_vector_for_generation)\n", + "\n", + "for epoch in range(1, epochs + 1):\n", + " start_time = time.time()\n", + " for train_x in train_dataset:\n", + " gradients, loss = compute_gradients(model, train_x)\n", + " apply_gradients(optimizer, gradients, model.trainable_variables)\n", + " end_time = time.time()\n", + "\n", + " if epoch % 1 == 0:\n", + " loss = tfe.metrics.Mean()\n", + " for test_x in test_dataset.make_one_shot_iterator():\n", + " loss(compute_loss(model, test_x))\n", + " elbo = -loss.result()\n", + " display.clear_output(wait=False)\n", + " print('Epoch: {}, Test set ELBO: {}, '\n", + " 'time elapse for current epoch {}'.format(epoch,\n", + " elbo,\n", + " end_time - start_time))\n", + " generate_and_save_images(\n", + " model, epoch, random_vector_for_generation)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "P4M_vIbUi7c0" + }, + "source": [ + "### Display an image using the epoch number" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "colab_type": "code", + "id": "WfO5wCdclHGL" + }, + "outputs": [], + "source": [ + "def display_image(epoch_no):\n", + " return PIL.Image.open('image_at_epoch_{:04d}.png'.format(epoch_no))" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "colab_type": "code", + "id": "5x3q9_Oe5q0A" + }, + "outputs": [], + "source": [ + "display_image(epochs) # Display images" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "NywiH3nL8guF" + }, + "source": [ + "### Generate a GIF of all the saved images." + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "colab_type": "code", + "id": "IGKQgENQ8lEI" + }, + "outputs": [], + "source": [ + "with imageio.get_writer('cvae.gif', mode='I') as writer:\n", + " filenames = glob.glob('image*.png')\n", + " filenames = sorted(filenames)\n", + " last = -1\n", + " for i,filename in enumerate(filenames):\n", + " frame = 2*(i**0.5)\n", + " if round(frame) \u003e round(last):\n", + " last = frame\n", + " else:\n", + " continue\n", + " image = imageio.imread(filename)\n", + " writer.append_data(image)\n", + " image = imageio.imread(filename)\n", + " writer.append_data(image)\n", + " \n", + "# this is a hack to display the gif inside the notebook\n", + "os.system('cp cvae.gif cvae.gif.png')" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "colab_type": "code", + "id": "uV0yiKpzNP1b" + }, + "outputs": [], + "source": [ + "display.Image(filename=\"cvae.gif.png\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "yQXO_dlXkKsT" + }, + "source": [ + "To downlod the animation from Colab uncomment the code below:" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "colab_type": "code", + "id": "4fSJS3m5HLFM" + }, + "outputs": [], + "source": [ + "#from google.colab import files\n", + "#files.download('cvae.gif')" + ] + } + ], + "metadata": { + "accelerator": "GPU", + "colab": { + "collapsed_sections": [], + "default_view": {}, + "name": "cvae.ipynb", + "private_outputs": true, + "provenance": [ + { + "file_id": "1eb0NOTQapkYs3X0v-zL1x5_LFKgDISnp", + "timestamp": 1527173385672 + } + ], + "toc_visible": true, + "version": "0.3.2", + "views": {} + }, + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} diff --git a/tensorflow/contrib/eager/python/examples/generative_examples/dcgan.ipynb b/tensorflow/contrib/eager/python/examples/generative_examples/dcgan.ipynb index 44ff43a1112e771eb6c91c398286a003e17632e0..5621d6a358e8969ea1a6663c1c770987de41ce0c 100644 --- a/tensorflow/contrib/eager/python/examples/generative_examples/dcgan.ipynb +++ b/tensorflow/contrib/eager/python/examples/generative_examples/dcgan.ipynb @@ -40,12 +40,7 @@ "cell_type": "code", "execution_count": 0, "metadata": { - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - }, + "colab": {}, "colab_type": "code", "id": "u_2z-B3piVsw" }, @@ -69,12 +64,7 @@ "cell_type": "code", "execution_count": 0, "metadata": { - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - }, + "colab": {}, "colab_type": "code", "id": "YfIk2es3hJEd" }, @@ -82,7 +72,7 @@ "source": [ "from __future__ import absolute_import, division, print_function\n", "\n", - "# Import TensorFlow \u003e= 1.9 and enable eager execution\n", + "# Import TensorFlow \u003e= 1.10 and enable eager execution\n", "import tensorflow as tf\n", "tf.enable_eager_execution()\n", "\n", @@ -112,12 +102,7 @@ "cell_type": "code", "execution_count": 0, "metadata": { - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - }, + "colab": {}, "colab_type": "code", "id": "a4fYMGxGhrna" }, @@ -130,12 +115,7 @@ "cell_type": "code", "execution_count": 0, "metadata": { - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - }, + "colab": {}, "colab_type": "code", "id": "NFC2ghIdiZYE" }, @@ -150,12 +130,7 @@ "cell_type": "code", "execution_count": 0, "metadata": { - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - }, + "colab": {}, "colab_type": "code", "id": "S4PIDhoDLbsZ" }, @@ -179,12 +154,7 @@ "cell_type": "code", "execution_count": 0, "metadata": { - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - }, + "colab": {}, "colab_type": "code", "id": "-yKCCQOoJ7cn" }, @@ -217,12 +187,7 @@ "cell_type": "code", "execution_count": 0, "metadata": { - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - }, + "colab": {}, "colab_type": "code", "id": "VGLbvBEmjK0a" }, @@ -265,12 +230,7 @@ "cell_type": "code", "execution_count": 0, "metadata": { - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - }, + "colab": {}, "colab_type": "code", "id": "bkOfJxk5j5Hi" }, @@ -299,12 +259,7 @@ "cell_type": "code", "execution_count": 0, "metadata": { - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - }, + "colab": {}, "colab_type": "code", "id": "gDkA05NE6QMs" }, @@ -318,12 +273,7 @@ "cell_type": "code", "execution_count": 0, "metadata": { - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - }, + "colab": {}, "colab_type": "code", "id": "k1HpMSLImuRi" }, @@ -360,12 +310,7 @@ "cell_type": "code", "execution_count": 0, "metadata": { - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - }, + "colab": {}, "colab_type": "code", "id": "wkMNfBWlT-PV" }, @@ -388,12 +333,7 @@ "cell_type": "code", "execution_count": 0, "metadata": { - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - }, + "colab": {}, "colab_type": "code", "id": "90BIcCKcDMxz" }, @@ -407,12 +347,7 @@ "cell_type": "code", "execution_count": 0, "metadata": { - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - }, + "colab": {}, "colab_type": "code", "id": "iWCn_PVdEJZ7" }, @@ -422,6 +357,34 @@ "generator_optimizer = tf.train.AdamOptimizer(1e-4)" ] }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "mWtinsGDPJlV" + }, + "source": [ + "## Checkpoints (Object-based saving)" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "CA1w-7s2POEy" + }, + "outputs": [], + "source": [ + "checkpoint_dir = './training_checkpoints'\n", + "checkpoint_prefix = os.path.join(checkpoint_dir, \"ckpt\")\n", + "checkpoint = tf.train.Checkpoint(generator_optimizer=generator_optimizer,\n", + " discriminator_optimizer=discriminator_optimizer,\n", + " generator=generator,\n", + " discriminator=discriminator)" + ] + }, { "cell_type": "markdown", "metadata": { @@ -449,12 +412,7 @@ "cell_type": "code", "execution_count": 0, "metadata": { - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - }, + "colab": {}, "colab_type": "code", "id": "NS2GWywBbAWo" }, @@ -462,7 +420,7 @@ "source": [ "EPOCHS = 150\n", "noise_dim = 100\n", - "num_examples_to_generate = 100\n", + "num_examples_to_generate = 16\n", "\n", "# keeping the random vector constant for generation (prediction) so\n", "# it will be easier to see the improvement of the gan.\n", @@ -474,12 +432,7 @@ "cell_type": "code", "execution_count": 0, "metadata": { - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - }, + "colab": {}, "colab_type": "code", "id": "RmdVsmvhPxyy" }, @@ -490,15 +443,13 @@ " # don't want to train the batchnorm layer when doing inference.\n", " predictions = model(test_input, training=False)\n", "\n", - " fig = plt.figure(figsize=(10,10))\n", + " fig = plt.figure(figsize=(4,4))\n", " \n", " for i in range(predictions.shape[0]):\n", - " plt.subplot(10, 10, i+1)\n", + " plt.subplot(4, 4, i+1)\n", " plt.imshow(predictions[i, :, :, 0] * 127.5 + 127.5, cmap='gray')\n", " plt.axis('off')\n", " \n", - " # tight_layout minimizes the overlap between 2 sub-plots\n", - " plt.tight_layout()\n", " plt.savefig('image_at_epoch_{:04d}.png'.format(epoch))\n", " plt.show()" ] @@ -507,12 +458,7 @@ "cell_type": "code", "execution_count": 0, "metadata": { - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - }, + "colab": {}, "colab_type": "code", "id": "2M7LmLtGEMQJ" }, @@ -542,15 +488,20 @@ " discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, discriminator.variables))\n", "\n", " \n", - " if epoch % 10 == 0:\n", + " if epoch % 1 == 0:\n", " display.clear_output(wait=True)\n", " generate_and_save_images(generator,\n", " epoch + 1,\n", " random_vector_for_generation)\n", - "\n", + " \n", + " # saving (checkpoint) the model every 15 epochs\n", + " if (epoch + 1) % 15 == 0:\n", + " checkpoint.save(file_prefix = checkpoint_prefix)\n", + " \n", " print ('Time taken for epoch {} is {} sec'.format(epoch + 1,\n", " time.time()-start))\n", " # generating after the final epoch\n", + " display.clear_output(wait=True)\n", " generate_and_save_images(generator,\n", " epochs,\n", " random_vector_for_generation)" @@ -560,12 +511,7 @@ "cell_type": "code", "execution_count": 0, "metadata": { - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - }, + "colab": {}, "colab_type": "code", "id": "Ly3UN0SLLY2l" }, @@ -574,6 +520,30 @@ "train(train_dataset, EPOCHS, noise_dim)" ] }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "rfM4YcPVPkNO" + }, + "source": [ + "## Restore the latest checkpoint" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "XhXsd0srPo8c" + }, + "outputs": [], + "source": [ + "# restoring the latest checkpoint in checkpoint_dir\n", + "checkpoint.restore(tf.train.latest_checkpoint(checkpoint_dir))" + ] + }, { "cell_type": "markdown", "metadata": { @@ -581,40 +551,28 @@ "id": "P4M_vIbUi7c0" }, "source": [ - "# Display an image using the epoch number" + "## Display an image using the epoch number" ] }, { "cell_type": "code", "execution_count": 0, "metadata": { - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - }, + "colab": {}, "colab_type": "code", "id": "WfO5wCdclHGL" }, "outputs": [], "source": [ "def display_image(epoch_no):\n", - " plt.figure(figsize=(15,15))\n", - " plt.imshow(np.array(PIL.Image.open('image_at_epoch_{:04d}.png'.format(epoch_no))))\n", - " plt.axis('off')" + " return PIL.Image.open('image_at_epoch_{:04d}.png'.format(epoch_no))" ] }, { "cell_type": "code", "execution_count": 0, "metadata": { - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - }, + "colab": {}, "colab_type": "code", "id": "5x3q9_Oe5q0A" }, @@ -647,12 +605,7 @@ "cell_type": "code", "execution_count": 0, "metadata": { - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - }, + "colab": {}, "colab_type": "code", "id": "IGKQgENQ8lEI" }, @@ -661,23 +614,27 @@ "with imageio.get_writer('dcgan.gif', mode='I') as writer:\n", " filenames = glob.glob('image*.png')\n", " filenames = sorted(filenames)\n", - " for filename in filenames:\n", + " last = -1\n", + " for i,filename in enumerate(filenames):\n", + " frame = 2*(i**0.5)\n", + " if round(frame) \u003e round(last):\n", + " last = frame\n", + " else:\n", + " continue\n", " image = imageio.imread(filename)\n", " writer.append_data(image)\n", - " # this is a hack to display the gif inside the notebook\n", - " os.system('mv dcgan.gif dcgan.gif.png')" + " image = imageio.imread(filename)\n", + " writer.append_data(image)\n", + " \n", + "# this is a hack to display the gif inside the notebook\n", + "os.system('cp dcgan.gif dcgan.gif.png')" ] }, { "cell_type": "code", "execution_count": 0, "metadata": { - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - }, + "colab": {}, "colab_type": "code", "id": "uV0yiKpzNP1b" }, @@ -686,22 +643,28 @@ "display.Image(filename=\"dcgan.gif.png\")" ] }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "6EEG-wePkmJQ" + }, + "source": [ + "To downlod the animation from Colab uncomment the code below:" + ] + }, { "cell_type": "code", "execution_count": 0, "metadata": { - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - }, + "colab": {}, "colab_type": "code", "id": "4UJjSnIMOzOJ" }, "outputs": [], "source": [ - "" + "#from google.colab import files\n", + "#files.download('dcgan.gif')" ] } ], @@ -709,7 +672,6 @@ "accelerator": "GPU", "colab": { "collapsed_sections": [], - "default_view": {}, "name": "dcgan.ipynb", "private_outputs": true, "provenance": [ @@ -719,8 +681,7 @@ } ], "toc_visible": true, - "version": "0.3.2", - "views": {} + "version": "0.3.2" }, "kernelspec": { "display_name": "Python 3", diff --git a/tensorflow/contrib/eager/python/examples/generative_examples/text_generation.ipynb b/tensorflow/contrib/eager/python/examples/generative_examples/text_generation.ipynb index b173f856c641b4d7dca96adda113f904c97a25a7..027097908f2c62724830c556d72b6b6bee218eec 100644 --- a/tensorflow/contrib/eager/python/examples/generative_examples/text_generation.ipynb +++ b/tensorflow/contrib/eager/python/examples/generative_examples/text_generation.ipynb @@ -96,12 +96,7 @@ "cell_type": "code", "execution_count": 0, "metadata": { - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - }, + "colab": {}, "colab_type": "code", "id": "wZ6LOM12wKGH" }, @@ -124,24 +119,20 @@ "cell_type": "code", "execution_count": 0, "metadata": { - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - }, + "colab": {}, "colab_type": "code", "id": "yG_n40gFzf9s" }, "outputs": [], "source": [ - "# Import TensorFlow \u003e= 1.9 and enable eager execution\n", + "# Import TensorFlow \u003e= 1.10 and enable eager execution\n", "import tensorflow as tf\n", "\n", "# Note: Once you enable eager execution, it cannot be disabled. \n", "tf.enable_eager_execution()\n", "\n", "import numpy as np\n", + "import os\n", "import re\n", "import random\n", "import unidecode\n", @@ -165,12 +156,7 @@ "cell_type": "code", "execution_count": 0, "metadata": { - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - }, + "colab": {}, "colab_type": "code", "id": "pD_55cOxLkAb" }, @@ -194,12 +180,7 @@ "cell_type": "code", "execution_count": 0, "metadata": { - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - }, + "colab": {}, "colab_type": "code", "id": "-E5JvY3wzf94" }, @@ -224,12 +205,7 @@ "cell_type": "code", "execution_count": 0, "metadata": { - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - }, + "colab": {}, "colab_type": "code", "id": "IalZLbvOzf-F" }, @@ -247,12 +223,7 @@ "cell_type": "code", "execution_count": 0, "metadata": { - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - }, + "colab": {}, "colab_type": "code", "id": "1v_qUYfAzf-I" }, @@ -302,12 +273,7 @@ "cell_type": "code", "execution_count": 0, "metadata": { - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - }, + "colab": {}, "colab_type": "code", "id": "0UHJDA39zf-O" }, @@ -341,19 +307,14 @@ "cell_type": "code", "execution_count": 0, "metadata": { - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - }, + "colab": {}, "colab_type": "code", "id": "p2pGotuNzf-S" }, "outputs": [], "source": [ "dataset = tf.data.Dataset.from_tensor_slices((input_text, target_text)).shuffle(BUFFER_SIZE)\n", - "dataset = dataset.apply(tf.contrib.data.batch_and_drop_remainder(BATCH_SIZE))" + "dataset = dataset.batch(BATCH_SIZE, drop_remainder=True)" ] }, { @@ -376,12 +337,7 @@ "cell_type": "code", "execution_count": 0, "metadata": { - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - }, + "colab": {}, "colab_type": "code", "id": "P3KTiiInzf-a" }, @@ -445,12 +401,7 @@ "cell_type": "code", "execution_count": 0, "metadata": { - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - }, + "colab": {}, "colab_type": "code", "id": "7t2XrzEOzf-e" }, @@ -463,12 +414,7 @@ "cell_type": "code", "execution_count": 0, "metadata": { - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - }, + "colab": {}, "colab_type": "code", "id": "dkjWIATszf-h" }, @@ -481,6 +427,32 @@ " return tf.losses.sparse_softmax_cross_entropy(labels=real, logits=preds)" ] }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "3K6s6F79P7za" + }, + "source": [ + "## Checkpoints (Object-based saving)" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "oAGisDdfP9rL" + }, + "outputs": [], + "source": [ + "checkpoint_dir = './training_checkpoints'\n", + "checkpoint_prefix = os.path.join(checkpoint_dir, \"ckpt\")\n", + "checkpoint = tf.train.Checkpoint(optimizer=optimizer,\n", + " model=model)" + ] + }, { "cell_type": "markdown", "metadata": { @@ -514,12 +486,7 @@ "cell_type": "code", "execution_count": 0, "metadata": { - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - }, + "colab": {}, "colab_type": "code", "id": "d4tSNwymzf-q" }, @@ -527,7 +494,7 @@ "source": [ "# Training step\n", "\n", - "EPOCHS = 30\n", + "EPOCHS = 20\n", "\n", "for epoch in range(EPOCHS):\n", " start = time.time()\n", @@ -547,17 +514,44 @@ " loss = loss_function(target, predictions)\n", " \n", " grads = tape.gradient(loss, model.variables)\n", - " optimizer.apply_gradients(zip(grads, model.variables), global_step=tf.train.get_or_create_global_step())\n", + " optimizer.apply_gradients(zip(grads, model.variables))\n", "\n", " if batch % 100 == 0:\n", " print ('Epoch {} Batch {} Loss {:.4f}'.format(epoch+1,\n", " batch,\n", " loss))\n", - " \n", + " # saving (checkpoint) the model every 5 epochs\n", + " if (epoch + 1) % 5 == 0:\n", + " checkpoint.save(file_prefix = checkpoint_prefix)\n", + "\n", " print ('Epoch {} Loss {:.4f}'.format(epoch+1, loss))\n", " print('Time taken for 1 epoch {} sec\\n'.format(time.time() - start))" ] }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "01AR9vpNQMFF" + }, + "source": [ + "## Restore the latest checkpoint" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "tyvpYomYQQkF" + }, + "outputs": [], + "source": [ + "# restoring the latest checkpoint in checkpoint_dir\n", + "checkpoint.restore(tf.train.latest_checkpoint(checkpoint_dir))" + ] + }, { "cell_type": "markdown", "metadata": { @@ -584,12 +578,7 @@ "cell_type": "code", "execution_count": 0, "metadata": { - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - }, + "colab": {}, "colab_type": "code", "id": "WvuwZBX5Ogfd" }, @@ -651,12 +640,7 @@ "cell_type": "code", "execution_count": 0, "metadata": { - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - }, + "colab": {}, "colab_type": "code", "id": "gtEd86sX5cB2" }, @@ -670,13 +654,11 @@ "accelerator": "GPU", "colab": { "collapsed_sections": [], - "default_view": {}, "name": "text_generation.ipynb", "private_outputs": true, "provenance": [], "toc_visible": true, - "version": "0.3.2", - "views": {} + "version": "0.3.2" }, "kernelspec": { "display_name": "Python 3", diff --git a/tensorflow/contrib/eager/python/examples/l2hmc/README.md b/tensorflow/contrib/eager/python/examples/l2hmc/README.md index d6a2ff7558c76c714df1674c4c8c627fa433f197..f171806e379da7213b6ee33e0d454056068fe7a5 100644 --- a/tensorflow/contrib/eager/python/examples/l2hmc/README.md +++ b/tensorflow/contrib/eager/python/examples/l2hmc/README.md @@ -4,16 +4,15 @@ This folder contains an implementation of [L2HMC](https://arxiv.org/pdf/1711.092 With eager execution enabled, longer sample chains can be handled compared to graph mode, since no graph is explicitly stored. Moreover, with eager execution enabled, there is no need to use a `tf.while_loop`. ## What is L2HMC? -L2HMC is an algorithm that learns a non-volume preserving transformation -for an HMC-like sampling algorithm. More specifically, the non-volume preserving +L2HMC is an adaptive Markov Chain Monte Carlo (MCMC) algorithm that learns a non-volume preserving transformation +for a Hamiltonian Monte Carlo (HMC) sampling algorithm. More specifically, the non-volume preserving transformation is learned with neural nets instantiated within Normalizing Flows -(more precisely, real-NVPs). +(real-NVPs). ## Content - `l2hmc.py`: Dynamics definitions and example energy functions, -including the 2D strongly correlated Gaussian, the rough well energy function, -and a Gaussian mixture model. +including the 2D strongly correlated Gaussian and the rough well energy function, - `l2hmc_test.py`: Unit tests and benchmarks for training a sampler on the energy functions in both eager and graph mode. - `neural_nets.py`: The neural net for learning the kernel on the 2D strongly correlated example. - `main.py`: Run to train a samplers on 2D energy landscapes. @@ -32,7 +31,7 @@ tensorboard and a plot of sampled chain from the trained sampler. Specifying the optional argument `use_defun` will let the program use compiled graphs when running specific sections and improve the overall speed. -## Boosting Performance with `defun` +## Boosting Performance with `tfe.defun` Currently, some models may experience increased overhead with eager execution enabled. To improve performance, we could wrap certain functions with the decorator `@tfe.defun`. For example, we could wrap the function that does the sampling step: diff --git a/tensorflow/contrib/eager/python/examples/nmt_with_attention/nmt_with_attention.ipynb b/tensorflow/contrib/eager/python/examples/nmt_with_attention/nmt_with_attention.ipynb index 1f66d7e75299df0c7db9bc8ec67cb6c0b5d4de40..08d8364978f6a9b4e8e15b5caac7db14c1d721b4 100644 --- a/tensorflow/contrib/eager/python/examples/nmt_with_attention/nmt_with_attention.ipynb +++ b/tensorflow/contrib/eager/python/examples/nmt_with_attention/nmt_with_attention.ipynb @@ -1,39 +1,11 @@ { - "nbformat": 4, - "nbformat_minor": 0, - "metadata": { - "colab": { - "name": "nmt_with_attention.ipynb", - "version": "0.3.2", - "views": {}, - "default_view": {}, - "provenance": [ - { - "file_id": "1C4fpM7_7IL8ZzF7Gc5abywqQjeQNS2-U", - "timestamp": 1527858391290 - }, - { - "file_id": "1pExo6aUuw0S6MISFWoinfJv0Ftm9V4qv", - "timestamp": 1527776041613 - } - ], - "private_outputs": true, - "collapsed_sections": [], - "toc_visible": true - }, - "kernelspec": { - "name": "python3", - "display_name": "Python 3" - }, - "accelerator": "GPU" - }, "cells": [ { + "cell_type": "markdown", "metadata": { - "id": "AOpGoE2T-YXS", - "colab_type": "text" + "colab_type": "text", + "id": "AOpGoE2T-YXS" }, - "cell_type": "markdown", "source": [ "##### Copyright 2018 The TensorFlow Authors.\n", "\n", @@ -41,19 +13,19 @@ "\n", "# Neural Machine Translation with Attention\n", "\n", - "
\n", - "\n", - " Run in Google Colab \n", - "\n", - "View source on GitHub
" + "\u003ctable class=\"tfo-notebook-buttons\" align=\"left\"\u003e\u003ctd\u003e\n", + "\u003ca target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/tensorflow/blob/master/tensorflow/contrib/eager/python/examples/nmt_with_attention/nmt_with_attention.ipynb\"\u003e\n", + " \u003cimg src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" /\u003eRun in Google Colab\u003c/a\u003e \n", + "\u003c/td\u003e\u003ctd\u003e\n", + "\u003ca target=\"_blank\" href=\"https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/eager/python/examples/nmt_with_attention/nmt_with_attention.ipynb\"\u003e\u003cimg width=32px src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" /\u003eView source on GitHub\u003c/a\u003e\u003c/td\u003e\u003c/table\u003e" ] }, { + "cell_type": "markdown", "metadata": { - "id": "CiwtNgENbx2g", - "colab_type": "text" + "colab_type": "text", + "id": "CiwtNgENbx2g" }, - "cell_type": "markdown", "source": [ "This notebook trains a sequence to sequence (seq2seq) model for Spanish to English translation using [tf.keras](https://www.tensorflow.org/programmers_guide/keras) and [eager execution](https://www.tensorflow.org/programmers_guide/eager). This is an advanced example that assumes some knowledge of sequence to sequence models.\n", "\n", @@ -61,27 +33,24 @@ "\n", "The translation quality is reasonable for a toy example, but the generated attention plot is perhaps more interesting. This shows which parts of the input sentence has the model's attention while translating:\n", "\n", - "\"spanish-english\n", + "\u003cimg src=\"https://tensorflow.org/images/spanish-english.png\" alt=\"spanish-english attention plot\"\u003e\n", "\n", "Note: This example takes approximately 10 mintues to run on a single P100 GPU." ] }, { + "cell_type": "code", + "execution_count": 0, "metadata": { - "id": "tnxXKDjq3jEL", + "colab": {}, "colab_type": "code", - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - } + "id": "tnxXKDjq3jEL" }, - "cell_type": "code", + "outputs": [], "source": [ "from __future__ import absolute_import, division, print_function\n", "\n", - "# Import TensorFlow >= 1.9 and enable eager execution\n", + "# Import TensorFlow \u003e= 1.10 and enable eager execution\n", "import tensorflow as tf\n", "\n", "tf.enable_eager_execution()\n", @@ -96,16 +65,14 @@ "import time\n", "\n", "print(tf.__version__)" - ], - "execution_count": 0, - "outputs": [] + ] }, { + "cell_type": "markdown", "metadata": { - "id": "wfodePkj3jEa", - "colab_type": "text" + "colab_type": "text", + "id": "wfodePkj3jEa" }, - "cell_type": "markdown", "source": [ "## Download and prepare the dataset\n", "\n", @@ -124,17 +91,14 @@ ] }, { + "cell_type": "code", + "execution_count": 0, "metadata": { - "id": "kRVATYOgJs1b", + "colab": {}, "colab_type": "code", - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - } + "id": "kRVATYOgJs1b" }, - "cell_type": "code", + "outputs": [], "source": [ "# Download the file\n", "path_to_zip = tf.keras.utils.get_file(\n", @@ -142,22 +106,17 @@ " extract=True)\n", "\n", "path_to_file = os.path.dirname(path_to_zip)+\"/spa-eng/spa.txt\"" - ], - "execution_count": 0, - "outputs": [] + ] }, { + "cell_type": "code", + "execution_count": 0, "metadata": { - "id": "rd0jw-eC3jEh", + "colab": {}, "colab_type": "code", - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - } + "id": "rd0jw-eC3jEh" }, - "cell_type": "code", + "outputs": [], "source": [ "# Converts the unicode file to ascii\n", "def unicode_to_ascii(s):\n", @@ -169,7 +128,7 @@ " w = unicode_to_ascii(w.lower().strip())\n", " \n", " # creating a space between a word and the punctuation following it\n", - " # eg: \"he is a boy.\" => \"he is a boy .\" \n", + " # eg: \"he is a boy.\" =\u003e \"he is a boy .\" \n", " # Reference:- https://stackoverflow.com/questions/3645931/python-padding-punctuation-with-white-spaces-keeping-punctuation\n", " w = re.sub(r\"([?.!,¿])\", r\" \\1 \", w)\n", " w = re.sub(r'[\" \"]+', \" \", w)\n", @@ -181,24 +140,19 @@ " \n", " # adding a start and an end token to the sentence\n", " # so that the model know when to start and stop predicting.\n", - " w = ' ' + w + ' '\n", + " w = '\u003cstart\u003e ' + w + ' \u003cend\u003e'\n", " return w" - ], - "execution_count": 0, - "outputs": [] + ] }, { + "cell_type": "code", + "execution_count": 0, "metadata": { - "id": "OHn4Dct23jEm", + "colab": {}, "colab_type": "code", - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - } + "id": "OHn4Dct23jEm" }, - "cell_type": "code", + "outputs": [], "source": [ "# 1. Remove the accents\n", "# 2. Clean the sentences\n", @@ -209,25 +163,20 @@ " word_pairs = [[preprocess_sentence(w) for w in l.split('\\t')] for l in lines[:num_examples]]\n", " \n", " return word_pairs" - ], - "execution_count": 0, - "outputs": [] + ] }, { + "cell_type": "code", + "execution_count": 0, "metadata": { - "id": "9xbqO7Iie9bb", + "colab": {}, "colab_type": "code", - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - } + "id": "9xbqO7Iie9bb" }, - "cell_type": "code", + "outputs": [], "source": [ - "# This class creates a word -> index mapping (e.g,. \"dad\" -> 5) and vice-versa \n", - "# (e.g., 5 -> \"dad\") for each language,\n", + "# This class creates a word -\u003e index mapping (e.g,. \"dad\" -\u003e 5) and vice-versa \n", + "# (e.g., 5 -\u003e \"dad\") for each language,\n", "class LanguageIndex():\n", " def __init__(self, lang):\n", " self.lang = lang\n", @@ -243,28 +192,23 @@ " \n", " self.vocab = sorted(self.vocab)\n", " \n", - " self.word2idx[''] = 0\n", + " self.word2idx['\u003cpad\u003e'] = 0\n", " for index, word in enumerate(self.vocab):\n", " self.word2idx[word] = index + 1\n", " \n", " for word, index in self.word2idx.items():\n", " self.idx2word[index] = word" - ], - "execution_count": 0, - "outputs": [] + ] }, { + "cell_type": "code", + "execution_count": 0, "metadata": { - "id": "eAY9k49G3jE_", + "colab": {}, "colab_type": "code", - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - } + "id": "eAY9k49G3jE_" }, - "cell_type": "code", + "outputs": [], "source": [ "def max_length(tensor):\n", " return max(len(t) for t in tensor)\n", @@ -300,119 +244,103 @@ " padding='post')\n", " \n", " return input_tensor, target_tensor, inp_lang, targ_lang, max_length_inp, max_length_tar" - ], - "execution_count": 0, - "outputs": [] + ] }, { + "cell_type": "markdown", "metadata": { - "id": "GOi42V79Ydlr", - "colab_type": "text" + "colab_type": "text", + "id": "GOi42V79Ydlr" }, - "cell_type": "markdown", "source": [ "### Limit the size of the dataset to experiment faster (optional)\n", "\n", - "Training on the complete dataset of >100,000 sentences will take a long time. To train faster, we can limit the size of the dataset to 30,000 sentences (of course, translation quality degrades with less data):" + "Training on the complete dataset of \u003e100,000 sentences will take a long time. To train faster, we can limit the size of the dataset to 30,000 sentences (of course, translation quality degrades with less data):" ] }, { + "cell_type": "code", + "execution_count": 0, "metadata": { - "id": "cnxC7q-j3jFD", + "colab": {}, "colab_type": "code", - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - } + "id": "cnxC7q-j3jFD" }, - "cell_type": "code", + "outputs": [], "source": [ "# Try experimenting with the size of that dataset\n", "num_examples = 30000\n", "input_tensor, target_tensor, inp_lang, targ_lang, max_length_inp, max_length_targ = load_dataset(path_to_file, num_examples)" - ], - "execution_count": 0, - "outputs": [] + ] }, { + "cell_type": "code", + "execution_count": 0, "metadata": { - "id": "4QILQkOs3jFG", + "colab": {}, "colab_type": "code", - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - } + "id": "4QILQkOs3jFG" }, - "cell_type": "code", + "outputs": [], "source": [ "# Creating training and validation sets using an 80-20 split\n", "input_tensor_train, input_tensor_val, target_tensor_train, target_tensor_val = train_test_split(input_tensor, target_tensor, test_size=0.2)\n", "\n", "# Show length\n", "len(input_tensor_train), len(target_tensor_train), len(input_tensor_val), len(target_tensor_val)" - ], - "execution_count": 0, - "outputs": [] + ] }, { + "cell_type": "markdown", "metadata": { - "id": "rgCLkfv5uO3d", - "colab_type": "text" + "colab_type": "text", + "id": "rgCLkfv5uO3d" }, - "cell_type": "markdown", "source": [ "### Create a tf.data dataset" ] }, { + "cell_type": "code", + "execution_count": 0, "metadata": { - "id": "TqHsArVZ3jFS", + "colab": {}, "colab_type": "code", - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - } + "id": "TqHsArVZ3jFS" }, - "cell_type": "code", + "outputs": [], "source": [ "BUFFER_SIZE = len(input_tensor_train)\n", "BATCH_SIZE = 64\n", + "N_BATCH = BUFFER_SIZE//BATCH_SIZE\n", "embedding_dim = 256\n", "units = 1024\n", "vocab_inp_size = len(inp_lang.word2idx)\n", "vocab_tar_size = len(targ_lang.word2idx)\n", "\n", "dataset = tf.data.Dataset.from_tensor_slices((input_tensor_train, target_tensor_train)).shuffle(BUFFER_SIZE)\n", - "dataset = dataset.apply(tf.contrib.data.batch_and_drop_remainder(BATCH_SIZE))" - ], - "execution_count": 0, - "outputs": [] + "dataset = dataset.batch(BATCH_SIZE, drop_remainder=True)" + ] }, { + "cell_type": "markdown", "metadata": { - "id": "TNfHIF71ulLu", - "colab_type": "text" + "colab_type": "text", + "id": "TNfHIF71ulLu" }, - "cell_type": "markdown", "source": [ "## Write the encoder and decoder model\n", "\n", "Here, we'll implement an encoder-decoder model with attention which you can read about in the TensorFlow [Neural Machine Translation (seq2seq) tutorial](https://www.tensorflow.org/tutorials/seq2seq). This example uses a more recent set of APIs. This notebook implements the [attention equations](https://www.tensorflow.org/tutorials/seq2seq#background_on_the_attention_mechanism) from the seq2seq tutorial. The following diagram shows that each input words is assigned a weight by the attention mechanism which is then used by the decoder to predict the next word in the sentence.\n", "\n", - "\"attention\n", + "\u003cimg src=\"https://www.tensorflow.org/images/seq2seq/attention_mechanism.jpg\" width=\"500\" alt=\"attention mechanism\"\u003e\n", "\n", "The input is put through an encoder model which gives us the encoder output of shape *(batch_size, max_length, hidden_size)* and the encoder hidden state of shape *(batch_size, hidden_size)*. \n", "\n", "Here are the equations that are implemented:\n", "\n", - "\"attention\n", - "\"attention\n", + "\u003cimg src=\"https://www.tensorflow.org/images/seq2seq/attention_equation_0.jpg\" alt=\"attention equation 0\" width=\"800\"\u003e\n", + "\u003cimg src=\"https://www.tensorflow.org/images/seq2seq/attention_equation_1.jpg\" alt=\"attention equation 1\" width=\"800\"\u003e\n", "\n", "We're using *Bahdanau attention*. Lets decide on notation before writing the simplified form:\n", "\n", @@ -434,17 +362,14 @@ ] }, { + "cell_type": "code", + "execution_count": 0, "metadata": { - "id": "avyJ_4VIUoHb", + "colab": {}, "colab_type": "code", - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - } + "id": "avyJ_4VIUoHb" }, - "cell_type": "code", + "outputs": [], "source": [ "def gru(units):\n", " # If you have a GPU, we recommend using CuDNNGRU(provides a 3x speedup than GRU)\n", @@ -460,22 +385,17 @@ " return_state=True, \n", " recurrent_activation='sigmoid', \n", " recurrent_initializer='glorot_uniform')" - ], - "execution_count": 0, - "outputs": [] + ] }, { + "cell_type": "code", + "execution_count": 0, "metadata": { - "id": "nZ2rI24i3jFg", + "colab": {}, "colab_type": "code", - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - } + "id": "nZ2rI24i3jFg" }, - "cell_type": "code", + "outputs": [], "source": [ "class Encoder(tf.keras.Model):\n", " def __init__(self, vocab_size, embedding_dim, enc_units, batch_sz):\n", @@ -492,22 +412,17 @@ " \n", " def initialize_hidden_state(self):\n", " return tf.zeros((self.batch_sz, self.enc_units))" - ], - "execution_count": 0, - "outputs": [] + ] }, { + "cell_type": "code", + "execution_count": 0, "metadata": { - "id": "yJ_B3mhW3jFk", + "colab": {}, "colab_type": "code", - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - } + "id": "yJ_B3mhW3jFk" }, - "cell_type": "code", + "outputs": [], "source": [ "class Decoder(tf.keras.Model):\n", " def __init__(self, vocab_size, embedding_dim, dec_units, batch_sz):\n", @@ -561,51 +476,41 @@ " \n", " def initialize_hidden_state(self):\n", " return tf.zeros((self.batch_sz, self.dec_units))" - ], - "execution_count": 0, - "outputs": [] + ] }, { + "cell_type": "code", + "execution_count": 0, "metadata": { - "id": "P5UY8wko3jFp", + "colab": {}, "colab_type": "code", - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - } + "id": "P5UY8wko3jFp" }, - "cell_type": "code", + "outputs": [], "source": [ "encoder = Encoder(vocab_inp_size, embedding_dim, units, BATCH_SIZE)\n", "decoder = Decoder(vocab_tar_size, embedding_dim, units, BATCH_SIZE)" - ], - "execution_count": 0, - "outputs": [] + ] }, { + "cell_type": "markdown", "metadata": { - "id": "_ch_71VbIRfK", - "colab_type": "text" + "colab_type": "text", + "id": "_ch_71VbIRfK" }, - "cell_type": "markdown", "source": [ "## Define the optimizer and the loss function" ] }, { + "cell_type": "code", + "execution_count": 0, "metadata": { - "id": "WmTHr5iV3jFr", + "colab": {}, "colab_type": "code", - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - } + "id": "WmTHr5iV3jFr" }, - "cell_type": "code", + "outputs": [], "source": [ "optimizer = tf.train.AdamOptimizer()\n", "\n", @@ -614,16 +519,41 @@ " mask = 1 - np.equal(real, 0)\n", " loss_ = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=real, logits=pred) * mask\n", " return tf.reduce_mean(loss_)" - ], - "execution_count": 0, - "outputs": [] + ] }, { + "cell_type": "markdown", "metadata": { - "id": "hpObfY22IddU", - "colab_type": "text" + "colab_type": "text", + "id": "DMVWzzsfNl4e" }, + "source": [ + "## Checkpoints (Object-based saving)" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "Zj8bXQTgNwrF" + }, + "outputs": [], + "source": [ + "checkpoint_dir = './training_checkpoints'\n", + "checkpoint_prefix = os.path.join(checkpoint_dir, \"ckpt\")\n", + "checkpoint = tf.train.Checkpoint(optimizer=optimizer,\n", + " encoder=encoder,\n", + " decoder=decoder)" + ] + }, + { "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "hpObfY22IddU" + }, "source": [ "## Training\n", "\n", @@ -637,17 +567,14 @@ ] }, { + "cell_type": "code", + "execution_count": 0, "metadata": { - "id": "ddefjBMa3jF0", + "colab": {}, "colab_type": "code", - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - } + "id": "ddefjBMa3jF0" }, - "cell_type": "code", + "outputs": [], "source": [ "EPOCHS = 10\n", "\n", @@ -665,7 +592,7 @@ " \n", " dec_hidden = enc_hidden\n", " \n", - " dec_input = tf.expand_dims([targ_lang.word2idx['']] * BATCH_SIZE, 1) \n", + " dec_input = tf.expand_dims([targ_lang.word2idx['\u003cstart\u003e']] * BATCH_SIZE, 1) \n", " \n", " # Teacher forcing - feeding the target as the next input\n", " for t in range(1, targ.shape[1]):\n", @@ -677,32 +604,35 @@ " # using teacher forcing\n", " dec_input = tf.expand_dims(targ[:, t], 1)\n", " \n", - " total_loss += (loss / int(targ.shape[1]))\n", + " batch_loss = (loss / int(targ.shape[1]))\n", + " \n", + " total_loss += batch_loss\n", " \n", " variables = encoder.variables + decoder.variables\n", " \n", " gradients = tape.gradient(loss, variables)\n", - " \n", - " optimizer.apply_gradients(zip(gradients, variables), tf.train.get_or_create_global_step())\n", - "\n", + " \n", + " optimizer.apply_gradients(zip(gradients, variables))\n", + " \n", " if batch % 100 == 0:\n", " print('Epoch {} Batch {} Loss {:.4f}'.format(epoch + 1,\n", " batch,\n", - " loss.numpy() / int(targ.shape[1])))\n", + " batch_loss.numpy()))\n", + " # saving (checkpoint) the model every 2 epochs\n", + " if (epoch + 1) % 2 == 0:\n", + " checkpoint.save(file_prefix = checkpoint_prefix)\n", " \n", " print('Epoch {} Loss {:.4f}'.format(epoch + 1,\n", - " total_loss/len(input_tensor)))\n", + " total_loss / N_BATCH))\n", " print('Time taken for 1 epoch {} sec\\n'.format(time.time() - start))" - ], - "execution_count": 0, - "outputs": [] + ] }, { + "cell_type": "markdown", "metadata": { - "id": "mU3Ce8M6I3rz", - "colab_type": "text" + "colab_type": "text", + "id": "mU3Ce8M6I3rz" }, - "cell_type": "markdown", "source": [ "## Translate\n", "\n", @@ -714,17 +644,14 @@ ] }, { + "cell_type": "code", + "execution_count": 0, "metadata": { - "id": "EbQpyYs13jF_", + "colab": {}, "colab_type": "code", - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - } + "id": "EbQpyYs13jF_" }, - "cell_type": "code", + "outputs": [], "source": [ "def evaluate(sentence, encoder, decoder, inp_lang, targ_lang, max_length_inp, max_length_targ):\n", " attention_plot = np.zeros((max_length_targ, max_length_inp))\n", @@ -741,7 +668,7 @@ " enc_out, enc_hidden = encoder(inputs, hidden)\n", "\n", " dec_hidden = enc_hidden\n", - " dec_input = tf.expand_dims([targ_lang.word2idx['']], 0)\n", + " dec_input = tf.expand_dims([targ_lang.word2idx['\u003cstart\u003e']], 0)\n", "\n", " for t in range(max_length_targ):\n", " predictions, dec_hidden, attention_weights = decoder(dec_input, dec_hidden, enc_out)\n", @@ -754,29 +681,24 @@ "\n", " result += targ_lang.idx2word[predicted_id] + ' '\n", "\n", - " if targ_lang.idx2word[predicted_id] == '':\n", + " if targ_lang.idx2word[predicted_id] == '\u003cend\u003e':\n", " return result, sentence, attention_plot\n", " \n", " # the predicted ID is fed back into the model\n", " dec_input = tf.expand_dims([predicted_id], 0)\n", "\n", " return result, sentence, attention_plot" - ], - "execution_count": 0, - "outputs": [] + ] }, { + "cell_type": "code", + "execution_count": 0, "metadata": { - "id": "s5hQWlbN3jGF", + "colab": {}, "colab_type": "code", - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - } + "id": "s5hQWlbN3jGF" }, - "cell_type": "code", + "outputs": [], "source": [ "# function for plotting the attention weights\n", "def plot_attention(attention, sentence, predicted_sentence):\n", @@ -790,22 +712,17 @@ " ax.set_yticklabels([''] + predicted_sentence, fontdict=fontdict)\n", "\n", " plt.show()" - ], - "execution_count": 0, - "outputs": [] + ] }, { + "cell_type": "code", + "execution_count": 0, "metadata": { - "id": "sl9zUHzg3jGI", + "colab": {}, "colab_type": "code", - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - } + "id": "sl9zUHzg3jGI" }, - "cell_type": "code", + "outputs": [], "source": [ "def translate(sentence, encoder, decoder, inp_lang, targ_lang, max_length_inp, max_length_targ):\n", " result, sentence, attention_plot = evaluate(sentence, encoder, decoder, inp_lang, targ_lang, max_length_inp, max_length_targ)\n", @@ -815,89 +732,91 @@ " \n", " attention_plot = attention_plot[:len(result.split(' ')), :len(sentence.split(' '))]\n", " plot_attention(attention_plot, sentence.split(' '), result.split(' '))" - ], - "execution_count": 0, - "outputs": [] + ] }, { + "cell_type": "markdown", "metadata": { - "id": "WrAM0FDomq3E", + "colab_type": "text", + "id": "n250XbnjOaqP" + }, + "source": [ + "## Restore the latest checkpoint and test" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": {}, "colab_type": "code", - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - } + "id": "UJpT9D5_OgP6" }, + "outputs": [], + "source": [ + "# restoring the latest checkpoint in checkpoint_dir\n", + "checkpoint.restore(tf.train.latest_checkpoint(checkpoint_dir))" + ] + }, + { "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "WrAM0FDomq3E" + }, + "outputs": [], "source": [ "translate('hace mucho frio aqui.', encoder, decoder, inp_lang, targ_lang, max_length_inp, max_length_targ)" - ], - "execution_count": 0, - "outputs": [] + ] }, { + "cell_type": "code", + "execution_count": 0, "metadata": { - "id": "zSx2iM36EZQZ", + "colab": {}, "colab_type": "code", - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - } + "id": "zSx2iM36EZQZ" }, - "cell_type": "code", + "outputs": [], "source": [ "translate('esta es mi vida.', encoder, decoder, inp_lang, targ_lang, max_length_inp, max_length_targ)" - ], - "execution_count": 0, - "outputs": [] + ] }, { + "cell_type": "code", + "execution_count": 0, "metadata": { - "id": "A3LLCx3ZE0Ls", + "colab": {}, "colab_type": "code", - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - } + "id": "A3LLCx3ZE0Ls" }, - "cell_type": "code", + "outputs": [], "source": [ "translate('¿todavia estan en casa?', encoder, decoder, inp_lang, targ_lang, max_length_inp, max_length_targ)" - ], - "execution_count": 0, - "outputs": [] + ] }, { + "cell_type": "code", + "execution_count": 0, "metadata": { - "id": "DUQVLVqUE1YW", + "colab": {}, "colab_type": "code", - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - } + "id": "DUQVLVqUE1YW" }, - "cell_type": "code", + "outputs": [], "source": [ "# wrong translation\n", "translate('trata de averiguarlo.', encoder, decoder, inp_lang, targ_lang, max_length_inp, max_length_targ)" - ], - "execution_count": 0, - "outputs": [] + ] }, { + "cell_type": "markdown", "metadata": { - "id": "RTe5P5ioMJwN", - "colab_type": "text" + "colab_type": "text", + "id": "RTe5P5ioMJwN" }, - "cell_type": "markdown", "source": [ "## Next steps\n", "\n", @@ -905,5 +824,31 @@ "* Experiment with training on a larger dataset, or using more epochs\n" ] } - ] -} \ No newline at end of file + ], + "metadata": { + "accelerator": "GPU", + "colab": { + "collapsed_sections": [], + "name": "nmt_with_attention.ipynb", + "private_outputs": true, + "provenance": [ + { + "file_id": "1C4fpM7_7IL8ZzF7Gc5abywqQjeQNS2-U", + "timestamp": 1527858391290 + }, + { + "file_id": "1pExo6aUuw0S6MISFWoinfJv0Ftm9V4qv", + "timestamp": 1527776041613 + } + ], + "toc_visible": true, + "version": "0.3.2" + }, + "kernelspec": { + "display_name": "Python 3", + "name": "python3" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} diff --git a/tensorflow/contrib/eager/python/examples/notebooks/automatic_differentiation.ipynb b/tensorflow/contrib/eager/python/examples/notebooks/automatic_differentiation.ipynb index 7c0f9b5b8161a763c4153ebdeece7e0d1b90b384..51b7ffc4de0cee31f7a907ae7bf90f17056f9bcf 100644 --- a/tensorflow/contrib/eager/python/examples/notebooks/automatic_differentiation.ipynb +++ b/tensorflow/contrib/eager/python/examples/notebooks/automatic_differentiation.ipynb @@ -1,46 +1,30 @@ { - "nbformat": 4, - "nbformat_minor": 0, - "metadata": { - "colab": { - "name": "automatic_differentiation.ipynb", - "version": "0.3.2", - "views": {}, - "default_view": {}, - "provenance": [], - "private_outputs": true, - "collapsed_sections": [], - "toc_visible": true - }, - "kernelspec": { - "name": "python3", - "display_name": "Python 3" - } - }, "cells": [ { + "cell_type": "markdown", "metadata": { - "id": "t09eeeR5prIJ", - "colab_type": "text" + "colab_type": "text", + "id": "t09eeeR5prIJ" }, - "cell_type": "markdown", "source": [ "##### Copyright 2018 The TensorFlow Authors." ] }, { + "cell_type": "code", + "execution_count": 0, "metadata": { - "id": "GCCk8_dHpuNf", - "colab_type": "code", + "cellView": "form", "colab": { "autoexec": { "startup": false, "wait_interval": 0 } }, - "cellView": "form" + "colab_type": "code", + "id": "GCCk8_dHpuNf" }, - "cell_type": "code", + "outputs": [], "source": [ "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n", "# you may not use this file except in compliance with the License.\n", @@ -53,81 +37,79 @@ "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", "# See the License for the specific language governing permissions and\n", "# limitations under the License." - ], - "execution_count": 0, - "outputs": [] + ] }, { + "cell_type": "markdown", "metadata": { - "id": "xh8WkEwWpnm7", - "colab_type": "text" + "colab_type": "text", + "id": "xh8WkEwWpnm7" }, - "cell_type": "markdown", "source": [ "# Automatic differentiation and gradient tape" ] }, { + "cell_type": "markdown", "metadata": { - "id": "idv0bPeCp325", - "colab_type": "text" + "colab_type": "text", + "id": "idv0bPeCp325" }, - "cell_type": "markdown", "source": [ - "
\n", - "\n", - " Run in Google Colab\n", - "\n", - "View source on GitHub
" + "\u003ctable class=\"tfo-notebook-buttons\" align=\"left\"\u003e\u003ctd\u003e\n", + "\u003ca target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/tensorflow/blob/master/tensorflow/contrib/eager/python/examples/notebooks/automatic_differentiation.ipynb\"\u003e\n", + " \u003cimg src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" /\u003eRun in Google Colab\u003c/a\u003e\n", + "\u003c/td\u003e\u003ctd\u003e\n", + "\u003ca target=\"_blank\" href=\"https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/eager/python/examples/notebooks/automatic_differentiation.ipynb\"\u003e\u003cimg width=32px src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" /\u003eView source on GitHub\u003c/a\u003e\u003c/td\u003e\u003c/table\u003e" ] }, { + "cell_type": "markdown", "metadata": { - "id": "vDJ4XzMqodTy", - "colab_type": "text" + "colab_type": "text", + "id": "vDJ4XzMqodTy" }, - "cell_type": "markdown", "source": [ "In the previous tutorial we introduced `Tensor`s and operations on them. In this tutorial we will cover [automatic differentiation](https://en.wikipedia.org/wiki/Automatic_differentiation), a key technique for optimizing machine learning models." ] }, { + "cell_type": "markdown", "metadata": { - "id": "GQJysDM__Qb0", - "colab_type": "text" + "colab_type": "text", + "id": "GQJysDM__Qb0" }, - "cell_type": "markdown", "source": [ "## Setup\n" ] }, { + "cell_type": "code", + "execution_count": 0, "metadata": { - "id": "OiMPZStlibBv", - "colab_type": "code", "colab": { "autoexec": { "startup": false, "wait_interval": 0 } - } + }, + "colab_type": "code", + "id": "OiMPZStlibBv" }, - "cell_type": "code", + "outputs": [], "source": [ "import tensorflow as tf\n", "tf.enable_eager_execution()\n", "\n", "tfe = tf.contrib.eager # Shorthand for some symbols" - ], - "execution_count": 0, - "outputs": [] + ] }, { + "cell_type": "markdown", "metadata": { - "id": "1CLWJl0QliB0", - "colab_type": "text" + "colab_type": "text", + "id": "1CLWJl0QliB0" }, - "cell_type": "markdown", "source": [ "## Derivatives of a function\n", "\n", @@ -135,17 +117,19 @@ ] }, { + "cell_type": "code", + "execution_count": 0, "metadata": { - "id": "9FViq92UX7P8", - "colab_type": "code", "colab": { "autoexec": { "startup": false, "wait_interval": 0 } - } + }, + "colab_type": "code", + "id": "9FViq92UX7P8" }, - "cell_type": "code", + "outputs": [], "source": [ "from math import pi\n", "\n", @@ -159,17 +143,15 @@ "# with respect to its arguments. Since f() has a single argument,\n", "# grad_f will return a list with a single element.\n", "grad_f = tfe.gradients_function(f)\n", - "assert tf.abs(grad_f(pi/2)[0]).numpy() < 1e-7" - ], - "execution_count": 0, - "outputs": [] + "assert tf.abs(grad_f(pi/2)[0]).numpy() \u003c 1e-7" + ] }, { + "cell_type": "markdown", "metadata": { - "id": "v9fPs8RyopCf", - "colab_type": "text" + "colab_type": "text", + "id": "v9fPs8RyopCf" }, - "cell_type": "markdown", "source": [ "### Higher-order gradients\n", "\n", @@ -177,17 +159,19 @@ ] }, { + "cell_type": "code", + "execution_count": 0, "metadata": { - "id": "3D0ZvnGYo0rW", - "colab_type": "code", "colab": { "autoexec": { "startup": false, "wait_interval": 0 } - } + }, + "colab_type": "code", + "id": "3D0ZvnGYo0rW" }, - "cell_type": "code", + "outputs": [], "source": [ "def f(x):\n", " return tf.square(tf.sin(x))\n", @@ -205,16 +189,14 @@ "plt.plot(x, grad(grad(grad(f)))(x), label=\"third derivative\")\n", "plt.legend()\n", "plt.show()" - ], - "execution_count": 0, - "outputs": [] + ] }, { + "cell_type": "markdown", "metadata": { - "id": "-39gouo7mtgu", - "colab_type": "text" + "colab_type": "text", + "id": "-39gouo7mtgu" }, - "cell_type": "markdown", "source": [ "## Gradient tapes\n", "\n", @@ -225,21 +207,25 @@ ] }, { + "cell_type": "code", + "execution_count": 0, "metadata": { - "id": "MH0UfjympWf7", - "colab_type": "code", "colab": { "autoexec": { "startup": false, "wait_interval": 0 } - } + }, + "colab_type": "code", + "id": "MH0UfjympWf7" }, - "cell_type": "code", + "outputs": [], "source": [ "def f(x, y):\n", " output = 1\n", - " for i in range(y):\n", + " # Must use range(int(y)) instead of range(y) in Python 3 when\n", + " # using TensorFlow 1.10 and earlier. Can use range(y) in 1.11+\n", + " for i in range(int(y)):\n", " output = tf.multiply(output, x)\n", " return output\n", "\n", @@ -251,16 +237,14 @@ "assert g(3.0, 2).numpy() == 6.0 # And its gradient will be 2 * x\n", "assert f(4.0, 3).numpy() == 64.0 # f(x, 3) is essentially x * x * x\n", "assert g(4.0, 3).numpy() == 48.0 # And its gradient will be 3 * x * x" - ], - "execution_count": 0, - "outputs": [] + ] }, { + "cell_type": "markdown", "metadata": { - "id": "aNmR5-jhpX2t", - "colab_type": "text" + "colab_type": "text", + "id": "aNmR5-jhpX2t" }, - "cell_type": "markdown", "source": [ "At times it may be inconvenient to encapsulate computation of interest into a function. For example, if you want the gradient of the output with respect to intermediate values computed in the function. In such cases, the slightly more verbose but explicit [tf.GradientTape](https://www.tensorflow.org/api_docs/python/tf/GradientTape) context is useful. All computation inside the context of a `tf.GradientTape` is \"recorded\".\n", "\n", @@ -268,17 +252,19 @@ ] }, { + "cell_type": "code", + "execution_count": 0, "metadata": { - "id": "bAFeIE8EuVIq", - "colab_type": "code", "colab": { "autoexec": { "startup": false, "wait_interval": 0 } - } + }, + "colab_type": "code", + "id": "bAFeIE8EuVIq" }, - "cell_type": "code", + "outputs": [], "source": [ "x = tf.ones((2, 2))\n", " \n", @@ -300,16 +286,14 @@ "for i in [0, 1]:\n", " for j in [0, 1]:\n", " assert dz_dx[i][j].numpy() == 8.0" - ], - "execution_count": 0, - "outputs": [] + ] }, { + "cell_type": "markdown", "metadata": { - "id": "DK05KXrAAld3", - "colab_type": "text" + "colab_type": "text", + "id": "DK05KXrAAld3" }, - "cell_type": "markdown", "source": [ "### Higher-order gradients\n", "\n", @@ -317,17 +301,19 @@ ] }, { + "cell_type": "code", + "execution_count": 0, "metadata": { - "id": "cPQgthZ7ugRJ", - "colab_type": "code", "colab": { "autoexec": { "startup": false, "wait_interval": 0 } - } + }, + "colab_type": "code", + "id": "cPQgthZ7ugRJ" }, - "cell_type": "code", + "outputs": [], "source": [ "# TODO(ashankar): Should we use the persistent tape here instead? Follow up on Tom and Alex's discussion\n", "\n", @@ -344,21 +330,37 @@ "\n", "assert dy_dx.numpy() == 3.0\n", "assert d2y_dx2.numpy() == 6.0" - ], - "execution_count": 0, - "outputs": [] + ] }, { + "cell_type": "markdown", "metadata": { - "id": "4U1KKzUpNl58", - "colab_type": "text" + "colab_type": "text", + "id": "4U1KKzUpNl58" }, - "cell_type": "markdown", "source": [ "## Next Steps\n", "\n", "In this tutorial we covered gradient computation in TensorFlow. With that we have enough of the primitives required to build an train neural networks, which we will cover in the [next tutorial](https://github.com/tensorflow/models/tree/master/official/contrib/eager/python/examples/notebooks/3_neural_networks.ipynb)." ] } - ] -} \ No newline at end of file + ], + "metadata": { + "colab": { + "collapsed_sections": [], + "default_view": {}, + "name": "automatic_differentiation.ipynb", + "private_outputs": true, + "provenance": [], + "toc_visible": true, + "version": "0.3.2", + "views": {} + }, + "kernelspec": { + "display_name": "Python 3", + "name": "python3" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} diff --git a/tensorflow/contrib/eager/python/examples/pix2pix/pix2pix_eager.ipynb b/tensorflow/contrib/eager/python/examples/pix2pix/pix2pix_eager.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..ee25d25b52a2e06d9f99bdbe295afd228a3c6ce1 --- /dev/null +++ b/tensorflow/contrib/eager/python/examples/pix2pix/pix2pix_eager.ipynb @@ -0,0 +1,810 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "0TD5ZrvEMbhZ" + }, + "source": [ + "##### Copyright 2018 The TensorFlow Authors.\n", + "\n", + "Licensed under the Apache License, Version 2.0 (the \"License\").\n", + "\n", + "# Pix2Pix: An example with tf.keras and eager\n", + "\n", + "\u003ctable class=\"tfo-notebook-buttons\" align=\"left\"\u003e\u003ctd\u003e\n", + "\u003ca target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/tensorflow/blob/master/tensorflow/contrib/eager/python/examples/pix2pix/pix2pix_eager.ipynb\"\u003e\n", + " \u003cimg src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" /\u003eRun in Google Colab\u003c/a\u003e \n", + "\u003c/td\u003e\u003ctd\u003e\n", + "\u003ca target=\"_blank\" href=\"https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/eager/python/examples/pix2pix/pix2pix_eager.ipynb\"\u003e\u003cimg width=32px src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" /\u003eView source on GitHub\u003c/a\u003e\u003c/td\u003e\u003c/table\u003e" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "ITZuApL56Mny" + }, + "source": [ + "This notebook demonstrates image to image translation using conditional GAN's, as described in [Image-to-Image Translation with Conditional Adversarial Networks](https://arxiv.org/abs/1611.07004). Using this technique we can colorize black and white photos, convert google maps to google earth, etc. Here, we convert building facades to real buildings. We use [tf.keras](https://www.tensorflow.org/programmers_guide/keras) and [eager execution](https://www.tensorflow.org/programmers_guide/eager) to achieve this.\n", + "\n", + "In example, we will use the [CMP Facade Database](http://cmp.felk.cvut.cz/~tylecr1/facade/), helpfully provided by the [Center for Machine Perception](http://cmp.felk.cvut.cz/) at the [Czech Technical University in Prague](https://www.cvut.cz/). To keep our example short, we will use a preprocessed [copy](https://people.eecs.berkeley.edu/~tinghuiz/projects/pix2pix/datasets/) of this dataset, created by the authors of the [paper](https://arxiv.org/abs/1611.07004) above.\n", + "\n", + "Each epoch takes around 58 seconds on a single P100 GPU.\n", + "\n", + "Below is the output generated after training the model for 200 epochs.\n", + "\n", + "\n", + "![sample output_1](https://www.tensorflow.org/images/gan/pix2pix_1.png)\n", + "![sample output_2](https://www.tensorflow.org/images/gan/pix2pix_2.png)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "e1_Y75QXJS6h" + }, + "source": [ + "## Import TensorFlow and enable eager execution" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "YfIk2es3hJEd" + }, + "outputs": [], + "source": [ + "# Import TensorFlow \u003e= 1.10 and enable eager execution\n", + "import tensorflow as tf\n", + "tf.enable_eager_execution()\n", + "\n", + "import os\n", + "import time\n", + "import numpy as np\n", + "import matplotlib.pyplot as plt\n", + "import PIL\n", + "from IPython.display import clear_output" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "iYn4MdZnKCey" + }, + "source": [ + "## Load the dataset\n", + "\n", + "You can download this dataset and similar datasets from [here](https://people.eecs.berkeley.edu/~tinghuiz/projects/pix2pix/datasets). As mentioned in the [paper](https://arxiv.org/abs/1611.07004) we apply random jittering and mirroring to the training dataset.\n", + "* In random jittering, the image is resized to `286 x 286` and then randomly cropped to `256 x 256`\n", + "* In random mirroring, the image is randomly flipped horizontally i.e left to right." + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "Kn-k8kTXuAlv" + }, + "outputs": [], + "source": [ + "path_to_zip = tf.keras.utils.get_file('facades.tar.gz',\n", + " cache_subdir=os.path.abspath('.'),\n", + " origin='https://people.eecs.berkeley.edu/~tinghuiz/projects/pix2pix/datasets/facades.tar.gz', \n", + " extract=True)\n", + "\n", + "PATH = os.path.join(os.path.dirname(path_to_zip), 'facades/')" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "2CbTEt448b4R" + }, + "outputs": [], + "source": [ + "BUFFER_SIZE = 400\n", + "BATCH_SIZE = 1\n", + "IMG_WIDTH = 256\n", + "IMG_HEIGHT = 256" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "tyaP4hLJ8b4W" + }, + "outputs": [], + "source": [ + "def load_image(image_file, is_train):\n", + " image = tf.read_file(image_file)\n", + " image = tf.image.decode_jpeg(image)\n", + "\n", + " w = tf.shape(image)[1]\n", + "\n", + " w = w // 2\n", + " real_image = image[:, :w, :]\n", + " input_image = image[:, w:, :]\n", + "\n", + " input_image = tf.cast(input_image, tf.float32)\n", + " real_image = tf.cast(real_image, tf.float32)\n", + "\n", + " if is_train:\n", + " # random jittering\n", + " \n", + " # resizing to 286 x 286 x 3\n", + " # method = 2 indicates using \"ResizeMethod.NEAREST_NEIGHBOR\"\n", + " input_image = tf.image.resize_images(input_image, [286, 286], \n", + " align_corners=True, method=2)\n", + " real_image = tf.image.resize_images(real_image, [286, 286], \n", + " align_corners=True, method=2)\n", + " \n", + " # randomly cropping to 256 x 256 x 3\n", + " stacked_image = tf.stack([input_image, real_image], axis=0)\n", + " cropped_image = tf.random_crop(stacked_image, size=[2, IMG_HEIGHT, IMG_WIDTH, 3])\n", + " input_image, real_image = cropped_image[0], cropped_image[1]\n", + "\n", + " if np.random.random() \u003e 0.5:\n", + " # random mirroring\n", + " input_image = tf.image.flip_left_right(input_image)\n", + " real_image = tf.image.flip_left_right(real_image)\n", + " else:\n", + " input_image = tf.image.resize_images(input_image, size=[IMG_HEIGHT, IMG_WIDTH], \n", + " align_corners=True, method=2)\n", + " real_image = tf.image.resize_images(real_image, size=[IMG_HEIGHT, IMG_WIDTH], \n", + " align_corners=True, method=2)\n", + " \n", + " # normalizing the images to [-1, 1]\n", + " input_image = (input_image / 127.5) - 1\n", + " real_image = (real_image / 127.5) - 1\n", + "\n", + " return input_image, real_image" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "PIGN6ouoQxt3" + }, + "source": [ + "## Use tf.data to create batches, map(do preprocessing) and shuffle the dataset" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "SQHmYSmk8b4b" + }, + "outputs": [], + "source": [ + "train_dataset = tf.data.Dataset.list_files(PATH+'train/*.jpg')\n", + "train_dataset = train_dataset.shuffle(BUFFER_SIZE)\n", + "train_dataset = train_dataset.map(lambda x: load_image(x, True))\n", + "train_dataset = train_dataset.batch(1)" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "MS9J0yA58b4g" + }, + "outputs": [], + "source": [ + "test_dataset = tf.data.Dataset.list_files(PATH+'test/*.jpg')\n", + "test_dataset = test_dataset.map(lambda x: load_image(x, False))\n", + "test_dataset = test_dataset.batch(1)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "THY-sZMiQ4UV" + }, + "source": [ + "## Write the generator and discriminator models\n", + "\n", + "* **Generator** \n", + " * The architecture of generator is a modified U-Net.\n", + " * Each block in the encoder is (Conv -\u003e Batchnorm -\u003e Leaky ReLU)\n", + " * Each block in the decoder is (Transposed Conv -\u003e Batchnorm -\u003e Dropout(applied to the first 3 blocks) -\u003e ReLU)\n", + " * There are skip connections between the encoder and decoder (as in U-Net).\n", + " \n", + "* **Discriminator**\n", + " * The Discriminator is a PatchGAN.\n", + " * Each block in the discriminator is (Conv -\u003e BatchNorm -\u003e Leaky ReLU)\n", + " * The shape of the output after the last layer is (batch_size, 30, 30, 1)\n", + " * Each 30x30 patch of the output classifies a 70x70 portion of the input image (such an architecture is called a PatchGAN).\n", + " * Discriminator receives 2 inputs.\n", + " * Input image and the target image, which it should classify as real.\n", + " * Input image and the generated image (output of generator), which it should classify as fake. \n", + " * We concatenate these 2 inputs together in the code (`tf.concat([inp, tar], axis=-1)`)\n", + "\n", + "* Shape of the input travelling through the generator and the discriminator is in the comments in the code.\n", + "\n", + "To learn more about the architecture and the hyperparameters you can refer the [paper](https://arxiv.org/abs/1611.07004).\n", + " " + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "tqqvWxlw8b4l" + }, + "outputs": [], + "source": [ + "OUTPUT_CHANNELS = 3" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "lFPI4Nu-8b4q" + }, + "outputs": [], + "source": [ + "class Downsample(tf.keras.Model):\n", + " \n", + " def __init__(self, filters, size, apply_batchnorm=True):\n", + " super(Downsample, self).__init__()\n", + " self.apply_batchnorm = apply_batchnorm\n", + " initializer = tf.random_normal_initializer(0., 0.02)\n", + "\n", + " self.conv1 = tf.keras.layers.Conv2D(filters, \n", + " (size, size), \n", + " strides=2, \n", + " padding='same',\n", + " kernel_initializer=initializer,\n", + " use_bias=False)\n", + " if self.apply_batchnorm:\n", + " self.batchnorm = tf.keras.layers.BatchNormalization()\n", + " \n", + " def call(self, x, training):\n", + " x = self.conv1(x)\n", + " if self.apply_batchnorm:\n", + " x = self.batchnorm(x, training=training)\n", + " x = tf.nn.leaky_relu(x)\n", + " return x \n", + "\n", + "\n", + "class Upsample(tf.keras.Model):\n", + " \n", + " def __init__(self, filters, size, apply_dropout=False):\n", + " super(Upsample, self).__init__()\n", + " self.apply_dropout = apply_dropout\n", + " initializer = tf.random_normal_initializer(0., 0.02)\n", + "\n", + " self.up_conv = tf.keras.layers.Conv2DTranspose(filters, \n", + " (size, size), \n", + " strides=2, \n", + " padding='same',\n", + " kernel_initializer=initializer,\n", + " use_bias=False)\n", + " self.batchnorm = tf.keras.layers.BatchNormalization()\n", + " if self.apply_dropout:\n", + " self.dropout = tf.keras.layers.Dropout(0.5)\n", + "\n", + " def call(self, x1, x2, training):\n", + " x = self.up_conv(x1)\n", + " x = self.batchnorm(x, training=training)\n", + " if self.apply_dropout:\n", + " x = self.dropout(x, training=training)\n", + " x = tf.nn.relu(x)\n", + " x = tf.concat([x, x2], axis=-1)\n", + " return x\n", + "\n", + "\n", + "class Generator(tf.keras.Model):\n", + " \n", + " def __init__(self):\n", + " super(Generator, self).__init__()\n", + " initializer = tf.random_normal_initializer(0., 0.02)\n", + " \n", + " self.down1 = Downsample(64, 4, apply_batchnorm=False)\n", + " self.down2 = Downsample(128, 4)\n", + " self.down3 = Downsample(256, 4)\n", + " self.down4 = Downsample(512, 4)\n", + " self.down5 = Downsample(512, 4)\n", + " self.down6 = Downsample(512, 4)\n", + " self.down7 = Downsample(512, 4)\n", + " self.down8 = Downsample(512, 4)\n", + "\n", + " self.up1 = Upsample(512, 4, apply_dropout=True)\n", + " self.up2 = Upsample(512, 4, apply_dropout=True)\n", + " self.up3 = Upsample(512, 4, apply_dropout=True)\n", + " self.up4 = Upsample(512, 4)\n", + " self.up5 = Upsample(256, 4)\n", + " self.up6 = Upsample(128, 4)\n", + " self.up7 = Upsample(64, 4)\n", + "\n", + " self.last = tf.keras.layers.Conv2DTranspose(OUTPUT_CHANNELS, \n", + " (4, 4), \n", + " strides=2, \n", + " padding='same',\n", + " kernel_initializer=initializer)\n", + " \n", + " @tf.contrib.eager.defun\n", + " def call(self, x, training):\n", + " # x shape == (bs, 256, 256, 3) \n", + " x1 = self.down1(x, training=training) # (bs, 128, 128, 64)\n", + " x2 = self.down2(x1, training=training) # (bs, 64, 64, 128)\n", + " x3 = self.down3(x2, training=training) # (bs, 32, 32, 256)\n", + " x4 = self.down4(x3, training=training) # (bs, 16, 16, 512)\n", + " x5 = self.down5(x4, training=training) # (bs, 8, 8, 512)\n", + " x6 = self.down6(x5, training=training) # (bs, 4, 4, 512)\n", + " x7 = self.down7(x6, training=training) # (bs, 2, 2, 512)\n", + " x8 = self.down8(x7, training=training) # (bs, 1, 1, 512)\n", + "\n", + " x9 = self.up1(x8, x7, training=training) # (bs, 2, 2, 1024)\n", + " x10 = self.up2(x9, x6, training=training) # (bs, 4, 4, 1024)\n", + " x11 = self.up3(x10, x5, training=training) # (bs, 8, 8, 1024)\n", + " x12 = self.up4(x11, x4, training=training) # (bs, 16, 16, 1024)\n", + " x13 = self.up5(x12, x3, training=training) # (bs, 32, 32, 512)\n", + " x14 = self.up6(x13, x2, training=training) # (bs, 64, 64, 256)\n", + " x15 = self.up7(x14, x1, training=training) # (bs, 128, 128, 128)\n", + "\n", + " x16 = self.last(x15) # (bs, 256, 256, 3)\n", + " x16 = tf.nn.tanh(x16)\n", + "\n", + " return x16" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "ll6aNeQx8b4v" + }, + "outputs": [], + "source": [ + "class DiscDownsample(tf.keras.Model):\n", + " \n", + " def __init__(self, filters, size, apply_batchnorm=True):\n", + " super(DiscDownsample, self).__init__()\n", + " self.apply_batchnorm = apply_batchnorm\n", + " initializer = tf.random_normal_initializer(0., 0.02)\n", + "\n", + " self.conv1 = tf.keras.layers.Conv2D(filters, \n", + " (size, size), \n", + " strides=2, \n", + " padding='same',\n", + " kernel_initializer=initializer,\n", + " use_bias=False)\n", + " if self.apply_batchnorm:\n", + " self.batchnorm = tf.keras.layers.BatchNormalization()\n", + " \n", + " def call(self, x, training):\n", + " x = self.conv1(x)\n", + " if self.apply_batchnorm:\n", + " x = self.batchnorm(x, training=training)\n", + " x = tf.nn.leaky_relu(x)\n", + " return x \n", + "\n", + "class Discriminator(tf.keras.Model):\n", + " \n", + " def __init__(self):\n", + " super(Discriminator, self).__init__()\n", + " initializer = tf.random_normal_initializer(0., 0.02)\n", + " \n", + " self.down1 = DiscDownsample(64, 4, False)\n", + " self.down2 = DiscDownsample(128, 4)\n", + " self.down3 = DiscDownsample(256, 4)\n", + " \n", + " # we are zero padding here with 1 because we need our shape to \n", + " # go from (batch_size, 32, 32, 256) to (batch_size, 31, 31, 512)\n", + " self.zero_pad1 = tf.keras.layers.ZeroPadding2D()\n", + " self.conv = tf.keras.layers.Conv2D(512, \n", + " (4, 4), \n", + " strides=1, \n", + " kernel_initializer=initializer, \n", + " use_bias=False)\n", + " self.batchnorm1 = tf.keras.layers.BatchNormalization()\n", + " \n", + " # shape change from (batch_size, 31, 31, 512) to (batch_size, 30, 30, 1)\n", + " self.zero_pad2 = tf.keras.layers.ZeroPadding2D()\n", + " self.last = tf.keras.layers.Conv2D(1, \n", + " (4, 4), \n", + " strides=1,\n", + " kernel_initializer=initializer)\n", + " \n", + " @tf.contrib.eager.defun\n", + " def call(self, inp, tar, training):\n", + " # concatenating the input and the target\n", + " x = tf.concat([inp, tar], axis=-1) # (bs, 256, 256, channels*2)\n", + " x = self.down1(x, training=training) # (bs, 128, 128, 64)\n", + " x = self.down2(x, training=training) # (bs, 64, 64, 128)\n", + " x = self.down3(x, training=training) # (bs, 32, 32, 256)\n", + "\n", + " x = self.zero_pad1(x) # (bs, 34, 34, 256)\n", + " x = self.conv(x) # (bs, 31, 31, 512)\n", + " x = self.batchnorm1(x, training=training)\n", + " x = tf.nn.leaky_relu(x)\n", + " \n", + " x = self.zero_pad2(x) # (bs, 33, 33, 512)\n", + " # don't add a sigmoid activation here since\n", + " # the loss function expects raw logits.\n", + " x = self.last(x) # (bs, 30, 30, 1)\n", + "\n", + " return x" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "gDkA05NE6QMs" + }, + "outputs": [], + "source": [ + "# The call function of Generator and Discriminator have been decorated\n", + "# with tf.contrib.eager.defun()\n", + "# We get a performance speedup if defun is used (~25 seconds per epoch)\n", + "generator = Generator()\n", + "discriminator = Discriminator()" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "0FMYgY_mPfTi" + }, + "source": [ + "## Define the loss functions and the optimizer\n", + "\n", + "* **Discriminator loss**\n", + " * The discriminator loss function takes 2 inputs; **real images, generated images**\n", + " * real_loss is a sigmoid cross entropy loss of the **real images** and an **array of ones(since these are the real images)**\n", + " * generated_loss is a sigmoid cross entropy loss of the **generated images** and an **array of zeros(since these are the fake images)**\n", + " * Then the total_loss is the sum of real_loss and the generated_loss\n", + " \n", + "* **Generator loss**\n", + " * It is a sigmoid cross entropy loss of the generated images and an **array of ones**.\n", + " * The [paper](https://arxiv.org/abs/1611.07004) also includes L1 loss which is MAE (mean absolute error) between the generated image and the target image.\n", + " * This allows the generated image to become structurally similar to the target image.\n", + " * The formula to calculate the total generator loss = gan_loss + LAMBDA * l1_loss, where LAMBDA = 100. This value was decided by the authors of the [paper](https://arxiv.org/abs/1611.07004)." + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "cyhxTuvJyIHV" + }, + "outputs": [], + "source": [ + "LAMBDA = 100" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "wkMNfBWlT-PV" + }, + "outputs": [], + "source": [ + "def discriminator_loss(disc_real_output, disc_generated_output):\n", + " real_loss = tf.losses.sigmoid_cross_entropy(multi_class_labels = tf.ones_like(disc_real_output), \n", + " logits = disc_real_output)\n", + " generated_loss = tf.losses.sigmoid_cross_entropy(multi_class_labels = tf.zeros_like(disc_generated_output), \n", + " logits = disc_generated_output)\n", + "\n", + " total_disc_loss = real_loss + generated_loss\n", + "\n", + " return total_disc_loss" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "90BIcCKcDMxz" + }, + "outputs": [], + "source": [ + "def generator_loss(disc_generated_output, gen_output, target):\n", + " gan_loss = tf.losses.sigmoid_cross_entropy(multi_class_labels = tf.ones_like(disc_generated_output),\n", + " logits = disc_generated_output) \n", + " # mean absolute error\n", + " l1_loss = tf.reduce_mean(tf.abs(target - gen_output))\n", + "\n", + " total_gen_loss = gan_loss + (LAMBDA * l1_loss)\n", + "\n", + " return total_gen_loss" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "iWCn_PVdEJZ7" + }, + "outputs": [], + "source": [ + "generator_optimizer = tf.train.AdamOptimizer(2e-4, beta1=0.5)\n", + "discriminator_optimizer = tf.train.AdamOptimizer(2e-4, beta1=0.5)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "aKUZnDiqQrAh" + }, + "source": [ + "## Checkpoints (Object-based saving)" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "WJnftd5sQsv6" + }, + "outputs": [], + "source": [ + "checkpoint_dir = './training_checkpoints'\n", + "checkpoint_prefix = os.path.join(checkpoint_dir, \"ckpt\")\n", + "checkpoint = tf.train.Checkpoint(generator_optimizer=generator_optimizer,\n", + " discriminator_optimizer=discriminator_optimizer,\n", + " generator=generator,\n", + " discriminator=discriminator)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "Rw1fkAczTQYh" + }, + "source": [ + "## Training\n", + "\n", + "* We start by iterating over the dataset\n", + "* The generator gets the input image and we get a generated output.\n", + "* The discriminator receives the input_image and the generated image as the first input. The second input is the input_image and the target_image.\n", + "* Next, we calculate the generator and the discriminator loss.\n", + "* Then, we calculate the gradients of loss with respect to both the generator and the discriminator variables(inputs) and apply those to the optimizer.\n", + "\n", + "## Generate Images\n", + "\n", + "* After training, its time to generate some images!\n", + "* We pass images from the test dataset to the generator.\n", + "* The generator will then translate the input image into the output we expect.\n", + "* Last step is to plot the predictions and **voila!**" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "NS2GWywBbAWo" + }, + "outputs": [], + "source": [ + "EPOCHS = 200" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "RmdVsmvhPxyy" + }, + "outputs": [], + "source": [ + "def generate_images(model, test_input, tar):\n", + " # the training=True is intentional here since\n", + " # we want the batch statistics while running the model\n", + " # on the test dataset. If we use training=False, we will get \n", + " # the accumulated statistics learned from the training dataset\n", + " # (which we don't want)\n", + " prediction = model(test_input, training=True)\n", + " plt.figure(figsize=(15,15))\n", + "\n", + " display_list = [test_input[0], tar[0], prediction[0]]\n", + " title = ['Input Image', 'Ground Truth', 'Predicted Image']\n", + "\n", + " for i in range(3):\n", + " plt.subplot(1, 3, i+1)\n", + " plt.title(title[i])\n", + " # getting the pixel values between [0, 1] to plot it.\n", + " plt.imshow(display_list[i] * 0.5 + 0.5)\n", + " plt.axis('off')\n", + " plt.show()" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "2M7LmLtGEMQJ" + }, + "outputs": [], + "source": [ + "def train(dataset, epochs): \n", + " for epoch in range(epochs):\n", + " start = time.time()\n", + "\n", + " for input_image, target in dataset:\n", + "\n", + " with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:\n", + " gen_output = generator(input_image, training=True)\n", + "\n", + " disc_real_output = discriminator(input_image, target, training=True)\n", + " disc_generated_output = discriminator(input_image, gen_output, training=True)\n", + "\n", + " gen_loss = generator_loss(disc_generated_output, gen_output, target)\n", + " disc_loss = discriminator_loss(disc_real_output, disc_generated_output)\n", + "\n", + " generator_gradients = gen_tape.gradient(gen_loss, \n", + " generator.variables)\n", + " discriminator_gradients = disc_tape.gradient(disc_loss, \n", + " discriminator.variables)\n", + "\n", + " generator_optimizer.apply_gradients(zip(generator_gradients, \n", + " generator.variables))\n", + " discriminator_optimizer.apply_gradients(zip(discriminator_gradients, \n", + " discriminator.variables))\n", + "\n", + " if epoch % 1 == 0:\n", + " clear_output(wait=True)\n", + " for inp, tar in test_dataset.take(1):\n", + " generate_images(generator, inp, tar)\n", + " \n", + " # saving (checkpoint) the model every 20 epochs\n", + " if (epoch + 1) % 20 == 0:\n", + " checkpoint.save(file_prefix = checkpoint_prefix)\n", + "\n", + " print ('Time taken for epoch {} is {} sec\\n'.format(epoch + 1,\n", + " time.time()-start))" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "a1zZmKmvOH85" + }, + "outputs": [], + "source": [ + "train(train_dataset, EPOCHS)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "kz80bY3aQ1VZ" + }, + "source": [ + "## Restore the latest checkpoint and test" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "4t4x69adQ5xb" + }, + "outputs": [], + "source": [ + "# restoring the latest checkpoint in checkpoint_dir\n", + "checkpoint.restore(tf.train.latest_checkpoint(checkpoint_dir))" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "1RGysMU_BZhx" + }, + "source": [ + "## Testing on the entire test dataset" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "KUgSnmy2nqSP" + }, + "outputs": [], + "source": [ + "# Run the trained model on the entire test dataset\n", + "for inp, tar in test_dataset:\n", + " generate_images(generator, inp, tar)" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "3AJXOByaZVOf" + }, + "outputs": [], + "source": [ + "" + ] + } + ], + "metadata": { + "accelerator": "GPU", + "colab": { + "collapsed_sections": [], + "name": "pix2pix_eager.ipynb", + "private_outputs": true, + "provenance": [ + { + "file_id": "1eb0NOTQapkYs3X0v-zL1x5_LFKgDISnp", + "timestamp": 1527173385672 + } + ], + "toc_visible": true, + "version": "0.3.2" + }, + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} diff --git a/tensorflow/contrib/eager/python/examples/resnet50/resnet50_test.py b/tensorflow/contrib/eager/python/examples/resnet50/resnet50_test.py index 07d8788882c2d831dfb041fe7409af51857190bf..d265169b5eff685f7b79fb221b9bd52be37ead9c 100644 --- a/tensorflow/contrib/eager/python/examples/resnet50/resnet50_test.py +++ b/tensorflow/contrib/eager/python/examples/resnet50/resnet50_test.py @@ -216,12 +216,12 @@ class ResNet50Benchmarks(tf.test.Benchmark): tf.constant(1.).cpu() def _benchmark_eager_apply(self, label, device_and_format, defun=False, - execution_mode=None, compiled=False): + execution_mode=None): with tfe.execution_mode(execution_mode): device, data_format = device_and_format model = resnet50.ResNet50(data_format) if defun: - model.call = tfe.defun(model.call, compiled=compiled) + model.call = tfe.defun(model.call) batch_size = 64 num_burn = 5 num_iters = 30 @@ -257,8 +257,7 @@ class ResNet50Benchmarks(tf.test.Benchmark): make_iterator, device_and_format, defun=False, - execution_mode=None, - compiled=False): + execution_mode=None): with tfe.execution_mode(execution_mode): device, data_format = device_and_format for batch_size in self._train_batch_sizes(): @@ -267,8 +266,8 @@ class ResNet50Benchmarks(tf.test.Benchmark): optimizer = tf.train.GradientDescentOptimizer(0.1) apply_grads = apply_gradients if defun: - model.call = tfe.defun(model.call, compiled=compiled) - apply_grads = tfe.defun(apply_gradients, compiled=compiled) + model.call = tfe.defun(model.call) + apply_grads = tfe.defun(apply_gradients) num_burn = 3 num_iters = 10 diff --git a/tensorflow/contrib/eager/python/examples/revnet/README.md b/tensorflow/contrib/eager/python/examples/revnet/README.md index 21fc44febc8abdc30daad1b35d8434b083360bdf..822d86e9c7a7e620da3b84ded9af98b1c1d4b701 100644 --- a/tensorflow/contrib/eager/python/examples/revnet/README.md +++ b/tensorflow/contrib/eager/python/examples/revnet/README.md @@ -1,19 +1,22 @@ # RevNet with TensorFlow eager execution -This folder contains an TensorFlow eager implementation of the [Reversible Residual Network](https://arxiv.org/pdf/1707.04585.pdf) adapted from the released implementation by the authors. The presented implementation can be ran both in eager and graph mode. The code is considerably simplified with `tf.GradientTape`. Moreover, we reduce the step of reconstructing the outputs. This saves us from using `tf.stop_gradient` and makes the model run faster. +This folder contains a TensorFlow eager implementation of the [Reversible Residual Network](https://arxiv.org/pdf/1707.04585.pdf) adapted from the released implementation by the authors. The presented implementation can be ran with both eager and graph execution. The code is considerably simplified with `tf.GradientTape`. Moreover, we reduce the a redundant forward pass in the implementation by the authors. This saves us from using `tf.stop_gradient` and makes the model run faster. ## Content - `revnet.py`: The RevNet model. - `blocks.py`: The relevant reversible blocks. +- `ops.py`: Auxiliary downsampling operation. - `cifar_tfrecords.py`: Script to generate the TFRecords for both CIFAR-10 and CIFAR-100. - `cifar_input.py`: Script to read from TFRecords and generate dataset objects with the `tf.data` API. - `config.py`: Configuration file for network architectures and training hyperparameters. - `main.py`: Main training and evaluation script. -- `ops.py`: Auxiliary downsampling operation. +- `main_estimator.py`: Script to train RevNet models on CIFAR-10 and CIFAR-100 with the `tf.estimator` API. +- `main_estimator_tpu.py`: Script to train RevNet models on ImageNet with TPU estimators on Cloud TPUs. +- `resnet_preprocessing.py`, `imagenet_input.py`: Boilerplate to read ImageNet data from TFRecords. -## To run -- Make sure you have installed TensorFlow 1.9+ or the latest `tf-nightly` +## Train on CIFAR-10/CIFAR-100 +- Make sure you have installed TensorFlow 1.10+ or the latest `tf-nightly` or `tf-nightly-gpu` pip package in order to access the eager execution feature. - First run @@ -24,7 +27,7 @@ python cifar_tfrecords.py --data_dir ${PWD}/cifar to download the cifar dataset and convert them to TFRecords. This produces TFRecord files for both CIFAR-10 and CIFAR-100. -- To train a model run +- To train a model, run ```bash python main.py --data_dir ${PWD}/cifar @@ -34,11 +37,75 @@ python main.py --data_dir ${PWD}/cifar - `train_dir`: Directory to store eventfiles and checkpoints. - `restore`: Restore the latest checkpoint. - `validate`: Use validation set for training monitoring. - - `manual_grad`: Use the manually defined gradient map given by the authors. - - `dataset`: Use either `cifar-10` or `cifar-100` + - `dataset`: Use either `cifar-10` or `cifar-100`. + - `config`: RevNet configuration. + - `use_defun`: Use `tfe.defun` to boost performance. + +- To train a model with estimators in graph execution, run + +```bash +python main_estimator.py --data_dir ${PWD}/cifar +``` +To ensure our code works properly when using the Keras model in an estimator, +`tf-nightly` or `tf-nightly-gpu` is highly recommended as of August 2018. + +- Optional arguments for `main.py` include + - `model_dir`: Directory to store eventfiles and checkpoints. + - `dataset`: Use either `cifar-10` or `cifar-100`. + - `config`: RevNet configuration. + - `export`: Export the model for serving if True. + +## Speed up with `tfe.defun` +To ensure that `tf.contrib.eager.defun` in our code works properly with all +part of the model during training, the latest `tf-nightly` or `tf-nightly-gpu` +is highly recommended as of August 2018. + +Even though the speed difference between pure eager execution and graph execution is noticeable, +the difference between fully "defunned" model training and graph +training is negligible. + +## Train on ImageNet with Cloud TPUs +The standard way to train models on Cloud TPUs is via TPU estimators and graph +execution. Models built with the `tf.keras` API are fully compatible with TPU estimators. +To ensure our code works properly in this setting, +`tf-nightly` or `tf-nightly-gpu` is highly recommended as of August 2018. + +### Setup a Google Cloud project + +Follow the instructions at the [Quickstart Guide](https://cloud.google.com/tpu/docs/quickstart) +to get a GCE VM with access to Cloud TPU. + +To run this model, you will need: + +* A GCE VM instance with an associated Cloud TPU resource +* A GCS bucket to store your training checkpoints +* (Optional): The ImageNet training and validation data preprocessed into + TFRecord format, and stored in GCS. + +### Format the data + +The data is expected to be formatted in TFRecord format, as generated by [this +script](https://github.com/tensorflow/tpu/blob/master/tools/datasets/imagenet_to_gcs.py). + +If you do not have ImageNet dataset prepared, you can use a randomly generated +fake dataset to test the model. It is located at +`gs://cloud-tpu-test-datasets/fake_imagenet`. + +### Start training + +Train the model by executing the following command (substituting the appropriate +values): + +```bash +python main_estimator_tpu.py \ + --tpu=$TPU_NAME \ + --data_dir=$DATA_DIR \ + --model_dir=$MODEL_DIR +``` ## Performance -- With the current implementation, RevNet-38 achieves >92% on CIFAR-10 and >71% on CIFAR-100. +- RevNet-38 achieves >92% and >71% accuracy on CIFAR-10 and CIFAR-100 respectively. +- RevNet-56 achieves <26% top-1 error rate on ImageNet. ## Reference The Reversible Residual Network: Backpropagation Without Storing Activations. diff --git a/tensorflow/contrib/eager/python/examples/revnet/blocks.py b/tensorflow/contrib/eager/python/examples/revnet/blocks.py index 8a530b0d71afab6dfc57ed16120a621cafcc3181..f61354bc38a9fcb941f186cac4eac8097eea742d 100644 --- a/tensorflow/contrib/eager/python/examples/revnet/blocks.py +++ b/tensorflow/contrib/eager/python/examples/revnet/blocks.py @@ -91,32 +91,21 @@ class RevBlock(tf.keras.Model): h = block(h, training=training) return h - def backward_grads_and_vars(self, x, y, dy, training=True): + def backward_grads(self, x, y, dy, training=True): """Apply reversible block backward to outputs.""" grads_all = [] - vars_all = [] - for i in reversed(range(len(self.blocks))): block = self.blocks[i] if i == 0: # First block usually contains downsampling that can't be reversed - with tf.GradientTape() as tape: - tape.watch(x) - y = block(x, training=training) - - grads_combined = tape.gradient( - y, [x] + block.trainable_variables, output_gradients=dy) - dy = grads_combined[0] - grads_all += grads_combined[1:] - vars_all += block.trainable_variables + dy, grads = block.backward_grads_with_downsample( + x, y, dy, training=True) else: - y, dy, grads, vars_ = block.backward_grads_and_vars( - y, dy, training=training) - grads_all += grads - vars_all += vars_ + y, dy, grads = block.backward_grads(y, dy, training=training) + grads_all = grads + grads_all - return dy, grads_all, vars_all + return dy, grads_all class _Residual(tf.keras.Model): @@ -178,10 +167,9 @@ class _Residual(tf.keras.Model): fused=fused, dtype=dtype) - def call(self, x, training=True, concat=True): + def call(self, x, training=True): """Apply residual block to inputs.""" - - x1, x2 = tf.split(x, num_or_size_splits=2, axis=self.axis) + x1, x2 = x f_x2 = self.f(x2, training=training) x1_down = ops.downsample( x1, self.filters // 2, self.strides, axis=self.axis) @@ -190,42 +178,81 @@ class _Residual(tf.keras.Model): y1 = f_x2 + x1_down g_y1 = self.g(y1, training=training) y2 = g_y1 + x2_down - if not concat: # For correct backward grads - return y1, y2 - return tf.concat([y1, y2], axis=self.axis) + return y1, y2 - def backward_grads_and_vars(self, y, dy, training=True): + def backward_grads(self, y, dy, training=True): """Manually compute backward gradients given input and output grads.""" - dy1, dy2 = tf.split(dy, num_or_size_splits=2, axis=self.axis) + dy1, dy2 = dy + y1, y2 = y - with tf.GradientTape(persistent=True) as tape: - tape.watch(y) - y1, y2 = tf.split(y, num_or_size_splits=2, axis=self.axis) + with tf.GradientTape() as gtape: + gtape.watch(y1) gy1 = self.g(y1, training=training) + grads_combined = gtape.gradient( + gy1, [y1] + self.g.trainable_variables, output_gradients=dy2) + dg = grads_combined[1:] + dx1 = dy1 + grads_combined[0] + # This doesn't affect eager execution, but improves memory efficiency with + # graphs + with tf.control_dependencies(dg + [dx1]): x2 = y2 - gy1 + + with tf.GradientTape() as ftape: + ftape.watch(x2) fx2 = self.f(x2, training=training) + grads_combined = ftape.gradient( + fx2, [x2] + self.f.trainable_variables, output_gradients=dx1) + df = grads_combined[1:] + dx2 = dy2 + grads_combined[0] + # Same behavior as above + with tf.control_dependencies(df + [dx2]): x1 = y1 - fx2 - grads_combined = tape.gradient( + x = x1, x2 + dx = dx1, dx2 + grads = df + dg + + return x, dx, grads + + def backward_grads_with_downsample(self, x, y, dy, training=True): + """Manually compute backward gradients given input and output grads.""" + # Splitting this from `backward_grads` for better readability + x1, x2 = x + y1, _ = y + dy1, dy2 = dy + + with tf.GradientTape() as gtape: + gtape.watch(y1) + gy1 = self.g(y1, training=training) + grads_combined = gtape.gradient( gy1, [y1] + self.g.trainable_variables, output_gradients=dy2) dg = grads_combined[1:] - dx1 = dy1 + grads_combined[0] + dz1 = dy1 + grads_combined[0] - grads_combined = tape.gradient( - fx2, [x2] + self.f.trainable_variables, output_gradients=dx1) - dx2 = dy2 + grads_combined[0] - df = grads_combined[1:] + # dx1 need one more step to backprop through downsample + with tf.GradientTape() as x1tape: + x1tape.watch(x1) + z1 = ops.downsample(x1, self.filters // 2, self.strides, axis=self.axis) + dx1 = x1tape.gradient(z1, x1, output_gradients=dz1) - del tape + with tf.GradientTape() as ftape: + ftape.watch(x2) + fx2 = self.f(x2, training=training) + grads_combined = ftape.gradient( + fx2, [x2] + self.f.trainable_variables, output_gradients=dz1) + dx2, df = grads_combined[0], grads_combined[1:] - grads = df + dg - vars_ = self.f.trainable_variables + self.g.trainable_variables + # dx2 need one more step to backprop through downsample + with tf.GradientTape() as x2tape: + x2tape.watch(x2) + z2 = ops.downsample(x2, self.filters // 2, self.strides, axis=self.axis) + dx2 += x2tape.gradient(z2, x2, output_gradients=dy2) - x = tf.concat([x1, x2], axis=self.axis) - dx = tf.concat([dx1, dx2], axis=self.axis) + dx = dx1, dx2 + grads = df + dg - return x, dx, grads, vars_ + return dx, grads # Ideally, the following should be wrapped in `tf.keras.Sequential`, however @@ -422,7 +449,7 @@ class InitBlock(tf.keras.Model): if self.config.init_max_pool: net = self.max_pool(net) - return net + return tf.split(net, num_or_size_splits=2, axis=self.axis) class FinalBlock(tf.keras.Model): @@ -468,7 +495,7 @@ class FinalBlock(tf.keras.Model): self.config.n_classes, dtype=self.config.dtype) def call(self, x, training=True): - net = x + net = tf.concat(x, axis=self.axis) net = self.batch_norm(net, training=training) net = self.activation(net) net = self.global_avg_pool(net) diff --git a/tensorflow/contrib/eager/python/examples/revnet/blocks_test.py b/tensorflow/contrib/eager/python/examples/revnet/blocks_test.py index d74785c8fe1c170ee95172974141c1cfe18b9502..9ff6b605b912772a92ab9e07a0ba5b9325030e43 100644 --- a/tensorflow/contrib/eager/python/examples/revnet/blocks_test.py +++ b/tensorflow/contrib/eager/python/examples/revnet/blocks_test.py @@ -116,70 +116,13 @@ def _validate_block_call_channels_first(block_factory, test): class RevBlockTest(tf.test.TestCase): - def test_call_channels_first(self): - """Test `call` function with `channels_first` data format.""" - if not tf.test.is_gpu_available(): - self.skipTest("GPU not available") - - with tf.device("/gpu:0"): # Default NCHW format - input_shape = (128, 8, 8) - data_shape = (16,) + input_shape - x = tf.random_normal(shape=data_shape) - - # Stride of 1 - block = blocks.RevBlock( - n_res=3, filters=128, strides=(1, 1), input_shape=input_shape) - y_tr, y_ev = block(x, training=True), block(x, training=False) - self.assertEqual(y_tr.shape, y_ev.shape) - self.assertEqual(y_ev.shape, (16, 128, 8, 8)) - self.assertNotAllClose(y_tr, y_ev) - - # Stride of 2 - block = blocks.RevBlock( - n_res=3, filters=128, strides=(2, 2), input_shape=input_shape) - y_tr, y_ev = block(x, training=True), block(x, training=False) - self.assertEqual(y_tr.shape, y_ev.shape) - self.assertEqual(y_ev.shape, [16, 128, 4, 4]) - self.assertNotAllClose(y_tr, y_ev) - - def test_call_channels_last(self): - """Test `call` function with `channels_last` data format.""" - with tf.device("/cpu:0"): # NHWC format - input_shape = (8, 8, 128) - data_shape = (16,) + input_shape - x = tf.random_normal(shape=data_shape) - - # Stride 1 - block = blocks.RevBlock( - n_res=3, - filters=128, - strides=(1, 1), - input_shape=input_shape, - data_format="channels_last") - y_tr, y_ev = block(x, training=True), block(x, training=False) - self.assertEqual(y_tr.shape, y_ev.shape) - self.assertEqual(y_ev.shape, (16, 8, 8, 128)) - self.assertNotAllClose(y_tr, y_ev) - - # Stride of 2 - block = blocks.RevBlock( - n_res=3, - filters=128, - strides=(2, 2), - input_shape=input_shape, - data_format="channels_last") - y_tr, y_ev = block(x, training=True), block(x, training=False) - self.assertEqual(y_tr.shape, y_ev.shape) - self.assertEqual(y_ev.shape, (16, 4, 4, 128)) - self.assertNotAllClose(y_tr, y_ev) - def _check_grad_angle(self, grads, grads_true, atol=1e0): """Check the angle between two list of vectors are all close.""" for g1, g2 in zip(grads, grads_true): degree = compute_degree(g1, g2) self.assertLessEqual(degree, atol) - def test_backward_grads_and_vars_channels_first(self): + def test_backward_grads_channels_first(self): """Test `backward` function with `channels_first` data format.""" if not tf.test.is_gpu_available(): self.skipTest("GPU not available") @@ -190,6 +133,7 @@ class RevBlockTest(tf.test.TestCase): data_shape = (16,) + input_shape x = tf.random_normal(shape=data_shape, dtype=tf.float64) dy = tf.random_normal(shape=data_shape, dtype=tf.float64) + dy1, dy2 = tf.split(dy, num_or_size_splits=2, axis=1) block = blocks.RevBlock( n_res=3, filters=128, @@ -199,9 +143,14 @@ class RevBlockTest(tf.test.TestCase): dtype=tf.float64) with tf.GradientTape() as tape: tape.watch(x) - y = block(x, training=True) + x1, x2 = tf.split(x, num_or_size_splits=2, axis=1) + y1, y2 = block((x1, x2), training=True) + y = tf.concat((y1, y2), axis=1) # Compute grads from reconstruction - dx, dw, vars_ = block.backward_grads_and_vars(x, y, dy, training=True) + (dx1, dx2), dw = block.backward_grads( + x=(x1, x2), y=(y1, y2), dy=(dy1, dy2), training=True) + dx = tf.concat((dx1, dx2), axis=1) + vars_ = block.trainable_variables # Compute true grads grads = tape.gradient(y, [x] + vars_, output_gradients=dy) dx_true, dw_true = grads[0], grads[1:] @@ -213,6 +162,7 @@ class RevBlockTest(tf.test.TestCase): # Stride 2 x = tf.random_normal(shape=data_shape, dtype=tf.float64) dy = tf.random_normal(shape=(16, 128, 4, 4), dtype=tf.float64) + dy1, dy2 = tf.split(dy, num_or_size_splits=2, axis=1) block = blocks.RevBlock( n_res=3, filters=128, @@ -222,9 +172,14 @@ class RevBlockTest(tf.test.TestCase): dtype=tf.float64) with tf.GradientTape() as tape: tape.watch(x) - y = block(x, training=True) + x1, x2 = tf.split(x, num_or_size_splits=2, axis=1) + y1, y2 = block((x1, x2), training=True) + y = tf.concat((y1, y2), axis=1) # Compute grads from reconstruction - dx, dw, vars_ = block.backward_grads_and_vars(x, y, dy, training=True) + (dx1, dx2), dw = block.backward_grads( + x=(x1, x2), y=(y1, y2), dy=(dy1, dy2), training=True) + dx = tf.concat((dx1, dx2), axis=1) + vars_ = block.trainable_variables # Compute true grads grads = tape.gradient(y, [x] + vars_, output_gradients=dy) dx_true, dw_true = grads[0], grads[1:] @@ -233,19 +188,44 @@ class RevBlockTest(tf.test.TestCase): self._check_grad_angle(dx_true, dx) self._check_grad_angle(dw_true, dw) + def test_backward_grads_with_nativepy(self): + if not tf.test.is_gpu_available(): + self.skipTest("GPU not available") -class _ResidualTest(tf.test.TestCase): + input_shape = (128, 8, 8) + data_shape = (16,) + input_shape + x = tf.random_normal(shape=data_shape, dtype=tf.float64) + dy = tf.random_normal(shape=data_shape, dtype=tf.float64) + dy1, dy2 = tf.split(dy, num_or_size_splits=2, axis=1) + block = blocks.RevBlock( + n_res=3, + filters=128, + strides=(1, 1), + input_shape=input_shape, + fused=False, + dtype=tf.float64) + with tf.GradientTape() as tape: + tape.watch(x) + x1, x2 = tf.split(x, num_or_size_splits=2, axis=1) + y1, y2 = block((x1, x2), training=True) + y = tf.concat((y1, y2), axis=1) - def test_call(self): - """Test `call` function. + # Compute true grads + dx_true = tape.gradient(y, x, output_gradients=dy) - Varying downsampling and data format options. - """ + # Compute grads from reconstruction + (dx1, dx2), _ = block.backward_grads( + x=(x1, x2), y=(y1, y2), dy=(dy1, dy2), training=True) + dx = tf.concat((dx1, dx2), axis=1) - _validate_block_call_channels_first(blocks._Residual, self) - _validate_block_call_channels_last(blocks._Residual, self) + thres = 1e-5 + diff_abs = tf.reshape(abs(dx - dx_true), [-1]) + assert all(diff_abs < thres) - def test_backward_grads_and_vars_channels_first(self): + +class _ResidualTest(tf.test.TestCase): + + def test_backward_grads_channels_first(self): """Test `backward_grads` function with `channels_first` data format.""" if not tf.test.is_gpu_available(): self.skipTest("GPU not available") @@ -256,6 +236,7 @@ class _ResidualTest(tf.test.TestCase): # Use double precision for testing x_true = tf.random_normal(shape=data_shape, dtype=tf.float64) dy = tf.random_normal(shape=data_shape, dtype=tf.float64) + dy1, dy2 = tf.split(dy, num_or_size_splits=2, axis=1) residual = blocks._Residual( filters=128, strides=(1, 1), @@ -264,16 +245,19 @@ class _ResidualTest(tf.test.TestCase): dtype=tf.float64) with tf.GradientTape() as tape: - x_true = tf.identity(x_true) tape.watch(x_true) - y = residual(x_true, training=True) + x1_true, x2_true = tf.split(x_true, num_or_size_splits=2, axis=1) + y1, y2 = residual((x1_true, x2_true), training=True) + y = tf.concat((y1, y2), axis=1) # Gradients computed due to reversibility - x, dx, dw, vars_ = residual.backward_grads_and_vars( - y, dy=dy, training=True) - + (x1, x2), (dx1, dx2), dw = residual.backward_grads( + y=(y1, y2), dy=(dy1, dy2), training=True) + x = tf.concat((x1, x2), axis=1) + dx = tf.concat((dx1, dx2), axis=1) # True gradients computed by the tape - grads = tape.gradient(y, [x_true] + vars_, output_gradients=dy) + grads = tape.gradient( + y, [x_true] + residual.trainable_variables, output_gradients=dy) dx_true, dw_true = grads[0], grads[1:] self.assertAllClose(x_true, x) diff --git a/tensorflow/contrib/eager/python/examples/revnet/config.py b/tensorflow/contrib/eager/python/examples/revnet/config.py index 821a4878c1cfe1aebe3697952059266def0f5817..29f1db0e0367515757413c8e47f7b7280fc4cfbb 100644 --- a/tensorflow/contrib/eager/python/examples/revnet/config.py +++ b/tensorflow/contrib/eager/python/examples/revnet/config.py @@ -82,7 +82,8 @@ def get_hparams_cifar_38(): config.num_train_images // config.tpu_batch_size) config.add_hparam("tpu_epochs", config.max_train_iter // config.tpu_iters_per_epoch) - + config.add_hparam("tpu_eval_steps", + config.num_eval_images // config.tpu_eval_batch_size) return config @@ -162,7 +163,8 @@ def get_hparams_imagenet_56(): config.num_train_images // config.tpu_batch_size) config.add_hparam("tpu_epochs", config.max_train_iter // config.tpu_iters_per_epoch) - + config.add_hparam("tpu_eval_steps", + config.num_eval_images // config.tpu_eval_batch_size) return config diff --git a/tensorflow/contrib/eager/python/examples/revnet/imagenet_input.py b/tensorflow/contrib/eager/python/examples/revnet/imagenet_input.py index e81351b1b14dbf6973e7430c369774339e2dcdd8..34a9984b0ecc527ad1991c28146246b716e96c98 100644 --- a/tensorflow/contrib/eager/python/examples/revnet/imagenet_input.py +++ b/tensorflow/contrib/eager/python/examples/revnet/imagenet_input.py @@ -211,8 +211,7 @@ class ImageNetInput(object): dataset = tf.data.Dataset.range(1).repeat().map(self._get_null_input) dataset = dataset.prefetch(batch_size) - dataset = dataset.apply( - tf.contrib.data.batch_and_drop_remainder(batch_size)) + dataset = dataset.batch(batch_size, drop_remainder=True) if self.transpose_input: dataset = dataset.map( lambda images, labels: (tf.transpose(images, [1, 2, 3, 0]), labels), diff --git a/tensorflow/contrib/eager/python/examples/revnet/main.py b/tensorflow/contrib/eager/python/examples/revnet/main.py index dcd4e1697faae2a06b1b1581d6f6f0cfebeacde1..b702e91f92220c2a9003a1b82411131332012a9e 100644 --- a/tensorflow/contrib/eager/python/examples/revnet/main.py +++ b/tensorflow/contrib/eager/python/examples/revnet/main.py @@ -29,6 +29,11 @@ from tensorflow.contrib.eager.python.examples.revnet import revnet tfe = tf.contrib.eager +def apply_gradients(optimizer, grads, vars_, global_step=None): + """Functional style apply_grads for `tfe.defun`.""" + optimizer.apply_gradients(zip(grads, vars_), global_step=global_step) + + def main(_): """Eager execution workflow with RevNet trained on CIFAR-10.""" tf.enable_eager_execution() @@ -48,6 +53,11 @@ def main(_): if FLAGS.use_defun: model.call = tfe.defun(model.call) + model.compute_gradients = tfe.defun(model.compute_gradients) + model.get_moving_stats = tfe.defun(model.get_moving_stats) + model.restore_moving_stats = tfe.defun(model.restore_moving_stats) + global apply_gradients # pylint:disable=global-variable-undefined + apply_gradients = tfe.defun(apply_gradients) if FLAGS.train_dir: summary_writer = tf.contrib.summary.create_file_writer(FLAGS.train_dir) @@ -197,9 +207,13 @@ def get_datasets(data_dir, config): def train_one_iter(model, inputs, labels, optimizer, global_step=None): """Train for one iteration.""" - grads, vars_, logits, loss = model.compute_gradients( - inputs, labels, training=True) - optimizer.apply_gradients(zip(grads, vars_), global_step=global_step) + logits, saved_hiddens = model(inputs, training=True) + values = model.get_moving_stats() + grads, loss = model.compute_gradients(saved_hiddens, labels) + # Restore moving averages when executing eagerly to avoid updating twice + model.restore_moving_stats(values) + apply_gradients( + optimizer, grads, model.trainable_variables, global_step=global_step) return logits, loss diff --git a/tensorflow/contrib/eager/python/examples/revnet/main_estimator.py b/tensorflow/contrib/eager/python/examples/revnet/main_estimator.py index 4868f1931f8cd9046e6e233c82f95969e355b6c2..3a17eb30da3b989acb0b33f2fcb730da76546c18 100644 --- a/tensorflow/contrib/eager/python/examples/revnet/main_estimator.py +++ b/tensorflow/contrib/eager/python/examples/revnet/main_estimator.py @@ -53,10 +53,11 @@ def model_fn(features, labels, mode, params): global_step, config.lr_decay_steps, config.lr_list) optimizer = tf.train.MomentumOptimizer( learning_rate, momentum=config.momentum) - grads, vars_, logits, loss = model.compute_gradients( - inputs, labels, training=True) - train_op = optimizer.apply_gradients( - zip(grads, vars_), global_step=global_step) + logits, saved_hidden = model(inputs, training=True) + grads, loss = model.compute_gradients(saved_hidden, labels, training=True) + with tf.control_dependencies(model.get_updates_for(inputs)): + train_op = optimizer.apply_gradients( + zip(grads, model.trainable_variables), global_step=global_step) return tf.estimator.EstimatorSpec(mode=mode, loss=loss, train_op=train_op) else: @@ -130,8 +131,7 @@ def get_input_fn(config, data_dir, split): return input_fn -def main(argv): - FLAGS = argv[0] # pylint:disable=invalid-name,redefined-outer-name +def main(_): tf.logging.set_verbosity(tf.logging.INFO) # RevNet specific configuration @@ -139,7 +139,7 @@ def main(argv): # Estimator specific configuration run_config = tf.estimator.RunConfig( - model_dir=FLAGS.train_dir, # Directory for storing checkpoints + model_dir=FLAGS.model_dir, # Directory for storing checkpoints tf_random_seed=config.seed, save_summary_steps=config.log_every, save_checkpoints_steps=config.log_every, @@ -153,7 +153,7 @@ def main(argv): # Construct estimator revnet_estimator = tf.estimator.Estimator( model_fn=model_fn, - model_dir=FLAGS.train_dir, + model_dir=FLAGS.model_dir, config=run_config, params={"config": config}) @@ -173,14 +173,14 @@ def main(argv): input_fn = tf.estimator.export.build_raw_serving_input_receiver_fn({ "image": inputs }) - revnet_estimator.export_savedmodel(FLAGS.train_dir, input_fn) + revnet_estimator.export_savedmodel(FLAGS.model_dir, input_fn) if __name__ == "__main__": flags.DEFINE_string( "data_dir", default=None, help="Directory to load tfrecords") flags.DEFINE_string( - "train_dir", + "model_dir", default=None, help="[Optional] Directory to store the training information") flags.DEFINE_string( @@ -197,4 +197,4 @@ if __name__ == "__main__": help="[Optional] Architecture of network. " "Other options include `revnet-110` and `revnet-164`") FLAGS = flags.FLAGS - tf.app.run(main=main, argv=[FLAGS]) + tf.app.run() diff --git a/tensorflow/contrib/eager/python/examples/revnet/main_estimator_tpu.py b/tensorflow/contrib/eager/python/examples/revnet/main_estimator_tpu.py index d809bcd287ccf26ef2d817168367f37c933b7182..8520cf5b71af503be35d5415707a283fb363a476 100644 --- a/tensorflow/contrib/eager/python/examples/revnet/main_estimator_tpu.py +++ b/tensorflow/contrib/eager/python/examples/revnet/main_estimator_tpu.py @@ -12,22 +12,90 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Cloud TPU Estimator workflow with RevNet train on CIFAR-10.""" +"""Cloud TPU Estimator workflow with RevNet train on ImageNet.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function -import os import time from absl import flags import tensorflow as tf -from tensorflow.contrib.eager.python.examples.revnet import cifar_input -from tensorflow.contrib.eager.python.examples.revnet import main as main_ +from tensorflow.contrib import summary +from tensorflow.contrib.eager.python.examples.revnet import config as config_ +from tensorflow.contrib.eager.python.examples.revnet import imagenet_input from tensorflow.contrib.eager.python.examples.revnet import revnet from tensorflow.contrib.training.python.training import evaluation -from tensorflow.python.estimator import estimator as estimator_ +from tensorflow.python.estimator import estimator + +MEAN_RGB = [0.485, 0.456, 0.406] +STDDEV_RGB = [0.229, 0.224, 0.225] + + +def _host_call_fn(gs, loss, lr): + """Training host call. + + Creates scalar summaries for training metrics. + + This function is executed on the CPU and should not directly reference + any Tensors in the rest of the `model_fn`. To pass Tensors from the + model to the `metric_fn`, provide as part of the `host_call`. See + https://www.tensorflow.org/api_docs/python/tf/contrib/tpu/TPUEstimatorSpec + for more information. + + Arguments should match the list of `Tensor` objects passed as the second + element in the tuple passed to `host_call`. + + Args: + gs: `Tensor with shape `[batch]` for the global_step + loss: `Tensor` with shape `[batch]` for the training loss. + lr: `Tensor` with shape `[batch]` for the learning_rate. + + Returns: + List of summary ops to run on the CPU host. + """ + # Host call fns are executed FLAGS.iterations_per_loop times after one + # TPU loop is finished, setting max_queue value to the same as number of + # iterations will make the summary writer only flush the data to storage + # once per loop. + gs = gs[0] + with summary.create_file_writer( + FLAGS.model_dir, max_queue=FLAGS.iterations_per_loop).as_default(): + with summary.always_record_summaries(): + summary.scalar("loss", loss[0], step=gs) + summary.scalar("learning_rate", lr[0], step=gs) + return summary.all_summary_ops() + + +def _metric_fn(labels, logits): + """Evaluation metric function. Evaluates accuracy. + + This function is executed on the CPU and should not directly reference + any Tensors in the rest of the `model_fn`. To pass Tensors from the model + to the `metric_fn`, provide as part of the `eval_metrics`. See + https://www.tensorflow.org/api_docs/python/tf/contrib/tpu/TPUEstimatorSpec + for more information. + + Arguments should match the list of `Tensor` objects passed as the second + element in the tuple passed to `eval_metrics`. + + Args: + labels: `Tensor` with shape `[batch]`. + logits: `Tensor` with shape `[batch, num_classes]`. + + Returns: + A dict of the metrics to return from evaluation. + """ + predictions = tf.argmax(logits, axis=1) + top_1_accuracy = tf.metrics.accuracy(labels, predictions) + in_top_5 = tf.cast(tf.nn.in_top_k(logits, labels, 5), tf.float32) + top_5_accuracy = tf.metrics.mean(in_top_5) + + return { + "top_1_accuracy": top_1_accuracy, + "top_5_accuracy": top_5_accuracy, + } def model_fn(features, labels, mode, params): @@ -42,50 +110,58 @@ def model_fn(features, labels, mode, params): Returns: An instance of `tf.contrib.tpu.TPUEstimatorSpec` """ + revnet_config = params["revnet_config"] + model = revnet.RevNet(config=revnet_config) inputs = features if isinstance(inputs, dict): inputs = features["image"] - FLAGS = params["FLAGS"] # pylint:disable=invalid-name,redefined-outer-name - config = params["config"] - model = revnet.RevNet(config=config) + if revnet_config.data_format == "channels_first": + assert not FLAGS.transpose_input # channels_first only for GPU + inputs = tf.transpose(inputs, [0, 3, 1, 2]) + + if FLAGS.transpose_input and mode != tf.estimator.ModeKeys.PREDICT: + inputs = tf.transpose(inputs, [3, 0, 1, 2]) # HWCN to NHWC + + # Normalize the image to zero mean and unit variance. + inputs -= tf.constant(MEAN_RGB, shape=[1, 1, 3], dtype=inputs.dtype) + inputs /= tf.constant(STDDEV_RGB, shape=[1, 1, 3], dtype=inputs.dtype) if mode == tf.estimator.ModeKeys.TRAIN: global_step = tf.train.get_or_create_global_step() learning_rate = tf.train.piecewise_constant( - global_step, config.lr_decay_steps, config.lr_list) - optimizer = tf.train.MomentumOptimizer( - learning_rate, momentum=config.momentum) - + global_step, revnet_config.lr_decay_steps, revnet_config.lr_list) + optimizer = tf.train.MomentumOptimizer(learning_rate, + revnet_config.momentum) if FLAGS.use_tpu: optimizer = tf.contrib.tpu.CrossShardOptimizer(optimizer) - # Define gradients - grads, vars_, logits, loss = model.compute_gradients( - inputs, labels, training=True) - train_op = optimizer.apply_gradients( - zip(grads, vars_), global_step=global_step) - - names = [v.name for v in model.variables] - tf.logging.warn("{}".format(names)) + logits, saved_hidden = model(inputs, training=True) + grads, loss = model.compute_gradients(saved_hidden, labels, training=True) + with tf.control_dependencies(model.get_updates_for(inputs)): + train_op = optimizer.apply_gradients( + zip(grads, model.trainable_variables), global_step=global_step) + if not FLAGS.skip_host_call: + # To log the loss, current learning rate, and epoch for Tensorboard, the + # summary op needs to be run on the host CPU via host_call. host_call + # expects [batch_size, ...] Tensors, thus reshape to introduce a batch + # dimension. These Tensors are implicitly concatenated to + # [params['batch_size']]. + gs_t = tf.reshape(global_step, [1]) + loss_t = tf.reshape(loss, [1]) + lr_t = tf.reshape(learning_rate, [1]) + host_call = (_host_call_fn, [gs_t, loss_t, lr_t]) return tf.contrib.tpu.TPUEstimatorSpec( - mode=tf.estimator.ModeKeys.TRAIN, loss=loss, train_op=train_op) + mode=mode, loss=loss, train_op=train_op, host_call=host_call) elif mode == tf.estimator.ModeKeys.EVAL: logits, _ = model(inputs, training=False) loss = model.compute_loss(labels=labels, logits=logits) - def metric_fn(labels, logits): - predictions = tf.argmax(logits, axis=1) - accuracy = tf.metrics.accuracy(labels=labels, predictions=predictions) - return { - "accuracy": accuracy, - } - return tf.contrib.tpu.TPUEstimatorSpec( - mode=mode, loss=loss, eval_metrics=(metric_fn, [labels, logits])) + mode=mode, loss=loss, eval_metrics=(_metric_fn, [labels, logits])) else: # Predict or export logits, _ = model(inputs, training=False) @@ -102,117 +178,75 @@ def model_fn(features, labels, mode, params): }) -def get_input_fn(config, data_dir, split): - """Get the input function required by the `tf.contrib.tpu.TPUEstimator` API. - - Args: - config: Customized hyperparameters - data_dir: Directory where the data is stored - split: One of `train`, `validation`, `train_all`, and `test` - - Returns: - Input function required by the `tf.contrib.tpu.TPUEstimator` API - """ - - data_dir = os.path.join(data_dir, config.dataset) - # Fix split-dependent hyperparameters - if split == "train_all" or split == "train": - data_aug = True - epochs = config.tpu_epochs - shuffle = True - else: - data_aug = False - epochs = 1 - shuffle = False - - def input_fn(params): - """Input function required by the `tf.contrib.tpu.TPUEstimator` API.""" - batch_size = params["batch_size"] - return cifar_input.get_ds_from_tfrecords( - data_dir=data_dir, - split=split, - data_aug=data_aug, - batch_size=batch_size, # per-shard batch size - epochs=epochs, - shuffle=shuffle, - prefetch=batch_size, # per-shard batch size - data_format=config.data_format) - - return input_fn - - -def main(argv): - FLAGS = argv[0] # pylint:disable=invalid-name,redefined-outer-name +def main(_): tf.logging.set_verbosity(tf.logging.INFO) # RevNet specific configuration - config = main_.get_config(config_name=FLAGS.config, dataset=FLAGS.dataset) + revnet_config = { + "revnet-56": config_.get_hparams_imagenet_56(), + "revnet-104": config_.get_hparams_imagenet_104() + }[FLAGS.revnet_config] if FLAGS.use_tpu: - tf.logging.info("Using TPU.") - tpu_cluster_resolver = tf.contrib.cluster_resolver.TPUClusterResolver( - FLAGS.tpu, zone=FLAGS.tpu_zone, project=FLAGS.gcp_project) - else: - tpu_cluster_resolver = None - - # TPU specific configuration - tpu_config = tf.contrib.tpu.TPUConfig( - # Recommended to be set as number of global steps for next checkpoint - iterations_per_loop=FLAGS.iterations_per_loop, - num_shards=FLAGS.num_shards) + revnet_config.data_format = "channels_last" + + tpu_cluster_resolver = tf.contrib.cluster_resolver.TPUClusterResolver( + FLAGS.tpu, zone=FLAGS.tpu_zone, project=FLAGS.gcp_project) # Estimator specific configuration - run_config = tf.contrib.tpu.RunConfig( + config = tf.contrib.tpu.RunConfig( cluster=tpu_cluster_resolver, model_dir=FLAGS.model_dir, session_config=tf.ConfigProto( - allow_soft_placement=True, log_device_placement=False), - tpu_config=tpu_config, + allow_soft_placement=True, log_device_placement=True), + tpu_config=tf.contrib.tpu.TPUConfig( + iterations_per_loop=FLAGS.iterations_per_loop, + num_shards=FLAGS.num_shards, + per_host_input_for_training=tf.contrib.tpu.InputPipelineConfig. + PER_HOST_V2), ) - # Construct TPU Estimator - estimator = tf.contrib.tpu.TPUEstimator( + # Input pipelines are slightly different (with regards to shuffling and + # preprocessing) between training and evaluation. + imagenet_train, imagenet_eval = [ + imagenet_input.ImageNetInput( + is_training=is_training, + data_dir=FLAGS.data_dir, + transpose_input=FLAGS.transpose_input, + use_bfloat16=False) for is_training in [True, False] + ] + + revnet_classifier = tf.contrib.tpu.TPUEstimator( model_fn=model_fn, use_tpu=FLAGS.use_tpu, - train_batch_size=config.tpu_batch_size, - eval_batch_size=config.tpu_eval_batch_size, - config=run_config, - params={ - "FLAGS": FLAGS, - "config": config, - }) - - # Construct input functions - train_input_fn = get_input_fn( - config=config, data_dir=FLAGS.data_dir, split="train_all") - eval_input_fn = get_input_fn( - config=config, data_dir=FLAGS.data_dir, split="test") - - # Disabling a range within an else block currently doesn't work - # due to https://github.com/PyCQA/pylint/issues/872 + train_batch_size=revnet_config.tpu_batch_size, + eval_batch_size=revnet_config.tpu_eval_batch_size, + config=config, + export_to_tpu=False, + params={"revnet_config": revnet_config}) + + steps_per_epoch = revnet_config.tpu_iters_per_epoch + eval_steps = revnet_config.tpu_eval_steps + # pylint: disable=protected-access if FLAGS.mode == "eval": - # TPUEstimator.evaluate *requires* a steps argument. - # Note that the number of examples used during evaluation is - # --eval_steps * --batch_size. - # So if you change --batch_size then change --eval_steps too. - eval_steps = 10000 // config.tpu_eval_batch_size - # Run evaluation when there's a new checkpoint for ckpt in evaluation.checkpoints_iterator( FLAGS.model_dir, timeout=FLAGS.eval_timeout): tf.logging.info("Starting to evaluate.") try: start_timestamp = time.time() # This time will include compilation time - eval_results = estimator.evaluate( - input_fn=eval_input_fn, steps=eval_steps, checkpoint_path=ckpt) + eval_results = revnet_classifier.evaluate( + input_fn=imagenet_eval.input_fn, + steps=eval_steps, + checkpoint_path=ckpt) elapsed_time = int(time.time() - start_timestamp) tf.logging.info("Eval results: %s. Elapsed seconds: %d" % (eval_results, elapsed_time)) # Terminate eval job when final checkpoint is reached current_step = int(os.path.basename(ckpt).split("-")[1]) - if current_step >= config.max_train_iter: + if current_step >= revnet_config.max_train_iter: tf.logging.info( "Evaluation finished after training step %d" % current_step) break @@ -226,37 +260,56 @@ def main(argv): "Checkpoint %s no longer exists, skipping checkpoint" % ckpt) else: # FLAGS.mode == 'train' or FLAGS.mode == 'train_and_eval' - current_step = estimator_._load_global_step_from_checkpoint_dir( + current_step = estimator._load_global_step_from_checkpoint_dir( FLAGS.model_dir) - tf.logging.info("Training for %d steps . Current" - " step %d." % (config.max_train_iter, current_step)) + + tf.logging.info( + "Training for %d steps (%.2f epochs in total). Current" + " step %d." % (revnet_config.max_train_iter, + revnet_config.max_train_iter / steps_per_epoch, + current_step)) start_timestamp = time.time() # This time will include compilation time + if FLAGS.mode == "train": - estimator.train(input_fn=train_input_fn, max_steps=config.max_train_iter) + revnet_classifier.train( + input_fn=imagenet_train.input_fn, + max_steps=revnet_config.max_train_iter) + else: - eval_steps = 10000 // config.tpu_eval_batch_size assert FLAGS.mode == "train_and_eval" - while current_step < config.max_train_iter: + while current_step < revnet_config.max_train_iter: # Train for up to steps_per_eval number of steps. # At the end of training, a checkpoint will be written to --model_dir. next_checkpoint = min(current_step + FLAGS.steps_per_eval, - config.max_train_iter) - estimator.train(input_fn=train_input_fn, max_steps=next_checkpoint) + revnet_config.max_train_iter) + revnet_classifier.train( + input_fn=imagenet_train.input_fn, max_steps=next_checkpoint) current_step = next_checkpoint + tf.logging.info("Finished training up to step %d. Elapsed seconds %d." % + (next_checkpoint, int(time.time() - start_timestamp))) + # Evaluate the model on the most recent model in --model_dir. # Since evaluation happens in batches of --eval_batch_size, some images - # may be consistently excluded modulo the batch size. + # may be excluded modulo the batch size. As long as the batch size is + # consistent, the evaluated images are also consistent. tf.logging.info("Starting to evaluate.") - eval_results = estimator.evaluate( - input_fn=eval_input_fn, steps=eval_steps) + eval_results = revnet_classifier.evaluate( + input_fn=imagenet_eval.input_fn, steps=eval_steps) tf.logging.info("Eval results: %s" % eval_results) - elapsed_time = int(time.time() - start_timestamp) - tf.logging.info("Finished training up to step %d. Elapsed seconds %d." % - (config.max_train_iter, elapsed_time)) - # pylint: enable=protected-access + elapsed_time = int(time.time() - start_timestamp) + tf.logging.info("Finished training up to step %d. Elapsed seconds %d." % + (revnet_config.max_train_iter, elapsed_time)) + + if FLAGS.export_dir is not None: + # The guide to serve an exported TensorFlow model is at: + # https://www.tensorflow.org/serving/serving_basic + tf.logging.info("Starting to export model.") + revnet_classifier.export_savedmodel( + export_dir_base=FLAGS.export_dir, + serving_input_receiver_fn=imagenet_input.image_serving_input_fn) if __name__ == "__main__": @@ -288,14 +341,10 @@ if __name__ == "__main__": default=None, help="[Optional] Directory to store the model information") flags.DEFINE_string( - "dataset", - default="cifar-10", - help="[Optional] The dataset used; either `cifar-10` or `cifar-100`") - flags.DEFINE_string( - "config", - default="revnet-38", + "revnet_config", + default="revnet-56", help="[Optional] Architecture of network. " - "Other options include `revnet-110` and `revnet-164`") + "Other options include `revnet-104`") flags.DEFINE_boolean( "use_tpu", default=True, help="[Optional] Whether to use TPU") flags.DEFINE_integer( @@ -309,20 +358,37 @@ if __name__ == "__main__": " train steps, the loop will exit before reaching" " --iterations_per_loop. The larger this value is, the higher the" " utilization on the TPU.")) - flags.DEFINE_string( - "mode", - default="train_and_eval", - help="[Optional] Mode to run: train, eval, train_and_eval") flags.DEFINE_integer( - "eval_timeout", 60 * 60 * 24, - "Maximum seconds between checkpoints before evaluation terminates.") + "eval_timeout", + default=None, + help="Maximum seconds between checkpoints before evaluation terminates.") flags.DEFINE_integer( "steps_per_eval", - default=1000, + default=5000, help=( "Controls how often evaluation is performed. Since evaluation is" " fairly expensive, it is advised to evaluate as infrequently as" " possible (i.e. up to --train_steps, which evaluates the model only" " after finishing the entire training regime).")) + flags.DEFINE_bool( + "transpose_input", + default=True, + help="Use TPU double transpose optimization") + flags.DEFINE_string( + "export_dir", + default=None, + help=("The directory where the exported SavedModel will be stored.")) + flags.DEFINE_bool( + "skip_host_call", + default=False, + help=("Skip the host_call which is executed every training step. This is" + " generally used for generating training summaries (train loss," + " learning rate, etc...). When --skip_host_call=false, there could" + " be a performance drop if host_call function is slow and cannot" + " keep up with the TPU-side computation.")) + flags.DEFINE_string( + "mode", + default="train_and_eval", + help='One of {"train_and_eval", "train", "eval"}.') FLAGS = flags.FLAGS - tf.app.run(main=main, argv=[FLAGS]) + tf.app.run() diff --git a/tensorflow/contrib/eager/python/examples/revnet/revnet.py b/tensorflow/contrib/eager/python/examples/revnet/revnet.py index b1cb312b7459eb1d8926e6e6635ed2cfbed79833..1f2cb14972f0b92d29489adff8f94e790e1ec4ed 100644 --- a/tensorflow/contrib/eager/python/examples/revnet/revnet.py +++ b/tensorflow/contrib/eager/python/examples/revnet/revnet.py @@ -24,7 +24,6 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import six import tensorflow as tf from tensorflow.contrib.eager.python.examples.revnet import blocks @@ -45,6 +44,7 @@ class RevNet(tf.keras.Model): self._init_block = blocks.InitBlock(config=self.config) self._final_block = blocks.FinalBlock(config=self.config) self._block_list = self._construct_intermediate_blocks() + self._moving_average_variables = [] def _construct_intermediate_blocks(self): # Precompute input shape after initial block @@ -128,126 +128,90 @@ class RevNet(tf.keras.Model): return tf.reduce_mean(cross_ent) - def compute_gradients(self, inputs, labels, training=True, l2_reg=True): + def compute_gradients(self, saved_hidden, labels, training=True, l2_reg=True): """Manually computes gradients. - When eager execution is enabled, this method also SILENTLY updates the - running averages of batch normalization when `training` is set to True. + This method silently updates the running averages of batch normalization. Args: - inputs: Image tensor, either NHWC or NCHW, conforming to `data_format` + saved_hidden: List of hidden states Tensors labels: One-hot labels for classification training: Use the mini-batch stats in batch norm if set to True l2_reg: Apply l2 regularization Returns: - A tuple with the first entry being a list of all gradients, the second - entry being a list of respective variables, the third being the logits, - and the forth being the loss + A tuple with the first entry being a list of all gradients and the second + being the loss """ - # Run forward pass to record hidden states - vars_and_vals = self.get_moving_stats() - _, saved_hidden = self(inputs, training=training) # pylint:disable=not-callable - if tf.executing_eagerly(): - # Restore moving averages when executing eagerly to avoid updating twice - self.restore_moving_stats(vars_and_vals) - else: - # Fetch batch norm updates in graph mode - updates = self.get_updates_for(inputs) - - grads_all = [] - vars_all = [] + def _defunable_pop(l): + """Functional style list pop that works with `tfe.defun`.""" + t, l = l[-1], l[:-1] + return t, l - # Manually backprop through last block + # Backprop through last block x = saved_hidden[-1] with tf.GradientTape() as tape: tape.watch(x) - # Running stats updated here logits = self._final_block(x, training=training) loss = self.compute_loss(logits, labels) - grads_combined = tape.gradient(loss, [x] + self._final_block.trainable_variables) - dy, grads_ = grads_combined[0], grads_combined[1:] - grads_all += grads_ - vars_all += self._final_block.trainable_variables + dy, final_grads = grads_combined[0], grads_combined[1:] - # Manually backprop through intermediate blocks + # Backprop through intermediate blocks + intermediate_grads = [] for block in reversed(self._block_list): - y = saved_hidden.pop() + y, saved_hidden = _defunable_pop(saved_hidden) x = saved_hidden[-1] - # Running stats updated here - dy, grads, vars_ = block.backward_grads_and_vars( - x, y, dy, training=training) - grads_all += grads - vars_all += vars_ - - # Manually backprop through first block - saved_hidden.pop() - x = saved_hidden.pop() - assert not saved_hidden # Cleared after backprop + dy, grads = block.backward_grads(x, y, dy, training=training) + intermediate_grads = grads + intermediate_grads + # Backprop through first block + _, saved_hidden = _defunable_pop(saved_hidden) + x, saved_hidden = _defunable_pop(saved_hidden) + assert not saved_hidden with tf.GradientTape() as tape: - # Running stats updated here y = self._init_block(x, training=training) - - grads_all += tape.gradient( + init_grads = tape.gradient( y, self._init_block.trainable_variables, output_gradients=dy) - vars_all += self._init_block.trainable_variables - # Apply weight decay + # Ordering match up with `model.trainable_variables` + grads_all = init_grads + final_grads + intermediate_grads if l2_reg: - grads_all = self._apply_weight_decay(grads_all, vars_all) - - if not tf.executing_eagerly(): - # Force updates to be executed before gradient computation in graph mode - # This does nothing when the function is wrapped in defun - with tf.control_dependencies(updates): - grads_all[0] = tf.identity(grads_all[0]) + grads_all = self._apply_weight_decay(grads_all) - return grads_all, vars_all, logits, loss + return grads_all, loss - def _apply_weight_decay(self, grads, vars_): + def _apply_weight_decay(self, grads): """Update gradients to reflect weight decay.""" - # Don't decay bias return [ g + self.config.weight_decay * v if v.name.endswith("kernel:0") else g - for g, v in zip(grads, vars_) + for g, v in zip(grads, self.trainable_variables) ] def get_moving_stats(self): - """Get moving averages of batch normalization. - - This is needed to avoid updating the running average twice in one iteration. - - Returns: - A dictionary mapping variables for batch normalization moving averages - to their current values. - """ - vars_and_vals = {} - - def _is_moving_var(v): - n = v.name - return n.endswith("moving_mean:0") or n.endswith("moving_variance:0") + """Get moving averages of batch normalization.""" + device = "/gpu:0" if tf.test.is_gpu_available() else "/cpu:0" + with tf.device(device): + return [v.read_value() for v in self.moving_average_variables] + def restore_moving_stats(self, values): + """Restore moving averages of batch normalization.""" device = "/gpu:0" if tf.test.is_gpu_available() else "/cpu:0" with tf.device(device): - for v in filter(_is_moving_var, self.variables): - vars_and_vals[v] = v.read_value() + for var_, val in zip(self.moving_average_variables, values): + var_.assign(val) - return vars_and_vals + @property + def moving_average_variables(self): + """Get all variables that are batch norm moving averages.""" - def restore_moving_stats(self, vars_and_vals): - """Restore moving averages of batch normalization. + def _is_moving_avg(v): + n = v.name + return n.endswith("moving_mean:0") or n.endswith("moving_variance:0") - This is needed to avoid updating the running average twice in one iteration. + if not self._moving_average_variables: + self._moving_average_variables = filter(_is_moving_avg, self.variables) - Args: - vars_and_vals: The dictionary mapping variables to their previous values. - """ - device = "/gpu:0" if tf.test.is_gpu_available() else "/cpu:0" - with tf.device(device): - for var_, val in six.iteritems(vars_and_vals): - # `assign` causes a copy to GPU (if variable is already on GPU) - var_.assign(val) + return self._moving_average_variables diff --git a/tensorflow/contrib/eager/python/examples/revnet/revnet_test.py b/tensorflow/contrib/eager/python/examples/revnet/revnet_test.py index 26b084752330ec6ccc3bbf34bcbfeb95a7429907..6a921e19978fdf6e3c20974b2c349bd6923b5782 100644 --- a/tensorflow/contrib/eager/python/examples/revnet/revnet_test.py +++ b/tensorflow/contrib/eager/python/examples/revnet/revnet_test.py @@ -31,9 +31,11 @@ tfe = tf.contrib.eager def train_one_iter(model, inputs, labels, optimizer, global_step=None): """Train for one iteration.""" - grads, vars_, logits, loss = model.compute_gradients( - inputs, labels, training=True) - optimizer.apply_gradients(zip(grads, vars_), global_step=global_step) + logits, saved_hidden = model(inputs) + grads, loss = model.compute_gradients( + saved_hidden=saved_hidden, labels=labels) + optimizer.apply_gradients( + zip(grads, model.trainable_variables), global_step=global_step) return logits, loss @@ -96,9 +98,10 @@ class RevNetTest(tf.test.TestCase): def test_compute_gradients(self): """Test `compute_gradients` function.""" - self.model(self.x, training=False) # Initialize model - grads, vars_, logits, loss = self.model.compute_gradients( - inputs=self.x, labels=self.t, training=True, l2_reg=True) + _, saved_hidden = self.model(self.x) # Initialize model + grads, loss = self.model.compute_gradients( + saved_hidden=saved_hidden, labels=self.t) + vars_ = self.model.trainable_variables self.assertTrue(isinstance(grads, list)) self.assertTrue(isinstance(vars_, list)) self.assertEqual(len(grads), len(vars_)) @@ -107,7 +110,7 @@ class RevNetTest(tf.test.TestCase): # Compare against the true gradient computed by the tape with tf.GradientTape() as tape: - logits, _ = self.model(self.x, training=True) + logits, _ = self.model(self.x) loss_true = self.model.compute_loss(logits=logits, labels=self.t) grads_true = tape.gradient(loss_true, vars_) self.assertAllClose(loss, loss_true) @@ -122,7 +125,9 @@ class RevNetTest(tf.test.TestCase): def test_compute_gradients_defun(self): """Test `compute_gradients` function with defun.""" compute_gradients = tfe.defun(self.model.compute_gradients) - grads, vars_, _, _ = compute_gradients(self.x, self.t, training=True) + _, saved_hidden = self.model(self.x) + grads, _ = compute_gradients(saved_hidden=saved_hidden, labels=self.t) + vars_ = self.model.trainable_variables self.assertTrue(isinstance(grads, list)) self.assertTrue(isinstance(vars_, list)) self.assertEqual(len(grads), len(vars_)) @@ -146,10 +151,11 @@ class RevNetTest(tf.test.TestCase): dtype=tf.int32) global_step = tf.Variable(0., trainable=False) model = revnet.RevNet(config=config) - grads_all, vars_all, _, _ = model.compute_gradients(x, t, training=True) + _, saved_hidden = model(x) + grads, _ = model.compute_gradients(saved_hidden=saved_hidden, labels=t) optimizer = tf.train.AdamOptimizer(learning_rate=1e-3) train_op = optimizer.apply_gradients( - zip(grads_all, vars_all), global_step=global_step) + zip(grads, model.trainable_variables), global_step=global_step) with tf.Session() as sess: sess.run(tf.global_variables_initializer()) @@ -220,14 +226,13 @@ class RevNetBenchmark(tf.test.Benchmark): label, device_and_format, defun=False, - execution_mode=None, - compiled=False): + execution_mode=None): config = config_.get_hparams_imagenet_56() with tfe.execution_mode(execution_mode): device, data_format = device_and_format model = revnet.RevNet(config=config) if defun: - model.call = tfe.defun(model.call, compiled=compiled) + model.call = tfe.defun(model.call) batch_size = 64 num_burn = 5 num_iters = 10 @@ -265,8 +270,7 @@ class RevNetBenchmark(tf.test.Benchmark): make_iterator, device_and_format, defun=False, - execution_mode=None, - compiled=False): + execution_mode=None): config = config_.get_hparams_imagenet_56() with tfe.execution_mode(execution_mode): device, data_format = device_and_format diff --git a/tensorflow/contrib/eager/python/examples/sagan/BUILD b/tensorflow/contrib/eager/python/examples/sagan/BUILD deleted file mode 100644 index b470a41d815ce650731680065cc7341f844e3fdc..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/eager/python/examples/sagan/BUILD +++ /dev/null @@ -1,59 +0,0 @@ -licenses(["notice"]) # Apache 2.0 - -package(default_visibility = ["//tensorflow:internal"]) - -load("//tensorflow:tensorflow.bzl", "cuda_py_test") - -# Model -py_library( - name = "config", - srcs = ["config.py"], - srcs_version = "PY2AND3", - deps = [ - "//tensorflow:tensorflow_py", - ], -) - -py_library( - name = "ops", - srcs = ["ops.py"], - srcs_version = "PY2AND3", - deps = [ - "//tensorflow:tensorflow_py", - ], -) - -py_library( - name = "sagan", - srcs = ["sagan.py"], - srcs_version = "PY2AND3", - deps = [ - ":ops", - "//tensorflow:tensorflow_py", - ], -) - -# Tests -cuda_py_test( - name = "ops_test", - size = "small", - srcs = ["ops_test.py"], - additional_deps = [ - ":ops", - "//tensorflow:tensorflow_py", - ], -) - -cuda_py_test( - name = "sagan_test", - size = "large", - srcs = ["sagan_test.py"], - additional_deps = [ - ":config", - ":sagan", - "//tensorflow:tensorflow_py", - ], - tags = [ - "optonly", - ], -) diff --git a/tensorflow/contrib/eager/python/examples/sagan/config.py b/tensorflow/contrib/eager/python/examples/sagan/config.py deleted file mode 100644 index 1967bbd867447d9deaf9a7cb3b22a38889276a50..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/eager/python/examples/sagan/config.py +++ /dev/null @@ -1,72 +0,0 @@ -# Copyright 2018 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Self-attention generative adversarial with eager execution. - -Configuration in format of tf.contrib.training.HParams. -Supports default 128x128 ImageNet. - -Reference [Self-Attention Generative Adversarial -Networks](https://arxiv.org/pdf/1805.08318.pdf) - -""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import tensorflow as tf -tfe = tf.contrib.eager - - -def get_hparams_imagenet(): - """Configurations to train SAGAN on 128x128 ImageNet dataset.""" - config = tf.contrib.training.HParams() - if tf.test.is_gpu_available(): - config.add_hparam("image_shape", (3, 128, 128)) - config.add_hparam("data_format", "channels_first") - config.add_hparam("g_init_shape", (512, 4, 4)) - else: - config.add_hparam("image_shape", (128, 128, 3)) - config.add_hparam("data_format", "channels_first") - config.add_hparam("g_init_shape", (4, 4, 512)) - - config.add_hparam("latent_dim", 128) - config.add_hparam("update_g_once_every", 1) - config.add_hparam("batch_size", 64) - config.add_hparam("d_init_filters", 32) - config.add_hparam("num_upsamples", 5) - # (512, 4, 4) -> (3, 128, 128) - return config - - -def get_hparams_mock(): - """Configurations of smaller networks for testing.""" - config = tf.contrib.training.HParams() - if tf.test.is_gpu_available(): - config.add_hparam("image_shape", (3, 16, 16)) - config.add_hparam("data_format", "channels_first") - config.add_hparam("g_init_shape", (32, 2, 2)) - else: - config.add_hparam("image_shape", (16, 16, 3)) - config.add_hparam("data_format", "channels_last") - config.add_hparam("g_init_shape", (2, 2, 32)) - - config.add_hparam("latent_dim", 16) - config.add_hparam("update_g_once_every", 1) - config.add_hparam("batch_size", 2) - config.add_hparam("d_init_filters", 4) - config.add_hparam("num_upsamples", 3) - # (32, 2, 2) -> (3, 16, 16) - return config diff --git a/tensorflow/contrib/eager/python/examples/sagan/ops.py b/tensorflow/contrib/eager/python/examples/sagan/ops.py deleted file mode 100644 index 9a03cab1d12fc16baa7343f72ac58ccd39f698bc..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/eager/python/examples/sagan/ops.py +++ /dev/null @@ -1,71 +0,0 @@ -# Copyright 2018 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Self-attention generative adversarial with eager execution. - -Auxiliary operations. - -Reference [Self-Attention Generative Adversarial -Networks](https://arxiv.org/pdf/1805.08318.pdf) -""" -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import tensorflow as tf - - -def flatten_hw(x, data_format="channels_first"): - """Flatten the input tensor across height and width dimensions.""" - if data_format == "channels_last": - x = tf.transpose(x, perm=[0, 3, 1, 2]) # Convert to `channels_first` - - old_shape = tf.shape(x) - new_shape = [old_shape[0], old_shape[2] * old_shape[3], old_shape[1]] - - return tf.reshape(x, new_shape) - - -def broaden_hw(x, h, w, c, data_format="channels_first"): - """Broaden dimension so that output has height and width.""" - if data_format == "channels_first": - shape = [-1, c, h, w] - else: - shape = [-1, h, w, c] - - return tf.reshape(x, shape) - - -class BroadenHW(tf.keras.layers.Layer): - """Wrapper class so that `broaden_hw` can be used in `tf.keras.Sequential`.""" - - def __init__(self, h, w, c, data_format="channels_first"): - super(BroadenHW, self).__init__() - self.h = h - self.w = w - self.c = c - self.data_format = data_format - - def call(self, x): - return broaden_hw( - x, h=self.h, w=self.w, c=self.c, data_format=self.data_format) - - def compute_output_shape(self, input_shape): - input_shape = tf.TensorShape(input_shape).as_list() - if self.data_format == "channels_first": - output_shape = (input_shape[0], self.c, self.h, self.w) - else: - output_shape = (input_shape[0], self.h, self.w, self.c) - - return tf.TensorShape(output_shape) diff --git a/tensorflow/contrib/eager/python/examples/sagan/ops_test.py b/tensorflow/contrib/eager/python/examples/sagan/ops_test.py deleted file mode 100644 index 3454985904215b59d27fc4b76ccb4a8c2c2eff00..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/eager/python/examples/sagan/ops_test.py +++ /dev/null @@ -1,59 +0,0 @@ -# Copyright 2018 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Tests for auxiliary operations.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import tensorflow as tf -from tensorflow.contrib.eager.python.examples.sagan import ops - - -class OpsTest(tf.test.TestCase): - - def test_flatten_hw(self): - """Test `flatten_hw` function with mock object.""" - - batch_size = 1 - # Default NCHW format - if tf.test.is_gpu_available(): - x = tf.random_normal(shape=(batch_size, 3, 4, 4)) - y = ops.flatten_hw(x, data_format="channels_first") - self.assertEqual(y.shape, (batch_size, 4 * 4, 3)) - - # NHWC format - x = tf.random_normal(shape=(batch_size, 4, 4, 3)) - y = ops.flatten_hw(x, data_format="channels_last") - self.assertEqual(y.shape, (batch_size, 4 * 4, 3)) - - def test_broaden_hw(self): - """Test `broaden_hw` function with mock object.""" - - batch_size = 1 - # NHWC format - x = tf.random_normal(shape=[batch_size, 4 * 4 * 16]) - y = ops.broaden_hw(x, h=4, w=4, c=16, data_format="channels_last") - self.assertEqual(y.shape, (batch_size, 4, 4, 16)) - - # Default NCHW format - if tf.test.is_gpu_available(): - y = ops.broaden_hw(x, h=4, w=4, c=16, data_format="channels_first") - self.assertEqual(y.shape, (batch_size, 16, 4, 4)) - - -if __name__ == "__main__": - tf.enable_eager_execution() - tf.test.main() diff --git a/tensorflow/contrib/eager/python/examples/sagan/sagan.py b/tensorflow/contrib/eager/python/examples/sagan/sagan.py deleted file mode 100644 index 81304149851675e07a3c7f9ad92697da2017022b..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/eager/python/examples/sagan/sagan.py +++ /dev/null @@ -1,232 +0,0 @@ -# Copyright 2018 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Self-attention generative adversarial with eager execution. - -Code for main model. - -Reference [Self-Attention Generative Adversarial -Networks](https://arxiv.org/pdf/1805.08318.pdf) -""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import numpy as np -import tensorflow as tf -from tensorflow.contrib.eager.python.examples.sagan import ops -tfe = tf.contrib.eager - - -class SelfAttentionModule(tf.keras.Model): - """Self-attention module composed of convolutional layers.""" - - def __init__(self, - attention_features, - original_features, - data_format="channels_first"): - """Initialize the module. - - Args: - attention_features: Number of filters for the attention computation. - original_features: Number of filters of the original Tensor. - data_format: Either 'channels_first' or 'channels_last' - """ - super(SelfAttentionModule, self).__init__() - self.data_format = data_format - # Matrix multiplication implemented as 2D Convolution - self.f = tf.keras.layers.Conv2D( - filters=attention_features, - kernel_size=1, - strides=(1, 1), - data_format=data_format) - self.g = tf.keras.layers.Conv2D( - filters=attention_features, - kernel_size=1, - strides=(1, 1), - data_format=data_format) - self.h = tf.keras.layers.Conv2D( - filters=original_features, - kernel_size=1, - strides=(1, 1), - data_format=data_format) - self.scale = tf.Variable(0., trainable=True) - - def call(self, x): - f = self.f(x) - g = self.g(x) - h = self.h(x) - - f_flatten = ops.flatten_hw(f, data_format=self.data_format) - g_flatten = ops.flatten_hw(g, data_format=self.data_format) - h_flatten = ops.flatten_hw(h, data_format=self.data_format) - - s = tf.matmul(g_flatten, f_flatten, transpose_b=True) - b = tf.nn.softmax(s, axis=-1) - o = tf.matmul(b, h_flatten) - y = self.scale * tf.reshape(o, tf.shape(x)) + x - - return y - - def compute_output_shape(self, input_shape): - return input_shape - - -class SAGAN(tf.contrib.checkpoint.Checkpointable): - """Self-attention generative adversarial network.""" - - def __init__(self, config): - """Initialize the model. - - Args: - config: tf.contrib.training.HParams object; specifies hyperparameters - """ - super(SAGAN, self).__init__() - self.config = config - self.generator = self._construct_generator() - self.discriminator = self._construct_discriminator() - - def _construct_generator(self): - """Construct generator.""" - # TODO(lxuechen): Add spectral normalization for WGAN - axis = 1 if self.config.data_format == "channels_first" else 3 - - generator = tf.keras.Sequential() - generator.add( - tf.keras.layers.InputLayer(input_shape=(self.config.latent_dim,))) - generator.add( - tf.keras.layers.Dense( - units=np.prod(self.config.g_init_shape), activation=tf.nn.relu)) - - if self.config.data_format == "channels_first": - c, h, w = self.config.g_init_shape - else: - h, w, c = self.config.g_init_shape - - # Reshape to NHWC/NCHW - generator.add( - ops.BroadenHW(h=h, w=w, c=c, data_format=self.config.data_format)) - - filters_list = [c // 2**p for p in range(1, self.config.num_upsamples + 1)] - filters_list[-1] = 3 # Standard RGB images - - for filters in filters_list[:len(filters_list) // 2]: - generator.add( - tf.keras.layers.Conv2DTranspose( - filters=filters, - kernel_size=4, - strides=(2, 2), - use_bias=False, - padding="SAME", - data_format=self.config.data_format)) - generator.add(tf.keras.layers.BatchNormalization(axis=axis)) - generator.add(tf.keras.layers.Activation("relu")) - - # pylint: disable=undefined-loop-variable - generator.add( - SelfAttentionModule( - original_features=filters, - attention_features=filters // 8, - data_format=self.config.data_format)) - # pylint: enable=undefined-loop-variable - - for filters in filters_list[len(filters_list) // 2:]: - generator.add( - tf.keras.layers.Conv2DTranspose( - filters=filters, - kernel_size=4, - strides=(2, 2), - use_bias=False, - padding="SAME", - data_format=self.config.data_format)) - if filters == 3: - # Assume Image rescaled to [-1, 1] - generator.add(tf.keras.layers.Activation("tanh")) - else: - generator.add(tf.keras.layers.BatchNormalization(axis=axis)) - generator.add(tf.keras.layers.Activation("relu")) - - return generator - - def _construct_discriminator(self): - """Construct discriminator.""" - # TODO(lxuechen): Add spectral normalization for WGAN - discriminator = tf.keras.Sequential() - discriminator.add( - tf.keras.layers.InputLayer(input_shape=self.config.image_shape)) - - filters_list = [ - self.config.d_init_filters * 2**p - for p in range(self.config.num_upsamples) - ] - - for filters in filters_list[:(len(filters_list) + 1) // 2]: - discriminator.add( - tf.keras.layers.Conv2D( - filters=filters, - kernel_size=4, - strides=(2, 2), - padding="SAME", - data_format=self.config.data_format)) - discriminator.add(tf.keras.layers.LeakyReLU(alpha=.1)) - - # pylint: disable=undefined-loop-variable - discriminator.add( - SelfAttentionModule( - original_features=filters, - attention_features=filters // 8, - data_format=self.config.data_format)) - # pylint: enable=undefined-loop-variable - - for filters in filters_list[(len(filters_list) + 1) // 2:]: - discriminator.add( - tf.keras.layers.Conv2D( - filters=filters, - kernel_size=4, - strides=(2, 2), - padding="SAME", - data_format=self.config.data_format)) - discriminator.add(tf.keras.layers.LeakyReLU(alpha=.1)) - - discriminator.add(tf.keras.layers.Flatten()) - discriminator.add(tf.keras.layers.Dense(units=1)) - - return discriminator - - def compute_loss_and_grads(self, real_images, noise, training=True): - """Compute loss and gradients for both generator and discriminator.""" - # TODO(lxuechen): Add gradient penalty for discriminator - with tf.GradientTape() as g_tape, tf.GradientTape() as d_tape: - real_logits = self.discriminator(real_images, training=training) - - fake_images = self.generator.call(noise, training=training) - fake_logits = self.discriminator.call(fake_images) - - g_loss = self.compute_g_loss(fake_logits) - d_loss = self.compute_d_loss(fake_logits, real_logits) - - g_grads = g_tape.gradient(g_loss, self.generator.trainable_variables) - d_grads = d_tape.gradient(d_loss, self.discriminator.trainable_variables) - - return g_loss, d_loss, g_grads, d_grads - - def compute_g_loss(self, fake_logits): - return -tf.reduce_mean(fake_logits) # Hinge loss - - def compute_d_loss(self, fake_logits, real_logits): - # Hinge loss - real_loss = tf.reduce_mean(tf.nn.relu(1. - real_logits)) - fake_loss = tf.reduce_mean(tf.nn.relu(1. + fake_logits)) - return real_loss + fake_loss diff --git a/tensorflow/contrib/eager/python/examples/sagan/sagan_test.py b/tensorflow/contrib/eager/python/examples/sagan/sagan_test.py deleted file mode 100644 index 18345945108111b57c5401c26b7dca0bfc8f8316..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/eager/python/examples/sagan/sagan_test.py +++ /dev/null @@ -1,101 +0,0 @@ -# Copyright 2018 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Tests for self-attention generative adversarial network.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import tensorflow as tf -from tensorflow.contrib.eager.python.examples.sagan import config as config_ -from tensorflow.contrib.eager.python.examples.sagan import sagan -tfe = tf.contrib.eager - - -class SAGANTest(tf.test.TestCase): - - def setUp(self): - super(SAGANTest, self).setUp() - config = config_.get_hparams_mock() - self.noise_shape = (config.batch_size, config.latent_dim) - self.logits_shape = (config.batch_size, 1) - self.images_shape = (config.batch_size,) + config.image_shape - - self.model = sagan.SAGAN(config=config) - self.noise = tf.random_normal(shape=self.noise_shape) - self.real_images = tf.random_normal(shape=self.images_shape) - self.config = config - - def tearDown(self): - del self.model - del self.noise - del self.real_images - super(SAGANTest, self).tearDown() - - def test_generator_call(self): - """Test `generator.__call__` function.""" - fake_images = self.model.generator(self.noise, training=False) - self.assertEqual(fake_images.shape, self.images_shape) - - def test_generator_call_defun(self): - """Test `generator.__call__` function with defun.""" - call_ = tfe.defun(self.model.generator.__call__) - fake_images = call_(self.noise, training=False) - self.assertEqual(fake_images.shape, self.images_shape) - - def test_discriminator_call(self): - """Test `discriminator.__call__` function.""" - real_logits = self.model.discriminator(self.real_images) - self.assertEqual(real_logits.shape, self.logits_shape) - - def test_discriminator_call_defun(self): - """Test `discriminator.__call__` function with defun.""" - call_ = tfe.defun(self.model.discriminator.__call__) - real_logits = call_(self.real_images) - self.assertEqual(real_logits.shape, self.logits_shape) - - def test_compute_loss_and_grads(self): - """Test `compute_loss_and_grads` function.""" - g_loss, d_loss, g_grads, d_grads = self.model.compute_loss_and_grads( - self.real_images, self.noise, training=False) - self.assertEqual(g_loss.shape, ()) - self.assertEqual(d_loss.shape, ()) - self.assertTrue(isinstance(g_grads, list)) - self.assertTrue(isinstance(d_grads, list)) - g_vars = self.model.generator.trainable_variables - d_vars = self.model.discriminator.trainable_variables - - self.assertEqual(len(g_grads), len(g_vars)) - self.assertEqual(len(d_grads), len(d_vars)) - - def test_compute_loss_and_grads_defun(self): - """Test `compute_loss_and_grads` function with defun.""" - compute_loss_and_grads = tfe.defun(self.model.compute_loss_and_grads) - g_loss, d_loss, g_grads, d_grads = compute_loss_and_grads( - self.real_images, self.noise, training=False) - self.assertEqual(g_loss.shape, ()) - self.assertEqual(d_loss.shape, ()) - self.assertTrue(isinstance(g_grads, list)) - self.assertTrue(isinstance(d_grads, list)) - g_vars = self.model.generator.trainable_variables - d_vars = self.model.discriminator.trainable_variables - - self.assertEqual(len(g_grads), len(g_vars)) - self.assertEqual(len(d_grads), len(d_vars)) - - -if __name__ == "__main__": - tf.enable_eager_execution() - tf.test.main() diff --git a/tensorflow/contrib/eager/python/examples/spinn/spinn_test.py b/tensorflow/contrib/eager/python/examples/spinn/spinn_test.py index 8ac553e0ae71382966d03d9ef4429adf5137b369..d18a097063c7d25947af3e2e2959ce574edd553f 100644 --- a/tensorflow/contrib/eager/python/examples/spinn/spinn_test.py +++ b/tensorflow/contrib/eager/python/examples/spinn/spinn_test.py @@ -36,7 +36,7 @@ from third_party.examples.eager.spinn import spinn from tensorflow.contrib.summary import summary_test_util from tensorflow.python.eager import test from tensorflow.python.framework import test_util -from tensorflow.python.training import saver +from tensorflow.python.training import checkpoint_management from tensorflow.python.training.checkpointable import util as checkpointable_utils # pylint: enable=g-bad-import-order @@ -422,7 +422,7 @@ class SpinnTest(test_util.TensorFlowTestCase): # 5. Verify that checkpoints exist and contains all the expected variables. self.assertTrue(glob.glob(os.path.join(config.logdir, "ckpt*"))) object_graph = checkpointable_utils.object_metadata( - saver.latest_checkpoint(config.logdir)) + checkpoint_management.latest_checkpoint(config.logdir)) ckpt_variable_names = set() for node in object_graph.nodes: for attribute in node.attributes: diff --git a/tensorflow/contrib/eager/python/remote_test.py b/tensorflow/contrib/eager/python/remote_test.py new file mode 100644 index 0000000000000000000000000000000000000000..76f48eeb1cab9d1f014adeafe4827cb5d3a8c77d --- /dev/null +++ b/tensorflow/contrib/eager/python/remote_test.py @@ -0,0 +1,178 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for remote eager execution.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import functools +import os + +import numpy as np + +from tensorflow.core.protobuf import cluster_pb2 +from tensorflow.core.protobuf import tensorflow_server_pb2 +from tensorflow.python.eager import backprop +from tensorflow.python.eager import context +from tensorflow.python.eager import function +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import resource_variable_ops +from tensorflow.python.platform import test +from tensorflow.python.training import server_lib + +JOB_NAME = "remote_device" +ALT_JOB_NAME = "alt_remote_device" + + +def run_sync_and_async(f): + """Execute all test methods in the given class in sync and async modes.""" + + @functools.wraps(f) + def decorator(self, *args, **kwargs): + with context.execution_mode(context.ASYNC): + f(self, *args, **kwargs) + + with context.execution_mode(context.SYNC): + f(self, *args, **kwargs) + + return decorator + + +def get_server_def(job_name, local_server_port, remote_server_addresses, + task_index): + """Returns a server def with a single job + multiple tasks.""" + cluster_def = cluster_pb2.ClusterDef() + job_def = cluster_def.job.add() + job_def.name = job_name + job_def.tasks[0] = "localhost:%d" % local_server_port + + for i, remote_server_address in enumerate(remote_server_addresses, start=1): + job_def.tasks[i] = remote_server_address + + server_def = tensorflow_server_pb2.ServerDef( + cluster=cluster_def, + job_name=job_name, + task_index=task_index, + protocol="grpc") + + return server_def + + +class RemoteExecutionTest(test.TestCase): + + def __init__(self, methodName="runTest"): # pylint: disable=invalid-name + super(RemoteExecutionTest, self).__init__(methodName) + self._cached_server1 = server_lib.Server.create_local_server() + self._cached_server2 = server_lib.Server.create_local_server() + + os.environ["TF_EAGER_REMOTE_USE_SEND_TENSOR_RPC"] = "1" + + self._cached_server1_target = self._cached_server1.target[len("grpc://"):] + self._cached_server2_target = self._cached_server2.target[len("grpc://"):] + + # Start the local server. + context.set_server_def( + server_def=get_server_def( + JOB_NAME, + local_server_port=0, + remote_server_addresses=[ + self._cached_server1_target, self._cached_server2_target + ], + task_index=0)) + + @run_sync_and_async + def testDefunMatmul(self): + """Basic remote eager execution with defun.""" + + mm_defun = function.defun(math_ops.matmul) + with ops.device("job:%s/replica:0/task:1/device:CPU:0" % JOB_NAME): + x1 = array_ops.ones([2, 2]) + with ops.device("job:%s/replica:0/task:2/device:CPU:0" % JOB_NAME): + x2 = array_ops.ones([2, 2]) + y = mm_defun(x1, x2) + np.testing.assert_array_equal([[2, 2], [2, 2]], y.numpy()) + + @run_sync_and_async + def testSimpleMatmul(self): + """Basic remote eager execution.""" + + with ops.device("job:%s/replica:0/task:1/device:CPU:0" % JOB_NAME): + x1 = array_ops.ones([2, 2]) + with ops.device("job:%s/replica:0/task:2/device:CPU:0" % JOB_NAME): + x2 = array_ops.ones([2, 2]) + y = math_ops.matmul(x1, x2) + np.testing.assert_array_equal([[2, 2], [2, 2]], y.numpy()) + + @run_sync_and_async + def testSimpleWeightRead(self): + """Basic remote eager weight read.""" + + with ops.device("job:%s/replica:0/task:1/device:CPU:0" % JOB_NAME): + w = resource_variable_ops.ResourceVariable([[2.0]]) + loss = w * w + np.testing.assert_array_equal([[4.0]], loss.numpy()) + + @run_sync_and_async + def testTapeWeightRead(self): + """Remote eager weight read in a tape.""" + + with ops.device("job:%s/replica:0/task:1/device:CPU:0" % JOB_NAME): + w = resource_variable_ops.ResourceVariable([[3.0]]) + with backprop.GradientTape() as tape: + loss = w * w + + grad = tape.gradient(loss, w) + np.testing.assert_array_equal([[9.0]], loss.numpy()) + np.testing.assert_array_equal([[6.0]], grad.numpy()) + + @run_sync_and_async + def testServerDefChanged(self): + """Update server def, and run ops on new cluster.""" + context.set_server_def( + server_def=get_server_def( + ALT_JOB_NAME, + local_server_port=0, + remote_server_addresses=[ + self._cached_server1_target, self._cached_server2_target + ], + task_index=0)) + + with ops.device("job:%s/replica:0/task:1/device:CPU:0" % ALT_JOB_NAME): + x1 = array_ops.ones([2, 2]) + y = math_ops.matmul(x1, x1) + np.testing.assert_array_equal([[2, 2], [2, 2]], y.numpy()) + + # Set the server def back to JOB_NAME + context.set_server_def( + server_def=get_server_def( + JOB_NAME, + local_server_port=0, + remote_server_addresses=[ + self._cached_server1_target, self._cached_server2_target + ], + task_index=0)) + + with ops.device("job:%s/replica:0/task:1/device:CPU:0" % JOB_NAME): + x1 = array_ops.ones([2, 2]) + y = math_ops.matmul(x1, x1) + np.testing.assert_array_equal([[2, 2], [2, 2]], y.numpy()) + + +if __name__ == "__main__": + ops.enable_eager_execution() + test.main() diff --git a/tensorflow/contrib/eager/python/saver.py b/tensorflow/contrib/eager/python/saver.py index d70930864784b3e48140da27ca33ff13f593e663..f9c716360c5755ee1902b576545d776725f9966f 100644 --- a/tensorflow/contrib/eager/python/saver.py +++ b/tensorflow/contrib/eager/python/saver.py @@ -161,7 +161,7 @@ class Saver(object): Args: file_prefix: Path prefix where parameters were previously saved. Typically obtained from a previous `save()` call, or from - @{tf.train.latest_checkpoint}. + `tf.train.latest_checkpoint`. """ with ops.device("/device:CPU:0"): self._saver.restore(None, file_prefix) diff --git a/tensorflow/contrib/eager/python/saver_test.py b/tensorflow/contrib/eager/python/saver_test.py index 90a3711475719a7f991473c6c9067da1e76ab9f2..91bc75213c72a7c44722e2cc2395f6a06a76f948 100644 --- a/tensorflow/contrib/eager/python/saver_test.py +++ b/tensorflow/contrib/eager/python/saver_test.py @@ -21,15 +21,11 @@ import os from tensorflow.contrib.eager.python import saver as _saver from tensorflow.python.eager import context -from tensorflow.python.eager import graph_callable from tensorflow.python.eager import test -from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops -from tensorflow.python.ops import init_ops from tensorflow.python.ops import resource_variable_ops -from tensorflow.python.ops import variable_scope from tensorflow.python.training import adam from tensorflow.python.training import gradient_descent from tensorflow.python.training import momentum @@ -142,53 +138,6 @@ class SaverTest(test.TestCase): with _saver.restore_variables_on_create(ckpt_prefix): _ = model(resource_variable_ops.ResourceVariable(1.0, name='v2')) - def testSaveRestoreGraphCallable(self): - with ops.device(self._dev()): - @graph_callable.graph_callable( - [graph_callable.ShapeAndDtype(shape=(), dtype=dtypes.float32)]) - def model(x): - v = variable_scope.get_variable( - 'v', initializer=init_ops.zeros_initializer(), shape=()) - return v + x - - # Default 2 + 0 = 2 - self.assertEqual( - 2, model(array_ops.constant(2, dtype=dtypes.float32)).numpy()) - - # Save the variable value 0. - ckpt_prefix = os.path.join(test.get_temp_dir(), 'ckpt') - _saver.Saver(model.variables).save(ckpt_prefix) - - # update variable to 1, so that 2 + 1 = 3 - model.variables[0].assign(1.) - self.assertEqual( - 3, model(array_ops.constant(2, dtype=dtypes.float32)).numpy()) - - # load the variable value 0, so that 2 + 0 = 2 - _saver.Saver(model.variables).restore(ckpt_prefix) - self.assertEqual( - 2, model(array_ops.constant(2, dtype=dtypes.float32)).numpy()) - - # update checkpoint variable to 1 and memory value to 2. - model.variables[0].assign(1.) - _saver.Saver(model.variables).save(ckpt_prefix) - model.variables[0].assign(2.) - self.assertEqual( - 4, model(array_ops.constant(2, dtype=dtypes.float32)).numpy()) - - # reset the graph and reload on create, so that 1 + 2 = 3 - ops.reset_default_graph() - with _saver.restore_variables_on_create(ckpt_prefix): - @graph_callable.graph_callable( - [graph_callable.ShapeAndDtype(shape=(), dtype=dtypes.float32)]) - def model2(x): - v = variable_scope.get_variable( - 'v', initializer=init_ops.zeros_initializer(), shape=()) - return v + x - - self.assertEqual( - 3, model2(array_ops.constant(2, dtype=dtypes.float32)).numpy()) - class GetOptimizerTests(test.TestCase): diff --git a/tensorflow/contrib/eager/python/tfe.py b/tensorflow/contrib/eager/python/tfe.py index ca6430253b67d825290b6a376ba3f29b3ae67577..4dfd0834430b2295d1454314e88c824efe4c8b13 100644 --- a/tensorflow/contrib/eager/python/tfe.py +++ b/tensorflow/contrib/eager/python/tfe.py @@ -16,7 +16,7 @@ EXPERIMENTAL: APIs here are unstable and likely to change without notice. -To use, at program startup, call `tfe.enable_eager_execution()`. +To use, at program startup, call `tf.enable_eager_execution()`. @@metrics @@ -34,6 +34,7 @@ To use, at program startup, call `tfe.enable_eager_execution()`. @@run @@enable_eager_execution +@@enable_remote_eager_execution @@custom_gradient @@ -66,10 +67,13 @@ To use, at program startup, call `tfe.enable_eager_execution()`. @@execution_mode @@async_wait @@async_clear_error +@@set_server_def @@run_test_in_graph_and_eager_modes @@run_all_tests_in_graph_and_eager_modes +@@TensorSpec + @@DEVICE_PLACEMENT_EXPLICIT @@DEVICE_PLACEMENT_WARN @@DEVICE_PLACEMENT_SILENT @@ -107,13 +111,16 @@ from tensorflow.python.eager.context import async_clear_error from tensorflow.python.eager.context import SYNC from tensorflow.python.eager.context import ASYNC from tensorflow.python.eager.context import num_gpus +from tensorflow.python.eager.context import set_server_def from tensorflow.python.eager.execution_callbacks import add_execution_callback from tensorflow.python.eager.execution_callbacks import clear_execution_callbacks from tensorflow.python.eager.execution_callbacks import inf_callback from tensorflow.python.eager.execution_callbacks import inf_nan_callback from tensorflow.python.eager.execution_callbacks import nan_callback from tensorflow.python.eager.execution_callbacks import seterr +from tensorflow.python.framework.tensor_spec import TensorSpec from tensorflow.python.framework.ops import enable_eager_execution +from tensorflow.python.framework.ops import enable_eager_execution_internal as enable_remote_eager_execution from tensorflow.python.framework.ops import eager_run as run from tensorflow.python.framework.test_util import run_in_graph_and_eager_modes as run_test_in_graph_and_eager_modes from tensorflow.python.framework.test_util import run_all_in_graph_and_eager_modes as run_all_tests_in_graph_and_eager_modes diff --git a/tensorflow/contrib/estimator/BUILD b/tensorflow/contrib/estimator/BUILD index 349f48f7f788b458af2639f7ad4cc4cd904465b4..77f62df99d5a052e2df61d3f225e1860d4d1da72 100644 --- a/tensorflow/contrib/estimator/BUILD +++ b/tensorflow/contrib/estimator/BUILD @@ -20,6 +20,7 @@ py_library( ":dnn_linear_combined", ":early_stopping", ":export", + ":exporter", ":extenders", ":head", ":hooks", @@ -219,6 +220,33 @@ py_test( ], ) +py_library( + name = "exporter", + srcs = [ + "python/estimator/exporter.py", + ], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow/python:framework_ops", + "//tensorflow/python:platform", + "//tensorflow/python:summary", + "//tensorflow/python/estimator:exporter", + ], +) + +py_test( + name = "exporter_test", + size = "medium", + srcs = ["python/estimator/exporter_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":exporter", + "//tensorflow/python:platform", + "//tensorflow/python/estimator", + "//tensorflow/python/estimator:exporter", + ], +) + py_library( name = "head", srcs = [ @@ -487,6 +515,9 @@ py_test( size = "medium", srcs = ["python/estimator/saved_model_estimator_test.py"], srcs_version = "PY2AND3", + tags = [ + "notsan", + ], deps = [ ":export", ":saved_model_estimator", diff --git a/tensorflow/contrib/estimator/__init__.py b/tensorflow/contrib/estimator/__init__.py index e1453ae1d04ebd8d72f812b51480f0b05f7a5416..258860f26340a0934e854f2d1950ead60e413234 100644 --- a/tensorflow/contrib/estimator/__init__.py +++ b/tensorflow/contrib/estimator/__init__.py @@ -45,6 +45,7 @@ _allowed_symbols = [ 'clip_gradients_by_norm', 'forward_features', 'InMemoryEvaluatorHook', + 'make_stop_at_checkpoint_step_hook', 'logistic_regression_head', 'multi_class_head', 'multi_head', diff --git a/tensorflow/contrib/estimator/python/estimator/boosted_trees.py b/tensorflow/contrib/estimator/python/estimator/boosted_trees.py index 43bfcffd790e7b3c716c3f70820851a8819af225..7ed77bcce6f00ed13e9952951800f1017d582f19 100644 --- a/tensorflow/contrib/estimator/python/estimator/boosted_trees.py +++ b/tensorflow/contrib/estimator/python/estimator/boosted_trees.py @@ -50,7 +50,8 @@ class _BoostedTreesEstimator(estimator.Estimator): tree_complexity=0., min_node_weight=0., config=None, - center_bias=False): + center_bias=False, + pruning_mode='none'): """Initializes a `BoostedTreesEstimator` instance. Args: @@ -89,13 +90,18 @@ class _BoostedTreesEstimator(estimator.Estimator): regression problems, the first node will return the mean of the labels. For binary classification problems, it will return a logit for a prior probability of label 1. + pruning_mode: one of 'none', 'pre', 'post' to indicate no pruning, pre- + pruning (do not split a node if not enough gain is observed) and post + pruning (build the tree up to a max depth and then prune branches with + negative gain). For pre and post pruning, you MUST provide + tree_complexity >0. """ # pylint:disable=protected-access # HParams for the model. tree_hparams = canned_boosted_trees._TreeHParams( n_trees, max_depth, learning_rate, l1_regularization, l2_regularization, - tree_complexity, min_node_weight, center_bias) + tree_complexity, min_node_weight, center_bias, pruning_mode) def _model_fn(features, labels, mode, config): return canned_boosted_trees._bt_model_fn( @@ -129,7 +135,8 @@ def boosted_trees_classifier_train_in_memory( min_node_weight=0., config=None, train_hooks=None, - center_bias=False): + center_bias=False, + pruning_mode='none'): """Trains a boosted tree classifier with in memory dataset. Example: @@ -208,6 +215,11 @@ def boosted_trees_classifier_train_in_memory( regression problems, the first node will return the mean of the labels. For binary classification problems, it will return a logit for a prior probability of label 1. + pruning_mode: one of 'none', 'pre', 'post' to indicate no pruning, pre- + pruning (do not split a node if not enough gain is observed) and post + pruning (build the tree up to a max depth and then prune branches with + negative gain). For pre and post pruning, you MUST provide + tree_complexity >0. Returns: a `BoostedTreesClassifier` instance created with the given arguments and @@ -228,7 +240,7 @@ def boosted_trees_classifier_train_in_memory( # HParams for the model. tree_hparams = canned_boosted_trees._TreeHParams( n_trees, max_depth, learning_rate, l1_regularization, l2_regularization, - tree_complexity, min_node_weight, center_bias) + tree_complexity, min_node_weight, center_bias, pruning_mode) def _model_fn(features, labels, mode, config): return canned_boosted_trees._bt_model_fn( @@ -269,7 +281,8 @@ def boosted_trees_regressor_train_in_memory( min_node_weight=0., config=None, train_hooks=None, - center_bias=False): + center_bias=False, + pruning_mode='none'): """Trains a boosted tree regressor with in memory dataset. Example: @@ -341,6 +354,11 @@ def boosted_trees_regressor_train_in_memory( regression problems, the first node will return the mean of the labels. For binary classification problems, it will return a logit for a prior probability of label 1. + pruning_mode: one of 'none', 'pre', 'post' to indicate no pruning, pre- + pruning (do not split a node if not enough gain is observed) and post + pruning (build the tree up to a max depth and then prune branches with + negative gain). For pre and post pruning, you MUST provide + tree_complexity >0. Returns: a `BoostedTreesClassifier` instance created with the given arguments and @@ -360,7 +378,7 @@ def boosted_trees_regressor_train_in_memory( # HParams for the model. tree_hparams = canned_boosted_trees._TreeHParams( n_trees, max_depth, learning_rate, l1_regularization, l2_regularization, - tree_complexity, min_node_weight, center_bias) + tree_complexity, min_node_weight, center_bias, pruning_mode) def _model_fn(features, labels, mode, config): return canned_boosted_trees._bt_model_fn( diff --git a/tensorflow/contrib/estimator/python/estimator/boosted_trees_test.py b/tensorflow/contrib/estimator/python/estimator/boosted_trees_test.py index 999c2aa5e28242f996e12da3807a74c6acf31df9..b1581f37509b5dc2bec98942e88c024905f25d93 100644 --- a/tensorflow/contrib/estimator/python/estimator/boosted_trees_test.py +++ b/tensorflow/contrib/estimator/python/estimator/boosted_trees_test.py @@ -136,6 +136,49 @@ class BoostedTreesEstimatorTest(test_util.TensorFlowTestCase): eval_res = est.evaluate(input_fn=input_fn, steps=1) self.assertAllClose(eval_res['average_loss'], 0.614642) + def testTrainAndEvaluateEstimatorWithPrePruning(self): + input_fn = _make_train_input_fn(is_classification=False) + + est = boosted_trees._BoostedTreesEstimator( + feature_columns=self._feature_columns, + n_batches_per_layer=1, + n_trees=2, + head=self._head, + max_depth=5, + tree_complexity=0.001, + pruning_mode='pre') + + num_steps = 100 + # Train for a few steps, and validate final checkpoint. + est.train(input_fn, steps=num_steps) + # We stop actually after 2*depth*n_trees steps (via a hook) because we still + # could not grow 2 trees of depth 5 (due to pre-pruning). + self._assert_checkpoint( + est.model_dir, global_step=21, finalized_trees=0, attempted_layers=21) + eval_res = est.evaluate(input_fn=input_fn, steps=1) + self.assertAllClose(eval_res['average_loss'], 3.83943) + + def testTrainAndEvaluateEstimatorWithPostPruning(self): + input_fn = _make_train_input_fn(is_classification=False) + + est = boosted_trees._BoostedTreesEstimator( + feature_columns=self._feature_columns, + n_batches_per_layer=1, + n_trees=2, + head=self._head, + max_depth=5, + tree_complexity=0.001, + pruning_mode='post') + + # It will stop after 10 steps because of the max depth and num trees. + num_steps = 100 + # Train for a few steps, and validate final checkpoint. + est.train(input_fn, steps=num_steps) + self._assert_checkpoint( + est.model_dir, global_step=10, finalized_trees=2, attempted_layers=10) + eval_res = est.evaluate(input_fn=input_fn, steps=1) + self.assertAllClose(eval_res['average_loss'], 2.37652) + def testInferEstimator(self): train_input_fn = _make_train_input_fn(is_classification=False) predict_input_fn = numpy_io.numpy_input_fn( @@ -231,6 +274,31 @@ class BoostedTreesEstimatorTest(test_util.TensorFlowTestCase): self.assertAllClose([[0], [1], [1], [0], [0]], [pred['class_ids'] for pred in predictions]) + def testBinaryClassifierTrainInMemoryAndEvalAndInferWithPrePruning(self): + train_input_fn = _make_train_input_fn(is_classification=True) + predict_input_fn = numpy_io.numpy_input_fn( + x=FEATURES_DICT, y=None, batch_size=1, num_epochs=1, shuffle=False) + + est = boosted_trees.boosted_trees_classifier_train_in_memory( + train_input_fn=train_input_fn, + feature_columns=self._feature_columns, + n_trees=1, + max_depth=5, + pruning_mode='pre', + tree_complexity=0.01) + # We stop actually after 2*depth*n_trees steps (via a hook) because we still + # could not grow 1 trees of depth 5 (due to pre-pruning). + self._assert_checkpoint( + est.model_dir, global_step=11, finalized_trees=0, attempted_layers=11) + + # Check evaluate and predict. + eval_res = est.evaluate(input_fn=train_input_fn, steps=1) + self.assertAllClose(eval_res['accuracy'], 1.0) + # Validate predictions. + predictions = list(est.predict(input_fn=predict_input_fn)) + self.assertAllClose([[0], [1], [1], [0], [0]], + [pred['class_ids'] for pred in predictions]) + def testBinaryClassifierTrainInMemoryWithDataset(self): train_input_fn = _make_train_input_fn_dataset(is_classification=True) predict_input_fn = numpy_io.numpy_input_fn( diff --git a/tensorflow/contrib/estimator/python/estimator/dnn_linear_combined.py b/tensorflow/contrib/estimator/python/estimator/dnn_linear_combined.py index 2eef60c39f54bfb464b7da0eb57a47e9eee9b800..724bc2c82f8289bbaa19a1dbbc1dc81b6e158e02 100644 --- a/tensorflow/contrib/estimator/python/estimator/dnn_linear_combined.py +++ b/tensorflow/contrib/estimator/python/estimator/dnn_linear_combined.py @@ -147,7 +147,7 @@ class DNNLinearCombinedEstimator(estimator.Estimator): if a categorical column is multivalent. One of "mean", "sqrtn", and "sum" -- these are effectively different ways to do example-level normalization, which can be useful for bag-of-words features. For more - details, see @{tf.feature_column.linear_model$linear_model}. + details, see `tf.feature_column.linear_model`. Raises: ValueError: If both linear_feature_columns and dnn_features_columns are diff --git a/tensorflow/contrib/estimator/python/estimator/export.py b/tensorflow/contrib/estimator/python/estimator/export.py index 03cf6f107c1c5589522d7be4946562a466740b0e..b0deb9b494ab3ad0fe8c56967606e5e5952b7ccf 100644 --- a/tensorflow/contrib/estimator/python/estimator/export.py +++ b/tensorflow/contrib/estimator/python/estimator/export.py @@ -31,8 +31,8 @@ def export_saved_model_for_mode( # pylint: disable=line-too-long """Exports a single train/eval/predict graph as a SavedModel. - For a detailed guide, see - @{$saved_model#using_savedmodel_with_estimators$Using SavedModel with Estimators}. + For a detailed guide, see [Using SavedModel with Estimators]( + https://tensorflow.org/guide/saved_model#using_savedmodel_with_estimators). Sample usage: ```python diff --git a/tensorflow/contrib/estimator/python/estimator/exporter.py b/tensorflow/contrib/estimator/python/estimator/exporter.py new file mode 100644 index 0000000000000000000000000000000000000000..09d744060568e458a3af32e9d7497dbfbeec561e --- /dev/null +++ b/tensorflow/contrib/estimator/python/estimator/exporter.py @@ -0,0 +1,280 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Implements StepsExporter to export the model in user specified steps.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os + +from tensorflow.python.estimator import exporter +from tensorflow.python.framework import ops +from tensorflow.python.platform import gfile +from tensorflow.python.platform import tf_logging +from tensorflow.python.summary import summary_iterator + +DEFAULT_GLOBAL_STEP_KEY = ops.GraphKeys.GLOBAL_STEP + + +class StepsExporter(exporter.Exporter): + """This class exports the model in user specified steps. + + This class exports the model at the steps given by the `steps_to_keep` + argument. Each number in the list is treated as a lower bound for model + exports, to handle the case when evaluation is performed at different steps. + + Consider this example: + + ``` + steps_to_keep = [1, 2, 3, 6, 7, 10, 12, 25] + ``` + + The model is evaluated at step increments of 5: `[5, 10, 15, 20, 25, 30]`. + The `StepsExporter` will export the model when it has reached steps + `[5, 10, 15, 25]`. + + This example illustrates the two cases when the model is exported: + + 1. Model is evaluated on a step defined in the list `steps_to_keep`. + + In the example, the model is exported on step `10` and `25`. + + 2. Model is evaluated on a step not defined in the list `steps_to_keep`, but + is still exported because a step in `steps_to_keep` was missed. + + In the example, when the model reaches step `5`, the model is exported even + though `steps_to_keep` does not contain `5`. Step `5` is exported to make + up for step `3`, which was missed. Steps `1` and `2` in `steps_to_keep` are + skipped completely (e.g. say the model is evaluated at step `6`. It will + **not** be exported to make up for step `2`). + + Using the `steps_to_keep` list as a lower bound allows users to define + approximate step boundaries for exporting their models, and avoid frustrating + off-by-one calculation errors. + + Sample Use Cases: + There are specific points during the training when having a saved version of + the model would be useful. One example is at the end of each training phase + when the set of freezed weights is changed. + Another good use case is saving the model at the end of each epoch for + visualization or retraining. + """ + + def __init__(self, + steps_to_keep, + name='steps_exporter', + serving_input_receiver_fn=None, + event_file_pattern='eval/*.tfevents.*', + assets_extra=None, + as_text=False): + """Create an `StepsExporter` to use with `tf.estimator.EvalSpec`. + + Example of creating a StepsExporter for training and evaluation: + + ```python + categorical_feature_a = categorical_column_with_hash_bucket(...) + categorical_feature_b = categorical_column_with_hash_bucket(...) + + categorical_feature_a_emb = embedding_column( + categorical_column=categorical_feature_a, ...) + categorical_feature_b_emb = embedding_column( + categorical_column=categorical_feature_b, ...) + + estimator = tf.estimator.DNNClassifier( + feature_columns=[categorical_feature_a_emb, categorical_feature_b_emb], + hidden_units=[1024, 512, 256]) + + # Input pipeline for train and evaluate. + def train_input_fn: # returns x, y + # please shuffle the data. + pass + def eval_input_fn_eval: # returns x, y + pass + + exporter = tf.contrib.estimator.exporter.StepsExporter( + name="steps_exporter", + serving_input_receiver_fn=serving_input_receiver_fn, + event_file_pattern='eval/*.tfevents.*' + steps_to_keep=[...]) + + train_spec = tf.estimator.TrainSpec(input_fn=train_input_fn, max_steps=1000) + + eval_spec = [tf.estimator.EvalSpec( + input_fn=eval_input_fn, + steps=1, + exporters=exporter, + start_delay_secs=0, + throttle_secs=5)] + + tf.estimator.train_and_evaluate(estimator, train_spec, eval_spec) + + # Models will be exported to estimator.model_dir in timestamped directories, + # which can be used for serving, analysis with TFMA, or directly loaded in. + # For example: + export_dir = os.path.join(estimator.model_dir, + ) + + with ops.Graph().as_default() as graph: + with session.Session(graph=graph) as sess: + tf.saved_model.loader.load( + sess, [tf.saved_model.tag_constants.SERVING], export_dir) + + ``` + + Args: + steps_to_keep: Non-empty list of positive integers containing + the step numbers at which the model should be exported. All the exports + will be kept, so there is no garbage collection. + name: Unique name of this `Exporter` that is going to be used in the + export path. + serving_input_receiver_fn: A function that takes no arguments and returns + a `ServingInputReceiver`. + event_file_pattern: Event file name pattern relative to model_dir. If + None, however, the exporter would not be preemption-safe. To be + preemption-safe, event_file_pattern should be specified. + assets_extra: An optional dict specifying how to populate the assets.extra + directory within the exported SavedModel. Each key should give the + destination path (including the filename) relative to the assets.extra + directory. The corresponding value gives the full path of the source + file to be copied. For example, the simple case of copying a single + file without renaming it is specified as `{'my_asset_file.txt': + '/path/to/my_asset_file.txt'}`. + as_text: Whether to write the SavedModel proto in text format. Defaults to + `False`. + + Raises: + ValueError: If any arguments is invalid. + """ + # pylint: disable=protected-access + self._saved_model_exporter = exporter._SavedModelExporter( + name, serving_input_receiver_fn, assets_extra, as_text) + # pylint: enable=protected-access + + self._event_file_pattern = event_file_pattern + self._model_dir = None + + self._input_steps_to_keep = steps_to_keep + steps_to_keep = [step for step in steps_to_keep if isinstance(step, int)] + steps_to_keep = [step for step in steps_to_keep if step > 0] + if not steps_to_keep: + raise ValueError( + '`steps_to_keep` list must have at least one positive integer') + elif self._input_steps_to_keep != steps_to_keep: + tf_logging.warn('Changed `steps_to_keep`, by omitting non-integer or' + ' less than 1 elements, to [%s]', + ', '.join(str(step) for step in steps_to_keep)) + self._steps_to_keep = sorted(steps_to_keep) + self._steps_kept = [] + + @property + def name(self): + return self._saved_model_exporter.name + + def export(self, estimator, export_path, checkpoint_path, eval_result, + is_the_final_export): + """Exports the given Estimator to a specific format. + + Args: + estimator: A `tf.estimator.Estimator` instance to export. + export_path: A string containing a directory where to write the export. + checkpoint_path: The checkpoint path to export. + eval_result: The output of Estimator.evaluate on this checkpoint. + is_the_final_export: This boolean is True when this is an export in the + end of training. It is False for the intermediate exports during the + training. When passing Exporter to tf.estimator.train_and_evaluate + is_the_final_export is always False if TrainSpec.max_steps is None. + + Returns: + The string path to the exported directory or None if export is skipped. + + Raises: + ValueError: If `eval_result` is None or doesn't have + `ops.GraphKeys.GLOBAL_STEP` as a key. + """ + export_result = None + + if not eval_result or DEFAULT_GLOBAL_STEP_KEY not in eval_result: + raise ValueError( + '`eval_result` is empty, or does not have global step. This' + ' should never happen as Estimator always sets the global step in ' + '`eval_result`. Please file a bug report. Got eval_result: %s' + % str(eval_result)) + + if self._model_dir != estimator.model_dir and self._event_file_pattern: + tf_logging.info('Loads the steps that the model was already evaluated at,' + 'from event files') + self._model_dir = estimator.model_dir + full_event_file_pattern = os.path.join(self._model_dir, + self._event_file_pattern) + self._steps_kept = self._get_kept_steps(full_event_file_pattern) + + if self._steps_kept: + self._steps_kept = sorted(self._steps_kept) + self._steps_to_keep = [step for step in self._steps_to_keep if + step > self._steps_kept[-1]] + # It is assumed that the model is exported at any evaluated step 'n' if + # there is any `steps_missed` lower than 'n'. As a result, all the steps in + # `_steps_to_keep` lower than the last evaluated step will be removed. + steps_missed = [step for step in self._steps_to_keep + if step <= eval_result[DEFAULT_GLOBAL_STEP_KEY]] + + if steps_missed: + # update the `_steps_to_keep` list by omitting all steps smaller than the + # current global step which are missed to be exported + export_result = self._saved_model_exporter.export(estimator, export_path, + checkpoint_path, + eval_result, + is_the_final_export) + self._steps_to_keep = [step for step in self._steps_to_keep if step + not in steps_missed] + # contains all the steps in which export has happened. + self._steps_kept.append(eval_result[DEFAULT_GLOBAL_STEP_KEY]) + # Show warning for all the missed steps except the last one + if steps_missed[:-1]: + tf_logging.warn('Missed steps [%s] for exporting, as no evaluation' + ' took place at them.', ', '.join(str(step) for step in + steps_missed[:-1])) + # Log model export if the last missed step is the same as the current step + if steps_missed[-1] == eval_result[DEFAULT_GLOBAL_STEP_KEY]: + tf_logging.info('Performing model export at step %d.', + eval_result[DEFAULT_GLOBAL_STEP_KEY]) + # Show warning for exporting model at another step instead of the user + # specified one + else: + tf_logging.warn('Performing model export at step %d instead of %d, as' + ' no evaluation took place at step %d.', + eval_result[DEFAULT_GLOBAL_STEP_KEY], steps_missed[-1], + steps_missed[-1]) + return export_result + + def _get_kept_steps(self, event_files): + """Get the steps that the model was evaluated at, from event files. + + Args: + event_files: Absolute pattern of event files. + + Returns: + steps_kept: A list of steps in which the model was evaluated. + """ + if not event_files: + return None + + steps_kept = [] + for event_file in gfile.Glob(os.path.join(event_files)): + for event in summary_iterator.summary_iterator(event_file): + if event.step not in steps_kept: + steps_kept.append(event.step) + return steps_kept diff --git a/tensorflow/contrib/estimator/python/estimator/exporter_test.py b/tensorflow/contrib/estimator/python/estimator/exporter_test.py new file mode 100644 index 0000000000000000000000000000000000000000..0d009b945e748394074a7278833abb1e12b15e7b --- /dev/null +++ b/tensorflow/contrib/estimator/python/estimator/exporter_test.py @@ -0,0 +1,206 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for `StepsExporter`.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +import shutil +import tempfile + +from tensorflow.contrib.estimator.python.estimator import exporter as exporter_lib +from tensorflow.python.estimator import estimator as estimator_lib +from tensorflow.python.platform import gfile +from tensorflow.python.platform import test + + +class StepsExporterTest(test.TestCase): + + def test_error_out_if_steps_to_keep_has_no_positive_integers(self): + + def _serving_input_receiver_fn(): + pass + + with self.assertRaisesRegexp(ValueError, "positive integer"): + exporter = exporter_lib.StepsExporter( + name="specified_steps_exporter", + serving_input_receiver_fn=_serving_input_receiver_fn, + steps_to_keep=[-1, 0, 1.1]) + self.assertEqual("specified_steps_exporter", exporter.name) + + def test_steps_exporter(self): + + def _serving_input_receiver_fn(): + pass + + export_dir_base = tempfile.mkdtemp() + gfile.MkDir(export_dir_base) + gfile.MkDir(export_dir_base + "/export") + gfile.MkDir(export_dir_base + "/eval") + + exporter = exporter_lib.StepsExporter( + name="steps_exporter", + serving_input_receiver_fn=_serving_input_receiver_fn, + assets_extra={"from/path": "to/path"}, + as_text=False, + steps_to_keep=[1]) + estimator = test.mock.Mock(spec=estimator_lib.Estimator) + estimator.export_savedmodel.return_value = "export_result_path" + estimator.model_dir = export_dir_base + + export_result = exporter.export(estimator, export_dir_base, + "checkpoint_path", {"global_step": 1}, + False) + + self.assertEqual("export_result_path", export_result) + estimator.export_savedmodel.assert_called_with( + export_dir_base, + _serving_input_receiver_fn, + assets_extra={"from/path": "to/path"}, + as_text=False, + checkpoint_path="checkpoint_path", + strip_default_attrs=True) + + shutil.rmtree(export_dir_base, ignore_errors=True) + + def test_steps_exporter_with_preemption(self): + + def _serving_input_receiver_fn(): + pass + + export_dir_base = tempfile.mkdtemp() + gfile.MkDir(export_dir_base) + gfile.MkDir(export_dir_base + "/export") + gfile.MkDir(export_dir_base + "/eval") + + eval_dir_base = os.path.join(export_dir_base, "eval_continuous") + estimator_lib._write_dict_to_summary(eval_dir_base, {}, 1) + estimator_lib._write_dict_to_summary(eval_dir_base, {}, 2) + + exporter = exporter_lib.StepsExporter( + name="steps_exporter", + serving_input_receiver_fn=_serving_input_receiver_fn, + event_file_pattern="eval_continuous/*.tfevents.*", + assets_extra={"from/path": "to/path"}, + as_text=False, + steps_to_keep=[1, 2, 6, 8]) + + estimator = test.mock.Mock(spec=estimator_lib.Estimator) + estimator.model_dir = export_dir_base + estimator.export_savedmodel.return_value = "export_result_path" + + export_result = exporter.export(estimator, export_dir_base, + "checkpoint_path", {"global_step": 3}, + False) + self.assertEqual(None, export_result) + + export_result = exporter.export(estimator, export_dir_base, + "checkpoint_path", {"global_step": 6}, + False) + self.assertEqual("export_result_path", export_result) + + export_result = exporter.export(estimator, export_dir_base, + "checkpoint_path", {"global_step": 7}, + False) + self.assertEqual(None, export_result) + + shutil.rmtree(export_dir_base, ignore_errors=True) + + def test_specified_step_is_saved(self): + + def _serving_input_receiver_fn(): + pass + + export_dir_base = tempfile.mkdtemp() + gfile.MkDir(export_dir_base) + gfile.MkDir(export_dir_base + "/export") + gfile.MkDir(export_dir_base + "/eval") + + exporter = exporter_lib.StepsExporter( + name="steps_exporter", + serving_input_receiver_fn=_serving_input_receiver_fn, + assets_extra={"from/path": "to/path"}, + as_text=False, + steps_to_keep=[1, 5, 8, 10, 11]) + estimator = test.mock.Mock(spec=estimator_lib.Estimator) + estimator.export_savedmodel.return_value = "export_result_path" + estimator.model_dir = export_dir_base + + export_result = exporter.export(estimator, export_dir_base, + "checkpoint_path", {"global_step": 1}, + False) + + self.assertTrue(estimator.export_savedmodel.called) + self.assertEqual("export_result_path", export_result) + + export_result = exporter.export(estimator, export_dir_base, + "checkpoint_path", {"global_step": 2}, + False) + self.assertEqual(None, export_result) + + export_result = exporter.export(estimator, export_dir_base, + "checkpoint_path", {"global_step": 5}, + False) + self.assertTrue(estimator.export_savedmodel.called) + self.assertEqual("export_result_path", export_result) + + export_result = exporter.export(estimator, export_dir_base, + "checkpoint_path", {"global_step": 10}, + False) + self.assertTrue(estimator.export_savedmodel.called) + self.assertEqual("export_result_path", export_result) + + export_result = exporter.export(estimator, export_dir_base, + "checkpoint_path", {"global_step": 15}, + False) + self.assertTrue(estimator.export_savedmodel.called) + self.assertEqual("export_result_path", export_result) + + export_result = exporter.export(estimator, export_dir_base, + "checkpoint_path", {"global_step": 20}, + False) + self.assertEqual(None, export_result) + + shutil.rmtree(export_dir_base, ignore_errors=True) + + def test_steps_exporter_with_no_global_step_key(self): + + def _serving_input_receiver_fn(): + pass + + export_dir_base = tempfile.mkdtemp() + gfile.MkDir(export_dir_base) + gfile.MkDir(export_dir_base + "/export") + gfile.MkDir(export_dir_base + "/eval") + + exporter = exporter_lib.StepsExporter( + name="steps_exporter", + serving_input_receiver_fn=_serving_input_receiver_fn, + assets_extra={"from/path": "to/path"}, + as_text=False, + steps_to_keep=[1]) + estimator = test.mock.Mock(spec=estimator_lib.Estimator) + estimator.export_savedmodel.return_value = "export_result_path" + estimator.model_dir = export_dir_base + + with self.assertRaisesRegexp(ValueError, "does not have global step"): + exporter.export(estimator, export_dir_base, "checkpoint_path", {}, False) + + shutil.rmtree(export_dir_base, ignore_errors=True) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/estimator/python/estimator/extenders.py b/tensorflow/contrib/estimator/python/estimator/extenders.py index bf08be09e7baf63e507a6a4db6a91e7b6bb20b74..26449b46516fe1d8c93a8e3567f93801c689a65a 100644 --- a/tensorflow/contrib/estimator/python/estimator/extenders.py +++ b/tensorflow/contrib/estimator/python/estimator/extenders.py @@ -34,7 +34,7 @@ _VALID_METRIC_FN_ARGS = set(['features', 'labels', 'predictions', 'config']) def add_metrics(estimator, metric_fn): - """Creates a new @{tf.estimator.Estimator} which has given metrics. + """Creates a new `tf.estimator.Estimator` which has given metrics. Example: @@ -61,7 +61,7 @@ def add_metrics(estimator, metric_fn): ``` Args: - estimator: A @{tf.estimator.Estimator} object. + estimator: A `tf.estimator.Estimator` object. metric_fn: A function which should obey the following signature: - Args: can only have following four arguments in any order: * predictions: Predictions `Tensor` or dict of `Tensor` created by given @@ -79,7 +79,7 @@ def add_metrics(estimator, metric_fn): function, namely a `(metric_tensor, update_op)` tuple. Returns: - A new @{tf.estimator.Estimator} which has a union of original metrics with + A new `tf.estimator.Estimator` which has a union of original metrics with given ones. """ _verify_metric_fn_args(metric_fn) @@ -165,14 +165,14 @@ def forward_features(estimator, keys=None): ``` Args: - estimator: A @{tf.estimator.Estimator} object. + estimator: A `tf.estimator.Estimator` object. keys: a `string` or a `list` of `string`. If it is `None`, all of the `features` in `dict` is forwarded to the `predictions`. If it is a `string`, only given key is forwarded. If it is a `list` of strings, all the given `keys` are forwarded. Returns: - A new @{tf.estimator.Estimator} which forwards features to predictions. + A new `tf.estimator.Estimator` which forwards features to predictions. Raises: ValueError: diff --git a/tensorflow/contrib/estimator/python/estimator/hooks.py b/tensorflow/contrib/estimator/python/estimator/hooks.py index caadafdfa6972c141d32a705e62a98d220cace41..66c46e66b77e8819268f7fe084abdc785077f116 100644 --- a/tensorflow/contrib/estimator/python/estimator/hooks.py +++ b/tensorflow/contrib/estimator/python/estimator/hooks.py @@ -19,6 +19,7 @@ from __future__ import division from __future__ import print_function import os +import time from tensorflow.python.estimator import estimator as estimator_lib from tensorflow.python.framework import ops @@ -26,6 +27,7 @@ from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import state_ops from tensorflow.python.training import training +from tensorflow.python.training import training_util # pylint: disable=protected-access @@ -72,8 +74,9 @@ class InMemoryEvaluatorHook(training.SessionRunHook): estimator: A `tf.estimator.Estimator` instance to call evaluate. input_fn: Equivalent to the `input_fn` arg to `estimator.evaluate`. A function that constructs the input data for evaluation. - See @{$premade_estimators#create_input_functions} for more - information. The function should construct and return one of + See [Createing input functions]( + https://tensorflow.org/guide/premade_estimators#create_input_functions) + for more information. The function should construct and return one of the following: * A 'tf.data.Dataset' object: Outputs of `Dataset` object must be a @@ -210,4 +213,72 @@ class InMemoryEvaluatorHook(training.SessionRunHook): self._evaluate(session) +class _StopAtCheckpointStepHook(training.SessionRunHook): + """Hook that requests stop at a specified step based on checkpoint. + + Note: We recommend using 'make_stop_at_checkpoint_step_hook` to get the proper + hook. + """ + + def __init__(self, model_dir, last_step, + wait_after_file_check_secs=30): + """Initializes a `StopAtCheckpointStepHook`. + + This hook requests stop after a last step has been reached. It checks latest + checkpoint to verify last step is written on disk or not. + + Args: + model_dir: Directory to read global step from latest checkpoint. + last_step: Step after which to stop. + wait_after_file_check_secs: Reading same file by many workers may create + I/O issues. To throttle that we will wait given secs after each read of + the file. + + Raises: + ValueError: If one of the arguments is invalid. + """ + if last_step is None: + raise ValueError('last_step must be specified.') + if model_dir is None: + raise ValueError('model_dir must be specified.') + + self._model_dir = model_dir + self._last_step = last_step + self._wait_after_file_check_secs = wait_after_file_check_secs + + def begin(self): + self._global_step_tensor = training_util._get_or_create_global_step_read() # pylint: disable=protected-access + if self._global_step_tensor is None: + raise RuntimeError( + 'Global step should be created to use StopAtCheckpointStepHook.') + + def before_run(self, run_context): # pylint: disable=unused-argument + return training.SessionRunArgs(self._global_step_tensor) + + def after_run(self, run_context, run_values): + global_step = run_values.results + 1 + if global_step >= self._last_step: + # Check latest global step in the checkpoint to ensure that the targeted + # last step is written on disk. + + step = estimator_lib._load_global_step_from_checkpoint_dir( + self._model_dir) + if step >= self._last_step: + run_context.request_stop() + else: + time.sleep(self._wait_after_file_check_secs) + + +def make_stop_at_checkpoint_step_hook(estimator, + last_step, + wait_after_file_check_secs=30): + """Creates a proper StopAtCheckpointStepHook based on chief status.""" + + if estimator.config.is_chief: + return training.StopAtStepHook(last_step=last_step) + return _StopAtCheckpointStepHook( + model_dir=estimator.model_dir, + last_step=last_step, + wait_after_file_check_secs=wait_after_file_check_secs) + # pylint: enable=protected-access diff --git a/tensorflow/contrib/estimator/python/estimator/hooks_test.py b/tensorflow/contrib/estimator/python/estimator/hooks_test.py index ee88d5ecf50aa15b2faa0f3e136c686b5b0ef62a..c6c6cad95a7575224c47bb5ec36e243691fed371 100644 --- a/tensorflow/contrib/estimator/python/estimator/hooks_test.py +++ b/tensorflow/contrib/estimator/python/estimator/hooks_test.py @@ -21,8 +21,11 @@ from __future__ import print_function import glob import json import os +import tempfile +import time from tensorflow.contrib.estimator.python.estimator import hooks as hooks_lib +from tensorflow.python.client import session as tf_session from tensorflow.python.data.ops import dataset_ops from tensorflow.python.estimator import estimator_lib from tensorflow.python.estimator import run_config as run_config_lib @@ -316,5 +319,85 @@ class InMemoryEvaluatorHookTest(test.TestCase): estimator.train(input_fn, hooks=[evaluator]) +class StopAtCheckpointStepHookTest(test.TestCase): + + def test_do_not_stop_if_checkpoint_is_not_there(self): + with ops.Graph().as_default(): + step = training.create_global_step() + assign_ten = step.assign(10) + no_op = control_flow_ops.no_op() + hook = hooks_lib._StopAtCheckpointStepHook( + model_dir=tempfile.mkdtemp(), last_step=10) + with training.SingularMonitoredSession(hooks=[hook]) as mon_sess: + mon_sess.raw_session().run(assign_ten) + with test.mock.patch.object(time, 'sleep') as mock_sleep: + mon_sess.run(no_op) + self.assertTrue(mock_sleep.called) + self.assertFalse(mon_sess.should_stop()) + + def test_do_not_stop_if_checkpoint_step_is_smaller(self): + model_dir = tempfile.mkdtemp() + with ops.Graph().as_default(): + step = training.create_global_step() + assign_nine = step.assign(9) + assign_ten = step.assign(10) + no_op = control_flow_ops.no_op() + hook = hooks_lib._StopAtCheckpointStepHook( + model_dir=model_dir, last_step=10) + with tf_session.Session() as sess: + sess.run(assign_nine) + training.Saver().save(sess, os.path.join(model_dir, 'model.ckpt')) + with training.SingularMonitoredSession(hooks=[hook]) as mon_sess: + mon_sess.raw_session().run(assign_ten) + with test.mock.patch.object(time, 'sleep') as mock_sleep: + mon_sess.run(no_op) + self.assertTrue(mock_sleep.called) + self.assertFalse(mon_sess.should_stop()) + + def test_stop_if_checkpoint_step_is_laststep(self): + model_dir = tempfile.mkdtemp() + with ops.Graph().as_default(): + step = training.create_global_step() + assign_ten = step.assign(10) + no_op = control_flow_ops.no_op() + hook = hooks_lib._StopAtCheckpointStepHook( + model_dir=model_dir, last_step=10) + with tf_session.Session() as sess: + sess.run(assign_ten) + training.Saver().save(sess, os.path.join(model_dir, 'model.ckpt')) + with training.SingularMonitoredSession(hooks=[hook]) as mon_sess: + mon_sess.raw_session().run(assign_ten) + with test.mock.patch.object(time, 'sleep') as mock_sleep: + mon_sess.run(no_op) + self.assertFalse(mock_sleep.called) + self.assertTrue(mon_sess.should_stop()) + + def test_creates_regular_stop_at_step_hook_for_chief(self): + # by default an estimator is in chief mode + dnn = estimator_lib.DNNClassifier( + feature_columns=[feature_column_lib.numeric_column('x')], + hidden_units=[3, 1]) + hook = hooks_lib.make_stop_at_checkpoint_step_hook(dnn, 300) + self.assertIsInstance(hook, training.StopAtStepHook) + self.assertEqual(300, hook._last_step) + + def test_creates_checkpoint_hook_for_workers(self): + + class FakeWorkerConfig(estimator_lib.RunConfig): + + @property + def is_chief(self): + return False + + dnn = estimator_lib.DNNClassifier( + feature_columns=[feature_column_lib.numeric_column('x')], + hidden_units=[3, 1], + config=FakeWorkerConfig()) + hook = hooks_lib.make_stop_at_checkpoint_step_hook(dnn, 300) + self.assertIsInstance(hook, hooks_lib._StopAtCheckpointStepHook) + self.assertEqual(300, hook._last_step) + self.assertEqual(dnn.model_dir, hook._model_dir) + + if __name__ == '__main__': test.main() diff --git a/tensorflow/contrib/estimator/python/estimator/linear.py b/tensorflow/contrib/estimator/python/estimator/linear.py index 62a37abefb1f6ed291df1df3da6de35bfd2b6c52..2b68f24eb2d4c528bc1cb87e7d858014f66c0433 100644 --- a/tensorflow/contrib/estimator/python/estimator/linear.py +++ b/tensorflow/contrib/estimator/python/estimator/linear.py @@ -121,7 +121,7 @@ class LinearEstimator(estimator.Estimator): is multivalent. One of "mean", "sqrtn", and "sum" -- these are effectively different ways to do example-level normalization, which can be useful for bag-of-words features. for more details, see - @{tf.feature_column.linear_model$linear_model}. + `tf.feature_column.linear_model`. """ def _model_fn(features, labels, mode, config): return linear_lib._linear_model_fn( # pylint: disable=protected-access diff --git a/tensorflow/contrib/estimator/python/estimator/saved_model_estimator.py b/tensorflow/contrib/estimator/python/estimator/saved_model_estimator.py index f3d0f6b0470bbbe148d251e8d2ab20d8e3c3d01c..ce98e9987ec728fadf170e56fe4bfe24fc9a0105 100644 --- a/tensorflow/contrib/estimator/python/estimator/saved_model_estimator.py +++ b/tensorflow/contrib/estimator/python/estimator/saved_model_estimator.py @@ -46,6 +46,7 @@ class SavedModelEstimator(estimator_lib.Estimator): Example with `tf.estimator.DNNClassifier`: **Step 1: Create and train DNNClassifier.** + ```python feature1 = tf.feature_column.embedding_column( tf.feature_column.categorical_column_with_vocabulary_list( @@ -66,13 +67,14 @@ class SavedModelEstimator(estimator_lib.Estimator): **Step 2: Export classifier.** First, build functions that specify the expected inputs. + ```python # During train and evaluation, both the features and labels should be defined. supervised_input_receiver_fn = ( tf.contrib.estimator.build_raw_supervised_input_receiver_fn( - {'feature1': tf.placeholder(dtype=tf.string, shape=[None]), - 'feature2': tf.placeholder(dtype=tf.float32, shape=[None])}, - tf.placeholder(dtype=tf.float32, shape=[None]))) + {'feature1': tf.placeholder(dtype=tf.string, shape=[None]), + 'feature2': tf.placeholder(dtype=tf.float32, shape=[None])}, + tf.placeholder(dtype=tf.float32, shape=[None]))) # During predict mode, expect to receive a `tf.Example` proto, so a parsing # function is used. @@ -83,6 +85,7 @@ class SavedModelEstimator(estimator_lib.Estimator): Next, export the model as a SavedModel. A timestamped directory will be created (for example `/tmp/export_all/1234567890`). + ```python # Option 1: Save all modes (train, eval, predict) export_dir = tf.contrib.estimator.export_all_saved_models( @@ -93,10 +96,11 @@ class SavedModelEstimator(estimator_lib.Estimator): # Option 2: Only export predict mode export_dir = classifier.export_savedmodel( - '/tmp/export_predict', serving_input_receiver_fn) + '/tmp/export_predict', serving_input_receiver_fn) ``` **Step 3: Create a SavedModelEstimator from the exported SavedModel.** + ```python est = tf.contrib.estimator.SavedModelEstimator(export_dir) @@ -108,7 +112,7 @@ class SavedModelEstimator(estimator_lib.Estimator): est.train(input_fn=input_fn, steps=20) def predict_input_fn(): - example = example_pb2.Example() + example = tf.train.Example() example.features.feature['feature1'].bytes_list.value.extend(['yellow']) example.features.feature['feature2'].float_list.value.extend([1.]) return {'inputs':tf.constant([example.SerializeToString()])} @@ -144,7 +148,7 @@ class SavedModelEstimator(estimator_lib.Estimator): super(SavedModelEstimator, self).__init__( model_fn=self._model_fn_from_saved_model, model_dir=model_dir, warm_start_from=warm_start_settings) - if self._distribution is not None: + if self._train_distribution or self._eval_distribution: raise NotImplementedError( 'SavedModelEstimator currently does not support ' 'DistributionStrategy.') diff --git a/tensorflow/contrib/factorization/BUILD b/tensorflow/contrib/factorization/BUILD index effec42f028fe472593a8d06e15a0831346d6f50..9e1f14f9905d584287864c15d9b6f9c152d17787 100644 --- a/tensorflow/contrib/factorization/BUILD +++ b/tensorflow/contrib/factorization/BUILD @@ -65,7 +65,7 @@ tf_custom_op_py_library( "//tensorflow/python:variable_scope", "//tensorflow/python:variables", "//tensorflow/python/estimator", - "//tensorflow/python/estimator:model_fn", + "//tensorflow/python/estimator:estimator_py", "//tensorflow/python/feature_column:feature_column_py", "//third_party/py/numpy", ], @@ -242,7 +242,7 @@ py_test( "//tensorflow/python:platform_benchmark", "//tensorflow/python:random_ops", "//tensorflow/python:training", - "//tensorflow/python/estimator:run_config", + "//tensorflow/python/estimator:estimator_py", "//tensorflow/python/feature_column:feature_column_py", "//third_party/py/numpy", ], diff --git a/tensorflow/contrib/factorization/python/ops/kmeans.py b/tensorflow/contrib/factorization/python/ops/kmeans.py index 9ffdd3ba5e8ac496533d0207f2b6846dbc92bc89..f384d761a8430074f022c973d7ec3d46cd90f70b 100644 --- a/tensorflow/contrib/factorization/python/ops/kmeans.py +++ b/tensorflow/contrib/factorization/python/ops/kmeans.py @@ -158,12 +158,12 @@ class _ModelFn(object): return either `features` or, equivalently, `(features, None)`. Args: - features: The input points. See @{tf.estimator.Estimator}. - mode: See @{tf.estimator.Estimator}. - config: See @{tf.estimator.Estimator}. + features: The input points. See `tf.estimator.Estimator`. + mode: See `tf.estimator.Estimator`. + config: See `tf.estimator.Estimator`. Returns: - A @{tf.estimator.EstimatorSpec} (see @{tf.estimator.Estimator}) specifying + A `tf.estimator.EstimatorSpec` (see `tf.estimator.Estimator`) specifying this behavior: * `train_op`: Execute one mini-batch or full-batch run of Lloyd's algorithm. @@ -188,7 +188,6 @@ class _ModelFn(object): # center. # is_initialized: scalar indicating whether the initial cluster centers # have been chosen; see init_op. - # cluster_centers_var: a Variable containing the cluster centers. # init_op: an op to choose the initial cluster centers. A single worker # repeatedly executes init_op until is_initialized becomes True. # training_op: an op that runs an iteration of training, either an entire @@ -394,7 +393,7 @@ class KMeansClustering(estimator.Estimator): relative_tolerance: A relative tolerance of change in the loss between iterations. Stops learning if the loss changes less than this amount. This may not work correctly if `use_mini_batch=True`. - config: See @{tf.estimator.Estimator}. + config: See `tf.estimator.Estimator`. feature_columns: An optionable iterable containing all the feature columns used by the model. All items in the set should be feature column instances that can be passed to `tf.feature_column.input_layer`. If this @@ -431,7 +430,7 @@ class KMeansClustering(estimator.Estimator): """Finds the index of the closest cluster center to each input point. Args: - input_fn: Input points. See @{tf.estimator.Estimator.predict}. + input_fn: Input points. See `tf.estimator.Estimator.predict`. Yields: The index of the closest cluster center for each input point. @@ -447,7 +446,7 @@ class KMeansClustering(estimator.Estimator): which returns the negative sum. Args: - input_fn: Input points. See @{tf.estimator.Estimator.evaluate}. Only one + input_fn: Input points. See `tf.estimator.Estimator.evaluate`. Only one batch is retrieved. Returns: @@ -465,7 +464,7 @@ class KMeansClustering(estimator.Estimator): sklearn function returns the Euclidean distance. Args: - input_fn: Input points. See @{tf.estimator.Estimator.predict}. + input_fn: Input points. See `tf.estimator.Estimator.predict`. Yields: The distances from each input point to each cluster center. diff --git a/tensorflow/contrib/ffmpeg/__init__.py b/tensorflow/contrib/ffmpeg/__init__.py index 484ffee3e7afe55c63cab2a463454353b2663e18..3a756da932b92d9ff974460773e34bcf25d04e6f 100644 --- a/tensorflow/contrib/ffmpeg/__init__.py +++ b/tensorflow/contrib/ffmpeg/__init__.py @@ -15,7 +15,7 @@ # pylint: disable=g-short-docstring-punctuation """Working with audio using FFmpeg. -See the @{$python/contrib.ffmpeg} guide. +See the [FFMPEG](https://tensorflow.org/api_guides/python/contrib.ffmpeg) guide. @@decode_audio @@encode_audio diff --git a/tensorflow/contrib/framework/__init__.py b/tensorflow/contrib/framework/__init__.py index dc49383c5c300e82839c478e097074b3e8776b3b..95f5ba90aba6ff8d3f1f5b93bde2211ddf1c231b 100644 --- a/tensorflow/contrib/framework/__init__.py +++ b/tensorflow/contrib/framework/__init__.py @@ -15,7 +15,9 @@ """Framework utilities. -See the @{$python/contrib.framework} guide. +See the +[Contrib Framework](https://tensorflow.org/api_guides/python/contrib.framework) +guide. @@assert_same_float_dtype @@assert_scalar @@ -100,6 +102,8 @@ See the @{$python/contrib.framework} guide. @@BoundedTensorSpec @@TensorSpec + +@@RecordInput """ from __future__ import absolute_import @@ -119,6 +123,7 @@ from tensorflow.python.framework.smart_cond import smart_cond from tensorflow.python.framework.smart_cond import smart_constant_value from tensorflow.python.framework.tensor_spec import BoundedTensorSpec from tensorflow.python.framework.tensor_spec import TensorSpec +from tensorflow.python.ops.data_flow_ops import RecordInput from tensorflow.python.ops.init_ops import convolutional_delta_orthogonal from tensorflow.python.ops.init_ops import convolutional_orthogonal_1d from tensorflow.python.ops.init_ops import convolutional_orthogonal_2d @@ -133,6 +138,7 @@ _nest_allowed_symbols = [ 'flatten_dict_items', 'pack_sequence_as', 'map_structure', + 'map_structure_with_paths', 'assert_shallow_structure', 'flatten_up_to', 'map_structure_up_to', diff --git a/tensorflow/contrib/framework/python/framework/checkpoint_utils.py b/tensorflow/contrib/framework/python/framework/checkpoint_utils.py index 9e356dd96562c28adec7fc28fe144394e1c2ed38..e7184a01fbf57319399fc6dd287b7387138b4058 100644 --- a/tensorflow/contrib/framework/python/framework/checkpoint_utils.py +++ b/tensorflow/contrib/framework/python/framework/checkpoint_utils.py @@ -27,7 +27,7 @@ from tensorflow.python.ops import variable_scope as vs from tensorflow.python.ops import variables from tensorflow.python.platform import gfile from tensorflow.python.platform import tf_logging as logging -from tensorflow.python.training import saver +from tensorflow.python.training import checkpoint_management from tensorflow.python.training import training as train __all__ = [ @@ -40,7 +40,7 @@ __all__ = [ def _get_checkpoint_filename(filepattern): """Returns checkpoint filename given directory or specific filepattern.""" if gfile.IsDirectory(filepattern): - return saver.latest_checkpoint(filepattern) + return checkpoint_management.latest_checkpoint(filepattern) return filepattern diff --git a/tensorflow/contrib/framework/python/ops/arg_scope.py b/tensorflow/contrib/framework/python/ops/arg_scope.py index 5b150339953f961c756c0909dd1795341159b9cd..0a02e76a265c8ad25d978e7d610fb50fc0fdfdb1 100644 --- a/tensorflow/contrib/framework/python/ops/arg_scope.py +++ b/tensorflow/contrib/framework/python/ops/arg_scope.py @@ -103,9 +103,8 @@ def _kwarg_names(func): def _add_op(op): - key = arg_scope_func_key(op) - if key not in _DECORATED_OPS: - _DECORATED_OPS[key] = _kwarg_names(op) + key_op = arg_scope_func_key(op) + _DECORATED_OPS[key_op] = _kwarg_names(op) @tf_contextlib.contextmanager diff --git a/tensorflow/contrib/framework/python/ops/arg_scope_test.py b/tensorflow/contrib/framework/python/ops/arg_scope_test.py index 4c3879d4fc08b53ea8be5f1256a830a64fb39af6..bcafc1a3280ba0435f655eacb8173e4e97051154 100644 --- a/tensorflow/contrib/framework/python/ops/arg_scope_test.py +++ b/tensorflow/contrib/framework/python/ops/arg_scope_test.py @@ -38,6 +38,12 @@ def func3(args, a=None, b=1, c=2): """Some cool doc string.""" return (args, a, b, c) +@add_arg_scope +def func4(x='x', y='y'): + if x: + pass + if y: + pass def _key_op(op): return getattr(op, '_key_op', str(op)) @@ -231,6 +237,15 @@ class ArgScopeTest(test.TestCase): self.assertTupleEqual(args, func2_args) self.assertDictEqual(kwargs, func2_kwargs) + def testAddArgScopeRaceCondition(self): + func4_kwargs = ('a', 'b', 'c', 'd', 'e', 'f', 'g', 'h') + for i in range(4): + # redefine the function with different args + @add_arg_scope + def func4(a=1, b=2, c=3, d=4, e=5, f=6, g=7, h=8): + pass + self.assertTupleEqual(arg_scoped_arguments(func4), func4_kwargs) + def testDocString(self): self.assertEqual(func3.__doc__, 'Some cool doc string.') diff --git a/tensorflow/contrib/framework/python/ops/critical_section_ops.py b/tensorflow/contrib/framework/python/ops/critical_section_ops.py index 72835c3ad86e6321eb30324c7dd0751034759ce4..71ab755aa2948c548db89b330bb93c9524412fa6 100644 --- a/tensorflow/contrib/framework/python/ops/critical_section_ops.py +++ b/tensorflow/contrib/framework/python/ops/critical_section_ops.py @@ -325,6 +325,8 @@ class CriticalSection(object): def _is_self_handle(self, x): """Check if the tensor `x` is the same Mutex as `self._handle`.""" + if isinstance(x, ops.EagerTensor): + return x is self._handle return (x.op.type == "MutexV2" # blank shared_name means the op will create a unique one. and x.op.get_attr("shared_name") @@ -365,8 +367,7 @@ class CriticalSection(object): "(CriticalSection: %s) requested exclusive resource access " "of this resource. Did you mean to call execute with keyword " "argument exclusive_resource_access=False?" % - (list(resource_intersection), self._handle.name, - sg.op.name, sg.handle.name)) + (list(resource_intersection), self._handle, sg, sg.handle)) # TODO(ebrevdo): Re-enable once CriticalSection is in core. diff --git a/tensorflow/contrib/framework/python/ops/script_ops.py b/tensorflow/contrib/framework/python/ops/script_ops.py index 5d269fefdcfae7902b35e0f29f8cd12fcc58b882..d5cb679e2c05a217f36b7abe9986227e898aacc4 100644 --- a/tensorflow/contrib/framework/python/ops/script_ops.py +++ b/tensorflow/contrib/framework/python/ops/script_ops.py @@ -13,7 +13,7 @@ # limitations under the License. # ============================================================================== -"""Script Language Operators. See the @{$python/script_ops} guide. +"""Script Language Operators. @@py_func """ diff --git a/tensorflow/contrib/framework/python/ops/variables.py b/tensorflow/contrib/framework/python/ops/variables.py index 322d5c335e6a77c46c7ce5dd795e21a2d5a1f8f9..a7acae804a0c71cc19757a48d47fd9cf9022b0e2 100644 --- a/tensorflow/contrib/framework/python/ops/variables.py +++ b/tensorflow/contrib/framework/python/ops/variables.py @@ -241,13 +241,13 @@ def variable(name, use_resource: If `True` use a ResourceVariable instead of a Variable. synchronization: Indicates when a distributed a variable will be aggregated. Accepted values are constants defined in the class - @{tf.VariableSynchronization}. By default the synchronization is set to + `tf.VariableSynchronization`. By default the synchronization is set to `AUTO` and the current `DistributionStrategy` chooses when to synchronize. If `synchronization` is set to `ON_READ`, `trainable` must not be set to `True`. aggregation: Indicates how a distributed variable will be aggregated. Accepted values are constants defined in the class - @{tf.VariableAggregation}. + `tf.VariableAggregation`. Returns: The created or existing variable. @@ -320,13 +320,13 @@ def model_variable(name, use_resource: If `True` use a ResourceVariable instead of a Variable. synchronization: Indicates when a distributed a variable will be aggregated. Accepted values are constants defined in the class - @{tf.VariableSynchronization}. By default the synchronization is set to + `tf.VariableSynchronization`. By default the synchronization is set to `AUTO` and the current `DistributionStrategy` chooses when to synchronize. If `synchronization` is set to `ON_READ`, `trainable` must not be set to `True`. aggregation: Indicates how a distributed variable will be aggregated. Accepted values are constants defined in the class - @{tf.VariableAggregation}. + `tf.VariableAggregation`. Returns: The created or existing variable. diff --git a/tensorflow/contrib/fused_conv/kernels/fused_conv2d_bias_activation_op.h b/tensorflow/contrib/fused_conv/kernels/fused_conv2d_bias_activation_op.h index 7534f5797c4f3eee3b031b2693e212749af85c6e..869e899ac873d393ff312622082c6d6076284a0f 100644 --- a/tensorflow/contrib/fused_conv/kernels/fused_conv2d_bias_activation_op.h +++ b/tensorflow/contrib/fused_conv/kernels/fused_conv2d_bias_activation_op.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRDPARTY_TENSORFLOW_CONTRIB_KERNELS_FUSED_CONV2D_BIAS_ACTIVATION_OP_H_ -#define THIRDPARTY_TENSORFLOW_CONTRIB_KERNELS_FUSED_CONV2D_BIAS_ACTIVATION_OP_H_ +#ifndef TENSORFLOW_CONTRIB_FUSED_CONV_KERNELS_FUSED_CONV2D_BIAS_ACTIVATION_OP_H_ +#define TENSORFLOW_CONTRIB_FUSED_CONV_KERNELS_FUSED_CONV2D_BIAS_ACTIVATION_OP_H_ #include "tensorflow/core/framework/resource_mgr.h" #include "tensorflow/core/framework/tensor_types.h" @@ -62,4 +62,4 @@ class LaunchFusedConv2DBiasActivationOp