diff --git a/tools/bazel.rc b/.bazelrc similarity index 95% rename from tools/bazel.rc rename to .bazelrc index 1fdf51f53e29c7111cf89c016400b710051cf9c6..cd7e13ddfc146208f79be900917b05b694869d72 100644 --- a/tools/bazel.rc +++ b/.bazelrc @@ -76,7 +76,6 @@ build:nonccl --define=no_nccl_support=true build --define=use_fast_cpp_protos=true build --define=allow_oversize_protos=true -build --define=grpc_no_ares=true build --spawn_strategy=standalone build --genrule_strategy=standalone @@ -93,3 +92,11 @@ build:dynamic_kernels --copt=-DAUTOLOAD_DYNAMIC_KERNELS build --define=PREFIX=/usr build --define=LIBDIR=$(PREFIX)/lib build --define=INCLUDEDIR=$(PREFIX)/include + +# Default options should come above this line + +# Options from ./configure +try-import %workspace%/.tf_configure.bazelrc + +# Put user-specific options in .bazelrc.user +try-import %workspace%/.bazelrc.user diff --git a/.gitignore b/.gitignore index 90324058600bee46af56e49028977971848a80de..e1d352c238a1b2d4febe0f5d4a30cfa0c942f7e7 100644 --- a/.gitignore +++ b/.gitignore @@ -1,7 +1,7 @@ .DS_Store .ipynb_checkpoints node_modules -/.bazelrc +/.bazelrc.user /.tf_configure.bazelrc /bazel-* /bazel_pip diff --git a/CODEOWNERS b/CODEOWNERS index 54a61a4d72c40d297d90d53e223f64f813d9167d..cb3fa2312405ce44d5dfc30ea4164740f436e07e 100644 --- a/CODEOWNERS +++ b/CODEOWNERS @@ -1,7 +1,7 @@ # Where component owners are known, add them here. /tenosrflow/core/debug @caisq -/tensorflow/core/nccl/ @azaks @csigg +/tensorflow/core/nccl/ @azaks2 @chsigg /tensorflow/core/platform/windows/ @mrry /tensorflow/core/platform/s3 @yongtang /tensorflow/go @asimshankar @@ -51,13 +51,13 @@ /tensorflow/contrib/pi_examples/ @maciekcc /tensorflow/contrib/quantization/ @petewarden /tensorflow/contrib/rnn/ @ebrevdo @scottzhu -/tensorflow/contrib/saved_model/ @nfiedel @sukritiramesh @allenl +/tensorflow/contrib/saved_model/ @nfiedel @sukritiramesh @allenlavoie /tensorflow/contrib/seq2seq/ @ebrevdo @lmthang /tensorflow/contrib/session_bundle/ @nfiedel @sukritiramesh /tensorflow/contrib/slim/ @sguada @thenbasilmanran /tensorflow/contrib/stateless/ @girving @alextp /tensorflow/contrib/tensor_forest/ @gilberthendry @thomascolthurst @yupbank -/tensorflow/contrib/tensorrt/ @aaroey +/tensorflow/contrib/tensorrt/ @aaroey @smit-hinsu @azaks2 # NEED OWNER: /tensorflow/contrib/testing/ /tensorflow/contrib/timeseries/ @allenlavoie /tensorflow/contrib/tpu/ @frankchn @saeta @jhseu @sourabhbajaj diff --git a/README.md b/README.md index 044174947a094d43a51f7140dd40ec0f17801d40..519815d006cc33be10132909baf414a4bd843435 100644 --- a/README.md +++ b/README.md @@ -113,11 +113,12 @@ The TensorFlow project strives to abide by generally accepted best practices in Build Type | Status | Artifacts ----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | --------- **IBM s390x** | [![Build Status](http://ibmz-ci.osuosl.org/job/TensorFlow_IBMZ_CI/badge/icon)](http://ibmz-ci.osuosl.org/job/TensorFlow_IBMZ_CI/) | TBA -**IBM ppc64le CPU** | [![Build Status](http://powerci.osuosl.org/job/TensorFlow_PPC64LE_CPU_Build/badge/icon)](http://powerci.osuosl.org/job/TensorFlow_PPC64LE_CPU_Build/) | TBA -**IBM ppc64le GPU** Nightly | [![Build Status](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Nightly_Artifact/badge/icon)](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Nightly_Artifact/) | [Nightly](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Nightly_Artifact/) -**IBM ppc64le GPU** Stable Release | [![Build Status](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Release_Build/badge/icon)](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Release_Build/) | [Release](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Release_Build/) +**Linux ppc64le CPU** Nightly | [![Build Status](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_CPU_Build/badge/icon)](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_CPU_Build/) | [Nightly](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_CPU_Nightly_Artifact/) +**Linux ppc64le CPU** Stable Release | [![Build Status](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_CPU_Release_Build/badge/icon)](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_CPU_Release_Build/) | [Release](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_CPU_Release_Build/) +**Linux ppc64le GPU** Nightly | [![Build Status](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Build/badge/icon)](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Build/) | [Nightly](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Nightly_Artifact/) +**Linux ppc64le GPU** Stable Release | [![Build Status](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Release_Build/badge/icon)](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Release_Build/) | [Release](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Release_Build/) **Linux CPU with Intel® MKL-DNN** Nightly | [![Build Status](https://tensorflow-ci.intel.com/job/tensorflow-mkl-linux-cpu/badge/icon)](https://tensorflow-ci.intel.com/job/tensorflow-mkl-linux-cpu/) | [Nightly](https://tensorflow-ci.intel.com/job/tensorflow-mkl-build-whl-nightly/) -**Linux CPU with Intel® MKL-DNN** Python 2.7
**Linux CPU with Intel® MKL-DNN** Python 3.4
**Linux CPU with Intel® MKL-DNN** Python 3.5
**Linux CPU with Intel® MKL-DNN** Python 3.6 | [![Build Status](https://tensorflow-ci.intel.com/job/tensorflow-mkl-build-release-whl/badge/icon)](https://tensorflow-ci.intel.com/job/tensorflow-mkl-build-release-whl/lastStableBuild) | [1.11.0 py2.7](https://storage.googleapis.com/intel-optimized-tensorflow/tensorflow-1.11.0-cp27-cp27mu-linux_x86_64.whl)
[1.11.0 py3.4](https://storage.googleapis.com/intel-optimized-tensorflow/tensorflow-1.11.0-cp34-cp34m-linux_x86_64.whl)
[1.11.0 py3.5](https://storage.googleapis.com/intel-optimized-tensorflow/tensorflow-1.11.0-cp35-cp35m-linux_x86_64.whl)
[1.11.0 py3.6](https://storage.googleapis.com/intel-optimized-tensorflow/tensorflow-1.11.0-cp36-cp36m-linux_x86_64.whl) +**Linux CPU with Intel® MKL-DNN** Python 2.7
**Linux CPU with Intel® MKL-DNN** Python 3.4
**Linux CPU with Intel® MKL-DNN** Python 3.5
**Linux CPU with Intel® MKL-DNN** Python 3.6 | [![Build Status](https://tensorflow-ci.intel.com/job/tensorflow-mkl-build-release-whl/badge/icon)](https://tensorflow-ci.intel.com/job/tensorflow-mkl-build-release-whl/lastStableBuild) | [1.12.0 py2.7](https://storage.googleapis.com/intel-optimized-tensorflow/tensorflow-1.12.0-cp27-cp27mu-linux_x86_64.whl)
[1.12.0 py3.4](https://storage.googleapis.com/intel-optimized-tensorflow/tensorflow-1.12.0-cp34-cp34m-linux_x86_64.whl)
[1.12.0 py3.5](https://storage.googleapis.com/intel-optimized-tensorflow/tensorflow-1.12.0-cp35-cp35m-linux_x86_64.whl)
[1.12.0 py3.6](https://storage.googleapis.com/intel-optimized-tensorflow/tensorflow-1.12.0-cp36-cp36m-linux_x86_64.whl) ## For more information diff --git a/RELEASE.md b/RELEASE.md index b13b071bd6cf4d3a260c8e248a67d23e1a688498..32abdcea497618918964174a661a6ba872598f65 100644 --- a/RELEASE.md +++ b/RELEASE.md @@ -7,6 +7,8 @@ Serving. * Keras models now support evaluating with a `tf.data.Dataset`. * TensorFlow binaries are built with XLA support linked in by default. +* Ignite Dataset added to contrib/ignite that allows to work with Apache + Ignite. ## Bug Fixes and Other Changes diff --git a/WORKSPACE b/WORKSPACE index 0c7bc085b512b084b9470abe17326d7c119aa327..7057d3f149e766cd2983ecc89509f84c37075602 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -1,5 +1,7 @@ workspace(name = "org_tensorflow") +load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive") + http_archive( name = "io_bazel_rules_closure", sha256 = "a38539c5b5c358548e75b44141b4ab637bba7c4dc02b46b1f62a96d6433f56ae", @@ -14,30 +16,27 @@ load("@io_bazel_rules_closure//closure:defs.bzl", "closure_repositories") closure_repositories() -http_archive( - name = "base_images_docker", - sha256 = "e2b1b7254270bb7605e814a9dbf6d1e4ae04a11136ff1714fbfdabe3f87f7cf9", - strip_prefix = "base-images-docker-12801524f867e657fbb5d1a74f31618aff181ac6", - urls = ["https://github.com/GoogleCloudPlatform/base-images-docker/archive/12801524f867e657fbb5d1a74f31618aff181ac6.tar.gz"], -) +load("//third_party/toolchains/preconfig/generate:archives.bzl", + "bazel_toolchains_archive") -http_archive( - name = "bazel_toolchains", - sha256 = "15b5858b1b5541ec44df31b94c3b8672815b31d71215a98398761ea9f4c4eedb", - strip_prefix = "bazel-toolchains-6200b238c9c2d137c0d9a7262c80cc71d98e692b", - urls = [ - "https://github.com/bazelbuild/bazel-toolchains/archive/6200b238c9c2d137c0d9a7262c80cc71d98e692b.tar.gz", - ], +bazel_toolchains_archive() + +load( + "@bazel_toolchains//repositories:repositories.bzl", + bazel_toolchains_repositories = "repositories", ) -http_archive( - name = "io_bazel_rules_docker", - sha256 = "29d109605e0d6f9c892584f07275b8c9260803bf0c6fcb7de2623b2bedc910bd", - strip_prefix = "rules_docker-0.5.1", - urls = ["https://github.com/bazelbuild/rules_docker/archive/v0.5.1.tar.gz"], +bazel_toolchains_repositories() + +load( + "@io_bazel_rules_docker//container:container.bzl", + container_repositories = "repositories", ) -load("//third_party/toolchains/preconfig/generate:workspace.bzl", "remote_config_workspace") +container_repositories() + +load("//third_party/toolchains/preconfig/generate:workspace.bzl", + "remote_config_workspace") remote_config_workspace() @@ -45,7 +44,7 @@ remote_config_workspace() # files, in case the parsing of those build files depends on the bazel # version we require here. load("//tensorflow:version_check.bzl", "check_bazel_version_at_least") -check_bazel_version_at_least("0.15.0") +check_bazel_version_at_least("0.18.0") load("//tensorflow:workspace.bzl", "tf_workspace") @@ -57,9 +56,9 @@ android_workspace() # Please add all new TensorFlow dependencies in workspace.bzl. tf_workspace() -new_http_archive( +http_archive( name = "inception_v1", - build_file = "models.BUILD", + build_file = "//:models.BUILD", sha256 = "7efe12a8363f09bc24d7b7a450304a15655a57a7751929b2c1593a71183bb105", urls = [ "http://storage.googleapis.com/download.tensorflow.org/models/inception_v1.zip", @@ -67,9 +66,9 @@ new_http_archive( ], ) -new_http_archive( +http_archive( name = "mobile_ssd", - build_file = "models.BUILD", + build_file = "//:models.BUILD", sha256 = "bddd81ea5c80a97adfac1c9f770e6f55cbafd7cce4d3bbe15fbeb041e6b8f3e8", urls = [ "http://storage.googleapis.com/download.tensorflow.org/models/object_detection/ssd_mobilenet_v1_android_export.zip", @@ -77,9 +76,9 @@ new_http_archive( ], ) -new_http_archive( +http_archive( name = "mobile_multibox", - build_file = "models.BUILD", + build_file = "//:models.BUILD", sha256 = "859edcddf84dddb974c36c36cfc1f74555148e9c9213dedacf1d6b613ad52b96", urls = [ "http://storage.googleapis.com/download.tensorflow.org/models/mobile_multibox_v1a.zip", @@ -87,9 +86,9 @@ new_http_archive( ], ) -new_http_archive( +http_archive( name = "stylize", - build_file = "models.BUILD", + build_file = "//:models.BUILD", sha256 = "3d374a730aef330424a356a8d4f04d8a54277c425e274ecb7d9c83aa912c6bfa", urls = [ "http://storage.googleapis.com/download.tensorflow.org/models/stylize_v1.zip", @@ -97,9 +96,9 @@ new_http_archive( ], ) -new_http_archive( +http_archive( name = "speech_commands", - build_file = "models.BUILD", + build_file = "//:models.BUILD", sha256 = "c3ec4fea3158eb111f1d932336351edfe8bd515bb6e87aad4f25dbad0a600d0c", urls = [ "http://storage.googleapis.com/download.tensorflow.org/models/speech_commands_v0.01.zip", diff --git a/configure.py b/configure.py index f087da002d534e1f0f4c1598e87217168c892dbe..1e732db26404906901a9eeab97a5e75137ee8388 100644 --- a/configure.py +++ b/configure.py @@ -255,18 +255,6 @@ def setup_python(environ_cp): def reset_tf_configure_bazelrc(): """Reset file that contains customized config settings.""" open(_TF_BAZELRC, 'w').close() - bazelrc_path = os.path.join(_TF_WORKSPACE_ROOT, '.bazelrc') - - data = [] - if os.path.exists(bazelrc_path): - with open(bazelrc_path, 'r') as f: - data = f.read().splitlines() - with open(bazelrc_path, 'w') as f: - for l in data: - if _TF_BAZELRC_FILENAME in l: - continue - f.write('%s\n' % l) - f.write('import %%workspace%%/%s\n' % _TF_BAZELRC_FILENAME) def cleanup_makefile(): """Delete any leftover BUILD files from the Makefile build. @@ -452,11 +440,12 @@ def convert_version_to_int(version): return int(version_str) -def check_bazel_version(min_version): - """Check installed bazel version is at least min_version. +def check_bazel_version(min_version, max_version): + """Check installed bazel version is between min_version and max_version. Args: min_version: string for minimum bazel version. + max_version: string for maximum bazel version. Returns: The bazel version detected. @@ -474,6 +463,7 @@ def check_bazel_version(min_version): min_version_int = convert_version_to_int(min_version) curr_version_int = convert_version_to_int(curr_version) + max_version_int = convert_version_to_int(max_version) # Check if current bazel version can be detected properly. if not curr_version_int: @@ -486,7 +476,12 @@ def check_bazel_version(min_version): if curr_version_int < min_version_int: print('Please upgrade your bazel installation to version %s or higher to ' 'build TensorFlow!' % min_version) - sys.exit(0) + sys.exit(1) + if (curr_version_int > max_version_int and + 'TF_IGNORE_MAX_BAZEL_VERSION' not in os.environ): + print('Please downgrade your bazel installation to version %s or lower to ' + 'build TensorFlow!' % max_version) + sys.exit(1) return curr_version @@ -1559,11 +1554,9 @@ def main(): # environment variables. environ_cp = dict(os.environ) - check_bazel_version('0.15.0') + check_bazel_version('0.19.0', '0.20.0') reset_tf_configure_bazelrc() - # Explicitly import tools/bazel.rc, this is needed for Bazel 0.19.0 or later - write_to_bazelrc('import %workspace%/tools/bazel.rc') cleanup_makefile() setup_python(environ_cp) diff --git a/tensorflow/BUILD b/tensorflow/BUILD index fd4b94202aad24a82abef8abd16431f61a8326f0..449a1372edb031c68786d8672e2a1499c2b3d047 100644 --- a/tensorflow/BUILD +++ b/tensorflow/BUILD @@ -267,6 +267,15 @@ config_setting( visibility = ["//visibility:public"], ) +# By default, XLA GPU is compiled into tensorflow when building with +# --config=cuda even when `with_xla_support` is false. The config setting +# here allows us to override the behavior if needed. +config_setting( + name = "no_xla_deps_in_cuda", + define_values = {"no_xla_deps_in_cuda": "true"}, + visibility = ["//visibility:public"], +) + config_setting( name = "with_gdr_support", define_values = {"with_gdr_support": "true"}, @@ -606,9 +615,11 @@ py_library( name = "tensorflow_py", srcs_version = "PY2AND3", visibility = ["//visibility:public"], - deps = [ + deps = select({ + "api_version_2": [], + "//conditions:default": ["//tensorflow/contrib:contrib_py"], + }) + [ ":tensorflow_py_no_contrib", - "//tensorflow/contrib:contrib_py", "//tensorflow/python/estimator:estimator_py", ], ) diff --git a/tensorflow/api_template.__init__.py b/tensorflow/api_template.__init__.py index f13623b0d57d3b59bb9455a46a9fab29fee25784..4eba763129a6aef40e3c130d56bf8ab19638b7ca 100644 --- a/tensorflow/api_template.__init__.py +++ b/tensorflow/api_template.__init__.py @@ -20,14 +20,14 @@ from __future__ import print_function as _print_function import os as _os +# API IMPORTS PLACEHOLDER + # pylint: disable=g-bad-import-order from tensorflow.python.tools import component_api_helper as _component_api_helper _component_api_helper.package_hook( parent_package_str=__name__, child_package_str=('tensorflow_estimator.python.estimator.api.estimator')) -# API IMPORTS PLACEHOLDER - # Make sure directory containing top level submodules is in # the __path__ so that "from tensorflow.foo import bar" works. # We're using bitwise, but there's nothing special about that. @@ -35,8 +35,9 @@ _tf_api_dir = _os.path.dirname(_os.path.dirname(bitwise.__file__)) # pylint: di if _tf_api_dir not in __path__: __path__.append(_tf_api_dir) -# Calls to enable and disable features. -enable_eager_execution() # pylint: disable=undefined-variable +# Enable TF2 behaviors +from tensorflow.python.compat import compat as _compat # pylint: disable=g-import-not-at-top +_compat.enable_v2_behavior() # These symbols appear because we import the python package which # in turn imports from tensorflow.core and tensorflow.python. They diff --git a/tensorflow/api_template_v1.__init__.py b/tensorflow/api_template_v1.__init__.py index 65bdb6cb1b5e6fb0656a12b932d767aeacfccd29..21b5277614667bdbd7271ac3e57f5b69d5a19264 100644 --- a/tensorflow/api_template_v1.__init__.py +++ b/tensorflow/api_template_v1.__init__.py @@ -23,13 +23,13 @@ import os as _os # pylint: disable=g-bad-import-order from tensorflow.python import pywrap_tensorflow # pylint: disable=unused-import +# API IMPORTS PLACEHOLDER + from tensorflow.python.tools import component_api_helper as _component_api_helper _component_api_helper.package_hook( parent_package_str=__name__, child_package_str=('tensorflow_estimator.python.estimator.api.estimator')) -# API IMPORTS PLACEHOLDER - from tensorflow.python.util.lazy_loader import LazyLoader # pylint: disable=g-import-not-at-top contrib = LazyLoader('contrib', globals(), 'tensorflow.contrib') del LazyLoader diff --git a/tensorflow/c/BUILD b/tensorflow/c/BUILD index f653e581bf3beda9fdbf8fb7905a4f9fe170e7fb..25df970ecab0757f23465ab19e7f45de0c759458 100644 --- a/tensorflow/c/BUILD +++ b/tensorflow/c/BUILD @@ -175,6 +175,34 @@ tf_cuda_library( ], ) +tf_cuda_library( + name = "env", + srcs = [ + "env.cc", + ], + hdrs = [ + "env.h", + ], + copts = tf_copts(), + visibility = ["//visibility:public"], + deps = select({ + "//tensorflow:android": [ + ":c_api", + ":tf_status_helper", + "//tensorflow/core:android_tensorflow_lib_lite", + "//tensorflow/core:platform_env", + "//tensorflow/core:lib", + ], + "//conditions:default": [ + ":c_api", + ":tf_status_helper", + "//tensorflow/core:framework", + "//tensorflow/core:platform_env", + "//tensorflow/core:lib", + ], + }) + [":c_api_internal"], +) + tf_cuda_library( name = "kernels", srcs = [ @@ -188,10 +216,14 @@ tf_cuda_library( deps = select({ "//tensorflow:android": [ ":c_api", + ":c_api_internal", + ":tf_status_helper", "//tensorflow/core:android_tensorflow_lib_lite", ], "//conditions:default": [ ":c_api", + ":c_api_internal", + ":tf_status_helper", "//tensorflow/core:framework", ], }), @@ -330,6 +362,27 @@ tf_kernel_library( alwayslink = 1, ) +tf_cuda_cc_test( + name = "env_test", + size = "small", + srcs = ["env_test.cc"], + linkopts = select({ + "//tensorflow:darwin": ["-headerpad_max_install_names"], + "//conditions:default": [], + }), + tags = ["noasan"], + # We must ensure that the dependencies can be dynamically linked since + # the shared library must be able to use core:framework. + # linkstatic = tf_kernel_tests_linkstatic(), + deps = [ + ":c_api", + ":env", + "//tensorflow/core:lib", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + ], +) + tf_cuda_cc_test( name = "kernels_test", size = "small", diff --git a/tensorflow/c/c_api.cc b/tensorflow/c/c_api.cc index f13e8777dff164bcd8eedf46310ae846abd0c804..9580215a317b1a6b1cdacbd430a1764af61be990 100644 --- a/tensorflow/c/c_api.cc +++ b/tensorflow/c/c_api.cc @@ -136,16 +136,22 @@ const char* TF_Message(const TF_Status* s) { namespace { class TF_ManagedBuffer : public TensorBuffer { public: - void* data_; - size_t len_; - void (*deallocator_)(void* data, size_t len, void* arg); - void* deallocator_arg_; + TF_ManagedBuffer(void* data, size_t len, + void (*deallocator)(void* data, size_t len, void* arg), + void* deallocator_arg) + : TensorBuffer(data), + len_(len), + deallocator_(deallocator), + deallocator_arg_(deallocator_arg) {} + + const size_t len_; + void (*const deallocator_)(void* data, size_t len, void* arg); + void* const deallocator_arg_; ~TF_ManagedBuffer() override { - (*deallocator_)(data_, len_, deallocator_arg_); + (*deallocator_)(data(), len_, deallocator_arg_); } - void* data() const override { return data_; } size_t size() const override { return len_; } TensorBuffer* root_buffer() override { return this; } void FillAllocationDescription(AllocationDescription* proto) const override { @@ -199,8 +205,7 @@ TF_Tensor* TF_NewTensor(TF_DataType dtype, const int64_t* dims, int num_dims, dimvec[i] = static_cast(dims[i]); } - TF_ManagedBuffer* buf = new TF_ManagedBuffer; - buf->len_ = len; + TF_ManagedBuffer* buf = nullptr; if (dtype != TF_STRING && dtype != TF_RESOURCE && tensorflow::DataTypeCanUseMemcpy(static_cast(dtype)) && reinterpret_cast(data) % std::max(1, EIGEN_MAX_ALIGN_BYTES) != @@ -212,17 +217,15 @@ TF_Tensor* TF_NewTensor(TF_DataType dtype, const int64_t* dims, int num_dims, // // Other types have the same representation, so copy only if it is safe to // do so. - buf->data_ = allocate_tensor("TF_NewTensor", len); - std::memcpy(buf->data_, data, len); - buf->deallocator_ = deallocate_buffer; - buf->deallocator_arg_ = nullptr; + buf = new TF_ManagedBuffer(allocate_tensor("TF_NewTensor", len), len, + deallocate_buffer, nullptr); + std::memcpy(buf->data(), data, len); // Free the original buffer. deallocator(data, len, deallocator_arg); } else { - buf->data_ = data; - buf->deallocator_ = deallocator; - buf->deallocator_arg_ = deallocator_arg; + buf = new TF_ManagedBuffer(data, len, deallocator, deallocator_arg); } + TF_Tensor* ret = new TF_Tensor{dtype, TensorShape(dimvec), buf}; size_t elem_size = TF_DataTypeSize(dtype); if (elem_size > 0 && len < (elem_size * ret->shape.num_elements())) { @@ -477,14 +480,15 @@ static TF_Tensor* EmptyTensor(TF_DataType dtype, const TensorShape& shape) { CHECK_EQ(nelems, 0); static_assert(sizeof(int64_t) == sizeof(tensorflow::int64), "64-bit int types should match in size"); - return TF_NewTensor(dtype, reinterpret_cast(dims.data()), - shape.dims(), reinterpret_cast(&empty), 0, - [](void*, size_t, void*) {}, nullptr); + return TF_NewTensor( + dtype, reinterpret_cast(dims.data()), shape.dims(), + reinterpret_cast(&empty), 0, [](void*, size_t, void*) {}, nullptr); } // Non-static for testing. TF_Tensor* TF_TensorFromTensor(const tensorflow::Tensor& src, TF_Status* status) { + TF_SetStatus(status, TF_OK, ""); if (!src.IsInitialized()) { status->status = FailedPrecondition( "attempt to use a tensor with an uninitialized value"); @@ -1592,18 +1596,20 @@ TF_AttrMetadata TF_OperationGetAttrMetadata(TF_Operation* oper, break; \ } - LIST_CASE(s, TF_ATTR_STRING, metadata.total_size = 0; - for (int i = 0; i < attr->list().s_size(); - ++i) { metadata.total_size += attr->list().s(i).size(); }); + LIST_CASE( + s, TF_ATTR_STRING, metadata.total_size = 0; + for (int i = 0; i < attr->list().s_size(); + ++i) { metadata.total_size += attr->list().s(i).size(); }); LIST_CASE(i, TF_ATTR_INT); LIST_CASE(f, TF_ATTR_FLOAT); LIST_CASE(b, TF_ATTR_BOOL); LIST_CASE(type, TF_ATTR_TYPE); - LIST_CASE(shape, TF_ATTR_SHAPE, metadata.total_size = 0; - for (int i = 0; i < attr->list().shape_size(); ++i) { - const auto& s = attr->list().shape(i); - metadata.total_size += s.unknown_rank() ? 0 : s.dim_size(); - }); + LIST_CASE( + shape, TF_ATTR_SHAPE, metadata.total_size = 0; + for (int i = 0; i < attr->list().shape_size(); ++i) { + const auto& s = attr->list().shape(i); + metadata.total_size += s.unknown_rank() ? 0 : s.dim_size(); + }); LIST_CASE(tensor, TF_ATTR_TENSOR); LIST_CASE(tensor, TF_ATTR_FUNC); #undef LIST_CASE diff --git a/tensorflow/c/c_api.h b/tensorflow/c/c_api.h index 3d56268110edbe96616201d15a69cc8c84d3115a..c7abba85521fccec07983cd5ab4f94a8368d6181 100644 --- a/tensorflow/c/c_api.h +++ b/tensorflow/c/c_api.h @@ -91,7 +91,7 @@ extern "C" { // -------------------------------------------------------------------------- // TF_Version returns a string describing version information of the // TensorFlow library. TensorFlow using semantic versioning. -TF_CAPI_EXPORT extern const char* TF_Version(); +TF_CAPI_EXPORT extern const char* TF_Version(void); // -------------------------------------------------------------------------- // TF_DataType holds the type for a scalar value. E.g., one slot in a tensor. @@ -157,7 +157,7 @@ typedef enum TF_Code { typedef struct TF_Status TF_Status; // Return a new status object. -TF_CAPI_EXPORT extern TF_Status* TF_NewStatus(); +TF_CAPI_EXPORT extern TF_Status* TF_NewStatus(void); // Delete a previously created status object. TF_CAPI_EXPORT extern void TF_DeleteStatus(TF_Status*); @@ -196,7 +196,7 @@ TF_CAPI_EXPORT extern TF_Buffer* TF_NewBufferFromString(const void* proto, size_t proto_len); // Useful for passing *out* a protobuf. -TF_CAPI_EXPORT extern TF_Buffer* TF_NewBuffer(); +TF_CAPI_EXPORT extern TF_Buffer* TF_NewBuffer(void); TF_CAPI_EXPORT extern void TF_DeleteBuffer(TF_Buffer*); @@ -305,7 +305,7 @@ TF_CAPI_EXPORT extern size_t TF_StringEncodedSize(size_t len); typedef struct TF_SessionOptions TF_SessionOptions; // Return a new options object. -TF_CAPI_EXPORT extern TF_SessionOptions* TF_NewSessionOptions(); +TF_CAPI_EXPORT extern TF_SessionOptions* TF_NewSessionOptions(void); // Set the target in TF_SessionOptions.options. // target can be empty, a single entry, or a comma separated list of entries. @@ -338,7 +338,7 @@ TF_CAPI_EXPORT extern void TF_DeleteSessionOptions(TF_SessionOptions*); typedef struct TF_Graph TF_Graph; // Return a new graph object. -TF_CAPI_EXPORT extern TF_Graph* TF_NewGraph(); +TF_CAPI_EXPORT extern TF_Graph* TF_NewGraph(void); // Destroy an options object. Graph will be deleted once no more // TFSession's are referencing it. @@ -890,7 +890,8 @@ TF_CAPI_EXPORT extern void TF_GraphVersions(TF_Graph* graph, // TF_GraphImportGraphDef. typedef struct TF_ImportGraphDefOptions TF_ImportGraphDefOptions; -TF_CAPI_EXPORT extern TF_ImportGraphDefOptions* TF_NewImportGraphDefOptions(); +TF_CAPI_EXPORT extern TF_ImportGraphDefOptions* TF_NewImportGraphDefOptions( + void); TF_CAPI_EXPORT extern void TF_DeleteImportGraphDefOptions( TF_ImportGraphDefOptions* opts); @@ -1611,7 +1612,7 @@ TF_CAPI_EXPORT extern void TF_DeleteLibraryHandle(TF_Library* lib_handle); // // The data in the buffer will be the serialized OpList proto for ops registered // in this address space. -TF_CAPI_EXPORT extern TF_Buffer* TF_GetAllOpList(); +TF_CAPI_EXPORT extern TF_Buffer* TF_GetAllOpList(void); // TF_ApiDefMap encapsulates a collection of API definitions for an operation. // diff --git a/tensorflow/c/c_api_experimental.cc b/tensorflow/c/c_api_experimental.cc index 3693cc85996365360253c8a94c29272a16e11e9a..81343f7bc027be82d28164be51011c794715d03a 100644 --- a/tensorflow/c/c_api_experimental.cc +++ b/tensorflow/c/c_api_experimental.cc @@ -66,7 +66,8 @@ void TF_EnableXLACompilation(TF_SessionOptions* options, unsigned char enable) { } TF_Buffer* TF_CreateConfig(unsigned char enable_xla_compilation, - unsigned char gpu_memory_allow_growth) { + unsigned char gpu_memory_allow_growth, + unsigned int num_cpu_devices) { tensorflow::ConfigProto config; auto* optimizer_options = config.mutable_graph_options()->mutable_optimizer_options(); @@ -87,6 +88,8 @@ TF_Buffer* TF_CreateConfig(unsigned char enable_xla_compilation, auto* gpu_options = config.mutable_gpu_options(); gpu_options->set_allow_growth(gpu_memory_allow_growth); + (*config.mutable_device_count())["CPU"] = num_cpu_devices; + // TODO(b/113217601): This is needed for EagerContext::runner_ to use a // threadpool, so that we avoid the possibility of running the runner_ in the // threadpool of GPU event mgr, as that can trigger more callbacks to be @@ -6530,7 +6533,7 @@ library { } } node_def { - name: "ParallelInterleaveDataset/cycle_length" + name: "ExperimentalParallelInterleaveDataset/cycle_length" op: "Const" attr { key: "dtype" @@ -6551,7 +6554,7 @@ library { } } node_def { - name: "ParallelInterleaveDataset/block_length" + name: "ExperimentalParallelInterleaveDataset/block_length" op: "Const" attr { key: "dtype" @@ -6572,7 +6575,7 @@ library { } } node_def { - name: "ParallelInterleaveDataset/sloppy" + name: "ExperimentalParallelInterleaveDataset/sloppy" op: "Const" attr { key: "dtype" @@ -6593,7 +6596,7 @@ library { } } node_def { - name: "ParallelInterleaveDataset/buffer_output_elements" + name: "ExperimentalParallelInterleaveDataset/buffer_output_elements" op: "Const" attr { key: "dtype" @@ -6614,7 +6617,7 @@ library { } } node_def { - name: "ParallelInterleaveDataset/prefetch_input_elements" + name: "ExperimentalParallelInterleaveDataset/prefetch_input_elements" op: "Const" attr { key: "dtype" @@ -6635,14 +6638,14 @@ library { } } node_def { - name: "ParallelInterleaveDataset" - op: "ParallelInterleaveDataset" + name: "ExperimentalParallelInterleaveDataset" + op: "ExperimentalParallelInterleaveDataset" input: "RepeatDataset:handle:0" - input: "ParallelInterleaveDataset/cycle_length:output:0" - input: "ParallelInterleaveDataset/block_length:output:0" - input: "ParallelInterleaveDataset/sloppy:output:0" - input: "ParallelInterleaveDataset/buffer_output_elements:output:0" - input: "ParallelInterleaveDataset/prefetch_input_elements:output:0" + input: "ExperimentalParallelInterleaveDataset/cycle_length:output:0" + input: "ExperimentalParallelInterleaveDataset/block_length:output:0" + input: "ExperimentalParallelInterleaveDataset/sloppy:output:0" + input: "ExperimentalParallelInterleaveDataset/buffer_output_elements:output:0" + input: "ExperimentalParallelInterleaveDataset/prefetch_input_elements:output:0" attr { key: "Targuments" value { @@ -6742,7 +6745,7 @@ library { node_def { name: "ShuffleDataset_2" op: "ShuffleDataset" - input: "ParallelInterleaveDataset:handle:0" + input: "ExperimentalParallelInterleaveDataset:handle:0" input: "ShuffleDataset_2/buffer_size_1:output:0" input: "ShuffleDataset_2/seed_2:output:0" input: "ShuffleDataset_2/seed2_2:output:0" @@ -8535,8 +8538,9 @@ TFE_Context* TFE_CreateContextFromSession(TF_Session* session, // Reduce GPU memory allocation, and set appropriate config options for TFE // context. - auto* config = - TF_CreateConfig(/*xla*/ false, /* gpu_memory_allow_growth */ true); + auto* config = TF_CreateConfig( + /*xla*/ false, /* gpu_memory_allow_growth */ true, /* num_cpu_devices */ + 10); TFE_ContextOptionsSetConfig(opts, config->data, config->length, status); if (!status->status.ok()) { CHECK(!config); diff --git a/tensorflow/c/c_api_experimental.h b/tensorflow/c/c_api_experimental.h index 80c8bfe594c4c89606efd01bec7f50e7a86b5bda..cb7a146846ff0bdac09f4a90765f78e0ada75718 100644 --- a/tensorflow/c/c_api_experimental.h +++ b/tensorflow/c/c_api_experimental.h @@ -67,9 +67,10 @@ TF_CAPI_EXPORT extern void TF_EnableXLACompilation(TF_SessionOptions* options, // a) ConfigProto.optimizer_options.global_jit_level is set to to ON_1 if // `enable_xla_compilation` is non-zero, and OFF otherwise. // b) ConfigProto.gpu_options.allow_growth is set to `gpu_memory_allow_growth`. +// c) ConfigProto.device_count is set to `num_cpu_devices`. TF_CAPI_EXPORT extern TF_Buffer* TF_CreateConfig( - unsigned char enable_xla_compilation, - unsigned char gpu_memory_allow_growth); + unsigned char enable_xla_compilation, unsigned char gpu_memory_allow_growth, + unsigned int num_cpu_devices); // Create a serialized tensorflow.RunOptions proto, where RunOptions.trace_level // is set to FULL_TRACE if `enable_full_trace` is non-zero, and NO_TRACE @@ -239,7 +240,7 @@ TF_CAPI_EXPORT void TF_InitMain(const char* usage, int* argc, char*** argv); // Platform-specific implementation to return an unused port. (This should used // in tests only.) -TF_CAPI_EXPORT int TF_PickUnusedPortOrDie(); +TF_CAPI_EXPORT int TF_PickUnusedPortOrDie(void); // Fast path method that makes constructing a single scalar tensor require less // overhead and copies. diff --git a/tensorflow/c/eager/c_api.h b/tensorflow/c/eager/c_api.h index 8d6c8d958d5961fce817156a14eb2b2940c1f2f0..120748ab763a3358b6e38e64bb3b6fd2ea32f7c3 100755 --- a/tensorflow/c/eager/c_api.h +++ b/tensorflow/c/eager/c_api.h @@ -48,7 +48,7 @@ extern "C" { typedef struct TFE_ContextOptions TFE_ContextOptions; // Return a new options object. -TF_CAPI_EXPORT extern TFE_ContextOptions* TFE_NewContextOptions(); +TF_CAPI_EXPORT extern TFE_ContextOptions* TFE_NewContextOptions(void); // Set the config in TF_ContextOptions.options. // config should be a serialized tensorflow.ConfigProto proto. @@ -170,23 +170,11 @@ TF_CAPI_EXPORT extern int64_t TFE_TensorHandleDim(TFE_TensorHandle* h, int dim_index, TF_Status* status); -// Returns the device of the operation that produced `h`. -// If `h` was produced by a copy, returns the destination device of -// the copy. Note that returned device name is not always the device -// holding the tensor handle's memory. If you want the latter, use -// TFE_TensorHandleBackingDeviceName. -// This function will block till the operation that produces `h` has completed. -// -// Device on which the kernel of the operation that produced `h` ran. -// -// If `h` was produced by a copy, returns the destination device of -// the copy. -// -// Note that returned device name is not always the device that owns the memory -// that backs the tensor handle. For the latter see -// TFE_TensorHandleBackingDeviceName. -// -// This function will block till the operation that produces `h` has completed. +// Returns the device of the operation that produced `h`. If `h` was produced by +// a copy, returns the destination device of the copy. Note that the returned +// device name is not always the device holding the tensor handle's memory. If +// you want the latter, use TFE_TensorHandleBackingDeviceName. This function +// will block till the operation that produces `h` has completed. TF_CAPI_EXPORT extern const char* TFE_TensorHandleDeviceName( TFE_TensorHandle* h, TF_Status* status); diff --git a/tensorflow/c/env.cc b/tensorflow/c/env.cc new file mode 100644 index 0000000000000000000000000000000000000000..1c35ff9001d0ee1ab0fbae9e1bcc07116fab1065 --- /dev/null +++ b/tensorflow/c/env.cc @@ -0,0 +1,183 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/c/env.h" + +#include "tensorflow/c/c_api_internal.h" +#include "tensorflow/c/tf_status_helper.h" +#include "tensorflow/core/platform/env.h" +#include "tensorflow/core/platform/types.h" + +struct TF_StringStream { + std::vector<::tensorflow::string>* list; + size_t position; +}; + +void TF_CreateDir(const char* dirname, TF_Status* status) { + TF_SetStatus(status, TF_OK, ""); + ::tensorflow::Set_TF_Status_from_Status( + status, ::tensorflow::Env::Default()->CreateDir(dirname)); +} + +void TF_DeleteDir(const char* dirname, TF_Status* status) { + TF_SetStatus(status, TF_OK, ""); + ::tensorflow::Set_TF_Status_from_Status( + status, ::tensorflow::Env::Default()->DeleteDir(dirname)); +} + +void TF_DeleteRecursively(const char* dirname, int64_t* undeleted_file_count, + int64_t* undeleted_dir_count, TF_Status* status) { + ::tensorflow::int64 f, d; + + TF_SetStatus(status, TF_OK, ""); + ::tensorflow::Set_TF_Status_from_Status( + status, ::tensorflow::Env::Default()->DeleteRecursively(dirname, &f, &d)); + *undeleted_file_count = f; + *undeleted_dir_count = d; +} + +void TF_FileStat(const char* filename, TF_FileStatistics* stats, + TF_Status* status) { + ::tensorflow::FileStatistics cc_stats; + TF_SetStatus(status, TF_OK, ""); + ::tensorflow::Status s = + ::tensorflow::Env::Default()->Stat(filename, &cc_stats); + ::tensorflow::Set_TF_Status_from_Status(status, s); + if (s.ok()) { + stats->length = cc_stats.length; + stats->mtime_nsec = cc_stats.mtime_nsec; + stats->is_directory = cc_stats.is_directory; + } +} + +void TF_NewWritableFile(const char* filename, TF_WritableFileHandle** handle, + TF_Status* status) { + std::unique_ptr<::tensorflow::WritableFile> f; + TF_SetStatus(status, TF_OK, ""); + ::tensorflow::Status s = + ::tensorflow::Env::Default()->NewWritableFile(filename, &f); + ::tensorflow::Set_TF_Status_from_Status(status, s); + + if (s.ok()) { + *handle = reinterpret_cast(f.release()); + } +} + +void TF_CloseWritableFile(TF_WritableFileHandle* handle, TF_Status* status) { + auto* cc_file = reinterpret_cast<::tensorflow::WritableFile*>(handle); + TF_SetStatus(status, TF_OK, ""); + ::tensorflow::Set_TF_Status_from_Status(status, cc_file->Close()); + delete cc_file; +} + +void TF_SyncWritableFile(TF_WritableFileHandle* handle, TF_Status* status) { + auto* cc_file = reinterpret_cast<::tensorflow::WritableFile*>(handle); + TF_SetStatus(status, TF_OK, ""); + ::tensorflow::Set_TF_Status_from_Status(status, cc_file->Sync()); +} + +void TF_FlushWritableFile(TF_WritableFileHandle* handle, TF_Status* status) { + auto* cc_file = reinterpret_cast<::tensorflow::WritableFile*>(handle); + TF_SetStatus(status, TF_OK, ""); + ::tensorflow::Set_TF_Status_from_Status(status, cc_file->Flush()); +} + +void TF_AppendWritableFile(TF_WritableFileHandle* handle, const char* data, + size_t length, TF_Status* status) { + auto* cc_file = reinterpret_cast<::tensorflow::WritableFile*>(handle); + TF_SetStatus(status, TF_OK, ""); + ::tensorflow::Set_TF_Status_from_Status( + status, cc_file->Append(::tensorflow::StringPiece{data, length})); +} + +void TF_DeleteFile(const char* filename, TF_Status* status) { + TF_SetStatus(status, TF_OK, ""); + ::tensorflow::Set_TF_Status_from_Status( + status, ::tensorflow::Env::Default()->DeleteFile(filename)); +} + +bool TF_StringStreamNext(TF_StringStream* list, const char** result) { + if (list->position >= list->list->size()) { + *result = nullptr; + return false; + } + + *result = list->list->at(list->position++).c_str(); + return true; +} + +void TF_StringStreamDone(TF_StringStream* list) { + delete list->list; + delete list; +} +TF_StringStream* TF_GetChildren(const char* dirname, TF_Status* status) { + auto* children = new std::vector<::tensorflow::string>; + + TF_SetStatus(status, TF_OK, ""); + ::tensorflow::Set_TF_Status_from_Status( + status, ::tensorflow::Env::Default()->GetChildren(dirname, children)); + + auto* list = new TF_StringStream; + list->list = children; + list->position = 0; + return list; +} + +TF_StringStream* TF_GetLocalTempDirectories() { + auto* tmpdirs = new std::vector<::tensorflow::string>; + + ::tensorflow::Env::Default()->GetLocalTempDirectories(tmpdirs); + + auto* list = new TF_StringStream; + list->list = tmpdirs; + list->position = 0; + return list; +} + +TF_CAPI_EXPORT extern uint64_t TF_NowNanos(void) { + return ::tensorflow::Env::Default()->NowNanos(); +} + +// Returns the number of microseconds since the Unix epoch. +TF_CAPI_EXPORT extern uint64_t TF_NowMicros(void) { + return ::tensorflow::Env::Default()->NowMicros(); +} + +// Returns the number of seconds since the Unix epoch. +TF_CAPI_EXPORT extern uint64_t TF_NowSeconds(void) { + return ::tensorflow::Env::Default()->NowSeconds(); +} + +void TF_DefaultThreadOptions(TF_ThreadOptions* options) { + options->stack_size = 0; + options->guard_size = 0; + options->numa_node = -1; +} + +TF_Thread* TF_StartThread(const TF_ThreadOptions* options, + const char* thread_name, void (*work_func)(void*), + void* param) { + ::tensorflow::ThreadOptions cc_options; + cc_options.stack_size = options->stack_size; + cc_options.guard_size = options->guard_size; + cc_options.numa_node = options->numa_node; + return reinterpret_cast(::tensorflow::Env::Default()->StartThread( + cc_options, thread_name, [=]() { (*work_func)(param); })); +} + +void TF_JoinThread(TF_Thread* thread) { + // ::tensorflow::Thread joins on destruction + delete reinterpret_cast<::tensorflow::Thread*>(thread); +} diff --git a/tensorflow/c/env.h b/tensorflow/c/env.h new file mode 100644 index 0000000000000000000000000000000000000000..15652353cd7e1f1e7d7a4c665703c0166682d790 --- /dev/null +++ b/tensorflow/c/env.h @@ -0,0 +1,194 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include + +#ifndef TENSORFLOW_C_ENV_H_ +#define TENSORFLOW_C_ENV_H_ + +#include "tensorflow/c/c_api.h" + +// -------------------------------------------------------------------------- +// C API for tensorflow::Env. + +struct TF_WritableFileHandle; +struct TF_StringStream; +struct TF_Thread; + +#ifdef __cplusplus +extern "C" { +#endif + +typedef struct TF_FileStatistics { + // The length of the file in bytes. + int64_t length; + // The last modified time in nanoseconds. + int64_t mtime_nsec; + // Whether the name refers to a directory. + bool is_directory; +} TF_FileStatistics; + +typedef struct TF_ThreadOptions { + // Thread stack size to use (in bytes), zero implies that the system default + // will be used. + size_t stack_size; + + // Guard area size to use near thread stacks to use (in bytes), zero implies + // that the system default will be used. + size_t guard_size; + + // The NUMA node to use, -1 implies that there should be no NUMA affinity for + // this thread. + int numa_node; +} TF_ThreadOptions; + +// Creates the specified directory. Typical status code are: +// * TF_OK - successfully created the directory +// * TF_ALREADY_EXISTS - directory already exists +// * TF_PERMISSION_DENIED - dirname is not writable +TF_CAPI_EXPORT extern void TF_CreateDir(const char* dirname, TF_Status* status); + +// Deletes the specified directory. Typical status codes are: +// * TF_OK - successfully deleted the directory +// * TF_FAILED_PRECONDITION - the directory is not empty +TF_CAPI_EXPORT extern void TF_DeleteDir(const char* dirname, TF_Status* status); + +// Deletes the specified directory and all subdirectories and files underneath +// it. This is accomplished by traversing the directory tree rooted at dirname +// and deleting entries as they are encountered. +// +// If dirname itself is not readable or does not exist, *undeleted_dir_count is +// set to 1, *undeleted_file_count is set to 0 and an appropriate status (e.g. +// TF_NOT_FOUND) is returned. +// +// If dirname and all its descendants were successfully deleted, TF_OK is +// returned and both error counters are set to zero. +// +// Otherwise, while traversing the tree, undeleted_file_count and +// undeleted_dir_count are updated if an entry of the corresponding type could +// not be deleted. The returned error status represents the reason that any one +// of these entries could not be deleted. +// +// Typical status codes: +// * TF_OK - dirname exists and we were able to delete everything underneath +// * TF_NOT_FOUND - dirname doesn't exist +// * TF_PERMISSION_DENIED - dirname or some descendant is not writable +// * TF_UNIMPLEMENTED - some underlying functions (like Delete) are not +// implemented +TF_CAPI_EXPORT extern void TF_DeleteRecursively(const char* dirname, + int64_t* undeleted_file_count, + int64_t* undeleted_dir_count, + TF_Status* status); + +// Obtains statistics for the given path. If status is TF_OK, *stats is +// updated, otherwise it is not touched. +TF_CAPI_EXPORT extern void TF_FileStat(const char* filename, + TF_FileStatistics* stats, + TF_Status* status); + +// Creates or truncates the given filename and returns a handle to be used for +// appending data to the file. If status is TF_OK, *handle is updated and the +// caller is responsible for freeing it (see TF_CloseWritableFile). +TF_CAPI_EXPORT extern void TF_NewWritableFile(const char* filename, + TF_WritableFileHandle** handle, + TF_Status* status); + +// Closes the given handle and frees its memory. If there was a problem closing +// the file, it is indicated by status. Memory is freed in any case. +TF_CAPI_EXPORT extern void TF_CloseWritableFile(TF_WritableFileHandle* handle, + TF_Status* status); + +// Syncs content of the handle to the filesystem. Blocks waiting for the +// filesystem to indicate that the content has been persisted. +TF_CAPI_EXPORT extern void TF_SyncWritableFile(TF_WritableFileHandle* handle, + TF_Status* status); + +// Flush local buffers to the filesystem. If the process terminates after a +// successful flush, the contents may still be persisted, since the underlying +// filesystem may eventually flush the contents. If the OS or machine crashes +// after a successful flush, the contents may or may not be persisted, depending +// on the implementation. +TF_CAPI_EXPORT extern void TF_FlushWritableFile(TF_WritableFileHandle* handle, + TF_Status* status); + +// Appends the given bytes to the file. Any failure to do so is indicated in +// status. +TF_CAPI_EXPORT extern void TF_AppendWritableFile(TF_WritableFileHandle* handle, + const char* data, + size_t length, + TF_Status* status); + +// Deletes the named file and indicates whether successful in *status. +TF_CAPI_EXPORT extern void TF_DeleteFile(const char* filename, + TF_Status* status); + +// Retrieves the next item from the given TF_StringStream and places a pointer +// to it in *result. If no more items are in the list, *result is set to NULL +// and false is returned. +// +// Ownership of the items retrieved with this function remains with the library. +// Item points are invalidated after a call to TF_StringStreamDone. +TF_CAPI_EXPORT extern bool TF_StringStreamNext(TF_StringStream* list, + const char** result); + +// Frees the resources associated with given string list. All pointers returned +// by TF_StringStreamNext are invalid after this call. +TF_CAPI_EXPORT extern void TF_StringStreamDone(TF_StringStream* list); + +// Retrieves the list of children of the given directory. You can iterate +// through the list with TF_StringStreamNext. The caller is responsible for +// freeing the list (see TF_StringStreamDone). +TF_CAPI_EXPORT extern TF_StringStream* TF_GetChildren(const char* filename, + TF_Status* status); + +// Retrieves a list of directory names on the local machine that may be used for +// temporary storage. You can iterate through the list with TF_StringStreamNext. +// The caller is responsible for freeing the list (see TF_StringStreamDone). +TF_CAPI_EXPORT extern TF_StringStream* TF_GetLocalTempDirectories(void); + +// Returns the number of nanoseconds since the Unix epoch. +TF_CAPI_EXPORT extern uint64_t TF_NowNanos(void); + +// Returns the number of microseconds since the Unix epoch. +TF_CAPI_EXPORT extern uint64_t TF_NowMicros(void); + +// Returns the number of seconds since the Unix epoch. +TF_CAPI_EXPORT extern uint64_t TF_NowSeconds(void); + +// Populates a TF_ThreadOptions struct with system-default values. +TF_CAPI_EXPORT extern void TF_DefaultThreadOptions(TF_ThreadOptions* options); + +// Returns a new thread that is running work_func and is identified +// (for debugging/performance-analysis) by thread_name. +// +// The given param (which may be null) is passed to work_func when the thread +// starts. In this way, data may be passed from the thread back to the caller. +// +// Caller takes ownership of the result and must call TF_JoinThread on it +// eventually. +TF_CAPI_EXPORT extern TF_Thread* TF_StartThread(const TF_ThreadOptions* options, + const char* thread_name, + void (*work_func)(void*), + void* param); + +// Waits for the given thread to finish execution, then deletes it. +TF_CAPI_EXPORT extern void TF_JoinThread(TF_Thread* thread); + +#ifdef __cplusplus +} +#endif + +#endif // TENSORFLOW_C_ENV_H_ diff --git a/tensorflow/c/env_test.cc b/tensorflow/c/env_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..687ad024137352662759ec1f43df87e89faca353 --- /dev/null +++ b/tensorflow/c/env_test.cc @@ -0,0 +1,127 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/c/env.h" + +#include "tensorflow/c/c_api.h" +#include "tensorflow/core/lib/io/path.h" +#include "tensorflow/core/platform/mutex.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/platform/types.h" + +#define ASSERT_TF_OK(x) ASSERT_EQ(TF_OK, TF_GetCode(x)) + +TEST(TestEnv, TestDirHandling) { + TF_StringStream* tempdirs = TF_GetLocalTempDirectories(); + const char* tempdir; + bool found = false; + while (TF_StringStreamNext(tempdirs, &tempdir)) { + found = true; + + TF_Status* s = TF_NewStatus(); + + ::tensorflow::string dirpath = + ::tensorflow::io::JoinPath(tempdir, "somedir"); + TF_CreateDir(dirpath.c_str(), s); + ASSERT_TF_OK(s) << "TF_CreateDir failed for " << dirpath << ": " + << TF_Message(s); + + ::tensorflow::string filepath = + ::tensorflow::io::JoinPath(dirpath, "somefile.txt"); + TF_WritableFileHandle* handle; + TF_NewWritableFile(filepath.c_str(), &handle, s); + ASSERT_TF_OK(s) << "NewWritableFile failed for " << filepath << ": " + << TF_Message(s); + + const char* data = "Hello, world!\n"; + TF_AppendWritableFile(handle, data, strlen(data), s); + ASSERT_TF_OK(s) << "TF_AppendWritableFile failed to append data to file at " + << filepath << ": " << TF_Message(s); + + TF_CloseWritableFile(handle, s); + ASSERT_TF_OK(s) << "TF_CloseWritableFile failed to close handle to " + << filepath << ": " << TF_Message(s); + + TF_StringStream* children = TF_GetChildren(dirpath.c_str(), s); + ASSERT_TF_OK(s) << "TF_GetChildren failed for " << dirpath; + const char* childpath; + ASSERT_TRUE(TF_StringStreamNext(children, &childpath)); + ASSERT_EQ(::tensorflow::string(childpath), "somefile.txt"); + // There should only be one file in this directory. + ASSERT_FALSE(TF_StringStreamNext(children, &childpath)); + ASSERT_EQ(childpath, nullptr); + TF_StringStreamDone(children); + + TF_FileStatistics stats; + TF_FileStat(filepath.c_str(), &stats, s); + ASSERT_EQ(stats.length, strlen(data)); + ASSERT_FALSE(stats.is_directory); + ASSERT_GT(stats.mtime_nsec, 0); + + // Trying to delete a non-empty directory should fail. + TF_DeleteDir(dirpath.c_str(), s); + ASSERT_NE(TF_OK, TF_GetCode(s)) + << "TF_DeleteDir unexpectedly succeeded with a non-empty directory " + << dirpath; + + TF_DeleteFile(filepath.c_str(), s); + ASSERT_TF_OK(s) << "TF_DeleteFile failed for " << filepath << ": " + << TF_Message(s); + + // Now deleting the directory should work. + TF_DeleteDir(dirpath.c_str(), s); + ASSERT_TF_OK(s) << "TF_DeleteDir failed for " << dirpath << ": " + << TF_Message(s); + + TF_DeleteStatus(s); + break; + } + + ASSERT_TRUE(found) << "expected at least one temp dir"; + + TF_StringStreamDone(tempdirs); +} + +TEST(TestEnv, TestTimeFunctions) { + ASSERT_GE(TF_NowSeconds(), 946684800); // Midnight Jan 1, 2000 + ASSERT_GE(TF_NowMicros(), 946684800 * 1e6); + ASSERT_GE(TF_NowNanos(), 946684800 * 1e9); +} + +namespace { + +struct SomeThreadData { + ::tensorflow::mutex mu; + bool did_work = false; +}; + +void SomeThreadFunc(void* data) { + auto* real_data = static_cast(data); + ::tensorflow::mutex_lock l(real_data->mu); + real_data->did_work = true; +} + +} // namespace + +TEST(TestEnv, TestThreads) { + TF_ThreadOptions options; + TF_DefaultThreadOptions(&options); + SomeThreadData data; + TF_Thread* thread = + TF_StartThread(&options, "SomeThreadName", &SomeThreadFunc, &data); + TF_JoinThread(thread); + ::tensorflow::mutex_lock l(data.mu); + ASSERT_TRUE(data.did_work); +} diff --git a/tensorflow/c/kernels.cc b/tensorflow/c/kernels.cc index ca69345264607ac689fb556b4f5c9bc08ea5eb88..2a4eaecb6cf2740a522b1e849d1306ebde6c4577 100644 --- a/tensorflow/c/kernels.cc +++ b/tensorflow/c/kernels.cc @@ -15,7 +15,9 @@ limitations under the License. #include +#include "tensorflow/c/c_api_internal.h" #include "tensorflow/c/kernels.h" +#include "tensorflow/c/tf_status_helper.h" #include "tensorflow/core/framework/kernel_def_builder.h" #include "tensorflow/core/framework/op_kernel.h" @@ -116,3 +118,43 @@ void TF_RegisterKernelBuilder(const char* name, TF_KernelBuilder* builder, TF_SetStatus(status, TF_OK, ""); } + +int TF_NumInputs(TF_OpKernelContext* ctx) { + auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelContext*>(ctx); + return cc_ctx->num_inputs(); +} + +int TF_NumOutputs(TF_OpKernelContext* ctx) { + auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelContext*>(ctx); + return cc_ctx->num_outputs(); +} + +void TF_GetInput(TF_OpKernelContext* ctx, int i, TF_Tensor** tensor, + TF_Status* status) { + auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelContext*>(ctx); + if (i < 0 || i >= cc_ctx->num_inputs()) { + TF_SetStatus(status, TF_OUT_OF_RANGE, "input index out of range"); + return; + } + const ::tensorflow::Tensor& cc_tensor(cc_ctx->input(i)); + TF_Tensor* result = ::tensorflow::TF_TensorFromTensor(cc_tensor, status); + if (TF_GetCode(status) == TF_OK) { + *tensor = result; + } +} + +void TF_SetOutput(TF_OpKernelContext* ctx, int i, const TF_Tensor* tensor, + TF_Status* status) { + auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelContext*>(ctx); + if (i < 0 || i >= cc_ctx->num_inputs()) { + TF_SetStatus(status, TF_OUT_OF_RANGE, "input index out of range"); + return; + } + ::tensorflow::Tensor cc_tensor; + ::tensorflow::Status s = ::tensorflow::TF_TensorToTensor(tensor, &cc_tensor); + TF_SetStatus(status, TF_OK, ""); + ::tensorflow::Set_TF_Status_from_Status(status, s); + if (s.ok()) { + cc_ctx->set_output(i, cc_tensor); + } +} diff --git a/tensorflow/c/kernels.h b/tensorflow/c/kernels.h index 2518789a3c141755d0b3373d53642c487331f68b..1a91aa184f11ac8e45b38a1d106c7b445747a7c1 100644 --- a/tensorflow/c/kernels.h +++ b/tensorflow/c/kernels.h @@ -85,6 +85,32 @@ TF_CAPI_EXPORT extern void TF_RegisterKernelBuilder(const char* kernel_name, // builder is not registered with TensorFlow via TF_RegisterKernelBuilder. TF_CAPI_EXPORT extern void TF_DeleteKernelBuilder(TF_KernelBuilder* builder); +// -------------------------------------------------------------------------- +// OpKernelContext routines + +// TF_NumInputs returns the number of inputs available in ctx. +TF_CAPI_EXPORT extern int TF_NumInputs(TF_OpKernelContext* ctx); + +// TF_NumOutputs returns the number of outputs to be placed in *ctx by the +// kernel. +TF_CAPI_EXPORT extern int TF_NumOutputs(TF_OpKernelContext* ctx); + +// Retrieves the ith input from ctx. If TF_GetCode(status) is TF_OK, *tensor is +// populated and its ownership is passed to the caller. In any other case, +// *tensor is not modified. +// +// If i < 0 or i >= TF_NumInputs(ctx), *status is set to TF_OUT_OF_RANGE. +TF_CAPI_EXPORT extern void TF_GetInput(TF_OpKernelContext* ctx, int i, + TF_Tensor** tensor, TF_Status* status); + +// Sets the ith output of ctx to tensor. If TF_GetCode(status) is anything but +// TF_OK, ctx is left unmodified. +// +// If i < 0 or i >= TF_NumOutputs(ctx), *status is set to TF_OUT_OF_RANGE. +TF_CAPI_EXPORT extern void TF_SetOutput(TF_OpKernelContext* ctx, int i, + const TF_Tensor* tensor, + TF_Status* status); + #ifdef __cplusplus } /* end extern "C" */ #endif diff --git a/tensorflow/c/kernels_test.cc b/tensorflow/c/kernels_test.cc index e706c7c1d96ee1781d8efc0f28c5e0cbcbc80861..e659ee3c3d258a626ccf03a782ec031b5a703a48 100644 --- a/tensorflow/c/kernels_test.cc +++ b/tensorflow/c/kernels_test.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/c/kernels.h" +#include "tensorflow/c/c_api.h" #include "tensorflow/core/framework/kernel_def.pb.h" #include "tensorflow/core/framework/node_def.pb_text.h" #include "tensorflow/core/framework/op.h" @@ -31,7 +32,6 @@ struct MyCustomKernel { static bool delete_called = false; static void* MyCreateFunc(TF_OpKernelConstruction* ctx) { - LOG(INFO) << "Wow, actually got into creation"; struct MyCustomKernel* s = new struct MyCustomKernel; s->created = true; s->compute_called = false; @@ -51,12 +51,31 @@ static void MyDeleteFunc(void* kernel) { delete s; } +namespace tensorflow { + +static std::unique_ptr GetFakeKernel(const char* device_name, + const char* op_name, + Status* status) { + NodeDef def; + def.set_op(op_name); + def.set_device(device_name); + def.add_input("input1"); + def.add_input("input2"); + return CreateOpKernel(DeviceType(device_name), nullptr, nullptr, def, 1, + status); +} + // Tests registration of a single C kernel and checks that calls through the // C/C++ boundary are being made. TEST(TestKernel, TestRegisterKernelBuilder) { const char* kernel_name = "SomeKernelName"; const char* op_name = "FooOp"; - const char* device_name = "barDev"; + const char* device_name = "FakeDeviceName1"; + + REGISTER_OP(op_name) + .Input("input1: double") + .Input("input2: uint8") + .Output("output1: uint8"); TF_KernelBuilder* builder = TF_NewKernelBuilder( op_name, device_name, &MyCreateFunc, &MyComputeFunc, &MyDeleteFunc); @@ -65,35 +84,120 @@ TEST(TestKernel, TestRegisterKernelBuilder) { TF_Status* status = TF_NewStatus(); TF_RegisterKernelBuilder(kernel_name, builder, status); EXPECT_EQ(TF_OK, TF_GetCode(status)); - TF_Buffer* buf = TF_GetRegisteredKernelsForOp("FooOp", status); + TF_Buffer* buf = TF_GetRegisteredKernelsForOp(op_name, status); EXPECT_EQ(TF_OK, TF_GetCode(status)); - ::tensorflow::KernelList list; + KernelList list; list.ParseFromArray(buf->data, buf->length); ASSERT_EQ(1, list.kernel_size()); - ASSERT_EQ("barDev", list.kernel(0).device_type()); + ASSERT_EQ(device_name, list.kernel(0).device_type()); TF_DeleteBuffer(buf); TF_DeleteStatus(status); } - REGISTER_OP("FooOp") + { + Status status; + std::unique_ptr kernel = + GetFakeKernel(device_name, op_name, &status); + TF_EXPECT_OK(status); + ASSERT_NE(nullptr, kernel.get()); + kernel->Compute(nullptr); + } + + ASSERT_TRUE(delete_called); +} + +class DummyDevice : public DeviceBase { + public: + DummyDevice(Env* env, bool save) : DeviceBase(env), save_(save) {} + bool RequiresRecordingAccessedTensors() const override { return save_; } + Allocator* GetAllocator(AllocatorAttributes /*attr*/) override { + return cpu_allocator(); + } + + private: + bool save_; +}; + +TEST(TestKernel, TestInputAndOutputCount) { + const char* kernel_name = "InputOutputCounterKernel"; + const char* op_name = "BarOp"; + const char* device_name = "FakeDeviceName2"; + + REGISTER_OP(op_name) .Input("input1: double") .Input("input2: uint8") .Output("output1: uint8"); + static int num_inputs = 0; + static int num_outputs = 0; + + // A kernel whose Compute function has a side-effect of updating num_inputs + // and num_outputs. Various functions on TF_OpKernelContext are also + // exercised. + auto my_compute_func = [](void* kernel, TF_OpKernelContext* ctx) { + num_inputs = TF_NumInputs(ctx); + num_outputs = TF_NumOutputs(ctx); + + TF_Tensor* input = nullptr; + TF_Status* s = TF_NewStatus(); + TF_GetInput(ctx, 0, &input, s); + EXPECT_EQ(TF_OK, TF_GetCode(s)) << "Failed to get input: " << TF_Message(s); + EXPECT_EQ(123, *static_cast(TF_TensorData(input))); + TF_GetInput(ctx, -1, &input, s); + EXPECT_EQ(TF_OUT_OF_RANGE, TF_GetCode(s)); + TF_GetInput(ctx, 3, &input, s); + EXPECT_EQ(TF_OUT_OF_RANGE, TF_GetCode(s)); + + // Copy the input tensor to output. + TF_SetOutput(ctx, 0, input, s); + EXPECT_EQ(TF_OK, TF_GetCode(s)); + + TF_SetOutput(ctx, 24, input, s); + EXPECT_EQ(TF_OUT_OF_RANGE, TF_GetCode(s)); + + TF_DeleteStatus(s); + if (input != nullptr) { + TF_DeleteTensor(input); + } + }; + + TF_KernelBuilder* builder = TF_NewKernelBuilder(op_name, device_name, nullptr, + my_compute_func, nullptr); + + { + TF_Status* status = TF_NewStatus(); + TF_RegisterKernelBuilder(kernel_name, builder, status); + EXPECT_EQ(TF_OK, TF_GetCode(status)); + TF_DeleteStatus(status); + } + { - ::tensorflow::NodeDef def; - def.set_op("FooOp"); - def.set_device("bar"); - def.add_input("input1"); - def.add_input("input2"); - ::tensorflow::Status status; - std::unique_ptr<::tensorflow::OpKernel> kernel = - ::tensorflow::CreateOpKernel(::tensorflow::DeviceType("barDev"), - nullptr, nullptr, def, 1, &status); + OpKernelContext::Params p; + DummyDevice dummy_device(nullptr, false); + p.device = &dummy_device; + + Tensor t(tensorflow::uint8(123)); + + gtl::InlinedVector inputs; + // Simulate 2 inputs + inputs.emplace_back(&t); + inputs.emplace_back(); + p.inputs = &inputs; + + Status status; + std::unique_ptr kernel = + GetFakeKernel(device_name, op_name, &status); TF_EXPECT_OK(status); ASSERT_NE(nullptr, kernel.get()); - kernel->Compute(nullptr); - } - ASSERT_TRUE(delete_called); + p.op_kernel = kernel.get(); + OpKernelContext ctx(&p); + kernel->Compute(&ctx); + + ASSERT_EQ(2, num_inputs); + ASSERT_EQ(1, num_outputs); + ASSERT_EQ(123, ctx.mutable_output(0)->scalar()()); + } } + +} // namespace tensorflow diff --git a/tensorflow/c/python_api.cc b/tensorflow/c/python_api.cc index 247236b760dd8c07bbb08426100b6a4d34296d2e..98d8393332269ae349cf8aa5c0b612c6f17172e6 100644 --- a/tensorflow/c/python_api.cc +++ b/tensorflow/c/python_api.cc @@ -160,4 +160,17 @@ void SetHandleShapeAndType(TF_Graph* graph, TF_Output output, const void* proto, ic->set_output_handle_shapes_and_types(output.index, shapes_and_types); } +void AddWhileInputHack(TF_Graph* graph, TF_Output new_src, TF_Operation* dst, + TF_Status* status) { + mutex_lock l(graph->mu); + status->status = graph->graph.AddWhileInputHack(&new_src.oper->node, + new_src.index, &dst->node); + if (status->status.ok()) { + // This modification only updates the destination node for + // the purposes of running this graph in a session. Thus, we don't + // record the source node as being modified. + RecordMutation(graph, *dst, "adding input tensor"); + } +} + } // namespace tensorflow diff --git a/tensorflow/c/python_api.h b/tensorflow/c/python_api.h index 5cce84020bc68d912d259f51512341eb5f464a2c..44779ca656165dd65590cb5e9ea3ccf71165ed63 100644 --- a/tensorflow/c/python_api.h +++ b/tensorflow/c/python_api.h @@ -34,6 +34,7 @@ void SetAttr(TF_Graph* graph, TF_Operation* op, const char* attr_name, void SetRequestedDevice(TF_Graph* graph, TF_Operation* op, const char* device); +// Updates 'dst' to consume 'new_src'. void UpdateEdge(TF_Graph* graph, TF_Output new_src, TF_Input dst, TF_Status* status); @@ -65,6 +66,13 @@ std::string GetHandleShapeAndType(TF_Graph* graph, TF_Output output); // because I couldn't get SWIG to work otherwise. void SetHandleShapeAndType(TF_Graph* graph, TF_Output output, const void* proto, size_t proto_len, TF_Status* status); + +// This method is used to add a new input edge to 'dst', which must be a While +// op. The While op's "T" attribute must have already been updated to include +// the new edge. This is used to construct tf.while_loop gradients. +void AddWhileInputHack(TF_Graph* graph, TF_Output new_src, TF_Operation* dst, + TF_Status* status); + } // namespace tensorflow #endif // TENSORFLOW_C_PYTHON_API_H_ diff --git a/tensorflow/compat_template_v1.__init__.py b/tensorflow/compat_template_v1.__init__.py index 7df80ec01245a7fe820c79d5879458c4cd0a93cb..d58acde09f007bc9df40b08b0ef79c6031ca7941 100644 --- a/tensorflow/compat_template_v1.__init__.py +++ b/tensorflow/compat_template_v1.__init__.py @@ -23,12 +23,12 @@ import os as _os # pylint: disable=g-bad-import-order from tensorflow.python import pywrap_tensorflow # pylint: disable=unused-import +# API IMPORTS PLACEHOLDER + from tensorflow.python.tools import component_api_helper as _component_api_helper _component_api_helper.package_hook( parent_package_str=__name__, child_package_str=('tensorflow_estimator.python.estimator.api.estimator')) -# API IMPORTS PLACEHOLDER - from tensorflow.python.platform import flags # pylint: disable=g-import-not-at-top app.flags = flags # pylint: disable=undefined-variable diff --git a/tensorflow/compiler/aot/codegen.cc b/tensorflow/compiler/aot/codegen.cc index e0ac7130a64d3928c39440c0e10a2d2e1990b9cd..ab1c1be344e2257721507543bc7647d4ff4becb2 100644 --- a/tensorflow/compiler/aot/codegen.cc +++ b/tensorflow/compiler/aot/codegen.cc @@ -178,7 +178,7 @@ Status GenArgMethods(const tf2xla::Config& config, TF_RETURN_IF_ERROR( AddRewritesForShape(i, xla::Shape(ps.parameters(i)), &rewrites)); const string code = R"( - void set_arg{{NAME}}_data(void* data) { + void set_arg{{NAME}}_data(const void* data) { set_arg_data({{I}}, data); } {{TYPE}}* arg{{NAME}}_data() { diff --git a/tensorflow/compiler/aot/codegen_test_h.golden b/tensorflow/compiler/aot/codegen_test_h.golden index a2cdab5d1a8e72504ca11b789287d4efd07a59e9..968afad65ed6d4b5510687df484b7ce6743f6a85 100644 --- a/tensorflow/compiler/aot/codegen_test_h.golden +++ b/tensorflow/compiler/aot/codegen_test_h.golden @@ -114,7 +114,7 @@ class MyClass : public tensorflow::XlaCompiledCpuFunction { // with dim indices specifying which value. No bounds checking is performed // on dim indices. - void set_arg0_data(void* data) { + void set_arg0_data(const void* data) { set_arg_data(0, data); } float* arg0_data() { @@ -132,7 +132,7 @@ class MyClass : public tensorflow::XlaCompiledCpuFunction { arg_data(0)))[dim0][dim1]; } - void set_arg_myfeed_data(void* data) { + void set_arg_myfeed_data(const void* data) { set_arg_data(0, data); } float* arg_myfeed_data() { @@ -150,7 +150,7 @@ class MyClass : public tensorflow::XlaCompiledCpuFunction { arg_data(0)))[dim0][dim1]; } - void set_arg1_data(void* data) { + void set_arg1_data(const void* data) { set_arg_data(1, data); } tensorflow::int64* arg1_data() { diff --git a/tensorflow/compiler/aot/tfcompile.bzl b/tensorflow/compiler/aot/tfcompile.bzl index 2dc3e8c9113b37bf9d575ad66783f4ab49478af4..4051664c24cacad4a2d151ad3ac9009015900609 100644 --- a/tensorflow/compiler/aot/tfcompile.bzl +++ b/tensorflow/compiler/aot/tfcompile.bzl @@ -283,7 +283,7 @@ def tf_library( ) # Variables used for gen_test and gen_benchmark. - cpp_class_split = cpp_class.rsplit("::", maxsplit = 2) + cpp_class_split = cpp_class.rsplit("::", 2) if len(cpp_class_split) == 1: no_ns_name = cpp_class_split[0] else: diff --git a/tensorflow/compiler/jit/BUILD b/tensorflow/compiler/jit/BUILD index be91ed4f432b1890c22900f293fd4196e5c9d970..d8c88a9fca2db74265b4962e07a66ab214b1d994 100644 --- a/tensorflow/compiler/jit/BUILD +++ b/tensorflow/compiler/jit/BUILD @@ -76,6 +76,7 @@ cc_library( srcs = ["xla_cpu_device.cc"], visibility = [":friends"], deps = [ + ":create_xla_launch_op", # buildcleaner: keep ":flags", ":jit_compilation_passes", ":xla_device", @@ -95,6 +96,7 @@ cc_library( srcs = ["xla_gpu_device.cc"], visibility = [":friends"], deps = [ + ":create_xla_launch_op", # buildcleaner: keep ":jit_compilation_passes", ":xla_device", "//tensorflow/compiler/jit/kernels:xla_ops", @@ -104,6 +106,7 @@ cc_library( "//tensorflow/core:core_cpu_internal", "//tensorflow/core:lib", "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", ], alwayslink = 1, ) @@ -512,6 +515,7 @@ cc_library( "//tensorflow/compiler/jit/ops:xla_ops", "//tensorflow/compiler/tf2xla:dump_graph", "//tensorflow/compiler/tf2xla:resource_operation_table", + "//tensorflow/compiler/tf2xla:side_effect_util", "//tensorflow/compiler/tf2xla:tf2xla_util", "//tensorflow/compiler/tf2xla:xla_compiler", "//tensorflow/compiler/tf2xla/cc:xla_jit_ops", @@ -610,6 +614,7 @@ tf_cc_test( "//tensorflow/cc:cc_ops", "//tensorflow/cc:cc_ops_internal", "//tensorflow/cc:function_ops", + "//tensorflow/cc:functional_ops", "//tensorflow/cc:ops", "//tensorflow/cc:resource_variable_ops", "//tensorflow/cc:scope", @@ -622,6 +627,7 @@ tf_cc_test( "//tensorflow/compiler/tf2xla/cc:xla_ops", "//tensorflow/compiler/tf2xla/kernels:xla_dummy_ops", "//tensorflow/compiler/tf2xla/kernels:xla_ops", + "//tensorflow/compiler/xla:test", "//tensorflow/core:core_cpu", "//tensorflow/core:framework", "//tensorflow/core:framework_internal", diff --git a/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc b/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc index f478832781cb1dc045d9163d4a6f5e5f64a8a705..03aba97bbe81a11f6366d118ee5bc573d0c6b31b 100644 --- a/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc +++ b/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc @@ -779,7 +779,8 @@ Status Encapsulator::Subgraph::RecordArg( if (inserted) { NodeDef arg_def; NodeDefBuilder builder( - absl::StrCat(src_node->name(), "_", src_slot, "_arg"), kArgOp); + absl::StrCat(src_node->name(), "_", src_slot, "_arg"), kArgOp, + NodeDebugInfo(src_node->def())); DataType dtype = edge->dst()->input_type(edge->dst_input()); builder.Attr("T", dtype); builder.Attr("index", arg_index); @@ -814,7 +815,8 @@ Status Encapsulator::Subgraph::RecordResult( if (inserted) { NodeDef ret_def; NodeDefBuilder builder( - absl::StrCat(src_node->name(), "_", src_slot, "_retval"), kRetValOp); + absl::StrCat(src_node->name(), "_", src_slot, "_retval"), kRetValOp, + NodeDebugInfo(src_node->def())); DataType dtype = src_node->output_type(src_slot); builder.Attr("T", dtype); builder.Attr("index", ret_index); @@ -974,6 +976,7 @@ Status Encapsulator::Subgraph::AddHostComputes( } NodeDef host_compute_def; + // TODO(shikharagarwal): What source node should we use for errors? NodeDefBuilder builder(absl::StrCat("outside_compilation_", oc_subgraph_name, "_host_compute"), kHostComputeOp); @@ -1040,6 +1043,7 @@ Status Encapsulator::Subgraph::MakeSequencingNode(const string& subgraph_name, Graph* graph_out) { if (sequencer_ == nullptr) { NodeDef seq_def; + // TODO(shikharagarwal): What source node should we use for errors? NodeDefBuilder builder(absl::StrCat(subgraph_name, "_sequencer"), "NoOp"); builder.Attr(kXlaHostTransferSequencerAttr, subgraph_name); builder.Device(device_); @@ -1214,7 +1218,8 @@ Status Encapsulator::Subgraph::AddHostComputeKeyPlaceholder( GraphDefBuilder::Options options(graph_out, /*status=*/nullptr); NodeDef key_def; NodeDefBuilder builder( - absl::StrCat(call_node_def_.name(), "_key_placeholder"), "Placeholder"); + absl::StrCat(call_node_def_.name(), "_key_placeholder"), "Placeholder", + NodeDebugInfo(call_node_def_)); builder.Attr("dtype", DT_STRING); builder.Attr("shape", shape_proto); builder.Attr("_host_compute_call_node", call_node_def_.name()); @@ -1248,6 +1253,7 @@ Status Encapsulator::Subgraph::AddRecvAtHostNode( } NodeDef recv_def; + // TODO(shikharagarwal): What source node should we use for errors? NodeDefBuilder builder(absl::StrCat("outside_compilation_", subgraph_name, "_", oc_subgraph_name, "_recv"), kRecvAtHostOp); @@ -1303,6 +1309,7 @@ Status Encapsulator::Subgraph::AddSendFromHostNode( } NodeDef send_def; + // TODO(shikharagarwal): What source node should we use for errors? NodeDefBuilder builder(absl::StrCat("outside_compilation_", subgraph_name, "_", oc_subgraph_name, "_send"), kSendFromHostOp); @@ -1833,8 +1840,9 @@ Node* AddDummyShapedNode(const Node* src_node, int src_port, // Add any Enter nodes required to bring the constant to the correct control // flow frame. while (!control_flow_info[src_node->id()].frame_name.empty()) { + NodeDebugInfo debug_info(*src_node); NodeBuilder enter_builder(options.GetNameForOp("Enter"), "Enter", - options.op_registry()); + options.op_registry(), &debug_info); enter_builder.Attr("frame_name", control_flow_info[src_node->id()].frame_name); enter_builder.Attr("is_constant", true); @@ -2018,7 +2026,8 @@ Status Encapsulator::DoStaticShapeInferenceForOutsideCompilationSend( return errors::InvalidArgument( "Shape inference is not possible for outside_compilation " "SendFromHost node ", - send_node->name(), " because shape of node ", n->name(), + send_node->name(), " because shape of node ", + FormatNodeForError(*n), " will not be known at compilation time."); } } @@ -2047,8 +2056,7 @@ Status Encapsulator::DoStaticShapeInferenceForOutsideCompilationSend( return errors::Internal( "Internal assumption failed while rewriting an outside_compilation " "cluster that contains a while loop. Logic assumes back-edge is to " - "port 1 of a 2-input " - "Merge node."); + "port 1 of a 2-input Merge node."); } // Connect the existing edge to both inputs of the Merge node so that the // graph will be well-formed. diff --git a/tensorflow/compiler/jit/encapsulate_subgraphs_pass_test.cc b/tensorflow/compiler/jit/encapsulate_subgraphs_pass_test.cc index de89be9a3555960dabe7bacd17226c15ae888ae6..8617beec004d0fe912155f054442c5b6249bb6b5 100644 --- a/tensorflow/compiler/jit/encapsulate_subgraphs_pass_test.cc +++ b/tensorflow/compiler/jit/encapsulate_subgraphs_pass_test.cc @@ -299,7 +299,7 @@ REGISTER_OP("XlaHostCompute") .Attr("Toutputs: list(type) >= 0") .Attr("ancestors: list(string) >= 0") .Attr("key: string") - .Attr("shape_inference_graph: string = ''") + .Attr("shape_inference_graph: func") .Attr("shapes: list(shape) >= 0") .SetShapeFn(::tensorflow::shape_inference::UnknownShape); @@ -510,11 +510,7 @@ Status Encapsulate(GraphDef* graphdef, FunctionDefLibrary* library, s = ConvertGraphDefToGraph(options, *graphdef, graph.get()); if (!s.ok()) return s; - s = PerformStaticShapeInferenceBeforeEncapsulation( - graph.get(), "_encapsulate", "_outside"); - if (!s.ok()) return s; - - s = PreprocessForEncapsulation(graph.get(), "_encapsulate", "_outside"); + s = PerformStaticShapeInferenceBeforeEncapsulation(graph.get()); if (!s.ok()) return s; std::unique_ptr graph_out; @@ -550,6 +546,14 @@ Status Encapsulate(GraphDef* graphdef, FunctionDefLibrary* library, graphdef->Swap(&graphdef_out); *library = lib_def->ToProto(); + // Remove "_xla_inferred_shapes" attr. They are added by + // `PerformStaticShapeInferenceBeforeEncapsulation`. + for (FunctionDef& fdef : *library->mutable_function()) { + for (NodeDef& node_def : *fdef.mutable_node_def()) { + node_def.mutable_attr()->erase("_xla_inferred_shapes"); + } + } + return s; } @@ -901,18 +905,22 @@ TEST(EncapsulateSubgraphsTest, OneFunctionOneOutside) { { GraphDefBuilder shape(GraphDefBuilder::kFailImmediately); Node* key_constant = KeyPlaceholder("F1", shape.opts()); - Node* recv = RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O1", - {DT_FLOAT, DT_FLOAT}, shape.opts()); + Node* recv = RecvAtHost( + ops::NodeOut(key_constant, 0), "F1", "O1", {DT_FLOAT, DT_FLOAT}, + shape.opts().WithAttr(kXlaHasHostTransferAttrName, true)); Node* e = Binary(ops::NodeOut(recv, 0), ops::NodeOut(recv, 1), shape.opts() .WithName("E") .WithAttr("_encapsulate", "F1") .WithAttr("_outside", "O1")); - SendFromHost(ops::NodeOut(key_constant, 0), "F1", "O1", {e}, shape.opts()); + SendFromHost(ops::NodeOut(key_constant, 0), "F1", "O1", {e}, + shape.opts().WithAttr(kXlaHasHostTransferAttrName, true)); TF_EXPECT_OK( AddGraphDefToFunctionLibrary(shape, "F1_O1", &library_expected)); } + NameAttrList shape_inference_graph; + shape_inference_graph.set_name("_outside_compilation_shape_inference_F1_O1"); *library_expected.add_function() = test::function::XTimesTwo(); *library_expected.add_function() = FunctionDefHelper::Create( "F1", {"a_0_arg:float", "b_0_arg:float"}, {"f_0_retval_retval:float"}, {}, @@ -931,8 +939,7 @@ TEST(EncapsulateSubgraphsTest, OneFunctionOneOutside) { {"Toutputs", absl::Span({DT_FLOAT})}, {"ancestors", absl::Span({})}, {"key", "host_compute_channel_F1_O1"}, - {"shape_inference_graph", - "_outside_compilation_shape_inference_F1_O1"}, + {"shape_inference_graph", shape_inference_graph}, {"shapes", absl::Span({})}, {"_outside_compilation_subgraph", "O1"}}, {"c"}}, @@ -948,16 +955,18 @@ TEST(EncapsulateSubgraphsTest, OneFunctionOneOutside) { Node* key_constant = KeyPlaceholder("F1", b2.opts().WithName("F1_key_placeholder")); - Node* recv = RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O1", - {DT_FLOAT, DT_FLOAT}, b2.opts()); + Node* recv = RecvAtHost( + ops::NodeOut(key_constant, 0), "F1", "O1", {DT_FLOAT, DT_FLOAT}, + b2.opts().WithAttr(kXlaHasHostTransferAttrName, true)); Node* e = Binary(ops::NodeOut(recv, 0), ops::NodeOut(recv, 1), b2.opts() .WithName("E") - .WithControlInputs({recv, b}) + .WithControlInputs({recv}) .WithAttr("_encapsulate", "F1") .WithAttr("_outside", "O1")); Node* send = SendFromHost(ops::NodeOut(key_constant, 0), "F1", "O1", {e}, - b2.opts().WithControlInput(e)); + b2.opts().WithControlInput(e).WithAttr( + kXlaHasHostTransferAttrName, true)); Node* s = Sequencer( b2.opts().WithName("F1_sequencer").WithControlInputs({recv, send}), @@ -966,9 +975,9 @@ TEST(EncapsulateSubgraphsTest, OneFunctionOneOutside) { NodeBuilder node_builder("F1", "F1", lib_def.get()); node_builder.Input(a).Input(b); Node* call = - b2.opts().WithControlInputs({s}).FinalizeBuilder(&node_builder); + b2.opts().WithControlInputs({s, b}).FinalizeBuilder(&node_builder); - Binary(a, call, b2.opts().WithName("G").WithControlInputs({e})); + Binary(a, call, b2.opts().WithName("G").WithControlInputs({call})); TF_EXPECT_OK(b2.ToGraphDef(&graphdef_expected)); } @@ -1022,14 +1031,16 @@ TEST(EncapsulateSubgraphsTest, OneFunctionTwoOutside) { { GraphDefBuilder shape1(GraphDefBuilder::kFailImmediately); Node* key_constant = KeyPlaceholder("F1", shape1.opts()); - Node* recv = RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O1", - {DT_FLOAT, DT_FLOAT}, shape1.opts()); + Node* recv = RecvAtHost( + ops::NodeOut(key_constant, 0), "F1", "O1", {DT_FLOAT, DT_FLOAT}, + shape1.opts().WithAttr(kXlaHasHostTransferAttrName, true)); Node* e = Binary(ops::NodeOut(recv, 0), ops::NodeOut(recv, 1), shape1.opts() .WithName("E") .WithAttr("_encapsulate", "F1") .WithAttr("_outside", "O1")); - SendFromHost(ops::NodeOut(key_constant, 0), "F1", "O1", {e}, shape1.opts()); + SendFromHost(ops::NodeOut(key_constant, 0), "F1", "O1", {e}, + shape1.opts().WithAttr(kXlaHasHostTransferAttrName, true)); TF_EXPECT_OK( AddGraphDefToFunctionLibrary(shape1, "F1_O1", &library_expected)); } @@ -1037,33 +1048,45 @@ TEST(EncapsulateSubgraphsTest, OneFunctionTwoOutside) { { GraphDefBuilder shape2(GraphDefBuilder::kFailImmediately); Node* key_constant = KeyPlaceholder("F1", shape2.opts()); - Node* recv1 = RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O1", - {DT_FLOAT, DT_FLOAT}, shape2.opts()); + Node* recv1 = RecvAtHost( + ops::NodeOut(key_constant, 0), "F1", "O1", {DT_FLOAT, DT_FLOAT}, + shape2.opts().WithAttr(kXlaHasHostTransferAttrName, true)); Node* e = Binary(ops::NodeOut(recv1, 0), ops::NodeOut(recv1, 1), shape2.opts() .WithName("E") .WithAttr("_encapsulate", "F1") .WithAttr("_outside", "O1")); - Node* recv2 = RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O2", - {DT_FLOAT, DT_FLOAT}, shape2.opts()); + Node* recv2 = RecvAtHost( + ops::NodeOut(key_constant, 0), "F1", "O2", {DT_FLOAT, DT_FLOAT}, + shape2.opts().WithAttr(kXlaHasHostTransferAttrName, true)); + Node* g = Binary(e, ops::NodeOut(recv2, 0), + shape2.opts() + .WithName("G") + .WithAttr("_encapsulate", "F1") + .WithAttr("_outside", "O2")); Node* h = Binary(ops::NodeOut(recv2, 1), e, shape2.opts() .WithName("H") .WithAttr("_encapsulate", "F1") .WithAttr("_outside", "O2")); - SendFromHost(ops::NodeOut(key_constant, 0), "F1", "O2", {h}, shape2.opts()); + SendFromHost(ops::NodeOut(key_constant, 0), "F1", "O2", {g, h}, + shape2.opts().WithAttr(kXlaHasHostTransferAttrName, true)); TF_EXPECT_OK( AddGraphDefToFunctionLibrary(shape2, "F1_O2", &library_expected)); } + NameAttrList shape_inference_graph1, shape_inference_graph2; + shape_inference_graph1.set_name("_outside_compilation_shape_inference_F1_O1"); + shape_inference_graph2.set_name("_outside_compilation_shape_inference_F1_O2"); *library_expected.add_function() = FunctionDefHelper::Create( - "F1", {"a_0_arg:float", "b_0_arg:float"}, {"i_0_retval_retval:float"}, {}, + "F1", {"a_0_arg:float", "b_0_arg:float"}, + {"g_0_retval_retval:float", "i_0_retval_retval:float"}, {}, { {{"C"}, "UnaryTest", {"a_0_arg"}}, {{"D"}, "BinaryTest", {"b_0_arg", "C:o:0"}, {}}, {{"I"}, "UnaryTest", - {"outside_compilation_O2_host_compute:outputs:0"}}, + {"outside_compilation_O2_host_compute:outputs:1"}}, {{"F"}, "BinaryTest", {"C:o:0", "outside_compilation_O1_host_compute:outputs:0"}, @@ -1073,11 +1096,10 @@ TEST(EncapsulateSubgraphsTest, OneFunctionTwoOutside) { "XlaHostCompute", {"F:o:0", "D:o:0"}, {{"Tinputs", absl::Span({DT_FLOAT, DT_FLOAT})}, - {"Toutputs", absl::Span({DT_FLOAT})}, + {"Toutputs", absl::Span({DT_FLOAT, DT_FLOAT})}, {"ancestors", absl::Span({})}, {"key", "host_compute_channel_F1_O2"}, - {"shape_inference_graph", - "_outside_compilation_shape_inference_F1_O2"}, + {"shape_inference_graph", shape_inference_graph2}, {"shapes", absl::Span({})}, {"_outside_compilation_subgraph", "O2"}}, {"F"}}, @@ -1088,13 +1110,13 @@ TEST(EncapsulateSubgraphsTest, OneFunctionTwoOutside) { {"Toutputs", absl::Span({DT_FLOAT})}, {"ancestors", absl::Span({})}, {"key", "host_compute_channel_F1_O1"}, - {"shape_inference_graph", - "_outside_compilation_shape_inference_F1_O1"}, + {"shape_inference_graph", shape_inference_graph1}, {"shapes", absl::Span({})}, {"_outside_compilation_subgraph", "O1"}}, {"D"}}, }, - {{"i_0_retval_retval", "I:o:0"}}); + {{"g_0_retval_retval", "outside_compilation_O2_host_compute:outputs:0"}, + {"i_0_retval_retval", "I:o:0"}}); { std::unique_ptr lib_def( @@ -1105,19 +1127,22 @@ TEST(EncapsulateSubgraphsTest, OneFunctionTwoOutside) { Node* key_constant = KeyPlaceholder("F1", b2.opts().WithName("F1_key_placeholder")); - Node* recv1 = RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O1", - {DT_FLOAT, DT_FLOAT}, b2.opts()); + Node* recv1 = RecvAtHost( + ops::NodeOut(key_constant, 0), "F1", "O1", {DT_FLOAT, DT_FLOAT}, + b2.opts().WithAttr(kXlaHasHostTransferAttrName, true)); Node* e = Binary(ops::NodeOut(recv1, 0), ops::NodeOut(recv1, 1), b2.opts() .WithName("E") - .WithControlInputs({recv1, b}) + .WithControlInputs({recv1}) .WithAttr("_encapsulate", "F1") .WithAttr("_outside", "O1")); Node* send1 = SendFromHost(ops::NodeOut(key_constant, 0), "F1", "O1", {e}, - b2.opts().WithControlInput(e)); + b2.opts().WithControlInput(e).WithAttr( + kXlaHasHostTransferAttrName, true)); - Node* recv2 = RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O2", - {DT_FLOAT, DT_FLOAT}, b2.opts()); + Node* recv2 = RecvAtHost( + ops::NodeOut(key_constant, 0), "F1", "O2", {DT_FLOAT, DT_FLOAT}, + b2.opts().WithAttr(kXlaHasHostTransferAttrName, true)); Node* g = Binary(e, ops::NodeOut(recv2, 0), b2.opts() .WithName("G") @@ -1130,7 +1155,8 @@ TEST(EncapsulateSubgraphsTest, OneFunctionTwoOutside) { .WithAttr("_encapsulate", "F1") .WithAttr("_outside", "O2")); Node* send2 = - SendFromHost(ops::NodeOut(key_constant, 0), "F1", "O2", {h}, b2.opts()); + SendFromHost(ops::NodeOut(key_constant, 0), "F1", "O2", {g, h}, + b2.opts().WithAttr(kXlaHasHostTransferAttrName, true)); Node* s = Sequencer(b2.opts() .WithName("F1_sequencer") @@ -1139,12 +1165,13 @@ TEST(EncapsulateSubgraphsTest, OneFunctionTwoOutside) { NodeBuilder node_builder("F1", "F1", lib_def.get()); node_builder.Input(a).Input(b); - Node* call = b2.opts().WithControlInput(s).FinalizeBuilder(&node_builder); + Node* call = + b2.opts().WithControlInputs({s, b}).FinalizeBuilder(&node_builder); - Binary(g, call, b2.opts().WithName("J")); + Binary(ops::NodeOut(call, 0), ops::NodeOut(call, 1), + b2.opts().WithName("J")); TF_EXPECT_OK(b2.ToGraphDef(&graphdef_expected)); } - TF_EXPECT_GRAPH_EQ(graphdef_expected, graphdef); TF_EXPECT_FUNCTIONDEFLIBRARY_EQ(library_expected, library); } @@ -1196,7 +1223,9 @@ TEST(EncapsulateSubgraphsTest, TwoFunctionsTwoOutside) { *library_expected.add_function() = FunctionDefHelper::Create( "F1", {"a_0_arg:float", "b_0_arg:float"}, - {"f_0_retval_retval:float", "d_0_retval_retval:float"}, {}, + {"e_0_retval_retval:float", "f_0_retval_retval:float", + "d_0_retval_retval:float"}, + {}, { {{"C"}, "UnaryTest", {"a_0_arg"}}, {{"D"}, "BinaryTest", {"b_0_arg", "C:o:0"}}, @@ -1212,35 +1241,37 @@ TEST(EncapsulateSubgraphsTest, TwoFunctionsTwoOutside) { {"Toutputs", absl::Span({DT_FLOAT})}, {"ancestors", absl::Span({})}, {"key", "host_compute_channel_F1_O1"}, - {"shape_inference_graph", ""}, + {"shape_inference_graph", NameAttrList()}, {"shapes", absl::Span({shape_proto_expected})}, {"_outside_compilation_subgraph", "O1"}}, {"D"}}, }, - {{"d_0_retval_retval", "D:o:0"}, {"f_0_retval_retval", "F:o:0"}}); + {{"e_0_retval_retval", "outside_compilation_O1_host_compute:outputs:0"}, + {"d_0_retval_retval", "D:o:0"}, + {"f_0_retval_retval", "F:o:0"}}); *library_expected.add_function() = FunctionDefHelper::Create( - "F2", {"f_0_arg:float", "bridge_e_g_0_arg:float"}, - {"i_0_retval_retval:float", "g_0_retval_retval:float"}, {}, + "F2", {"e_0_arg:float", "f_0_arg:float", "d_0_arg:float"}, + {"g_0_retval_retval:float", "i_0_retval_retval:float"}, {}, { - {{"G"}, "BinaryTest", {"bridge_e_g_0_arg", "f_0_arg"}}, + {{"G"}, "BinaryTest", {"e_0_arg", "f_0_arg"}}, {{"I"}, "BinaryTest", {"f_0_arg", "outside_compilation_O1_host_compute:outputs:0"}}, {{"outside_compilation_O1_host_compute"}, "XlaHostCompute", - {"G:o:0"}, - {{"Tinputs", absl::Span({DT_FLOAT})}, + {"d_0_arg", "G:o:0"}, + {{"Tinputs", absl::Span({DT_FLOAT, DT_FLOAT})}, {"Toutputs", absl::Span({DT_FLOAT})}, {"ancestors", absl::Span({})}, {"key", "host_compute_channel_F2_O1"}, - {"shape_inference_graph", ""}, + {"shape_inference_graph", NameAttrList()}, {"shapes", absl::Span({shape_proto_expected})}, {"_outside_compilation_subgraph", "O1"}}}, }, - {{"i_0_retval_retval", "I:o:0"}, {"g_0_retval_retval", "G:o:0"}}); + {{"g_0_retval_retval", "G:o:0"}, {"i_0_retval_retval", "I:o:0"}}); { std::unique_ptr lib_def( @@ -1251,16 +1282,18 @@ TEST(EncapsulateSubgraphsTest, TwoFunctionsTwoOutside) { Node* key_constant1 = KeyPlaceholder("F1", b2.opts().WithName("F1_key_placeholder")); - Node* recv1 = RecvAtHost(ops::NodeOut(key_constant1, 0), "F1", "O1", - {DT_FLOAT, DT_FLOAT}, b2.opts()); + Node* recv1 = RecvAtHost( + ops::NodeOut(key_constant1, 0), "F1", "O1", {DT_FLOAT, DT_FLOAT}, + b2.opts().WithAttr(kXlaHasHostTransferAttrName, true)); Node* e = Binary(ops::NodeOut(recv1, 0), ops::NodeOut(recv1, 1), b2.opts() .WithName("E") - .WithControlInputs({recv1, b}) + .WithControlInputs({recv1}) .WithAttr("_encapsulate", "F1") .WithAttr("_outside", "O1")); Node* send1 = SendFromHost(ops::NodeOut(key_constant1, 0), "F1", "O1", {e}, - b2.opts().WithControlInput(e)); + b2.opts().WithControlInput(e).WithAttr( + kXlaHasHostTransferAttrName, true)); Node* s1 = Sequencer( b2.opts().WithName("F1_sequencer").WithControlInputs({recv1, send1}), "F1"); @@ -1268,29 +1301,33 @@ TEST(EncapsulateSubgraphsTest, TwoFunctionsTwoOutside) { NodeBuilder node_builder1("F1", "F1", lib_def.get()); node_builder1.Input(a).Input(b); Node* call1 = - b2.opts().WithControlInput(s1).FinalizeBuilder(&node_builder1); + b2.opts().WithControlInputs({s1, b}).FinalizeBuilder(&node_builder1); Node* key_constant2 = KeyPlaceholder("F2", b2.opts().WithName("F2_key_placeholder")); - Node* recv2 = RecvAtHost(ops::NodeOut(key_constant2, 0), "F2", "O1", - {DT_FLOAT}, b2.opts()); - Node* h = Binary(ops::NodeOut(call1, 1), recv2, + Node* recv2 = RecvAtHost( + ops::NodeOut(key_constant2, 0), "F2", "O1", {DT_FLOAT, DT_FLOAT}, + b2.opts().WithAttr(kXlaHasHostTransferAttrName, true)); + Node* h = Binary(recv2, ops::NodeOut(recv2, 1), b2.opts() .WithName("H") .WithAttr("_encapsulate", "F2") .WithAttr("_outside", "O1")); - Node* send2 = SendFromHost(ops::NodeOut(key_constant2, 0), "F2", "O1", {h}, - b2.opts()); + Node* send2 = + SendFromHost(ops::NodeOut(key_constant2, 0), "F2", "O1", {h}, + b2.opts().WithAttr(kXlaHasHostTransferAttrName, true)); Node* s2 = Sequencer( b2.opts().WithName("F2_sequencer").WithControlInputs({recv2, send2}), "F2"); NodeBuilder node_builder2("F2", "F2", lib_def.get()); - node_builder2.Input(call1).Input(e); + node_builder2.Input(call1) + .Input(ops::NodeOut(call1, 1)) + .Input(ops::NodeOut(call1, 2)); Node* call2 = b2.opts() - .WithControlInputs({s2, e, call1}) + .WithControlInputs({s2, call1}) .FinalizeBuilder(&node_builder2); - Binary(ops::NodeOut(call2, 1), call2, b2.opts().WithName("J")); + Binary(call2, ops::NodeOut(call2, 1), b2.opts().WithName("J")); TF_EXPECT_OK(b2.ToGraphDef(&graphdef_expected)); } @@ -1326,8 +1363,7 @@ TEST(EncapsulateSubgraphsTest, TwoFunctionsTwoOutsideDependencyFromOutside) { Node* h = Unary(g, b1.opts() .WithName("H") .WithAttr("_encapsulate", "F2") - .WithAttr("_outside", "O1") - .WithControlInput(e)); + .WithAttr("_outside", "O1")); Node* i = Unary(h, b1.opts().WithName("I").WithAttr("_encapsulate", "F2")); Binary(f, i, b1.opts().WithName("J")); TF_EXPECT_OK(b1.ToGraphDef(&graphdef)); @@ -1358,7 +1394,7 @@ TEST(EncapsulateSubgraphsTest, TwoFunctionsTwoOutsideDependencyFromOutside) { {"Toutputs", absl::Span({DT_FLOAT})}, {"ancestors", absl::Span({})}, {"key", "host_compute_channel_F1_O1"}, - {"shape_inference_graph", ""}, + {"shape_inference_graph", NameAttrList()}, {"shapes", absl::Span({shape_proto_expected})}, {"_outside_compilation_subgraph", "O1"}}, @@ -1380,7 +1416,7 @@ TEST(EncapsulateSubgraphsTest, TwoFunctionsTwoOutsideDependencyFromOutside) { {"Toutputs", absl::Span({DT_FLOAT})}, {"ancestors", absl::Span({})}, {"key", "host_compute_channel_F2_O1"}, - {"shape_inference_graph", ""}, + {"shape_inference_graph", NameAttrList()}, {"shapes", absl::Span({shape_proto_expected})}, {"_outside_compilation_subgraph", "O1"}}}, @@ -1401,7 +1437,7 @@ TEST(EncapsulateSubgraphsTest, TwoFunctionsTwoOutsideDependencyFromOutside) { Node* e = Binary(ops::NodeOut(recv1, 0), ops::NodeOut(recv1, 1), b2.opts() .WithName("E") - .WithControlInputs({recv1, b}) + .WithControlInputs({recv1}) .WithAttr("_encapsulate", "F1") .WithAttr("_outside", "O1")); Node* send1 = SendFromHost(ops::NodeOut(key_constant1, 0), "F1", "O1", {e}, @@ -1413,7 +1449,7 @@ TEST(EncapsulateSubgraphsTest, TwoFunctionsTwoOutsideDependencyFromOutside) { NodeBuilder node_builder1("F1", "F1", lib_def.get()); node_builder1.Input(a).Input(b); Node* call1 = - b2.opts().WithControlInput(s1).FinalizeBuilder(&node_builder1); + b2.opts().WithControlInputs({s1, b}).FinalizeBuilder(&node_builder1); Node* key_constant2 = KeyPlaceholder("F2", b2.opts().WithName("F2_key_placeholder")); @@ -1422,8 +1458,7 @@ TEST(EncapsulateSubgraphsTest, TwoFunctionsTwoOutsideDependencyFromOutside) { Node* h = Unary(recv2, b2.opts() .WithName("H") .WithAttr("_encapsulate", "F2") - .WithAttr("_outside", "O1") - .WithControlInput(e)); + .WithAttr("_outside", "O1")); Node* send2 = SendFromHost(ops::NodeOut(key_constant2, 0), "F2", "O1", {h}, b2.opts()); @@ -1484,12 +1519,12 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationNoInputs) { {"D:o:0", "outside_compilation_O1_host_compute:outputs:0"}}, {{"outside_compilation_O1_host_compute"}, "XlaHostCompute", - {}, - {{"Tinputs", absl::Span({})}, + {"a_0_arg"}, + {{"Tinputs", absl::Span({DT_FLOAT})}, {"Toutputs", absl::Span({DT_FLOAT})}, {"ancestors", absl::Span({})}, {"key", "host_compute_channel_F1_O1"}, - {"shape_inference_graph", ""}, + {"shape_inference_graph", NameAttrList()}, {"shapes", absl::Span({shape_proto_expected})}, {"_outside_compilation_subgraph", "O1"}}}, @@ -1503,16 +1538,19 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationNoInputs) { Node* a = InputShaped(b2.opts().WithName("A")); Node* b = Input(b2.opts().WithName("B")); - Node* e = Unary(a, b2.opts() - .WithName("E") - .WithAttr("_encapsulate", "F1") - .WithAttr("_outside", "O1")); Node* key_constant = KeyPlaceholder("F1", b2.opts().WithName("F1_key_placeholder")); + Node* recv1 = RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O1", + {DT_FLOAT}, b2.opts()); + Node* e = Unary(recv1, b2.opts() + .WithName("E") + .WithAttr("_encapsulate", "F1") + .WithAttr("_outside", "O1")); Node* send1 = SendFromHost(ops::NodeOut(key_constant, 0), "F1", "O1", {e}, b2.opts()); Node* s1 = Sequencer( - b2.opts().WithName("F1_sequencer").WithControlInput(send1), "F1"); + b2.opts().WithName("F1_sequencer").WithControlInputs({send1, recv1}), + "F1"); NodeBuilder node_builder1("F1", "F1", lib_def.get()); node_builder1.Input(a).Input(b); Node* call1 = @@ -1569,12 +1607,12 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationControlInput) { {"D:o:0", "outside_compilation_O1_host_compute:outputs:0"}}, {{"outside_compilation_O1_host_compute"}, "XlaHostCompute", - {}, - {{"Tinputs", absl::Span({})}, + {"a_0_arg"}, + {{"Tinputs", absl::Span({DT_FLOAT})}, {"Toutputs", absl::Span({DT_FLOAT})}, {"ancestors", absl::Span({})}, {"key", "host_compute_channel_F1_O1"}, - {"shape_inference_graph", ""}, + {"shape_inference_graph", NameAttrList()}, {"shapes", absl::Span({shape_proto_expected})}, {"_outside_compilation_subgraph", "O1"}}, @@ -1591,13 +1629,13 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationControlInput) { Node* key_constant = KeyPlaceholder("F1", b2.opts().WithName("F1_key_placeholder")); - Node* recv1 = - RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O1", {}, b2.opts()); - Node* e = Unary(a, b2.opts() - .WithName("E") - .WithControlInput(recv1) - .WithAttr("_encapsulate", "F1") - .WithAttr("_outside", "O1")); + Node* recv1 = RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O1", + {DT_FLOAT}, b2.opts()); + Node* e = Unary(recv1, b2.opts() + .WithName("E") + .WithControlInput(recv1) + .WithAttr("_encapsulate", "F1") + .WithAttr("_outside", "O1")); Node* send1 = SendFromHost(ops::NodeOut(key_constant, 0), "F1", "O1", {e}, b2.opts()); Node* s1 = Sequencer( @@ -1644,8 +1682,27 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationNoOutputs) { FunctionDefLibrary library_expected; GraphDef graphdef_expected; + { + GraphDefBuilder shape1(GraphDefBuilder::kFailImmediately); + Node* key_constant = KeyPlaceholder("F1", shape1.opts()); + Node* recv1 = + RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O1", {DT_FLOAT}, + shape1.opts().WithAttr(kXlaHasHostTransferAttrName, true)); + Node* e = Unary(ops::NodeOut(recv1, 0), shape1.opts() + .WithName("E") + .WithAttr("_encapsulate", "F1") + .WithAttr("_outside", "O1")); + SendFromHost(ops::NodeOut(key_constant, 0), "F1", "O1", {e}, + shape1.opts().WithAttr(kXlaHasHostTransferAttrName, true)); + TF_EXPECT_OK( + AddGraphDefToFunctionLibrary(shape1, "F1_O1", &library_expected)); + } + + NameAttrList shape_inference_graph; + shape_inference_graph.set_name("_outside_compilation_shape_inference_F1_O1"); *library_expected.add_function() = FunctionDefHelper::Create( - "F1", {"a_0_arg:float", "b_0_arg:float"}, {"f_0_retval_retval:float"}, {}, + "F1", {"a_0_arg:float", "b_0_arg:float"}, + {"e_0_retval_retval:float", "f_0_retval_retval:float"}, {}, { {{"C"}, "UnaryTest", {"a_0_arg"}}, {{"D"}, "BinaryTest", {"b_0_arg", "C:o:0"}}, @@ -1654,14 +1711,15 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationNoOutputs) { "XlaHostCompute", {"D:o:0"}, {{"Tinputs", absl::Span({DT_FLOAT})}, - {"Toutputs", absl::Span({})}, + {"Toutputs", absl::Span({DT_FLOAT})}, {"ancestors", absl::Span({})}, {"key", "host_compute_channel_F1_O1"}, - {"shape_inference_graph", ""}, + {"shape_inference_graph", shape_inference_graph}, {"shapes", absl::Span({})}, {"_outside_compilation_subgraph", "O1"}}}, }, - {{"f_0_retval_retval", "F:o:0"}}); + {{"e_0_retval_retval", "outside_compilation_O1_host_compute:outputs:0"}, + {"f_0_retval_retval", "F:o:0"}}); { std::unique_ptr lib_def( @@ -1678,14 +1736,17 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationNoOutputs) { .WithName("E") .WithAttr("_encapsulate", "F1") .WithAttr("_outside", "O1")); + Node* send1 = + SendFromHost(ops::NodeOut(key_constant, 0), "F1", "O1", {e}, b2.opts()); Node* s1 = Sequencer( - b2.opts().WithName("F1_sequencer").WithControlInput(recv1), "F1"); + b2.opts().WithName("F1_sequencer").WithControlInputs({recv1, send1}), + "F1"); NodeBuilder node_builder1("F1", "F1", lib_def.get()); node_builder1.Input(a).Input(b); Node* call1 = b2.opts().WithControlInput(s1).FinalizeBuilder(&node_builder1); - Binary(e, call1, b2.opts().WithName("G")); + Binary(call1, ops::NodeOut(call1, 1), b2.opts().WithName("G")); TF_EXPECT_OK(b2.ToGraphDef(&graphdef_expected)); } @@ -1722,8 +1783,27 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationControlOutput) { FunctionDefLibrary library_expected; GraphDef graphdef_expected; + { + GraphDefBuilder shape1(GraphDefBuilder::kFailImmediately); + Node* key_constant = KeyPlaceholder("F1", shape1.opts()); + Node* recv1 = + RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O1", {DT_FLOAT}, + shape1.opts().WithAttr(kXlaHasHostTransferAttrName, true)); + Node* e = Unary(ops::NodeOut(recv1, 0), shape1.opts() + .WithName("E") + .WithAttr("_encapsulate", "F1") + .WithAttr("_outside", "O1")); + SendFromHost(ops::NodeOut(key_constant, 0), "F1", "O1", {e}, + shape1.opts().WithAttr(kXlaHasHostTransferAttrName, true)); + TF_EXPECT_OK( + AddGraphDefToFunctionLibrary(shape1, "F1_O1", &library_expected)); + } + + NameAttrList shape_inference_graph; + shape_inference_graph.set_name("_outside_compilation_shape_inference_F1_O1"); *library_expected.add_function() = FunctionDefHelper::Create( - "F1", {"a_0_arg:float", "b_0_arg:float"}, {"f_0_retval_retval:float"}, {}, + "F1", {"a_0_arg:float", "b_0_arg:float"}, + {"e_0_retval_retval:float", "f_0_retval_retval:float"}, {}, { {{"C"}, "UnaryTest", {"a_0_arg"}}, {{"D"}, "BinaryTest", {"b_0_arg", "C:o:0"}}, @@ -1736,14 +1816,15 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationControlOutput) { "XlaHostCompute", {"D:o:0"}, {{"Tinputs", absl::Span({DT_FLOAT})}, - {"Toutputs", absl::Span({})}, + {"Toutputs", absl::Span({DT_FLOAT})}, {"ancestors", absl::Span({})}, {"key", "host_compute_channel_F1_O1"}, - {"shape_inference_graph", ""}, + {"shape_inference_graph", shape_inference_graph}, {"shapes", absl::Span({})}, {"_outside_compilation_subgraph", "O1"}}}, }, - {{"f_0_retval_retval", "F:o:0"}}); + {{"e_0_retval_retval", "outside_compilation_O1_host_compute:outputs:0"}, + {"f_0_retval_retval", "F:o:0"}}); { std::unique_ptr lib_def( @@ -1760,7 +1841,7 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationControlOutput) { .WithName("E") .WithAttr("_encapsulate", "F1") .WithAttr("_outside", "O1")); - Node* send1 = SendFromHost(ops::NodeOut(key_constant, 0), "F1", "O1", {}, + Node* send1 = SendFromHost(ops::NodeOut(key_constant, 0), "F1", "O1", {e}, b2.opts().WithControlInput(e)); Node* s1 = Sequencer( b2.opts().WithName("F1_sequencer").WithControlInputs({recv1, send1}), @@ -1770,7 +1851,7 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationControlOutput) { Node* call1 = b2.opts().WithControlInput(s1).FinalizeBuilder(&node_builder1); - Binary(e, call1, b2.opts().WithName("G")); + Binary(call1, ops::NodeOut(call1, 1), b2.opts().WithName("G")); TF_EXPECT_OK(b2.ToGraphDef(&graphdef_expected)); } @@ -1813,22 +1894,45 @@ TEST(EncapsulateSubgraphsTest, FunctionDefLibrary library_expected; GraphDef graphdef_expected; + { + GraphDefBuilder shape1(GraphDefBuilder::kFailImmediately); + Node* key_constant = KeyPlaceholder("F1", shape1.opts()); + Node* recv1 = + RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O1", {DT_FLOAT}, + shape1.opts().WithAttr(kXlaHasHostTransferAttrName, true)); + Node* e = Unary(ops::NodeOut(recv1, 0), shape1.opts() + .WithName("E") + .WithAttr("_encapsulate", "F1") + .WithAttr("_outside", "O1")); + SendFromHost(ops::NodeOut(key_constant, 0), "F1", "O1", {e}, + shape1.opts().WithAttr(kXlaHasHostTransferAttrName, true)); + TF_EXPECT_OK( + AddGraphDefToFunctionLibrary(shape1, "F1_O1", &library_expected)); + } + { GraphDefBuilder shape2(GraphDefBuilder::kFailImmediately); Node* key_constant = KeyPlaceholder("F1", shape2.opts()); - Node* recv2 = RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O2", - {DT_FLOAT}, shape2.opts()); + Node* recv2 = + RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O2", {DT_FLOAT}, + shape2.opts().WithAttr(kXlaHasHostTransferAttrName, true)); Node* g = Unary(ops::NodeOut(recv2, 0), shape2.opts() .WithName("G") .WithAttr("_encapsulate", "F1") .WithAttr("_outside", "O2")); - SendFromHost(ops::NodeOut(key_constant, 0), "F1", "O2", {g}, shape2.opts()); + SendFromHost(ops::NodeOut(key_constant, 0), "F1", "O2", {g}, + shape2.opts().WithAttr(kXlaHasHostTransferAttrName, true)); TF_EXPECT_OK( AddGraphDefToFunctionLibrary(shape2, "F1_O2", &library_expected)); } + NameAttrList shape_inference_graph1; + shape_inference_graph1.set_name("_outside_compilation_shape_inference_F1_O1"); + NameAttrList shape_inference_graph2; + shape_inference_graph2.set_name("_outside_compilation_shape_inference_F1_O2"); *library_expected.add_function() = FunctionDefHelper::Create( - "F1", {"a_0_arg:float", "b_0_arg:float"}, {"h_0_retval_retval:float"}, {}, + "F1", {"a_0_arg:float", "b_0_arg:float"}, + {"e_0_retval_retval:float", "h_0_retval_retval:float"}, {}, { {{"C"}, "UnaryTest", {"a_0_arg"}}, {{"D"}, "BinaryTest", {"b_0_arg", "C:o:0"}}, @@ -1836,6 +1940,16 @@ TEST(EncapsulateSubgraphsTest, {{"H"}, "UnaryTest", {"outside_compilation_O2_host_compute:outputs:0"}}, + {{"outside_compilation_O1_host_compute"}, + "XlaHostCompute", + {"a_0_arg"}, + {{"Tinputs", absl::Span({DT_FLOAT})}, + {"Toutputs", absl::Span({DT_FLOAT})}, + {"ancestors", absl::Span({})}, + {"key", "host_compute_channel_F1_O1"}, + {"shape_inference_graph", shape_inference_graph1}, + {"shapes", absl::Span({})}, + {"_outside_compilation_subgraph", "O1"}}}, {{"outside_compilation_O2_host_compute"}, "XlaHostCompute", {"F:o:0"}, @@ -1843,12 +1957,12 @@ TEST(EncapsulateSubgraphsTest, {"Toutputs", absl::Span({DT_FLOAT})}, {"ancestors", absl::Span({})}, {"key", "host_compute_channel_F1_O2"}, - {"shape_inference_graph", - "_outside_compilation_shape_inference_F1_O2"}, + {"shape_inference_graph", shape_inference_graph2}, {"shapes", absl::Span({})}, {"_outside_compilation_subgraph", "O2"}}}, }, - {{"h_0_retval_retval", "H:o:0"}}); + {{"e_0_retval_retval", "outside_compilation_O1_host_compute:outputs:0"}, + {"h_0_retval_retval", "H:o:0"}}); { std::unique_ptr lib_def( @@ -1856,30 +1970,39 @@ TEST(EncapsulateSubgraphsTest, GraphDefBuilder b2(GraphDefBuilder::kFailImmediately, lib_def.get()); Node* a = Input(b2.opts().WithName("A")); Node* b = Input(b2.opts().WithName("B")); - - Node* e = Unary(a, b2.opts() - .WithName("E") - .WithAttr("_encapsulate", "F1") - .WithAttr("_outside", "O1")); Node* key_constant = KeyPlaceholder("F1", b2.opts().WithName("F1_key_placeholder")); - Node* recv = RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O2", - {DT_FLOAT}, b2.opts()); - Node* g = Unary(recv, b2.opts() - .WithName("G") - .WithAttr("_encapsulate", "F1") - .WithAttr("_outside", "O2") - .WithControlInput(e)); - Node* send = - SendFromHost(ops::NodeOut(key_constant, 0), "F1", "O2", {g}, b2.opts()); - Node* s1 = Sequencer( - b2.opts().WithName("F1_sequencer").WithControlInputs({recv, send}), - "F1"); + Node* recv1 = + RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O1", {DT_FLOAT}, + b2.opts().WithAttr(kXlaHasHostTransferAttrName, true)); + + Node* e = Unary(recv1, b2.opts() + .WithName("E") + .WithAttr("_encapsulate", "F1") + .WithAttr("_outside", "O1")); + Node* send1 = + SendFromHost(ops::NodeOut(key_constant, 0), "F1", "O1", {e}, + b2.opts().WithAttr(kXlaHasHostTransferAttrName, true)); + Node* recv2 = + RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O2", {DT_FLOAT}, + b2.opts().WithAttr(kXlaHasHostTransferAttrName, true)); + Node* g = Unary(recv2, b2.opts() + .WithName("G") + .WithAttr("_encapsulate", "F1") + .WithAttr("_outside", "O2") + .WithControlInput(e)); + Node* send2 = + SendFromHost(ops::NodeOut(key_constant, 0), "F1", "O2", {g}, + b2.opts().WithAttr(kXlaHasHostTransferAttrName, true)); + Node* s1 = Sequencer(b2.opts() + .WithName("F1_sequencer") + .WithControlInputs({recv1, send1, recv2, send2}), + "F1"); NodeBuilder node_builder1("F1", "F1", lib_def.get()); node_builder1.Input(a).Input(b).ControlInput(s1); Node* call1 = b2.opts().FinalizeBuilder(&node_builder1); - Binary(e, call1, b2.opts().WithName("I")); + Binary(call1, ops::NodeOut(call1, 1), b2.opts().WithName("I")); TF_EXPECT_OK(b2.ToGraphDef(&graphdef_expected)); } @@ -1925,19 +2048,24 @@ TEST(EncapsulateSubgraphsTest, { GraphDefBuilder shape1(GraphDefBuilder::kFailImmediately); Node* key_constant = KeyPlaceholder("F1", shape1.opts()); - Node* recv2 = RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O1", - {DT_FLOAT}, shape1.opts()); + Node* recv2 = + RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O1", {DT_FLOAT}, + shape1.opts().WithAttr(kXlaHasHostTransferAttrName, true)); Node* e = Unary(ops::NodeOut(recv2, 0), shape1.opts() .WithName("E") .WithAttr("_encapsulate", "F1") .WithAttr("_outside", "O1")); - SendFromHost(ops::NodeOut(key_constant, 0), "F1", "O1", {e}, shape1.opts()); + SendFromHost(ops::NodeOut(key_constant, 0), "F1", "O1", {e}, + shape1.opts().WithAttr(kXlaHasHostTransferAttrName, true)); TF_EXPECT_OK( AddGraphDefToFunctionLibrary(shape1, "F1_O1", &library_expected)); } + NameAttrList shape_inference_graph; + shape_inference_graph.set_name("_outside_compilation_shape_inference_F1_O1"); *library_expected.add_function() = FunctionDefHelper::Create( - "F1", {"a_0_arg:float", "b_0_arg:float"}, {"h_0_retval_retval:float"}, {}, + "F1", {"a_0_arg:float", "b_0_arg:float"}, + {"e_0_retval_retval:float", "h_0_retval_retval:float"}, {}, { {{"C"}, "UnaryTest", {"a_0_arg"}}, {{"D"}, "BinaryTest", {"b_0_arg", "C:o:0"}}, @@ -1945,6 +2073,16 @@ TEST(EncapsulateSubgraphsTest, "UnaryTest", {"outside_compilation_O1_host_compute:outputs:0"}}, {{"H"}, "UnaryTest", {"F:o:0"}}, + {{"outside_compilation_O2_host_compute"}, + "XlaHostCompute", + {"a_0_arg"}, + {{"Tinputs", absl::Span({DT_FLOAT})}, + {"Toutputs", absl::Span({})}, + {"ancestors", absl::Span({})}, + {"key", "host_compute_channel_F1_O2"}, + {"shape_inference_graph", NameAttrList()}, + {"shapes", absl::Span({})}, + {"_outside_compilation_subgraph", "O2"}}}, {{"outside_compilation_O1_host_compute"}, "XlaHostCompute", {"D:o:0"}, @@ -1952,12 +2090,12 @@ TEST(EncapsulateSubgraphsTest, {"Toutputs", absl::Span({DT_FLOAT})}, {"ancestors", absl::Span({})}, {"key", "host_compute_channel_F1_O1"}, - {"shape_inference_graph", - "_outside_compilation_shape_inference_F1_O1"}, + {"shape_inference_graph", shape_inference_graph}, {"shapes", absl::Span({})}, {"_outside_compilation_subgraph", "O1"}}}, }, - {{"h_0_retval_retval", "H:o:0"}}); + {{"e_0_retval_retval", "outside_compilation_O1_host_compute:outputs:0"}, + {"h_0_retval_retval", "H:o:0"}}); { std::unique_ptr lib_def( @@ -1968,27 +2106,33 @@ TEST(EncapsulateSubgraphsTest, Node* key_constant = KeyPlaceholder("F1", b2.opts().WithName("F1_key_placeholder")); - Node* recv = RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O1", - {DT_FLOAT}, b2.opts()); - Node* e = Unary(recv, b2.opts() - .WithName("E") - .WithAttr("_encapsulate", "F1") - .WithAttr("_outside", "O1")); + Node* recv1 = + RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O1", {DT_FLOAT}, + b2.opts().WithAttr(kXlaHasHostTransferAttrName, true)); + Node* e = Unary(recv1, b2.opts() + .WithName("E") + .WithAttr("_encapsulate", "F1") + .WithAttr("_outside", "O1")); Node* send = - SendFromHost(ops::NodeOut(key_constant, 0), "F1", "O1", {e}, b2.opts()); - /*Node* g =*/Unary(a, b2.opts() - .WithName("G") - .WithAttr("_encapsulate", "F1") - .WithAttr("_outside", "O2") - .WithControlInput(e)); - Node* s1 = Sequencer( - b2.opts().WithName("F1_sequencer").WithControlInputs({recv, send}), - "F1"); + SendFromHost(ops::NodeOut(key_constant, 0), "F1", "O1", {e}, + b2.opts().WithAttr(kXlaHasHostTransferAttrName, true)); + Node* recv2 = + RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O2", {DT_FLOAT}, + b2.opts().WithAttr(kXlaHasHostTransferAttrName, true)); + /*Node* g =*/Unary(recv2, b2.opts() + .WithName("G") + .WithAttr("_encapsulate", "F1") + .WithAttr("_outside", "O2") + .WithControlInput(e)); + Node* s1 = Sequencer(b2.opts() + .WithName("F1_sequencer") + .WithControlInputs({recv1, recv2, send}), + "F1"); NodeBuilder node_builder1("F1", "F1", lib_def.get()); node_builder1.Input(a).Input(b).ControlInput(s1); Node* call1 = b2.opts().FinalizeBuilder(&node_builder1); - Binary(e, call1, b2.opts().WithName("I")); + Binary(call1, ops::NodeOut(call1, 1), b2.opts().WithName("I")); TF_EXPECT_OK(b2.ToGraphDef(&graphdef_expected)); } @@ -2039,19 +2183,24 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationClusterDependency) { { GraphDefBuilder shape1(GraphDefBuilder::kFailImmediately); Node* key_constant = KeyPlaceholder("F1", shape1.opts()); - Node* recv2 = RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O1", - {DT_FLOAT}, shape1.opts()); + Node* recv2 = + RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O1", {DT_FLOAT}, + shape1.opts().WithAttr(kXlaHasHostTransferAttrName, true)); Node* e = Unary(ops::NodeOut(recv2, 0), shape1.opts() .WithName("E") .WithAttr("_encapsulate", "F1") .WithAttr("_outside", "O1")); - SendFromHost(ops::NodeOut(key_constant, 0), "F1", "O1", {e}, shape1.opts()); + SendFromHost(ops::NodeOut(key_constant, 0), "F1", "O1", {e}, + shape1.opts().WithAttr(kXlaHasHostTransferAttrName, true)); TF_EXPECT_OK( AddGraphDefToFunctionLibrary(shape1, "F1_O1", &library_expected)); } + NameAttrList shape_inference_graph; + shape_inference_graph.set_name("_outside_compilation_shape_inference_F1_O1"); *library_expected.add_function() = FunctionDefHelper::Create( - "F1", {"a_0_arg:float", "b_0_arg:float"}, {"h_0_retval_retval:float"}, {}, + "F1", {"a_0_arg:float", "b_0_arg:float"}, + {"e_0_retval_retval:float", "h_0_retval_retval:float"}, {}, {{{"C"}, "UnaryTest", {"a_0_arg"}}, {{"D"}, "BinaryTest", {"b_0_arg", "C:o:0"}}, {{"F"}, "UnaryTest", {"outside_compilation_O1_host_compute:outputs:0"}}, @@ -2063,8 +2212,7 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationClusterDependency) { {"Toutputs", absl::Span({DT_FLOAT})}, {"ancestors", absl::Span({})}, {"key", "host_compute_channel_F1_O1"}, - {"shape_inference_graph", - "_outside_compilation_shape_inference_F1_O1"}, + {"shape_inference_graph", shape_inference_graph}, {"shapes", absl::Span({})}, {"_outside_compilation_subgraph", "O1"}}}, {{"outside_compilation_O2_host_compute"}, @@ -2074,7 +2222,7 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationClusterDependency) { {"Toutputs", absl::Span({})}, {"ancestors", absl::Span({})}, {"key", "host_compute_channel_F1_O2"}, - {"shape_inference_graph", ""}, + {"shape_inference_graph", NameAttrList()}, {"shapes", absl::Span({})}, {"_outside_compilation_subgraph", "O2"}}, {}}, @@ -2085,11 +2233,12 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationClusterDependency) { {"Toutputs", absl::Span({})}, {"ancestors", absl::Span({})}, {"key", "host_compute_channel_F1_O3"}, - {"shape_inference_graph", ""}, + {"shape_inference_graph", NameAttrList()}, {"shapes", absl::Span({})}, {"_outside_compilation_subgraph", "O3"}}, {}}}, - {{"h_0_retval_retval", "H:o:0"}}); + {{"e_0_retval_retval", "outside_compilation_O1_host_compute:outputs:0"}, + {"h_0_retval_retval", "H:o:0"}}); { std::unique_ptr lib_def( @@ -2100,23 +2249,27 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationClusterDependency) { Node* key_constant = KeyPlaceholder("F1", b2.opts().WithName("F1_key_placeholder")); - Node* recv1 = RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O1", - {DT_FLOAT}, b2.opts()); + Node* recv1 = + RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O1", {DT_FLOAT}, + b2.opts().WithAttr(kXlaHasHostTransferAttrName, true)); Node* e = Unary(recv1, b2.opts() .WithName("E") .WithAttr("_encapsulate", "F1") .WithAttr("_outside", "O1")); Node* send = - SendFromHost(ops::NodeOut(key_constant, 0), "F1", "O1", {e}, b2.opts()); - Node* recv2 = RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O2", - {DT_FLOAT}, b2.opts()); + SendFromHost(ops::NodeOut(key_constant, 0), "F1", "O1", {e}, + b2.opts().WithAttr(kXlaHasHostTransferAttrName, true)); + Node* recv2 = + RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O2", {DT_FLOAT}, + b2.opts().WithAttr(kXlaHasHostTransferAttrName, true)); Node* g = Unary(recv2, b2.opts() .WithName("G") .WithAttr("_encapsulate", "F1") .WithAttr("_outside", "O2") .WithControlInput(e)); - Node* recv3 = RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O3", - {DT_FLOAT}, b2.opts()); + Node* recv3 = + RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O3", {DT_FLOAT}, + b2.opts().WithAttr(kXlaHasHostTransferAttrName, true)); /*Node* i =*/Binary(recv3, e, b2.opts() .WithName("I") @@ -2131,7 +2284,7 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationClusterDependency) { node_builder1.Input(a).Input(b).ControlInput(s1); Node* call1 = b2.opts().FinalizeBuilder(&node_builder1); - Binary(e, call1, b2.opts().WithName("J")); + Binary(call1, ops::NodeOut(call1, 1), b2.opts().WithName("J")); TF_EXPECT_OK(b2.ToGraphDef(&graphdef_expected)); } @@ -2167,14 +2320,44 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationNoInputsOrOutputs) { FunctionDefLibrary library_expected; GraphDef graphdef_expected; + { + GraphDefBuilder shape1(GraphDefBuilder::kFailImmediately); + Node* key_constant = KeyPlaceholder("F1", shape1.opts()); + Node* recv2 = + RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O1", {DT_FLOAT}, + shape1.opts().WithAttr(kXlaHasHostTransferAttrName, true)); + Node* e = Unary(ops::NodeOut(recv2, 0), shape1.opts() + .WithName("E") + .WithAttr("_encapsulate", "F1") + .WithAttr("_outside", "O1")); + SendFromHost(ops::NodeOut(key_constant, 0), "F1", "O1", {e}, + shape1.opts().WithAttr(kXlaHasHostTransferAttrName, true)); + TF_EXPECT_OK( + AddGraphDefToFunctionLibrary(shape1, "F1_O1", &library_expected)); + } + + NameAttrList shape_inference_graph; + shape_inference_graph.set_name("_outside_compilation_shape_inference_F1_O1"); *library_expected.add_function() = FunctionDefHelper::Create( - "F1", {"a_0_arg:float", "b_0_arg:float"}, {"f_0_retval_retval:float"}, {}, + "F1", {"a_0_arg:float", "b_0_arg:float"}, + {"e_0_retval_retval:float", "f_0_retval_retval:float"}, {}, { {{"C"}, "UnaryTest", {"a_0_arg"}}, {{"D"}, "BinaryTest", {"b_0_arg", "C:o:0"}}, {{"F"}, "UnaryTest", {"D:o:0"}}, + {{"outside_compilation_O1_host_compute"}, + "XlaHostCompute", + {"a_0_arg"}, + {{"Tinputs", absl::Span({DT_FLOAT})}, + {"Toutputs", absl::Span({DT_FLOAT})}, + {"ancestors", absl::Span({})}, + {"key", "host_compute_channel_F1_O1"}, + {"shape_inference_graph", shape_inference_graph}, + {"shapes", absl::Span({})}, + {"_outside_compilation_subgraph", "O1"}}}, }, - {{"f_0_retval_retval", "F:o:0"}}); + {{"e_0_retval_retval", "outside_compilation_O1_host_compute:outputs:0"}, + {"f_0_retval_retval", "F:o:0"}}); { std::unique_ptr lib_def( @@ -2183,15 +2366,26 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationNoInputsOrOutputs) { Node* a = Input(b2.opts().WithName("A")); Node* b = Input(b2.opts().WithName("B")); - Node* e = Unary(a, b2.opts() - .WithName("E") - .WithAttr("_encapsulate", "F1") - .WithAttr("_outside", "O1")); + Node* key_constant = + KeyPlaceholder("F1", b2.opts().WithName("F1_key_placeholder")); + Node* recv = + RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O1", {DT_FLOAT}, + b2.opts().WithAttr(kXlaHasHostTransferAttrName, true)); + Node* e = Unary(recv, b2.opts() + .WithName("E") + .WithAttr("_encapsulate", "F1") + .WithAttr("_outside", "O1")); + Node* send = + SendFromHost(ops::NodeOut(key_constant, 0), "F1", "O1", {e}, + b2.opts().WithAttr(kXlaHasHostTransferAttrName, true)); + Node* s = Sequencer( + b2.opts().WithName("F1_sequencer").WithControlInputs({recv, send}), + "F1"); NodeBuilder node_builder1("F1", "F1", lib_def.get()); - node_builder1.Input(a).Input(b); + node_builder1.Input(a).Input(b).ControlInput(s); Node* call1 = b2.opts().FinalizeBuilder(&node_builder1); - Binary(e, call1, b2.opts().WithName("G")); + Binary(call1, ops::NodeOut(call1, 1), b2.opts().WithName("G")); TF_EXPECT_OK(b2.ToGraphDef(&graphdef_expected)); } @@ -2236,20 +2430,22 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationShapeInference) { { GraphDefBuilder shape(GraphDefBuilder::kFailImmediately); Node* key_constant = KeyPlaceholder("F1", shape.opts()); - Node* recv = RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O1", - {DT_FLOAT}, shape.opts()); - Node* a = InputShaped(shape.opts().WithName("A")); - Node* c = Unary(a, shape.opts().WithName("C")); - Node* e = BinaryUnknownShape(c, recv, + Node* recv = RecvAtHost( + ops::NodeOut(key_constant, 0), "F1", "O1", {DT_FLOAT, DT_FLOAT}, + shape.opts().WithAttr(kXlaHasHostTransferAttrName, true)); + Node* e = BinaryUnknownShape(recv, ops::NodeOut(recv, 1), shape.opts() .WithName("E") .WithAttr("_encapsulate", "F1") .WithAttr("_outside", "O1")); - SendFromHost(ops::NodeOut(key_constant, 0), "F1", "O1", {e}, shape.opts()); + SendFromHost(ops::NodeOut(key_constant, 0), "F1", "O1", {e}, + shape.opts().WithAttr(kXlaHasHostTransferAttrName, true)); TF_EXPECT_OK( AddGraphDefToFunctionLibrary(shape, "F1_O1", &library_expected)); } + NameAttrList shape_inference_graph; + shape_inference_graph.set_name("_outside_compilation_shape_inference_F1_O1"); *library_expected.add_function() = test::function::XTimesTwo(); *library_expected.add_function() = FunctionDefHelper::Create( "F1", {"b_0_arg:float", "c_0_arg:float"}, {"f_0_retval_retval:float"}, {}, @@ -2262,13 +2458,12 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationShapeInference) { {"outside_compilation_O1_host_compute"}}, {{"outside_compilation_O1_host_compute"}, "XlaHostCompute", - {"c:o:0"}, - {{"Tinputs", absl::Span({DT_FLOAT})}, + {"c_0_arg", "c:o:0"}, + {{"Tinputs", absl::Span({DT_FLOAT, DT_FLOAT})}, {"Toutputs", absl::Span({DT_FLOAT})}, {"ancestors", absl::Span({})}, {"key", "host_compute_channel_F1_O1"}, - {"shape_inference_graph", - "_outside_compilation_shape_inference_F1_O1"}, + {"shape_inference_graph", shape_inference_graph}, {"shapes", absl::Span({})}, {"_outside_compilation_subgraph", "O1"}}, {"c"}}, @@ -2285,16 +2480,18 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationShapeInference) { Node* key_constant = KeyPlaceholder("F1", b2.opts().WithName("F1_key_placeholder")); - Node* recv = RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O1", - {DT_FLOAT}, b2.opts()); - Node* e = BinaryUnknownShape(c, ops::NodeOut(recv, 0), + Node* recv = RecvAtHost( + ops::NodeOut(key_constant, 0), "F1", "O1", {DT_FLOAT, DT_FLOAT}, + b2.opts().WithAttr(kXlaHasHostTransferAttrName, true)); + Node* e = BinaryUnknownShape(recv, ops::NodeOut(recv, 1), b2.opts() .WithName("E") - .WithControlInputs({recv, b}) + .WithControlInputs({recv}) .WithAttr("_encapsulate", "F1") .WithAttr("_outside", "O1")); Node* send = SendFromHost(ops::NodeOut(key_constant, 0), "F1", "O1", {e}, - b2.opts().WithControlInput(e)); + b2.opts().WithControlInput(e).WithAttr( + kXlaHasHostTransferAttrName, true)); Node* s = Sequencer( b2.opts().WithName("F1_sequencer").WithControlInputs({recv, send}), @@ -2303,9 +2500,9 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationShapeInference) { NodeBuilder node_builder("F1", "F1", lib_def.get()); node_builder.Input(b).Input(c); Node* call = - b2.opts().WithControlInputs({s, c}).FinalizeBuilder(&node_builder); + b2.opts().WithControlInputs({s, b, c}).FinalizeBuilder(&node_builder); - Binary(a, call, b2.opts().WithName("G").WithControlInputs({e})); + Binary(a, call, b2.opts().WithName("G").WithControlInputs({call})); TF_EXPECT_OK(b2.ToGraphDef(&graphdef_expected)); } diff --git a/tensorflow/compiler/jit/encapsulate_util.cc b/tensorflow/compiler/jit/encapsulate_util.cc index bcc3213285bee2a2094bd6c39b37ba95874d90ed..2264806d6bdabd9f26d9f83b681524399f996317 100644 --- a/tensorflow/compiler/jit/encapsulate_util.cc +++ b/tensorflow/compiler/jit/encapsulate_util.cc @@ -62,516 +62,6 @@ void ReplaceAttr(Node* n, const string& attr_name, const T& value) { n->AddAttr(attr_name, value); } -// Step 1a ~ 1d for PreprocessForEncapsulation(). See comments of -// PreprocessForEncapsulation() for details. -Status ProcessControlEdges(Graph* g, const string& xla_computation_attr_name, - const string& outside_compilation_attr_name) { - // Gather edges to remove. We should not remove the edge while iterating. - std::vector edges_to_remove; - for (const Edge* e : g->edges()) { - if (!e->IsControlEdge()) { - continue; - } - - auto src_xla_computation = - GetStringAttr(*e->src(), xla_computation_attr_name); - auto dst_xla_computation = - GetStringAttr(*e->dst(), xla_computation_attr_name); - auto src_outside_compilation = - GetStringAttr(*e->src(), outside_compilation_attr_name); - auto dst_outside_compilation = - GetStringAttr(*e->dst(), outside_compilation_attr_name); - - if (!src_xla_computation && !dst_xla_computation) { - continue; - } else if (src_xla_computation && !dst_xla_computation) { - if (src_outside_compilation) { - // Case 1c: outside compilation to host computation control edge. - edges_to_remove.push_back(e); - - TF_RETURN_IF_ERROR(AppendToListAttr( - e->dst(), kXlaControlDependenciesAttrName, e->src()->name())); - } - } else if (!src_xla_computation && dst_xla_computation) { - if (dst_outside_compilation) { - // Case 1c: host computation control to outside compilation edge. - edges_to_remove.push_back(e); - - TF_RETURN_IF_ERROR(AppendToListAttr( - e->dst(), kXlaControlDependenciesAttrName, e->src()->name())); - } - } else { // src_xla_computation && dst_xla_computation - if (*src_xla_computation != *dst_xla_computation) { - if (src_outside_compilation && dst_outside_compilation) { - // Case 1b: outside compilation to outside compilation control edge. - edges_to_remove.push_back(e); - - TF_RETURN_IF_ERROR(AppendToListAttr( - e->dst(), kXlaControlDependenciesAttrName, e->src()->name())); - } else if (src_outside_compilation && !dst_outside_compilation) { - // Case 1a: outside compilation to another XLA computaition control - // edge. - TF_RETURN_IF_ERROR(AppendToListAttr( - e->src(), kXlaConnectedToOtherXlaComputationAttrName, - *dst_xla_computation)); - } else if (!src_outside_compilation && dst_outside_compilation) { - // Case 1a: another XLA computaition to outside compilation control - // edge. - TF_RETURN_IF_ERROR(AppendToListAttr( - e->dst(), kXlaConnectedFromOtherXlaComputationAttrName, - *src_xla_computation)); - } - } - } - } - - for (auto e : edges_to_remove) { - g->RemoveEdge(e); - } - return Status::OK(); -} - -// Step 2 for PreprocessForEncapsulation(). See comments of -// PreprocessForEncapsulation() for details. -Status ProcessXlaToXlaDataEdges(Graph* g, - const string& xla_computation_attr_name, - const string& outside_compilation_attr_name) { - // Gather edges between XLA computations. Notice that we do not store `Edge*` - // directly because we remove some nodes while adding Identity nodes, and - // those Edge pointers might be invalidated. - struct EdgeInfo { - int dst_input, dst_node_id; - }; - std::vector edges; - for (const Edge* e : g->edges()) { - if (e->IsControlEdge()) { - continue; - } - - auto src_xla_computation = - GetStringAttr(*e->src(), xla_computation_attr_name); - auto dst_xla_computation = - GetStringAttr(*e->dst(), xla_computation_attr_name); - auto src_outside_compilation = - GetStringAttr(*e->src(), outside_compilation_attr_name); - auto dst_outside_compilation = - GetStringAttr(*e->dst(), outside_compilation_attr_name); - if (!src_xla_computation || !dst_xla_computation) { - continue; - } - - if (*src_xla_computation != *dst_xla_computation) { - if (src_outside_compilation || dst_outside_compilation) { - edges.push_back(EdgeInfo{e->dst_input(), e->dst()->id()}); - VLOG(4) << "XLA -> XLA edge: " << e->DebugString(); - } - } - } - - // For each XLA -> XLA edge, add an Identity node between src and dst. - for (int i = 0; i < edges.size(); i++) { - Node* dst = g->FindNodeId(edges[i].dst_node_id); - const Edge* e; - TF_RETURN_IF_ERROR(dst->input_edge(edges[i].dst_input, &e)); - Node* src = e->src(); - int src_output = e->src_output(), dst_input = e->dst_input(); - g->RemoveEdge(e); - - // Create Identity node, and connect it between `src` and `dst`. - string identity_node_name = - absl::StrCat("bridge_", src->name(), "_", dst->name()); - DataType dtype = src->output_type(src_output); - TF_ASSIGN_OR_RETURN(Node * identity_node, - BuildIdentityNode(g, identity_node_name, dtype, src, - /*requested_device=*/absl::nullopt)); - identity_node->AddAttr(kBridgeSourceNodeAttrName, src->name()); - g->AddEdge(src, src_output, identity_node, 0); - g->AddEdge(identity_node, 0, dst, dst_input); - - // Replace `e->dst()` because its input node changed. - NodeDef new_def = dst->def(); - *new_def.mutable_input(dst_input) = identity_node->name(); - TF_ASSIGN_OR_RETURN(Node * dst_replace_node, ReplaceNode(g, dst, new_def)); - - // Other edge in `edges` might have `e->dst()` as src or dst - // node. Before removing `e->dst()`, replace those edges with corresponding - // edges for `dst_replace_node`. - for (int j = i + 1; j < edges.size(); j++) { - if (edges[j].dst_node_id == edges[i].dst_node_id) { - edges[j].dst_node_id = dst_replace_node->id(); - } - } - } - return Status::OK(); -} - -// Step 3 for PreprocessForEncapsulation(). See comments of -// PreprocessForEncapsulation() for details. -Status ProcessDataEdgeBetweenOutsideCompilationAndHostComputation( - Graph* g, const string& xla_computation_attr_name, - const string& outside_compilation_attr_name) { - // Gather edges between outside compilation and host computation. Notice that - // we do not store `Edge*` directly because we remove some nodes while adding - // Identity nodes, and those Edge pointers might be invalidated. - struct EdgeInfo { - int dst_input, dst_node_id; - bool is_host_to_outside_compilation; - }; - std::vector edges; - for (const Edge* e : g->edges()) { - if (e->IsControlEdge()) { - continue; - } - - if (e->src()->attrs().Find(xla_computation_attr_name) == nullptr && - e->dst()->attrs().Find(xla_computation_attr_name) != nullptr && - e->dst()->attrs().Find(outside_compilation_attr_name) != nullptr) { - edges.push_back(EdgeInfo{e->dst_input(), e->dst()->id(), - /*is_host_to_outside_compilation=*/true}); - VLOG(4) << "Host -> oc edge: " << e->DebugString(); - } else if (e->dst()->attrs().Find(xla_computation_attr_name) == nullptr && - e->src()->attrs().Find(xla_computation_attr_name) != nullptr && - e->src()->attrs().Find(outside_compilation_attr_name) != - nullptr) { - edges.push_back(EdgeInfo{e->dst_input(), e->dst()->id(), - /*is_host_to_outside_compilation=*/false}); - VLOG(4) << "Oc -> host edge: " << e->DebugString(); - } - } - - // Remove the edge from host to outside compilation. Add a placeholder as - // outside compilation node input. - std::map placeholders; - for (int i = 0; i < edges.size(); i++) { - Node* dst = g->FindNodeId(edges[i].dst_node_id); - const Edge* e; - TF_RETURN_IF_ERROR(dst->input_edge(edges[i].dst_input, &e)); - Node* src = e->src(); - int src_output = e->src_output(), dst_input = e->dst_input(); - g->RemoveEdge(e); - - // Find or create placeholder node. - string new_name = - edges[i].is_host_to_outside_compilation - ? absl::StrCat(src->name(), "_host_to_oc_placeholder") - : absl::StrCat(src->name(), "_oc_to_host_placeholder"); - auto iter = placeholders.find(new_name); - Node* placeholder_node; - if (iter == placeholders.end()) { - NodeDefBuilder placeholder_builder(new_name, "Placeholder"); - placeholder_builder.Attr("dtype", src->output_type(src_output)); - if (edges[i].is_host_to_outside_compilation) { - placeholder_builder.Attr(kHostToOutsideCompilationOriginalNodeAttrName, - src->name()); - placeholder_builder.Attr(kHostToOutsideCompilationSrcOutputAttrName, - src_output); - // If this placeholder node is in outside compilation, we need to set - // `xla_computation_attr_name` and `outside_compilation_attr_name`. - string xla_computation_attr, outside_compilation_attr; - TF_RETURN_IF_ERROR(GetNodeAttr(dst->attrs(), xla_computation_attr_name, - &xla_computation_attr)); - TF_RETURN_IF_ERROR(GetNodeAttr(dst->attrs(), - outside_compilation_attr_name, - &outside_compilation_attr)); - placeholder_builder.Attr(xla_computation_attr_name, - xla_computation_attr); - placeholder_builder.Attr(outside_compilation_attr_name, - outside_compilation_attr); - } else { - placeholder_builder.Attr(kOutsideCompilationToHostOriginalNodeAttrName, - src->name()); - placeholder_builder.Attr(kOutsideCompilationToHostSrcOutputAttrName, - src_output); - } - NodeDef placeholder_def; - TF_RETURN_IF_ERROR(placeholder_builder.Finalize(&placeholder_def)); - Status s; - placeholder_node = g->AddNode(placeholder_def, &s); - TF_RETURN_IF_ERROR(s); - placeholders[new_name] = placeholder_node; - } else { - placeholder_node = iter->second; - } - g->AddEdge(placeholder_node, 0, dst, dst_input); - - // Replace `e->dst()` because its input node changed. - NodeDef new_def = dst->def(); - *new_def.mutable_input(dst_input) = placeholder_node->name(); - TF_ASSIGN_OR_RETURN(Node * dst_replace_node, ReplaceNode(g, dst, new_def)); - - // Other edge in `edges` might have `e->dst()` as src or dst - // node. Before removing `e->dst()`, replace those edges with corresponding - // edges for `dst_replace_node`. - for (int j = i + 1; j < edges.size(); j++) { - if (edges[j].dst_node_id == edges[i].dst_node_id) { - edges[j].dst_node_id = dst_replace_node->id(); - } - } - } - return Status::OK(); -} - -// Step 1 for `PostprocessForEncapsulation`. See comments of -// `PostprocessForEncapsulation` for details. -Status RemovePlaceholderBetweenOutsideCompilationAndHostComputation(Graph* g) { - // Gather all outside compilation to host computation nodes. - struct PlaceHolderNodeInfo { - Node* n; - bool is_host_to_oc; - }; - std::vector placeholder_nodes; - for (Node* n : g->nodes()) { - if (n->type_string() == "Placeholder") { - if (HasNodeAttr(n->def(), - kOutsideCompilationToHostOriginalNodeAttrName)) { - placeholder_nodes.push_back({n, false}); - } else if (HasNodeAttr(n->def(), - kHostToOutsideCompilationOriginalNodeAttrName)) { - placeholder_nodes.push_back({n, true}); - } - } - } - - // Remove the placeholder nodes, and reconnect original edge. - auto node_name_index = g->BuildNodeNameIndex(); - for (auto placeholder_iter : placeholder_nodes) { - Node* n = placeholder_iter.n; - - string node_name; - int node_src_output; - if (placeholder_iter.is_host_to_oc) { - TF_RETURN_IF_ERROR( - GetNodeAttr(n->attrs(), kHostToOutsideCompilationOriginalNodeAttrName, - &node_name)); - TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), - kHostToOutsideCompilationSrcOutputAttrName, - &node_src_output)); - } else { - TF_RETURN_IF_ERROR( - GetNodeAttr(n->attrs(), kOutsideCompilationToHostOriginalNodeAttrName, - &node_name)); - TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), - kOutsideCompilationToHostSrcOutputAttrName, - &node_src_output)); - } - auto iter = node_name_index.find(node_name); - if (iter == node_name_index.end()) { - return errors::Internal( - "Cannot find original node for oc -> host placeholder node ", - node_name); - } - - // Change all usage node to use the original node instead. - Node* original_node = iter->second; - std::vector control_edges; - std::vector data_edges; - for (auto e : n->out_edges()) { - if (e->IsControlEdge()) { - control_edges.push_back(e); - } else { - data_edges.push_back({e->dst(), e->src_output(), e->dst_input()}); - } - } - for (const Edge* e : control_edges) { - g->AddControlEdge(original_node, e->dst()); - g->RemoveEdge(e); - } - for (int i = 0; i < data_edges.size(); i++) { - Node* dst = data_edges[i].dst; - NodeDef new_def = dst->def(); - int dst_input = data_edges[i].dst_input; - *new_def.mutable_input(dst_input) = - absl::StrCat(original_node->name(), ":", node_src_output); - TF_ASSIGN_OR_RETURN(Node * replace_node, ReplaceNode(g, dst, new_def)); - - const Edge* edge_to_replace = nullptr; - TF_RETURN_IF_ERROR(replace_node->input_edge(dst_input, &edge_to_replace)); - g->RemoveEdge(edge_to_replace); - g->AddEdge(original_node, node_src_output, replace_node, dst_input); - - // Other edges might have `dst` as dst node. Update those edges with - // `replace_node`. - for (int j = i + 1; j < data_edges.size(); j++) { - if (data_edges[j].dst == dst) { - data_edges[j].dst = replace_node; - } - } - - // Other placeholder node might have `dst` as original node. Update - // `node_name_index` with `replace_node`. - node_name_index[replace_node->name()] = replace_node; - } - - // Remove placeholder node. - g->RemoveNode(n); - } - return Status::OK(); -} - -// Step 2 for `PostprocessForEncapsulation`. See comments of -// `PostprocessForEncapsulation` for details. -Status RemoveIdentityBetweenDifferentXlaComputation(Graph* g) { - // Gather Identity nodes to remove. - std::vector bridge_nodes; - for (Node* n : g->nodes()) { - if (n->type_string() == "Identity" && - HasNodeAttr(n->def(), kBridgeSourceNodeAttrName)) { - bridge_nodes.push_back(n); - } - } - - // Remove the identity nodes, and reconnect the original edge. - for (int i = 0; i < bridge_nodes.size(); i++) { - Node* n = bridge_nodes[i]; - const Edge* src_edge = nullptr; - TF_RETURN_IF_ERROR(n->input_edge(0, &src_edge)); - - // Change all usage node to use the original node instead. - std::vector control_edges; - std::vector data_edges; - for (auto e : n->out_edges()) { - if (e->IsControlEdge()) { - control_edges.push_back(e); - } else { - data_edges.push_back({e->dst(), e->src_output(), e->dst_input()}); - } - } - for (const Edge* e : control_edges) { - g->AddControlEdge(src_edge->src(), e->dst()); - g->RemoveEdge(e); - } - for (int j = 0; j < data_edges.size(); j++) { - Node* dst = data_edges[j].dst; - NodeDef new_def = dst->def(); - int dst_input = data_edges[j].dst_input; - *new_def.mutable_input(dst_input) = - absl::StrCat(src_edge->src()->name(), ":", src_edge->src_output()); - TF_ASSIGN_OR_RETURN(Node * replace_node, ReplaceNode(g, dst, new_def)); - - const Edge* edge_to_replace = nullptr; - TF_RETURN_IF_ERROR(replace_node->input_edge(dst_input, &edge_to_replace)); - g->RemoveEdge(edge_to_replace); - g->AddEdge(src_edge->src(), src_edge->src_output(), replace_node, - dst_input); - - // Other edges might have `dst` as dst node. Update those edges with - // `replace_node`. - for (int k = j + 1; k < data_edges.size(); k++) { - if (data_edges[k].dst == dst) { - data_edges[k].dst = replace_node; - } - } - - // The node we replaced might be in `bridge_nodes`. If so, update - // `bridge_nodes` to use the replaced node. - for (int k = i + 1; k < bridge_nodes.size(); k++) { - if (bridge_nodes[k] == dst) { - bridge_nodes[k] = replace_node; - } - } - } - - // Remove Identity node. - g->RemoveNode(n); - } - return Status::OK(); -} - -// Step 3 for `PostprocessForEncapsulation`. See comments of -// `PostprocessForEncapsulation` for details. -// We do not need to worry about removed nodes in step 1 and 2; -// `PreprocessForEncapsulation` will not record control dependencies for those -// remvoed nodes in the first place. -Status AddControlDependencies( - Graph* g, const std::unordered_map& cluster_node_names) { - auto node_name_index = g->BuildNodeNameIndex(); - - // Reconnect outside compilation to outside compilation control edge. - for (Node* n : g->nodes()) { - std::vector control_deps; - Status s = - GetNodeAttr(n->attrs(), kXlaControlDependenciesAttrName, &control_deps); - if (!s.ok()) { - if (s.code() != error::NOT_FOUND) { - return s; - } else { - continue; - } - } else { - n->ClearAttr(kXlaControlDependenciesAttrName); - for (const string& control_input : control_deps) { - auto iter = node_name_index.find(control_input); - if (iter == node_name_index.end()) { - return errors::Internal("Cannot find original node for ", - control_input); - } - g->AddControlEdge(iter->second, n); - } - } - } - - // Reconnect outside compilation to XLA computation control edge. - for (Node* n : g->nodes()) { - std::vector control_deps; - Status s = GetNodeAttr( - n->attrs(), kXlaConnectedToOtherXlaComputationAttrName, &control_deps); - if (!s.ok()) { - if (s.code() != error::NOT_FOUND) { - return s; - } else { - continue; - } - } else { - n->ClearAttr(kXlaConnectedToOtherXlaComputationAttrName); - for (const string& control_input : control_deps) { - auto iter = cluster_node_names.find(control_input); - if (iter == cluster_node_names.end()) { - return errors::Internal("Cannot find cluster node for ", - control_input); - } - auto iter2 = node_name_index.find(iter->second); - if (iter2 == node_name_index.end()) { - return errors::Internal("Cannot find cluster node for ", - iter->second); - } - g->AddControlEdge(n, iter2->second); - } - } - } - - // Reconnect XLA computation to outside compilation control edge. - for (Node* n : g->nodes()) { - std::vector control_deps; - Status s = - GetNodeAttr(n->attrs(), kXlaConnectedFromOtherXlaComputationAttrName, - &control_deps); - if (!s.ok()) { - if (s.code() != error::NOT_FOUND) { - return s; - } else { - continue; - } - } else { - n->ClearAttr(kXlaConnectedFromOtherXlaComputationAttrName); - for (const string& control_input : control_deps) { - auto iter = cluster_node_names.find(control_input); - if (iter == cluster_node_names.end()) { - return errors::Internal("Cannot find cluster node for ", - control_input); - } - auto iter2 = node_name_index.find(iter->second); - if (iter2 == node_name_index.end()) { - return errors::Internal("Cannot find cluster node for ", - iter->second); - } - g->AddControlEdge(iter2->second, n); - } - } - } - - return Status::OK(); -} - // Step 1 for `PreprocessEdgesBetweenOutsideCompilations`. See comments of // `PreprocessEdgesBetweenOutsideCompilations` for details. Status PreprocessControlEdgesBetweenOutsideCompilations( @@ -642,7 +132,7 @@ Status PreprocessDataEdgesBetweenOutsideCompilations( // Remove the edge from host to outside compilation. Add a placeholder as // outside compilation node input. - std::map placeholders; + std::map, Node*> placeholders; for (int i = 0; i < edges.size(); i++) { Node* dst = g->FindNodeId(edges[i].dst_node_id); const Edge* e; @@ -652,8 +142,10 @@ Status PreprocessDataEdgesBetweenOutsideCompilations( g->RemoveEdge(e); // Find or create placeholder node. - string new_name = absl::StrCat(src->name(), "_oc_to_oc_placeholder"); - auto iter = placeholders.find(new_name); + string new_name = + absl::StrCat(src->name(), "_oc_to_oc_placeholder_", src_output); + auto placeholder_index = std::make_pair(src->name(), src_output); + auto iter = placeholders.find(placeholder_index); Node* placeholder_node; if (iter == placeholders.end()) { NodeDefBuilder placeholder_builder(new_name, "Placeholder"); @@ -673,7 +165,7 @@ Status PreprocessDataEdgesBetweenOutsideCompilations( Status s; placeholder_node = g->AddNode(placeholder_def, &s); TF_RETURN_IF_ERROR(s); - placeholders[new_name] = placeholder_node; + placeholders[placeholder_index] = placeholder_node; } else { placeholder_node = iter->second; } @@ -808,20 +300,6 @@ Status PostprocessControlEdgesBetweenOutsideCompilations( const char kXlaInferredShapesAttrName[] = "_xla_inferred_shapes"; -const char kXlaConnectedToOtherXlaComputationAttrName[] = - "_xla_connected_to_other_xla_computation"; -const char kXlaConnectedFromOtherXlaComputationAttrName[] = - "_xla_connected_from_other_xla_computation"; -const char kXlaControlDependenciesAttrName[] = "_xla_control_dependencies"; -const char kBridgeSourceNodeAttrName[] = "_xla_bridge_src"; -const char kOutsideCompilationToHostOriginalNodeAttrName[] = - "_xla_oc_to_host_node_name"; -const char kOutsideCompilationToHostSrcOutputAttrName[] = - "_xla_oc_to_host_src_output"; -const char kHostToOutsideCompilationOriginalNodeAttrName[] = - "_xla_host_to_oc_node_name"; -const char kHostToOutsideCompilationSrcOutputAttrName[] = - "_xla_host_to_oc_src_output"; const char kXlaConnectedToXlaComputationAttrName[] = "_xla_connected_to_xla_computation"; const char kXlaConnectedFromXlaComputationAttrName[] = @@ -832,32 +310,7 @@ const char kOutsideCompilationSrcOutputAttrName[] = "_xla_oc_to_oc_src_output"; const char kXlaControlDependenciesWithinXlaClusterAttrName[] = "_xla_control_dependencies_within_xla_cluster"; -Status PerformStaticShapeInferenceBeforeEncapsulation( - Graph* g, const string& xla_computation_attr_name, - const string& outside_compilation_attr_name) { - // Find all outside compilation to XLA computation data edges. - std::unordered_set outside_compilation_send_nodes; - for (auto e : g->edges()) { - if (e->IsControlEdge()) { - continue; - } - - auto src_computation = GetStringAttr(*e->src(), xla_computation_attr_name); - auto dst_computation = GetStringAttr(*e->dst(), xla_computation_attr_name); - if (!src_computation || !dst_computation || - *src_computation != *dst_computation) { - continue; - } - - auto src_outside_compilation = - GetStringAttr(*e->src(), outside_compilation_attr_name); - auto dst_outside_compilation = - GetStringAttr(*e->dst(), outside_compilation_attr_name); - if (src_outside_compilation && !dst_outside_compilation) { - outside_compilation_send_nodes.insert(e->src()); - } - } - +Status PerformStaticShapeInferenceBeforeEncapsulation(Graph* g) { // Perform shape inference. std::map arg_shapes; GraphShapeInfo shape_info; @@ -865,55 +318,21 @@ Status PerformStaticShapeInferenceBeforeEncapsulation( InferShapes(g, arg_shapes, /*fnlib_def=*/nullptr, &shape_info)); // Add attribute for output shapes. - for (Node* n : outside_compilation_send_nodes) { - auto iter = shape_info.find(n->name()); - if (iter == shape_info.end()) { - continue; - } - + auto node_name_index = g->BuildNodeNameIndex(); + for (auto iter : shape_info) { std::vector output_shapes; - std::transform(iter->second.begin(), iter->second.end(), + std::transform(iter.second.begin(), iter.second.end(), std::back_inserter(output_shapes), [](const InferredShape& inferred_shape) { return inferred_shape.shape; }); + Node* n = node_name_index[iter.first]; n->AddAttr(kXlaInferredShapesAttrName, output_shapes); } return Status::OK(); } -Status PreprocessForEncapsulation(Graph* g, - const string& xla_computation_attr_name, - const string& outside_compilation_attr_name) { - TF_RETURN_IF_ERROR(ProcessControlEdges(g, xla_computation_attr_name, - outside_compilation_attr_name)); - TF_RETURN_IF_ERROR(ProcessXlaToXlaDataEdges(g, xla_computation_attr_name, - outside_compilation_attr_name)); - TF_RETURN_IF_ERROR(ProcessDataEdgeBetweenOutsideCompilationAndHostComputation( - g, xla_computation_attr_name, outside_compilation_attr_name)); - return Status::OK(); -} - -Status PostprocessForEncapsulation( - Graph* g, const string& xla_computation_attr_name, - const string& outside_compilation_attr_name, - const std::unordered_map& clusters) { - // The `node` pointer in `XlaClusterInfo` might be invalidated in step 1/2, - // but the node name won't change. Record cluster node name for - // `AddControlDependencies`. - std::unordered_map cluster_node_names; - for (const auto& iter : clusters) { - cluster_node_names[iter.first] = iter.second.node->name(); - } - - TF_RETURN_IF_ERROR( - RemovePlaceholderBetweenOutsideCompilationAndHostComputation(g)); - TF_RETURN_IF_ERROR(RemoveIdentityBetweenDifferentXlaComputation(g)); - TF_RETURN_IF_ERROR(AddControlDependencies(g, cluster_node_names)); - return Status::OK(); -} - Status PreprocessEdgesBetweenOutsideCompilations( Graph* g, const string& outside_compilation_attr_name) { // Remove edges from source node to outside compilation nodes, and edges diff --git a/tensorflow/compiler/jit/encapsulate_util.h b/tensorflow/compiler/jit/encapsulate_util.h index e363bc5754ac395bae262dc67a780a0173efaf5e..c9f16d14168163e11bb19092f566f1de8724aca3 100644 --- a/tensorflow/compiler/jit/encapsulate_util.h +++ b/tensorflow/compiler/jit/encapsulate_util.h @@ -27,51 +27,13 @@ namespace tensorflow { // a list of PartialTensorShape objects. extern const char kXlaInferredShapesAttrName[]; -// Infer output shapes for outside compilation nodes which have output data -// edges to XLA computation nodes. These shapes will be used later by XLA -// compiler as output shapes of the outside compilation's XlaHostCompute op. -// XLA computation nodes will be mark by attr `xla_computation_attr_name`; -// outside compilation nodes will be marked by both attr -// `xla_computation_attr_name` and `outside_compilation_attr_name`. -// -// Those outside compilation nodes will be marked with attribute -// `kXlaInferredShapesAttrName`. +// Infers output shapes for all nodes in graph `g`. The output shapes will be +// stored in node attribute `kXlaInferredShapesAttrName`. // // We have to perform shape inference before encapsulation because after // encapsulation, some nodes will be encapsulated into function call, and shape // inference does not handle function call at the moment. -Status PerformStaticShapeInferenceBeforeEncapsulation( - Graph* g, const string& xla_computation_attr_name, - const string& outside_compilation_attr_name); - -// Attribute indicating that some ops in other XLA computation has control -// dependency on this node. Attribute value will be a list of string (XLA -// computation names). -extern const char kXlaConnectedToOtherXlaComputationAttrName[]; - -// Attribute indicating that this node has control dependency on some ops in -// other XLA computation. Attribute value will be a list of string (XLA -// computation names). -extern const char kXlaConnectedFromOtherXlaComputationAttrName[]; - -// Attribute indicating that this node has control dependencies on some other -// nodes. Attribute value will be a list of string (node names). -extern const char kXlaControlDependenciesAttrName[]; - -// Attribute indicating that this is an Identity node added to act as a bridge -// between different XLA computations. Attribute value will be string (source -// node name). -extern const char kBridgeSourceNodeAttrName[]; - -// Attribute indicating that this is an Placeholder node added to act as a -// temporary input node for an outside compilation node. Attribute value will be -// string (original input node name). -extern const char kOutsideCompilationToHostOriginalNodeAttrName[]; - -// Attribute indicating that this is an Placeholder node added to act as a -// temporary input node for an outside compilation node. Attribute value will be -// int (src_output for original edge). -extern const char kOutsideCompilationToHostSrcOutputAttrName[]; +Status PerformStaticShapeInferenceBeforeEncapsulation(Graph* g); // Attribute indicating that some ops in this node's XLA computation has control // dependency on this node. Attribute value will always be "true". @@ -81,16 +43,6 @@ extern const char kXlaConnectedToXlaComputationAttrName[]; // this node's XLA computation. Attribute value will always be "true". extern const char kXlaConnectedFromXlaComputationAttrName[]; -// Attribute indicating that this is an Placeholder node added to act as a -// temporary input node for an host node. Attribute value will be string -// (original input node name). -extern const char kHostToOutsideCompilationOriginalNodeAttrName[]; - -// Attribute indicating that this is an Placeholder node added to act as a -// temporary input node for a host node. Attribute value will be int (src_output -// for original edge). -extern const char kHostToOutsideCompilationSrcOutputAttrName[]; - // Attribute indicating that this is an Placeholder node added to act as a // temporary input node for an outside compilation node. Attribute value will be // string (original input node name). @@ -106,27 +58,6 @@ extern const char kOutsideCompilationSrcOutputAttrName[]; // (node names). extern const char kXlaControlDependenciesWithinXlaClusterAttrName[]; -// Preprocesses edges between different XLA clusters for encapsulation. It will -// perform the following operations in order: -// -// 1a. For control edges between outside compilation and another XLA -// computation, add attr "kXlaConnected{From, To}OtherXlaComputationAttrName -// = XLA computation node name" to the outside compilation node. -// 1b. For control edges between different outside compilations (in different -// XLA computations), remove the edge and add attr -// "kXlaControlDependenciesAttrName = src node name" to dst node. -// 1c. For control edges between outside compilation and host computation, -// remove the edge and add attr "kXlaControlDependenciesAttrName = src node -// name" to dst node. -// 2. For data edges between different XLA computations, if either src or dst -// is outside compilation, add an Identity node in between the edge. The -// identity node will have attr kBridgeSourceNodeAttrName. -// 3. For data edges between outside compilation and host computation, remove -// the edge and create a Placeholder node as dst node's input. -Status PreprocessForEncapsulation(Graph* g, - const string& xla_computation_attr_name, - const string& outside_compilation_attr_name); - // Information for XLA computation. struct XlaClusterInfo { // Add an explicitly-defined default constructor for this class. @@ -158,24 +89,6 @@ struct XlaClusterInfo { const std::map host_compute_core; }; -// Postprocesses edges between different XLA clusters for encapsulation. This -// function reverts what `PreprocessForEncapsulation` did. It will perform the -// following operations in order: -// -// 1. Remove Placeholder nodes between outside compilation and host computation -// (created in `PreprocessForEncapsulation` step 3). -// 2. Remove Identity nodes created in `PreprocessForEncapsulation` step 2. -// 3a. Reconnect control edges between outside compilation and another XLA -// computation (marked by `PreprocessForEncapsulation` step 1a). -// 3b. Reconnect control edges between different outside compilations (marked by -// `PreprocessForEncapsulation` step 1b). -// 3c. Reconnect control edges between outside compilation and host computation -// (marked by `PreprocessForEncapsulation` step 1c). -Status PostprocessForEncapsulation( - Graph* g, const string& xla_computation_attr_name, - const string& outside_compilation_attr_name, - const std::unordered_map& clusters); - // Preprocesses edges within the same XLA cluster. It will perform the following // operations in order: // diff --git a/tensorflow/compiler/jit/encapsulate_util_test.cc b/tensorflow/compiler/jit/encapsulate_util_test.cc index 25c32cef01d7f9877a35001457539f2ad189192f..3bb979e0698d2d6be42ed5bae66c25267928192c 100644 --- a/tensorflow/compiler/jit/encapsulate_util_test.cc +++ b/tensorflow/compiler/jit/encapsulate_util_test.cc @@ -38,24 +38,11 @@ TEST(PerformStaticShapeInferenceBeforeEncapsulationTest, Basic) { Graph g(OpRegistry::Global()); TF_CHECK_OK(s.ToGraph(&g)); - // "add" node is outside compilation node, "identity" node is XLA node. - auto node_index = g.BuildNodeNameIndex(); - Node *add_node = node_index["add"], *identity_node = node_index["identity"]; - add_node->AddAttr("_xla", "cluster"); - add_node->AddAttr("_oc", "cluster"); - identity_node->AddAttr("_xla", "cluster"); - TF_CHECK_OK( - PerformStaticShapeInferenceBeforeEncapsulation(&g, "_xla", "_oc")); + TF_CHECK_OK(PerformStaticShapeInferenceBeforeEncapsulation(&g)); - // Check that only "add" node now has _xla_inferred_shapes attr. - std::vector nodes_with_inferred_shape; - for (Node *n : g.nodes()) { - if (HasNodeAttr(n->def(), kXlaInferredShapesAttrName)) { - nodes_with_inferred_shape.push_back(n); - } - } - EXPECT_EQ(nodes_with_inferred_shape.size(), 1); - EXPECT_EQ(nodes_with_inferred_shape[0], add_node); + // Check that "add" node now has _xla_inferred_shapes attr. + auto node_index = g.BuildNodeNameIndex(); + Node *add_node = node_index["add"]; std::vector output_shapes; TF_CHECK_OK(GetNodeAttr(add_node->attrs(), kXlaInferredShapesAttrName, &output_shapes)); @@ -66,293 +53,4 @@ TEST(PerformStaticShapeInferenceBeforeEncapsulationTest, Basic) { EXPECT_EQ(shape_proto.dim(0).size(), 2); } -TEST(PreprocessForEncapsulationTest, ControlEdges) { - // Build the graph: - // "const_0" and "const_1" in host computation - // "add" = "const_0" + "const_1" in XLA computation 0 - // "identity0" = "add" in XLA computation 0 & outside compilation 0 - // "identity1" = "identity0" in XLA computation 0 - // "identity2" = "identity1" in host computation - // "identity3" = "identity2" in XLA computation 1 - // "identity4" = "identity3" in XLA computation 1 & outside compilation 1 - // "identity5" = "identity4" in XLA computation 1 - // "identity6" = "identity5" in host computation - tensorflow::Scope s = tensorflow::Scope::NewRootScope(); - Output const_0 = ops::Const(s.WithOpName("const_0"), 1, {}); - Output const_1 = ops::Const(s.WithOpName("const_1"), 2, {}); - Output add = ops::Add(s.WithOpName("add"), const_0, const_1); - Output identity0 = ops::Identity(s.WithOpName("identity0"), add); - Output identity1 = ops::Identity(s.WithOpName("identity1"), identity0); - Output identity2 = ops::Identity(s.WithOpName("identity2"), identity1); - Output identity3 = ops::Identity(s.WithOpName("identity3"), identity2); - Output identity4 = ops::Identity(s.WithOpName("identity4"), identity3); - Output identity5 = ops::Identity(s.WithOpName("identity5"), identity4); - Graph g(OpRegistry::Global()); - TF_CHECK_OK(s.ToGraph(&g)); - auto node_index = g.BuildNodeNameIndex(); - - // Set XLA computation/outside compilation attr, and add control edges. - Node *const0_node = node_index["const_0"], *add_node = node_index["add"], - *identity0_node = node_index["identity0"], - *identity1_node = node_index["identity1"], - *identity2_node = node_index["identity2"], - *identity3_node = node_index["identity3"], - *identity4_node = node_index["identity4"], - *identity5_node = node_index["identity5"]; - add_node->AddAttr("_xla", "0"); - identity0_node->AddAttr("_xla", "0"); - identity0_node->AddAttr("_oc", "0"); - identity1_node->AddAttr("_xla", "0"); - identity3_node->AddAttr("_xla", "1"); - identity4_node->AddAttr("_xla", "1"); - identity4_node->AddAttr("_oc", "0"); - identity5_node->AddAttr("_xla", "1"); - // Case 1a: control edges between outside compilation and another XLA - // computation. - g.AddControlEdge(identity0_node, identity3_node); - g.AddControlEdge(identity1_node, identity4_node); - // Case 1b: control edges between different outside compilations. - g.AddControlEdge(identity0_node, identity4_node); - // Case 1c: control edges between outside compilation and host computation. - g.AddControlEdge(const0_node, identity0_node); - g.AddControlEdge(identity0_node, identity2_node); - - TF_CHECK_OK(PreprocessForEncapsulation(&g, "_xla", "_oc")); - - // Case 1a: add attr "_xla_control_deps_{from/to} = XLA computation node name" - // to the outside compilation node. - std::vector attr; - TF_CHECK_OK(GetNodeAttr(identity0_node->def(), - kXlaConnectedToOtherXlaComputationAttrName, &attr)); - EXPECT_EQ(attr.size(), 1); - EXPECT_EQ(attr[0], "1"); - attr.clear(); - TF_CHECK_OK(GetNodeAttr(identity4_node->def(), - kXlaConnectedFromOtherXlaComputationAttrName, &attr)); - EXPECT_EQ(attr.size(), 1); - EXPECT_EQ(attr[0], "0"); - // Case 1b: add attr "_xla_control_deps = src node name" to dst node. - attr.clear(); - TF_CHECK_OK(GetNodeAttr(identity4_node->def(), - kXlaControlDependenciesAttrName, &attr)); - EXPECT_EQ(attr.size(), 1); - EXPECT_EQ(attr[0], "identity0"); - // Case 1c: add attr "_xla_control_deps = src node name" to dst node. - attr.clear(); - TF_CHECK_OK(GetNodeAttr(identity0_node->def(), - kXlaControlDependenciesAttrName, &attr)); - EXPECT_EQ(attr.size(), 1); - EXPECT_EQ(attr[0], "const_0"); - attr.clear(); - TF_CHECK_OK(GetNodeAttr(identity2_node->def(), - kXlaControlDependenciesAttrName, &attr)); - EXPECT_EQ(attr.size(), 1); - EXPECT_EQ(attr[0], "identity0"); -} - -TEST(PreprocessForEncapsulationTest, DataEdges) { - // Build the graph: - // "const_0" and "const_1" in host computation - // "add0" = "const_0" + "const_1" in XLA computation 0 - // "add1" = "add0" + "const_0" in XLA computation 0 & outside compilation 0 - // "identity0" = "add1" in XLA computation 0 - // "add2" = "add1" + "identity0" in host computation - // "add3" = "add1" + "add2" in XLA computation 1 - // "add4" = "identity0" + "add2" in XLA computation 1 & outside compilation 1 - // "identity1" = "add4" in XLA computation 1 - // "identity2" = "identity1" in host computation - tensorflow::Scope s = tensorflow::Scope::NewRootScope(); - Output const_0 = ops::Const(s.WithOpName("const_0"), 1, {}); - Output const_1 = ops::Const(s.WithOpName("const_1"), 2, {}); - Output add0 = ops::Add(s.WithOpName("add0"), const_0, const_1); - Output add1 = ops::Add(s.WithOpName("add1"), add0, const_0); - Output identity0 = ops::Identity(s.WithOpName("identity0"), add1); - Output add2 = ops::Add(s.WithOpName("add2"), add1, identity0); - Output add3 = ops::Add(s.WithOpName("add3"), add1, add2); - Output add4 = ops::Add(s.WithOpName("add4"), identity0, add2); - Output identity1 = ops::Identity(s.WithOpName("identity1"), add4); - Output identity2 = ops::Identity(s.WithOpName("identity2"), add4); - Graph g(OpRegistry::Global()); - TF_CHECK_OK(s.ToGraph(&g)); - auto node_index = g.BuildNodeNameIndex(); - - // Set XLA computation/outside compilation attr. - Node *add0_node = node_index["add0"], *add1_node = node_index["add1"], - *identity0_node = node_index["identity0"], - *add3_node = node_index["add3"], *add4_node = node_index["add4"], - *identity1_node = node_index["identity1"]; - add0_node->AddAttr("_xla", "0"); - add1_node->AddAttr("_xla", "0"); - add1_node->AddAttr("_oc", "0"); - identity0_node->AddAttr("_xla", "0"); - add3_node->AddAttr("_xla", "1"); - add4_node->AddAttr("_xla", "1"); - add4_node->AddAttr("_oc", "0"); - identity1_node->AddAttr("_xla", "1"); - - TF_CHECK_OK(PreprocessForEncapsulation(&g, "_xla", "_oc")); - - // Check input nodes for related data edges. - node_index = g.BuildNodeNameIndex(); - // Step 2: add an Identity node between different XLA computations. - Node *bridge_add1_add3 = node_index["bridge_add1_add3"]; - EXPECT_NE(bridge_add1_add3, nullptr); - string str; - TF_CHECK_OK( - GetNodeAttr(bridge_add1_add3->attrs(), kBridgeSourceNodeAttrName, &str)); - EXPECT_EQ(str, "add1"); - Node *bridge_identity0_add4 = node_index["bridge_identity0_add4"]; - EXPECT_NE(bridge_identity0_add4, nullptr); - // Step 3: add placeholder for edges between host computation and outside - // compilation. - EXPECT_EQ(bridge_add1_add3->def().input(0), "add1_oc_to_host_placeholder"); - Node *add1_oc_to_host_placeholder = node_index["add1_oc_to_host_placeholder"]; - TF_CHECK_OK(GetNodeAttr(add1_oc_to_host_placeholder->attrs(), - kOutsideCompilationToHostOriginalNodeAttrName, &str)); - EXPECT_EQ(str, "add1"); - int i; - TF_CHECK_OK(GetNodeAttr(add1_oc_to_host_placeholder->attrs(), - kOutsideCompilationToHostSrcOutputAttrName, &i)); - EXPECT_EQ(i, 0); - add4_node = node_index["add4"]; - ASSERT_NE(add4_node, nullptr); - EXPECT_EQ(add4_node->def().input(0), - "bridge_identity0_add4_host_to_oc_placeholder"); - Node *identity0_host_to_oc_placeholder = - node_index["bridge_identity0_add4_host_to_oc_placeholder"]; - TF_CHECK_OK(GetNodeAttr(identity0_host_to_oc_placeholder->attrs(), - kHostToOutsideCompilationOriginalNodeAttrName, &str)); - EXPECT_EQ(str, "bridge_identity0_add4"); - TF_CHECK_OK(GetNodeAttr(identity0_host_to_oc_placeholder->attrs(), - kHostToOutsideCompilationSrcOutputAttrName, &i)); - EXPECT_EQ(i, 0); -} - -TEST(PostprocessForEncapsulationTest, ControlEdges) { - // Build the graph: - // "const0" - // "identity0" = "const0" (XLA computation 0) - // "identity1" = "identity0" - // "identity2" = "identity1" (XLA computation 1) - // "identity3" = "identity2" - tensorflow::Scope s = tensorflow::Scope::NewRootScope(); - Output const0 = ops::Const(s.WithOpName("const0"), 1, {}); - Output identity0 = ops::Identity(s.WithOpName("identity0"), const0); - Output identity1 = ops::Identity(s.WithOpName("identity1"), identity0); - Output identity2 = ops::Identity(s.WithOpName("identity2"), identity1); - Output identity3 = ops::Identity(s.WithOpName("identity3"), identity2); - Graph g(OpRegistry::Global()); - TF_CHECK_OK(s.ToGraph(&g)); - auto node_index = g.BuildNodeNameIndex(); - - // Set XLA computation/outside compilation attr, and add control edges. - Node *const0_node = node_index["const0"], - *identity0_node = node_index["identity0"], - *identity1_node = node_index["identity1"], - *identity2_node = node_index["identity2"], - *identity3_node = node_index["identity3"]; - identity1_node->AddAttr(kXlaConnectedFromOtherXlaComputationAttrName, - std::vector{"0"}); - identity1_node->AddAttr(kXlaConnectedToOtherXlaComputationAttrName, - std::vector{"1"}); - identity3_node->AddAttr(kXlaControlDependenciesAttrName, - std::vector{"const0", "identity1"}); - - std::unordered_map clusters; - clusters["0"].node = identity0_node; - clusters["1"].node = identity2_node; - TF_CHECK_OK(PostprocessForEncapsulation(&g, "_xla", "_oc", clusters)); - - // Case 3a: we have control edge identity0 -> identity1, and identity1 -> - // identity2. - bool edge_identity0_identity1 = false, edge_identity1_identity2 = false; - for (const Edge *e : g.edges()) { - if (!e->IsControlEdge()) { - continue; - } - if (e->src() == identity0_node && e->dst() == identity1_node) { - edge_identity0_identity1 = true; - } else if (e->src() == identity1_node && e->dst() == identity2_node) { - edge_identity1_identity2 = true; - } - } - EXPECT_TRUE(edge_identity0_identity1); - EXPECT_TRUE(edge_identity1_identity2); - // Case 3b: we have control edge const0 -> identity3, and identity1 -> - // identity3. - bool edge_const0_identity3 = false, edge_identity1_identity3 = false; - for (const Edge *e : g.edges()) { - if (!e->IsControlEdge()) { - continue; - } - if (e->src() == const0_node && e->dst() == identity3_node) { - edge_const0_identity3 = true; - } else if (e->src() == identity1_node && e->dst() == identity3_node) { - edge_identity1_identity3 = true; - } - } - EXPECT_TRUE(edge_const0_identity3); - EXPECT_TRUE(edge_identity1_identity3); -} - -TEST(PostprocessForEncapsulationTest, DataEdges) { - // Build the graph: - // "const0" in outside compilation "0" - // "placeholder0" (for "const0") in host computation - // "add0" = "placeholder0" + "placeholder0" in host computation - // "placeholder1" (for "add0") in outside compilation 1 - // "add1" = "placeholder1" + "placeholder1" in outside compilation 1 - // - // "bridge" = "placeholder0" in host computation - // "placeholder2" (for "bridge") in outside compilation 1 - // "add2" = "placeholder2" + "placeholder2" in outside compilation 1 - tensorflow::Scope s = tensorflow::Scope::NewRootScope(); - Output const0 = ops::Const(s.WithOpName("const0"), 1, {}); - Output placeholder0 = - ops::Placeholder(s.WithOpName("placeholder0"), DT_INT32); - Output add0 = ops::Add(s.WithOpName("add0"), placeholder0, placeholder0); - Output placeholder1 = - ops::Placeholder(s.WithOpName("placeholder1"), DT_INT32); - Output add1 = ops::Add(s.WithOpName("add1"), placeholder1, placeholder1); - Output bridge = ops::Identity(s.WithOpName("bridge"), placeholder0); - Output placeholder2 = - ops::Placeholder(s.WithOpName("placeholder2"), DT_INT32); - Output add2 = ops::Add(s.WithOpName("add2"), placeholder2, placeholder2); - Graph g(OpRegistry::Global()); - TF_CHECK_OK(s.ToGraph(&g)); - auto node_index = g.BuildNodeNameIndex(); - - // Set related attributes. - Node *placeholder0_node = node_index["placeholder0"]; - placeholder0_node->AddAttr(kOutsideCompilationToHostOriginalNodeAttrName, - "const0"); - placeholder0_node->AddAttr(kOutsideCompilationToHostSrcOutputAttrName, 0); - Node *placeholder1_node = node_index["placeholder1"]; - placeholder1_node->AddAttr(kHostToOutsideCompilationOriginalNodeAttrName, - "add0"); - placeholder1_node->AddAttr(kHostToOutsideCompilationSrcOutputAttrName, 0); - Node *bridge_node = node_index["bridge"]; - bridge_node->AddAttr(kBridgeSourceNodeAttrName, "const0"); - Node *placeholder2_node = node_index["placeholder2"]; - placeholder2_node->AddAttr(kHostToOutsideCompilationOriginalNodeAttrName, - "bridge"); - placeholder2_node->AddAttr(kHostToOutsideCompilationSrcOutputAttrName, 0); - - std::unordered_map clusters; - TF_CHECK_OK(PostprocessForEncapsulation(&g, "_xla", "_oc", clusters)); - - // Result graph should be: - // "add0" = "const0" + "const0" - // "add1" = "add0" + "add0" - // "add2" = "const0" + "const0" - node_index = g.BuildNodeNameIndex(); - EXPECT_EQ(node_index.size(), 6); - EXPECT_EQ(node_index["add0"]->def().input(0), "const0:0"); - EXPECT_EQ(node_index["add0"]->def().input(1), "const0:0"); - EXPECT_EQ(node_index["add1"]->def().input(0), "add0:0"); - EXPECT_EQ(node_index["add1"]->def().input(1), "add0:0"); - EXPECT_EQ(node_index["add2"]->def().input(0), "const0:0"); - EXPECT_EQ(node_index["add2"]->def().input(1), "const0:0"); -} - } // namespace tensorflow diff --git a/tensorflow/compiler/jit/encapsulate_xla_computations_pass.cc b/tensorflow/compiler/jit/encapsulate_xla_computations_pass.cc index d334100aa4a915a87fb05d371e0e3379a7ee05f2..ec745cdbb7e237f8b4935dd41e9791fc75f5355d 100644 --- a/tensorflow/compiler/jit/encapsulate_xla_computations_pass.cc +++ b/tensorflow/compiler/jit/encapsulate_xla_computations_pass.cc @@ -297,6 +297,7 @@ Status RewriteSubgraph(const std::vector& arg_source_tensors, NodeDef def; def.set_name(launch->name()); + MergeDebugInfo(NodeDebugInfo(launch->def()), &def); // Target the XLA CPU/GPU backends. VLOG(2) << "Replacing with XlaLaunch"; diff --git a/tensorflow/compiler/jit/extract_outside_compilation_pass.cc b/tensorflow/compiler/jit/extract_outside_compilation_pass.cc index e3c7e2f89be9b37b51a633dabb099969c181013f..8b01768c49422b331b52a8ba31bade000c95722e 100644 --- a/tensorflow/compiler/jit/extract_outside_compilation_pass.cc +++ b/tensorflow/compiler/jit/extract_outside_compilation_pass.cc @@ -20,8 +20,10 @@ limitations under the License. #include "tensorflow/compiler/jit/encapsulate_subgraphs_pass.h" #include "tensorflow/compiler/jit/encapsulate_util.h" #include "tensorflow/compiler/tf2xla/dump_graph.h" +#include "tensorflow/compiler/tf2xla/side_effect_util.h" #include "tensorflow/compiler/tf2xla/tf2xla_util.h" #include "tensorflow/core/common_runtime/function.h" +#include "tensorflow/core/framework/function.h" #include "tensorflow/core/framework/graph_to_functiondef.h" #include "tensorflow/core/framework/node_def_builder.h" #include "tensorflow/core/framework/node_def_util.h" @@ -98,9 +100,12 @@ xla::StatusOr BuildRecvAtHostNode( recv_at_host_builder.Attr("Toutputs", recv_at_host_dtypes); // The correct device_ordinal will be inserted during replication in a // subsequent rewrite. - recv_at_host_builder.Attr("device_ordinal", 0); + AttrValue device_ordinal_value; + device_ordinal_value.set_placeholder("device_ordinal"); + recv_at_host_builder.Attr("device_ordinal", device_ordinal_value); recv_at_host_builder.Attr( "key", absl::StrCat("host_compute_channel_", oc_cluster_name)); + recv_at_host_builder.Attr(kXlaHasHostTransferAttrName, true); recv_at_host_builder.Input(key_placeholder->name(), 0, DT_STRING); TF_RETURN_IF_ERROR(recv_at_host_builder.Finalize(&recv_at_host_def)); Status s; @@ -197,9 +202,12 @@ xla::StatusOr BuildSendFromHostNode( send_from_host_builder.Attr("Tinputs", send_from_host_dtypes); // The correct device_ordinal will be inserted during replication in a // subsequent rewrite. - send_from_host_builder.Attr("device_ordinal", 0); + AttrValue device_ordinal_value; + device_ordinal_value.set_placeholder("device_ordinal"); + send_from_host_builder.Attr("device_ordinal", device_ordinal_value); send_from_host_builder.Attr( "key", absl::StrCat("host_compute_channel_", oc_cluster_name)); + send_from_host_builder.Attr(kXlaHasHostTransferAttrName, true); std::vector inputs(send_from_host_dtypes.size()); for (auto* n : ret_nodes) { int index; @@ -322,6 +330,38 @@ xla::StatusOr BuildXlaHostComputeNodeDef( return new_def; } +Status ValidateOutsideCompilationCallNode(Node* call_node) { + // DT_INT64 as input/output for outside compilation is not supported yet: + // b/120809951. + for (const Edge* e : call_node->in_edges()) { + if (e->IsControlEdge()) { + continue; + } + DataType dtype = e->src()->output_type(e->src_output()); + if (dtype == DT_INT64) { + return errors::Unimplemented( + "int64 input for outside compilation is not supported yet: " + "b/120809951. Please cast output of node ", + e->src()->DebugString(), + " to int32 before feeding it into outside compilation."); + } + } + for (const Edge* e : call_node->out_edges()) { + if (e->IsControlEdge()) { + continue; + } + DataType dtype = e->dst()->input_type(e->dst_input()); + if (dtype == DT_INT64) { + return errors::Unimplemented( + "int64 output for outside compilation is not supported yet: " + "b/120809951. Please cast input of node ", + e->dst()->DebugString(), + " to int32 before returning it from outside compilation."); + } + } + return Status::OK(); +} + // Replace outside compilation function call node with XlaHostCompute node. // If the function call node has no input/output edges, we will just remove it // and not create a XlaHostCompute node. @@ -357,6 +397,47 @@ Status ReplaceOrRemoveOutsideCompilationCallNode( return Status::OK(); } +// Resets "device_ordinal" attr to placeholder value for related nodes +// (XlaRecvAtHost nodes; XlaSendFromHost nodes; If nodes containing +// XlaRecvAtHost/XlaSendFromHost). +Status ResetDeviceOrdinalToPlaceholderValue(Graph* g) { + AttrValue device_ordinal_value; + device_ordinal_value.set_placeholder("device_ordinal"); + for (Node* n : g->nodes()) { + if (!HasNodeAttr(n->def(), kXlaHasHostTransferAttrName)) { + continue; + } + + if (n->type_string() == "_XlaRecvAtHost" || + n->type_string() == "_XlaSendFromHost") { + n->ClearAttr("device_ordinal"); + n->AddAttr("device_ordinal", device_ordinal_value); + } else if (n->type_string() == "If") { + for (const string& attr_name : + std::vector{"then_branch", "else_branch"}) { + NameAttrList branch_func; + TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), attr_name, &branch_func)); + (*branch_func.mutable_attr())["device_ordinal"] = device_ordinal_value; + n->ClearAttr(attr_name); + n->AddAttr(attr_name, branch_func); + } + } else if (n->type_string() == "While") { + for (const string& attr_name : std::vector{"cond", "body"}) { + NameAttrList branch_func; + TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), attr_name, &branch_func)); + (*branch_func.mutable_attr())["device_ordinal"] = device_ordinal_value; + n->ClearAttr(attr_name); + n->AddAttr(attr_name, branch_func); + } + } else { + return errors::Internal("Unknown node marked with ", + kXlaHasHostTransferAttrName, ": ", + n->DebugString()); + } + } + return Status::OK(); +} + // For an XLA computation, builds host side graph given all outside compilation // graphs inside it. The host side graph contains: // 1) a "sequencer" node (we will add control edge between XlaRecvAtHost and @@ -368,8 +449,8 @@ Status ReplaceOrRemoveOutsideCompilationCallNode( Status ConstructHostGraph( const string& xla_cluster_name, const string& outside_compilation_attr_name, const std::vector& outside_compilation_host_graphs, - FunctionLibraryDefinition* fld, std::unique_ptr* host_graph) { - host_graph->reset(new Graph(fld)); + FunctionLibraryDefinition* fld, const string& host_graph_func_name) { + Graph host_graph(fld); // Create sequencer node in host graph. NodeDefBuilder sequencer_builder(absl::StrCat(xla_cluster_name, "_sequencer"), @@ -378,24 +459,34 @@ Status ConstructHostGraph( NodeDef sequencer_def; TF_RETURN_IF_ERROR(sequencer_builder.Finalize(&sequencer_def)); Status s; - Node* sequencer = (*host_graph)->AddNode(sequencer_def, &s); + Node* sequencer = host_graph.AddNode(sequencer_def, &s); TF_RETURN_IF_ERROR(s); // Create key placeholder in host graph. TF_ASSIGN_OR_RETURN( Node * key_placeholder, - AddHostComputeKeyPlaceholder(xla_cluster_name, host_graph->get())); + AddHostComputeKeyPlaceholder(xla_cluster_name, &host_graph)); // For each outside compilation graph, copy them to host graph with the // following changes: // a) Use key_placeholder in host graph instead of its own. - // b) Add control edge from RecvAtHost/SendFromHost to sequencer. + // b) Add control edge from host transfer nodes (XlaRecvAtHost, + // XlaSendFromHost, If/While nodes containing + // XlaRecvAtHost/XlaSendFromHost) to sequencer node. // c) Clear node_def.device(), so device placer won't get confused. for (const string& host_func : outside_compilation_host_graphs) { VLOG(4) << "Expanding host graph " << host_func; + // Temporarily use "0" as "device_ordinal". It will be reset to placeholder + // value after we expanded all host graphs. We cannot just use placeholder + // value here because FunctionDef instantiation does not allow placeholder + // value for attributes. + AttrValue device_ordinal_attr; + device_ordinal_attr.set_i(0); + protobuf::Map attrs; + attrs["device_ordinal"] = device_ordinal_attr; FunctionBody* host_fbody = nullptr; TF_RETURN_IF_ERROR(FunctionDefToBodyHelper( - *fld->Find(host_func), AttrSlice(), fld, + *fld->Find(host_func), AttrSlice(&attrs), fld, [&](const string& op, const OpDef** sig) { return fld->LookUpOpDef(op, sig); }, @@ -408,8 +499,8 @@ Status ConstructHostGraph( FixupSourceAndSinkEdges(host_fbody->graph); std::map node_map; - node_map[host_fbody->graph->source_node()] = (*host_graph)->source_node(); - node_map[host_fbody->graph->sink_node()] = (*host_graph)->sink_node(); + node_map[host_fbody->graph->source_node()] = host_graph.source_node(); + node_map[host_fbody->graph->sink_node()] = host_graph.sink_node(); Status s; ReverseDFS( *host_fbody->graph, /*enter=*/nullptr, @@ -431,7 +522,7 @@ Status ConstructHostGraph( NodeDef copy_def = n->def(); // Change c). copy_def.clear_device(); - copy = (*host_graph)->AddNode(copy_def, &s); + copy = host_graph.AddNode(copy_def, &s); if (!s.ok()) { return; } @@ -446,22 +537,23 @@ Status ConstructHostGraph( e->src()->DebugString()); return; } - (*host_graph) - ->AddEdge(node_map[e->src()], e->src_output(), copy, - e->dst_input()); + host_graph.AddEdge(node_map[e->src()], e->src_output(), copy, + e->dst_input()); } // Change b). - if (copy->type_string() == "_XlaRecvAtHost" || - copy->type_string() == "_XlaSendFromHost") { - (*host_graph)->AddControlEdge(copy, sequencer); + if (HasNodeAttr(copy->def(), kXlaHasHostTransferAttrName)) { + host_graph.AddControlEdge(copy, sequencer); } }, NodeComparatorID()); + if (!s.ok()) { return s; } } + // Reset "device_ordinal" to placeholder value. + TF_RETURN_IF_ERROR(ResetDeviceOrdinalToPlaceholderValue(&host_graph)); // sequencer and key_placeholder might be dead nodes. Prune them if necessary. // - sequencer should be pruned iff it has no input control edges from @@ -470,21 +562,30 @@ Status ConstructHostGraph( // - key_placeholder should be pruned iff there's no RecvAtHost/SendFromHost. // We don't need to do anything special. if (!sequencer->in_edges().empty()) { - (*host_graph)->AddControlEdge(sequencer, (*host_graph)->sink_node()); + host_graph.AddControlEdge(sequencer, host_graph.sink_node()); } PruneForReverseReachability( - host_graph->get(), - std::unordered_set{(*host_graph)->sink_node()}); + &host_graph, std::unordered_set{host_graph.sink_node()}); // Postprocess edges between different outside compilations. TF_RETURN_IF_ERROR(PostprocessEdgesBetweenOutsideCompilations( - host_graph->get(), outside_compilation_attr_name)); + &host_graph, outside_compilation_attr_name)); if (VLOG_IS_ON(4)) { dump_graph::DumpGraphToFile( absl::StrCat("extract_outside_compilation_host_graph_for_", xla_cluster_name), - **host_graph, fld); + host_graph, fld); + } + + FunctionDef host_graph_fdef; + TF_RETURN_IF_ERROR( + GraphToFunctionDef(host_graph, host_graph_func_name, &host_graph_fdef)); + if (fld->Find(host_graph_func_name)) { + TF_RETURN_IF_ERROR( + fld->ReplaceFunction(host_graph_func_name, host_graph_fdef)); + } else { + TF_RETURN_IF_ERROR(fld->AddFunctionDef(host_graph_fdef)); } return Status::OK(); @@ -492,8 +593,28 @@ Status ConstructHostGraph( // Expand XLA computation's outside compilation host side graph into main graph. // Add a control edge between sequencer node and the XLA computation node. -Status ExpandHostGraphIntoMainGraph(Graph* main_graph, Graph* host_graph, +Status ExpandHostGraphIntoMainGraph(Graph* main_graph, + FunctionLibraryDefinition* fld, + const string& host_graph_func_name, Node* xla_computation_node) { + // Temporarily use "0" as "device_ordinal". It will be rewritten with the + // correct value in a later pass. We cannot just use placeholder value here + // because FunctionDef instantiation does not allow placeholder value for + // attributes. + AttrValue device_ordinal_attr; + device_ordinal_attr.set_i(0); + protobuf::Map attrs; + attrs["device_ordinal"] = device_ordinal_attr; + FunctionBody* fbody = nullptr; + TF_RETURN_IF_ERROR(FunctionDefToBodyHelper( + *fld->Find(host_graph_func_name), AttrSlice(&attrs), fld, + [&](const string& op, const OpDef** sig) { + return fld->LookUpOpDef(op, sig); + }, + &fbody)); + std::unique_ptr fbody_deleter(fbody); + Graph* host_graph = fbody->graph; + // We use ReverseDFS() to copy nodes. Make sure all nodes are reverse // reachable from sink node so all nodes will be copied. // TODO(b/77601805): consolidate copy graph functions. @@ -545,23 +666,25 @@ Status ExpandHostGraphIntoMainGraph(Graph* main_graph, Graph* host_graph, return s; } -// Rewrites shape inference graph for outside compilation. -// 1. If the outside compilation is a "top-level" one (not in a function of any -// If/While/etc.), this shape inference graph might have host computation to -// outside compilation placeholder nodes, which will cause shape inference to -// fail. However, those nodes are not in `host_graph` any more (because we -// have executed `PostprocessForEncapsultion`). In this case, we clear the -// graph, and copy SendFromHost with all its predecessors from `host_graph`. -// This case is detected by whether the SendFromHost node exists in -// `host_graph` as well. -// 2. Remove control edges, and prune nodes that are not useful for shape -// inference. +// Rewrites shape inference graph for outside compilation: +// 1) If XlaSendFromHost also exists in `host_graph`, copy nodes from +// `host_graph`. Because we might still have outside compilation to outside +// compilation placeholder nodes in shape inference graph, which will prevent +// us from inferring XlaSendFromHost shape. But in `host_graph`, we already +// removed those placeholder nodes. +// 2) Remove control edges. +// 3) Prune nodes that are not useful for shape inference. Status RewriteShapeInferenceGraph(const string& shape_inference_graph_name, Graph* host_graph, FunctionLibraryDefinition* fld) { + // Use "0" as "device_ordinal". It does not matter for shape inference. + AttrValue device_ordinal_attr; + device_ordinal_attr.set_i(0); + protobuf::Map attrs; + attrs["device_ordinal"] = device_ordinal_attr; FunctionBody* fbody = nullptr; TF_RETURN_IF_ERROR(FunctionDefToBodyHelper( - *fld->Find(shape_inference_graph_name), AttrSlice(), fld, + *fld->Find(shape_inference_graph_name), AttrSlice(&attrs), fld, [&](const string& op, const OpDef** sig) { return fld->LookUpOpDef(op, sig); }, @@ -650,6 +773,7 @@ Status RewriteShapeInferenceGraph(const string& shape_inference_graph_name, g->RemoveEdge(e); } } + // Nodes that are not reverse reachable from SendFromHost are not useful for // shape inference. Prune them. PruneForReverseReachability(g, @@ -669,6 +793,572 @@ Status RewriteShapeInferenceGraph(const string& shape_inference_graph_name, return Status::OK(); } +// Builds XlaSendToHost node which sends cond predicate to host. +xla::StatusOr BuildSendIfPredNode(const string& name, + const string& host_transfer_key, + Node* pred_node, Graph* g) { + NodeDefBuilder send_pred_builder(name, "XlaSendToHost"); + send_pred_builder.Attr("Tinput", DT_BOOL); + send_pred_builder.Attr("key", absl::StrCat(host_transfer_key, "_dtoh_0")); + send_pred_builder.Attr(kXlaTokenInputNodesAttrName, + std::vector{kXlaTokenArgNodeName}); + send_pred_builder.Input(pred_node->name(), 0, DT_BOOL); + NodeDef send_pred_def; + TF_RETURN_IF_ERROR(send_pred_builder.Finalize(&send_pred_def)); + Status s; + Node* send_pred_node = g->AddNode(send_pred_def, &s); + TF_RETURN_IF_ERROR(s); + g->AddEdge(pred_node, 0, send_pred_node, 0); + return send_pred_node; +} + +// Replaces key placeholder node with an _Arg node. +Status ReplaceKeyPlaceholderWithArgNode(const string& xla_cluster_name, + const string& func_name, + FunctionLibraryDefinition* fld) { + // Temporarily use "0" as "device_ordinal". It will be reset to placeholder + // value after rewriting. + AttrValue device_ordinal_attr; + device_ordinal_attr.set_i(0); + protobuf::Map attrs; + attrs["device_ordinal"] = device_ordinal_attr; + FunctionBody* fbody = nullptr; + TF_RETURN_IF_ERROR(FunctionDefToBodyHelper( + *fld->Find(func_name), AttrSlice(&attrs), fld, + [&](const string& op, const OpDef** sig) { + return fld->LookUpOpDef(op, sig); + }, + &fbody)); + std::unique_ptr fbody_deleter(fbody); + Graph* g = fbody->graph; + + // Find or create the key placeholder node. + Node* key_placeholder = nullptr; + for (Node* n : g->nodes()) { + if (IsKeyPlaceholderNode(*n)) { + key_placeholder = n; + break; + } + } + if (!key_placeholder) { + TF_ASSIGN_OR_RETURN(key_placeholder, + AddHostComputeKeyPlaceholder(xla_cluster_name, g)); + } + + // Build the _Arg node, and replace key placeholder node with it. + NodeDefBuilder arg_builder("key_arg", FunctionLibraryDefinition::kArgOp); + arg_builder.Attr("T", DT_STRING); + arg_builder.Attr("index", 0); + NodeDef arg_def; + TF_RETURN_IF_ERROR(arg_builder.Finalize(&arg_def)); + TF_RETURN_IF_ERROR(ReplaceNode(g, key_placeholder, arg_def).status()); + + // Reset "device_ordinal" to placeholder value. + TF_RETURN_IF_ERROR(ResetDeviceOrdinalToPlaceholderValue(g)); + + FunctionDef replace_fdef; + TF_RETURN_IF_ERROR(GraphToFunctionDef(*g, func_name, &replace_fdef)); + TF_RETURN_IF_ERROR(fld->ReplaceFunction(func_name, replace_fdef)); + return Status::OK(); +} + +// Builds host side graph for If node. +Status BuildHostGraphForIfNode(const string& xla_cluster_attr_name, + const string& outside_compilation_attr_name, + const string& xla_cluster_name, + const string& if_node_name, + const string& host_transfer_key, + const string& host_graph_func_name, + FunctionLibraryDefinition* fld, + const string& then_branch_host_func_name, + const string& else_branch_host_func_name) { + Graph host_graph(fld); + string outside_compilation_name = absl::StrCat("oc_if_", if_node_name); + AttrValue device_ordinal_value; + device_ordinal_value.set_placeholder("device_ordinal"); + + // Step 1: add key placeholder node. + TF_ASSIGN_OR_RETURN( + Node * key_placeholder, + AddHostComputeKeyPlaceholder(xla_cluster_name, &host_graph)); + + // Step 2: build XlaRecvAtHost node to recv predicate. + NodeDefBuilder recv_pred_builder( + absl::StrCat("recv_oc_if_pred_", if_node_name), "_XlaRecvAtHost"); + recv_pred_builder.Attr("Toutputs", std::vector{DT_BOOL}); + recv_pred_builder.Attr("key", host_transfer_key); + recv_pred_builder.Attr("device_ordinal", device_ordinal_value); + recv_pred_builder.Attr(xla_cluster_attr_name, xla_cluster_name); + recv_pred_builder.Attr(outside_compilation_attr_name, + outside_compilation_name); + recv_pred_builder.Attr(kXlaHasHostTransferAttrName, true); + recv_pred_builder.Input(key_placeholder->name(), 0, DT_STRING); + NodeDef recv_pred_def; + TF_RETURN_IF_ERROR(recv_pred_builder.Finalize(&recv_pred_def)); + Status s; + Node* recv_pred_node = host_graph.AddNode(recv_pred_def, &s); + TF_RETURN_IF_ERROR(s); + host_graph.AddEdge(key_placeholder, 0, recv_pred_node, 0); + + // Step 3: rewrite `{then, else}_branch_host_func_name`, replace key + // placeholder with an _Arg node. + TF_RETURN_IF_ERROR(ReplaceKeyPlaceholderWithArgNode( + xla_cluster_name, then_branch_host_func_name, fld)); + TF_RETURN_IF_ERROR(ReplaceKeyPlaceholderWithArgNode( + xla_cluster_name, else_branch_host_func_name, fld)); + + // Step 4: build If node to choose between `{then, else}_branch_host_graph`. + NodeDefBuilder if_builder(absl::StrCat("oc_if_", if_node_name), "If"); + if_builder.Attr("Tcond", DT_BOOL); + if_builder.Attr("Tin", std::vector{DT_STRING}); + if_builder.Attr("Tout", std::vector{}); + NameAttrList host_then_branch, host_else_branch; + host_then_branch.set_name(then_branch_host_func_name); + (*host_then_branch.mutable_attr())["device_ordinal"] = device_ordinal_value; + host_else_branch.set_name(else_branch_host_func_name); + (*host_else_branch.mutable_attr())["device_ordinal"] = device_ordinal_value; + if_builder.Attr("then_branch", host_then_branch); + if_builder.Attr("else_branch", host_else_branch); + if_builder.Attr(kXlaHasHostTransferAttrName, true); + if_builder.Attr(xla_cluster_attr_name, xla_cluster_name); + if_builder.Attr(outside_compilation_attr_name, outside_compilation_name); + if_builder.Input(recv_pred_node->name(), 0, DT_BOOL); + std::vector if_inputs{ + {key_placeholder->name(), 0, DT_STRING}}; + if_builder.Input(if_inputs); + NodeDef if_def; + TF_RETURN_IF_ERROR(if_builder.Finalize(&if_def)); + Node* if_node = host_graph.AddNode(if_def, &s); + TF_RETURN_IF_ERROR(s); + host_graph.AddEdge(recv_pred_node, 0, if_node, 0); + host_graph.AddEdge(key_placeholder, 0, if_node, 1); + + // Convert `host_graph` to function, and add a "device_ordinal" attr. + FunctionDef oc_host_graph_fdef; + TF_RETURN_IF_ERROR(GraphToFunctionDef(host_graph, host_graph_func_name, + &oc_host_graph_fdef)); + if (fld->Find(host_graph_func_name)) { + TF_RETURN_IF_ERROR( + fld->ReplaceFunction(host_graph_func_name, oc_host_graph_fdef)); + } else { + TF_RETURN_IF_ERROR(fld->AddFunctionDef(oc_host_graph_fdef)); + } + + return Status::OK(); +} + +// Rewrites loop cond to add a node which sends loop cond to host. +Status AddSendLoopPredToLoopCond(FunctionLibraryDefinition* fld, + const NameAttrList& loop_cond_func, + const string& while_node_name, + const string& host_transfer_key) { + // Instantiate the loop cond function. + FunctionBody* fbody = nullptr; + TF_RETURN_IF_ERROR(FunctionDefToBodyHelper( + *fld->Find(loop_cond_func.name()), AttrSlice(&loop_cond_func.attr()), fld, + [&](const string& op, const OpDef** sig) { + return fld->LookUpOpDef(op, sig); + }, + &fbody)); + std::unique_ptr fbody_deleter(fbody); + Graph* g = fbody->graph; + + // Find the _Retval node and the loop cond node. + Node* ret_node = nullptr; + for (Node* n : g->nodes()) { + if (n->type_string() == "_Retval") { + if (ret_node) { + return errors::Internal("Multiple return node for loop cond function ", + loop_cond_func.name(), ": ", + ret_node->DebugString(), " and ", + n->DebugString()); + } else { + ret_node = n; + } + } + } + if (!ret_node) { + return errors::Internal("No _Retval node for loop cond function ", + loop_cond_func.name()); + } + Node* loop_cond; + TF_RETURN_IF_ERROR(ret_node->input_node(0, &loop_cond)); + + // Build the XlaSendToHost node. + NodeDefBuilder send_loop_cond_builder( + absl::StrCat("send_oc_while_cond_", while_node_name), "XlaSendToHost"); + send_loop_cond_builder.Attr("Tinput", DT_BOOL); + send_loop_cond_builder.Attr("key", + absl::StrCat(host_transfer_key, "_dtoh_0")); + send_loop_cond_builder.Attr(kXlaTokenInputNodesAttrName, + std::vector{kXlaTokenArgNodeName}); + send_loop_cond_builder.Input(loop_cond->name(), 0, DT_BOOL); + NodeDef send_loop_cond_def; + TF_RETURN_IF_ERROR(send_loop_cond_builder.Finalize(&send_loop_cond_def)); + Status s; + Node* send_loop_cond_node = g->AddNode(send_loop_cond_def, &s); + TF_RETURN_IF_ERROR(s); + g->AddEdge(loop_cond, 0, send_loop_cond_node, 0); + + // Replace original function. + FunctionDef replace_fdef; + TF_RETURN_IF_ERROR( + GraphToFunctionDef(*g, loop_cond_func.name(), &replace_fdef)); + TF_RETURN_IF_ERROR(fld->ReplaceFunction(loop_cond_func.name(), replace_fdef)); + + return Status::OK(); +} + +// Rewrites while loop cond function for host. +Status RewriteHostWhileLoopCond( + const string& cond_host_func_name, const string& while_node_name, + const string& host_transfer_key, const string& xla_cluster_attr_name, + const string& xla_cluster_name, const string& outside_compilation_attr_name, + const string& outside_compilation_name, FunctionLibraryDefinition* fld) { + // Replace key placeholder node with _Arg node. + TF_RETURN_IF_ERROR(ReplaceKeyPlaceholderWithArgNode( + xla_cluster_name, cond_host_func_name, fld)); + + // Instantiate cond function. + AttrValue device_ordinal_temp_value; + device_ordinal_temp_value.set_i(0); + protobuf::Map attrs; + attrs["device_ordinal"] = device_ordinal_temp_value; + FunctionBody* cond_fbody = nullptr; + TF_RETURN_IF_ERROR(FunctionDefToBodyHelper( + *fld->Find(cond_host_func_name), AttrSlice(&attrs), fld, + [&](const string& op, const OpDef** sig) { + return fld->LookUpOpDef(op, sig); + }, + &cond_fbody)); + std::unique_ptr cond_fbody_deleter(cond_fbody); + Graph* cond_graph = cond_fbody->graph; + Node* key_arg = nullptr; + for (Node* n : cond_graph->nodes()) { + if (n->type_string() == "_Arg") { + key_arg = n; + } + } + if (!key_arg) { + return errors::Internal( + "No _Arg node found for host compute key in function ", + cond_host_func_name); + } + + // Add an XlaRecvAtHost node to use as cond function return value. + // We don't need to set kXlaHasHostTransferAttrName for this node, because + // it's already added for the "While" node on the host. + NodeDefBuilder recv_pred_builder( + absl::StrCat("recv_oc_while_cond_", while_node_name), "_XlaRecvAtHost"); + recv_pred_builder.Attr("Toutputs", std::vector{DT_BOOL}); + recv_pred_builder.Attr("key", host_transfer_key); + AttrValue device_ordinal_value; + device_ordinal_value.set_placeholder("device_ordinal"); + recv_pred_builder.Attr("device_ordinal", device_ordinal_value); + recv_pred_builder.Attr(xla_cluster_attr_name, xla_cluster_name); + recv_pred_builder.Attr(outside_compilation_attr_name, + outside_compilation_name); + recv_pred_builder.Input(key_arg->name(), 0, DT_STRING); + NodeDef recv_pred_def; + TF_RETURN_IF_ERROR(recv_pred_builder.Finalize(&recv_pred_def)); + Status s; + Node* recv_pred_node = cond_graph->AddNode(recv_pred_def, &s); + TF_RETURN_IF_ERROR(s); + cond_graph->AddEdge(key_arg, 0, recv_pred_node, 0); + NodeDefBuilder ret_builder( + absl::StrCat("recv_oc_while_cond_ret_", while_node_name), "_Retval"); + ret_builder.Attr("T", DT_BOOL); + ret_builder.Attr("index", 0); + ret_builder.Input(recv_pred_node->name(), 0, DT_BOOL); + NodeDef ret_def; + TF_RETURN_IF_ERROR(ret_builder.Finalize(&ret_def)); + Node* ret_node = cond_graph->AddNode(ret_def, &s); + TF_RETURN_IF_ERROR(s); + cond_graph->AddEdge(recv_pred_node, 0, ret_node, 0); + + // Reset device_ordinal to placeholder value. + TF_RETURN_IF_ERROR(ResetDeviceOrdinalToPlaceholderValue(cond_graph)); + + // Replace original function. + FunctionDef cond_replace_fdef; + TF_RETURN_IF_ERROR( + GraphToFunctionDef(*cond_graph, cond_host_func_name, &cond_replace_fdef)); + TF_RETURN_IF_ERROR( + fld->ReplaceFunction(cond_host_func_name, cond_replace_fdef)); + + return Status::OK(); +} + +// Rewrites while loop body function for host. +Status RewriteHostWhileLoopBody( + const string& body_host_func_name, const string& while_node_name, + const string& host_transfer_key, const string& xla_cluster_attr_name, + const string& xla_cluster_name, const string& outside_compilation_attr_name, + const string& outside_compilation_name, FunctionLibraryDefinition* fld) { + // Replace key placeholder node with _Arg node. + TF_RETURN_IF_ERROR(ReplaceKeyPlaceholderWithArgNode( + xla_cluster_name, body_host_func_name, fld)); + + // Instantiate body function. + AttrValue device_ordinal_temp_value; + device_ordinal_temp_value.set_i(0); + protobuf::Map attrs; + attrs["device_ordinal"] = device_ordinal_temp_value; + FunctionBody* body_fbody = nullptr; + TF_RETURN_IF_ERROR(FunctionDefToBodyHelper( + *fld->Find(body_host_func_name), AttrSlice(&attrs), fld, + [&](const string& op, const OpDef** sig) { + return fld->LookUpOpDef(op, sig); + }, + &body_fbody)); + std::unique_ptr body_fbody_deleter(body_fbody); + Graph* body_graph = body_fbody->graph; + Node* key_arg = nullptr; + for (Node* n : body_graph->nodes()) { + if (n->type_string() == "_Arg") { + key_arg = n; + } + } + if (!key_arg) { + return errors::Internal( + "No _Arg node found for host compute key in function ", + body_host_func_name); + } + + // Add a _Retval node to loop body. + NodeDefBuilder ret_builder( + absl::StrCat("recv_oc_while_body_ret_", while_node_name), "_Retval"); + ret_builder.Attr("T", DT_STRING); + ret_builder.Attr("index", 0); + ret_builder.Input(key_arg->name(), 0, DT_STRING); + NodeDef ret_def; + TF_RETURN_IF_ERROR(ret_builder.Finalize(&ret_def)); + Status s; + Node* ret_node = body_graph->AddNode(ret_def, &s); + TF_RETURN_IF_ERROR(s); + body_graph->AddEdge(key_arg, 0, ret_node, 0); + + // Reset device_ordinal to placeholder value. + TF_RETURN_IF_ERROR(ResetDeviceOrdinalToPlaceholderValue(body_graph)); + + // Replace original function. + FunctionDef body_replace_fdef; + TF_RETURN_IF_ERROR( + GraphToFunctionDef(*body_graph, body_host_func_name, &body_replace_fdef)); + TF_RETURN_IF_ERROR( + fld->ReplaceFunction(body_host_func_name, body_replace_fdef)); + + return Status::OK(); +} + +// Builds host side graph for while node. +Status BuildHostGraphForWhileNode( + const string& xla_cluster_attr_name, + const string& outside_compilation_attr_name, const string& xla_cluster_name, + const string& while_node_name, const string& host_transfer_key, + const string& host_graph_func_name, FunctionLibraryDefinition* fld, + const string& cond_host_func_name, const string& body_host_func_name) { + Graph host_graph(fld); + string outside_compilation_name = absl::StrCat("oc_while_", while_node_name); + + // Step 1: add key placeholder node. + TF_ASSIGN_OR_RETURN( + Node * key_placeholder, + AddHostComputeKeyPlaceholder(xla_cluster_name, &host_graph)); + + // Step 2: rewrite cond function. + TF_RETURN_IF_ERROR(RewriteHostWhileLoopCond( + cond_host_func_name, while_node_name, host_transfer_key, + xla_cluster_attr_name, xla_cluster_name, outside_compilation_attr_name, + outside_compilation_name, fld)); + + // Step 3: rewrite body function. + TF_RETURN_IF_ERROR(RewriteHostWhileLoopBody( + body_host_func_name, while_node_name, host_transfer_key, + xla_cluster_attr_name, xla_cluster_name, outside_compilation_attr_name, + outside_compilation_name, fld)); + + // Step 4: build While node. + NodeDefBuilder while_builder(absl::StrCat("oc_while_", while_node_name), + "While"); + while_builder.Attr("T", std::vector{DT_STRING}); + NameAttrList func; + AttrValue device_ordinal_value; + device_ordinal_value.set_placeholder("device_ordinal"); + (*func.mutable_attr())["device_ordinal"] = device_ordinal_value; + func.set_name(cond_host_func_name); + while_builder.Attr("cond", func); + func.set_name(body_host_func_name); + while_builder.Attr("body", func); + while_builder.Attr(kXlaHasHostTransferAttrName, true); + while_builder.Attr(xla_cluster_attr_name, xla_cluster_name); + while_builder.Attr(outside_compilation_attr_name, outside_compilation_name); + std::vector while_inputs{ + {key_placeholder->name(), 0, DT_STRING}}; + while_builder.Input(while_inputs); + NodeDef while_def; + TF_RETURN_IF_ERROR(while_builder.Finalize(&while_def)); + Status s; + Node* while_node = host_graph.AddNode(while_def, &s); + TF_RETURN_IF_ERROR(s); + host_graph.AddEdge(key_placeholder, 0, while_node, 0); + + // Convert `host_graph` to function. + FunctionDef oc_host_graph_fdef; + TF_RETURN_IF_ERROR(GraphToFunctionDef(host_graph, host_graph_func_name, + &oc_host_graph_fdef)); + if (fld->Find(host_graph_func_name)) { + TF_RETURN_IF_ERROR( + fld->ReplaceFunction(host_graph_func_name, oc_host_graph_fdef)); + } else { + TF_RETURN_IF_ERROR(fld->AddFunctionDef(oc_host_graph_fdef)); + } + + return Status::OK(); +} + +Status ExtractOutsideCompilationForNodesWithAssociatedFunctions( + Graph* g, const string& xla_cluster_attr_name, + const string& outside_compilation_attr_name, const string& xla_cluster_name, + const std::map& host_compute_core, + FunctionLibraryDefinition* fld, std::vector* host_graphs, + std::vector* shape_inference_graphs, + bool* has_outside_compilation) { + std::vector if_nodes, while_nodes; + for (Node* n : g->nodes()) { + if (n->type_string() == "If") { + if_nodes.push_back(n); + } else if (n->type_string() == "While") { + while_nodes.push_back(n); + } + } + + for (Node* n : if_nodes) { + // Instantiate "then_branch" and "else_branch". + NameAttrList then_branch, else_branch; + TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "then_branch", &then_branch)); + TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "else_branch", &else_branch)); + + // Extract outside compilation for then_branch and else_branch. + bool then_branch_has_outside_compilation = false; + bool else_branch_has_outside_compilation = false; + string then_branch_host_func_name = + absl::StrCat("oc_then_branch_host_if_", n->name()), + else_branch_host_func_name = + absl::StrCat("oc_else_branch_host_if_", n->name()); + string then_branch_xla_func_name = absl::StrCat(then_branch.name(), "_oc"), + else_branch_xla_func_name = absl::StrCat(else_branch.name(), "_oc"); + TF_RETURN_IF_ERROR(ExtractOutsideCompilationForFunction( + xla_cluster_attr_name, outside_compilation_attr_name, xla_cluster_name, + then_branch, then_branch_xla_func_name, then_branch_host_func_name, + host_compute_core, fld, shape_inference_graphs, + &then_branch_has_outside_compilation)); + TF_RETURN_IF_ERROR(ExtractOutsideCompilationForFunction( + xla_cluster_attr_name, outside_compilation_attr_name, xla_cluster_name, + else_branch, else_branch_xla_func_name, else_branch_host_func_name, + host_compute_core, fld, shape_inference_graphs, + &else_branch_has_outside_compilation)); + + // If then/else branch do not have outside compilation, nothing to do. + if (!then_branch_has_outside_compilation && + !else_branch_has_outside_compilation) { + continue; + } + + *has_outside_compilation = true; + + // Change If node to call the new functions. + then_branch.set_name(then_branch_xla_func_name); + n->ClearAttr("then_branch"); + n->AddAttr("then_branch", then_branch); + else_branch.set_name(else_branch_xla_func_name); + n->ClearAttr("else_branch"); + n->AddAttr("else_branch", else_branch); + + string host_transfer_key = absl::StrCat("oc_if_pred_", n->name()); + + // XLA computation: add a SendToHost node to send cond predicate. + Node* pred_node; + TF_RETURN_IF_ERROR(n->input_node(0, &pred_node)); + TF_ASSIGN_OR_RETURN( + Node * send_pred_node, + BuildSendIfPredNode(absl::StrCat("send_oc_if_pred_", n->name()), + host_transfer_key, pred_node, g)); + n->AddAttr(kXlaTokenInputNodesAttrName, + std::vector{send_pred_node->name()}); + + // Add a control edge from `send_pred_node` to If node, so XlaCompiler will + // visit If node after `send_pred_node`, thus the token output for + // `send_pred_node` has been generated. + g->AddControlEdge(send_pred_node, n); + + // Build host side graph for the "If" node. + string oc_host_graph_name = absl::StrCat("oc_if_host_graph_", n->name()); + TF_RETURN_IF_ERROR(BuildHostGraphForIfNode( + xla_cluster_attr_name, outside_compilation_attr_name, xla_cluster_name, + n->name(), host_transfer_key, oc_host_graph_name, fld, + then_branch_host_func_name, else_branch_host_func_name)); + host_graphs->push_back(oc_host_graph_name); + } + + for (Node* n : while_nodes) { + // Instantiate "cond" and "body". + NameAttrList cond, body; + TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "cond", &cond)); + TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "body", &body)); + + // Extract outside compilation for cond and body. + bool cond_has_outside_compilation = false; + bool body_has_outside_compilation = false; + string cond_host_func_name = absl::StrCat("oc_cond_host_while_", n->name()), + body_host_func_name = absl::StrCat("oc_body_host_while_", n->name()); + string cond_xla_func_name = absl::StrCat(cond.name(), "_oc"), + body_xla_func_name = absl::StrCat(body.name(), "_oc"); + TF_RETURN_IF_ERROR(ExtractOutsideCompilationForFunction( + xla_cluster_attr_name, outside_compilation_attr_name, xla_cluster_name, + cond, cond_xla_func_name, cond_host_func_name, host_compute_core, fld, + shape_inference_graphs, &cond_has_outside_compilation)); + TF_RETURN_IF_ERROR(ExtractOutsideCompilationForFunction( + xla_cluster_attr_name, outside_compilation_attr_name, xla_cluster_name, + body, body_xla_func_name, body_host_func_name, host_compute_core, fld, + shape_inference_graphs, &body_has_outside_compilation)); + + // If cond/body do not have outside compilation, nothing to do. + if (!cond_has_outside_compilation && !body_has_outside_compilation) { + continue; + } + + *has_outside_compilation = true; + + // Change While node to call the new functions. + cond.set_name(cond_xla_func_name); + n->ClearAttr("cond"); + n->AddAttr("cond", cond); + body.set_name(body_xla_func_name); + n->ClearAttr("body"); + n->AddAttr("body", body); + + string host_transfer_key = absl::StrCat("oc_while_pred_", n->name()); + + // XLA computation: rewrite cond function to add a SendToHost node to send + // loop predicate. + TF_RETURN_IF_ERROR( + AddSendLoopPredToLoopCond(fld, cond, n->name(), host_transfer_key)); + n->AddAttr(kXlaTokenInputNodesAttrName, + std::vector{kXlaTokenArgNodeName}); + + // Build host side graph for the "While" node. + string oc_host_graph_name = absl::StrCat("oc_while_host_graph_", n->name()); + TF_RETURN_IF_ERROR(BuildHostGraphForWhileNode( + xla_cluster_attr_name, outside_compilation_attr_name, xla_cluster_name, + n->name(), host_transfer_key, oc_host_graph_name, fld, + cond_host_func_name, body_host_func_name)); + host_graphs->push_back(oc_host_graph_name); + } + + return Status::OK(); +} + } // namespace Status RewriteOutsideCompilationSubgraphFn::operator()( @@ -755,12 +1445,15 @@ Status RewriteOutsideCompilationSubgraphFn::operator()( // it with HostCompute node later. AddNodeAttr("_outside_compilation_subgraph", old_name, node_def); if (shapes) { - AddNodeAttr("shape_inference_graph", "", node_def); + NameAttrList shape_inference_graph; + AddNodeAttr("shape_inference_graph", shape_inference_graph, node_def); AddNodeAttr("shapes", *shapes, node_def); } else { string shape_inference_func_name = absl::StrCat("_outside_compilation_shape_inference_", new_name); - AddNodeAttr("shape_inference_graph", shape_inference_func_name, node_def); + NameAttrList shape_inference_graph; + shape_inference_graph.set_name(shape_inference_func_name); + AddNodeAttr("shape_inference_graph", shape_inference_graph, node_def); AddNodeAttr("shapes", std::vector{}, node_def); } AddNodeAttr("ancestors", std::vector{}, node_def); @@ -775,11 +1468,10 @@ Status ExtractOutsideCompilationForFunction( const string& xla_cluster_attr_name, const string& outside_compilation_attr_name, const string& xla_cluster_name, const NameAttrList& func_name_attrs, const string& new_func_name, + const string& host_graph_func_name, const std::map& host_compute_core, - FunctionLibraryDefinition* fld, std::unique_ptr* host_graph, - std::vector* shape_inference_graphs, + FunctionLibraryDefinition* fld, std::vector* shape_inference_graphs, bool* has_outside_compilation) { - // Early return if function does not have any outside compilation nodes. const string& func_name = func_name_attrs.name(); const FunctionDef* fdef = fld->Find(func_name); if (!fdef) { @@ -792,9 +1484,8 @@ Status ExtractOutsideCompilationForFunction( break; } } - if (!has_outside_compilation) { - return Status::OK(); - } + // We cannot early return here, because we might have outside compilation in + // If/While function body. // Convert the function to graph. FunctionBody* fbody = nullptr; @@ -835,11 +1526,11 @@ Status ExtractOutsideCompilationForFunction( // If we could not infer shapes for XlaSendFromHost inputs statically, we // will set the "shape_inference_graph" attribute. In that case, copy // outside compilation subgraph as shape inference graph in `fld`. - string shape_inference_graph; + NameAttrList shape_inference_graph; TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "shape_inference_graph", &shape_inference_graph)); - if (!shape_inference_graph.empty()) { - shape_inference_graphs->push_back(shape_inference_graph); + if (!shape_inference_graph.name().empty()) { + shape_inference_graphs->push_back(shape_inference_graph.name()); const FunctionDef* xla_fdef = fld->Find(n->name()); if (!xla_fdef) { @@ -847,9 +1538,9 @@ Status ExtractOutsideCompilationForFunction( } FunctionDef shape_inference_fdef = *xla_fdef; shape_inference_fdef.mutable_signature()->set_name( - shape_inference_graph); - if (fld->Find(shape_inference_graph)) { - TF_RETURN_IF_ERROR(fld->ReplaceFunction(shape_inference_graph, + shape_inference_graph.name()); + if (fld->Find(shape_inference_graph.name())) { + TF_RETURN_IF_ERROR(fld->ReplaceFunction(shape_inference_graph.name(), shape_inference_fdef)); } else { TF_RETURN_IF_ERROR(fld->AddFunctionDef(shape_inference_fdef)); @@ -858,6 +1549,7 @@ Status ExtractOutsideCompilationForFunction( } } for (Node* n : outside_compilation_nodes) { + TF_RETURN_IF_ERROR(ValidateOutsideCompilationCallNode(n)); TF_RETURN_IF_ERROR(ReplaceOrRemoveOutsideCompilationCallNode( graph_out.get(), n, host_compute_core)); } @@ -867,12 +1559,17 @@ Status ExtractOutsideCompilationForFunction( *graph_out, fld); } + // Handle nodes with associated functions. + TF_RETURN_IF_ERROR(ExtractOutsideCompilationForNodesWithAssociatedFunctions( + graph_out.get(), xla_cluster_attr_name, outside_compilation_attr_name, + xla_cluster_name, host_compute_core, fld, + &outside_compilation_host_graphs, shape_inference_graphs, + has_outside_compilation)); + // Construct host graph. - if (!outside_compilation_host_graphs.empty()) { - TF_RETURN_IF_ERROR( - ConstructHostGraph(xla_cluster_name, outside_compilation_attr_name, - outside_compilation_host_graphs, fld, host_graph)); - } + TF_RETURN_IF_ERROR(ConstructHostGraph( + xla_cluster_name, outside_compilation_attr_name, + outside_compilation_host_graphs, fld, host_graph_func_name)); // Remove the outside compilation graphs from function library. for (const string& func : outside_compilation_host_graphs) { @@ -909,24 +1606,17 @@ Status ExtractOutsideCompilation( auto const& host_compute_core = iter.second.host_compute_core; bool has_outside_compilation; - std::unique_ptr host_graph; + string host_graph_func_name = absl::StrCat("oc_host_graph_", n->name()); TF_RETURN_IF_ERROR(ExtractOutsideCompilationForFunction( xla_cluster_attr_name, outside_compilation_attr_name, xla_cluster_name, - func_name_attrs, func_name_attrs.name(), host_compute_core, fld, - &host_graph, &shape_inference_graphs, &has_outside_compilation)); - if (host_graph) { - TF_RETURN_IF_ERROR(ExpandHostGraphIntoMainGraph(g, host_graph.get(), n)); - } - } - - if (VLOG_IS_ON(4)) { - dump_graph::DumpGraphToFile("extract_outside_compilation_expanded", *g, - fld); + func_name_attrs, func_name_attrs.name(), host_graph_func_name, + host_compute_core, fld, &shape_inference_graphs, + &has_outside_compilation)); + TF_RETURN_IF_ERROR( + ExpandHostGraphIntoMainGraph(g, fld, host_graph_func_name, n)); + TF_RETURN_IF_ERROR(fld->RemoveFunction(host_graph_func_name)); } - TF_RETURN_IF_ERROR(PostprocessForEncapsulation( - g, xla_cluster_attr_name, outside_compilation_attr_name, clusters)); - for (auto shape_inference_graph_name : shape_inference_graphs) { TF_RETURN_IF_ERROR( RewriteShapeInferenceGraph(shape_inference_graph_name, g, fld)); diff --git a/tensorflow/compiler/jit/extract_outside_compilation_pass.h b/tensorflow/compiler/jit/extract_outside_compilation_pass.h index 2a4f07cca213d999202024294f5d8f94527059c3..e07e7c5dd0cd42ddd4d643d8b36583c82056bbb2 100644 --- a/tensorflow/compiler/jit/extract_outside_compilation_pass.h +++ b/tensorflow/compiler/jit/extract_outside_compilation_pass.h @@ -88,9 +88,10 @@ Status ExtractOutsideCompilationForFunction( const string& xla_cluster_attr_name, const string& outside_compilation_attr_name, const string& xla_cluster_name, const NameAttrList& func_name_attrs, const string& new_func_name, + const string& host_graph_func_name, const std::map& host_compute_core, - FunctionLibraryDefinition* fld, std::unique_ptr* host_graph, - std::vector* shape_inference_graphs, bool* has_outside_compilation); + FunctionLibraryDefinition* fld, std::vector* shape_inference_graphs, + bool* has_outside_compilation); // Rewrites XLA computation in `clusters` to replace outside compilation nodes // with XlaHostCompute, and moves those outside compilations into `g`. If shapes diff --git a/tensorflow/compiler/jit/extract_outside_compilation_pass_test.cc b/tensorflow/compiler/jit/extract_outside_compilation_pass_test.cc index bff956100da661b679b4557fce53671e6cef88c5..e9a89e34e0c7b04b4be34e367b2d0bf627c0061a 100644 --- a/tensorflow/compiler/jit/extract_outside_compilation_pass_test.cc +++ b/tensorflow/compiler/jit/extract_outside_compilation_pass_test.cc @@ -19,8 +19,10 @@ limitations under the License. #include "tensorflow/cc/framework/scope.h" #include "tensorflow/cc/ops/array_ops.h" #include "tensorflow/cc/ops/function_ops.h" +#include "tensorflow/cc/ops/functional_ops.h" #include "tensorflow/cc/ops/standard_ops.h" #include "tensorflow/compiler/jit/encapsulate_util.h" +#include "tensorflow/compiler/xla/test.h" #include "tensorflow/core/common_runtime/function.h" #include "tensorflow/core/framework/common_shape_fns.h" #include "tensorflow/core/framework/function.h" @@ -109,10 +111,10 @@ TEST(RewriteOutsideCompilationSubgraphFnTest, Basic) { } EXPECT_TRUE(has_control_edge_to_send_from_host); // Verify step 7: necessary attrs added to call_node_def. - string shape_inference_graph; + NameAttrList shape_inference_graph; TF_CHECK_OK(GetNodeAttr(AttrSlice(&call_node_def.attr()), "shape_inference_graph", &shape_inference_graph)); - EXPECT_EQ(shape_inference_graph, + EXPECT_EQ(shape_inference_graph.name(), "_outside_compilation_shape_inference_cluster_0"); } @@ -249,27 +251,26 @@ TEST(ExtractOutsideCompilationForFunctionTest, Basic) { protobuf::Map attrs; std::map host_compute_core = {{"0", 1}, {"1", 0}}; - std::unique_ptr host_graph; std::vector shape_inference_graphs; bool has_outside_compilation; NameAttrList name_attrs; name_attrs.set_name("cluster"); *name_attrs.mutable_attr() = attrs; TF_CHECK_OK(ExtractOutsideCompilationForFunction( - "_xla", "_oc", "cluster", name_attrs, "cluster_rewritten", - host_compute_core, &fld, &host_graph, &shape_inference_graphs, + "_xla", "_oc", "cluster", name_attrs, "cluster_rewritten", "host_graph", + host_compute_core, &fld, &shape_inference_graphs, &has_outside_compilation)); // Get rewritten XLA computation function. - FunctionBody *fbody = nullptr; - TF_CHECK_OK(FunctionDefToBodyHelper(*fld.Find("cluster_rewritten"), - AttrSlice(), &fld, - [&](const string &op, const OpDef **sig) { - return fld.LookUpOpDef(op, sig); - }, - &fbody)); - std::unique_ptr fbody_deleter(fbody); - auto node_name_index = fbody->graph->BuildNodeNameIndex(); + FunctionBody *xla_fbody = nullptr; + TF_CHECK_OK(FunctionDefToBodyHelper( + *fld.Find("cluster_rewritten"), AttrSlice(), &fld, + [&](const string &op, const OpDef **sig) { + return fld.LookUpOpDef(op, sig); + }, + &xla_fbody)); + std::unique_ptr xla_fbody_deleter(xla_fbody); + auto node_name_index = xla_fbody->graph->BuildNodeNameIndex(); // Check XlaHostCompute nodes. Node *host_compute_0 = node_name_index["outside_compilation_0_host_compute"]; @@ -292,18 +293,31 @@ TEST(ExtractOutsideCompilationForFunctionTest, Basic) { EXPECT_EQ(shapes[0].dim_size(), 1); // Check XlaHostCompute nodes' "shape_inference_graph" attr. Both should have // empty values. - string shape_inference_graph; + NameAttrList shape_inference_graph; TF_CHECK_OK(GetNodeAttr(host_compute_0->attrs(), "shape_inference_graph", &shape_inference_graph)); - EXPECT_EQ(shape_inference_graph, ""); + EXPECT_EQ(shape_inference_graph.name(), ""); TF_CHECK_OK(GetNodeAttr(host_compute_1->attrs(), "shape_inference_graph", &shape_inference_graph)); - EXPECT_EQ(shape_inference_graph, ""); + EXPECT_EQ(shape_inference_graph.name(), ""); // Check `shape_inference_graphs`. EXPECT_EQ(shape_inference_graphs.size(), 0); - // Check `host_graph`: verify we have key placeholder and sequencer. + // Check host graph: verify we have key placeholder and sequencer. + FunctionBody *host_fbody = nullptr; + AttrValue device_ordinal_temp_value; + device_ordinal_temp_value.set_i(0); + protobuf::Map host_func_attrs; + host_func_attrs["device_ordinal"] = device_ordinal_temp_value; + TF_CHECK_OK(FunctionDefToBodyHelper( + *fld.Find("host_graph"), AttrSlice(&host_func_attrs), &fld, + [&](const string &op, const OpDef **sig) { + return fld.LookUpOpDef(op, sig); + }, + &host_fbody)); + std::unique_ptr host_fbody_deleter(host_fbody); + Graph *host_graph = host_fbody->graph; Node *key_placeholder = nullptr, *sequencer = nullptr; for (Node *n : host_graph->nodes()) { if (n->type_string() == "Placeholder" && @@ -365,25 +379,37 @@ TEST(ExtractOutsideCompilationForFunctionTest, NoHostGraph) { protobuf::Map attrs; std::map host_compute_core = {{"0", 1}, {"1", 0}}; - std::unique_ptr host_graph; std::vector shape_inference_graphs; bool has_outside_compilation; NameAttrList name_attrs; name_attrs.set_name("cluster"); *name_attrs.mutable_attr() = attrs; TF_CHECK_OK(ExtractOutsideCompilationForFunction( - "_xla", "_oc", "cluster", name_attrs, "cluster_rewritten", - host_compute_core, &fld, &host_graph, &shape_inference_graphs, + "_xla", "_oc", "cluster", name_attrs, "cluster_rewritten", "host_graph", + host_compute_core, &fld, &shape_inference_graphs, &has_outside_compilation)); - // Check `host_graph` is empty. - EXPECT_FALSE(host_graph); + // Check host graph is empty. + FunctionBody *host_fbody = nullptr; + AttrValue device_ordinal_temp_value; + device_ordinal_temp_value.set_i(0); + protobuf::Map host_func_attrs; + host_func_attrs["device_ordinal"] = device_ordinal_temp_value; + TF_CHECK_OK(FunctionDefToBodyHelper( + *fld.Find("host_graph"), AttrSlice(&host_func_attrs), &fld, + [&](const string &op, const OpDef **sig) { + return fld.LookUpOpDef(op, sig); + }, + &host_fbody)); + std::unique_ptr host_fbody_deleter(host_fbody); + Graph *host_graph = host_fbody->graph; + EXPECT_EQ(host_graph->num_nodes(), 2); } TEST(ExtractOutsideCompilationForFunctionTest, XlaHostComputeRemoved) { // Build the XLA computation func. // "const0" - // "const1" (outside compilation clsuter "0") + // "const1" (outside compilation cluster "0") FunctionDefLibrary fdl; { tensorflow::Scope s = tensorflow::Scope::NewRootScope(); @@ -401,31 +427,43 @@ TEST(ExtractOutsideCompilationForFunctionTest, XlaHostComputeRemoved) { protobuf::Map attrs; std::map host_compute_core = {{"0", 1}, {"1", 0}}; - std::unique_ptr host_graph; std::vector shape_inference_graphs; bool has_outside_compilation; NameAttrList name_attrs; name_attrs.set_name("cluster"); *name_attrs.mutable_attr() = attrs; TF_CHECK_OK(ExtractOutsideCompilationForFunction( - "_xla", "_oc", "cluster", name_attrs, "cluster_rewritten", - host_compute_core, &fld, &host_graph, &shape_inference_graphs, + "_xla", "_oc", "cluster", name_attrs, "cluster_rewritten", "host_graph", + host_compute_core, &fld, &shape_inference_graphs, &has_outside_compilation)); // Check rewritten XLA graph: verify that we have no XlaHostCompute. - FunctionBody *fbody = nullptr; - TF_CHECK_OK(FunctionDefToBodyHelper(*fld.Find("cluster_rewritten"), - AttrSlice(), &fld, - [&](const string &op, const OpDef **sig) { - return fld.LookUpOpDef(op, sig); - }, - &fbody)); - std::unique_ptr fbody_deleter(fbody); - for (Node *n : fbody->graph->nodes()) { + FunctionBody *xla_fbody = nullptr; + TF_CHECK_OK(FunctionDefToBodyHelper( + *fld.Find("cluster_rewritten"), AttrSlice(), &fld, + [&](const string &op, const OpDef **sig) { + return fld.LookUpOpDef(op, sig); + }, + &xla_fbody)); + std::unique_ptr xla_fbody_deleter(xla_fbody); + for (Node *n : xla_fbody->graph->nodes()) { EXPECT_NE(n->type_string(), "XlaHostCompute"); } - // Check `host_graph`: verify we have no placeholder, but we have "const1". + // Check host graph: verify we have no placeholder, but we have "const1". + FunctionBody *host_fbody = nullptr; + AttrValue device_ordinal_temp_value; + device_ordinal_temp_value.set_i(0); + protobuf::Map host_func_attrs; + host_func_attrs["device_ordinal"] = device_ordinal_temp_value; + TF_CHECK_OK(FunctionDefToBodyHelper( + *fld.Find("host_graph"), AttrSlice(&host_func_attrs), &fld, + [&](const string &op, const OpDef **sig) { + return fld.LookUpOpDef(op, sig); + }, + &host_fbody)); + std::unique_ptr host_fbody_deleter(host_fbody); + Graph *host_graph = host_fbody->graph; int num_key_placeholders = 0; for (Node *n : host_graph->nodes()) { if (n->type_string() == "Placeholder" && @@ -438,4 +476,310 @@ TEST(ExtractOutsideCompilationForFunctionTest, XlaHostComputeRemoved) { EXPECT_NE(node_name_index.find("const1"), node_name_index.end()); } +REGISTER_OP("XlaSendToHost") + .Input("input: Tinput") + .Attr("Tinput: type") + .Attr("key: string") + .SetIsStateful(); + +REGISTER_OP("XlaRecvFromHost") + .Output("output: Toutput") + .Attr("Toutput: type") + .Attr("shape: shape") + .Attr("key: string") + .SetIsStateful(); + +TEST(ExtractOutsideCompilationForFunctionTest, OutsideCompilationInIf) { + // Build the XLA computation func. + // "const0" (bool) + // "const1" (int32) + // "if0" (pred = "const0", input = "const1", then_branch = "true_fn", + // else_branch = "false_fn") + FunctionDefLibrary fdl; + { + tensorflow::Scope s = tensorflow::Scope::NewRootScope(); + Output arg = ops::_Arg(s.WithOpName("arg"), DT_INT32, 0); + Output identity = ops::Identity(s.WithOpName("identity_true_fn"), arg); + ops::_Retval retval(s.WithOpName("retval"), identity, 0); + std::unique_ptr g(new Graph(OpRegistry::Global())); + TF_CHECK_OK(s.ToGraph(g.get())); + auto node_name_image = g->BuildNodeNameIndex(); + node_name_image["identity_true_fn"]->AddAttr("_oc", "0"); + PartialTensorShape shape({2}); + node_name_image["identity_true_fn"]->AddAttr( + kXlaInferredShapesAttrName, std::vector{shape}); + + FunctionDef *true_fn_fdef = fdl.add_function(); + TF_CHECK_OK(GraphToFunctionDef(*g, "true_fn", true_fn_fdef)); + } + { + tensorflow::Scope s = tensorflow::Scope::NewRootScope(); + Output arg = ops::_Arg(s.WithOpName("arg"), DT_INT32, 0); + Output identity = ops::Identity(s.WithOpName("identity_false_fn"), arg); + ops::_Retval retval(s.WithOpName("retval"), identity, 0); + std::unique_ptr g(new Graph(OpRegistry::Global())); + TF_CHECK_OK(s.ToGraph(g.get())); + auto node_name_image = g->BuildNodeNameIndex(); + node_name_image["identity_false_fn"]->AddAttr("_oc", "0"); + PartialTensorShape shape({2}); + node_name_image["identity_false_fn"]->AddAttr( + kXlaInferredShapesAttrName, std::vector{shape}); + + FunctionDef *false_fn_fdef = fdl.add_function(); + TF_CHECK_OK(GraphToFunctionDef(*g, "false_fn", false_fn_fdef)); + } + { + tensorflow::Scope s = tensorflow::Scope::NewRootScope(); + Output cond = ops::Const(s.WithOpName("const0"), true, {2}); + Output input = ops::Const(s.WithOpName("const1"), 1, {2}); + NameAttrList true_fn; + true_fn.set_name("true_fn"); + NameAttrList false_fn; + false_fn.set_name("false_fn"); + auto if_op = ops::If(s.WithOpName("if"), cond, + std::initializer_list{cond, input}, {DT_INT32}, + true_fn, false_fn); + ops::_Retval retval(s.WithOpName("retval"), if_op.output[0], 0); + std::unique_ptr g(new Graph(OpRegistry::Global())); + TF_CHECK_OK(s.ToGraph(g.get())); + + FunctionDef *xla_fdef = fdl.add_function(); + TF_CHECK_OK(GraphToFunctionDef(*g, "cluster", xla_fdef)); + } + FunctionLibraryDefinition fld(OpRegistry::Global(), fdl); + + protobuf::Map attrs; + std::map host_compute_core; + std::vector shape_inference_graphs; + bool has_outside_compilation; + NameAttrList name_attrs; + name_attrs.set_name("cluster"); + *name_attrs.mutable_attr() = attrs; + TF_CHECK_OK(ExtractOutsideCompilationForFunction( + "_xla", "_oc", "cluster", name_attrs, "cluster_rewritten", "host_graph", + host_compute_core, &fld, &shape_inference_graphs, + &has_outside_compilation)); + + // Check host graph. + { + FunctionBody *host_fbody = nullptr; + AttrValue device_ordinal_temp_value; + device_ordinal_temp_value.set_i(0); + protobuf::Map host_func_attrs; + host_func_attrs["device_ordinal"] = device_ordinal_temp_value; + TF_CHECK_OK(FunctionDefToBodyHelper( + *fld.Find("host_graph"), AttrSlice(&host_func_attrs), &fld, + [&](const string &op, const OpDef **sig) { + return fld.LookUpOpDef(op, sig); + }, + &host_fbody)); + std::unique_ptr host_fbody_deleter(host_fbody); + Graph *host_graph = host_fbody->graph; + auto node_name_index = host_graph->BuildNodeNameIndex(); + + // Verify we have XlaRecvAtHost to receive "If" predicate. + Node *recv_if_pred_node = node_name_index["recv_oc_if_pred_if"]; + EXPECT_NE(recv_if_pred_node, nullptr); + + // Verify we have an "If" to choose outside compilation between then_branch + // and else_branch, and it has `recv_if_pred_node` as cond input. + Node *if_oc_node = node_name_index["oc_if_if"]; + EXPECT_NE(if_oc_node, nullptr); + Node *if_oc_node_cond_input; + TF_CHECK_OK(if_oc_node->input_node(0, &if_oc_node_cond_input)); + EXPECT_EQ(if_oc_node_cond_input, recv_if_pred_node); + + // Check that then_branch outside compilation has node "identity_true_fn". + const FunctionDef *true_def = fld.Find("oc_then_branch_host_if_if"); + EXPECT_NE(true_def, nullptr); + bool has_identity_true_fn_node = false; + for (const auto &node_def : true_def->node_def()) { + if (node_def.name() == "identity_true_fn") { + has_identity_true_fn_node = true; + break; + } + } + EXPECT_TRUE(has_identity_true_fn_node); + + // Check that else_branch outside compilation has node "identity_false_fn". + const FunctionDef *false_def = fld.Find("oc_else_branch_host_if_if"); + EXPECT_NE(false_def, nullptr); + bool has_identity_false_fn_node = false; + for (const auto &node_def : false_def->node_def()) { + if (node_def.name() == "identity_false_fn") { + has_identity_false_fn_node = true; + break; + } + } + EXPECT_TRUE(has_identity_false_fn_node); + } + + // Check XLA graph. + { + FunctionBody *xla_fbody = nullptr; + TF_CHECK_OK(FunctionDefToBodyHelper( + *fld.Find("cluster_rewritten"), AttrSlice(), &fld, + [&](const string &op, const OpDef **sig) { + return fld.LookUpOpDef(op, sig); + }, + &xla_fbody)); + std::unique_ptr xla_fbody_deleter(xla_fbody); + Graph *xla_graph = xla_fbody->graph; + auto node_name_index = xla_graph->BuildNodeNameIndex(); + + // Check that we have XlaSendToHost to send cond predicate to host, and + // there is a control edge to If node. + Node *send_if_pred_node = node_name_index["send_oc_if_pred_if"]; + EXPECT_NE(send_if_pred_node, nullptr); + bool has_control_edge_to_if = false; + for (const Edge *e : send_if_pred_node->out_edges()) { + if (e->IsControlEdge() && e->dst()->name() == "if") { + has_control_edge_to_if = true; + break; + } + } + EXPECT_TRUE(has_control_edge_to_if); + + // Check that the "If" node now has `send_if_pred_node` as attribute + // _xla_token_input_nodes. + Node *if_node = node_name_index["if"]; + EXPECT_NE(if_node, nullptr); + std::vector token_inputs; + TF_CHECK_OK( + GetNodeAttr(if_node->def(), "_xla_token_input_nodes", &token_inputs)); + EXPECT_THAT(token_inputs, ::testing::ElementsAre("send_oc_if_pred_if")); + } +} + +TEST(ExtractOutsideCompilationForFunctionTest, OutsideCompilationInWhile) { + // Build the XLA computation func. + // "const0" (bool) + // "while0" (input = "const0", cond = "cond_fn", body = "body_fn") + FunctionDefLibrary fdl; + { + tensorflow::Scope s = tensorflow::Scope::NewRootScope(); + Output arg = ops::_Arg(s.WithOpName("arg"), DT_BOOL, 0); + Output identity = ops::Identity(s.WithOpName("identity_cond_fn"), arg); + ops::_Retval retval(s.WithOpName("retval"), identity, 0); + std::unique_ptr g(new Graph(OpRegistry::Global())); + TF_CHECK_OK(s.ToGraph(g.get())); + auto node_name_image = g->BuildNodeNameIndex(); + node_name_image["identity_cond_fn"]->AddAttr("_oc", "0"); + PartialTensorShape shape({2}); + node_name_image["identity_cond_fn"]->AddAttr( + kXlaInferredShapesAttrName, std::vector{shape}); + + FunctionDef *cond_fn_fdef = fdl.add_function(); + TF_CHECK_OK(GraphToFunctionDef(*g, "cond_fn", cond_fn_fdef)); + } + { + tensorflow::Scope s = tensorflow::Scope::NewRootScope(); + Output arg = ops::_Arg(s.WithOpName("arg"), DT_BOOL, 0); + Output identity = ops::Identity(s.WithOpName("identity_body_fn"), arg); + ops::_Retval retval(s.WithOpName("retval"), identity, 0); + std::unique_ptr g(new Graph(OpRegistry::Global())); + TF_CHECK_OK(s.ToGraph(g.get())); + auto node_name_image = g->BuildNodeNameIndex(); + node_name_image["identity_body_fn"]->AddAttr("_oc", "0"); + PartialTensorShape shape({2}); + node_name_image["identity_body_fn"]->AddAttr( + kXlaInferredShapesAttrName, std::vector{shape}); + + FunctionDef *body_fn_fdef = fdl.add_function(); + TF_CHECK_OK(GraphToFunctionDef(*g, "body_fn", body_fn_fdef)); + } + { + tensorflow::Scope s = tensorflow::Scope::NewRootScope(); + Output input = ops::Const(s.WithOpName("const0"), true, {2}); + NameAttrList cond_fn; + cond_fn.set_name("cond_fn"); + NameAttrList body_fn; + body_fn.set_name("body_fn"); + auto while_op = + ops::While(s.WithOpName("while"), std::initializer_list{input}, + cond_fn, body_fn); + ops::_Retval retval(s.WithOpName("retval"), while_op.output[0], 0); + std::unique_ptr g(new Graph(OpRegistry::Global())); + TF_CHECK_OK(s.ToGraph(g.get())); + + FunctionDef *xla_fdef = fdl.add_function(); + TF_CHECK_OK(GraphToFunctionDef(*g, "cluster", xla_fdef)); + } + FunctionLibraryDefinition fld(OpRegistry::Global(), fdl); + + protobuf::Map attrs; + std::map host_compute_core; + std::vector shape_inference_graphs; + bool has_outside_compilation; + NameAttrList name_attrs; + name_attrs.set_name("cluster"); + *name_attrs.mutable_attr() = attrs; + TF_CHECK_OK(ExtractOutsideCompilationForFunction( + "_xla", "_oc", "cluster", name_attrs, "cluster_rewritten", "host_graph", + host_compute_core, &fld, &shape_inference_graphs, + &has_outside_compilation)); + + // Check host graph. + { + FunctionBody *host_fbody = nullptr; + AttrValue device_ordinal_temp_value; + device_ordinal_temp_value.set_i(0); + protobuf::Map host_func_attrs; + host_func_attrs["device_ordinal"] = device_ordinal_temp_value; + TF_CHECK_OK(FunctionDefToBodyHelper( + *fld.Find("host_graph"), AttrSlice(&host_func_attrs), &fld, + [&](const string &op, const OpDef **sig) { + return fld.LookUpOpDef(op, sig); + }, + &host_fbody)); + std::unique_ptr host_fbody_deleter(host_fbody); + Graph *host_graph = host_fbody->graph; + auto node_name_index = host_graph->BuildNodeNameIndex(); + + // Verify we have an "While" to execute outside compilation. + Node *while_oc_node = node_name_index["oc_while_while"]; + EXPECT_NE(while_oc_node, nullptr); + + // Check that cond outside compilation has node "identity_cond_fn". + const FunctionDef *cond_def = fld.Find("oc_cond_host_while_while"); + EXPECT_NE(cond_def, nullptr); + bool has_identity_cond_fn_node = false; + for (const auto &node_def : cond_def->node_def()) { + if (node_def.name() == "identity_cond_fn") { + has_identity_cond_fn_node = true; + break; + } + } + EXPECT_TRUE(has_identity_cond_fn_node); + + // Check that body outside compilation has node "identity_body_fn". + const FunctionDef *body_def = fld.Find("oc_body_host_while_while"); + EXPECT_NE(body_def, nullptr); + bool has_identity_body_fn_node = false; + for (const auto &node_def : body_def->node_def()) { + if (node_def.name() == "identity_body_fn") { + has_identity_body_fn_node = true; + break; + } + } + EXPECT_TRUE(has_identity_body_fn_node); + } + + // Check XLA graph. + { + // Verify that rewritten cond fn has XlaSendToHost to send loop predicate to + // host. + const FunctionDef *cond_def = fld.Find("cond_fn_oc"); + EXPECT_NE(cond_def, nullptr); + bool has_send_oc_while_cond_node = false; + for (const auto &node_def : cond_def->node_def()) { + if (node_def.name() == "send_oc_while_cond_while") { + has_send_oc_while_cond_node = true; + break; + } + } + EXPECT_TRUE(has_send_oc_while_cond_node); + } +} + } // namespace tensorflow diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass.cc b/tensorflow/compiler/jit/mark_for_compilation_pass.cc index 25796435a5c87af5e252981abf96833f4cda9a5e..6618e3a58ab7b6374ed775cd6e4e18a6a4975588 100644 --- a/tensorflow/compiler/jit/mark_for_compilation_pass.cc +++ b/tensorflow/compiler/jit/mark_for_compilation_pass.cc @@ -86,7 +86,7 @@ bool IsDummyImplOp(absl::string_view op_name) { bool IsStatefulRandomOp(absl::string_view op_name) { return op_name == "RandomUniform" || op_name == "RandomShuffle" || op_name == "RandomUniformInt" || op_name == "RandomStandardNormal" || - op_name == "TruncatedNormal"; + op_name == "TruncatedNormal" || op_name == "Multinomial"; } bool OpProducesOrConsumesVariant(const Node& node) { diff --git a/tensorflow/compiler/jit/partially_decluster_pass.cc b/tensorflow/compiler/jit/partially_decluster_pass.cc index 42ea3926e16ae791dbe1bede3b8742383db7667c..e1fd2aaee2822daeffb415d053c9c4f56002a856 100644 --- a/tensorflow/compiler/jit/partially_decluster_pass.cc +++ b/tensorflow/compiler/jit/partially_decluster_pass.cc @@ -120,6 +120,7 @@ Status PartiallyDeclusterNode(Graph* graph, Node* n) { NodeDef ndef = n->def(); ndef.set_name(absl::StrCat(n->name(), "/declustered")); + MergeDebugInfo(NodeDebugInfo(n->def()), &ndef); RemoveFromXlaCluster(&ndef); Status s; Node* cloned_node = graph->AddNode(ndef, &s); diff --git a/tensorflow/compiler/jit/shape_inference.cc b/tensorflow/compiler/jit/shape_inference.cc index 80c691fe490c1092315708a2da754d367d585300..a27e0d9f2a6ecddfdbdb29be673084d77a178d8a 100644 --- a/tensorflow/compiler/jit/shape_inference.cc +++ b/tensorflow/compiler/jit/shape_inference.cc @@ -53,7 +53,15 @@ Status PropagateShapes(const Graph& graph, // shapes, even if no shape function is registered for a node. Status status = shape_refiner->AddNode(n); if (!status.ok()) { - VLOG(1) << "Shape inference failed for node: " << status; + VLOG(1) << "Shape inference failed for node " << n->name() << ": " + << status; + } else { + shape_inference::InferenceContext* context = shape_refiner->GetContext(n); + for (int i = 0; i < n->num_outputs(); i++) { + shape_inference::ShapeHandle handle = context->output(i); + VLOG(4) << "Output " << i << " for node " << n->name() << ": " + << context->DebugString(handle); + } } if (n->type_string() == "_Arg") { diff --git a/tensorflow/compiler/jit/xla_compile_on_demand_op.cc b/tensorflow/compiler/jit/xla_compile_on_demand_op.cc index 1fe612d43d10030675cf307b109e4dcc89cb2d79..c7e8d61d280a33a83c3386d8ef801018634d31ec 100644 --- a/tensorflow/compiler/jit/xla_compile_on_demand_op.cc +++ b/tensorflow/compiler/jit/xla_compile_on_demand_op.cc @@ -142,11 +142,22 @@ Status XlaCompileOnDemandOp::Compile( TF_RETURN_IF_ERROR(ctx->allocate_temp( device_tensor.dtype(), device_tensor.shape(), &host_tensor, attrs)); Notification n; + Status status; ctx->op_device_context()->CopyDeviceTensorToCPU( &device_tensor, "ConstantArgument", reinterpret_cast(ctx->device()), &host_tensor, - [&](Status status) { n.Notify(); }); + [&](Status s) { + status = s; + n.Notify(); + }); n.WaitForNotification(); + if (!status.ok()) { + LOG(ERROR) << "Copying tensor of shape " + << device_tensor.shape().DebugString() << " from " + << ctx->device()->name() << "to CPU failed with " + << status.ToString(); + return status; + } constant_arguments[i] = host_tensor; } } @@ -189,6 +200,7 @@ Status XlaCompileOnDemandOp::Compile( std::map variable_args = GetVariables(ctx); std::vector args; + TF_RETURN_IF_ERROR(XlaComputationLaunchContext::BuildXlaCompilerArguments( constant_arguments, variable_args, ctx, &args)); diff --git a/tensorflow/compiler/jit/xla_cpu_device.cc b/tensorflow/compiler/jit/xla_cpu_device.cc index 7df898ad12a15345f45fc96e0ec3d42b6e51731b..e9770647e7ba96cc1db026d12d5f11f52ce98d35 100644 --- a/tensorflow/compiler/jit/xla_cpu_device.cc +++ b/tensorflow/compiler/jit/xla_cpu_device.cc @@ -63,7 +63,19 @@ Status XlaCpuDeviceFactory::CreateDevices( options.device_ordinal = 0; options.compilation_device_name = DEVICE_CPU_XLA_JIT; options.use_multiple_streams = false; - devices->push_back(absl::make_unique(session_options, options)); + auto device = absl::make_unique(session_options, options); + + // Setting GpuDeviceInfo because eager runtime relies on the device + // context in tensorflow_gpu_device_info(). Also, + // tensorflow_gpu_device_info() == nullptr is used as an IsCPU test. + // We need XlaCpuDevice to be treated not as CPU because it allocates + // XlaTensors, not regular Tensors. + Status status = device->UseGpuDeviceInfo(); + if (!status.ok()) { + errors::AppendToMessage(&status, "while setting up ", DEVICE_GPU_XLA_JIT); + return status; + } + devices->push_back(std::move(device)); return Status::OK(); } diff --git a/tensorflow/compiler/jit/xla_device_context.cc b/tensorflow/compiler/jit/xla_device_context.cc index 6e6532731e64bd42ee56aa719748988f321e0f17..1f3afe8822d441a5ce37617fe18d7767e9bc72e4 100644 --- a/tensorflow/compiler/jit/xla_device_context.cc +++ b/tensorflow/compiler/jit/xla_device_context.cc @@ -79,6 +79,13 @@ XlaDeviceContext::XlaDeviceContext( } } +void XlaDeviceContext::CopyTensorInSameDevice(const Tensor* input_tensor, + Device* device, + Tensor* output_tensor, + StatusCallback done) const { + done(errors::Unimplemented("XLA->XLA same-device copies not implemented.")); +} + void XlaDeviceContext::CopyCPUTensorToDevice(const Tensor* cpu_tensor, Device* device, Tensor* device_tensor, diff --git a/tensorflow/compiler/jit/xla_device_context.h b/tensorflow/compiler/jit/xla_device_context.h index 1e18df197a2dd65590c5181b4dae4481dca36641..e45db989fac720df6c3458c93a6b8dbb0919f930 100644 --- a/tensorflow/compiler/jit/xla_device_context.h +++ b/tensorflow/compiler/jit/xla_device_context.h @@ -62,6 +62,9 @@ class XlaDeviceContext : public DeviceContext { void CopyDeviceTensorToCPU(const Tensor* device_tensor, absl::string_view tensor_name, Device* device, Tensor* cpu_tensor, StatusCallback done) override; + void CopyTensorInSameDevice(const Tensor* input_tensor, Device* device, + Tensor* output_tensor, + StatusCallback done) const override; xla::LocalClient* client() const { return client_; } se::Stream* stream() const { return stream_.get(); } diff --git a/tensorflow/compiler/jit/xla_device_ops.h b/tensorflow/compiler/jit/xla_device_ops.h index adf0f994b84d9fbf918a5b2478aa7d106853e038..927f983ba9ef23c8509523f42366c0c89c29db9f 100644 --- a/tensorflow/compiler/jit/xla_device_ops.h +++ b/tensorflow/compiler/jit/xla_device_ops.h @@ -203,6 +203,8 @@ class XlaAssignVariableOp : public OpKernel { .HostMemory("output") \ .TypeConstraint("T"), \ ArgOp); \ + REGISTER_KERNEL_BUILDER( \ + Name(kArgOp).Device(DEVICE).TypeConstraint("T"), ArgOp); \ \ REGISTER_KERNEL_BUILDER(Name(kRetOp) \ .Device(DEVICE) \ diff --git a/tensorflow/compiler/jit/xla_gpu_device.cc b/tensorflow/compiler/jit/xla_gpu_device.cc index 944f732b99c0924a08932eda0aedd8c815cc51d0..0191315a66f4d331e54fadc9dc6a073a05fd67ef 100644 --- a/tensorflow/compiler/jit/xla_gpu_device.cc +++ b/tensorflow/compiler/jit/xla_gpu_device.cc @@ -16,7 +16,10 @@ limitations under the License. // Registers the XLA_GPU device, which is an XlaDevice instantiation that runs // operators using XLA via the XLA "CUDA" (GPU) backend. +#include #include "absl/memory/memory.h" +#include "absl/strings/numbers.h" +#include "absl/strings/str_split.h" #include "tensorflow/compiler/jit/kernels/xla_ops.h" #include "tensorflow/compiler/jit/xla_device.h" #include "tensorflow/compiler/jit/xla_device_ops.h" @@ -52,8 +55,35 @@ Status XlaGpuDeviceFactory::CreateDevices( VLOG(1) << "Failed to create XLA_GPU device: " << platform.status(); return Status::OK(); } - - for (int i = 0; i < platform.ValueOrDie()->VisibleDeviceCount(); ++i) { + string allowed_gpus = + session_options.config.gpu_options().visible_device_list(); + std::set gpu_ids; + int num_visible_devices = platform.ValueOrDie()->VisibleDeviceCount(); + if (allowed_gpus.empty()) { + for (int i = 0; i < num_visible_devices; ++i) { + gpu_ids.insert(i); + } + } else { + // For loop below is copied from gpu/gpu_device.cc. It validates + // the visible_device_list and populates gpu_ids set. + const std::vector visible_devices = + absl::StrSplit(allowed_gpus, ','); + for (const string& platform_gpu_id_str : visible_devices) { + int32 platform_gpu_id; + if (!absl::SimpleAtoi(platform_gpu_id_str, &platform_gpu_id)) { + return errors::InvalidArgument( + "Could not parse entry in 'visible_device_list': '", + platform_gpu_id_str, "'. visible_device_list = ", allowed_gpus); + } + if (platform_gpu_id < 0 || platform_gpu_id >= num_visible_devices) { + return errors::InvalidArgument( + "'visible_device_list' listed an invalid GPU id '", platform_gpu_id, + "' but visible device count is ", num_visible_devices); + } + gpu_ids.insert(platform_gpu_id); + } + } + for (int i : gpu_ids) { XlaDevice::Options options; options.platform = platform.ValueOrDie(); options.device_name_prefix = name_prefix; diff --git a/tensorflow/compiler/jit/xla_launch_util.h b/tensorflow/compiler/jit/xla_launch_util.h index 437db019a0eabe66417725148d8b121842e90479..554227f09de0ab4d9e07f199b957657f3121ff06 100644 --- a/tensorflow/compiler/jit/xla_launch_util.h +++ b/tensorflow/compiler/jit/xla_launch_util.h @@ -199,19 +199,17 @@ class XlaTensorBuffer : public TensorBuffer { public: XlaTensorBuffer(const void* ptr, size_t expected_size, size_t actual_size, Allocator* allocator) - : expected_size_(expected_size), + : TensorBuffer(const_cast(ptr)), + expected_size_(expected_size), actual_size_(actual_size), - allocator_(allocator) { - data_ = const_cast(ptr); - } + allocator_(allocator) {} ~XlaTensorBuffer() override { - if (data_) { - allocator_->DeallocateRaw(data_); + if (data()) { + allocator_->DeallocateRaw(data()); } } - void* data() const override { return data_; } size_t size() const override { return expected_size_; } TensorBuffer* root_buffer() override { return this; } @@ -231,7 +229,6 @@ class XlaTensorBuffer : public TensorBuffer { } private: - void* data_; size_t expected_size_; size_t actual_size_; Allocator* allocator_; diff --git a/tensorflow/compiler/tests/BUILD b/tensorflow/compiler/tests/BUILD index bc3d60b90e58b4018f1c52b09941dedba7ef348a..093b61629cd0b04d5d8488139b8d7262b739f86d 100644 --- a/tensorflow/compiler/tests/BUILD +++ b/tensorflow/compiler/tests/BUILD @@ -408,13 +408,6 @@ tf_xla_py_test( name = "eager_test", size = "large", srcs = ["eager_test.py"], - disabled_backends = [ - # TODO(b/78199195) Support XLA CPU devices in eager runtime - "cpu", - "cpu_ondemand", - # TODO(b/78468222) Enable GPU backend - "gpu", - ], deps = [ ":xla_test", "//tensorflow/python:array_ops", diff --git a/tensorflow/compiler/tests/depthwise_conv_op_test.py b/tensorflow/compiler/tests/depthwise_conv_op_test.py index 174bfa9efbcd7dcb4f895237eb01c17bc4a3a6b4..90146e6b27ca31304a2549ec247412341efe390c 100644 --- a/tensorflow/compiler/tests/depthwise_conv_op_test.py +++ b/tensorflow/compiler/tests/depthwise_conv_op_test.py @@ -350,8 +350,13 @@ class DepthwiseConv2DTest(xla_test.XLATestCase): self._CompareBackpropInput(input_size, filter_size, output_size, stride, padding) - def _CompareBackpropFilter(self, input_sizes, filter_sizes, output_sizes, - stride, padding): + def _CompareBackpropFilter(self, + input_sizes, + filter_sizes, + output_sizes, + stride, + padding, + data_format="NHWC"): x0 = np.random.rand(*input_sizes).astype(np.float32) x2 = np.random.rand(*output_sizes).astype(np.float32) @@ -360,13 +365,30 @@ class DepthwiseConv2DTest(xla_test.XLATestCase): t0 = array_ops.placeholder(np.float32, shape=input_sizes) t1 = constant_op.constant(filter_sizes, shape=[len(filter_sizes)]) t2 = array_ops.placeholder(np.float32, shape=output_sizes) + native_t0 = t0 + native_t2 = t2 + strides = [1, stride, stride, 1] + if use_xla: + if data_format == "NCHW": + # Transpose from NWHC input to NCHW + # Ex. [4, 5, 5, 48] to [4, 48, 5, 5] + native_t0 = array_ops.transpose(t0, [0, 3, 1, 2]) + native_t2 = array_ops.transpose(t2, [0, 3, 1, 2]) + strides = [1, 1, stride, stride] with self.test_scope(): backprop = nn_ops.depthwise_conv2d_native_backprop_filter( - t0, t1, t2, strides=[1, stride, stride, 1], padding=padding) + native_t0, + t1, + native_t2, + strides=strides, + padding=padding, + data_format=data_format) else: + # For CPU, the format NCHW is not supported. Therefore we always use + # NHWC here. backprop = nn_ops.depthwise_conv2d_native_backprop_filter( - t0, t1, t2, strides=[1, stride, stride, 1], padding=padding) + native_t0, t1, native_t2, strides=strides, padding=padding) ret = backprop.eval({t0: x0, t2: x2}) self.assertShapeEqual(ret, backprop) return ret @@ -379,11 +401,24 @@ class DepthwiseConv2DTest(xla_test.XLATestCase): for index, (input_size, filter_size, output_size, stride, padding) in enumerate(ConfigsToTest()): print("Testing DepthwiseConv2DFilterGradCompare,", index, "th config:", - input_size, "*", filter_size, "stride:", stride, "padding:", - padding) + input_size, "*", filter_size, "producing output", output_size, + "stride:", stride, "padding:", padding) self._CompareBackpropFilter(input_size, filter_size, output_size, stride, padding) + def testDepthwiseConv2DFilterGradFormatNCHWCompare(self): + for index, (input_size, filter_size, output_size, stride, + padding) in enumerate(ConfigsToTest()): + print("Testing DepthwiseConv2DFilterGradFormatNCHWCompare,", index, + "th config:", input_size, "*", filter_size, "producing output", + output_size, "stride:", stride, "padding:", padding) + self._CompareBackpropFilter( + input_size, + filter_size, + output_size, + stride, + padding, + data_format="NCHW") if __name__ == "__main__": test.main() diff --git a/tensorflow/compiler/tests/xla_ops_test.py b/tensorflow/compiler/tests/xla_ops_test.py index 4cf88fc523735cc2d22e085afb83790c7ebb48e4..28274ff799de2c85e1e80512cadbe0206cb640a4 100644 --- a/tensorflow/compiler/tests/xla_ops_test.py +++ b/tensorflow/compiler/tests/xla_ops_test.py @@ -319,7 +319,7 @@ class XlaOpsTest(xla_test.XLATestCase, parameterized.TestCase): session.run(output) self.assertRegexpMatches( invalid_arg_error.exception.message, - (r'^start_indices must be a vector with length equal to input rank, ' + (r'start_indices must be a vector with length equal to input rank, ' r'but input rank is 3 and start_indices has shape \[2\].*')) def testDynamicSliceWithIncorrectSizeIndicesShape(self): @@ -332,7 +332,7 @@ class XlaOpsTest(xla_test.XLATestCase, parameterized.TestCase): session.run(output) self.assertRegexpMatches( invalid_arg_error.exception.message, - (r'^size_indices must be a vector with length equal to input rank, ' + (r'size_indices must be a vector with length equal to input rank, ' r'but input rank is 3 and size_indices has shape \[2\].*')) diff --git a/tensorflow/compiler/tf2xla/BUILD b/tensorflow/compiler/tf2xla/BUILD index 25a84fb1b6609106213231db1ca1ce54da8bd960..5a0d9b9af9d55a8dee809d3cf909bce39c3b8b6c 100644 --- a/tensorflow/compiler/tf2xla/BUILD +++ b/tensorflow/compiler/tf2xla/BUILD @@ -445,14 +445,9 @@ cc_library( ], deps = [ "//tensorflow/compiler/jit:flags", - "//tensorflow/compiler/xla:parse_flags_from_env", - "//tensorflow/core:core_cpu", - "//tensorflow/core:core_cpu_internal", "//tensorflow/core:framework", - "//tensorflow/core:framework_internal", - "//tensorflow/core:lib", + "//tensorflow/core:graph", "//tensorflow/core:protos_all_cc", - "@com_google_absl//absl/strings", ], ) diff --git a/tensorflow/compiler/tf2xla/dump_graph.cc b/tensorflow/compiler/tf2xla/dump_graph.cc index 1de85004a51bea464f8f0166511402e5dd85ac14..64fdbbebc65bff4ed0b965fcdd534cc9696472b6 100644 --- a/tensorflow/compiler/tf2xla/dump_graph.cc +++ b/tensorflow/compiler/tf2xla/dump_graph.cc @@ -18,86 +18,26 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/dump_graph.h" -#include "absl/strings/str_cat.h" #include "tensorflow/compiler/jit/flags.h" -#include "tensorflow/core/platform/env.h" -#include "tensorflow/core/platform/mutex.h" +#include "tensorflow/core/util/dump_graph.h" namespace tensorflow { namespace dump_graph { -namespace { - -struct NameCounts { - mutex counts_mutex; - std::unordered_map counts; -}; - -string MakeUniqueFilename(string name) { - static NameCounts& instance = *new NameCounts; - - // Remove illegal characters from `name`. - for (int i = 0; i < name.size(); ++i) { - char ch = name[i]; - if (ch == '/' || ch == '[' || ch == ']' || ch == '*' || ch == '?') { - name[i] = '_'; - } - } - - int count; - { - mutex_lock lock(instance.counts_mutex); - count = instance.counts[name]++; - } - - string filename = name; - if (count > 0) { - absl::StrAppend(&filename, "_", count); - } - absl::StrAppend(&filename, ".pbtxt"); - return filename; -} - -string WriteTextProtoToUniqueFile( - Env* env, const string& name, const char* proto_type, - const ::tensorflow::protobuf::Message& proto) { - const string& dirname = GetDumpGraphFlags()->tf_dump_graph_prefix; - Status status = env->RecursivelyCreateDir(dirname); - if (!status.ok()) { - LOG(WARNING) << "Failed to create " << dirname << " for dumping " - << proto_type << ": " << status; - return "(unavailable)"; - } - string filepath = absl::StrCat(dirname, "/", MakeUniqueFilename(name)); - status = WriteTextProto(Env::Default(), filepath, proto); - if (!status.ok()) { - LOG(WARNING) << "Failed to dump " << proto_type << " to file: " << filepath - << " : " << status; - return "(unavailable)"; - } - LOG(INFO) << "Dumped " << proto_type << " to " << filepath; - return filepath; -} - -} // anonymous namespace - string DumpGraphDefToFile(const string& name, GraphDef const& graph_def) { - return WriteTextProtoToUniqueFile(Env::Default(), name, "GraphDef", - graph_def); + return tensorflow::DumpGraphDefToFile( + name, graph_def, GetDumpGraphFlags()->tf_dump_graph_prefix); } string DumpGraphToFile(const string& name, Graph const& graph, const FunctionLibraryDefinition* flib_def) { - GraphDef graph_def; - graph.ToGraphDef(&graph_def); - if (flib_def) { - *graph_def.mutable_library() = flib_def->ToProto(); - } - return DumpGraphDefToFile(name, graph_def); + return tensorflow::DumpGraphToFile(name, graph, flib_def, + GetDumpGraphFlags()->tf_dump_graph_prefix); } string DumpFunctionDefToFile(const string& name, FunctionDef const& fdef) { - return WriteTextProtoToUniqueFile(Env::Default(), name, "FunctionDef", fdef); + return tensorflow::DumpFunctionDefToFile( + name, fdef, GetDumpGraphFlags()->tf_dump_graph_prefix); } } // namespace dump_graph diff --git a/tensorflow/compiler/tf2xla/functionalize_cond.cc b/tensorflow/compiler/tf2xla/functionalize_cond.cc index c693e42d26712d55852f45c806215fc1f1b9a030..7ae96e1d484900e28e8c23c3bb2232401144ad82 100644 --- a/tensorflow/compiler/tf2xla/functionalize_cond.cc +++ b/tensorflow/compiler/tf2xla/functionalize_cond.cc @@ -640,7 +640,8 @@ Status Conditional::ExtractBodies(Graph* graph) { Status Conditional::BuildIfNode(Graph* graph, FunctionLibraryDefinition* library) { VLOG(2) << "Build cond function for " << name(); - NodeDefBuilder builder(name(), "If", library); + NodeDebugInfo debug_info((*merges_.begin())->def()); + NodeDefBuilder builder(name(), "If", library, &debug_info); const string branch_name[] = {"else_branch", "then_branch"}; for (auto branch : {BranchType::kElseBranch, BranchType::kThenBranch}) { int branch_index = static_cast(branch); diff --git a/tensorflow/compiler/tf2xla/kernels/BUILD b/tensorflow/compiler/tf2xla/kernels/BUILD index d85b4f5ae0cb9c7d2476158a5830f921742ae980..a18a4e92d62787051f6ab92e72ee8bf0d1060dca 100644 --- a/tensorflow/compiler/tf2xla/kernels/BUILD +++ b/tensorflow/compiler/tf2xla/kernels/BUILD @@ -1,16 +1,11 @@ +load("//tensorflow:tensorflow.bzl", "tf_copts", "tf_kernel_library") + licenses(["notice"]) # Apache 2.0 package( default_visibility = ["//tensorflow/compiler/tf2xla:internal"], ) -load("//tensorflow:tensorflow.bzl", "tf_copts") -load("//tensorflow:tensorflow.bzl", "tf_kernel_library") -load( - "//third_party/mkl:build_defs.bzl", - "if_mkl", -) - tf_kernel_library( name = "xla_ops", srcs = [ @@ -121,15 +116,10 @@ tf_kernel_library( ":while_op", "//tensorflow/compiler/tf2xla:common", "//tensorflow/compiler/tf2xla:xla_compiler", - "//tensorflow/compiler/tf2xla/lib:batch_dot", "//tensorflow/compiler/tf2xla/lib:broadcast", - "//tensorflow/compiler/tf2xla/lib:cholesky", - "//tensorflow/compiler/tf2xla/lib:qr", "//tensorflow/compiler/tf2xla/lib:random", "//tensorflow/compiler/tf2xla/lib:scatter", - "//tensorflow/compiler/tf2xla/lib:triangular_solve", "//tensorflow/compiler/tf2xla/lib:util", - "//tensorflow/compiler/tf2xla/lib:while_loop", "//tensorflow/compiler/tf2xla/ops:xla_ops", "//tensorflow/compiler/xla:array4d", "//tensorflow/compiler/xla:literal", @@ -142,12 +132,16 @@ tf_kernel_library( "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/client/lib:arithmetic", + "//tensorflow/compiler/xla/client/lib:cholesky", "//tensorflow/compiler/xla/client/lib:constants", + "//tensorflow/compiler/xla/client/lib:loops", "//tensorflow/compiler/xla/client/lib:math", - "//tensorflow/compiler/xla/client/lib:numeric", + "//tensorflow/compiler/xla/client/lib:matrix", "//tensorflow/compiler/xla/client/lib:pooling", "//tensorflow/compiler/xla/client/lib:prng", + "//tensorflow/compiler/xla/client/lib:qr", "//tensorflow/compiler/xla/client/lib:sorting", + "//tensorflow/compiler/xla/client/lib:triangular_solve", "//tensorflow/core:framework", "//tensorflow/core:image_ops_op_lib", "//tensorflow/core:lib", @@ -196,7 +190,6 @@ cc_library( "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/client/lib:arithmetic", "//tensorflow/compiler/xla/client/lib:constants", - "//tensorflow/compiler/xla/client/lib:numeric", "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", @@ -216,7 +209,6 @@ cc_library( "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/client/lib:arithmetic", "//tensorflow/compiler/xla/client/lib:constants", - "//tensorflow/compiler/xla/client/lib:numeric", "//tensorflow/core:framework", "//tensorflow/core/kernels:bounds_check", "//tensorflow/core/kernels:conv_ops", diff --git a/tensorflow/compiler/tf2xla/kernels/batch_matmul_op.cc b/tensorflow/compiler/tf2xla/kernels/batch_matmul_op.cc index 4cfe946b2e6146f034867c06e996ffae42b90705..1b254e328a8c71bd81a0ec700e2af1d81a5fa67a 100644 --- a/tensorflow/compiler/tf2xla/kernels/batch_matmul_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/batch_matmul_op.cc @@ -13,9 +13,11 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/tf2xla/lib/batch_dot.h" +#include "tensorflow/compiler/tf2xla/lib/util.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/lib/math.h" +#include "tensorflow/compiler/xla/client/lib/matrix.h" namespace tensorflow { namespace { @@ -28,9 +30,11 @@ class BatchMatMulOp : public XlaOpKernel { } void Compile(XlaOpKernelContext* ctx) override { - auto result = BatchDot(ctx->Input(0), ctx->Input(1), - /*transpose_x=*/adj_x_, /*transpose_y=*/adj_y_, - /*conjugate_x=*/adj_x_, /*conjugate_y=*/adj_y_); + auto result = + xla::BatchDot(MaybeTransposeInMinorDims( + MaybeConjugate(ctx->Input(0), adj_x_), adj_x_), + MaybeTransposeInMinorDims( + MaybeConjugate(ctx->Input(1), adj_y_), adj_y_)); ctx->SetOutput(0, result); } diff --git a/tensorflow/compiler/tf2xla/kernels/cholesky_op.cc b/tensorflow/compiler/tf2xla/kernels/cholesky_op.cc index 9fcbc86adc0967cbb7fb73da8bdabc58b60953da..0ed3044efa5b1060d2b0ad2d5563b0e02ebf66ec 100644 --- a/tensorflow/compiler/tf2xla/kernels/cholesky_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/cholesky_op.cc @@ -13,9 +13,9 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/tf2xla/lib/cholesky.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/lib/cholesky.h" namespace tensorflow { namespace { @@ -24,7 +24,7 @@ class CholeskyOp : public XlaOpKernel { public: explicit CholeskyOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} void Compile(XlaOpKernelContext* ctx) override { - ctx->SetOutput(0, Cholesky(ctx->Input(0))); + ctx->SetOutput(0, xla::Cholesky(ctx->Input(0))); } }; diff --git a/tensorflow/compiler/tf2xla/kernels/conv_op_helpers.cc b/tensorflow/compiler/tf2xla/kernels/conv_op_helpers.cc index 641fefafb357f6ad10483c454600f3dadd4f8cb7..4124b258c7788e3850f07cbf4d53930784c635fd 100644 --- a/tensorflow/compiler/tf2xla/kernels/conv_op_helpers.cc +++ b/tensorflow/compiler/tf2xla/kernels/conv_op_helpers.cc @@ -392,23 +392,31 @@ xla::StatusOr MakeXlaBackpropFilterConvOp( builder->GetShape(activations)); TF_ASSIGN_OR_RETURN(xla::Shape out_backprop_shape, builder->GetShape(gradients)); + xla::XlaOp filter_backprop; + + xla::Shape input_shape = activations_shape; + xla::Shape output_shape = out_backprop_shape; + + TensorShape input_tensor_shape, filter_tensor_shape, output_tensor_shape; + TF_RETURN_IF_ERROR(XLAShapeToTensorShape(filter_shape, &filter_tensor_shape)); + TF_RETURN_IF_ERROR(XLAShapeToTensorShape(input_shape, &input_tensor_shape)); + TF_RETURN_IF_ERROR(XLAShapeToTensorShape(output_shape, &output_tensor_shape)); + const xla::Shape expanded_filter_shape = attrs.depthwise ? ExpandedFilterShapeForDepthwiseConvolution(filter_shape) : filter_shape; - // Reuse dimension computation logic from conv_grad_ops.cc. ConvBackpropDimensions dims; - TF_RETURN_IF_ERROR(ConvBackpropComputeDimensionsV2XlaShapes( - type_string, attrs.num_spatial_dims, activations_shape, - expanded_filter_shape, out_backprop_shape, attrs.dilations, attrs.strides, - attrs.padding, attrs.data_format, &dims)); - // The filter gradients are computed by a convolution of the input // activations and the output gradients, with some appropriate padding. // See the comment at the top of conv_grad_ops.h for details. - xla::ConvolutionDimensionNumbers dnums; + TF_RETURN_IF_ERROR(ConvBackpropComputeDimensionsV2XlaShapes( + type_string, attrs.num_spatial_dims, activations_shape, + expanded_filter_shape, out_backprop_shape, attrs.dilations, attrs.strides, + attrs.padding, attrs.data_format, &dims)); + // The activations (inputs) form the LHS of the convolution. // Activations have shape: [batch, in_rows, in_cols, ..., in_depth] // For the gradient computation, we flip the roles of the batch and @@ -420,29 +428,99 @@ xla::StatusOr MakeXlaBackpropFilterConvOp( int n_dim = GetTensorBatchDimIndex(num_dims, attrs.data_format); int c_dim = GetTensorFeatureDimIndex(num_dims, attrs.data_format); - // Swap n_dim and c_dim in the activations. - dnums.set_input_batch_dimension(c_dim); - dnums.set_input_feature_dimension(n_dim); + int64 total_spatial_size = 1; + for (int i = 0; i < attrs.num_spatial_dims; ++i) { + total_spatial_size *= dims.input_size(i); + } - // The gradients become the RHS of the convolution. - // The gradients have shape [batch, out_rows, out_cols, ..., out_depth] - // where the batch becomes the input feature for the convolution. - dnums.set_kernel_input_feature_dimension(n_dim); - dnums.set_kernel_output_feature_dimension(c_dim); + // We use this approach only for depthwise convolutions where feature counts + // are large but space dimensions are small. The conversion logic below + // assumes that the data format is NHWC, so we also check that here. + bool should_perform_depthwise_conv = + attrs.data_format == FORMAT_NHWC && + (total_spatial_size < dims.in_depth) && + filter_tensor_shape.dim_size(num_dims - 1) == 1 && attrs.depthwise; + + int64 num_spatial_dims = + attrs.num_spatial_dims + (should_perform_depthwise_conv ? 1 : 0); + + std::vector> padding(num_spatial_dims); + std::vector rhs_dilation(num_spatial_dims); + std::vector window_strides(num_spatial_dims); + std::vector ones(num_spatial_dims, 1); + + if (should_perform_depthwise_conv) { + // This approach is similar to handling of grouped convolutions in + // the convolution_feature_group_converter.cc. Please refer to it for + // details. + + // Add spatial dimension to the activation, and reshape. + std::vector activations_reshape_sizes, gradients_reshape_sizes; + + activations_reshape_sizes.push_back(dims.batch_size); + gradients_reshape_sizes.push_back(dims.batch_size); + for (int i = 0; i < attrs.num_spatial_dims; i++) { + activations_reshape_sizes.push_back(dims.input_size(i)); + gradients_reshape_sizes.push_back(dims.output_size(i)); + } + activations_reshape_sizes.push_back(dims.in_depth); + activations_reshape_sizes.push_back(1); + gradients_reshape_sizes.push_back(dims.out_depth); + gradients_reshape_sizes.push_back(1); + + activations = xla::Reshape(activations, activations_reshape_sizes); + gradients = xla::Reshape(gradients, gradients_reshape_sizes); + + int64 new_spatial_dim = activations_reshape_sizes.size() - 1; + + // Set the newly added dimension to be the batch. + dnums.set_input_batch_dimension(new_spatial_dim); + dnums.set_input_feature_dimension(c_dim); + + // The gradients become the RHS of the convolution. + // The gradients have shape [batch, out_rows, out_cols, ..., out_depth, 1] + // where the batch becomes a spatial dimension, and 1 becomes + // the input feature for the convolution. + dnums.set_kernel_input_feature_dimension(new_spatial_dim); + dnums.set_kernel_output_feature_dimension(c_dim); + + // Treat original batch dimension as a spatial dimension. + dnums.add_input_spatial_dimensions(n_dim); + dnums.add_kernel_spatial_dimensions(n_dim); + } else { + // The activations (inputs) form the LHS of the convolution. + // Activations have shape: [batch, in_rows, in_cols, ..., in_depth] + // For the gradient computation, we flip the roles of the batch and + // feature dimensions. + // Each spatial entry has size in_depth * batch + + // Swap n_dim and c_dim in the activations. + dnums.set_input_batch_dimension(c_dim); + dnums.set_input_feature_dimension(n_dim); + + // The gradients become the RHS of the convolution. + // The gradients have shape [batch, out_rows, out_cols, ..., out_depth] + // where the batch becomes the input feature for the convolution. + dnums.set_kernel_input_feature_dimension(n_dim); + dnums.set_kernel_output_feature_dimension(c_dim); + } - std::vector> padding(attrs.num_spatial_dims); - std::vector rhs_dilation(attrs.num_spatial_dims); - std::vector window_strides(attrs.num_spatial_dims); - std::vector ones(attrs.num_spatial_dims, 1); + dnums.set_output_batch_dimension(num_spatial_dims); + dnums.set_output_feature_dimension(num_spatial_dims + 1); // Tensorflow filter shape is [ H, W, ..., inC, outC ]. - for (int i = 0; i < attrs.num_spatial_dims; ++i) { + for (int i = 0; i < num_spatial_dims; ++i) { dnums.add_output_spatial_dimensions(i); } - dnums.set_output_batch_dimension(attrs.num_spatial_dims); - dnums.set_output_feature_dimension(attrs.num_spatial_dims + 1); - for (int i = 0; i < attrs.num_spatial_dims; ++i) { + if (should_perform_depthwise_conv) { + // Set the right parameters for the newly created spatial dimension. + padding[0] = {0, 0}; + rhs_dilation[0] = 1; + window_strides[0] = 1; + } + + for (int64 i = 0; i < attrs.num_spatial_dims; ++i) { int64 dim = GetTensorSpatialDimIndex(num_dims, attrs.data_format, i); dnums.add_input_spatial_dimensions(dim); dnums.add_kernel_spatial_dimensions(dim); @@ -483,9 +561,10 @@ xla::StatusOr MakeXlaBackpropFilterConvOp( const int64 pad_before = attrs.padding == Padding::SAME ? std::max(pad_total / 2, 0) : 0; - padding[i] = {pad_before, pad_total - pad_before}; - rhs_dilation[i] = dims.spatial_dims[i].stride; - window_strides[i] = attrs.dilations[dim]; + int64 dim_being_operated = should_perform_depthwise_conv ? i + 1 : i; + padding[dim_being_operated] = {pad_before, pad_total - pad_before}; + rhs_dilation[dim_being_operated] = dims.spatial_dims[i].stride; + window_strides[dim_being_operated] = attrs.dilations[dim]; } // Besides padding the input, we will also expand output_rows to @@ -496,13 +575,19 @@ xla::StatusOr MakeXlaBackpropFilterConvOp( // // This is done by specifying the window dilation factors in the // convolution HLO below. - auto filter_backprop = - xla::ConvGeneralDilated(activations, gradients, window_strides, padding, - /*lhs_dilation=*/ones, rhs_dilation, dnums); - - if (attrs.depthwise) { - filter_backprop = ContractFilterForDepthwiseBackprop( - filter_shape, filter_backprop, activations.builder()); + filter_backprop = xla::ConvGeneralDilated( + activations, gradients, window_strides, padding, + /*lhs_dilation=*/ones, rhs_dilation, dnums, + /*feature_group_count=*/ + should_perform_depthwise_conv ? dims.in_depth : 1); + + if (should_perform_depthwise_conv) { + filter_backprop = xla::Reshape(filter_backprop, filter_shape.dimensions()); + } else { + if (attrs.depthwise) { + filter_backprop = ContractFilterForDepthwiseBackprop( + filter_shape, filter_backprop, activations.builder()); + } } return filter_backprop; diff --git a/tensorflow/compiler/tf2xla/kernels/conv_ops.cc b/tensorflow/compiler/tf2xla/kernels/conv_ops.cc index d820528a43064e327cb90e5a2889f77ab1f3f3e2..eafdba876ae9e2c38694f065cf83bb3725b8460e 100644 --- a/tensorflow/compiler/tf2xla/kernels/conv_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/conv_ops.cc @@ -22,7 +22,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/lib/constants.h" -#include "tensorflow/compiler/xla/client/lib/numeric.h" +#include "tensorflow/compiler/xla/client/lib/matrix.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/core/framework/node_def_util.h" diff --git a/tensorflow/compiler/tf2xla/kernels/diag_op.cc b/tensorflow/compiler/tf2xla/kernels/diag_op.cc index 49c12fc232092873b69961644a059abc6035f64f..ee79cbc70da269be7586c47b4fd33c901f4fd581 100644 --- a/tensorflow/compiler/tf2xla/kernels/diag_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/diag_op.cc @@ -19,7 +19,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/lib/constants.h" -#include "tensorflow/compiler/xla/client/lib/numeric.h" +#include "tensorflow/compiler/xla/client/lib/matrix.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/framework/op_kernel.h" diff --git a/tensorflow/compiler/tf2xla/kernels/gather_op.cc b/tensorflow/compiler/tf2xla/kernels/gather_op.cc index 20b0de193dc060197f3062d3be0b8d45f7dcb9b1..41c31d0ed58fe9bc9bbde0bd58993c975f04fd60 100644 --- a/tensorflow/compiler/tf2xla/kernels/gather_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/gather_op.cc @@ -14,7 +14,6 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/tf2xla/kernels/gather_op_helpers.h" -#include "tensorflow/compiler/tf2xla/lib/while_loop.h" #include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/tf2xla/type_util.h" #include "tensorflow/compiler/tf2xla/xla_context.h" diff --git a/tensorflow/compiler/tf2xla/kernels/if_op.cc b/tensorflow/compiler/tf2xla/kernels/if_op.cc index b5e083912555c865b5eadc7697075c9ca4451ca9..4f0f0fd9aefecc3d31f8bd9c8ca40ebb0860c82d 100644 --- a/tensorflow/compiler/tf2xla/kernels/if_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/if_op.cc @@ -56,6 +56,7 @@ void XlaIfOp::Compile(XlaOpKernelContext* ctx) { VLOG(1) << "Building If: " << input_types_.size() << " inputs"; std::vector arguments(input_types_.size()); + int num_resource_args = 0; for (int i = 0; i < input_types_.size(); ++i) { XlaCompiler::Argument& arg = arguments[i]; DataType type = ctx->input_type(i + 1); @@ -81,6 +82,8 @@ void XlaIfOp::Compile(XlaOpKernelContext* ctx) { << " type: " << DataTypeString(arg.type) << " shape: " << arg.shape.DebugString() << " initialized: " << arg.initialized; + + num_resource_args++; } else { arg.kind = XlaCompiler::Argument::kParameter; arg.type = input_types_[i]; @@ -236,9 +239,13 @@ void XlaIfOp::Compile(XlaOpKernelContext* ctx) { ctx->SetOutput(i, output_handle); } if (has_token_input_output_) { - // Set token output for this "if" op. + // Set token output for this "If" op. Token output is the last output of + // XLA computation, which comes after all "normal" TF outputs and resource + // updates. For "If" node, num of resource updates equals to number of + // resource args because we set `return_updated_values_for_all_resources` + // to true in XlaCompiler option. xla::XlaOp token_output = - xla::GetTupleElement(outputs, output_types_.size()); + xla::GetTupleElement(outputs, output_types_.size() + num_resource_args); auto shape_or = b->GetShape(token_output); OP_REQUIRES_OK(ctx, shape_or.status()); OP_REQUIRES(ctx, xla::ShapeUtil::IsToken(shape_or.ValueOrDie()), diff --git a/tensorflow/compiler/tf2xla/kernels/image_ops.cc b/tensorflow/compiler/tf2xla/kernels/image_ops.cc index e9bb0a77e99d144863b027bd214081316d61c314..96ddd42e2ae04d454e4fb85628d139e17a543d2e 100644 --- a/tensorflow/compiler/tf2xla/kernels/image_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/image_ops.cc @@ -15,12 +15,12 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/kernels/gather_op_helpers.h" #include "tensorflow/compiler/tf2xla/lib/util.h" -#include "tensorflow/compiler/tf2xla/lib/while_loop.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/lib/arithmetic.h" #include "tensorflow/compiler/xla/client/lib/constants.h" +#include "tensorflow/compiler/xla/client/lib/loops.h" #include "tensorflow/compiler/xla/client/lib/sorting.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/shape_util.h" @@ -505,9 +505,9 @@ class NonMaxSuppressionOp : public XlaOpKernel { init_values.push_back(included_iou); auto suppress_loop_result = - XlaWhileLoop(WhileCondFn(num_boxes, output_size), - SuppressBodyFn(num_boxes), init_values, "suppress_loop", - builder) + xla::WhileLoopHelper(WhileCondFn(num_boxes, output_size), + SuppressBodyFn(num_boxes), init_values, + "suppress_loop", builder) .ValueOrDie(); xla::XlaOp included_score = diff --git a/tensorflow/compiler/tf2xla/kernels/image_resize_ops.cc b/tensorflow/compiler/tf2xla/kernels/image_resize_ops.cc index 0c7ca602bfacd598dada0303d3a3e77fe7f1b0fc..5a10c52ba8b6d4fab73f0dda67cbd52fd625e76b 100644 --- a/tensorflow/compiler/tf2xla/kernels/image_resize_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/image_resize_ops.cc @@ -19,7 +19,6 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/array4d.h" #include "tensorflow/compiler/xla/client/lib/constants.h" -#include "tensorflow/compiler/xla/client/lib/numeric.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/core/framework/kernel_def_builder.h" #include "tensorflow/core/framework/register_types.h" diff --git a/tensorflow/compiler/tf2xla/kernels/matrix_triangular_solve_op.cc b/tensorflow/compiler/tf2xla/kernels/matrix_triangular_solve_op.cc index f4def11d08c31513aec5aad15187016a7294c2fd..90c0ebefb24ec2c4378782e9b15d3f57c33032a4 100644 --- a/tensorflow/compiler/tf2xla/kernels/matrix_triangular_solve_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/matrix_triangular_solve_op.cc @@ -13,9 +13,9 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/tf2xla/lib/triangular_solve.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/lib/triangular_solve.h" namespace tensorflow { namespace { @@ -29,7 +29,7 @@ class MatrixTriangularSolveOp : public XlaOpKernel { } void Compile(XlaOpKernelContext* ctx) override { - auto result = TriangularSolve( + auto result = xla::TriangularSolve( ctx->Input(0), ctx->Input(1), /*left_side=*/true, /*lower=*/lower_, /*transpose_a=*/adjoint_, /*conjugate_a=*/adjoint_); ctx->SetOutput(0, result); diff --git a/tensorflow/compiler/tf2xla/kernels/pooling_ops.cc b/tensorflow/compiler/tf2xla/kernels/pooling_ops.cc index a259da6383d461fd11b0d79096bf66aae7ddef06..06c6cc37ec90192486ba15010bfeb763a9ffb987 100644 --- a/tensorflow/compiler/tf2xla/kernels/pooling_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/pooling_ops.cc @@ -152,7 +152,12 @@ class MaxPoolOp : public PoolingOp { public: MaxPoolOp(OpKernelConstruction* ctx, int num_spatial_dims) : PoolingOp(ctx, /*num_spatial_dims=*/num_spatial_dims, - /*reduction_type=*/ctx->input_type(0)) {} + /*reduction_type=*/ctx->input_type(0)) { + string data_format_str; + OP_REQUIRES_OK(ctx, ctx->GetAttr("data_format", &data_format_str)); + OP_REQUIRES(ctx, FormatFromString(data_format_str, &data_format_), + errors::InvalidArgument("Invalid data format")); + } void Compile(XlaOpKernelContext* ctx) override { auto ksize_or_error = GetKernelSize(ctx); @@ -180,10 +185,6 @@ class MaxPool2DOp : public MaxPoolOp { public: explicit MaxPool2DOp(OpKernelConstruction* ctx) : MaxPoolOp(ctx, /*num_spatial_dims=*/2) { - string data_format_str; - OP_REQUIRES_OK(ctx, ctx->GetAttr("data_format", &data_format_str)); - OP_REQUIRES(ctx, FormatFromString(data_format_str, &data_format_), - errors::InvalidArgument("Invalid data format")); } }; REGISTER_XLA_OP(Name("MaxPool"), MaxPool2DOp); @@ -204,7 +205,12 @@ class AvgPoolOp : public PoolingOp { AvgPoolOp(OpKernelConstruction* ctx, int num_spatial_dims) : PoolingOp(ctx, /*num_spatial_dims=*/num_spatial_dims, /*reduction_type=*/ - XlaHelpers::SumAccumulationType(ctx->input_type(0))) {} + XlaHelpers::SumAccumulationType(ctx->input_type(0))) { + string data_format_str; + OP_REQUIRES_OK(ctx, ctx->GetAttr("data_format", &data_format_str)); + OP_REQUIRES(ctx, FormatFromString(data_format_str, &data_format_), + errors::InvalidArgument("Invalid data format")); + } void Compile(XlaOpKernelContext* ctx) override { auto ksize_or_error = GetKernelSize(ctx); @@ -241,10 +247,6 @@ class AvgPool2DOp : public AvgPoolOp { public: explicit AvgPool2DOp(OpKernelConstruction* ctx) : AvgPoolOp(ctx, /*num_spatial_dims=*/2) { - string data_format_str; - OP_REQUIRES_OK(ctx, ctx->GetAttr("data_format", &data_format_str)); - OP_REQUIRES(ctx, FormatFromString(data_format_str, &data_format_), - errors::InvalidArgument("Invalid data format")); } }; REGISTER_XLA_OP(Name("AvgPool"), AvgPool2DOp); @@ -390,6 +392,11 @@ class AvgPoolGradOp : public XlaOpKernel { OP_REQUIRES(ctx, ksize_[0] == 1 && stride_[0] == 1, errors::Unimplemented( "Pooling is not yet supported on the batch dimension.")); + + string data_format; + OP_REQUIRES_OK(ctx, ctx->GetAttr("data_format", &data_format)); + OP_REQUIRES(ctx, FormatFromString(data_format, &data_format_), + errors::InvalidArgument("Invalid data format")); } int num_dims() const { return num_spatial_dims_ + 2; } @@ -449,10 +456,6 @@ class AvgPool2DGradOp : public AvgPoolGradOp { public: explicit AvgPool2DGradOp(OpKernelConstruction* ctx) : AvgPoolGradOp(ctx, /*num_spatial_dims=*/2) { - string data_format; - OP_REQUIRES_OK(ctx, ctx->GetAttr("data_format", &data_format)); - OP_REQUIRES(ctx, FormatFromString(data_format, &data_format_), - errors::InvalidArgument("Invalid data format")); } }; REGISTER_XLA_OP( diff --git a/tensorflow/compiler/tf2xla/kernels/qr_op.cc b/tensorflow/compiler/tf2xla/kernels/qr_op.cc index 7ea0afc1f53cbe4cfcc3f6121a4ecd55864c1b52..66ec40a946b8a063d84acd33daf81f52ea2c35ed 100644 --- a/tensorflow/compiler/tf2xla/kernels/qr_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/qr_op.cc @@ -13,9 +13,9 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/tf2xla/lib/qr.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/lib/qr.h" namespace tensorflow { namespace { @@ -26,7 +26,7 @@ class QROp : public XlaOpKernel { OP_REQUIRES_OK(ctx, ctx->GetAttr("full_matrices", &full_matrices_)); } void Compile(XlaOpKernelContext* ctx) override { - auto result = QRDecomposition(ctx->Input(0), full_matrices_); + auto result = xla::QRDecomposition(ctx->Input(0), full_matrices_); if (!result.ok()) { ctx->SetStatus(result.status()); return; diff --git a/tensorflow/compiler/tf2xla/kernels/random_ops.cc b/tensorflow/compiler/tf2xla/kernels/random_ops.cc index 8822e29f7e77b1cbc6fa6ca61d0062d9b1b0c36e..2d92056e4f522f6206e7d632f0fa1e8b793fd6e3 100644 --- a/tensorflow/compiler/tf2xla/kernels/random_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/random_ops.cc @@ -20,12 +20,12 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/kernels/gather_op_helpers.h" #include "tensorflow/compiler/tf2xla/lib/random.h" #include "tensorflow/compiler/tf2xla/lib/util.h" -#include "tensorflow/compiler/tf2xla/lib/while_loop.h" #include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/lib/arithmetic.h" +#include "tensorflow/compiler/xla/client/lib/loops.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/tensor.h" @@ -175,8 +175,8 @@ class RandomShuffleOp : public XlaOpKernel { }; // for i in range(n): auto swap_loop_result = - XlaForEachIndex(n, xla::S32, swap_body_fn, {swaps, indices}, - "indices_swap_loop", builder) + xla::ForEachIndex(n, xla::S32, swap_body_fn, {swaps, indices}, + "indices_swap_loop", builder) .ValueOrDie(); auto swapped_indices = swap_loop_result[1]; diff --git a/tensorflow/compiler/tf2xla/kernels/resampler_ops.cc b/tensorflow/compiler/tf2xla/kernels/resampler_ops.cc index 769e0cd1409dd7e8099178c8d80b5a9adb0b20b3..f9985d526033ca675c701a508a3d1576e46bc5f7 100644 --- a/tensorflow/compiler/tf2xla/kernels/resampler_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/resampler_ops.cc @@ -24,7 +24,6 @@ limitations under the License. #include "tensorflow/compiler/xla/array4d.h" #include "tensorflow/compiler/xla/client/lib/arithmetic.h" #include "tensorflow/compiler/xla/client/lib/constants.h" -#include "tensorflow/compiler/xla/client/lib/numeric.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/shape_util.h" @@ -126,7 +125,7 @@ XlaOp ConcatenateIota(xla::XlaBuilder* b, XlaOp indices, dimensions.back() = 1; auto batch_indices = - xla::Iota(b, xla::ShapeUtil::MakeShape(xla::U32, dimensions), + xla::Iota(b, xla::ShapeUtil::MakeShape(xla::S32, dimensions), /*iota_dimension=*/0); return xla::ConcatInDim(b, {batch_indices, indices}, dimensions.size() - 1); @@ -190,11 +189,53 @@ XlaOp ScatterToGradData(XlaOpKernelContext* ctx, XlaOp grad_data, XlaOp indices, scatter_dim_numbers); } +// Bounds samples to 0 if the warp image indices are out of the (-1, image_size) +// bound. +// The resulting dimension is given by 'result_dims'. +XlaOp BoundSamples(XlaOpKernelContext* ctx, XlaOp warp, + xla::PrimitiveType warp_type, TensorShape warp_shape, + std::vector result_dims, + std::vector broadcasted_dims, int64 last_warp_dim, + xla::Shape data_shape, XlaOp sample) { + auto is_gt_minus_one = + xla::Gt(warp, + xla::ConvertElementType( + xla::ConstantR1(ctx->builder(), {-1, -1}), warp_type), + /*broadcast_dimensions=*/{warp_shape.dims() - 1}); + auto is_lt_image_size = xla::Lt( + warp, + xla::ConvertElementType( + xla::ConstantR1( + ctx->builder(), + {/*width=*/static_cast(data_shape.dimensions(2)), + /*height=*/static_cast(data_shape.dimensions(1))}), + warp_type), + /*broadcast_dimensions=*/{warp_shape.dims() - 1}); + + auto is_in_bound_padded_x_y = xla::And(is_gt_minus_one, is_lt_image_size); + // Reduce along last dimension. The resulting dimension is: + // [batch, dim_0, ...dim_n]. + auto is_in_bound = xla::Reduce( + is_in_bound_padded_x_y, xla::ConstantR0(ctx->builder(), true), + xla::CreateScalarAndComputation(xla::PrimitiveType::PRED, ctx->builder()), + {last_warp_dim}); + + // Broadcast 'is_in_bound' to the same dimension as 'result_dims'. + auto broadcasted_is_in_bound = + xla::BroadcastInDim(is_in_bound, result_dims, broadcasted_dims); + + // Set out of bound samples to zero. + auto zeros = + xla::Broadcast(xla::Zero(ctx->builder(), warp_type), result_dims); + return xla::Select(broadcasted_is_in_bound, sample, zeros); +} + // Build computation the backprop into input 'data'. // Where input: // grad_output is of dimension [batch, dim_0, ...dim_n, channel] // ratio is of dimension [batch, dim_0, ...dim_n, 2] // gather_indices is of dimension [batch, dim_0, ...dim_n, 3] +// data_shape is of dimension [batch, x(width), y(height), channel] // // Output: // scatter-add to each 2x2 grad_data neighbor: @@ -202,10 +243,12 @@ XlaOp ScatterToGradData(XlaOpKernelContext* ctx, XlaOp grad_data, XlaOp indices, // grad_data[cx, fy, chan] += output_grad * (1 - dx) * dy // grad_data[fx, cy, chan] += output_grad * dx * (1 - dy) // grad_data[cx, cy, chan] += output_grad * (1 - dx) * (1 - dy) -// where (dx, dy) is (1 - ratio). +// where (dx, dy) is (1 - ratio). If (dx, dy) is out of bound, then the their +// contribution is 0 to 'grad_data'. XlaOp CalculateGradData(XlaOpKernelContext* ctx, XlaOp grad_output, XlaOp ratio, - XlaOp gather_indices, xla::PrimitiveType warp_type, - TensorShape warp_shape, int64 data_channels, + XlaOp gather_indices, XlaOp warp, + xla::PrimitiveType warp_type, TensorShape warp_shape, + int64 last_warp_dim, int64 data_channels, xla::Shape data_shape) { // Weights tensor has dimension [batch, dim_0, ... dim_n, 4]. auto weights = BilinearWeights(ctx, ratio, warp_shape, warp_type); @@ -230,6 +273,18 @@ XlaOp CalculateGradData(XlaOpKernelContext* ctx, XlaOp grad_output, XlaOp ratio, std::iota(reshaped_weights_indices.begin(), reshaped_weights_indices.end(), 0); + // Set out of bound weights to 0. + // The dimension of the reshaped_weight: [batch, dim_0, ...dim_n, 2, 2]. + std::vector reshaped_result_dims(warp_dims.begin(), + warp_dims.end() - 1); + reshaped_result_dims.push_back(2); + reshaped_result_dims.push_back(2); + std::vector broadcasted_dims(warp_dims.size() - 1); + std::iota(broadcasted_dims.begin(), broadcasted_dims.end(), 0); + reshaped_weights = BoundSamples(ctx, warp, warp_type, warp_shape, + reshaped_result_dims, broadcasted_dims, + last_warp_dim, data_shape, reshaped_weights); + // The dimension is [batch, dim_0, ..., dim_n, 2, 2, data_channel]. auto broadcast_reshaped_weights = xla::BroadcastInDim( reshaped_weights, weights_with_channels_dims, reshaped_weights_indices); @@ -246,18 +301,41 @@ XlaOp CalculateGradData(XlaOpKernelContext* ctx, XlaOp grad_output, XlaOp ratio, auto grad_data = xla::ConstantLiteral( ctx->builder(), xla::Literal::CreateFromShape(data_shape)); - return ScatterToGradData(ctx, grad_data, gather_indices, - grad_output_multiply_weights, warp_shape.dims(), - warp_type); + // Pad grad data then slice it back. + // + // After left and right column 0-padding, the new dimension of padded data + // will be [batch, x+2, y+2, channel]. + auto padded_grad_data = + xla::Pad(grad_data, xla::Zero(ctx->builder(), warp_type), + xla::MakeEdgePaddingConfig({{0, 0}, {1, 1}, {1, 1}, {0, 0}})); + + auto shifting_value = xla::ConstantR1( + ctx->builder(), {/*batch=*/0, /*x(width)=*/1, /*y(height)=*/1}); + auto shifted_gather_indices = + xla::Add(gather_indices, shifting_value, {last_warp_dim}); + + auto updated_grad_data = ScatterToGradData( + ctx, padded_grad_data, shifted_gather_indices, + grad_output_multiply_weights, warp_shape.dims(), warp_type); + + const int64 batch_size = data_shape.dimensions(0); + const int64 width = data_shape.dimensions(1); + const int64 height = data_shape.dimensions(2); + // Slice out the result accounting for the padding. + return xla::Slice( + updated_grad_data, /*start_indices=*/{0, 1, 1, 0}, + /*limit_indices=*/{batch_size, width + 1, height + 1, data_channels}, + /*strides=*/{1, 1, 1, 1}); } // Build computation for the backprop into input 'warp'. // Where input: -// warp is of dimension [batch, dim_0, ...dim_n, 2] -// grad_output is of dimension [batch, dim_0, ...dim_n, channel] -// ratio is of dimension [batch, dim_0, ...dim_n, 2] -// gather_indices is of dimension [batch, dim_0, ...dim_n, 3] -// data is of dimension [batch, x, y, channel] +// warp is of dimension [batch, dim_0, ...dim_n, 2] +// grad_output is of dimension [batch, dim_0, ...dim_n, channel] +// ratio is of dimension [batch, dim_0, ...dim_n, 2] +// gather_indices is of dimension [batch, dim_0, ...dim_n, 3] where the last +// dimension of size 3 is for {batch, x(width), y(height)}. +// data is of dimension [batch, x, y, channel] // // Output (simplified by ignoring the batch dimensions): // Since the forward path has: @@ -276,12 +354,12 @@ XlaOp CalculateGradData(XlaOpKernelContext* ctx, XlaOp grad_output, XlaOp ratio, // grad_warp_x = py * (img_cxcy - img_fxcy) + (1-py) * (img_cxfy-img_fxfy) // grad_warp_y = px * (img_cxcy - img_cxfy) + (1-px) * (img_fxcy-img_fxfy) // -// where (px, py) is warp, (fx, fy) is the left top corner and (cx, cy) is the +// where (px, py) is warp, (fx, fy) is the top left corner and (cx, cy) is the // bottom right corner in a 2x2 neighborhood. XlaOp CalculateGradWarp(XlaOpKernelContext* ctx, XlaOp grad_output, XlaOp ratio, XlaOp gather_indices, XlaOp data, TensorShape warp_shape, int64 data_channels, - xla::PrimitiveType data_type) { + xla::PrimitiveType data_type, xla::Shape data_shape) { auto warp_dims = warp_shape.dim_sizes(); std::vector warp_dims_without_last_dims(warp_dims.begin(), warp_dims.end() - 1); @@ -290,12 +368,30 @@ XlaOp CalculateGradWarp(XlaOpKernelContext* ctx, XlaOp grad_output, XlaOp ratio, std::vector neighbor_broadcast_dims = warp_dims_without_last_dims; neighbor_broadcast_dims.push_back(4); - // The dimension is [batch, dim_0, ... dim_n, 4, data_channels] - auto neighbors_data = Gather2by2Neighbors( - ctx->builder(), data, gather_indices, data_channels, warp_shape.dims()); + // With dimension [batch, dim_0, ...dim_n, 4] + auto neighbor_broadcast_shape = + xla::ShapeUtil::MakeShape(data_type, neighbor_broadcast_dims); const int64 last_warp_dim = warp_shape.dims() - 1; + // Pad data with 0, before gathering such that 0 will be returned for samples + // in the range of (-1, 0) or (image_dimension-1, image_dimension). + // After left and right column 0-padding, the new dimension of padded data + // will be [batch, x+2, y+2, channel]. + auto padded_data = + xla::Pad(data, xla::Zero(ctx->builder(), data_type), + xla::MakeEdgePaddingConfig({{0, 0}, {1, 1}, {1, 1}, {0, 0}})); + + auto shifting_value = xla::ConstantR1( + ctx->builder(), {/*batch=*/0, /*x(width)=*/1, /*y(height)=*/1}); + auto shifted_gather_indices = + xla::Add(gather_indices, shifting_value, {last_warp_dim}); + + // The dimension is [batch, dim_0, ... dim_n, 4, data_channels] + auto neighbors_data = + Gather2by2Neighbors(ctx->builder(), padded_data, shifted_gather_indices, + data_channels, warp_shape.dims()); + // Since we will be creating the dot product of: // lhs: [batch, dim_0, ...dim_n, 4] // and @@ -418,7 +514,7 @@ class ResamplerOp : public XlaOpKernel { // Find the coordinates of the top left corner for the 2x2 region to be // sampled from. The dimensions are [batch, dim_0, ... dim_n, 2] where the // last dimension of size 2 in turn is [x, y]. - XlaOp top_left = xla::ConvertElementType(warp, xla::U32); + XlaOp top_left = xla::ConvertElementType(warp, xla::S32); auto gather_indices = ConcatenateIota(ctx->builder(), top_left, warp_shape); @@ -527,7 +623,8 @@ class ResamplerGradOp : public XlaOpKernel { size, "]")); } // Last dimension of warp shape must be of size 2. - OP_REQUIRES(ctx, warp_shape.dim_size(warp_shape.dims() - 1) == 2, + const int64 last_warp_dim = warp_shape.dims() - 1; + OP_REQUIRES(ctx, warp_shape.dim_size(last_warp_dim) == 2, errors::InvalidArgument( "the last dimension of warp must be exactly size 2.")); xla::PrimitiveType warp_type = ctx->input_xla_type(1); @@ -550,24 +647,32 @@ class ResamplerGradOp : public XlaOpKernel { // Find the top left corner coordinate for the region to be sampled from. // The dimensions are [batch, dim_0, ... dim_n, 2] where the last dimension // of size 2 in turn is [x, y]. - XlaOp top_left = xla::ConvertElementType(warp, xla::U32); + XlaOp top_left = xla::ConvertElementType(xla::Floor(warp), xla::S32); - // Dimensions are [batch, dim_0, ... dim_n, 2] + // Dimensions are [batch, dim_0, ... dim_n, 2]. XlaOp ratio = warp - xla::ConvertElementType(top_left, warp_type); // Indices for gathering neighboring pixels. auto gather_indices = ConcatenateIota(ctx->builder(), top_left, warp_shape); - auto grad_data = - CalculateGradData(ctx, grad_output, ratio, gather_indices, warp_type, - warp_shape, data_channels, data_shape); + auto grad_data = CalculateGradData( + ctx, grad_output, ratio, gather_indices, warp, warp_type, warp_shape, + last_warp_dim, data_channels, data_shape); auto grad_warp = CalculateGradWarp(ctx, grad_output, ratio, gather_indices, data, - warp_shape, data_channels, data_type); + warp_shape, data_channels, data_type, data_shape); + auto warp_dims = warp_shape.dim_sizes(); + std::vector result_dims(warp_dims.begin(), warp_dims.end() - 1); + result_dims.push_back(2); + std::vector broadcasted_dims(warp_dims.size() - 1); + std::iota(broadcasted_dims.begin(), broadcasted_dims.end(), 0); + auto grad_warp_bounded = + BoundSamples(ctx, warp, warp_type, warp_shape, result_dims, + broadcasted_dims, last_warp_dim, data_shape, grad_warp); ctx->SetOutput(0, grad_data); - ctx->SetOutput(1, grad_warp); + ctx->SetOutput(1, grad_warp_bounded); } }; diff --git a/tensorflow/compiler/tf2xla/kernels/stateless_random_ops.cc b/tensorflow/compiler/tf2xla/kernels/stateless_random_ops.cc index 5db52781be473a9a1aef0adf105e3edf69ccd306..50653d7b3973b73d580cdeec5d71943b575d7cc9 100644 --- a/tensorflow/compiler/tf2xla/kernels/stateless_random_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/stateless_random_ops.cc @@ -23,7 +23,6 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/lib/constants.h" #include "tensorflow/compiler/xla/client/lib/math.h" -#include "tensorflow/compiler/xla/client/lib/numeric.h" #include "tensorflow/compiler/xla/client/lib/prng.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/core/framework/op_kernel.h" diff --git a/tensorflow/compiler/tf2xla/kernels/topk_op.cc b/tensorflow/compiler/tf2xla/kernels/topk_op.cc index 8a0c94cfae1b298bd62a3231caf39ecf9b32880e..ee3bdf3394e37c757f31724e73e95417becaa534 100644 --- a/tensorflow/compiler/tf2xla/kernels/topk_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/topk_op.cc @@ -15,7 +15,6 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "tensorflow/compiler/xla/client/lib/numeric.h" #include "tensorflow/compiler/xla/client/lib/sorting.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/literal.h" diff --git a/tensorflow/compiler/tf2xla/kernels/training_ops.cc b/tensorflow/compiler/tf2xla/kernels/training_ops.cc index 960c1462ceb8c00a2d6c96564f6c985fd1caef0f..26d4214099d1d07c1b2e275d783654d9cd948e28 100644 --- a/tensorflow/compiler/tf2xla/kernels/training_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/training_ops.cc @@ -172,6 +172,65 @@ class ResourceApplyMomentum : public XlaOpKernel { REGISTER_XLA_OP(Name("ResourceApplyMomentum").TypeConstraint("T", kFloatTypes), ResourceApplyMomentum); +class ResourceApplyKerasMomentum : public XlaOpKernel { + public: + explicit ResourceApplyKerasMomentum(OpKernelConstruction* ctx) + : XlaOpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("use_nesterov", &use_nesterov_)); + } + + void Compile(XlaOpKernelContext* ctx) override { + DataType type = ctx->input_type(2); + + TensorShape var_shape, accum_shape; + xla::XlaOp var, accum; + OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, type, &var_shape, &var)); + OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(1, type, &accum_shape, &accum)); + + OP_REQUIRES(ctx, var_shape.IsSameSize(accum_shape), + errors::InvalidArgument( + "var and accum do not have the same shape", + var_shape.DebugString(), " ", accum_shape.DebugString())); + + TensorShape lr_shape = ctx->InputShape(2); + OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(lr_shape), + errors::InvalidArgument("lr is not a scalar: ", + lr_shape.DebugString())); + + TensorShape grad_shape = ctx->InputShape(3); + OP_REQUIRES(ctx, var_shape.IsSameSize(grad_shape), + errors::InvalidArgument( + "var and grad do not have the same shape", + var_shape.DebugString(), " ", grad_shape.DebugString())); + + TensorShape momentum_shape = ctx->InputShape(4); + OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(momentum_shape), + errors::InvalidArgument("momentum is not a scalar: ", + momentum_shape.DebugString())); + + xla::XlaOp lr = ctx->Input(2); + xla::XlaOp grad = ctx->Input(3); + xla::XlaOp momentum = ctx->Input(4); + + accum = accum * momentum - grad * lr; + if (use_nesterov_) { + // See https://github.com/tensorflow/tensorflow/pull/2798 for an + // explanation of the reparameterization used here. + var = var + accum * momentum - grad * lr; + } else { + var = var + accum; + } + OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, type, var)); + OP_REQUIRES_OK(ctx, ctx->AssignVariable(1, type, accum)); + } + + private: + bool use_nesterov_; +}; +REGISTER_XLA_OP( + Name("ResourceApplyKerasMomentum").TypeConstraint("T", kFloatTypes), + ResourceApplyKerasMomentum); + class ResourceApplyAdagrad : public XlaOpKernel { public: explicit ResourceApplyAdagrad(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} diff --git a/tensorflow/compiler/tf2xla/kernels/while_op.cc b/tensorflow/compiler/tf2xla/kernels/while_op.cc index ce007fc04a818869686b9936a1607cee42665e87..89b577bfc05b4665d492f4ea5cf6f869af2fa9a9 100644 --- a/tensorflow/compiler/tf2xla/kernels/while_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/while_op.cc @@ -41,8 +41,7 @@ Status MakeXlaCompilerArgumentsFromInputs( *has_uninitialized_vars = false; *has_tensor_arrays = false; for (int i = 0; i < ctx->num_inputs(); ++i) { - VLOG(2) << " Input " << i - << " type: " << DataTypeString(ctx->input_type(i)) + VLOG(2) << " Input " << i << " type: " << DataTypeString(ctx->input_type(i)) << " shape: " << ctx->InputShape(i).DebugString(); XlaCompiler::Argument& arg = (*args)[i]; DataType type = ctx->input_type(i); @@ -233,13 +232,22 @@ void XlaWhileOp::Compile(XlaOpKernelContext* ctx) { xla::ShapeUtil::HumanString(body_input_shape), " vs. ", xla::ShapeUtil::HumanString(body.xla_output_shape))); - xla::Shape expected_cond_output_shape = xla::ShapeUtil::MakeTupleShape( - {xla::ShapeUtil::MakeShape(xla::PRED, {})}); + xla::Shape expected_cond_output_shape_without_side_effect = + xla::ShapeUtil::MakeTupleShape( + {xla::ShapeUtil::MakeShape(xla::PRED, {})}); + xla::Shape expected_cond_output_shape_with_side_effect = + xla::ShapeUtil::MakeTupleShape({xla::ShapeUtil::MakeShape(xla::PRED, {}), + xla::ShapeUtil::MakeTokenShape()}); OP_REQUIRES(ctx, - xla::ShapeUtil::Compatible(cond.xla_output_shape, - expected_cond_output_shape), + xla::ShapeUtil::Compatible( + cond.xla_output_shape, + expected_cond_output_shape_without_side_effect) || + xla::ShapeUtil::Compatible( + cond.xla_output_shape, + expected_cond_output_shape_with_side_effect), errors::InvalidArgument( - "Output shape of loop condition should be (pred[]), got: ", + "Output shape of loop condition should be (pred[]) or " + "(pred[], token[]), got: ", xla::ShapeUtil::HumanString(cond.xla_output_shape))); int num_inputs = body.input_mapping.size(); diff --git a/tensorflow/compiler/tf2xla/lib/BUILD b/tensorflow/compiler/tf2xla/lib/BUILD index 1ce3930fd1cd91f8e8dfb765b49be2dc969d1bd7..3d7b0bc959f9dbf3c1b9749379e2ea0d285b302b 100644 --- a/tensorflow/compiler/tf2xla/lib/BUILD +++ b/tensorflow/compiler/tf2xla/lib/BUILD @@ -15,22 +15,6 @@ filegroup( ]), ) -load("//tensorflow/compiler/xla/tests:build_defs.bzl", "xla_test") - -cc_library( - name = "batch_dot", - srcs = ["batch_dot.cc"], - hdrs = ["batch_dot.h"], - deps = [ - "//tensorflow/compiler/xla:shape_util", - "//tensorflow/compiler/xla:status_macros", - "//tensorflow/compiler/xla:statusor", - "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/client:xla_builder", - "//tensorflow/core:lib", - ], -) - cc_library( name = "broadcast", srcs = ["broadcast.cc"], @@ -47,26 +31,6 @@ cc_library( ], ) -cc_library( - name = "cholesky", - srcs = ["cholesky.cc"], - hdrs = ["cholesky.h"], - deps = [ - ":batch_dot", - ":triangular_solve", - ":util", - ":while_loop", - "//tensorflow/compiler/xla:literal", - "//tensorflow/compiler/xla:shape_util", - "//tensorflow/compiler/xla:status_macros", - "//tensorflow/compiler/xla:statusor", - "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/client:xla_builder", - "//tensorflow/compiler/xla/client/lib:constants", - "//tensorflow/core:lib", - ], -) - cc_library( name = "random", srcs = ["random.cc"], @@ -82,35 +46,12 @@ cc_library( ], ) -cc_library( - name = "qr", - srcs = ["qr.cc"], - hdrs = ["qr.h"], - deps = [ - ":batch_dot", - ":util", - ":while_loop", - "//tensorflow/compiler/xla:literal_util", - "//tensorflow/compiler/xla:shape_util", - "//tensorflow/compiler/xla:status_macros", - "//tensorflow/compiler/xla:statusor", - "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/client:xla_builder", - "//tensorflow/compiler/xla/client/lib:arithmetic", - "//tensorflow/compiler/xla/client/lib:constants", - "//tensorflow/compiler/xla/client/lib:math", - "//tensorflow/compiler/xla/client/lib:numeric", - "//tensorflow/core:lib", - ], -) - cc_library( name = "scatter", srcs = ["scatter.cc"], hdrs = ["scatter.h"], deps = [ ":util", - ":while_loop", "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", @@ -124,51 +65,6 @@ cc_library( ], ) -cc_library( - name = "triangular_solve", - srcs = ["triangular_solve.cc"], - hdrs = ["triangular_solve.h"], - deps = [ - ":batch_dot", - ":util", - "//tensorflow/compiler/xla:literal", - "//tensorflow/compiler/xla:shape_util", - "//tensorflow/compiler/xla:status_macros", - "//tensorflow/compiler/xla:statusor", - "//tensorflow/compiler/xla:util", - "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/client:xla_builder", - "//tensorflow/compiler/xla/client:xla_computation", - "//tensorflow/compiler/xla/client/lib:constants", - "//tensorflow/compiler/xla/client/lib:numeric", - "//tensorflow/core:lib", - ], -) - -xla_test( - name = "triangular_solve_test", - srcs = ["triangular_solve_test.cc"], - tags = ["noasan"], # sometimes times out, http://b/78650012 - deps = [ - ":triangular_solve", - "//tensorflow/compiler/xla:array2d", - "//tensorflow/compiler/xla:literal", - "//tensorflow/compiler/xla:shape_util", - "//tensorflow/compiler/xla:statusor", - "//tensorflow/compiler/xla:test", - "//tensorflow/compiler/xla:types", - "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/client:global_data", - "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/compiler/xla/client:xla_builder", - "//tensorflow/compiler/xla/tests:client_library_test_base", - "//tensorflow/compiler/xla/tests:literal_test_util", - "//tensorflow/compiler/xla/tests:xla_internal_test_main", - "//tensorflow/core:lib", - "//tensorflow/core:test", - ], -) - cc_library( name = "util", srcs = ["util.cc"], @@ -186,42 +82,3 @@ cc_library( "@com_google_absl//absl/types:span", ], ) - -xla_test( - name = "util_test", - srcs = ["util_test.cc"], - deps = [ - ":batch_dot", - ":util", - "//tensorflow/compiler/xla:array2d", - "//tensorflow/compiler/xla:literal", - "//tensorflow/compiler/xla:shape_util", - "//tensorflow/compiler/xla:statusor", - "//tensorflow/compiler/xla:test", - "//tensorflow/compiler/xla:types", - "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/client:global_data", - "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/compiler/xla/tests:client_library_test_base", - "//tensorflow/compiler/xla/tests:literal_test_util", - "//tensorflow/compiler/xla/tests:xla_internal_test_main", - "//tensorflow/core:lib", - "//tensorflow/core:test", - ], -) - -cc_library( - name = "while_loop", - srcs = ["while_loop.cc"], - hdrs = ["while_loop.h"], - deps = [ - ":util", - "//tensorflow/compiler/xla:shape_util", - "//tensorflow/compiler/xla:status_macros", - "//tensorflow/compiler/xla:statusor", - "//tensorflow/compiler/xla/client:xla_builder", - "//tensorflow/compiler/xla/client:xla_computation", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/types:span", - ], -) diff --git a/tensorflow/compiler/tf2xla/lib/batch_dot.cc b/tensorflow/compiler/tf2xla/lib/batch_dot.cc deleted file mode 100644 index 5400e8834cb9807f6dd71abe7789b2672e29e905..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/tf2xla/lib/batch_dot.cc +++ /dev/null @@ -1,115 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/compiler/tf2xla/lib/batch_dot.h" - -#include -#include - -#include "tensorflow/compiler/xla/client/xla_builder.h" -#include "tensorflow/compiler/xla/shape_util.h" -#include "tensorflow/compiler/xla/status_macros.h" -#include "tensorflow/compiler/xla/statusor.h" -#include "tensorflow/core/lib/core/errors.h" - -namespace tensorflow { - -xla::XlaOp BatchDot(xla::XlaOp x, xla::XlaOp y, bool transpose_x, - bool transpose_y, bool conjugate_x, bool conjugate_y, - xla::PrecisionConfig::Precision precision) { - xla::XlaBuilder* builder = x.builder(); - return builder->ReportErrorOrReturn([&]() -> xla::StatusOr { - TF_ASSIGN_OR_RETURN(xla::Shape x_shape, builder->GetShape(x)); - TF_ASSIGN_OR_RETURN(xla::Shape y_shape, builder->GetShape(y)); - - // Check that both tensors have the same number of dimensions. There must be - // at least two (the batch dimensions can be empty). - if (xla::ShapeUtil::Rank(x_shape) != xla::ShapeUtil::Rank(y_shape)) { - return errors::InvalidArgument( - "Arguments to BatchedDot have different ranks: ", - xla::ShapeUtil::HumanString(x_shape), " vs. ", - xla::ShapeUtil::HumanString(y_shape)); - } - const int ndims = xla::ShapeUtil::Rank(x_shape); - if (ndims < 2) { - return errors::InvalidArgument( - "Arguments to BatchedDot must have rank >= 2: ", ndims); - } - - // The batch dimensions must be equal and the matrix dimensions must be - // valid. - std::vector batch_dimension_numbers; - for (int i = 0; i < ndims - 2; ++i) { - if (x_shape.dimensions(i) != y_shape.dimensions(i)) { - return errors::InvalidArgument( - "Dimension ", i, " of inputs to BatchedDot must be equal: ", - xla::ShapeUtil::HumanString(x_shape), " vs ", - xla::ShapeUtil::HumanString(y_shape)); - } - batch_dimension_numbers.push_back(i); - } - - int x_inner_dim = transpose_x ? (ndims - 2) : (ndims - 1); - int y_inner_dim = transpose_y ? (ndims - 1) : (ndims - 2); - if (x_shape.dimensions(x_inner_dim) != y_shape.dimensions(y_inner_dim)) { - return errors::InvalidArgument( - "Dimensions ", x_inner_dim, " and ", y_inner_dim, - " of arguments to BatchedDot must be equal: ", - xla::ShapeUtil::HumanString(x_shape), " transpose: ", transpose_x, - " vs. ", xla::ShapeUtil::HumanString(y_shape), - " transpose: ", transpose_y); - } - - // Check for zero lhs/rhs dim size. - if (xla::ShapeUtil::IsZeroElementArray(x_shape) || - xla::ShapeUtil::IsZeroElementArray(y_shape)) { - std::vector dimensions(batch_dimension_numbers.size()); - for (int i = 0; i < batch_dimension_numbers.size(); ++i) { - dimensions[i] = x_shape.dimensions(batch_dimension_numbers[i]); - } - int x_outer_dim = transpose_x ? (ndims - 1) : (ndims - 2); - int y_outer_dim = transpose_y ? (ndims - 2) : (ndims - 1); - dimensions.push_back(x_shape.dimensions(x_outer_dim)); - dimensions.push_back(y_shape.dimensions(y_outer_dim)); - return xla::Broadcast( - xla::ConstantLiteral(builder, - xla::LiteralUtil::Zero(x_shape.element_type())), - dimensions); - } - - if (x_shape.element_type() == xla::C64 && conjugate_x) { - x = xla::Conj(x); - } - if (y_shape.element_type() == xla::C64 && conjugate_y) { - y = xla::Conj(y); - } - - xla::PrecisionConfig precision_proto; - precision_proto.add_operand_precision(precision); - precision_proto.add_operand_precision(precision); - - xla::DotDimensionNumbers dot_dnums; - dot_dnums.add_lhs_contracting_dimensions(x_inner_dim); - dot_dnums.add_rhs_contracting_dimensions(y_inner_dim); - for (auto batch_dimension_number : batch_dimension_numbers) { - dot_dnums.add_lhs_batch_dimensions(batch_dimension_number); - dot_dnums.add_rhs_batch_dimensions(batch_dimension_number); - } - - return xla::DotGeneral(x, y, dot_dnums, &precision_proto); - }); -} - -} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/lib/batch_dot.h b/tensorflow/compiler/tf2xla/lib/batch_dot.h deleted file mode 100644 index 6edd63a4d3b66c21aa4cce8c9f36eef0dc363cd8..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/tf2xla/lib/batch_dot.h +++ /dev/null @@ -1,54 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_COMPILER_TF2XLA_LIB_BATCH_DOT_H_ -#define TENSORFLOW_COMPILER_TF2XLA_LIB_BATCH_DOT_H_ - -#include "tensorflow/compiler/xla/client/xla_builder.h" -#include "tensorflow/compiler/xla/xla_data.pb.h" - -namespace tensorflow { - -// Multiplies slices of two tensors in batches. - -// Multiplies all slices of `Tensor` `x` and `y` (each slice can be -// viewed as an element of a batch), and arranges the individual results -// in a single output tensor of the same batch size. Each of the -// individual slices can optionally be transposed before multiplication by -// setting the `transpose_x` or `transpose_y` flag to `true`. Similarly, each -// can be elementwise-complex-conjugated by setting the `conjugate_x` or -// `conjugate_y` flag to `true`. To apply a Hermitian adjoint to `x`, set both -// `transpose_x` and `conjugate_x` to `true`, and analogously for `y`. -// -// The input tensors `x` and `y` are 2-D or higher with shape `[..., r_x, c_x]` -// and `[..., r_y, c_y]`. -// -// The output tensor is 2-D or higher with shape `[..., r_o, c_o]`, where: -// -// r_o = c_x if transpose_x else r_x -// c_o = r_y if transpose_y else c_y -// -// It is computed as: -// -// output[..., :, :] = matrix(x[..., :, :]) * matrix(y[..., :, :]) -xla::XlaOp BatchDot( - xla::XlaOp x, xla::XlaOp y, bool transpose_x = false, - bool transpose_y = false, bool conjugate_x = false, - bool conjugate_y = false, - xla::PrecisionConfig::Precision precision = xla::PrecisionConfig::DEFAULT); - -} // namespace tensorflow - -#endif // TENSORFLOW_COMPILER_TF2XLA_LIB_BATCH_DOT_H_ diff --git a/tensorflow/compiler/tf2xla/lib/scatter.cc b/tensorflow/compiler/tf2xla/lib/scatter.cc index 2b1c2ced925d9fee7392986015a6e716a94d356f..688056791f9750e6b22df4b2cd4643de0b780651 100644 --- a/tensorflow/compiler/tf2xla/lib/scatter.cc +++ b/tensorflow/compiler/tf2xla/lib/scatter.cc @@ -20,7 +20,6 @@ limitations under the License. #include "absl/types/span.h" #include "tensorflow/compiler/tf2xla/lib/util.h" -#include "tensorflow/compiler/tf2xla/lib/while_loop.h" #include "tensorflow/compiler/xla/client/lib/arithmetic.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/literal.h" diff --git a/tensorflow/compiler/tf2xla/lib/util.cc b/tensorflow/compiler/tf2xla/lib/util.cc index 804671fbc75b0a5a6e04b204822b6f084013cd8b..c0bd172d17c192435ba8ee196f9def0491c0bf5c 100644 --- a/tensorflow/compiler/tf2xla/lib/util.cc +++ b/tensorflow/compiler/tf2xla/lib/util.cc @@ -113,36 +113,6 @@ xla::XlaOp IntegerLiteral(xla::XlaBuilder* builder, xla::PrimitiveType type, return xla::ConstantLiteral(builder, literal); } -xla::XlaOp SliceInMinorDims(xla::XlaOp x, absl::Span start, - absl::Span end) { - xla::XlaBuilder* builder = x.builder(); - return builder->ReportErrorOrReturn([&]() -> xla::StatusOr { - TF_RET_CHECK(start.size() == end.size()); - int64 n_minor_dims = start.size(); - - TF_ASSIGN_OR_RETURN(xla::Shape shape, builder->GetShape(x)); - - const int64 n_dims = xla::ShapeUtil::Rank(shape); - TF_RET_CHECK(n_minor_dims <= n_dims); - auto major_dims = xla::AsInt64Slice(shape.dimensions()) - .subspan( - /*pos=*/0, - /*len=*/n_dims - n_minor_dims); - - // Prepends 0s in the major dim - std::vector padded_start(n_dims, 0); - std::copy(start.begin(), start.end(), - padded_start.begin() + major_dims.size()); - - // Prepends the shape of the major dims. - std::vector padded_end(n_dims); - std::copy(major_dims.begin(), major_dims.end(), padded_end.begin()); - std::copy(end.begin(), end.end(), padded_end.begin() + major_dims.size()); - - std::vector strides(n_dims, 1); - return xla::Slice(x, padded_start, padded_end, strides); - }); -} std::vector ConcatVectors(absl::Span xs, absl::Span ys) { @@ -152,100 +122,4 @@ std::vector ConcatVectors(absl::Span xs, return output; } -xla::XlaOp DynamicSliceInMinorDims(xla::XlaOp x, - absl::Span starts, - absl::Span sizes) { - xla::XlaBuilder* builder = x.builder(); - return builder->ReportErrorOrReturn([&]() -> xla::StatusOr { - TF_ASSIGN_OR_RETURN(xla::Shape shape, builder->GetShape(x)); - const int64 n_dims = xla::ShapeUtil::Rank(shape); - int64 n_minor_dims = starts.size(); - TF_RET_CHECK(n_minor_dims == sizes.size()); - TF_RET_CHECK(n_minor_dims <= n_dims); - auto major_dims = xla::AsInt64Slice(shape.dimensions()) - .subspan( - /*pos=*/0, - /*len=*/n_dims - sizes.size()); - auto padded_starts = PrependZerosInMajorDims(x, starts); - auto padded_sizes = ConcatVectors(major_dims, sizes); - return xla::DynamicSlice(x, padded_starts, padded_sizes); - }); -} - -xla::XlaOp UpdateSlice(xla::XlaOp x, xla::XlaOp update, - absl::Span start) { - xla::XlaBuilder* builder = x.builder(); - return builder->ReportErrorOrReturn([&]() -> xla::StatusOr { - // TODO(phawkins): make int64 work on all backends, remove the int32 cast. - std::vector start_as_int32(start.begin(), start.end()); - auto start_constant = xla::ConstantR1(builder, start_as_int32); - TF_ASSIGN_OR_RETURN(xla::Shape shape, builder->GetShape(x)); - const int64 n_dims = xla::ShapeUtil::Rank(shape); - TF_ASSIGN_OR_RETURN(xla::Shape start_constant_shape, - builder->GetShape(start_constant)); - const int64 start_length = - xla::ShapeUtil::GetDimension(start_constant_shape, -1); - TF_RET_CHECK(start_length == n_dims); - return xla::DynamicUpdateSlice(x, update, start_constant); - }); -} - -xla::XlaOp UpdateSliceInMinorDims(xla::XlaOp x, xla::XlaOp update, - absl::Span start) { - xla::XlaBuilder* builder = x.builder(); - return builder->ReportErrorOrReturn([&]() -> xla::StatusOr { - TF_ASSIGN_OR_RETURN(xla::Shape shape, builder->GetShape(x)); - const int64 n_dims = xla::ShapeUtil::Rank(shape); - const int64 n_minor_dims = start.size(); - TF_RET_CHECK(n_minor_dims <= n_dims); - std::vector padded_start(n_dims, 0); - std::copy(start.begin(), start.end(), - padded_start.begin() + (n_dims - n_minor_dims)); - return UpdateSlice(x, update, padded_start); - }); -} - -xla::XlaOp DynamicUpdateSliceInMinorDims(xla::XlaOp x, xla::XlaOp update, - absl::Span starts) { - auto padded_starts = PrependZerosInMajorDims(x, starts); - return xla::DynamicUpdateSlice(x, update, padded_starts); -} - -xla::XlaOp PrependZerosInMajorDims(xla::XlaOp x, - absl::Span starts) { - xla::XlaBuilder* builder = x.builder(); - return builder->ReportErrorOrReturn([&]() -> xla::StatusOr { - TF_ASSIGN_OR_RETURN(xla::Shape shape, builder->GetShape(x)); - const int64 n_dims = xla::ShapeUtil::Rank(shape); - auto zero = xla::Reshape(xla::ConstantR0(builder, 0), {1}); - std::vector padded_starts(n_dims, zero); - for (int i = 0; i < starts.size(); ++i) { - padded_starts[n_dims - starts.size() + i] = xla::Reshape(starts[i], {1}); - } - return xla::ConcatInDim(builder, padded_starts, 0); - }); -} - -xla::XlaOp TransposeInMinorDims(xla::XlaOp x) { - xla::XlaBuilder* builder = x.builder(); - return builder->ReportErrorOrReturn([&]() -> xla::StatusOr { - TF_ASSIGN_OR_RETURN(xla::Shape shape, builder->GetShape(x)); - const int64 n_dims = xla::ShapeUtil::Rank(shape); - TF_RET_CHECK(n_dims >= 2); - std::vector permutation(n_dims); - std::iota(permutation.begin(), permutation.end(), 0); - std::swap(permutation[n_dims - 1], permutation[n_dims - 2]); - return xla::Transpose(x, permutation); - }); -} - -xla::XlaOp MaybeConjugate(xla::XlaOp x, bool conjugate) { - xla::XlaBuilder* builder = x.builder(); - return builder->ReportErrorOrReturn([&]() -> xla::StatusOr { - TF_ASSIGN_OR_RETURN(xla::Shape shape, builder->GetShape(x)); - auto perform_conj = shape.element_type() == xla::C64 && conjugate; - return perform_conj ? xla::Conj(x) : x; - }); -} - } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/lib/util.h b/tensorflow/compiler/tf2xla/lib/util.h index 80e9e5b002d49581209e608b98606e02709c5876..aec8061cb4322b8d315b6cdc80c7fff1e0cb4cb1 100644 --- a/tensorflow/compiler/tf2xla/lib/util.h +++ b/tensorflow/compiler/tf2xla/lib/util.h @@ -38,44 +38,10 @@ xla::XlaOp PrependZerosInMajorDims(xla::XlaOp x, xla::XlaOp IntegerLiteral(xla::XlaBuilder* builder, xla::PrimitiveType type, int64 value); -// Builds a vector of zeros of length rank(x) with the last values being -// those in `starts`. -xla::XlaOp PrependZerosInMajorDims(xla::XlaOp x, - absl::Span starts); - -// Performs a slice in the minor dimensions of a Tensor. -xla::XlaOp SliceInMinorDims(xla::XlaOp x, absl::Span start, - absl::Span end); - // Returns the concatenation of `xs` and `ys`. std::vector ConcatVectors(absl::Span xs, absl::Span ys); -// Performs a dynamic slice in the minor dimensions of a Tensor. -xla::XlaOp DynamicSliceInMinorDims(xla::XlaOp x, - absl::Span starts, - absl::Span sizes); - -// Updates a slice of 'x', i.e., -// x[start[0], ..., start[n]] = update -xla::XlaOp UpdateSlice(xla::XlaOp x, xla::XlaOp update, - absl::Span start); - -// Updates a slice of 'x', where 'start' contains a list of minor dimensions: -// x[..., start[0], ..., start[n]] = update -xla::XlaOp UpdateSliceInMinorDims(xla::XlaOp x, xla::XlaOp update, - absl::Span start); - -xla::XlaOp DynamicUpdateSliceInMinorDims(xla::XlaOp x, xla::XlaOp update, - absl::Span starts); - -// Transposes a stack of matrices `x` by swapping the last two dimensions. -xla::XlaOp TransposeInMinorDims(xla::XlaOp x); - -// Applies a complex conjugation operation if `a` is complex and `conjugate_a` -// is true, otherwise returns its argument. -xla::XlaOp MaybeConjugate(xla::XlaOp x, bool conjugate); - } // namespace tensorflow #endif // TENSORFLOW_COMPILER_TF2XLA_LIB_UTIL_H_ diff --git a/tensorflow/compiler/tf2xla/resource_operation_table.cc b/tensorflow/compiler/tf2xla/resource_operation_table.cc index 72b240996fb4d9dcb5f5dfd919da618cbae08c16..ff9f1b9ccba2c4f3307890d5aac4ddb6cfaafcd9 100644 --- a/tensorflow/compiler/tf2xla/resource_operation_table.cc +++ b/tensorflow/compiler/tf2xla/resource_operation_table.cc @@ -65,6 +65,7 @@ CreateResourceOpInfoMap() { add("ResourceApplyFtrlV2" , kReadWrite, kVariable); add("ResourceApplyGradientDescent" , kReadWrite, kVariable); add("ResourceApplyMomentum" , kReadWrite, kVariable); + add("ResourceApplyKerasMomentum" , kReadWrite, kVariable); add("ResourceApplyPowerSign" , kReadWrite, kVariable); add("ResourceApplyProximalAdagrad" , kReadWrite, kVariable); add("ResourceApplyProximalGradientDescent" , kReadWrite, kVariable); diff --git a/tensorflow/compiler/tf2xla/side_effect_util.cc b/tensorflow/compiler/tf2xla/side_effect_util.cc index b233e6b2c28e1968bb74901fc684e808ae45ab60..b62f8e9115229ac35c657d374c68336f1168ff77 100644 --- a/tensorflow/compiler/tf2xla/side_effect_util.cc +++ b/tensorflow/compiler/tf2xla/side_effect_util.cc @@ -24,6 +24,8 @@ const char kXlaTokenInputNodesAttrName[] = "_xla_token_input_nodes"; const char kXlaTokenArgNodeName[] = "_xla_token_arg_node"; +const char kXlaHasHostTransferAttrName[] = "_xla_has_host_transfer"; + std::set CalculateTokenInputsForOutputToken(const Graph& g) { std::set results; Node* first_side_effecting_node_on_path = nullptr; diff --git a/tensorflow/compiler/tf2xla/side_effect_util.h b/tensorflow/compiler/tf2xla/side_effect_util.h index f22ddb2f58e1fa5c10ca0fdb956d9136942388b7..7081b362c36c4785164b29003a5f89cd73bcf3af 100644 --- a/tensorflow/compiler/tf2xla/side_effect_util.h +++ b/tensorflow/compiler/tf2xla/side_effect_util.h @@ -35,6 +35,9 @@ extern const char kXlaTokenInputNodesAttrName[]; // node has side-effect dependency on current graph's token input. extern const char kXlaTokenArgNodeName[]; +// This node have XlaRecvAtHost/XlaSendFromHost in its associated functions. +extern const char kXlaHasHostTransferAttrName[]; + // Calculates side-effect dependencies for the graph's token output. // Returns a set of node names representing these dependencies. std::set CalculateTokenInputsForOutputToken(const Graph& g); diff --git a/tensorflow/compiler/tf2xla/tf2xla_test.cc b/tensorflow/compiler/tf2xla/tf2xla_test.cc index ab26d939ccba75ce58609ffd71c7ccadbe90cfa8..24afe595b18b823818bd8fe65bc599af8bce040a 100644 --- a/tensorflow/compiler/tf2xla/tf2xla_test.cc +++ b/tensorflow/compiler/tf2xla/tf2xla_test.cc @@ -91,7 +91,7 @@ TEST(ConvertGraphDefToXla, Sum) { client->ExecuteAndTransfer(computation, {x_global.get(), y_global.get()}); TF_EXPECT_OK(result_or.status()); xla::Literal result = std::move(result_or.ValueOrDie()); - EXPECT_EQ("(s32[]) (\n42\n)", result.ToString()); + EXPECT_EQ("(\ns32[] 42\n)", result.ToString()); config.mutable_feed(0)->mutable_id()->set_output_index( 123); /* invalid output_index */ diff --git a/tensorflow/compiler/tf2xla/tf2xla_util.cc b/tensorflow/compiler/tf2xla/tf2xla_util.cc index cc81772e8c5da710bc733f7e4f5fe820b2c2d110..18d87727c500619bf386be7d8c7085724f44aba3 100644 --- a/tensorflow/compiler/tf2xla/tf2xla_util.cc +++ b/tensorflow/compiler/tf2xla/tf2xla_util.cc @@ -364,6 +364,7 @@ Status AddPlaceholdersForFeeds( GraphDef gd; *gd.mutable_versions() = graph_def->versions(); *gd.add_node() = *existing; + MergeDebugInfo(NodeDebugInfo(*existing), gd.mutable_node(0)); TF_RETURN_IF_ERROR( AddDefaultAttrsToGraphDef(&gd, *op_registry, 0 /*node_offset*/)); @@ -390,6 +391,7 @@ Status AddPlaceholdersForFeeds( // in this code. for (auto it = placeholder_info.begin(); it != placeholder_info.end(); ++it) { const PlaceholderInfo& info = it->second; + // TODO(shikharagarwal): Add original node information. NodeDef* d = graph_def->add_node(); d->set_name(info.placeholder_name); d->set_op("PlaceholderV2"); @@ -557,6 +559,12 @@ bool HasAssociatedFunction(const NodeDef& node_def, return true; } + if (node_def.op() == "XlaHostCompute") { + // XlaHostCompute has "shape_inference_graph" func attr, but that's not + // related to graph execution. + return false; + } + for (const auto& iter : node_def.attr()) { if (iter.second.has_func()) { return true; @@ -578,6 +586,9 @@ std::vector GetAssociatedFunctions( // This is a SymbolicGradient op. AttrValueMap attrs(node.attrs().begin(), node.attrs().end()); results.emplace_back(AssociatedFunctionInfo::SymbolicGradient(op, attrs)); + } else if (node.type_string() == "XlaHostCompute") { + // XlaHostCompute has "shape_inference_graph" func attr, but that's not + // related to graph execution. } else { // Collect all function attrs for the node. for (auto& iter : node.attrs()) { @@ -599,7 +610,9 @@ Status RewriteAssociatedFunction( switch (associated_function.type()) { case AssociatedFunctionInfo::kFunctionCallNode: { // Change this node to call the new function. - NodeDefBuilder builder(node->name(), rewritten_function_name, fld); + NodeDebugInfo debug_info(*node); + NodeDefBuilder builder(node->name(), rewritten_function_name, fld, + &debug_info); for (auto attr : node->attrs()) { builder.Attr(attr.first, attr.second); } diff --git a/tensorflow/compiler/tf2xla/xla_compiled_cpu_function.h b/tensorflow/compiler/tf2xla/xla_compiled_cpu_function.h index a1d359e97c4fad3ca74d44a358cba0e8190cdc22..c7341cf8b9e8d7a06fd304ae8766420d20f0c16e 100644 --- a/tensorflow/compiler/tf2xla/xla_compiled_cpu_function.h +++ b/tensorflow/compiler/tf2xla/xla_compiled_cpu_function.h @@ -206,8 +206,14 @@ class XlaCompiledCpuFunction { // // Aliasing of argument and result buffers is not allowed, and results in // undefined behavior. - void set_arg_data(size_t index, void* data) { - buffer_table_[arg_index_table_[index]] = data; + void set_arg_data(size_t index, const void* data) { + // The const_cast is safe because the generated code does not write to arg + // buffers. + // + // buffer_table_ contains pointers to buffers that _will_ be written to by + // generated code so it would be misleading to make buffer_table_ a `const + // void**`. + buffer_table_[arg_index_table_[index]] = const_cast(data); } // ------------------------------ diff --git a/tensorflow/compiler/xla/BUILD b/tensorflow/compiler/xla/BUILD index 4360e0857964b0ac63fc887e269b04a4b00d854a..722d1376687efa1c04158e3fd9ce539aac9d0122 100644 --- a/tensorflow/compiler/xla/BUILD +++ b/tensorflow/compiler/xla/BUILD @@ -109,7 +109,7 @@ cc_library( name = "status_macros", srcs = ["status_macros.cc"], hdrs = ["status_macros.h"], - visibility = [":friends"], + visibility = ["//visibility:public"], deps = [ ":statusor", ":types", @@ -224,6 +224,7 @@ cc_library( name = "shape_util", srcs = [ "index_util.cc", + "layout.cc", "layout_util.cc", "primitive_util.cc", "shape.cc", @@ -231,6 +232,7 @@ cc_library( ], hdrs = [ "index_util.h", + "layout.h", "layout_util.h", "primitive_util.h", "shape.h", @@ -290,6 +292,22 @@ tf_cc_test( ], ) +tf_cc_test( + name = "primitive_util_test", + srcs = ["primitive_util_test.cc"], + deps = [ + ":shape_util", + ":status_macros", + ":test", + ":test_helpers", + ":types", + ":util", + ":xla_data_proto", + "//tensorflow/core:lib", + "//tensorflow/core:test_main", + ], +) + tf_cc_test( name = "layout_util_test", srcs = ["layout_util_test.cc"], @@ -301,6 +319,22 @@ tf_cc_test( ], ) +tf_cc_test( + name = "layout_test", + srcs = ["layout_test.cc"], + deps = [ + ":shape_util", + ":status_macros", + ":test", + ":test_helpers", + ":types", + ":util", + ":xla_data_proto", + "//tensorflow/core:test_main", + "@com_google_absl//absl/strings", + ], +) + tf_cc_test( name = "index_util_test", srcs = ["index_util_test.cc"], @@ -575,6 +609,7 @@ cc_library( ":types", ":util", ":xla_data_proto", + "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", "@com_google_absl//absl/memory", @@ -705,7 +740,6 @@ cc_library( "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service:hlo_evaluator", "//tensorflow/compiler/xla/service:shape_inference", - "//tensorflow/compiler/xla/service/cpu:runtime_single_threaded_matmul", "//tensorflow/core:lib", "@com_google_absl//absl/memory", "@com_google_absl//absl/types:span", diff --git a/tensorflow/compiler/xla/client/BUILD b/tensorflow/compiler/xla/client/BUILD index fe99564d3c671cd7890e1fa26fcd2e3384972983..e61d9d2520366f3f21a18b6c62ba924fba23308a 100644 --- a/tensorflow/compiler/xla/client/BUILD +++ b/tensorflow/compiler/xla/client/BUILD @@ -3,7 +3,7 @@ licenses(["notice"]) # Apache 2.0 -package(default_visibility = [":friends"]) +package(default_visibility = ["//visibility:public"]) package_group( name = "friends", diff --git a/tensorflow/compiler/xla/client/client.cc b/tensorflow/compiler/xla/client/client.cc index 74b76f929949d3300a5d0ff45d5fa4cd9f162642..43127cae1e5d81521003a28288e27d291e33c9b9 100644 --- a/tensorflow/compiler/xla/client/client.cc +++ b/tensorflow/compiler/xla/client/client.cc @@ -186,7 +186,7 @@ StatusOr Client::ComputeConstant(const XlaComputation& computation, ComputeConstantGraphRequest request; *request.mutable_computation() = computation.proto(); if (output_layout != nullptr) { - *request.mutable_output_layout() = *output_layout; + *request.mutable_output_layout() = output_layout->ToProto(); } ComputeConstantResponse response; diff --git a/tensorflow/compiler/xla/client/lib/BUILD b/tensorflow/compiler/xla/client/lib/BUILD index c5733bc66deb8d55a9186ad1893abaf17ed6909e..970f00759f630f30f1c1321231fd9e0199026142 100644 --- a/tensorflow/compiler/xla/client/lib/BUILD +++ b/tensorflow/compiler/xla/client/lib/BUILD @@ -1,5 +1,7 @@ # Common computation builders for XLA. +load("//tensorflow/compiler/xla/tests:build_defs.bzl", "generate_backend_suites", "xla_test") + licenses(["notice"]) # Apache 2.0 package(default_visibility = ["//tensorflow/compiler/xla/client:friends"]) @@ -13,9 +15,6 @@ filegroup( ]), ) -load("//tensorflow/compiler/xla/tests:build_defs.bzl", "xla_test") -load("//tensorflow/compiler/xla/tests:build_defs.bzl", "generate_backend_suites") - # Generate test_suites for all backends, named "${backend}_tests". generate_backend_suites() @@ -35,6 +34,48 @@ cc_library( ], ) +cc_library( + name = "cholesky", + srcs = ["cholesky.cc"], + hdrs = ["cholesky.h"], + deps = [ + ":math", + "//tensorflow/compiler/xla:literal", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/client/lib:constants", + "//tensorflow/compiler/xla/client/lib:loops", + "//tensorflow/compiler/xla/client/lib:matrix", + "//tensorflow/compiler/xla/client/lib:slicing", + "//tensorflow/compiler/xla/client/lib:triangular_solve", + "//tensorflow/core:lib", + ], +) + +xla_test( + name = "cholesky_test", + srcs = ["cholesky_test.cc"], + tags = ["optonly"], + deps = [ + ":arithmetic", + ":cholesky", + ":matrix", + "//tensorflow/compiler/xla:array2d", + "//tensorflow/compiler/xla:literal", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/tests:client_library_test_base", + "//tensorflow/compiler/xla/tests:literal_test_util", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/core:test", + ], +) + cc_library( name = "constants", srcs = ["constants.cc"], @@ -75,6 +116,22 @@ cc_library( ], ) +cc_library( + name = "loops", + srcs = ["loops.cc"], + hdrs = ["loops.h"], + deps = [ + ":constants", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/client:xla_computation", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", + ], +) + cc_library( name = "math", srcs = ["math.cc"], @@ -104,13 +161,17 @@ xla_test( ) cc_library( - name = "numeric", - srcs = ["numeric.cc"], - hdrs = ["numeric.h"], + name = "matrix", + srcs = ["matrix.cc"], + hdrs = ["matrix.h"], deps = [ ":arithmetic", ":constants", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:xla_builder", "@com_google_absl//absl/types:span", @@ -118,11 +179,12 @@ cc_library( ) xla_test( - name = "numeric_test", - srcs = ["numeric_test.cc"], + name = "matrix_test", + srcs = ["matrix_test.cc"], tags = ["enable_for_xla_interpreter"], deps = [ - ":numeric", + ":matrix", + ":slicing", "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:xla_data_proto", @@ -172,6 +234,80 @@ cc_library( ], ) +cc_library( + name = "qr", + srcs = ["qr.cc"], + hdrs = ["qr.h"], + deps = [ + ":arithmetic", + ":constants", + ":loops", + ":math", + ":matrix", + ":slicing", + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/core:lib", + ], +) + +xla_test( + name = "qr_test", + srcs = ["qr_test.cc"], + tags = ["optonly"], + deps = [ + ":matrix", + ":qr", + "//tensorflow/compiler/xla:array2d", + "//tensorflow/compiler/xla:array3d", + "//tensorflow/compiler/xla:literal", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/tests:client_library_test_base", + "//tensorflow/compiler/xla/tests:literal_test_util", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/core:test", + ], +) + +cc_library( + name = "slicing", + srcs = ["slicing.cc"], + hdrs = ["slicing.h"], + deps = [ + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/client:xla_builder", + "@com_google_absl//absl/types:span", + ], +) + +xla_test( + name = "slicing_test", + srcs = ["slicing_test.cc"], + tags = ["enable_for_xla_interpreter"], + deps = [ + ":slicing", + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/tests:client_library_test_base", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + ], +) + cc_library( name = "sorting", srcs = ["sorting.cc"], @@ -200,6 +336,34 @@ xla_test( ], ) +cc_library( + name = "quantize", + hdrs = ["quantize.h"], + deps = [ + ":constants", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/core:lib", + ], +) + +xla_test( + name = "quantize_test", + srcs = ["quantize_test.cc"], + tags = ["enable_for_xla_interpreter"], + deps = [ + ":quantize", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/tests:client_library_test_base", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + ], +) + cc_library( name = "testing", srcs = ["testing.cc"], @@ -221,3 +385,48 @@ cc_library( "@com_google_absl//absl/strings", ], ) + +cc_library( + name = "triangular_solve", + srcs = ["triangular_solve.cc"], + hdrs = ["triangular_solve.h"], + deps = [ + "//tensorflow/compiler/xla:literal", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/client:xla_computation", + "//tensorflow/compiler/xla/client/lib:constants", + "//tensorflow/compiler/xla/client/lib:math", + "//tensorflow/compiler/xla/client/lib:matrix", + "//tensorflow/compiler/xla/client/lib:slicing", + "//tensorflow/core:lib", + ], +) + +xla_test( + name = "triangular_solve_test", + srcs = ["triangular_solve_test.cc"], + tags = ["noasan"], # sometimes times out, http://b/78650012 + deps = [ + ":triangular_solve", + "//tensorflow/compiler/xla:array2d", + "//tensorflow/compiler/xla:literal", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/client:global_data", + "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/tests:client_library_test_base", + "//tensorflow/compiler/xla/tests:literal_test_util", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/core:lib", + "//tensorflow/core:test", + ], +) diff --git a/tensorflow/compiler/tf2xla/lib/cholesky.cc b/tensorflow/compiler/xla/client/lib/cholesky.cc similarity index 61% rename from tensorflow/compiler/tf2xla/lib/cholesky.cc rename to tensorflow/compiler/xla/client/lib/cholesky.cc index ab3d0a566839343828d176d9a46672824e425613..fd98049968491d80b9717a2de1f34997bd9d18c1 100644 --- a/tensorflow/compiler/tf2xla/lib/cholesky.cc +++ b/tensorflow/compiler/xla/client/lib/cholesky.cc @@ -13,16 +13,17 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/tf2xla/lib/cholesky.h" +#include "tensorflow/compiler/xla/client/lib/cholesky.h" #include #include -#include "tensorflow/compiler/tf2xla/lib/batch_dot.h" -#include "tensorflow/compiler/tf2xla/lib/triangular_solve.h" -#include "tensorflow/compiler/tf2xla/lib/util.h" -#include "tensorflow/compiler/tf2xla/lib/while_loop.h" #include "tensorflow/compiler/xla/client/lib/constants.h" +#include "tensorflow/compiler/xla/client/lib/loops.h" +#include "tensorflow/compiler/xla/client/lib/math.h" +#include "tensorflow/compiler/xla/client/lib/matrix.h" +#include "tensorflow/compiler/xla/client/lib/slicing.h" +#include "tensorflow/compiler/xla/client/lib/triangular_solve.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/shape_util.h" @@ -30,7 +31,7 @@ limitations under the License. #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/core/lib/core/errors.h" -namespace tensorflow { +namespace xla { namespace { @@ -49,26 +50,25 @@ namespace { // l[..., j+1:, j] = (a[..., j+1:, j] - np.dot(l[..., j+1:, :j], row_t)) / // l[..., j, j] // return l -xla::XlaOp CholeskyUnblocked(xla::XlaOp a, - xla::PrecisionConfig::Precision precision) { - xla::XlaBuilder* builder = a.builder(); - return builder->ReportErrorOrReturn([&]() -> xla::StatusOr { - TF_ASSIGN_OR_RETURN(xla::Shape a_shape, builder->GetShape(a)); - const int n_dims = xla::ShapeUtil::Rank(a_shape); - const int64 n = xla::ShapeUtil::GetDimension(a_shape, -1); - auto major_dims = xla::AsInt64Slice(a_shape.dimensions()) +XlaOp CholeskyUnblocked(XlaOp a, PrecisionConfig::Precision precision) { + XlaBuilder* builder = a.builder(); + return builder->ReportErrorOrReturn([&]() -> StatusOr { + TF_ASSIGN_OR_RETURN(Shape a_shape, builder->GetShape(a)); + const int n_dims = ShapeUtil::Rank(a_shape); + const int64 n = ShapeUtil::GetDimension(a_shape, -1); + auto major_dims = AsInt64Slice(a_shape.dimensions()) .subspan( /*pos=*/0, /*len=*/n_dims - 2); - xla::XlaOp l = xla::ZerosLike(a); + XlaOp l = ZerosLike(a); // Construct the for loop body to iterate over rows. - auto body_fn = [&](xla::XlaOp i, absl::Span loop_vars, - xla::XlaBuilder* body_builder) - -> xla::StatusOr> { - xla::Shape col_shape; - xla::Shape row_shape; + auto body_fn = + [&](XlaOp i, absl::Span loop_vars, + XlaBuilder* body_builder) -> StatusOr> { + Shape col_shape; + Shape row_shape; for (int64 d : major_dims) { row_shape.add_dimensions(d); col_shape.add_dimensions(d); @@ -76,59 +76,49 @@ xla::XlaOp CholeskyUnblocked(xla::XlaOp a, row_shape.add_dimensions(1); row_shape.add_dimensions(n); row_shape.set_element_type(a_shape.element_type()); - auto mask_zeros_row = xla::Zeros(body_builder, row_shape); + auto mask_zeros_row = Zeros(body_builder, row_shape); col_shape.add_dimensions(n); col_shape.add_dimensions(1); col_shape.set_element_type(a_shape.element_type()); - auto mask_zeros_col = xla::Zeros(body_builder, col_shape); + auto mask_zeros_col = Zeros(body_builder, col_shape); std::vector mask_vector(n); std::iota(mask_vector.begin(), mask_vector.end(), 0); - auto mask_range = xla::ConstantR1(body_builder, mask_vector); + auto mask_range = ConstantR1(body_builder, mask_vector); auto mask_range_row = - xla::Broadcast(xla::Reshape(mask_range, {0}, {1, n}), major_dims); + Broadcast(Reshape(mask_range, {0}, {1, n}), major_dims); auto mask_range_col = - xla::Broadcast(xla::Reshape(mask_range, {0}, {n, 1}), major_dims); + Broadcast(Reshape(mask_range, {0}, {n, 1}), major_dims); auto body_a = loop_vars[0]; auto body_l = loop_vars[1]; // row = l[..., i, :i] // select the whole i-th row, then mask out all columns past i-1 - auto zero = xla::ConstantR0(body_builder, 0); + auto zero = ConstantR0(body_builder, 0); auto l_i = DynamicSliceInMinorDims(body_l, {i, zero}, {1, n}); - auto row = xla::Select(xla::Ge(mask_range_row, i), mask_zeros_row, l_i); + auto row = Select(Ge(mask_range_row, i), mask_zeros_row, l_i); // a[..., i, i] auto a_ii = DynamicSliceInMinorDims(body_a, {i, i}, {1, 1}); // np.dot(row, np.swapaxes(row, -1, -2)) - auto diag_dot = BatchDot(row, row, - /*transpose_x=*/false, - /*transpose_y=*/true, /*conjugate_x=*/false, - /*conjugate_y=*/false, precision); + auto diag_dot = BatchDot(row, TransposeInMinorDims(row), precision); // l[..., i, i] = np.sqrt(a[..., i, i] - np.dot(row, // np.swapaxes(row, -1, -2))) - auto l_ii = - xla::Pow(a_ii - diag_dot, - FloatLiteral(body_builder, a_shape.element_type(), 0.5)); + auto l_ii = Sqrt(a_ii - diag_dot); // a[..., i+1:, i] // select the whole i-th column, then mask out all rows above i+1 auto a_0i = DynamicSliceInMinorDims(body_a, {i}, {1}); - auto a_ip1i = - xla::Select(xla::Le(mask_range_col, i), mask_zeros_col, a_0i); + auto a_ip1i = Select(Le(mask_range_col, i), mask_zeros_col, a_0i); // l[..., i+1:, i] = (a[..., i+1:, i] - np.dot(l[..., i+1:, :i], r.T)) / // l[..., i, i] // The columns in [i, n] are zeroed out in `row`, so we just have to // zero out rows above i+1 after the BatchDot. np.dot(l[..., :, :i], // r.T) - auto dot = BatchDot(body_l, row, - /*transpose_x=*/false, - /*transpose_y=*/true, /*conjugate_x=*/false, - /*conjugate_y=*/false, precision); + auto dot = BatchDot(body_l, TransposeInMinorDims(row), precision); // np.dot(l[..., i+1:, :i], r.T) - auto dot_ip1 = - xla::Select(xla::Le(mask_range_col, i), mask_zeros_col, dot); + auto dot_ip1 = Select(Le(mask_range_col, i), mask_zeros_col, dot); body_l = DynamicUpdateSliceInMinorDims(body_l, (a_ip1i - dot_ip1) / l_ii, {i}); @@ -136,12 +126,12 @@ xla::XlaOp CholeskyUnblocked(xla::XlaOp a, // column assign will wrap around and overwrite the diagonal assign. body_l = DynamicUpdateSliceInMinorDims(body_l, l_ii, {i, i}); - return std::vector{body_a, body_l}; + return std::vector{body_a, body_l}; }; TF_ASSIGN_OR_RETURN( auto cholesky_while, - XlaForEachIndex(n, xla::S32, body_fn, {a, l}, "unblocked", builder)); + ForEachIndex(n, S32, body_fn, {a, l}, "unblocked", builder)); return cholesky_while[1]; }); @@ -149,34 +139,35 @@ xla::XlaOp CholeskyUnblocked(xla::XlaOp a, } // namespace -xla::XlaOp Cholesky(xla::XlaOp a, int64 block_size, - xla::PrecisionConfig::Precision precision) { - xla::XlaBuilder* builder = a.builder(); - return builder->ReportErrorOrReturn([&]() -> xla::StatusOr { - TF_ASSIGN_OR_RETURN(xla::Shape a_shape, builder->GetShape(a)); - const int ndims = xla::ShapeUtil::Rank(a_shape); +XlaOp Cholesky(XlaOp a, int64 block_size, + PrecisionConfig::Precision precision) { + XlaBuilder* builder = a.builder(); + return builder->ReportErrorOrReturn([&]() -> StatusOr { + TF_ASSIGN_OR_RETURN(Shape a_shape, builder->GetShape(a)); + const int ndims = ShapeUtil::Rank(a_shape); if (ndims < 2) { - return errors::InvalidArgument( - "Arguments to Cholesky must have rank >= 2: ", ndims); + return InvalidArgument( + "Argument to Cholesky must have rank >= 2; shape was %s", + a_shape.ToString()); } - const int64 n = xla::ShapeUtil::GetDimension(a_shape, -1); - if (n != xla::ShapeUtil::GetDimension(a_shape, -2)) { - return errors::InvalidArgument( - "Arguments to Cholesky must be square matrices: ", - xla::ShapeUtil::HumanString(a_shape)); + const int64 n = ShapeUtil::GetDimension(a_shape, -1); + if (n != ShapeUtil::GetDimension(a_shape, -2)) { + return InvalidArgument( + "Argument to Cholesky must be batched square matrices; got shape %s", + ShapeUtil::HumanString(a_shape)); } if (block_size < 1) { - return errors::InvalidArgument( - "block_size argument to Cholesky must be >= 1; got ", block_size); + return InvalidArgument( + "block_size argument to Cholesky must be >= 1; got %d", block_size); } // Blocked left-looking Cholesky factorization. // Algorithm 1 from // Haidar, Azzam, et al. "High-performance Cholesky factorization for // GPU-only execution." Proceedings of General Purpose GPUs. ACM, 2017. - xla::XlaOp l = xla::ZerosLike(a); + XlaOp l = ZerosLike(a); for (int64 i = 0; i < n; i += block_size) { int64 k = std::min(block_size, n - i); if (i > 0) { @@ -185,9 +176,7 @@ xla::XlaOp Cholesky(xla::XlaOp a, int64 block_size, // a[i:, i:i+k] -= np.dot(l[i:, :i], np.transpose(l[i:i+k, :i])) auto lhs = SliceInMinorDims(l, {i, 0}, {n, i}); auto rhs = SliceInMinorDims(l, {i, 0}, {i + k, i}); - auto delta = BatchDot(lhs, rhs, /*transpose_x=*/false, - /*transpose_y=*/true, /*conjugate_x=*/false, - /*conjugate_y=*/false, precision); + auto delta = BatchDot(lhs, TransposeInMinorDims(rhs), precision); auto before = SliceInMinorDims(a, {i, i}, {n, i + k}); a = UpdateSliceInMinorDims(a, before - delta, {i, i}); } @@ -214,4 +203,4 @@ xla::XlaOp Cholesky(xla::XlaOp a, int64 block_size, }); } -} // namespace tensorflow +} // namespace xla diff --git a/tensorflow/compiler/tf2xla/lib/cholesky.h b/tensorflow/compiler/xla/client/lib/cholesky.h similarity index 87% rename from tensorflow/compiler/tf2xla/lib/cholesky.h rename to tensorflow/compiler/xla/client/lib/cholesky.h index 9a561c34b92ee45059f2a05336e682838f8e36e2..0bae26837c0f14dd0cfab82cf426becc787ec11c 100644 --- a/tensorflow/compiler/tf2xla/lib/cholesky.h +++ b/tensorflow/compiler/xla/client/lib/cholesky.h @@ -13,13 +13,13 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_COMPILER_TF2XLA_LIB_CHOLESKY_H_ -#define TENSORFLOW_COMPILER_TF2XLA_LIB_CHOLESKY_H_ +#ifndef TENSORFLOW_COMPILER_XLA_CLIENT_LIB_CHOLESKY_H_ +#define TENSORFLOW_COMPILER_XLA_CLIENT_LIB_CHOLESKY_H_ #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -namespace tensorflow { +namespace xla { // Computes the Cholesky decompositions of a batch of symmetric positive // definite matrices. @@ -34,6 +34,6 @@ xla::XlaOp Cholesky( xla::XlaOp a, int64 block_size = 256, xla::PrecisionConfig::Precision precision = xla::PrecisionConfig::HIGHEST); -} // namespace tensorflow +} // namespace xla -#endif // TENSORFLOW_COMPILER_TF2XLA_LIB_CHOLESKY_H_ +#endif // TENSORFLOW_COMPILER_XLA_CLIENT_LIB_CHOLESKY_H_ diff --git a/tensorflow/compiler/xla/client/lib/cholesky_test.cc b/tensorflow/compiler/xla/client/lib/cholesky_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..ba9580a3d32225625acc1447344b7d2c16c5d8a5 --- /dev/null +++ b/tensorflow/compiler/xla/client/lib/cholesky_test.cc @@ -0,0 +1,166 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/client/lib/cholesky.h" + +#include +#include +#include + +#include "tensorflow/compiler/xla/array2d.h" +#include "tensorflow/compiler/xla/client/lib/arithmetic.h" +#include "tensorflow/compiler/xla/client/lib/matrix.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/literal.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/tests/client_library_test_base.h" +#include "tensorflow/compiler/xla/tests/literal_test_util.h" +#include "tensorflow/compiler/xla/tests/test_macros.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/core/lib/core/status_test_util.h" + +namespace { + +using xla::int64; + +using CholeskyTest = xla::ClientLibraryTestBase; + +XLA_TEST_F(CholeskyTest, Simple) { + xla::XlaBuilder builder(TestName()); + + xla::Array2D a_vals({ + {4, 6, 8, 10}, + {6, 45, 54, 63}, + {8, 54, 146, 166}, + {10, 63, 166, 310}, + }); + + xla::XlaOp a; + auto a_data = CreateR2Parameter(a_vals, 0, "a", &builder, &a); + xla::Cholesky(a, /*block_size=*/2); + + xla::Array2D expected({ + {2, 0, 0, 0}, + {3, 6, 0, 0}, + {4, 7, 9, 0}, + {5, 8, 10, 11}, + }); + + ComputeAndCompareR2(&builder, expected, {a_data.get()}, + xla::ErrorSpec(1e-4, 1e-4)); +} + +XLA_TEST_F(CholeskyTest, Simple2) { + xla::XlaBuilder builder(TestName()); + + xla::Array2D a_vals({ + {16, 24, 8, 12}, + {24, 61, 82, 48}, + {8, 82, 456, 106}, + {12, 48, 106, 62}, + }); + + xla::XlaOp a; + auto a_data = CreateR2Parameter(a_vals, 0, "a", &builder, &a); + xla::Cholesky(a); + + xla::Array2D expected( + {{4, 0, 0, 0}, {6, 5, 0, 0}, {2, 14, 16, 0}, {3, 6, 1, 4}}); + + ComputeAndCompareR2(&builder, expected, {a_data.get()}, + xla::ErrorSpec(1e-4, 1e-4)); +} + +XLA_TEST_F(CholeskyTest, SimpleBatched) { + xla::XlaBuilder builder(TestName()); + + xla::Array3D a_vals({ + { + {4, 6, 8, 10}, + {6, 45, 54, 63}, + {8, 54, 146, 166}, + {10, 63, 166, 310}, + }, + { + {16, 24, 8, 12}, + {24, 61, 82, 48}, + {8, 82, 456, 106}, + {12, 48, 106, 62}, + }, + }); + + xla::XlaOp a; + auto a_data = CreateR3Parameter(a_vals, 0, "a", &builder, &a); + xla::Cholesky(a); + + xla::Array3D expected({ + { + {2, 0, 0, 0}, + {3, 6, 0, 0}, + {4, 7, 9, 0}, + {5, 8, 10, 11}, + }, + {{4, 0, 0, 0}, {6, 5, 0, 0}, {2, 14, 16, 0}, {3, 6, 1, 4}}, + }); + + ComputeAndCompareR3(&builder, expected, {a_data.get()}, + xla::ErrorSpec(1e-4, 1e-4)); +} + +using CholeskyTestCase = std::tuple; + +class RandomCholeskyTest + : public xla::ClientLibraryTestBase, + public ::testing::WithParamInterface {}; + +XLA_TEST_P(RandomCholeskyTest, Random) { + xla::XlaBuilder builder(TestName()); + + auto test_params = GetParam(); + std::vector dimensions = {std::get<0>(test_params), + std::get<1>(test_params), + std::get<1>(test_params)}; + xla::Shape shape = xla::ShapeUtil::MakeShape(xla::F32, dimensions); + TF_ASSERT_OK_AND_ASSIGN( + auto literal, + xla::LiteralUtil::CreateRandomLiteral(shape, 0.0, 1.0)); + + auto input = xla::Parameter(&builder, 0, shape, "input"); + // Form a random positive definite matrix. + auto matrix = xla::BatchDot(input, TransposeInMinorDims(input), + xla::PrecisionConfig::HIGHEST); + + auto cholesky = xla::Cholesky(matrix, /*block_size=*/4); + + // Verify that ||matrix - cholesky * cholesky_t||_2 ~= 0 + auto verification = xla::BatchDot(cholesky, TransposeInMinorDims(cholesky), + xla::PrecisionConfig::HIGHEST); + auto delta = matrix - verification; + xla::Reduce(delta * delta, xla::ConstantR0(&builder, 0.0), + CreateScalarAddComputation(xla::F32, &builder), {0, 1, 2}); + + TF_ASSERT_OK_AND_ASSIGN(auto input_data, client_->TransferToServer(literal)); + ComputeAndCompareR0(&builder, 0.0, {input_data.get()}, + xla::ErrorSpec(1e-4, 1e-4)); +} + +INSTANTIATE_TEST_CASE_P(RandomCholeskyTestInstance, RandomCholeskyTest, + ::testing::Values(CholeskyTestCase{1, 1}, + CholeskyTestCase{1, 2}, + CholeskyTestCase{10, 5}, + CholeskyTestCase{2, 20})); + +} // namespace diff --git a/tensorflow/compiler/tf2xla/lib/while_loop.cc b/tensorflow/compiler/xla/client/lib/loops.cc similarity index 50% rename from tensorflow/compiler/tf2xla/lib/while_loop.cc rename to tensorflow/compiler/xla/client/lib/loops.cc index 594ab1dfd0700f47501712183f6efe62d17e15e7..721f987628a8ac7da3f3f872939c3f0457d6bbe2 100644 --- a/tensorflow/compiler/tf2xla/lib/while_loop.cc +++ b/tensorflow/compiler/xla/client/lib/loops.cc @@ -13,44 +13,43 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/tf2xla/lib/while_loop.h" -#include "tensorflow/compiler/tf2xla/lib/util.h" +#include "tensorflow/compiler/xla/client/lib/loops.h" + +#include "tensorflow/compiler/xla/client/lib/constants.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" -namespace tensorflow { +namespace xla { -xla::StatusOr> XlaWhileLoop( - const LoopConditionFunction& condition_function, - const LoopBodyFunction& body_function, - absl::Span initial_values, absl::string_view name, - xla::XlaBuilder* builder) { +StatusOr> WhileLoopHelper( + const WhileLoopHelperConditionFunction& condition_function, + const WhileLoopHelperBodyFunction& body_function, + absl::Span initial_values, absl::string_view name, + XlaBuilder* builder) { int arity = initial_values.size(); - std::vector var_shapes; + std::vector var_shapes; var_shapes.reserve(arity); - for (const xla::XlaOp& input : initial_values) { + for (const XlaOp& input : initial_values) { TF_ASSIGN_OR_RETURN(auto shape, builder->GetShape(input)); var_shapes.push_back(std::move(shape)); } - xla::Shape tuple_shape = xla::ShapeUtil::MakeTupleShape(var_shapes); + Shape tuple_shape = ShapeUtil::MakeTupleShape(var_shapes); // Unpacks a tuple into its component parts. - auto unpack_tuple = [](xla::XlaOp tuple, int arity, - xla::XlaBuilder* builder) { - std::vector elements(arity); + auto unpack_tuple = [](XlaOp tuple, int arity, XlaBuilder* builder) { + std::vector elements(arity); for (int i = 0; i < arity; ++i) { - elements[i] = xla::GetTupleElement(tuple, i); + elements[i] = GetTupleElement(tuple, i); } return elements; }; // Build the condition. - std::unique_ptr cond_builder = + std::unique_ptr cond_builder = builder->CreateSubBuilder(absl::StrCat(name, "_condition")); { - auto parameter = - xla::Parameter(cond_builder.get(), 0, tuple_shape, "parameter"); + auto parameter = Parameter(cond_builder.get(), 0, tuple_shape, "parameter"); TF_RETURN_IF_ERROR( condition_function(unpack_tuple(parameter, arity, cond_builder.get()), @@ -60,11 +59,10 @@ xla::StatusOr> XlaWhileLoop( TF_ASSIGN_OR_RETURN(auto cond, cond_builder->Build()); // Build the body. - std::unique_ptr body_builder = + std::unique_ptr body_builder = builder->CreateSubBuilder(absl::StrCat(name, "_body")); { - auto parameter = - xla::Parameter(body_builder.get(), 0, tuple_shape, "parameter"); + auto parameter = Parameter(body_builder.get(), 0, tuple_shape, "parameter"); TF_ASSIGN_OR_RETURN( auto result, @@ -72,56 +70,54 @@ xla::StatusOr> XlaWhileLoop( body_builder.get())); TF_RET_CHECK(result.size() == initial_values.size()); - xla::Tuple(body_builder.get(), result); + Tuple(body_builder.get(), result); } TF_ASSIGN_OR_RETURN(auto body, body_builder->Build()); - auto outputs = xla::While(cond, body, xla::Tuple(builder, initial_values)); + auto outputs = While(cond, body, Tuple(builder, initial_values)); return unpack_tuple(outputs, arity, builder); } -xla::StatusOr> XlaForEachIndex( - int64 num_iterations, xla::PrimitiveType num_iterations_type, +StatusOr> ForEachIndex( + int64 num_iterations, PrimitiveType num_iterations_type, const ForEachIndexBodyFunction& body_function, - absl::Span initial_values, absl::string_view name, - xla::XlaBuilder* builder) { - auto while_cond_fn = - [&](absl::Span values, - xla::XlaBuilder* cond_builder) -> xla::StatusOr { - return xla::Lt(values[0], IntegerLiteral(cond_builder, num_iterations_type, - num_iterations)); + absl::Span initial_values, absl::string_view name, + XlaBuilder* builder) { + auto while_cond_fn = [&](absl::Span values, + XlaBuilder* cond_builder) -> StatusOr { + return Lt(values[0], ConstantR0WithType(cond_builder, num_iterations_type, + num_iterations)); }; - auto while_body_fn = [&](absl::Span values, - xla::XlaBuilder* body_builder) - -> xla::StatusOr> { - xla::XlaOp iteration = values[0]; + auto while_body_fn = + [&](absl::Span values, + XlaBuilder* body_builder) -> StatusOr> { + XlaOp iteration = values[0]; - std::vector updated_values; + std::vector updated_values; updated_values.reserve(values.size()); - updated_values.push_back(xla::Add( + updated_values.push_back(Add( iteration, - xla::ConstantLiteral(body_builder, - xla::LiteralUtil::One(num_iterations_type)))); + ConstantLiteral(body_builder, LiteralUtil::One(num_iterations_type)))); values.remove_prefix(1); - TF_ASSIGN_OR_RETURN(std::vector body_outputs, + TF_ASSIGN_OR_RETURN(std::vector body_outputs, body_function(iteration, values, body_builder)); updated_values.insert(updated_values.end(), body_outputs.begin(), body_outputs.end()); return updated_values; }; - std::vector values; + std::vector values; values.reserve(initial_values.size() + 1); - values.push_back(xla::ConstantLiteral( - builder, xla::LiteralUtil::Zero(num_iterations_type))); + values.push_back( + ConstantLiteral(builder, LiteralUtil::Zero(num_iterations_type))); values.insert(values.end(), initial_values.begin(), initial_values.end()); - TF_ASSIGN_OR_RETURN(values, XlaWhileLoop(while_cond_fn, while_body_fn, values, - name, builder)); + TF_ASSIGN_OR_RETURN(values, WhileLoopHelper(while_cond_fn, while_body_fn, + values, name, builder)); values.erase(values.begin(), values.begin() + 1); return values; } -} // namespace tensorflow +} // namespace xla diff --git a/tensorflow/compiler/tf2xla/lib/while_loop.h b/tensorflow/compiler/xla/client/lib/loops.h similarity index 62% rename from tensorflow/compiler/tf2xla/lib/while_loop.h rename to tensorflow/compiler/xla/client/lib/loops.h index f2134bb4495a12b8342961d96f70e7737f816c7d..e11de59493e9c1de51fbdb6c45dab6d82b85a62a 100644 --- a/tensorflow/compiler/tf2xla/lib/while_loop.h +++ b/tensorflow/compiler/xla/client/lib/loops.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_COMPILER_TF2XLA_LIB_WHILE_LOOP_H_ -#define TENSORFLOW_COMPILER_TF2XLA_LIB_WHILE_LOOP_H_ +#ifndef TENSORFLOW_COMPILER_XLA_CLIENT_LIB_LOOPS_H_ +#define TENSORFLOW_COMPILER_XLA_CLIENT_LIB_LOOPS_H_ #include #include @@ -25,19 +25,18 @@ limitations under the License. #include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/statusor.h" -namespace tensorflow { +namespace xla { // Function that builds a loop condition. Takes as input a sequence of input // values, and returns a boolean value representing if the condition succeeds. -typedef std::function(absl::Span, - xla::XlaBuilder*)> - LoopConditionFunction; +typedef std::function(absl::Span, XlaBuilder*)> + WhileLoopHelperConditionFunction; // Function that builds a loop body. Takes as input a sequence of input values // and returns a sequence of output values. -typedef std::function>( - absl::Span, xla::XlaBuilder*)> - LoopBodyFunction; +typedef std::function>(absl::Span, + XlaBuilder*)> + WhileLoopHelperBodyFunction; // Helper function for building an XLA while loop, where the values carried by // the loop are a tuple of values, e.g., (a, b, c): @@ -47,27 +46,27 @@ typedef std::function>( // init: (a, b, c) // ) // 'name' is a descriptive name for the loop. -xla::StatusOr> XlaWhileLoop( - const LoopConditionFunction& condition_function, - const LoopBodyFunction& body_function, - absl::Span initial_values, absl::string_view name, - xla::XlaBuilder* builder); +StatusOr> WhileLoopHelper( + const WhileLoopHelperConditionFunction& condition_function, + const WhileLoopHelperBodyFunction& body_function, + absl::Span initial_values, absl::string_view name, + XlaBuilder* builder); // Builds an XLA loop that repeats a computation `num_iterations` times. // // The body function (ForEachIndexBodyFunction) takes as input a pair of // (current iteration number, loop-carried values), and returns an updated // vector of the loop-carried values. -typedef std::function>( - xla::XlaOp, absl::Span, xla::XlaBuilder*)> +typedef std::function>( + XlaOp, absl::Span, XlaBuilder*)> ForEachIndexBodyFunction; -xla::StatusOr> XlaForEachIndex( - int64 num_iterations, xla::PrimitiveType num_iterations_type, +StatusOr> ForEachIndex( + int64 num_iterations, PrimitiveType num_iterations_type, const ForEachIndexBodyFunction& body_function, - absl::Span initial_values, absl::string_view name, - xla::XlaBuilder* builder); + absl::Span initial_values, absl::string_view name, + XlaBuilder* builder); -} // namespace tensorflow +} // namespace xla -#endif // TENSORFLOW_COMPILER_TF2XLA_LIB_WHILE_LOOP_H_ +#endif // TENSORFLOW_COMPILER_XLA_CLIENT_LIB_LOOPS_H_ diff --git a/tensorflow/compiler/xla/client/lib/math.cc b/tensorflow/compiler/xla/client/lib/math.cc index 08a887a6e4660cb2528f0ec7244b7ccc540808d2..36fdda39b4124b9100c6054160f9c17bdf787d6f 100644 --- a/tensorflow/compiler/xla/client/lib/math.cc +++ b/tensorflow/compiler/xla/client/lib/math.cc @@ -268,17 +268,16 @@ XlaOp Digamma(XlaOp input) { // Implements Banker's rounding: numbers that are equidistant between two // integers are rounded towards even. XlaOp RoundToEven(XlaOp x) { - auto half = xla::ScalarLike(x, 0.5); - auto one = xla::ScalarLike(x, 1.0); - auto two = xla::ScalarLike(x, 2.0); + auto half = ScalarLike(x, 0.5); + auto one = ScalarLike(x, 1.0); + auto two = ScalarLike(x, 2.0); - auto round_val = xla::Floor(x); + auto round_val = Floor(x); auto fraction = x - round_val; - auto nearest_even_int = round_val - two * xla::Floor(half * x); - auto is_odd = xla::Eq(nearest_even_int, one); - return xla::Select(xla::Or(xla::Gt(fraction, half), - xla::And(xla::Eq(fraction, half), is_odd)), - round_val + one, round_val); + auto nearest_even_int = round_val - two * Floor(half * x); + auto is_odd = Eq(nearest_even_int, one); + return Select(Or(Gt(fraction, half), And(Eq(fraction, half), is_odd)), + round_val + one, round_val); } // Trigonometric functions. @@ -320,4 +319,13 @@ XlaOp Cosh(XlaOp x) { return (Exp(x) + Exp(-x)) * ScalarLike(x, 0.5); } XlaOp Sinh(XlaOp x) { return (Exp(x) - Exp(-x)) * ScalarLike(x, 0.5); } +XlaOp MaybeConjugate(XlaOp x, bool conjugate) { + XlaBuilder* builder = x.builder(); + return builder->ReportErrorOrReturn([&]() -> StatusOr { + TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(x)); + auto perform_conj = shape.element_type() == C64 && conjugate; + return perform_conj ? Conj(x) : x; + }); +} + } // namespace xla diff --git a/tensorflow/compiler/xla/client/lib/math.h b/tensorflow/compiler/xla/client/lib/math.h index 3f06d04b9ae98b3aa75e68cd07810b2b4c24d280..17612bf9fdc0f1eabb338671c93c025c5b268872 100644 --- a/tensorflow/compiler/xla/client/lib/math.h +++ b/tensorflow/compiler/xla/client/lib/math.h @@ -86,6 +86,10 @@ XlaOp Cosh(XlaOp x); // Computes the hyperbolic sine of 'x'. XlaOp Sinh(XlaOp x); +// Applies a complex conjugation operation if `a` is complex and `conjugate` +// is true, otherwise returns its argument. +xla::XlaOp MaybeConjugate(xla::XlaOp x, bool conjugate); + } // namespace xla #endif // TENSORFLOW_COMPILER_XLA_CLIENT_LIB_MATH_H_ diff --git a/tensorflow/compiler/xla/client/lib/matrix.cc b/tensorflow/compiler/xla/client/lib/matrix.cc new file mode 100644 index 0000000000000000000000000000000000000000..ffd744d190885b8e3f4149a48a706498b3787618 --- /dev/null +++ b/tensorflow/compiler/xla/client/lib/matrix.cc @@ -0,0 +1,185 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/client/lib/matrix.h" + +#include +#include + +#include "absl/types/span.h" +#include "tensorflow/compiler/xla/client/lib/arithmetic.h" +#include "tensorflow/compiler/xla/client/lib/constants.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/util.h" + +namespace xla { + +XlaOp IdentityMatrix(XlaBuilder* builder, PrimitiveType type, int64 m, + int64 n) { + auto a = Iota(builder, type, m); + auto b = Iota(builder, type, n); + auto indicator = Eq(a, Broadcast(b, {m}), /*broadcast_dimensions=*/{0}); + return ConvertElementType(indicator, type); +} + +XlaOp GetMatrixDiagonal(XlaOp x) { + XlaBuilder* builder = x.builder(); + return builder->ReportErrorOrReturn([&]() -> StatusOr { + TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(x)); + const int64 n_dims = ShapeUtil::Rank(shape); + TF_RET_CHECK(n_dims >= 2); + const int64 m = shape.dimensions(n_dims - 2); + const int64 n = shape.dimensions(n_dims - 1); + absl::Span major_dims = + AsInt64Slice(shape.dimensions()).subspan(/*pos=*/0, /*len=*/n_dims - 2); + auto a = Iota(builder, U32, n); + auto b = Iota(builder, U32, m); + auto indicator = Eq(b, Broadcast(a, {m}), /*broadcast_dimensions=*/{0}); + auto mask = Broadcast(indicator, major_dims); + + // TPUs don't support S64 add reduction at the moment. But fortunately + // OR-reductions work just as well for integers. + XlaComputation reducer = + primitive_util::IsIntegralType(shape.element_type()) + ? CreateScalarOrComputation(shape.element_type(), builder) + : CreateScalarAddComputation(shape.element_type(), builder); + + return Reduce(Select(mask, x, Zeros(builder, shape)), ScalarLike(x, 0), + reducer, {m >= n ? n_dims - 2 : n_dims - 1}); + }); +} + +XlaOp Triangle(XlaOp x, bool lower) { + XlaBuilder* builder = x.builder(); + return builder->ReportErrorOrReturn([&]() -> StatusOr { + TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(x)); + const int64 n_dims = ShapeUtil::Rank(shape); + TF_RET_CHECK(n_dims >= 2); + const int64 m = shape.dimensions(n_dims - 2); + const int64 n = shape.dimensions(n_dims - 1); + absl::Span major_dims = + AsInt64Slice(shape.dimensions()).subspan(/*pos=*/0, /*len=*/n_dims - 2); + auto a = Iota(builder, U32, n); + auto b = Iota(builder, U32, m); + XlaOp indicator; + if (lower) { + indicator = Ge(b, Broadcast(a, {m}), /*broadcast_dimensions=*/{0}); + } else { + indicator = Le(b, Broadcast(a, {m}), /*broadcast_dimensions=*/{0}); + } + auto mask = Broadcast(indicator, major_dims); + + return Select(mask, x, Zeros(builder, shape)); + }); +} + +XlaOp UpperTriangle(XlaOp x) { return Triangle(x, false); } + +XlaOp LowerTriangle(XlaOp x) { return Triangle(x, true); } + +XlaOp BatchDot(XlaOp x, XlaOp y, PrecisionConfig::Precision precision) { + XlaBuilder* builder = x.builder(); + return builder->ReportErrorOrReturn([&]() -> StatusOr { + TF_ASSIGN_OR_RETURN(Shape x_shape, builder->GetShape(x)); + TF_ASSIGN_OR_RETURN(Shape y_shape, builder->GetShape(y)); + + // Check that both tensors have the same number of dimensions. There must be + // at least two (the batch dimensions can be empty). + if (ShapeUtil::Rank(x_shape) != ShapeUtil::Rank(y_shape)) { + return InvalidArgument( + "Arguments to BatchDot have different ranks: %s vs. %s", + ShapeUtil::HumanString(x_shape), ShapeUtil::HumanString(y_shape)); + } + const int ndims = ShapeUtil::Rank(x_shape); + if (ndims < 2) { + return InvalidArgument( + "Arguments to BatchDot must have rank >= 2: got %d", ndims); + } + + // The batch dimensions must be equal and the matrix dimensions must be + // valid. + std::vector batch_dimension_numbers; + for (int i = 0; i < ndims - 2; ++i) { + if (x_shape.dimensions(i) != y_shape.dimensions(i)) { + return InvalidArgument( + "Dimension %d of inputs to BatchDot must be equal: shapes %s vs %s", + i, ShapeUtil::HumanString(x_shape), + ShapeUtil::HumanString(y_shape)); + } + batch_dimension_numbers.push_back(i); + } + + int x_inner_dim = ndims - 1; + int y_inner_dim = ndims - 2; + if (x_shape.dimensions(x_inner_dim) != y_shape.dimensions(y_inner_dim)) { + return InvalidArgument( + "Dimensions %d and %d of arguments to BatchDot must be equal: " + "shapes %s vs %s", + x_inner_dim, y_inner_dim, ShapeUtil::HumanString(x_shape), + ShapeUtil::HumanString(y_shape)); + } + + // Check for zero lhs/rhs dim size. + if (ShapeUtil::IsZeroElementArray(x_shape) || + ShapeUtil::IsZeroElementArray(y_shape)) { + std::vector dimensions(batch_dimension_numbers.size()); + for (int i = 0; i < batch_dimension_numbers.size(); ++i) { + dimensions[i] = x_shape.dimensions(batch_dimension_numbers[i]); + } + int x_outer_dim = ndims - 2; + int y_outer_dim = ndims - 1; + dimensions.push_back(x_shape.dimensions(x_outer_dim)); + dimensions.push_back(y_shape.dimensions(y_outer_dim)); + return Broadcast( + ConstantLiteral(builder, LiteralUtil::Zero(x_shape.element_type())), + dimensions); + } + + PrecisionConfig precision_proto; + precision_proto.add_operand_precision(precision); + precision_proto.add_operand_precision(precision); + + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(x_inner_dim); + dot_dnums.add_rhs_contracting_dimensions(y_inner_dim); + for (auto batch_dimension_number : batch_dimension_numbers) { + dot_dnums.add_lhs_batch_dimensions(batch_dimension_number); + dot_dnums.add_rhs_batch_dimensions(batch_dimension_number); + } + + return DotGeneral(x, y, dot_dnums, &precision_proto); + }); +} + +XlaOp TransposeInMinorDims(XlaOp x) { + XlaBuilder* builder = x.builder(); + return builder->ReportErrorOrReturn([&]() -> StatusOr { + TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(x)); + const int64 n_dims = ShapeUtil::Rank(shape); + TF_RET_CHECK(n_dims >= 2); + std::vector permutation(n_dims); + std::iota(permutation.begin(), permutation.end(), 0); + std::swap(permutation[n_dims - 1], permutation[n_dims - 2]); + return Transpose(x, permutation); + }); +} + +XlaOp MaybeTransposeInMinorDims(XlaOp x, bool transpose) { + return transpose ? TransposeInMinorDims(x) : x; +} +} // namespace xla diff --git a/tensorflow/compiler/xla/client/lib/numeric.h b/tensorflow/compiler/xla/client/lib/matrix.h similarity index 56% rename from tensorflow/compiler/xla/client/lib/numeric.h rename to tensorflow/compiler/xla/client/lib/matrix.h index f62fdab4b0e5e84347cfaa1424a8c2e5c58dd3ce..8856f99c7a0fee8f315aac11fab392cf5536f57b 100644 --- a/tensorflow/compiler/xla/client/lib/numeric.h +++ b/tensorflow/compiler/xla/client/lib/matrix.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_COMPILER_XLA_CLIENT_LIB_NUMERIC_H_ -#define TENSORFLOW_COMPILER_XLA_CLIENT_LIB_NUMERIC_H_ +#ifndef TENSORFLOW_COMPILER_XLA_CLIENT_LIB_MATRIX_H_ +#define TENSORFLOW_COMPILER_XLA_CLIENT_LIB_MATRIX_H_ #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/types.h" @@ -40,6 +40,34 @@ XlaOp UpperTriangle(XlaOp x); // Get the lower triangle part of the last two dimensions XlaOp LowerTriangle(XlaOp x); +// Multiplies slices of two tensors in batches. + +// Multiplies all slices of `Tensor` `x` and `y` (each slice can be +// viewed as an element of a batch), and arranges the individual results +// in a single output tensor of the same batch size. +// +// The input tensors `x` and `y` are 2-D or higher with shape `[..., r_x, c_x]` +// and `[..., r_y, c_y]`. +// +// The output tensor is 2-D or higher with shape `[..., r_o, c_o]`, where: +// +// r_o = c_x if transpose_x else r_x +// c_o = r_y if transpose_y else c_y +// +// It is computed as: +// +// output[..., :, :] = matrix(x[..., :, :]) * matrix(y[..., :, :]) +xla::XlaOp BatchDot( + xla::XlaOp x, xla::XlaOp y, + xla::PrecisionConfig::Precision precision = xla::PrecisionConfig::DEFAULT); + +// Transposes a stack of matrices `x` by swapping the last two dimensions. +xla::XlaOp TransposeInMinorDims(xla::XlaOp x); + +// Transposes `x` in its minor dimensions if `transpose` is true, otherwise +// returns `x` unchanged. +xla::XlaOp MaybeTransposeInMinorDims(xla::XlaOp x, bool transpose); + } // namespace xla -#endif // TENSORFLOW_COMPILER_XLA_CLIENT_LIB_NUMERIC_H_ +#endif // TENSORFLOW_COMPILER_XLA_CLIENT_LIB_MATRIX_H_ diff --git a/tensorflow/compiler/xla/client/lib/numeric_test.cc b/tensorflow/compiler/xla/client/lib/matrix_test.cc similarity index 53% rename from tensorflow/compiler/xla/client/lib/numeric_test.cc rename to tensorflow/compiler/xla/client/lib/matrix_test.cc index 7d6aedd49462bd4f075f90d0b0f85c40f1191aa1..0593a7517ac125ca8dc5395cee76f6bc23232cd3 100644 --- a/tensorflow/compiler/xla/client/lib/numeric_test.cc +++ b/tensorflow/compiler/xla/client/lib/matrix_test.cc @@ -13,7 +13,9 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/xla/client/lib/numeric.h" +#include "tensorflow/compiler/xla/client/lib/matrix.h" + +#include "tensorflow/compiler/xla/client/lib/slicing.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" @@ -24,13 +26,13 @@ limitations under the License. namespace xla { namespace { -class NumericTest : public ClientLibraryTestBase { +class MatrixTest : public ClientLibraryTestBase { protected: template void TestMatrixDiagonal(); }; -XLA_TEST_F(NumericTest, Triangle) { +XLA_TEST_F(MatrixTest, Triangle) { XlaBuilder builder(TestName()); Array3D input(2, 3, 4); input.FillIota(0); @@ -45,7 +47,7 @@ XLA_TEST_F(NumericTest, Triangle) { } template -void NumericTest::TestMatrixDiagonal() { +void MatrixTest::TestMatrixDiagonal() { XlaBuilder builder("GetMatrixDiagonal"); Array3D input(2, 3, 4); input.FillIota(0); @@ -58,11 +60,46 @@ void NumericTest::TestMatrixDiagonal() { ComputeAndCompareR2(&builder, expected, {a_data.get()}); } -XLA_TEST_F(NumericTest, GetMatrixDiagonal_S32) { TestMatrixDiagonal(); } +XLA_TEST_F(MatrixTest, GetMatrixDiagonal_S32) { TestMatrixDiagonal(); } + +XLA_TEST_F(MatrixTest, GetMatrixDiagonal_S64) { TestMatrixDiagonal(); } + +XLA_TEST_F(MatrixTest, GetMatrixDiagonal_F32) { TestMatrixDiagonal(); } + +Array3D BatchedAValsFull() { + return {{ + {2, 0, 1, 2}, + {3, 6, 0, 1}, + {4, 7, 9, 0}, + {5, 8, 10, 11}, + }, + { + {16, 24, 8, 12}, + {24, 61, 82, 48}, + {8, 82, 456, 106}, + {12, 48, 106, 62}, + }}; +} + +XLA_TEST_F(MatrixTest, RowBatchDot) { + XlaBuilder builder(TestName()); + + int n = 4; -XLA_TEST_F(NumericTest, GetMatrixDiagonal_S64) { TestMatrixDiagonal(); } + XlaOp a, row, index; + auto a_data = + CreateR3Parameter(BatchedAValsFull(), 0, "a", &builder, &a); + auto row_data = CreateR3Parameter({{{9, 1, 0, 0}}, {{2, 4, 0, 0}}}, 1, + "row", &builder, &row); + // Select {{3, 6, 0, 1}, {24, 61, 82, 48}} out of BatchedAValsFull(). + auto index_data = CreateR0Parameter(1, 2, "index", &builder, &index); -XLA_TEST_F(NumericTest, GetMatrixDiagonal_F32) { TestMatrixDiagonal(); } + auto l_index = DynamicSliceInMinorDims( + a, {index, ConstantR0(&builder, 0)}, {1, n}); + BatchDot(l_index, TransposeInMinorDims(row)); + ComputeAndCompareR3(&builder, {{{33}}, {{292}}}, + {a_data.get(), row_data.get(), index_data.get()}); +} } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/client/lib/numeric.cc b/tensorflow/compiler/xla/client/lib/numeric.cc deleted file mode 100644 index 377654220b5df4487e9e194361473d54ff46a54e..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/xla/client/lib/numeric.cc +++ /dev/null @@ -1,89 +0,0 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include -#include - -#include "absl/types/span.h" -#include "tensorflow/compiler/xla/client/lib/arithmetic.h" -#include "tensorflow/compiler/xla/client/lib/constants.h" -#include "tensorflow/compiler/xla/client/lib/numeric.h" - -namespace xla { - -XlaOp IdentityMatrix(XlaBuilder* builder, PrimitiveType type, int64 m, - int64 n) { - auto a = Iota(builder, type, m); - auto b = Iota(builder, type, n); - auto indicator = Eq(a, Broadcast(b, {m}), /*broadcast_dimensions=*/{0}); - return ConvertElementType(indicator, type); -} - -XlaOp GetMatrixDiagonal(XlaOp x) { - XlaBuilder* builder = x.builder(); - return builder->ReportErrorOrReturn([&]() -> StatusOr { - TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(x)); - const int64 n_dims = ShapeUtil::Rank(shape); - TF_RET_CHECK(n_dims >= 2); - const int64 m = shape.dimensions(n_dims - 2); - const int64 n = shape.dimensions(n_dims - 1); - absl::Span major_dims = - AsInt64Slice(shape.dimensions()).subspan(/*pos=*/0, /*len=*/n_dims - 2); - auto a = Iota(builder, U32, n); - auto b = Iota(builder, U32, m); - auto indicator = Eq(b, Broadcast(a, {m}), /*broadcast_dimensions=*/{0}); - auto mask = Broadcast(indicator, major_dims); - - // TPUs don't support S64 add reduction at the moment. But fortunately - // OR-reductions work just as well for integers. - XlaComputation reducer = - primitive_util::IsIntegralType(shape.element_type()) - ? CreateScalarOrComputation(shape.element_type(), builder) - : CreateScalarAddComputation(shape.element_type(), builder); - - return Reduce(Select(mask, x, Zeros(builder, shape)), ScalarLike(x, 0), - reducer, {m >= n ? n_dims - 2 : n_dims - 1}); - }); -} - -XlaOp Triangle(XlaOp x, bool lower) { - XlaBuilder* builder = x.builder(); - return builder->ReportErrorOrReturn([&]() -> StatusOr { - TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(x)); - const int64 n_dims = ShapeUtil::Rank(shape); - TF_RET_CHECK(n_dims >= 2); - const int64 m = shape.dimensions(n_dims - 2); - const int64 n = shape.dimensions(n_dims - 1); - absl::Span major_dims = - AsInt64Slice(shape.dimensions()).subspan(/*pos=*/0, /*len=*/n_dims - 2); - auto a = Iota(builder, U32, n); - auto b = Iota(builder, U32, m); - xla::XlaOp indicator; - if (lower) { - indicator = Ge(b, Broadcast(a, {m}), /*broadcast_dimensions=*/{0}); - } else { - indicator = Le(b, Broadcast(a, {m}), /*broadcast_dimensions=*/{0}); - } - auto mask = Broadcast(indicator, major_dims); - - return Select(mask, x, Zeros(builder, shape)); - }); -} - -XlaOp UpperTriangle(XlaOp x) { return Triangle(x, false); } - -XlaOp LowerTriangle(XlaOp x) { return Triangle(x, true); } - -} // namespace xla diff --git a/tensorflow/compiler/tf2xla/lib/qr.cc b/tensorflow/compiler/xla/client/lib/qr.cc similarity index 55% rename from tensorflow/compiler/tf2xla/lib/qr.cc rename to tensorflow/compiler/xla/client/lib/qr.cc index 6b3f2b6e065b5c99e2d0248237369ecc30188aa5..72ca653173b78d9338f632c41779f2a30db1e978 100644 --- a/tensorflow/compiler/tf2xla/lib/qr.cc +++ b/tensorflow/compiler/xla/client/lib/qr.cc @@ -13,18 +13,17 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/tf2xla/lib/qr.h" +#include "tensorflow/compiler/xla/client/lib/qr.h" #include #include -#include "tensorflow/compiler/tf2xla/lib/batch_dot.h" -#include "tensorflow/compiler/tf2xla/lib/util.h" -#include "tensorflow/compiler/tf2xla/lib/while_loop.h" #include "tensorflow/compiler/xla/client/lib/arithmetic.h" #include "tensorflow/compiler/xla/client/lib/constants.h" +#include "tensorflow/compiler/xla/client/lib/loops.h" #include "tensorflow/compiler/xla/client/lib/math.h" -#include "tensorflow/compiler/xla/client/lib/numeric.h" +#include "tensorflow/compiler/xla/client/lib/matrix.h" +#include "tensorflow/compiler/xla/client/lib/slicing.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/shape_util.h" @@ -32,10 +31,18 @@ limitations under the License. #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/core/lib/core/errors.h" -namespace tensorflow { +namespace xla { namespace { +std::vector ConcatVectors(absl::Span xs, + absl::Span ys) { + std::vector output(xs.size() + ys.size()); + std::copy(xs.begin(), xs.end(), output.begin()); + std::copy(ys.begin(), ys.end(), output.begin() + xs.size()); + return output; +} + // Computes a Householder reflection of the form: // H = I - tau v v.T. // such that @@ -65,52 +72,47 @@ namespace { // return (v, tau, beta) // TODO(phawkins): LAPACK's xLARFG implementation has code for handling // overflows in the norm/beta calculations. Perhaps do the same here. -xla::Status House(xla::XlaOp x, xla::XlaOp k, - absl::Span batch_dims, const int64 m, - xla::XlaOp* v, xla::XlaOp* tau, xla::XlaOp* beta) { - xla::XlaBuilder* const builder = x.builder(); - TF_ASSIGN_OR_RETURN(xla::Shape x_shape, builder->GetShape(x)); - const xla::PrimitiveType type = x_shape.element_type(); +Status House(XlaOp x, XlaOp k, absl::Span batch_dims, + const int64 m, XlaOp* v, XlaOp* tau, XlaOp* beta) { + XlaBuilder* const builder = x.builder(); + TF_ASSIGN_OR_RETURN(Shape x_shape, builder->GetShape(x)); + const PrimitiveType type = x_shape.element_type(); std::vector batch_dim_ids(batch_dims.size()); std::iota(batch_dim_ids.begin(), batch_dim_ids.end(), 0); const int64 minor_dim = batch_dims.size(); - xla::XlaOp zero = xla::ScalarLike(x, 0.0); - xla::XlaOp one = xla::ScalarLike(x, 1.0); + XlaOp zero = ScalarLike(x, 0.0); + XlaOp one = ScalarLike(x, 1.0); // alpha = x[k] - xla::XlaOp alpha = - xla::Reshape(DynamicSliceInMinorDims(x, {k}, {1}), batch_dims); + XlaOp alpha = Reshape(DynamicSliceInMinorDims(x, {k}, {1}), batch_dims); // Compute x[k+1:] (padded with zeros in elements 0..k) - xla::XlaOp iota = xla::Iota(builder, xla::S32, m); - xla::XlaOp x_after_k = - xla::Mul(x, xla::ConvertElementType(xla::Gt(iota, k), type), - /*broadcast_dimensions=*/{minor_dim}); + XlaOp iota = Iota(builder, S32, m); + XlaOp x_after_k = Mul(x, ConvertElementType(Gt(iota, k), type), + /*broadcast_dimensions=*/{minor_dim}); // sigma = np.dot(x[k+1:], x[k+1:]) - auto sigma = - xla::Reduce(x_after_k * x_after_k, zero, - xla::CreateScalarAddComputation(type, builder), {minor_dim}); + auto sigma = Reduce(x_after_k * x_after_k, zero, + CreateScalarAddComputation(type, builder), {minor_dim}); // mu = np.sqrt(x[k]*x[k] + sigma) - auto mu = xla::Sqrt(xla::Square(alpha) + sigma); + auto mu = Sqrt(Square(alpha) + sigma); - auto sigma_is_zero = xla::Eq(sigma, zero); + auto sigma_is_zero = Eq(sigma, zero); - *beta = xla::Select(sigma_is_zero, alpha, -xla::Sign(alpha) * mu); - *tau = xla::Select(sigma_is_zero, xla::Broadcast(zero, batch_dims), - (*beta - alpha) / *beta); - auto divisor = xla::Select(sigma_is_zero, xla::Broadcast(one, batch_dims), - alpha - *beta); + *beta = Select(sigma_is_zero, alpha, -Sign(alpha) * mu); + *tau = Select(sigma_is_zero, Broadcast(zero, batch_dims), + (*beta - alpha) / *beta); + auto divisor = + Select(sigma_is_zero, Broadcast(one, batch_dims), alpha - *beta); - auto e_k = xla::Broadcast(xla::ConvertElementType(xla::Eq(iota, k), type), - std::vector(batch_dims.size(), 1)); + auto e_k = Broadcast(ConvertElementType(Eq(iota, k), type), + std::vector(batch_dims.size(), 1)); // Form v as [0, 0, ..., 1] ++ x[k+1:] / divisor // If sigma is zero, x[k+1:] is zero, so use any non-zero divisor. - *v = e_k + - xla::Div(x_after_k, divisor, /*broadcast_dimensions=*/batch_dim_ids); + *v = e_k + Div(x_after_k, divisor, /*broadcast_dimensions=*/batch_dim_ids); return Status::OK(); } @@ -143,94 +145,86 @@ xla::Status House(xla::XlaOp x, xla::XlaOp k, // return (q, vs, taus) struct QRBlockResult { // The factored R value - xla::XlaOp r; + XlaOp r; // Representation of the Householder matrices I - beta v v.T - xla::XlaOp taus; // Shape: [..., n] - xla::XlaOp vs; // Shape: [..., m, n] + XlaOp taus; // Shape: [..., n] + XlaOp vs; // Shape: [..., m, n] }; -xla::StatusOr QRBlock( - xla::XlaOp a, xla::PrecisionConfig::Precision precision) { - xla::XlaBuilder* builder = a.builder(); - TF_ASSIGN_OR_RETURN(xla::Shape a_shape, builder->GetShape(a)); - const int num_dims = xla::ShapeUtil::Rank(a_shape); +StatusOr QRBlock(XlaOp a, PrecisionConfig::Precision precision) { + XlaBuilder* builder = a.builder(); + TF_ASSIGN_OR_RETURN(Shape a_shape, builder->GetShape(a)); + const int num_dims = ShapeUtil::Rank(a_shape); if (num_dims < 2) { - return errors::InvalidArgument("Arguments to QR must have rank >= 2: ", - num_dims); + return InvalidArgument("Argument to QR must have rank >= 2; got shape %s", + a_shape.ToString()); } - xla::PrimitiveType type = a_shape.element_type(); + PrimitiveType type = a_shape.element_type(); - const int64 m = xla::ShapeUtil::GetDimension(a_shape, -2); - const int64 n = xla::ShapeUtil::GetDimension(a_shape, -1); + const int64 m = ShapeUtil::GetDimension(a_shape, -2); + const int64 n = ShapeUtil::GetDimension(a_shape, -1); const int64 num_batch_dims = num_dims - 2; std::vector batch_dims(num_batch_dims); for (int i = 0; i < num_batch_dims; ++i) { - batch_dims[i] = xla::ShapeUtil::GetDimension(a_shape, i); + batch_dims[i] = ShapeUtil::GetDimension(a_shape, i); } std::vector batch_dim_indices(num_batch_dims); std::iota(batch_dim_indices.begin(), batch_dim_indices.end(), 0); - auto qr_body_fn = - [&](xla::XlaOp j, absl::Span values, - xla::XlaBuilder* builder) -> xla::StatusOr> { + auto qr_body_fn = [&](XlaOp j, absl::Span values, + XlaBuilder* builder) -> StatusOr> { auto a = values[0]; auto vs = values[1]; auto taus = values[2]; // v, beta = house(a[:, j], j) auto x = DynamicSliceInMinorDims(a, {j}, {1}); - xla::XlaOp v, tau, beta; - TF_RETURN_IF_ERROR(House(xla::Collapse(x, {num_dims - 2, num_dims - 1}), j, + XlaOp v, tau, beta; + TF_RETURN_IF_ERROR(House(Collapse(x, {num_dims - 2, num_dims - 1}), j, batch_dims, m, &v, &tau, &beta)); std::vector shape = batch_dims; shape.push_back(1); shape.push_back(m); - auto v_broadcast = xla::Reshape(v, shape); + auto v_broadcast = Reshape(v, shape); // a[:, :] -= tau * np.dot(v[:, np.newaxis], // np.dot(v[np.newaxis, :], a[:, :])) - auto vva = - BatchDot(v_broadcast, a, /*transpose_x=*/false, /*transpose_y=*/false, - /*conjugate_x=*/false, /*conjugate_y=*/false, precision); - vva = - BatchDot(v_broadcast, vva, /*transpose_x=*/true, /*transpose_y=*/false, - /*conjugate_x=*/false, /*conjugate_y=*/false, precision); - a = a - xla::Mul(tau, vva, - /*broadcast_dimensions=*/batch_dim_indices); + auto vva = BatchDot(v_broadcast, a, precision); + vva = BatchDot(TransposeInMinorDims(v_broadcast), vva, precision); + a = a - Mul(tau, vva, + /*broadcast_dimensions=*/batch_dim_indices); // It is more precise to populate column 'k' explicitly, rather than // computing it implicitly by applying the Householder transformation. // a[k,k] = beta // a[k+1:,k] = np.zeros([m-k-1], dtype=a.dtype) - auto iota = xla::Reshape(xla::Iota(a.builder(), xla::S32, m), {m, 1}); - auto predecessor_mask = xla::ConvertElementType(xla::Lt(iota, j), type); - auto mask = xla::Broadcast(xla::ConvertElementType(xla::Eq(iota, j), type), - std::vector(batch_dims.size(), 1)); - auto new_x = - xla::Mul(x, predecessor_mask, - /*broadcast_dimensions=*/{num_dims - 2, num_dims - 1}) + - xla::Mul(beta, mask, /*broadcast_dimensions=*/batch_dim_indices); + auto iota = Reshape(Iota(a.builder(), S32, m), {m, 1}); + auto predecessor_mask = ConvertElementType(Lt(iota, j), type); + auto mask = Broadcast(ConvertElementType(Eq(iota, j), type), + std::vector(batch_dims.size(), 1)); + auto new_x = Mul(x, predecessor_mask, + /*broadcast_dimensions=*/{num_dims - 2, num_dims - 1}) + + Mul(beta, mask, /*broadcast_dimensions=*/batch_dim_indices); a = DynamicUpdateSliceInMinorDims(a, new_x, {j}); // vs[:, j] = v vs = DynamicUpdateSliceInMinorDims( - vs, xla::Reshape(v, ConcatVectors(batch_dims, {m, 1})), {j}); + vs, Reshape(v, ConcatVectors(batch_dims, {m, 1})), {j}); // taus[j] = tau taus = DynamicUpdateSliceInMinorDims( - taus, xla::Reshape(tau, ConcatVectors(batch_dims, {1})), {j}); - return std::vector{a, vs, taus}; + taus, Reshape(tau, ConcatVectors(batch_dims, {1})), {j}); + return std::vector{a, vs, taus}; }; - auto vs = xla::Zeros(builder, xla::ShapeUtil::MakeShape( - type, ConcatVectors(batch_dims, {m, n}))); - auto taus = xla::Zeros( - builder, xla::ShapeUtil::MakeShape(type, ConcatVectors(batch_dims, {n}))); + auto vs = Zeros( + builder, ShapeUtil::MakeShape(type, ConcatVectors(batch_dims, {m, n}))); + auto taus = Zeros(builder, + ShapeUtil::MakeShape(type, ConcatVectors(batch_dims, {n}))); - TF_ASSIGN_OR_RETURN(auto values, - XlaForEachIndex(std::min(m, n), xla::S32, qr_body_fn, - {a, vs, taus}, "qr", builder)); + TF_ASSIGN_OR_RETURN(auto values, ForEachIndex(std::min(m, n), S32, qr_body_fn, + {a, vs, taus}, "qr", builder)); QRBlockResult result; result.r = values[0]; @@ -254,62 +248,58 @@ xla::StatusOr QRBlock( // return W // There is no need to return Y since at termination of the loop it is equal to // vs. -xla::StatusOr ComputeWYRepresentation( - xla::PrimitiveType type, absl::Span batch_dims, xla::XlaOp vs, - xla::XlaOp taus, int64 m, int64 n, - xla::PrecisionConfig::Precision precision) { +StatusOr ComputeWYRepresentation(PrimitiveType type, + absl::Span batch_dims, + XlaOp vs, XlaOp taus, int64 m, int64 n, + PrecisionConfig::Precision precision) { std::vector batch_dim_indices(batch_dims.size()); std::iota(batch_dim_indices.begin(), batch_dim_indices.end(), 0); int64 n_index = batch_dims.size() + 1; - auto body_fn = - [&](xla::XlaOp j, absl::Span values, - xla::XlaBuilder* builder) -> xla::StatusOr> { + auto body_fn = [&](XlaOp j, absl::Span values, + XlaBuilder* builder) -> StatusOr> { auto w = values[0]; auto y = values[1]; const auto vs = values[2]; const auto taus = values[3]; // Want j values in range [1, ... n). - j = j + xla::ConstantR0(builder, 1); + j = j + ConstantR0(builder, 1); // vs has shape [..., m, 1] auto v = DynamicSliceInMinorDims(vs, {j}, {1}); // beta has shape [..., 1] auto beta = DynamicSliceInMinorDims(taus, {j}, {1}); // yv has shape [..., n, 1] - auto yv = BatchDot(y, v, /*transpose_x=*/true, /*transpose_y=*/false, - /*conjugate_x=*/false, /*conjugate_y=*/false, precision); + auto yv = BatchDot(TransposeInMinorDims(y), v, precision); // wyv has shape [..., m, 1] - auto wyv = - BatchDot(w, yv, /*transpose_x=*/false, /*transpose_y=*/false, - /*conjugate_x=*/false, /*conjugate_y=*/false, precision); + auto wyv = BatchDot(w, yv, precision); - auto z = xla::Mul( + auto z = Mul( -beta, v + wyv, /*broadcast_dimensions=*/ConcatVectors(batch_dim_indices, {n_index})); w = DynamicUpdateSliceInMinorDims(w, z, {j}); y = DynamicUpdateSliceInMinorDims(y, v, {j}); - return std::vector{w, y, vs, taus}; + return std::vector{w, y, vs, taus}; }; - xla::XlaBuilder* builder = vs.builder(); - auto w = xla::Zeros(builder, xla::ShapeUtil::MakeShape( - type, ConcatVectors(batch_dims, {m, n}))); + XlaBuilder* builder = vs.builder(); + auto w = Zeros(builder, + ShapeUtil::MakeShape(type, ConcatVectors(batch_dims, {m, n}))); auto y = w; auto v = SliceInMinorDims(vs, {0}, {1}); auto beta = SliceInMinorDims(taus, {0}, {1}); y = UpdateSliceInMinorDims(y, v, {0}); - auto bv = xla::Mul( - -beta, v, - /*broadcast_dimensions=*/ConcatVectors(batch_dim_indices, {n_index})); + auto bv = + Mul(-beta, v, + /*broadcast_dimensions=*/ConcatVectors(batch_dim_indices, {n_index})); w = UpdateSliceInMinorDims(w, bv, {0}); TF_ASSIGN_OR_RETURN( - auto values, XlaForEachIndex(n - 1, xla::S32, body_fn, {w, y, vs, taus}, - "wy", builder)); + auto values, + ForEachIndex(n - 1, S32, body_fn, {w, y, vs, taus}, "wy", builder)); return values[0]; } @@ -330,34 +320,34 @@ xla::StatusOr ComputeWYRepresentation( // return (q, a) // TODO(phawkins): consider using UT transformations (in the form I - V U V') // rather than WY transformations. -xla::StatusOr QRDecomposition( - xla::XlaOp a, bool full_matrices, int64 block_size, - xla::PrecisionConfig::Precision precision) { - xla::XlaBuilder* builder = a.builder(); - TF_ASSIGN_OR_RETURN(xla::Shape a_shape, builder->GetShape(a)); - const int num_dims = xla::ShapeUtil::Rank(a_shape); +StatusOr QRDecomposition( + XlaOp a, bool full_matrices, int64 block_size, + PrecisionConfig::Precision precision) { + XlaBuilder* builder = a.builder(); + TF_ASSIGN_OR_RETURN(Shape a_shape, builder->GetShape(a)); + const int num_dims = ShapeUtil::Rank(a_shape); if (num_dims < 2) { - return errors::InvalidArgument("Arguments to QR must have rank >= 2: ", - num_dims); + return InvalidArgument("Arguments to QR must have rank >= 2: got shape %s", + a_shape.ToString()); } - xla::PrimitiveType type = a_shape.element_type(); + PrimitiveType type = a_shape.element_type(); - const int64 m = xla::ShapeUtil::GetDimension(a_shape, -2); - const int64 n = xla::ShapeUtil::GetDimension(a_shape, -1); + const int64 m = ShapeUtil::GetDimension(a_shape, -2); + const int64 n = ShapeUtil::GetDimension(a_shape, -1); const int64 p = std::min(m, n); if (block_size < 1) { - return errors::InvalidArgument( - "block_size argument to QR must be >= 1; got ", block_size); + return InvalidArgument("block_size argument to QR must be >= 1; got %d", + block_size); } const int64 num_batch_dims = num_dims - 2; std::vector batch_dims(num_batch_dims); for (int i = 0; i < num_batch_dims; ++i) { - batch_dims[i] = xla::ShapeUtil::GetDimension(a_shape, i); + batch_dims[i] = ShapeUtil::GetDimension(a_shape, i); } - auto q = xla::Broadcast(xla::IdentityMatrix(builder, type, m, m), batch_dims); + auto q = Broadcast(IdentityMatrix(builder, type, m, m), batch_dims); for (int64 i = 0; i < p; i += block_size) { int64 k = std::min(block_size, p - i); @@ -375,23 +365,15 @@ xla::StatusOr QRDecomposition( // a[i:, i+k:] += np.dot(Y, np.dot(W.T, a[i:, i+k:])) auto a_panel = SliceInMinorDims(a, {i, i + k}, {m, n}); - auto a_update = - BatchDot(w, a_panel, /*transpose_x=*/true, /*transpose_y=*/false, - /*conjugate_x=*/false, /*conjugate_y=*/false, precision); - a_update = - BatchDot(y, a_update, /*transpose_x=*/false, /*transpose_y=*/false, - /*conjugate_x=*/false, /*conjugate_y=*/false, precision); + auto a_update = BatchDot(TransposeInMinorDims(w), a_panel, precision); + a_update = BatchDot(y, a_update, precision); a_panel = a_panel + a_update; a = UpdateSliceInMinorDims(a, a_panel, {i, i + k}); // q[:, i:] += np.dot(np.dot(q[:, i:], W), Y.T)) auto q_panel = SliceInMinorDims(q, {0, i}, {m, m}); - auto q_update = - BatchDot(q_panel, w, /*transpose_x=*/false, /*transpose_y=*/false, - /*conjugate_x=*/false, /*conjugate_y=*/false, precision); - q_update = BatchDot(q_update, y, /*transpose_x=*/false, - /*transpose_y=*/true, /*conjugate_x=*/false, - /*conjugate_y=*/false, precision); + auto q_update = BatchDot(q_panel, w, precision); + q_update = BatchDot(q_update, TransposeInMinorDims(y), precision); q_panel = q_panel + q_update; q = UpdateSliceInMinorDims(q, q_panel, {0, i}); } @@ -408,4 +390,4 @@ xla::StatusOr QRDecomposition( return result; } -} // namespace tensorflow +} // namespace xla diff --git a/tensorflow/compiler/tf2xla/lib/qr.h b/tensorflow/compiler/xla/client/lib/qr.h similarity index 74% rename from tensorflow/compiler/tf2xla/lib/qr.h rename to tensorflow/compiler/xla/client/lib/qr.h index 24b537ac8b63b93e734c3d0e335ea455f7d51a54..827c8eeca05ef09a0d77363eb3c40961b95813d8 100644 --- a/tensorflow/compiler/tf2xla/lib/qr.h +++ b/tensorflow/compiler/xla/client/lib/qr.h @@ -13,13 +13,13 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_COMPILER_TF2XLA_LIB_QR_H_ -#define TENSORFLOW_COMPILER_TF2XLA_LIB_QR_H_ +#ifndef TENSORFLOW_COMPILER_XLA_CLIENT_LIB_QR_H_ +#define TENSORFLOW_COMPILER_XLA_CLIENT_LIB_QR_H_ #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -namespace tensorflow { +namespace xla { // Computes the QR decompositions of a batch of matrices. That is, // given a (batched) matrix a, computes an orthonormal matrix Q and an @@ -29,14 +29,14 @@ namespace tensorflow { // the block size to use. // TODO(phawkins): handle the complex case. struct QRDecompositionResult { - xla::XlaOp q; - xla::XlaOp r; + XlaOp q; + XlaOp r; }; -xla::StatusOr QRDecomposition( - xla::XlaOp a, bool full_matrices, int64 block_size = 128, - xla::PrecisionConfig::Precision precision = xla::PrecisionConfig::HIGHEST); +StatusOr QRDecomposition( + XlaOp a, bool full_matrices, int64 block_size = 128, + PrecisionConfig::Precision precision = PrecisionConfig::HIGHEST); -} // namespace tensorflow +} // namespace xla -#endif // TENSORFLOW_COMPILER_TF2XLA_LIB_QR_H_ +#endif // TENSORFLOW_COMPILER_XLA_CLIENT_LIB_QR_H_ diff --git a/tensorflow/compiler/xla/client/lib/qr_test.cc b/tensorflow/compiler/xla/client/lib/qr_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..b27d364b62444d6d5fb1278b6e6461affc15b2e6 --- /dev/null +++ b/tensorflow/compiler/xla/client/lib/qr_test.cc @@ -0,0 +1,93 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/client/lib/qr.h" + +#include "tensorflow/compiler/xla/array2d.h" +#include "tensorflow/compiler/xla/array3d.h" +#include "tensorflow/compiler/xla/client/lib/matrix.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/literal.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/tests/client_library_test_base.h" +#include "tensorflow/compiler/xla/tests/literal_test_util.h" +#include "tensorflow/compiler/xla/tests/test_macros.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/core/status_test_util.h" + +namespace { + +using QrTest = xla::ClientLibraryTestBase; + +XLA_TEST_F(QrTest, Simple) { + xla::XlaBuilder builder(TestName()); + + xla::Array2D a_vals({ + {4, 6, 8, 10}, + {6, 45, 54, 63}, + {8, 54, 146, 166}, + {10, 63, 166, 310}, + }); + + xla::XlaOp a; + auto a_data = CreateR2Parameter(a_vals, 0, "a", &builder, &a); + TF_ASSERT_OK_AND_ASSIGN( + auto result, + xla::QRDecomposition(a, /*full_matrices=*/true, /*block_size=*/2)); + + // Verifies that the decomposition composes back to the original matrix. + // + // This isn't a terribly demanding test, (e.g., we should verify that Q is + // orthonormal and R is upper-triangular) but it's awkward to write such tests + // without more linear algebra libraries. It's easier to test the numerics + // from Python, anyway, where we have access to numpy and scipy. + xla::BatchDot(result.q, result.r, xla::PrecisionConfig::HIGHEST); + + ComputeAndCompareR2(&builder, a_vals, {a_data.get()}, + xla::ErrorSpec(1e-4, 1e-4)); +} + +XLA_TEST_F(QrTest, SimpleBatched) { + xla::XlaBuilder builder(TestName()); + + xla::Array3D a_vals({ + { + {4, 6, 8, 10}, + {6, 45, 54, 63}, + {8, 54, 146, 166}, + {10, 63, 166, 310}, + }, + { + {16, 24, 8, 12}, + {24, 61, 82, 48}, + {8, 82, 456, 106}, + {12, 48, 106, 62}, + }, + }); + + xla::XlaOp a; + auto a_data = CreateR3Parameter(a_vals, 0, "a", &builder, &a); + TF_ASSERT_OK_AND_ASSIGN( + auto result, + xla::QRDecomposition(a, /*full_matrices=*/true, /*block_size=*/2)); + + xla::BatchDot(result.q, result.r, xla::PrecisionConfig::HIGHEST); + + ComputeAndCompareR3(&builder, a_vals, {a_data.get()}, + xla::ErrorSpec(1e-4, 1e-4)); +} + +} // namespace diff --git a/tensorflow/compiler/xla/client/lib/quantize.h b/tensorflow/compiler/xla/client/lib/quantize.h new file mode 100644 index 0000000000000000000000000000000000000000..26dbbd5b00bd1a29f4047c9a4294fcac7340cf6c --- /dev/null +++ b/tensorflow/compiler/xla/client/lib/quantize.h @@ -0,0 +1,186 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_CLIENT_LIB_QUANTIZE_H_ +#define TENSORFLOW_COMPILER_XLA_CLIENT_LIB_QUANTIZE_H_ + +#include +#include +#include + +#include "tensorflow/compiler/xla/client/lib/constants.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/bfloat16/bfloat16.h" + +namespace xla { + +constexpr int64 kBitsOfByte = 8; + +// Represents the range used for quantization +struct QuantizedRange { + QuantizedRange() = default; + QuantizedRange(float min_in, float max_in) : min(min_in), max(max_in) {} + + bool operator==(const QuantizedRange& rhs) const { + return this->min == rhs.min && this->max == rhs.max; + } + + bool operator!=(const QuantizedRange& rhs) const { return !(*this == rhs); } + + tensorflow::bfloat16 min = tensorflow::bfloat16(0.0f); + tensorflow::bfloat16 max = tensorflow::bfloat16(0.0f); +}; + +template +inline std::vector PackToUint32(absl::Span input) { + const int64 kElementsPerPack = sizeof(uint32) / sizeof(T); + const int64 input_size = input.size(); + const int64 output_size = CeilOfRatio(input_size, kElementsPerPack); + + std::vector output_vec; + constexpr int64 kShiftBits = sizeof(T) / sizeof(uint8) * kBitsOfByte; + + for (int64 i = 0; i < output_size; i++) { + uint32 result = 0; + for (int64 p = 0; p < kElementsPerPack; p++) { + int64 index = i * kElementsPerPack + p; + if (index < input_size) { + int64 total_shift_bits = kShiftBits * (kElementsPerPack - p - 1); + result |= (input[index] << total_shift_bits); + } + } + output_vec.push_back(result); + } + + return output_vec; +} + +// Dequantize the quantized input of packed uint32 to bfloat16. +// Only uint8 or uint16 is supported for the original unpacked input. +// Returns a tensor of shape [d0,..., dn * unpack_size] if +// input shape is [d0, ..., dn], where unpack_size = sizeof(unit32) / sizeof(T). +// If transpose_output is true, will return a tensor of shape +// [dn * unpack_size, dn-1, ..., d1, d0]. transpose_output is faster when +// input's rank higher than 1. The input needs to be transposed to use +// transpose_output feature. +template +inline XlaOp Dequantize(XlaOp input, const QuantizedRange& range, + absl::string_view mode_string = "MIN_COMBINED", + bool transpose_output = false) { + XlaBuilder* const builder = input.builder(); + return builder->ReportErrorOrReturn([&]() -> StatusOr { + float half_range = + !std::is_signed::value + ? 0.0f + : (static_cast(std::numeric_limits::max()) - + std::numeric_limits::min() + 1) / + 2.0f; + const int64 unpack_size = sizeof(uint32) / sizeof(T); + TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(input)); + + auto element_type = shape.element_type(); + if (element_type != U32) { + return InvalidArgument( + "Only U32 is supported for input type of xla::Dequantize Op."); + } + + // Broadcast the input to [unpack_size, d0, ..., dn] if input size is + // [d0, ..., dn]. + auto broadcast_input = Broadcast(input, {unpack_size}); + + XlaOp iota_r1 = Iota(builder, U32, unpack_size); + // Highest significant bytes needs to shift more bytes than lower + // significant bytes. + XlaOp shift_bytes = + xla::ConstantR0(builder, unpack_size - 1) - iota_r1; + + const int bytes_of_type = sizeof(T) / sizeof(uint8); + std::vector shift_vec(unpack_size, kBitsOfByte * bytes_of_type); + XlaOp shift_bits = + shift_bytes * xla::ConstantR1(builder, shift_vec); + + // Make bit_mask for different data type T. + uint32 bit_mask = 0x00000000; + for (int i = 0; i < bytes_of_type; i++) { + bit_mask <<= kBitsOfByte; + bit_mask |= 0x000000ff; + } + + std::vector shift_transpose_dimensions(shape.dimensions_size()); + std::iota(shift_transpose_dimensions.begin(), + shift_transpose_dimensions.end(), 0); + shift_transpose_dimensions.insert(shift_transpose_dimensions.begin(), 1, + shape.dimensions_size()); + + // Shift the input by sizeof(T) bytes and apply bit_mask to unpack. + XlaOp shifted_input = ShiftRightLogical( + broadcast_input, Transpose(Broadcast(shift_bits, shape.dimensions()), + shift_transpose_dimensions)); + XlaOp unpack_input = + And(shifted_input, xla::ConstantR0(builder, bit_mask)); + + XlaOp result; + + if (mode_string == "MIN_COMBINED") { + const tensorflow::bfloat16 scale_factor = + (range.max - range.min) / + (static_cast(std::numeric_limits::max() - + std::numeric_limits::min())); + // result = bfloat16(input + half_range) * scale_factor + range.min + XlaOp unpack_input_bf16 = ConvertElementType(unpack_input, BF16); + XlaOp half_range_bf16 = xla::ConstantR0( + builder, static_cast(half_range)); + XlaOp sum = unpack_input_bf16 + half_range_bf16; + + result = + sum * xla::ConstantR0(builder, scale_factor) + + xla::ConstantR0(builder, range.min); + } else { + // TODO(wangtao): support other modes. + return InvalidArgument( + "Only MIN_COMBINED mode is supported in xla::Dequantize Op."); + } + + std::vector transpose_dimensions(shape.dimensions_size()); + std::iota(transpose_dimensions.begin(), transpose_dimensions.end(), 1); + std::reverse(transpose_dimensions.begin(), transpose_dimensions.end()); + transpose_dimensions.insert(transpose_dimensions.begin() + 1, 1, 0); + + // Transpose the result to be [dn, unpack_size, dn-1, ..., d1, d0]. + XlaOp transposed_result = Transpose(result, transpose_dimensions); + + // Reshape to be [dn * unpack_size, dn-1, ..., d1, d0]. + XlaOp reshaped_result = Collapse(transposed_result, {0, 1}); + + // Return the transpose result if transpose_output is true. + if (transpose_output) { + return reshaped_result; + } + + // Transpose the result to be [d0, d1, ..., dn-1, dn * unpack_size]. + std::vector result_dimensions(shape.dimensions_size()); + std::iota(result_dimensions.begin(), result_dimensions.end(), 0); + std::reverse(result_dimensions.begin(), result_dimensions.end()); + + return Transpose(reshaped_result, result_dimensions); + }); +} + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_CLIENT_LIB_QUANTIZE_H_ diff --git a/tensorflow/compiler/xla/client/lib/quantize_test.cc b/tensorflow/compiler/xla/client/lib/quantize_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..be3603d9e11670913c21a834d2216a999306d582 --- /dev/null +++ b/tensorflow/compiler/xla/client/lib/quantize_test.cc @@ -0,0 +1,337 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/client/lib/quantize.h" + +#include + +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/tests/client_library_test_base.h" +#include "tensorflow/compiler/xla/tests/test_macros.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/util.h" + +namespace xla { +namespace { + +using bfloat16 = tensorflow::bfloat16; + +template +std::vector GenerateInput() { + std::vector input; + + for (int64 i = std::numeric_limits::min(); + i < std::numeric_limits::max(); ++i) { + input.push_back(static_cast(i)); + } + + return input; +} + +template +Array2D GenerateLargeSizeInput(int num_columns, int num_rows) { + Array2D input(num_columns, num_rows); + + input.FillRandom(6, 128); + + return input; +} + +template +Array2D PackLargeInput(Array2D &input) { + const int64 size_per_pack = sizeof(uint32) / sizeof(NativeT); + int64 width = input.width(); + + int64 padded_output_width = CeilOfRatio(width, size_per_pack); + + Array2D pack_input(input.height(), padded_output_width); + + for (int h = 0; h < input.height(); h++) { + std::vector input_row; + for (int w = 0; w < width; w++) { + input_row.push_back(input({h, w})); + } + + auto pack_input_vec = PackToUint32(input_row); + + for (int w = 0; w < padded_output_width; w++) { + pack_input(h, w) = pack_input_vec[w]; + } + } + + return pack_input; +} + +template +Array2D GenerateLargeSizeMinCombinedOutput( + Array2D &input, const QuantizedRange &range, + bool transpose_output = false) { + const int64 size_per_pack = sizeof(uint32) / sizeof(NativeT); + int64 width = input.width(); + + int64 padded_output_width = CeilOfRatio(width, size_per_pack) * size_per_pack; + + int64 output_height; + int64 output_width; + + if (transpose_output) { + output_height = padded_output_width; + output_width = input.height(); + } else { + output_height = input.height(); + output_width = padded_output_width; + } + + Array2D output(output_height, output_width, bfloat16(0.0)); + + float half_range = + !std::is_signed::value + ? 0.0f + : (static_cast(std::numeric_limits::max() - + std::numeric_limits::min() + 1)) / + 2.0f; + const bfloat16 scale_factor = + (range.max - range.min) / + (static_cast(std::numeric_limits::max() - + std::numeric_limits::min())); + + for (int h = 0; h < input.height(); h++) { + std::vector input_row; + for (int w = 0; w < width; w++) { + bfloat16 result = + static_cast(input(h, w) + half_range) * scale_factor + + range.min; + if (transpose_output) { + output(w, h) = result; + } else { + output(h, w) = result; + } + } + } + + return output; +} + +template +std::vector GenerateMinCombinedOutput(const QuantizedRange &range) { + float half_range = + !std::is_signed::value + ? 0.0f + : (static_cast(std::numeric_limits::max() - + std::numeric_limits::min() + 1)) / + 2.0f; + const bfloat16 scale_factor = + (range.max - range.min) / + (static_cast(std::numeric_limits::max() - + std::numeric_limits::min())); + std::vector output; + for (int64 i = std::numeric_limits::min(); + i < std::numeric_limits::max(); ++i) { + bfloat16 result = + static_cast(i + half_range) * scale_factor + range.min; + output.push_back(result); + } + + const int64 pack_size = sizeof(uint32) / sizeof(NativeT); + const int64 output_size = output.size(); + + int64 num_tailing_zeros = + CeilOfRatio(output_size, pack_size) * pack_size - output_size; + + output.insert(output.end(), num_tailing_zeros, bfloat16(0.0)); + return output; +} + +// TODO(wangtao): add a test to make sure this op is the inverse of the existing +// TF quantize op defined in: third_party/tensorflow/core/kernels/quantize_op.cc + +using DequantizeTest = ClientLibraryTestBase; + +TEST(PackTest, PackUint8ToUint32) { + std::vector input = {0xAB, 0x0B, 0x00, 0xF0, 0x01}; + auto output = PackToUint32(input); + EXPECT_THAT(output, ::testing::ElementsAre(0xAB0B00F0, 0x01000000)); +} + +TEST(PackTest, PackInt8ToUint32) { + std::vector input = {static_cast(0x81), 0x0B, 0x00, 0x20, + 0x01}; + auto output = PackToUint32(input); + EXPECT_THAT(output, ::testing::ElementsAre(0x810B0020, 0x01000000)); +} + +TEST(PackTest, PackUint8ToUint32PerfectSize) { + std::vector input = {3, 2, 1, 0}; + auto output = PackToUint32(input); + EXPECT_THAT(output, ::testing::ElementsAre(0x03020100)); +} + +XLA_TEST_F(DequantizeTest, MinCombinedUint16R1) { + XlaBuilder builder(TestName()); + auto input = GenerateInput(); + auto x = ConstantR1(&builder, PackToUint32(input)); + QuantizedRange range(0, 255.0f); + xla::Dequantize(x, range, "MIN_COMBINED"); + auto expected = GenerateMinCombinedOutput(range); + ComputeAndCompareR1(&builder, expected, {}); +} + +XLA_TEST_F(DequantizeTest, MinCombinedUint8R1) { + XlaBuilder builder(TestName()); + auto input = GenerateInput(); + auto x = ConstantR1(&builder, PackToUint32(input)); + QuantizedRange range(0, 127.0f); + xla::Dequantize(x, range, "MIN_COMBINED"); + auto expected = GenerateMinCombinedOutput(range); + ComputeAndCompareR1(&builder, expected, {}); +} + +XLA_TEST_F(DequantizeTest, MinCombinedUint8R2) { + XlaBuilder builder(TestName()); + std::vector> input = { + {0, 1, 2, 3}, + {4, 5, 6, 7}, + {8, 9, 10, 11}, + {12, 13, 16, 15}, + }; + auto x = ConstantR2(&builder, {{PackToUint32(input[0])[0]}, + {PackToUint32(input[1])[0]}, + {PackToUint32(input[2])[0]}, + {PackToUint32(input[3])[0]}}); + QuantizedRange range(0, 255.0f); + xla::Dequantize(x, range, "MIN_COMBINED"); + const Array2D expected = { + {bfloat16(0.0), bfloat16(1.0), bfloat16(2.0), bfloat16(3.0)}, + {bfloat16(4.0), bfloat16(5.0), bfloat16(6.0), bfloat16(7.0)}, + {bfloat16(8.0), bfloat16(9.0), bfloat16(10.0), bfloat16(11.0)}, + {bfloat16(12.0), bfloat16(13.0), bfloat16(16.0), bfloat16(15.0)}, + }; + ComputeAndCompareR2(&builder, expected, {}); +} + +XLA_TEST_F(DequantizeTest, MinCombinedUint8R2TransposeOutput) { + XlaBuilder builder(TestName()); + std::vector> input = { + {0, 1, 2, 3}, + {4, 5, 6, 7}, + {8, 9, 10, 11}, + {12, 13, 16, 15}, + }; + auto x = ConstantR2(&builder, {{PackToUint32(input[0])[0]}, + {PackToUint32(input[1])[0]}, + {PackToUint32(input[2])[0]}, + {PackToUint32(input[3])[0]}}); + QuantizedRange range(0, 255.0f); + xla::Dequantize(x, range, "MIN_COMBINED", /*transpose_output=*/true); + const Array2D expected = { + {bfloat16(0.0), bfloat16(4.0), bfloat16(8.0), bfloat16(12.0)}, + {bfloat16(1.0), bfloat16(5.0), bfloat16(9.0), bfloat16(13.0)}, + {bfloat16(2.0), bfloat16(6.0), bfloat16(10.0), bfloat16(16.0)}, + {bfloat16(3.0), bfloat16(7.0), bfloat16(11.0), bfloat16(15.0)}, + }; + ComputeAndCompareR2(&builder, expected, {}); +} + +XLA_TEST_F(DequantizeTest, MinCombinedUint8R2TailingZero) { + XlaBuilder builder(TestName()); + std::vector> input = { + {0, 1, 2, 3, 16}, + {4, 5, 6, 7, 17}, + {8, 9, 10, 11, 18}, + {12, 13, 16, 15, 19}, + }; + auto x = ConstantR2( + &builder, + {{PackToUint32(input[0])[0], PackToUint32(input[0])[1]}, + {PackToUint32(input[1])[0], PackToUint32(input[1])[1]}, + {PackToUint32(input[2])[0], PackToUint32(input[2])[1]}, + {PackToUint32(input[3])[0], PackToUint32(input[3])[1]}}); + QuantizedRange range(0, 255.0f); + xla::Dequantize(x, range, "MIN_COMBINED"); + + const Array2D expected = { + {bfloat16(0.0), bfloat16(1.0), bfloat16(2.0), bfloat16(3.0), + bfloat16(16.0), bfloat16(0.0), bfloat16(0.0), bfloat16(0.0)}, + {bfloat16(4.0), bfloat16(5.0), bfloat16(6.0), bfloat16(7.0), + bfloat16(17.0), bfloat16(0.0), bfloat16(0.0), bfloat16(0.0)}, + {bfloat16(8.0), bfloat16(9.0), bfloat16(10.0), bfloat16(11.0), + bfloat16(18.0), bfloat16(0.0), bfloat16(0.0), bfloat16(0.0)}, + {bfloat16(12.0), bfloat16(13.0), bfloat16(16.0), bfloat16(15.0), + bfloat16(19.0), bfloat16(0.0), bfloat16(0.0), bfloat16(0.0)}, + }; + ComputeAndCompareR2(&builder, expected, {}); +} + +XLA_TEST_F(DequantizeTest, MinCombinedUint8R2TailingZeroTransposeOutput) { + XlaBuilder builder(TestName()); + std::vector> input = { + {0, 1, 2, 3, 16}, + {4, 5, 6, 7, 17}, + {8, 9, 10, 11, 18}, + {12, 13, 16, 15, 19}, + }; + auto x = ConstantR2( + &builder, + {{PackToUint32(input[0])[0], PackToUint32(input[0])[1]}, + {PackToUint32(input[1])[0], PackToUint32(input[1])[1]}, + {PackToUint32(input[2])[0], PackToUint32(input[2])[1]}, + {PackToUint32(input[3])[0], PackToUint32(input[3])[1]}}); + QuantizedRange range(0, 255.0f); + xla::Dequantize(x, range, "MIN_COMBINED", /*transpose_output=*/true); + + const Array2D expected = { + {bfloat16(0.0), bfloat16(4.0), bfloat16(8.0), bfloat16(12.0)}, + {bfloat16(1.0), bfloat16(5.0), bfloat16(9.0), bfloat16(13.0)}, + {bfloat16(2.0), bfloat16(6.0), bfloat16(10.0), bfloat16(16.0)}, + {bfloat16(3.0), bfloat16(7.0), bfloat16(11.0), bfloat16(15.0)}, + {bfloat16(16.0), bfloat16(17.0), bfloat16(18.0), bfloat16(19.0)}, + {bfloat16(0.0), bfloat16(0.0), bfloat16(0.0), bfloat16(0.0)}, + {bfloat16(0.0), bfloat16(0.0), bfloat16(0.0), bfloat16(0.0)}, + {bfloat16(0.0), bfloat16(0.0), bfloat16(0.0), bfloat16(0.0)}, + }; + ComputeAndCompareR2(&builder, expected, {}); +} + +XLA_TEST_F(DequantizeTest, MinCombinedUint8LargeSizeTest) { + XlaBuilder builder(TestName()); + Array2D input = GenerateLargeSizeInput(500, 3547); + Array2D input_packed = PackLargeInput(input); + + auto x = ConstantR2FromArray2D(&builder, input_packed); + QuantizedRange range(0, 255.0f); + xla::Dequantize(x, range, "MIN_COMBINED"); + + const Array2D expected = + GenerateLargeSizeMinCombinedOutput(input, range); + ComputeAndCompareR2(&builder, expected, {}); +} + +XLA_TEST_F(DequantizeTest, MinCombinedUint8LargeSizeTestTransposeOutput) { + XlaBuilder builder(TestName()); + Array2D input = GenerateLargeSizeInput(500, 3547); + Array2D input_packed = PackLargeInput(input); + + auto x = ConstantR2FromArray2D(&builder, input_packed); + QuantizedRange range(0, 255.0f); + xla::Dequantize(x, range, "MIN_COMBINED", /*transpose_output=*/true); + + const Array2D expected = GenerateLargeSizeMinCombinedOutput( + input, range, /*transpose_output=*/true); + ComputeAndCompareR2(&builder, expected, {}); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/client/lib/slicing.cc b/tensorflow/compiler/xla/client/lib/slicing.cc new file mode 100644 index 0000000000000000000000000000000000000000..f8c7df3ff5189c817202eaf39adb572f7e232ec2 --- /dev/null +++ b/tensorflow/compiler/xla/client/lib/slicing.cc @@ -0,0 +1,134 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/client/lib/slicing.h" + +namespace xla { + +XlaOp SliceInMinorDims(XlaOp x, absl::Span start, + absl::Span end) { + XlaBuilder* builder = x.builder(); + return builder->ReportErrorOrReturn([&]() -> StatusOr { + TF_RET_CHECK(start.size() == end.size()); + int64 n_minor_dims = start.size(); + + TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(x)); + + const int64 n_dims = ShapeUtil::Rank(shape); + TF_RET_CHECK(n_minor_dims <= n_dims); + auto major_dims = AsInt64Slice(shape.dimensions()) + .subspan( + /*pos=*/0, + /*len=*/n_dims - n_minor_dims); + + // Prepends 0s in the major dim + std::vector padded_start(n_dims, 0); + std::copy(start.begin(), start.end(), + padded_start.begin() + major_dims.size()); + + // Prepends the shape of the major dims. + std::vector padded_end(n_dims); + std::copy(major_dims.begin(), major_dims.end(), padded_end.begin()); + std::copy(end.begin(), end.end(), padded_end.begin() + major_dims.size()); + + std::vector strides(n_dims, 1); + return Slice(x, padded_start, padded_end, strides); + }); +} + +XlaOp UpdateSlice(XlaOp x, XlaOp update, absl::Span start) { + XlaBuilder* builder = x.builder(); + return builder->ReportErrorOrReturn([&]() -> StatusOr { + // TODO(phawkins): make int64 work on all backends, remove the int32 cast. + std::vector start_as_int32(start.begin(), start.end()); + auto start_constant = ConstantR1(builder, start_as_int32); + TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(x)); + const int64 n_dims = ShapeUtil::Rank(shape); + TF_ASSIGN_OR_RETURN(Shape start_constant_shape, + builder->GetShape(start_constant)); + const int64 start_length = + ShapeUtil::GetDimension(start_constant_shape, -1); + TF_RET_CHECK(start_length == n_dims); + return DynamicUpdateSlice(x, update, start_constant); + }); +} + +XlaOp UpdateSliceInMinorDims(XlaOp x, XlaOp update, + absl::Span start) { + XlaBuilder* builder = x.builder(); + return builder->ReportErrorOrReturn([&]() -> StatusOr { + TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(x)); + const int64 n_dims = ShapeUtil::Rank(shape); + const int64 n_minor_dims = start.size(); + TF_RET_CHECK(n_minor_dims <= n_dims); + std::vector padded_start(n_dims, 0); + std::copy(start.begin(), start.end(), + padded_start.begin() + (n_dims - n_minor_dims)); + return UpdateSlice(x, update, padded_start); + }); +} + +namespace { + +std::vector ConcatVectors(absl::Span xs, + absl::Span ys) { + std::vector output(xs.size() + ys.size()); + std::copy(xs.begin(), xs.end(), output.begin()); + std::copy(ys.begin(), ys.end(), output.begin() + xs.size()); + return output; +} + +XlaOp PrependZerosInMajorDims(XlaOp x, absl::Span starts) { + XlaBuilder* builder = x.builder(); + return builder->ReportErrorOrReturn([&]() -> StatusOr { + TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(x)); + const int64 n_dims = ShapeUtil::Rank(shape); + auto zero = Reshape(ConstantR0(builder, 0), {1}); + std::vector padded_starts(n_dims, zero); + for (int i = 0; i < starts.size(); ++i) { + padded_starts[n_dims - starts.size() + i] = Reshape(starts[i], {1}); + } + return ConcatInDim(builder, padded_starts, 0); + }); +} + +} // namespace + +XlaOp DynamicSliceInMinorDims(XlaOp x, absl::Span starts, + absl::Span sizes) { + XlaBuilder* builder = x.builder(); + return builder->ReportErrorOrReturn([&]() -> StatusOr { + TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(x)); + const int64 n_dims = ShapeUtil::Rank(shape); + int64 n_minor_dims = starts.size(); + TF_RET_CHECK(n_minor_dims == sizes.size()); + TF_RET_CHECK(n_minor_dims <= n_dims); + auto major_dims = AsInt64Slice(shape.dimensions()) + .subspan( + /*pos=*/0, + /*len=*/n_dims - sizes.size()); + auto padded_starts = PrependZerosInMajorDims(x, starts); + auto padded_sizes = ConcatVectors(major_dims, sizes); + return DynamicSlice(x, padded_starts, padded_sizes); + }); +} + +XlaOp DynamicUpdateSliceInMinorDims(XlaOp x, XlaOp update, + absl::Span starts) { + auto padded_starts = PrependZerosInMajorDims(x, starts); + return DynamicUpdateSlice(x, update, padded_starts); +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/client/lib/slicing.h b/tensorflow/compiler/xla/client/lib/slicing.h new file mode 100644 index 0000000000000000000000000000000000000000..6c482a38b5489c9fb17c3dca9ee3d2a1b8fd1890 --- /dev/null +++ b/tensorflow/compiler/xla/client/lib/slicing.h @@ -0,0 +1,48 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "absl/types/span.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/types.h" + +#ifndef TENSORFLOW_COMPILER_XLA_CLIENT_LIB_SLICING_H_ +#define TENSORFLOW_COMPILER_XLA_CLIENT_LIB_SLICING_H_ + +namespace xla { + +// Updates a slice of 'x', i.e., +// x[start[0], ..., start[n]] = update +XlaOp UpdateSlice(XlaOp x, XlaOp update, absl::Span start); + +// Performs a slice in the minor dimensions of a tensor. +// x[..., start[0]:end[0], ..., start[n]:end[n]] +XlaOp SliceInMinorDims(XlaOp x, absl::Span start, + absl::Span end); + +// Updates a slice of 'x', where 'start' contains a list of minor dimensions: +// x[..., start[0]:..., ..., start[n]:...] = update +XlaOp UpdateSliceInMinorDims(XlaOp x, XlaOp update, + absl::Span start); + +// Performs a dynamic slice in the minor dimensions of a tensor. +XlaOp DynamicSliceInMinorDims(XlaOp x, absl::Span starts, + absl::Span sizes); + +XlaOp DynamicUpdateSliceInMinorDims(XlaOp x, XlaOp update, + absl::Span starts); + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_CLIENT_LIB_SLICING_H_ diff --git a/tensorflow/compiler/tf2xla/lib/util_test.cc b/tensorflow/compiler/xla/client/lib/slicing_test.cc similarity index 67% rename from tensorflow/compiler/tf2xla/lib/util_test.cc rename to tensorflow/compiler/xla/client/lib/slicing_test.cc index 442fe92c34ca26cb1a854cc90da8dc034bca79bb..8d362119e01006555db0f82d02626175936e1d05 100644 --- a/tensorflow/compiler/tf2xla/lib/util_test.cc +++ b/tensorflow/compiler/xla/client/lib/slicing_test.cc @@ -13,28 +13,19 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/tf2xla/lib/util.h" +#include "tensorflow/compiler/xla/client/lib/slicing.h" -#include -#include -#include - -#include "tensorflow/compiler/tf2xla/lib/batch_dot.h" -#include "tensorflow/compiler/xla/array2d.h" -#include "tensorflow/compiler/xla/literal.h" -#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" -#include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/compiler/xla/tests/test_macros.h" #include "tensorflow/compiler/xla/types.h" -#include "tensorflow/core/lib/core/status_test_util.h" -namespace tensorflow { +namespace xla { namespace { -using UtilTest = xla::ClientLibraryTestBase; -using UtilLeftLookingTest = xla::ClientLibraryTestBase; +using SlicingTest = xla::ClientLibraryTestBase; xla::Array2D BValsRight() { return {{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}}; @@ -63,7 +54,7 @@ xla::Array3D BatchedAValsFull() { }}; } -XLA_TEST_F(UtilTest, Simple2dLookup) { +XLA_TEST_F(SlicingTest, Simple2dLookup) { xla::XlaBuilder builder(TestName()); xla::XlaOp a, x, y; @@ -77,7 +68,7 @@ XLA_TEST_F(UtilTest, Simple2dLookup) { xla::ErrorSpec(1e-2, 1e-2)); } -XLA_TEST_F(UtilTest, Simple3dLookup) { +XLA_TEST_F(SlicingTest, Simple3dLookup) { xla::XlaBuilder builder(TestName()); xla::XlaOp a, index; @@ -92,7 +83,7 @@ XLA_TEST_F(UtilTest, Simple3dLookup) { {a_data.get(), index_data.get()}); } -XLA_TEST_F(UtilTest, SimpleSliceUpdate) { +XLA_TEST_F(SlicingTest, SimpleSliceUpdate) { xla::XlaBuilder builder(TestName()); xla::XlaOp a, b, x, y; @@ -111,26 +102,5 @@ XLA_TEST_F(UtilTest, SimpleSliceUpdate) { {a_data.get(), b_data.get(), x_data.get(), y_data.get()}); } -XLA_TEST_F(UtilTest, RowBatchDot) { - xla::XlaBuilder builder(TestName()); - - int n = 4; - - xla::XlaOp a, row, index; - auto a_data = - CreateR3Parameter(BatchedAValsFull(), 0, "a", &builder, &a); - auto row_data = CreateR3Parameter({{{9, 1, 0, 0}}, {{2, 4, 0, 0}}}, 1, - "row", &builder, &row); - // Select {{3, 6, 0, 1}, {24, 61, 82, 48}} out of BatchedAValsFull(). - auto index_data = CreateR0Parameter(1, 2, "index", &builder, &index); - - auto l_index = DynamicSliceInMinorDims( - a, {index, xla::ConstantR0(&builder, 0)}, {1, n}); - BatchDot(l_index, row, /*transpose_x=*/false, /*transpose_y=*/true); - - ComputeAndCompareR3(&builder, {{{33}}, {{292}}}, - {a_data.get(), row_data.get(), index_data.get()}); -} - } // namespace -} // namespace tensorflow +} // namespace xla diff --git a/tensorflow/compiler/xla/client/lib/testing.cc b/tensorflow/compiler/xla/client/lib/testing.cc index a95bbf2c8c860914877d3195b97342097dafc725..5db9d10dff4c50d71cde934b3f3c345bee571f29 100644 --- a/tensorflow/compiler/xla/client/lib/testing.cc +++ b/tensorflow/compiler/xla/client/lib/testing.cc @@ -59,22 +59,25 @@ XlaOp BuildFakeDataOpOnDevice(const Shape& shape, XlaBuilder* builder) { return Tuple(builder, parts); } -std::unique_ptr MakeFakeDataViaDeviceOrDie(const Shape& shape, - Client* client) { +std::unique_ptr MakeFakeDataViaDeviceOrDie( + const Shape& shape, Client* client, DebugOptions* debug_opts) { XlaBuilder b(absl::StrCat("make_fake_", ShapeUtil::HumanString(shape))); BuildFakeDataOpOnDevice(shape, &b); XlaComputation computation = b.Build().ConsumeValueOrDie(); auto execution_options = CreateDefaultExecutionOptions(); *execution_options.mutable_shape_with_output_layout() = shape.ToProto(); + if (debug_opts) { + *execution_options.mutable_debug_options() = *debug_opts; + } return client->Execute(computation, /*arguments=*/{}, &execution_options) .ConsumeValueOrDie(); } } // namespace -std::unique_ptr MakeFakeDataOrDie(const Shape& shape, - Client* client) { +std::unique_ptr MakeFakeDataOrDie( + const Shape& shape, Client* client, DebugOptions* debug_opts /*=nullptr*/) { if (DataSizeOfShape(shape) < (1LL << 20)) { StatusOr literal_status = MakeFakeLiteral(shape); if (!literal_status.ok()) { @@ -82,24 +85,25 @@ std::unique_ptr MakeFakeDataOrDie(const Shape& shape, // an on-device computation. CHECK_EQ(literal_status.status().code(), tensorflow::error::UNIMPLEMENTED); - return MakeFakeDataViaDeviceOrDie(shape, client); + return MakeFakeDataViaDeviceOrDie(shape, client, debug_opts); } return client->TransferToServer(literal_status.ValueOrDie()).ValueOrDie(); } // If the data is large, generate it on-device. - return MakeFakeDataViaDeviceOrDie(shape, client); + return MakeFakeDataViaDeviceOrDie(shape, client, debug_opts); } std::vector> MakeFakeArgumentsOrDie( - const XlaComputation& computation, Client* client) { + const XlaComputation& computation, Client* client, + DebugOptions* debug_opts /*=nullptr*/) { CHECK(computation.proto().has_host_program_shape()) << "Computation should have progran shape."; auto program_shape = computation.proto().host_program_shape(); std::vector> results; for (const ShapeProto& shape : program_shape.parameters()) { - results.push_back(MakeFakeDataOrDie(Shape(shape), client)); + results.push_back(MakeFakeDataOrDie(Shape(shape), client, debug_opts)); } return results; } diff --git a/tensorflow/compiler/xla/client/lib/testing.h b/tensorflow/compiler/xla/client/lib/testing.h index 03695ce2a339735e3e49522f4fe1bbf2d83a3834..428fa3e93d1b46983aae60176e7c2242d2552fdb 100644 --- a/tensorflow/compiler/xla/client/lib/testing.h +++ b/tensorflow/compiler/xla/client/lib/testing.h @@ -29,14 +29,19 @@ namespace xla { // Generates fake data of the given shape on the device or dies. The fake data // is created by performing a computation on the device rather than transferring // data from the host to the device. -std::unique_ptr MakeFakeDataOrDie(const Shape& shape, - Client* client); +// +// The optional DebugOptions are used when generating fake data on the device. +std::unique_ptr MakeFakeDataOrDie( + const Shape& shape, Client* client, DebugOptions* debug_opts = nullptr); // Returns vector of GlobalData handles of fake data (created using // MakeFakeDataOrDie) that are correctly shaped arguments for the given // xla computation. +// +// The optional DebugOptions are used when generating fake data on the device. std::vector> MakeFakeArgumentsOrDie( - const XlaComputation& computation, Client* client); + const XlaComputation& computation, Client* client, + DebugOptions* debug_opts = nullptr); } // namespace xla diff --git a/tensorflow/compiler/tf2xla/lib/triangular_solve.cc b/tensorflow/compiler/xla/client/lib/triangular_solve.cc similarity index 62% rename from tensorflow/compiler/tf2xla/lib/triangular_solve.cc rename to tensorflow/compiler/xla/client/lib/triangular_solve.cc index 6524c2a9b1ada632d80edd234272760c2b545cc4..ac58090dfe33a8ae350019771e0b970d6f26e476 100644 --- a/tensorflow/compiler/tf2xla/lib/triangular_solve.cc +++ b/tensorflow/compiler/xla/client/lib/triangular_solve.cc @@ -13,15 +13,15 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/tf2xla/lib/triangular_solve.h" +#include "tensorflow/compiler/xla/client/lib/triangular_solve.h" #include #include -#include "tensorflow/compiler/tf2xla/lib/batch_dot.h" -#include "tensorflow/compiler/tf2xla/lib/util.h" #include "tensorflow/compiler/xla/client/lib/constants.h" -#include "tensorflow/compiler/xla/client/lib/numeric.h" +#include "tensorflow/compiler/xla/client/lib/math.h" +#include "tensorflow/compiler/xla/client/lib/matrix.h" +#include "tensorflow/compiler/xla/client/lib/slicing.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/literal.h" @@ -29,21 +29,20 @@ limitations under the License. #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/util.h" -#include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/math/math_util.h" -namespace tensorflow { +namespace xla { // Get the diagonal blocks of the coefficient matrix -xla::XlaOp DiagonalBlocks(xla::XlaOp a, int64 block_size) { - xla::XlaBuilder* builder = a.builder(); - return builder->ReportErrorOrReturn([&]() -> xla::StatusOr { - TF_ASSIGN_OR_RETURN(xla::Shape shape, builder->GetShape(a)); - int ndims = xla::ShapeUtil::Rank(shape); - int64 n = xla::ShapeUtil::GetDimension(shape, -1); +XlaOp DiagonalBlocks(XlaOp a, int64 block_size) { + XlaBuilder* builder = a.builder(); + return builder->ReportErrorOrReturn([&]() -> StatusOr { + TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(a)); + int ndims = ShapeUtil::Rank(shape); + int64 n = ShapeUtil::GetDimension(shape, -1); int64 num_blocks = n / block_size; - xla::XlaOp diag_blocks; + XlaOp diag_blocks; // If the coefficient matrix is exactly the block size, we just add a // singleton dimension i.e. [..., n, n] -> [..., 1, n, n] @@ -58,13 +57,13 @@ xla::XlaOp DiagonalBlocks(xla::XlaOp a, int64 block_size) { if (n > block_size) { // Construct the starting indices of the diagonal blocks auto start_indices = - Transpose(Broadcast(Mul(Iota(builder, xla::S32, num_blocks), - xla::ConstantR0(builder, block_size)), + Transpose(Broadcast(Mul(Iota(builder, S32, num_blocks), + ConstantR0(builder, block_size)), /*broadcast_sizes=*/{2}), /*permutation=*/{1, 0}); // Gather the diagonal blocks - xla::GatherDimensionNumbers dim_numbers; + GatherDimensionNumbers dim_numbers; dim_numbers.add_offset_dims(ndims - 1); dim_numbers.add_offset_dims(ndims); dim_numbers.add_start_index_map(ndims - 2); @@ -80,7 +79,7 @@ xla::XlaOp DiagonalBlocks(xla::XlaOp a, int64 block_size) { // Pad with zeros auto last_blocks = SliceInMinorDims(a, {n - n % block_size, n - n % block_size}, {n, n}); - xla::PaddingConfig config = xla::MakeNoPaddingConfig(ndims); + PaddingConfig config = MakeNoPaddingConfig(ndims); int64 padding = block_size - n % block_size; config.mutable_dimensions(ndims - 1)->set_edge_padding_high(padding); config.mutable_dimensions(ndims - 2)->set_edge_padding_high(padding); @@ -89,9 +88,8 @@ xla::XlaOp DiagonalBlocks(xla::XlaOp a, int64 block_size) { // Add a singleton dimension // i.e. [..., block_size, block_size] -> [..., 1, block_size, block_size] - TF_ASSIGN_OR_RETURN(xla::Shape blocks_shape, - builder->GetShape(last_blocks)); - auto shape_dims = xla::AsInt64Slice(blocks_shape.dimensions()); + TF_ASSIGN_OR_RETURN(Shape blocks_shape, builder->GetShape(last_blocks)); + auto shape_dims = AsInt64Slice(blocks_shape.dimensions()); auto last_blocks_dims = std::vector(ndims); std::copy(shape_dims.begin(), shape_dims.end(), last_blocks_dims.begin()); last_blocks_dims.insert(last_blocks_dims.end() - 2, 1); @@ -100,7 +98,7 @@ xla::XlaOp DiagonalBlocks(xla::XlaOp a, int64 block_size) { // Concatenate with the other blocks if necessary if (n > block_size) { diag_blocks = - xla::ConcatInDim(builder, {diag_blocks, last_blocks}, ndims - 2); + ConcatInDim(builder, {diag_blocks, last_blocks}, ndims - 2); } else { diag_blocks = last_blocks; } @@ -110,16 +108,16 @@ xla::XlaOp DiagonalBlocks(xla::XlaOp a, int64 block_size) { }); } -xla::XlaOp InvertDiagonalBlocks(xla::XlaOp diag_blocks, bool lower, - bool transpose_a, bool conjugate_a, - xla::PrecisionConfig::Precision precision) { - xla::XlaBuilder* builder = diag_blocks.builder(); - return builder->ReportErrorOrReturn([&]() -> xla::StatusOr { +XlaOp InvertDiagonalBlocks(XlaOp diag_blocks, bool lower, bool transpose_a, + bool conjugate_a, + PrecisionConfig::Precision precision) { + XlaBuilder* builder = diag_blocks.builder(); + return builder->ReportErrorOrReturn([&]() -> StatusOr { // Input is a batch of square lower triangular square matrices. Its shape is // (..., size, size). We resize this to (num_blocks, size, size). - TF_ASSIGN_OR_RETURN(xla::Shape shape, builder->GetShape(diag_blocks)); - int64 block_size = xla::ShapeUtil::GetDimension(shape, -1); - int64 num_blocks = xla::ShapeUtil::ElementsIn(shape) / + TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(diag_blocks)); + int64 block_size = ShapeUtil::GetDimension(shape, -1); + int64 num_blocks = ShapeUtil::ElementsIn(shape) / tensorflow::MathUtil::IPow(block_size, 2); diag_blocks = Reshape(diag_blocks, {num_blocks, block_size, block_size}); @@ -131,9 +129,9 @@ xla::XlaOp InvertDiagonalBlocks(xla::XlaOp diag_blocks, bool lower, // zero (which can happen if the last block was padded) otherwise it will // introduce nans which will propagate auto diags = GetMatrixDiagonal(diag_blocks); - TF_ASSIGN_OR_RETURN(xla::Shape diags_shape, builder->GetShape(diags)); + TF_ASSIGN_OR_RETURN(Shape diags_shape, builder->GetShape(diags)); auto one = ScalarLike(diags, 1); - auto ones = Broadcast(one, xla::AsInt64Slice(diags_shape.dimensions())); + auto ones = Broadcast(one, AsInt64Slice(diags_shape.dimensions())); diags = Select(Eq(diags, Zero(builder, shape.element_type())), ones, diags); auto scaled_diag_blocks = Div(diag_blocks, diags, {0, 2}); @@ -159,40 +157,40 @@ xla::XlaOp InvertDiagonalBlocks(xla::XlaOp diag_blocks, bool lower, auto start_index = (lower) ? 0 : block_size - 1; auto output_block = DynamicUpdateSlice( neg_identity, pos_one, - /*start_indices=*/xla::ConstantR1(builder, 2, start_index)); + /*start_indices=*/ConstantR1(builder, 2, start_index)); // Broadcast diag([1, -1, -1, ...]) to every block - xla::XlaOp output = Broadcast(output_block, - /*broadcast_sizes=*/{num_blocks}); + XlaOp output = Broadcast(output_block, + /*broadcast_sizes=*/{num_blocks}); // Now we construct a loop that performs matrix-vector multiplications // inverting the blocks one row at a time - std::vector tuple_shapes = { + std::vector tuple_shapes = { // The loop iteration counter is a scalar, incremented each iteration. - xla::ShapeUtil::MakeShape(xla::S32, {}), + ShapeUtil::MakeShape(S32, {}), // The output has the shape of A, with one row updated each iteration. - xla::ShapeUtil::MakeShape(shape.element_type(), - {num_blocks, block_size, block_size}), + ShapeUtil::MakeShape(shape.element_type(), + {num_blocks, block_size, block_size}), // The input is a loop invariant. - xla::ShapeUtil::MakeShape(shape.element_type(), - {num_blocks, block_size, block_size})}; - xla::Shape tuple_shape = xla::ShapeUtil::MakeTupleShape(tuple_shapes); + ShapeUtil::MakeShape(shape.element_type(), + {num_blocks, block_size, block_size})}; + Shape tuple_shape = ShapeUtil::MakeTupleShape(tuple_shapes); - auto init_i = One(builder, xla::S32); - auto init = xla::Tuple(builder, {init_i, output, scaled_diag_blocks}); + auto init_i = One(builder, S32); + auto init = Tuple(builder, {init_i, output, scaled_diag_blocks}); // Construct the loop condition function. - std::unique_ptr condb = + std::unique_ptr condb = builder->CreateSubBuilder("InvertDiagCond"); { auto i = GetTupleElement( Parameter(condb.get(), 0, tuple_shape, "InvertDiagCondTuple"), 0); - Lt(i, xla::ConstantR0(condb.get(), block_size)); + Lt(i, ConstantR0(condb.get(), block_size)); } TF_ASSIGN_OR_RETURN(auto cond, condb->Build()); // Construct the loop body function. - std::unique_ptr bodyb = + std::unique_ptr bodyb = builder->CreateSubBuilder("InvertDiagBody"); { auto input_tuple = @@ -202,21 +200,21 @@ xla::XlaOp InvertDiagonalBlocks(xla::XlaOp diag_blocks, bool lower, auto body_out = GetTupleElement(input_tuple, 1); auto body_input = GetTupleElement(input_tuple, 2); - auto zero = xla::ConstantR1(bodyb.get(), 1, 0); + auto zero = ConstantR1(bodyb.get(), 1, 0); auto j = (lower) ? i : ScalarLike(i, block_size - 1) - i; auto start_indices = - xla::ConcatInDim(bodyb.get(), {zero, Reshape(j, {1}), zero}, 0); + ConcatInDim(bodyb.get(), {zero, Reshape(j, {1}), zero}, 0); auto input_row = DynamicSlice(body_input, start_indices, /*slice_sizes=*/{num_blocks, 1, block_size}); // We want -L21 L11^{-1} - xla::DotDimensionNumbers dnums; + DotDimensionNumbers dnums; dnums.add_lhs_batch_dimensions(0); dnums.add_rhs_batch_dimensions(0); dnums.add_lhs_contracting_dimensions(2); dnums.add_rhs_contracting_dimensions(1); - xla::PrecisionConfig precision_proto; + PrecisionConfig precision_proto; precision_proto.add_operand_precision(precision); precision_proto.add_operand_precision(precision); auto update = -DotGeneral(input_row, body_out, dnums, &precision_proto); @@ -224,7 +222,7 @@ xla::XlaOp InvertDiagonalBlocks(xla::XlaOp diag_blocks, bool lower, body_out = DynamicUpdateSlice(body_out, update, start_indices); auto next_i = i + ScalarLike(i, 1); - xla::Tuple(bodyb.get(), {next_i, body_out, body_input}); + Tuple(bodyb.get(), {next_i, body_out, body_input}); } TF_ASSIGN_OR_RETURN(auto body, bodyb->Build()); @@ -238,27 +236,26 @@ xla::XlaOp InvertDiagonalBlocks(xla::XlaOp diag_blocks, bool lower, /*broadcast_dimensions=*/{0, 1}); // Reshape back to original batch major dimensions - return Reshape(inv_diag_blocks, xla::AsInt64Slice(shape.dimensions())); + return Reshape(inv_diag_blocks, AsInt64Slice(shape.dimensions())); }); } -xla::XlaOp SolveWithInvertedDiagonalBlocks( - xla::XlaOp a, xla::XlaOp b, xla::XlaOp inv_diag_blocks, bool left_side, - bool lower, bool transpose_a, bool conjugate_a, - xla::PrecisionConfig::Precision precision) { - xla::XlaBuilder* builder = a.builder(); - return builder->ReportErrorOrReturn([&]() -> xla::StatusOr { - TF_ASSIGN_OR_RETURN(xla::Shape blocks_shape, - builder->GetShape(inv_diag_blocks)); - TF_ASSIGN_OR_RETURN(xla::Shape b_shape, builder->GetShape(b)); - int64 block_size = xla::ShapeUtil::GetDimension(blocks_shape, -1); - - TF_ASSIGN_OR_RETURN(xla::Shape a_shape, builder->GetShape(a)); - int64 ndims = xla::ShapeUtil::Rank(a_shape); - int64 n = xla::ShapeUtil::GetDimension(a_shape, -1); +XlaOp SolveWithInvertedDiagonalBlocks(XlaOp a, XlaOp b, XlaOp inv_diag_blocks, + bool left_side, bool lower, + bool transpose_a, bool conjugate_a, + PrecisionConfig::Precision precision) { + XlaBuilder* builder = a.builder(); + return builder->ReportErrorOrReturn([&]() -> StatusOr { + TF_ASSIGN_OR_RETURN(Shape blocks_shape, builder->GetShape(inv_diag_blocks)); + TF_ASSIGN_OR_RETURN(Shape b_shape, builder->GetShape(b)); + int64 block_size = ShapeUtil::GetDimension(blocks_shape, -1); + + TF_ASSIGN_OR_RETURN(Shape a_shape, builder->GetShape(a)); + int64 ndims = ShapeUtil::Rank(a_shape); + int64 n = ShapeUtil::GetDimension(a_shape, -1); int64 num_blocks = n / block_size + (n % block_size != 0); int64 m_dim = (left_side) ? -1 : -2; - int64 m = xla::ShapeUtil::GetDimension(b_shape, m_dim); + int64 m = ShapeUtil::GetDimension(b_shape, m_dim); // Initialize the solution auto x = ZerosLike(b); @@ -294,7 +291,7 @@ xla::XlaOp SolveWithInvertedDiagonalBlocks( } auto b_row = SliceInMinorDims(b, start, end); - xla::XlaOp remainder; + XlaOp remainder; if (i == 0) { remainder = b_row; } else { @@ -311,29 +308,27 @@ xla::XlaOp SolveWithInvertedDiagonalBlocks( auto a_row = MaybeConjugate(SliceInMinorDims(a, start, end), conjugate_a); if (left_side) { - remainder = b_row - BatchDot(a_row, x, transpose_a, false, - /*conjugate_x=*/false, - /*conjugate_y=*/false, precision); + remainder = + b_row - BatchDot(MaybeTransposeInMinorDims(a_row, transpose_a), x, + precision); } else { - remainder = b_row - BatchDot(x, a_row, false, transpose_a, - /*conjugate_x=*/false, - /*conjugate_y=*/false, precision); + remainder = + b_row - BatchDot(x, MaybeTransposeInMinorDims(a_row, transpose_a), + precision); } } - xla::XlaOp x_update; - auto zero = Zero(builder, xla::S32); - auto start_index = - xla::ConstantR0WithType(builder, xla::S32, j * block_size); - std::vector update_starts = {start_index, zero}; + XlaOp x_update; + auto zero = Zero(builder, S32); + auto start_index = ConstantR0WithType(builder, S32, j * block_size); + std::vector update_starts = {start_index, zero}; if (left_side) { - x_update = - BatchDot(inv_block, remainder, transpose_a, false, - /*conjugate_x=*/false, /*conjugate_y=*/false, precision); + x_update = BatchDot(MaybeTransposeInMinorDims(inv_block, transpose_a), + remainder, precision); } else { - x_update = - BatchDot(remainder, inv_block, false, transpose_a, - /*conjugate_x=*/false, /*conjugate_y=*/false, precision); + x_update = BatchDot(remainder, + MaybeTransposeInMinorDims(inv_block, transpose_a), + precision); std::swap(update_starts[0], update_starts[1]); } x = DynamicUpdateSliceInMinorDims(x, x_update, /*starts=*/update_starts); @@ -343,24 +338,24 @@ xla::XlaOp SolveWithInvertedDiagonalBlocks( }); } -xla::XlaOp TriangularSolve(xla::XlaOp a, xla::XlaOp b, bool left_side, - bool lower, bool transpose_a, bool conjugate_a, - int64 block_size, - xla::PrecisionConfig::Precision precision) { - xla::XlaBuilder* builder = a.builder(); - return builder->ReportErrorOrReturn([&]() -> xla::StatusOr { - TF_ASSIGN_OR_RETURN(xla::Shape a_shape, builder->GetShape(a)); - TF_ASSIGN_OR_RETURN(xla::Shape b_shape, builder->GetShape(b)); - if (xla::ShapeUtil::Rank(a_shape) != xla::ShapeUtil::Rank(b_shape)) { - return errors::InvalidArgument( - "Arguments to TriangularSolve have different ranks: ", - xla::ShapeUtil::HumanString(a_shape), " vs. ", - xla::ShapeUtil::HumanString(b_shape)); +XlaOp TriangularSolve(XlaOp a, XlaOp b, bool left_side, bool lower, + bool transpose_a, bool conjugate_a, int64 block_size, + PrecisionConfig::Precision precision) { + XlaBuilder* builder = a.builder(); + return builder->ReportErrorOrReturn([&]() -> StatusOr { + TF_ASSIGN_OR_RETURN(Shape a_shape, builder->GetShape(a)); + TF_ASSIGN_OR_RETURN(Shape b_shape, builder->GetShape(b)); + if (ShapeUtil::Rank(a_shape) != ShapeUtil::Rank(b_shape)) { + return InvalidArgument( + "Arguments to TriangularSolve have shapes with different ranks: " + "%s vs. %s", + ShapeUtil::HumanString(a_shape), ShapeUtil::HumanString(b_shape)); } - const int64 ndims = xla::ShapeUtil::Rank(a_shape); + const int64 ndims = ShapeUtil::Rank(a_shape); if (ndims < 2) { - return errors::InvalidArgument( - "Arguments to TriangularSolve must have rank >= 2: ", ndims); + return InvalidArgument( + "Arguments to TriangularSolve was rank %d but must have rank >= 2.", + ndims); } // The batch dimensions must be equal. std::vector batch_dimensions; @@ -368,35 +363,42 @@ xla::XlaOp TriangularSolve(xla::XlaOp a, xla::XlaOp b, bool left_side, int64 a_size = a_shape.dimensions(i); int64 b_size = b_shape.dimensions(i); if (a_size != b_size) { - return errors::InvalidArgument( - "Batch dimensions of arguments to TriangularSolve must be equal: ", - xla::ShapeUtil::HumanString(a_shape), " vs ", - xla::ShapeUtil::HumanString(b_shape)); + return InvalidArgument( + "Batch dimensions of arguments to TriangularSolve must be equal; " + "shapes were %s and %s.", + ShapeUtil::HumanString(a_shape), ShapeUtil::HumanString(b_shape)); } batch_dimensions.push_back(a_size); } - if (xla::ShapeUtil::GetDimension(a_shape, -1) != - xla::ShapeUtil::GetDimension(a_shape, -2)) { - return errors::InvalidArgument( - "The 'a' arguments to TriangularSolve must be square matrices: ", - xla::ShapeUtil::HumanString(a_shape)); + if (ShapeUtil::GetDimension(a_shape, -1) != + ShapeUtil::GetDimension(a_shape, -2)) { + return InvalidArgument( + "The 'a' argument to TriangularSolve must be a batched square matrix;" + " shape was: %s", + ShapeUtil::HumanString(a_shape)); } - const int64 m = xla::ShapeUtil::GetDimension(b_shape, -2); - const int64 n = xla::ShapeUtil::GetDimension(b_shape, -1); - if ((left_side ? m : n) != xla::ShapeUtil::GetDimension(a_shape, -1)) { - return errors::InvalidArgument( - "Arguments to TriangularSolve have incompatible matrix shapes: ", - xla::ShapeUtil::HumanString(a_shape), " vs ", - xla::ShapeUtil::HumanString(b_shape)); + const int64 m = ShapeUtil::GetDimension(b_shape, -2); + const int64 n = ShapeUtil::GetDimension(b_shape, -1); + if ((left_side ? m : n) != ShapeUtil::GetDimension(a_shape, -1)) { + return InvalidArgument( + "Arguments to TriangularSolve have incompatible matrix shapes %s and " + "%s", + ShapeUtil::HumanString(a_shape), ShapeUtil::HumanString(b_shape)); } if (block_size < 1) { - return errors::InvalidArgument( - "block_size argument to TriangularSolve must be >= 1; got ", + return InvalidArgument( + "block_size argument to TriangularSolve must be >= 1; got %d", block_size); } + if (ShapeUtil::IsZeroElementArray(b_shape)) { + // The output has the same shape as 'b', and since the output has zero + // elements, any such array will do. + return b; + } + // We find the diagonal blocks of the coefficient matrix auto diag_blocks = DiagonalBlocks(a, block_size); @@ -413,4 +415,4 @@ xla::XlaOp TriangularSolve(xla::XlaOp a, xla::XlaOp b, bool left_side, }); } -} // namespace tensorflow +} // namespace xla diff --git a/tensorflow/compiler/tf2xla/lib/triangular_solve.h b/tensorflow/compiler/xla/client/lib/triangular_solve.h similarity index 88% rename from tensorflow/compiler/tf2xla/lib/triangular_solve.h rename to tensorflow/compiler/xla/client/lib/triangular_solve.h index 2303234f361e54cd2a0ad495cb03b371bed76877..50a3b30ebd1c15eb6d2ace4e351cb41f21db7093 100644 --- a/tensorflow/compiler/tf2xla/lib/triangular_solve.h +++ b/tensorflow/compiler/xla/client/lib/triangular_solve.h @@ -13,13 +13,13 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_COMPILER_TF2XLA_LIB_TRIANGULAR_SOLVE_H_ -#define TENSORFLOW_COMPILER_TF2XLA_LIB_TRIANGULAR_SOLVE_H_ +#ifndef TENSORFLOW_COMPILER_XLA_CLIENT_LIB_TRIANGULAR_SOLVE_H_ +#define TENSORFLOW_COMPILER_XLA_CLIENT_LIB_TRIANGULAR_SOLVE_H_ #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -namespace tensorflow { +namespace xla { // Solves systems of linear equations with lower or upper triangular coefficient // matrices by forward- or back-substitution. Broadcasting along leading @@ -57,11 +57,11 @@ namespace tensorflow { // // Uses a blocked algorithm if `block_size` is > 1; if block_size == 1 then no // blocking is used. -xla::XlaOp TriangularSolve( - xla::XlaOp a, xla::XlaOp b, bool left_side, bool lower, bool transpose_a, +XlaOp TriangularSolve( + XlaOp a, XlaOp b, bool left_side, bool lower, bool transpose_a, bool conjugate_a, int64 block_size = 128, - xla::PrecisionConfig::Precision precision = xla::PrecisionConfig::HIGHEST); + PrecisionConfig::Precision precision = PrecisionConfig::HIGHEST); -} // namespace tensorflow +} // namespace xla -#endif // TENSORFLOW_COMPILER_TF2XLA_LIB_TRIANGULAR_SOLVE_H_ +#endif // TENSORFLOW_COMPILER_XLA_CLIENT_LIB_TRIANGULAR_SOLVE_H_ diff --git a/tensorflow/compiler/tf2xla/lib/triangular_solve_test.cc b/tensorflow/compiler/xla/client/lib/triangular_solve_test.cc similarity index 78% rename from tensorflow/compiler/tf2xla/lib/triangular_solve_test.cc rename to tensorflow/compiler/xla/client/lib/triangular_solve_test.cc index aeebf16028d40189203cdfd815f06a339ee72902..d0188e8ea06d0edacdba330f46647af201747abf 100644 --- a/tensorflow/compiler/tf2xla/lib/triangular_solve_test.cc +++ b/tensorflow/compiler/xla/client/lib/triangular_solve_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/tf2xla/lib/triangular_solve.h" +#include "tensorflow/compiler/xla/client/lib/triangular_solve.h" #include #include @@ -30,59 +30,71 @@ limitations under the License. #include "tensorflow/compiler/xla/types.h" #include "tensorflow/core/lib/core/status_test_util.h" -namespace tensorflow { +namespace xla { namespace { -using TriangularSolveTest = xla::ClientLibraryTestBase; -using TriangularSolveLeftLookingTest = xla::ClientLibraryTestBase; -using complex64 = xla::complex64; +using TriangularSolveTest = ClientLibraryTestBase; +using TriangularSolveLeftLookingTest = ClientLibraryTestBase; -xla::Array2D AValsLower() { +Array2D AValsLower() { return {{2, 0, 0, 0}, {3, 6, 0, 0}, {4, 7, 9, 0}, {5, 8, 10, 11}}; } -xla::Array2D AValsUpper() { +Array2D AValsUpper() { return {{2, 3, 4, 5}, {0, 6, 7, 8}, {0, 0, 9, 10}, {0, 0, 0, 11}}; } -xla::Array2D BValsRight() { +Array2D BValsRight() { return {{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}}; } -xla::Array2D BValsLeft() { +Array2D BValsLeft() { return {{1, 2, 3}, {4, 5, 6}, {7, 8, 9}, {10, 11, 12}}; } -xla::Array2D AValsLowerComplex() { +Array2D AValsLowerComplex() { return {{2, 0, 0, 0}, {complex64(3, 1), 6, 0, 0}, {4, complex64(7, 2), 9, 0}, {5, 8, complex64(10, 3), 11}}; } -xla::Array2D AValsUpperComplex() { +Array2D AValsUpperComplex() { return {{2, 3, complex64(4, 3), 5}, {0, 6, complex64(7, 2), 8}, {0, 0, complex64(9, 1), 10}, {0, 0, 0, 11}}; } -xla::Array2D BValsRightComplex() { +Array2D BValsRightComplex() { return {{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}}; } -xla::Array2D BValsLeftComplex() { +Array2D BValsLeftComplex() { return {{1, 2, 3}, {4, 5, 6}, {7, 8, 9}, {10, 11, 12}}; } -xla::Array2D AValsFull() { - return {{2, 0, 1, 2}, {3, 6, 0, 1}, {4, 7, 9, 0}, {5, 8, 10, 11}}; +XLA_TEST_F(TriangularSolveTest, EmptyArrays) { + XlaBuilder builder(TestName()); + + XlaOp a, b; + auto a_data = + CreateR2Parameter(Array2D(0, 0), 0, "a", &builder, &a); + auto b_data = + CreateR2Parameter(Array2D(0, 10), 1, "b", &builder, &b); + TriangularSolve(a, b, + /*left_side=*/true, /*lower=*/true, + /*transpose_a=*/true, /*conjugate_a=*/false, + /*block_size=*/2); + + ComputeAndCompareR2(&builder, Array2D(0, 10), + {a_data.get(), b_data.get()}); } XLA_TEST_F(TriangularSolveTest, SimpleRightLowerTranspose) { - xla::XlaBuilder builder(TestName()); + XlaBuilder builder(TestName()); - xla::XlaOp a, b; + XlaOp a, b; auto a_data = CreateR2Parameter(AValsLower(), 0, "a", &builder, &a); auto b_data = CreateR2Parameter(BValsRight(), 1, "b", &builder, &b); TriangularSolve(a, b, @@ -90,20 +102,20 @@ XLA_TEST_F(TriangularSolveTest, SimpleRightLowerTranspose) { /*transpose_a=*/true, /*conjugate_a=*/false, /*block_size=*/2); - xla::Array2D expected({ + Array2D expected({ {0.5, 0.08333334, 0.04629629, 0.03367003}, {2.5, -0.25, -0.1388889, -0.1010101}, {4.5, -0.58333331, -0.32407406, -0.23569024}, }); ComputeAndCompareR2(&builder, expected, {a_data.get(), b_data.get()}, - xla::ErrorSpec(1e-2, 1e-2)); + ErrorSpec(1e-2, 1e-2)); } XLA_TEST_F(TriangularSolveTest, SimpleRightLowerNotranspose) { - xla::XlaBuilder builder(TestName()); + XlaBuilder builder(TestName()); - xla::XlaOp a, b; + XlaOp a, b; auto a_data = CreateR2Parameter(AValsLower(), 0, "a", &builder, &a); auto b_data = CreateR2Parameter(BValsRight(), 1, "b", &builder, &b); TriangularSolve(a, b, @@ -111,20 +123,20 @@ XLA_TEST_F(TriangularSolveTest, SimpleRightLowerNotranspose) { /*transpose_a=*/false, /*conjugate_a=*/false, /*block_size=*/2); - xla::Array2D expected({ + Array2D expected({ {-0.16414141, -0.06902357, -0.07070707, 0.36363636}, {0.64393939, 0.06565657, -0.03030303, 0.72727273}, {1.4520202, 0.2003367, 0.01010101, 1.09090909}, }); ComputeAndCompareR2(&builder, expected, {a_data.get(), b_data.get()}, - xla::ErrorSpec(1e-2, 1e-2)); + ErrorSpec(1e-2, 1e-2)); } XLA_TEST_F(TriangularSolveTest, SimpleRightUpperTranspose) { - xla::XlaBuilder builder(TestName()); + XlaBuilder builder(TestName()); - xla::XlaOp a, b; + XlaOp a, b; auto a_data = CreateR2Parameter(AValsUpper(), 0, "a", &builder, &a); auto b_data = CreateR2Parameter(BValsRight(), 1, "b", &builder, &b); TriangularSolve(a, b, @@ -132,20 +144,20 @@ XLA_TEST_F(TriangularSolveTest, SimpleRightUpperTranspose) { /*transpose_a=*/true, /*conjugate_a=*/false, /*block_size=*/2); - xla::Array2D expected({ + Array2D expected({ {-0.16414141, -0.06902357, -0.07070707, 0.36363636}, {0.64393939, 0.06565657, -0.03030303, 0.72727273}, {1.4520202, 0.2003367, 0.01010101, 1.09090909}, }); ComputeAndCompareR2(&builder, expected, {a_data.get(), b_data.get()}, - xla::ErrorSpec(1e-2, 1e-2)); + ErrorSpec(1e-2, 1e-2)); } XLA_TEST_F(TriangularSolveTest, SimpleRightUpperNotranspose) { - xla::XlaBuilder builder(TestName()); + XlaBuilder builder(TestName()); - xla::XlaOp a, b; + XlaOp a, b; auto a_data = CreateR2Parameter(AValsUpper(), 0, "a", &builder, &a); auto b_data = CreateR2Parameter(BValsRight(), 1, "b", &builder, &b); TriangularSolve(a, b, @@ -153,20 +165,20 @@ XLA_TEST_F(TriangularSolveTest, SimpleRightUpperNotranspose) { /*transpose_a=*/false, /*conjugate_a=*/false, /*block_size=*/2); - xla::Array2D expected({ + Array2D expected({ {0.5, 0.08333334, 0.04629629, 0.03367003}, {2.5, -0.25, -0.1388889, -0.1010101}, {4.5, -0.58333331, -0.32407406, -0.23569024}, }); ComputeAndCompareR2(&builder, expected, {a_data.get(), b_data.get()}, - xla::ErrorSpec(1e-2, 1e-2)); + ErrorSpec(1e-2, 1e-2)); } XLA_TEST_F(TriangularSolveTest, SimpleLeftLowerTranspose) { - xla::XlaBuilder builder(TestName()); + XlaBuilder builder(TestName()); - xla::XlaOp a, b; + XlaOp a, b; auto a_data = CreateR2Parameter(AValsLower(), 0, "a", &builder, &a); auto b_data = CreateR2Parameter(BValsLeft(), 1, "b", &builder, &b); TriangularSolve(a, b, @@ -174,7 +186,7 @@ XLA_TEST_F(TriangularSolveTest, SimpleLeftLowerTranspose) { /*transpose_a=*/true, /*conjugate_a=*/false, /*block_size=*/2); - xla::Array2D expected({ + Array2D expected({ {-0.89646465, -0.69444444, -0.49242424}, {-0.27441077, -0.24074074, -0.20707071}, {-0.23232323, -0.22222222, -0.21212121}, @@ -182,13 +194,13 @@ XLA_TEST_F(TriangularSolveTest, SimpleLeftLowerTranspose) { }); ComputeAndCompareR2(&builder, expected, {a_data.get(), b_data.get()}, - xla::ErrorSpec(1e-2, 1e-2)); + ErrorSpec(1e-2, 1e-2)); } XLA_TEST_F(TriangularSolveTest, SimpleLeftLowerNotranspose) { - xla::XlaBuilder builder(TestName()); + XlaBuilder builder(TestName()); - xla::XlaOp a, b; + XlaOp a, b; auto a_data = CreateR2Parameter(AValsLower(), 0, "a", &builder, &a); auto b_data = CreateR2Parameter(BValsLeft(), 1, "b", &builder, &b); TriangularSolve(a, b, @@ -196,7 +208,7 @@ XLA_TEST_F(TriangularSolveTest, SimpleLeftLowerNotranspose) { /*transpose_a=*/false, /*conjugate_a=*/false, /*block_size=*/2); - xla::Array2D expected({ + Array2D expected({ {0.5, 1.0, 1.5}, {0.41666667, 0.33333333, 0.25}, {0.23148148, 0.18518519, 0.13888889}, @@ -204,13 +216,13 @@ XLA_TEST_F(TriangularSolveTest, SimpleLeftLowerNotranspose) { }); ComputeAndCompareR2(&builder, expected, {a_data.get(), b_data.get()}, - xla::ErrorSpec(1e-2, 1e-2)); + ErrorSpec(1e-2, 1e-2)); } XLA_TEST_F(TriangularSolveTest, SimpleLeftLowerNotransposeIrregularblock) { - xla::XlaBuilder builder(TestName()); + XlaBuilder builder(TestName()); - xla::XlaOp a, b; + XlaOp a, b; auto a_data = CreateR2Parameter(AValsLower(), 0, "a", &builder, &a); auto b_data = CreateR2Parameter(BValsLeft(), 1, "b", &builder, &b); TriangularSolve(a, b, @@ -218,7 +230,7 @@ XLA_TEST_F(TriangularSolveTest, SimpleLeftLowerNotransposeIrregularblock) { /*transpose_a=*/false, /*conjugate_a=*/false, /*block_size=*/3); - xla::Array2D expected({ + Array2D expected({ {0.5, 1.0, 1.5}, {0.41666667, 0.33333333, 0.25}, {0.23148148, 0.18518519, 0.13888889}, @@ -226,13 +238,13 @@ XLA_TEST_F(TriangularSolveTest, SimpleLeftLowerNotransposeIrregularblock) { }); ComputeAndCompareR2(&builder, expected, {a_data.get(), b_data.get()}, - xla::ErrorSpec(1e-2, 1e-2)); + ErrorSpec(1e-2, 1e-2)); } XLA_TEST_F(TriangularSolveTest, SimpleLeftUpperTranspose) { - xla::XlaBuilder builder(TestName()); + XlaBuilder builder(TestName()); - xla::XlaOp a, b; + XlaOp a, b; auto a_data = CreateR2Parameter(AValsUpper(), 0, "a", &builder, &a); auto b_data = CreateR2Parameter(BValsLeft(), 1, "b", &builder, &b); TriangularSolve(a, b, @@ -240,7 +252,7 @@ XLA_TEST_F(TriangularSolveTest, SimpleLeftUpperTranspose) { /*transpose_a=*/true, /*conjugate_a=*/false, /*block_size=*/2); - xla::Array2D expected({ + Array2D expected({ {0.5, 1.0, 1.5}, {0.41666667, 0.33333333, 0.25}, {0.23148148, 0.18518519, 0.13888889}, @@ -248,13 +260,13 @@ XLA_TEST_F(TriangularSolveTest, SimpleLeftUpperTranspose) { }); ComputeAndCompareR2(&builder, expected, {a_data.get(), b_data.get()}, - xla::ErrorSpec(1e-2, 1e-2)); + ErrorSpec(1e-2, 1e-2)); } XLA_TEST_F(TriangularSolveTest, SimpleLeftUpperNotranspose) { - xla::XlaBuilder builder(TestName()); + XlaBuilder builder(TestName()); - xla::XlaOp a, b; + XlaOp a, b; auto a_data = CreateR2Parameter(AValsUpper(), 0, "a", &builder, &a); auto b_data = CreateR2Parameter(BValsLeft(), 1, "b", &builder, &b); TriangularSolve(a, b, @@ -262,7 +274,7 @@ XLA_TEST_F(TriangularSolveTest, SimpleLeftUpperNotranspose) { /*transpose_a=*/false, /*conjugate_a=*/false, /*block_size=*/2); - xla::Array2D expected({ + Array2D expected({ {-0.89646465, -0.69444444, -0.49242424}, {-0.27441077, -0.24074074, -0.20707071}, {-0.23232323, -0.22222222, -0.21212121}, @@ -270,13 +282,13 @@ XLA_TEST_F(TriangularSolveTest, SimpleLeftUpperNotranspose) { }); ComputeAndCompareR2(&builder, expected, {a_data.get(), b_data.get()}, - xla::ErrorSpec(1e-2, 1e-2)); + ErrorSpec(1e-2, 1e-2)); } XLA_TEST_F(TriangularSolveTest, SimpleRightLowerTransposeConjugate) { - xla::XlaBuilder builder(TestName()); + XlaBuilder builder(TestName()); - xla::XlaOp a, b; + XlaOp a, b; auto a_data = CreateR2Parameter(AValsLowerComplex(), 0, "a", &builder, &a); auto b_data = @@ -286,7 +298,7 @@ XLA_TEST_F(TriangularSolveTest, SimpleRightLowerTransposeConjugate) { /*transpose_a=*/true, /*conjugate_a=*/true, /*block_size=*/2); - xla::Array2D expected({ + Array2D expected({ {0.5, complex64(0.08333333, 0.08333333), complex64(0.02777778, -0.0462963), complex64(0.06313131, -0.01094276)}, {2.5, complex64(-0.25, 0.41666667), complex64(-0.23148148, -0.37962963), @@ -295,15 +307,14 @@ XLA_TEST_F(TriangularSolveTest, SimpleRightLowerTransposeConjugate) { complex64(0.11026936, -0.03114478)}, }); - ComputeAndCompareR2(&builder, expected, - {a_data.get(), b_data.get()}, - xla::ErrorSpec(1e-2, 1e-2)); + ComputeAndCompareR2( + &builder, expected, {a_data.get(), b_data.get()}, ErrorSpec(1e-2, 1e-2)); } XLA_TEST_F(TriangularSolveTest, SimpleLeftUpperTransposeNoconjugate) { - xla::XlaBuilder builder(TestName()); + XlaBuilder builder(TestName()); - xla::XlaOp a, b; + XlaOp a, b; auto a_data = CreateR2Parameter(AValsUpperComplex(), 0, "a", &builder, &a); auto b_data = @@ -313,7 +324,7 @@ XLA_TEST_F(TriangularSolveTest, SimpleLeftUpperTransposeNoconjugate) { /*transpose_a=*/true, /*conjugate_a=*/false, /*block_size=*/2); - xla::Array2D expected({ + Array2D expected({ {0.5, 1., 1.5}, {0.41666667, 0.33333333, 0.25}, {complex64(0.20020325, -2.81504065e-01), @@ -324,10 +335,9 @@ XLA_TEST_F(TriangularSolveTest, SimpleLeftUpperTransposeNoconjugate) { complex64(0.15798226, 5.12749446e-01)}, }); - ComputeAndCompareR2(&builder, expected, - {a_data.get(), b_data.get()}, - xla::ErrorSpec(1e-2, 1e-2)); + ComputeAndCompareR2( + &builder, expected, {a_data.get(), b_data.get()}, ErrorSpec(1e-2, 1e-2)); } } // namespace -} // namespace tensorflow +} // namespace xla diff --git a/tensorflow/compiler/xla/client/local_client.cc b/tensorflow/compiler/xla/client/local_client.cc index aaa5d6989eefb94edb8921d13f96e3705aa3e3a4..049cd15738a619294b19d5cf74ca514d7b4a00ad 100644 --- a/tensorflow/compiler/xla/client/local_client.cc +++ b/tensorflow/compiler/xla/client/local_client.cc @@ -71,9 +71,9 @@ Status LocalExecutable::ValidateExecutionOptions( "parameter " "%d: want %s, got %s", i, - ShapeUtil::HumanString( + ShapeUtil::HumanStringWithLayout( computation_layout.parameter_layout(i).shape()), - ShapeUtil::HumanString(arguments[i]->on_host_shape())); + ShapeUtil::HumanStringWithLayout(arguments[i]->on_host_shape())); } } diff --git a/tensorflow/compiler/xla/debug_options_flags.cc b/tensorflow/compiler/xla/debug_options_flags.cc index a40330a9b1fe201b6ec83d1bfe1a21e294e18f55..a9a91648ac377987e7f226116e11c9c697ace103 100644 --- a/tensorflow/compiler/xla/debug_options_flags.cc +++ b/tensorflow/compiler/xla/debug_options_flags.cc @@ -22,49 +22,49 @@ limitations under the License. #include "tensorflow/compiler/xla/parse_flags_from_env.h" namespace xla { -namespace { -DebugOptions* flag_values; -std::vector* flag_objects; -std::once_flag flags_init; - -void SetDebugOptionsDefaults(DebugOptions* flags) { - flags->set_xla_llvm_enable_alias_scope_metadata(true); - flags->set_xla_llvm_enable_noalias_metadata(true); - flags->set_xla_llvm_enable_invariant_load_metadata(true); - flags->set_xla_llvm_disable_expensive_passes(false); - flags->set_xla_backend_optimization_level(3); - flags->set_xla_cpu_multi_thread_eigen(true); - flags->set_xla_gpu_cuda_data_dir("./cuda_sdk_lib"); - flags->set_xla_eliminate_hlo_implicit_broadcast(true); +DebugOptions DefaultDebugOptionsIgnoringFlags() { + DebugOptions opts; + opts.set_xla_llvm_enable_alias_scope_metadata(true); + opts.set_xla_llvm_enable_noalias_metadata(true); + opts.set_xla_llvm_enable_invariant_load_metadata(true); + opts.set_xla_llvm_disable_expensive_passes(false); + opts.set_xla_backend_optimization_level(3); + opts.set_xla_cpu_multi_thread_eigen(true); + opts.set_xla_gpu_cuda_data_dir("./cuda_sdk_lib"); + opts.set_xla_eliminate_hlo_implicit_broadcast(true); + opts.set_xla_hlo_dump_as_html(false); #ifdef INTEL_MKL - flags->set_xla_cpu_use_mkl_dnn(true); + opts.set_xla_cpu_use_mkl_dnn(true); #endif // INTEL_MKL - flags->set_xla_gpu_max_kernel_unroll_factor(4); + opts.set_xla_gpu_max_kernel_unroll_factor(4); // Set cudnn batchnorm off by default; it does not provide a performance win // on average. - flags->set_xla_gpu_use_cudnn_batchnorm(false); + opts.set_xla_gpu_use_cudnn_batchnorm(false); // Run all GPU work on one stream by default. Using multiple streams // increases memory usage and we lack strong motivating benchmarks for tuning // the heuristics needed to decide when to run on multiple streams. See // b/77879207. - flags->set_xla_gpu_disable_multi_streaming(true); + opts.set_xla_gpu_disable_multi_streaming(true); // TODO(jlebar): Disable fastmath once doing so is not a performance // regression. - flags->set_xla_cpu_enable_fast_math(true); - flags->set_xla_gpu_enable_fast_math(true); + opts.set_xla_cpu_enable_fast_math(true); + opts.set_xla_gpu_enable_fast_min_max(true); - flags->set_xla_force_host_platform_device_count(1); + opts.set_xla_force_host_platform_device_count(1); + return opts; } +static DebugOptions* flag_values; +static std::vector* flag_objects; +static std::once_flag flags_init; + // Allocates flag_values and flag_objects; this function must not be called more // than once - its call done via call_once. -void AllocateFlags() { - flag_values = new DebugOptions; - - SetDebugOptionsDefaults(flag_values); +static void AllocateFlags() { + flag_values = new DebugOptions(DefaultDebugOptionsIgnoringFlags()); // Returns a lambda that calls "member_setter" on "flag_values" with the // argument passed in to the lambda. @@ -133,6 +133,11 @@ void AllocateFlags() { bool_setter_for(&DebugOptions::set_xla_hlo_dump_as_graphdef), flag_values->xla_hlo_dump_as_graphdef(), "Dump HLO graphs as TensorFlow GraphDefs."), + tensorflow::Flag("xla_hlo_dump_as_html", + bool_setter_for(&DebugOptions::set_xla_hlo_dump_as_html), + flag_values->xla_hlo_dump_as_html(), + "Dump HLO graphs as an HTML (DOT rendered into SVG " + "inlined in HTML)."), tensorflow::Flag( "xla_hlo_graph_sharding_color", bool_setter_for(&DebugOptions::set_xla_hlo_graph_sharding_color), @@ -160,11 +165,11 @@ void AllocateFlags() { "Enable unsafe fast-math optimizations in the CPU compiler; " "this may produce faster code at the expense of some accuracy."), tensorflow::Flag( - "xla_gpu_enable_fast_math", - bool_setter_for(&DebugOptions::set_xla_cpu_enable_fast_math), - flag_values->xla_cpu_enable_fast_math(), - "Enable unsafe fast-math optimizations in the GPU compiler; " - "this may produce faster code at the expense of some accuracy."), + "xla_gpu_enable_fast_min_max", + bool_setter_for(&DebugOptions::set_xla_gpu_enable_fast_min_max), + flag_values->xla_gpu_enable_fast_min_max(), + "Enable fast floating point min/max lowering that does not propagate " + "NaNs."), tensorflow::Flag( "xla_llvm_enable_alias_scope_metadata", bool_setter_for( @@ -202,6 +207,16 @@ void AllocateFlags() { "Comma-separated list of hlo passes to be disabled. These names " "must exactly match the passes' names; no whitespace around " "commas."), + tensorflow::Flag( + "xla_disable_all_hlo_passes", + bool_setter_for(&DebugOptions::set_xla_disable_all_hlo_passes), false, + "Disables all HLO passes. Notes that some passes are necessary for " + "correctness and the invariants that must be satisfied by 'fully " + "optimized' HLO are different for different devices and may change " + "over time. The only 'guarantee', such as it is, is that if you " + "compile XLA and dump the optimized HLO for some graph, you should " + "be able to run it again on the same device with the same build of " + "XLA."), tensorflow::Flag( "xla_embed_ir_in_executable", bool_setter_for(&DebugOptions::set_xla_embed_ir_in_executable), @@ -334,12 +349,16 @@ void AllocateFlags() { "overhead from context switching but we let the user override this " "behavior to help run tests on the host that run models in parallel " "across multiple devices."), + tensorflow::Flag( + "xla_gpu_disable_ptxas_optimizations", + bool_setter_for( + &DebugOptions::set_xla_gpu_disable_ptxas_optimizations), + flag_values->xla_gpu_disable_ptxas_optimizations(), + "In XLA:GPU run ptxas in -O0 (default is -O3)."), }); ParseFlagsFromEnvAndDieIfUnknown("XLA_FLAGS", *flag_objects); } -} // namespace - void AppendDebugOptionsFlags(std::vector* flag_list) { std::call_once(flags_init, &AllocateFlags); flag_list->insert(flag_list->end(), flag_objects->begin(), diff --git a/tensorflow/compiler/xla/debug_options_flags.h b/tensorflow/compiler/xla/debug_options_flags.h index 60e59abc2a2e0f1cce3de1afc928f9fe36f75b33..dbf86a40f052af09c61da0e1abb3116ef5214357 100644 --- a/tensorflow/compiler/xla/debug_options_flags.h +++ b/tensorflow/compiler/xla/debug_options_flags.h @@ -29,7 +29,10 @@ void AppendDebugOptionsFlags(std::vector* flag_list); // Fetches a DebugOptions proto message from flags provided to the program. // Flags must be registered with the flags parser using AppendDebugOptionsFlags // first. -xla::DebugOptions GetDebugOptionsFromFlags(); +DebugOptions GetDebugOptionsFromFlags(); + +// Gets a DebugOptions proto that reflects the defaults as if no flags were set. +DebugOptions DefaultDebugOptionsIgnoringFlags(); } // namespace xla diff --git a/tensorflow/compiler/xla/g3doc/_book.yaml b/tensorflow/compiler/xla/g3doc/_book.yaml index 12b7094705e75305dc43a013576f4549dd5f4185..267701e9c0e42a21d2cda6238520f6a9692e7e76 100644 --- a/tensorflow/compiler/xla/g3doc/_book.yaml +++ b/tensorflow/compiler/xla/g3doc/_book.yaml @@ -31,3 +31,5 @@ upper_tabs: - title: XLA compile API path: /xla/tutorials/xla_compile status: experimental + +- include: /_upper_tabs_right.yaml diff --git a/tensorflow/compiler/xla/g3doc/operation_semantics.md b/tensorflow/compiler/xla/g3doc/operation_semantics.md index e0807518bc401808266cd3b198efa9697d6804de..002ebc31b992826b4dfc53f31a9e3625cde3c5d0 100644 --- a/tensorflow/compiler/xla/g3doc/operation_semantics.md +++ b/tensorflow/compiler/xla/g3doc/operation_semantics.md @@ -38,25 +38,25 @@ Alltoall is a collective operation that sends data from all cores to all cores. It has two phases: 1. the scatter phase. On each core, the operand is split into `split_count` - number of blocks along the `split_dimensions`, and the blocks are scattered - to all cores, e.g., the ith block is send to the ith core. +number of blocks along the `split_dimensions`, and the blocks are scattered +to all cores, e.g., the ith block is send to the ith core. 2. the gather phase. Each core concatenates the received blocks along the - `concat_dimension`. +`concat_dimension`. The participating cores can be configured by: - `replica_groups`: each ReplicaGroup contains a list of replica id. If empty, - all replicas belong to one group in the order of 0 - (n-1). Alltoall will be - applied within subgroups in the specified order. For example, replica - groups = {{1,2,3},{4,5,0}} means, an Alltoall will be applied within replica - 1, 2, 3, and in the gather phase, the received blocks will be concatenated - in the order of 1, 2, 3; another Alltoall will be applied within replica 4, - 5, 0, and the concatenation order is 4, 5, 0. +all replicas belong to one group in the order of 0 - (n-1). Alltoall will be +applied within subgroups in the specified order. For example, replica +groups = {{1,2,3},{4,5,0}} means, an Alltoall will be applied within replica +1, 2, 3, and in the gather phase, the received blocks will be concatenated +in the order of 1, 2, 3; another Alltoall will be applied within replica 4, +5, 0, and the concatenation order is 4, 5, 0. Prerequisites: - The dimension size of the operand on the split_dimension is divisible by - split_count. +split_count. - The operand's shape is not tuple. `AllToAll(operand, split_dimension, concat_dimension, split_count, @@ -93,7 +93,7 @@ AllToAll(x, /*split_dimension=*/1, /*concat_dimension=*/0, /*split_count=*/4); ```
- +
In this example, there are 4 cores participating the Alltoall. On each core, the @@ -387,34 +387,34 @@ For example, let v be an array of 24 elements: ``` let v = f32[4x2x3] {{{10, 11, 12}, {15, 16, 17}}, - {{20, 21, 22}, {25, 26, 27}}, - {{30, 31, 32}, {35, 36, 37}}, - {{40, 41, 42}, {45, 46, 47}}}; +{{20, 21, 22}, {25, 26, 27}}, +{{30, 31, 32}, {35, 36, 37}}, +{{40, 41, 42}, {45, 46, 47}}}; // Collapse to a single dimension, leaving one dimension. let v012 = Collapse(v, {0,1,2}); then v012 == f32[24] {10, 11, 12, 15, 16, 17, - 20, 21, 22, 25, 26, 27, - 30, 31, 32, 35, 36, 37, - 40, 41, 42, 45, 46, 47}; +20, 21, 22, 25, 26, 27, +30, 31, 32, 35, 36, 37, +40, 41, 42, 45, 46, 47}; // Collapse the two lower dimensions, leaving two dimensions. let v01 = Collapse(v, {0,1}); then v01 == f32[4x6] {{10, 11, 12, 15, 16, 17}, - {20, 21, 22, 25, 26, 27}, - {30, 31, 32, 35, 36, 37}, - {40, 41, 42, 45, 46, 47}}; +{20, 21, 22, 25, 26, 27}, +{30, 31, 32, 35, 36, 37}, +{40, 41, 42, 45, 46, 47}}; // Collapse the two higher dimensions, leaving two dimensions. let v12 = Collapse(v, {1,2}); then v12 == f32[8x3] {{10, 11, 12}, - {15, 16, 17}, - {20, 21, 22}, - {25, 26, 27}, - {30, 31, 32}, - {35, 36, 37}, - {40, 41, 42}, - {45, 46, 47}}; +{15, 16, 17}, +{20, 21, 22}, +{25, 26, 27}, +{30, 31, 32}, +{35, 36, 37}, +{40, 41, 42}, +{45, 46, 47}}; ``` @@ -441,9 +441,9 @@ replicas. Note that there are the following restrictions on the `source_target_pair`: - Any two pairs should not have the same target replica id, and they should - not have the same source replica id. +not have the same source replica id. - If a replica id is not a target in any pair, then the output on that replica - is a tensor consists of 0(s) with the same shape as the input. +is a tensor consists of 0(s) with the same shape as the input. ## Concatenate @@ -480,25 +480,25 @@ Concat({{2, 3}, {4, 5}, {6, 7}}, 0) ``` let a = { - {1, 2}, - {3, 4}, - {5, 6}, +{1, 2}, +{3, 4}, +{5, 6}, }; let b = { - {7, 8}, +{7, 8}, }; Concat({a, b}, 0) >>> { - {1, 2}, - {3, 4}, - {5, 6}, - {7, 8}, +{1, 2}, +{3, 4}, +{5, 6}, +{7, 8}, } ``` Diagram:
- +
## Conditional @@ -566,20 +566,20 @@ the rhs is also an input. In a neural network, these are the input activations. The n+2 dimensions are, in this order: * `batch`: Each coordinate in this dimension represents an independent input - for which convolution is carried out. +for which convolution is carried out. * `z/depth/features`: Each (y,x) position in the base area has a vector - associated to it, which goes into this dimension. +associated to it, which goes into this dimension. * `spatial_dims`: Describes the `n` spatial dimensions that define the base - area that the window moves across. +area that the window moves across. The `rhs` argument is a rank n+2 array describing the convolutional filter/kernel/window. The dimensions are, in this order: * `output-z`: The `z` dimension of the output. * `input-z`: The size of this dimension times `feature_group_count` should - equal the size of the `z` dimension in lhs. +equal the size of the `z` dimension in lhs. * `spatial_dims`: Describes the `n` spatial dimensions that define the n-d - window that moves across the base area. +window that moves across the base area. The `window_strides` argument specifies the stride of the convolutional window in the spatial dimensions. For example, if the stride in the first spatial @@ -633,7 +633,7 @@ The output shape has these dimensions, in this order: * `batch`: Same size as `batch` on the input (`lhs`). * `z`: Same size as `output-z` on the kernel (`rhs`). * `spatial_dims`: One value for each valid placement of the convolutional - window. +window. The valid placements of the convolutional window are determined by the strides and the size of the base area after padding. @@ -658,15 +658,15 @@ Here is pseudo-code for a 2d convolution with padding and striding: ``` for (b, oz, oy, ox) { // output coordinates - value = 0; - for (iz, ky, kx) { // kernel coordinates and input z - iy = oy*stride_y + ky - pad_low_y; - ix = ox*stride_x + kx - pad_low_x; - if ((iy, ix) inside the base area considered without padding) { - value += input(b, iz, iy, ix) * kernel(oz, iz, ky, kx); - } - } - output(b, oz, oy, ox) = value; +value = 0; +for (iz, ky, kx) { // kernel coordinates and input z +iy = oy*stride_y + ky - pad_low_y; +ix = ox*stride_x + kx - pad_low_x; +if ((iy, ix) inside the base area considered without padding) { +value += input(b, iz, iy, ix) * kernel(oz, iz, ky, kx); +} +} +output(b, oz, oy, ox) = value; } ``` @@ -777,19 +777,19 @@ Here is an example of an implementation of `myfunc`: ``` extern "C" void myfunc(void* out, void** in) { - float (&x)[2] = *static_cast(in[0]); - float (&y)[2][3] = *static_cast(in[1]); - EXPECT_EQ(1, x[0]); - EXPECT_EQ(2, x[1]); - EXPECT_EQ(10, y[0][0]); - EXPECT_EQ(20, y[0][1]); - EXPECT_EQ(30, y[0][2]); - EXPECT_EQ(40, y[1][0]); - EXPECT_EQ(50, y[1][1]); - EXPECT_EQ(60, y[1][2]); - float (&z)[3][3] = *static_cast(out); - z[0][0] = x[1] + y[1][0]; - // ... +float (&x)[2] = *static_cast(in[0]); +float (&y)[2][3] = *static_cast(in[1]); +EXPECT_EQ(1, x[0]); +EXPECT_EQ(2, x[1]); +EXPECT_EQ(10, y[0][0]); +EXPECT_EQ(20, y[0][1]); +EXPECT_EQ(30, y[0][2]); +EXPECT_EQ(40, y[1][0]); +EXPECT_EQ(50, y[1][1]); +EXPECT_EQ(60, y[1][2]); +float (&z)[3][3] = *static_cast(out); +z[0][0] = x[1] + y[1][0]; +// ... } ``` @@ -864,17 +864,17 @@ Example with contracting dimension numbers: ``` lhs = { {1.0, 2.0, 3.0}, - {4.0, 5.0, 6.0} } +{4.0, 5.0, 6.0} } rhs = { {1.0, 1.0, 1.0}, - {2.0, 2.0, 2.0} } +{2.0, 2.0, 2.0} } DotDimensionNumbers dnums; dnums.add_lhs_contracting_dimensions(1); dnums.add_rhs_contracting_dimensions(1); DotGeneral(lhs, rhs, dnums) -> { {6.0, 12.0}, - {15.0, 30.0} } +{15.0, 30.0} } ``` Associated batch dimension numbers from the 'lhs' and 'rhs' must have the same @@ -886,14 +886,14 @@ Example with batch dimension numbers (batch size 2, 2x2 matrices): ``` lhs = { { {1.0, 2.0}, - {3.0, 4.0} }, - { {5.0, 6.0}, - {7.0, 8.0} } } +{3.0, 4.0} }, +{ {5.0, 6.0}, +{7.0, 8.0} } } rhs = { { {1.0, 0.0}, - {0.0, 1.0} }, - { {1.0, 0.0}, - {0.0, 1.0} } } +{0.0, 1.0} }, +{ {1.0, 0.0}, +{0.0, 1.0} } } DotDimensionNumbers dnums; dnums.add_lhs_contracting_dimensions(2); @@ -902,9 +902,9 @@ dnums.add_lhs_batch_dimensions(0); dnums.add_rhs_batch_dimensions(0); DotGeneral(lhs, rhs, dnums) -> { { {1.0, 2.0}, - {3.0, 4.0} }, - { {5.0, 6.0}, - {7.0, 8.0} } } +{3.0, 4.0} }, +{ {5.0, 6.0}, +{7.0, 8.0} } } ``` | Input | Output | Semantics | @@ -963,22 +963,22 @@ let a = {0.0, 1.0, 2.0, 3.0, 4.0} let s = {2} DynamicSlice(a, s, {2}) produces: - {2.0, 3.0} +{2.0, 3.0} ``` 2-dimensional example: ``` let b = - { {0.0, 1.0, 2.0}, - {3.0, 4.0, 5.0}, - {6.0, 7.0, 8.0}, - {9.0, 10.0, 11.0} } +{ {0.0, 1.0, 2.0}, +{3.0, 4.0, 5.0}, +{6.0, 7.0, 8.0}, +{9.0, 10.0, 11.0} } let s = {2, 1} DynamicSlice(b, s, {2, 2}) produces: - { { 7.0, 8.0}, - {10.0, 11.0} } +{ { 7.0, 8.0}, +{10.0, 11.0} } ``` ## DynamicUpdateSlice @@ -1027,29 +1027,29 @@ let u = {5.0, 6.0} let s = {2} DynamicUpdateSlice(a, u, s) produces: - {0.0, 1.0, 5.0, 6.0, 4.0} +{0.0, 1.0, 5.0, 6.0, 4.0} ``` 2-dimensional example: ``` let b = - { {0.0, 1.0, 2.0}, - {3.0, 4.0, 5.0}, - {6.0, 7.0, 8.0}, - {9.0, 10.0, 11.0} } +{ {0.0, 1.0, 2.0}, +{3.0, 4.0, 5.0}, +{6.0, 7.0, 8.0}, +{9.0, 10.0, 11.0} } let u = - { {12.0, 13.0}, - {14.0, 15.0}, - {16.0, 17.0} } +{ {12.0, 13.0}, +{14.0, 15.0}, +{16.0, 17.0} } let s = {1, 1} DynamicUpdateSlice(b, u, s) produces: - { {0.0, 1.0, 2.0}, - {3.0, 12.0, 13.0}, - {6.0, 14.0, 15.0}, - {9.0, 16.0, 17.0} } +{ {0.0, 1.0, 2.0}, +{3.0, 12.0, 13.0}, +{6.0, 14.0, 15.0}, +{9.0, 16.0, 17.0} } ``` ## Element-wise binary arithmetic operations @@ -1235,42 +1235,42 @@ shape of `start_indices` to be `[6,7,1]`). The bounds for the output array along dimension `i` is computed as follows: - 1. If `i` is present in `batch_dims` (i.e. is equal to `batch_dims[k]` for - some `k`) then we pick the corresponding dimension bounds out of - `start_indices.shape`, skipping `index_vector_dim` (i.e. pick - `start_indices.shape.dims`[`k`] if `k` < `index_vector_dim` and - `start_indices.shape.dims`[`k`+`1`] otherwise). +1. If `i` is present in `batch_dims` (i.e. is equal to `batch_dims[k]` for +some `k`) then we pick the corresponding dimension bounds out of +`start_indices.shape`, skipping `index_vector_dim` (i.e. pick +`start_indices.shape.dims`[`k`] if `k` < `index_vector_dim` and +`start_indices.shape.dims`[`k`+`1`] otherwise). - 2. If `i` is present in `offset_dims` (i.e. equal to `offset_dims`[`k`] for - some `k`) then we pick the corresponding bound out of `slice_sizes` after - accounting for `collapsed_slice_dims` (i.e. we pick - `adjusted_slice_sizes`[`k`] where `adjusted_slice_sizes` is `slice_sizes` - with the bounds at indices `collapsed_slice_dims` removed). +2. If `i` is present in `offset_dims` (i.e. equal to `offset_dims`[`k`] for +some `k`) then we pick the corresponding bound out of `slice_sizes` after +accounting for `collapsed_slice_dims` (i.e. we pick +`adjusted_slice_sizes`[`k`] where `adjusted_slice_sizes` is `slice_sizes` +with the bounds at indices `collapsed_slice_dims` removed). Formally, the operand index `In` corresponding to an output index `Out` is computed as follows: - 1. Let `G` = { `Out`[`k`] for `k` in `batch_dims` }. Use `G` to slice out - vector `S` such that `S`[`i`] = `start_indices`[Combine(`G`, `i`)] where - Combine(A, b) inserts b at position `index_vector_dim` into A. Note that - this is well defined even if `G` is empty -- if `G` is empty then `S` = - `start_indices`. - - 2. Create a starting index, `S``in`, into `operand` using `S` by - scattering `S` using `start_index_map`. More precisely: - 1. `S``in`[`start_index_map`[`k`]] = `S`[`k`] if `k` < - `start_index_map.size`. - 2. `S``in`[`_`] = `0` otherwise. - - 3. Create an index `O``in` into `operand` by scattering the indices - at the offset dimensions in `Out` according to the `collapsed_slice_dims` - set. More precisely: - 1. `O``in`[`expand_offset_dims`(`k`)] = - `Out`[`offset_dims`[`k`]] if `k` < `offset_dims.size` - (`expand_offset_dims` is defined below). - 2. `O``in`[`_`] = `0` otherwise. - 4. `In` is `O``in` + `S``in` where + is element-wise - addition. +1. Let `G` = { `Out`[`k`] for `k` in `batch_dims` }. Use `G` to slice out +vector `S` such that `S`[`i`] = `start_indices`[Combine(`G`, `i`)] where +Combine(A, b) inserts b at position `index_vector_dim` into A. Note that +this is well defined even if `G` is empty -- if `G` is empty then `S` = +`start_indices`. + +2. Create a starting index, `S``in`, into `operand` using `S` by +scattering `S` using `start_index_map`. More precisely: +1. `S``in`[`start_index_map`[`k`]] = `S`[`k`] if `k` < +`start_index_map.size`. +2. `S``in`[`_`] = `0` otherwise. + +3. Create an index `O``in` into `operand` by scattering the indices +at the offset dimensions in `Out` according to the `collapsed_slice_dims` +set. More precisely: +1. `O``in`[`expand_offset_dims`(`k`)] = +`Out`[`offset_dims`[`k`]] if `k` < `offset_dims.size` +(`expand_offset_dims` is defined below). +2. `O``in`[`_`] = `0` otherwise. +4. `In` is `O``in` + `S``in` where + is element-wise +addition. `expand_offset_dims` is the monotonic function with domain [`0`, `offset.size`) and range [`0`, `operand.rank`) \ `collapsed_slice_dims`. So if, e.g., @@ -1282,21 +1282,21 @@ and range [`0`, `operand.rank`) \ `collapsed_slice_dims`. So if, e.g., Informally, every index `Out` in the output array corresponds to an element `E` in the operand array, computed as follows: - - We use the batch dimensions in `Out` to look up a starting index from - `start_indices`. +- We use the batch dimensions in `Out` to look up a starting index from +`start_indices`. - - We use `start_index_map` to map the starting index (which may have size less - than operand.rank) to a "full" starting index into operand. +- We use `start_index_map` to map the starting index (which may have size less +than operand.rank) to a "full" starting index into operand. - - We dynamic-slice out a slice with size `slice_sizes` using the full starting - index. +- We dynamic-slice out a slice with size `slice_sizes` using the full starting +index. - - We reshape the slice by collapsing the `collapsed_slice_dims` dimensions. - Since all collapsed slice dimensions have to have bound 1 this reshape is - always legal. +- We reshape the slice by collapsing the `collapsed_slice_dims` dimensions. +Since all collapsed slice dimensions have to have bound 1 this reshape is +always legal. - - We use the offset dimensions in `Out` to index into this slice to get the - input element, `E`, corresponding to output index `Out`. +- We use the offset dimensions in `Out` to index into this slice to get the +input element, `E`, corresponding to output index `Out`. `index_vector_dim` is set to `start_indices.rank` - `1` in all of the examples that follow. More interesting values for `index_vector_dim` does not @@ -1315,7 +1315,7 @@ the output shape, and maps it to an element in the input array in the following way:
- +
We first select an (`X`,`Y`) vector from the gather indices array using `G`. @@ -1334,7 +1334,7 @@ version of the example above using a "gather indices" array of shape `[4,5,2]` would translate indices like this:
- +
Again, this acts as a batch dynamic slice `G``0` and @@ -1343,27 +1343,27 @@ Again, this acts as a batch dynamic slice `G``0` and The gather operation in XLA generalizes the informal semantics outlined above in the following ways: - 1. We can configure which dimensions in the output shape are the offset - dimensions (dimensions containing `O``0`, `O``1` in - the last example). The output batch dimensions (dimensions containing - `G``0`, `G``1` in the last example) are defined to be - the output dimensions that are not offset dimensions. +1. We can configure which dimensions in the output shape are the offset +dimensions (dimensions containing `O``0`, `O``1` in +the last example). The output batch dimensions (dimensions containing +`G``0`, `G``1` in the last example) are defined to be +the output dimensions that are not offset dimensions. - 2. The number of output offset dimensions explicitly present in the output - shape may be smaller than the input rank. These "missing" dimensions, which - are listed explicitly as `collapsed_slice_dims`, must have a slice size of - `1`. Since they have a slice size of `1` the only valid index for them is - `0` and eliding them does not introduce ambiguity. +2. The number of output offset dimensions explicitly present in the output +shape may be smaller than the input rank. These "missing" dimensions, which +are listed explicitly as `collapsed_slice_dims`, must have a slice size of +`1`. Since they have a slice size of `1` the only valid index for them is +`0` and eliding them does not introduce ambiguity. - 3. The slice extracted from the "Gather Indices" array ((`X`, `Y`) in the last - example) may have fewer elements than the input array rank, and an explicit - mapping dictates how the index should be expanded to have the same rank as - the input. +3. The slice extracted from the "Gather Indices" array ((`X`, `Y`) in the last +example) may have fewer elements than the input array rank, and an explicit +mapping dictates how the index should be expanded to have the same rank as +the input. As a final example, we use (2) and (3) to implement `tf.gather_nd`:
- +
`G``0` and `G``1` are used to slice out a starting index @@ -1442,11 +1442,11 @@ dependency between the while loops. ``` result1 = while (condition, init = init_value) { - Infeed(shape) +Infeed(shape) } result2 = while (condition, init = result1) { - Infeed(shape) +Infeed(shape) } ``` @@ -1464,12 +1464,15 @@ Infeed of the device. Builds a constant literal on device rather than a potentially large host transfer. Creates a rank 1 array of values starting at zero and incrementing by -one. +one. For floating-point types, the produced array is equivalent to +`ConvertElementType(Iota(...))` where the `Iota` is of integral type and the +conversion is to the floating-point type. -Arguments | Type | Semantics ---------- | --------------- | ------------------------------------ -`type` | `PrimitiveType` | type U -`size` | `int64` | The number of elements in the array. +Arguments | Type | Semantics +---------------- | --------------- | ------------------------------------ +`type` | `PrimitiveType` | type U +`size` | `int64` | The number of elements in the array. +`iota_dimension` | `int64` | The dimension to increment along. ## Map diff --git a/tensorflow/compiler/xla/layout.cc b/tensorflow/compiler/xla/layout.cc new file mode 100644 index 0000000000000000000000000000000000000000..e3b5fcd5274881cec31ecf906e3461685f82a1f4 --- /dev/null +++ b/tensorflow/compiler/xla/layout.cc @@ -0,0 +1,96 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/layout.h" + +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" +#include "tensorflow/compiler/xla/layout_util.h" + +namespace xla { + +TileProto Tile::ToProto() const { + TileProto tile_proto; + for (int64 i : dimensions()) { + tile_proto.add_dimensions(i); + } + return tile_proto; +} + +string Tile::ToString() const { + return absl::StrCat("(", absl::StrJoin(dimensions(), ","), ")"); +} + +/* static */ Layout Layout::CreateFromProto(const LayoutProto& proto) { + Layout layout; + layout.set_format(proto.format()); + layout.minor_to_major_.reserve(proto.minor_to_major_size()); + for (const int64 dimension : proto.minor_to_major()) { + layout.add_minor_to_major(dimension); + } + layout.set_max_sparse_elements(proto.max_sparse_elements()); + for (const TileProto& tile_proto : proto.tiles()) { + *layout.add_tiles() = Tile::CreateFromProto(tile_proto); + } + layout.set_element_size_in_bits(proto.element_size_in_bits()); + return layout; +} + +LayoutProto Layout::ToProto() const { + LayoutProto proto; + proto.set_format(format_); + proto.mutable_minor_to_major()->Reserve(minor_to_major_size()); + for (const int64 dimension : minor_to_major()) { + proto.add_minor_to_major(dimension); + } + proto.set_max_sparse_elements(max_sparse_elements_); + for (const Tile& tile : tiles()) { + *proto.add_tiles() = tile.ToProto(); + } + proto.set_element_size_in_bits(element_size_in_bits()); + return proto; +} + +string Layout::ToString() const { + // TODO(b/119839262): Emit tiles in string. + if (format() == SPARSE) { + return absl::StrCat("sparse{", max_sparse_elements(), "}"); + } else if (format() == DENSE) { + return absl::StrCat("{", absl::StrJoin(minor_to_major(), ","), "}"); + } else { + CHECK_EQ(format(), INVALID_FORMAT); + return "invalid{}"; + } +} + +bool Layout::operator==(const Layout& other) const { + return (other.format() == format() && + other.minor_to_major() == minor_to_major() && + other.element_size_in_bits() == element_size_in_bits() && + other.max_sparse_elements() == max_sparse_elements() && + other.tiles() == tiles()); +} + +std::ostream& operator<<(std::ostream& out, const Tile& tile) { + out << tile.ToString(); + return out; +} + +std::ostream& operator<<(std::ostream& out, const Layout& layout) { + out << layout.ToString(); + return out; +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/layout.h b/tensorflow/compiler/xla/layout.h new file mode 100644 index 0000000000000000000000000000000000000000..313368c39e4c976fc481941eb17325101f2ba69a --- /dev/null +++ b/tensorflow/compiler/xla/layout.h @@ -0,0 +1,187 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_LAYOUT_H_ +#define TENSORFLOW_COMPILER_XLA_LAYOUT_H_ + +#include + +#include "absl/types/span.h" + +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/platform/types.h" + +namespace xla { + +// Describes a tile used in tiling-based layout. Refer to +// g3doc/third_party/tensorflow/compiler/xla/g3doc/layout_with_tiling.md for +// details. +class Tile { + public: + Tile() = default; + explicit Tile(absl::Span dimensions) + : dimensions_(dimensions.begin(), dimensions.end()) {} + + // De/Serialize a Tile to and from a TileProto. + static Tile CreateFromProto(const TileProto& tile_proto) { + return Tile(AsInt64Slice(tile_proto.dimensions())); + } + TileProto ToProto() const; + + bool operator==(const Tile& other) const { + return dimensions() == other.dimensions(); + } + bool operator!=(const Tile& other) const { return !(*this == other); } + + string ToString() const; + + // Returns the bound of the tile in the given dimension index. + int64 dimension(int i) const { return dimensions_.at(i); } + + // Returns the dimensions of the tile. + const std::vector& dimensions() const { return dimensions_; } + + private: + // The bounds of the tile. + std::vector dimensions_; +}; + +class Layout { + public: + Layout() = default; + + // Constructs a dense layout with the given minor-to-major order. + explicit Layout(absl::Span minor_to_major) + : format_(DENSE), + minor_to_major_(minor_to_major.begin(), minor_to_major.end()) {} + + // Constructs a dense tiled layout with the given minor-to-major order and + // tiles. + Layout(absl::Span minor_to_major, absl::Span tiles) + : format_(DENSE), + minor_to_major_(minor_to_major.begin(), minor_to_major.end()), + tiles_(tiles.begin(), tiles.end()) {} + + // Construct a shape from a LayoutProto. + static Layout CreateFromProto(const LayoutProto& proto); + + // Returns a LayoutProto representation of the Layout. + LayoutProto ToProto() const; + + // Returns a human-readable string that represents this layout. + string ToString() const; + + bool operator==(const Layout& other) const; + bool operator!=(const Layout& other) const { return !(*this == other); } + + // The following methods mirror the protobuf generated code interface for the + // message LayoutProto. This enabled easy migration of this data structure + // from a proto to a proper C++ class. + // + // TODO(b/29771030): Replace or augment these methods with a more ergonomic + // interface. + + // Methods for accessing the format. + Format format() const { return format_; } + Layout& set_format(Format value) { + format_ = value; + return *this; + } + + // Methods for accessing the minor-to-major array. + int minor_to_major_size() const { return minor_to_major_.size(); } + int64 minor_to_major(int index) const { return minor_to_major_.at(index); } + Layout& set_minor_to_major(int index, int64 value) { + minor_to_major_.at(index) = value; + return *this; + } + Layout& add_minor_to_major(int64 value) { + minor_to_major_.push_back(value); + return *this; + } + Layout& clear_minor_to_major() { + minor_to_major_.clear(); + return *this; + } + const std::vector& minor_to_major() const { return minor_to_major_; } + std::vector* mutable_minor_to_major() { return &minor_to_major_; } + + // Methods for accessing the tile field. + int tiles_size() const { return tiles_.size(); } + const Tile& tiles(int index) const { return tiles_.at(index); } + Tile* mutable_tiles(int index) { return &tiles_.at(index); } + Tile* add_tiles() { + tiles_.push_back(Tile()); + return &tiles_.back(); + } + Layout& clear_tiles() { + tiles_.clear(); + return *this; + } + const std::vector& tiles() const { return tiles_; } + std::vector* mutable_tiles() { return &tiles_; } + + // Methods for accessing the int64 fields. + int64 max_sparse_elements() const { return max_sparse_elements_; } + Layout& set_max_sparse_elements(int64 value) { + max_sparse_elements_ = value; + return *this; + } + int64 element_size_in_bits() const { return element_size_in_bits_; } + Layout& set_element_size_in_bits(int64 value) { + element_size_in_bits_ = value; + return *this; + } + + void Swap(Layout* other) { + using std::swap; + swap(*this, *other); + } + + void Clear() { + format_ = INVALID_FORMAT; + minor_to_major_.clear(); + max_sparse_elements_ = 0; + element_size_in_bits_ = 0; + } + + public: + // The format of this layout. + Format format_ = INVALID_FORMAT; + + // Sequence of dimension numbers, from minor (fastest varying index) to major + // (slowest varying index). + std::vector minor_to_major_; + + // The maximum number of elements that can be stored for SPARSE formats. This + // can be used to determine the maximum size in bytes of arrays stored in + // memory. This field must be zero unless the format is SPARSE. + int64 max_sparse_elements_ = 0; + + // The number of bits used to store an individual array element. + int64 element_size_in_bits_ = 0; + + // The tiles used in tiling-based layout. + std::vector tiles_; +}; + +std::ostream& operator<<(std::ostream& out, const Tile& Tile); +std::ostream& operator<<(std::ostream& out, const Layout& layout); + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_LAYOUT_H_ diff --git a/tensorflow/compiler/xla/layout_test.cc b/tensorflow/compiler/xla/layout_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..fb6abd3f6523b978e72b21ec082ae06973e86243 --- /dev/null +++ b/tensorflow/compiler/xla/layout_test.cc @@ -0,0 +1,104 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/layout.h" + +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" +#include "tensorflow/compiler/xla/layout_util.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/test_helpers.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" + +namespace xla { +namespace { + +class LayoutTest : public ::testing::Test {}; + +TEST_F(LayoutTest, ToString) { + EXPECT_EQ(Layout().ToString(), "invalid{}"); + EXPECT_EQ(Layout({4, 5, 6}).ToString(), "{4,5,6}"); + EXPECT_EQ(Layout().set_format(SPARSE).set_max_sparse_elements(123).ToString(), + "sparse{123}"); + EXPECT_EQ(Layout({4, 5, 6}).ToString(), "{4,5,6}"); + EXPECT_EQ(Layout({3, 2, 1, 0}, {Tile({42, 123}), Tile({4, 5})}).ToString(), + "{3,2,1,0}"); + EXPECT_EQ( + Layout({1, 0}, {Tile({2, 55})}).set_element_size_in_bits(42).ToString(), + "{1,0}"); +} + +TEST_F(LayoutTest, StreamOut) { + { + std::ostringstream oss; + oss << Tile({7, 8}); + EXPECT_EQ(oss.str(), "(7,8)"); + } + + { + std::ostringstream oss; + oss << Layout({0, 1, 2}); + EXPECT_EQ(oss.str(), "{0,1,2}"); + } +} + +TEST_F(LayoutTest, SparseLayoutMaxElements) { + EXPECT_EQ(LayoutUtil::MaxSparseElements(LayoutUtil::MakeSparseLayout(101)), + 101); +} + +TEST_F(LayoutTest, Equality) { + EXPECT_EQ(Layout(), Layout()); + const std::vector empty_dims; + EXPECT_EQ(Layout(empty_dims), Layout(empty_dims)); + EXPECT_NE(Layout(), Layout(empty_dims)); + EXPECT_EQ(Layout({0, 1, 2, 3}), Layout({0, 1, 2, 3})); + EXPECT_NE(Layout({0, 1, 2, 3}), Layout({0, 1, 2})); + EXPECT_EQ(Layout({0, 1, 2}, {Tile({42, 44})}), + Layout({0, 1, 2}, {Tile({42, 44})})); + EXPECT_NE(Layout({0, 1, 2}, {Tile({42, 44})}), + Layout({0, 1, 2}, {Tile({42, 45})})); + EXPECT_NE(Layout({0, 1, 2}, {Tile({42, 44})}), Layout({0, 1, 2, 3})); + EXPECT_EQ(Layout({0, 1, 2}).set_element_size_in_bits(33), + Layout({0, 1, 2}).set_element_size_in_bits(33)); + EXPECT_NE(Layout({0, 1, 2}).set_element_size_in_bits(33), + Layout({0, 1, 2}).set_element_size_in_bits(7)); + EXPECT_EQ(Layout().set_format(SPARSE), Layout().set_format(SPARSE)); + EXPECT_EQ(Layout().set_format(SPARSE).set_max_sparse_elements(42), + Layout().set_format(SPARSE).set_max_sparse_elements(42)); + EXPECT_NE(Layout().set_format(SPARSE).set_max_sparse_elements(42), + Layout().set_format(SPARSE).set_max_sparse_elements(24)); +} + +TEST_F(LayoutTest, LayoutToFromProto) { + // Round-trips a Layout through proto de/serialization. + auto expect_unchanged = [](const Layout& layout) { + EXPECT_EQ(layout, Layout::CreateFromProto(layout.ToProto())); + }; + + expect_unchanged(Layout()); + expect_unchanged(Layout({1, 3, 2, 0})); + expect_unchanged(Layout().set_format(SPARSE)); + expect_unchanged(Layout().set_format(SPARSE).set_max_sparse_elements(123)); + expect_unchanged(Layout({0, 1}).set_element_size_in_bits(42)); + expect_unchanged(Layout({3, 2, 1, 0}, {Tile({42, 123}), Tile({4, 5})})); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/layout_util.cc b/tensorflow/compiler/xla/layout_util.cc index dbb81381acde645f08639737b6e7b6f6ad971f9b..ddccd8c798df5b926d2e5aea8975cb6cb6640824 100644 --- a/tensorflow/compiler/xla/layout_util.cc +++ b/tensorflow/compiler/xla/layout_util.cc @@ -41,15 +41,13 @@ namespace { // Internal helper for GetDefaultLayoutForShape and SetToDefaultLayout. Sets // minor_to_major to the value that represents the default layout. -void SetDefaultLayoutToContainer( - tensorflow::protobuf::RepeatedField* - minor_to_major) { +void SetDefaultLayoutToContainer(std::vector* minor_to_major) { // The default XLA layout is major-to-minor (dim 0 is major). // For more information on XLA layouts, see: // https://www.tensorflow.org/performance/xla/shapes const int64 size = minor_to_major->size(); for (int64 i = 0; i < size; ++i) { - minor_to_major->Set(i, size - 1 - i); + (*minor_to_major)[i] = size - 1 - i; } } @@ -94,9 +92,8 @@ namespace { Layout CreateDefaultLayoutForRank(int64 rank) { Layout layout; layout.set_format(DENSE); - tensorflow::protobuf::RepeatedField* - minor_to_major = layout.mutable_minor_to_major(); - minor_to_major->Resize(rank, 0); + std::vector* minor_to_major = layout.mutable_minor_to_major(); + minor_to_major->resize(rank, 0); SetDefaultLayoutToContainer(minor_to_major); return layout; } @@ -139,9 +136,8 @@ Layout CreateDefaultLayoutForRank(int64 rank) { shape->clear_layout(); } else if (ShapeUtil::IsArray(*shape)) { shape->mutable_layout()->set_format(DENSE); - tensorflow::protobuf::RepeatedField* - minor_to_major = shape->mutable_layout()->mutable_minor_to_major(); - minor_to_major->Resize(shape->dimensions_size(), 0); + auto* minor_to_major = shape->mutable_layout()->mutable_minor_to_major(); + minor_to_major->resize(shape->dimensions_size(), 0); SetDefaultLayoutToContainer(minor_to_major); } else { // Opaque, token types etc. have no layout. @@ -210,9 +206,8 @@ Layout CreateDefaultLayoutForRank(int64 rank) { } if (layout.format() == INVALID_FORMAT || !Format_IsValid(layout.format())) { - return InvalidArgument( - "Layout has an invalid format (%d) in layout {%s}, shape {%s}", - layout.format(), layout.ShortDebugString(), shape.ShortDebugString()); + return InvalidArgument("Layout has an invalid format (%d)", + layout.format()); } if (layout.format() == DENSE) { @@ -316,7 +311,7 @@ Layout CreateDefaultLayoutForRank(int64 rank) { } /* static */ bool LayoutUtil::Equal(const Layout& lhs, const Layout& rhs) { - return protobuf_util::ProtobufEquals(lhs, rhs); + return lhs == rhs; } /* static */ absl::Span LayoutUtil::MinorToMajor( @@ -358,11 +353,7 @@ Layout CreateDefaultLayoutForRank(int64 rank) { } /* static */ string LayoutUtil::HumanString(const Layout& layout) { - if (IsSparse(layout)) { - return absl::StrCat("sparse{", layout.max_sparse_elements(), "}"); - } - CHECK(IsDense(layout)); - return absl::StrCat("{", absl::StrJoin(layout.minor_to_major(), ","), "}"); + return layout.ToString(); } namespace { @@ -444,11 +435,6 @@ Status LayoutUtil::CopyLayoutBetweenShapes(const Shape& src, Shape* dst) { return true; } -std::ostream& operator<<(std::ostream& out, const Layout& layout) { - out << LayoutUtil::HumanString(layout); - return out; -} - /*static*/ size_t LayoutUtil::Hash(const Layout& layout) { using tensorflow::hash; using tensorflow::Hash64Combine; diff --git a/tensorflow/compiler/xla/layout_util.h b/tensorflow/compiler/xla/layout_util.h index 6c298e57252449ce3f1f9055436e918f2d9f17f1..609dba67bcdbcb11be0906b7d87a52a17ba0dfbd 100644 --- a/tensorflow/compiler/xla/layout_util.h +++ b/tensorflow/compiler/xla/layout_util.h @@ -21,6 +21,7 @@ limitations under the License. #include #include "absl/types/span.h" +#include "tensorflow/compiler/xla/layout.h" #include "tensorflow/compiler/xla/shape.h" #include "tensorflow/compiler/xla/status.h" #include "tensorflow/compiler/xla/types.h" @@ -195,8 +196,6 @@ class LayoutUtil { TF_DISALLOW_COPY_AND_ASSIGN(LayoutUtil); }; -std::ostream& operator<<(std::ostream& out, const Layout& layout); - } // namespace xla #endif // TENSORFLOW_COMPILER_XLA_LAYOUT_UTIL_H_ diff --git a/tensorflow/compiler/xla/layout_util_test.cc b/tensorflow/compiler/xla/layout_util_test.cc index 12ce2d2d7c6fa8c590035f9ff2af50001ccf80d8..4cc94c270cd64eb19761cc1044861c7d185b7888 100644 --- a/tensorflow/compiler/xla/layout_util_test.cc +++ b/tensorflow/compiler/xla/layout_util_test.cc @@ -317,17 +317,6 @@ TEST_F(LayoutUtilTest, DefaultLayoutGettersMajorToMinor) { ShapeUtil::MakeShape(F32, {10, 20, 30, 15, 25})))); } -TEST_F(LayoutUtilTest, SparseLayoutMaxElements) { - EXPECT_EQ(LayoutUtil::MaxSparseElements(LayoutUtil::MakeSparseLayout(101)), - 101); -} - -TEST_F(LayoutUtilTest, StreamOut) { - std::ostringstream oss; - oss << LayoutUtil::MakeLayout({0, 1, 2}); - EXPECT_EQ(oss.str(), "{0,1,2}"); -} - TEST_F(LayoutUtilTest, ValidateLayout_ValidArrayLayout) { Shape shape = ShapeUtil::MakeShapeWithLayout(F32, {2, 3}, {0, 1}); auto status = diff --git a/tensorflow/compiler/xla/literal.cc b/tensorflow/compiler/xla/literal.cc index 8f480c1f1079b4e1a5be53958ebdf6e004ad9ebe..277c98721e59ac12965392500fdfdc3d91e59a8b 100644 --- a/tensorflow/compiler/xla/literal.cc +++ b/tensorflow/compiler/xla/literal.cc @@ -1028,20 +1028,21 @@ string ShapeToString(bool print_layout, const Shape& shape) { } void ToStringHelper(const LiteralBase& literal, const ShapeIndex& shape_index, - bool print_layout, std::vector* pieces); + bool print_shape, bool print_layout, + std::vector* pieces); void TupleToStringHelper(const LiteralBase& literal, - const ShapeIndex& shape_index, bool print_layout, - std::vector* pieces) { + const ShapeIndex& shape_index, bool print_shape, + bool print_layout, std::vector* pieces) { const Shape& subshape = ShapeUtil::GetSubshape(literal.shape(), shape_index); - pieces->push_back(ShapeToString(print_layout, subshape)); - pieces->push_back(" (\n"); + pieces->push_back("(\n"); std::vector tuple_pieces; for (int i = 0; i < ShapeUtil::TupleElementCount(subshape); ++i) { ShapeIndex element_index = shape_index; element_index.push_back(i); std::vector element_pieces; - ToStringHelper(literal, element_index, print_layout, &element_pieces); + ToStringHelper(literal, element_index, print_shape, print_layout, + &element_pieces); tuple_pieces.push_back(absl::StrJoin(element_pieces, "")); } pieces->push_back(absl::StrJoin(tuple_pieces, ",\n")); @@ -1049,9 +1050,11 @@ void TupleToStringHelper(const LiteralBase& literal, } void SparseArrayToStringHelper(const LiteralBase& literal, - const Shape& subshape, bool print_layout, - std::vector* pieces) { - pieces->push_back(ShapeToString(print_layout, subshape)); + const Shape& subshape, bool print_shape, + bool print_layout, std::vector* pieces) { + if (print_shape) { + pieces->push_back(ShapeToString(print_layout, subshape)); + } pieces->push_back("{"); int64 rank = ShapeUtil::Rank(subshape); int64 num_elements = literal.sparse_element_count(); @@ -1073,8 +1076,8 @@ void SparseArrayToStringHelper(const LiteralBase& literal, } void DenseArrayToStringHelper(const LiteralBase& literal, - const ShapeIndex& shape_index, bool print_layout, - std::vector* pieces) { + const ShapeIndex& shape_index, bool print_shape, + bool print_layout, std::vector* pieces) { const Shape& subshape = ShapeUtil::GetSubshape(literal.shape(), shape_index); int64 rank = ShapeUtil::Rank(subshape); @@ -1135,7 +1138,7 @@ void DenseArrayToStringHelper(const LiteralBase& literal, } }; - if (rank > 1) { + if (print_shape) { pieces->push_back(ShapeToString(print_layout, subshape)); pieces->push_back(" "); } @@ -1146,19 +1149,23 @@ void DenseArrayToStringHelper(const LiteralBase& literal, } void ToStringHelper(const LiteralBase& literal, const ShapeIndex& shape_index, - bool print_layout, std::vector* pieces) { + bool print_shape, bool print_layout, + std::vector* pieces) { const Shape& subshape = ShapeUtil::GetSubshape(literal.shape(), shape_index); CHECK(LayoutUtil::HasLayout(literal.shape())); CHECK(LayoutUtil::HasLayout(subshape)); if (ShapeUtil::IsTuple(subshape)) { - TupleToStringHelper(literal, shape_index, print_layout, pieces); + TupleToStringHelper(literal, shape_index, print_shape, print_layout, + pieces); } else if (ShapeUtil::IsToken(subshape)) { pieces->push_back("token"); } else if (LayoutUtil::IsSparseArray(subshape)) { - SparseArrayToStringHelper(literal, subshape, print_layout, pieces); + SparseArrayToStringHelper(literal, subshape, print_shape, print_layout, + pieces); } else { CHECK(LayoutUtil::IsDenseArray(subshape)); - DenseArrayToStringHelper(literal, shape_index, print_layout, pieces); + DenseArrayToStringHelper(literal, shape_index, print_shape, print_layout, + pieces); } } @@ -1169,10 +1176,27 @@ int64 LiteralBase::sparse_element_count() const { return sparse_indices()->index_count(); } -string LiteralBase::ToString(bool print_layout) const { +string LiteralBase::ToString() const { + std::vector pieces; + CHECK(LayoutUtil::HasLayout(this->shape())); + ToStringHelper(*this, {}, /*print_shape=*/true, + /*print_layout=*/false, &pieces); + return absl::StrJoin(pieces, ""); +} + +string LiteralBase::ToStringWithoutShape() const { + std::vector pieces; + CHECK(LayoutUtil::HasLayout(this->shape())); + ToStringHelper(*this, {}, /*print_shape=*/false, + /*print_layout=*/false, &pieces); + return absl::StrJoin(pieces, ""); +} + +string LiteralBase::ToStringWithLayout() const { std::vector pieces; CHECK(LayoutUtil::HasLayout(this->shape())); - ToStringHelper(*this, {}, print_layout, &pieces); + ToStringHelper(*this, {}, /*print_shape=*/true, + /*print_layout=*/true, &pieces); return absl::StrJoin(pieces, ""); } diff --git a/tensorflow/compiler/xla/literal.h b/tensorflow/compiler/xla/literal.h index fa9a71af4ceb998a7a289443cbef70eb52cb1a11..67e908e7ec4d4346f4e26a99a42aac26928ec0c2 100644 --- a/tensorflow/compiler/xla/literal.h +++ b/tensorflow/compiler/xla/literal.h @@ -92,9 +92,20 @@ class LiteralBase { // array. string GetR1U8AsString() const; - // Returns a string representation of the literal value. - // Warning: this function can take minutes for multi-million element Literals. - string ToString(bool print_layout = false) const; + // Returns a string representation of the literal value. The Shape of the + // literal is a prefix of the literal value in the string. + + // Warning: this function can take minutes for multi-million + // element Literals. + string ToString() const; + + // Returns a string representation of the literal value which does *not* + // include the shape string. + string ToStringWithoutShape() const; + + // Returns a string representation of the literal value which includes the + // shape string with its layout.does *not* include the shape string. + string ToStringWithLayout() const; // Gets an element in the literal at the given index. The multi_index is // CHECKed against the dimension sizes. diff --git a/tensorflow/compiler/xla/literal_comparison.cc b/tensorflow/compiler/xla/literal_comparison.cc index b044f0ad73f13a0599e77f1f43888bc974e31f73..1ac9a48e805daa86f0dc65b54626195c89241020 100644 --- a/tensorflow/compiler/xla/literal_comparison.cc +++ b/tensorflow/compiler/xla/literal_comparison.cc @@ -46,68 +46,102 @@ uint16 GetRawValue(Eigen::half val) { return val.x; } // between the left-hand-side and right-hand-side, by bit-casting to UnsignedT // -- on miscompare, a nice error message is given in the AssertionFailure. template -Status CompareFloatsBitwiseEqual(FloatT lhs, FloatT rhs, - absl::Span multi_index) { +bool CompareFloatsBitwiseEqual(FloatT lhs, FloatT rhs, + absl::Span multi_index) { + auto ulhs = absl::bit_cast(GetRawValue(lhs)); + auto urhs = absl::bit_cast(GetRawValue(rhs)); + return ulhs == urhs; +} + +// Templated comparator that specializes for float equality comparison with the +// bitwise helper above (this is the un-specialized fallback, to just use the +// default gunit implementation). +template +bool CompareEqual(NativeT lhs, NativeT rhs, + absl::Span multi_index) { + return lhs == rhs; +} + +// Specializations for floating types that do bitwise comparisons when equality +// comparison is requested. +template <> +bool CompareEqual(bfloat16 lhs, bfloat16 rhs, + absl::Span multi_index) { + return CompareFloatsBitwiseEqual(lhs, rhs, multi_index); +} +template <> +bool CompareEqual(Eigen::half lhs, Eigen::half rhs, + absl::Span multi_index) { + return CompareFloatsBitwiseEqual(lhs, rhs, multi_index); +} +template <> +bool CompareEqual(float lhs, float rhs, + absl::Span multi_index) { + return CompareFloatsBitwiseEqual(lhs, rhs, multi_index); +} +template <> +bool CompareEqual(double lhs, double rhs, + absl::Span multi_index) { + return CompareFloatsBitwiseEqual(lhs, rhs, multi_index); +} +template <> +bool CompareEqual(complex64 lhs, complex64 rhs, + absl::Span multi_index) { + return CompareEqual(lhs.real(), rhs.real(), multi_index) && + CompareEqual(lhs.imag(), rhs.imag(), multi_index); +} + +template +Status MakeBitwiseErrorStatus(NativeT lhs, NativeT rhs, + absl::Span multi_index) { auto ulhs = absl::bit_cast(GetRawValue(lhs)); auto urhs = absl::bit_cast(GetRawValue(rhs)); auto lhs_double = static_cast(lhs); auto rhs_double = static_cast(rhs); - if (ulhs != urhs) { return InvalidArgument( "floating values are not bitwise-equal; and equality testing " "was requested: %s=%g=%a vs %s=%g=%a at array index %s", StrCat(absl::Hex(ulhs)), lhs_double, lhs_double, StrCat(absl::Hex(urhs)), rhs_double, rhs_double, LiteralUtil::MultiIndexAsString(multi_index)); - } - return Status::OK(); } -// Templated comparator that specializes for float equality comparison with the -// bitwise helper above (this is the un-specialized fallback, to just use the -// default gunit implementation). template -Status CompareEqual(NativeT lhs, NativeT rhs, - absl::Span multi_index) { - if (lhs == rhs) { - return Status::OK(); - } +Status MakeErrorStatus(NativeT lhs, NativeT rhs, + absl::Span multi_index) { return InvalidArgument( "first mismatch at array index %s:\n expected value: %s\n actual " "value: %s", LiteralUtil::MultiIndexAsString(multi_index), StrCat(lhs), StrCat(rhs)); } -// Specializations for floating types that do bitwise comparisons when equality -// comparison is requested. template <> -Status CompareEqual(bfloat16 lhs, bfloat16 rhs, - absl::Span multi_index) { - return CompareFloatsBitwiseEqual(lhs, rhs, multi_index); +Status MakeErrorStatus(bfloat16 lhs, bfloat16 rhs, + absl::Span multi_index) { + return MakeBitwiseErrorStatus(lhs, rhs, multi_index); } template <> -Status CompareEqual(Eigen::half lhs, Eigen::half rhs, - absl::Span multi_index) { - return CompareFloatsBitwiseEqual(lhs, rhs, multi_index); +Status MakeErrorStatus(Eigen::half lhs, Eigen::half rhs, + absl::Span multi_index) { + return MakeBitwiseErrorStatus(lhs, rhs, multi_index); } template <> -Status CompareEqual(float lhs, float rhs, - absl::Span multi_index) { - return CompareFloatsBitwiseEqual(lhs, rhs, multi_index); +Status MakeErrorStatus(float lhs, float rhs, + absl::Span multi_index) { + return MakeBitwiseErrorStatus(lhs, rhs, multi_index); } template <> -Status CompareEqual(double lhs, double rhs, - absl::Span multi_index) { - return CompareFloatsBitwiseEqual(lhs, rhs, multi_index); +Status MakeErrorStatus(double lhs, double rhs, + absl::Span multi_index) { + return MakeBitwiseErrorStatus(lhs, rhs, multi_index); } template <> -Status CompareEqual(complex64 lhs, complex64 rhs, - absl::Span multi_index) { - auto res = CompareEqual(lhs.real(), rhs.real(), multi_index); - if (!res.ok()) { - return res; +Status MakeErrorStatus(complex64 lhs, complex64 rhs, + absl::Span multi_index) { + if (!CompareEqual(lhs.real(), rhs.real(), multi_index)) { + return MakeErrorStatus(lhs.real(), rhs.real(), multi_index); } - return CompareEqual(lhs.imag(), rhs.imag(), multi_index); + return MakeErrorStatus(lhs.imag(), rhs.imag(), multi_index); } // A recursive function which iterates through every index of expected and @@ -119,7 +153,11 @@ Status Equal(LiteralSlice expected, LiteralSlice actual, if (dimension == expected.shape().dimensions_size()) { NativeT expected_value = expected.Get(multi_index); NativeT actual_value = actual.Get(multi_index); - return CompareEqual(expected_value, actual_value, multi_index); + bool result = + CompareEqual(expected_value, actual_value, multi_index); + return result ? Status::OK() + : MakeErrorStatus(expected_value, actual_value, + multi_index); } Status result; @@ -330,7 +368,7 @@ class NearComparator { NanMismatch(expected, actual, error_.relaxed_nans); float abs_error; float rel_error; - if (CompareEqual(expected, actual, {linear_index}).ok()) { + if (CompareEqual(expected, actual, {linear_index})) { abs_error = 0; rel_error = 0; } else if (is_nan_mismatch) { @@ -344,7 +382,7 @@ class NearComparator { } else if (IsInf(expected) || IsInf(actual)) { // If either the expected or actual value is infinity but not both, // then both absolute and relative error are regarded as inifity. - CHECK(!CompareEqual(expected, actual, {linear_index}).ok()); + CHECK(!CompareEqual(expected, actual, {linear_index})); abs_error = std::numeric_limits::infinity(); rel_error = std::numeric_limits::infinity(); } else { diff --git a/tensorflow/compiler/xla/literal_test.cc b/tensorflow/compiler/xla/literal_test.cc index 49363ad802ddb9520f89b53257216bc7ddaf8ff5..d8c7141cacb8f60cb4ce56d07ac5827a8dbf9b20 100644 --- a/tensorflow/compiler/xla/literal_test.cc +++ b/tensorflow/compiler/xla/literal_test.cc @@ -98,42 +98,42 @@ class LiteralUtilTest : public ::testing::Test { TEST_F(LiteralUtilTest, LiteralScalarToString) { auto true_lit = LiteralUtil::CreateR0(true); - EXPECT_EQ("true", true_lit.ToString()); + EXPECT_EQ("pred[] true", true_lit.ToString()); auto false_lit = LiteralUtil::CreateR0(false); - EXPECT_EQ("false", false_lit.ToString()); + EXPECT_EQ("pred[] false", false_lit.ToString()); auto u32_lit = LiteralUtil::CreateR0(42); - EXPECT_EQ("42", u32_lit.ToString()); + EXPECT_EQ("u32[] 42", u32_lit.ToString()); auto s32_lit = LiteralUtil::CreateR0(-999); - EXPECT_EQ("-999", s32_lit.ToString()); + EXPECT_EQ("s32[] -999", s32_lit.ToString()); auto f32_lit = LiteralUtil::CreateR0(3.14f); - EXPECT_EQ("3.14", f32_lit.ToString()); + EXPECT_EQ("f32[] 3.14", f32_lit.ToString()); auto f16_lit = LiteralUtil::CreateR0(static_cast(0.5f)); - EXPECT_EQ("0.5", f16_lit.ToString()); + EXPECT_EQ("f16[] 0.5", f16_lit.ToString()); auto c64_lit = LiteralUtil::CreateR0({3.14f, 2.78f}); - EXPECT_EQ("(3.14, 2.78)", c64_lit.ToString()); + EXPECT_EQ("c64[] (3.14, 2.78)", c64_lit.ToString()); auto bf16_lit = LiteralUtil::CreateR0(static_cast(0.5f)); - EXPECT_EQ("0.5", bf16_lit.ToString()); + EXPECT_EQ("bf16[] 0.5", bf16_lit.ToString()); // 3.14 will be rounded to 3.14062 in bfloat16 format. auto bf16_lit_truncated = LiteralUtil::CreateR0(static_cast(3.14f)); - ASSERT_EQ("3.14062", bf16_lit_truncated.ToString()); + ASSERT_EQ("bf16[] 3.14062", bf16_lit_truncated.ToString()); auto bf16_lit_truncated2 = LiteralUtil::CreateR0(static_cast(9.001f)); - EXPECT_EQ("9", bf16_lit_truncated2.ToString()); + EXPECT_EQ("bf16[] 9", bf16_lit_truncated2.ToString()); } TEST_F(LiteralUtilTest, LiteralVectorToString) { auto pred_vec = LiteralUtil::CreateR1({true, false, true}); - EXPECT_EQ("{1, 0, 1}", pred_vec.ToString()); + EXPECT_EQ("pred[3] {1, 0, 1}", pred_vec.ToString()); } TEST_F(LiteralUtilTest, R2ToString) { @@ -210,8 +210,8 @@ TEST_F(LiteralUtilTest, TupleToString) { auto scalar = LiteralUtil::CreateR0(1.0); auto matrix = LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}}); auto tuple = LiteralUtil::MakeTuple({&scalar, &matrix}); - const string expected = R"((f32[], f32[2,2]) ( -1, + const string expected = R"(( +f32[] 1, f32[2,2] { { 1, 2 }, { 3, 4 } @@ -1890,7 +1890,7 @@ TEST_F(LiteralUtilTest, SortSparseElements) { literal.AppendSparseElement({3, 4, 5}, 3.0); literal.AppendSparseElement({1, 2, 3}, 1.0); literal.SortSparseElements(); - EXPECT_EQ(literal.ToString(false), + EXPECT_EQ(literal.ToString(), "f32[10,10,10]{[1, 2, 3]: 1, [2, 3, 4]: 2, [3, 4, 5]: 3}"); } diff --git a/tensorflow/compiler/xla/packed_literal_reader.cc b/tensorflow/compiler/xla/packed_literal_reader.cc index 0f86f9f35e105713aa3072a9ebf572d33d35d66d..339660cf44fd64fc5859e72255d63762fcf20efe 100644 --- a/tensorflow/compiler/xla/packed_literal_reader.cc +++ b/tensorflow/compiler/xla/packed_literal_reader.cc @@ -42,8 +42,7 @@ PackedLiteralReader::~PackedLiteralReader() { delete file_; } StatusOr PackedLiteralReader::Read(const Shape& shape, const Layout* layout) { VLOG(3) << "reading shape from file: " << ShapeUtil::HumanString(shape) - << " layout: " - << (layout == nullptr ? "" : layout->ShortDebugString()); + << " layout: " << (layout == nullptr ? "" : layout->ToString()); Shape literal_shape = shape; if (layout != nullptr) { TF_RETURN_IF_ERROR( diff --git a/tensorflow/compiler/xla/primitive_util.cc b/tensorflow/compiler/xla/primitive_util.cc index b16147e3be71771269d8b7a18528bef3a8c72d99..00ad01fc407017624a9183d69e61cb0d382e3f11 100644 --- a/tensorflow/compiler/xla/primitive_util.cc +++ b/tensorflow/compiler/xla/primitive_util.cc @@ -15,6 +15,9 @@ limitations under the License. #include "tensorflow/compiler/xla/primitive_util.h" +#include "absl/strings/ascii.h" +#include "absl/strings/numbers.h" +#include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/platform/logging.h" namespace xla { @@ -90,5 +93,65 @@ bool IsArrayType(PrimitiveType primitive_type) { primitive_type != OPAQUE && primitive_type != TOKEN; } +// Class to memoize the computation of +// absl::AsciiStrToLower(PrimitiveType_Name(p)) +// for all PrimitiveType values "p" +class PrimitiveTypeNameGenerator { + public: + PrimitiveTypeNameGenerator() { + for (int i = 0; i < PrimitiveType_ARRAYSIZE; i++) { + if (PrimitiveType_IsValid(i)) { + lowercase_name_[i] = absl::AsciiStrToLower( + PrimitiveType_Name(static_cast(i))); + } + } + } + const string& LowercaseName(PrimitiveType t) { + return lowercase_name_[static_cast(t)]; + } + + private: + string lowercase_name_[PrimitiveType_ARRAYSIZE]; +}; + +const string& LowercasePrimitiveTypeName(PrimitiveType s) { + static auto* gen = new PrimitiveTypeNameGenerator(); + return gen->LowercaseName(s); +} + +namespace { + +// Returns a map from lower-case primitive type name to primitive type. +const std::unordered_map& GetPrimitiveTypeStringMap() { + static std::unordered_map* name_to_type = [] { + static auto* map = new std::unordered_map; + for (int i = 0; i < PrimitiveType_ARRAYSIZE; i++) { + if (PrimitiveType_IsValid(i) && i != PRIMITIVE_TYPE_INVALID) { + auto value = static_cast(i); + (*map)[LowercasePrimitiveTypeName(value)] = value; + } + } + return map; + }(); + return *name_to_type; +} + +} // namespace + +StatusOr StringToPrimitiveType(absl::string_view name) { + const auto& map = GetPrimitiveTypeStringMap(); + auto found = map.find(string(name)); + if (found == map.end()) { + return InvalidArgument("Invalid element type string: \"%s\".", name); + } + return found->second; +} + +bool IsPrimitiveTypeName(absl::string_view name) { + const auto& map = GetPrimitiveTypeStringMap(); + auto found = map.find(string(name)); + return found != map.end(); +} + } // namespace primitive_util } // namespace xla diff --git a/tensorflow/compiler/xla/primitive_util.h b/tensorflow/compiler/xla/primitive_util.h index 889e9a1ceca675689406d255d348c82c398563aa..70603b6fed1be50c427799e6dce7b8bf9631a6f4 100644 --- a/tensorflow/compiler/xla/primitive_util.h +++ b/tensorflow/compiler/xla/primitive_util.h @@ -20,6 +20,9 @@ limitations under the License. #include +#include "absl/strings/string_view.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" @@ -221,6 +224,17 @@ template <> struct PrimitiveTypeToNative { using type = complex64; }; + +// Returns the lower-case name of the given primitive type. +const string& LowercasePrimitiveTypeName(PrimitiveType s); + +// Returns the PrimitiveType matching the given name. The given name is expected +// to be lower-case. +StatusOr StringToPrimitiveType(absl::string_view name); + +// Returns true if the given name is a primitive type string (lower-case). +bool IsPrimitiveTypeName(absl::string_view name); + } // namespace primitive_util } // namespace xla diff --git a/tensorflow/compiler/xla/primitive_util_test.cc b/tensorflow/compiler/xla/primitive_util_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..1f765d6da9ef65849fe8ede56ced7597d623cb59 --- /dev/null +++ b/tensorflow/compiler/xla/primitive_util_test.cc @@ -0,0 +1,46 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/primitive_util.h" + +#include +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/test_helpers.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" + +namespace xla { +namespace { + +TEST(PrimitiveUtilTest, StringToPrimitiveType) { + auto expect_ok_and_equal = [](const string& str, PrimitiveType expected) { + TF_ASSERT_OK_AND_ASSIGN(PrimitiveType actual, + primitive_util::StringToPrimitiveType(str)); + EXPECT_EQ(expected, actual); + }; + expect_ok_and_equal("f32", F32); + expect_ok_and_equal("tuple", TUPLE); + expect_ok_and_equal("pred", PRED); + expect_ok_and_equal("s32", S32); + + EXPECT_IS_NOT_OK(primitive_util::StringToPrimitiveType("F32").status()); + EXPECT_IS_NOT_OK(primitive_util::StringToPrimitiveType("Pred").status()); + EXPECT_IS_NOT_OK(primitive_util::StringToPrimitiveType("preD").status()); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/python/BUILD b/tensorflow/compiler/xla/python/BUILD index 63ac1c6649210cbae9e238a74e0a45fb8ee4da63..4a57b1051e081a706267df66e239dc9d330c57ba 100644 --- a/tensorflow/compiler/xla/python/BUILD +++ b/tensorflow/compiler/xla/python/BUILD @@ -66,7 +66,10 @@ cc_library( "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/client:xla_computation", + "//tensorflow/compiler/xla/client/lib:cholesky", "//tensorflow/compiler/xla/client/lib:math", + "//tensorflow/compiler/xla/client/lib:qr", + "//tensorflow/compiler/xla/client/lib:triangular_solve", "//tensorflow/compiler/xla/service:platform_util", "//tensorflow/compiler/xla/service:shaped_buffer", "//tensorflow/compiler/xrt:xrt_proto", diff --git a/tensorflow/compiler/xla/python/local_computation_builder.cc b/tensorflow/compiler/xla/python/local_computation_builder.cc index c0b57e7d26581662476fb64ddaedafe4d55d8619..5d191f5a18ebad8213c29fcc08f317db9626e4ed 100644 --- a/tensorflow/compiler/xla/python/local_computation_builder.cc +++ b/tensorflow/compiler/xla/python/local_computation_builder.cc @@ -24,7 +24,10 @@ limitations under the License. #include "tensorflow/cc/framework/ops.h" #include "tensorflow/cc/framework/scope.h" #include "tensorflow/cc/ops/standard_ops.h" +#include "tensorflow/compiler/xla/client/lib/cholesky.h" #include "tensorflow/compiler/xla/client/lib/math.h" +#include "tensorflow/compiler/xla/client/lib/qr.h" +#include "tensorflow/compiler/xla/client/lib/triangular_solve.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/executable_run_options.h" @@ -148,14 +151,19 @@ static StatusOr ToBuffer(LocalClient* client, /* static */ StatusOr LocalShapedBuffer::FromLiteral( - const Literal& argument, const absl::optional& shape_with_layout) { + const Literal& argument, const absl::optional& shape_with_layout, + int replica_number) { LocalClient* client = GetOrCreateLocalClient(); + TF_ASSIGN_OR_RETURN(int device_ordinal, + client->ReplicaNumberToDeviceOrdinal(replica_number)); + VLOG(1) << "Creating shaped buffer from literal on replica/ordinal: " + << replica_number << "/" << device_ordinal; StatusOr buf = [&] { if (shape_with_layout) { Literal relaid = argument.Relayout(shape_with_layout.value()); - return ToBuffer(client, /*device_ordinal=*/0, relaid); + return ToBuffer(client, device_ordinal, relaid); } - return ToBuffer(client, /*device_ordinal=*/0, argument); + return ToBuffer(client, device_ordinal, argument); }(); TF_RETURN_IF_ERROR(buf.status()); return new LocalShapedBuffer(std::move(buf).ValueOrDie()); @@ -312,66 +320,127 @@ CompiledLocalComputation::CompiledLocalComputation( StatusOr CompiledLocalComputation::Execute( absl::Span argument_handles) { LocalClient* client = GetOrCreateLocalClient(); + StatusOr device_ordinal_status = client->ReplicaNumberToDeviceOrdinal(0); + StatusOr result_buffer_status; + if (!device_ordinal_status.ok()) { + result_buffer_status = device_ordinal_status.status(); + } else { + const int device_ordinal = device_ordinal_status.ValueOrDie(); + VLOG(3) << "Replica 0 mapped to device ordinal for execution: " + << device_ordinal; + + std::vector argument_buffers; + argument_buffers.reserve(argument_handles.size()); + for (auto& handle : argument_handles) { + argument_buffers.push_back(handle->shaped_buffer()); + } + + DeviceAssignment device_assignment = + client->backend() + .computation_placer() + ->AssignDevices(1, /*computation_count=*/1) + .ConsumeValueOrDie(); + + ExecutableRunOptions options; + options.set_device_ordinal(device_ordinal); + options.set_allocator(client->backend().memory_allocator()); + options.set_intra_op_thread_pool( + client->backend().eigen_intra_op_thread_pool_device()); + options.set_device_assignment(&device_assignment); + + result_buffer_status = executable_->Run(argument_buffers, options); + } + + if (!result_buffer_status.ok()) { + return InternalError( + "Failed running replica 0 (other replicas may have failed as well): " + "%s.", + result_buffer_status.status().ToString()); + } + return new LocalShapedBuffer(std::move(result_buffer_status).ValueOrDie()); +} + +StatusOr CompiledLocalComputation::ExecutePerReplica( + absl::Span> argument_handles) { + LocalClient* client = GetOrCreateLocalClient(); + const int num_replicas = GetReplicaCount(); - VLOG(1) << "Execution requested with " << GetReplicaCount() << " replicas."; + if (argument_handles.size() != num_replicas) { + return InvalidArgument( + "Attempted to execute with %d replicas when replica count is %d", + argument_handles.size(), num_replicas); + } + + VLOG(1) << "Executing with " << num_replicas << " replicas."; // Each replica populates a StatusOr result, but only the output value of // replica zero is returned. - std::vector> results(GetReplicaCount()); - { + std::vector> results(num_replicas); + auto execute = [this, client, num_replicas, &argument_handles, + &results](int replica) { + StatusOr device_ordinal_status = + client->ReplicaNumberToDeviceOrdinal(replica); + if (!device_ordinal_status.ok()) { + results[replica] = device_ordinal_status.status(); + return; + } + const int device_ordinal = device_ordinal_status.ValueOrDie(); + VLOG(3) << "Replica " << replica + << " mapped to device ordinal for execution: " << device_ordinal; + + std::vector argument_buffers; + argument_buffers.reserve(argument_handles[replica].size()); + for (auto& handle : argument_handles[replica]) { + argument_buffers.push_back(handle->shaped_buffer()); + } + + DeviceAssignment device_assignment = + client->backend() + .computation_placer() + ->AssignDevices(num_replicas, /*computation_count=*/1) + .ConsumeValueOrDie(); + + ExecutableRunOptions options; + options.set_device_ordinal(device_ordinal); + options.set_allocator(client->backend().memory_allocator()); + options.set_intra_op_thread_pool( + client->backend().eigen_intra_op_thread_pool_device()); + options.set_device_assignment(&device_assignment); + StatusOr result_buffer_status = + executable_->Run(argument_buffers, options); + + results[replica] = std::move(result_buffer_status); + }; + + if (num_replicas == 1) { + // Fast-path if there is only one replica — run the computation on the + // current thread. + execute(0); + } else { + // TODO(phawkins): don't recreate the threadpool for each execution. tensorflow::thread::ThreadPool pool(tensorflow::Env::Default(), "xlarun", - GetReplicaCount()); - - for (int replica = 0; replica < GetReplicaCount(); ++replica) { - pool.Schedule([this, client, replica, &argument_handles, &results] { - StatusOr device_ordinal_status = - client->ReplicaNumberToDeviceOrdinal(replica); - if (!device_ordinal_status.ok()) { - results[replica] = device_ordinal_status.status(); - return; - } - const int device_ordinal = device_ordinal_status.ValueOrDie(); - VLOG(3) << "Replica " << replica - << " mapped to device ordinal for execution: " - << device_ordinal; - - std::vector argument_buffers; - argument_buffers.reserve(argument_handles.size()); - for (auto& handle : argument_handles) { - argument_buffers.push_back(handle->shaped_buffer()); - } - - DeviceAssignment device_assignment = - client->backend() - .computation_placer() - ->AssignDevices(GetReplicaCount(), /*computation_count=*/1) - .ConsumeValueOrDie(); - - ExecutableRunOptions options; - options.set_device_ordinal(device_ordinal); - options.set_allocator(client->backend().memory_allocator()); - options.set_intra_op_thread_pool( - client->backend().eigen_intra_op_thread_pool_device()); - options.set_device_assignment(&device_assignment); - StatusOr result_buffer_status = - executable_->Run(argument_buffers, options); - - results[replica] = std::move(result_buffer_status); - }); + num_replicas - 1); + + for (int replica = 0; replica < num_replicas - 1; ++replica) { + pool.Schedule([&execute, replica] { execute(replica); }); } + execute(num_replicas - 1); } - for (int replica = 0; replica < GetReplicaCount(); ++replica) { - const auto& statusor = results[replica]; + std::vector wrapped_results(num_replicas); + for (int replica = 0; replica < num_replicas; ++replica) { + auto& statusor = results[replica]; if (!statusor.ok()) { return InternalError( "Failed running replica %d (other replicas may have failed as well): " "%s.", replica, statusor.status().ToString()); } + wrapped_results[replica] = + new LocalShapedBuffer(std::move(statusor).ValueOrDie()); } - return new LocalShapedBuffer(std::move(results[0]).ValueOrDie()); + return new LocalShapedBufferTuple(std::move(wrapped_results)); } static StatusOr GetReturnValueShape(const XlaComputation& computation) { @@ -578,6 +647,15 @@ LocalOp LocalComputationBuilder::ConstantLiteral(const Literal& literal) { return xla::ConstantLiteral(&builder_, literal); } +LocalOp LocalComputationBuilder::Iota(PrimitiveType element_type, int64 size) { + return xla::Iota(&builder_, element_type, size); +} + +LocalOp LocalComputationBuilder::BroadcastedIota(const Shape& shape, + int64 dimension) { + return xla::Iota(&builder_, shape, dimension); +} + LocalOp LocalComputationBuilder::Broadcast( const LocalOp& operand, absl::Span broadcast_sizes) { return xla::Broadcast(operand.op(), broadcast_sizes); @@ -714,6 +792,21 @@ LocalOp LocalComputationBuilder::Call(const LocalComputation& local_computation, return xla::Call(&builder_, local_computation.computation(), xla_ops); } +LocalOp LocalComputationBuilder::CustomCall( + const string& call_target_name, absl::Span operands, + const Shape& shape_with_layout, + const std::vector& operand_shapes_with_layout, + const string& opaque) { + std::vector xla_ops; + xla_ops.reserve(operands.size()); + for (const auto& op : operands) { + xla_ops.push_back(op.op()); + } + return xla::CustomCallWithLayout(&builder_, call_target_name, xla_ops, + shape_with_layout, + operand_shapes_with_layout, opaque); +} + LocalOp LocalComputationBuilder::Transpose( const LocalOp& operand, absl::Span permutation) { return xla::Transpose(operand.op(), permutation); @@ -799,6 +892,27 @@ LocalOp LocalComputationBuilder::SortKeyVal(const LocalOp& keys, return xla::Sort(keys.op(), {values.op()}, dimension); } +LocalOp LocalComputationBuilder::Cholesky(const LocalOp& a) { + return xla::Cholesky(a.op()); +} + +LocalOp LocalComputationBuilder::QR(const LocalOp& a, bool full_matrices) { + XlaBuilder* builder = a.op().builder(); + return builder->ReportErrorOrReturn([&]() -> StatusOr { + TF_ASSIGN_OR_RETURN(auto qr, xla::QRDecomposition(a.op(), full_matrices)); + return xla::Tuple(builder, {qr.q, qr.r}); + }); +} + +LocalOp LocalComputationBuilder::TriangularSolve(const LocalOp& a, + const LocalOp& b, + bool left_side, bool lower, + bool transpose_a, + bool conjugate_a) { + return xla::TriangularSolve(a.op(), b.op(), left_side, lower, transpose_a, + conjugate_a); +} + StatusOr LocalComputationBuilder::BuildConstantSubGraph( const LocalOp& operand) { TF_ASSIGN_OR_RETURN(XlaComputation computation, diff --git a/tensorflow/compiler/xla/python/local_computation_builder.h b/tensorflow/compiler/xla/python/local_computation_builder.h index c9b7ae824a4e5dac3360de0f95859d7c1deb360f..c6e58ac971d93662c41fc7a6001f94fb26d2eff5 100644 --- a/tensorflow/compiler/xla/python/local_computation_builder.h +++ b/tensorflow/compiler/xla/python/local_computation_builder.h @@ -71,7 +71,8 @@ StatusOr TransferFromOutfeedLocalReplica(const Shape& shape, class LocalShapedBuffer { public: static StatusOr FromLiteral( - const Literal& argument, const absl::optional& shape_with_layout); + const Literal& argument, const absl::optional& shape_with_layout, + int replica_number); LocalShapedBuffer(ScopedShapedBuffer shaped_buffer); StatusOr ToLiteral() const; @@ -175,6 +176,12 @@ class CompiledLocalComputation { StatusOr Execute( absl::Span argument_handles); + // Execute on many replicas. Takes a sequence of argument lists (one argument + // list per replica) and returns a tuple of results (one result per replica). + // The number of argument lists must be equal to the replica count. + StatusOr ExecutePerReplica( + absl::Span > argument_handles); + private: std::unique_ptr executable_; }; @@ -279,6 +286,10 @@ class LocalComputationBuilder { LocalOp ConstantLiteral(const Literal& literal); + LocalOp Iota(PrimitiveType element_type, int64 size); + + LocalOp BroadcastedIota(const Shape& shape, int64 dimension); + LocalOp Broadcast(const LocalOp& operand, absl::Span broadcast_sizes); @@ -345,6 +356,12 @@ class LocalComputationBuilder { LocalOp Call(const LocalComputation& local_computation, absl::Span operands); + LocalOp CustomCall(const string& call_target_name, + absl::Span operands, + const Shape& shape_with_layout, + const std::vector& operand_shapes_with_layout, + const string& opaque); + LocalOp Transpose(const LocalOp& operand, absl::Span permutation); @@ -387,6 +404,13 @@ class LocalComputationBuilder { LocalOp SortKeyVal(const LocalOp& keys, const LocalOp& values, int64 dimension); + LocalOp QR(const LocalOp& a, bool full_matrices); + + LocalOp Cholesky(const LocalOp& a); + + LocalOp TriangularSolve(const LocalOp& a, const LocalOp& b, bool left_side, + bool lower, bool transpose_a, bool conjugate_a); + StatusOr BuildConstantSubGraph(const LocalOp& operand); #define _FORWARD(method_name, return_sig, args_sig) \ diff --git a/tensorflow/compiler/xla/python/local_computation_builder.i b/tensorflow/compiler/xla/python/local_computation_builder.i index 5c2538dcc36d93008382a517fd4dc680caaa4347..11fb00e616ad410fd1e5b0225ca3cd5362fef59b 100644 --- a/tensorflow/compiler/xla/python/local_computation_builder.i +++ b/tensorflow/compiler/xla/python/local_computation_builder.i @@ -363,6 +363,37 @@ tensorflow::ImportNumpy(); $1 = temps; } +%typemap(in) absl::Span > + (std::vector > temps) { + if (!PySequence_Check($input)) { + PyErr_SetString(PyExc_TypeError, "Argument is not a sequence"); + SWIG_fail; + } + const int size = PySequence_Size($input); + temps.reserve(size); + for (int i = 0; i < size; ++i) { + PyObject* o = PySequence_GetItem($input, i); + std::vector vec; + const int vec_size = PySequence_Size(o); + vec.reserve(vec_size); + for (int j = 0; j < vec_size; ++j) { + PyObject* vec_elt = PySequence_GetItem(o, j); + LocalShapedBuffer* lsbp; + if ((SWIG_ConvertPtr(vec_elt, (void**) &lsbp, $descriptor(xla::swig::LocalShapedBuffer*), + SWIG_POINTER_EXCEPTION)) == -1) { + Py_DECREF(vec_elt); + Py_DECREF(o); + SWIG_fail; + } + vec.push_back(lsbp); + Py_DECREF(vec_elt); + } + temps.push_back(vec); + Py_DECREF(o); + } + $1 = temps; +} + %typemap(in) absl::Span (std::vector temps) { if (!PySequence_Check($input)) { @@ -998,6 +1029,7 @@ tensorflow::ImportNumpy(); %unignore xla::swig::XrtAllocationTuple::size; %unignore xla::swig::CompiledLocalComputation; %unignore xla::swig::CompiledLocalComputation::Execute; +%unignore xla::swig::CompiledLocalComputation::ExecutePerReplica; %unignore xla::swig::CompiledXrtComputation; %unignore xla::swig::CompiledXrtComputation::Execute; %unignore xla::swig::LocalComputation; @@ -1019,6 +1051,8 @@ tensorflow::ImportNumpy(); %unignore xla::swig::LocalComputationBuilder::Outfeed; %unignore xla::swig::LocalComputationBuilder::ConstantLiteral; %unignore xla::swig::LocalComputationBuilder::ConstantR0; +%unignore xla::swig::LocalComputationBuilder::Iota; +%unignore xla::swig::LocalComputationBuilder::BroadcastedIota; %unignore xla::swig::LocalComputationBuilder::Broadcast; %unignore xla::swig::LocalComputationBuilder::BroadcastInDim; %unignore xla::swig::LocalComputationBuilder::Pad; @@ -1112,6 +1146,10 @@ tensorflow::ImportNumpy(); %unignore xla::swig::LocalComputationBuilder::Imag; %unignore xla::swig::LocalComputationBuilder::Conj; %unignore xla::swig::LocalComputationBuilder::Complex; +%unignore xla::swig::LocalComputationBuilder::Cholesky; +%unignore xla::swig::LocalComputationBuilder::QR; +%unignore xla::swig::LocalComputationBuilder::TriangularSolve; +%unignore xla::swig::LocalComputationBuilder::CustomCall; %unignore xla::swig::DeleteLocalComputation; %unignore xla::swig::DestructureLocalShapedBufferTuple; %unignore xla::swig::DestructureXrtAllocationTuple; diff --git a/tensorflow/compiler/xla/python/xla_client.py b/tensorflow/compiler/xla/python/xla_client.py index e5fba0d7acb838788f8e7e05a4634e807d9d21d0..4166fa0327eba5edd0dee030e283c86ade627040 100644 --- a/tensorflow/compiler/xla/python/xla_client.py +++ b/tensorflow/compiler/xla/python/xla_client.py @@ -222,24 +222,33 @@ class LocalBuffer(object): means the referent is in device memory. """ - def __init__(self, c_buffer, backend): + def __init__(self, c_buffer, backend, replica): self.c_buffer = c_buffer self._backend = backend + self._replica = replica if backend.backend_type == BackendType.XRT: self._delete = c_api.DeleteXrtAllocation else: self._delete = c_api.DeleteLocalShapedBuffer @staticmethod - def from_pyval(pyval, backend=XLA_LOCAL_BACKEND): + def from_pyval(pyval, replica=0, backend=XLA_LOCAL_BACKEND): """Allocate and copy to XLA the given python value.""" pyval = require_numpy_array_layout(pyval) + num_replicas = get_replica_count() + if not 0 <= replica < num_replicas: + raise ValueError( + 'Attempt to place buffer on replica {} when the replica count is {}' + .format(replica, num_replicas)) if backend.backend_type == BackendType.XRT: + if replica != 0: + raise NotImplementedError( + 'Multi-replica execution is not yet supported via the XRT backend.') cbuf = c_api.XrtAllocation.FromLiteral( pyval, _maybe_encode_string(backend.target)) else: - cbuf = c_api.LocalShapedBuffer.FromLiteral(pyval, None) - return LocalBuffer(cbuf, backend) + cbuf = c_api.LocalShapedBuffer.FromLiteral(pyval, None, replica) + return LocalBuffer(cbuf, backend, replica) def to_py(self): return self.c_buffer.ToLiteral() @@ -247,6 +256,9 @@ class LocalBuffer(object): def shape(self): return _wrap_shape(self.c_buffer.shape()) + def replica(self): + return self._replica + def delete(self): if self.c_buffer is not None: self._delete(self.c_buffer) @@ -263,7 +275,8 @@ class LocalBuffer(object): self.delete() size = result.size() destructured = tuple( - LocalBuffer(result.Release(i), backend=self._backend) + LocalBuffer( + result.Release(i), replica=self._replica, backend=self._backend) for i in xrange(size)) return destructured @@ -575,23 +588,87 @@ class LocalComputation(object): compile_options=compile_options, layout_fn=layout_fn) - def Execute(self, arguments=()): - """Execute with LocalBuffer arguments and return value.""" + def GetReturnValueShape(self): + return _wrap_shape(self._c_computation.GetReturnValueShape()) + + def Execute(self, arguments=(), check_for_deleted_args=True): + """Execute on one replica with LocalBuffer arguments and return value.""" + if check_for_deleted_args and any(arg.is_deleted() for arg in arguments): + raise ValueError('Executing with deleted local buffer argument') + raw_args = [arg.c_buffer for arg in arguments] + output_buffer = self._c_computation.Execute(raw_args) + return LocalBuffer(output_buffer, backend=self._backend, replica=0) + + def ExecutePerReplica(self, arguments=None): + """Execute on many replicas with LocalBuffer arguments and return value. + + Args: + arguments: A sequence of sequences of LocalBuffers. The i'th inner + sequence comprises the arguments for execution on the i'th replica. + + Returns: + A list of the computation's outputs on each replica, as a LocalBuffer. If + a shallow sequence of arguments was passed in for `arguments`, then the + sole, zero'th replica's output is returned instead, as a LocalBuffer. + """ if not self._is_compiled: raise ValueError('Cannot execute an uncompiled local XLA computation.') - arguments = tuple(arguments) - if any(arg.is_deleted() for arg in arguments): - raise ValueError('Executing with deleted local buffer argument') - return LocalBuffer( - self._c_computation.Execute([arg.c_buffer for arg in arguments]), - backend=self._backend) + if arguments is None: + arguments = ((),) * get_replica_count() + else: + arguments = [list(replica_args) for replica_args in arguments] + + # Check arguments + for replica, replica_args in enumerate(arguments): + for arg in replica_args: + if arg.is_deleted(): + raise ValueError('Executing with deleted local buffer argument') + if arg.replica() != replica: + raise ValueError( + 'Executing on replica {} with argument from replica {}'.format( + replica, arg.replica())) + + # Pull out argument buffer handles + stripped_args = [ + [arg.c_buffer for arg in replica_args] for replica_args in arguments + ] + + # Execute + if self._backend.backend_type == BackendType.XRT: + if len(stripped_args) > 1: + raise NotImplementedError( + 'Multi-replica execution is not yet supported via the XRT backend.') + output_buffers = [self._c_computation.Execute(stripped_args[0])] + else: + output_buffer_tup = self._c_computation.ExecutePerReplica(stripped_args) + size = output_buffer_tup.size() + output_buffers = [output_buffer_tup.Release(i) for i in xrange(size)] + + # Wrap output handles in LocalBuffer instances + return tuple( + LocalBuffer(output_buffer, backend=self._backend, replica=replica) + for replica, output_buffer in enumerate(output_buffers)) def ExecuteWithPythonValues(self, arguments=()): - """Execute with Python values as arguments and return value.""" - arguments = tuple( - LocalBuffer.from_pyval(arg, backend=self._backend) for arg in arguments) + """Execute on one replica with Python values as arguments and output.""" + + def put(arg): + return LocalBuffer.from_pyval(arg, backend=self._backend) + + arguments = [put(arg) for arg in arguments] return self.Execute(arguments).to_py() + def ExecuteWithPythonValuesPerReplica(self, arguments): + """Execute on many replicas with Python values as arguments and output.""" + + def put(arg, replica): + return LocalBuffer.from_pyval(arg, replica, backend=self._backend) + + arguments = [[put(arg, replica) + for arg in replica_args] + for replica, replica_args in enumerate(arguments)] + return [out.to_py() for out in self.ExecutePerReplica(arguments)] + def __del__(self): self._delete(self._c_computation) @@ -754,6 +831,33 @@ class ComputationBuilder(object): return self.ParameterWithShape( Shape.from_pyval(value), name=name, parameter_num=parameter_num) + def Iota(self, dtype, size): + """Enqueues an iota constant onto the computation. + + Args: + dtype: expected numpy dtype of the output. + size: integer, the number of elements in the array. + + Returns: + A LocalOp representing the added iota constant. + """ + element_type = DTYPE_TO_XLA_ELEMENT_TYPE[str(np.dtype(dtype))] + return self._client.Iota(element_type, size) + + def BroadcastedIota(self, dtype, shape, dimension): + """Enqueues a broadcasted iota constant onto the computation. + + Args: + dtype: expected numpy dtype of the output. + shape: tuple of integers, the expected output shape (dimensions). + dimension: positive integer, dimension along which to increment values. + + Returns: + A LocalOp representing the added broadcasted iota constant. + """ + xla_shape = Shape.array_shape(dtype, shape) + return self._client.BroadcastedIota(xla_shape, dimension) + def Broadcast(self, operand, sizes): """Enqueues a broadcast operation onto the computation. @@ -1025,6 +1129,31 @@ class ComputationBuilder(object): """ return self._client.Call(computation_to_apply.computation, operands) + def CustomCall(self, + call_target_name, + operands, + shape_with_layout, + operand_shapes_with_layout, + opaque=None): + """Enqueues a custom call operation onto the computation. + + Args: + call_target_name: the name of the function to call. + operands: an iterable of LocalOp. The number and types of operands must + match the arity of `operand_shapes_with_layout`. + shape_with_layout: the shape of the operator's output, with layout. + operand_shapes_with_layout: the shapes of `operands`, including the + expected layouts. + opaque: an opaque string passed to the backend. + + Returns: + A LocalOp representing the added custom call op. + """ + opaque = opaque or '' + return self._client.CustomCall(call_target_name, operands, + shape_with_layout, + operand_shapes_with_layout, opaque) + def Map(self, operands, computation_to_apply, dimensions): """Enqueues a map operation onto the computation. @@ -1334,6 +1463,20 @@ class ComputationBuilder(object): """Enqueues a key-value sort operation onto the computation.""" return self._client.SortKeyVal(keys, values, dimension) + def Cholesky(self, a): + """Enqueues a Cholesky decomposition onto the computation.""" + return self._client.Cholesky(a) + + def QR(self, a, full_matrices=True): + """Enqueues a QR decomposition onto the computation.""" + return self._client.QR(a, full_matrices) + + def TriangularSolve(self, a, b, left_side=False, lower=False, + transpose_a=False, conjugate_a=False): + """Enqueues a triangular-solve operation onto the computation.""" + return self._client.TriangularSolve( + a, b, left_side, lower, transpose_a, conjugate_a) + def _forward_methods_to_local_builder(): """Forward remaining ComputationBuilder methods to the C API. diff --git a/tensorflow/compiler/xla/python/xla_client_test.py b/tensorflow/compiler/xla/python/xla_client_test.py index 21b5c93b615ec429a5da0b4ffe89e8f75f59ef1b..95c6dc8c4570564e361c27fd2bca5c90eebb4661 100644 --- a/tensorflow/compiler/xla/python/xla_client_test.py +++ b/tensorflow/compiler/xla/python/xla_client_test.py @@ -18,6 +18,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import functools import itertools import threading @@ -51,9 +52,11 @@ class LocalComputationTest(unittest.TestCase): def _ExecuteAndCompareExact(self, c, arguments=(), expected=None): self._ExecuteAndAssertWith(np.testing.assert_equal, c, arguments, expected) - def _ExecuteAndCompareClose(self, c, arguments=(), expected=None): - self._ExecuteAndAssertWith(np.testing.assert_allclose, c, arguments, - expected) + def _ExecuteAndCompareClose(self, c, arguments=(), expected=None, rtol=1e-7, + atol=0): + self._ExecuteAndAssertWith( + functools.partial(np.testing.assert_allclose, rtol=rtol, atol=atol), + c, arguments, expected) def NumpyArrayF32(*args, **kwargs): @@ -143,6 +146,17 @@ class ComputationsWithConstantsTest(LocalComputationTest): c.Pow(c.Constant(NumpyArrayF64([1.5, 2.5, 3.0])), c.ConstantF64Scalar(2.)) self._ExecuteAndCompareClose(c, expected=[2.25, 6.25, 9.]) + def testIota(self): + c = self._NewComputation() + c.Iota(np.float32, 10) + self._ExecuteAndCompareExact(c, expected=np.arange(10, dtype=np.float32)) + + def testBroadcastedIota(self): + c = self._NewComputation() + c.BroadcastedIota(np.int64, (2, 3), 1) + expected = np.array([[0, 1, 2], [0, 1, 2]], dtype=np.int64) + self._ExecuteAndCompareExact(c, expected=expected) + def testBooleanAnd(self): c = self._NewComputation() c.And( @@ -1057,6 +1071,38 @@ class SingleOpTest(LocalComputationTest): self.assertTrue(np.all(lo <= result)) self.assertTrue(np.all(result < hi)) + def testCholesky(self): + l = np.array([[4, 0, 0, 0], [6, 5, 0, 0], [2, 14, 16, 0], [3, 6, 1, 4]], + dtype=np.float32) + c = self._NewComputation() + c.Cholesky(c.Constant(np.dot(l, l.T))) + self._ExecuteAndCompareClose(c, expected=l, rtol=1e-4) + + def testQR(self): + a = np.array( + [[4, 6, 8, 10], [6, 45, 54, 63], [8, 54, 146, 166], [10, 63, 166, 310]], + dtype=np.float32) + c = self._NewComputation() + c.QR(c.Constant(a), full_matrices=True) + q, r = self._Execute(c, ()) + np.testing.assert_allclose(np.dot(q, r), a, rtol=1e-4) + + def testTriangularSolve(self): + a_vals = np.array( + [[2, 0, 0, 0], [3, 6, 0, 0], [4, 7, 9, 0], [5, 8, 10, 11]], + dtype=np.float32) + b_vals = np.array([[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]], + dtype=np.float32) + + c = self._NewComputation() + c.TriangularSolve(c.Constant(a_vals), c.Constant(b_vals), left_side=False, + lower=True, transpose_a=True) + self._ExecuteAndCompareClose(c, expected=np.array([ + [0.5, 0.08333334, 0.04629629, 0.03367003], + [2.5, -0.25, -0.1388889, -0.1010101], + [4.5, -0.58333331, -0.32407406, -0.23569024], + ], dtype=np.float32), rtol=1e-4) + def testIsConstant(self): c = self._NewComputation() a = c.ConstantS32Scalar(3) diff --git a/tensorflow/compiler/xla/python_api/xla_shape.py b/tensorflow/compiler/xla/python_api/xla_shape.py index 95b2bf300ec67e9f034f77450416544cb088ae55..bdcd4abd6cc708795416b15412f37dde10d7fe97 100644 --- a/tensorflow/compiler/xla/python_api/xla_shape.py +++ b/tensorflow/compiler/xla/python_api/xla_shape.py @@ -20,6 +20,8 @@ from __future__ import print_function import numpy as _np # Avoids becoming a part of public Tensorflow API. +from six.moves import xrange + from tensorflow.compiler.xla import xla_data_pb2 from tensorflow.compiler.xla.python_api import types diff --git a/tensorflow/compiler/xla/reference_util.cc b/tensorflow/compiler/xla/reference_util.cc index ceb5e74db7c3b9305e9d77068df9ae0a3690af8a..a27e2005dae3a44f4e49032e70f62d633f64779a 100644 --- a/tensorflow/compiler/xla/reference_util.cc +++ b/tensorflow/compiler/xla/reference_util.cc @@ -21,7 +21,6 @@ limitations under the License. #include "absl/memory/memory.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/literal_util.h" -#include "tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul.h" #include "tensorflow/compiler/xla/service/hlo_evaluator.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/shape_inference.h" @@ -32,48 +31,19 @@ limitations under the License. namespace xla { -namespace { - -template -std::unique_ptr> MatmulArray2DImpl( - const Array2D& lhs, const Array2D& rhs, - const std::function& impl_fn) { - CHECK_EQ(lhs.width(), rhs.height()); - int m = lhs.height(); - int n = rhs.width(); - int k = lhs.width(); - auto result = absl::make_unique>(m, n); - // Because Eigen is a header-oriented library, make sure that the Eigen code - // is the same as the code used by the CPU backend (otherwise the linker will - // randomly pick *some* definition). - impl_fn( - /*run_options_ptr=*/nullptr, result->data(), rhs.data(), lhs.data(), n, m, - k, - /*transpose_lhs=*/0, - /*transpose_rhs=*/0); - return result; -} - -} // namespace - /* static */ std::unique_ptr> ReferenceUtil::MatmulArray2D( const Array2D& lhs, const Array2D& rhs) { - return MatmulArray2DImpl( - lhs, rhs, __xla_cpu_runtime_EigenSingleThreadedMatMulF16); + return HloEvaluator::MatmulArray2D(lhs, rhs); } /* static */ std::unique_ptr> ReferenceUtil::MatmulArray2D( const Array2D& lhs, const Array2D& rhs) { - return MatmulArray2DImpl( - lhs, rhs, __xla_cpu_runtime_EigenSingleThreadedMatMulF32); + return HloEvaluator::MatmulArray2D(lhs, rhs); } /* static */ std::unique_ptr> ReferenceUtil::MatmulArray2D( const Array2D& lhs, const Array2D& rhs) { - return MatmulArray2DImpl( - lhs, rhs, __xla_cpu_runtime_EigenSingleThreadedMatMulF64); + return HloEvaluator::MatmulArray2D(lhs, rhs); } /* static */ std::unique_ptr> ReferenceUtil::Array2DF32ToF64( diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD index 429b4e490cc2f1ab894924e95db3ad7e80342a72..55cadfdec64047a1d8cd4e2cd1d649d4c3f717e2 100644 --- a/tensorflow/compiler/xla/service/BUILD +++ b/tensorflow/compiler/xla/service/BUILD @@ -241,6 +241,7 @@ cc_library( ":hlo_casting_utils", ":hlo_query", ":shape_inference", + "//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_util", @@ -249,6 +250,7 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:window_util", "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/service/cpu:runtime_single_threaded_matmul", "//tensorflow/core:lib", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/base", @@ -281,10 +283,12 @@ tf_cc_test( "//tensorflow/compiler/xla/service:hlo_element_type_converter", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", + "//tensorflow/compiler/xla/tests:test_utils", "//tensorflow/compiler/xla/tests:xla_internal_test_main", # fixdeps: keep "//tensorflow/core:lib", "//tensorflow/core:test", "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings:str_format", ], ) @@ -1010,6 +1014,7 @@ cc_library( srcs = ["name_uniquer.cc"], hdrs = ["name_uniquer.h"], deps = [ + "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:types", "//tensorflow/core:lib", "@com_google_absl//absl/container:flat_hash_map", @@ -1410,6 +1415,7 @@ cc_library( "//tensorflow/core:lib", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/memory", ], ) @@ -1574,6 +1580,7 @@ cc_library( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:lib", "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:optional", @@ -1589,7 +1596,10 @@ tf_cc_test( ":hlo", ":hlo_casting_utils", ":hlo_matchers", + ":hlo_parser", ":hlo_pass", + ":pattern_matcher", + ":pattern_matcher_gmock", "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:test", @@ -1777,6 +1787,7 @@ tf_cc_test( ":hlo_cse", ":hlo_dce", ":hlo_matchers", + ":hlo_parser", ":hlo_pass", ":hlo_pass_pipeline", ":tuple_simplifier", @@ -1905,6 +1916,41 @@ cc_library( ], ) +cc_library( + name = "dynamic_dimension_inference", + srcs = ["dynamic_dimension_inference.cc"], + hdrs = ["dynamic_dimension_inference.h"], + deps = [ + ":hlo", + "//tensorflow/compiler/xla:status", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:types", + "//tensorflow/core:lib", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/types:span", + ], +) + +tf_cc_test( + name = "dynamic_dimension_inference_test", + srcs = ["dynamic_dimension_inference_test.cc"], + deps = [ + ":dynamic_dimension_inference", + "//tensorflow/compiler/xla:debug_options_flags", + "//tensorflow/compiler/xla:literal", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla:test_helpers", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service:hlo_matchers", + "//tensorflow/compiler/xla/service:hlo_runner", + "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/core:test", + ], +) + tf_cc_test( name = "reshape_mover_test", srcs = ["reshape_mover_test.cc"], @@ -2062,7 +2108,8 @@ tf_cc_test( srcs = ["hlo_computation_test.cc"], deps = [ ":hlo", - ":hlo_matchers", + ":pattern_matcher", + ":pattern_matcher_gmock", "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:test", @@ -2656,7 +2703,6 @@ tf_cc_test( ":algebraic_simplifier", ":computation_layout", ":hlo", - ":hlo_matchers", ":layout_assignment", ":pattern_matcher", ":pattern_matcher_gmock", @@ -2670,6 +2716,7 @@ tf_cc_test( "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:test_utils", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", "//tensorflow/core:test", "@com_google_absl//absl/types:span", @@ -3122,6 +3169,7 @@ cc_library( name = "hlo_graph_dumper", srcs = [ "hlo_graph_dumper.cc", + "hlo_graph_html_renderer.cc", ], hdrs = ["hlo_graph_dumper.h"], deps = [ @@ -3129,6 +3177,7 @@ cc_library( ":hlo_casting_utils", ":hlo_execution_profile", ":hlo_tfgraph_builder", + ":pattern_matcher", "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:types", @@ -3582,7 +3631,6 @@ cc_library( srcs = ["hlo_lexer.cc"], hdrs = [ "hlo_lexer.h", - "hlo_token.h", ], deps = [ "//tensorflow/compiler/xla:shape_util", diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.cc b/tensorflow/compiler/xla/service/algebraic_simplifier.cc index a348bcf0a232994a046df51563a9167faac08190..1287dcf546d9fe575dd440d48323ed8efbf1de9d 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier.cc @@ -16,6 +16,8 @@ limitations under the License. #include "tensorflow/compiler/xla/service/algebraic_simplifier.h" #include +#include +#include #include #include #include @@ -24,6 +26,7 @@ limitations under the License. #include #include "absl/algorithm/container.h" +#include "absl/container/inlined_vector.h" #include "absl/memory/memory.h" #include "absl/strings/str_cat.h" #include "absl/types/optional.h" @@ -68,6 +71,45 @@ bool IsAll(const HloInstruction* op, int8 value) { } } +// Checks whether `op` is a floating-point constant or broadcast of a constant +// of the form +/- 2^k for some integer k positive, negative, or zero. Such +// values are interesting because multiplying by a power of 2 just moves the +// exponent. +bool IsAllFpConstantPowerOf2(const HloInstruction* op) { + // Unwrap the broadcast if necessary. + const HloInstruction* c; + if (!Match(op, m::ConstantEffectiveScalar(&c)) && + !Match(op, m::Broadcast(m::Constant(&c).WithShape( + m::Shape().IsEffectiveScalar())))) { + return false; + } + auto val = [&]() -> absl::optional { + switch (c->shape().element_type()) { + case BF16: + return static_cast(c->literal().GetFirstElement()); + case F16: + return static_cast(c->literal().GetFirstElement()); + case F32: + return c->literal().GetFirstElement(); + case F64: + return c->literal().GetFirstElement(); + default: + // Cowardly refuse to consider complex types. + return absl::nullopt; + } + }(); + if (!val) { + return false; + } + + int exp; + double mantissa = std::frexp(*val, &exp); + // frexp returns a value in the range (-1; -0.5] U [0.5, 1). A return value + // of +/-0.5 therefore indicates that the floating point value is a power of + // 2. + return mantissa == 0.5 || mantissa == -0.5; +} + // Returns whether the given transpose produces a result which is bit-wise // identical to its operand and thus may be replaced with a bitcast. bool TransposeIsBitcast(const HloInstruction* transpose) { @@ -199,6 +241,13 @@ class AlgebraicSimplifierVisitor : public DfsHloVisitorWithDefault { // more fusion than leaving the nodes as Dot operations. StatusOr HandleDotStrengthReduction(HloInstruction* dot); + // Removes dimension dim from hlo. + HloInstruction* StripDim(HloInstruction* hlo, int64 dim) { + CHECK_EQ(hlo->shape().dimensions(dim), 1); + return computation_->AddInstruction(HloInstruction::CreateReshape( + ShapeUtil::DeleteDimension(dim, hlo->shape()), hlo)); + } + // Reshapes an instruction to rank 1 if it is not already rank 1. HloInstruction* Flatten(HloInstruction* hlo) { if (ShapeUtil::Rank(hlo->shape()) == 1) { @@ -415,6 +464,40 @@ Status AlgebraicSimplifierVisitor::HandleAdd(HloInstruction* add) { sum_of_constants)); } + // A*C + B*C => (A+B)*C + // + // - If A, B, and C are integers, do this unconditionally. Proof of + // correctness: https://rise4fun.com/Alive/u9X. + // + // - If A, B, and C are floating point, do this if C is a scalar constant or + // broadcast of scalar constant and is equal to +/- 2^k for some (possibly + // negative) integer k. + // + // Multiplying by a power of 2 just moves the exponent, so our answer is + // exact modulo rounding of intermediate results so long as + // + // - none of the three products has an exponent which underflows (so the + // result is 0 or denormal), and + // - none of the three products overflows to inf. + // + // Proof: See algebraic_simplifier_proof_distributive_property.py. + // + // We deem these differences in rounding, underflow, and overflow + // acceptable in the ML context. + HloInstruction *b, *c; + if (((Match(lhs, m::Multiply(m::Op(&a), m::Op(&c))) && + Match(rhs, m::MultiplyAnyOrder(m::Op().Is(c), m::Op(&b)))) || + (Match(lhs, m::Multiply(m::Op(&c), m::Op(&a))) && + Match(rhs, m::MultiplyAnyOrder(m::Op().Is(c), m::Op(&b))))) && + (ShapeUtil::ElementIsIntegral(add->shape()) || + IsAllFpConstantPowerOf2(c))) { + return ReplaceWithNewInstruction( + add, HloInstruction::CreateBinary( + add->shape(), HloOpcode::kMultiply, + computation_->AddInstruction(HloInstruction::CreateBinary( + add->shape(), HloOpcode::kAdd, a, b)), + c)); + } return Status::OK(); } @@ -834,21 +917,51 @@ StatusOr AlgebraicSimplifierVisitor::HandleDotStrengthReduction( HloInstruction* dot) { HloInstruction *lhs, *rhs; CHECK(Match(dot, m::Dot(m::Op(&lhs), m::Op(&rhs)))); - int64 lhs_collapsing_dim = - dot->dot_dimension_numbers().lhs_contracting_dimensions(0); + + const auto kept_dim = [](int64 rank, int64 contracting_dimension, + absl::Span batch_dimensions) -> int64 { + for (int64 i = 0; i < rank; ++i) { + if (i != contracting_dimension && + !absl::c_linear_search(batch_dimensions, i)) { + return i; + } + } + return -1; + }; + + const int64 dot_rank = ShapeUtil::Rank(dot->shape()); + const int64 rhs_rank = ShapeUtil::Rank(rhs->shape()); + const int64 lhs_rank = ShapeUtil::Rank(lhs->shape()); + const auto& dnums = dot->dot_dimension_numbers(); + if (dnums.rhs_contracting_dimensions_size() > 1) { + return false; + } + if (dot_rank > 2 && (lhs_rank != rhs_rank || lhs_rank != dot_rank)) { + return false; + } + int64 lhs_collapsing_dim = dnums.lhs_contracting_dimensions(0); + int64 lhs_kept_dim = kept_dim(lhs_rank, lhs_collapsing_dim, + AsInt64Slice(dnums.lhs_batch_dimensions())); + // If there is no non-contracting dimension in rank 2, do not strength reduce. + if (lhs_kept_dim == -1 && lhs_rank > 1) { + return false; + } if (lhs->IsRank2Transpose()) { lhs = lhs->mutable_operand(0); - lhs_collapsing_dim = 1 - lhs_collapsing_dim; + std::swap(lhs_collapsing_dim, lhs_kept_dim); } - const int64 lhs_kept_dim = 1 - lhs_collapsing_dim; - int64 rhs_collapsing_dim = - dot->dot_dimension_numbers().rhs_contracting_dimensions(0); + int64 rhs_collapsing_dim = dnums.rhs_contracting_dimensions(0); + int64 rhs_kept_dim = kept_dim(rhs_rank, rhs_collapsing_dim, + AsInt64Slice(dnums.rhs_batch_dimensions())); + // If there is no non-contracting dimension in rank 2, do not strength reduce. + if (rhs_kept_dim == -1 && rhs_rank > 1) { + return false; + } if (rhs->IsRank2Transpose()) { rhs = rhs->mutable_operand(0); - rhs_collapsing_dim = 1 - rhs_collapsing_dim; + std::swap(rhs_collapsing_dim, rhs_kept_dim); } - const int64 rhs_kept_dim = 1 - rhs_collapsing_dim; auto as_type = [&](HloInstruction* hlo, const PrimitiveType element_type) { if (hlo->shape().element_type() == element_type) { @@ -871,10 +984,15 @@ StatusOr AlgebraicSimplifierVisitor::HandleDotStrengthReduction( return AddReduce(as_type(hlo, F32), dim); }; + auto broadcast = [&](HloInstruction* hlo, const Shape& shape, + absl::Span dims) { + return computation_->AddInstruction( + HloInstruction::CreateBroadcast(shape, hlo, dims)); + }; + auto broadcast_to_dim = [&](HloInstruction* hlo, const Shape& shape, int64 dim) { - return computation_->AddInstruction( - HloInstruction::CreateBroadcast(shape, hlo, {dim})); + return broadcast(hlo, shape, {dim}); }; auto multiply = [&](HloInstruction* local_lhs, HloInstruction* local_rhs) { @@ -885,11 +1003,9 @@ StatusOr AlgebraicSimplifierVisitor::HandleDotStrengthReduction( // Strength reduce dot(a[K] , b[K]) = // reshape(result.shape, // reduce_sum(multiply(a, b), {0})) - if (ShapeUtil::Rank(rhs->shape()) == 1 && - ShapeUtil::Rank(lhs->shape()) == 1) { - TF_RETURN_IF_ERROR( - ReplaceInstruction(dot, reshape_if_necessary(add_reduce_in_f32( - multiply(Flatten(lhs), Flatten(rhs)), 0)))); + if (rhs_rank == 1 && lhs_rank == 1) { + TF_RETURN_IF_ERROR(ReplaceInstruction( + dot, reshape_if_necessary(add_reduce_in_f32(multiply(lhs, rhs), 0)))); return true; } @@ -903,8 +1019,7 @@ StatusOr AlgebraicSimplifierVisitor::HandleDotStrengthReduction( // Simplify outer product into multiply with implicit broadcasting. // // A dot(a[M, 1], b[1, N]) = multiply(a [M,1], b [1, N]) - if (ShapeUtil::Rank(rhs->shape()) == 2 && - rhs->shape().dimensions(rhs_collapsing_dim) == 1) { + if (rhs_rank == 2 && rhs->shape().dimensions(rhs_collapsing_dim) == 1) { TF_RETURN_IF_ERROR(ReplaceInstruction( dot, multiply(broadcast_to_dim(Flatten(lhs), dot->shape(), 0), broadcast_to_dim(Flatten(rhs), dot->shape(), 1)))); @@ -918,9 +1033,8 @@ StatusOr AlgebraicSimplifierVisitor::HandleDotStrengthReduction( // {0}) // ) // ) - if (ShapeUtil::Rank(lhs->shape()) == 1 || - (ShapeUtil::Rank(lhs->shape()) == 2 && - lhs->shape().dimensions(lhs_kept_dim) == 1)) { + if (lhs_rank == 1 || + (lhs_rank == 2 && lhs->shape().dimensions(lhs_kept_dim) == 1)) { if (ShapeUtil::Rank(rhs->shape()) == 1) { TF_RETURN_IF_ERROR( ReplaceInstruction(dot, reshape_if_necessary(add_reduce_in_f32( @@ -940,9 +1054,8 @@ StatusOr AlgebraicSimplifierVisitor::HandleDotStrengthReduction( // reshape(result.shape, // reduce_sum(multiply(a, broadcast(reshape([K],b), {1})), {0}) // ) - if (ShapeUtil::Rank(rhs->shape()) == 1 || - (ShapeUtil::Rank(rhs->shape()) == 2 && - rhs->shape().dimensions(rhs_kept_dim) == 1)) { + if (rhs_rank == 1 || + (rhs_rank == 2 && rhs->shape().dimensions(rhs_kept_dim) == 1)) { TF_RETURN_IF_ERROR(ReplaceInstruction( dot, reshape_if_necessary(add_reduce_in_f32( multiply(lhs, broadcast_to_dim(Flatten(rhs), lhs->shape(), @@ -950,6 +1063,97 @@ StatusOr AlgebraicSimplifierVisitor::HandleDotStrengthReduction( lhs_collapsing_dim)))); return true; } + + // Only consider kDot with batch dimension. + if (dot_rank <= 2) { + return false; + } + + CHECK_EQ(rhs_rank, lhs_rank); + CHECK_EQ(dot_rank, lhs_rank); + // If there is more than one non-contracting dimension or the batch dimensions + // are not equal, bail out since transposes may be required to do a strength + // reduction. + if (dnums.rhs_batch_dimensions_size() + 2 != dot_rank || + !absl::c_equal(dnums.lhs_batch_dimensions(), + dnums.rhs_batch_dimensions())) { + return false; + } + + auto broadcast_dims = [](int64 rank, int64 non_broadcast_dim) { + absl::InlinedVector dims; + for (int64 i = 0; i < rank; ++i) { + if (i != non_broadcast_dim) { + dims.push_back(i); + } + } + return dims; + }; + + // If the contracting dimension is 1, remove the degnerate dimnesions from the + // lhs and rhs, broadcast each to the result shape and multiply. + if (lhs->shape().dimensions(lhs_collapsing_dim) == 1 && + (rhs_kept_dim == rhs_rank - 1 || + (rhs_collapsing_dim == rhs_rank - 1 && rhs_kept_dim == rhs_rank - 2))) { + CHECK_EQ(rhs->shape().dimensions(rhs_collapsing_dim), 1); + const int64 lhs_kept_dim_in_output = + lhs_kept_dim > lhs_collapsing_dim ? (lhs_kept_dim - 1) : lhs_kept_dim; + absl::InlinedVector lhs_broadcast_dims; + for (const int64 dim : dnums.lhs_batch_dimensions()) { + lhs_broadcast_dims.push_back(dim > lhs_collapsing_dim ? (dim - 1) : dim); + } + absl::InlinedVector rhs_broadcast_dims = lhs_broadcast_dims; + lhs_broadcast_dims.push_back(lhs_kept_dim_in_output); + absl::c_sort(lhs_broadcast_dims); + rhs_broadcast_dims.push_back(dot_rank - 1); + absl::c_sort(rhs_broadcast_dims); + TF_RETURN_IF_ERROR(ReplaceInstruction( + dot, reshape_if_necessary( + multiply(broadcast(StripDim(lhs, lhs_collapsing_dim), + dot->shape(), lhs_broadcast_dims), + broadcast(StripDim(rhs, rhs_collapsing_dim), + dot->shape(), rhs_broadcast_dims))))); + return true; + } + + // If the lhs and rhs non-contracting dimensions are both one, strip each one, + // multiply and then reduce the collapsing dimension + if (lhs->shape().dimensions(lhs_kept_dim) == 1 && + rhs->shape().dimensions(rhs_kept_dim) == 1 && + lhs_kept_dim == rhs_kept_dim) { + auto new_lhs = StripDim(lhs, lhs_kept_dim); + auto new_rhs = StripDim(rhs, rhs_kept_dim); + const int64 reduce_dim = rhs_kept_dim < rhs_collapsing_dim + ? (rhs_collapsing_dim - 1) + : rhs_collapsing_dim; + TF_RETURN_IF_ERROR( + ReplaceInstruction(dot, reshape_if_necessary(add_reduce_in_f32( + multiply(new_lhs, new_rhs), reduce_dim)))); + return true; + } + + // If the lhs non-contracting dimensions is one, strip the one, brodcast to + // the rhs shape, multiply and then reduce the collapsing dimension + if (lhs->shape().dimensions(lhs_kept_dim) == 1) { + auto new_lhs = broadcast(StripDim(lhs, lhs_kept_dim), rhs->shape(), + broadcast_dims(rhs_rank, rhs_kept_dim)); + TF_RETURN_IF_ERROR(ReplaceInstruction( + dot, reshape_if_necessary(add_reduce_in_f32(multiply(new_lhs, rhs), + rhs_collapsing_dim)))); + return true; + } + + // If the rhs non-contracting dimensions is one, strip the one, brodcast to + // the lhs shape, multiply and then reduce the collapsing dimension + if (rhs->shape().dimensions(rhs_kept_dim) == 1) { + auto new_rhs = broadcast(StripDim(rhs, rhs_kept_dim), lhs->shape(), + broadcast_dims(lhs_rank, lhs_kept_dim)); + TF_RETURN_IF_ERROR(ReplaceInstruction( + dot, reshape_if_necessary(add_reduce_in_f32(multiply(lhs, new_rhs), + lhs_collapsing_dim)))); + return true; + } + return false; } @@ -1228,25 +1432,31 @@ Status AlgebraicSimplifierVisitor::HandleDot(HloInstruction* dot) { HloInstruction *lhs, *rhs; CHECK(Match(dot, m::Dot(m::Op(&lhs), m::Op(&rhs)))); - // Only optimize F32 or BF16 dot operations where the dot, rhs and lhs are - // rank 2 or below. - if ((dot->shape().element_type() != F32 && - dot->shape().element_type() != BF16) || - ShapeUtil::Rank(lhs->shape()) > 2 || ShapeUtil::Rank(rhs->shape()) > 2 || - ShapeUtil::Rank(dot->shape()) > 2) { - return Status::OK(); - } - // Replace a zero element dot with a broadcast of the constant 0. if (ShapeUtil::IsZeroElementArray(dot->shape()) || ShapeUtil::IsZeroElementArray(lhs->shape()) || ShapeUtil::IsZeroElementArray(rhs->shape())) { - auto zero = computation_->AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.0f))); + auto zero = computation_->AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::Zero(dot->shape().element_type()))); return ReplaceWithNewInstruction( dot, HloInstruction::CreateBroadcast(dot->shape(), zero, {})); } + // Only optimize F32 or BF16 dot operations where the dot, rhs and lhs are + // rank 2 or below. + if (dot->shape().element_type() != F32 && + dot->shape().element_type() != BF16) { + return Status::OK(); + } + if (ShapeUtil::Rank(lhs->shape()) > 2 || ShapeUtil::Rank(rhs->shape()) > 2 || + ShapeUtil::Rank(dot->shape()) > 2) { + if (options_.enable_dot_strength_reduction() && + !options_.is_layout_sensitive()) { + TF_RETURN_IF_ERROR(HandleDotStrengthReduction(dot).status()); + } + return Status::OK(); + } + TF_ASSIGN_OR_RETURN(HloInstruction * dot_of_concat_optimized, OptimizeDotOfConcat(dot)); if (dot_of_concat_optimized) { @@ -1952,6 +2162,7 @@ Status AlgebraicSimplifierVisitor::HandleReshape(HloInstruction* reshape) { reshape, HloInstruction::CreateReshape(reshape->shape(), operand->mutable_operand(0))); } + if (operand->opcode() == HloOpcode::kRng && operand->user_count() == 1) { *operand->mutable_shape() = reshape->shape(); return ReplaceInstruction(reshape, operand); @@ -2674,6 +2885,22 @@ Status AlgebraicSimplifierVisitor::HandleSort(HloInstruction* sort) { return Status::OK(); } +namespace { +bool OnlyPermutesMoreThanOneDegenerateDim(const Shape& shape, + absl::Span perm) { + std::vector new_permutation; + int64 degenerate_count = 0; + for (int64 i = 0; i < perm.size(); ++i) { + if (shape.dimensions(i) != 1) { + new_permutation.push_back(perm[i]); + } else { + ++degenerate_count; + } + } + return degenerate_count > 1 && absl::c_is_sorted(new_permutation); +} +} // namespace + Status AlgebraicSimplifierVisitor::HandleTranspose(HloInstruction* transpose) { auto operand = transpose->mutable_operand(0); if (std::is_sorted(transpose->dimensions().begin(), @@ -2690,6 +2917,15 @@ Status AlgebraicSimplifierVisitor::HandleTranspose(HloInstruction* transpose) { transpose->dimensions()))); } + // Replace transpose with a reshape if more than one degenerate method is + // permuted. + if (OnlyPermutesMoreThanOneDegenerateDim(transpose->shape(), + transpose->dimensions())) { + return ReplaceWithNewInstruction( + transpose, HloInstruction::CreateReshape( + transpose->shape(), transpose->mutable_operand(0))); + } + if (operand->opcode() == HloOpcode::kRng && operand->user_count() == 1) { *operand->mutable_shape() = transpose->shape(); return ReplaceInstruction(transpose, operand); diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier_proof_distributive_property.py b/tensorflow/compiler/xla/service/algebraic_simplifier_proof_distributive_property.py new file mode 100644 index 0000000000000000000000000000000000000000..5da13da041b4ded813876af7ca379025187545ab --- /dev/null +++ b/tensorflow/compiler/xla/service/algebraic_simplifier_proof_distributive_property.py @@ -0,0 +1,82 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Proof that transforming (A*C)+(B*C) <=> (A+B)*C is "safe" if C=2^k. + +Specifically, for all floating-point values A, B, and C, if + + - C is equal to +/- 2^k for some (possibly negative) integer k, and + - A, B, C, A*C, B*C, and A+B are not subnormal, zero, or inf, + +then there exists a rounding mode rm in [RTZ, RNE] such that + + (A*C) + (B*C) == (A+B) * C (computed with rounding mode rm). + +Informally, this means that the equivalence holds for powers of 2 C, modulo +flushing to zero or inf, and modulo rounding of intermediate results. + +Requires z3 python bindings; try `pip install z3-solver`. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +import z3 + +# We do float16 because it lets the solver run much faster. These results +# should generalize to fp32 and fp64, and you can verify this by changing the +# value of FLOAT_TY (and then waiting a while). +FLOAT_TY = z3.Float16 + +a = z3.FP("a", FLOAT_TY()) +b = z3.FP("b", FLOAT_TY()) +c = z3.FP("c", FLOAT_TY()) + +s = z3.Solver() + +# C must be a power of 2, i.e. significand bits must all be 0. +s.add(z3.Extract(FLOAT_TY().sbits() - 1, 0, z3.fpToIEEEBV(c)) == 0) + +for rm in [z3.RTZ(), z3.RNE()]: + z3.set_default_rounding_mode(rm) + before = a * c + b * c + after = (a + b) * c + + # Check that before == after, allowing that 0 == -0. + s.add( + z3.Not( + z3.Or( + before == after, # + z3.And(z3.fpIsZero(before), z3.fpIsZero(after))))) + + for x in [ + (a * c), + (b * c), + (a + b), + ]: + s.add(z3.Not(z3.fpIsSubnormal(x))) + s.add(z3.Not(z3.fpIsZero(x))) + s.add(z3.Not(z3.fpIsInf(x))) + +if s.check() == z3.sat: + m = s.model() + print("Counterexample found!") + print(m) + print("a*c: ", z3.simplify(m[a] * m[c])) + print("b*c: ", z3.simplify(m[b] * m[c])) + print("a+b: ", z3.simplify(m[a] + m[b])) + print("a*c + b*c: ", z3.simplify(m[a] * m[c] + m[b] * m[c])) + print("(a+b) * c: ", z3.simplify((m[a] + m[b]) * m[c])) +else: + print("Proved!") diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc index 48f689c96a98065498818aa081d4a5a911aea5a6..cfb4c48277605a6f90ef51debac1c3bc26bed070 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc @@ -27,9 +27,11 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_instructions.h" -#include "tensorflow/compiler/xla/service/hlo_matchers.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/service/hlo_pass_fix.h" +#include "tensorflow/compiler/xla/service/pattern_matcher.h" +#include "tensorflow/compiler/xla/service/pattern_matcher_gmock.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" @@ -42,8 +44,7 @@ namespace xla { namespace { using ::testing::ElementsAre; - -namespace op = xla::testing::opcode_matchers; +namespace m = match; AlgebraicSimplifierOptions::ValidBitcastCallback bitcasting_callback() { return [](const Shape&, const Shape&) { return true; }; @@ -79,6 +80,128 @@ TEST_F(AlgebraicSimplifierTest, AddZero) { EXPECT_EQ(root, param0); } +TEST_F(AlgebraicSimplifierTest, FactorIntegerAddition) { + const char* kModuleStr = R"( + HloModule m + test { + p0 = s32[8] parameter(0) + p1 = s32[8] parameter(1) + p2 = s32[8] parameter(2) + x = s32[8] multiply(p0, p2) + y = s32[8] multiply(p1, p2) + ROOT sum = s32[8] add(x, y) + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr)); + AlgebraicSimplifier simplifier(default_options_); + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); + EXPECT_THAT( + m->entry_computation()->root_instruction(), + GmockMatch(m::MultiplyAnyOrder( + m::AddAnyOrder(m::Parameter(0), m::Parameter(1)), m::Parameter(2)))); +} + +// A*C + B*C => (A+B)*C if C is a floating-point power of 2. +TEST_F(AlgebraicSimplifierTest, FactorFpAddition) { + const char* kModuleStr = R"( + HloModule m + test { + p0 = f32[] parameter(0) + p1 = f32[] parameter(1) + c = f32[] constant(0.125) + x = f32[] multiply(p0, c) + y = f32[] multiply(p1, c) + ROOT sum = f32[] add(x, y) + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr)); + ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie()); + EXPECT_THAT(m->entry_computation()->root_instruction(), + GmockMatch(m::MultiplyAnyOrder( + m::AddAnyOrder(m::Parameter(0), m::Parameter(1)), + m::ConstantScalar(0.125)))); +} + +// A*C + B*C => (A+B)*C if C is a broadcast of a floating-point power of 2. +TEST_F(AlgebraicSimplifierTest, FactorFpAdditionWithBroadcast) { + const char* kModuleStr = R"( + HloModule m + test { + p0 = f32[4] parameter(0) + p1 = f32[4] parameter(1) + c = f32[] constant(0.125) + b = f32[4] broadcast(c), dimensions={} + x = f32[4] multiply(p0, b) + y = f32[4] multiply(p1, b) + ROOT sum = f32[4] add(x, y) + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr)); + ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie()); + EXPECT_THAT(m->entry_computation()->root_instruction(), + GmockMatch(m::MultiplyAnyOrder( + m::AddAnyOrder(m::Parameter(0), m::Parameter(1)), + m::Broadcast(m::ConstantScalar(0.125))))); +} + +// A*C + B*C => (A+B)*C simplification should not happen if C is not a +// floating-point power of 2. +TEST_F(AlgebraicSimplifierTest, FactorFpAdditionNotPowerOf2) { + const char* kModuleStr = R"( + HloModule m + test { + p0 = f32[] parameter(0) + p1 = f32[] parameter(1) + c = f32[] constant(0.3) + x = f32[] multiply(p0, c) + y = f32[] multiply(p1, c) + ROOT sum = f32[] add(x, y) + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr)); + EXPECT_FALSE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie()); +} + +// A*C + B*C => (A+B)*C simplification should not happen if A, B, and C are +// complex numbers. +TEST_F(AlgebraicSimplifierTest, FactorFpAdditionComplex) { + const char* kModuleStr = R"( + HloModule m + test { + p0 = c64[8] parameter(0) + p1 = c64[8] parameter(1) + p2 = c64[8] parameter(2) + x = c64[8] multiply(p0, p2) + y = c64[8] multiply(p1, p2) + ROOT sum = c64[8] add(x, y) + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr)); + EXPECT_FALSE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie()); +} + +// A*C + B*C => (A+B)*C simplification is OK if A, B, and C are complex. +TEST_F(AlgebraicSimplifierTest, FactorFpAdditionBfloat16) { + const char* kModuleStr = R"( + HloModule m + test { + p0 = bf16[4] parameter(0) + p1 = bf16[4] parameter(1) + c = bf16[] constant(0.125) + b = bf16[4] broadcast(c), dimensions={} + x = bf16[4] multiply(p0, b) + y = bf16[4] multiply(p1, b) + ROOT sum = bf16[4] add(x, y) + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr)); + ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie()); + EXPECT_THAT(m->entry_computation()->root_instruction(), + GmockMatch(m::MultiplyAnyOrder( + m::AddAnyOrder(m::Parameter(0), m::Parameter(1)), + m::Broadcast(m::ConstantScalar(0.125))))); +} + // Test that A * 0 is simplified to 0 TEST_F(AlgebraicSimplifierTest, MulZero) { auto m = CreateNewVerifiedModule(); @@ -197,7 +320,7 @@ TEST_F(AlgebraicSimplifierTest, TwoReducesToOne) { AlgebraicSimplifier simplifier(default_options_); ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); HloInstruction* root = m->entry_computation()->root_instruction(); - EXPECT_THAT(root, op::Reduce(param, zero)); + EXPECT_THAT(root, GmockMatch(m::Reduce(m::Parameter(0), m::Op().Is(zero)))); EXPECT_EQ(root->dimensions(), std::vector({0, 2, 3})); } @@ -219,7 +342,7 @@ TEST_F(AlgebraicSimplifierTest, AddConstOnLHS) { AlgebraicSimplifier simplifier(default_options_); ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); root = computation->root_instruction(); - EXPECT_THAT(root, op::Add(param0, op::Constant())); + EXPECT_THAT(root, GmockMatch(m::Add(m::Parameter(0), m::Constant()))); } // Test that [(A + C1) + C2] => [A + (C1 + C2)] for constants C1 and C2. @@ -245,7 +368,9 @@ TEST_F(AlgebraicSimplifierTest, AddReassociateMergeConstants) { AlgebraicSimplifier simplifier(default_options_); ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); root = computation->root_instruction(); - EXPECT_THAT(root, op::Add(param0, op::Add(constant1, constant2))); + EXPECT_THAT(root, GmockMatch(m::Add( + m::Op().Is(param0), + m::Add(m::Op().Is(constant1), m::Op().Is(constant2))))); } TEST_F(AlgebraicSimplifierTest, AddBroadcastZeroR0Operand) { @@ -303,7 +428,8 @@ TEST_F(AlgebraicSimplifierTest, InlineTrivialMap) { AlgebraicSimplifier simplifier(default_options_); ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); root = computation->root_instruction(); - EXPECT_THAT(root, op::Add(param0, op::Broadcast(zero))); + EXPECT_THAT(root, GmockMatch(m::Add(m::Parameter(0), + m::Broadcast(m::Op().Is(zero))))); } TEST_F(AlgebraicSimplifierTest, AddBroadcastZeroR1Operand) { @@ -336,11 +462,11 @@ TEST_F(AlgebraicSimplifierTest, ConstantToBroadcast) { auto computation = m->AddEntryComputation(builder.Build()); HloInstruction* root = computation->root_instruction(); - EXPECT_THAT(root, op::Constant()); + EXPECT_THAT(root, GmockMatch(m::Constant())); AlgebraicSimplifier simplifier(default_options_); ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); root = computation->root_instruction(); - EXPECT_THAT(root, op::Broadcast(op::Constant())); + EXPECT_THAT(root, GmockMatch(m::Broadcast(m::Constant()))); EXPECT_EQ(3.14f, root->operand(0)->literal().GetFirstElement()); } @@ -352,11 +478,11 @@ TEST_F(AlgebraicSimplifierTest, ConstantNotToBroadcast) { auto computation = m->AddEntryComputation(builder.Build()); HloInstruction* root = computation->root_instruction(); - EXPECT_THAT(root, op::Constant()); + EXPECT_THAT(root, GmockMatch(m::Constant())); AlgebraicSimplifier simplifier(default_options_); ASSERT_FALSE(simplifier.Run(m.get()).ValueOrDie()); root = computation->root_instruction(); - EXPECT_THAT(root, op::Constant()); + EXPECT_THAT(root, GmockMatch(m::Constant())); } TEST_F(AlgebraicSimplifierTest, IotaToBroadcast) { @@ -367,11 +493,11 @@ TEST_F(AlgebraicSimplifierTest, IotaToBroadcast) { auto computation = m->AddEntryComputation(builder.Build()); HloInstruction* root = computation->root_instruction(); - EXPECT_THAT(root, op::Constant()); + EXPECT_THAT(root, GmockMatch(m::Constant())); AlgebraicSimplifier simplifier(default_options_); ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); root = computation->root_instruction(); - EXPECT_THAT(root, op::Iota()); + EXPECT_THAT(root, GmockMatch(m::Iota())); } // Test that A - 0 is simplified to A @@ -413,7 +539,8 @@ TEST_F(AlgebraicSimplifierTest, SubConstCanonicalization) { AlgebraicSimplifier simplifier(default_options_); ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); root = computation->root_instruction(); - EXPECT_THAT(root, op::Add(param0, op::Negate(constant))); + EXPECT_THAT(root, GmockMatch(m::Add(m::Parameter(0), + m::Negate(m::Op().Is(constant))))); } // Test that (A/B)/C is simplified to A/(B*C). @@ -435,13 +562,16 @@ TEST_F(AlgebraicSimplifierTest, LhsDivOfDiv) { auto computation = m->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), - op::Divide(op::Divide(param0, param1), param2)); + GmockMatch(m::Divide(m::Divide(m::Parameter(0), m::Parameter(1)), + m::Parameter(2)))); AlgebraicSimplifier simplifier(default_options_); ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); - EXPECT_THAT(computation->root_instruction(), - op::Divide(param0, op::Multiply(param1, param2))); + EXPECT_THAT( + computation->root_instruction(), + GmockMatch(m::Divide(m::Parameter(0), + m::Multiply(m::Parameter(1), m::Parameter(2))))); } // Test that A/(B/C) is simplified to (A*C)/B. @@ -462,14 +592,18 @@ TEST_F(AlgebraicSimplifierTest, RhsDivOfDiv) { auto computation = m->AddEntryComputation(builder.Build()); - EXPECT_THAT(computation->root_instruction(), - op::Divide(param0, op::Divide(param1, param2))); + EXPECT_THAT( + computation->root_instruction(), + GmockMatch(m::Divide(m::Parameter(0), + m::Divide(m::Parameter(1), m::Parameter(2))))); AlgebraicSimplifier simplifier(default_options_); ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); - EXPECT_THAT(computation->root_instruction(), - op::Divide(op::Multiply(param0, param2), param1)); + EXPECT_THAT( + computation->root_instruction(), + GmockMatch(m::Divide(m::Multiply(m::Parameter(0), m::Parameter(2)), + m::Parameter(1)))); } // Test that (A/B)/(C/D) is simplified to (A*D)/(B*C). @@ -496,14 +630,16 @@ TEST_F(AlgebraicSimplifierTest, DivOfDivAndDiv) { EXPECT_THAT( computation->root_instruction(), - op::Divide(op::Divide(param0, param1), op::Divide(param2, param3))); + GmockMatch(m::Divide(m::Divide(m::Parameter(0), m::Parameter(1)), + m::Divide(m::Parameter(2), m::Parameter(3))))); AlgebraicSimplifier simplifier(default_options_); ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); EXPECT_THAT( computation->root_instruction(), - op::Divide(op::Multiply(param0, param3), op::Multiply(param1, param2))); + GmockMatch(m::Divide(m::Multiply(m::Parameter(0), m::Parameter(3)), + m::Multiply(m::Parameter(1), m::Parameter(2))))); } // Test that A/exp(B) is simplified to A*exp(-B). @@ -523,13 +659,14 @@ TEST_F(AlgebraicSimplifierTest, DivOfExp) { auto computation = m->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), - op::Divide(param0, op::Exp(param1))); + GmockMatch(m::Divide(m::Parameter(0), m::Exp(m::Parameter(1))))); AlgebraicSimplifier simplifier(default_options_); ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), - op::Multiply(param0, op::Exp(op::Negate(param1)))); + GmockMatch(m::Multiply(m::Parameter(0), + m::Exp(m::Negate(m::Parameter(1)))))); } // Test that A/pow(B,C) is simplified to A*pow(B,-C). @@ -550,14 +687,18 @@ TEST_F(AlgebraicSimplifierTest, DivOfPower) { auto computation = m->AddEntryComputation(builder.Build()); - EXPECT_THAT(computation->root_instruction(), - op::Divide(param0, op::Power(param1, param2))); + EXPECT_THAT( + computation->root_instruction(), + GmockMatch(m::Divide(m::Parameter(0), + m::Power(m::Parameter(1), m::Parameter(2))))); AlgebraicSimplifier simplifier(default_options_); ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), - op::Multiply(param0, op::Power(param1, op::Negate(param2)))); + GmockMatch(m::Multiply( + m::Parameter(0), + m::Power(m::Parameter(1), m::Negate(m::Parameter(2)))))); } // Test that broadcasting is done on the right step when simplifying A/pow(B,C) @@ -579,14 +720,18 @@ TEST_F(AlgebraicSimplifierTest, DivOfBroadcastingPower) { auto computation = m->AddEntryComputation(builder.Build()); - EXPECT_THAT(computation->root_instruction(), - op::Divide(param0, op::Power(param1, param2))); + EXPECT_THAT( + computation->root_instruction(), + GmockMatch(m::Divide(m::Parameter(0), + m::Power(m::Parameter(1), m::Parameter(2))))); AlgebraicSimplifier simplifier(default_options_); ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); ASSERT_THAT(computation->root_instruction(), - op::Multiply(param0, op::Power(param1, op::Negate(param2)))); + GmockMatch(m::Multiply( + m::Parameter(0), + m::Power(m::Parameter(1), m::Negate(m::Parameter(2)))))); } // A / Const => A * InvertedConst @@ -608,7 +753,7 @@ TEST_F(AlgebraicSimplifierTest, DivideByConstant) { ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), - op::Multiply(param0, op::Constant())); + GmockMatch(m::Multiply(m::Parameter(0), m::Constant()))); } // pow(pow(A, X), Y) => pow(A, X*Y) @@ -630,8 +775,10 @@ TEST_F(AlgebraicSimplifierTest, PowerOfPower) { auto computation = m->AddEntryComputation(builder.Build()); AlgebraicSimplifier simplifier(default_options_); ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); - EXPECT_THAT(computation->root_instruction(), - op::Power(base, op::Multiply(exp1, exp2))); + EXPECT_THAT( + computation->root_instruction(), + GmockMatch(m::Power(m::Op().Is(base), + m::Multiply(m::Op().Is(exp1), m::Op().Is(exp2))))); } // Don't simplify pow(pow(A, X), Y) => pow(A, X*Y) if X and Y are complex @@ -794,7 +941,7 @@ TEST_F(AlgebraicSimplifierTest, SelectMakeTuple) { AlgebraicSimplifier simplifier(default_options_); ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); root = computation->root_instruction(); - EXPECT_THAT(root, op::Add(param1, param2)); + EXPECT_THAT(root, GmockMatch(m::Add(m::Parameter(1), m::Parameter(2)))); } // Test that exp(A)/exp(B) is simplified to exp(A-B) @@ -815,14 +962,16 @@ TEST_F(AlgebraicSimplifierTest, ExpDiv) { auto computation = m->AddEntryComputation(builder.Build()); - EXPECT_THAT(computation->root_instruction(), - op::Divide(op::Exp(param0), op::Exp(param1))); + EXPECT_THAT( + computation->root_instruction(), + GmockMatch(m::Divide(m::Exp(m::Parameter(0)), m::Exp(m::Parameter(1))))); AlgebraicSimplifier simplifier(default_options_); ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); - EXPECT_THAT(computation->root_instruction(), - op::Exp(op::Subtract(param0, param1))); + EXPECT_THAT( + computation->root_instruction(), + GmockMatch(m::Exp(m::Subtract(m::Parameter(0), m::Parameter(1))))); } // Test that exp(A)*exp(B) is simplified to exp(A+B) @@ -844,13 +993,14 @@ TEST_F(AlgebraicSimplifierTest, ExpMul) { auto computation = m->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), - op::Multiply(op::Exp(param0), op::Exp(param1))); + GmockMatch(m::Multiply(m::Exp(m::Parameter(0)), + m::Exp(m::Parameter(1))))); AlgebraicSimplifier simplifier(default_options_); ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), - op::Exp(op::Add(param0, param1))); + GmockMatch(m::Exp(m::Add(m::Parameter(0), m::Parameter(1))))); } // Test that pow(exp(A), B) is simplified to exp(A*B) @@ -870,13 +1020,14 @@ TEST_F(AlgebraicSimplifierTest, PowExp) { auto computation = m->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), - op::Power(op::Exp(param0), param1)); + GmockMatch(m::Power(m::Exp(m::Parameter(0)), m::Parameter(1)))); AlgebraicSimplifier simplifier(default_options_); ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); - EXPECT_THAT(computation->root_instruction(), - op::Exp(op::Multiply(param0, param1))); + EXPECT_THAT( + computation->root_instruction(), + GmockMatch(m::Exp(m::Multiply(m::Parameter(0), m::Parameter(1))))); } // Test that ln(pow(A, B)) is simplified to ln(A)*B @@ -896,13 +1047,14 @@ TEST_F(AlgebraicSimplifierTest, LnPow) { auto computation = m->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), - op::Log(op::Power(param0, param1))); + GmockMatch(m::Log(m::Power(m::Parameter(0), m::Parameter(1))))); AlgebraicSimplifier simplifier(default_options_); ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); - EXPECT_THAT(computation->root_instruction(), - op::Multiply(op::Log(param0), param1)); + EXPECT_THAT( + computation->root_instruction(), + GmockMatch(m::Multiply(m::Log(m::Parameter(0)), m::Parameter(1)))); } // Test that ln(exp(A)) is simplified to A @@ -919,7 +1071,8 @@ TEST_F(AlgebraicSimplifierTest, LnExp) { auto computation = m->AddEntryComputation(builder.Build()); - EXPECT_THAT(computation->root_instruction(), op::Log(op::Exp(param0))); + EXPECT_THAT(computation->root_instruction(), + GmockMatch(m::Log(m::Exp(m::Parameter(0))))); AlgebraicSimplifier simplifier(default_options_); ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); @@ -948,12 +1101,14 @@ TEST_F(AlgebraicSimplifierTest, LnExpDiv) { auto computation = m->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), - op::Log(op::Divide(op::Exp(param0), op::Exp(param1)))); + GmockMatch(m::Log(m::Divide(m::Exp(m::Parameter(0)), + m::Exp(m::Parameter(1)))))); AlgebraicSimplifier simplifier(default_options_); ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); - EXPECT_THAT(computation->root_instruction(), op::Subtract(param0, param1)); + EXPECT_THAT(computation->root_instruction(), + GmockMatch(m::Subtract(m::Parameter(0), m::Parameter(1)))); } // Test that pow(A, 0) where A is a scalar is simplified to the scalar @@ -971,13 +1126,14 @@ TEST_F(AlgebraicSimplifierTest, Pow0Scalar) { auto computation = m->AddEntryComputation(builder.Build()); - EXPECT_THAT(computation->root_instruction(), op::Power(param0, zero)); + EXPECT_THAT(computation->root_instruction(), + GmockMatch(m::Power(m::Parameter(0), m::Op().Is(zero)))); AlgebraicSimplifier simplifier(default_options_); ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); HloInstruction* root = computation->root_instruction(); - EXPECT_THAT(root, op::Constant()); + EXPECT_THAT(root, GmockMatch(m::Constant())); EXPECT_EQ(root->literal().GetFirstElement(), 1); } @@ -995,13 +1151,14 @@ TEST_F(AlgebraicSimplifierTest, Pow0Vector) { auto computation = m->AddEntryComputation(builder.Build()); - EXPECT_THAT(computation->root_instruction(), op::Power(param0, zero)); + EXPECT_THAT(computation->root_instruction(), + GmockMatch(m::Power(m::Parameter(0), m::Op().Is(zero)))); AlgebraicSimplifier simplifier(default_options_); ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); HloInstruction* root = computation->root_instruction(); - EXPECT_THAT(root, op::Broadcast()); + EXPECT_THAT(root, GmockMatch(m::Broadcast())); EXPECT_TRUE(ShapeUtil::Equal(root->shape(), r1f32)) << ShapeUtil::HumanString(root->shape()); EXPECT_EQ(root->dimensions().size(), 0); @@ -1023,7 +1180,8 @@ TEST_F(AlgebraicSimplifierTest, Pow1) { auto computation = m->AddEntryComputation(builder.Build()); - EXPECT_THAT(computation->root_instruction(), op::Power(param0, one)); + EXPECT_THAT(computation->root_instruction(), + GmockMatch(m::Power(m::Parameter(0), m::Op().Is(one)))); AlgebraicSimplifier simplifier(default_options_); ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); @@ -1045,12 +1203,14 @@ TEST_F(AlgebraicSimplifierTest, Pow2) { auto computation = m->AddEntryComputation(builder.Build()); - EXPECT_THAT(computation->root_instruction(), op::Power(param0, two)); + EXPECT_THAT(computation->root_instruction(), + GmockMatch(m::Power(m::Parameter(0), m::Op().Is(two)))); AlgebraicSimplifier simplifier(default_options_); ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); - EXPECT_THAT(computation->root_instruction(), op::Multiply(param0, param0)); + EXPECT_THAT(computation->root_instruction(), + GmockMatch(m::Multiply(m::Parameter(0), m::Parameter(0)))); } // Test that pow(A, -1) is simplified to 1/A. @@ -1067,13 +1227,14 @@ TEST_F(AlgebraicSimplifierTest, PowNegative1) { auto computation = m->AddEntryComputation(builder.Build()); - EXPECT_THAT(computation->root_instruction(), op::Power(param0, negative_one)); + EXPECT_THAT(computation->root_instruction(), + GmockMatch(m::Power(m::Parameter(0), m::Op().Is(negative_one)))); AlgebraicSimplifier simplifier(default_options_); ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); HloInstruction* root = computation->root_instruction(); - EXPECT_THAT(root, op::Divide(op::Broadcast(), param0)); + EXPECT_THAT(root, GmockMatch(m::Divide(m::Broadcast(), m::Parameter(0)))); EXPECT_EQ(root->operand(0)->opcode(), HloOpcode::kBroadcast); EXPECT_EQ(root->operand(0)->operand(0)->literal().GetFirstElement(), 1); @@ -1116,10 +1277,10 @@ TEST_F(AlgebraicSimplifierTest, ZeroSizedConvolution) { m->AddEntryComputation(builder.Build()); HloPassFix simplifier(default_options_); EXPECT_THAT(m->entry_computation()->root_instruction(), - op::Convolution(lhs, rhs)); + GmockMatch(m::Convolution(m::Op().Is(lhs), m::Op().Is(rhs)))); ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); EXPECT_THAT(m->entry_computation()->root_instruction(), - op::Broadcast(op::Constant())); + GmockMatch(m::Broadcast(m::Constant()))); } TEST_F(AlgebraicSimplifierTest, ZeroSizedReduceWindow) { @@ -1158,10 +1319,10 @@ TEST_F(AlgebraicSimplifierTest, ZeroSizedReduceWindow) { m->AddEntryComputation(builder.Build()); HloPassFix simplifier(default_options_); EXPECT_THAT(m->entry_computation()->root_instruction(), - op::ReduceWindow(param, op::Constant())); + GmockMatch(m::ReduceWindow(m::Parameter(0), m::Constant()))); ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); EXPECT_THAT(m->entry_computation()->root_instruction(), - op::Broadcast(op::Constant())); + GmockMatch(m::Broadcast(m::Constant()))); } TEST_F(AlgebraicSimplifierTest, ZeroSizedPad) { @@ -1184,11 +1345,11 @@ TEST_F(AlgebraicSimplifierTest, ZeroSizedPad) { padding)); m->AddEntryComputation(builder.Build()); EXPECT_THAT(m->entry_computation()->root_instruction(), - op::Pad(param, op::Constant())); + GmockMatch(m::Pad(m::Parameter(0), m::Constant()))); HloPassFix simplifier(default_options_); ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); EXPECT_THAT(m->entry_computation()->root_instruction(), - op::Broadcast(op::Constant())); + GmockMatch(m::Broadcast(m::Constant()))); } TEST_F(AlgebraicSimplifierTest, ReshapeBroadcast) { @@ -1209,7 +1370,7 @@ TEST_F(AlgebraicSimplifierTest, ReshapeBroadcast) { m->AddEntryComputation(std::move(computation)); EXPECT_THAT(m->entry_computation()->root_instruction(), - op::Reshape(op::Broadcast(op::Reshape(op)))); + GmockMatch(m::Reshape(m::Broadcast(m::Reshape(m::Op().Is(op)))))); HloPassFix simplifier(default_options_); ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); @@ -1228,7 +1389,8 @@ TEST_F(AlgebraicSimplifierTest, ConvertBetweenSameType) { auto computation = m->AddEntryComputation(builder.Build()); - EXPECT_THAT(computation->root_instruction(), op::Convert(input)); + EXPECT_THAT(computation->root_instruction(), + GmockMatch(m::Convert(m::Op().Is(input)))); AlgebraicSimplifier simplifier(default_options_); ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); @@ -1248,7 +1410,8 @@ TEST_F(AlgebraicSimplifierTest, RemoveCopy) { auto computation = m->AddEntryComputation(builder.Build()); - EXPECT_THAT(computation->root_instruction(), op::Copy(param0)); + EXPECT_THAT(computation->root_instruction(), + GmockMatch(m::Copy(m::Parameter(0)))); AlgebraicSimplifier simplifier(default_options_); ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); @@ -1269,21 +1432,24 @@ TEST_F(AlgebraicSimplifierTest, CopyEqualsBitcast) { *copy->mutable_shape()->mutable_layout() = LayoutUtil::MakeLayout({1, 2, 0, 3}); auto computation = m->AddEntryComputation(builder.Build()); - EXPECT_THAT(computation->root_instruction(), op::Copy(param)); + EXPECT_THAT(computation->root_instruction(), + GmockMatch(m::Copy(m::Parameter(0)))); AlgebraicSimplifierOptions options(non_bitcasting_callback()); options.set_is_layout_sensitive(true); AlgebraicSimplifier simplifier1(options); ASSERT_FALSE(simplifier1.Run(m.get()).ValueOrDie()); // Verify that the copy is not replaced. - EXPECT_THAT(computation->root_instruction(), op::Copy(param)); + EXPECT_THAT(computation->root_instruction(), + GmockMatch(m::Copy(m::Parameter(0)))); AlgebraicSimplifierOptions options2(bitcasting_callback()); options2.set_is_layout_sensitive(true); AlgebraicSimplifier simplifier2(options2); ASSERT_TRUE(simplifier2.Run(m.get()).ValueOrDie()); // Verify that the copy is replaced. - EXPECT_THAT(computation->root_instruction(), op::Bitcast(param)); + EXPECT_THAT(computation->root_instruction(), + GmockMatch(m::Bitcast(m::Parameter(0)))); } // Test that unary concatenates are removed. @@ -1298,7 +1464,8 @@ TEST_F(AlgebraicSimplifierTest, RemoveUnaryConcatenate) { auto computation = m->AddEntryComputation(builder.Build()); - EXPECT_THAT(computation->root_instruction(), op::Concatenate(param0)); + EXPECT_THAT(computation->root_instruction(), + GmockMatch(m::Concatenate(m::Parameter(0)))); AlgebraicSimplifier simplifier(default_options_); ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); @@ -1327,15 +1494,17 @@ TEST_F(AlgebraicSimplifierTest, RemoveEmptyConcatenateOperands) { auto computation = m->AddEntryComputation(builder.Build()); - EXPECT_THAT( - computation->root_instruction(), - op::Concatenate(empty_literal, param0, param0, empty_slice, param1)); + EXPECT_THAT(computation->root_instruction(), + GmockMatch(m::Concatenate( + m::Op().Is(empty_literal), m::Parameter(0), m::Parameter(0), + m::Op().Is(empty_slice), m::Parameter(1)))); AlgebraicSimplifier simplifier(default_options_); ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), - op::Concatenate(param0, param0, param1)); + GmockMatch(m::Concatenate(m::Parameter(0), m::Parameter(0), + m::Parameter(1)))); } // Test that reduce of concat is simplified. @@ -1383,8 +1552,9 @@ TEST_F(AlgebraicSimplifierTest, SimplifyReduceOfConcat) { EXPECT_THAT( computation->root_instruction(), - op::Map(op::Map(op::Reduce(param0, zero), op::Reduce(param1, zero)), - op::Reduce(param2, zero))); + GmockMatch(m::Map(m::Map(m::Reduce(m::Parameter(0), m::Op().Is(zero)), + m::Reduce(m::Parameter(1), m::Op().Is(zero))), + m::Reduce(m::Parameter(2), m::Op().Is(zero))))); } // Test a concatenate with only empty operands is removed. @@ -1407,7 +1577,8 @@ TEST_F(AlgebraicSimplifierTest, OnlyEmptyConcatenateOperands) { auto computation = m->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), - op::Concatenate(empty_literal, empty_slice)); + GmockMatch(m::Concatenate(m::Op().Is(empty_literal), + m::Op().Is(empty_slice)))); AlgebraicSimplifier simplifier(default_options_); ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); @@ -1434,7 +1605,8 @@ TEST_F(AlgebraicSimplifierTest, ConcatenateOfBroadcastBecomesPad) { AlgebraicSimplifier simplifier(default_options_); ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); - EXPECT_THAT(computation->root_instruction(), op::Pad(param0, param1)); + EXPECT_THAT(computation->root_instruction(), + GmockMatch(m::Pad(m::Parameter(0), m::Parameter(1)))); } TEST_F(AlgebraicSimplifierTest, SimplifyConcatenateOfSlices) { @@ -1495,10 +1667,10 @@ TEST_F(AlgebraicSimplifierTest, SimplifyConcatenateOfSlices) { AlgebraicSimplifier simplifier(default_options_); ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); + auto s = m::Slice(m::Parameter(0)); EXPECT_THAT( computation->root_instruction(), - op::Concatenate(op::Slice(param0), op::Slice(param0), op::Slice(param0), - op::Slice(param0), op::Slice(param0), op::Slice(param1))); + GmockMatch(m::Concatenate(s, s, s, s, s, m::Slice(m::Parameter(1))))); // The operand 3 should be a merge of 'slice3', 'slice4' and 'slice5', so its // shape should have dimensions {50, 30}. EXPECT_TRUE( @@ -1524,7 +1696,8 @@ TEST_F(AlgebraicSimplifierTest, CopyWithDifferentLayout) { *param0->mutable_shape()->mutable_layout() = LayoutUtil::MakeLayout({0, 1}); *copy->mutable_shape()->mutable_layout() = LayoutUtil::MakeLayout({1, 0}); - EXPECT_THAT(computation->root_instruction(), op::Copy(param0)); + EXPECT_THAT(computation->root_instruction(), + GmockMatch(m::Copy(m::Parameter(0)))); AlgebraicSimplifierOptions options(non_bitcasting_callback()); options.set_is_layout_sensitive(true); @@ -1532,7 +1705,8 @@ TEST_F(AlgebraicSimplifierTest, CopyWithDifferentLayout) { EXPECT_FALSE(simplifier.Run(m.get()).ValueOrDie()); // Copy has not been removed. - EXPECT_THAT(computation->root_instruction(), op::Copy(param0)); + EXPECT_THAT(computation->root_instruction(), + GmockMatch(m::Copy(m::Parameter(0)))); } // Test that a simplification which preserves layouts is performed if layout @@ -1552,7 +1726,8 @@ TEST_F(AlgebraicSimplifierTest, CopyWithSameLayout) { *param0->mutable_shape()->mutable_layout() = LayoutUtil::MakeLayout({0, 1}); *copy->mutable_shape()->mutable_layout() = LayoutUtil::MakeLayout({0, 1}); - EXPECT_THAT(computation->root_instruction(), op::Copy(param0)); + EXPECT_THAT(computation->root_instruction(), + GmockMatch(m::Copy(m::Parameter(0)))); AlgebraicSimplifierOptions options(non_bitcasting_callback()); options.set_is_layout_sensitive(true); @@ -1581,7 +1756,8 @@ TEST_F(AlgebraicSimplifierTest, NoBitcastAdded) { auto computation = m->AddEntryComputation(builder.Build()); - EXPECT_THAT(computation->root_instruction(), op::Reshape(param0)); + EXPECT_THAT(computation->root_instruction(), + GmockMatch(m::Reshape(m::Parameter(0)))); AlgebraicSimplifierOptions options(non_bitcasting_callback()); options.set_is_layout_sensitive(true); @@ -1589,7 +1765,8 @@ TEST_F(AlgebraicSimplifierTest, NoBitcastAdded) { EXPECT_FALSE(simplifier.Run(m.get()).ValueOrDie()); // Reshape is not replaced with a bitcast. - EXPECT_THAT(computation->root_instruction(), op::Reshape(param0)); + EXPECT_THAT(computation->root_instruction(), + GmockMatch(m::Reshape(m::Parameter(0)))); } // Test transforming reshapes and transposes of rng. @@ -1617,9 +1794,9 @@ TEST_F(AlgebraicSimplifierTest, ReshapeOfTransposeOfRngToRng) { (AlgebraicSimplifierOptions(bitcasting_callback()))); EXPECT_TRUE(simplifier.Run(m.get()).ValueOrDie()); - // Verify that that reshape(transpose(rng)) is replace by a single rng of the + // Verify that reshape(transpose(rng)) is replace by a single rng of the // same shape as the reshape. - EXPECT_THAT(computation->root_instruction(), op::Rng()); + EXPECT_THAT(computation->root_instruction(), GmockMatch(m::Rng())); EXPECT_TRUE(ShapeUtil::Equal(computation->root_instruction()->shape(), reshape_shape)); } @@ -1661,8 +1838,9 @@ TEST_F(AlgebraicSimplifierTest, ReshapeReplacedWithBitcast) { auto computation = m->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), - op::Tuple(transformable_reshape, dimensions_wrong_reshape, - layout_wrong_reshape)); + GmockMatch(m::Tuple(m::Op().Is(transformable_reshape), + m::Op().Is(dimensions_wrong_reshape), + m::Op().Is(layout_wrong_reshape)))); AlgebraicSimplifierOptions options(bitcasting_callback()); options.set_is_layout_sensitive(true); @@ -1672,7 +1850,8 @@ TEST_F(AlgebraicSimplifierTest, ReshapeReplacedWithBitcast) { // Verify that only the first reshape is replaced. EXPECT_THAT( computation->root_instruction(), - op::Tuple(op::Bitcast(), dimensions_wrong_reshape, layout_wrong_reshape)); + GmockMatch(m::Tuple(m::Bitcast(), m::Op().Is(dimensions_wrong_reshape), + m::Op().Is(layout_wrong_reshape)))); } // Regression test for a bug where if we failed to sink a reshape, we'd set the @@ -1741,7 +1920,8 @@ TEST_F(AlgebraicSimplifierTest, TransposeEqualsBitcast1) { auto computation = m->AddEntryComputation(builder.Build()); - EXPECT_THAT(computation->root_instruction(), op::Transpose(param)); + EXPECT_THAT(computation->root_instruction(), + GmockMatch(m::Transpose(m::Parameter(0)))); AlgebraicSimplifierOptions options(bitcasting_callback()); options.set_is_layout_sensitive(true); @@ -1749,7 +1929,8 @@ TEST_F(AlgebraicSimplifierTest, TransposeEqualsBitcast1) { ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); // Verify that the reshape is replaced. - EXPECT_THAT(computation->root_instruction(), op::Bitcast(param)); + EXPECT_THAT(computation->root_instruction(), + GmockMatch(m::Bitcast(m::Parameter(0)))); } TEST_F(AlgebraicSimplifierTest, TransposeEqualsBitcast2) { @@ -1769,7 +1950,8 @@ TEST_F(AlgebraicSimplifierTest, TransposeEqualsBitcast2) { auto computation = m->AddEntryComputation(builder.Build()); - EXPECT_THAT(computation->root_instruction(), op::Transpose(param)); + EXPECT_THAT(computation->root_instruction(), + GmockMatch(m::Transpose(m::Parameter(0)))); AlgebraicSimplifierOptions options(bitcasting_callback()); options.set_is_layout_sensitive(true); @@ -1777,7 +1959,8 @@ TEST_F(AlgebraicSimplifierTest, TransposeEqualsBitcast2) { ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); // Verify that the reshape is replaced. - EXPECT_THAT(computation->root_instruction(), op::Bitcast(param)); + EXPECT_THAT(computation->root_instruction(), + GmockMatch(m::Bitcast(m::Parameter(0)))); } TEST_F(AlgebraicSimplifierTest, ReshapesMerged) { @@ -1797,12 +1980,13 @@ TEST_F(AlgebraicSimplifierTest, ReshapesMerged) { auto computation = m->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), - op::Reshape(op::Reshape(param0))); + GmockMatch(m::Reshape(m::Reshape(m::Parameter(0))))); AlgebraicSimplifier simplifier(default_options_); ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); - EXPECT_THAT(computation->root_instruction(), op::Reshape(param0)); + EXPECT_THAT(computation->root_instruction(), + GmockMatch(m::Reshape(m::Parameter(0)))); } TEST_F(AlgebraicSimplifierTest, CopiesMerged) { @@ -1823,14 +2007,16 @@ TEST_F(AlgebraicSimplifierTest, CopiesMerged) { auto computation = m->AddEntryComputation(builder.Build()); - EXPECT_THAT(computation->root_instruction(), op::Copy(op::Copy(param0))); + EXPECT_THAT(computation->root_instruction(), + GmockMatch(m::Copy(m::Copy(m::Parameter(0))))); AlgebraicSimplifierOptions options(non_bitcasting_callback()); options.set_is_layout_sensitive(true); AlgebraicSimplifier simplifier(options); ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); - EXPECT_THAT(computation->root_instruction(), op::Copy(param0)); + EXPECT_THAT(computation->root_instruction(), + GmockMatch(m::Copy(m::Parameter(0)))); } TEST_F(AlgebraicSimplifierTest, TransposesMerged) { @@ -1849,16 +2035,39 @@ TEST_F(AlgebraicSimplifierTest, TransposesMerged) { auto computation = m->AddEntryComputation(builder.Build()); - EXPECT_THAT(computation->root_instruction(), op::Transpose(transpose1)); + EXPECT_THAT(computation->root_instruction(), + GmockMatch(m::Transpose(m::Op().Is(transpose1)))); AlgebraicSimplifier simplifier(default_options_); ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); - EXPECT_THAT(computation->root_instruction(), op::Transpose(param0)); + EXPECT_THAT(computation->root_instruction(), + GmockMatch(m::Transpose(m::Parameter(0)))); EXPECT_EQ(std::vector({2, 1, 0}), computation->root_instruction()->dimensions()); } +TEST_F(AlgebraicSimplifierTest, TransposeIsReshape) { + const char* hlo_string = R"( + HloModule module + + ENTRY test { + param = f32[10] parameter(0) + reshaped = f32[1,1,10] reshape(f32[10] param) + transposed = f32[10,1,1] transpose(f32[1,1,10] reshaped), dimensions={2,1,0} + ROOT reshaped_again = f32[10] reshape(f32[10,1,1] transposed) + } + )"; + TF_ASSERT_OK_AND_ASSIGN( + auto module, + HloRunner::CreateModuleFromString(hlo_string, GetDebugOptionsForTest())); + + HloPassFix simplifier(default_options_); + EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + auto root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, GmockMatch(m::Parameter())); +} + // Test merging reshape and broadcast. TEST_F(AlgebraicSimplifierTest, ReshapeAndBroadcastMerged) { auto m = CreateNewVerifiedModule(); @@ -1873,12 +2082,13 @@ TEST_F(AlgebraicSimplifierTest, ReshapeAndBroadcastMerged) { auto computation = m->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), - op::Broadcast(op::Reshape(param0))); + GmockMatch(m::Broadcast(m::Reshape(m::Parameter(0))))); AlgebraicSimplifier simplifier(default_options_); ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); - EXPECT_THAT(computation->root_instruction(), op::Broadcast(param0)); + EXPECT_THAT(computation->root_instruction(), + GmockMatch(m::Broadcast(m::Parameter(0)))); } // Test merging broadcast and reshape. @@ -1895,12 +2105,13 @@ TEST_F(AlgebraicSimplifierTest, BroadcastAndReshapeMerged) { auto computation = m->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), - op::Reshape(op::Broadcast(param0))); + GmockMatch(m::Reshape(m::Broadcast(m::Parameter(0))))); AlgebraicSimplifier simplifier(default_options_); ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); - EXPECT_THAT(computation->root_instruction(), op::Broadcast(param0)); + EXPECT_THAT(computation->root_instruction(), + GmockMatch(m::Broadcast(m::Parameter(0)))); } TEST_F(AlgebraicSimplifierTest, BroadcastAndReshape_1_3x1_3) { @@ -1916,13 +2127,13 @@ TEST_F(AlgebraicSimplifierTest, BroadcastAndReshape_1_3x1_3) { auto computation = m->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), - op::Reshape(op::Broadcast(param))); + GmockMatch(m::Reshape(m::Broadcast(m::Parameter(0))))); AlgebraicSimplifier simplifier(default_options_); EXPECT_FALSE(simplifier.Run(m.get()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), - op::Reshape(op::Broadcast(param))); + GmockMatch(m::Reshape(m::Broadcast(m::Parameter(0))))); } TEST_F(AlgebraicSimplifierTest, BroadcastAndReshape_4_3x2x4_6x1x1x4) { @@ -1938,12 +2149,13 @@ TEST_F(AlgebraicSimplifierTest, BroadcastAndReshape_4_3x2x4_6x1x1x4) { HloComputation* computation = m->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), - op::Reshape(op::Broadcast(param))); + GmockMatch(m::Reshape(m::Broadcast(m::Parameter(0))))); AlgebraicSimplifier simplifier(default_options_); ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); - EXPECT_THAT(computation->root_instruction(), op::Broadcast(param)); + EXPECT_THAT(computation->root_instruction(), + GmockMatch(m::Broadcast(m::Parameter(0)))); EXPECT_THAT(computation->root_instruction()->dimensions(), ::testing::ElementsAre(3)); } @@ -1961,12 +2173,13 @@ TEST_F(AlgebraicSimplifierTest, BroadcastAndReshape_1_3x2x1_6x1x1x1) { HloComputation* computation = m->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), - op::Reshape(op::Broadcast(param))); + GmockMatch(m::Reshape(m::Broadcast(m::Parameter(0))))); AlgebraicSimplifier simplifier(default_options_); ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); - EXPECT_THAT(computation->root_instruction(), op::Broadcast(param)); + EXPECT_THAT(computation->root_instruction(), + GmockMatch(m::Broadcast(m::Parameter(0)))); const std::vector broadcast_dims = computation->root_instruction()->dimensions(); EXPECT_EQ(1, broadcast_dims.size()); @@ -1986,13 +2199,13 @@ TEST_F(AlgebraicSimplifierTest, BroadcastAndReshape_4_3x2x4x2_6x8) { HloComputation* computation = m->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), - op::Reshape(op::Broadcast(param))); + GmockMatch(m::Reshape(m::Broadcast(m::Parameter(0))))); AlgebraicSimplifier simplifier(default_options_); EXPECT_FALSE(simplifier.Run(m.get()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), - op::Reshape(op::Broadcast(param))); + GmockMatch(m::Reshape(m::Broadcast(m::Parameter(0))))); } TEST_F(AlgebraicSimplifierTest, IotaAndReshapeMerged) { @@ -2005,12 +2218,13 @@ TEST_F(AlgebraicSimplifierTest, IotaAndReshapeMerged) { auto computation = m->AddEntryComputation(builder.Build()); - EXPECT_THAT(computation->root_instruction(), op::Reshape(op::Iota())); + EXPECT_THAT(computation->root_instruction(), + GmockMatch(m::Reshape(m::Iota()))); AlgebraicSimplifier simplifier(default_options_); ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); - EXPECT_THAT(computation->root_instruction(), op::Iota()); + EXPECT_THAT(computation->root_instruction(), GmockMatch(m::Iota())); EXPECT_TRUE( ShapeUtil::Equal(computation->root_instruction()->shape(), result_shape)); } @@ -2024,13 +2238,13 @@ TEST_F(AlgebraicSimplifierTest, IotaEffectiveScalar) { auto computation = m->AddEntryComputation(builder.Build()); - EXPECT_THAT(computation->root_instruction(), op::Iota()); + EXPECT_THAT(computation->root_instruction(), GmockMatch(m::Iota())); AlgebraicSimplifier simplifier(default_options_); ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); auto root = computation->root_instruction(); - EXPECT_THAT(root, op::Broadcast(op::Constant())); + EXPECT_THAT(root, GmockMatch(m::Broadcast(m::Constant()))); EXPECT_EQ(0.0f, root->operand(0)->literal().GetFirstElement()); EXPECT_TRUE( ShapeUtil::Equal(computation->root_instruction()->shape(), result_shape)); @@ -2046,12 +2260,14 @@ TEST_F(AlgebraicSimplifierTest, IotaAndReshape_1_3x2_6) { auto computation = m->AddEntryComputation(builder.Build()); - EXPECT_THAT(computation->root_instruction(), op::Reshape(op::Iota())); + EXPECT_THAT(computation->root_instruction(), + GmockMatch(m::Reshape(m::Iota()))); AlgebraicSimplifier simplifier(default_options_); EXPECT_FALSE(simplifier.Run(m.get()).ValueOrDie()); - EXPECT_THAT(computation->root_instruction(), op::Reshape(op::Iota())); + EXPECT_THAT(computation->root_instruction(), + GmockMatch(m::Reshape(m::Iota()))); } TEST_F(AlgebraicSimplifierTest, IotaAndReshape_4_3x2x4_6x1x1x4) { @@ -2064,12 +2280,13 @@ TEST_F(AlgebraicSimplifierTest, IotaAndReshape_4_3x2x4_6x1x1x4) { HloComputation* computation = m->AddEntryComputation(builder.Build()); - EXPECT_THAT(computation->root_instruction(), op::Reshape(op::Iota())); + EXPECT_THAT(computation->root_instruction(), + GmockMatch(m::Reshape(m::Iota()))); AlgebraicSimplifier simplifier(default_options_); ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); - EXPECT_THAT(computation->root_instruction(), op::Iota()); + EXPECT_THAT(computation->root_instruction(), GmockMatch(m::Iota())); EXPECT_EQ(Cast(computation->root_instruction()) ->iota_dimension(), 3); @@ -2085,12 +2302,13 @@ TEST_F(AlgebraicSimplifierTest, IotaAndReshape_1_3x2x2_6x1x1x2) { HloComputation* computation = m->AddEntryComputation(builder.Build()); - EXPECT_THAT(computation->root_instruction(), op::Reshape(op::Iota())); + EXPECT_THAT(computation->root_instruction(), + GmockMatch(m::Reshape(m::Iota()))); AlgebraicSimplifier simplifier(default_options_); ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); - EXPECT_THAT(computation->root_instruction(), op::Iota()); + EXPECT_THAT(computation->root_instruction(), GmockMatch(m::Iota())); const int64 iota_dim = Cast(computation->root_instruction()) ->iota_dimension(); @@ -2107,12 +2325,14 @@ TEST_F(AlgebraicSimplifierTest, IotaAndReshape_4_3x2x4x2_6x8) { HloComputation* computation = m->AddEntryComputation(builder.Build()); - EXPECT_THAT(computation->root_instruction(), op::Reshape(op::Iota())); + EXPECT_THAT(computation->root_instruction(), + GmockMatch(m::Reshape(m::Iota()))); AlgebraicSimplifier simplifier(default_options_); EXPECT_FALSE(simplifier.Run(m.get()).ValueOrDie()); - EXPECT_THAT(computation->root_instruction(), op::Reshape(op::Iota())); + EXPECT_THAT(computation->root_instruction(), + GmockMatch(m::Reshape(m::Iota()))); } TEST_F(AlgebraicSimplifierTest, RemoveNoopPad) { @@ -2135,7 +2355,8 @@ TEST_F(AlgebraicSimplifierTest, RemoveNoopPad) { auto module = CreateNewVerifiedModule(); HloComputation* computation = module->AddEntryComputation(builder.Build()); - EXPECT_THAT(computation->root_instruction(), op::Pad(param, zero)); + EXPECT_THAT(computation->root_instruction(), + GmockMatch(m::Pad(m::Parameter(0), m::Op().Is(zero)))); AlgebraicSimplifier simplifier(default_options_); ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); @@ -2179,12 +2400,14 @@ TEST_F(AlgebraicSimplifierTest, NegativePadding) { return false; }; - EXPECT_THAT(computation->root_instruction(), op::Pad(param, zero)); + EXPECT_THAT(computation->root_instruction(), + GmockMatch(m::Pad(m::Parameter(0), m::Op().Is(zero)))); EXPECT_TRUE(has_negative_padding(pad)); ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); - EXPECT_THAT(computation->root_instruction(), op::Slice(op::Pad(param, zero))); + EXPECT_THAT(computation->root_instruction(), + GmockMatch(m::Slice(m::Pad(m::Parameter(0), m::Op().Is(zero))))); EXPECT_FALSE( has_negative_padding(computation->root_instruction()->operand(0))); } @@ -2213,12 +2436,14 @@ TEST_F(AlgebraicSimplifierTest, TrivialInteriorPadding) { AlgebraicSimplifier simplifier(default_options_); - ASSERT_THAT(computation->root_instruction(), op::Pad(param, zero)); + ASSERT_THAT(computation->root_instruction(), + GmockMatch(m::Pad(m::Parameter(0), m::Op().Is(zero)))); ASSERT_TRUE(HasInteriorPadding(pad->padding_config())); EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie()); - EXPECT_THAT(computation->root_instruction(), op::Pad(param, zero)); + EXPECT_THAT(computation->root_instruction(), + GmockMatch(m::Pad(m::Parameter(0), m::Op().Is(zero)))); EXPECT_FALSE( HasInteriorPadding(computation->root_instruction()->padding_config())); } @@ -2234,7 +2459,8 @@ TEST_F(AlgebraicSimplifierTest, RemoveNoopReshape) { auto module = CreateNewVerifiedModule(); HloComputation* computation = module->AddEntryComputation(builder.Build()); - EXPECT_THAT(computation->root_instruction(), op::Reshape(param)); + EXPECT_THAT(computation->root_instruction(), + GmockMatch(m::Reshape(m::Parameter(0)))); AlgebraicSimplifier simplifier(default_options_); ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); @@ -2256,7 +2482,8 @@ TEST_F(AlgebraicSimplifierTest, RemoveNoopSlice) { auto module = CreateNewVerifiedModule(); HloComputation* computation = module->AddEntryComputation(builder.Build()); - EXPECT_THAT(computation->root_instruction(), op::Slice(param)); + EXPECT_THAT(computation->root_instruction(), + GmockMatch(m::Slice(m::Parameter(0)))); AlgebraicSimplifier simplifier(default_options_); ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); @@ -2284,12 +2511,14 @@ TEST_F(AlgebraicSimplifierTest, SliceOfSliceToSlice) { auto module = CreateNewVerifiedModule(); HloComputation* computation = module->AddEntryComputation(builder.Build()); - EXPECT_THAT(computation->root_instruction(), op::Slice(op::Slice(param))); + EXPECT_THAT(computation->root_instruction(), + GmockMatch(m::Slice(m::Slice(m::Parameter(0))))); AlgebraicSimplifier simplifier(default_options_); ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); - EXPECT_THAT(computation->root_instruction(), op::Slice(param)); + EXPECT_THAT(computation->root_instruction(), + GmockMatch(m::Slice(m::Parameter(0)))); EXPECT_EQ(computation->root_instruction()->slice_starts(0), 3); EXPECT_EQ(computation->root_instruction()->slice_starts(1), 5); EXPECT_EQ(computation->root_instruction()->slice_limits(0), dim0 - 2); @@ -2315,12 +2544,14 @@ TEST_F(AlgebraicSimplifierTest, SliceOfReshapeToReshapeOfSlice) { auto module = CreateNewVerifiedModule(); HloComputation* computation = module->AddEntryComputation(builder.Build()); - EXPECT_THAT(computation->root_instruction(), op::Slice(op::Reshape(param))); + EXPECT_THAT(computation->root_instruction(), + GmockMatch(m::Slice(m::Reshape(m::Parameter(0))))); AlgebraicSimplifier simplifier(default_options_); ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); - EXPECT_THAT(computation->root_instruction(), op::Reshape(op::Slice(param))); + EXPECT_THAT(computation->root_instruction(), + GmockMatch(m::Reshape(m::Slice(m::Parameter(0))))); } TEST_F(AlgebraicSimplifierTest, SliceOfReshapeUnchanged) { @@ -2339,7 +2570,8 @@ TEST_F(AlgebraicSimplifierTest, SliceOfReshapeUnchanged) { auto module = CreateNewVerifiedModule(); HloComputation* computation = module->AddEntryComputation(builder.Build()); - EXPECT_THAT(computation->root_instruction(), op::Slice(op::Reshape(param))); + EXPECT_THAT(computation->root_instruction(), + GmockMatch(m::Slice(m::Reshape(m::Parameter(0))))); AlgebraicSimplifier simplifier(default_options_); ASSERT_FALSE(simplifier.Run(module.get()).ValueOrDie()); @@ -2380,10 +2612,10 @@ TEST_F(AlgebraicSimplifierTest, ReplacePermutationSortWithScatter) { EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie()); auto root = module->entry_computation()->root_instruction(); EXPECT_THAT(root, - op::Tuple(op::Iota(), - op::Scatter(op::Iota(), - op::Concatenate(op::Iota(), op::Reshape()), - op::Reshape()))); + GmockMatch(m::Tuple( + m::Iota(), + m::Scatter(m::Iota(), m::Concatenate(m::Iota(), m::Reshape()), + m::Reshape())))); } TEST_F(AlgebraicSimplifierTest, DontReplacePermutationSortIfNonIntegral) { @@ -2451,7 +2683,8 @@ TEST_F(AlgebraicSimplifierTest, ReplaceEffectiveScalarKeyValueSortWithTuple) { AlgebraicSimplifier simplifier(default_options_); ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), - op::Tuple(keys, values0, values1)); + GmockMatch(m::Tuple(m::Op().Is(keys), m::Op().Is(values0), + m::Op().Is(values1)))); } // Test that A && True is simplified to A @@ -2753,7 +2986,8 @@ TEST_P(ConvInputPaddingTest, DoTest) { ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); auto* conv = module->entry_computation()->root_instruction(); SCOPED_TRACE(module->ToString()); - ASSERT_THAT(conv, op::Convolution(op::Parameter(), op::Parameter())); + ASSERT_THAT(conv, + GmockMatch(m::Convolution(m::Parameter(), m::Parameter()))); EXPECT_EQ(window_util::ToString(conv->window()), absl::StrCat("size=3x3 ", testcase.expected_conv_window)); } @@ -2870,7 +3104,8 @@ TEST_P(ConvFilterPaddingTest, DoIt) { ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); auto* conv = module->entry_computation()->root_instruction(); SCOPED_TRACE(module->ToString()); - ASSERT_THAT(conv, op::Convolution(op::Parameter(), op::Parameter())); + ASSERT_THAT(conv, + GmockMatch(m::Convolution(m::Parameter(), m::Parameter()))); EXPECT_EQ(window_util::ToString(conv->window()), absl::StrFormat("size=%dx%d %s", conv->operand(1)->shape().dimensions(2), @@ -3142,10 +3377,9 @@ TEST_F(AlgebraicSimplifierTest, ScalarBroadcastToSlice) { // Running simplification again should not result in any further changes. ASSERT_FALSE(simplifier.Run(module.get()).ValueOrDie()); - - root = computation->root_instruction(); - EXPECT_THAT(root, op::Broadcast(scalar_param)); - EXPECT_TRUE(ShapeUtil::Equal(root->shape(), slice_shape)); + EXPECT_THAT(computation->root_instruction(), + GmockMatch(m::Broadcast(m::Op().Is(scalar_param)) + .WithShapeEqualTo(&slice_shape))); } // Test that reshape(transpose(broadcast(/*scalar value*/))) simplifies to a @@ -3176,10 +3410,9 @@ TEST_F(AlgebraicSimplifierTest, ScalarBroadcastToTransposeReshape) { AlgebraicSimplifier simplifier(default_options_); ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); - - root = computation->root_instruction(); - EXPECT_THAT(root, op::Broadcast(forty_two)); - EXPECT_TRUE(ShapeUtil::Equal(root->shape(), reshape_shape)); + EXPECT_THAT(computation->root_instruction(), + GmockMatch(m::Broadcast(m::Op().Is(forty_two)) + .WithShapeEqualTo(&reshape_shape))); } // Test that ReduceWindow(Pad(op, x), y) can simplify to ReduceWindow(op, x). @@ -3248,7 +3481,8 @@ TEST_F(AlgebraicSimplifierTest, FoldPadIntoReduceWindow) { // Verify the result root = computation->root_instruction(); - EXPECT_THAT(root, op::ReduceWindow(operand, op::Constant())); + EXPECT_THAT(root, + GmockMatch(m::ReduceWindow(m::Op().Is(operand), m::Constant()))); EXPECT_TRUE(ShapeUtil::Equal(root->shape(), reduce_window_shape)) << ShapeUtil::HumanString(root->shape()) << " vs " << ShapeUtil::HumanString(reduce_window_shape); @@ -3333,7 +3567,8 @@ TEST_F(AlgebraicSimplifierTest, FoldConvertedPadIntoReduceWindow) { // Verify the result root = computation->root_instruction(); - EXPECT_THAT(root, op::ReduceWindow(op::Convert(parameter), op::Constant())); + EXPECT_THAT(root, GmockMatch(m::ReduceWindow(m::Convert(m::Parameter(0)), + m::Constant()))); EXPECT_TRUE(ShapeUtil::Equal(root->shape(), reduce_window_shape)) << ShapeUtil::HumanString(root->shape()) << " vs " << ShapeUtil::HumanString(reduce_window_shape); @@ -3414,7 +3649,7 @@ TEST_F(AlgebraicSimplifierTest, ConstantTupleBecomesTupleOfConstants) { AlgebraicSimplifier simplifier(default_options_); ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), - op::Tuple(op::Constant(), op::Constant())); + GmockMatch(m::Tuple(m::Constant(), m::Constant()))); } // A dynamic-slice is trivial if its start indices are all zeroes and the size @@ -3436,7 +3671,7 @@ TEST_F(AlgebraicSimplifierTest, TrivialDynamicSlice) { auto computation = m->AddEntryComputation(builder.Build()); AlgebraicSimplifier simplifier(default_options_); ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); - EXPECT_THAT(computation->root_instruction(), op::Parameter()); + EXPECT_THAT(computation->root_instruction(), GmockMatch(m::Parameter())); } // A dynamic-update-slice is trivial if its start indices are all zeroes and the @@ -3470,7 +3705,7 @@ TEST_F(AlgebraicSimplifierTest, TrivialDynamicUpdateSlice) { AlgebraicSimplifier simplifier(default_options_); ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), - op::DynamicSlice(op::Parameter(), op::Parameter())); + GmockMatch(m::DynamicSlice(m::Parameter(), m::Parameter()))); } // Test that two consecutive broadcasts can be merged to one. @@ -3492,7 +3727,7 @@ TEST_F(AlgebraicSimplifierTest, MergeBroadcasts) { AlgebraicSimplifier simplifier(default_options_); ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); root = computation->root_instruction(); - EXPECT_THAT(root, op::Broadcast(op::Constant())); + EXPECT_THAT(root, GmockMatch(m::Broadcast(m::Constant()))); EXPECT_THAT(root->dimensions(), ElementsAre(2)); } @@ -3518,7 +3753,7 @@ TEST_F(AlgebraicSimplifierTest, MergeBroadcasts2) { AlgebraicSimplifier simplifier(default_options_); ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); root = computation->root_instruction(); - EXPECT_THAT(root, op::Broadcast(op::Parameter(0))); + EXPECT_THAT(root, GmockMatch(m::Broadcast(m::Parameter(0)))); EXPECT_THAT(root->dimensions(), ElementsAre(1, 3)); } @@ -3538,7 +3773,7 @@ TEST_F(AlgebraicSimplifierTest, MergeBroadcastAndIota) { AlgebraicSimplifier simplifier(default_options_); ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); root = computation->root_instruction(); - EXPECT_THAT(root, op::Iota()); + EXPECT_THAT(root, GmockMatch(m::Iota())); EXPECT_EQ(Cast(root)->iota_dimension(), 2); } @@ -3559,7 +3794,7 @@ TEST_F(AlgebraicSimplifierTest, MergeBroadcastAndIota2) { AlgebraicSimplifier simplifier(default_options_); ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); root = computation->root_instruction(); - EXPECT_THAT(root, op::Iota()); + EXPECT_THAT(root, GmockMatch(m::Iota())); EXPECT_EQ(Cast(root)->iota_dimension(), 2); } @@ -3581,7 +3816,7 @@ TEST_F(AlgebraicSimplifierTest, SliceOfPadLow) { AlgebraicSimplifier simplifier(options); EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie()); auto root = module->entry_computation()->root_instruction(); - EXPECT_THAT(root, op::Reshape(op::Constant())); + EXPECT_THAT(root, GmockMatch(m::Reshape(m::Constant()))); } TEST_F(AlgebraicSimplifierTest, SliceOfPadHigh) { @@ -3602,7 +3837,7 @@ TEST_F(AlgebraicSimplifierTest, SliceOfPadHigh) { AlgebraicSimplifier simplifier(options); EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie()); auto root = module->entry_computation()->root_instruction(); - EXPECT_THAT(root, op::Reshape(op::Constant())); + EXPECT_THAT(root, GmockMatch(m::Reshape(m::Constant()))); } TEST_F(AlgebraicSimplifierTest, SliceOfPadMidNonScalar) { @@ -3642,7 +3877,7 @@ TEST_F(AlgebraicSimplifierTest, SliceOfPadMidScalar) { AlgebraicSimplifier simplifier(options); EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie()); auto root = module->entry_computation()->root_instruction(); - EXPECT_THAT(root, op::Parameter()); + EXPECT_THAT(root, GmockMatch(m::Parameter())); } TEST_F(AlgebraicSimplifierTest, SliceOfConcatScalarInput) { @@ -3664,7 +3899,7 @@ TEST_F(AlgebraicSimplifierTest, SliceOfConcatScalarInput) { AlgebraicSimplifier simplifier(options); EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie()); auto root = module->entry_computation()->root_instruction(); - EXPECT_THAT(root, op::Parameter(1)); + EXPECT_THAT(root, GmockMatch(m::Parameter(1))); } TEST_F(AlgebraicSimplifierTest, SliceOfConcatNonScalarInput) { @@ -3686,7 +3921,7 @@ TEST_F(AlgebraicSimplifierTest, SliceOfConcatNonScalarInput) { AlgebraicSimplifier simplifier(options); EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie()); auto root = module->entry_computation()->root_instruction(); - EXPECT_THAT(root, op::Slice(op::Parameter(2))); + EXPECT_THAT(root, GmockMatch(m::Slice(m::Parameter(2)))); EXPECT_EQ(root->slice_starts(0), 1); EXPECT_EQ(root->slice_limits(0), 2); } @@ -3708,7 +3943,7 @@ TEST_F(AlgebraicSimplifierTest, NegateNegate) { AlgebraicSimplifier simplifier(options); EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie()); auto root = module->entry_computation()->root_instruction(); - EXPECT_THAT(root, op::Parameter(0)); + EXPECT_THAT(root, GmockMatch(m::Parameter(0))); } TEST_F(AlgebraicSimplifierTest, NotNot) { @@ -3728,7 +3963,7 @@ TEST_F(AlgebraicSimplifierTest, NotNot) { AlgebraicSimplifier simplifier(options); EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie()); auto root = module->entry_computation()->root_instruction(); - EXPECT_THAT(root, op::Parameter(0)); + EXPECT_THAT(root, GmockMatch(m::Parameter(0))); } struct PadReduceWindowEffectiveBroadcastCase { @@ -3832,10 +4067,10 @@ TEST_P(PadReduceWindowEffectiveBroadcastTest, DoIt) { ShapeUtil::Equal(computation->root_instruction()->shape(), output_shape)); if (param.should_become_broadcast) { - EXPECT_THAT(computation->root_instruction(), op::Broadcast(::testing::_)); + EXPECT_THAT(computation->root_instruction(), GmockMatch(m::Broadcast())); } else { EXPECT_THAT(computation->root_instruction(), - op::ReduceWindow(::testing::_, zero)); + GmockMatch(m::ReduceWindow(m::Op(), m::Op().Is(zero)))); } } @@ -3869,6 +4104,57 @@ INSTANTIATE_TEST_CASE_P( PadReduceWindowEffectiveBroadcastTest, ::testing::ValuesIn(PadReduceWindowEffectiveBroadcastCases())); +class BatchDotStrengthReductionTest + : public AlgebraicSimplifierTest, + public ::testing::WithParamInterface< + ::testing::tuple> {}; +TEST_P(BatchDotStrengthReductionTest, BatchDotStrengthReduction) { + auto module = CreateNewVerifiedModule(); + int m, k, n; + PrimitiveType element_type; + std::tie(m, k, n, element_type) = GetParam(); + + Shape dot_shape = ShapeUtil::MakeShape(element_type, {1, 3, 5, m, n}); + Shape lhs_shape = ShapeUtil::MakeShape(element_type, {1, 3, 5, m, k}); + Shape rhs_shape = ShapeUtil::MakeShape(element_type, {1, 3, 5, k, n}); + HloComputation::Builder builder(TestName()); + + auto lhs = builder.AddInstruction( + HloInstruction::CreateParameter(0, lhs_shape, "lhs")); + auto rhs = builder.AddInstruction( + HloInstruction::CreateParameter(1, rhs_shape, "rhs")); + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_batch_dimensions(0); + dot_dnums.add_lhs_batch_dimensions(1); + dot_dnums.add_lhs_batch_dimensions(2); + dot_dnums.add_rhs_batch_dimensions(0); + dot_dnums.add_rhs_batch_dimensions(1); + dot_dnums.add_rhs_batch_dimensions(2); + dot_dnums.add_lhs_contracting_dimensions(4); + dot_dnums.add_rhs_contracting_dimensions(3); + builder.AddInstruction(HloInstruction::CreateDot( + dot_shape, lhs, rhs, dot_dnums, DefaultPrecisionConfig(2))); + auto computation = module->AddEntryComputation(builder.Build()); + AlgebraicSimplifier simplifier(default_options_); + TF_ASSERT_OK_AND_ASSIGN(bool changed, simplifier.Run(module.get())); + const bool dot_should_be_transformed = m == 1 || k == 1 || n == 1; + const bool computation_should_be_modified = dot_should_be_transformed; + EXPECT_EQ(changed, computation_should_be_modified); + bool has_no_dot = true; + for (const auto& hlo : computation->instructions()) { + if (hlo->opcode() == HloOpcode::kDot) { + has_no_dot = false; + break; + } + } + EXPECT_EQ(has_no_dot, dot_should_be_transformed); +} + +INSTANTIATE_TEST_CASE_P( + BatchDotStrengthReductionTestInstantiation, BatchDotStrengthReductionTest, + ::testing::Combine(::testing::Values(1, 2), ::testing::Values(1, 2), + ::testing::Values(1, 2), ::testing::Values(F32, BF16))); + class DotStrengthReductionTest : public AlgebraicSimplifierTest, public ::testing::WithParamInterface< @@ -3989,11 +4275,12 @@ TEST_P(DotOfConcatSimplificationTest, ConstantLHS) { EXPECT_TRUE( ShapeUtil::Equal(computation->root_instruction()->shape(), dot_shape)); - auto match_dot_0 = op::Dot(op::Slice(op::Constant()), op::Parameter(0)); - auto match_dot_1 = op::Dot(op::Slice(op::Constant()), op::Parameter(1)); - auto match_dot_2 = op::Dot(op::Slice(op::Constant()), op::Parameter(2)); - EXPECT_THAT(computation->root_instruction(), - op::Add(op::Add(match_dot_0, match_dot_1), match_dot_2)); + auto match_dot_0 = m::Dot(m::Slice(m::Constant()), m::Parameter(0)); + auto match_dot_1 = m::Dot(m::Slice(m::Constant()), m::Parameter(1)); + auto match_dot_2 = m::Dot(m::Slice(m::Constant()), m::Parameter(2)); + EXPECT_THAT( + computation->root_instruction(), + GmockMatch(m::Add(m::Add(match_dot_0, match_dot_1), match_dot_2))); } // Test that we transform @@ -4052,13 +4339,14 @@ TEST_P(DotOfConcatSimplificationTest, ConstantRHS) { EXPECT_TRUE( ShapeUtil::Equal(computation->root_instruction()->shape(), dot_shape)); - auto match_dot_0 = op::Dot(op::Parameter(0), op::Slice(op::Constant())); - auto match_dot_1 = op::Dot(op::Parameter(1), op::Slice(op::Constant())); - auto match_dot_2 = op::Dot(op::Parameter(2), op::Slice(op::Constant())); - auto match_dot_3 = op::Dot(op::Parameter(3), op::Slice(op::Constant())); - EXPECT_THAT(computation->root_instruction(), - op::Add(op::Add(op::Add(match_dot_0, match_dot_1), match_dot_2), - match_dot_3)); + auto match_dot_0 = m::Dot(m::Parameter(0), m::Slice(m::Constant())); + auto match_dot_1 = m::Dot(m::Parameter(1), m::Slice(m::Constant())); + auto match_dot_2 = m::Dot(m::Parameter(2), m::Slice(m::Constant())); + auto match_dot_3 = m::Dot(m::Parameter(3), m::Slice(m::Constant())); + EXPECT_THAT( + computation->root_instruction(), + GmockMatch(m::Add(m::Add(m::Add(match_dot_0, match_dot_1), match_dot_2), + match_dot_3))); } DotOfConcatTestSpec kDotOfConcatTestSpecs[] = { @@ -4175,8 +4463,8 @@ TEST_P(DotOfGatherSimplificationTest, ConstantRHS) { HloOpcode::kDynamicSlice); } else { EXPECT_THAT(computation->root_instruction(), - op::DynamicSlice(op::Dot(op::Constant(), op::Constant()), - op::Concatenate())); + GmockMatch(m::DynamicSlice(m::Dot(m::Constant(), m::Constant()), + m::Concatenate()))); } } @@ -4245,8 +4533,8 @@ TEST_P(DotOfGatherSimplificationTest, ConstantLHS) { HloOpcode::kDynamicSlice); } else { EXPECT_THAT(computation->root_instruction(), - op::DynamicSlice(op::Dot(op::Constant(), op::Constant()), - op::Concatenate())); + GmockMatch(m::DynamicSlice(m::Dot(m::Constant(), m::Constant()), + m::Concatenate()))); } } diff --git a/tensorflow/compiler/xla/service/ar_crs_combiner.cc b/tensorflow/compiler/xla/service/ar_crs_combiner.cc index c11452a6fbd49a1fc382d11d24a7d7b7eeab0bcc..47d2c7e35705698d49950c2fa042af1c6327d521 100644 --- a/tensorflow/compiler/xla/service/ar_crs_combiner.cc +++ b/tensorflow/compiler/xla/service/ar_crs_combiner.cc @@ -36,31 +36,47 @@ namespace { namespace m = match; -// If the argument instruction is a CRS in the sequence -// AR -> Convert -> Add -> CRS -// then return the AR in the sequence. -// TODO(b/117554291): Rewrite this to recognize more general patterns, -// not just the specific one of AR -> Add -> Convert -> CRS. -absl::optional MatchesArCrsPattern( - HloInstruction* instruction) { - HloInstruction *ar, *convert, *add, *crs; - if (Match(instruction, - m::CrossReplicaSum( - &crs, m::Add(&add, m::Op(), - m::Convert(&convert, - m::CrossReplicaSum(&ar, m::Op()))))) && - ar->users().size() == 1 && ar->shape().element_type() == BF16 && - convert->shape().element_type() == F32 && !crs->all_reduce_id()) { - return ar; +// Returns true iff the argument instruction is an AllReduce, followed by a +// certain sequence of instructions and then a CRS. It must be possible to move +// the AR past each instruction in the sequence. +bool MatchesArCrsPattern(HloInstruction* instruction) { + auto can_ar_move_past_instruction = [](HloInstruction* instruction) -> bool { + if (instruction->user_count() != 1) { + return false; + } + auto opcode = instruction->opcode(); + return opcode == HloOpcode::kBitcast || opcode == HloOpcode::kTranspose || + opcode == HloOpcode::kReshape || opcode == HloOpcode::kConvert || + opcode == HloOpcode::kAdd || opcode == HloOpcode::kSubtract || + opcode == HloOpcode::kMultiply; + }; + + auto computation_is_addition = [](HloComputation* c) { + return c->instruction_count() == 3 && + Match(c->root_instruction(), m::Add(m::Parameter(), m::Parameter())); + }; + + if (!instruction->IsCrossModuleAllReduce() || + !computation_is_addition(instruction->called_computations()[0]) || + instruction->user_count() != 1) { + return false; } - return absl::optional(); + auto next = instruction->users()[0]; + while (!next->IsCrossReplicaAllReduce()) { + if (can_ar_move_past_instruction(next)) { + next = next->users()[0]; + } else { + return false; + } + } + return computation_is_addition(next->called_computations()[0]); } } // namespace absl::optional ArCrsCombiner::WhileFromBodyParameter( HloInstruction* instruction) { - CHECK(HloOpcode::kParameter == instruction->opcode()); + CHECK_EQ(HloOpcode::kParameter, instruction->opcode()); HloComputation* computation = instruction->parent(); auto caller_instructions = call_graph_->GetComputationCallers(computation); if (caller_instructions.size() == 1) { @@ -120,7 +136,7 @@ bool ArCrsCombiner::TupleElementsComputeSameValue( return false; } for (auto tuple : tuples) { - CHECK(tuple->opcode() == HloOpcode::kTuple); + CHECK_EQ(tuple->opcode(), HloOpcode::kTuple); if (!InstructionsComputeSameValue(tuple->mutable_operand(i1), tuple->mutable_operand(i2), visited_pairs)) { @@ -160,13 +176,6 @@ bool ArCrsCombiner::InstructionsComputeSameValue( if (opcode1 != i2->opcode() || operands1.size() != i2->operands().size()) { return false; } - if (opcode1 == HloOpcode::kConstant || i1->IsCrossModuleAllReduce()) { - return i1->Identical( - *i2, - /*eq_operands=*/std::equal_to(), - /*eq_computations=*/std::equal_to(), - /*layout_sensitive=*/false); - } visited_pairs->emplace(min_uid, max_uid); for (int i = 0; i < operands1.size(); ++i) { auto operand1 = operands1[i]; @@ -175,22 +184,35 @@ bool ArCrsCombiner::InstructionsComputeSameValue( return false; } } + if (opcode1 == HloOpcode::kParameter) { + // In the general case, we don't try to prove equality of parameters. + // We only try in the context of get-tuple-element + // (see TupleElementsComputeSameValue). + return false; + } if (opcode1 == HloOpcode::kGetTupleElement) { - if (i1->tuple_index() == i2->tuple_index()) { - return true; - } - return TupleElementsComputeSameValue(operands1[0], i1->tuple_index(), + return i1->tuple_index() == i2->tuple_index() || + TupleElementsComputeSameValue(operands1[0], i1->tuple_index(), i2->tuple_index(), visited_pairs); } - return true; + // Don't check that the operands are identical, because Identical can + // return false for instructions that compute the same value but are not + // identical, which we don't want. We have checked the arguments with + // InstructionsComputeSameValue earlier. + auto eq_instructions = [](const HloInstruction* i1, + const HloInstruction* i2) -> bool { return true; }; + auto eq_computations = [](const HloComputation* a, const HloComputation* b) { + return *a == *b; + }; + return i1->Identical(*i2, eq_instructions, eq_computations, + /*layout_sensitive=*/false); } void ArCrsCombiner::GroupAllReducesById(HloModule* module) { for (HloComputation* computation : module->MakeNonfusionComputations()) { for (HloInstruction* instruction : computation->instructions()) { - auto ar = MatchesArCrsPattern(instruction); - if (ar) { - all_reduce_map_[*((*ar)->all_reduce_id())].push_back(*ar); + if (MatchesArCrsPattern(instruction)) { + all_reduce_map_[*(instruction->all_reduce_id())].push_back(instruction); } } } @@ -198,21 +220,23 @@ void ArCrsCombiner::GroupAllReducesById(HloModule* module) { void ArCrsCombiner::KeepProvablyEqualInstructionGroups() { for (auto it : all_reduce_map_) { + auto all_reduce_id = it.first; auto instruction_vec = it.second; CHECK_EQ(instruction_vec.size(), num_spatial_partitions_); - auto instr_0 = instruction_vec[0]; - auto add_0 = instr_0->users()[0]->users()[0]; - CHECK(HloOpcode::kAdd == add_0->opcode()); - for (int i = 1; i < instruction_vec.size(); ++i) { auto instr_i = instruction_vec[i]; - auto add_i = instr_i->users()[0]->users()[0]; - CHECK(HloOpcode::kAdd == add_i->opcode()); + auto next_0 = instr_0->users()[0]; + auto next_i = instr_i->users()[0]; absl::flat_hash_map visited_pairs; - if (!InstructionsComputeSameValue(add_0, add_i, &visited_pairs)) { - all_reduce_map_.erase(it.first); - } + do { + if (!InstructionsComputeSameValue(next_0, next_i, &visited_pairs)) { + all_reduce_map_.erase(all_reduce_id); + break; + } + next_0 = next_0->users()[0]; + next_i = next_i->users()[0]; + } while (!next_0->IsCrossReplicaAllReduce()); } } } @@ -221,55 +245,51 @@ StatusOr ArCrsCombiner::RewriteGraph() { if (all_reduce_map_.empty()) { return false; } - - auto computation_is_addition = [](HloComputation* c) { - return c->instruction_count() == 3 && - Match(c->root_instruction(), m::Add(m::Parameter(), m::Parameter())); - }; - for (auto it : all_reduce_map_) { auto instruction_vec = it.second; for (auto all_reduce : instruction_vec) { auto parent_computation = all_reduce->parent(); - auto convert = all_reduce->users()[0]; - auto add = convert->users()[0]; - auto crs = add->users()[0]; - - if (!computation_is_addition(all_reduce->called_computations()[0]) || - !computation_is_addition(crs->called_computations()[0])) { - continue; - } - HloInstruction* other_summand = (add->operands()[0] == convert) - ? add->operands()[1] - : add->operands()[0]; - // Remove the AllReduce and replace the CRS with: - // AllReduce - (other_summand * (num_spatial_partitions_ - 1)) - TF_CHECK_OK( - all_reduce->ReplaceAllUsesWith(all_reduce->mutable_operand(0))); - crs->set_all_reduce_id(all_reduce->all_reduce_id()); - auto new_shape = crs->shape(); - HloInstruction* to_subtract; - if (num_spatial_partitions_ == 2) { - to_subtract = other_summand; - } else { - Literal partitions_minus_1_lit = Literal(new_shape); - partitions_minus_1_lit.PopulateWithValue( - num_spatial_partitions_ - 1); - auto partitions_minus_1_const = parent_computation->AddInstruction( - HloInstruction::CreateConstant(partitions_minus_1_lit.Clone())); - to_subtract = - parent_computation->AddInstruction(HloInstruction::CreateBinary( - new_shape, HloOpcode::kMultiply, other_summand, - partitions_minus_1_const)); - } - auto sub = - parent_computation->AddInstruction(HloInstruction::CreateBinary( - new_shape, HloOpcode::kSubtract, crs, to_subtract)); - TF_CHECK_OK(crs->ReplaceAllUsesWith(sub)); + auto all_reduce_id = all_reduce->all_reduce_id(); + auto prev = all_reduce->mutable_operand(0); + auto next = all_reduce->users()[0]; + TF_CHECK_OK(all_reduce->ReplaceUseWith(next, prev)); TF_CHECK_OK(parent_computation->RemoveInstruction(all_reduce)); + while (!next->IsCrossReplicaAllReduce()) { + switch (next->opcode()) { + case HloOpcode::kBitcast: + case HloOpcode::kTranspose: + case HloOpcode::kReshape: + case HloOpcode::kConvert: + case HloOpcode::kMultiply: + break; + case HloOpcode::kAdd: + case HloOpcode::kSubtract: { + auto other_operand = (next->operands()[0] == prev) + ? next->operands()[1] + : next->operands()[0]; + // To move the AR past the addition/subtraction, we need to divide + // other_operand by the number of spatial partitions. + auto shape = other_operand->shape(); + Literal lit(shape); + lit.PopulateWithValue(num_spatial_partitions_); + auto divisor = parent_computation->AddInstruction( + HloInstruction::CreateConstant(lit.Clone())); + auto division = + parent_computation->AddInstruction(HloInstruction::CreateBinary( + shape, HloOpcode::kDivide, other_operand, divisor)); + TF_CHECK_OK(other_operand->ReplaceUseWith(next, division)); + break; + } + default: + LOG(FATAL) << "Unexpected instruction: " << next->ToShortString(); + } + prev = next; + next = next->users()[0]; + } + // The AllReduce and the CRS are combined to an all-core AllReduce. + next->set_all_reduce_id(all_reduce_id); } } - return true; } diff --git a/tensorflow/compiler/xla/service/ar_crs_combiner.h b/tensorflow/compiler/xla/service/ar_crs_combiner.h index f6a7ef76ec3b76972d1b2c7fb548cecfb9423160..6be7e1002dc6822bf0b563721f00896da171c0a9 100644 --- a/tensorflow/compiler/xla/service/ar_crs_combiner.h +++ b/tensorflow/compiler/xla/service/ar_crs_combiner.h @@ -25,9 +25,12 @@ limitations under the License. namespace xla { -// Combine an AllReduce and a CrossReplicaSum when they are close to each other -// in the graph, to use an efficient CrossReplicaSum implementation that -// fully utilizes the interconnect bandwidth. +// When the HLO graph contains an AllReduce, followed by some simple linear +// operations, followed by a CrossReplicaSum, we can combine the AR and the CRS, +// to use an efficient CrossReplicaSum implementation that fully utilizes the +// interconnect bandwidth. +// Such sequences appear in spatially partitioned models. +// This pass must run right after spatial partitioning. class ArCrsCombiner : public HloModulePass { public: ArCrsCombiner(int num_spatial_partitions) diff --git a/tensorflow/compiler/xla/service/ar_crs_combiner_test.cc b/tensorflow/compiler/xla/service/ar_crs_combiner_test.cc index 9d5eaf63ccf32cd78b8c11f12f9bccdfd1fec3e0..8a4fd0ee1b25ec82f5dadfc8446af185914d4033 100644 --- a/tensorflow/compiler/xla/service/ar_crs_combiner_test.cc +++ b/tensorflow/compiler/xla/service/ar_crs_combiner_test.cc @@ -32,8 +32,8 @@ HloModule foobar ENTRY %entrycomp (p: f32[2,2]) -> (f32[2,2], f32[2,2]) { %p = f32[2,2] parameter(0) - %constant.f32.1 = f32[2,2] constant(f32[2,2] {{1, 2}, {3, 4}}) - %constant.f32.2 = f32[2,2] constant(f32[2,2] {{1, 2}, {3, 4}}) + %constant.f32.1 = f32[2,2] constant({{1, 2}, {3, 4}}) + %constant.f32.2 = f32[2,2] constant({{1, 2}, {3, 4}}) ROOT %tuple = (f32[2,2], f32[2,2]) tuple(%constant.f32.1, %constant.f32.2) } )"; @@ -48,13 +48,50 @@ ENTRY %entrycomp (p: f32[2,2]) -> (f32[2,2], f32[2,2]) { EXPECT_TRUE(ArCrsCombiner::TestInstructionsComputeSameValue(i1, i2)); } +TEST_F(ArCrsCombinerTest, SameValueTestBasecase2) { + const char* module_str = R"( +HloModule foobar + +ENTRY %entrycomp (x: f32[]) -> (f32[], f32[]) { + %x = f32[] parameter(0) + ROOT %tuple = (f32[], f32[]) tuple(%x, %x) +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(module_str)); + auto root_tuple = module->entry_computation()->root_instruction(); + auto i1 = root_tuple->operands()[0]; + auto i2 = root_tuple->operands()[1]; + EXPECT_TRUE(ArCrsCombiner::TestInstructionsComputeSameValue(i1, i2)); +} + +TEST_F(ArCrsCombinerTest, SameValueTestBasecase3) { + const char* module_str = R"( +HloModule foobar + +ENTRY %entrycomp (x: f32[], y: f32[]) -> (f32[], f32[]) { + %x = f32[] parameter(0) + %y = f32[] parameter(1) + ROOT %tuple = (f32[], f32[]) tuple(%x, %y) +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(module_str)); + auto root_tuple = module->entry_computation()->root_instruction(); + auto i1 = root_tuple->operands()[0]; + auto i2 = root_tuple->operands()[1]; + EXPECT_FALSE(ArCrsCombiner::TestInstructionsComputeSameValue(i1, i2)); +} + TEST_F(ArCrsCombinerTest, SameValueTestNumOperands) { const char* module_str = R"( HloModule foobar ENTRY %entrycomp (p: f32[2,2]) -> ((f32[2,2]), (f32[2,2], f32[2,2])) { %p = f32[2,2] parameter(0) - %constant.f32 = f32[2,2] constant(f32[2,2] {{1, 2}, {3, 4}}) + %constant.f32 = f32[2,2] constant({{1, 2}, {3, 4}}) %tuple1 = (f32[2,2]) tuple(%constant.f32) %tuple2 = (f32[2,2], f32[2,2]) tuple(%constant.f32, %constant.f32) ROOT %tuple = ((f32[2,2]), (f32[2,2], f32[2,2])) tuple(%tuple1, %tuple2) @@ -69,13 +106,53 @@ ENTRY %entrycomp (p: f32[2,2]) -> ((f32[2,2]), (f32[2,2], f32[2,2])) { EXPECT_FALSE(ArCrsCombiner::TestInstructionsComputeSameValue(i1, i2)); } +TEST_F(ArCrsCombinerTest, SameValueTestSliceIndicesMatch) { + const char* module_str = R"( +HloModule foobar + +ENTRY %entrycomp (p: f32[2]) -> (f32[1], f32[1]) { + %p = f32[2] parameter(0) + %slice.1 = f32[1] slice(f32[2] %p), slice={[0:1]} + %slice.2 = f32[1] slice(f32[2] %p), slice={[0:1]} + ROOT %tuple = (f32[1], f32[1]) tuple(%slice.1, %slice.2) +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(module_str)); + auto root_tuple = module->entry_computation()->root_instruction(); + auto i1 = root_tuple->operands()[0]; + auto i2 = root_tuple->operands()[1]; + EXPECT_TRUE(ArCrsCombiner::TestInstructionsComputeSameValue(i1, i2)); +} + +TEST_F(ArCrsCombinerTest, SameValueTestSliceIndicesDontMatch) { + const char* module_str = R"( +HloModule foobar + +ENTRY %entrycomp (p: f32[2]) -> (f32[1], f32[1]) { + %p = f32[2] parameter(0) + %slice.1 = f32[1] slice(f32[2] %p), slice={[0:1]} + %slice.2 = f32[1] slice(f32[2] %p), slice={[1:2]} + ROOT %tuple = (f32[1], f32[1]) tuple(%slice.1, %slice.2) +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(module_str)); + auto root_tuple = module->entry_computation()->root_instruction(); + auto i1 = root_tuple->operands()[0]; + auto i2 = root_tuple->operands()[1]; + EXPECT_FALSE(ArCrsCombiner::TestInstructionsComputeSameValue(i1, i2)); +} + TEST_F(ArCrsCombinerTest, SameValueTestTupleElementSameIndex) { const char* module_str = R"( HloModule foobar ENTRY %entrycomp (p: f32[2,2]) -> (f32[2,2], f32[2,2]) { %p = f32[2,2] parameter(0) - %constant.f32 = f32[2,2] constant(f32[2,2] {{1, 2}, {3, 4}}) + %constant.f32 = f32[2,2] constant({{1, 2}, {3, 4}}) %tuple.1 = (f32[2,2], f32[2,2]) tuple(%constant.f32, %constant.f32) %get-tuple-element.1 = f32[2,2] get-tuple-element(%tuple.1), index=0 %get-tuple-element.2 = f32[2,2] get-tuple-element(%tuple.1), index=0 @@ -97,7 +174,7 @@ HloModule foobar ENTRY %entrycomp (p: f32[2,2]) -> (f32[2,2], f32[2,2]) { %p = f32[2,2] parameter(0) - %constant.f32 = f32[2,2] constant(f32[2,2] {{1, 2}, {3, 4}}) + %constant.f32 = f32[2,2] constant({{1, 2}, {3, 4}}) %tuple.1 = (f32[2,2], f32[2,2]) tuple(%constant.f32, %constant.f32) %get-tuple-element.1 = f32[2,2] get-tuple-element(%tuple.1), index=0 %get-tuple-element.2 = f32[2,2] get-tuple-element(%tuple.1), index=1 @@ -119,8 +196,8 @@ HloModule foobar ENTRY %entrycomp (p: f32[2,2]) -> (f32[2,2], f32[2,2]) { %p = f32[2,2] parameter(0) - %constant.f32.1 = f32[2,2] constant(f32[2,2] {{1, 2}, {3, 4}}) - %constant.f32.2 = f32[2,2] constant(f32[2,2] {{2, 3}, {4, 5}}) + %constant.f32.1 = f32[2,2] constant({{1, 2}, {3, 4}}) + %constant.f32.2 = f32[2,2] constant({{2, 3}, {4, 5}}) %tuple.1 = (f32[2,2], f32[2,2]) tuple(%constant.f32.1, %constant.f32.2) %get-tuple-element.1 = f32[2,2] get-tuple-element(%tuple.1), index=0 %get-tuple-element.2 = f32[2,2] get-tuple-element(%tuple.1), index=1 @@ -149,7 +226,7 @@ HloModule foobar %body (x: (f32[2,2], f32[2,2])) -> (f32[2,2], f32[2,2]) { %x = (f32[2,2], f32[2,2]) parameter(0) - %constant.f32 = f32[2,2] constant(f32[2,2] {{1, 2}, {3, 4}}) + %constant.f32 = f32[2,2] constant({{1, 2}, {3, 4}}) %get-tuple-element.1 = f32[2,2] get-tuple-element(%x), index=0 %get-tuple-element.2 = f32[2,2] get-tuple-element(%x), index=1 %add.1 = f32[2,2] add(%get-tuple-element.1, %constant.f32) @@ -158,7 +235,7 @@ HloModule foobar } ENTRY %WhileLoop () -> (f32[2,2], f32[2,2]) { - %constant.f32 = f32[2,2] constant(f32[2,2] {{3, 4}, {5, 6}}) + %constant.f32 = f32[2,2] constant({{3, 4}, {5, 6}}) %init.tuple = (f32[2,2], f32[2,2]) tuple(%constant.f32, %constant.f32) ROOT %while = (f32[2,2], f32[2,2]) while(%init.tuple), condition=%condition, body=%body } @@ -186,7 +263,7 @@ HloModule foobar %body (x: (f32[2,2], f32[2,2])) -> (f32[2,2], f32[2,2]) { %x = (f32[2,2], f32[2,2]) parameter(0) - %constant.f32 = f32[2,2] constant(f32[2,2] {{1, 2}, {3, 4}}) + %constant.f32 = f32[2,2] constant({{1, 2}, {3, 4}}) %get-tuple-element.1 = f32[2,2] get-tuple-element(%x), index=0 %get-tuple-element.2 = f32[2,2] get-tuple-element(%x), index=1 %add.1 = f32[2,2] add(%get-tuple-element.1, %constant.f32) @@ -195,8 +272,8 @@ HloModule foobar } ENTRY %WhileLoop () -> (f32[2,2], f32[2,2]) { - %constant.f32.1 = f32[2,2] constant(f32[2,2] {{3, 4}, {5, 6}}) - %constant.f32.2 = f32[2,2] constant(f32[2,2] {{3, 4}, {7, 8}}) + %constant.f32.1 = f32[2,2] constant({{3, 4}, {5, 6}}) + %constant.f32.2 = f32[2,2] constant({{3, 4}, {7, 8}}) %init.tuple = (f32[2,2], f32[2,2]) tuple(%constant.f32.1, %constant.f32.2) ROOT %while = (f32[2,2], f32[2,2]) while(%init.tuple), condition=%condition, body=%body } @@ -224,8 +301,8 @@ HloModule foobar %body (x: (f32[2,2], f32[2,2])) -> (f32[2,2], f32[2,2]) { %x = (f32[2,2], f32[2,2]) parameter(0) - %constant.f32.1 = f32[2,2] constant(f32[2,2] {{1, 2}, {3, 4}}) - %constant.f32.2 = f32[2,2] constant(f32[2,2] {{3, 4}, {1, 2}}) + %constant.f32.1 = f32[2,2] constant({{1, 2}, {3, 4}}) + %constant.f32.2 = f32[2,2] constant({{3, 4}, {1, 2}}) %get-tuple-element.1 = f32[2,2] get-tuple-element(%x), index=0 %get-tuple-element.2 = f32[2,2] get-tuple-element(%x), index=1 %add.1 = f32[2,2] add(%get-tuple-element.1, %constant.f32.1) @@ -234,7 +311,7 @@ HloModule foobar } ENTRY %WhileLoop () -> (f32[2,2], f32[2,2]) { - %constant.f32 = f32[2,2] constant(f32[2,2] {{3, 4}, {5, 6}}) + %constant.f32 = f32[2,2] constant({{3, 4}, {5, 6}}) %init.tuple = (f32[2,2], f32[2,2]) tuple(%constant.f32, %constant.f32) ROOT %while = (f32[2,2], f32[2,2]) while(%init.tuple), condition=%condition, body=%body } @@ -249,11 +326,27 @@ ENTRY %WhileLoop () -> (f32[2,2], f32[2,2]) { EXPECT_FALSE(ArCrsCombiner::TestInstructionsComputeSameValue(i1, i2)); } -TEST_F(ArCrsCombinerTest, RewritePatternArConvertAddCrs) { +void CompareReplicaGroups(const std::vector& groups_before, + const std::vector& groups_after) { + ASSERT_EQ(groups_before.size(), groups_after.size()); + for (int i = 0; i < groups_before.size(); ++i) { + // Somewhat verbose way to compare the replica_ids, because EqualsProto + // is not available in the open-source build. + auto group_before = groups_before[i]; + std::vector ids_before(group_before.replica_ids().begin(), + group_before.replica_ids().end()); + auto group_after = groups_after[i]; + std::vector ids_after(group_after.replica_ids().begin(), + group_after.replica_ids().end()); + EXPECT_EQ(ids_before, ids_after); + } +} + +TEST_F(ArCrsCombinerTest, RewriteArConvertCrs) { const char* module_str = R"( HloModule foobar -%binary_add (a: bf16[], b: bf16[]) -> bf16[] { +%sum.bf16 (a: bf16[], b: bf16[]) -> bf16[] { %a = bf16[] parameter(0) %b = bf16[] parameter(1) ROOT %add = bf16[] add(%a, %b) @@ -265,48 +358,257 @@ HloModule foobar ROOT %add = f32[] add(%x, %y) } -ENTRY %entrycomp (p: f32[2,2]) -> (f32[2,2], f32[2,2]) { - %p = f32[2,2] parameter(0) - %constant.bf16 = bf16[2,2] constant(bf16[2,2] {{1, 2}, {3, 4}}) - %constant.f32 = f32[2,2] constant(f32[2,2] {{1, 2}, {3, 4}}) +ENTRY %entrycomp (p: bf16[]) -> (f32[], f32[]) { + %p = bf16[] parameter(0) + + %cross-replica-sum.ar.1 = bf16[] + cross-replica-sum(%p), + replica_groups={{0},{1}}, + all_reduce_id=1, + to_apply=%sum.bf16, + sharding={maximal device=0} + %convert.1 = f32[] + convert(%cross-replica-sum.ar.1), + sharding={maximal device=0} + %cross-replica-sum.1 = f32[] + cross-replica-sum(%convert.1), + replica_groups={{0,1}}, + to_apply=%sum.f32, + sharding={maximal device=0} + + %cross-replica-sum.ar.2 = bf16[] + cross-replica-sum(%p), + replica_groups={{0},{1}}, + all_reduce_id=1, + to_apply=%sum.bf16, + sharding={maximal device=1} + %convert.2 = f32[] + convert(%cross-replica-sum.ar.2), + sharding={maximal device=1} + %cross-replica-sum.2 = f32[] + cross-replica-sum(%convert.2), + replica_groups={{0,1}}, + to_apply=%sum.f32, + sharding={maximal device=1} - %cross-replica-sum.ar.1 = bf16[2,2] + ROOT %tuple = (f32[], f32[]) + tuple(%cross-replica-sum.1, %cross-replica-sum.2), + sharding={{maximal device=0}, {maximal device=1}} +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(module_str)); + auto crs_before = + module->entry_computation()->root_instruction()->operands()[0]; + auto replica_groups_before = crs_before->replica_groups(); + ArCrsCombiner combiner(2); + auto changed = combiner.Run(module.get()).ValueOrDie(); + EXPECT_TRUE(changed); + EXPECT_THAT(module->entry_computation()->root_instruction(), + op::Tuple(op::CrossReplicaSum(op::Convert(op::Parameter())), + op::CrossReplicaSum(op::Convert(op::Parameter())))); + auto crs_after = + module->entry_computation()->root_instruction()->operands()[0]; + auto replica_groups_after = crs_after->replica_groups(); + CompareReplicaGroups(replica_groups_before, replica_groups_after); +} + +TEST_F(ArCrsCombinerTest, RewriteArBitcastCrs) { + const char* module_str = R"( +HloModule foobar + +%sum.1 (a: f32[2,1], b: f32[2,1]) -> f32[2,1] { + %a = f32[2,1] parameter(0) + %b = f32[2,1] parameter(1) + ROOT %add = f32[2,1] add(%a, %b) +} + +%sum.2 (x: f32[2], y: f32[2]) -> f32[2] { + %x = f32[2] parameter(0) + %y = f32[2] parameter(1) + ROOT %add = f32[2] add(%x, %y) +} + +ENTRY %entrycomp (p: f32[2,1]) -> (f32[2], f32[2]) { + %p = f32[2,1] parameter(0) + + %cross-replica-sum.ar.1 = f32[2,1] + cross-replica-sum(%p), + replica_groups={{0},{1}}, + all_reduce_id=1, + to_apply=%sum.1, + sharding={maximal device=0} + %bitcast.1 = f32[2]{0} bitcast(f32[2,1]{1,0} %cross-replica-sum.ar.1) + %cross-replica-sum.1 = f32[2] + cross-replica-sum(%bitcast.1), + replica_groups={{0,1}}, + to_apply=%sum.2, + sharding={maximal device=0} + + %cross-replica-sum.ar.2 = f32[2,1] + cross-replica-sum(%p), + replica_groups={{0},{1}}, + all_reduce_id=1, + to_apply=%sum.1, + sharding={maximal device=1} + %bitcast.2 = f32[2]{0} bitcast(f32[2,1]{1,0} %cross-replica-sum.ar.2) + %cross-replica-sum.2 = f32[2] + cross-replica-sum(%bitcast.2), + replica_groups={{0,1}}, + to_apply=%sum.2, + sharding={maximal device=1} + + ROOT %tuple = (f32[], f32[]) + tuple(%cross-replica-sum.1, %cross-replica-sum.2), + sharding={{maximal device=0}, {maximal device=1}} +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(module_str)); + auto crs_before = + module->entry_computation()->root_instruction()->operands()[0]; + auto replica_groups_before = crs_before->replica_groups(); + ArCrsCombiner combiner(2); + auto changed = combiner.Run(module.get()).ValueOrDie(); + EXPECT_TRUE(changed); + EXPECT_THAT(module->entry_computation()->root_instruction(), + op::Tuple(op::CrossReplicaSum(op::Bitcast(op::Parameter())), + op::CrossReplicaSum(op::Bitcast(op::Parameter())))); + auto crs_after = + module->entry_computation()->root_instruction()->operands()[0]; + auto replica_groups_after = crs_after->replica_groups(); + CompareReplicaGroups(replica_groups_before, replica_groups_after); +} + +TEST_F(ArCrsCombinerTest, RewriteArMultiplyCrs) { + const char* module_str = R"( +HloModule foobar + +%sum.f32 (x: f32[], y: f32[]) -> f32[] { + %x = f32[] parameter(0) + %y = f32[] parameter(1) + ROOT %add = f32[] add(%x, %y) +} + +ENTRY %entrycomp (p: f32[]) -> (f32[], f32[]) { + %p = f32[] parameter(0) + %constant.f32 = f32[] constant(123) + + %cross-replica-sum.ar.1 = f32[] + cross-replica-sum(%p), + replica_groups={{0},{1}}, + all_reduce_id=1, + to_apply=%sum.f32, + sharding={maximal device=0} + %multiply.1 = f32[] + multiply(%cross-replica-sum.ar.1, %constant.f32), + sharding={maximal device=0} + %cross-replica-sum.1 = f32[] + cross-replica-sum(%multiply.1), + replica_groups={{0,1}}, + to_apply=%sum.f32, + sharding={maximal device=0} + + %cross-replica-sum.ar.2 = f32[] + cross-replica-sum(%p), + replica_groups={{0},{1}}, + all_reduce_id=1, + to_apply=%sum.f32, + sharding={maximal device=1} + %multiply.2 = f32[] + multiply(%cross-replica-sum.ar.2, %constant.f32), + sharding={maximal device=1} + %cross-replica-sum.2 = f32[] + cross-replica-sum(%multiply.2), + replica_groups={{0,1}}, + to_apply=%sum.f32, + sharding={maximal device=1} + + ROOT %tuple = (f32[], f32[]) + tuple(%cross-replica-sum.1, %cross-replica-sum.2), + sharding={{maximal device=0}, {maximal device=1}} +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(module_str)); + auto crs_before = + module->entry_computation()->root_instruction()->operands()[0]; + auto replica_groups_before = crs_before->replica_groups(); + ArCrsCombiner combiner(2); + auto changed = combiner.Run(module.get()).ValueOrDie(); + EXPECT_TRUE(changed); + EXPECT_THAT( + module->entry_computation()->root_instruction(), + op::Tuple( + op::CrossReplicaSum(op::Multiply(op::Parameter(), op::Constant())), + op::CrossReplicaSum(op::Multiply(op::Parameter(), op::Constant())))); + auto crs_after = + module->entry_computation()->root_instruction()->operands()[0]; + auto replica_groups_after = crs_after->replica_groups(); + CompareReplicaGroups(replica_groups_before, replica_groups_after); +} + +TEST_F(ArCrsCombinerTest, RewriteArConvertAddCrs) { + const char* module_str = R"( +HloModule foobar + +%sum.bf16 (a: bf16[], b: bf16[]) -> bf16[] { + %a = bf16[] parameter(0) + %b = bf16[] parameter(1) + ROOT %add = bf16[] add(%a, %b) +} + +%sum.f32 (x: f32[], y: f32[]) -> f32[] { + %x = f32[] parameter(0) + %y = f32[] parameter(1) + ROOT %add = f32[] add(%x, %y) +} + +ENTRY %entrycomp (p: f32[]) -> (f32[], f32[]) { + %p = f32[] parameter(0) + %constant.bf16 = bf16[] constant(1) + %constant.f32 = f32[] constant(2) + + %cross-replica-sum.ar.1 = bf16[] cross-replica-sum(%constant.bf16), replica_groups={{0},{1}}, all_reduce_id=1, - to_apply=%binary_add, + to_apply=%sum.bf16, sharding={maximal device=0} - %convert.1 = f32[2,2] + %convert.1 = f32[] convert(%cross-replica-sum.ar.1), sharding={maximal device=0} - %add.1 = f32[2,2] + %add.1 = f32[] add(%constant.f32, %convert.1), sharding={maximal device=0} - %cross-replica-sum.1 = f32[2,2] + %cross-replica-sum.1 = f32[] cross-replica-sum(%add.1), replica_groups={{0,1}}, to_apply=%sum.f32, sharding={maximal device=0} - %cross-replica-sum.ar.2 = bf16[2,2] + %cross-replica-sum.ar.2 = bf16[] cross-replica-sum(%constant.bf16), replica_groups={{0},{1}}, all_reduce_id=1, - to_apply=%binary_add, + to_apply=%sum.bf16, sharding={maximal device=1} - %convert.2 = f32[2,2] + %convert.2 = f32[] convert(%cross-replica-sum.ar.2), sharding={maximal device=1} - %add.2 = f32[2,2] + %add.2 = f32[] add(%constant.f32, %convert.2), sharding={maximal device=1} - %cross-replica-sum.2 = f32[2,2] + %cross-replica-sum.2 = f32[] cross-replica-sum(%add.2), replica_groups={{0,1}}, to_apply=%sum.f32, sharding={maximal device=1} - ROOT %tuple = (f32[2,2], f32[2,2]) + ROOT %tuple = (f32[], f32[]) tuple(%cross-replica-sum.1, %cross-replica-sum.2), sharding={{maximal device=0}, {maximal device=1}} } @@ -320,31 +622,24 @@ ENTRY %entrycomp (p: f32[2,2]) -> (f32[2,2], f32[2,2]) { ArCrsCombiner combiner(2); auto changed = combiner.Run(module.get()).ValueOrDie(); EXPECT_TRUE(changed); - EXPECT_THAT(module->entry_computation()->root_instruction(), - op::Tuple(op::Subtract(op::CrossReplicaSum(), op::Constant()), - op::Subtract(op::CrossReplicaSum(), op::Constant()))); - auto sub = module->entry_computation()->root_instruction()->operands()[0]; - auto crs_after = sub->operands()[0]; + EXPECT_THAT( + module->entry_computation()->root_instruction(), + op::Tuple( + op::CrossReplicaSum(op::Add( + op::Divide(op::Constant(), op::Constant()), op::Convert())), + op::CrossReplicaSum(op::Add( + op::Divide(op::Constant(), op::Constant()), op::Convert())))); + auto crs_after = + module->entry_computation()->root_instruction()->operands()[0]; auto replica_groups_after = crs_after->replica_groups(); - ASSERT_EQ(replica_groups_before.size(), replica_groups_after.size()); - for (int i = 0; i < replica_groups_before.size(); ++i) { - // Somewhat verbose way to compare the replica_ids, because EqualsProto - // is not available in the open-source build. - auto group_before = replica_groups_before[i]; - std::vector ids_before(group_before.replica_ids().begin(), - group_before.replica_ids().end()); - auto group_after = replica_groups_after[i]; - std::vector ids_after(group_after.replica_ids().begin(), - group_after.replica_ids().end()); - EXPECT_EQ(ids_before, ids_after); - } + CompareReplicaGroups(replica_groups_before, replica_groups_after); } TEST_F(ArCrsCombinerTest, OtherSummandNotTheSameDontRewrite) { const char* module_str = R"( HloModule foobar -%binary_add (a: bf16[], b: bf16[]) -> bf16[] { +%sum.bf16 (a: bf16[], b: bf16[]) -> bf16[] { %a = bf16[] parameter(0) %b = bf16[] parameter(1) ROOT %add = bf16[] add(%a, %b) @@ -356,49 +651,49 @@ HloModule foobar ROOT %add = f32[] add(%x, %y) } -ENTRY %entrycomp (p: f32[2,2]) -> (f32[2,2], f32[2,2]) { - %p = f32[2,2] parameter(0) - %constant.bf16 = bf16[2,2] constant(bf16[2,2] {{1, 2}, {3, 4}}) - %constant.f32.1 = f32[2,2] constant(f32[2,2] {{1, 2}, {3, 4}}) - %constant.f32.2 = f32[2,2] constant(f32[2,2] {{3, 4}, {5, 6}}) +ENTRY %entrycomp (p: f32[]) -> (f32[], f32[]) { + %p = f32[] parameter(0) + %constant.bf16 = bf16[] constant(1) + %constant.f32.1 = f32[] constant(2) + %constant.f32.2 = f32[] constant(3) - %cross-replica-sum.ar.1 = bf16[2,2] + %cross-replica-sum.ar.1 = bf16[] cross-replica-sum(%constant.bf16), replica_groups={{0},{1}}, all_reduce_id=1, - to_apply=%binary_add, + to_apply=%sum.bf16, sharding={maximal device=0} - %convert.1 = f32[2,2] + %convert.1 = f32[] convert(%cross-replica-sum.ar.1), sharding={maximal device=0} - %add.1 = f32[2,2] + %add.1 = f32[] add(%constant.f32.1, %convert.1), sharding={maximal device=0} - %cross-replica-sum.1 = f32[2,2] + %cross-replica-sum.1 = f32[] cross-replica-sum(%add.1), replica_groups={{0,1}}, to_apply=%sum.f32, sharding={maximal device=0} - %cross-replica-sum.ar.2 = bf16[2,2] + %cross-replica-sum.ar.2 = bf16[] cross-replica-sum(%constant.bf16), replica_groups={{0},{1}}, all_reduce_id=1, - to_apply=%binary_add, + to_apply=%sum.bf16, sharding={maximal device=1} - %convert.2 = f32[2,2] + %convert.2 = f32[] convert(%cross-replica-sum.ar.2), sharding={maximal device=1} - %add.2 = f32[2,2] + %add.2 = f32[] add(%constant.f32.2, %convert.2), sharding={maximal device=1} - %cross-replica-sum.2 = f32[2,2] + %cross-replica-sum.2 = f32[] cross-replica-sum(%add.2), replica_groups={{0,1}}, to_apply=%sum.f32, sharding={maximal device=1} - ROOT %tuple = (f32[2,2], f32[2,2]) + ROOT %tuple = (f32[], f32[]) tuple(%cross-replica-sum.1, %cross-replica-sum.2), sharding={{maximal device=0}, {maximal device=1}} } diff --git a/tensorflow/compiler/xla/service/batchnorm_expander.cc b/tensorflow/compiler/xla/service/batchnorm_expander.cc index f70f6ddfec69c0113a1afe2073a2392098f49456..0e6ca1871b379a2f55b92207133822fc6258b007 100644 --- a/tensorflow/compiler/xla/service/batchnorm_expander.cc +++ b/tensorflow/compiler/xla/service/batchnorm_expander.cc @@ -107,19 +107,37 @@ class BatchNormExpanderVisitor : public DfsHloVisitorWithDefault { } std::unique_ptr Mean( - int64 element_count, HloInstruction* operand, + HloInstruction* element_count, HloInstruction* operand, const std::function)>& add_instruction) { - HloInstruction* elem_count_recip = - add_instruction(HloInstruction::CreateBroadcast( - operand->shape(), - add_instruction(HloInstruction::CreateConvert( - ShapeUtil::MakeShape(operand->shape().element_type(), {}), - add_instruction(HloInstruction::CreateConstant( - LiteralUtil::CreateR0(1.0 / element_count))))), - {})); - return HloInstruction::CreateBinary(operand->shape(), HloOpcode::kMultiply, - operand, elem_count_recip); + auto broadcast = add_instruction( + HloInstruction::CreateBroadcast(operand->shape(), element_count, {})); + return HloInstruction::CreateBinary(operand->shape(), HloOpcode::kDivide, + operand, broadcast); + } + + std::unique_ptr DynamicElementCountPerFeature( + HloInstruction* operand, int64 feature_index, + const std::function)>& + add_instruction) { + auto elements_per_feature_u32 = add_instruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1))); + + for (int64 i = 0; i < ShapeUtil::Rank(operand->shape()); ++i) { + if (i == feature_index) { + continue; + } + auto dynamic_dimension_size = + add_instruction(HloInstruction::CreateGetDimensionSize( + ShapeUtil::MakeShape(U32, {}), operand, i)); + elements_per_feature_u32 = add_instruction(HloInstruction::CreateBinary( + ShapeUtil::MakeShape(U32, {}), HloOpcode::kMultiply, + dynamic_dimension_size, elements_per_feature_u32)); + } + + return HloInstruction::CreateConvert( + ShapeUtil::MakeShape(operand->shape().element_type(), {}), + elements_per_feature_u32); } // Replaces the existing HLO instruction old_instruction, with @@ -195,9 +213,6 @@ Status BatchNormExpanderVisitor::HandleBatchNormTraining( const Shape operand_shape = operand->shape(); PrimitiveType ptype = operand_shape.element_type(); int64 feature_index = batch_norm->feature_index(); - const int64 feature_count = operand_shape.dimensions(feature_index); - const int64 size_in_elements = ShapeUtil::ElementsIn(operand_shape); - int64 elements_per_feature_int64 = size_in_elements / feature_count; HloInstruction* scale = batch_norm->mutable_operand(1); HloInstruction* offset = batch_norm->mutable_operand(2); @@ -220,6 +235,9 @@ Status BatchNormExpanderVisitor::HandleBatchNormTraining( } } + auto elements_per_feature = + add(DynamicElementCountPerFeature(operand, feature_index, add)); + auto scale_broadcasted = add( HloInstruction::CreateBroadcast(operand_shape, scale, {feature_index})); @@ -243,13 +261,13 @@ Status BatchNormExpanderVisitor::HandleBatchNormTraining( add_reduce_computation)); // E[X]. - auto mean = add(Mean(elements_per_feature_int64, sum, add)); + auto mean = add(Mean(elements_per_feature, sum, add)); auto mean_broadcasted = add( HloInstruction::CreateBroadcast(operand_shape, mean, {feature_index})); // E[X^2]. - auto square_mean = add(Mean(elements_per_feature_int64, squared_sum, add)); + auto square_mean = add(Mean(elements_per_feature, squared_sum, add)); // E^2[X]. auto mean_square = @@ -458,9 +476,8 @@ Status BatchNormExpanderVisitor::HandleBatchNormGrad( int64 feature_index = batch_norm->feature_index(); - const int64 size_in_elements = ShapeUtil::ElementsIn(activation_shape); - const int64 feature_count = activation_shape.dimensions(feature_index); - const int64 elements_per_feature_int64 = size_in_elements / feature_count; + auto elements_per_feature = + add(DynamicElementCountPerFeature(activation, feature_index, add)); auto zero_literal = LiteralUtil::CreateR0(0.0f); TF_ASSIGN_OR_RETURN(zero_literal, zero_literal.Convert(ptype)); @@ -553,15 +570,9 @@ Status BatchNormExpanderVisitor::HandleBatchNormGrad( add_binary(activation_shape, HloOpcode::kMultiply, scale_broadcasted, rsqrt_var_add_epsilon_broadcasted); - scale_times_rsqrt_var_add_epsilon = add( - Mean(elements_per_feature_int64, scale_times_rsqrt_var_add_epsilon, add)); + scale_times_rsqrt_var_add_epsilon = + add(Mean(elements_per_feature, scale_times_rsqrt_var_add_epsilon, add)); - auto elements_per_feature_literal = - LiteralUtil::CreateR0(elements_per_feature_int64); - TF_ASSIGN_OR_RETURN(elements_per_feature_literal, - elements_per_feature_literal.Convert(ptype)); - auto elements_per_feature = add( - HloInstruction::CreateConstant(std::move(elements_per_feature_literal))); auto i1 = add_binary(activation_shape, HloOpcode::kMultiply, grad_output, add(HloInstruction::CreateBroadcast( activation_shape, elements_per_feature, {}))); diff --git a/tensorflow/compiler/xla/service/batchnorm_expander_test.cc b/tensorflow/compiler/xla/service/batchnorm_expander_test.cc index 08cf8026177d77ff98cca5e5d168ac3194936b35..8e8fbbd935b154e5a77d68e60d861601d740bf03 100644 --- a/tensorflow/compiler/xla/service/batchnorm_expander_test.cc +++ b/tensorflow/compiler/xla/service/batchnorm_expander_test.cc @@ -36,7 +36,21 @@ limitations under the License. namespace xla { namespace { -using BatchNormExpanderTest = HloTestBase; +class BatchNormExpanderTest : public HloTestBase { + protected: + // BatchNorm should have a dynamic sized dividor for mean operations. + int64 CountGetDimensionSize(const HloModule& module) { + int64 count = 0; + for (HloComputation* comp : module.computations()) { + for (HloInstruction* inst : comp->instructions()) { + if (inst->opcode() == HloOpcode::kGetDimensionSize) { + count++; + } + } + } + return count; + } +}; // Test that we expand BatchNormTraining. TEST_F(BatchNormExpanderTest, BatchNormTraining) { @@ -68,6 +82,7 @@ TEST_F(BatchNormExpanderTest, BatchNormTraining) { /*rewrite_grad_op=*/true); ASSERT_TRUE(rewriter.Run(module.get()).ValueOrDie()); root = computation->root_instruction(); + EXPECT_EQ(CountGetDimensionSize(*module), 3); // Make sure this operation is expanded. EXPECT_EQ(root->opcode(), HloOpcode::kTuple); } @@ -110,6 +125,7 @@ TEST_F(BatchNormExpanderTest, BatchNormGrad) { /*rewrite_grad_op=*/true); ASSERT_TRUE(rewriter.Run(module.get()).ValueOrDie()); root = computation->root_instruction(); + EXPECT_EQ(CountGetDimensionSize(*module), 3); // Make sure this operation is expanded. EXPECT_EQ(root->opcode(), HloOpcode::kTuple); } diff --git a/tensorflow/compiler/xla/service/convolution_feature_group_converter.cc b/tensorflow/compiler/xla/service/convolution_feature_group_converter.cc index 09c3f32860b3176ee5afbb147872ddafc51af256..95c7724c3c93507ae61a984301ecfc0111bef192 100644 --- a/tensorflow/compiler/xla/service/convolution_feature_group_converter.cc +++ b/tensorflow/compiler/xla/service/convolution_feature_group_converter.cc @@ -205,38 +205,6 @@ Status ConvolutionVisitor::HandleConvolution(HloInstruction* convolution) { // If the code generator handles depthwise separable convolutions // inherently, then no filter expansion is needed. if (!filter_expansion_ && depthwise_separable) { - const int64 old_kernel_input_feature_dimension = - dim_numbers.kernel_input_feature_dimension(); - const int64 old_kernel_output_feature_dimension = - dim_numbers.kernel_output_feature_dimension(); - - // For depthwise convolutions, we want the kernel input feature dimension - // to be smaller than the output feature dimension. If that's not the - // case, we swap the dimensions. - if (old_kernel_input_feature_dimension > - old_kernel_output_feature_dimension) { - Shape reshaped_filter_shape = filter->shape(); - auto& dimensions = *reshaped_filter_shape.mutable_dimensions(); - std::swap(dimensions[old_kernel_input_feature_dimension], - dimensions[old_kernel_output_feature_dimension]); - - auto reshaped_filter = - add(HloInstruction::CreateReshape(reshaped_filter_shape, filter)); - - dim_numbers.set_kernel_input_feature_dimension( - old_kernel_output_feature_dimension); - - dim_numbers.set_kernel_output_feature_dimension( - old_kernel_input_feature_dimension); - - auto new_convolution = HloInstruction::CreateConvolve( - convolution->shape(), convolution->mutable_operand(0), - reshaped_filter, group_count, convolution->window(), dim_numbers, - convolution->precision_config()); - - TF_RETURN_IF_ERROR(computation_->ReplaceWithNewInstruction( - convolution, std::move(new_convolution))); - } return Status::OK(); } // We want to repeat 'filter' in the 'input_feature_dim' dimension @@ -271,130 +239,72 @@ Status ConvolutionVisitor::HandleConvolution(HloInstruction* convolution) { convolution, std::move(new_convolution))); } else { int64 activation_input_feature_dim = dim_numbers.input_feature_dimension(); - auto activation = convolution->mutable_operand(0); int64 output_feature = filter->shape().dimensions(kernel_output_feature_dim); - int64 input_feature = - activation->shape().dimensions(activation_input_feature_dim); - // If group_count == output_feature, then we map those grouped convolutions - // onto depthwise convolution + reduce. E.g., we would turn + // onto depthwise convolution. This is done by adding an additional spatial + // dimension to the activations, kernel, and the output. + // E.g., we would turn // [2, 12]{B, IF} conv [3, 4]{IF, OF} into - // [2, 12]{B, IF} depth conv [1, 12]{IF, OF}, and then use a reduce window - // of {1, 3} on the generated [2, 12] output to produce the final result of - // [2, 4]. + // [3, 2, 4]{S, B, IF} depth conv [3, 1, 4]{S, IF, OF}, where S is the + // additional spatial dimension. The generated convolution output will be + // [1, 2, 4]{S, B, OF} and then reshape the output back to [2, 4] {B, OF}. + if (group_count == output_feature && !filter_expansion_) { - Shape reshaped_filter_shape = filter->shape(); + auto filter = convolution->mutable_operand(1); + auto activation = convolution->mutable_operand(0); - if (kernel_input_feature_dim < kernel_output_feature_dim) { - // Transpose IF and OF on the kernel. - std::vector filter_dims; - for (int64 i = 0; i < dim_numbers.kernel_spatial_dimensions().size(); - ++i) { - filter_dims.push_back(dim_numbers.kernel_spatial_dimensions(i)); - } - filter_dims.push_back(kernel_output_feature_dim); - filter_dims.push_back(kernel_input_feature_dim); - - Shape transposed_filter = filter->shape(); - auto& dimensions = *transposed_filter.mutable_dimensions(); - std::swap(dimensions[kernel_input_feature_dim], - dimensions[kernel_output_feature_dim]); - - filter = add(HloInstruction::CreateTranspose(transposed_filter, filter, - filter_dims)); - } else { - // For depthwise convolutions, we want the kernel input feature - // dimension to be smaller than the output feature dimension. If that's - // not the case, we swap the dimensions. - - auto& dimensions = *reshaped_filter_shape.mutable_dimensions(); - std::swap(dimensions[kernel_input_feature_dim], - dimensions[kernel_output_feature_dim]); - - dim_numbers.set_kernel_input_feature_dimension( - kernel_output_feature_dim); - - dim_numbers.set_kernel_output_feature_dimension( - kernel_input_feature_dim); - std::swap(kernel_output_feature_dim, kernel_input_feature_dim); - } + // Add spatial dimension to the activation, and reshape. + Shape reshaped_activation_shape = activation->shape(); + ShapeUtil::AppendMajorDimension(group_size, &reshaped_activation_shape); - reshaped_filter_shape.set_dimensions(kernel_input_feature_dim, 1); - reshaped_filter_shape.set_dimensions(kernel_output_feature_dim, - group_count * group_size); - auto reshaped_filter = - add(HloInstruction::CreateReshape(reshaped_filter_shape, filter)); + int64 new_spatial_dim = reshaped_activation_shape.dimensions().size() - 1; - Shape reshaped_convolution_shape = convolution->shape(); - reshaped_convolution_shape.set_dimensions( - dim_numbers.output_feature_dimension(), group_count * group_size); - auto new_convolution = add(HloInstruction::CreateConvolve( - reshaped_convolution_shape, convolution->mutable_operand(0), - reshaped_filter, /*feature_group_count=*/input_feature, - convolution->window(), dim_numbers, convolution->precision_config())); - - // Create the reduce window. - Window window; - for (int64 i = 0; i < new_convolution->shape().dimensions_size(); ++i) { - auto* dim = window.add_dimensions(); - dim->set_padding_low(0); - dim->set_padding_high(0); - dim->set_window_dilation(1); - dim->set_base_dilation(1); - if (i == dim_numbers.output_feature_dimension()) { - dim->set_stride(group_size); - dim->set_size(group_size); - } else { - dim->set_stride(1); - dim->set_size(1); - } - } + reshaped_activation_shape.set_dimensions(activation_input_feature_dim, + group_count); + activation = add( + HloInstruction::CreateReshape(reshaped_activation_shape, activation)); - auto reduce_window_shape = new_convolution->shape(); - reduce_window_shape.set_dimensions(dim_numbers.output_feature_dimension(), - group_count); - - auto zero_literal = LiteralUtil::CreateR0(0.0f); - TF_ASSIGN_OR_RETURN(zero_literal, zero_literal.Convert(F32)); - auto zero = add(HloInstruction::CreateConstant(std::move(zero_literal))); - - auto reduce_function = [&]() -> HloComputation* { - HloComputation::Builder b("add_computation"); - Shape shape = ShapeUtil::MakeShape(F32, {}); - auto lhs = - b.AddInstruction(HloInstruction::CreateParameter(0, shape, "lhs")); - auto rhs = - b.AddInstruction(HloInstruction::CreateParameter(1, shape, "rhs")); - auto scalar_op = b.AddInstruction( - HloInstruction::CreateBinary(shape, HloOpcode::kAdd, lhs, rhs)); - return computation_->parent()->AddEmbeddedComputation( - b.Build(scalar_op)); - }; - - // Ensure that data input to reduce window is of type F32. - if (primitive_util::BitWidth(new_convolution->shape().element_type()) < - primitive_util::BitWidth(F32)) { - Shape convert_shape = new_convolution->shape(); - convert_shape.set_element_type(F32); - new_convolution = add(HloInstruction::CreateBitcastConvert( - convert_shape, new_convolution)); - } + // Add spatial dimension to the filter, and reshape. + Shape reshaped_filter_shape = filter->shape(); + ShapeUtil::AppendMajorDimension(1, &reshaped_filter_shape); - auto reduce_window = add(HloInstruction::CreateReduceWindow( - reduce_window_shape, new_convolution, zero, window, - reduce_function())); + filter = + add(HloInstruction::CreateReshape(reshaped_filter_shape, filter)); + + Shape new_output_shape = convolution->shape(); + ShapeUtil::AppendMajorDimension(1, &new_output_shape); + + // Edit convolution dimension numbers. Note that kernel_input_feature_dim + // now becomes a spatial dimension, and the newly added dimension of size + // 1 is the new kernel_input_feature_dim. + dim_numbers.add_input_spatial_dimensions(new_spatial_dim); + dim_numbers.add_kernel_spatial_dimensions(kernel_input_feature_dim); + dim_numbers.set_kernel_input_feature_dimension(new_spatial_dim); + dim_numbers.add_output_spatial_dimensions(new_spatial_dim); + + // Add window for the new spatial dimension. + Window new_window = convolution->window(); + auto* dim = new_window.add_dimensions(); + dim->set_window_dilation(1); + dim->set_base_dilation(1); + dim->set_stride(1); + dim->set_size(group_size); + + auto new_convolution = add(HloInstruction::CreateConvolve( + new_output_shape, activation, filter, group_count, new_window, + dim_numbers, convolution->precision_config())); - Shape convert_back_shape = reduce_window->shape(); - convert_back_shape.set_element_type(activation->shape().element_type()); + // Delete the extra spatial dimension, and reshape. + Shape reshaped_convolution_shape = + ShapeUtil::DeleteDimension(new_spatial_dim, new_convolution->shape()); + auto reshaped_convolution = HloInstruction::CreateReshape( + reshaped_convolution_shape, new_convolution); - // Convert reduced data back to the original data type. - auto reduce_window_converted = HloInstruction::CreateBitcastConvert( - convert_back_shape, reduce_window); TF_RETURN_IF_ERROR(computation_->ReplaceWithNewInstruction( - convolution, std::move(reduce_window_converted))); + convolution, std::move(reshaped_convolution))); } else { // The filter expansion mechanism adds zeroes in the kernel. diff --git a/tensorflow/compiler/xla/service/cpu/BUILD b/tensorflow/compiler/xla/service/cpu/BUILD index ce4c2a9cc69240b9565b35a3f2504d7fc9373917..4173af5179ba096523db973ca7e0466faefda38a 100644 --- a/tensorflow/compiler/xla/service/cpu/BUILD +++ b/tensorflow/compiler/xla/service/cpu/BUILD @@ -572,6 +572,7 @@ cc_library( ":runtime_matvec", "//tensorflow/compiler/xla:executable_run_options", "//tensorflow/core:framework_lite", + "//tensorflow/core/kernels:eigen_contraction_kernel", "//third_party/eigen3", ], ) diff --git a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc index 2bf24c15c1f050b200b1d9af2d95286f9a9dbe4c..f3dfa4d64264808e0d5c9f86693bb844b2011964 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc @@ -250,7 +250,6 @@ Status CpuCompiler::RunHloPassesThroughLayoutAssn( &pipeline, module->config().debug_options(), ReducePrecisionInsertion::PassTiming::BEFORE_OPTIMIZATION); - pipeline.AddPass(); pipeline.AddPass(); // TODO(b/65775800): Fix wrong output bug in Call and remove the CallInliner @@ -270,6 +269,7 @@ Status CpuCompiler::RunHloPassesThroughLayoutAssn( /*rewrite_training_op=*/true, /*rewrite_inference_op=*/true, /*rewrite_grad_op=*/true); + pipeline.AddPass(); AlgebraicSimplifierOptions options( [](const Shape&, const Shape&) { return false; }); options.set_enable_dot_strength_reduction(false); @@ -635,18 +635,17 @@ StatusOr> CpuCompiler::RunBackend( .EmitComputation( embedded_computation, embedded_computation->name(), /*is_top_level_computation=*/false, - &schedule.sequence(embedded_computation).instructions()) + schedule.sequence(embedded_computation).instructions()) .status()); } string function_name_prefix = entry_computation->name().empty() ? "__compute" : entry_computation->name(); - TF_ASSIGN_OR_RETURN( - llvm::Function * entry_function, - ir_emitter.EmitComputation( - entry_computation, function_name_prefix, - /*is_top_level_computation=*/true, - &schedule.sequence(entry_computation).instructions())); + TF_ASSIGN_OR_RETURN(llvm::Function * entry_function, + ir_emitter.EmitComputation( + entry_computation, function_name_prefix, + /*is_top_level_computation=*/true, + schedule.sequence(entry_computation).instructions())); string function_name = [&]() { llvm::SmallVector function_name_vector; @@ -835,7 +834,7 @@ CpuCompiler::CompileAheadOfTime(std::unique_ptr module_group, .EmitComputation( embedded_computation, embedded_computation->name(), /*is_top_level_computation=*/false, - &schedule.sequence(embedded_computation).instructions()) + schedule.sequence(embedded_computation).instructions()) .status()); } const string& entry_point_name = options.entry_point_name(); @@ -843,7 +842,7 @@ CpuCompiler::CompileAheadOfTime(std::unique_ptr module_group, ir_emitter.EmitComputation( computation, entry_point_name, /*is_top_level_computation=*/true, - &schedule.sequence(computation).instructions())); + schedule.sequence(computation).instructions())); CHECK(entry_function->getName() == llvm_ir::AsStringRef(entry_point_name)); diff --git a/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc b/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc index 97f9b85a606e140fd7f3b1e3ecfb0dd5ba289f03..a33035ad1081d7d73ceed6ce3a208af5910d2d2c 100644 --- a/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc +++ b/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc @@ -323,11 +323,11 @@ void ColumnMajorMatrixVectorProductEmitter::Emit() { int64 column_remainder = k() % tile_cols(); int64 column_limit = k() - column_remainder; - ksl_.ForReturnVoid("dot.outer.tiled", - /*start=*/0, /*end=*/column_limit, /*step=*/tile_cols(), - [&](llvm::Value* column, bool is_first_column) { - EmitOuterLoopBody(column, tile_cols(), is_first_column); - }); + ksl_.For("dot.outer.tiled", + /*start=*/0, /*end=*/column_limit, /*step=*/tile_cols(), + [&](llvm::Value* column, bool is_first_column) { + EmitOuterLoopBody(column, tile_cols(), is_first_column); + }); if (column_remainder != 0) { EmitOuterLoopBody(b_->getInt64(column_limit), column_remainder, @@ -340,7 +340,7 @@ void ColumnMajorMatrixVectorProductEmitter::EmitInnerLoopTiled( int64 columns, bool is_first_column) { int64 row_limit = m() - (m() % tile_rows()); - ksl_.ForReturnVoid( + ksl_.For( "dot.inner.tiled", /*start=*/0, /*end=*/row_limit, /*step=*/tile_rows(), [&](llvm::Value* row) { std::vector lhs_tile = @@ -372,7 +372,7 @@ void ColumnMajorMatrixVectorProductEmitter::EmitInnerLoopEpilogue( // // initialized. // } - ksl_.ForReturnVoid( + ksl_.For( "dot.inner.epilg.outer", /*start=*/current_tile_col, /*end=*/b_->CreateAdd(columns_llvm, current_tile_col), /*step=*/1, /*peel_first_iteration=*/false, @@ -381,14 +381,14 @@ void ColumnMajorMatrixVectorProductEmitter::EmitInnerLoopEpilogue( llvm::Value* total_offset = b_->CreateMul(col, b_->getInt64(m())); llvm::Value* lhs_base_pointer = vsl_.ComputeOffsetPointer(lhs_, total_offset); - ksl_.ForReturnVoid( + ksl_.For( "dot.inner.epilg.inner", /*start=*/row_start, /*end=*/m(), /*step=*/1, [&](llvm::Value* scalar_row) { llvm::Value* product = vsl_.Mul( vsl_.LoadScalar(lhs_base_pointer, scalar_row), rhs_element); llvm::Value* setting_result_first_time = b_->CreateAnd( is_first_scalar_col, b_->getInt1(is_first_tiled_column)); - ksl_.IfReturnVoid( + ksl_.If( setting_result_first_time, /*true_block_generator=*/ [&]() { @@ -568,10 +568,9 @@ void RowMajorMatrixVectorProductEmitter::Emit() { int64 row_remainder = m() % tile_rows(); int64 row_limit = m() - row_remainder; - ksl_.ForReturnVoid( - "dot.outer.tiled", - /*start=*/0, /*end=*/row_limit, /*step=*/tile_rows(), - [&](llvm::Value* row) { EmitOuterLoopBody(row, tile_rows()); }); + ksl_.For("dot.outer.tiled", + /*start=*/0, /*end=*/row_limit, /*step=*/tile_rows(), + [&](llvm::Value* row) { EmitOuterLoopBody(row, tile_rows()); }); if (row_remainder != 0) { EmitOuterLoopBody(b_->getInt64(row_limit), row_remainder); @@ -583,17 +582,17 @@ void RowMajorMatrixVectorProductEmitter::EmitInnerLoopTiled( std::vector* vector_accumulators) { int64 column_limit = k() - (k() % tile_cols()); - ksl_.ForReturnVoid("dot.inner.tiled", /*start=*/0, /*end=*/column_limit, - /*step=*/tile_cols(), [&](llvm::Value* col) { - std::vector lhs_tile = - lhs_memory_tile->LoadTile(/*minor_dim_offset=*/col); - llvm::Value* rhs_value = vsl_.LoadVector(rhs_, col); - for (int i = 0; i < rows; i++) { - llvm::Value* old_sum = (*vector_accumulators)[i].Get(); - (*vector_accumulators)[i].Set(vsl_.Add( - old_sum, vsl_.Mul(rhs_value, lhs_tile[i]))); - } - }); + ksl_.For("dot.inner.tiled", /*start=*/0, /*end=*/column_limit, + /*step=*/tile_cols(), [&](llvm::Value* col) { + std::vector lhs_tile = + lhs_memory_tile->LoadTile(/*minor_dim_offset=*/col); + llvm::Value* rhs_value = vsl_.LoadVector(rhs_, col); + for (int i = 0; i < rows; i++) { + llvm::Value* old_sum = (*vector_accumulators)[i].Get(); + (*vector_accumulators)[i].Set( + vsl_.Add(old_sum, vsl_.Mul(rhs_value, lhs_tile[i]))); + } + }); } void RowMajorMatrixVectorProductEmitter::EmitInnerLoopEpilogue( @@ -609,7 +608,7 @@ void RowMajorMatrixVectorProductEmitter::EmitInnerLoopEpilogue( b_->CreateAdd(b_->getInt64(r), current_tile_row), b_->getInt64(k())); llvm::Value* lhs_base_pointer = vsl_.ComputeOffsetPointer(lhs_, total_offset); - ksl_.ForReturnVoid( + ksl_.For( "dot.inner.epilg.inner", /*start=*/column_start, /*end=*/k(), /*step=*/1, [&](llvm::Value* scalar_col) { llvm::Value* product = @@ -813,7 +812,7 @@ void TiledSmallGemmEmitter::HandleResiduesOnN() { if (n_start != dims().n()) { VectorSupportLibrary vsl(scalar_type(), 1, b_, "gemm"); - ksl_.ForReturnVoid("epi.n", n_start, dims().n(), 1, [&](llvm::Value* n_i) { + ksl_.For("epi.n", n_start, dims().n(), 1, [&](llvm::Value* n_i) { llvm::Value* n_i_next = b_->CreateAdd(n_i, b_->getInt64(1)); HandleResiduesOnK(&vsl, n_i, n_i_next); }); @@ -924,7 +923,7 @@ void TiledSmallGemmEmitter::EmitTiledGemm( VectorSupportLibrary* vsl, int64 tile_size_k, llvm::Value* k_start, llvm::Value* k_end, llvm::Value* n_start, llvm::Value* n_end, int64 tile_size_m, llvm::Value* m_start, llvm::Value* m_end) { - ksl_.ForReturnVoid( + ksl_.For( "dot.m", m_start, m_end, tile_size_m, [&](llvm::Value* m_i) { MemoryTile result_memory_tile( vsl, b_, /*matrix=*/result_, @@ -935,11 +934,11 @@ void TiledSmallGemmEmitter::EmitTiledGemm( /*matrix_size_along_minor_dim=*/dims().k(), /*major_dim_offset=*/m_i, /*tile_size_along_major_dim=*/tile_size_m); - ksl_.ForReturnVoid( + ksl_.For( "dot.n", n_start, n_end, vsl->vector_size(), [&](llvm::Value* n_i) { TileVariable result_tile_var(vsl, result_memory_tile.LoadTile(n_i)); - ksl_.ForReturnVoid( + ksl_.For( "dot.k", k_start, k_end, tile_size_k, [&](llvm::Value* k_i) { MemoryTile rhs_memory_tile(vsl, b_, rhs_, dims().n(), k_i, tile_size_k); diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc index 4032c2da2f33ee61da8771ae6225a14172cbe6e8..62a4e8d3507a4e678e80c1abea680c030d048de5 100644 --- a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc +++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc @@ -111,10 +111,9 @@ IrEmitter::IrEmitter( StatusOr IrEmitter::EmitComputation( HloComputation* computation, const string& function_name_prefix, bool is_top_level_computation, - const std::vector* instruction_order) { + absl::Span instruction_order) { string function_name = name_uniquer_.GetUniqueName(function_name_prefix); - VLOG(2) << "Emitting IR for CPU function [" << function_name_prefix - << "]; ordered? " << (instruction_order != nullptr); + VLOG(2) << "Emitting IR for CPU function [" << function_name_prefix << "]"; is_top_level_computation_ = is_top_level_computation; num_dynamic_loop_bounds_ = 0; if (!computation->root_instruction()->outer_dimension_partitions().empty()) { @@ -141,11 +140,7 @@ StatusOr IrEmitter::EmitComputation( bool use_rdtscp = arch_type_ == llvm::Triple::ArchType::x86 || arch_type_ == llvm::Triple::ArchType::x86_64; profiling_state_ = ProfilingState(use_rdtscp); - if (instruction_order == nullptr) { - TF_RETURN_IF_ERROR(computation->Accept(this)); - } else { - TF_RETURN_IF_ERROR(computation->AcceptOrdered(this, *instruction_order)); - } + TF_RETURN_IF_ERROR(computation->AcceptOrdered(this, instruction_order)); llvm::Function* ir_function = compute_function_->function(); InsertOrDie(&emitted_functions_, computation, ir_function); // Delete 'compute_function', finalizing 'ir_function' and restoring caller @@ -2271,6 +2266,22 @@ Status IrEmitter::HandleCustomCall(HloInstruction* custom_call) { /*isVarArg=*/false))); TF_RETURN_IF_ERROR(EmitTargetAddressForOp(custom_call)); + // Write the tuple table if the output is a tuple. + if (ShapeUtil::IsTuple(custom_call->shape())) { + std::vector base_ptrs; + for (int i = 0; i < ShapeUtil::TupleElementCount(custom_call->shape()); + ++i) { + const Shape& elem_shape = + ShapeUtil::GetTupleElementShape(custom_call->shape(), i); + TF_RET_CHECK(!ShapeUtil::IsTuple(elem_shape)) + << "Nested tuples not implemented"; + TF_ASSIGN_OR_RETURN(const BufferAllocation::Slice slice, + assignment_.GetUniqueSlice(custom_call, {i})); + llvm::Value* addr = EmitBufferPointer(slice, elem_shape); + base_ptrs.push_back(addr); + } + llvm_ir::EmitTuple(GetIrArrayFor(custom_call), base_ptrs, &b_, module_); + } auto* output_address_arg = PointerCast(GetEmittedValueFor(custom_call), i8_ptr_type); diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.h b/tensorflow/compiler/xla/service/cpu/ir_emitter.h index 559a8162a2d53f28ea6817653503c216af90a610..1db75cc8becea80f121289a843d4eb16ee9a8c8a 100644 --- a/tensorflow/compiler/xla/service/cpu/ir_emitter.h +++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.h @@ -101,7 +101,7 @@ class IrEmitter : public DfsHloVisitorWithDefault, StatusOr EmitComputation( HloComputation* computation, const string& function_name_prefix, bool is_top_level_computation, - const std::vector* instruction_order); + absl::Span instruction_order); llvm::IRBuilder<>* b() { return &b_; } diff --git a/tensorflow/compiler/xla/service/cpu/parallel_task_assignment_test.cc b/tensorflow/compiler/xla/service/cpu/parallel_task_assignment_test.cc index f0b65046c14ccec5336abf7c4d05d1d755f783bd..35ae62b42dfa768c6abd0508097d6b235b2ebf54 100644 --- a/tensorflow/compiler/xla/service/cpu/parallel_task_assignment_test.cc +++ b/tensorflow/compiler/xla/service/cpu/parallel_task_assignment_test.cc @@ -112,10 +112,10 @@ TEST_F(ParallelTaskAssignmentTest, InfeedOutfeedOperationNotParallelized) { const string hlo_string = R"( HloModule TestTaskParallel_infeed_outfeed ENTRY InfeedOutfeed { - token = token[] after-all() - infeed0 = (u32[12345678,2]{1,0}, token[]) infeed(token) + token0 = token[] after-all() + infeed0 = (u32[12345678,2]{1,0}, token[]) infeed(token0) infeed0.data = u32[12345678,2]{1,0} get-tuple-element((u32[12345678,2]{1,0}, token[]) infeed0), index=0 - ROOT outfeed0 = token[] outfeed(infeed0.data, token) + ROOT outfeed0 = token[] outfeed(infeed0.data, token0) } )"; diff --git a/tensorflow/compiler/xla/service/cpu/runtime_matmul.cc b/tensorflow/compiler/xla/service/cpu/runtime_matmul.cc index a71a85913cfef271bc2a226cb0cf2dd4204499a4..56f018abdd496e804dc4dea5420d400175491db3 100644 --- a/tensorflow/compiler/xla/service/cpu/runtime_matmul.cc +++ b/tensorflow/compiler/xla/service/cpu/runtime_matmul.cc @@ -23,6 +23,10 @@ limitations under the License. #include "tensorflow/core/platform/dynamic_annotations.h" #include "tensorflow/core/platform/types.h" +#if defined(TENSORFLOW_USE_CUSTOM_CONTRACTION_KERNEL) +#include "tensorflow/core/kernels/eigen_contraction_kernel.h" +#endif + using tensorflow::int32; using tensorflow::int64; diff --git a/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc b/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc index efccadedf27181a4cddf4f1dc3610f7c6db1d821..296f39a4853f2d3f7030209a921001e92c39d609 100644 --- a/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc +++ b/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc @@ -139,7 +139,7 @@ llvm::JITSymbol SimpleOrcJIT::ResolveRuntimeSymbol(const std::string& name) { } if (func_addr == nullptr) { - VLOG(2) << "Unable to resolve runtime symbol: " << name; + LOG(ERROR) << "Unable to resolve runtime symbol: " << name; return nullptr; } llvm::JITEvaluatedSymbol symbol_info(reinterpret_cast(func_addr), @@ -296,6 +296,9 @@ bool RegisterKnownJITSymbols() { REGISTER_LIBM_SYMBOL(sin, double (*)(double)); #ifdef __APPLE__ REGISTER_LIBM_SYMBOL(__sincos, void (*)(double, double*, double*)); + registry->Register("__sincosf_stret", + reinterpret_cast(__sincosf_stret)); + registry->Register("__sincos_stret", reinterpret_cast(__sincos_stret)); #else REGISTER_LIBM_SYMBOL(sincos, void (*)(double, double*, double*)); #endif @@ -311,6 +314,13 @@ bool RegisterKnownJITSymbols() { registry->Register("memcpy", reinterpret_cast(memcpy)); registry->Register("memmove", reinterpret_cast(memmove)); registry->Register("memset", reinterpret_cast(memset)); + +#ifdef __APPLE__ + registry->Register("__bzero", reinterpret_cast(bzero)); + registry->Register("memset_pattern16", + reinterpret_cast(memset_pattern16)); +#endif + return true; } diff --git a/tensorflow/compiler/xla/service/cpu/tests/cpu_literal_caching_test.cc b/tensorflow/compiler/xla/service/cpu/tests/cpu_literal_caching_test.cc index fa0e09ff6b5694c0e97963b83c6e541b858a1376..0584c0484f810a03ccccd522163f54535440ef8b 100644 --- a/tensorflow/compiler/xla/service/cpu/tests/cpu_literal_caching_test.cc +++ b/tensorflow/compiler/xla/service/cpu/tests/cpu_literal_caching_test.cc @@ -31,29 +31,27 @@ HloModule RepeatedConstants while_body { arg_body = f32[2,3,2] parameter(0) ROOT const = f32[2,3,2] constant( - f32[2,3,2] {{{1, 2}, {1001, 1002}, {2001, 2002}}, {{2, 1}, {2001, 3002}, {2001, 2002}}}) } while_cond { arg_cond = f32[2,3,2] parameter(0) - token = token[] after-all() - infeed = (pred[], token[]) infeed(token) + token0 = token[] after-all() + infeed = (pred[], token[]) infeed(token0) ROOT unknown = pred[] get-tuple-element((pred[], token[]) infeed), index=0 } ENTRY main { param = f32[2,3,2] parameter(0) const_a = f32[2,3,2] constant( - f32[2,3,2] {{{1, 2}, {1001, 1002}, {2001, 2002}}, {{2, 1}, {2001, 3002}, {2001, 2002}}}) const_b = f32[2,3,2] while(f32[2,3,2] const_a), condition=while_cond, body=while_body - token = token[] after-all() - out0 = token[] outfeed(f32[2,3,2] const_a, token[] token) - ROOT out1 = token[] outfeed(f32[2,3,2] const_b, token[] token) + token0 = token[] after-all() + out0 = token[] outfeed(f32[2,3,2] const_a, token[] token0) + ROOT out1 = token[] outfeed(f32[2,3,2] const_b, token[] token0) } )"; @@ -82,24 +80,24 @@ HloModule RepeatedConstants while_body { arg_body = (f32[2,1]{1,0}, f32[1]{0}) parameter(0) - ROOT const = (f32[2,1]{1,0}, f32[1]{0}) constant((f32[2,1], f32[1]) ( f32[2,1] { { 1 }, { 2 } }, {2} )) + ROOT const = (f32[2,1]{1,0}, f32[1]{0}) constant(({ { 1 }, { 2 } }, {2} )) } while_cond { arg_cond = (f32[2,1]{1,0}, f32[1]{0}) parameter(0) - token = token[] after-all() - infeed = (pred[], token[]) infeed(token) + token0 = token[] after-all() + infeed = (pred[], token[]) infeed(token0) ROOT unknown = pred[] get-tuple-element((pred[], token[]) infeed), index=0 } ENTRY main { param = f32[2,3,2] parameter(0) - const_a = (f32[2,1]{1,0}, f32[1]{0}) constant((f32[2,1], f32[1]) ( f32[2,1] { { 1 }, { 2 } }, {2} )) + const_a = (f32[2,1]{1,0}, f32[1]{0}) constant(( { { 1 }, { 2 } }, {2} )) const_b = (f32[2,1]{1,0}, f32[1]{0}) while((f32[2,1]{1,0}, f32[1]{0}) const_a), condition=while_cond, body=while_body - token = token[] after-all() - out0 = () outfeed((f32[2,1]{1,0}, f32[1]{0}) const_a, token[] token) - ROOT out1 = () outfeed((f32[2,1]{1,0}, f32[1]{0}) const_b, token[] token) + token0 = token[] after-all() + out0 = () outfeed((f32[2,1]{1,0}, f32[1]{0}) const_a, token[] token0) + ROOT out1 = () outfeed((f32[2,1]{1,0}, f32[1]{0}) const_b, token[] token0) } )"; diff --git a/tensorflow/compiler/xla/service/cpu/tests/cpu_outfeed_test.cc b/tensorflow/compiler/xla/service/cpu/tests/cpu_outfeed_test.cc index e2c7af541eede5265f274c72f55305549f059839..aab7f0b393881642437f1891256bd138823a3b87 100644 --- a/tensorflow/compiler/xla/service/cpu/tests/cpu_outfeed_test.cc +++ b/tensorflow/compiler/xla/service/cpu/tests/cpu_outfeed_test.cc @@ -28,12 +28,11 @@ HloModule Outfeed ENTRY main { const_a = f32[2,3,2] constant( - f32[2,3,2] {{{1, 2}, {1001, 1002}, {2001, 2002}}, {{2, 1}, {2001, 3002}, {2001, 2002}}}) - token = token[] after-all() - outfeed = token[] outfeed(f32[2,3,2] const_a, token) + token0 = token[] after-all() + outfeed = token[] outfeed(f32[2,3,2] const_a, token0) ROOT root = () tuple() } )"; diff --git a/tensorflow/compiler/xla/service/dynamic_dimension_inference.cc b/tensorflow/compiler/xla/service/dynamic_dimension_inference.cc new file mode 100644 index 0000000000000000000000000000000000000000..6d0472689bf48092ceef2e9792c1358687d707ec --- /dev/null +++ b/tensorflow/compiler/xla/service/dynamic_dimension_inference.cc @@ -0,0 +1,459 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/dynamic_dimension_inference.h" +#include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" + +namespace xla { + +namespace { +bool IsTrivialWindowDimension(const WindowDimension& window_dimension) { + return window_dimension.size() == 1 && window_dimension.stride() == 1 && + window_dimension.padding_low() == 0 && + window_dimension.padding_high() == 0 && + window_dimension.window_dilation() == 1 && + window_dimension.base_dilation() == 1; +} +} // namespace + +class DynamicDimensionInferenceVisitor : public DfsHloVisitorWithDefault { + public: + explicit DynamicDimensionInferenceVisitor( + const DynamicParameterBinding& param_bindings, + DynamicDimensionInference* parent) + : param_bindings_(param_bindings), parent_(parent) {} + + Status DefaultAction(HloInstruction* hlo) override; + + static Status Run(HloComputation* computation, + const DynamicParameterBinding& param_bindings, + DynamicDimensionInference* parent) { + DynamicDimensionInferenceVisitor visitor(param_bindings, parent); + return computation->Accept(&visitor); + } + + Status HandleParameter(HloInstruction* hlo) override; + + Status HandleReduce(HloInstruction* hlo) override; + + Status HandleDot(HloInstruction* hlo) override; + + Status HandleTranspose(HloInstruction* hlo) override; + + Status HandleReshape(HloInstruction* hlo) override; + + Status HandlePad(HloInstruction* hlo) override; + + Status HandleBroadcast(HloInstruction* hlo) override; + + Status HandleGetDimensionSize(HloInstruction* hlo) override; + + Status HandleSelect(HloInstruction* hlo) override; + + Status HandleConvolution(HloInstruction* hlo) override; + + Status HandleReduceWindow(HloInstruction* hlo) override; + + Status HandleSelectAndScatter(HloInstruction* hlo) override; + + Status HandleGetTupleElement(HloInstruction* hlo) override; + + Status HandleElementwiseUnary(HloInstruction* hlo) override; + + Status HandleElementwiseBinary(HloInstruction* hlo) override; + + private: + using OperandDynamicDimensionFn = std::function; + + Status ForEachOperandDynamicDimension(HloInstruction* inst, + const OperandDynamicDimensionFn&); + + // Pass through a dynamic dimension from the input to the output with the same + // value and index in the shape. This is a helper function to handle trivial + // instructions like elementwise operations. + Status PassThroughDynamicDimension(HloInstruction*); + + // The dynamic parameter bindings of this computation. + const DynamicParameterBinding& param_bindings_; + + // A pointer to DynamicDimensionInference, used to update the dynamic mapping. + DynamicDimensionInference* parent_; +}; + +Status DynamicDimensionInferenceVisitor::DefaultAction(HloInstruction* hlo) { + return ForEachOperandDynamicDimension( + hlo, [&](HloInstruction* operand, ShapeIndex index, int64 dimension, + int64 operand_index, HloInstruction* dynamic_size) { + return UnimplementedStrCat( + "Asked to propagate a dynamic dimension from hlo ", + operand->ToString(), "@", index.ToString(), "@", dimension, + " to hlo ", hlo->ToString(), ", which is not implemented."); + }); +} + +Status DynamicDimensionInferenceVisitor::HandleGetTupleElement( + HloInstruction* hlo) { + return ForEachOperandDynamicDimension( + hlo, [&](HloInstruction* operand, ShapeIndex index, int64 dimension, + int64 operand_index, HloInstruction* dynamic_size) { + if (hlo->tuple_index() == index[0]) { + ShapeIndex new_index = + ShapeIndexView(index).ConsumeFront().ToShapeIndex(); + parent_->SetDynamicSize(hlo, new_index, dimension, dynamic_size); + } + return Status::OK(); + }); +} + +Status DynamicDimensionInferenceVisitor::HandleBroadcast(HloInstruction* hlo) { + return ForEachOperandDynamicDimension( + hlo, [&](HloInstruction* operand, ShapeIndex index, int64 dimension, + int64 operand_index, HloInstruction* dynamic_size) { + int64 broadcast_dim = hlo->dimensions(dimension); + parent_->SetDynamicSize(hlo, index, broadcast_dim, dynamic_size); + return Status::OK(); + }); +} + +Status DynamicDimensionInferenceVisitor::HandlePad(HloInstruction* hlo) { + return ForEachOperandDynamicDimension( + hlo, [&](HloInstruction* operand, ShapeIndex index, int64 dimension, + int64 operand_index, HloInstruction* dynamic_size) { + if (operand_index != 0) { + return Unimplemented( + "Dynamic dimension on padding value is not supported"); + } + const PaddingConfig_PaddingConfigDimension& padding_config = + hlo->padding_config().dimensions(dimension); + if (padding_config.interior_padding() == 0 && + padding_config.edge_padding_low() == 0 && + padding_config.edge_padding_high() == 0) { + parent_->SetDynamicSize(hlo, {}, dimension, dynamic_size); + return Status::OK(); + } else { + return Unimplemented( + "Dynamic dimension propagation on padding dimension is not " + "supported."); + } + }); +} + +Status DynamicDimensionInferenceVisitor::HandleReduce(HloInstruction* hlo) { + return ForEachOperandDynamicDimension( + hlo, [&](HloInstruction* operand, ShapeIndex index, int64 dimension, + int64 operand_index, HloInstruction* dynamic_size) { + HloInstruction* reduce = hlo; + int64 operand_count = reduce->operand_count(); + CHECK_EQ(operand_count % 2, 0); + if (operand_index >= operand_count / 2) { + // Init values doesn't have dynamic size. + return Status::OK(); + } + if ((absl::c_count(reduce->dimensions(), dimension) != 0)) { + // Dimension is to be reduce, stop tracing. + return Status::OK(); + } + + // Find out the new dynamic dimension after reduce. + int64 dimensions_not_reduced_count = 0; + for (int i = 0; i < ShapeUtil::Rank(operand->shape()); ++i) { + if (dimension == i) { + parent_->SetDynamicSize(reduce, {}, dimensions_not_reduced_count, + dynamic_size); + + return Status::OK(); + } + if (absl::c_count(reduce->dimensions(), i) == 0) { + dimensions_not_reduced_count++; + } + } + + return Status::OK(); + }); +} + +Status DynamicDimensionInferenceVisitor::HandleDot(HloInstruction* hlo) { + return ForEachOperandDynamicDimension( + hlo, [&](HloInstruction* operand, ShapeIndex index, int64 dimension, + int64 operand_index, HloInstruction* dynamic_size) { + HloInstruction* dot = hlo; + const DotDimensionNumbers& dimension_numbers = + dot->dot_dimension_numbers(); + // A map from the operand dimensions to result dimension. + absl::flat_hash_map result_dim_mapping; + int64 current_result_dims = 0; + std::unordered_set batch_dims( + dimension_numbers.rhs_batch_dimensions().begin(), + dimension_numbers.rhs_batch_dimensions().end()); + + for (int64 i : dimension_numbers.rhs_batch_dimensions()) { + result_dim_mapping[i] = current_result_dims++; + } + + for (int64 i = 0; i < ShapeUtil::Rank(dot->operand(0)->shape()); i++) { + if (!absl::c_linear_search( + dimension_numbers.lhs_contracting_dimensions(), i)) { + if (operand_index == 0) { + result_dim_mapping[i] = current_result_dims; + } + current_result_dims++; + } + } + + for (int64 i = 0; i < ShapeUtil::Rank(dot->operand(1)->shape()); i++) { + if (!absl::c_linear_search( + dimension_numbers.rhs_contracting_dimensions(), i) && + !absl::c_linear_search(dimension_numbers.rhs_batch_dimensions(), + i)) { + if (operand_index == 1) { + result_dim_mapping[i] = current_result_dims; + } + current_result_dims++; + } + } + + // Check if the operand dim is in the result shape. If so, add another + // work item to trace that dimension. + auto iter = result_dim_mapping.find(dimension); + if (iter != result_dim_mapping.end()) { + parent_->SetDynamicSize(dot, {}, iter->second, dynamic_size); + } + + return Status::OK(); + }); +} + +Status DynamicDimensionInferenceVisitor::HandleTranspose(HloInstruction* hlo) { + return ForEachOperandDynamicDimension( + hlo, [&](HloInstruction* operand, ShapeIndex index, int64 dimension, + int64 operand_index, HloInstruction* dynamic_size) { + parent_->SetDynamicSize(hlo, {}, hlo->dimensions()[dimension], + dynamic_size); + return Status::OK(); + }); +} + +Status DynamicDimensionInferenceVisitor::HandleConvolution( + HloInstruction* hlo) { + return ForEachOperandDynamicDimension( + hlo, [&](HloInstruction* operand, ShapeIndex index, int64 dimension, + int64 operand_index, HloInstruction* dynamic_size) { + HloInstruction* conv = hlo; + const ConvolutionDimensionNumbers& dimension_numbers = + conv->convolution_dimension_numbers(); + + if (operand_index == 0) { + if (dimension == dimension_numbers.input_batch_dimension()) { + parent_->SetDynamicSize(conv, {}, + dimension_numbers.output_batch_dimension(), + dynamic_size); + return Status::OK(); + } + + if (dimension == dimension_numbers.input_feature_dimension()) { + return Status::OK(); + } + } else { + if (dimension == dimension_numbers.kernel_input_feature_dimension()) { + return Status::OK(); + } + } + + return Unimplemented("Dynamic Spatial Convolution is not supported: %s", + conv->ToString()); + }); +} + +Status DynamicDimensionInferenceVisitor::HandleGetDimensionSize( + HloInstruction*) { + // Dynamic dimension doesn't propagate through GetDimensionSize: + // + // Input: F32[x, y, z] + // | + // GetDimensionSize(1): U32[] + // + // The returned value is a scalar, which doesn't have any dynamic dimension in + // the shape (although the value contains the real size of the dynamic + // dimension of the input). + return Status::OK(); +} + +Status DynamicDimensionInferenceVisitor::PassThroughDynamicDimension( + HloInstruction* hlo) { + return ForEachOperandDynamicDimension( + hlo, [&](HloInstruction* operand, ShapeIndex index, int64 dimension, + int64 operand_index, HloInstruction* dynamic_size) { + parent_->SetDynamicSize(hlo, index, dimension, dynamic_size); + return Status::OK(); + }); +} + +Status DynamicDimensionInferenceVisitor::HandleElementwiseUnary( + HloInstruction* hlo) { + return PassThroughDynamicDimension(hlo); +} + +Status DynamicDimensionInferenceVisitor::HandleSelect(HloInstruction* hlo) { + return PassThroughDynamicDimension(hlo); +} + +Status DynamicDimensionInferenceVisitor::HandleElementwiseBinary( + HloInstruction* hlo) { + return PassThroughDynamicDimension(hlo); +} + +Status DynamicDimensionInferenceVisitor::HandleReshape(HloInstruction* hlo) { + return ForEachOperandDynamicDimension( + hlo, [&](HloInstruction* operand, ShapeIndex index, int64 dimension, + int64 operand_index, HloInstruction* dynamic_size) { + HloInstruction* reshape = hlo; + std::vector> unmodified_dims = + ShapeUtil::DimensionsUnmodifiedByReshape(operand->shape(), + reshape->shape()); + for (auto& unmodified : unmodified_dims) { + if (unmodified.first == dimension) { + parent_->SetDynamicSize(reshape, {}, unmodified.second, + dynamic_size); + return Status::OK(); + } + } + return Unimplemented( + "Dynamic Reshape on modified dimensions is yet not supported: %s", + reshape->ToString()); + }); +} + +Status DynamicDimensionInferenceVisitor::HandleReduceWindow( + HloInstruction* hlo) { + return ForEachOperandDynamicDimension( + hlo, [&](HloInstruction* operand, ShapeIndex index, int64 dimension, + int64 operand_index, HloInstruction* dynamic_size) { + HloInstruction* reduce_window = hlo; + const WindowDimension& window_dimension = + reduce_window->window().dimensions(dimension); + + if (!IsTrivialWindowDimension(window_dimension)) { + return Unimplemented( + "Dynamic Spatial reduce window is not supported: %s", + reduce_window->ToString()); + } + + parent_->SetDynamicSize(reduce_window, {}, dimension, dynamic_size); + + return Status::OK(); + }); +} + +Status DynamicDimensionInferenceVisitor::HandleSelectAndScatter( + HloInstruction* hlo) { + return ForEachOperandDynamicDimension( + hlo, [&](HloInstruction* operand, ShapeIndex index, int64 dimension, + int64 operand_index, HloInstruction* dynamic_size) { + HloInstruction* select_and_scatter = hlo; + const WindowDimension& window_dimension = + select_and_scatter->window().dimensions(dimension); + + if (!IsTrivialWindowDimension(window_dimension)) { + return Unimplemented( + "Dynamic Spatial select and scatter is not supported: %s", + select_and_scatter->ToString()); + } + + parent_->SetDynamicSize(select_and_scatter, {}, dimension, + dynamic_size); + + return Status::OK(); + }); +} + +Status DynamicDimensionInferenceVisitor::HandleParameter(HloInstruction* hlo) { + return param_bindings_.ForEachBinding( + [&](const DynamicParameterBinding::DynamicParameter& dynamic_parameter, + const DynamicParameterBinding::DynamicDimension& dynamic_dimension) { + if (dynamic_dimension.parameter_num != hlo->parameter_number()) { + return Status::OK(); + } + HloComputation* computation = hlo->parent(); + HloInstruction* target_parameter = + computation->parameter_instruction(dynamic_dimension.parameter_num); + + HloInstruction* dynamic_size = + computation->parameter_instruction(dynamic_parameter.parameter_num); + for (int64 i : dynamic_parameter.parameter_index) { + dynamic_size = + computation->AddInstruction(HloInstruction::CreateGetTupleElement( + ShapeUtil::GetSubshape(dynamic_size->shape(), {i}), + dynamic_size, i)); + } + + parent_->SetDynamicSize(target_parameter, + dynamic_dimension.parameter_index, + dynamic_dimension.dimension, dynamic_size); + return Status::OK(); + }); +} + +Status DynamicDimensionInferenceVisitor::ForEachOperandDynamicDimension( + HloInstruction* inst, const OperandDynamicDimensionFn& fn) { + for (int64 operand_index = 0; operand_index < inst->operand_count(); + ++operand_index) { + auto iter = + parent_->per_hlo_dynamic_dimensions_.find(inst->operand(operand_index)); + if (iter != parent_->per_hlo_dynamic_dimensions_.end()) { + for (auto& dynamic_dimension : iter->second) { + HloInstruction* dynamic_size = parent_->GetDynamicSize( + dynamic_dimension.inst, dynamic_dimension.index, + dynamic_dimension.dim); + TF_RETURN_IF_ERROR(fn(dynamic_dimension.inst, dynamic_dimension.index, + dynamic_dimension.dim, operand_index, + dynamic_size)); + } + } + } + return Status::OK(); +} + +/* static */ +StatusOr DynamicDimensionInference::Run( + HloModule* module) { + VLOG(0) << "Param Config " << module->dynamic_parameter_binding().ToString(); + DynamicDimensionInference inference(module); + TF_RETURN_IF_ERROR(inference.AnalyzeDynamicDimensions()); + return inference; +} + +DynamicDimensionInference::DynamicDimensionInference(HloModule* module) + : module_(module) {} + +Status DynamicDimensionInference::AnalyzeDynamicDimensions() { + return DynamicDimensionInferenceVisitor::Run( + module_->entry_computation(), module_->dynamic_parameter_binding(), this); +} + +HloInstruction* DynamicDimensionInference::GetDynamicSize( + HloInstruction* inst, const ShapeIndex& index, int64 dim) const { + auto iter = dynamic_mapping_.find(DynamicDimension{inst, index, dim}); + if (iter != dynamic_mapping_.end()) { + return iter->second; + } + return nullptr; +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/dynamic_dimension_inference.h b/tensorflow/compiler/xla/service/dynamic_dimension_inference.h new file mode 100644 index 0000000000000000000000000000000000000000..164d15bf111a92e3da957f609b54ee0662ef18b1 --- /dev/null +++ b/tensorflow/compiler/xla/service/dynamic_dimension_inference.h @@ -0,0 +1,112 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_DYNAMIC_DIMENSION_INFERENCE_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_DYNAMIC_DIMENSION_INFERENCE_H_ + +#include +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/types/span.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/status.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/core/platform/macros.h" + +namespace xla { + +// DynamicDimensionInference analyzes each HLO instruction in a graph and +// inferences which dimensions are dynamic and which scalar instructions +// represent the runtime real size of those dynamic dimensions. +class DynamicDimensionInference { + public: + static StatusOr Run(HloModule* module); + + string ToString() const; + + // If the dimension `dim` of instruction `inst` at `index` has a dynamic size, + // returns a scalar HloInstruction that represents the runtime size of that + // dimension. Otherwise returns nullptr. + HloInstruction* GetDynamicSize(HloInstruction* inst, const ShapeIndex& index, + int64 dim) const; + + friend class DynamicDimensionInferenceVisitor; + + private: + explicit DynamicDimensionInference(HloModule* module); + + // DynamicDimension is used as a key in the dynamic key-value mapping. It + // unambiguously represents a dynamic dimension of a instruction at a given + // index. + struct DynamicDimension { + // HloInstruction that holds the dimension. + HloInstruction* inst; + // Subshape of the instruction that holds the dimension. + ShapeIndex index; + // The dimension number of the dynamic dimension at given index of a given + // instruction. + int64 dim; + + // Artifacts needed to make this struct able to be used as a `key` in absl + // maps. "friend" keywords are added so these functions can be found through + // ADL. + template + friend H AbslHashValue(H h, const DynamicDimension& m) { + return H::combine(std::move(h), m.inst, m.index, m.dim); + } + + friend bool operator==(const DynamicDimension& lhs, + const DynamicDimension& rhs) { + return lhs.inst == rhs.inst && lhs.index == rhs.index && + lhs.dim == rhs.dim; + } + }; + + // Update the dynamic mapping so that we know dimension `dim` of instruction + // `inst` at `index` has a dynamic size, and its runtime size is represented + // by a scalar instruction `size`. + void SetDynamicSize(HloInstruction* inst, const ShapeIndex& index, int64 dim, + HloInstruction* size) { + dynamic_mapping_.try_emplace(DynamicDimension{inst, index, dim}, size); + auto iter = per_hlo_dynamic_dimensions_.try_emplace(inst); + iter.first->second.emplace(DynamicDimension{inst, index, dim}); + } + + // AnalyzeDynamicDimensions starts the analysis of the dynamic dimensions in + // module_. + Status AnalyzeDynamicDimensions(); + + // HloModule being analyzed. + HloModule* module_; + + // dynamic_mapping_ holds the result of the analysis. It maps a dynamic + // dimension to a scalar HloInstruction that represents the real dynamic size + // of the dynamic dimension. + using DynamicMapping = absl::flat_hash_map; + DynamicMapping dynamic_mapping_; + + using PerHloDynamicDimensions = + absl::flat_hash_map>; + PerHloDynamicDimensions per_hlo_dynamic_dimensions_; +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_DYNAMIC_DIMENSION_INFERENCE_H_ diff --git a/tensorflow/compiler/xla/service/dynamic_dimension_inference_test.cc b/tensorflow/compiler/xla/service/dynamic_dimension_inference_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..ea9ebed45d99797ce4f80376ec3d0b758da3ca17 --- /dev/null +++ b/tensorflow/compiler/xla/service/dynamic_dimension_inference_test.cc @@ -0,0 +1,535 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/dynamic_dimension_inference.h" + +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/literal.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_matchers.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/service/hlo_runner.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/test_helpers.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/platform/test_benchmark.h" + +namespace op = xla::testing::opcode_matchers; + +namespace xla { +namespace { + +class DynamicDimensionInferenceTest : public HloTestBase { + protected: + DynamicDimensionInferenceTest() : HloTestBase() { + module_ = CreateNewVerifiedModule(); + } + + Status RunInference() { + hlo_graph_dumper::MaybeDumpHloModule(*module_, "Before alias analysis"); + TF_ASSIGN_OR_RETURN(DynamicDimensionInference inference, + DynamicDimensionInference::Run(module_.get())); + + inference_ = absl::make_unique(inference); + return Status::OK(); + } + + HloComputation* GetAdd() { + auto embedded_builder = HloComputation::Builder("add"); + auto lhs = embedded_builder.AddInstruction(HloInstruction::CreateParameter( + 0, ShapeUtil::MakeShape(F32, {}), "lhs")); + auto rhs = embedded_builder.AddInstruction(HloInstruction::CreateParameter( + 1, ShapeUtil::MakeShape(F32, {}), "rhs")); + embedded_builder.AddInstruction( + HloInstruction::CreateBinary(lhs->shape(), HloOpcode::kAdd, lhs, rhs)); + return module_->AddEmbeddedComputation(embedded_builder.Build()); + } + + std::unique_ptr module_; + std::unique_ptr inference_; + const Shape scalar_shape_ = ShapeUtil::MakeShape(S32, {}); +}; + +TEST_F(DynamicDimensionInferenceTest, ParamTest) { + auto builder = HloComputation::Builder(TestName()); + auto input_shape = ShapeUtil::MakeShape(F32, {1, 2, 2}); + + auto param = builder.AddInstruction( + HloInstruction::CreateParameter(0, input_shape, "param")); + auto param2 = builder.AddInstruction( + HloInstruction::CreateParameter(1, scalar_shape_, "param")); + + module_->AddEntryComputation(builder.Build()); + // Set up dynamic parameter binding. + TF_CHECK_OK(module_->dynamic_parameter_binding().Bind( + DynamicParameterBinding::DynamicParameter{1, {}}, + DynamicParameterBinding::DynamicDimension{0, {}, 1})); + + TF_ASSERT_OK(RunInference()); + EXPECT_EQ(inference_->GetDynamicSize(param, {}, 1), param2); + EXPECT_EQ(inference_->GetDynamicSize(param, {}, 0), nullptr); + EXPECT_EQ(inference_->GetDynamicSize(param2, {}, 0), nullptr); +} + +TEST_F(DynamicDimensionInferenceTest, ParamTestTuple) { + auto builder = HloComputation::Builder(TestName()); + auto input_shape = ShapeUtil::MakeShape(F32, {1, 2, 2}); + + auto param = builder.AddInstruction(HloInstruction::CreateParameter( + 0, ShapeUtil::MakeTupleShape({input_shape, scalar_shape_}), "param")); + + module_->AddEntryComputation(builder.Build()); + // Set up dynamic parameter binding. + TF_CHECK_OK(module_->dynamic_parameter_binding().Bind( + DynamicParameterBinding::DynamicParameter{0, {1}}, + DynamicParameterBinding::DynamicDimension{0, {0}, 1})); + + TF_ASSERT_OK(RunInference()); + EXPECT_THAT(inference_->GetDynamicSize(param, {0}, 1), + op::GetTupleElement(param, 1)); + + EXPECT_EQ(inference_->GetDynamicSize(param, {0}, 0), nullptr); +} + +TEST_F(DynamicDimensionInferenceTest, GetTupleElement) { + // When data flows through GTE, the dynamic dimension size keeps the + // same, and the index has its front popped. + auto builder = HloComputation::Builder(TestName()); + auto input_shape = ShapeUtil::MakeShape(F32, {1, 2, 2}); + + auto param = builder.AddInstruction(HloInstruction::CreateParameter( + 0, ShapeUtil::MakeTupleShape({input_shape, scalar_shape_}), "param")); + + auto gte = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(input_shape, param, 0)); + + module_->AddEntryComputation(builder.Build()); + // Set up dynamic parameter binding. + TF_CHECK_OK(module_->dynamic_parameter_binding().Bind( + DynamicParameterBinding::DynamicParameter{0, {1}}, + DynamicParameterBinding::DynamicDimension{0, {0}, 1})); + + TF_ASSERT_OK(RunInference()); + EXPECT_THAT(inference_->GetDynamicSize(param, {0}, 1), + op::GetTupleElement(param, 1)); + + EXPECT_THAT(inference_->GetDynamicSize(gte, {}, 1), + op::GetTupleElement(param, 1)); + + EXPECT_EQ(inference_->GetDynamicSize(param, {0}, 0), nullptr); +} + +TEST_F(DynamicDimensionInferenceTest, ElementwiseTest) { + // When data flows through elementwise, the dynamic dimension size keeps the + // same. + auto builder = HloComputation::Builder(TestName()); + auto input_shape = ShapeUtil::MakeShape(F32, {1, 2, 2}); + + auto data_param = builder.AddInstruction( + HloInstruction::CreateParameter(0, input_shape, "data_param")); + auto size_param = builder.AddInstruction( + HloInstruction::CreateParameter(1, scalar_shape_, "size_param")); + + auto* negate = builder.AddInstruction( + HloInstruction::CreateUnary(input_shape, HloOpcode::kNegate, data_param)); + + module_->AddEntryComputation(builder.Build()); + // Set up dynamic parameter binding. + TF_CHECK_OK(module_->dynamic_parameter_binding().Bind( + DynamicParameterBinding::DynamicParameter{1, {}}, + DynamicParameterBinding::DynamicDimension{0, {}, 1})); + + TF_ASSERT_OK(RunInference()); + EXPECT_EQ(inference_->GetDynamicSize(negate, {}, 1), size_param); +} + +TEST_F(DynamicDimensionInferenceTest, ReduceTestI) { + auto builder = HloComputation::Builder(TestName()); + auto input_shape = ShapeUtil::MakeShape(F32, {1, 2, 2}); + auto reduce_shape = ShapeUtil::MakeShape(F32, {2}); + + auto data_param = builder.AddInstruction( + HloInstruction::CreateParameter(0, input_shape, "data_param")); + auto size_param = builder.AddInstruction( + HloInstruction::CreateParameter(1, scalar_shape_, "size_param")); + + auto negate = builder.AddInstruction( + HloInstruction::CreateUnary(input_shape, HloOpcode::kNegate, data_param)); + + auto init = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.0))); + + auto reduce = builder.AddInstruction(HloInstruction::CreateReduce( + reduce_shape, negate, init, {0, 2}, GetAdd())); + + module_->AddEntryComputation(builder.Build()); + + // Set up dynamic parameter binding. + TF_CHECK_OK(module_->dynamic_parameter_binding().Bind( + DynamicParameterBinding::DynamicParameter{1, {}}, + DynamicParameterBinding::DynamicDimension{0, {}, 1})); + + TF_ASSERT_OK(RunInference()); + EXPECT_EQ(inference_->GetDynamicSize(reduce, {}, 0), size_param); +} + +TEST_F(DynamicDimensionInferenceTest, ReduceTestII) { + // Same as ReduceTestI, but only reduce one dimension. + auto builder = HloComputation::Builder(TestName()); + auto input_shape = ShapeUtil::MakeShape(F32, {1, 2, 2}); + auto reduce_shape = ShapeUtil::MakeShape(F32, {1, 2}); + + auto data_param = builder.AddInstruction( + HloInstruction::CreateParameter(0, input_shape, "data_param")); + auto size_param = builder.AddInstruction( + HloInstruction::CreateParameter(1, scalar_shape_, "size_param")); + + auto negate = builder.AddInstruction( + HloInstruction::CreateUnary(input_shape, HloOpcode::kNegate, data_param)); + + auto init = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.0))); + + auto reduce = builder.AddInstruction( + HloInstruction::CreateReduce(reduce_shape, negate, init, {1}, GetAdd())); + + module_->AddEntryComputation(builder.Build()); + + // Set up dynamic parameter binding. + TF_CHECK_OK(module_->dynamic_parameter_binding().Bind( + DynamicParameterBinding::DynamicParameter{1, {}}, + DynamicParameterBinding::DynamicDimension{0, {}, 2})); + + TF_ASSERT_OK(RunInference()); + EXPECT_EQ(inference_->GetDynamicSize(reduce, {}, 1), size_param); + EXPECT_EQ(inference_->GetDynamicSize(reduce, {}, 0), nullptr); +} + +TEST_F(DynamicDimensionInferenceTest, DotTest) { + auto builder = HloComputation::Builder(TestName()); + constexpr int xdim = 3; + constexpr int ydim = 2; + constexpr int zdim = 1; + auto xy_shape = ShapeUtil::MakeShape(F32, {xdim, ydim}); + auto yz_shape = ShapeUtil::MakeShape(F32, {ydim, zdim}); + auto xz_shape = ShapeUtil::MakeShape(F32, {xdim, zdim}); + + auto* a_param = builder.AddInstruction(HloInstruction::CreateParameter( + /*parameter_number=*/0, xy_shape, "A")); + auto* b_param = builder.AddInstruction(HloInstruction::CreateParameter( + /*parameter_number=*/1, yz_shape, "B")); + auto* size_param = builder.AddInstruction(HloInstruction::CreateParameter( + /*parameter_number=*/2, scalar_shape_, "size_param")); + + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(1); + dot_dnums.add_rhs_contracting_dimensions(0); + auto dot = builder.AddInstruction( + HloInstruction::CreateDot(xz_shape, a_param, b_param, dot_dnums, + HloTestBase::DefaultPrecisionConfig(2))); + + module_->AddEntryComputation(builder.Build()); + + // Set up dynamic parameter binding for non-contracting dimension. + TF_CHECK_OK(module_->dynamic_parameter_binding().Bind( + DynamicParameterBinding::DynamicParameter{2, {}}, + DynamicParameterBinding::DynamicDimension{0, {}, 0})); + + // Set up binding for contracting dimensions. + TF_CHECK_OK(module_->dynamic_parameter_binding().Bind( + DynamicParameterBinding::DynamicParameter{2, {}}, + DynamicParameterBinding::DynamicDimension{0, {}, 1})); + TF_CHECK_OK(module_->dynamic_parameter_binding().Bind( + DynamicParameterBinding::DynamicParameter{2, {}}, + DynamicParameterBinding::DynamicDimension{1, {}, 0})); + + TF_ASSERT_OK(RunInference()); + EXPECT_EQ(inference_->GetDynamicSize(dot, {}, 0), size_param); + EXPECT_EQ(inference_->GetDynamicSize(dot, {}, 1), nullptr); +} + +TEST_F(DynamicDimensionInferenceTest, ConvolutionTest) { + auto builder = HloComputation::Builder(TestName()); + constexpr int xdim = 3; + constexpr int ydim = 2; + constexpr int zdim = 1; + auto xy_shape = ShapeUtil::MakeShape(F32, {xdim, ydim}); + auto yz_shape = ShapeUtil::MakeShape(F32, {ydim, zdim}); + auto zx_shape = ShapeUtil::MakeShape(F32, {zdim, xdim}); + + auto* a_param = builder.AddInstruction(HloInstruction::CreateParameter( + /*parameter_number=*/0, xy_shape, "A")); + auto* b_param = builder.AddInstruction(HloInstruction::CreateParameter( + /*parameter_number=*/1, yz_shape, "B")); + auto* size_param = builder.AddInstruction(HloInstruction::CreateParameter( + /*parameter_number=*/2, scalar_shape_, "size_param")); + + auto dnums = XlaBuilder::CreateDefaultConvDimensionNumbers(0); + + dnums.set_kernel_input_feature_dimension(0); + dnums.set_kernel_output_feature_dimension(1); + dnums.set_input_batch_dimension(0); + dnums.set_output_batch_dimension(1); + dnums.set_output_feature_dimension(0); + + Window window; + + auto* conv = builder.AddInstruction(HloInstruction::CreateConvolve( + zx_shape, a_param, b_param, /*feature_group_count=*/1, window, dnums, + HloTestBase::DefaultPrecisionConfig(2))); + + module_->AddEntryComputation(builder.Build()); + + // Set up dynamic parameter binding for non-contracting dimension. + TF_CHECK_OK(module_->dynamic_parameter_binding().Bind( + DynamicParameterBinding::DynamicParameter{2, {}}, + DynamicParameterBinding::DynamicDimension{0, {}, 0})); + + // Set up binding for contracting dimensions. + TF_CHECK_OK(module_->dynamic_parameter_binding().Bind( + DynamicParameterBinding::DynamicParameter{2, {}}, + DynamicParameterBinding::DynamicDimension{0, {}, 1})); + + TF_ASSERT_OK(RunInference()); + EXPECT_EQ(inference_->GetDynamicSize(conv, {}, 1), size_param); + EXPECT_EQ(inference_->GetDynamicSize(conv, {}, 0), nullptr); +} + +TEST_F(DynamicDimensionInferenceTest, TransposeTest) { + // Test the ability to trace unmodified dimensions + auto builder = HloComputation::Builder(TestName()); + auto input_shape = ShapeUtil::MakeShape(F32, {1, 2, 3}); + auto output_shape = ShapeUtil::MakeShape(F32, {3, 2, 1}); + + auto* a_param = builder.AddInstruction(HloInstruction::CreateParameter( + /*parameter_number=*/0, input_shape, "A")); + auto* size_param_1 = builder.AddInstruction(HloInstruction::CreateParameter( + /*parameter_number=*/1, scalar_shape_, "size_param")); + auto* size_param_2 = builder.AddInstruction(HloInstruction::CreateParameter( + /*parameter_number=*/2, scalar_shape_, "size_param")); + auto* size_param_3 = builder.AddInstruction(HloInstruction::CreateParameter( + /*parameter_number=*/3, scalar_shape_, "size_param")); + + auto* transpose = builder.AddInstruction( + HloInstruction::CreateTranspose(output_shape, a_param, {2, 1, 0})); + + module_->AddEntryComputation(builder.Build()); + + TF_CHECK_OK(module_->dynamic_parameter_binding().Bind( + DynamicParameterBinding::DynamicParameter{1, {}}, + DynamicParameterBinding::DynamicDimension{0, {}, 0})); + + TF_CHECK_OK(module_->dynamic_parameter_binding().Bind( + DynamicParameterBinding::DynamicParameter{2, {}}, + DynamicParameterBinding::DynamicDimension{0, {}, 1})); + + TF_CHECK_OK(module_->dynamic_parameter_binding().Bind( + DynamicParameterBinding::DynamicParameter{3, {}}, + DynamicParameterBinding::DynamicDimension{0, {}, 2})); + + TF_ASSERT_OK(RunInference()); + EXPECT_EQ(inference_->GetDynamicSize(transpose, {}, 0), size_param_3); + EXPECT_EQ(inference_->GetDynamicSize(transpose, {}, 1), size_param_2); + EXPECT_EQ(inference_->GetDynamicSize(transpose, {}, 2), size_param_1); +} + +TEST_F(DynamicDimensionInferenceTest, ReshapeTest) { + // Test the ability to trace unmodified reshape dimensions. + auto builder = HloComputation::Builder(TestName()); + auto input_shape = ShapeUtil::MakeShape(F32, {2, 3, 4, 5, 6}); + auto output_shape = ShapeUtil::MakeShape(F32, {6, 4, 1, 5, 2, 3}); + + auto* a_param = builder.AddInstruction(HloInstruction::CreateParameter( + /*parameter_number=*/0, input_shape, "A")); + auto* size_param = builder.AddInstruction(HloInstruction::CreateParameter( + /*parameter_number=*/1, scalar_shape_, "size_param")); + + auto* reshape = builder.AddInstruction( + HloInstruction::CreateReshape(output_shape, a_param)); + + module_->AddEntryComputation(builder.Build()); + + TF_CHECK_OK(module_->dynamic_parameter_binding().Bind( + DynamicParameterBinding::DynamicParameter{1, {}}, + DynamicParameterBinding::DynamicDimension{0, {}, 2})); + + TF_CHECK_OK(module_->dynamic_parameter_binding().Bind( + DynamicParameterBinding::DynamicParameter{1, {}}, + DynamicParameterBinding::DynamicDimension{0, {}, 3})); + + TF_ASSERT_OK(RunInference()); + EXPECT_EQ(inference_->GetDynamicSize(reshape, {}, 0), nullptr); + EXPECT_EQ(inference_->GetDynamicSize(reshape, {}, 1), size_param); + EXPECT_EQ(inference_->GetDynamicSize(reshape, {}, 2), nullptr); + EXPECT_EQ(inference_->GetDynamicSize(reshape, {}, 3), size_param); + EXPECT_EQ(inference_->GetDynamicSize(reshape, {}, 4), nullptr); + EXPECT_EQ(inference_->GetDynamicSize(reshape, {}, 5), nullptr); +} + +TEST_F(DynamicDimensionInferenceTest, ReshapeTestUnimplemented) { + // Test the ability to trace unmodified reshape dimensions. + auto builder = HloComputation::Builder(TestName()); + auto input_shape = ShapeUtil::MakeShape(F32, {2, 3, 4, 5, 6}); + auto output_shape = ShapeUtil::MakeShape(F32, {6, 4, 1, 5, 2, 3}); + + auto* a_param = builder.AddInstruction(HloInstruction::CreateParameter( + /*parameter_number=*/0, input_shape, "A")); + + builder.AddInstruction(HloInstruction::CreateParameter( + /*parameter_number=*/1, scalar_shape_, "size_param")); + + builder.AddInstruction(HloInstruction::CreateReshape(output_shape, a_param)); + + module_->AddEntryComputation(builder.Build()); + + TF_CHECK_OK(module_->dynamic_parameter_binding().Bind( + DynamicParameterBinding::DynamicParameter{1, {}}, + DynamicParameterBinding::DynamicDimension{0, {}, 1})); + + Status status = RunInference(); + EXPECT_EQ(status.code(), tensorflow::error::UNIMPLEMENTED); +} + +TEST_F(DynamicDimensionInferenceTest, BroadcastTest) { + // Test the ability to trace broadcast dimension. + auto builder = HloComputation::Builder(TestName()); + auto input_shape = ShapeUtil::MakeShape(F32, {2}); + auto output_shape = ShapeUtil::MakeShape(F32, {3, 2, 4}); + + auto* a_param = builder.AddInstruction(HloInstruction::CreateParameter( + /*parameter_number=*/0, input_shape, "A")); + auto* size_param = builder.AddInstruction(HloInstruction::CreateParameter( + /*parameter_number=*/1, scalar_shape_, "size_param")); + + auto* broadcast = builder.AddInstruction( + HloInstruction::CreateBroadcast(output_shape, a_param, {1})); + + module_->AddEntryComputation(builder.Build()); + + TF_CHECK_OK(module_->dynamic_parameter_binding().Bind( + DynamicParameterBinding::DynamicParameter{1, {}}, + DynamicParameterBinding::DynamicDimension{0, {}, 0})); + + TF_ASSERT_OK(RunInference()); + EXPECT_EQ(inference_->GetDynamicSize(broadcast, {}, 0), nullptr); + EXPECT_EQ(inference_->GetDynamicSize(broadcast, {}, 1), size_param); + EXPECT_EQ(inference_->GetDynamicSize(broadcast, {}, 2), nullptr); +} + +TEST_F(DynamicDimensionInferenceTest, ReduceWindowBatchTest) { + // Test the ability to trace reduce window batch dimensions. + auto builder = HloComputation::Builder(TestName()); + auto input_shape = ShapeUtil::MakeShape(F32, {2, 4, 4}); + auto output_shape = ShapeUtil::MakeShape(F32, {2, 2, 2}); + + Window window; + // First dimension is unchanged. + WindowDimension* batch_dim = window.add_dimensions(); + batch_dim->set_size(1); + batch_dim->set_stride(1); + batch_dim->set_padding_low(0); + batch_dim->set_padding_high(0); + batch_dim->set_window_dilation(1); + batch_dim->set_base_dilation(1); + + // Second and third dimension are reduced. + for (int64 i = 0; i < 2; ++i) { + WindowDimension* dim = window.add_dimensions(); + dim->set_size(2); + dim->set_stride(2); + dim->set_padding_low(0); + dim->set_padding_high(0); + dim->set_window_dilation(1); + dim->set_base_dilation(1); + } + + auto* a_param = builder.AddInstruction(HloInstruction::CreateParameter( + /*parameter_number=*/0, input_shape, "A")); + auto* size_param = builder.AddInstruction(HloInstruction::CreateParameter( + /*parameter_number=*/1, scalar_shape_, "size_param")); + + auto init = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.0))); + + auto* reduce_window = + builder.AddInstruction(HloInstruction::CreateReduceWindow( + output_shape, a_param, init, window, GetAdd())); + + module_->AddEntryComputation(builder.Build()); + + TF_CHECK_OK(module_->dynamic_parameter_binding().Bind( + DynamicParameterBinding::DynamicParameter{1, {}}, + DynamicParameterBinding::DynamicDimension{0, {}, 0})); + + TF_ASSERT_OK(RunInference()); + EXPECT_EQ(inference_->GetDynamicSize(reduce_window, {}, 0), size_param); +} + +TEST_F(DynamicDimensionInferenceTest, SelectAndScatterTest) { + // Test the ability to trace select and scatter batch dimensions. + auto builder = HloComputation::Builder(TestName()); + auto input_shape = ShapeUtil::MakeShape(F32, {2, 4, 4}); + auto output_shape = ShapeUtil::MakeShape(F32, {2, 2, 2}); + + Window window; + // First dimension is unchanged. + WindowDimension* batch_dim = window.add_dimensions(); + batch_dim->set_size(1); + batch_dim->set_stride(1); + batch_dim->set_padding_low(0); + batch_dim->set_padding_high(0); + batch_dim->set_window_dilation(1); + batch_dim->set_base_dilation(1); + + // Second and third dimension are reduced. + for (int64 i = 0; i < 2; ++i) { + WindowDimension* dim = window.add_dimensions(); + dim->set_size(2); + dim->set_stride(2); + dim->set_padding_low(0); + dim->set_padding_high(0); + dim->set_window_dilation(1); + dim->set_base_dilation(1); + } + + auto* a_param = builder.AddInstruction(HloInstruction::CreateParameter( + /*parameter_number=*/0, input_shape, "A")); + auto* size_param = builder.AddInstruction(HloInstruction::CreateParameter( + /*parameter_number=*/1, scalar_shape_, "size_param")); + + auto init = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.0))); + + auto* reduce_window = + builder.AddInstruction(HloInstruction::CreateReduceWindow( + output_shape, a_param, init, window, GetAdd())); + + module_->AddEntryComputation(builder.Build()); + + TF_CHECK_OK(module_->dynamic_parameter_binding().Bind( + DynamicParameterBinding::DynamicParameter{1, {}}, + DynamicParameterBinding::DynamicDimension{0, {}, 0})); + + TF_ASSERT_OK(RunInference()); + EXPECT_EQ(inference_->GetDynamicSize(reduce_window, {}, 0), size_param); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_conv_padding_legalization.cc b/tensorflow/compiler/xla/service/gpu/cudnn_conv_padding_legalization.cc index d7829045cc127deaa4c2c9b705dca5285d704af2..3a09d4d4716950a09d65dd093272482d55ac5c27 100644 --- a/tensorflow/compiler/xla/service/gpu/cudnn_conv_padding_legalization.cc +++ b/tensorflow/compiler/xla/service/gpu/cudnn_conv_padding_legalization.cc @@ -43,13 +43,14 @@ bool IsForwardConvolutionCanonical(const HloInstruction& conv) { // dilation), returns kPad and/or kSlice instructions that explicitly apply the // padding; otherwise returns the original input operand. When there is both // positive padding (including dilation) and negative padding, we insert both -// kPad and kSlice. +// kPad and kSlice. Modifies 'conv_window' accordingly if any padding was moved +// into a kPad or kSlice op. HloInstruction* MaybePaddedAndSlicedInput( - const Window& conv_window, const ConvolutionDimensionNumbers& conv_dnums, + Window* conv_window, const ConvolutionDimensionNumbers& conv_dnums, HloInstruction* input) { HloComputation* computation = input->parent(); - if (!window_util::HasSymmetricPadding(conv_window) || - window_util::HasBaseDilation(conv_window)) { + if (!window_util::HasSymmetricPadding(*conv_window) || + window_util::HasBaseDilation(*conv_window)) { // If padding is uneven or has dilation, we insert a kPad instruction that // applies positive padding and dilation. // @@ -62,12 +63,21 @@ HloInstruction* MaybePaddedAndSlicedInput( MakeNoPaddingConfig(input->shape().dimensions_size()); for (size_t i = 0; i < conv_dnums.input_spatial_dimensions().size(); ++i) { int64 dim = conv_dnums.input_spatial_dimensions(i); - padding_config.mutable_dimensions(dim)->set_edge_padding_low( - std::max(0LL, conv_window.dimensions(i).padding_low())); - padding_config.mutable_dimensions(dim)->set_edge_padding_high( - std::max(0LL, conv_window.dimensions(i).padding_high())); - padding_config.mutable_dimensions(dim)->set_interior_padding( - conv_window.dimensions(i).base_dilation() - 1); + if (conv_window->dimensions(i).padding_low() > 0) { + padding_config.mutable_dimensions(dim)->set_edge_padding_low( + conv_window->dimensions(i).padding_low()); + conv_window->mutable_dimensions(i)->set_padding_low(0); + } + if (conv_window->dimensions(i).padding_high() > 0) { + padding_config.mutable_dimensions(dim)->set_edge_padding_high( + conv_window->dimensions(i).padding_high()); + conv_window->mutable_dimensions(i)->set_padding_high(0); + } + if (conv_window->dimensions(i).base_dilation() != 1) { + padding_config.mutable_dimensions(dim)->set_interior_padding( + conv_window->dimensions(i).base_dilation() - 1); + conv_window->mutable_dimensions(i)->set_base_dilation(1); + } } PrimitiveType element_type = input->shape().element_type(); HloInstruction* padding = computation->AddInstruction( @@ -75,7 +85,7 @@ HloInstruction* MaybePaddedAndSlicedInput( input = MakePadHlo(input, padding, padding_config).ValueOrDie(); } - if (window_util::HasNegativePadding(conv_window)) { + if (window_util::HasNegativePadding(*conv_window)) { // If the window has negative padding, insert a kSlice that explicitly // applies negative padding. // @@ -89,10 +99,14 @@ HloInstruction* MaybePaddedAndSlicedInput( int64 dim = conv_dnums.input_spatial_dimensions(i); // If dimension "dim" has negative padding, increase the start index or // decrement the limit index by the amount of negative padding. - start_indices[dim] += - std::max(0LL, -conv_window.dimensions(i).padding_low()); - limit_indices[dim] -= - std::max(0LL, -conv_window.dimensions(i).padding_high()); + if (conv_window->dimensions(i).padding_low() < 0) { + start_indices[dim] += -conv_window->dimensions(i).padding_low(); + conv_window->mutable_dimensions(i)->set_padding_low(0); + } + if (conv_window->dimensions(i).padding_high() < 0) { + limit_indices[dim] -= -conv_window->dimensions(i).padding_high(); + conv_window->mutable_dimensions(i)->set_padding_high(0); + } } input = @@ -140,25 +154,22 @@ bool CudnnConvPaddingLegalization::CanonicalizeForwardConvolution( // Insert slices and/or pads between the convolution and its input and/or // kernel operand. + Window new_conv_window = conv->window(); HloInstruction* new_input = MaybePaddedAndSlicedInput( - conv->window(), conv->convolution_dimension_numbers(), + &new_conv_window, conv->convolution_dimension_numbers(), conv->mutable_operand(0)); HloInstruction* new_kernel = - MaybePaddedKernel(conv->window(), conv->convolution_dimension_numbers(), + MaybePaddedKernel(new_conv_window, conv->convolution_dimension_numbers(), conv->mutable_operand(1)); - // Remove the padding from convolution's window field. These paddings are - // made explicit with the inserted pads. - Window new_conv_window = conv->window(); + // Remove the window dilation from convolution's window field. These paddings + // are made explicit with the pads inserted by MaybePaddedKernel(). for (size_t i = 0; i < new_conv_window.dimensions_size(); ++i) { WindowDimension* dim = new_conv_window.mutable_dimensions(i); // The size of the kernel may have changed so update the Window to match. dim->set_size(new_kernel->shape().dimensions( conv->convolution_dimension_numbers().kernel_spatial_dimensions(i))); - dim->set_padding_low(0); - dim->set_padding_high(0); - dim->set_base_dilation(1); dim->set_window_dilation(1); } diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_conv_rewriter_test.cc b/tensorflow/compiler/xla/service/gpu/cudnn_conv_rewriter_test.cc index 443883a89f66a747def1049bc5afb53fec3c2409..73af18f87aeeedaefac4fc37fb7b6f78f506bb4f 100644 --- a/tensorflow/compiler/xla/service/gpu/cudnn_conv_rewriter_test.cc +++ b/tensorflow/compiler/xla/service/gpu/cudnn_conv_rewriter_test.cc @@ -599,7 +599,7 @@ TEST_F(CudnnConvRewriterTest, BackwardInputConvolveConstantFilter) { Array4D constant_arr(4, 4, 2, 2); constant_arr.FillIota(0); string constant_str = - LiteralUtil::CreateR4FromArray4D(constant_arr).ToString(); + LiteralUtil::CreateR4FromArray4D(constant_arr).ToStringWithoutShape(); const string module_str = absl::StrFormat(R"( HloModule test diff --git a/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc b/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc index 6dcdaf1cfe06e446deed847aaf29088a7ed10e13..2ab754a471070d5f90a3eaebd0600ff180d2fe5d 100644 --- a/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc +++ b/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc @@ -161,6 +161,16 @@ StatusOr GpuElementalIrEmitter::EmitFloatBinaryOp( PrimitiveType lhs_input_type = op->operand(0)->shape().element_type(); PrimitiveType rhs_input_type = op->operand(1)->shape().element_type(); PrimitiveType output_type = op->shape().element_type(); + HloOpcode opcode = op->opcode(); + + if (hlo_module_config_.debug_options().xla_gpu_enable_fast_min_max() && + (opcode == HloOpcode::kMaximum || opcode == HloOpcode::kMinimum)) { + return llvm_ir::EmitCallToIntrinsic( + opcode == HloOpcode::kMaximum ? llvm::Intrinsic::maxnum + : llvm::Intrinsic::minnum, + {lhs_value, rhs_value}, {lhs_value->getType()}, b_); + } + switch (op->opcode()) { case HloOpcode::kRemainder: { return EmitLibdeviceMathCall("__nv_fmod", {lhs_value, rhs_value}, diff --git a/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment_test.cc b/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment_test.cc index 2ffc8bfb49b205dced0d540ba72426e72d95e596..29756d27260b0f41b2dd4b649ea9b1610ff90268 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment_test.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment_test.cc @@ -369,7 +369,7 @@ TEST_F(LayoutAssignmentTest, SortLayout) { const char* hlo_text = R"( HloModule SortLayout ENTRY sort { - keys = f32[3,2]{0,1} constant(f32[3,2]{0,1}{{0,1},{0,1},{0,1}}) + keys = f32[3,2]{0,1} constant({{0,1},{0,1},{0,1}}) values = f32[2,3]{1,0} parameter(0) transpose = f32[3,2]{1,0} transpose(values), dimensions={1,0} ROOT sort = (f32[3,2]{1,0}, f32[3,2]{1,0}) sort(keys, transpose), diff --git a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc index 42fb38dffae31b0f4322216545027e067cab250d..33e41a2782b5932430eea621d3cea2c6634f292f 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc @@ -268,5 +268,17 @@ string CudnnConvKindToString(CudnnConvKind kind) { } } +llvm::Value* IsBlock0Thread0(llvm::IRBuilder<>* b) { + return b->CreateAnd( + b->CreateICmpEQ( + b->getInt32(0), + llvm_ir::EmitCallToIntrinsic( + llvm::Intrinsic::nvvm_read_ptx_sreg_tid_x, {}, {}, b)), + b->CreateICmpEQ( + b->getInt32(0), + llvm_ir::EmitCallToIntrinsic( + llvm::Intrinsic::nvvm_read_ptx_sreg_ctaid_x, {}, {}, b))); +} + } // namespace gpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.h b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.h index f373d4a8393a047aba599b0fae954e98a740161e..ebf4d926b7a280e10b09a2532caba7ad6ab3ceb2 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.h +++ b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.h @@ -155,6 +155,10 @@ llvm::Value* EmitPrintf(absl::string_view fmt, llvm::Value* EmitFullWarpShuffleDown(llvm::Value* value, llvm::Value* offset, llvm::IRBuilder<>* builder); +// Emits code that determines whether the current thread is thread 0 within +// block 0 of the kernel. +llvm::Value* IsBlock0Thread0(llvm::IRBuilder<>* b); + } // namespace gpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter.cc index 31591914cc553f0f5ecd81cb514faa1dc56ea041..6693f66d62d8b04d1b78e001fdb515b34539c67f 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter.cc @@ -63,9 +63,6 @@ IrEmitter::IrEmitter(const HloModuleConfig& hlo_module_config, &ir_emitter_context->buffer_assignment(), &b_, module_, is_nested), hlo_module_config_(hlo_module_config) { - b_.setFastMathFlags(llvm_ir::GetFastMathFlags( - /*fast_math_enabled=*/hlo_module_config.debug_options() - .xla_gpu_enable_fast_math())); } Status IrEmitter::DefaultAction(HloInstruction* hlo) { diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc index bbe1583c01167b3fbb50e066ad59a48e45f5e683..87d16c0afcc3c115f652558b5d8c24606ff56733 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc @@ -2059,8 +2059,16 @@ Status IrEmitterUnnested::EmitTargetElementLoopInThunk( GetIndexTypeForKernel(&hlo, launch_dimensions.launch_bound(), &b_)); } - // For multioutput fusion, we need to emit each operand and the root. + // Emit the tuple pointers in one thread. We could do this at any point in + // the kernel, but we do it at the beginning in the hopes of reducing register + // pressure, since we touch threadIdx.x and blockIdx.x at the beginning of the + // kernel *anyway*. std::vector output_arrays = ConstructIrArrayForOutputs(hlo); + KernelSupportLibrary{&b_}.If("emit_mof_tuple", IsBlock0Thread0(&b_), [&] { + llvm_ir::EmitTuple(GetIrArray(hlo, hlo), output_arrays, &b_, module_); + }); + + // For multioutput fusion, we need to emit each operand and the root. TF_RETURN_IF_ERROR( ParallelLoopEmitter(element_generator, output_arrays, launch_dimensions, &b_, unroll_factor) @@ -2069,8 +2077,6 @@ Status IrEmitterUnnested::EmitTargetElementLoopInThunk( &hlo, launch_dimensions.launch_bound(), &b_))); b_.SetInsertPoint(b_.GetInsertBlock()->getTerminator()); - llvm_ir::EmitTuple(GetIrArray(hlo, hlo), output_arrays, &b_, module_); - return Status::OK(); } @@ -2130,65 +2136,36 @@ int IrEmitterUnnested::ConstructInputReducedShapeAndCastInputIrArrayToShape( namespace { -// Reads thread_idx.x and converts it to a (y,x) coordinate, assuming that the -// thread lives within a square tile of size tile_size (so thread blocks are of -// size tile_size * tile_size). -std::tuple CalculateYXCoordinateWithinTile( - llvm::IRBuilder<>* builder, llvm::Value* tile_size, - int64 threads_per_tile) { - // Calculate the starting element coordinate within a tile for the current - // thread, (y, x) from thread_id. - llvm::Value* thread_id = llvm_ir::EmitCallToIntrinsic( - llvm::Intrinsic::nvvm_read_ptx_sreg_tid_x, {}, {}, builder); - llvm_ir::AddRangeMetadata(0, threads_per_tile, - llvm::cast(thread_id)); - thread_id = builder->CreateIntCast(thread_id, tile_size->getType(), - /*isSigned=*/true, "thread.id.x"); - auto x = builder->CreateURem(thread_id, tile_size); - auto y = builder->CreateUDiv(thread_id, tile_size); - return std::make_tuple(y, x); -} - -// Reads block_idx.x, casts it to type index_ty, and adds the assumption that -// it's in the range [0, num_blocks]. -llvm::Value* GetBlockIdx(llvm::IRBuilder<>* builder, llvm::Type* index_ty, - int64 num_blocks) { - llvm::Value* block_id = llvm_ir::EmitCallToIntrinsic( - llvm::Intrinsic::nvvm_read_ptx_sreg_ctaid_x, {}, {}, builder); - llvm_ir::AddRangeMetadata(0, num_blocks, - llvm::cast(block_id)); - return builder->CreateIntCast(block_id, index_ty, /*isSigned=*/true, - "block.id.x"); -} - -void EmitFullTile(const KernelMappingScheme* mapping_scheme, - const IrArray::Index& tile_origin_index, - llvm::IRBuilder<>* builder, llvm::Value* y, llvm::Value* x, - llvm::Type* index_ty, - const std::function& emit_elem_function) { +void EmitFullElementalTile( + const KernelMappingScheme* mapping_scheme, + const IrArray::Index& tile_origin_index, const string& loop_name, + KernelSupportLibrary* ksl, llvm::IRBuilder<>* builder, llvm::Value* y, + llvm::Value* x, llvm::Type* index_ty, + const std::function& emit_elem_function) { int64 num_threads_x = mapping_scheme->GetNumberOfThreadsForDimensionX(); int64 num_threads_y = mapping_scheme->GetNumberOfThreadsForDimensionY(); int64 tile_size_x = mapping_scheme->GetTileSizeForDimensionX(); int64 tile_size_y = mapping_scheme->GetTileSizeForDimensionY(); - for (int64 i = 0; i < tile_size_y; i += num_threads_y) { - IrArray::Index source_idx_y = - tile_origin_index.AddOffsetToDim(llvm::ConstantInt::get(index_ty, i), - KernelMappingScheme::DimY, builder); - llvm::Value* y_loc = - builder->CreateAdd(llvm::ConstantInt::get(index_ty, i), y); - for (int64 j = 0; j < tile_size_x; j += num_threads_x) { - IrArray::Index source_idx = - source_idx_y.AddOffsetToDim(llvm::ConstantInt::get(index_ty, j), - KernelMappingScheme::DimX, builder); - llvm::Value* x_loc = - builder->CreateAdd(llvm::ConstantInt::get(index_ty, j), x); - emit_elem_function(source_idx, y_loc, x_loc); - } - } -} - -void EmitPartialTile( + ksl->For(loop_name + "_y", /*start=*/llvm::ConstantInt::get(index_ty, 0), + /*end=*/llvm::ConstantInt::get(index_ty, tile_size_y), + /*step=*/llvm::ConstantInt::get(index_ty, num_threads_y), + [&](llvm::Value* y_indvar) { + IrArray::Index source_idx_y = tile_origin_index.AddOffsetToDim( + y_indvar, KernelMappingScheme::DimY, builder); + llvm::Value* y_loc = builder->CreateAdd(y_indvar, y); + for (int64 j = 0; j < tile_size_x; j += num_threads_x) { + IrArray::Index source_idx = source_idx_y.AddOffsetToDim( + llvm::ConstantInt::get(index_ty, j), + KernelMappingScheme::DimX, builder); + llvm::Value* x_loc = + builder->CreateAdd(llvm::ConstantInt::get(index_ty, j), x); + emit_elem_function(source_idx, y_loc, x_loc); + } + }); +} + +void EmitPartialElementalTile( const KernelMappingScheme* mapping_scheme, const IrArray::Index& tile_origin_index, const string& loop_name, KernelSupportLibrary* ksl, llvm::IRBuilder<>* builder, llvm::Value* y, @@ -2207,7 +2184,7 @@ void EmitPartialTile( llvm::Value* x_loc = builder->CreateAdd(llvm::ConstantInt::get(index_ty, j), x); - ksl->IfReturnVoid( + ksl->If( loop_name + "_x_in_tile", builder->CreateICmpULT(x_loc, tile_width), [&] { // tile_height_bound = @@ -2219,13 +2196,13 @@ void EmitPartialTile( llvm::Value* tile_height_bound = builder->CreateMul( ceiling_of_ratio, llvm::ConstantInt::get(index_ty, num_threads_y)); - ksl->ForReturnVoid( + ksl->For( loop_name, /*start=*/llvm::ConstantInt::get(index_ty, 0), /*end=*/tile_height_bound, /*step=*/llvm::ConstantInt::get(index_ty, num_threads_y), [&](llvm::Value* y_indvar) { llvm::Value* y_loc = builder->CreateAdd(y_indvar, y); - ksl->IfReturnVoid( + ksl->If( loop_name + "_y_in_tile", builder->CreateICmpULT(y_loc, tile_height), [&] { emit_elem_function( @@ -2257,7 +2234,7 @@ void EmitTiledElementalCodeWithBoundsCheck( int64 tile_size_y = mapping_scheme->GetTileSizeForDimensionY(); llvm::Type* index_ty = tile_width->getType(); - ksl->IfReturnVoid( + ksl->If( loop_name + "_full_tile", builder->CreateAnd( builder->CreateICmpEQ(llvm::ConstantInt::get(index_ty, tile_size_x), @@ -2265,13 +2242,13 @@ void EmitTiledElementalCodeWithBoundsCheck( builder->CreateICmpEQ(llvm::ConstantInt::get(index_ty, tile_size_y), tile_height)), [&] { - EmitFullTile(mapping_scheme, tile_origin_index, builder, y, x, index_ty, - emit_elem_function); + EmitFullElementalTile(mapping_scheme, tile_origin_index, loop_name, ksl, + builder, y, x, index_ty, emit_elem_function); }, [&] { - EmitPartialTile(mapping_scheme, tile_origin_index, loop_name, ksl, - builder, y, x, tile_height, tile_width, index_ty, - emit_elem_function); + EmitPartialElementalTile(mapping_scheme, tile_origin_index, loop_name, + ksl, builder, y, x, tile_height, tile_width, + index_ty, emit_elem_function); }); } } // namespace @@ -2381,14 +2358,14 @@ class ReductionCodegenInfo : public IrEmitterUnnested::KernelCodegenInfo { AddressVector* GetMutablePartialResultAddresses() { return &partial_result_addresses_; } - const AddressVector& GetPartialResultAddresses() const { + absl::Span GetPartialResultAddresses() const { return partial_result_addresses_; } AddressVector* GetMutableReductionInputAddresses() { return &reduction_input_addresses_; } - const AddressVector& GetReductionInputAddresses() const { + absl::Span GetReductionInputAddresses() const { return reduction_input_addresses_; } @@ -2401,7 +2378,7 @@ class ReductionCodegenInfo : public IrEmitterUnnested::KernelCodegenInfo { InlinedVector* GetMutableReductionOutputShapeIndices() { return &reduction_output_shape_indices_; } - const InlinedVector& GetReductionOutputShapeIndices() const { + absl::Span GetReductionOutputShapeIndices() const { return reduction_output_shape_indices_; } @@ -2556,8 +2533,8 @@ void IrEmitterUnnested::EmitPrologueForReduction( } void IrEmitterUnnested::EmitFullWarpShuffleDownLoopForAllReduces( - const InlinedVector& reducers, - const AddressVector& partial_result_addresses) { + absl::Span reducers, + absl::Span partial_result_addresses) { for (int distance = 16; distance >= 1; distance /= 2) { for (int i = 0; i != reducers.size(); ++i) { llvm::Type* element_type = @@ -2589,11 +2566,11 @@ void IrEmitterUnnested::EmitEpilogueForReduction( ReductionCodegenInfo* reduction_info = static_cast(kernel_info); int num_reduces = reduction_info->GetNumberOfReduces(); - const AddressVector& partial_result_addresses = + absl::Span partial_result_addresses = reduction_info->GetPartialResultAddresses(); const InlinedVector& reducers = reduction_info->GetReducers(); - const InlinedVector& reduction_output_shape_indices = + absl::Span reduction_output_shape_indices = reduction_info->GetReductionOutputShapeIndices(); if (reduction_info->IsRowReduction()) { @@ -2713,9 +2690,9 @@ void IrEmitterUnnested::EmitTileElementForReduction( reduction_info->GetKernelMappingScheme()->GetUnnormalizedIndex( index, GetFirstReduceInstruction(output_instructions)->operand(0)->shape()); - const AddressVector& partial_reduction_result_addresses = + absl::Span partial_reduction_result_addresses = reduction_info->GetPartialResultAddresses(); - const AddressVector& reduction_input_addresses = + absl::Span reduction_input_addresses = reduction_info->GetReductionInputAddresses(); const InlinedVector& reducers = reduction_info->GetReducers(); @@ -2774,15 +2751,14 @@ void IrEmitterUnnested::EmitBlock(const TileGenerator& emit_one_tile, Select(ICmpEQ(last_block_for_dim, block_id_for_dim), last_block_size_for_dim, block_size_for_dim); - ksl.ForReturnVoid( - loop_name, - /*start=*/index_typed_constant(0), - /*end=*/num_tiles_in_block, - /*step=*/1, [&](llvm::Value* block_dim_induction_var) { - IrArray::Index tile_index = starting_tile.AddOffsetToDim( - block_dim_induction_var, dim_id, &b_); - emit_next_block_dim(tile_index); - }); + ksl.For(loop_name, + /*start=*/index_typed_constant(0), + /*end=*/num_tiles_in_block, + /*step=*/1, [&](llvm::Value* block_dim_induction_var) { + IrArray::Index tile_index = starting_tile.AddOffsetToDim( + block_dim_induction_var, dim_id, &b_); + emit_next_block_dim(tile_index); + }); } }; @@ -2864,14 +2840,40 @@ LaunchDimensions IrEmitterUnnested::EmitKernel( << llvm_ir::DumpToString(*param_shmem_buffers[id]); } - LaunchDimensions launch_dimensions = LaunchDimensions( - mapping_scheme->GetNumberOfBlocks(), mapping_scheme->GetThreadsPerTile()); - llvm::Type* index_ty = GetIndexTypeForKernel( - unnested_hlo, launch_dimensions.launch_bound(), &b_); + const ReductionCodegenInfo* reduction_info = + dynamic_cast(kernel_info); + bool is_column_reduction = + (reduction_info && !reduction_info->IsRowReduction()); + + LaunchDimensions launch_dimensions = + LaunchDimensions(mapping_scheme->GetNumberOfBlocks(), + mapping_scheme->GetThreadsPerBlock()); + + // TODO(b/110211620): Enable int32 index type for column reduction. + llvm::Type* index_ty = + is_column_reduction + ? b_.getInt64Ty() + : GetIndexTypeForKernel(unnested_hlo, + launch_dimensions.launch_bound(), &b_); + auto index_typed_constant = [&](uint64 c) -> llvm::Constant* { return llvm::ConstantInt::get(index_ty, c); }; + // For multioutput fusion, one thread needs to output a tuple with pointers to + // all the individual outputs. We could do this at any point in the kernel, + // but we do it at the beginning in the hopes of reducing register pressure, + // since we touch threadIdx.x and blockIdx.x at the beginning of the kernel + // *anyway*. + if (!reduction_info && unnested_hlo->IsMultiOutputFusion()) { + KernelSupportLibrary{&b_}.If( + "emit_mof_tuple", IsBlock0Thread0(&b_), [&] { + llvm_ir::EmitTuple(GetIrArray(*unnested_hlo, *unnested_hlo), + ConstructIrArrayForOutputs(*unnested_hlo), &b_, + module_); + }); + } + // For each tiled parameter, cast its input IrArray to the corresponding // reduced shape and keep the reduced shape live during IR emission. std::vector param_in_reduced_shape_arrays; @@ -2985,15 +2987,6 @@ LaunchDimensions IrEmitterUnnested::EmitKernel( block_epilogue_generator(unnested_hlo, kernel_info); } - // For multioutput fusion, emit a tuple with pointers to all the individual - // outputs. - if (unnested_hlo->IsMultiOutputFusion()) { - std::vector output_arrays = - ConstructIrArrayForOutputs(*unnested_hlo); - llvm_ir::EmitTuple(GetIrArray(*unnested_hlo, *unnested_hlo), output_arrays, - &b_, module_); - } - return launch_dimensions; } @@ -3260,15 +3253,17 @@ std::tuple GetReductionToVectorDimensions( return std::make_tuple(num_reduced_major, num_kept, num_reduced_minor); } -std::tuple ComputeMappingSchemeAndReductionKind( - const HloInstruction* first_reduce, llvm::IRBuilder<>* b) { +} // namespace + +std::tuple +IrEmitterUnnested::ComputeMappingSchemeAndReductionKind( + const HloInstruction* first_reduce) { int64 depth = 1; int64 height = 1; int64 width = 1; bool is_row_reduction = true; int64 tile_size_x = 1; int64 tile_size_y = 1; - int64 block_size_y = 1; int64 block_size_z = 1; int64 num_threads_x = 1; int64 num_threads_y = 1; @@ -3291,14 +3286,17 @@ std::tuple ComputeMappingSchemeAndReductionKind( height = num_reduced_major; width = num_kept; is_row_reduction = false; - tile_size_x = std::min(kWarpSize, num_kept); - // The old Column reduction algorithm uses kTileHeight = 128. We choose - // tile_size_y * block_size_y = 128 to match the value of kTileHeight. Using - // a non-trivial block_size_y here is a way to avoid unrolling all the 128 - // iterations. - tile_size_y = 32; - block_size_y = 4; + // Column reduction without transpose doesn't require communication among + // threads processing elements in the same tile. The current implementation + // only support the use of on hardware thread block to process one block of + // tiles in the KernelMappingScheme. We try to maximize the values of + // num_threads_x and tile_size_x to allow a bigger hardware thread block. + int64 hw_threads_per_block_limit = + ThreadsPerBlockLimit(ir_emitter_context_->device_description()); + tile_size_x = std::min(hw_threads_per_block_limit, num_kept); num_threads_x = tile_size_x; + int64 kNumElementsPerPartialSum = 128; + tile_size_y = kNumElementsPerPartialSum; } else { // Row reduction reduces inputs with dimension [depth, height, width], // where width is the most minor dimension, to dimension [height] . @@ -3321,15 +3319,13 @@ std::tuple ComputeMappingSchemeAndReductionKind( << " " << width; DimensionVector dims_in_elem{depth, height, width}; - DimensionVector req_block_sizes{block_size_z, block_size_y, 1}; - llvm_ir::KernelMappingScheme mapping_scheme(dims_in_elem, tile_size_y, - tile_size_x, req_block_sizes, - num_threads_y, num_threads_x, b); + DimensionVector req_block_sizes{block_size_z, 1, 1}; + llvm_ir::KernelMappingScheme mapping_scheme( + dims_in_elem, tile_size_y, tile_size_x, req_block_sizes, num_threads_y, + num_threads_x, &b_); return std::make_tuple(mapping_scheme, is_row_reduction); } -} // namespace - Status IrEmitterUnnested::EmitReductionToVector(HloInstruction* unnested_hlo) { VLOG(10) << "Emitting reduction to vector " << unnested_hlo->ToString(); @@ -3375,7 +3371,7 @@ Status IrEmitterUnnested::EmitReductionToVector(HloInstruction* unnested_hlo) { bool is_row_reduction; llvm_ir::KernelMappingScheme mapping_scheme; std::tie(mapping_scheme, is_row_reduction) = - ComputeMappingSchemeAndReductionKind(first_reduce, &b_); + ComputeMappingSchemeAndReductionKind(first_reduce); ReductionCodegenInfo reduction_info(&mapping_scheme, is_row_reduction); KernelCodeGenerator kernel_generator( /*tile_element_generator=*/ diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h index 85a0e5328c4e436d4522593b38421efc87c42d32..1ebea7ab48664e693937b45561d096f7ec15132f 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h @@ -215,6 +215,11 @@ class IrEmitterUnnested : public IrEmitter { // Prerequisite: `IsReductionToVector(*unnested_hlo)` Status EmitReductionToVector(HloInstruction* unnested_hlo); + // Computes the KernelMappingScheme for the reduce HLO and indicates whether + // the reduction is a row reduction. + std::tuple + ComputeMappingSchemeAndReductionKind(const HloInstruction* first_reduce); + // Emits code for an in-place scatter, modifying `thunk`s launch dimensions in // the process. `scatter` may be fused, scatter indices are taken from // `scatter_indices_gen`, updates from`updates_gen`. The output buffer is @@ -272,9 +277,8 @@ class IrEmitterUnnested : public IrEmitter { // For each reducer, emits the shuffle-down loop to accumulate the partial // result to the global result. void EmitFullWarpShuffleDownLoopForAllReduces( - const absl::InlinedVector& reducers, - const absl::InlinedVector& - partial_result_addresses); + absl::Span reducers, + absl::Span partial_result_addresses); // Generates the IrArray for each input of an hlo and returns a vector that // constains such IrArrays. diff --git a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/nvptx_backend_lib.cc b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/nvptx_backend_lib.cc index 364f69a69d47644b383af9cf6865c93360b82bab..bd53b90b42d8e657a3ee58e7ca03fb60522aae28 100644 --- a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/nvptx_backend_lib.cc +++ b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/nvptx_backend_lib.cc @@ -177,13 +177,6 @@ std::unique_ptr GetTargetMachine( } TargetOptions target_options = InitTargetOptionsFromCodeGenFlags(); - llvm_ir::SetTargetOptions( - /*fast_math_enabled=*/hlo_module_config.debug_options() - .xla_gpu_enable_fast_math(), - &target_options); - - // Enable FMA synthesis. - target_options.AllowFPOpFusion = FPOpFusion::Fast; // Set the verbose assembly options. target_options.MCOptions.AsmVerbose = false; @@ -206,8 +199,7 @@ std::unique_ptr GetTargetMachine( } return absl::WrapUnique(target->createTargetMachine( triple.str(), llvm_ir::AsStringRef(cpu_name), "+ptx60", target_options, - Optional(RelocModel), Optional(CMModel), - codegen_opt_level)); + getRelocModel(), getCodeModel(), codegen_opt_level)); } // Adds the standard LLVM optimization passes, based on the speed optimization @@ -401,8 +393,16 @@ StatusOr CompileModuleToPtx(llvm::Module* module, int32 opt_level = hlo_module_config.debug_options().xla_backend_optimization_level(); - CHECK_GE(opt_level, 2) - << "The XLA GPU backend doesn't support unoptimized code generation"; + if (opt_level < 2) { + LOG(ERROR) << std::string(80, '*'); + LOG(ERROR) << "The XLA GPU backend doesn't support unoptimized code " + "generation but "; + LOG(ERROR) << "--xla_backend_optimization_level is set to " << opt_level + << "!"; + LOG(ERROR) << "(Supported configuration is " + "--xla_backend_optimization_level >= 2.)"; + LOG(ERROR) << std::string(80, '*'); + } AddOptimizationPasses(opt_level, /*size_level=*/0, target_machine.get(), &module_passes, @@ -465,6 +465,9 @@ void GPUBackendInit(const HloModuleConfig& hlo_module_config) { // between those loads. FeedLLVMWithFlags({"-memdep-block-scan-limit=500"}); + // Use div.approx -- it matters for some float-division heavy benchmarks. + FeedLLVMWithFlags({"-nvptx-prec-divf32=0"}); + llvm_ir::InitializeLLVMCommandLineOptions(hlo_module_config); // Initialize the NVPTX target; it's the only target we link with, so call its diff --git a/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc b/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc index 637b861f70235f17e8e739907a3f262b7004ee7c..60f2116e6088fd2c5d3400b4463cb7fa8bbadfdc 100644 --- a/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc +++ b/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc @@ -108,27 +108,33 @@ namespace { namespace tracing = tensorflow::tracing; -// Returns the directory containing nvvm libdevice files. config_cuda_data_dir -// should be equal to config().debug_options().xla_gpu_cuda_data_dir() of the -// HloModule being compiled. -string GetLibdeviceDir(const string& config_cuda_data_dir) { - std::vector potential_libdevice_dirs; - if (!config_cuda_data_dir.empty()) { - potential_libdevice_dirs.push_back(config_cuda_data_dir); - } - potential_libdevice_dirs.push_back(tensorflow::LibdeviceRoot()); - - // Tries all potential libdevice directories in the order they are inserted. - // Returns the first directory that exists in the file system. - for (const string& potential_libdevice_dir : potential_libdevice_dirs) { - if (tensorflow::Env::Default()->IsDirectory(potential_libdevice_dir).ok()) { - VLOG(2) << "Found libdevice dir " << potential_libdevice_dir; - return potential_libdevice_dir; +// Returns a vector of potential locations of the CUDA root directory. +std::vector GetCudaRootCandidates( + const HloModuleConfig& hlo_module_config) { + std::vector potential_cuda_roots = tensorflow::CandidateCudaRoots(); + + // CUDA location explicitly specified by user via --xla_gpu_cuda_data_dir has + // highest priority. + string xla_gpu_cuda_data_dir = + hlo_module_config.debug_options().xla_gpu_cuda_data_dir(); + if (!xla_gpu_cuda_data_dir.empty()) { + potential_cuda_roots.insert(potential_cuda_roots.begin(), + xla_gpu_cuda_data_dir); + } + return potential_cuda_roots; +} + +// Returns the directory containing nvvm libdevice files. +string GetLibdeviceDir(const HloModuleConfig& hlo_module_config) { + for (const string& cuda_root : GetCudaRootCandidates(hlo_module_config)) { + string libdevice_dir = + tensorflow::io::JoinPath(cuda_root, "nvvm", "libdevice"); + VLOG(2) << "Looking for libdevice at " << libdevice_dir; + if (tensorflow::Env::Default()->IsDirectory(libdevice_dir).ok()) { + VLOG(2) << "Found libdevice dir " << libdevice_dir; + return libdevice_dir; } - VLOG(2) << "Unable to find potential libdevice dir " - << potential_libdevice_dir; } - LOG(WARNING) << "Unable to find libdevice dir. Using '.'"; // Last resort: maybe in the current folder. return "."; @@ -143,7 +149,6 @@ Status OptimizeHloModule(HloModule* hlo_module, se::StreamExecutor* stream_exec, Compiler* compiler) { { HloPassPipeline pipeline("optimization"); - pipeline.AddPass(); pipeline.AddInvariantChecker(/*layout_sensitive=*/false, /*allow_mixed_precision=*/false); pipeline.AddPass(); @@ -175,6 +180,8 @@ Status OptimizeHloModule(HloModule* hlo_module, se::StreamExecutor* stream_exec, /*rewrite_inference_op=*/true, /*rewrite_grad_op=*/true); + pipeline.AddPass(); + // BatchNormExpander can create zero-sized ops, so zero-sized HLO // elimination has to come after that pass. pipeline.AddPass(); @@ -477,13 +484,19 @@ void WarnIfBadDriverJITVersion() { // Compiles the given PTX string using ptxas and returns the resulting machine // code (i.e. a cubin) as a byte array. -StatusOr> CompilePtx(const string& ptx, int cc_major, - int cc_minor) { +StatusOr> CompilePtx( + const string& ptx, int cc_major, int cc_minor, + const HloModuleConfig& hlo_module_config) { tracing::ScopedActivity activity("Compile PTX", /*is_expensive=*/true); - const string ptxas_path = - tensorflow::io::JoinPath(tensorflow::CudaRoot(), "bin", "ptxas"); - VLOG(2) << "Checking ptxas at " << ptxas_path; auto env = tensorflow::Env::Default(); + string ptxas_path; + for (const string& cuda_root : GetCudaRootCandidates(hlo_module_config)) { + ptxas_path = tensorflow::io::JoinPath(cuda_root, "bin", "ptxas"); + VLOG(2) << "Looking for ptxas at " << ptxas_path; + if (env->FileExists(ptxas_path).ok()) { + break; + } + } TF_RETURN_IF_ERROR(env->FileExists(ptxas_path)); VLOG(2) << "Using ptxas at " << ptxas_path; @@ -518,6 +531,9 @@ StatusOr> CompilePtx(const string& ptx, int cc_major, if (VLOG_IS_ON(2)) { ptxas_args.push_back("-v"); } + if (hlo_module_config.debug_options().xla_gpu_disable_ptxas_optimizations()) { + ptxas_args.push_back("-O0"); + } ptxas_info_dumper.SetProgram(ptxas_path, ptxas_args); ptxas_info_dumper.SetChannelAction(tensorflow::CHAN_STDERR, tensorflow::ACTION_PIPE); @@ -680,12 +696,8 @@ StatusOr> NVPTXCompiler::RunBackend( // Find the directory containing libdevice. To avoid searching for it every // time, we have a one-element cache, keyed on the module's config's // cuda_data_dir. - const auto& config_cuda_data_dir = - module->config().debug_options().xla_gpu_cuda_data_dir(); - if (cached_libdevice_dir_.empty() || - cached_cuda_data_dir_ != config_cuda_data_dir) { - cached_cuda_data_dir_ = config_cuda_data_dir; - cached_libdevice_dir_ = GetLibdeviceDir(config_cuda_data_dir); + if (cached_libdevice_dir_.empty()) { + cached_libdevice_dir_ = GetLibdeviceDir(module->config()); } libdevice_dir = cached_libdevice_dir_; } @@ -739,7 +751,7 @@ StatusOr> NVPTXCompiler::RunBackend( } const std::vector cubin = - CompilePtxOrGetCachedResult(ptx, cc_major, cc_minor); + CompilePtxOrGetCachedResult(ptx, cc_major, cc_minor, module->config()); auto thunk_schedule = absl::make_unique( ir_emitter.ConsumeThunkSequence(), std::move(stream_assignment), @@ -771,9 +783,9 @@ StatusOr> NVPTXCompiler::RunBackend( return std::unique_ptr(gpu_executable); } -std::vector NVPTXCompiler::CompilePtxOrGetCachedResult(const string& ptx, - int cc_major, - int cc_minor) { +std::vector NVPTXCompiler::CompilePtxOrGetCachedResult( + const string& ptx, int cc_major, int cc_minor, + const HloModuleConfig& hlo_module_config) { XLA_SCOPED_LOGGING_TIMER("NVPTXCompiler::CompilePtxOrGetCachedResult"); tracing::ScopedActivity activity("PTX->CUBIN", /*is_expensive=*/true); bool inserted; @@ -802,7 +814,7 @@ std::vector NVPTXCompiler::CompilePtxOrGetCachedResult(const string& ptx, CHECK(!cache_value->compilation_done); if (!ptx.empty()) { StatusOr> maybe_cubin = - CompilePtx(*cache_ptx, cc_major, cc_minor); + CompilePtx(*cache_ptx, cc_major, cc_minor, hlo_module_config); if (maybe_cubin.ok()) { cache_value->cubin_data = std::move(maybe_cubin).ValueOrDie(); VLOG(2) << "Compiled PTX size:" << ptx.size() diff --git a/tensorflow/compiler/xla/service/gpu/nvptx_compiler.h b/tensorflow/compiler/xla/service/gpu/nvptx_compiler.h index f79ae2990ae7d6e6985b15727a72358289121aa9..b2077f42fd097330703fde063d80a20704fa48e2 100644 --- a/tensorflow/compiler/xla/service/gpu/nvptx_compiler.h +++ b/tensorflow/compiler/xla/service/gpu/nvptx_compiler.h @@ -97,8 +97,9 @@ class NVPTXCompiler : public LLVMCompiler { // Tries to compile the given ptx string to cubin. Returns a vector with the // compiled cubin. If compilation was unsuccessful, returns an empty vector. - std::vector CompilePtxOrGetCachedResult(const string& ptx, - int cc_major, int cc_minor); + std::vector CompilePtxOrGetCachedResult( + const string& ptx, int cc_major, int cc_minor, + const HloModuleConfig& hlo_module_config); // The compilation_cache_ map is a cache from {ptx string, cc_major, cc_minor} // -> cubin so we don't recompile the same ptx twice. This is important for diff --git a/tensorflow/compiler/xla/service/gpu/partition_assignment.cc b/tensorflow/compiler/xla/service/gpu/partition_assignment.cc index 375f68a15957936151aee068582a714b62694af2..bfed4f5230dfe37bca48560ce83a2dd82c8950a4 100644 --- a/tensorflow/compiler/xla/service/gpu/partition_assignment.cc +++ b/tensorflow/compiler/xla/service/gpu/partition_assignment.cc @@ -39,6 +39,25 @@ std::ostream& operator<<(std::ostream& out, return out; } +int64 ThreadsPerBlockLimit(const se::DeviceDescription& device_desc) { + int64 threads_per_block = device_desc.threads_per_block_limit(); + if (threads_per_block == 0) { + static std::atomic log_count{0}; + if (log_count.fetch_add(1) < 8) { + LOG(WARNING) << "Attempting to calculate launch dimensions for GPU " + "without full information about its capabilities. " + "StreamExecutor's PopulateDeviceDescription should be " + "updated for this device."; + } + threads_per_block = device_desc.threads_per_warp(); + if (threads_per_block == 0) { + // Fall back to *something* if we can't even get num threads per warp. + threads_per_block = 32; + } + } + return threads_per_block; +} + // Calculates the launch dimensions used to invoke `hlo`. LaunchDimensions CalculateLaunchDimensions( const Shape& shape, const se::DeviceDescription& device_desc, @@ -62,21 +81,7 @@ LaunchDimensions CalculateLaunchDimensions( // // * = - int64 threads_per_block = device_desc.threads_per_block_limit(); - if (threads_per_block == 0) { - static std::atomic log_count{0}; - if (log_count.fetch_add(1) < 8) { - LOG(WARNING) << "Attempting to calculate launch dimensions for GPU " - "without full information about its capabilities. " - "StreamExecutor's PopulateDeviceDescription should be " - "updated for this device."; - } - threads_per_block = device_desc.threads_per_warp(); - if (threads_per_block == 0) { - // Fall back to *something* if we can't even get num threads per warp. - threads_per_block = 32; - } - } + int64 threads_per_block = ThreadsPerBlockLimit(device_desc); if (num_elements < threads_per_block) { threads_per_block = num_elements; diff --git a/tensorflow/compiler/xla/service/gpu/partition_assignment.h b/tensorflow/compiler/xla/service/gpu/partition_assignment.h index 02471129e004b4876ce20a62cade34060c65b478..eb41dcccb938ccc088c2371def96ca73276771ab 100644 --- a/tensorflow/compiler/xla/service/gpu/partition_assignment.h +++ b/tensorflow/compiler/xla/service/gpu/partition_assignment.h @@ -57,6 +57,9 @@ class LaunchDimensions { std::ostream& operator<<(std::ostream& out, const LaunchDimensions& launch_dims); +// Returns the maximum number of threads per block allowed by the device. +int64 ThreadsPerBlockLimit(const se::DeviceDescription& device_desc); + LaunchDimensions CalculateLaunchDimensions( const Shape& shape, const se::DeviceDescription& device_desc, int unroll_factor = 1); diff --git a/tensorflow/compiler/xla/service/gpu/stream_executor_util.h b/tensorflow/compiler/xla/service/gpu/stream_executor_util.h index 1fc46bafa10e7ba6c896f081d5c836bd400886c9..92e4d6dbbc1bd564657f8a5de09d23d5ae81a93e 100644 --- a/tensorflow/compiler/xla/service/gpu/stream_executor_util.h +++ b/tensorflow/compiler/xla/service/gpu/stream_executor_util.h @@ -16,6 +16,7 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_STREAM_EXECUTOR_UTIL_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_STREAM_EXECUTOR_UTIL_H_ +#include "tensorflow/compiler/xla/layout.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" diff --git a/tensorflow/compiler/xla/service/gpu/tests/gpu_ftz_test.cc b/tensorflow/compiler/xla/service/gpu/tests/gpu_ftz_test.cc index d0ccd8619bde9ddd560989380b403efed5c5f42c..5e524faab18947f5793dc2ae34e9329a446d4235 100644 --- a/tensorflow/compiler/xla/service/gpu/tests/gpu_ftz_test.cc +++ b/tensorflow/compiler/xla/service/gpu/tests/gpu_ftz_test.cc @@ -75,16 +75,16 @@ class GpuFtzDisabledTest : public GpuFtzTest { // Check that we emit mul.ftz.f32 when in ftz mode, and plain mul.f32 otherwise. TEST_F(GpuFtzEnabledTest, MultiplyFtz) { CompileAndVerifyPtx(CreateBinaryOpModule(HloOpcode::kMultiply), R"( - CHECK-NOT: mul.f32 - CHECK: mul.ftz.f32 - CHECK-NOT: mul.f32 + CHECK-NOT: mul.rn.f32 + CHECK: mul.rn.ftz.f32 + CHECK-NOT: mul.rn.f32 )"); } TEST_F(GpuFtzDisabledTest, MultiplyFtz) { CompileAndVerifyPtx(CreateBinaryOpModule(HloOpcode::kMultiply), R"( - CHECK-NOT: mul.ftz.f32 - CHECK: mul.f32 - CHECK-NOT: mul.ftz.f32 + CHECK-NOT: mul.rn.ftz.f32 + CHECK: mul.rn.f32 + CHECK-NOT: mul.rn.ftz.f32 )"); } diff --git a/tensorflow/compiler/xla/service/hlo_computation.cc b/tensorflow/compiler/xla/service/hlo_computation.cc index ff122b529bdcdcc69d2245136e19101902dbf957..ca663b8b4a970900a4a899a7ad9d33dc45af9d99 100644 --- a/tensorflow/compiler/xla/service/hlo_computation.cc +++ b/tensorflow/compiler/xla/service/hlo_computation.cc @@ -711,8 +711,6 @@ bool HloComputation::operator==(const HloComputation& other) const { return eq(root_instruction(), other.root_instruction()); } -uint64 HloComputation::Hash() const { return root_instruction()->Hash(); } - Status HloComputation::ReplaceWithNewInstruction( HloInstruction* old_instruction, std::unique_ptr new_instruction) { @@ -797,7 +795,7 @@ Status HloComputation::AcceptWithOperandOrder( template Status HloComputation::AcceptOrdered( DfsHloVisitorBase* visitor, - const std::vector& order) const { + absl::Span order) const { VLOG(3) << "Accepting visitor with order."; for (HloInstruction* root : CollectUnreachableRoots()) { TF_RET_CHECK(std::find(order.begin(), order.end(), root) != order.end()) @@ -827,9 +825,9 @@ Status HloComputation::AcceptOrdered( // Explicit instantiations. template Status HloComputation::AcceptOrdered( - DfsHloVisitor*, const std::vector&) const; + DfsHloVisitor*, absl::Span) const; template Status HloComputation::AcceptOrdered( - ConstDfsHloVisitor*, const std::vector&) const; + ConstDfsHloVisitor*, absl::Span) const; Status HloComputation::Accept( const std::function& visitor_func) { diff --git a/tensorflow/compiler/xla/service/hlo_computation.h b/tensorflow/compiler/xla/service/hlo_computation.h index c584e4c7ca5770533f28352b0df9dadd9dbe1860..5467d0a68b18170891dcd9f67e44d3bb269bf920 100644 --- a/tensorflow/compiler/xla/service/hlo_computation.h +++ b/tensorflow/compiler/xla/service/hlo_computation.h @@ -264,12 +264,6 @@ class HloComputation { // Return whether `*this` and `other` are functionally equivalent. bool operator==(const HloComputation& other) const; - // Generates a hash value of an HLO computation. Hash considers - // information on opcode, shape, operands, and typically a root instruction. - // This function returns the same hash value for equivalent HLO computations, - // with respect to HloInstruction::Identical() method. - uint64 Hash() const; - // Replaces old instruction with newly created instruction. Removes old // instruction from computation. Updates uses and root instruction. Status ReplaceWithNewInstruction( @@ -307,7 +301,7 @@ class HloComputation { // be a topological sort of all instructions in the computation. template Status AcceptOrdered(DfsHloVisitorBase* visitor, - const std::vector& order) const; + absl::Span order) const; // Same as Accept() above, but the visitor is given as a function. Status Accept(const std::function& visitor_func); diff --git a/tensorflow/compiler/xla/service/hlo_computation_test.cc b/tensorflow/compiler/xla/service/hlo_computation_test.cc index 8b50cfa9aed90091cfbedc1df902440ec9bf2a80..0361c87428f6e4c031d95492a5bc782ad388e5b5 100644 --- a/tensorflow/compiler/xla/service/hlo_computation_test.cc +++ b/tensorflow/compiler/xla/service/hlo_computation_test.cc @@ -20,19 +20,19 @@ limitations under the License. #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" -#include "tensorflow/compiler/xla/service/hlo_matchers.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/service/pattern_matcher.h" +#include "tensorflow/compiler/xla/service/pattern_matcher_gmock.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/test_helpers.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" -namespace op = xla::testing::opcode_matchers; - namespace xla { namespace { +namespace m = match; using ::testing::ElementsAre; using ::testing::UnorderedElementsAre; @@ -261,7 +261,7 @@ TEST_F(HloComputationTest, DeepCopyArray) { auto computation = module->AddEntryComputation(builder.Build()); auto copy = computation->DeepCopyInstruction(constant).ValueOrDie(); - EXPECT_THAT(copy, op::Copy(constant)); + EXPECT_THAT(copy, GmockMatch(m::Copy(m::Op().Is(constant)))); } TEST_F(HloComputationTest, DeepCopyTuple) { @@ -278,8 +278,9 @@ TEST_F(HloComputationTest, DeepCopyTuple) { auto computation = module->AddEntryComputation(builder.Build()); auto tuple_copy = computation->DeepCopyInstruction(tuple).ValueOrDie(); - EXPECT_THAT(tuple_copy, op::Tuple(op::Copy(op::GetTupleElement(tuple)), - op::Copy(op::GetTupleElement(tuple)))); + EXPECT_THAT(tuple_copy, GmockMatch(m::Tuple( + m::Copy(m::GetTupleElement(m::Op().Is(tuple))), + m::Copy(m::GetTupleElement(m::Op().Is(tuple)))))); EXPECT_EQ(0, tuple_copy->operand(0)->operand(0)->tuple_index()); EXPECT_EQ(1, tuple_copy->operand(1)->operand(0)->tuple_index()); } @@ -297,7 +298,7 @@ TEST_F(HloComputationTest, DeepCopyArrayAtIndices) { ShapeTree indices_to_copy(constant->shape(), /*init_value=*/true); EXPECT_THAT(computation->DeepCopyInstruction(constant, &indices_to_copy) .ValueOrDie(), - op::Copy(constant)); + GmockMatch(m::Copy(m::Op().Is(constant)))); } { @@ -330,10 +331,11 @@ TEST_F(HloComputationTest, DeepCopyTupleAtIndices) { computation->DeepCopyInstruction(tuple, &indices_to_copy, &copies_added) .ValueOrDie(); - EXPECT_THAT(deep_copy, op::Tuple(op::Copy(op::GetTupleElement(tuple)), - op::Copy(op::GetTupleElement(tuple)))); - EXPECT_THAT(deep_copy, op::Tuple(copies_added.element({0}), - copies_added.element({1}))); + EXPECT_THAT(deep_copy, GmockMatch(m::Tuple( + m::Copy(m::GetTupleElement(m::Op().Is(tuple))) + .Is(copies_added.element({0})), + m::Copy(m::GetTupleElement(m::Op().Is(tuple))) + .Is(copies_added.element({1}))))); } { @@ -346,8 +348,9 @@ TEST_F(HloComputationTest, DeepCopyTupleAtIndices) { computation->DeepCopyInstruction(tuple, &indices_to_copy, &copies_added) .ValueOrDie(); - EXPECT_THAT(deep_copy, op::Tuple(op::GetTupleElement(tuple), - op::GetTupleElement(tuple))); + EXPECT_THAT(deep_copy, + GmockMatch(m::Tuple(m::GetTupleElement(m::Op().Is(tuple)), + m::GetTupleElement(m::Op().Is(tuple))))); EXPECT_TRUE(copies_added.element({}) == nullptr); EXPECT_TRUE(copies_added.element({0}) == nullptr); EXPECT_TRUE(copies_added.element({1}) == nullptr); @@ -363,8 +366,9 @@ TEST_F(HloComputationTest, DeepCopyTupleAtIndices) { computation->DeepCopyInstruction(tuple, &indices_to_copy, &copies_added) .ValueOrDie(); - EXPECT_THAT(deep_copy, op::Tuple(op::Copy(op::GetTupleElement(tuple)), - op::GetTupleElement(tuple))); + EXPECT_THAT(deep_copy, GmockMatch(m::Tuple( + m::Copy(m::GetTupleElement(m::Op().Is(tuple))), + m::GetTupleElement(m::Op().Is(tuple))))); EXPECT_TRUE(copies_added.element({}) == nullptr); EXPECT_TRUE(copies_added.element({0}) != nullptr); EXPECT_TRUE(copies_added.element({1}) == nullptr); @@ -381,7 +385,7 @@ TEST_F(HloComputationTest, DeepCopyToken) { auto copy = computation->DeepCopyInstruction(token).ValueOrDie(); // No copy should be added. - EXPECT_THAT(copy, op::AfterAll()); + EXPECT_THAT(copy, GmockMatch(m::AfterAll())); } TEST_F(HloComputationTest, DeepCopyTokenTuple) { @@ -399,8 +403,9 @@ TEST_F(HloComputationTest, DeepCopyTokenTuple) { // Only the array (second tuple element) should be copied. The token is passed // through transparently. - EXPECT_THAT(copy, op::Tuple(op::GetTupleElement(tuple), - op::Copy(op::GetTupleElement(tuple)))); + EXPECT_THAT(copy, GmockMatch(m::Tuple( + m::GetTupleElement(m::Op().Is(tuple)), + m::Copy(m::GetTupleElement(m::Op().Is(tuple)))))); } TEST_F(HloComputationTest, CycleDetection) { @@ -443,13 +448,15 @@ TEST_F(HloComputationTest, RemoveInstructionWithDuplicateOperand) { auto module = CreateNewVerifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); EXPECT_EQ(4, computation->instruction_count()); - EXPECT_THAT(computation->root_instruction(), op::Negate(constant)); + EXPECT_THAT(computation->root_instruction(), + GmockMatch(m::Negate(m::Op().Is(constant)))); EXPECT_EQ(negate, computation->root_instruction()); ASSERT_IS_OK(computation->RemoveInstructionAndUnusedOperands(dead_add)); EXPECT_EQ(2, computation->instruction_count()); - EXPECT_THAT(computation->root_instruction(), op::Negate(constant)); + EXPECT_THAT(computation->root_instruction(), + GmockMatch(m::Negate(m::Op().Is(constant)))); EXPECT_EQ(negate, computation->root_instruction()); } diff --git a/tensorflow/compiler/xla/service/hlo_constant_folding_test.cc b/tensorflow/compiler/xla/service/hlo_constant_folding_test.cc index 4f81dc94e577a63c09ae4019e5e8158252c712ce..92b748d813c3efef83ef0155f1d5d3c637ce2c57 100644 --- a/tensorflow/compiler/xla/service/hlo_constant_folding_test.cc +++ b/tensorflow/compiler/xla/service/hlo_constant_folding_test.cc @@ -252,7 +252,7 @@ const char* const kConstantFoldLargePad = R"( HloModule ConstantFoldLargePad ENTRY r { - a = f32[1,1,1] constant(f32[1,1,1]{{{7}}}) + a = f32[1,1,1] constant({{{7}}}) b = f32[] constant(42) ROOT pad = f32[2048,2048,128] pad(a, b), padding=1024_1023x1024_1023x64_63 })"; diff --git a/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc b/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc index f7a1f19a6f52befd58a405d0e406d7d0d37a8e57..94de7c55dd2402e55ec344b79c24af2d8283fe73 100644 --- a/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc +++ b/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc @@ -1882,8 +1882,8 @@ TEST_P(HloDataflowAnalysisTest, AddDependency) { HloModule AddDependency ENTRY %AddDependency (p: f32[3]) -> f32[3] { %p = f32[3] parameter(0) - %token = token[] after-all() - ROOT %add_dep = f32[3] add-dependency(f32[3] %p, token[] %token) + %token0 = token[] after-all() + ROOT %add_dep = f32[3] add-dependency(f32[3] %p, token[] %token0) } )"; TF_ASSERT_OK_AND_ASSIGN( diff --git a/tensorflow/compiler/xla/service/hlo_domain_test.cc b/tensorflow/compiler/xla/service/hlo_domain_test.cc index acdb42128e3d9a1fb912a466c9c2c3cbbe3d3f83..fd4fb0246d8d42ab7329c05dc23e386303cdce3c 100644 --- a/tensorflow/compiler/xla/service/hlo_domain_test.cc +++ b/tensorflow/compiler/xla/service/hlo_domain_test.cc @@ -195,10 +195,10 @@ HloModule Module ENTRY entry { p0 = (f32[4]) parameter(0) a = f32[4] get-tuple-element(p0), index=0 - token = token[] after-all() - b = (f32[4], u32[], token[]) send(a, token), channel_id=1, sharding={maximal device=0} + token0 = token[] after-all() + b = (f32[4], u32[], token[]) send(a, token0), channel_id=1, sharding={maximal device=0} c = token[] send-done(b), channel_id=1, sharding={maximal device=0} - d = (f32[4], u32[], token[]) recv(token), channel_id=2, sharding={maximal device=0} + d = (f32[4], u32[], token[]) recv(token0), channel_id=2, sharding={maximal device=0} e = (f32[4], token[]) recv-done(d), channel_id=2, sharding={maximal device=0} e_element = f32[4] get-tuple-element(e), index=0, sharding={maximal device=0} f = f32[4] add(a, e_element) @@ -235,12 +235,12 @@ TEST_F(HloDomainTest, CheckNoDomainAddedOnPureIOComputation) { HloModule Module ENTRY entry { - token = token[] after-all(), sharding={maximal device=-1} - a = (f32[4], u32[], token[]) recv(token), channel_id=1, sharding={maximal device=-1} + token0 = token[] after-all(), sharding={maximal device=-1} + a = (f32[4], u32[], token[]) recv(token0), channel_id=1, sharding={maximal device=-1} b = (f32[4], token[]) recv-done(a), channel_id=1, sharding={maximal device=-1} b_element = f32[4] get-tuple-element(b), index=0, sharding={maximal device=-1} c = f32[4] add(b_element, b_element), sharding={maximal device=-1} - d = (f32[4], u32[], token[]) send(c, token), channel_id=2, sharding={maximal device=-1} + d = (f32[4], u32[], token[]) send(c, token0), channel_id=2, sharding={maximal device=-1} ROOT e = token[] send-done(d), channel_id=2, sharding={maximal device=-1} } )"; @@ -259,12 +259,12 @@ TEST_F(HloDomainTest, CheckNormalizationOnPureIOComputation) { HloModule Module ENTRY entry { - token = token[] after-all(), sharding={maximal device=0} - a = (f32[4], u32[], token[]) recv(token), channel_id=1, sharding={maximal device=0} + token0 = token[] after-all(), sharding={maximal device=0} + a = (f32[4], u32[], token[]) recv(token0), channel_id=1, sharding={maximal device=0} b = (f32[4], token[]) recv-done(a), channel_id=1, sharding={maximal device=0} b_element = f32[4] get-tuple-element(b), index=0, sharding={maximal device=0} c = f32[4] add(b_element, b_element) - d = (f32[4], u32[], token[]) send(c, token), channel_id=2, sharding={maximal device=0} + d = (f32[4], u32[], token[]) send(c, token0), channel_id=2, sharding={maximal device=0} ROOT e = token[] send-done(d), channel_id=2, sharding={maximal device=0} } )"; @@ -344,8 +344,8 @@ TEST_F(HloDomainTest, CheckNormalizationOnInfeedTuple) { HloModule Module ENTRY entry { - token = token[] after-all() - infeed = ((f32[4], f32[4]), token[]) infeed(token), + token0 = token[] after-all() + infeed = ((f32[4], f32[4]), token[]) infeed(token0), sharding={{maximal device=1}, {maximal device=0}, {maximal device=0}} infeed.data = (f32[4], f32[4]) get-tuple-element(infeed), index=0, sharding={{maximal device=1}, {maximal device=0}} diff --git a/tensorflow/compiler/xla/service/hlo_element_type_converter_test.cc b/tensorflow/compiler/xla/service/hlo_element_type_converter_test.cc index c170e36c73ad2bef830e528de3ec72d38683d888..a3b56a44a0b02923585c1dcb69571479236188a3 100644 --- a/tensorflow/compiler/xla/service/hlo_element_type_converter_test.cc +++ b/tensorflow/compiler/xla/service/hlo_element_type_converter_test.cc @@ -57,10 +57,10 @@ TEST_F(HloElementTypeConverterTest, InfeedsOutfeedsNotConverted) { const string& hlo_string = R"( HloModule InfeedOutfeed ENTRY RoundTrip16MiBR1.v2 { - token = token[] after-all() - infeed = (bf16[4]{0}, token[]) infeed(token) + token0 = token[] after-all() + infeed = (bf16[4]{0}, token[]) infeed(token0) ROOT infeed.data = bf16[4]{0} get-tuple-element(infeed), index=0 - outfeed = token[] outfeed(infeed.data, token) + outfeed = token[] outfeed(infeed.data, token0) } )"; auto module = CreateModuleFromHloString(hlo_string); @@ -96,13 +96,13 @@ TEST_F(HloElementTypeConverterTest, BatchNormGradBF16Converted) { const string& hlo_string = R"( HloModule BatchNormGrad ENTRY BatchNormGrad.v6 { - constant.4 = bf16[2,2,2,1]{3,2,1,0} constant(bf16[2,2,2,1] { { /*i0=0*/ + constant.4 = bf16[2,2,2,1]{3,2,1,0} constant({ { /*i0=0*/ { /*i1=0*/ {0}, {0} }, { /*i1=1*/ {0}, {0} } }, { /*i0=1*/ { /*i1=0*/ {0}, {0} }, { /*i1=1*/ {0}, {0} } } }) constant.5 = bf16[2]{0} constant({1, 1}) constant.6 = bf16[2]{0} constant({0, 0}) constant.7 = bf16[2]{0} constant({1, 1}) - constant.8 = bf16[2,2,2,1]{3,2,1,0} constant(bf16[2,2,2,1] { { /*i0=0*/ + constant.8 = bf16[2,2,2,1]{3,2,1,0} constant({ { /*i0=0*/ { /*i1=0*/ {1}, {2} }, { /*i1=1*/ {3}, {4} } }, { /*i0=1*/ { /*i1=0*/ {5}, {6} }, { /*i1=1*/ {7}, {8} } } }) ROOT batch-norm-grad = (bf16[2,2,2,1]{3,2,1,0}, bf16[2]{0}, bf16[2]{0}) diff --git a/tensorflow/compiler/xla/service/hlo_evaluator.cc b/tensorflow/compiler/xla/service/hlo_evaluator.cc index 51a3fba1768aaf219b78ddc09a1c526448389d9e..934c082bb9f003b1d2d80835f09a8f4109c7e7fd 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator.cc +++ b/tensorflow/compiler/xla/service/hlo_evaluator.cc @@ -33,12 +33,14 @@ limitations under the License. #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/map_util.h" #include "tensorflow/compiler/xla/primitive_util.h" +#include "tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul.h" #include "tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/hlo_query.h" #include "tensorflow/compiler/xla/service/shape_inference.h" #include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/window_util.h" @@ -396,6 +398,16 @@ StatusOr HloEvaluator::EvaluateDotOp( return Evaluate(cloned_instruction.get()); } +Status HloEvaluator::HandleBitcast(HloInstruction* bitcast) { + const Literal& operand_literal = GetEvaluatedLiteralFor(bitcast->operand(0)); + Literal result(bitcast->shape()); + TF_RET_CHECK(operand_literal.size_bytes() == result.size_bytes()); + memcpy(result.untyped_data(), operand_literal.untyped_data(), + operand_literal.size_bytes()); + evaluated_[bitcast] = std::move(result); + return Status::OK(); +} + Status HloEvaluator::HandleParameter(HloInstruction* parameter) { CHECK_LT(parameter->parameter_number(), arg_literals_.size()); const Literal* input_literal = arg_literals_[parameter->parameter_number()]; @@ -618,8 +630,11 @@ Status HloEvaluator::HandleCompare(HloInstruction* compare) { evaluated_[compare], Compare(compare->shape(), opcode, lhs_literal, rhs_literal)); } break; - case F16: - return Unimplemented("unhandled primitive type: F16."); + case F16: { + TF_ASSIGN_OR_RETURN( + evaluated_[compare], + Compare(compare->shape(), opcode, lhs_literal, rhs_literal)); + } break; case BF16: { TF_ASSIGN_OR_RETURN(evaluated_[compare], Compare(compare->shape(), opcode, @@ -1438,4 +1453,46 @@ template StatusOr HloEvaluator::Evaluate( template StatusOr HloEvaluator::Evaluate( HloInstruction* instruction, absl::Span arg_literals); +namespace { +template +std::unique_ptr> MatmulArray2DImpl( + const Array2D& lhs, const Array2D& rhs, + const std::function& impl_fn) { + CHECK_EQ(lhs.width(), rhs.height()); + int m = lhs.height(); + int n = rhs.width(); + int k = lhs.width(); + auto result = absl::make_unique>(m, n); + // Because Eigen is a header-oriented library, make sure that the Eigen code + // is the same as the code used by the CPU backend (otherwise the linker will + // randomly pick *some* definition). + impl_fn( + /*run_options_ptr=*/nullptr, result->data(), rhs.data(), lhs.data(), n, m, + k, + /*transpose_lhs=*/0, + /*transpose_rhs=*/0); + return result; +} +} // namespace + +std::unique_ptr> HloEvaluator::MatmulArray2D( + const Array2D& lhs, const Array2D& rhs) { + return MatmulArray2DImpl( + lhs, rhs, __xla_cpu_runtime_EigenSingleThreadedMatMulF16); +} + +std::unique_ptr> HloEvaluator::MatmulArray2D( + const Array2D& lhs, const Array2D& rhs) { + return MatmulArray2DImpl( + lhs, rhs, __xla_cpu_runtime_EigenSingleThreadedMatMulF32); +} + +std::unique_ptr> HloEvaluator::MatmulArray2D( + const Array2D& lhs, const Array2D& rhs) { + return MatmulArray2DImpl( + lhs, rhs, __xla_cpu_runtime_EigenSingleThreadedMatMulF64); +} + } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_evaluator.h b/tensorflow/compiler/xla/service/hlo_evaluator.h index d847900010c697d7d280ed8e4a9502f1c465ee07..d363a51c63de6fd4246c4970f580b68f4a627df8 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator.h +++ b/tensorflow/compiler/xla/service/hlo_evaluator.h @@ -21,6 +21,7 @@ limitations under the License. #include "absl/container/node_hash_map.h" #include "absl/memory/memory.h" #include "absl/types/span.h" +#include "tensorflow/compiler/xla/array2d.h" #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" @@ -119,6 +120,17 @@ class HloEvaluator : public DfsHloVisitorWithDefault { const PrecisionConfig& precision_config, const Literal& lhs, const Literal& rhs); + // Enable the fast path for certain operations like dot or convolution. + void set_use_fast_path(bool value) { use_fast_path_ = value; } + + // Returns the result of a matrix multiply `lhs x rhs`. + static std::unique_ptr> MatmulArray2D( + const Array2D& lhs, const Array2D& rhs); + static std::unique_ptr> MatmulArray2D( + const Array2D& lhs, const Array2D& rhs); + static std::unique_ptr> MatmulArray2D( + const Array2D& lhs, const Array2D& rhs); + protected: // Make HloEvaluatorTypedVisitor a friend because it is logically part of this // class. @@ -144,6 +156,8 @@ class HloEvaluator : public DfsHloVisitorWithDefault { // Operations that are type-agnostic or always return a specific type, such as // HandleIsFinite where boolean is always returned. // + Status HandleBitcast(HloInstruction* bitcast) override; + Status HandleParameter(HloInstruction* parameter) override; Status HandleConstant(HloInstruction* constant) override; @@ -215,6 +229,9 @@ class HloEvaluator : public DfsHloVisitorWithDefault { // we cannot use flat_hash_map any more. absl::node_hash_map evaluated_; + // Use fast path that uses eigen in the evaluator. + bool use_fast_path_ = false; + private: template static StatusOr ElementWiseUnaryOpImpl( @@ -248,6 +265,8 @@ class HloEvaluator : public DfsHloVisitorWithDefault { TF_DISALLOW_COPY_AND_ASSIGN(HloEvaluator); }; +std::unique_ptr> MatmulArray2D(const Array2D& lhs, + const Array2D& rhs); } // namespace xla #endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_EVALUATOR_H_ diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_test.cc b/tensorflow/compiler/xla/service/hlo_evaluator_test.cc index d95b6ad04f2c446b423a3aaef4de333ed2968883..4eaaab20ea0add17d9b49b1b2b97991af0438dcc 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator_test.cc +++ b/tensorflow/compiler/xla/service/hlo_evaluator_test.cc @@ -22,6 +22,7 @@ limitations under the License. #include #include "absl/memory/memory.h" +#include "absl/strings/str_format.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/reference_util.h" @@ -35,6 +36,7 @@ limitations under the License. #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" +#include "tensorflow/compiler/xla/tests/test_utils.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" @@ -2765,6 +2767,33 @@ ENTRY main { EXPECT_TRUE(LiteralTestUtil::Equal(arg, actual)); } +TEST_P(HloEvaluatorTest, Bitcast) { + // Regression test for b/114735354. + constexpr absl::string_view hlo_text_base = R"( +HloModule Bitcast + +ENTRY main { + param = %s[32,121]{1,0} parameter(0) + ROOT bitcast = %s[121,32,1]{0,1,2} bitcast(%s[32,121]{1,0} param) +} +)"; + string hlo_text; + if (use_bfloat16_) { + hlo_text = absl::StrFormat(hlo_text_base, "bf16", "bf16", "bf16"); + } else { + hlo_text = absl::StrFormat(hlo_text_base, "f32", "f32", "f32"); + } + TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text)); + auto args = MakeFakeArguments(m_.get()).ConsumeValueOrDie(); + Literal actual = Evaluate({&args[0]}); + if (use_bfloat16_) { + EXPECT_TRUE( + absl::c_equal(args[0].data(), actual.data())); + } else { + EXPECT_TRUE(absl::c_equal(args[0].data(), actual.data())); + } +} + INSTANTIATE_TEST_CASE_P(HloEvaluatorTest_Instantiation, HloEvaluatorTest, ::testing::ValuesIn(use_bf16_params)); diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h index b87fc3e34012e75ee07bff6c1e113dce404f83cb..03d42990ce9dcd3f689831078354f878bcb0800f 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h +++ b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h @@ -23,6 +23,7 @@ limitations under the License. #include "absl/container/inlined_vector.h" #include "absl/memory/memory.h" #include "absl/types/optional.h" +#include "tensorflow/compiler/xla/array2d.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/service/hlo_casting_utils.h" #include "tensorflow/compiler/xla/service/hlo_evaluator.h" @@ -105,6 +106,12 @@ bool SafeLess(const NativeT& a, const NativeT& b) { template class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { private: + Status UnsupportedTypeError(HloInstruction* instruction) { + return InvalidArgument( + "Unsupported type for %s: %s", HloOpcodeString(instruction->opcode()), + PrimitiveType_Name(instruction->shape().element_type())); + } + // Get the value in the given literal static_cast as a double. template < typename NativeT, @@ -224,7 +231,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { typename NativeT, typename std::enable_if::value>::type* = nullptr> Status HandleRound(HloInstruction* round) { - return InvalidArgument("Unsupported type for Round"); + return UnsupportedTypeError(round); } Status HandleRound(HloInstruction* round) override { @@ -246,7 +253,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { typename NativeT, typename std::enable_if::value>::type* = nullptr> Status HandleCeil(HloInstruction* ceil) { - return InvalidArgument("Unsupported type for Ceil"); + return UnsupportedTypeError(ceil); } Status HandleCeil(HloInstruction* ceil) override { @@ -297,8 +304,8 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { template < typename NativeT, typename std::enable_if::value>::type* = nullptr> - Status HandleExpm1(HloInstruction* floor) { - return InvalidArgument("Unsupported type for Expm1"); + Status HandleExpm1(HloInstruction* expm1) { + return UnsupportedTypeError(expm1); } Status HandleExpm1(HloInstruction* floor) override { @@ -321,7 +328,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { typename NativeT, typename std::enable_if::value>::type* = nullptr> Status HandleFloor(HloInstruction* floor) { - return InvalidArgument("Unsupported type for Floor"); + return UnsupportedTypeError(floor); } Status HandleFloor(HloInstruction* floor) override { @@ -351,12 +358,12 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { template < typename NativeT, typename std::enable_if::value>::type* = nullptr> - Status HandleLog1p(HloInstruction* floor) { - return InvalidArgument("Unsupported type for Log1p"); + Status HandleLog1p(HloInstruction* log1p) { + return UnsupportedTypeError(log1p); } - Status HandleLog1p(HloInstruction* floor) override { - return HandleLog1p(floor); + Status HandleLog1p(HloInstruction* log1p) override { + return HandleLog1p(log1p); } template ::value>::type* = nullptr> Status HandleNot(HloInstruction* not_) { - return InvalidArgument("Unsupported type for Not"); + return UnsupportedTypeError(not_); } Status HandleNot(HloInstruction* not_) override { @@ -476,7 +483,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { template ::value>::type* = nullptr> Status HandleAtan2(HloInstruction* atan2) { - return InvalidArgument("Unsupported type for Atan2"); + return UnsupportedTypeError(atan2); } Status HandleAtan2(HloInstruction* atan2) override { @@ -624,7 +631,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { typename NativeT, typename std::enable_if::value>::type* = nullptr> Status HandleMaximum(HloInstruction* maximum) { - return InvalidArgument("Unsupported type for Maximum"); + return UnsupportedTypeError(maximum); } Status HandleMaximum(HloInstruction* maximum) override { @@ -659,7 +666,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { typename NativeT, typename std::enable_if::value>::type* = nullptr> Status HandleMinimum(HloInstruction* minimum) { - return InvalidArgument("Unsupported type for Minimum"); + return UnsupportedTypeError(minimum); } Status HandleMinimum(HloInstruction* minimum) override { @@ -724,7 +731,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { typename NativeT, typename std::enable_if::value>::type* = nullptr> Status HandleRemainder(HloInstruction* remainder) { - return InvalidArgument("Unsupported type for Remainder"); + return UnsupportedTypeError(remainder); } Status HandleRemainder(HloInstruction* remainder) override { @@ -746,14 +753,14 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { template ::value>::type* = nullptr> Status HandleAnd(HloInstruction* and_) { - return InvalidArgument("Unsupported type for And"); + return UnsupportedTypeError(and_); } template < typename NativeT, typename std::enable_if::value>::type* = nullptr> Status HandleAnd(HloInstruction* and_) { - return InvalidArgument("Unsupported type for And"); + return UnsupportedTypeError(and_); } Status HandleAnd(HloInstruction* and_) override { @@ -775,7 +782,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { template ::value>::type* = nullptr> Status HandleOr(HloInstruction* or_) { - return InvalidArgument("Unsupported type for Or"); + return UnsupportedTypeError(or_); } template < @@ -804,14 +811,14 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { template ::value>::type* = nullptr> Status HandleXor(HloInstruction* xor_) { - return InvalidArgument("Unsupported type for Xor"); + return UnsupportedTypeError(xor_); } template < typename NativeT, typename std::enable_if::value>::type* = nullptr> Status HandleXor(HloInstruction* xor_) { - return InvalidArgument("Unsupported type for Xor"); + return UnsupportedTypeError(xor_); } Status HandleXor(HloInstruction* xor_) override { @@ -836,8 +843,8 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { typename std::enable_if::value || std::is_same::value>::type* = nullptr> - Status HandleShiftLeft(HloInstruction*) { - return InvalidArgument("Unsupported type for ShiftLeft"); + Status HandleShiftLeft(HloInstruction* shift) { + return UnsupportedTypeError(shift); } Status HandleShiftLeft(HloInstruction* shl) override { @@ -866,8 +873,8 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { typename std::enable_if::value || std::is_same::value>::type* = nullptr> - Status HandleShiftRightArithmetic(HloInstruction*) { - return InvalidArgument("Unsupported type for ShiftRightArithmetic"); + Status HandleShiftRightArithmetic(HloInstruction* shift) { + return UnsupportedTypeError(shift); } Status HandleShiftRightArithmetic(HloInstruction* shra) override { @@ -897,8 +904,8 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { typename std::enable_if::value || std::is_same::value>::type* = nullptr> - Status HandleShiftRightLogical(HloInstruction*) { - return InvalidArgument("Unsupported type for ShiftRightLogical"); + Status HandleShiftRightLogical(HloInstruction* shift) { + return UnsupportedTypeError(shift); } Status HandleShiftRightLogical(HloInstruction* shrl) override { @@ -923,8 +930,8 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { template < typename NativeT, typename std::enable_if::value>::type* = nullptr> - Status HandleClamp(HloInstruction*) { - return InvalidArgument("Unsupported type for Clamp"); + Status HandleClamp(HloInstruction* clamp) { + return UnsupportedTypeError(clamp); } Status HandleClamp(HloInstruction* clamp) override { @@ -1148,6 +1155,78 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { } Status HandleDot(HloInstruction* dot) override { + if (parent_->use_fast_path_) { + return HandleDot(dot); + } + return HandleDotSlowPath(dot); + } + + template ::value>::type* = nullptr> + Status HandleDot(HloInstruction* dot) { + const HloInstruction* lhs = dot->operand(0); + const HloInstruction* rhs = dot->operand(1); + CHECK(ShapeUtil::IsArray(dot->shape())); + CHECK(ShapeUtil::IsArray(lhs->shape())); + CHECK(ShapeUtil::IsArray(rhs->shape())); + + const auto& dnums = dot->dot_dimension_numbers(); + + const int64 lhs_rank = ShapeUtil::Rank(lhs->shape()); + const int64 rhs_rank = ShapeUtil::Rank(rhs->shape()); + + CHECK(ShapeUtil::SameElementType(lhs->shape(), rhs->shape())); + CHECK(ShapeUtil::SameElementType(lhs->shape(), dot->shape())); + + // There must be 1 and only 1 Contracting dimension for lhs and rhs. + CHECK_EQ(dnums.lhs_contracting_dimensions_size(), 1); + CHECK_EQ(dnums.rhs_contracting_dimensions_size(), 1); + const int64 lhs_contracting_dimension = dnums.lhs_contracting_dimensions(0); + const int64 rhs_contracting_dimension = dnums.rhs_contracting_dimensions(0); + // Contracted dimension sizes must be the same. + CHECK_EQ(lhs->shape().dimensions(lhs_contracting_dimension), + rhs->shape().dimensions(rhs_contracting_dimension)) + << "lhs contracted dimension: " + << lhs->shape().dimensions(lhs_contracting_dimension) + << " rhs contracted dimension: " + << rhs->shape().dimensions(rhs_contracting_dimension); + + // The fast path is for a simple rank 2 dot with default layout operands. + if (lhs_rank == 2 && rhs_rank == 2 && lhs_contracting_dimension == 1 && + rhs_contracting_dimension == 0 && + LayoutUtil::Equal(lhs->shape().layout(), + LayoutUtil::GetDefaultLayoutForR2()) && + LayoutUtil::Equal(rhs->shape().layout(), + LayoutUtil::GetDefaultLayoutForR2()) && + LayoutUtil::Equal(dot->shape().layout(), + LayoutUtil::GetDefaultLayoutForR2())) { + const Literal& lhs_literal = parent_->GetEvaluatedLiteralFor(lhs); + const Literal& rhs_literal = parent_->GetEvaluatedLiteralFor(rhs); + const int64 contracted_dimension_size = + lhs->shape().dimensions(lhs_contracting_dimension); + Array2D lhs_array(lhs->shape().dimensions(0), + contracted_dimension_size); + lhs_array.SetValues(lhs_literal.data()); + Array2D rhs_array(contracted_dimension_size, + rhs->shape().dimensions(1)); + rhs_array.SetValues(rhs_literal.data()); + std::unique_ptr> result_array = + HloEvaluator::MatmulArray2D(lhs_array, rhs_array); + Literal result(dot->shape()); + result.PopulateR2FromArray2D(*result_array); + parent_->evaluated_[dot] = std::move(result); + return Status::OK(); + } + return HandleDotSlowPath(dot); + } + + template ::value>::type* = nullptr> + Status HandleDot(HloInstruction* dot) { + return HandleDotSlowPath(dot); + } + + Status HandleDotSlowPath(HloInstruction* dot) { auto lhs = dot->operand(0); auto rhs = dot->operand(1); CHECK(ShapeUtil::IsArray(dot->shape())); @@ -1578,7 +1657,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { std::is_same::value>::type* = nullptr> Status HandleSort(HloInstruction* sort) { - return InvalidArgument("Unsupported type for Sort"); + return UnsupportedTypeError(sort); } Status HandleSort(HloInstruction* sort) override { @@ -2357,7 +2436,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { std::is_same::value || std::is_same::value)>::type* = nullptr> Status HandleClz(HloInstruction* clz) { - return InvalidArgument("Unsupported type for Clz"); + return UnsupportedTypeError(clz); } template ::value || is_complex_t::value>::type* = nullptr> Status HandleSin(HloInstruction* sin) { - return InvalidArgument("Unsupported type for Sin"); + return UnsupportedTypeError(sin); } Status HandleSin(HloInstruction* sin) override { @@ -2425,7 +2504,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { typename std::enable_if::value || is_complex_t::value>::type* = nullptr> Status HandleCos(HloInstruction* cos) { - return InvalidArgument("Unsupported type for Cos"); + return UnsupportedTypeError(cos); } Status HandleCos(HloInstruction* cos) override { @@ -2534,7 +2613,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { typename std::enable_if::value || is_complex_t::value>::type* = nullptr> Status HandleReducePrecision(HloInstruction* reduce_precision) { - return InvalidArgument("Unsupported type for reduce precision"); + return UnsupportedTypeError(reduce_precision); } Status HandleReducePrecision(HloInstruction* reduce_precision) override { @@ -2543,15 +2622,27 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { template ::value || + std::is_same::value || std::is_integral::value || std::is_floating_point::value>::type* = nullptr> Status HandleIota(HloInstruction* instruction) { auto* iota = Cast(instruction); + const int64 iota_size = iota->shape().dimensions(iota->iota_dimension()); // Avoid using std::vector since std::vector does not convert to // absl::Span. - absl::InlinedVector data( - iota->shape().dimensions(iota->iota_dimension())); - std::iota(data.begin(), data.end(), 0); + absl::InlinedVector data(iota_size); + // We don't use std::iota for two reasons: + // + // (1) std:iota does not support bfloat16 and float16. + // + // (2) std::iota saturates for floating point types when the value is not + // representable, but the definition of HLO iota is the value as a + // 64-bit integer cast to the native type. + for (int64 i = 0; i < iota_size; ++i) { + // static_cast is required for Eigen::half (F16). + data[i] = static_cast(i); + } auto result = LiteralUtil::CreateR1(data); if (ShapeUtil::Rank(iota->shape()) > 1) { @@ -2567,10 +2658,12 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { } template ::value || + !(std::is_same::value || + std::is_same::value || + std::is_integral::value || std::is_floating_point::value)>::type* = nullptr> Status HandleIota(HloInstruction* iota) { - return InvalidArgument("Unsupported type for iota"); + return UnsupportedTypeError(iota); } Status HandleIota(HloInstruction* iota) override { return HandleIota(iota); diff --git a/tensorflow/compiler/xla/service/hlo_get_dimension_size_rewriter.cc b/tensorflow/compiler/xla/service/hlo_get_dimension_size_rewriter.cc index 631b3ad735f369922d10b37d11e2a1b1ba117e6b..c919dbd82d3668c477bf37074f1d56f8cb7d9506 100644 --- a/tensorflow/compiler/xla/service/hlo_get_dimension_size_rewriter.cc +++ b/tensorflow/compiler/xla/service/hlo_get_dimension_size_rewriter.cc @@ -39,7 +39,7 @@ StatusOr ReplaceGetSize(HloInstruction* instr) { uint32 size = instr->operand(0)->shape().dimensions(instr->dimension()); HloInstruction* new_instr = computation->AddInstruction( HloInstruction::CreateConstant(LiteralUtil::CreateR0(size))); - TF_RETURN_IF_ERROR(computation->ReplaceInstruction(instr, new_instr)); + TF_RETURN_IF_ERROR(instr->ReplaceAllUsesWith(new_instr)); return true; } @@ -50,12 +50,7 @@ StatusOr HloGetDimensionSizeRewriter::Run(HloModule* module) { HloProto proto; *proto.mutable_hlo_module() = module->ToProto(); for (auto* computation : module->computations()) { - // Replacing instructions will change the instruction list in the - // computation. So instead of iterating computation->instructions() - // directly, we make a copy of the list to avoid use-after-free. - std::vector instrs(computation->instruction_count()); - absl::c_copy(computation->instructions(), instrs.begin()); - for (auto instruction : instrs) { + for (auto instruction : computation->instructions()) { TF_ASSIGN_OR_RETURN(bool replaced, ReplaceGetSize(instruction)); changed = changed || replaced; } diff --git a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc index 804feff290a1c0800a8e6bf209b042241b6cb759..5db21e47ca94af3b017e0401237692913365a48c 100644 --- a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc +++ b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc @@ -21,6 +21,7 @@ limitations under the License. #include #include #include +#include #include #include #include @@ -38,6 +39,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_instructions.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_tfgraph_builder.h" +#include "tensorflow/compiler/xla/service/pattern_matcher.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/window_util.h" @@ -111,11 +113,6 @@ class NodeFilter { result == kSomeUsersOmitted; } - bool ShowFusionSubcomputation(const HloInstruction* instr) const { - CHECK_EQ(instr->opcode(), HloOpcode::kFusion); - return Show(instr) && !SomeOrAllOperandsOmitted(instr); - } - private: std::function filter_; }; @@ -240,34 +237,28 @@ string HtmlLikeStringSanitize(absl::string_view s) { // it to a short string lets us tell the user what the subcomputation is without // drawing it as a graph. optional MatchTrivialComputation(const HloComputation* computation) { + namespace m = match; + if (computation->instruction_count() != 3) { return nullopt; } - HloInstruction* root = computation->root_instruction(); - if (root->operand_count() != 2) { - return nullopt; - } - - // Check that both of the operands to the root are parameters. - const HloInstruction* operand0 = root->operand(0); - const HloInstruction* operand1 = root->operand(1); - if (operand0->opcode() != HloOpcode::kParameter || - operand1->opcode() != HloOpcode::kParameter) { + const HloInstruction *param0, *param1; + if (!Match(root, m::Op() + .WithNumOperands(2) + .WithShape(m::Shape().IsEffectiveScalar()) + .WithBinaryOperandsAnyOrder( + m::Parameter(¶m0, 0) + .WithShape(m::Shape().IsEffectiveScalar()), + m::Parameter(¶m1, 1) + .WithShape(m::Shape().IsEffectiveScalar())))) { return nullopt; } - // Check that the two operands of root are param0 and param1. All of the - // opcodes we recognize are commutative, so we're OK with either order. - auto n0 = operand0->parameter_number(); - auto n1 = operand1->parameter_number(); - if (!(n0 == 0 && n1 == 1) && !(n1 == 0 && n0 == 1)) { - return nullopt; - } - - // If the params are reversed, check that the operation being performed is - // commutative. - if (n0 == 1) { + // If the params are reversed (i.e. operand0 is param1 and operand1 is + // param0), check that the operation being performed is commutative. + if (root->operand(0) == param1) { + CHECK_EQ(root->operand(1), param0); switch (root->opcode()) { case HloOpcode::kLe: case HloOpcode::kGe: @@ -279,13 +270,6 @@ optional MatchTrivialComputation(const HloComputation* computation) { } } - // Check that the root and params are all effective scalars. - if (!ShapeUtil::IsEffectiveScalar(root->shape()) || - !ShapeUtil::IsEffectiveScalar(operand0->shape()) || - !ShapeUtil::IsEffectiveScalar(operand1->shape())) { - return nullopt; - } - // If we recognize the root's opcode, we've successfully pattern-matched! switch (root->opcode()) { case HloOpcode::kAdd: @@ -578,7 +562,7 @@ bool HloDotDumper::ShouldShowSubcomputation(const HloComputation* subcomp) { // Show the subcomputation if we're showing any of its members. return std::any_of( - computation_->instructions().begin(), computation_->instructions().end(), + subcomp->instructions().begin(), subcomp->instructions().end(), [&](const HloInstruction* instr) { return filter_.Show(instr); }); } @@ -1298,7 +1282,8 @@ namespace { // Gets a NodeFilter that includes roughly all instructions whose distance from // root is <= radius. -NodeFilter MakeNodeFilter(const HloInstruction* root, int64 radius) { +NodeFilter MakeNodeRadiusAroundFilter(const HloInstruction* root, + int64 radius) { // First, find the neighborhood of nodes with distance from root <= radius. // These nodes are our initial set of "normal" nodes. std::unordered_map nodes; @@ -1405,6 +1390,56 @@ NodeFilter MakeNodeFilter(const HloInstruction* root, int64 radius) { }); } +// Gets a node filter that includes nodes on all paths from `from` to `to`. If +// the all-paths set contains more than max_nodes elements, includes the nodes +// on the shortest paths and sets hit_limit to true. +NodeFilter MakeNodeFromToFilter(const HloInstruction* from, + const HloInstruction* to, int64 max_nodes, + bool* hit_limit) { + *hit_limit = false; + + // Elements in the queue are paths through the graph. + std::deque> queue; + queue.push_front({from}); + + // Compute the set of nodes we want to show using a slightly-modified + // Djikstra's algorithm. The only real difference is, rather than stopping + // when we find a (shortest) path, we continue until we've found max_nodes + // nodes on some path. + std::unordered_set visited; + std::unordered_set to_display = {from, to}; + while (!queue.empty() && to_display.size() < max_nodes) { + std::vector path = std::move(queue.front()); + queue.pop_front(); + if (!visited.insert(path.back()).second) { + continue; + } + + for (const auto* user : path.back()->users()) { + if (user == to) { + auto it = path.begin(); + for (; it != path.end() && to_display.size() < max_nodes; ++it) { + to_display.insert(*it); + } + if (it != path.end()) { + *hit_limit = true; + } + } else if (!visited.count(user)) { + auto new_path = path; + new_path.push_back(user); + queue.push_back(std::move(new_path)); + } + } + } + + return NodeFilter([=](const HloInstruction* instr) { + if (instr == from || instr == to) { + return kHighlightNode; + } + return to_display.count(instr) ? kNormalNode : kHideNode; + }); +} + string SaveGraph(const string& graph, GraphRendererInterface::GraphKind graph_kind, const string& dest_path) { @@ -1439,14 +1474,15 @@ string ExportGraph(const string& graph, GraphRendererInterface::GraphKind graph_kind, const DebugOptions& debug_options) { string path = debug_options.xla_hlo_graph_path(); - if (!path.empty()) { + if (!path.empty() && !debug_options.xla_hlo_dump_as_html()) { return SaveGraph(graph, graph_kind, path); } else { auto graph_renderer = GraphRendererRegistry::Default()->GetDefaultRenderer(); CHECK(graph_renderer != nullptr) << "No registered renderer for the HLO graph. " - "Use --xla_hlo_graph_path=PATH to export to local file system"; + "Use --xla_hlo_graph_path=PATH --xla_hlo_dump_as_html=false to " + "export to local file system"; return graph_renderer->RenderGraph(graph, graph_kind, debug_options); } } @@ -1484,7 +1520,7 @@ string DumpNeighborhoodAround(const HloInstruction& node, int radius, auto debug_options = node.GetModule()->config().debug_options(); string label = StrCat("Neighborhood of ", radius, " nodes around ", node.name()); - NodeFilter filter = MakeNodeFilter(&node, radius); + NodeFilter filter = MakeNodeRadiusAroundFilter(&node, radius); string graph = HloDotDumper(node.parent(), label, debug_options, show_backend_config, /*profile=*/nullptr, filter) @@ -1492,6 +1528,29 @@ string DumpNeighborhoodAround(const HloInstruction& node, int radius, return ExportGraph(graph, GraphRendererInterface::DOT_GRAPH, debug_options); } +string DumpAllPathsFromTo(const HloInstruction& from, const HloInstruction& to, + int64 max_nodes, bool show_backend_config) { + CHECK_EQ(from.parent(), to.parent()) << "Nodes must be in same computation!"; + auto debug_options = from.GetModule()->config().debug_options(); + + bool hit_limit = false; + NodeFilter filter = MakeNodeFromToFilter(&from, &to, max_nodes, &hit_limit); + string label; + if (!hit_limit) { + label = StrCat("All paths from ", from.name(), " to ", to.name()); + } else { + label = StrCat(max_nodes, " nodes on the shortest paths from ", from.name(), + " to ", to.name(), + "

***SHOWING ONLY A SUBSET OF ALL PATHS BETWEEN " + "NODES***

"); + } + string graph = + HloDotDumper(from.parent(), label, debug_options, show_backend_config, + /*profile=*/nullptr, filter) + .Dump(); + return ExportGraph(graph, GraphRendererInterface::DOT_GRAPH, debug_options); +} + void DumpText(const HloModule& module, const string& label, const string& directory_path, bool do_prefix) { Env* env = Env::Default(); @@ -1531,5 +1590,143 @@ string MaybeDumpHloModule(const HloModule& module, const string& label, return graph_url; } +string WrapDotInHTML(const string& dot) { + static const char html_prefix[] = R"html( + + + + + + + + + + + +
+ + + +)html"; + + return html_prefix + dot + html_suffix; +} + +string RenderDotAsHTMLFile(const string& dot, + const DebugOptions& debug_options) { + string html = WrapDotInHTML(dot); + + auto env = tensorflow::Env::Default(); + std::vector dirs; + string output_dir = debug_options.xla_hlo_graph_path(); + if (output_dir.empty()) { + env->GetLocalTempDirectories(&dirs); + } else { + dirs.push_back(output_dir); + } + // Try each directory, as they might be full, have inappropriate + // permissions or have different problems at times. + string output; + for (const string& dir : dirs) { + string filename = tensorflow::io::JoinPath(dir, "graph-"); + if (env->CreateUniqueFileName(&filename, ".html")) { + output = filename; + break; + } + } + if (output.empty()) { + LOG(FATAL) << "Failed to create unique output file name."; + } + TF_CHECK_OK(tensorflow::WriteStringToFile(env, output, html)); + return "file://" + output; +} + } // namespace hlo_graph_dumper } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_graph_dumper.h b/tensorflow/compiler/xla/service/hlo_graph_dumper.h index 8d5945aba8cb0a7426597f07173e83c4574f3365..8e51454ef1cf992386cc7325e32705c08bf7712f 100644 --- a/tensorflow/compiler/xla/service/hlo_graph_dumper.h +++ b/tensorflow/compiler/xla/service/hlo_graph_dumper.h @@ -66,6 +66,12 @@ string DumpGraph(const HloComputation& computation, const string& label, string DumpNeighborhoodAround(const HloInstruction& node, int radius, bool show_backend_config = false); +// Dumps nodes on any of the paths from `from` to `to`. If there are more than +// max_nodes on all paths, restricts to the max_nodes nodes on the shortest +// paths. +string DumpAllPathsFromTo(const HloInstruction& from, const HloInstruction& to, + int64 max_nodes, bool show_backend_config = false); + // Dumps the HloModule::ToString() as a file into the provided directory path // suffixed with the provided label. // @@ -75,6 +81,12 @@ string DumpNeighborhoodAround(const HloInstruction& node, int radius, void DumpText(const HloModule& module, const string& label, const string& directory_path, bool do_prefix = true); +// Renders DOT graph as inline SVG and saves it in an HTML file in a temprary +// directory or directory specified via --xla_hlo_graph_path. Returns the file +// URI pointing to the file. +string RenderDotAsHTMLFile(const string& dot, + const DebugOptions& debug_options); + // Graph renderers may be added using a registration mechanism, e.g.: // XLA_REGISTER_GRAPH_RENDERER(AGraphRendererClass, 100) // The renderer with the highest numeric priority value is used. diff --git a/tensorflow/compiler/xla/service/hlo_graph_html_renderer.cc b/tensorflow/compiler/xla/service/hlo_graph_html_renderer.cc new file mode 100644 index 0000000000000000000000000000000000000000..84c4cf18df69816c611f4eb159ba247320ebc20e --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_graph_html_renderer.cc @@ -0,0 +1,43 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Implementation of an DOT graph renderer that uses Javascript to render DOT to +// SVG in a browser. + +#include "tensorflow/compiler/xla/service/hlo_graph_dumper.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" + +namespace xla { +namespace hlo_graph_dumper { +namespace { + +class GraphHtmlRenderer : public GraphRendererInterface { + public: + string RenderGraph(const string& graph, GraphKind graph_kind, + const DebugOptions& debug_options) override { + switch (graph_kind) { + case DOT_GRAPH: + return RenderDotAsHTMLFile(graph, debug_options); + default: + LOG(FATAL) << "Only DOT graphs can be rendered"; + } + } +}; + +XLA_REGISTER_GRAPH_RENDERER(GraphHtmlRenderer); + +} // namespace +} // namespace hlo_graph_dumper +} // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc index 21b1dbc1676cccd2fe5b331a1f9d6ff5e3a73fcd..8b2ace1e82eff250f4d9f0d5630e9e6d646cfe6d 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction.cc @@ -569,6 +569,11 @@ StatusOr> HloInstruction::CreateFromProto( instruction->SetAndSanitizeName(proto.name()); instruction->metadata_ = proto.metadata(); instruction->backend_config_ = proto.backend_config(); + + TF_RET_CHECK(proto.id() >= 0) + << "Instruction with negative id: " << proto.id(); + TF_RET_CHECK(proto.id() <= INT_MAX) + << "Instruction with id > INT_MAX: " << proto.id(); instruction->unique_id_ = proto.id(); if (proto.has_sharding()) { @@ -914,12 +919,8 @@ HloInstruction::CreateDynamicUpdateSlice(const Shape& shape, HloInstruction* operand, HloInstruction* update, HloInstruction* start_indices) { - auto instruction = absl::WrapUnique( - new HloInstruction(HloOpcode::kDynamicUpdateSlice, shape)); - instruction->AppendOperand(operand); - instruction->AppendOperand(update); - instruction->AppendOperand(start_indices); - return instruction; + return absl::make_unique( + shape, operand, update, start_indices); } /* static */ std::unique_ptr HloInstruction::CreateConcatenate( @@ -1760,7 +1761,12 @@ bool HloInstruction::IdenticalSlowPath( return false; } -uint64 HloInstruction::Hash() const { +static uint64 HashOperand(const HloInstruction* hlo) { + return ShapeUtil::Hash(hlo->shape()); +} + +uint64 HloInstruction::Hash( + const std::function& hash_operand) const { using tensorflow::Hash64Combine; uint64 hash_value = Hash64Combine(0, static_cast(opcode())); @@ -1769,7 +1775,7 @@ uint64 HloInstruction::Hash() const { if (!IsCrossModuleAllReduce()) { if (!operands().empty()) { for (size_t i = 0; i < operands().size(); ++i) { - hash_value = Hash64Combine(hash_value, operand(i)->Hash()); + hash_value = Hash64Combine(hash_value, hash_operand(operand(i))); } } } @@ -1778,6 +1784,11 @@ uint64 HloInstruction::Hash() const { return hash_value; } +uint64 HloInstruction::Hash() const { + // Use HashOperand as an argument to prevent non-termination. + return Hash(HashOperand); +} + uint64 HloInstruction::InnerHash() const { return 13; } void HloInstruction::RemoveUser(HloInstruction* user) { @@ -2059,6 +2070,10 @@ bool HloInstruction::IsCrossModuleAllReduce() const { return opcode() == HloOpcode::kCrossReplicaSum && all_reduce_id(); } +bool HloInstruction::IsCrossReplicaAllReduce() const { + return opcode() == HloOpcode::kCrossReplicaSum && !all_reduce_id(); +} + string HloInstruction::ToStringWithCanonicalNameMap( const HloPrintOptions& options, CanonicalNameMap* canonical_name_map) const { diff --git a/tensorflow/compiler/xla/service/hlo_instruction.h b/tensorflow/compiler/xla/service/hlo_instruction.h index a54716217d6bbc5c0601f5d9ff7bf4072a6b30f5..dd77f101a049d7247dcf571d2d19cb4f74e2f8ea 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.h +++ b/tensorflow/compiler/xla/service/hlo_instruction.h @@ -909,6 +909,14 @@ class HloInstruction { // information on opcode, shape, operands, and typically a root instruction. // This function returns the same hash value for equivalent HLO instructions, // with respect to HloInstruction::Identical() method. + // + // Uses hash_operand function to compute hash values of its operands. + // At the very top level, hash_operand should be non-recursive to prevent + // non-termination. + uint64 Hash( + const std::function& hash_operand) const; + + // Calls the above method with non-recursive hash_operand function. uint64 Hash() const; // Returns whether the instruction has a constant operand. @@ -1174,9 +1182,12 @@ class HloInstruction { // Returns true if this instruction is elementwise on all its operands. bool IsElementwise() const; - // Returns true if this is an cross module all-reduce instrucion. + // Returns true if this is a cross module all-reduce instruction. bool IsCrossModuleAllReduce() const; + // Returns true if this is a cross-replica all-reduce instruction. + bool IsCrossReplicaAllReduce() const; + // Returns true if this elementwise instruction implicitly broadcasts operand // `operand_idx`. // diff --git a/tensorflow/compiler/xla/service/hlo_instructions.cc b/tensorflow/compiler/xla/service/hlo_instructions.cc index 1ea02cf9c03866a598bec0e5356f0eb31ad27755..5521e5bd9acefcd1cb7721ed55fe987189623404 100644 --- a/tensorflow/compiler/xla/service/hlo_instructions.cc +++ b/tensorflow/compiler/xla/service/hlo_instructions.cc @@ -905,7 +905,7 @@ string HloConstantInstruction::OperandsToStringWithCanonicalNameMap( options.print_large_constants())) { // Literal::ToString emits multidimensional arrays over multiple // lines. Compact this into one line by stripping out white space. - string tmp = literal().ToString(); + string tmp = literal().ToStringWithoutShape(); std::replace(tmp.begin(), tmp.end(), '\n', ' '); std::vector v = absl::StrSplit(tmp, ' '); bool first = true; @@ -1372,8 +1372,14 @@ bool HloFusionInstruction::IdenticalSlowPath( other.fused_instructions_computation()); } +static uint64 HashOperandRecursive(const HloInstruction* hlo) { + return hlo->Hash(HashOperandRecursive); +} + uint64 HloFusionInstruction::InnerHash() const { - return fused_instructions_computation()->Hash(); + // Use HashOperandRecursive to recursively compute hash on inner operands. + return fused_instructions_computation()->root_instruction()->Hash( + HashOperandRecursive); } std::unique_ptr HloFusionInstruction::CloneWithNewOperandsImpl( @@ -1994,12 +2000,21 @@ std::unique_ptr HloPadInstruction::CloneWithNewOperandsImpl( HloDynamicSliceInstruction::HloDynamicSliceInstruction( const Shape& shape, HloInstruction* operand, HloInstruction* start_indices, absl::Span slice_sizes) - : HloInstruction(HloOpcode::kDynamicSlice, shape), + : HloDynamicIndexInstruction(HloOpcode::kDynamicSlice, shape), dynamic_slice_sizes_(slice_sizes.begin(), slice_sizes.end()) { AppendOperand(operand); AppendOperand(start_indices); } +HloDynamicUpdateSliceInstruction::HloDynamicUpdateSliceInstruction( + const Shape& shape, HloInstruction* operand, HloInstruction* update, + HloInstruction* start_indices) + : HloDynamicIndexInstruction(HloOpcode::kDynamicUpdateSlice, shape) { + AppendOperand(operand); + AppendOperand(update); + AppendOperand(start_indices); +} + HloInstructionProto HloDynamicSliceInstruction::ToProto() const { HloInstructionProto proto = HloInstruction::ToProto(); for (int64 slice_size : dynamic_slice_sizes_) { diff --git a/tensorflow/compiler/xla/service/hlo_instructions.h b/tensorflow/compiler/xla/service/hlo_instructions.h index b5c28137a145667a977d39c9d3c40c6d36a8436e..5420d4ce11f4bdd068e82f208a98e9943ad4479e 100644 --- a/tensorflow/compiler/xla/service/hlo_instructions.h +++ b/tensorflow/compiler/xla/service/hlo_instructions.h @@ -1171,7 +1171,14 @@ class HloPadInstruction : public HloInstruction { PaddingConfig padding_config_; }; -class HloDynamicSliceInstruction : public HloInstruction { +class HloDynamicIndexInstruction : public HloInstruction { + public: + explicit HloDynamicIndexInstruction(HloOpcode opcode, const Shape& shape) + : HloInstruction(opcode, shape) {} + virtual int64 index_operand_number() const = 0; +}; + +class HloDynamicSliceInstruction : public HloDynamicIndexInstruction { public: explicit HloDynamicSliceInstruction(const Shape& shape, HloInstruction* operand, @@ -1189,6 +1196,8 @@ class HloDynamicSliceInstruction : public HloInstruction { // Returns a serialized representation of this instruction. HloInstructionProto ToProto() const override; + int64 index_operand_number() const override { return 1; } + private: std::vector ExtraAttributesToStringImpl( const HloPrintOptions& options) const override; @@ -1206,6 +1215,16 @@ class HloDynamicSliceInstruction : public HloInstruction { std::vector dynamic_slice_sizes_; }; +class HloDynamicUpdateSliceInstruction : public HloDynamicIndexInstruction { + public: + explicit HloDynamicUpdateSliceInstruction(const Shape& shape, + HloInstruction* operand, + HloInstruction* update, + HloInstruction* start_indices); + + int64 index_operand_number() const override { return 2; } +}; + class HloGatherInstruction : public HloInstruction { public: explicit HloGatherInstruction( diff --git a/tensorflow/compiler/xla/service/hlo_lexer.cc b/tensorflow/compiler/xla/service/hlo_lexer.cc index 1390537101e95a08e4ba4eef7ae8d6059a40e916..dc712e5e42c449737bf4415f3a5e3eb9d81d9be4 100644 --- a/tensorflow/compiler/xla/service/hlo_lexer.cc +++ b/tensorflow/compiler/xla/service/hlo_lexer.cc @@ -19,6 +19,7 @@ limitations under the License. #include "absl/strings/escaping.h" #include "absl/strings/numbers.h" +#include "absl/strings/str_split.h" #include "absl/types/optional.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/statusor.h" @@ -82,9 +83,23 @@ tensorflow::RegexpStringPiece HloLexer::RegexpStringPieceFromPointers( return tensorflow::RegexpStringPiece(begin, end - begin); } +TokKind HloLexer::LookAhead() { + if (GetKind() == TokKind::kEof || GetKind() == TokKind::kError) { + return GetKind(); + } + + const char* old_current_ptr = current_ptr_; + TokenState old_token_state = token_state_; + Lex(); + TokKind kind = GetKind(); + token_state_ = old_token_state; + current_ptr_ = old_current_ptr; + return kind; +} + TokKind HloLexer::LexToken() { while (true) { - token_start_ = current_ptr_; + token_state_.token_start = current_ptr_; int current_char = GetNextChar(); switch (current_char) { @@ -206,43 +221,37 @@ TokKind HloLexer::LexToken() { // dim_labels_pattern ::= [0-9bf]{2,}_[0-9io]{2,}->[0-9bf]{2,} // identifiers ::= other cases that match [a-zA-Z_][a-zA-Z0-9_.-]* TokKind HloLexer::LexIdentifier() { - { - auto consumable = RegexpStringPieceFromPointers(token_start_, buf_.end()); - // 'consumable' will be advanced iff its prefix matches the pattern. - static LazyRE2 shape_pattern = { - R"(^(\w*\d*)\[([\d,\s]*)\](?:(dense|sparse)?{([\d,\s]+)})?)"}; - if (RE2::Consume(&consumable, *shape_pattern)) { - auto status_or_shape = ShapeUtil::ParseShapeString( - StringPieceFromPointers(token_start_, consumable.begin())); - if (status_or_shape.ok()) { - // This is a shape string. - shape_val_ = status_or_shape.ValueOrDie(); - current_ptr_ = consumable.begin(); - return TokKind::kShape; - } - } - } - while (IsIdentifierChar(PeekCurrentChar())) { current_ptr_++; } // If followed by ':', it's a name. if (PeekCurrentChar() == ':') { - str_val_.assign(token_start_, current_ptr_); + token_state_.str_val.assign(token_state_.token_start, current_ptr_); current_ptr_++; // skip ':' return TokKind::kName; } // If followed by '=', it's a attribute name. if (PeekCurrentChar() == '=') { - str_val_.assign(token_start_, current_ptr_); + token_state_.str_val.assign(token_state_.token_start, current_ptr_); current_ptr_++; // skip '=' return TokKind::kAttributeName; } absl::string_view identifier = - StringPieceFromPointers(token_start_, current_ptr_); + StringPieceFromPointers(token_state_.token_start, current_ptr_); + + // Primitive type strings are reserved words. The exception is 'tuple' whose + // type is represented using nested parentheses without the string 'tuple'. + if (primitive_util::IsPrimitiveTypeName(identifier)) { + PrimitiveType primitive_type = + primitive_util::StringToPrimitiveType(identifier).ValueOrDie(); + if (primitive_type != TUPLE) { + token_state_.primitive_type_val = primitive_type; + return TokKind::kPrimitiveType; + } + } // See if this is a keyword. #define KEYWORD(STR) \ @@ -261,21 +270,23 @@ TokKind HloLexer::LexIdentifier() { KEYWORD(ROOT); KEYWORD(maximal); KEYWORD(replicated); + KEYWORD(sparse); #undef KEYWORD { - auto consumable = RegexpStringPieceFromPointers(token_start_, buf_.end()); + auto consumable = + RegexpStringPieceFromPointers(token_state_.token_start, buf_.end()); static LazyRE2 dim_labels_pattern = { R"([0-9bf]{2,}_[0-9io]{2,}->[0-9bf]{2,})"}; if (RE2::Consume(&consumable, *dim_labels_pattern)) { current_ptr_ = consumable.begin(); - str_val_.assign(token_start_, current_ptr_); + token_state_.str_val.assign(token_state_.token_start, current_ptr_); return TokKind::kDimLabels; } } - str_val_ = string(identifier); + token_state_.str_val = string(identifier); return TokKind::kIdent; } @@ -289,7 +300,7 @@ TokKind HloLexer::LexPercent() { while (IsIdentifierChar(PeekCurrentChar())) { current_ptr_++; } - str_val_.assign(name_start, current_ptr_); + token_state_.str_val.assign(name_start, current_ptr_); return TokKind::kName; } return TokKind::kError; @@ -307,12 +318,14 @@ TokKind HloLexer::LexPercent() { // int ::= [-]?[0-9]+ // negative inf ::= '-inf' TokKind HloLexer::LexNumberOrPattern() { - auto consumable = RegexpStringPieceFromPointers(token_start_, buf_.end()); + auto consumable = + RegexpStringPieceFromPointers(token_state_.token_start, buf_.end()); static LazyRE2 float_pattern = { R"([-]?((\d+|\d+[.]\d*|\d*[.]\d+)([eE][+-]?\d+))|[-]?(\d+[.]\d*|\d*[.]\d+))"}; if (RE2::Consume(&consumable, *float_pattern)) { current_ptr_ = consumable.begin(); - CHECK(absl::SimpleAtod(string(token_start_, current_ptr_), &decimal_val_)); + CHECK(absl::SimpleAtod(string(token_state_.token_start, current_ptr_), + &token_state_.decimal_val)); return TokKind::kDecimal; } @@ -324,27 +337,28 @@ TokKind HloLexer::LexNumberOrPattern() { if (RE2::Consume(&consumable, *dim_labels_pattern)) { current_ptr_ = consumable.begin(); - str_val_.assign(token_start_, current_ptr_); + token_state_.str_val.assign(token_state_.token_start, current_ptr_); return TokKind::kDimLabels; } if (RE2::Consume(&consumable, *dxd_pattern)) { current_ptr_ = consumable.begin(); - str_val_.assign(token_start_, current_ptr_); + token_state_.str_val.assign(token_state_.token_start, current_ptr_); return TokKind::kDxD; } if (RE2::Consume(&consumable, *pad_pattern)) { current_ptr_ = consumable.begin(); - str_val_.assign(token_start_, current_ptr_); + token_state_.str_val.assign(token_state_.token_start, current_ptr_); return TokKind::kPad; } static LazyRE2 int_pattern = {R"([-]?\d+)"}; if (RE2::Consume(&consumable, *int_pattern)) { current_ptr_ = consumable.begin(); - auto slice = StringPieceFromPointers(token_start_, current_ptr_); - if (absl::SimpleAtoi(slice, &int64_val_)) { + auto slice = + StringPieceFromPointers(token_state_.token_start, current_ptr_); + if (absl::SimpleAtoi(slice, &token_state_.int64_val)) { return TokKind::kInt; } LOG(ERROR) << "Failed to parse int literal: " << slice; @@ -403,16 +417,17 @@ absl::string_view HloLexer::GetLine(LocTy loc) const { } // Lexes quoted string with escaping characters. If matched, the quoted string -// will be unescaped and stored to str_val_. +// will be unescaped and stored to token_state_.str_val. TokKind HloLexer::LexString() { - auto consumable = RegexpStringPieceFromPointers(token_start_, buf_.end()); + auto consumable = + RegexpStringPieceFromPointers(token_state_.token_start, buf_.end()); static LazyRE2 escaping_pattern = {R"("([^"\\]|\\.)*")"}; if (RE2::Consume(&consumable, *escaping_pattern)) { current_ptr_ = consumable.begin(); absl::string_view raw = - StringPieceFromPointers(token_start_ + 1, current_ptr_ - 1); + StringPieceFromPointers(token_state_.token_start + 1, current_ptr_ - 1); string error; - if (!absl::CUnescape(raw, &str_val_, &error)) { + if (!absl::CUnescape(raw, &token_state_.str_val, &error)) { LOG(ERROR) << "Failed unescaping string: " << raw << ". error: " << error; return TokKind::kError; } @@ -467,6 +482,10 @@ string TokKindToString(TokKind kind) { return "kw_inf"; case TokKind::kNegInf: return "kNegInf"; + case TokKind::kw_sparse: + return "kw_sparse"; + case TokKind::kPrimitiveType: + return "kPrimitiveType"; case TokKind::kName: return "kName"; case TokKind::kAttributeName: @@ -481,8 +500,6 @@ string TokKindToString(TokKind kind) { return "kIdent"; case TokKind::kString: return "kString"; - case TokKind::kShape: - return "kShape"; case TokKind::kInt: return "kInt"; case TokKind::kDecimal: diff --git a/tensorflow/compiler/xla/service/hlo_lexer.h b/tensorflow/compiler/xla/service/hlo_lexer.h index d6a2b292a3916b2ff85f278cf5cb9f1567df88fa..41f5043904a2622814154693679a0e27cb92f642 100644 --- a/tensorflow/compiler/xla/service/hlo_lexer.h +++ b/tensorflow/compiler/xla/service/hlo_lexer.h @@ -19,7 +19,6 @@ limitations under the License. #include #include "absl/strings/string_view.h" -#include "tensorflow/compiler/xla/service/hlo_token.h" #include "tensorflow/compiler/xla/shape.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" @@ -29,6 +28,57 @@ limitations under the License. namespace xla { +// Defines different kinds of tokens used by the HLO lexer. +// +// You shouldn't need to use this directly unless you're using HloLexer +// directly, and you probably don't need to do that. Use hlo_parser instead. +enum class TokKind { + // Markers + kEof, + kError, + + // Tokens with no info. + kEqual, // = + kComma, // , + kColon, // : + kLsquare, + kRsquare, // [ ] + kLbrace, + kRbrace, // { } + kLparen, + kRparen, // ( ) + + kArrow, // -> + + // Keywords + kw_HloModule, + kw_ENTRY, + kw_ROOT, + kw_true, + kw_false, + kw_maximal, + kw_replicated, + kw_nan, + kw_inf, + kw_sparse, + + kNegInf, // -inf + + // Typed tokens. + kPrimitiveType, // F32, PRED, etc. + kName, // %foo + kAttributeName, // dimensions= + kDimLabels, // [0-9bf]{2,}_[0-9io]{2,}->[0-9bf]{2,} + kDxD, // [0-9]+(x[0-9]+)+ + kPad, // [0-9]+_[0-9]+(_[0-9]+)?(x[0-9]+_[0-9]+(_[0-9]+)?)* + kIdent, // other identifiers + kString, // "abcd\"\n" + kInt, // 42 + kDecimal, // 4.2 +}; + +string TokKindToString(TokKind kind); + // Lexer for the HloModule::ToString() format text. // // This class is meant to be used by hlo_parser.cc. You shouldn't need to use @@ -39,9 +89,9 @@ class HloLexer { current_ptr_ = buf_.begin(); } - TokKind Lex() { return current_kind_ = LexToken(); } + TokKind Lex() { return token_state_.current_kind = LexToken(); } - TokKind GetKind() const { return current_kind_; } + TokKind GetKind() const { return token_state_.current_kind; } string GetStrVal() const { switch (GetKind()) { case TokKind::kName: @@ -51,28 +101,28 @@ class HloLexer { case TokKind::kPad: case TokKind::kString: case TokKind::kIdent: - return str_val_; + return token_state_.str_val; default: LOG(FATAL) << "This token does not have string value"; } } - Shape GetShapeVal() const { - CHECK(GetKind() == TokKind::kShape); - return shape_val_; - } tensorflow::int64 GetInt64Val() const { CHECK(GetKind() == TokKind::kInt); - return int64_val_; + return token_state_.int64_val; } double GetDecimalVal() const { CHECK(GetKind() == TokKind::kDecimal); - return decimal_val_; + return token_state_.decimal_val; + } + PrimitiveType GetPrimitiveTypeVal() const { + CHECK(GetKind() == TokKind::kPrimitiveType); + return token_state_.primitive_type_val; } typedef const char* LocTy; // Returns the location of the current token. - LocTy GetLoc() const { return token_start_; } + LocTy GetLoc() const { return token_state_.token_start; } // Returns the line and column of a location in the buffer. std::pair GetLineAndColumn(LocTy location) const; @@ -80,6 +130,9 @@ class HloLexer { // Returns the whole line given the location. absl::string_view GetLine(LocTy loc) const; + // Looks ahead one token and returns it. Lexer state is unchanged. + TokKind LookAhead(); + private: // Returns the current character. If it's neither the end of input buffer nor // an invalid character, moves the pointer forward. @@ -112,12 +165,15 @@ class HloLexer { const char* current_ptr_; // Information about the current token. - const char* token_start_ = nullptr; - TokKind current_kind_; - string str_val_; - Shape shape_val_; - tensorflow::int64 int64_val_; - double decimal_val_; + struct TokenState { + const char* token_start = nullptr; + TokKind current_kind; + string str_val; + tensorflow::int64 int64_val; + double decimal_val; + PrimitiveType primitive_type_val; + }; + TokenState token_state_; struct LineNoCacheTy { const char* last_query; diff --git a/tensorflow/compiler/xla/service/hlo_liveness_analysis_test.cc b/tensorflow/compiler/xla/service/hlo_liveness_analysis_test.cc index e0ae1173c6114f0bc6ef18b2cfff9d54ccfe2faf..436cccb1fb9ecf6f4efad772c700c611b28ce628 100644 --- a/tensorflow/compiler/xla/service/hlo_liveness_analysis_test.cc +++ b/tensorflow/compiler/xla/service/hlo_liveness_analysis_test.cc @@ -403,9 +403,9 @@ TEST_F(HloLivenessAnalysisTest, WhileWithOutfeed) { HloModule OutfeedLoop WhileBody { body_param = (s32[]) parameter(0) - token = token[] after-all() + token0 = token[] after-all() constant.2 = s32[] constant(2) - outfeed_tuple = (s32[]) outfeed(constant.2, token) + outfeed_tuple = (s32[]) outfeed(constant.2, token0) get-tuple-element.1 = s32[] get-tuple-element(body_param), index=0 constant.1 = s32[] constant(1) add = s32[] add(get-tuple-element.1, constant.1) @@ -436,9 +436,9 @@ TEST_F(HloLivenessAnalysisTest, NestedWhileWithOutfeed) { HloModule OutfeedLoop InnerWhileBody { body_param = (s32[]) parameter(0) - token = token[] after-all() + token0 = token[] after-all() constant.2 = s32[] constant(2) - outfeed_tuple = (s32[]) outfeed(constant.2, token) + outfeed_tuple = (s32[]) outfeed(constant.2, token0) get-tuple-element.1 = s32[] get-tuple-element(body_param), index=0 constant.1 = s32[] constant(1) add = s32[] add(get-tuple-element.1, constant.1) diff --git a/tensorflow/compiler/xla/service/hlo_matchers.h b/tensorflow/compiler/xla/service/hlo_matchers.h index 235efb19ce4ed28a5cd9fe5ca52ae5d8e9e5ba3d..1fbcbdf98d68204b1c6269d51d9b19363761ee04 100644 --- a/tensorflow/compiler/xla/service/hlo_matchers.h +++ b/tensorflow/compiler/xla/service/hlo_matchers.h @@ -312,8 +312,8 @@ inline ::testing::Matcher Shape( } inline ::testing::Matcher Shape( absl::string_view shape) { - return ::testing::MakeMatcher(new ::xla::testing::HloShapeMatcher( - ShapeUtil::ParseShapeString(shape).ValueOrDie())); + return ::testing::MakeMatcher( + new ::xla::testing::HloShapeMatcher(ParseShape(shape).ValueOrDie())); } inline ::testing::Matcher ShapeWithLayout( const class Shape& shape) { @@ -323,7 +323,7 @@ inline ::testing::Matcher ShapeWithLayout( inline ::testing::Matcher ShapeWithLayout( absl::string_view shape) { return ::testing::MakeMatcher(new ::xla::testing::HloShapeAndLayoutMatcher( - ShapeUtil::ParseShapeString(shape).ValueOrDie())); + ParseShape(shape).ValueOrDie())); } // Verifies the value of the HloSharing against the provided sharding object. diff --git a/tensorflow/compiler/xla/service/hlo_module.h b/tensorflow/compiler/xla/service/hlo_module.h index 7b9cbf9a53a2201b1312405bbd7ed2b88f65c9be..f1310e4b270898a21dbb4f86123edde4ba8993d0 100644 --- a/tensorflow/compiler/xla/service/hlo_module.h +++ b/tensorflow/compiler/xla/service/hlo_module.h @@ -136,7 +136,9 @@ class HloModule { // information on opcode, shape, operands, and typically a root instruction. // This function returns the same hash value for equivalent HLO modules, // with respect to HloInstruction::Identical() method. - uint64 Hash() const { return entry_computation()->Hash(); } + uint64 Hash() const { + return entry_computation()->root_instruction()->Hash(); + } // Gets the computations in this module. // diff --git a/tensorflow/compiler/xla/service/hlo_module_dce_test.cc b/tensorflow/compiler/xla/service/hlo_module_dce_test.cc index bf66cc6bc37a5e11c9ecfc07a62ba0ea5ca11a03..e535b7d74943943069b4d795cf999a3b1e963360 100644 --- a/tensorflow/compiler/xla/service/hlo_module_dce_test.cc +++ b/tensorflow/compiler/xla/service/hlo_module_dce_test.cc @@ -373,9 +373,9 @@ TEST_F(HloModuleDceTest, WhileWithOutfeed) { HloModule OutfeedLoop WhileBody { body_param = (s32[]) parameter(0) - token = token[] after-all() + token0 = token[] after-all() constant.2 = s32[] constant(2) - outfeed_tuple = (s32[]) outfeed(constant.2, token) + outfeed_tuple = (s32[]) outfeed(constant.2, token0) get-tuple-element.1 = s32[] get-tuple-element(body_param), index=0 constant.1 = s32[] constant(1) add = s32[] add(get-tuple-element.1, constant.1) diff --git a/tensorflow/compiler/xla/service/hlo_parser.cc b/tensorflow/compiler/xla/service/hlo_parser.cc index 9b5bb5d0bd6af104ef62eaa5d3e53cedbe0213d3..29bb088f6de9a5113d253b7e5559a8e66e7e408b 100644 --- a/tensorflow/compiler/xla/service/hlo_parser.cc +++ b/tensorflow/compiler/xla/service/hlo_parser.cc @@ -23,6 +23,7 @@ limitations under the License. #include "absl/strings/str_split.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/primitive_util.h" #include "tensorflow/compiler/xla/service/hlo_domain_metadata.h" #include "tensorflow/compiler/xla/service/hlo_instructions.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" @@ -74,6 +75,7 @@ class HloParser { string GetError() const { return StrJoin(error_, "\n"); } // Stand alone parsing utils for various aggregate data types. + StatusOr ParseShapeOnly(); StatusOr ParseShardingOnly(); StatusOr ParseWindowOnly(); StatusOr ParseConvolutionDimensionNumbersOnly(); @@ -255,7 +257,9 @@ class HloParser { bool ParseName(string* result); bool ParseAttributeName(string* result); bool ParseString(string* result); + bool ParseDimensionSizes(std::vector* dimension_sizes); bool ParseShape(Shape* result); + bool ParseLayout(Layout* layout); bool ParseOpcode(HloOpcode* result); bool ParseFftType(FftType* result); bool ParseFusionKind(HloInstruction::FusionKind* result); @@ -279,9 +283,6 @@ class HloParser { // If the current token is 'kind', eats it (i.e. lexes the next token) and // returns true. bool EatIfPresent(TokKind kind); - // Parses a shape, and returns true if the result is compatible with the given - // shape. - bool EatShapeAndCheckCompatible(const Shape& shape); // Adds the instruction to the pool. Returns false and emits an error if the // instruction already exists. @@ -1697,11 +1698,6 @@ bool HloParser::ParseSingleSharding(OpSharding* sharding, } break; } - case TokKind::kShape: - // TODO(b/112302613): Left here for backward compatibility to ignore the - // removed tile shape data. - lexer_.Lex(); - break; case TokKind::kRbrace: break; default: @@ -1925,19 +1921,6 @@ bool HloParser::SetValueInLiteralHelper(ParsedElemT value, return true; } -bool HloParser::EatShapeAndCheckCompatible(const Shape& shape) { - Shape new_shape; - if (!ParseShape(&new_shape)) { - return TokenError(StrCat("expects shape ", ShapeUtil::HumanString(shape))); - } - if (!ShapeUtil::Compatible(shape, new_shape)) { - return TokenError(StrCat( - "expects shape ", ShapeUtil::HumanString(shape), - ", but sees a different shape: ", ShapeUtil::HumanString(new_shape))); - } - return true; -} - // literal // ::= tuple // ::= non_tuple @@ -1952,10 +1935,6 @@ bool HloParser::ParseLiteral(Literal* literal, const Shape& shape) { // ::= /*empty*/ // ::= literal (',' literal)* bool HloParser::ParseTupleLiteral(Literal* literal, const Shape& shape) { - if (!EatShapeAndCheckCompatible(shape)) { - return TokenError(StrCat("expects tuple constant in shape ", - ShapeUtil::HumanString(shape))); - } if (!ParseToken(TokKind::kLparen, "expects '(' in front of tuple elements")) { return false; } @@ -1990,16 +1969,12 @@ bool HloParser::ParseNonTupleLiteral(Literal* literal, const Shape& shape) { return ParseSparseLiteral(literal, shape); } - CHECK(LayoutUtil::IsDenseArray(shape)); + CHECK(LayoutUtil::IsDenseArray(shape)) << shape.ToString(true); return ParseDenseLiteral(literal, shape); } bool HloParser::ParseDenseLiteral(Literal* literal, const Shape& shape) { const tensorflow::int64 rank = ShapeUtil::Rank(shape); - if (rank > 1 && !EatShapeAndCheckCompatible(shape)) { - return false; - } - // Create a literal with the given shape in default layout. *literal = LiteralUtil::CreateFromDimensions( shape.element_type(), AsInt64Slice(shape.dimensions())); @@ -2126,10 +2101,6 @@ bool HloParser::ParseDenseLiteral(Literal* literal, const Shape& shape) { } bool HloParser::ParseSparseLiteral(Literal* literal, const Shape& shape) { - if (!EatShapeAndCheckCompatible(shape)) { - return false; - } - switch (shape.element_type()) { case PRED: return ParseSparseLiteralHelper(literal, shape); @@ -2994,6 +2965,39 @@ bool HloParser::ParseParamList() { return ParseToken(TokKind::kRparen, "expects ')' at the end of param list"); } +// dimension_sizes ::= '[' int64_list ']' +bool HloParser::ParseDimensionSizes(std::vector* dimension_sizes) { + auto parse_and_add_item = [&]() { + tensorflow::int64 i; + if (!ParseInt64(&i)) { + return false; + } + dimension_sizes->push_back(i); + return true; + }; + return ParseList(TokKind::kLsquare, TokKind::kRsquare, TokKind::kComma, + parse_and_add_item); +} + +// layout ::= '{' int64_list '}' +bool HloParser::ParseLayout(Layout* layout) { + std::vector minor_to_major; + auto parse_and_add_item = [&]() { + tensorflow::int64 i; + if (!ParseInt64(&i)) { + return false; + } + minor_to_major.push_back(i); + return true; + }; + if (!ParseList(TokKind::kLbrace, TokKind::kRbrace, TokKind::kComma, + parse_and_add_item)) { + return false; + } + *layout = LayoutUtil::MakeLayout(minor_to_major); + return true; +} + // shape ::= shape_val_ // shape ::= '(' tuple_elements ')' // tuple_elements @@ -3017,19 +3021,61 @@ bool HloParser::ParseShape(Shape* result) { return ParseToken(TokKind::kRparen, "expects ')' at the end of tuple."); } - if (lexer_.GetKind() != TokKind::kShape) { - return TokenError(absl::StrCat("expected shape, saw ", + if (lexer_.GetKind() != TokKind::kPrimitiveType) { + return TokenError(absl::StrCat("expected primitive type, saw ", TokKindToString(lexer_.GetKind()))); } - *result = lexer_.GetShapeVal(); + PrimitiveType primitive_type = lexer_.GetPrimitiveTypeVal(); lexer_.Lex(); + + std::vector dimension_sizes; + if (!ParseDimensionSizes(&dimension_sizes)) { + return false; + } + result->set_element_type(primitive_type); + *result->mutable_dimensions() = dimension_sizes; + LayoutUtil::SetToDefaultLayout(result); + + if (lexer_.GetKind() == TokKind::kw_sparse) { + lexer_.Lex(); + const string message = + "expects a brace-bracketed integer for sparse layout"; + tensorflow::int64 max_sparse_elements; + if (!ParseToken(TokKind::kLbrace, message) || + !ParseInt64(&max_sparse_elements) || + !ParseToken(TokKind::kRbrace, message)) { + return false; + } + *result->mutable_layout() = + LayoutUtil::MakeSparseLayout(max_sparse_elements); + return true; + } + + // We need to lookahead to see if a following open brace is the start of a + // layout. The specific problematic case is: + // + // ENTRY %foo (x: f32[42]) -> f32[123] { + // ... + // } + // + // The open brace could either be the start of a computation or the start of a + // layout for the f32[123] shape. We consider it the start of a layout if the + // next token after the open brace is a integer + if (lexer_.GetKind() == TokKind::kLbrace && + lexer_.LookAhead() == TokKind::kInt) { + Layout layout; + if (!ParseLayout(&layout)) { + return false; + } + *result->mutable_layout() = layout; + } return true; } bool HloParser::CanBeShape() { - // A non-tuple shape starts with a kShape token; a tuple shape starts with - // '('. - return lexer_.GetKind() == TokKind::kShape || + // A non-tuple shape starts with a kPrimitiveType token; a tuple shape starts + // with '('. + return lexer_.GetKind() == TokKind::kPrimitiveType || lexer_.GetKind() == TokKind::kLparen; } @@ -3332,6 +3378,18 @@ bool HloParser::AddComputation(const string& name, HloComputation* computation, return true; } +StatusOr HloParser::ParseShapeOnly() { + lexer_.Lex(); + Shape shape; + if (!ParseShape(&shape)) { + return InvalidArgument("Syntax error:\n%s", GetError()); + } + if (lexer_.GetKind() != TokKind::kEof) { + return InvalidArgument("Syntax error:\nExtra content after shape"); + } + return shape; +} + StatusOr HloParser::ParseShardingOnly() { lexer_.Lex(); OpSharding op_sharding; @@ -3475,4 +3533,9 @@ StatusOr ParsePaddingConfig(absl::string_view str) { return parser.ParsePaddingConfigOnly(); } +StatusOr ParseShape(absl::string_view str) { + HloParser parser(str); + return parser.ParseShapeOnly(); +} + } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_parser.h b/tensorflow/compiler/xla/service/hlo_parser.h index d830fa61438239005875f785f85cf2486123ebc9..450a54c54c156c2ae27475d145a8e83dc841b431 100644 --- a/tensorflow/compiler/xla/service/hlo_parser.h +++ b/tensorflow/compiler/xla/service/hlo_parser.h @@ -60,6 +60,9 @@ StatusOr ParseConvolutionDimensionNumbers( // Parses the result of PaddingConfigToString(), e.g. "0_0x1_1". StatusOr ParsePaddingConfig(absl::string_view str); +// Parses and returns a Shape::ToString-format string. +StatusOr ParseShape(absl::string_view str); + } // namespace xla #endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_PARSER_H_ diff --git a/tensorflow/compiler/xla/service/hlo_parser_test.cc b/tensorflow/compiler/xla/service/hlo_parser_test.cc index ab71f011ac9d77d00ddfb41aca7a224d26d416b7..80882d490d6b477403f87a4eb266d3ba2fdb3378 100644 --- a/tensorflow/compiler/xla/service/hlo_parser_test.cc +++ b/tensorflow/compiler/xla/service/hlo_parser_test.cc @@ -82,7 +82,7 @@ ENTRY %constant_pred () -> pred[] { R"(HloModule module ENTRY %constant_pred_array () -> pred[2,3] { - ROOT %constant = pred[2,3]{1,0} constant(pred[2,3] { { 0, 1, 0 }, { 1, 0, 1 } }) + ROOT %constant = pred[2,3]{1,0} constant({ { 0, 1, 0 }, { 1, 0, 1 } }) } )" @@ -128,7 +128,7 @@ ENTRY %ConstantF32Empty.v4 () -> f32[0] { R"(HloModule ConstantF32R4Empty_module ENTRY %ConstantF32R4Empty.v4 () -> f32[2,0,4,3] { - ROOT %constant = f32[2,0,4,3]{3,2,1,0} constant(f32[2,0,4,3] { { /*i0=0*/ }, { /*i0=1*/ } }) + ROOT %constant = f32[2,0,4,3]{3,2,1,0} constant({ { /*i0=0*/ }, { /*i0=1*/ } }) } )" @@ -139,7 +139,7 @@ ENTRY %ConstantF32R4Empty.v4 () -> f32[2,0,4,3] { R"(HloModule Small_3x2x1x1_module ENTRY %Small_3x2x1x1.v1 () -> f32[3,2,1,1] { - ROOT %constant = f32[3,2,1,1]{3,2,1,0} constant(f32[3,2,1,1] { { /*i0=0*/ { /*i1=0*/ {-1} }, { /*i1=1*/ {4.1} } }, { /*i0=1*/ { /*i1=0*/ {2} }, { /*i1=1*/ {4.1} } }, { /*i0=2*/ { /*i1=0*/ {5} }, { /*i1=1*/ {4.4} } } }) + ROOT %constant = f32[3,2,1,1]{3,2,1,0} constant({ { /*i0=0*/ { /*i1=0*/ {-1} }, { /*i1=1*/ {4.1} } }, { /*i0=1*/ { /*i1=0*/ {2} }, { /*i1=1*/ {4.1} } }, { /*i0=2*/ { /*i1=0*/ {5} }, { /*i1=1*/ {4.4} } } }) } )" @@ -196,7 +196,7 @@ ENTRY %add_constants () -> f32[] { R"(HloModule TupleConstant_module ENTRY %TupleConstant.v1 () -> (f32[2,1], f32[2]) { - ROOT %constant = (f32[2,1]{1,0}, f32[2]{0}) constant((f32[2,1], f32[2]) ( f32[2,1] { {1}, {2} }, {2, 42} )) + ROOT %constant = (f32[2,1]{1,0}, f32[2]{0}) constant(( { {1}, {2} }, {2, 42} )) } )" @@ -295,11 +295,11 @@ ENTRY %WhileWithScalarS32Result.v2 () -> s32[] { R"(HloModule TwoSendRecvBothWayRecvFist_module ENTRY %TwoSendRecvBothWayRecvFist.v3 () -> (f32[], token[]) { - %token = token[] after-all() - %recv = (f32[], u32[], token[]) recv(token[] %token), channel_id=15, sharding={maximal device=1} + %token0 = token[] after-all() + %recv = (f32[], u32[], token[]) recv(token[] %token0), channel_id=15, sharding={maximal device=1} ROOT %recv-done = (f32[], token[]) recv-done((f32[], u32[], token[]) %recv), channel_id=15, sharding={maximal device=1} %constant = f32[] constant(2.1), sharding={maximal device=0} - %send = (f32[], u32[], token[]) send(f32[] %constant, token[] %token), channel_id=16, sharding={maximal device=0}, control-predecessors={%recv} + %send = (f32[], u32[], token[]) send(f32[] %constant, token[] %token0), channel_id=16, sharding={maximal device=0}, control-predecessors={%recv} %send-done = token[] send-done((f32[], u32[], token[]) %send), channel_id=16, sharding={maximal device=0} } @@ -310,11 +310,11 @@ ENTRY %TwoSendRecvBothWayRecvFist.v3 () -> (f32[], token[]) { R"(HloModule HostTransferSendRecv_module ENTRY %TwoSendRecvBothWayRecvFist.v3 () -> (f32[], token[]) { - %token = token[] after-all() - %recv = (f32[], u32[], token[]) recv(token[] %token), channel_id=15, is_host_transfer=true + %token0 = token[] after-all() + %recv = (f32[], u32[], token[]) recv(token[] %token0), channel_id=15, is_host_transfer=true ROOT %recv-done = (f32[], token[]) recv-done((f32[], u32[], token[]) %recv), channel_id=15, is_host_transfer=true %constant = f32[] constant(2.1), sharding={maximal device=0} - %send = (f32[], u32[], token[]) send(f32[] %constant, token[] %token), channel_id=16, is_host_transfer=true + %send = (f32[], u32[], token[]) send(f32[] %constant, token[] %token0), channel_id=16, is_host_transfer=true %send-done = token[] send-done((f32[], u32[], token[]) %send), channel_id=16, is_host_transfer=true } @@ -327,7 +327,7 @@ R"(HloModule GetTupleElement_module ENTRY %GetTupleElement.v4 () -> s32[2,3] { %constant = f32[3]{0} constant({1, 2, 3}) - %constant.1 = s32[2,3]{1,0} constant(s32[2,3] { { 1, 2, 3 }, { 4, 5, 6 } }) + %constant.1 = s32[2,3]{1,0} constant({ { 1, 2, 3 }, { 4, 5, 6 } }) %tuple = (f32[3]{0}, s32[2,3]{1,0}) tuple(f32[3]{0} %constant, s32[2,3]{1,0} %constant.1) ROOT %get-tuple-element = s32[2,3]{1,0} get-tuple-element((f32[3]{0}, s32[2,3]{1,0}) %tuple), index=1, sharding={maximal device=0} } @@ -434,7 +434,7 @@ ENTRY %ConvolveBackward (input: f32[128,7,7,512], filter: f32[3,3,512,512]) -> f R"(HloModule Reverse4DFloatArrayOnDim01_module ENTRY %Reverse4DFloatArrayOnDim01.v2 () -> f32[4,3,2,1] { - %constant = f32[4,3,2,1]{0,1,2,3} constant(f32[4,3,2,1] { { /*i0=0*/ { /*i1=0*/ {1}, {2} }, { /*i1=1*/ {3}, {4} }, { /*i1=2*/ {5}, {6} } }, { /*i0=1*/ { /*i1=0*/ {7}, {8} }, { /*i1=1*/ {9}, {10} }, { /*i1=2*/ {11}, {12} } }, { /*i0=2*/ { /*i1=0*/ {13}, {14} }, { /*i1=1*/ {15}, {16} }, { /*i1=2*/ {17}, {18} } }, { /*i0=3*/ { /*i1=0*/ {19}, {20} }, { /*i1=1*/ {21}, {22} }, { /*i1=2*/ {23}, {24} } } }) + %constant = f32[4,3,2,1]{0,1,2,3} constant({ { /*i0=0*/ { /*i1=0*/ {1}, {2} }, { /*i1=1*/ {3}, {4} }, { /*i1=2*/ {5}, {6} } }, { /*i0=1*/ { /*i1=0*/ {7}, {8} }, { /*i1=1*/ {9}, {10} }, { /*i1=2*/ {11}, {12} } }, { /*i0=2*/ { /*i1=0*/ {13}, {14} }, { /*i1=1*/ {15}, {16} }, { /*i1=2*/ {17}, {18} } }, { /*i0=3*/ { /*i1=0*/ {19}, {20} }, { /*i1=1*/ {21}, {22} }, { /*i1=2*/ {23}, {24} } } }) ROOT %reverse = f32[4,3,2,1]{0,1,2,3} reverse(f32[4,3,2,1]{0,1,2,3} %constant), dimensions={0,1} } @@ -446,8 +446,8 @@ ENTRY %Reverse4DFloatArrayOnDim01.v2 () -> f32[4,3,2,1] { R"(HloModule Concat2x3With2x5_module ENTRY %Concat2x3With2x5.v3 () -> f32[2,8] { - %constant = f32[2,3]{1,0} constant(f32[2,3] { { 0, 1, 2 }, { 1000, 1001, 1002 } }) - %constant.1 = f32[2,5]{1,0} constant(f32[2,5] { { 64, 65, 66, 67, 68 }, { 1064, 1065, 1066, 1067, 1068 } }) + %constant = f32[2,3]{1,0} constant({ { 0, 1, 2 }, { 1000, 1001, 1002 } }) + %constant.1 = f32[2,5]{1,0} constant({ { 64, 65, 66, 67, 68 }, { 1064, 1065, 1066, 1067, 1068 } }) ROOT %concatenate = f32[2,8]{1,0} concatenate(f32[2,3]{1,0} %constant, f32[2,5]{1,0} %constant.1), dimensions={1} } @@ -471,8 +471,8 @@ R"(HloModule R4F32OverlapSmall_module } ENTRY %R4F32OverlapSmall.v4 () -> f32[4,5,1,1] { - %constant = f32[4,5,1,1]{3,2,1,0} constant(f32[4,5,1,1] { { /*i0=0*/ { /*i1=0*/ {7} }, { /*i1=1*/ {2} }, { /*i1=2*/ {5} }, { /*i1=3*/ {3} }, { /*i1=4*/ {8} } }, { /*i0=1*/ { /*i1=0*/ {3} }, { /*i1=1*/ {8} }, { /*i1=2*/ {9} }, { /*i1=3*/ {3} }, { /*i1=4*/ {4} } }, { /*i0=2*/ { /*i1=0*/ {1} }, { /*i1=1*/ {5} }, { /*i1=2*/ {7} }, { /*i1=3*/ {5} }, { /*i1=4*/ {6} } }, { /*i0=3*/ { /*i1=0*/ {0} }, { /*i1=1*/ {6} }, { /*i1=2*/ {2} }, { /*i1=3*/ {10} }, { /*i1=4*/ {2} } } }) - %constant.1 = f32[2,2,1,1]{3,2,1,0} constant(f32[2,2,1,1] { { /*i0=0*/ { /*i1=0*/ {2} }, { /*i1=1*/ {6} } }, { /*i0=1*/ { /*i1=0*/ {3} }, { /*i1=1*/ {1} } } }) + %constant = f32[4,5,1,1]{3,2,1,0} constant({ { /*i0=0*/ { /*i1=0*/ {7} }, { /*i1=1*/ {2} }, { /*i1=2*/ {5} }, { /*i1=3*/ {3} }, { /*i1=4*/ {8} } }, { /*i0=1*/ { /*i1=0*/ {3} }, { /*i1=1*/ {8} }, { /*i1=2*/ {9} }, { /*i1=3*/ {3} }, { /*i1=4*/ {4} } }, { /*i0=2*/ { /*i1=0*/ {1} }, { /*i1=1*/ {5} }, { /*i1=2*/ {7} }, { /*i1=3*/ {5} }, { /*i1=4*/ {6} } }, { /*i0=3*/ { /*i1=0*/ {0} }, { /*i1=1*/ {6} }, { /*i1=2*/ {2} }, { /*i1=3*/ {10} }, { /*i1=4*/ {2} } } }) + %constant.1 = f32[2,2,1,1]{3,2,1,0} constant({ { /*i0=0*/ { /*i1=0*/ {2} }, { /*i1=1*/ {6} } }, { /*i0=1*/ { /*i1=0*/ {3} }, { /*i1=1*/ {1} } } }) %constant.2 = f32[] constant(0) ROOT %select-and-scatter = f32[4,5,1,1]{3,2,1,0} select-and-scatter(f32[4,5,1,1]{3,2,1,0} %constant, f32[2,2,1,1]{3,2,1,0} %constant.1, f32[] %constant.2), window={size=2x3x1x1 stride=2x2x1x1}, select=%ge_F32.v3, scatter=%add_F32.v3 } @@ -523,7 +523,7 @@ ENTRY %slice.v2 (p0: f32[3,3,4,4]) -> f32[3,3,2,4] { R"(HloModule Slice3x3x3_To_1x3x3_F32_module ENTRY %Slice3x3x3_To_1x3x3_F32.v2 () -> f32[1,3,3] { - %constant = f32[3,3,3]{2,1,0} constant(f32[3,3,3] { { { 0, 1, 2 }, { 3, 4, 5 }, { 6, 7, 8 } }, { { 9, 10, 11 }, { 12, 13, 14 }, { 15, 16, 17 } }, { { 18, 19, 20 }, { 21, 22, 23 }, { 24, 25, 26 } } }) + %constant = f32[3,3,3]{2,1,0} constant({ { { 0, 1, 2 }, { 3, 4, 5 }, { 6, 7, 8 } }, { { 9, 10, 11 }, { 12, 13, 14 }, { 15, 16, 17 } }, { { 18, 19, 20 }, { 21, 22, 23 }, { 24, 25, 26 } } }) ROOT %slice = f32[1,3,3]{2,1,0} slice(f32[3,3,3]{2,1,0} %constant), slice={[0:1], [0:3], [0:3]} } @@ -547,7 +547,7 @@ ENTRY %SliceR0.v2 () -> s32[] { R"(HloModule Transpose_module ENTRY %Transpose.v2 () -> s32[1,2,3] { - %constant = s32[1,2,3]{2,1,0} constant(s32[1,2,3] { { { 1, 2, 3 }, { 4, 5, 6 } } }) + %constant = s32[1,2,3]{2,1,0} constant({ { { 1, 2, 3 }, { 4, 5, 6 } } }) ROOT %transpose = s32[1,2,3]{2,1,0} transpose(s32[1,2,3]{2,1,0} %constant), dimensions={0,1,2} } @@ -588,7 +588,7 @@ ENTRY %DynamicUpdateSlice.v4 (input: s32[1,1,25,1], update: s32[1,1,2,1], start_ R"(HloModule BasicTraining_module ENTRY %BasicTraining.v4 () -> (f32[2,2,1,2], f32[2], f32[2]) { - %constant = f32[2,2,1,2]{3,2,1,0} constant(f32[2,2,1,2] { { /*i0=0*/ { /*i1=0*/ { 1, 2 } }, { /*i1=1*/ { 3, 4 } } }, { /*i0=1*/ { /*i1=0*/ { 5, 6 } }, { /*i1=1*/ { 7, 8 } } } }) + %constant = f32[2,2,1,2]{3,2,1,0} constant({ { /*i0=0*/ { /*i1=0*/ { 1, 2 } }, { /*i1=1*/ { 3, 4 } } }, { /*i0=1*/ { /*i1=0*/ { 5, 6 } }, { /*i1=1*/ { 7, 8 } } } }) %constant.1 = f32[2]{0} constant({2, 3}) %constant.2 = f32[2]{0} constant({1, 2}) ROOT %batch-norm-training = (f32[2,2,1,2]{3,2,1,0}, f32[2]{0}, f32[2]{0}) batch-norm-training(f32[2,2,1,2]{3,2,1,0} %constant, f32[2]{0} %constant.1, f32[2]{0} %constant.2), epsilon=0.001, feature_index=3 @@ -728,7 +728,7 @@ R"(HloModule fusion_module } ENTRY %fusion.v3 () -> f32[3,2,1,1] { - %constant = f32[3,2,1,1]{3,2,1,0} constant(f32[3,2,1,1] { { /*i0=0*/ { /*i1=0*/ {-1} }, { /*i1=1*/ {4.1} } }, { /*i0=1*/ { /*i1=0*/ {2} }, { /*i1=1*/ {4.1} } }, { /*i0=2*/ { /*i1=0*/ {5} }, { /*i1=1*/ {4.4} } } }) + %constant = f32[3,2,1,1]{3,2,1,0} constant({ { /*i0=0*/ { /*i1=0*/ {-1} }, { /*i1=1*/ {4.1} } }, { /*i0=1*/ { /*i1=0*/ {2} }, { /*i1=1*/ {4.1} } }, { /*i0=2*/ { /*i1=0*/ {5} }, { /*i1=1*/ {4.4} } } }) %constant.1 = f32[2]{0} constant({3.14, 4.25}) ROOT %fusion = f32[3,2,1,1]{3,2,1,0} fusion(f32[3,2,1,1]{3,2,1,0} %constant, f32[2]{0} %constant.1), kind=kLoop, calls=%fused_computation } @@ -740,7 +740,7 @@ ENTRY %fusion.v3 () -> f32[3,2,1,1] { R"(HloModule sparse_f32 ENTRY %sparse () -> f32[2,3,4] { - ROOT %foo = f32[2,3,4]sparse{10} constant(f32[2,3,4]{[0, 1, 2]: 1, [1, 2, 3]: 2, [2, 3, 4]: 3}) + ROOT %foo = f32[2,3,4]sparse{10} constant({[0, 1, 2]: 1, [1, 2, 3]: 2, [2, 3, 4]: 3}) } )" @@ -750,7 +750,7 @@ ENTRY %sparse () -> f32[2,3,4] { R"(HloModule sparse_f32_empty ENTRY %sparse_f32_empty () -> f32[2,3,4] { - ROOT %foo = f32[2,3,4]sparse{10} constant(f32[2,3,4]{}) + ROOT %foo = f32[2,3,4]sparse{10} constant({}) } )" @@ -760,7 +760,7 @@ ENTRY %sparse_f32_empty () -> f32[2,3,4] { R"(HloModule sparse_f32_r1 ENTRY %sparse_f32_r1 () -> f32[9] { - ROOT %foo = f32[9]sparse{10} constant(f32[9]{1: 2, 3: 4, 5: 6}) + ROOT %foo = f32[9]sparse{10} constant({1: 2, 3: 4, 5: 6}) } )" @@ -931,11 +931,11 @@ ENTRY reduce_entry { R"(HloModule outfeed_module ENTRY InfeedToOutfeed { - token = token[] after-all() - infeed = ((u32[3]{0}, pred[]), token[]) infeed(token) + token0 = token[] after-all() + infeed = ((u32[3]{0}, pred[]), token[]) infeed(token0) infeed.data = (u32[3]{0}, pred[]) get-tuple-element(infeed), index=0 - outfeed = token[] outfeed(infeed.data, token) - ROOT infeed.1 = ((u32[3]{0}, pred[]), token[]) infeed(token) + outfeed = token[] outfeed(infeed.data, token0) + ROOT infeed.1 = ((u32[3]{0}, pred[]), token[]) infeed(token0) infeed.1.data = (u32[3]{0}, pred[]) get-tuple-element(infeed.1), index=0 infeed.1.token = token[] get-tuple-element(infeed.1), index=1 outfeed.1 = token[] outfeed(infeed.1.data, infeed.1.token) @@ -1266,8 +1266,8 @@ R"(HloModule AddDependency ENTRY AddDependency { p = f32[] parameter(0) neg = f32[] negate(p) - token = token[] after-all(neg) - p_after_token = f32[] add-dependency(p, token) + token0 = token[] after-all(neg) + p_after_token = f32[] add-dependency(p, token0) exp = f32[] exponential(p_after_token) ROOT sum = f32[] add(neg, exp) } @@ -1419,7 +1419,7 @@ TEST_F(HloParserTest, MoreConstants) { ENTRY %SelectScalarS32True.v4 () -> s32[] { %constant.2 = pred[] constant(true) - %constant.1 = s32[] constant(-42), sharding={s32[5,6] devices=[2,2]1,2,3,4} + %constant.1 = s32[] constant(-42), sharding={devices=[2,2]1,2,3,4} %constant = s32[] constant(42) %select = s32[] select(pred[] %constant.2, s32[] %constant.1, s32[] %constant) } @@ -1462,7 +1462,7 @@ TEST_F(HloParserTest, LiteralDimensionsMismatch_2) { const string original = R"(HloModule some_2x3_module ENTRY %some_2x3 () -> f32[2,3] { - ROOT %constant = f32[2,3]{1,0} constant(f32[2,3] {1, 2, 3, 4, 5, 6}) + ROOT %constant = f32[2,3]{1,0} constant({1, 2, 3, 4, 5, 6}) } )"; @@ -1476,7 +1476,7 @@ TEST_F(HloParserTest, LiteralDimensionsMismatch_3) { const string original = R"(HloModule some_2x3x2_module ENTRY %some_2x3x2 () -> f32[2,3,2] { - ROOT %constant = f32[2,3,2]{2,1,0} constant(f32[2,3,2] {{{1, 2}, {3, 4}, {5, 6}, {7, 8}, {9, 10}, {11, 12}}}) + ROOT %constant = f32[2,3,2]{2,1,0} constant({{{1, 2}, {3, 4}, {5, 6}, {7, 8}, {9, 10}, {11, 12}}}) } )"; @@ -1594,11 +1594,11 @@ TEST_F(HloParserTest, UnexpectedAttribute) { const string original = R"(HloModule unexpected_attr_module ENTRY %TwoSendRecvBothWayRecvFist.v3 () -> f32[] { - %token = token[] after-all() - %recv = (f32[], u32[], token[]) recv(token[] %token), channel_id=15 + %token0 = token[] after-all() + %recv = (f32[], u32[], token[]) recv(token[] %token0), channel_id=15 %recv-done = (f32[], token[]) recv-done((f32[], u32[], token[]) %recv), channel_id=15 ROOT %constant = f32[] constant(2.1) - %send = (f32[], u32[], token[]) send(f32[] %constant, token[] %token), channel_id=16, calls=%recv + %send = (f32[], u32[], token[]) send(f32[] %constant, token[] %token0), channel_id=16, calls=%recv %send-done = token[] send-done((f32[], u32[], token[]) %send), channel_id=16 } @@ -1611,11 +1611,11 @@ TEST_F(HloParserTest, MissingAttribute) { const string original = R"(HloModule missing_attr_module ENTRY %TwoSendRecvBothWayRecvFist.v3 () -> f32[] { - %token = token[] after-all() - %recv = (f32[], u32[], token[]) recv(token[] %token), channel_id=15 + %token0 = token[] after-all() + %recv = (f32[], u32[], token[]) recv(token[] %token0), channel_id=15 %recv-done = (f32[], token[]) recv-done((f32[], u32[], token[]) %recv), channel_id=15 ROOT %constant = f32[] constant(-2.1) - %send = (f32[], u32[], token[]) send(f32[] %constant, token[] %token) + %send = (f32[], u32[], token[]) send(f32[] %constant, token[] %token0) %send-done = token[] send-done((f32[], u32[], token[]) %send), channel_id=16 } @@ -1628,11 +1628,11 @@ TEST_F(HloParserTest, PredecessorUndefined) { const string original = R"(HloModule pre_not_found_module ENTRY %TwoSendRecvBothWayRecvFist.v3 () -> f32[] { - %token = token[] after-all() - %recv = (f32[], u32[], token[]) recv(token[] %token), channel_id=15 + %token0 = token[] after-all() + %recv = (f32[], u32[], token[]) recv(token[] %token0), channel_id=15 %recv-done = (f32[], token[]) recv-done((f32[], u32[], token[]) %recv), channel_id=15 ROOT %constant = f32[] constant(2.1) - %send = (f32[], u32[], token[]) send(f32[] %constant, token[] %token), channel_id=16, control-predecessors={%done} + %send = (f32[], u32[], token[]) send(f32[] %constant, token[] %token0), channel_id=16, control-predecessors={%done} %send-done = token[] send-done((f32[], u32[], token[]) %send), channel_id=16 } @@ -1940,8 +1940,8 @@ TEST_F(HloParserTest, ParsePaddingConfigInteriorPaddingImplicitZeroDim) { TEST_F(HloParserTest, NontupleInfeed) { const string original = R"(HloModule nontuple_infeed: ENTRY nontuple_infeed { - token = token[] after-all() - ROOT infeed = pred[] infeed(token) + token0 = token[] after-all() + ROOT infeed = pred[] infeed(token0) })"; ExpectHasSubstr(ParseHloString(original).status().error_message(), "infeed must have a non-empty tuple shape"); @@ -2239,7 +2239,7 @@ HloModule foobar ENTRY %entrycomp (p: f32[2,2]) -> f32[2,2] { %p = f32[2,2] parameter(0) - %constant.1 = f32[2,2] constant(f32[2,2] {{1, 2}, {3, 4}}) + %constant.1 = f32[2,2] constant({{1, 2}, {3, 4}}) ROOT %add.1 = f32[2,2] add(f32[2,2] %p, f32[2,5] %constant.1) } )"; @@ -2249,7 +2249,85 @@ ENTRY %entrycomp (p: f32[2,2]) -> f32[2,2] { " with the shape of the operand instruction f32[2,2]{1,0}."); } -// custom call incompatible shape. +TEST_F(HloParserTest, ParseShapeStringR2F32) { + string shape_string = "f32[123,456]"; + TF_ASSERT_OK_AND_ASSIGN(Shape actual, ParseShape(shape_string)); + Shape expected = ShapeUtil::MakeShape(F32, {123, 456}); + ASSERT_TRUE(ShapeUtil::Equal(expected, actual)) + << "expected: " << ShapeUtil::HumanString(expected) + << "actual: " << ShapeUtil::HumanString(actual); +} + +TEST_F(HloParserTest, ParseShapeStringTupleOfArrays) { + string shape_string = "(f32[1572864],s8[5120,1024])"; + TF_ASSERT_OK_AND_ASSIGN(Shape actual, ParseShape(shape_string)); + Shape expected = + ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(F32, {1572864}), + ShapeUtil::MakeShape(S8, {5120, 1024})}); + ASSERT_TRUE(ShapeUtil::Equal(expected, actual)) + << "expected: " << ShapeUtil::HumanString(expected) + << "actual: " << ShapeUtil::HumanString(actual); +} + +TEST_F(HloParserTest, ParseShapeStringNestedTuple) { + string shape_string = "(f32[1],(f32[2], token[]), opaque[], f32[3])"; + TF_ASSERT_OK_AND_ASSIGN(Shape actual, ParseShape(shape_string)); + Shape expected = ShapeUtil::MakeTupleShape({ + ShapeUtil::MakeShape(F32, {1}), + ShapeUtil::MakeTupleShape( + {ShapeUtil::MakeShape(F32, {2}), ShapeUtil::MakeTokenShape()}), + ShapeUtil::MakeOpaqueShape(), + ShapeUtil::MakeShape(F32, {3}), + }); + ASSERT_TRUE(ShapeUtil::Equal(expected, actual)) + << "expected: " << ShapeUtil::HumanString(expected) + << "actual: " << ShapeUtil::HumanString(actual); +} + +TEST_F(HloParserTest, ParseShapeStringWithLayout) { + string shape_string = "f32[123,456]{0,1}"; + TF_ASSERT_OK_AND_ASSIGN(Shape actual, ParseShape(shape_string)); + Shape expected = ShapeUtil::MakeShapeWithLayout(F32, {123, 456}, {0, 1}); + ASSERT_TRUE(ShapeUtil::Equal(expected, actual)) + << "expected: " << ShapeUtil::HumanString(expected) + << "actual: " << ShapeUtil::HumanString(actual); +} + +TEST_F(HloParserTest, ParseShapeStringWithSparseLayout) { + string shape_string = "f32[123,456]sparse{10}"; + TF_ASSERT_OK_AND_ASSIGN(Shape actual, ParseShape(shape_string)); + Shape expected = ShapeUtil::MakeShapeWithSparseLayout(F32, {123, 456}, 10); + ASSERT_TRUE(ShapeUtil::Equal(expected, actual)) + << "expected: " << ShapeUtil::HumanString(expected) + << "actual: " << ShapeUtil::HumanString(actual); +} + +TEST_F(HloParserTest, ParseOpaqueType) { + TF_ASSERT_OK_AND_ASSIGN(Shape actual, ParseShape("opaque[]")); + Shape expected = ShapeUtil::MakeOpaqueShape(); + ASSERT_TRUE(ShapeUtil::Equal(expected, actual)) + << "expected: " << ShapeUtil::HumanString(expected) + << "actual: " << ShapeUtil::HumanString(actual); +} + +TEST_F(HloParserTest, ParseTokenType) { + TF_ASSERT_OK_AND_ASSIGN(Shape actual, ParseShape("token[]")); + Shape expected = ShapeUtil::MakeTokenShape(); + ASSERT_TRUE(ShapeUtil::Equal(expected, actual)) + << "expected: " << ShapeUtil::HumanString(expected) + << "actual: " << ShapeUtil::HumanString(actual); +} + +TEST_F(HloParserTest, ParseInvalidShapeString) { + string shape_strings[] = { + "f32[123,456]foobar{0,1}", "f32[123,456]sparse{0,1}", "f32[123,456]{foo}", + "f32[123,456]dense{foo}", "f32[123,456]sparse{foo}", + }; + for (const string& shape_string : shape_strings) { + StatusOr result = ParseShape(shape_string); + ASSERT_FALSE(result.ok()) << "shape: " << shape_string; + } +} } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_pass_pipeline.cc b/tensorflow/compiler/xla/service/hlo_pass_pipeline.cc index 312b5d020c398feb7738d14a9cfa0928d5178948..33ce7e23a82d840676bba5f1ca9c0ffc4433465d 100644 --- a/tensorflow/compiler/xla/service/hlo_pass_pipeline.cc +++ b/tensorflow/compiler/xla/service/hlo_pass_pipeline.cc @@ -77,6 +77,11 @@ std::vector HloPassPipeline::GetEnabledPasses( auto repeated_field = debug_options.xla_disable_hlo_passes(); absl::flat_hash_set disabled_pass_names(repeated_field.begin(), repeated_field.end()); + if (debug_options.xla_disable_all_hlo_passes()) { + VLOG(1) << "*All* passes disabled by --xla_disable_all_hlo_passes."; + return {}; + } + if (!disabled_pass_names.empty()) { VLOG(1) << "Passes disabled by --xla_disable_hlo_passes: " << absl::StrJoin(disabled_pass_names, ", "); @@ -113,7 +118,7 @@ void HloPassPipeline::MaybeDumpHlo(const HloModule& module, } const string message = - StrCat("after ", after_pass_name, ", before ", before_pass_name); + absl::StrCat("after ", after_pass_name, ", before ", before_pass_name); hlo_graph_dumper::MaybeDumpHloModule(module, message); VLOG(3) << "HLO " << message << ":"; VLOG(3) << module.entry_computation_layout().ToString(); diff --git a/tensorflow/compiler/xla/service/hlo_proto_util.cc b/tensorflow/compiler/xla/service/hlo_proto_util.cc index 981d06ce101644ecce587c4bd2f7a12c8edf6548..3a9ee57e5551ae5b608f02d9f8bd0428ff16db13 100644 --- a/tensorflow/compiler/xla/service/hlo_proto_util.cc +++ b/tensorflow/compiler/xla/service/hlo_proto_util.cc @@ -39,6 +39,7 @@ HloProto MakeHloProto(const HloModule& module) { StatusOr> CreateModuleFromProto( const HloModuleProto& proto, const HloModuleConfig& module_config) { + VLOG(4) << proto.ShortDebugString(); TF_ASSIGN_OR_RETURN(std::unique_ptr module, HloModule::CreateFromProto(proto, module_config)); TF_RETURN_IF_ERROR( diff --git a/tensorflow/compiler/xla/service/hlo_token.h b/tensorflow/compiler/xla/service/hlo_token.h deleted file mode 100644 index 4458c251dee4af365e39027dd4289925c8890efd..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/xla/service/hlo_token.h +++ /dev/null @@ -1,78 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_TOKEN_H_ -#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_TOKEN_H_ - -#include - -#include "tensorflow/compiler/xla/types.h" -#include "tensorflow/core/platform/types.h" - -namespace xla { - -// Defines different kinds of tokens in a hlo module string. -// -// You shouldn't need to use this directly unless you're using HloLexer -// directly, and you probably don't need to do that. Use hlo_parser instead. -enum class TokKind { - // Markers - kEof, - kError, - - // Tokens with no info. - kEqual, // = - kComma, // , - kColon, // : - kLsquare, - kRsquare, // [ ] - kLbrace, - kRbrace, // { } - kLparen, - kRparen, // ( ) - - kArrow, // -> - - // Keywords - kw_HloModule, - kw_ENTRY, - kw_ROOT, - kw_true, - kw_false, - kw_maximal, - kw_replicated, - kw_nan, - kw_inf, - - kNegInf, // -inf - - // Typed tokens. - kName, // %foo - kAttributeName, // dimensions= - kDimLabels, // [0-9bf]{2,}_[0-9io]{2,}->[0-9bf]{2,} - kDxD, // [0-9]+(x[0-9]+)+ - kPad, // [0-9]+_[0-9]+(_[0-9]+)?(x[0-9]+_[0-9]+(_[0-9]+)?)* - kIdent, // other identifiers - kString, // "abcd\"\n" - kShape, // f32[2,3]{1,0} - kInt, // 42 - kDecimal, // 4.2 -}; - -string TokKindToString(TokKind kind); - -} // namespace xla - -#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_TOKEN_H_ diff --git a/tensorflow/compiler/xla/service/hlo_verifier.cc b/tensorflow/compiler/xla/service/hlo_verifier.cc index 77db7b098a38ff4efdcc7447935fae61561c9ff4..ace854ed6a243c3788a46333f41cb85d90c8e174 100644 --- a/tensorflow/compiler/xla/service/hlo_verifier.cc +++ b/tensorflow/compiler/xla/service/hlo_verifier.cc @@ -481,7 +481,9 @@ Status ShapeVerifier::HandleCustomCall(HloInstruction* instruction) { const Shape& operand_shape_with_layout = custom_call->operand_shapes_with_layout()[i]; TF_RET_CHECK(ShapeUtil::Compatible(custom_call->operand(i)->shape(), - operand_shape_with_layout)); + operand_shape_with_layout)) + << custom_call->operand(i)->shape().ToString() << " operand " + << operand_shape_with_layout.ToString(); TF_RET_CHECK(LayoutUtil::HasLayout(operand_shape_with_layout)); } } diff --git a/tensorflow/compiler/xla/service/indexed_array_analysis_test.cc b/tensorflow/compiler/xla/service/indexed_array_analysis_test.cc index 98246d5403e4aebc2f4d81e52145706355ddd9a9..295465c8481bcb7d1385192febe0d09614e393b3 100644 --- a/tensorflow/compiler/xla/service/indexed_array_analysis_test.cc +++ b/tensorflow/compiler/xla/service/indexed_array_analysis_test.cc @@ -99,7 +99,7 @@ TEST_F(IndexedArrayAnalysisTest, SimpleOneToOneConstantGather) { HloModule SimpleGather ENTRY main { - operand = s32[3,3] constant(s32[3,3]{{1,2,3},{1,2,3},{1,2,3}}) + operand = s32[3,3] constant({{1,2,3},{1,2,3},{1,2,3}}) indices = s32[5] parameter(0) ROOT gather = s32[5,3] gather(operand, indices), offset_dims={1}, @@ -119,7 +119,7 @@ TEST_F(IndexedArrayAnalysisTest, GatherIsNotScalarIndexed0) { HloModule SimpleGather ENTRY main { - operand = s32[3,3] constant(s32[3,3]{{1,2,3},{1,2,3},{1,2,3}}) + operand = s32[3,3] constant({{1,2,3},{1,2,3},{1,2,3}}) indices = s32[5,2] parameter(0) ROOT gather = s32[5] gather(operand, indices), offset_dims={}, @@ -195,7 +195,7 @@ TEST_F(IndexedArrayAnalysisTest, GatherOfGather_OneToOne) { HloModule SimpleGather ENTRY main { - operand = s32[3,3] constant(s32[3,3]{{1,2,3},{1,2,3},{1,2,3}}) + operand = s32[3,3] constant({{1,2,3},{1,2,3},{1,2,3}}) indices_a = s32[5] parameter(0) indices_b = s32[2] parameter(1) gather_a = s32[5,3] gather(operand, indices_a), @@ -309,7 +309,7 @@ TEST_F(IndexedArrayAnalysisTest, ReshapeOfGather0) { HloModule ReshapeOfGather ENTRY main { - operand = s32[3,4] constant(s32[3,4]{{1,2,3,4},{1,2,3,4},{1,2,3,4}}) + operand = s32[3,4] constant({{1,2,3,4},{1,2,3,4},{1,2,3,4}}) indices = s32[5] parameter(0) gather = s32[5,4] gather(operand, indices), offset_dims={1}, @@ -330,7 +330,7 @@ TEST_F(IndexedArrayAnalysisTest, ReshapeOfGather1) { HloModule ReshapeOfGather ENTRY main { - operand = s32[3,4] constant(s32[3,4]{{1,2,3,4},{1,2,3,4},{1,2,3,4}}) + operand = s32[3,4] constant({{1,2,3,4},{1,2,3,4},{1,2,3,4}}) indices = s32[5,7] parameter(0) gather = s32[5,4,7] gather(operand, indices), offset_dims={1}, @@ -352,7 +352,7 @@ TEST_F(IndexedArrayAnalysisTest, ReshapeOfGather2) { HloModule ReshapeOfGather ENTRY main { - operand = s32[3,2,6] constant(s32[3,2,6]{ + operand = s32[3,2,6] constant({ {{1,2,3,4,5,6},{1,2,3,4,5,6}}, {{1,2,3,4,5,6},{1,2,3,4,5,6}}, {{1,2,3,4,5,6},{1,2,3,4,5,6}}}) @@ -377,7 +377,7 @@ TEST_F(IndexedArrayAnalysisTest, ReshapeOfGather3) { HloModule ReshapeOfGather ENTRY main { - operand = s32[2,6] constant(s32[2,6]{ + operand = s32[2,6] constant({ {1,2,3,4,5,6},{1,2,3,4,5,6}}) indices = s32[1] parameter(0) gather = s32[1,6] gather(operand, indices), @@ -405,7 +405,7 @@ TEST_F(IndexedArrayAnalysisTest, ReshapeOfGather4) { HloModule ReshapeOfGather ENTRY main { - operand = s32[2,3]{1,0} constant(s32[2,3] { { 1, 2, 3 }, { 1, 2, 3 } }) + operand = s32[2,3]{1,0} constant({ { 1, 2, 3 }, { 1, 2, 3 } }) i.0 = s64[1,3]{1,0} parameter(0) g.0 = s32[1,3,3]{2,1,0} gather(operand, i.0), offset_dims={2}, @@ -438,7 +438,7 @@ TEST_F(IndexedArrayAnalysisTest, ReshapeOfGather5) { HloModule ReshapeOfGather ENTRY main { - operand = s32[1,6] constant(s32[1,6]{{1,2,3,4,5,6}}) + operand = s32[1,6] constant({{1,2,3,4,5,6}}) indices = s32[1] parameter(0) gather = s32[1,6] gather(operand, indices), offset_dims={1}, @@ -465,7 +465,7 @@ TEST_F(IndexedArrayAnalysisTest, ReshapeOfGather6) { HloModule ReshapeOfGather ENTRY main { - operand = s32[1,2,6] constant(s32[1,2,6]{{ + operand = s32[1,2,6] constant({{ {1,2,3,4,5,6},{1,2,3,4,5,6}}}) indices = s32[1] parameter(0) gather = s32[1,1,6] gather(operand, indices), @@ -496,7 +496,7 @@ TEST_F(IndexedArrayAnalysisTest, ReshapeOfGather7) { HloModule ReshapeOfGather ENTRY main { - operand = s32[2,6] constant(s32[2,6]{ + operand = s32[2,6] constant({ {1,2,3,4,5,6},{1,2,3,4,5,6}}) indices = s32[1,5] parameter(0) gather = s32[1,5,6] gather(operand, indices), @@ -527,7 +527,7 @@ TEST_F(IndexedArrayAnalysisTest, ReshapeOfGatherNoFold0) { HloModule ReshapeOfGather ENTRY main { - operand = s32[3,4] constant(s32[3,4]{{1,2,3,4},{1,2,3,4},{1,2,3,4}}) + operand = s32[3,4] constant({{1,2,3,4},{1,2,3,4},{1,2,3,4}}) indices = s32[5,6] parameter(0) gather = s32[5,4,6] gather(operand, indices), offset_dims={1}, @@ -556,7 +556,7 @@ TEST_F(IndexedArrayAnalysisTest, ReshapeOfGatherNoFold1) { HloModule ReshapeOfGather ENTRY main { - operand = s32[3,5,2] constant(s32[3,5,2]{ + operand = s32[3,5,2] constant({ {{1,2},{3,4},{5,6},{7,8},{9,10}}, {{1,2},{3,4},{5,6},{7,8},{9,10}}, {{1,2},{3,4},{5,6},{7,8},{9,10}}}) @@ -588,7 +588,7 @@ TEST_F(IndexedArrayAnalysisTest, ReshapeOfGatherNoFold2) { HloModule ReshapeOfGather ENTRY main { - operand = s32[3,4,1] constant(s32[3,4,1]{ + operand = s32[3,4,1] constant({ {{1},{2},{3},{4}}, {{1},{2},{3},{4}}, {{1},{2},{3},{4}}}) @@ -620,7 +620,7 @@ TEST_F(IndexedArrayAnalysisTest, UnaryOpOfGather) { HloModule UnaryOpOfGather ENTRY main { - operand = f32[3,4] constant(f32[3,4]{{1,2,3,4},{1,3,2,4},{4,3,2,1}}) + operand = f32[3,4] constant({{1,2,3,4},{1,3,2,4},{4,3,2,1}}) indices = s32[5] parameter(0) gather = f32[5,4] gather(operand, indices), offset_dims={1}, @@ -645,7 +645,7 @@ TEST_F(IndexedArrayAnalysisTest, AddBroadcastedScalarWithGather) { HloModule AddBroadcastedScalarWithGather ENTRY main { - gather_operand = s32[3,4] constant(s32[3,4]{{1,2,3,4},{1,3,2,4},{4,3,2,1}}) + gather_operand = s32[3,4] constant({{1,2,3,4},{1,3,2,4},{4,3,2,1}}) constant = s32[] constant(5) constant_broadcasted = s32[5,4] broadcast(constant), dimensions={} indices = s32[5] parameter(0) @@ -673,7 +673,7 @@ TEST_F(IndexedArrayAnalysisTest, HloModule SubtractBroadcastedScalarWithGather ENTRY main { - gather_operand = s32[3,4] constant(s32[3,4]{{1,2,3,4},{1,3,2,4},{4,3,2,1}}) + gather_operand = s32[3,4] constant({{1,2,3,4},{1,3,2,4},{4,3,2,1}}) constant = s32[] constant(5) constant_broadcasted = s32[5,4] broadcast(constant), dimensions={} indices = s32[5] parameter(0) @@ -701,7 +701,7 @@ TEST_F(IndexedArrayAnalysisTest, HloModule SubtractBroadcastedScalarWithGather ENTRY main { - gather_operand = s32[3,4] constant(s32[3,4]{{1,2,3,4},{1,3,2,4},{4,3,2,1}}) + gather_operand = s32[3,4] constant({{1,2,3,4},{1,3,2,4},{4,3,2,1}}) constant = s32[] constant(5) constant_broadcasted = s32[5,4] broadcast(constant), dimensions={} indices = s32[5] parameter(0) @@ -728,7 +728,7 @@ TEST_F(IndexedArrayAnalysisTest, AddBroadcastedVectorWithGather) { HloModule AddBroadcastedVectorWithGather ENTRY main { - gather_operand = s32[3,4] constant(s32[3,4]{{1,2,3,4},{1,3,2,4},{4,3,2,1}}) + gather_operand = s32[3,4] constant({{1,2,3,4},{1,3,2,4},{4,3,2,1}}) constant_vect = s32[4] constant({10,11,12,13}) constant_broadcasted = s32[5,4] broadcast(constant_vect), dimensions={1} indices = s32[5] parameter(0) @@ -755,7 +755,7 @@ TEST_F(IndexedArrayAnalysisTest, AddBroadcastedVectorWithGather_Negative) { HloModule AddBroadcastedVectorWithGather ENTRY main { - gather_operand = s32[3,4] constant(s32[3,4]{{1,2,3,4},{1,3,2,4},{4,3,2,1}}) + gather_operand = s32[3,4] constant({{1,2,3,4},{1,3,2,4},{4,3,2,1}}) constant_vect = s32[5] constant({10,11,12,13,14}) constant_broadcasted = s32[5,4] broadcast(constant_vect), dimensions={0} indices = s32[5] parameter(0) @@ -804,8 +804,8 @@ TEST_F(IndexedArrayAnalysisTest, DotOpBasic_0) { HloModule DotOp ENTRY main { - gather_operand = s32[3,4] constant(s32[3,4]{{1,2,3,4},{5,6,7,8},{9,10,11,12}}) - dot_rhs_constant = s32[4,3] constant(s32[4,3]{{1,2,3},{4,5,6},{7,8,9},{10,11,12}}) + gather_operand = s32[3,4] constant({{1,2,3,4},{5,6,7,8},{9,10,11,12}}) + dot_rhs_constant = s32[4,3] constant({{1,2,3},{4,5,6},{7,8,9},{10,11,12}}) indices = s32[5] parameter(0) dot_lhs = s32[5,4] gather(gather_operand, indices), offset_dims={1}, @@ -831,8 +831,8 @@ TEST_F(IndexedArrayAnalysisTest, DotOpBasic_1) { HloModule DotOp ENTRY main { - gather_operand = s32[3,4] constant(s32[3,4]{{1,2,3,4},{5,6,7,8},{9,10,11,12}}) - dot_rhs_constant = s32[3,3] constant(s32[3,3]{{1,2,3},{4,5,6},{7,8,9}}) + gather_operand = s32[3,4] constant({{1,2,3,4},{5,6,7,8},{9,10,11,12}}) + dot_rhs_constant = s32[3,3] constant({{1,2,3},{4,5,6},{7,8,9}}) indices = s32[5] parameter(0) dot_lhs = s32[3,5] gather(gather_operand, indices), offset_dims={0}, @@ -859,8 +859,8 @@ TEST_F(IndexedArrayAnalysisTest, DotOpBasic_2) { HloModule DotOp ENTRY main { - gather_operand = s32[3,4] constant(s32[3,4]{{1,2,3,4},{5,6,7,8},{9,10,11,12}}) - dot_lhs_constant = s32[4,3] constant(s32[4,3]{{1,2,3},{4,5,6},{7,8,9},{10,11,12}}) + gather_operand = s32[3,4] constant({{1,2,3,4},{5,6,7,8},{9,10,11,12}}) + dot_lhs_constant = s32[4,3] constant({{1,2,3},{4,5,6},{7,8,9},{10,11,12}}) indices = s32[5] parameter(0) dot_rhs = s32[3,5] gather(gather_operand, indices), offset_dims={0}, @@ -888,8 +888,8 @@ TEST_F(IndexedArrayAnalysisTest, DotOpBasic_3) { HloModule DotOp ENTRY main { - gather_operand = s32[4,3] constant(s32[4,3]{{1,2,3},{4,5,6},{7,8,9},{10,11,12}}) - dot_lhs_constant = s32[4,3] constant(s32[4,3]{{1,2,3},{4,5,6},{7,8,9},{10,11,12}}) + gather_operand = s32[4,3] constant({{1,2,3},{4,5,6},{7,8,9},{10,11,12}}) + dot_lhs_constant = s32[4,3] constant({{1,2,3},{4,5,6},{7,8,9},{10,11,12}}) indices = s32[5] parameter(0) dot_rhs = s32[5,3] gather(gather_operand, indices), offset_dims={1}, @@ -917,8 +917,8 @@ TEST_F(IndexedArrayAnalysisTest, DotOpWithBatch) { HloModule DotOp ENTRY main { - gather_operand = s32[2,3,2] constant(s32[2,3,2]{{{1,2},{3,4},{5,6}},{{7,8},{9,10},{11,12}}}) - dot_lhs_constant = s32[2,2,3] constant(s32[2,2,3]{{{1,2,3},{4,5,6}},{{7,8,9},{10,11,12}}}) + gather_operand = s32[2,3,2] constant({{{1,2},{3,4},{5,6}},{{7,8},{9,10},{11,12}}}) + dot_lhs_constant = s32[2,2,3] constant({{{1,2,3},{4,5,6}},{{7,8,9},{10,11,12}}}) indices = s32[4] parameter(0) dot_rhs = s32[2,3,4] gather(gather_operand, indices), offset_dims={0,1}, @@ -948,8 +948,8 @@ TEST_F(IndexedArrayAnalysisTest, DotOpNegative) { HloModule DotOp ENTRY main { - gather_operand = s32[3,4] constant(s32[3,4]{{1,2,3,4},{5,6,7,8},{9,10,11,12}}) - dot_rhs_constant = s32[2,3] constant(s32[2,3]{{1,2,3},{4,5,6}}) + gather_operand = s32[3,4] constant({{1,2,3,4},{5,6,7,8},{9,10,11,12}}) + dot_rhs_constant = s32[2,3] constant({{1,2,3},{4,5,6}}) indices = s32[2] parameter(0) dot_lhs = s32[3,2] gather(gather_operand, indices), offset_dims={0}, diff --git a/tensorflow/compiler/xla/service/instruction_fusion.cc b/tensorflow/compiler/xla/service/instruction_fusion.cc index 2297edcbe1d167f0752423f76b795b3592e85c47..3ea0b81d0d0c1e3edaf8fc2221e0c55a8086e110 100644 --- a/tensorflow/compiler/xla/service/instruction_fusion.cc +++ b/tensorflow/compiler/xla/service/instruction_fusion.cc @@ -23,6 +23,7 @@ limitations under the License. #include "absl/algorithm/container.h" #include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" #include "absl/memory/memory.h" #include "tensorflow/compiler/xla/map_util.h" #include "tensorflow/compiler/xla/service/fusion_queue.h" @@ -457,8 +458,13 @@ StatusOr InstructionFusion::Run(HloModule* module) { computation_ = computation; reachability_ = HloReachabilityMap::Build(computation_); - HloInstructionSet do_not_duplicate = - ComputeGloballyUnfusible(computation_->MakeInstructionPostOrder()); + HloInstructionSet do_not_duplicate; + // If we allow duplications, we need to compute which instructions we do not + // want to duplicate based on a global analysis of the graph. + if (may_duplicate_) { + do_not_duplicate = + ComputeGloballyUnfusible(computation_->MakeInstructionPostOrder()); + } auto fusion_queue = GetFusionQueue(computation_); // Instruction fusion effectively fuses edges in the computation graph @@ -565,19 +571,42 @@ HloInstruction* InstructionFusion::FuseIntoMultiOutput( bool InstructionFusion::MultiOutputFusionCreatesCycle( HloInstruction* producer, HloInstruction* consumer) { - auto is_reachable = [&](const HloInstruction* a, const HloInstruction* b) { - // A consumer operand may have been multii-output fused into a parallel - // consumer and thus be missing from the oridinal reachability map. - if (!reachability_->IsPresent(a) || !reachability_->IsPresent(b)) { - reachability_ = HloReachabilityMap::Build(consumer->parent()); + absl::flat_hash_set operands; + for (const HloInstruction* operand : consumer->operands()) { + if (operand == producer) { + continue; + } + + // If the reachability map already contains the producer and the operand of + // the consumer, and the producer can reach the operand, then we know for + // sure MultiOutputFusion would create a cycle. If not, we need to do a DFS + // traversal of the computation to verify that this multioutput fusion would + // not create a cycle. + if (reachability_->IsPresent(producer) && + reachability_->IsPresent(operand) && + reachability_->IsReachable(producer, operand)) { + return true; + } + operands.insert(operand->unique_id()); + } + + // Do a DFS on the producer to see if any of the other consumer operands are + // reachable in the current state of the graph. + std::vector worklist = producer->users(); + absl::flat_hash_set visits; + while (!worklist.empty()) { + const HloInstruction* user = worklist.back(); + worklist.pop_back(); + if (operands.count(user->unique_id()) != 0) { + return true; } - return reachability_->IsReachable(a, b); - }; - return absl::c_any_of(consumer->operands(), - [&](const HloInstruction* consumer_operand) { - return consumer_operand != producer && - is_reachable(producer, consumer_operand); - }); + if (visits.count(user->unique_id()) == 0) { + visits.insert(user->unique_id()); + worklist.insert(worklist.end(), user->users().begin(), + user->users().end()); + } + } + return false; } bool InstructionFusion::ShouldFuse(HloInstruction* consumer, diff --git a/tensorflow/compiler/xla/service/instruction_fusion_test.cc b/tensorflow/compiler/xla/service/instruction_fusion_test.cc index 6b483126499fe1e635a7d13cf597ec5d089c5b24..611cfd404d7622f561f0acc86fc9b05e16eea22e 100644 --- a/tensorflow/compiler/xla/service/instruction_fusion_test.cc +++ b/tensorflow/compiler/xla/service/instruction_fusion_test.cc @@ -259,8 +259,8 @@ TEST_F(InstructionFusionTest, AvoidDuplicationIfNotAllFusibleRecursively) { add = f32[4,3]{1,0} add(p0, p0) abs1 = f32[4,3]{1,0} abs(add) log = f32[4,3]{1,0} log(abs1) - token = token[] after-all() - send = f32[4,3]{1,0} send(log, token), channel_id=0 + token0 = token[] after-all() + send = f32[4,3]{1,0} send(log, token0), channel_id=0 abs2 = f32[4,3]{1,0} abs(log) ROOT root = f32[4,3]{1,0} subtract(abs2, add) })") @@ -290,8 +290,8 @@ TEST_F(InstructionFusionTest, AvoidDuplicationIfNotAllFusibleRecursively) { p0 = f32[4,3]{1,0} parameter(0) add1 = f32[4,3]{1,0} add(p0, p0) log = f32[4,3]{1,0} log(p0) - token = token[] after-all() - send = f32[4,3]{1,0} send(log, token), channel_id=0 + token0 = token[] after-all() + send = f32[4,3]{1,0} send(log, token0), channel_id=0 add2 = f32[4,3]{1,0} add(log, add1) ROOT root = f32[4,3]{1,0} subtract(add1, add2) })") @@ -324,8 +324,8 @@ TEST_F(InstructionFusionTest, AvoidDuplicationIfNotAllFusibleRecursively) { add1 = f32[4,3]{1,0} add(p0, p0) add2 = f32[4,3]{1,0} add(add1, add1) log = f32[4,3]{1,0} log(add2) - token = token[] after-all() - send = f32[4,3]{1,0} send(log, token), channel_id=0 + token0 = token[] after-all() + send = f32[4,3]{1,0} send(log, token0), channel_id=0 sub1 = f32[4,3]{1,0} subtract(log, add2) sub2 = f32[4,3]{1,0} subtract(add2, add1) ROOT root = (f32[4,3]{1,0}, f32[4,3]{1,0}) tuple(sub1, sub2) @@ -394,6 +394,56 @@ TEST_F(InstructionFusionTest, AllowEffectiveUnaryDuplication) { .ValueOrDie()); } +TEST_F(InstructionFusionTest, FuseDiamondGraphsNoDuplication) { + auto module = ParseHloString(R"( + HloModule test_module + ENTRY Test { + p0 = f32[100] parameter(0) + p1 = f32[100] parameter(1) + add = f32[100] add(p0, p1) + slice1 = f32[99] slice(add), slice={[0:99:1]} + slice2 = f32[99] slice(add), slice={[1:100:1]} + ROOT add2 = f32[99] add(slice1, slice2) + })") + .ValueOrDie(); + EXPECT_TRUE( + InstructionFusion(InstructionFusion::IsExpensive, /*may_duplicate=*/false) + .Run(module.get()) + .ValueOrDie()) + << module->ToString(); + + HloInstruction* root = module->entry_computation()->root_instruction(); + // 'add' would originally need to be duplicated if fused. However after its + // two users 'slice1' and 'slice2' are fused into 'add2', 'add' has only one + // user and can now be also fused. + EXPECT_THAT(root, op::Fusion(op::Parameter(), op::Parameter())); +} + +TEST_F(InstructionFusionTest, FuseDiamondGraphsAllowDuplication) { + auto module = ParseHloString(R"( + HloModule test_module + ENTRY Test { + p0 = f32[100] parameter(0) + p1 = f32[100] parameter(1) + add = f32[100] add(p0, p1) + slice1 = f32[99] slice(add), slice={[0:99:1]} + slice2 = f32[99] slice(add), slice={[1:100:1]} + ROOT add2 = f32[99] add(slice1, slice2) + })") + .ValueOrDie(); + EXPECT_TRUE( + InstructionFusion(InstructionFusion::IsExpensive, /*may_duplicate=*/true) + .Run(module.get()) + .ValueOrDie()) + << module->ToString(); + + HloInstruction* root = module->entry_computation()->root_instruction(); + // 'add' would originally need to be duplicated if fused. However after its + // two users 'slice1' and 'slice2' are fused into 'add2', 'add' has only one + // user and can now be also fused. + EXPECT_THAT(root, op::Fusion(op::Parameter(), op::Parameter())); +} + TEST_F(InstructionFusionTest, WideningConvertsAreAlwaysDuplicableIntoConsumers) { auto module = ParseHloString(R"( diff --git a/tensorflow/compiler/xla/service/interpreter/compiler.cc b/tensorflow/compiler/xla/service/interpreter/compiler.cc index 3a5177c418e3af8253df228a51f2fc0901d10041..d37ae94bf6c4c697bbf30390c02a5099271e00a4 100644 --- a/tensorflow/compiler/xla/service/interpreter/compiler.cc +++ b/tensorflow/compiler/xla/service/interpreter/compiler.cc @@ -76,9 +76,12 @@ StatusOr> InterpreterCompiler::RunBackend( // need to compile anything // Create executable from only the Hlo module. + auto evaluator = absl::make_unique(); + evaluator->set_use_fast_path( + hlo_module->config().debug_options().xla_hlo_evaluator_use_fast_path()); std::unique_ptr executable = - absl::make_unique( - std::move(hlo_module), absl::make_unique()); + absl::make_unique(std::move(hlo_module), + std::move(evaluator)); return std::move(executable); } diff --git a/tensorflow/compiler/xla/service/interpreter/executable.cc b/tensorflow/compiler/xla/service/interpreter/executable.cc index 7635fbfed6f6a51fc9d203251d9bebf43cc63fd9..de9204011ce5ba8a9fc2871c6bd7120b6ed371b5 100644 --- a/tensorflow/compiler/xla/service/interpreter/executable.cc +++ b/tensorflow/compiler/xla/service/interpreter/executable.cc @@ -85,6 +85,7 @@ StatusOr InterpreterExecutable::ExecuteOnStream( Literal result_literal; { tensorflow::mutex_lock lock(evaluator_lock_); + evaluator_->ResetVisitStates(); TF_ASSIGN_OR_RETURN(result_literal, evaluator_->Evaluate( *computation, arg_literals)); } diff --git a/tensorflow/compiler/xla/service/layout_assignment_test.cc b/tensorflow/compiler/xla/service/layout_assignment_test.cc index 311bd7890545b5b2cbec920d2d12ddd482d0d53c..9fe8c3accbf283f3b3eebbefbac8739c37df16bc 100644 --- a/tensorflow/compiler/xla/service/layout_assignment_test.cc +++ b/tensorflow/compiler/xla/service/layout_assignment_test.cc @@ -27,7 +27,6 @@ limitations under the License. #include "tensorflow/compiler/xla/service/computation_layout.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" -#include "tensorflow/compiler/xla/service/hlo_matchers.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/hlo_parser.h" @@ -848,12 +847,12 @@ TEST_F(LayoutAssignmentTest, ChannelLayoutMismatch) { ENTRY entry_computation { param = (f32[2,2]) parameter(0) gte = f32[2,2] get-tuple-element(param), index=0 - token = token[] after-all() - recv = (f32[2,2], u32[], token[]) recv(token), channel_id=1, sharding={maximal device=1} + token0 = token[] after-all() + recv = (f32[2,2], u32[], token[]) recv(token0), channel_id=1, sharding={maximal device=1} recv-done = (f32[2,2], token[]) recv-done(recv), channel_id=1, sharding={maximal device=1} ROOT root = f32[2,2] get-tuple-element(recv-done), index=0 - send = (f32[2,2], u32[], token[]) send(gte, token), channel_id=1, + send = (f32[2,2], u32[], token[]) send(gte, token0), channel_id=1, sharding={maximal device=0} send-done = token[] send-done(send), channel_id=1, sharding={maximal device=0} } @@ -898,7 +897,7 @@ TEST_F(LayoutAssignmentTest, AllReduceLayoutMissmatch) { ar.0 = f32[2,2] cross-replica-sum(gte), all_reduce_id=1, replica_groups={{0}}, to_apply=add, sharding={maximal device=0} - const = f32[2,2] constant(f32[2,2]{{0,1},{2,3}}) + const = f32[2,2] constant({{0,1},{2,3}}) ROOT ar.1 = f32[2,2] cross-replica-sum(const), all_reduce_id=1, replica_groups={{0}}, to_apply=add, sharding={maximal device=1} diff --git a/tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.cc b/tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.cc index bd0139f85b6a5c5dc23dad962263038451921e65..5eeb29c478a371dae83251771f2dc4844672d3e9 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.cc @@ -18,28 +18,29 @@ limitations under the License. #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" namespace xla { -Status KernelSupportLibrary::For( +Status KernelSupportLibrary::ForWithStatus( absl::string_view name, llvm::Value* start, llvm::Value* end, llvm::Value* step, const std::function& for_body_generator) { - return If(b_->CreateICmpSLT(start, end), [&]() -> Status { + return IfWithStatus(b_->CreateICmpSLT(start, end), [&]() -> Status { TF_RETURN_IF_ERROR(for_body_generator(start, /*is_first_iteration=*/true)); - return For(name, b_->CreateAdd(start, step), end, step, - [&](llvm::Value* iv) { return for_body_generator(iv, false); }); + return ForWithStatus( + name, b_->CreateAdd(start, step), end, step, + [&](llvm::Value* iv) { return for_body_generator(iv, false); }); }); } -Status KernelSupportLibrary::For( +Status KernelSupportLibrary::ForWithStatus( absl::string_view name, llvm::Value* start, llvm::Value* end, llvm::Value* step, bool peel_first_iteration, const std::function& for_body_generator) { if (peel_first_iteration) { - return For(name, start, end, step, true, - [&](llvm::Value* indvar, bool is_first_iteration) -> Status { - return for_body_generator(indvar, - b_->getInt1(is_first_iteration)); - }); + return ForWithStatus( + name, start, end, step, true, + [&](llvm::Value* indvar, bool is_first_iteration) -> Status { + return for_body_generator(indvar, b_->getInt1(is_first_iteration)); + }); } else { std::unique_ptr loop = llvm_ir::ForLoop::EmitForLoop( name, start, end, step, b_, @@ -55,7 +56,7 @@ Status KernelSupportLibrary::For( } } -Status KernelSupportLibrary::If( +Status KernelSupportLibrary::IfWithStatus( absl::string_view name, llvm::Value* condition, const std::function& true_block_generator, const std::function& false_block_generator) { diff --git a/tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.h b/tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.h index 43fec311f150d6054f6ad24f99db332f90ff94a3..612b839cfa15711061e1ae53358a72d5220e1801 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.h +++ b/tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.h @@ -48,41 +48,42 @@ class KernelSupportLibrary { // for (i64 i = `start` + `step`; i s< `end`; i += `step`) // `for_body_generator(/*ind_var=*/,i, /*is_first_iteration=*/false)`; // } - Status For( + Status ForWithStatus( absl::string_view name, llvm::Value* start, llvm::Value* end, llvm::Value* step, const std::function& for_body_generator); - void ForReturnVoid( + void For( absl::string_view name, llvm::Value* start, llvm::Value* end, llvm::Value* step, const std::function& for_body_generator) { CHECK_EQ(Status::OK(), - For(name, start, end, step, + ForWithStatus( + name, start, end, step, [&](llvm::Value* ind_var, bool is_first_iteration) -> Status { for_body_generator(ind_var, is_first_iteration); return Status::OK(); })); } - Status For(absl::string_view name, int64 start, int64 end, int64 step, - const std::function& - for_body_generator) { - return For(name, /*start=*/b_->getInt64(start), - /*end=*/b_->getInt64(end), - /*step=*/b_->getInt64(step), for_body_generator); + Status ForWithStatus( + absl::string_view name, int64 start, int64 end, int64 step, + const std::function& for_body_generator) { + return ForWithStatus(name, /*start=*/b_->getInt64(start), + /*end=*/b_->getInt64(end), + /*step=*/b_->getInt64(step), for_body_generator); } - void ForReturnVoid( + void For( absl::string_view name, int64 start, int64 end, int64 step, const std::function& for_body_generator) { - ForReturnVoid(name, /*start=*/b_->getInt64(start), - /*end=*/b_->getInt64(end), - /*step=*/b_->getInt64(step), for_body_generator); + For(name, /*start=*/b_->getInt64(start), + /*end=*/b_->getInt64(end), + /*step=*/b_->getInt64(step), for_body_generator); } // Generates the following control flow structure if `peel_first_iteration` is @@ -99,19 +100,19 @@ class KernelSupportLibrary { // for (i64 i = `start`; i s< `end`; i += `step`) // `for_body_generator(/*ind_var=*/,i, // /*is_first_iteration=*/,(i != `start`))`; - Status For(absl::string_view name, llvm::Value* start, llvm::Value* end, - llvm::Value* step, bool peel_first_iteration, - const std::function& - for_body_generator); + Status ForWithStatus( + absl::string_view name, llvm::Value* start, llvm::Value* end, + llvm::Value* step, bool peel_first_iteration, + const std::function& + for_body_generator); - void ForReturnVoid(absl::string_view name, llvm::Value* start, - llvm::Value* end, llvm::Value* step, - bool peel_first_iteration, - const std::function& - for_body_generator) { - TF_CHECK_OK(For( + void For(absl::string_view name, llvm::Value* start, llvm::Value* end, + llvm::Value* step, bool peel_first_iteration, + const std::function& + for_body_generator) { + TF_CHECK_OK(ForWithStatus( name, start, end, step, peel_first_iteration, [&](llvm::Value* ind_var, llvm::Value* is_first_iteration) -> Status { for_body_generator(ind_var, is_first_iteration); @@ -119,80 +120,81 @@ class KernelSupportLibrary { })); } - Status For(absl::string_view name, llvm::Value* start, llvm::Value* end, - int64 step, bool peel_first_iteration, - const std::function& - for_body_generator) { - return For(name, /*start=*/start, /*end=*/end, - /*step=*/llvm::ConstantInt::get(start->getType(), step), - peel_first_iteration, for_body_generator); + Status ForWithStatus( + absl::string_view name, llvm::Value* start, llvm::Value* end, int64 step, + bool peel_first_iteration, + const std::function& + for_body_generator) { + return ForWithStatus( + name, /*start=*/start, /*end=*/end, + /*step=*/llvm::ConstantInt::get(start->getType(), step), + peel_first_iteration, for_body_generator); } - void ForReturnVoid(absl::string_view name, llvm::Value* start, - llvm::Value* end, int64 step, bool peel_first_iteration, - const std::function& - for_body_generator) { - ForReturnVoid(name, /*start=*/start, /*end=*/end, - /*step=*/llvm::ConstantInt::get(start->getType(), step), - peel_first_iteration, for_body_generator); + void For(absl::string_view name, llvm::Value* start, llvm::Value* end, + int64 step, bool peel_first_iteration, + const std::function& + for_body_generator) { + For(name, /*start=*/start, /*end=*/end, + /*step=*/llvm::ConstantInt::get(start->getType(), step), + peel_first_iteration, for_body_generator); } - Status For( + Status ForWithStatus( absl::string_view name, llvm::Value* start, llvm::Value* end, llvm::Value* step, const std::function& for_body_generator) { - return For(name, start, end, step, - /*peel_first_iteration=*/false, - [&](llvm::Value* indvar, llvm::Value*) -> Status { - return for_body_generator(indvar); - }); + return ForWithStatus(name, start, end, step, + /*peel_first_iteration=*/false, + [&](llvm::Value* indvar, llvm::Value*) -> Status { + return for_body_generator(indvar); + }); } - void ForReturnVoid( + void For( absl::string_view name, llvm::Value* start, llvm::Value* end, llvm::Value* step, const std::function& for_body_generator) { - ForReturnVoid(name, start, end, step, - /*peel_first_iteration=*/false, - [&](llvm::Value* indvar, llvm::Value*) { - return for_body_generator(indvar); - }); + For(name, start, end, step, + /*peel_first_iteration=*/false, [&](llvm::Value* indvar, llvm::Value*) { + return for_body_generator(indvar); + }); } - Status For( + Status ForWithStatus( absl::string_view name, llvm::Value* start, llvm::Value* end, int64 step, const std::function& for_body_generator) { - return For(name, start, end, llvm::ConstantInt::get(start->getType(), step), - /*peel_first_iteration=*/false, - [&](llvm::Value* indvar, llvm::Value*) -> Status { - return for_body_generator(indvar); - }); + return ForWithStatus(name, start, end, + llvm::ConstantInt::get(start->getType(), step), + /*peel_first_iteration=*/false, + [&](llvm::Value* indvar, llvm::Value*) -> Status { + return for_body_generator(indvar); + }); } - void ForReturnVoid( + void For( absl::string_view name, llvm::Value* start, llvm::Value* end, int64 step, const std::function& for_body_generator) { - ForReturnVoid(name, start, end, - llvm::ConstantInt::get(start->getType(), step), - for_body_generator); + For(name, start, end, llvm::ConstantInt::get(start->getType(), step), + for_body_generator); } - Status For( + Status ForWithStatus( absl::string_view name, int64 start, int64 end, int64 step, const std::function& for_body_generator) { - return For(name, /*start=*/b_->getInt64(start), - /*end=*/b_->getInt64(end), - /*step=*/b_->getInt64(step), for_body_generator); + return ForWithStatus(name, /*start=*/b_->getInt64(start), + /*end=*/b_->getInt64(end), + /*step=*/b_->getInt64(step), for_body_generator); } - void ForReturnVoid( + void For( absl::string_view name, int64 start, int64 end, int64 step, const std::function& for_body_generator) { - ForReturnVoid(name, /*start=*/b_->getInt64(start), - /*end=*/b_->getInt64(end), - /*step=*/b_->getInt64(step), for_body_generator); + For(name, /*start=*/b_->getInt64(start), + /*end=*/b_->getInt64(end), + /*step=*/b_->getInt64(step), for_body_generator); } // Generates the following control flow structure: @@ -201,38 +203,43 @@ class KernelSupportLibrary { // `true_block_generator()`; // else // `false_block_generator()`; - Status If(absl::string_view name, llvm::Value* condition, - const std::function& true_block_generator, - const std::function& false_block_generator = - []() -> Status { return Status::OK(); }); + Status IfWithStatus( + absl::string_view name, llvm::Value* condition, + const std::function& true_block_generator, + const std::function& false_block_generator = []() -> Status { + return Status::OK(); + }); - Status If(llvm::Value* condition, - const std::function& true_block_generator, - const std::function& false_block_generator = - []() -> Status { return Status::OK(); }) { - return If("", condition, true_block_generator, false_block_generator); + Status IfWithStatus( + llvm::Value* condition, + const std::function& true_block_generator, + const std::function& false_block_generator = []() -> Status { + return Status::OK(); + }) { + return IfWithStatus("", condition, true_block_generator, + false_block_generator); } - void IfReturnVoid(llvm::Value* condition, - const std::function& true_block_generator, - const std::function& false_block_generator = []() { - }) { - IfReturnVoid("", condition, true_block_generator, false_block_generator); + void If( + llvm::Value* condition, const std::function& true_block_generator, + const std::function& false_block_generator = []() {}) { + If("", condition, true_block_generator, false_block_generator); } - void IfReturnVoid(absl::string_view name, llvm::Value* condition, - const std::function& true_block_generator, - const std::function& false_block_generator = []() { - }) { - TF_CHECK_OK(If(name, condition, - [&]() { - true_block_generator(); - return Status::OK(); - }, - [&]() { - false_block_generator(); - return Status::OK(); - })); + void If( + absl::string_view name, llvm::Value* condition, + const std::function& true_block_generator, + const std::function& false_block_generator = []() {}) { + TF_CHECK_OK(IfWithStatus( + name, condition, + [&]() { + true_block_generator(); + return Status::OK(); + }, + [&]() { + false_block_generator(); + return Status::OK(); + })); } using ArgumentVector = absl::Span; diff --git a/tensorflow/compiler/xla/service/llvm_ir/kernel_tiling.cc b/tensorflow/compiler/xla/service/llvm_ir/kernel_tiling.cc index 1aa85eb8d2d206bf0537deb659e779b24fffbb0a..cebbc4290163d4e98003cd7b5df6ec906509a446 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/kernel_tiling.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/kernel_tiling.cc @@ -170,14 +170,16 @@ IrArray::Index KernelMappingScheme::EmitBlockIndex(llvm::Type* index_ty) { IrArray::Index KernelMappingScheme::GetTileIndexForBlockOrigin( const IrArray::Index& block_index) { - IrArray::Index tile_index = block_index; + DCHECK_EQ(block_index.size(), block_sizes_.size()); + std::vector multidim; + multidim.reserve(block_sizes_.size()); for (int i = 0; i < block_sizes_.size(); ++i) { - tile_index[i] = b_->CreateMul( + multidim.push_back(b_->CreateMul( block_index[i], llvm::ConstantInt::get(block_index[i]->getType(), block_sizes_[i]), - "block_origin." + std::to_string(i)); + "block_origin." + std::to_string(i))); } - return tile_index; + return IrArray::Index(multidim, block_index[0]->getType()); } IrArray::Index KernelMappingScheme::GetElementIndexForTileOrigin( @@ -217,14 +219,14 @@ KernelMappingScheme::EmitThreadYXCoordinate(llvm::Type* index_ty) { // defined by (num_thread_y, num_thread_x) from thread_id. llvm::CallInst* thread_id_raw = llvm_ir::EmitCallToIntrinsic( llvm::Intrinsic::nvvm_read_ptx_sreg_tid_x, {}, {}, b_); - llvm_ir::AddRangeMetadata(0, GetThreadsPerTile(), thread_id_raw); + llvm_ir::AddRangeMetadata(0, GetThreadsPerBlock(), thread_id_raw); llvm::Value* thread_id_int = b_->CreateIntCast(thread_id_raw, index_ty, /*isSigned=*/true, "thread.id.x"); llvm::Value* num_thread_x = llvm::ConstantInt::get(index_ty, GetNumberOfThreadsForDimensionX()); - llvm::Value* x = b_->CreateURem(thread_id_int, num_thread_x); - llvm::Value* y = b_->CreateUDiv(thread_id_int, num_thread_x); + llvm::Value* x = b_->CreateURem(thread_id_int, num_thread_x, "thread.x"); + llvm::Value* y = b_->CreateUDiv(thread_id_int, num_thread_x, "thread.y"); return std::make_tuple(y, x); } diff --git a/tensorflow/compiler/xla/service/llvm_ir/kernel_tiling.h b/tensorflow/compiler/xla/service/llvm_ir/kernel_tiling.h index 7277aeac8ad2086a2f6419b1fdb60c4872841adc..fb633b12e60d1a9f3103fb2919ad2c3f3f14de20 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/kernel_tiling.h +++ b/tensorflow/compiler/xla/service/llvm_ir/kernel_tiling.h @@ -142,7 +142,7 @@ class KernelMappingScheme { int64 GetNumberOfThreadsForDimensionX() const { return num_threads_x_; } int64 GetNumberOfThreadsForDimensionY() const { return num_threads_y_; } - int64 GetThreadsPerTile() const { + int64 GetThreadsPerBlock() const { return GetNumberOfThreadsForDimensionX() * GetNumberOfThreadsForDimensionY(); } diff --git a/tensorflow/compiler/xla/service/llvm_ir/sort_util.cc b/tensorflow/compiler/xla/service/llvm_ir/sort_util.cc index e22c2173c271fc9571be1ddb0759d2b31562dc98..6a9406bfebafcc02dc2e144b62284a9e83c3edeb 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/sort_util.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/sort_util.cc @@ -108,7 +108,7 @@ void EmitCompareLoopBody( // if (is_smaller_index && index_is_inbounds) KernelSupportLibrary ksl(b); - ksl.IfReturnVoid("smaller_comparison_index", do_comparison, [&]() { + ksl.If("smaller_comparison_index", do_comparison, [&]() { auto key1 = read_element(0, current_keys_index); auto key2 = read_element(0, compare_keys_index); auto compare_key1 = key1; @@ -155,7 +155,7 @@ void EmitCompareLoopBody( is_smaller_than = b->CreateOr( is_smaller_than, b->CreateAnd(keys_equal, index_is_smaller_than)); } - ksl.IfReturnVoid("is_smaller_than", is_smaller_than, [&]() { + ksl.If("is_smaller_than", is_smaller_than, [&]() { // Swap key1 with key2. write_element(0, current_keys_index, key2); write_element(0, compare_keys_index, key1); @@ -192,7 +192,7 @@ void EmitTiledCompareLoop( b->CreateShl(tiled_keys_index[dimension_to_sort], value_one); // We want to copy two adjacent elements. We first check whether the // first index position is within bounds. - ksl.IfReturnVoid( + ksl.If( "smaller_keys_index", b->CreateICmpSLT(current_keys_index, tiled_keys_index.GetConstantWithIndexType( @@ -203,15 +203,14 @@ void EmitTiledCompareLoop( // Increment to go the next index position. current_keys_index = b->CreateAdd(current_keys_index, value_one); // Here we check whether the next index position is within bounds. - ksl.IfReturnVoid( - "inner_smaller_keys_index", - b->CreateICmpSLT(current_keys_index, - tiled_keys_index.GetConstantWithIndexType( - dimension_to_sort_bound)), - [&]() { - cache_index = b->CreateAdd(cache_index, value_one); - read_or_write(cache_index, current_keys_index); - }); + ksl.If("inner_smaller_keys_index", + b->CreateICmpSLT(current_keys_index, + tiled_keys_index.GetConstantWithIndexType( + dimension_to_sort_bound)), + [&]() { + cache_index = b->CreateAdd(cache_index, value_one); + read_or_write(cache_index, current_keys_index); + }); }); }; @@ -253,7 +252,7 @@ void EmitTiledCompareLoop( if (dimension_to_sort_bound % tile_size) { // Otherwise we need a bounds check for the last tile. The last tile has // size 'dimension_to_sort_bound' % 'tile_size'. - ksl.IfReturnVoid( + ksl.If( "is_last_tile", b->CreateICmpUGE( b->CreateMul(tiled_keys_index[dimension_to_sort], diff --git a/tensorflow/compiler/xla/service/name_uniquer.cc b/tensorflow/compiler/xla/service/name_uniquer.cc index ac2f79674feceff436c0e9c65338967f498e4473..daa718879ddd45afb02725b557380b2f49fe833e 100644 --- a/tensorflow/compiler/xla/service/name_uniquer.cc +++ b/tensorflow/compiler/xla/service/name_uniquer.cc @@ -17,6 +17,7 @@ limitations under the License. #include "absl/strings/numbers.h" #include "absl/strings/str_cat.h" +#include "tensorflow/compiler/xla/primitive_util.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/types.h" @@ -42,6 +43,7 @@ NameUniquer::NameUniquer(const string& separator) { if (name.empty()) { return ""; } + string result = name; char c = static_cast(result[0]); if (!isalpha(c) && c != '_') { @@ -52,6 +54,13 @@ NameUniquer::NameUniquer(const string& separator) { result[i] = '_'; } } + + // HLO primitive type names (with the exception of 'tuple') are keywords in + // the HLO text representation and cannot be names, so append an underscore if + // the name is a primitive type. + if (primitive_util::IsPrimitiveTypeName(result) && result != "tuple") { + result += "_"; + } return result; } diff --git a/tensorflow/compiler/xla/service/name_uniquer_test.cc b/tensorflow/compiler/xla/service/name_uniquer_test.cc index 3e2592c6ac626143f1421e545a31d9be91e376bc..d0d04147e0c29c66cba447550c0a9c703f35573a 100644 --- a/tensorflow/compiler/xla/service/name_uniquer_test.cc +++ b/tensorflow/compiler/xla/service/name_uniquer_test.cc @@ -104,5 +104,21 @@ TEST_F(NameUniquerTest, KeepNamesInRandomOrder) { EXPECT_EQ("foo.3", uniquer.GetUniqueName("foo.3")); } +TEST_F(NameUniquerTest, AvoidKeywords) { + NameUniquer uniquer("."); + + EXPECT_EQ("f32_", uniquer.GetUniqueName("f32")); + EXPECT_EQ("s64_", uniquer.GetUniqueName("s64")); + EXPECT_EQ("pred_", uniquer.GetUniqueName("pred")); + + // Though a primitive type, "tuple" is not a keyword. + EXPECT_EQ("tuple", uniquer.GetUniqueName("tuple")); + + // Keywords are not capitalized. + EXPECT_EQ("F32", uniquer.GetUniqueName("F32")); + EXPECT_EQ("S32", uniquer.GetUniqueName("S32")); + EXPECT_EQ("Pred", uniquer.GetUniqueName("Pred")); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/pattern_matcher.h b/tensorflow/compiler/xla/service/pattern_matcher.h index fb1645d9b2ebeae77190a950ebd023979c567016..81db3bb643a989cafb6c6a8bcbd35e218fdcaf44 100644 --- a/tensorflow/compiler/xla/service/pattern_matcher.h +++ b/tensorflow/compiler/xla/service/pattern_matcher.h @@ -64,6 +64,9 @@ namespace xla { // e.g. IsConstantScalar() or IsConstantScalar(42). // - WithFusionKind // - WithTupleIndex: get-tuple-element operations with the given tuple index +// - WithOneUse: Instruction is used as an operand exactly once. +// - WithOneUser: Instruction is used by exactly one other instruction, but +// is possibly used more than once as an operand (e.g. multiply(x,x)). // // Shape(): // - EqualTo @@ -1133,6 +1136,13 @@ inline const HloInstruction* HloOperand(const HloInstruction* instr, return instr->operand(idx); } +// Pretty-printer for HloInstruction. Sort of like ToShortString, but with +// fewer %s and more shapes. +inline string InstToString(const HloInstruction* inst) { + return inst->ToString( + HloPrintOptions().set_print_metadata(false).set_print_percent(false)); +} + template class HloInstructionPattern; @@ -1187,14 +1197,14 @@ class HloInstructionIsImpl { bool Match(const ::xla::HloInstruction* inst, MatchOption option) const { if (inst != inst_) { EXPLAIN << "HloInstruction " << inst << " is not " << inst_ << " (" - << inst_->ToShortString() << ")"; + << InstToString(inst_) << ")"; return false; } return true; } void DescribeTo(std::ostream* os, int64 indent = 0) const { - *os << "which is " << inst_ << " (" << inst_->ToShortString() << ")"; + *os << "which is " << inst_ << " (" << InstToString(inst_) << ")"; } private: @@ -1603,6 +1613,64 @@ class HloInstructionPatternParameterNumImpl { int64 parameter_num_; }; +// Superclass that contains common code used by Op::WithOneUse() and +// Op::WithOneUser(). +class HloInstructionPatternOneUseOrUserImpl { + protected: + bool MatchOneUser(const HloInstruction* inst, MatchOption option) const { + if (inst->user_count() != 1) { + EXPLAIN << "HloInstruction has " << inst->user_count() + << " users, but expected exactly one."; + if (inst->user_count() > 1) { + EXPLAIN << "\nAll users:"; + for (const HloInstruction* user : inst->users()) { + EXPLAIN << "\n - " << InstToString(user); + } + } + return false; + } + return true; + } +}; + +class HloInstructionPatternOneUseImpl + : public HloInstructionPatternOneUseOrUserImpl { + public: + bool Match(const HloInstruction* inst, MatchOption option) const { + if (!MatchOneUser(inst, option)) { + return false; + } + + int64 use_count = absl::c_count_if( + inst->users()[0]->operands(), + [&](const HloInstruction* operand) { return operand == inst; }); + if (use_count != 1) { + EXPLAIN << "HloInstruction is used " << use_count + << " times by its user, but is expected to be used just once: " + << InstToString(inst->users()[0]); + return false; + } + return true; + } + + void DescribeTo(std::ostream* os, int64 indent = 0) const { + *os << "which has exactly one use"; + } +}; + +class HloInstructionPatternOneUserImpl + : public HloInstructionPatternOneUseOrUserImpl { + public: + bool Match(const HloInstruction* inst, MatchOption option) const { + return MatchOneUser(inst, option); + } + + void DescribeTo(std::ostream* os, int64 indent = 0) const { + *os << "which has exactly one user (but possibly is used multiple times by " + "that instruction)"; + } +}; + // Matches a constant scalar or effective scalar, optionally with a given value. template class HloConstantScalarImpl { @@ -1669,7 +1737,8 @@ class HloConstantScalarImpl { literal_r0_as_val_ty_or.ValueOrDie() == val_literal && literal_r0 == val_as_literal_ty; if (!rv) { - EXPLAIN << "HloInstruction's constant value " << literal_r0.ToString() + EXPLAIN << "HloInstruction's constant value " + << literal_r0.ToStringWithoutShape() << " did not match expected value " << *val_; } return rv; @@ -1706,10 +1775,7 @@ class HloInstructionPattern { return true; } if (inst != nullptr) { - EXPLAIN << "\nin " - << inst->ToString(HloPrintOptions() - .set_print_metadata(false) - .set_print_percent(false)); + EXPLAIN << "\nin " << InstToString(inst); } return false; } @@ -1722,10 +1788,7 @@ class HloInstructionPattern { } return true; } - EXPLAIN << "\nin " - << inst->ToString(HloPrintOptions() - .set_print_metadata(false) - .set_print_percent(false)); + EXPLAIN << "\nin " << InstToString(inst); return false; } @@ -1877,6 +1940,22 @@ class HloInstructionPattern { return AppendImpl(HloInstructionPatternParameterNumImpl(parameter_num)); } + // Modifies the pattern to match if the instruction is used exactly once. + // Does not match if the instruction is used twice by the same user (e.g. + // multiply(x,x)). + constexpr auto WithOneUse() const + -> decltype(this->AppendImpl(HloInstructionPatternOneUseImpl())) { + return AppendImpl(HloInstructionPatternOneUseImpl()); + } + + // Modifies the pattern to match if the instruction is used by exactly one + // other instruction. Will match if the instruction is used twice, so long as + // it's by the same user (e.g. multiply(x,x)). + constexpr auto WithOneUser() const + -> decltype(this->AppendImpl(HloInstructionPatternOneUserImpl())) { + return AppendImpl(HloInstructionPatternOneUserImpl()); + } + void DescribeTo(std::ostream* os, int64 indent = 0) const { impl_.DescribeTo(os, indent); } @@ -1922,6 +2001,7 @@ Op(::xla::HloInstruction** matched_inst) { XLA_NULLOP_PATTERN(Constant) XLA_NULLOP_PATTERN(Parameter) XLA_NULLOP_PATTERN(Iota) +XLA_NULLOP_PATTERN(Rng) #undef XLA_NULLOP_PATTERN // Helpers for unary instructions. @@ -2028,10 +2108,10 @@ XLA_UNOP_PATTERN(Transpose) } \ template \ inline auto NAME##AnyOrder(Lhs&& lhs, Rhs&& rhs) \ - ->decltype(NAME##AnyOrder( \ + ->decltype(NAME##AnyOrder( \ nullptr, std::forward(lhs), std::forward(rhs))) { \ - return NAME##AnyOrder(nullptr, std::forward(lhs), \ - std::forward(rhs)); \ + return NAME##AnyOrder( \ + nullptr, std::forward(lhs), std::forward(rhs)); \ } XLA_COMMUTATIVE_BINOP_PATTERN(Add) XLA_BINOP_PATTERN(Atan2) @@ -2053,6 +2133,7 @@ XLA_COMMUTATIVE_BINOP_PATTERN(Ne) XLA_BINOP_PATTERN(Outfeed) XLA_BINOP_PATTERN(Pad) XLA_BINOP_PATTERN(Power) +XLA_BINOP_PATTERN(ReduceWindow) XLA_BINOP_PATTERN(Remainder) XLA_BINOP_PATTERN(Send) XLA_BINOP_PATTERN(Subtract) @@ -2099,6 +2180,7 @@ XLA_BINOP_PATTERN(ShiftRightLogical) .WithOperand(2, std::forward(arg2)); \ } XLA_TERNOP_PATTERN(Clamp); +XLA_TERNOP_PATTERN(Scatter); XLA_TERNOP_PATTERN(Select); #undef XLA_TERNOP_PATTERN @@ -2151,8 +2233,10 @@ inline auto WithOperands(Matcher&& m, int64 operand_num, FirstArg&& first_arg, // We could implement all ops as "variadic" ops, but it would make the // already-bad compile errors even worse. +XLA_VARIADIC_OP_PATTERN(AfterAll); XLA_VARIADIC_OP_PATTERN(Concatenate); XLA_VARIADIC_OP_PATTERN(CustomCall); +XLA_VARIADIC_OP_PATTERN(Map) XLA_VARIADIC_OP_PATTERN(Reduce); XLA_VARIADIC_OP_PATTERN(Tuple); diff --git a/tensorflow/compiler/xla/service/pattern_matcher_test.cc b/tensorflow/compiler/xla/service/pattern_matcher_test.cc index 13886fa6f5b7b55283e6e420734a22312987d8a6..5c3c009a68bffbda8642fceedfb724879fbf1530 100644 --- a/tensorflow/compiler/xla/service/pattern_matcher_test.cc +++ b/tensorflow/compiler/xla/service/pattern_matcher_test.cc @@ -242,8 +242,8 @@ TEST(PatternMatcherTest, ConstantScalar) { HloModule test_module ENTRY test { a = s32[] constant(1) - b = s32[1,1] constant(s32[1,1]{{2}}) - c = s32[1,2] constant(s32[1,2]{{2,2}}) + b = s32[1,1] constant({{2}}) + c = s32[1,2] constant({{2,2}}) d = f32[] constant(1) e = f32[] constant(1.25) ROOT tuple = (s32[], s32[1,1], s32[1,2], f32[], f32[]) tuple(a,b,c,d,e) @@ -767,10 +767,11 @@ TEST(PatternMatcherTest, HloInstructionDescribeToAndExplain) { "in c = f64[] constant(2.25)"); EXPECT_DESC_AND_EXPLANATION( constant, m::Op().Is(iota.get()), - absl::StrCat("an HloInstruction which is 0x", absl::Hex(iota.get()), " (", - iota->ToShortString(), ")"), + absl::StrCat("an HloInstruction which is 0x", absl::Hex(iota.get()), + " (i = s32[42]{0} iota(), iota_dimension=0)"), absl::StrCat("HloInstruction 0x", absl::Hex(constant.get()), " is not 0x", - absl::Hex(iota.get()), " (", iota->ToShortString(), ")\n", + absl::Hex(iota.get()), + " (i = s32[42]{0} iota(), iota_dimension=0)\n" "in c = s32[] constant(0)")); } @@ -875,5 +876,60 @@ TEST(PatternMatcherTest, Parameter) { "in p0 = f32[] parameter(0)"); } +TEST(PatternMatcherTest, OneUseAndOneUser) { + auto param = + HloInstruction::CreateParameter(0, ShapeUtil::MakeShape(F32, {}), "p0"); + + EXPECT_FALSE(Match(param.get(), m::Op().WithOneUse())); + EXPECT_DESC_AND_EXPLANATION( + param, m::Op().WithOneUse(), + "an HloInstruction which has exactly one use", + "HloInstruction has 0 users, but expected exactly one.\n" + "in p0 = f32[] parameter(0)"); + + EXPECT_FALSE(Match(param.get(), m::Op().WithOneUser())); + EXPECT_DESC_AND_EXPLANATION( + param, m::Op().WithOneUser(), + "an HloInstruction which has exactly one user (but possibly is used " + "multiple times by that instruction)", + "HloInstruction has 0 users, but expected exactly one.\n" + "in p0 = f32[] parameter(0)"); + + { + auto reshape = + SetName("r", HloInstruction::CreateReshape( + ShapeUtil::MakeShape(F32, {1}), param.get())); + EXPECT_TRUE(Match(param.get(), m::Op().WithOneUse())); + EXPECT_TRUE(Match(param.get(), m::Op().WithOneUser())); + + auto reshape1 = + SetName("r1", HloInstruction::CreateReshape( + ShapeUtil::MakeShape(F32, {1}), param.get())); + EXPECT_FALSE(Match(param.get(), m::Op().WithOneUse())); + EXPECT_FALSE(Match(param.get(), m::Op().WithOneUser())); + + const char* kMultipleUserExplanation = + "HloInstruction has 2 users, but expected exactly one.\n" + "All users:\n" + " - r = f32[1]{0} reshape(f32[] p0)\n" + " - r1 = f32[1]{0} reshape(f32[] p0)\n" + "in p0 = f32[] parameter(0)"; + EXPECT_EQ(Explanation(param.get(), m::Op().WithOneUse()), + kMultipleUserExplanation); + EXPECT_EQ(Explanation(param.get(), m::Op().WithOneUser()), + kMultipleUserExplanation); + } + + auto add = SetName("add", HloInstruction::CreateBinary( + ShapeUtil::MakeShape(F32, {}), HloOpcode::kAdd, + param.get(), param.get())); + EXPECT_TRUE(Match(param.get(), m::Op().WithOneUser())); + EXPECT_FALSE(Match(param.get(), m::Op().WithOneUse())); + EXPECT_EQ(Explanation(param.get(), m::Op().WithOneUse()), + "HloInstruction is used 2 times by its user, but is expected to be " + "used just once: add = f32[] add(f32[] p0, f32[] p0)\n" + "in p0 = f32[] parameter(0)"); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/service.cc b/tensorflow/compiler/xla/service/service.cc index 3b336d5c9db80ff2ca8d0e45396dca66a29a0494..ae5bd93e7c56117cc78ecc729d370250787efac6 100644 --- a/tensorflow/compiler/xla/service/service.cc +++ b/tensorflow/compiler/xla/service/service.cc @@ -746,9 +746,9 @@ Status Service::GetDeviceHandles(const GetDeviceHandlesRequest* arg, } if (available_device_count < arg->device_count() * replica_count) { return ResourceExhausted( - "Requested device count (%d) exceeds the number of available devices " - "on the target (%d)", - arg->device_count(), available_device_count); + "Requested logical device count (%d) with replica count (%d) exceeds " + "the number of available physical devices on the target (%d)", + arg->device_count(), replica_count, available_device_count); } for (int64 i = 0; i < arg->device_count(); ++i) { @@ -1078,9 +1078,11 @@ Status Service::ComputeConstantGraph(const ComputeConstantGraphRequest* arg, ProgramShape program_shape(arg->computation().host_program_shape()); TF_DCHECK_OK(ShapeUtil::ValidateShape(program_shape.result())); + absl::optional output_layout; if (arg->has_output_layout()) { + output_layout = Layout::CreateFromProto(arg->output_layout()); TF_RETURN_IF_ERROR(LayoutUtil::ValidateLayoutForShape( - arg->output_layout(), program_shape.result())); + *output_layout, program_shape.result())); } HloModuleConfig config(program_shape); @@ -1096,8 +1098,8 @@ Status Service::ComputeConstantGraph(const ComputeConstantGraphRequest* arg, // relayout here. // // TODO(b/77824332): Make HloEvaluator take care of the re-layout. - if (arg->has_output_layout()) { - result_literal = result_literal.Relayout(arg->output_layout()); + if (output_layout.has_value()) { + result_literal = result_literal.Relayout(*output_layout); } *result->mutable_literal() = result_literal.ToProto(); diff --git a/tensorflow/compiler/xla/service/transpose_folding_test.cc b/tensorflow/compiler/xla/service/transpose_folding_test.cc index 17cdaa74fc328d156292f5af828d4222a9a01f1f..3ca53edc8171a134f2bfb9a36beacfd2d2e0d425 100644 --- a/tensorflow/compiler/xla/service/transpose_folding_test.cc +++ b/tensorflow/compiler/xla/service/transpose_folding_test.cc @@ -139,9 +139,9 @@ TEST_F(TransposeFoldingTest, FoldDotTransposeConstant) { HloModule FoldDotTransposeConstant ENTRY entry_computation { - constant = f32[2,1]{1,0} constant(f32[2,1] { { 1 }, { 2 } }) + constant = f32[2,1]{1,0} constant({ { 1 }, { 2 } }) transpose = f32[1,2]{1,0} transpose(constant), dimensions={1,0} - constant.1 = f32[3,2]{1,0} constant(f32[3,2] { { 1, 2 }, { 3, 4 }, { 5, 6 } }) + constant.1 = f32[3,2]{1,0} constant({ { 1, 2 }, { 3, 4 }, { 5, 6 } }) transpose.1 = f32[2,3]{1,0} transpose(constant.1), dimensions={1,0} ROOT dot = f32[1,3]{1,0} dot(transpose, transpose.1), lhs_contracting_dims={1}, rhs_contracting_dims={0} } diff --git a/tensorflow/compiler/xla/service/while_loop_constant_sinking_test.cc b/tensorflow/compiler/xla/service/while_loop_constant_sinking_test.cc index 75d406435b6f58faecc86b82c33e9e2dd6bccbea..3bcf5c38309a86e9e3cab3268f3f065005f7a923 100644 --- a/tensorflow/compiler/xla/service/while_loop_constant_sinking_test.cc +++ b/tensorflow/compiler/xla/service/while_loop_constant_sinking_test.cc @@ -129,7 +129,7 @@ condition { ENTRY entry { const_0 = f32[2] constant({1, 2}) - const_1 = (f32[2], f32[2]) constant((f32[2], f32[2]) ({2, 1},{3,1})) + const_1 = (f32[2], f32[2]) constant(({2, 1},{3,1})) while_init = (f32[2],(f32[2],f32[2])) tuple(const_0, const_1) ROOT while = (f32[2],(f32[2],f32[2])) while(while_init), condition=condition, body=body } @@ -206,8 +206,8 @@ body { p_body.0 = f32[2] get-tuple-element((f32[2],f32[2]) p_body), index=0 p_body.1 = f32[2] get-tuple-element((f32[2],f32[2]) p_body), index=1 - token = token[] after-all() - outfeed = token[] outfeed(p_body.0, token) + token0 = token[] after-all() + outfeed = token[] outfeed(p_body.0, token0) ROOT root = (f32[2],f32[2],f32[2]) tuple(p_body.0, p_body.1, p_body.1) } @@ -305,7 +305,7 @@ condition { ENTRY entry { const_0 = f32[] constant(0) - const_1 = (f32[], f32[]) constant((f32[], f32[]) (1, 10)) + const_1 = (f32[], f32[]) constant((1, 10)) while_init = (f32[],(f32[],f32[])) tuple(const_0, const_1) ROOT while = (f32[],(f32[],f32[])) while(while_init), condition=condition, body=body } diff --git a/tensorflow/compiler/xla/service/while_loop_simplifier_test.cc b/tensorflow/compiler/xla/service/while_loop_simplifier_test.cc index 4950e8269e9cf0723d717bd1734518d104c0c9f2..3713989ca2f64ee1d94c9f77255017909d957da2 100644 --- a/tensorflow/compiler/xla/service/while_loop_simplifier_test.cc +++ b/tensorflow/compiler/xla/service/while_loop_simplifier_test.cc @@ -22,6 +22,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_dce.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_matchers.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/service/tuple_simplifier.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" @@ -554,8 +555,7 @@ TEST_F(WhileLoopSimplifierTest, FlattenNestedTuple) { HloInstruction* new_while = FindFirstWhile(m.get()); Shape flat_tuple = - ShapeUtil::ParseShapeString("(s32[1], s32[2], s32[3], s32[4])") - .ValueOrDie(); + ParseShape("(s32[1], s32[2], s32[3], s32[4])").ValueOrDie(); SCOPED_TRACE(m->ToString()); EXPECT_TRUE(ShapeUtil::Equal(new_while->shape(), flat_tuple)); EXPECT_TRUE(ShapeUtil::Equal( @@ -567,8 +567,7 @@ TEST_F(WhileLoopSimplifierTest, FlattenNestedTuple) { flat_tuple)); EXPECT_TRUE(ShapeUtil::Equal( m->entry_computation()->root_instruction()->shape(), - ShapeUtil::ParseShapeString("((s32[1]), (s32[2], s32[3], (s32[4])))") - .ValueOrDie())); + ParseShape("((s32[1]), (s32[2], s32[3], (s32[4])))").ValueOrDie())); } // Edge-case: All elements of the loop carry are constants which can be removed, @@ -641,8 +640,7 @@ TEST_F(WhileLoopSimplifierTest, RemoveConstantFromLoopCarry) { EXPECT_TRUE(TupleSimplifier().Run(m.get()).ok()); HloInstruction* new_while = FindFirstWhile(m.get()); - Shape new_while_shape = - ShapeUtil::ParseShapeString("(s32[1], s32[3])").ValueOrDie(); + Shape new_while_shape = ParseShape("(s32[1], s32[3])").ValueOrDie(); EXPECT_TRUE(ShapeUtil::Equal(new_while->shape(), new_while_shape)); EXPECT_TRUE(ShapeUtil::Equal( new_while->while_body()->root_instruction()->shape(), new_while_shape)); @@ -652,9 +650,9 @@ TEST_F(WhileLoopSimplifierTest, RemoveConstantFromLoopCarry) { EXPECT_TRUE(ShapeUtil::Equal( new_while->while_condition()->parameter_instruction(0)->shape(), new_while_shape)); - EXPECT_TRUE(ShapeUtil::Equal( - m->entry_computation()->root_instruction()->shape(), - ShapeUtil::ParseShapeString("(s32[1], s32[2], s32[3])").ValueOrDie())); + EXPECT_TRUE( + ShapeUtil::Equal(m->entry_computation()->root_instruction()->shape(), + ParseShape("(s32[1], s32[2], s32[3])").ValueOrDie())); EXPECT_THAT(m->entry_computation()->root_instruction(), op::Tuple(_, op::Constant(), _)); } @@ -712,7 +710,7 @@ TEST_F(WhileLoopSimplifierTest, MergeInductionVariables_Simple) { // We should have added a new loop counter for s32[] to the end of the tuple. SCOPED_TRACE(m->ToString()); Shape new_while_shape = - ShapeUtil::ParseShapeString("(s32[], s32[], s32[], s32[])").ValueOrDie(); + ParseShape("(s32[], s32[], s32[], s32[])").ValueOrDie(); EXPECT_TRUE(ShapeUtil::Equal(new_while->shape(), new_while_shape)); EXPECT_TRUE(ShapeUtil::Equal( new_while->while_body()->root_instruction()->shape(), new_while_shape)); diff --git a/tensorflow/compiler/xla/service/while_util_test.cc b/tensorflow/compiler/xla/service/while_util_test.cc index 5e6941933330fde29bc9c779aae4bb3c36914660..d92b9870f373564ae8fd904c8bf9f0d1afbff9c4 100644 --- a/tensorflow/compiler/xla/service/while_util_test.cc +++ b/tensorflow/compiler/xla/service/while_util_test.cc @@ -180,8 +180,8 @@ body { cond { param.c = (s32[], s32[]) parameter(0) - token = token[] after-all() - infeed = (pred[], token[]) infeed(token) + token0 = token[] after-all() + infeed = (pred[], token[]) infeed(token0) ROOT condition = pred[] get-tuple-element(infeed), index=0 } diff --git a/tensorflow/compiler/xla/shape.cc b/tensorflow/compiler/xla/shape.cc index 746ab9e9977b1b10cdb0cb57197027d65bd50f55..b206345db2ac2940b1f139c82fa03a93538b5ccd 100644 --- a/tensorflow/compiler/xla/shape.cc +++ b/tensorflow/compiler/xla/shape.cc @@ -32,7 +32,7 @@ Shape::Shape(const ShapeProto& shape_proto) { *add_tuple_shapes() = Shape(element_shape); } if (shape_proto.has_layout()) { - *mutable_layout() = shape_proto.layout(); + *mutable_layout() = Layout::CreateFromProto(shape_proto.layout()); } } @@ -48,7 +48,7 @@ ShapeProto Shape::ToProto() const { *proto.add_tuple_shapes() = shape.ToProto(); } if (has_layout()) { - *proto.mutable_layout() = layout(); + *proto.mutable_layout() = layout().ToProto(); } return proto; } diff --git a/tensorflow/compiler/xla/shape.h b/tensorflow/compiler/xla/shape.h index 7f6b14ab4286c696dce64d2250a3fe8a57e4865b..7643f64d8a5f0450be1cddad35cf7422afb89048 100644 --- a/tensorflow/compiler/xla/shape.h +++ b/tensorflow/compiler/xla/shape.h @@ -20,6 +20,7 @@ limitations under the License. #include #include "absl/types/optional.h" +#include "tensorflow/compiler/xla/layout.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/platform/types.h" @@ -76,21 +77,10 @@ class Shape { std::vector* mutable_tuple_shapes() { return &tuple_shapes_; } // Methods for accessing the layout field. - bool has_layout() const { return layout_.has_value(); } - const Layout& layout() const { - if (layout_.has_value()) { - return *layout_; - } else { - return Layout::default_instance(); - } - } - Layout* mutable_layout() { - if (!layout_.has_value()) { - layout_ = Layout(); - } - return &layout_.value(); - } - void clear_layout() { layout_.reset(); } + bool has_layout() const { return layout_.format() != INVALID_FORMAT; } + const Layout& layout() const { return layout_; } + Layout* mutable_layout() { return &layout_; } + void clear_layout() { layout_.Clear(); } void Swap(Shape* other) { using std::swap; @@ -101,7 +91,7 @@ class Shape { element_type_ = PRIMITIVE_TYPE_INVALID; dimensions_.clear(); tuple_shapes_.clear(); - layout_.reset(); + clear_layout(); } string SerializeAsString() const { return ToProto().SerializeAsString(); } @@ -118,8 +108,8 @@ class Shape { // The tuple element subshapes. This is nonempty only for tuple shapes. std::vector tuple_shapes_; - // The array layout of the shape. This is present only for array shapes. - absl::optional layout_; + // The layout of the shape. Only relevant for arrays. + Layout layout_; }; // Shape of the parameters and output of an XLA computation. This is analogous diff --git a/tensorflow/compiler/xla/shape_util.cc b/tensorflow/compiler/xla/shape_util.cc index b95fabf488291b0a7f393cb9f7f4a5dc9eb7c7eb..be7d71ada009535a5c08aa3d3d062fa451cfeef3 100644 --- a/tensorflow/compiler/xla/shape_util.cc +++ b/tensorflow/compiler/xla/shape_util.cc @@ -164,9 +164,9 @@ StatusOr MakeShapeWithLayoutInternal( TF_ASSIGN_OR_RETURN(Shape shape, ShapeUtil::MakeValidatedShape(element_type, dimensions)); auto min2maj = shape.mutable_layout()->mutable_minor_to_major(); - min2maj->Clear(); + min2maj->clear(); for (int64 value : minor_to_major) { - min2maj->Add(value); + min2maj->push_back(value); } if (!shape.has_layout()) { return InvalidArgument("Shape has no layout."); @@ -234,7 +234,7 @@ StatusOr MakeShapeWithLayoutInternal( /* static */ StatusOr ShapeUtil::MakeValidatedShape( PrimitiveType element_type, absl::Span dimensions) { - CHECK(IsArrayPrimitiveType(element_type)); + CHECK(IsArrayPrimitiveType(element_type)) << element_type; Shape result; TF_RETURN_IF_ERROR(PopulateShape(element_type, dimensions, &result)); return result; @@ -480,54 +480,6 @@ ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout( return IsScalar(shape) && shape.element_type() == element_type; } -namespace { - -// Class to memoize the computation of -// absl::AsciiStrToLower(PrimitiveType_Name(p)) -// for all PrimitiveType values "p" -class PrimitiveTypeNameGenerator { - public: - PrimitiveTypeNameGenerator() { - for (int i = 0; i < PrimitiveType_ARRAYSIZE; i++) { - if (PrimitiveType_IsValid(i)) { - lowercase_name_[i] = absl::AsciiStrToLower( - PrimitiveType_Name(static_cast(i))); - } - } - } - const string& LowercaseName(PrimitiveType t) { - return lowercase_name_[static_cast(t)]; - } - - private: - string lowercase_name_[PrimitiveType_ARRAYSIZE]; -}; - -const string& LowercasePrimitiveTypeName(PrimitiveType s) { - static PrimitiveTypeNameGenerator* gen = new PrimitiveTypeNameGenerator(); - return gen->LowercaseName(s); -} - -StatusOr StringToPrimitiveType(const string& name) { - static std::unordered_map* name_to_type = [] { - static auto* map = new std::unordered_map; - for (int i = 0; i < PrimitiveType_ARRAYSIZE; i++) { - if (PrimitiveType_IsValid(i)) { - auto value = static_cast(i); - (*map)[LowercasePrimitiveTypeName(value)] = value; - } - } - return map; - }(); - auto found = name_to_type->find(name); - if (found == name_to_type->end()) { - return InvalidArgument("Invalid element type string: \"%s\".", name); - } - return found->second; -} - -} // namespace - /* static */ string ShapeUtil::HumanString(const Shape& shape) { if (IsTuple(shape)) { string text = "("; @@ -539,8 +491,9 @@ StatusOr StringToPrimitiveType(const string& name) { text += ")"; return text; } - return StrCat(LowercasePrimitiveTypeName(shape.element_type()), "[", - absl::StrJoin(shape.dimensions(), ","), "]"); + return StrCat( + primitive_util::LowercasePrimitiveTypeName(shape.element_type()), "[", + absl::StrJoin(shape.dimensions(), ","), "]"); } /* static */ string ShapeUtil::HumanStringWithLayout(const Shape& shape) { @@ -554,7 +507,8 @@ StatusOr StringToPrimitiveType(const string& name) { text += ")"; return text; } - string result = StrCat(LowercasePrimitiveTypeName(shape.element_type()), "["); + string result = StrCat( + primitive_util::LowercasePrimitiveTypeName(shape.element_type()), "["); for (int i = 0; i < shape.dimensions().size(); i++) { StrAppend(&result, (i > 0) ? "," : "", shape.dimensions(i)); } @@ -580,116 +534,6 @@ StatusOr StringToPrimitiveType(const string& name) { HumanString(program_shape.result())); } -namespace { -// Parses shapes with simple recursive descent structure -- consumes from the -// front of s and passes that view recursively as required. -StatusOr ParseShapeStringInternal(absl::string_view* s) { - *s = StripLeadingAsciiWhitespace(*s); - - if (absl::ConsumePrefix(s, "(")) { // Tuple. - std::vector shapes; - bool must_end = false; - while (true) { - if (absl::ConsumePrefix(s, ")")) { - break; - } else if (must_end) { - return InvalidArgument("Expected end of tuple; got: \"%s\"", *s); - } - shapes.emplace_back(); - TF_ASSIGN_OR_RETURN(shapes.back(), ParseShapeStringInternal(s)); - *s = StripLeadingAsciiWhitespace(*s); - must_end = !absl::ConsumePrefix(s, ","); - } - return ShapeUtil::MakeTupleShape(shapes); - } - - string element_type_string; - string dimensions_string; - string format_string; - string layout_string; - // absl::string_view is not compatible with internal RE2 StringPiece, so - // we convert in to the RE2-consumable type and then consume the corresponding - // amount from our string_view type. - static LazyRE2 shape_pattern = { - "^(\\w*\\d*)\\[([\\d,\\s]*)\\](?:\\s*(dense|sparse)?\\s*{([\\d,\\s]+)})" - "?"}; - tensorflow::RegexpStringPiece s_consumable(s->data(), s->size()); - if (RE2::Consume(&s_consumable, *shape_pattern, &element_type_string, - &dimensions_string, &format_string, &layout_string)) { - size_t consumed = s->size() - s_consumable.size(); - s->remove_prefix(consumed); - auto string_to_int64 = [&s](absl::string_view input) -> StatusOr { - int64 element; - if (!absl::SimpleAtoi(input, &element)) { - return InvalidArgument( - "Invalid s64 value in parsed shape string: \"%s\" in \"%s\"", input, - *s); - } - return element; - }; - - auto comma_list_to_int64s = - [string_to_int64](const string& input) -> StatusOr> { - std::vector results; - for (const auto& piece : absl::StrSplit(input, ',', absl::SkipEmpty())) { - TF_ASSIGN_OR_RETURN(int64 element, string_to_int64(piece)); - results.push_back(element); - } - return results; - }; - - // Extract the dimensions. - TF_ASSIGN_OR_RETURN(std::vector dimensions, - comma_list_to_int64s(dimensions_string)); - - // Extract the primitive element type. - TF_ASSIGN_OR_RETURN(const PrimitiveType primitive_type, - StringToPrimitiveType(element_type_string)); - if (primitive_type == PRIMITIVE_TYPE_INVALID || primitive_type == TUPLE) { - return InvalidArgument("Invalid element type string: \"%s\".", - element_type_string); - } - - Shape result; - if (primitive_type == OPAQUE) { - result = ShapeUtil::MakeOpaqueShape(); - } else if (primitive_type == TOKEN) { - result = ShapeUtil::MakeTokenShape(); - } else if (format_string.empty() && layout_string.empty()) { - // Create a shape without a layout set. - TF_ASSIGN_OR_RETURN( - result, ShapeUtil::MakeValidatedShape(primitive_type, dimensions)); - } else if (format_string == "sparse") { - TF_ASSIGN_OR_RETURN(int64 max_elements, string_to_int64(layout_string)); - result = ShapeUtil::MakeShapeWithSparseLayout(primitive_type, dimensions, - max_elements); - } else if (format_string.empty() || format_string == "dense") { - // Extract the layout minor-to-major and set it. - TF_ASSIGN_OR_RETURN(std::vector min2maj, - comma_list_to_int64s(layout_string)); - TF_ASSIGN_OR_RETURN(result, MakeShapeWithLayoutInternal( - primitive_type, dimensions, min2maj)); - } else { - // This should not be reached. - LOG(FATAL) << "Unhandled condition when parsing shape; format: \"" - << format_string << "\", layout: \"" << layout_string << "\""; - } - TF_RETURN_IF_ERROR(ShapeUtil::ValidateShape(result)); - return std::move(result); - } - - return InvalidArgument("Invalid shape string to parse: \"%s\"", *s); -} -} // namespace - -/* static */ StatusOr ShapeUtil::ParseShapeString(absl::string_view s) { - TF_ASSIGN_OR_RETURN(Shape shape, ParseShapeStringInternal(&s)); - if (!s.empty()) { - return InvalidArgument("Invalid shape string to parse: \"%s\"", s); - } - return shape; -} - /* static */ bool ShapeUtil::SameDimensions(const Shape& lhs, const Shape& rhs) { CHECK(ShapeUtil::IsArray(lhs)); @@ -867,13 +711,13 @@ StatusOr ParseShapeStringInternal(absl::string_view* s) { if (shape.dimensions_size() != 0) { return InvalidArgument( "shape has %s element type, but has dimensions field: %s", - LowercasePrimitiveTypeName(shape.element_type()), + primitive_util::LowercasePrimitiveTypeName(shape.element_type()), shape.ShortDebugString()); } if (shape.has_layout()) { return InvalidArgument( "shape has %s element type, but has layout field: %s", - LowercasePrimitiveTypeName(shape.element_type()), + primitive_util::LowercasePrimitiveTypeName(shape.element_type()), shape.ShortDebugString()); } return Status::OK(); @@ -1067,6 +911,11 @@ bool ShapeUtil::IsLeafIndex(const Shape& shape, const ShapeIndex& index) { return absl::c_linear_search(shape.dimensions(), 1); } +/* static */ Shape ShapeUtil::DropDegenerateDimensions(const Shape& shape) { + return FilterDimensions( + [&](int64 dim) -> bool { return shape.dimensions()[dim] != 1; }, shape); +} + namespace { // Helper for ForEachSubshape which visits the subshapes of the given shape in @@ -1168,7 +1017,7 @@ Status ForEachMutableSubshapeHelper( // Let the argument `permutation` be P. This is a permutation over `shape`'s // dimensions, so our return value will be a shape with dims P.I = P. Our // goal is to construct a layout permutation L* that we can apply to P such - // that that the physical dimension ordering of the returned shape is the same + // that the physical dimension ordering of the returned shape is the same // as that of the original shape, namely L'. // // Our returned shape has dims P and layout L*, so its in-memory layout is @@ -1618,10 +1467,10 @@ ShapeUtil::DimensionsUnmodifiedByReshape(const Shape& input_shape, if (LayoutUtil::HasLayout(shape)) { Layout* layout = shape.mutable_layout(); layout->set_format(DENSE); - for (size_t i = 0; i < layout->minor_to_major().size();) { + for (int64 i = 0; i < layout->minor_to_major().size();) { if (layout->minor_to_major(i) == dim_to_delete) { layout->mutable_minor_to_major()->erase( - layout->minor_to_major().begin() + i); + layout->mutable_minor_to_major()->begin() + i); continue; } if (layout->minor_to_major(i) > dim_to_delete) { diff --git a/tensorflow/compiler/xla/shape_util.h b/tensorflow/compiler/xla/shape_util.h index 84a27f662a57ba274562e2e9be57b7e971c9b477..6b7a9cd34f25f2088bdb8d2c7f0412e5d8519d23 100644 --- a/tensorflow/compiler/xla/shape_util.h +++ b/tensorflow/compiler/xla/shape_util.h @@ -241,10 +241,6 @@ class ShapeUtil { // (param_name: f32[42x12], ...) -> f32[24x42] static string HumanString(const ProgramShape& program_shape); - // Parses a ShapeUtil::HumanString-format shape string back into a shape - // object. - static StatusOr ParseShapeString(absl::string_view s); - // Returns whether the LHS and RHS shapes have the same dimensions; note: does // not check element type. // Precondition: IsArray(lhs) && IsArray(rhs) @@ -551,6 +547,9 @@ class ShapeUtil { // (dimensions with bound 1). static bool HasDegenerateDimensions(const Shape& shape); + // Drops any degenerate dimensions (i.e. dimensions of size 1) + static Shape DropDegenerateDimensions(const Shape& shape); + // Permutes the dimensions by the given permutation, so // return_value.dimensions[permutation[i]] = argument.dimensions[i]. // diff --git a/tensorflow/compiler/xla/shape_util_test.cc b/tensorflow/compiler/xla/shape_util_test.cc index 60bdbe302045e6f3b4bae500c50bc68fb217525d..0a3081f5161f80ac97e864ba08d186df4fbdb55d 100644 --- a/tensorflow/compiler/xla/shape_util_test.cc +++ b/tensorflow/compiler/xla/shape_util_test.cc @@ -82,102 +82,6 @@ TEST(ShapeUtilTest, Rank4DimensionIndexing) { ASSERT_EQ(3, shape.dimensions(0)); } -TEST(ShapeUtilTest, ParseShapeStringR2F32) { - string shape_string = "f32[123,456]"; - TF_ASSERT_OK_AND_ASSIGN(Shape actual, - ShapeUtil::ParseShapeString(shape_string)); - Shape expected = ShapeUtil::MakeShape(F32, {123, 456}); - ASSERT_TRUE(ShapeUtil::Equal(expected, actual)) - << "expected: " << ShapeUtil::HumanString(expected) - << "actual: " << ShapeUtil::HumanString(actual); -} - -TEST(ShapeUtilTest, ParseShapeStringTupleOfArrays) { - string shape_string = "(f32[1572864],s8[5120,1024])"; - TF_ASSERT_OK_AND_ASSIGN(Shape actual, - ShapeUtil::ParseShapeString(shape_string)); - Shape expected = - ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(F32, {1572864}), - ShapeUtil::MakeShape(S8, {5120, 1024})}); - ASSERT_TRUE(ShapeUtil::Equal(expected, actual)) - << "expected: " << ShapeUtil::HumanString(expected) - << "actual: " << ShapeUtil::HumanString(actual); -} - -TEST(ShapeUtilTest, ParseShapeStringNestedTuple) { - string shape_string = "(f32[1],(f32[2], token[]), opaque[], f32[3])"; - TF_ASSERT_OK_AND_ASSIGN(Shape actual, - ShapeUtil::ParseShapeString(shape_string)); - Shape expected = ShapeUtil::MakeTupleShape({ - ShapeUtil::MakeShape(F32, {1}), - ShapeUtil::MakeTupleShape( - {ShapeUtil::MakeShape(F32, {2}), ShapeUtil::MakeTokenShape()}), - ShapeUtil::MakeOpaqueShape(), - ShapeUtil::MakeShape(F32, {3}), - }); - ASSERT_TRUE(ShapeUtil::Equal(expected, actual)) - << "expected: " << ShapeUtil::HumanString(expected) - << "actual: " << ShapeUtil::HumanString(actual); -} - -TEST(ShapeUtilTest, ParseShapeStringWithLayout) { - string shape_string = "f32[123,456]{0,1}"; - TF_ASSERT_OK_AND_ASSIGN(Shape actual, - ShapeUtil::ParseShapeString(shape_string)); - Shape expected = ShapeUtil::MakeShapeWithLayout(F32, {123, 456}, {0, 1}); - ASSERT_TRUE(ShapeUtil::Equal(expected, actual)) - << "expected: " << ShapeUtil::HumanString(expected) - << "actual: " << ShapeUtil::HumanString(actual); -} - -TEST(ShapeUtilTest, ParseShapeStringWithExplicitDenseLayout) { - string shape_string = "f32[123,456]dense{0,1}"; - TF_ASSERT_OK_AND_ASSIGN(Shape actual, - ShapeUtil::ParseShapeString(shape_string)); - Shape expected = ShapeUtil::MakeShapeWithLayout(F32, {123, 456}, {0, 1}); - ASSERT_TRUE(ShapeUtil::Equal(expected, actual)) - << "expected: " << ShapeUtil::HumanString(expected) - << "actual: " << ShapeUtil::HumanString(actual); -} - -TEST(ShapeUtilTest, ParseShapeStringWithSparseLayout) { - string shape_string = "f32[123,456]sparse{10}"; - TF_ASSERT_OK_AND_ASSIGN(Shape actual, - ShapeUtil::ParseShapeString(shape_string)); - Shape expected = ShapeUtil::MakeShapeWithSparseLayout(F32, {123, 456}, 10); - ASSERT_TRUE(ShapeUtil::Equal(expected, actual)) - << "expected: " << ShapeUtil::HumanString(expected) - << "actual: " << ShapeUtil::HumanString(actual); -} - -TEST(ShapeUtilTest, ParseOpaqueType) { - TF_ASSERT_OK_AND_ASSIGN(Shape actual, - ShapeUtil::ParseShapeString("opaque[]")); - Shape expected = ShapeUtil::MakeOpaqueShape(); - ASSERT_TRUE(ShapeUtil::Equal(expected, actual)) - << "expected: " << ShapeUtil::HumanString(expected) - << "actual: " << ShapeUtil::HumanString(actual); -} - -TEST(ShapeUtilTest, ParseTokenType) { - TF_ASSERT_OK_AND_ASSIGN(Shape actual, ShapeUtil::ParseShapeString("token[]")); - Shape expected = ShapeUtil::MakeTokenShape(); - ASSERT_TRUE(ShapeUtil::Equal(expected, actual)) - << "expected: " << ShapeUtil::HumanString(expected) - << "actual: " << ShapeUtil::HumanString(actual); -} - -TEST(ShapeUtilTest, ParseInvalidShapeString) { - string shape_strings[] = { - "f32[123,456]foobar{0,1}", "f32[123,456]sparse{0,1}", "f32[123,456]{foo}", - "f32[123,456]dense{foo}", "f32[123,456]sparse{foo}", - }; - for (const string& shape_string : shape_strings) { - StatusOr result = ShapeUtil::ParseShapeString(shape_string); - ASSERT_FALSE(result.ok()) << "shape: " << shape_string; - } -} - TEST(ShapeUtilTest, CompatibleIdenticalShapes) { Shape shape1 = ShapeUtil::MakeShape(F32, {3, 2}); Shape shape2 = ShapeUtil::MakeShape(F32, {3, 2}); diff --git a/tensorflow/compiler/xla/tests/BUILD b/tensorflow/compiler/xla/tests/BUILD index 2c18e2fd10105b6f0c146cad1842c7723699c8d9..0300b64ed59a3d4d8b0cd161109c97cabfdc6734 100644 --- a/tensorflow/compiler/xla/tests/BUILD +++ b/tensorflow/compiler/xla/tests/BUILD @@ -1,6 +1,13 @@ # Description: # Base testing infrastructure for XLA. +load("//tensorflow/compiler/xla/tests:build_defs.bzl", "generate_backend_suites", "generate_backend_test_macros", "xla_test", "xla_test_library") +load( + "//tensorflow/core:platform/default/build_config_root.bzl", + "tf_cuda_tests_tags", +) +load("//tensorflow:tensorflow.bzl", "tf_cc_binary", "tf_cc_test") + licenses(["notice"]) # Apache 2.0 package( @@ -23,17 +30,6 @@ filegroup( ]), ) -load("//tensorflow/compiler/xla/tests:build_defs.bzl", "xla_test") -load("//tensorflow/compiler/xla/tests:build_defs.bzl", "xla_test_library") -load("//tensorflow/compiler/xla/tests:build_defs.bzl", "generate_backend_suites") -load("//tensorflow/compiler/xla/tests:build_defs.bzl", "generate_backend_test_macros") -load("//tensorflow:tensorflow.bzl", "tf_cc_binary") -load("//tensorflow:tensorflow.bzl", "tf_cc_test") -load( - "//tensorflow/core:platform/default/build_config_root.bzl", - "tf_cuda_tests_tags", -) - # Generate test_suites for all backends, named "${backend}_tests". generate_backend_suites() @@ -303,10 +299,31 @@ xla_test( name = "conv_depthwise_test", timeout = "long", srcs = ["conv_depthwise_test.cc"], + shard_count = 50, + deps = [ + "//tensorflow/compiler/xla:execution_options_util", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla/client:xla_computation", + "//tensorflow/compiler/xla/service:bfloat16_normalization", + "//tensorflow/compiler/xla/service:despecializer", + "//tensorflow/compiler/xla/service:hlo_parser", + "//tensorflow/compiler/xla/tests:client_library_test_base", + "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "@com_google_absl//absl/types:optional", + ], +) + +xla_test( + name = "grouped_convolution_test", + timeout = "long", + srcs = ["grouped_convolution_test.cc"], blacklisted_backends = [ # disabled because of a break b/119590850. - "cpu", "gpu", + # disabled because it times out. + "cpu", ], shard_count = 50, deps = [ @@ -1327,6 +1344,7 @@ xla_test( xla_test( name = "custom_call_test", srcs = ["custom_call_test.cc"], + backends = ["cpu"], deps = [ "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:literal_util", diff --git a/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc b/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc index 0615f9425c1289d666641f4d581946b44b4895ce..915b456b52215f8d6a9eb6c5b933f3502f1d3d2c 100644 --- a/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc +++ b/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc @@ -329,13 +329,13 @@ TEST_P(ArrayElementwiseOpTestParamCount, AddManyValues) { Literal b_literal = LiteralUtil::CreateR1({b_values}); std::unique_ptr b_data = client_->TransferToServer(b_literal).ConsumeValueOrDie(); - auto b_constant = Parameter(&builder, 1, a_literal.shape(), "b_param"); - auto b_param = ConstantR1(&builder, b_values); + auto b_param = Parameter(&builder, 1, a_literal.shape(), "b_param"); + auto b_constant = ConstantR1(&builder, b_values); - auto sum1 = Add(a_constant, b_constant); - auto sum2 = Add(a_constant, b_param); - auto sum3 = Add(a_param, b_constant); - auto sum4 = Add(a_param, b_param); + auto sum1 = Add(a_constant, b_param); + auto sum2 = Add(a_constant, b_constant); + auto sum3 = Add(a_param, b_param); + auto sum4 = Add(a_param, b_constant); auto sum = Add(sum1, sum2); sum = Add(sum, sum3); @@ -350,6 +350,44 @@ TEST_P(ArrayElementwiseOpTestParamCount, AddManyValues) { error_spec_); } +// TODO(b/119692968): This test runs OOM on the GPU and CPU backend. +XLA_TEST_F(ArrayElementwiseOpTest, + DISABLED_ON_GPU(DISABLED_ON_CPU(DeeplyNestedAddWithSlices))) { + XlaBuilder builder(TestName()); + std::vector values(30, 0.0); + auto a_literal = LiteralUtil::CreateR1(values); + auto a = Parameter(&builder, 0, a_literal.shape(), "x"); + auto b_literal = LiteralUtil::CreateR1(values); + auto b = Parameter(&builder, 1, b_literal.shape(), "x"); + + // Construct a sequence of diamond-shaped gadgets like this: + // + // add + // / \ + // slice slice + // \ / + // add + // + // Each 'left' slice removes the last element, each 'right' slice removes the + // first element. In this way, we index into the add with different + // multi-dimensional index arrays, which defeats the caching we use to avoid + // exponential compile time. + std::function generate_recursive = + [&](int64 slice_size) -> XlaOp { + if (slice_size == values.size()) { + return Add(a, b); + } + XlaOp param = generate_recursive(slice_size + 1); + auto slice1 = Slice(param, {0}, {slice_size}, {1}); + auto slice2 = Slice(param, {1}, {slice_size + 1}, {1}); + return Add(slice1, slice2); + }; + generate_recursive(1); + auto a_data = client_->TransferToServer(a_literal).ConsumeValueOrDie(); + auto b_data = client_->TransferToServer(b_literal).ConsumeValueOrDie(); + ComputeAndCompareR1(&builder, {0.0}, {a_data.get(), b_data.get()}); +} + XLA_TEST_F(ArrayElementwiseOpTest, SubTwoConstantF32s) { XlaBuilder builder(TestName()); auto a = ConstantR1(&builder, {-2.5f, 3.14f, 2.25f, -10.0f, 6.0f}); diff --git a/tensorflow/compiler/xla/tests/client_library_test_base.cc b/tensorflow/compiler/xla/tests/client_library_test_base.cc index 12c029983336cc9aed0fde4ce6881c9a00a9869e..697236dc6236738df08205fa3631a2919dd361c5 100644 --- a/tensorflow/compiler/xla/tests/client_library_test_base.cc +++ b/tensorflow/compiler/xla/tests/client_library_test_base.cc @@ -74,6 +74,9 @@ ClientLibraryTestBase::ClientLibraryTestBase( // default. execution_options_.mutable_debug_options()->add_xla_disable_hlo_passes( "constant_folding"); + + execution_options_.mutable_debug_options() + ->set_xla_hlo_evaluator_use_fast_path(true); } ClientLibraryTestBase::ClientLibraryTestBase(se::Platform* platform) @@ -88,6 +91,9 @@ ClientLibraryTestBase::ClientLibraryTestBase(se::Platform* platform) execution_options_.mutable_debug_options()->add_xla_disable_hlo_passes( "constant_folding"); + + execution_options_.mutable_debug_options() + ->set_xla_hlo_evaluator_use_fast_path(true); } string ClientLibraryTestBase::TestName() const { diff --git a/tensorflow/compiler/xla/tests/client_library_test_base.h b/tensorflow/compiler/xla/tests/client_library_test_base.h index 34148e5886d3806b19fc5bee90806c5678df345e..65a23dd883594b9bf9c37494a37e9be39b197788 100644 --- a/tensorflow/compiler/xla/tests/client_library_test_base.h +++ b/tensorflow/compiler/xla/tests/client_library_test_base.h @@ -76,7 +76,7 @@ class ClientLibraryTestBase : public ::testing::Test { void SetFastMathDisabled(bool disabled) { auto* opts = execution_options_.mutable_debug_options(); opts->set_xla_cpu_enable_fast_math(!disabled); - opts->set_xla_gpu_enable_fast_math(!disabled); + opts->set_xla_gpu_enable_fast_min_max(!disabled); } void SetSeed(uint64 seed) { execution_options_.set_seed(seed); } diff --git a/tensorflow/compiler/xla/tests/conv_depthwise_test.cc b/tensorflow/compiler/xla/tests/conv_depthwise_test.cc index 60ce576ceb20b89b59e72d821e63b0ccdee51b0b..627a17a0ca114085240dbaf28211bb3511cf0cab 100644 --- a/tensorflow/compiler/xla/tests/conv_depthwise_test.cc +++ b/tensorflow/compiler/xla/tests/conv_depthwise_test.cc @@ -50,9 +50,9 @@ class DepthwiseConvolution2DTest static std::vector GetConv2DTestCases() { std::vector config_set; std::vector> config_options = { - {128, 6, 3, 64}, {256, 5, 3, 256}, {256, 5, 2, 144}, {144, 5, 3, 64}, - {144, 5, 2, 256}, {8, 48, 17, 8}, {128, 20, 6, 64}, {128, 1, 2, 144}, - {256, 1, 2, 64}, {64, 14, 12, 172}, {16, 9, 4, 16}}; + {128, 6, 3, 64}, {256, 5, 3, 256}, {256, 5, 2, 144}, {144, 5, 3, 64}, + {144, 5, 2, 256}, {8, 48, 17, 8}, {128, 20, 6, 64}, {64, 14, 12, 172}, + {16, 9, 4, 16}, {128, 1, 2, 144}, {256, 1, 2, 64}}; for (auto option : config_options) { int64 feature = option[0]; @@ -136,7 +136,7 @@ string BuildHloTextDepthwiseConvolution2D( if (spec.activation_dims[1] == 1 && spec.kernel_dims[1] == 2) { return absl::StrFormat( R"( - HloModule TensorFlowDepthwiseConv, is_scheduled=true + HloModule TensorFlowDepthwiseConv ENTRY main { activation = %s[%s]{%s} parameter(0) @@ -161,7 +161,7 @@ string BuildHloTextDepthwiseConvolution2D( } else if (spec.stride == -1) { return absl::StrFormat( R"( - HloModule TensorFlowDepthwiseConv, is_scheduled=true + HloModule TensorFlowDepthwiseConv ENTRY main { activation = %s[%s]{%s} parameter(0) @@ -185,7 +185,7 @@ string BuildHloTextDepthwiseConvolution2D( } else { return absl::StrFormat( R"( - HloModule TensorFlowDepthwiseConv, is_scheduled=true + HloModule TensorFlowDepthwiseConv ENTRY main { activation = %s[%s]{%s} parameter(0) @@ -215,13 +215,13 @@ XLA_TEST_P(DepthwiseConvolution2DTest, DoIt) { const string hlo_text = BuildHloTextDepthwiseConvolution2D(spec, use_bfloat16); - EXPECT_TRUE(RunAndCompareNoHloPasses( - hlo_text, ErrorSpec{0.01, 0.01}, [](HloModule* module) -> Status { - BFloat16MixedPrecisionRemoval remover; - TF_RETURN_IF_ERROR(remover.Run(module).status()); - Despecializer despecializer; - return despecializer.Run(module).status(); - })); + EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{0.01, 0.01}, + [](HloModule* module) -> Status { + BFloat16MixedPrecisionRemoval remover; + TF_RETURN_IF_ERROR(remover.Run(module).status()); + Despecializer despecializer; + return despecializer.Run(module).status(); + })); } INSTANTIATE_TEST_CASE_P( diff --git a/tensorflow/compiler/xla/tests/copy_test.cc b/tensorflow/compiler/xla/tests/copy_test.cc index 3622f2c1e84639baed13059b21b20609d1347da6..df005a67097bb8aaf070c57d1c51acd1909fee12 100644 --- a/tensorflow/compiler/xla/tests/copy_test.cc +++ b/tensorflow/compiler/xla/tests/copy_test.cc @@ -133,7 +133,9 @@ XLA_TEST_F(CopyOpTest, CopyConstantR2DifferentLayouts) { // Reverse the minor-to-major order of the literal. Layout* literal_layout = literal.mutable_shape_do_not_use()->mutable_layout(); ASSERT_EQ(2, literal_layout->minor_to_major_size()); - literal_layout->mutable_minor_to_major()->SwapElements(0, 1); + // Swap the first and second elements. + *literal_layout->mutable_minor_to_major() = { + literal_layout->minor_to_major(1), literal_layout->minor_to_major(0)}; HloInstruction* constant = builder.AddInstruction( HloInstruction::CreateConstant(std::move(literal))); diff --git a/tensorflow/compiler/xla/tests/custom_call_test.cc b/tensorflow/compiler/xla/tests/custom_call_test.cc index 738b6442354b01364278e3e3c713aa2cdb5cf47d..cad43d1b5547d74701760fa623e50466fc15c263 100644 --- a/tensorflow/compiler/xla/tests/custom_call_test.cc +++ b/tensorflow/compiler/xla/tests/custom_call_test.cc @@ -54,11 +54,20 @@ void Add1ToValues(float* out, float** in) { out[2] = array[2] + 1; out[3] = array[3] + 1; } + +void F32TupleSwap(float** out, float** in) { + TF_ANNOTATE_MEMORY_IS_INITIALIZED(in[0], sizeof(float)); + TF_ANNOTATE_MEMORY_IS_INITIALIZED(in[1], sizeof(float)); + *out[0] = *in[1]; + *out[1] = *in[0]; +} + } // namespace REGISTER_CUSTOM_CALL_TARGET(R0F32Add2); REGISTER_CUSTOM_CALL_TARGET(R2F32ReduceSum); REGISTER_CUSTOM_CALL_TARGET(Add1ToValues); +REGISTER_CUSTOM_CALL_TARGET(F32TupleSwap); namespace xla { namespace { @@ -69,7 +78,7 @@ class CustomCallTest : public HloTestBase { Shape r2f32_ = ShapeUtil::MakeShape(F32, {2, 2}); }; -XLA_TEST_F(CustomCallTest, DISABLED_ON_GPU(CustomCallR0F32Add2)) { +XLA_TEST_F(CustomCallTest, CustomCallR0F32Add2) { auto module = CreateNewUnverifiedModule(); auto builder = HloComputation::Builder(TestName()); @@ -84,7 +93,7 @@ XLA_TEST_F(CustomCallTest, DISABLED_ON_GPU(CustomCallR0F32Add2)) { LiteralTestUtil::ExpectR0Near(44.0f, result, error_spec_); } -XLA_TEST_F(CustomCallTest, DISABLED_ON_GPU(CustomCallR2F32Reduce)) { +XLA_TEST_F(CustomCallTest, CustomCallR2F32Reduce) { auto module = CreateNewUnverifiedModule(); auto builder = HloComputation::Builder(TestName()); @@ -105,7 +114,7 @@ XLA_TEST_F(CustomCallTest, DISABLED_ON_GPU(CustomCallR2F32Reduce)) { LiteralTestUtil::ExpectR0Near(10.0f, result, error_spec_); } -XLA_TEST_F(CustomCallTest, DISABLED_ON_GPU(UsedInOtherComputations)) { +XLA_TEST_F(CustomCallTest, UsedInOtherComputations) { auto module = CreateNewUnverifiedModule(); auto b = HloComputation::Builder(TestName()); @@ -129,7 +138,7 @@ XLA_TEST_F(CustomCallTest, DISABLED_ON_GPU(UsedInOtherComputations)) { Array3D{{{2, 3}, {4, 5}}, {{3, 4}, {5, 6}}}, result); } -XLA_TEST_F(CustomCallTest, DISABLED_ON_GPU(InputAndOutputLayoutDiffer)) { +XLA_TEST_F(CustomCallTest, InputAndOutputLayoutDiffer) { auto module = CreateNewUnverifiedModule(); auto b = HloComputation::Builder(TestName()); @@ -151,7 +160,7 @@ XLA_TEST_F(CustomCallTest, DISABLED_ON_GPU(InputAndOutputLayoutDiffer)) { LiteralTestUtil::ExpectR2Equal({{2.f, 4.f}, {3.f, 5.f}}, result); } -XLA_TEST_F(CustomCallTest, DISABLED_ON_GPU(LayoutConstrained)) { +XLA_TEST_F(CustomCallTest, LayoutConstrained) { // The argument and result of the computation are set to different layouts, // but the custom call is layout constrained to a fixed operand and result // layout, so the correct result should be produced. @@ -176,6 +185,26 @@ XLA_TEST_F(CustomCallTest, DISABLED_ON_GPU(LayoutConstrained)) { LiteralTestUtil::ExpectR2Equal({{2.f, 3.f}, {4.f, 5.f}}, result); } +XLA_TEST_F(CustomCallTest, TupleOutput) { + const char* kModuleStr = R"( + HloModule m + test { + p0 = f32[] parameter(0) + p1 = f32[] parameter(1) + ROOT %custom-call = (f32[], f32[]) custom-call(f32[] %p0, f32[] %p1), custom_call_target="F32TupleSwap", operand_layout_constraints={f32[], f32[]} + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(kModuleStr)); + + Literal arg0 = LiteralUtil::CreateR0(7.f); + Literal arg1 = LiteralUtil::CreateR0(42.f); + + Literal expected = LiteralUtil::MakeTuple({&arg1, &arg0}); + Literal result = ExecuteAndTransfer(std::move(module), {&arg0, &arg1}); + EXPECT_EQ(result, expected); +} + class CustomCallClientAPITest : public ClientLibraryTestBase {}; // When using the client API, CustomCall targets can't begin with '$' -- these diff --git a/tensorflow/compiler/xla/tests/dot_operation_test.cc b/tensorflow/compiler/xla/tests/dot_operation_test.cc index 25091b8d5d5498edf3ce86efe225cd0e2fd8ff6b..c5d8b663f4abe77e05ec213d2e4e075c260a8655 100644 --- a/tensorflow/compiler/xla/tests/dot_operation_test.cc +++ b/tensorflow/compiler/xla/tests/dot_operation_test.cc @@ -30,7 +30,6 @@ limitations under the License. #include "tensorflow/compiler/xla/tests/test_utils.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/types.h" -#include "tensorflow/core/util/command_line_flags.h" namespace xla { namespace { diff --git a/tensorflow/compiler/xla/tests/grouped_convolution_test.cc b/tensorflow/compiler/xla/tests/grouped_convolution_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..8f7049910e70c4e591636a47c1b6ba72cf2c234f --- /dev/null +++ b/tensorflow/compiler/xla/tests/grouped_convolution_test.cc @@ -0,0 +1,245 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "absl/types/optional.h" +#include "tensorflow/compiler/xla/client/xla_computation.h" +#include "tensorflow/compiler/xla/execution_options_util.h" +#include "tensorflow/compiler/xla/service/bfloat16_normalization.h" +#include "tensorflow/compiler/xla/service/despecializer.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/tests/client_library_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/tests/test_macros.h" + +namespace xla { +namespace { + +string GetFloatDataType(bool use_bfloat16) { + return use_bfloat16 ? "bf16" : "f32"; +} + +struct GroupedConvolution2DSpec { + int64 input_feature, output_feature, window, stride, pad, lhs_dilate; + int64 group_size, group_count; + std::vector activation_dims; + std::vector activation_layout; + std::vector kernel_dims; + std::vector kernel_layout; + std::vector output_dims; + std::vector output_layout; +}; + +class GroupedConvolution2DTest + : public HloTestBase, + public ::testing::WithParamInterface< + ::testing::tuple> {}; + +static std::vector GetConv2DTestCases() { + std::vector config_set; + // Add to this set if you want a new test configuration. + // Rule : the penultimate number must be divisible by the last number. + std::vector> config_options = {{8, 2, 2, 1, 1024, 128}, + {512, 3, 3, 144, 1024, 16}, + {256, 3, 3, 129, 512, 64}, + {64, 1, 2, 127, 32, 8}, + {256, 3, 3, 256, 1024, 4}}; + + for (auto option : config_options) { + int64 output_feature = option[0]; + int64 activation_size = option[1]; + int64 kernel_size = option[2]; + int64 batch = option[3]; + int64 input_feature = option[4]; + int64 group_size = option[5]; + + std::vector kernel_layout = {3, 2, 1, 0}; + GroupedConvolution2DSpec config; + config.group_size = group_size; + config.group_count = input_feature / group_size; + config.output_feature = output_feature; + config.window = kernel_size; + + config.activation_dims = {batch, activation_size, activation_size, + input_feature}; + config.activation_layout = {3, 0, 2, 1}; + + config.kernel_dims = {kernel_size, kernel_size, group_size, output_feature}; + config.kernel_layout = {3, 2, 1, 0}; + + if (activation_size == 1 && kernel_size == 2) { + // Test for outer dim. + config.output_dims = {batch, activation_size + kernel_size - 1, + activation_size + kernel_size, output_feature}; + } else if (output_feature == 256) { + // Restrict dilation-based tests only to one feature configuration. + config.stride = activation_size - 1; + config.pad = 0; + config.lhs_dilate = output_feature / 32; + config.output_dims = {batch, output_feature / 32, + activation_size - kernel_size + 1, output_feature}; + } else { + config.stride = config.pad = config.lhs_dilate = -1; + config.output_dims = {batch, activation_size - kernel_size + 1, + activation_size - kernel_size + 1, output_feature}; + } + + // Try this layout for all kernel shapes. + config.output_layout = {3, 0, 2, 1}; + config_set.push_back(config); + + // Try other layouts only for certain kernel shapes. + if (kernel_size % 2 == 0) { + config.activation_layout = {0, 3, 2, 1}; + config_set.push_back(config); + + config.output_layout = {0, 3, 2, 1}; + config_set.push_back(config); + + config.activation_layout = {3, 0, 2, 1}; + config_set.push_back(config); + } + } + + return config_set; +} + +string GroupedConvolution2DTestDataToString( + const ::testing::TestParamInfo< + ::testing::tuple>& data) { + const auto& spec = ::testing::get<0>(data.param); + const string data_type = GetFloatDataType(::testing::get<1>(data.param)); + string str = absl::StrCat( + "activation_dims_", absl::StrJoin(spec.activation_dims, "x"), + "_activation_layout_", absl::StrJoin(spec.activation_layout, "_"), + "_kernel_dims_", absl::StrJoin(spec.kernel_dims, "x"), "_kernel_layout_", + absl::StrJoin(spec.kernel_layout, "_"), "_output_dims_", + absl::StrJoin(spec.output_dims, "x"), "_output_layout_", + absl::StrJoin(spec.output_layout, "_"), data_type); + // -1 indicates non-existence. + if (spec.stride != -1) { + absl::StrAppend(&str, "_lhs_dilation_", spec.lhs_dilate, "x1"); + } + + // Test names are not allowed to contain the '-' character. + absl::c_replace(str, '-', 'n'); + return str; +} + +string BuildHloTextGroupedConvolution2D(const GroupedConvolution2DSpec& spec, + bool use_bfloat16) { + const string data_type = GetFloatDataType(use_bfloat16); + if (spec.activation_dims[1] == 1 && spec.kernel_dims[1] == 2) { + // Check for outer dim. + return absl::StrFormat( + R"( + HloModule TensorFlowDepthwiseConv + + ENTRY main { + activation = %s[%s]{%s} parameter(0) + kernel = %s[%s]{%s} parameter(1) + ROOT conv = %s[%s]{%s} convolution(%s[%s]{%s} activation, %s[%s]{%s} kernel), + window={size=%dx%d pad=1_1x%d_%d rhs_dilate=1x%d}, dim_labels=b01f_01io->b01f, + feature_group_count=%d + } + )", + data_type, absl::StrJoin(spec.activation_dims, ","), + absl::StrJoin(spec.activation_layout, ","), data_type, + absl::StrJoin(spec.kernel_dims, ","), + absl::StrJoin(spec.kernel_layout, ","), data_type, + absl::StrJoin(spec.output_dims, ","), + absl::StrJoin(spec.output_layout, ","), data_type, + absl::StrJoin(spec.activation_dims, ","), + absl::StrJoin(spec.activation_layout, ","), data_type, + absl::StrJoin(spec.kernel_dims, ","), + absl::StrJoin(spec.kernel_layout, ","), spec.window, spec.window, + spec.window, spec.window, spec.window, spec.group_count); + + } else if (spec.stride == -1) { + // Check for basic, non-dilated cases. + return absl::StrFormat( + R"( + HloModule TensorFlowDepthwiseConv + + ENTRY main { + activation = %s[%s]{%s} parameter(0) + kernel = %s[%s]{%s} parameter(1) + ROOT conv = %s[%s]{%s} convolution(%s[%s]{%s} activation, %s[%s]{%s} kernel), + window={size=%dx%d}, dim_labels=b01f_01io->b01f, + feature_group_count=%d + } + )", + data_type, absl::StrJoin(spec.activation_dims, ","), + absl::StrJoin(spec.activation_layout, ","), data_type, + absl::StrJoin(spec.kernel_dims, ","), + absl::StrJoin(spec.kernel_layout, ","), data_type, + absl::StrJoin(spec.output_dims, ","), + absl::StrJoin(spec.output_layout, ","), data_type, + absl::StrJoin(spec.activation_dims, ","), + absl::StrJoin(spec.activation_layout, ","), data_type, + absl::StrJoin(spec.kernel_dims, ","), + absl::StrJoin(spec.kernel_layout, ","), spec.window, spec.window, + spec.group_count); + } else { + // Check for base dilations. + return absl::StrFormat( + R"( + HloModule TensorFlowDepthwiseConv + + ENTRY main { + activation = %s[%s]{%s} parameter(0) + kernel = %s[%s]{%s} parameter(1) + ROOT conv = %s[%s]{%s} convolution(%s[%s]{%s} activation, %s[%s]{%s} kernel), + window={size=%dx%d stride=%dx1 pad=%d_%dx0_0 lhs_dilate=%dx1}, + dim_labels=b01f_01io->b01f, feature_group_count=%d + } + )", + data_type, absl::StrJoin(spec.activation_dims, ","), + absl::StrJoin(spec.activation_layout, ","), data_type, + absl::StrJoin(spec.kernel_dims, ","), + absl::StrJoin(spec.kernel_layout, ","), data_type, + absl::StrJoin(spec.output_dims, ","), + absl::StrJoin(spec.output_layout, ","), data_type, + absl::StrJoin(spec.activation_dims, ","), + absl::StrJoin(spec.activation_layout, ","), data_type, + absl::StrJoin(spec.kernel_dims, ","), + absl::StrJoin(spec.kernel_layout, ","), spec.window, spec.window, + spec.stride, 0, 0, spec.lhs_dilate, spec.group_count); + } +} + +XLA_TEST_P(GroupedConvolution2DTest, DoIt) { + const GroupedConvolution2DSpec& spec = ::testing::get<0>(GetParam()); + bool use_bfloat16 = ::testing::get<1>(GetParam()); + const string hlo_text = BuildHloTextGroupedConvolution2D(spec, use_bfloat16); + + EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{0.01, 0.01}, + [](HloModule* module) -> Status { + BFloat16MixedPrecisionRemoval remover; + TF_RETURN_IF_ERROR(remover.Run(module).status()); + Despecializer despecializer; + return despecializer.Run(module).status(); + })); +} + +INSTANTIATE_TEST_CASE_P( + GroupedConvolution2DTestWithRandomIndices, GroupedConvolution2DTest, + ::testing::Combine(::testing::ValuesIn(GetConv2DTestCases()), + ::testing::Bool()), + GroupedConvolution2DTestDataToString); + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/tests/hlo_test_base.cc b/tensorflow/compiler/xla/tests/hlo_test_base.cc index 989a7c705a8254f99e5cc0e97dfde5942f146964..d57846e19bb80c5b9c87d50596da2915f9aef317 100644 --- a/tensorflow/compiler/xla/tests/hlo_test_base.cc +++ b/tensorflow/compiler/xla/tests/hlo_test_base.cc @@ -181,6 +181,7 @@ DebugOptions HloTestBase::GetDebugOptionsForTest() { // TODO(b/38354253): Change tests to use Parameters instead of Constants. debug_options.add_xla_disable_hlo_passes("constant_folding"); debug_options.set_xla_gpu_max_kernel_unroll_factor(1); + debug_options.set_xla_hlo_evaluator_use_fast_path(true); return debug_options; } diff --git a/tensorflow/compiler/xla/tests/iota_test.cc b/tensorflow/compiler/xla/tests/iota_test.cc index 65205f53ddc582ae477d67705f161fef1e31b857..37b2c635eebe57590e1ba73c62f015ccf399b548 100644 --- a/tensorflow/compiler/xla/tests/iota_test.cc +++ b/tensorflow/compiler/xla/tests/iota_test.cc @@ -80,7 +80,7 @@ TEST_P(IotaR2Test, DoIt) { } INSTANTIATE_TEST_CASE_P(IotaR2TestInstantiation, IotaR2Test, - ::testing::Combine(::testing::Values(F32, S32), + ::testing::Combine(::testing::Values(F32, S32, BF16), ::testing::Range(/*start=*/10, /*end=*/1001, /*step=*/10), diff --git a/tensorflow/compiler/xla/tests/literal_test_util_test.cc b/tensorflow/compiler/xla/tests/literal_test_util_test.cc index b6f9b8156b51144e4f74d285b1e4111d098f13c2..ea9b3037cf482e41238413179888f125822d161c 100644 --- a/tensorflow/compiler/xla/tests/literal_test_util_test.cc +++ b/tensorflow/compiler/xla/tests/literal_test_util_test.cc @@ -89,11 +89,11 @@ TEST(LiteralTestUtilTest, ExpectNearFailurePlacesResultsInTemporaryDirectory) { Literal literal = Literal::CreateFromProto(literal_proto).ConsumeValueOrDie(); if (result.find("expected") != string::npos) { - EXPECT_EQ("2", literal.ToString()); + EXPECT_EQ("f32[] 2", literal.ToString()); } else if (result.find("actual") != string::npos) { - EXPECT_EQ("4", literal.ToString()); + EXPECT_EQ("f32[] 4", literal.ToString()); } else if (result.find("mismatches") != string::npos) { - EXPECT_EQ("true", literal.ToString()); + EXPECT_EQ("pred[] true", literal.ToString()); } else { FAIL() << "unknown file in temporary directory: " << result; } @@ -105,9 +105,9 @@ TEST(LiteralTestUtilTest, NotEqualHasValuesInMessage) { auto actual = LiteralUtil::CreateR1({4, 5, 6}); ::testing::AssertionResult result = LiteralTestUtil::Equal(expected, actual); EXPECT_THAT(result.message(), - ::testing::HasSubstr("Expected literal:\n{1, 2, 3}")); + ::testing::HasSubstr("Expected literal:\ns32[3] {1, 2, 3}")); EXPECT_THAT(result.message(), - ::testing::HasSubstr("Actual literal:\n{4, 5, 6}")); + ::testing::HasSubstr("Actual literal:\ns32[3] {4, 5, 6}")); } TEST(LiteralTestUtilTest, NearComparatorR1) { diff --git a/tensorflow/compiler/xla/tests/test_utils_test.cc b/tensorflow/compiler/xla/tests/test_utils_test.cc index e8f5d7a9a79ebddea3cb989dbe8eab90b630d5e7..448a66cfdd897b17cce1c87c050520a2f2eb0ea2 100644 --- a/tensorflow/compiler/xla/tests/test_utils_test.cc +++ b/tensorflow/compiler/xla/tests/test_utils_test.cc @@ -61,11 +61,11 @@ XLA_TEST_F(TestUtilsTest, Token) { R"(HloModule outfeed_module ENTRY InfeedToOutfeed { - token = token[] parameter(0) - infeed = ((u32[3]{0}, pred[]), token[]) infeed(token) + token0 = token[] parameter(0) + infeed = ((u32[3]{0}, pred[]), token[]) infeed(token0) infeed.data = (u32[3]{0}, pred[]) get-tuple-element(infeed), index=0 - outfeed = token[] outfeed(infeed.data, token) - ROOT infeed.1 = ((u32[3]{0}, pred[]), token[]) infeed(token) + outfeed = token[] outfeed(infeed.data, token0) + ROOT infeed.1 = ((u32[3]{0}, pred[]), token[]) infeed(token0) infeed.1.data = (u32[3]{0}, pred[]) get-tuple-element(infeed.1), index=0 infeed.1.token = token[] get-tuple-element(infeed.1), index=1 outfeed.1 = token[] outfeed(infeed.1.data, infeed.1.token) diff --git a/tensorflow/compiler/xla/tests/token_hlo_test.cc b/tensorflow/compiler/xla/tests/token_hlo_test.cc index 601c6b06938fef1f1ae809b33209ae59b24c70a2..b77cf38ed8e29973985406015c0a3936916ad5e6 100644 --- a/tensorflow/compiler/xla/tests/token_hlo_test.cc +++ b/tensorflow/compiler/xla/tests/token_hlo_test.cc @@ -214,8 +214,8 @@ ENTRY %AddDependency (p0: f32[], p1: f32[]) -> f32[] { %forty_two = f32[] constant(42.0) %add = f32[] add(f32[] %p0, f32[] %forty_two) - %token = token[] after-all(f32[] %add) - %p1_after_token = f32[] add-dependency(f32[] %p1, token[] %token) + %token0 = token[] after-all(f32[] %add) + %p1_after_token = f32[] add-dependency(f32[] %p1, token[] %token0) %neg = f32[] negate(f32[] %p1_after_token) ROOT %product = f32[] multiply(f32[] %add, f32[] %neg) } @@ -236,8 +236,8 @@ HloModule AddDependencyOfConstant, is_scheduled=true ENTRY %AddDependency (p0: f32[]) -> f32[] { %p0 = f32[] parameter(0) %forty_two = f32[] constant(42.0) - %token = token[] after-all(f32[] %p0) - %forty_two_after_token = f32[] add-dependency(f32[] %forty_two, token[] %token) + %token0 = token[] after-all(f32[] %p0) + %forty_two_after_token = f32[] add-dependency(f32[] %forty_two, token[] %token0) ROOT %product = f32[] multiply(f32[] %p0, f32[] %forty_two_after_token) } )"; @@ -255,8 +255,8 @@ HloModule AddDependencyAsRoot, is_scheduled=true ENTRY %AddDependency (p: f32[3]) -> f32[3] { %p = f32[3] parameter(0) %neg = f32[3] negate(f32[3] %p) - %token = token[] after-all() - ROOT %add_dep = f32[3] add-dependency(f32[3] %neg, token[] %token) + %token0 = token[] after-all() + ROOT %add_dep = f32[3] add-dependency(f32[3] %neg, token[] %token0) } )"; TF_ASSERT_OK_AND_ASSIGN( @@ -274,9 +274,9 @@ ENTRY %TupleShapedAddDependency (p0: f32[3], p1: f32[3]) -> f32[3] { %p0 = f32[3] parameter(0) %p1 = f32[3] parameter(1) %forty_two = f32[] constant(42.0) - %token = token[] after-all() - %tuple = (f32[3], token[], f32[3], f32[]) tuple(f32[3] %p0, token[] %token, f32[3] %p1, f32[] %forty_two) - %add_dep = (f32[3], token[], f32[3], f32[]) add-dependency((f32[3], token[], f32[3], f32[]) %tuple, token[] %token) + %token0 = token[] after-all() + %tuple = (f32[3], token[], f32[3], f32[]) tuple(f32[3] %p0, token[] %token0, f32[3] %p1, f32[] %forty_two) + %add_dep = (f32[3], token[], f32[3], f32[]) add-dependency((f32[3], token[], f32[3], f32[]) %tuple, token[] %token0) %elem0 = f32[3] get-tuple-element((f32[3], token[], f32[3], f32[]) %add_dep), index=0 %elem2 = f32[3] get-tuple-element((f32[3], token[], f32[3], f32[]) %add_dep), index=2 ROOT %diff = f32[3] subtract(f32[3] %elem0, f32[3] %elem2) diff --git a/tensorflow/compiler/xla/tests/tuple_test.cc b/tensorflow/compiler/xla/tests/tuple_test.cc index 27ce243e9bd4afbdcc1fdc5b6873d4968086e459..9c586bdeb05afb7378e92caed1f3edc408e051bf 100644 --- a/tensorflow/compiler/xla/tests/tuple_test.cc +++ b/tensorflow/compiler/xla/tests/tuple_test.cc @@ -555,8 +555,8 @@ XLA_TEST_F(TupleHloTest, s = (f32[2],f32[2]) tuple-select(cond, tup0, tup1) gte = f32[2] get-tuple-element(s), index=0 tuple = (f32[2]) tuple(gte) - token = token[] after-all() - ROOT outfeed = token[] outfeed(tuple, token) + token0 = token[] after-all() + ROOT outfeed = token[] outfeed(tuple, token0) } )"; auto module = diff --git a/tensorflow/compiler/xla/text_literal_reader.cc b/tensorflow/compiler/xla/text_literal_reader.cc index cdde88c1359416d423685f330e9cbdf77948034f..c78ec522aa5f13556c6d4602267544694887f250 100644 --- a/tensorflow/compiler/xla/text_literal_reader.cc +++ b/tensorflow/compiler/xla/text_literal_reader.cc @@ -27,6 +27,7 @@ limitations under the License. #include "absl/strings/string_view.h" #include "absl/strings/strip.h" #include "tensorflow/compiler/xla/literal.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/types.h" @@ -66,7 +67,7 @@ StatusOr TextLiteralReader::ReadAllLines() { } absl::StripAsciiWhitespace(&shape_string); - TF_ASSIGN_OR_RETURN(Shape shape, ShapeUtil::ParseShapeString(shape_string)); + TF_ASSIGN_OR_RETURN(Shape shape, ParseShape(shape_string)); if (shape.element_type() != F32) { return Unimplemented( "unsupported element type for text literal reading: %s", diff --git a/tensorflow/compiler/xla/tools/replay_computation.cc b/tensorflow/compiler/xla/tools/replay_computation.cc index ff2c3399928c0e6339304323c4f93e212933a340..27a8dd13308b29da9a5013ac9f696613981d68bb 100644 --- a/tensorflow/compiler/xla/tools/replay_computation.cc +++ b/tensorflow/compiler/xla/tools/replay_computation.cc @@ -118,7 +118,12 @@ StatusOr ReplayComputation(const HloSnapshot& module, std::vector> global_data_arguments; std::vector argument_ptrs; if (opts.use_fake_data) { - global_data_arguments = MakeFakeArgumentsOrDie(computation, client); + // Run fake computations with debug options ignoring XLA_FLAGS. Users very + // likely want XLA_FLAGS only to apply to the "real" computation being run, + // not to the fake computations we use for generating arguments. + auto debug_opts = DefaultDebugOptionsIgnoringFlags(); + global_data_arguments = + MakeFakeArgumentsOrDie(computation, client, &debug_opts); for (const auto& data : global_data_arguments) { argument_ptrs.push_back( client->GlobalDataToShapedBuffer(data->handle(), /*device_ordinal=*/0) @@ -140,8 +145,7 @@ StatusOr ReplayComputation(const HloSnapshot& module, bool provide_infeed = false; Shape infeed_shape; if (!opts.fake_infeed_shape.empty()) { - StatusOr shape_status = - ShapeUtil::ParseShapeString(opts.fake_infeed_shape); + StatusOr shape_status = ParseShape(opts.fake_infeed_shape); TF_CHECK_OK(shape_status.status()); infeed_shape = std::move(shape_status).ValueOrDie(); provide_infeed = true; diff --git a/tensorflow/compiler/xla/xla.proto b/tensorflow/compiler/xla/xla.proto index f745fb850655edaba8c95ba0cd3af3cc765b99e6..0e8fa73f8170addfa5061b33f3d6882a13890bce 100644 --- a/tensorflow/compiler/xla/xla.proto +++ b/tensorflow/compiler/xla/xla.proto @@ -100,6 +100,14 @@ message DebugOptions { // names as specified by the HloPassInterface::name() method. repeated string xla_disable_hlo_passes = 30; + // Disables all HLO passes. Notes that some passes are necessary for + // correctness and the invariants that must be satisfied by "fully optimized" + // HLO are different for different devices and may change over time. The only + // "guarantee", such as it is, is that if you compile XLA and dump the + // optimized HLO for some graph, you should be able to run it again on the + // same device with the same build of XLA. + bool xla_disable_all_hlo_passes = 104; + // Numerical optimization level for the XLA compiler backend; the specific // interpretation of this value is left to the backends. int32 xla_backend_optimization_level = 31; @@ -193,7 +201,11 @@ message DebugOptions { // - Assuming that operations never produce or consume NaN or +/- Inf. // - Assuming that +0 and -0 are indistinguishable. bool xla_cpu_enable_fast_math = 99; - bool xla_gpu_enable_fast_math = 100; + + // When true we lower the Minimum and Maximum hlos in the GPU backend such + // that Min(NotNaN, NaN) = Min(NaN, NotNaN) = NotNaN. In other words, if flag + // this is true we don't propagate NaNs through Min and Max. + bool xla_gpu_enable_fast_min_max = 100; // Crashes the program when any kind of verification fails, instead of just // logging the failures. One example is cross checking of convolution results @@ -209,6 +221,17 @@ message DebugOptions { // the host that run models in parallel across multiple devices. int32 xla_force_host_platform_device_count = 102; + // If set to true XLA:GPU invokes `ptxas` with -O0 (default is -O3). + bool xla_gpu_disable_ptxas_optimizations = 103; + + // Dump HLO graphs as an HTML (DOT -> SVG inlined in HTML) + bool xla_hlo_dump_as_html = 105; + + // Enable fast math with eigen in the HLO evaluator. + bool xla_hlo_evaluator_use_fast_path = 106; + + // Next id: 107 + // Extra options to pass to the compilation backend (e.g. LLVM); specific // interpretation of these values is left to the backend. map xla_backend_extra_options = 500; @@ -382,7 +405,7 @@ message WaitForExecutionResponse { message ComputeConstantGraphRequest { HloModuleProto computation = 1; - Layout output_layout = 2; + LayoutProto output_layout = 2; } message ComputeConstantResponse { diff --git a/tensorflow/compiler/xla/xla_data.proto b/tensorflow/compiler/xla/xla_data.proto index 85ec83437a10d973687a7fb84285c2e2541a53c7..e9c86abe5094244988d3465ef7c949509deaec37 100644 --- a/tensorflow/compiler/xla/xla_data.proto +++ b/tensorflow/compiler/xla/xla_data.proto @@ -100,6 +100,8 @@ message PaddingConfig { // A format specifies the method used by a layout to store an array in memory. enum Format { + // TODO(b/120869032): Rename this to FORMAT_NONE or something else which + // better corresponds to its meaning. INVALID_FORMAT = 0; // The default layout, with exactly one storage location per element. DENSE = 1; @@ -109,8 +111,9 @@ enum Format { } // Describes a tile used in tiling-based layout. Refer to -// g3doc/layout_with_tiling.md for details about tiling-based layout. -message Tile { +// g3doc/third_party/tensorflow/compiler/xla/g3doc/layout_with_tiling.md for +// details about tiling-based layout. +message TileProto { // Number of elements in each dimension of the tile. It's ordered from the // most major dimension of the tile to the most minor dimension of the tile. // The dimensions correspond to a suffix of the dimensions of the shape being @@ -128,7 +131,7 @@ message Tile { // See the XLA documentation for more information on shapes and layouts. // // LINT.IfChange -message Layout { +message LayoutProto { // The method used to store the data in memory. The format determines which of // the other fields are used by the layout. Format format = 4; @@ -153,7 +156,7 @@ message Layout { // // TODO(b/119839262): implement tiling in each backend or add Unimplemented // error. - repeated Tile tiles = 6; + repeated TileProto tiles = 6; // Bit size of each element. If the size is bigger than what the element // type requires, the value is stored in the least significant @@ -196,7 +199,7 @@ message ShapeProto { repeated ShapeProto tuple_shapes = 4; // The layout used to back this shape. - Layout layout = 5; + LayoutProto layout = 5; // Important: if any field is added, be sure to modify ShapeUtil::Equal(), // ShapeUtil::Compatible() and ShapeUtil::Hash() appropriately to account for diff --git a/tensorflow/compiler/xrt/kernels/xrt_execute_op.cc b/tensorflow/compiler/xrt/kernels/xrt_execute_op.cc index 8c6191ddc06ea7d85f5fd21a7d4058c669ffdeb2..751329eefc33f3372335c805233dafabbf42bf36 100644 --- a/tensorflow/compiler/xrt/kernels/xrt_execute_op.cc +++ b/tensorflow/compiler/xrt/kernels/xrt_execute_op.cc @@ -228,14 +228,35 @@ Status XRTExecuteOp::DoWork(OpKernelContext* context) { TF_RETURN_IF_ERROR(XRTTupleAllocation::CreateFromBuffer( shaped_buffer, device_ref.backend(), device_ref.device_ordinal(), &output_tuple)); - - Tensor* output_tensor; - TF_RETURN_IF_ERROR( - context->allocate_output(0, TensorShape({}), &output_tensor)); - int64 key; - TF_RETURN_IF_ERROR(output_tuple->Intern(rm, &key)); - output_tensor->scalar()() = key; - + if (config_proto.return_exploded_tuple() && + xla::ShapeUtil::IsTuple(output_tuple->on_device_shape())) { + int64 tuple_element_count = + xla::ShapeUtil::TupleElementCount(output_tuple->on_device_shape()); + Tensor* output_tensor; + TF_RETURN_IF_ERROR(context->allocate_output( + 0, TensorShape({tuple_element_count}), &output_tensor)); + + for (int64 i = 0; i < tuple_element_count; ++i) { + xla::ShapeIndex shape_index; + shape_index.push_back(i); + + XRTTupleAllocation* suballocation; + TF_RETURN_IF_ERROR(XRTTupleAllocation::MakeSubBuffer( + output_tuple, shape_index, &suballocation, + /*alias_parent_allocation=*/false)); + int64 key; + TF_RETURN_IF_ERROR(suballocation->Intern(rm, &key)); + output_tensor->vec()(i) = key; + } + output_tuple->Unref(); + } else { + Tensor* output_tensor; + TF_RETURN_IF_ERROR( + context->allocate_output(0, TensorShape({}), &output_tensor)); + int64 key; + TF_RETURN_IF_ERROR(output_tuple->Intern(rm, &key)); + output_tensor->scalar()() = key; + } return Status::OK(); } diff --git a/tensorflow/compiler/xrt/kernels/xrt_state_ops.cc b/tensorflow/compiler/xrt/kernels/xrt_state_ops.cc index ffea592491d43788b876a51866dc8a6611e8c734..1a5bfac337baf773b84b92af5f88ef7a4c8ba81f 100644 --- a/tensorflow/compiler/xrt/kernels/xrt_state_ops.cc +++ b/tensorflow/compiler/xrt/kernels/xrt_state_ops.cc @@ -87,6 +87,19 @@ REGISTER_KERNEL_BUILDER(Name("XRTReadLiteral") .HostMemory("literal"), XRTReadLiteralOp); +REGISTER_KERNEL_BUILDER(Name("XRTWriteLiteral") + .Device(DEVICE_XLA_GPU) + .HostMemory("handle") + .HostMemory("literal") + .HostMemory("output_handle"), + XRTWriteLiteralOp); +REGISTER_KERNEL_BUILDER(Name("XRTWriteLiteral") + .Device(DEVICE_XLA_CPU) + .HostMemory("handle") + .HostMemory("literal") + .HostMemory("output_handle"), + XRTWriteLiteralOp); + REGISTER_KERNEL_BUILDER(Name("XRTReadLiteralAndRelease") .Device(DEVICE_XLA_GPU) .HostMemory("handle") @@ -107,4 +120,9 @@ REGISTER_KERNEL_BUILDER(Name("XRTReleaseAllocationHandle") .HostMemory("handle"), XRTReleaseAllocationOp); +REGISTER_KERNEL_BUILDER(Name("XRTReleaseAllAllocations").Device(DEVICE_XLA_GPU), + XRTReleaseAllAllocationsOp); +REGISTER_KERNEL_BUILDER(Name("XRTReleaseAllAllocations").Device(DEVICE_XLA_CPU), + XRTReleaseAllAllocationsOp); + } // namespace tensorflow diff --git a/tensorflow/compiler/xrt/kernels/xrt_state_ops.h b/tensorflow/compiler/xrt/kernels/xrt_state_ops.h index 54b06558adcd8ef1f8f1bee52d210d558801afea..e3b292e7907bfb82f1efc8ed0f27462c682848ce 100644 --- a/tensorflow/compiler/xrt/kernels/xrt_state_ops.h +++ b/tensorflow/compiler/xrt/kernels/xrt_state_ops.h @@ -393,6 +393,56 @@ class XRTReadLiteralOp : public OpKernel { } }; +// Op that writes a new literal value into device-resident memory. +template +class XRTWriteLiteralOp : public OpKernel { + public: + explicit XRTWriteLiteralOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} + ~XRTWriteLiteralOp() override = default; + XRTWriteLiteralOp(const XRTWriteLiteralOp&) = delete; + XRTWriteLiteralOp& operator=(const XRTWriteLiteralOp&) = delete; + + void Compute(OpKernelContext* ctx) override { + VLOG(1) << "XRTWriteLiteralOp::Compute"; + + const Tensor& handle_tensor = ctx->input(0); + OP_REQUIRES( + ctx, TensorShapeUtils::IsScalar(handle_tensor.shape()), + errors::Internal("computation input should be an int64 scalar")); + int64 allocation_handle = handle_tensor.scalar()(); + + const Tensor& literal_info = ctx->input(1); + OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(literal_info.shape()), + errors::Internal("literal input should be a string scalar")); + xla::LiteralProto literal_proto; + OP_REQUIRES(ctx, + literal_proto.ParseFromString(literal_info.scalar()()), + errors::InvalidArgument( + "Unable to parse allocation input to LiteralProto")); + xla::Literal literal; + OP_REQUIRES_OK(ctx, XRTStateHelpers::MakeLiteral(literal_proto, &literal)); + + ResourceMgr* rm; + OP_REQUIRES_OK(ctx, DeviceAccessor::GetResourceManager(ctx, &rm)); + + XRTTupleAllocation* allocation; + OP_REQUIRES_OK( + ctx, XRTTupleAllocation::Lookup(rm, allocation_handle, &allocation)); + core::ScopedUnref allocation_unref(allocation); + // We are guaranteed that the underlying device object won't be deleted out + // from under us, while the ScopedRef is live. + typename DeviceAccessor::ScopedRef device_ref; + OP_REQUIRES_OK(ctx, DeviceAccessor::InitScopedRef( + ctx, allocation->device_ordinal(), &device_ref)); + OP_REQUIRES_OK(ctx, + allocation->WriteLiteral(device_ref.backend(), literal)); + + Tensor output(DT_INT64, TensorShape({})); + output.scalar()() = allocation_handle; + ctx->set_output(0, output); + } +}; + // Op that discards a handle to device memory. template class XRTReleaseAllocationOp : public OpKernel { @@ -419,6 +469,26 @@ class XRTReleaseAllocationOp : public OpKernel { } }; +// Op that discards a handle to device memory. +template +class XRTReleaseAllAllocationsOp : public OpKernel { + public: + explicit XRTReleaseAllAllocationsOp(OpKernelConstruction* ctx) + : OpKernel(ctx) {} + ~XRTReleaseAllAllocationsOp() override = default; + XRTReleaseAllAllocationsOp(const XRTReleaseAllAllocationsOp&) = delete; + XRTReleaseAllAllocationsOp& operator=(const XRTReleaseAllAllocationsOp&) = + delete; + + void Compute(OpKernelContext* ctx) override { + VLOG(1) << "XRTReleaseAllAllocationsOp::Compute"; + + ResourceMgr* rm; + OP_REQUIRES_OK(ctx, DeviceAccessor::GetResourceManager(ctx, &rm)); + OP_REQUIRES_OK(ctx, XRTTupleAllocation::ReleaseAllAllocations(rm)); + } +}; + } // namespace tensorflow #endif // TENSORFLOW_COMPILER_XRT_KERNELS_XRT_STATE_OPS_H_ diff --git a/tensorflow/compiler/xrt/ops/xrt_state_ops.cc b/tensorflow/compiler/xrt/ops/xrt_state_ops.cc index 07d025ce343f229097b557d33ad41bf9612b0696..fe6bee0dacf5dc2050613fc9ad34d3235b5a7b63 100644 --- a/tensorflow/compiler/xrt/ops/xrt_state_ops.cc +++ b/tensorflow/compiler/xrt/ops/xrt_state_ops.cc @@ -95,6 +95,20 @@ Copies an allocated tuple from device memory and returns it as a literal. 'literal' is a serialized xla::LiteralProto proto. )"); +REGISTER_OP("XRTWriteLiteral") + .Input("handle: int64") + .Input("literal: string") + .Output("output_handle: int64") + .SetShapeFn(tensorflow::shape_inference::ScalarShape) + .Doc( + R"( +Copies the input literal into the device memory pointed to by handle. +Returns the handle itself. + +'handle' is the id returned from the Op that produced the on-device allocation. +'literal' is a serialized xla::LiteralProto proto to be written to device memory. +)"); + REGISTER_OP("XRTReadLiteralAndRelease") .Input("handle: int64") .Output("literal: string") @@ -119,4 +133,11 @@ used. 'handle' is the id returned from the Op that produced the on-device allocation. )"); +REGISTER_OP("XRTReleaseAllAllocations") + .SetShapeFn(tensorflow::shape_inference::NoOutputs) + .Doc( + R"( +Discards all the XRT allocations. All the client held handles will be invalid. +)"); + } // namespace tensorflow diff --git a/tensorflow/compiler/xrt/tests/raw_api_test.cc b/tensorflow/compiler/xrt/tests/raw_api_test.cc index b9262c1843a7ae48af49acbef5ba4ef58ec0f050..730a2271677c91afecaf252f4a3d1a989a1ccfba 100644 --- a/tensorflow/compiler/xrt/tests/raw_api_test.cc +++ b/tensorflow/compiler/xrt/tests/raw_api_test.cc @@ -102,7 +102,7 @@ bool CompareLiteralProtos(const xla::LiteralProto& a, auto l_b = xla::Literal::CreateFromProto(b).ValueOrDie(); bool equal = l_a == l_b; if (!equal) { - LOG(INFO) << "LiteralProtos don't match " << a.DebugString() + LOG(INFO) << "LiteralProtos don't match: " << a.DebugString() << " != " << b.DebugString(); } return equal; @@ -175,6 +175,18 @@ xla::XlaComputation AddAndTuple() { return builder.Build().ValueOrDie(); } +xla::XlaComputation AddAndSubTuple() { + xla::XlaBuilder builder("AddAndSubTuple"); + auto p0 = xla::Parameter(&builder, 0, xla::ShapeUtil::MakeShape(xla::F32, {}), + "P0"); + auto p1 = xla::Parameter(&builder, 1, xla::ShapeUtil::MakeShape(xla::F32, {}), + "P1"); + auto sum = xla::Add(p0, p1); + auto sub = xla::Sub(p0, p1); + xla::Tuple(&builder, {sum, sub}); + return builder.Build().ValueOrDie(); +} + void StoreComputationSnapshot(const xla::XlaComputation& computation, xla::HloSnapshot* dst) { auto snapshot = computation.Snapshot().ValueOrDie(); @@ -203,6 +215,87 @@ xla::ProgramShape XlaCompiledProgramShape( ->ComputeProgramShape(); } +TEST(RawApiTest, AllocAndRewrite) { + xrt::XLAAllocation alloc; + alloc.set_device_ordinal(0); + *alloc.mutable_value() = + xla::LiteralUtil::CreateR2({{4, 5}, {6, 7}}).ToProto(); + + Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag()); + auto value = + ops::Const(root.WithDevice("/device:CPU:0"), alloc.SerializeAsString()); + auto handle = ops::XRTAllocate(root, value); + auto read_back = ops::XRTReadLiteral(root, handle); + TF_ASSERT_OK(root.status()); + + tensorflow::ClientSession session(root); + std::vector outputs; + TF_EXPECT_OK(session.Run({read_back, handle}, &outputs)); + EXPECT_EQ(outputs.size(), 2); + + int64 allocation_handle = outputs[1].scalar()(); + xla::LiteralProto response; + EXPECT_TRUE(response.ParseFromString(outputs[0].scalar()())); + EXPECT_TRUE(CompareLiteralProtos(alloc.value(), response)); + outputs.clear(); + + xla::LiteralProto new_literal = + xla::LiteralUtil::CreateR2({{9, 2}, {4, 1}}).ToProto(); + auto new_value = ops::Const(root.WithDevice("/device:CPU:0"), + new_literal.SerializeAsString()); + auto write_op = + ops::XRTWriteLiteral(root, Input(allocation_handle), new_value); + TF_ASSERT_OK(root.status()); + TF_EXPECT_OK(session.Run({write_op}, &outputs)); + EXPECT_EQ(outputs.size(), 1); + EXPECT_EQ(allocation_handle, outputs[0].scalar()()); + outputs.clear(); + + auto read_after_write = ops::XRTReadLiteral(root, Input(allocation_handle)); + TF_EXPECT_OK(session.Run({read_after_write}, &outputs)); + EXPECT_EQ(outputs.size(), 1); + + xla::LiteralProto new_response; + EXPECT_TRUE(new_response.ParseFromString(outputs[0].scalar()())); + EXPECT_TRUE(CompareLiteralProtos(new_literal, new_response)); + + auto release = + ops::XRTReleaseAllocationHandle(root, Input(allocation_handle)); + TF_EXPECT_OK(session.Run(tensorflow::ClientSession::FeedType(), {}, {release}, + &outputs)); +} + +TEST(RawApiTest, AllocAndClearAll) { + xrt::XLAAllocation alloc; + alloc.set_device_ordinal(0); + *alloc.mutable_value() = + xla::LiteralUtil::CreateR2({{4, 5}, {6, 7}}).ToProto(); + + Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag()); + auto value = + ops::Const(root.WithDevice("/device:CPU:0"), alloc.SerializeAsString()); + auto handle = ops::XRTAllocate(root, value); + TF_ASSERT_OK(root.status()); + + tensorflow::ClientSession session(root); + std::vector outputs; + TF_EXPECT_OK(session.Run({handle}, &outputs)); + EXPECT_EQ(outputs.size(), 1); + + int64 allocation_handle = outputs[0].scalar()(); + + auto clear_all = ops::XRTReleaseAllAllocations(root); + + outputs.clear(); + TF_EXPECT_OK(session.Run(tensorflow::ClientSession::FeedType(), {}, + {clear_all}, &outputs)); + EXPECT_EQ(outputs.size(), 0); + + auto read_after_clear = ops::XRTReadLiteral(root, Input(allocation_handle)); + EXPECT_EQ(session.Run({read_after_clear}, &outputs).code(), + tensorflow::error::Code::NOT_FOUND); +} + TEST(RawApiTest, ReadAndWriteState) { xrt::XLAAllocation alloc; alloc.set_device_ordinal(0); @@ -681,6 +774,70 @@ TEST(RawApiTest, CompileAndExecuteReturnTuple) { EXPECT_TRUE(CompareLiteralToLiteralProto(expected, response)); } +TEST(RawApiTest, CompileAndExecuteReturnExplodedTuple) { + xrt::XLAAllocation p0; + p0.set_device_ordinal(0); + *p0.mutable_value() = xla::LiteralUtil::CreateR0(12.0f).ToProto(); + + xrt::XLAAllocation p1; + p1.set_device_ordinal(0); + *p1.mutable_value() = xla::LiteralUtil::CreateR0(3.0f).ToProto(); + + xrt::XLAComputation c; + auto config = c.mutable_config(); + auto shapes = config->mutable_program_shape(); + *shapes->add_parameters() = xla::ShapeUtil::MakeShape(xla::F32, {}).ToProto(); + *shapes->add_parameters() = xla::ShapeUtil::MakeShape(xla::F32, {}).ToProto(); + *shapes->mutable_result() = + xla::ShapeUtil::MakeTupleShape({xla::ShapeUtil::MakeShape(xla::F32, {}), + xla::ShapeUtil::MakeShape(xla::F32, {})}) + .ToProto(); + StoreComputationSnapshot(AddAndSubTuple(), c.mutable_hlo_snapshot()); + + xrt::XRTExecutionConfig e; + e.set_release_input_handles(true); + e.set_release_compilation_handle(true); + e.set_return_exploded_tuple(true); + + Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag()); + auto e_config = + ops::Const(root.WithDevice("/device:CPU:0"), e.SerializeAsString()); + auto computation = + ops::Const(root.WithDevice("/device:CPU:0"), c.SerializeAsString()); + auto c_handle = ops::XRTCompile(root, computation); + auto p0_value = + ops::Const(root.WithDevice("/device:CPU:0"), p0.SerializeAsString()); + auto p0_handle = ops::XRTAllocate(root, p0_value); + auto p1_value = + ops::Const(root.WithDevice("/device:CPU:0"), p1.SerializeAsString()); + auto p1_handle = ops::XRTAllocate(root, p1_value); + auto result = ops::XRTExecute(root, c_handle.handle, e_config, + {Output(p0_handle), Output(p1_handle)}); + TF_ASSERT_OK(root.status()); + + ClientSession session(root); + std::vector outputs; + TF_EXPECT_OK(session.Run({result}, &outputs)); + EXPECT_EQ(outputs.size(), 1); + + auto handles_vec = outputs.front().vec(); + EXPECT_EQ(handles_vec.size(), 2); + + const float kResults[2] = {15.0f, 9.0f}; + for (int64 i = 0; i < handles_vec.size(); ++i) { + auto read_back = ops::XRTReadLiteralAndRelease(root, Input(handles_vec(i))); + std::vector voutputs; + TF_EXPECT_OK(session.Run({read_back}, &voutputs)); + EXPECT_EQ(voutputs.size(), 1); + + xla::LiteralProto response; + EXPECT_TRUE(response.ParseFromString(voutputs[0].scalar()())); + + auto expected = xla::LiteralUtil::CreateR0(kResults[i]); + EXPECT_TRUE(CompareLiteralToLiteralProto(expected, response)); + } +} + TEST(RawApiTest, LeakCompilationReference) { xrt::XLAComputation c; auto config = c.mutable_config(); diff --git a/tensorflow/compiler/xrt/xrt.proto b/tensorflow/compiler/xrt/xrt.proto index e149f2f43593ea412ef279b2c99dabac285cdac4..378bb9246f27b8106310d565435404d7ac260a87 100644 --- a/tensorflow/compiler/xrt/xrt.proto +++ b/tensorflow/compiler/xrt/xrt.proto @@ -101,4 +101,8 @@ message XRTExecutionConfig { bool release_input_handles = 5; // If true, release the handle to the computation after running. bool release_compilation_handle = 6; + // If set to true, and the result shape is a tuple, then instead of returning + // a single tuple allocation the execution will return a vector of + // allocations, one for each of the first-level elements of the result tuple. + bool return_exploded_tuple = 7; } diff --git a/tensorflow/compiler/xrt/xrt_state.cc b/tensorflow/compiler/xrt/xrt_state.cc index 3a99820d7aa9e9546cc95385fd98c05f28988e9e..343460ff107fa81be127950837f786fe4eeadf26 100644 --- a/tensorflow/compiler/xrt/xrt_state.cc +++ b/tensorflow/compiler/xrt/xrt_state.cc @@ -19,6 +19,7 @@ limitations under the License. #include "tensorflow/compiler/xrt/xrt_state.h" #include +#include #include #include #include @@ -34,6 +35,7 @@ limitations under the License. #include "tensorflow/core/framework/resource_mgr.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/random/random.h" +#include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/platform/types.h" #include "tensorflow/stream_executor/stream_executor.h" @@ -41,6 +43,34 @@ namespace tensorflow { namespace { +class BufferAllocStats { + public: + struct Stats { + int64 count = 0; + int64 size = 0; + }; + + Stats ReportAlloc(int64 device, int64 msize) { + mutex_lock lock(lock_); + Stats* device_stats = &stats_[device]; + device_stats->count += 1; + device_stats->size += msize; + return *device_stats; + } + + Stats ReportFree(int64 device, int64 msize) { + mutex_lock lock(lock_); + Stats* device_stats = &stats_[device]; + device_stats->count -= 1; + device_stats->size -= msize; + return *device_stats; + } + + private: + mutable mutex lock_; + std::map stats_; +}; + const char* kTupleContainer = "tuples"; int64 get_uid() { @@ -48,6 +78,11 @@ int64 get_uid() { return static_cast(unsigned_rand); } +BufferAllocStats* GetAllocStats() { + static BufferAllocStats* stats = new BufferAllocStats(); + return stats; +} + Status AllocateScopedShapedBuffer( xla::Backend* backend, int device_ordinal, const xla::Shape& shape, std::unique_ptr* buffer) { @@ -100,9 +135,19 @@ XRTBufferAllocation::XRTBufferAllocation(const se::DeviceMemoryBase& allocation, xla::DeviceMemoryAllocator* allocator) : allocation_(allocation), device_ordinal_(device_ordinal), - allocator_(allocator) {} + allocator_(allocator) { + if (VLOG_IS_ON(2)) { + auto stats = + GetAllocStats()->ReportAlloc(device_ordinal_, allocation_.size()); + LOG(INFO) << "XRT Allocation Stats: device=" << device_ordinal_ + << " count=" << stats.count << " size=" << stats.size; + } +} XRTBufferAllocation::~XRTBufferAllocation() { + if (VLOG_IS_ON(2)) { + GetAllocStats()->ReportFree(device_ordinal_, allocation_.size()); + } // Deallocate explicitly allows allocation_ to be null. Status s = allocator_->Deallocate(device_ordinal_, allocation_); // Nothing to do but check fail here if memory datastructures are corrupted. @@ -183,6 +228,20 @@ Status XRTTupleAllocation::ToLiteral(xla::Backend* backend, int device_ordinal, return Status::OK(); } +Status XRTTupleAllocation::WriteLiteral(xla::Backend* backend, + const xla::Literal& literal) { + if (!xla::ShapeUtil::Equal(literal.shape(), on_host_shape())) { + return errors::InvalidArgument( + "New literal shape not matching the existing one: literal=", + xla::ShapeUtil::HumanStringWithLayout(literal.shape()), + " device=", xla::ShapeUtil::HumanStringWithLayout(on_host_shape())); + } + auto transfer_manager = backend->transfer_manager(); + TF_ASSIGN_OR_RETURN(auto stream, backend->BorrowStream(device_ordinal())); + return transfer_manager->TransferLiteralToDevice(stream.get(), literal, + ToShapedBuffer()); +} + void XRTTupleAllocation::DiscardAllocation( const xla::ShapeIndex& buffer_index) { buffers_.element(buffer_index)->DiscardAllocation(); @@ -213,6 +272,11 @@ const se::DeviceMemoryBase& XRTTupleAllocation::root_allocation() { return rm->Delete(kTupleContainer, key_string); } +/* static */ Status XRTTupleAllocation::ReleaseAllAllocations(ResourceMgr* rm) { + VLOG(1) << "Releasing all XRT held device memory"; + return rm->Cleanup(kTupleContainer); +} + // Helper typedef to make ShapeTree ForEach helper lambda signatures more // readable. They need a type of const T& where in this case T is the // following pointer. diff --git a/tensorflow/compiler/xrt/xrt_state.h b/tensorflow/compiler/xrt/xrt_state.h index 73b5584e38f781343fe6793af7ad28232fbfc184..3e3d5024124e13b87eed6f79596d50cd64325914 100644 --- a/tensorflow/compiler/xrt/xrt_state.h +++ b/tensorflow/compiler/xrt/xrt_state.h @@ -129,6 +129,10 @@ class XRTTupleAllocation : public ResourceBase { // Deletes the reference in the rm to an allocation interned under key. static Status DeleteFromResourceManager(ResourceMgr* rm, int64 key); + // Releases all the device memory allocated by XRT within the resource + // manager. + static Status ReleaseAllAllocations(ResourceMgr* rm); + // Adds the allocation to a ResourceMgr and returns the key that will be used // to retrieve it. Transfers a reference on *this to rm. Status Intern(ResourceMgr* rm, int64* key); @@ -137,6 +141,9 @@ class XRTTupleAllocation : public ResourceBase { Status ToLiteral(xla::Backend* backend, int device_ordinal, xla::Literal* literal); + // Write a new literal value to the allocation. + Status WriteLiteral(xla::Backend* backend, const xla::Literal& literal); + // True if none of the buffers in the allocation are aliased by any other live // handle. bool IsExclusiveOwner(); diff --git a/tensorflow/contrib/android/cmake/build.gradle b/tensorflow/contrib/android/cmake/build.gradle index 17a57b99fd6c9efc09bda0ce1249b1f51bd5af5c..ddec08894f34f96b080610f1d27a6a436f7ffa91 100644 --- a/tensorflow/contrib/android/cmake/build.gradle +++ b/tensorflow/contrib/android/cmake/build.gradle @@ -22,8 +22,8 @@ android { } externalNativeBuild { cmake { - arguments '-DANDROID_TOOLCHAIN=gcc', - '-DANDROID_STL=gnustl_static' + arguments '-DANDROID_TOOLCHAIN=clang', + '-DANDROID_STL=c++_static' } } } @@ -70,7 +70,7 @@ if (ndkDir == null || ndkDir == "") { ndkDir = System.getenv('ANDROID_NDK_HOME') } -if(! Os.isFamily(Os.FAMILY_WINDOWS)) { +if (!Os.isFamily(Os.FAMILY_WINDOWS)) { // This script is for non-Windows OS. For Windows OS, MANUALLY build // (or copy the built) libs/headers to the // ${TENSORFLOW_ROOT_DIR}/tensorflow/contrib/makefile/gen diff --git a/tensorflow/contrib/bigtable/README.md b/tensorflow/contrib/bigtable/README.md index 2c44abed5e1955cc666273e97e6b2378766f13d2..79052bee35c7895cb4048b10c1f73acb036d1587 100644 --- a/tensorflow/contrib/bigtable/README.md +++ b/tensorflow/contrib/bigtable/README.md @@ -51,25 +51,18 @@ BIGTABLE_TABLE_NAME = '' PREFIX = 'train-' def main(): + tf.enable_eager_execution() + client = tf.contrib.cloud.BigtableClient(GCP_PROJECT_ID, BIGTABLE_INSTANCE_ID) table = client.table(BIGTABLE_TABLE_NAME) dataset = table.keys_by_prefix_dataset(PREFIX) - iterator = dataset.make_initializable_iterator() - get_next_op = iterator.get_next() - with tf.Session() as sess: - print('Initializing the iterator.') - sess.run(iterator.initializer) - print('Retrieving rows:') - row_index = 0 - while True: - try: - row_key = sess.run(get_next_op) - print('Row key %d: %s' % (row_index, row_key)) - row_index += 1 - except tf.errors.OutOfRangeError: - print('Finished reading data!') - break + print('Retrieving rows:') + row_index = 0 + for row_key in dataset: + print('Row key %d: %s' % (row_index, row_key)) + row_index += 1 + print('Finished reading data!') if __name__ == '__main__': main() diff --git a/tensorflow/contrib/bigtable/kernels/test_kernels/bigtable_test_client.cc b/tensorflow/contrib/bigtable/kernels/test_kernels/bigtable_test_client.cc index e95dc577184f7e81d942755b41065f52131ce9f6..3fe71a2ea730cc9b60b2e2088a0d80a08b38d1a9 100644 --- a/tensorflow/contrib/bigtable/kernels/test_kernels/bigtable_test_client.cc +++ b/tensorflow/contrib/bigtable/kernels/test_kernels/bigtable_test_client.cc @@ -399,6 +399,17 @@ BigtableTestClient::AsyncMutateRows( return nullptr; } +std::unique_ptr> +BigtableTestClient::AsyncCheckAndMutateRow( + grpc::ClientContext* context, + const google::bigtable::v2::CheckAndMutateRowRequest& request, + grpc::CompletionQueue* cq) { + LOG(WARNING) << "Call to InMemoryDataClient::" << __func__ + << "(); this will likely cause a crash!"; + return nullptr; +} + std::shared_ptr BigtableTestClient::Channel() { LOG(WARNING) << "Call to InMemoryDataClient::Channel(); this will likely " "cause a crash!"; diff --git a/tensorflow/contrib/bigtable/kernels/test_kernels/bigtable_test_client.h b/tensorflow/contrib/bigtable/kernels/test_kernels/bigtable_test_client.h index c4a1f06bc504c3565c7bb09b42e48e7fbddb9cc6..85705904573e9e7710912e3f4ff30dd8fed5bf85 100644 --- a/tensorflow/contrib/bigtable/kernels/test_kernels/bigtable_test_client.h +++ b/tensorflow/contrib/bigtable/kernels/test_kernels/bigtable_test_client.h @@ -80,6 +80,13 @@ class BigtableTestClient : public ::google::cloud::bigtable::DataClient { const ::google::bigtable::v2::MutateRowsRequest& request, ::grpc::CompletionQueue* cq, void* tag) override; + std::unique_ptr> + AsyncCheckAndMutateRow( + grpc::ClientContext* context, + const google::bigtable::v2::CheckAndMutateRowRequest& request, + grpc::CompletionQueue* cq) override; + std::shared_ptr Channel() override; private: diff --git a/tensorflow/contrib/bigtable/python/kernel_tests/bigtable_ops_test.py b/tensorflow/contrib/bigtable/python/kernel_tests/bigtable_ops_test.py index 316da9ebe152ef52c7e7f846cf8c3eb1555ee8a6..197f5578eb010bee5a3aad7c05446393193f99e2 100644 --- a/tensorflow/contrib/bigtable/python/kernel_tests/bigtable_ops_test.py +++ b/tensorflow/contrib/bigtable/python/kernel_tests/bigtable_ops_test.py @@ -57,7 +57,7 @@ class BigtableOpsTest(test.TestCase): sess.run(write_op) def runReadKeyTest(self, read_ds): - itr = read_ds.make_initializable_iterator() + itr = dataset_ops.make_initializable_iterator(read_ds) n = itr.get_next() expected = list(self.COMMON_ROW_KEYS) expected.reverse() @@ -78,7 +78,7 @@ class BigtableOpsTest(test.TestCase): self.runReadKeyTest(self._table.keys_by_range_dataset("r1", "r4")) def runScanTest(self, read_ds): - itr = read_ds.make_initializable_iterator() + itr = dataset_ops.make_initializable_iterator(read_ds) n = itr.get_next() expected_keys = list(self.COMMON_ROW_KEYS) expected_keys.reverse() @@ -120,7 +120,7 @@ class BigtableOpsTest(test.TestCase): def testLookup(self): ds = self._table.keys_by_prefix_dataset("r") ds = ds.apply(self._table.lookup_columns(cf1="c1")) - itr = ds.make_initializable_iterator() + itr = dataset_ops.make_initializable_iterator(ds) n = itr.get_next() expected_keys = list(self.COMMON_ROW_KEYS) expected_values = list(self.COMMON_VALUES) @@ -141,7 +141,7 @@ class BigtableOpsTest(test.TestCase): def testSampleKeys(self): ds = self._table.sample_keys() - itr = ds.make_initializable_iterator() + itr = dataset_ops.make_initializable_iterator(ds) n = itr.get_next() expected_key = self.COMMON_ROW_KEYS[0] with self.cached_session() as sess: @@ -161,7 +161,7 @@ class BigtableOpsTest(test.TestCase): sess.run(n) def runSampleKeyPairsTest(self, ds, expected_key_pairs): - itr = ds.make_initializable_iterator() + itr = dataset_ops.make_initializable_iterator(ds) n = itr.get_next() with self.cached_session() as sess: self._writeCommonValues(sess) @@ -218,7 +218,7 @@ class BigtableOpsTest(test.TestCase): def testSampleKeyPairsPrefixAndStartKey(self): ds = bigtable_api._BigtableSampleKeyPairsDataset( self._table, prefix="r", start="r1", end="") - itr = ds.make_initializable_iterator() + itr = dataset_ops.make_initializable_iterator(ds) with self.cached_session() as sess: with self.assertRaises(errors.InvalidArgumentError): sess.run(itr.initializer) @@ -226,14 +226,14 @@ class BigtableOpsTest(test.TestCase): def testSampleKeyPairsPrefixAndEndKey(self): ds = bigtable_api._BigtableSampleKeyPairsDataset( self._table, prefix="r", start="", end="r3") - itr = ds.make_initializable_iterator() + itr = dataset_ops.make_initializable_iterator(ds) with self.cached_session() as sess: with self.assertRaises(errors.InvalidArgumentError): sess.run(itr.initializer) def testParallelScanPrefix(self): ds = self._table.parallel_scan_prefix(prefix="r", cf1="c1") - itr = ds.make_initializable_iterator() + itr = dataset_ops.make_initializable_iterator(ds) n = itr.get_next() with self.cached_session() as sess: self._writeCommonValues(sess) @@ -251,7 +251,7 @@ class BigtableOpsTest(test.TestCase): def testParallelScanRange(self): ds = self._table.parallel_scan_range(start="r1", end="r4", cf1="c1") - itr = ds.make_initializable_iterator() + itr = dataset_ops.make_initializable_iterator(ds) n = itr.get_next() with self.cached_session() as sess: self._writeCommonValues(sess) diff --git a/tensorflow/contrib/bigtable/python/ops/bigtable_api.py b/tensorflow/contrib/bigtable/python/ops/bigtable_api.py index 7c87b0daeb09950cc44c51f49c16534d413f0376..b6cdc7aab0320fe5f457288ada03a46e18a694cc 100644 --- a/tensorflow/contrib/bigtable/python/ops/bigtable_api.py +++ b/tensorflow/contrib/bigtable/python/ops/bigtable_api.py @@ -35,8 +35,8 @@ from tensorflow.contrib.util import loader from tensorflow.python.data.experimental.ops import interleave_ops from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.util import nest +from tensorflow.python.data.util import structure from tensorflow.python.framework import dtypes -from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape from tensorflow.python.platform import resource_loader @@ -111,8 +111,7 @@ class BigtableClient(object): class BigtableTable(object): - """BigtableTable is the entrypoint for reading and writing data in Cloud - Bigtable. + """Entry point for reading and writing data in Cloud Bigtable. This BigtableTable class is the Python representation of the Cloud Bigtable table within TensorFlow. Methods on this class allow data to be read from and @@ -222,7 +221,7 @@ class BigtableTable(object): A `tf.data.Dataset`. containing `tf.string` Tensors corresponding to all of the row keys matching that prefix. """ - return _BigtablePrefixKeyDataset(self, prefix) + return dataset_ops.DatasetV1Adapter(_BigtablePrefixKeyDataset(self, prefix)) def sample_keys(self): """Retrieves a sampling of row keys from the Bigtable table. @@ -234,7 +233,7 @@ class BigtableTable(object): Returns: A `tf.data.Dataset` returning string row keys. """ - return _BigtableSampleKeysDataset(self) + return dataset_ops.DatasetV1Adapter(_BigtableSampleKeysDataset(self)) def scan_prefix(self, prefix, probability=None, columns=None, **kwargs): """Retrieves row (including values) from the Bigtable service. @@ -279,7 +278,8 @@ class BigtableTable(object): """ probability = _normalize_probability(probability) normalized = _normalize_columns(columns, kwargs) - return _BigtableScanDataset(self, prefix, "", "", normalized, probability) + return dataset_ops.DatasetV1Adapter( + _BigtableScanDataset(self, prefix, "", "", normalized, probability)) def scan_range(self, start, end, probability=None, columns=None, **kwargs): """Retrieves rows (including values) from the Bigtable service. @@ -324,7 +324,8 @@ class BigtableTable(object): """ probability = _normalize_probability(probability) normalized = _normalize_columns(columns, kwargs) - return _BigtableScanDataset(self, "", start, end, normalized, probability) + return dataset_ops.DatasetV1Adapter( + _BigtableScanDataset(self, "", start, end, normalized, probability)) def parallel_scan_prefix(self, prefix, @@ -380,7 +381,8 @@ class BigtableTable(object): """ probability = _normalize_probability(probability) normalized = _normalize_columns(columns, kwargs) - ds = _BigtableSampleKeyPairsDataset(self, prefix, "", "") + ds = dataset_ops.DatasetV1Adapter( + _BigtableSampleKeyPairsDataset(self, prefix, "", "")) return self._make_parallel_scan_dataset(ds, num_parallel_scans, probability, normalized) @@ -442,7 +444,8 @@ class BigtableTable(object): """ probability = _normalize_probability(probability) normalized = _normalize_columns(columns, kwargs) - ds = _BigtableSampleKeyPairsDataset(self, "", start, end) + ds = dataset_ops.DatasetV1Adapter( + _BigtableSampleKeyPairsDataset(self, "", start, end)) return self._make_parallel_scan_dataset(ds, num_parallel_scans, probability, normalized) @@ -589,16 +592,8 @@ class _BigtableKeyDataset(dataset_ops.DatasetSource): self._table = table @property - def output_classes(self): - return ops.Tensor - - @property - def output_shapes(self): - return tensor_shape.TensorShape([]) - - @property - def output_types(self): - return dtypes.string + def _element_structure(self): + return structure.TensorStructure(dtypes.string, []) class _BigtablePrefixKeyDataset(_BigtableKeyDataset): @@ -658,16 +653,9 @@ class _BigtableLookupDataset(dataset_ops.DatasetSource): self._columns = [i[1] for i in normalized] @property - def output_classes(self): - return tuple([ops.Tensor] * self._num_outputs) - - @property - def output_shapes(self): - return tuple([tensor_shape.TensorShape([])] * self._num_outputs) - - @property - def output_types(self): - return tuple([dtypes.string] * self._num_outputs) + def _element_structure(self): + return structure.NestedStructure(tuple( + [structure.TensorStructure(dtypes.string, [])] * self._num_outputs)) def _as_variant_tensor(self): # pylint: disable=protected-access @@ -693,16 +681,9 @@ class _BigtableScanDataset(dataset_ops.DatasetSource): self._num_outputs = len(normalized) + 1 # 1 for row key @property - def output_classes(self): - return tuple([ops.Tensor] * self._num_outputs) - - @property - def output_shapes(self): - return tuple([tensor_shape.TensorShape([])] * self._num_outputs) - - @property - def output_types(self): - return tuple([dtypes.string] * self._num_outputs) + def _element_structure(self): + return structure.NestedStructure(tuple( + [structure.TensorStructure(dtypes.string, [])] * self._num_outputs)) def _as_variant_tensor(self): return gen_bigtable_ops.bigtable_scan_dataset( @@ -726,16 +707,10 @@ class _BigtableSampleKeyPairsDataset(dataset_ops.DatasetSource): self._end = end @property - def output_classes(self): - return (ops.Tensor, ops.Tensor) - - @property - def output_shapes(self): - return (tensor_shape.TensorShape([]), tensor_shape.TensorShape([])) - - @property - def output_types(self): - return (dtypes.string, dtypes.string) + def _element_structure(self): + return structure.NestedStructure( + (structure.TensorStructure(dtypes.string, []), + structure.TensorStructure(dtypes.string, []))) def _as_variant_tensor(self): # pylint: disable=protected-access diff --git a/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py b/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py index 9fdc2fc0c2c7b85502f7a3f9ae7c85cf05d5916c..a5951fb7377d48748f5eb578c034176517df7749 100644 --- a/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py +++ b/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py @@ -614,13 +614,19 @@ class GradientBoostedDecisionTreeModel(object): predictions_dict[NUM_TREES_ATTEMPTED] % self._logits_dimension) return constant_op.constant(-1, dtype=dtypes.int32) - def update_stats(self, loss, predictions_dict): + def update_stats(self, loss, predictions_dict, gradients=None, hessians=None): """Update the accumulators with stats from this batch. Args: loss: A scalar tensor representing average loss of examples. predictions_dict: Dictionary of Rank 2 `Tensor` representing information about predictions per example. + gradients: A tensor with the gradients with the respect to logits from + predictions_dict. If not provided, tensorflow will do + autodifferentiation. + hessians: A tensor with the hessians with the respect to logits from + predictions_dict. If not provided, tensorflow will do + autodifferentiation. Returns: Three values: @@ -642,13 +648,14 @@ class GradientBoostedDecisionTreeModel(object): predictions = predictions_dict[PREDICTIONS] partition_ids = predictions_dict[PARTITION_IDS] ensemble_stamp = predictions_dict[ENSEMBLE_STAMP] - gradients = gradients_impl.gradients( - loss, - predictions, - name="Gradients", - colocate_gradients_with_ops=False, - gate_gradients=0, - aggregation_method=None)[0] + if gradients is None: + gradients = gradients_impl.gradients( + loss, + predictions, + name="Gradients", + colocate_gradients_with_ops=False, + gate_gradients=0, + aggregation_method=None)[0] strategy = self._learner_config.multi_class_strategy class_id = self._get_class_id(predictions_dict) @@ -657,17 +664,20 @@ class GradientBoostedDecisionTreeModel(object): # We build one vs rest trees. if self._logits_dimension == 1: # We have only 1 score, gradients is of shape [batch, 1]. - hessians = gradients_impl.gradients( - gradients, - predictions, - name="Hessian", - colocate_gradients_with_ops=False, - gate_gradients=0, - aggregation_method=None)[0] + if hessians is None: + hessians = gradients_impl.gradients( + gradients, + predictions, + name="Hessian", + colocate_gradients_with_ops=False, + gate_gradients=0, + aggregation_method=None)[0] squeezed_gradients = array_ops.squeeze(gradients, axis=[1]) squeezed_hessians = array_ops.squeeze(hessians, axis=[1]) else: + if hessians is not None: + raise ValueError("Providing hessians is not yet supported here.") hessian_list = self._diagonal_hessian(gradients, predictions) # Assemble hessian list into a tensor. hessians = array_ops.stack(hessian_list, axis=1) @@ -678,6 +688,8 @@ class GradientBoostedDecisionTreeModel(object): squeezed_hessians = array_ops.squeeze( _get_column_by_index(hessians, class_id)) else: + if hessians is not None: + raise ValueError("Providing hessians is not yet supported here.") # Other multiclass strategies. if strategy == learner_pb2.LearnerConfig.FULL_HESSIAN: hessian_list = self._full_hessian(gradients, predictions) @@ -835,9 +847,9 @@ class GradientBoostedDecisionTreeModel(object): stats_update_ops.append( control_flow_ops.cond( continue_centering, - self._make_update_bias_stats_fn( - ensemble_stamp, predictions, gradients, - bias_stats_accumulator), control_flow_ops.no_op)) + self._make_update_bias_stats_fn(ensemble_stamp, predictions, + gradients, bias_stats_accumulator, + hessians), control_flow_ops.no_op)) # Update handler stats. handler_reads = collections.OrderedDict() @@ -1162,7 +1174,8 @@ class GradientBoostedDecisionTreeModel(object): def get_max_tree_depth(self): return self._max_tree_depth - def train(self, loss, predictions_dict, labels): + def train(self, loss, predictions_dict, labels, gradients=None, + hessians=None): """Updates the accumalator stats and grows the ensemble. Args: @@ -1171,6 +1184,12 @@ class GradientBoostedDecisionTreeModel(object): about predictions per example. labels: Rank 2 `Tensor` representing labels per example. Has no effect on the training and is only kept for backward compatibility. + gradients: A tensor with the gradients with the respect to logits from + predictions_dict. If not provided, tensorflow will do + autodifferentiation. + hessians: A tensor with the hessians with the respect to logits from + predictions_dict. If not provided, tensorflow will do + autodifferentiation. Returns: An op that adds a new tree to the ensemble. @@ -1179,7 +1198,8 @@ class GradientBoostedDecisionTreeModel(object): ValueError: if inputs are not valid. """ del labels # unused; kept for backward compatibility. - update_op, _, training_state = self.update_stats(loss, predictions_dict) + update_op, _, training_state = self.update_stats(loss, predictions_dict, + gradients, hessians) with ops.control_dependencies(update_op): return self.increment_step_counter_and_maybe_update_ensemble( predictions_dict, training_state) @@ -1271,21 +1291,28 @@ class GradientBoostedDecisionTreeModel(object): ps_ops=ps_ops, ps_strategy=ps_strategy) - def _make_update_bias_stats_fn(self, ensemble_stamp, predictions, gradients, - bias_stats_accumulator): + def _make_update_bias_stats_fn(self, + ensemble_stamp, + predictions, + gradients, + bias_stats_accumulator, + hessians=None): """A method to create the function which updates the bias stats.""" def _update_bias_stats(): """A method to update the bias stats.""" # Get reduced gradients and hessians. grads_sum = math_ops.reduce_sum(gradients, 0) - hess = gradients_impl.gradients( - grads_sum, - predictions, - name="Hessians", - colocate_gradients_with_ops=False, - gate_gradients=0, - aggregation_method=None)[0] + if hessians is not None: + hess = hessians + else: + hess = gradients_impl.gradients( + grads_sum, + predictions, + name="Hessians", + colocate_gradients_with_ops=False, + gate_gradients=0, + aggregation_method=None)[0] hess_sum = math_ops.reduce_sum(hess, 0) # Accumulate gradients and hessians. diff --git a/tensorflow/contrib/cmake/CMakeLists.txt b/tensorflow/contrib/cmake/CMakeLists.txt index a63366e1361effe20787c197eddd66b5c0c96410..2ad9ae42a16f690d38b8e2652e853012ec1dd267 100644 --- a/tensorflow/contrib/cmake/CMakeLists.txt +++ b/tensorflow/contrib/cmake/CMakeLists.txt @@ -3,16 +3,16 @@ cmake_minimum_required(VERSION 3.5) if(WIN32) if(${CMAKE_VERSION} VERSION_LESS "3.8") - message(WARNING "Your current cmake version is ${CMAKE_VERSION} which does not support setting the toolset architecture to x64. This may cause \"compiler out of heap space\" errors when building. Consider upgrading your cmake to > 3.8 and using the flag -Thost=x64 when running cmake.") + message(WARNING "Your current cmake version is ${CMAKE_VERSION} which does not support setting the toolset architecture to x64. This may cause \"compiler out of heap space\" errors when building. Consider upgrading your cmake to > 3.8 and using the flag -Thost=x64 when running cmake. Ignore this if you are on CMake GUI.") else() if(NOT CMAKE_VS_PLATFORM_TOOLSET_HOST_ARCHITECTURE OR NOT "${CMAKE_VS_PLATFORM_TOOLSET_HOST_ARCHITECTURE}" STREQUAL "x64") - message(WARNING "Your current cmake generator is set to use 32 bit toolset architecture. This may cause \"compiler out of heap space\" errors when building. Consider using the flag -Thost=x64 when running cmake.") + message(WARNING "Your current cmake generator is set to use 32 bit toolset architecture. This may cause \"compiler out of heap space\" errors when building. Consider using the flag -Thost=x64 when running cmake. Ignore this if you are on CMake GUI.") endif() endif() endif() # Project -project(tensorflow C CXX) +project(tensorflow VERSION 1.12.0 LANGUAGES C CXX) # Set C++14 as standard for the whole project set(CMAKE_CXX_STANDARD 14) @@ -52,11 +52,17 @@ option(tensorflow_OPTIMIZE_FOR_NATIVE_ARCH "Enable compiler optimizations for th option(tensorflow_ENABLE_SNAPPY_SUPPORT "Enable SNAPPY compression support" ON) option(tensorflow_DISABLE_EIGEN_FORCEINLINE "Disable forceinline, to speed up build on windows." OFF) +if (WIN32) +SET(tensorflow_WIN_CPU_SIMD_OPTIONS "/arch:AVX" CACHE STRING "Enables CPU SIMD instructions") +SET_PROPERTY(CACHE tensorflow_WIN_CPU_SIMD_OPTIONS PROPERTY STRINGS /arch:AVX) +endif() + # SIMD, MKL and MKLDNN options option(tensorflow_WIN_CPU_SIMD_OPTIONS "Enables CPU SIMD instructions" OFF) option(tensorflow_ENABLE_MKL_SUPPORT "Enable Intel MKL support" OFF) option(tensorflow_ENABLE_MKLDNN_SUPPORT "Enable Intel MKLDNN support, requires MKL enabled" OFF) + # GPU, CUDA and cuDNN options option(tensorflow_ENABLE_GPU "Enable GPU support" OFF) @@ -79,6 +85,11 @@ if (NOT WIN32) # option's default value is OFF. Fill it with real default values set(tensorflow_CUDNN_INCLUDE /usr/include) endif (NOT tensorflow_CUDNN_INCLUDE) + option(tensorflow_NCCL_INCLUDE "nccl.h header install path" /usr/include/) + if (NOT tensorflow_NCCL_INCLUDE) + # option's default value is OFF. Fill it with real default values + set(tensorflow_NCCL_INCLUDE /usr/include) + endif (NOT tensorflow_NCCL_INCLUDE) option(tensorflow_PATH_CUDNN_LIB "Override PATH_CUDA_LIB for cudnn" ${tensorflow_PATH_CUDA_LIB}) if (NOT tensorflow_PATH_CUDNN_LIB) # option's default value is OFF. Fill it with real default values @@ -193,6 +204,7 @@ if(WIN32) set(CMAKE_SUPPRESS_REGENERATION ON) endif() + if ("${CMAKE_CXX_COMPILER_ID}" STREQUAL "GNU") set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fno-exceptions -std=c++11") endif() @@ -281,6 +293,14 @@ else (systemlib_ZLIB) ${zlib_STATIC_LIBRARIES}) endif (systemlib_ZLIB) +if (systemlib_ABSEIL_CPP) + set(tensorflow_EXTERNAL_LIBRARIES ${tensorflow_EXTERNAL_LIBRARIES} + ${abseil_cpp_LIBRARIES}) +else (systemlib_ABSEIL_CPP) + set(tensorflow_EXTERNAL_LIBRARIES ${tensorflow_EXTERNAL_LIBRARIES} + ${abseil_cpp_STATIC_LIBRARIES}) +endif (systemlib_ABSEIL_CPP) + set(tensorflow_EXTERNAL_DEPENDENCIES zlib_copy_headers_to_destination gif_copy_headers_to_destination @@ -378,8 +398,8 @@ if (tensorflow_ENABLE_GPU) list(APPEND CMAKE_LIBRARY_PATH "${tensorflow_CUDA_LIBRARY_PATH}/stubs") endif (NOT WIN32) - # minimum 9.1 in cuda version - find_package(CUDA 9.1 REQUIRED) + # minimum 9.0 in cuda version + find_package(CUDA 9.0 REQUIRED) if(NOT CUDA_FOUND) message(FATAL_ERROR "CUDA not found.") endif() @@ -394,6 +414,7 @@ if (tensorflow_ENABLE_GPU) set(CUDA_NVCC_FLAGS ${CUDA_NVCC_FLAGS};--include-path ${PROJECT_BINARY_DIR}/$\{build_configuration\};--expt-relaxed-constexpr) set(CUDA_NVCC_FLAGS ${CUDA_NVCC_FLAGS};-ftz=true) # Flush denormals to zero set(CUDA_INCLUDE ${CUDA_TOOLKIT_TARGET_DIR} ${CUDA_TOOLKIT_TARGET_DIR}/extras/CUPTI/include) + include_directories(${CUDA_INCLUDE}) if (WIN32) add_definitions(-DGOOGLE_CUDA=1 -DTF_EXTRA_CUDA_CAPABILITIES=3.7,5.2,6.0,6.1,7.0) @@ -546,14 +567,20 @@ if (tensorflow_ENABLE_GPU) cudnn_version_number=${tensorflow_CUDNN_VERSION}) endif(WIN32) else(tensorflow_ENABLE_GPU) - set(tensorflow_BUILD_INFO_FLAGS --build_config cpu --key_value - msvcp_dll_name=msvcp140.dll) + if(WIN32) + set(tensorflow_BUILD_INFO_FLAGS --build_config cpu --key_value + msvcp_dll_name=msvcp140.dll) + else() + set(tensorflow_BUILD_INFO_FLAGS --build_config cpu) + endif() endif(tensorflow_ENABLE_GPU) -# Find python executable -include(FindPythonInterp) -if(NOT ${PYTHONINTERP_FOUND}) - message(FATAL_ERROR "CMake was unable to find a python interpreter.") +if(tensorflow_BUILD_PYTHON_BINDINGS) + # Find python executable + include(FindPythonInterp) + if(NOT ${PYTHONINTERP_FOUND}) + message(FATAL_ERROR "CMake was unable to find a python interpreter.") + endif() endif() # Let's get to work! @@ -574,6 +601,7 @@ include(tf_cc_ops.cmake) include(tf_c.cmake) include(tf_grappler.cmake) include(tf_core_profiler.cmake) +include(tf_core_eager_runtime.cmake) if(tensorflow_BUILD_CC_EXAMPLE) include(tf_tutorials.cmake) include(tf_label_image_example.cmake) @@ -587,4 +615,4 @@ if(tensorflow_BUILD_SHARED_LIB) endif() if(tensorflow_BUILD_CC_TESTS OR tensorflow_BUILD_PYTHON_TESTS) include(tf_tests.cmake) -endif() +endif() \ No newline at end of file diff --git a/tensorflow/contrib/cmake/README.md b/tensorflow/contrib/cmake/README.md index 84c679162c3ed8ffc9babcd3af583b26fb62c2d6..df8b48dfc46124d3b9454d92ffb70dbcf1bc4217 100644 --- a/tensorflow/contrib/cmake/README.md +++ b/tensorflow/contrib/cmake/README.md @@ -5,10 +5,10 @@ CMAKE build is deprecated for TensorFlow. Please use `bazel` to build TF for all platforms. For details, see the [TensorFlow install guide](https://www.tensorflow.org/install/). -This directory contains CMake files for building TensorFlow on Microsoft -Windows. [CMake](https://cmake.org) is a cross-platform tool that can -generate build scripts for multiple build systems, including Microsoft -Visual Studio. +This directory contains CMake files for building TensorFlow on Microsoft Windows +and Linux. [CMake](https://cmake.org) is a cross-platform tool that can generate +build scripts for multiple build systems, including Microsoft Visual Studio and +GCC. "The method has not been tested on Mac OS X. **N.B.** We provide Linux build instructions primarily for the purpose of testing the build. We recommend using the standard Bazel-based build on @@ -17,12 +17,17 @@ Linux. Current Status -------------- -CMake can be used to build TensorFlow on Windows. See the [getting started documentation](https://www.tensorflow.org/install/source_windows) -for instructions on how to install a pre-built TensorFlow package on Windows. +CMake can be used to build TensorFlow on all platforms. See the +[getting started documentation](https://www.tensorflow.org/install/install_windows) +for instructions on how to install a pre-built TensorFlow package on Windows and +Linux. The procedure in MacOS is similar to the Linux build. ### Current known limitations -* It is not possible to load a custom Op library. -* GCS file system is not supported. + +* It is not possible to load a custom Op library. +* GCS file system is not supported. +* Debug build is not available since Python for Windows is no longer + distributed with a debug library. ## Building with CMake @@ -32,70 +37,88 @@ bindings. ### Prerequisites -* CMake version 3.5 or later. +* CMake version 3.5 or later. + +* [Git](https://git-scm.com) + +* [SWIG](http://www.swig.org/download.html) + +* [Perl](https://www.perl.org/get.html) (optional, for SSL support build) + +* [Go](https://golang.org/) (optional, for SSL support build) + +* [NASM](http://www.nasm.us/)/[YASM](http://yasm.tortall.net/) (optional, for + SSL support build) + +* Additional pre-requisites for Microsoft Windows: + + - Visual Studio 2015 (latest version of MSVC 2017 is not supported by CUDA + yet, try it on your own risk) -* [Git](https://git-scm.com) + - Python 3.5 -* [SWIG](http://www.swig.org/download.html) +* Additional prerequisites for Linux: -* Additional prerequisites for Microsoft Windows: - - Visual Studio 2015 - - Python 3.5 + - Python 2.7 or later + - [Docker](https://www.docker.com/) (for automated testing) -* Additional prerequisites for Linux: - - Python 2.7 or later - - [Docker](https://www.docker.com/) (for automated testing) +* Python dependencies: -* Python dependencies: - - wheel - - NumPy 1.11.0 or later + - wheel + - NumPy 1.11.0 or later ### Known-good configurations -* Microsoft Windows 10 - - Microsoft Visual Studio Enterprise 2015 with Visual C++ 2015 - - [Anaconda 4.1.1 (Python 3.5 64-bit)](https://www.anaconda.com/download/) - - [Git for Windows version 2.9.2.windows.1](https://git-scm.com/download/win) - - [swigwin-3.0.10](http://www.swig.org/download.html) - - [NVidia CUDA Toolkit 8.0](https://developer.nvidia.com/cuda-downloads) - - [NVidia CUDNN 5.1](https://developer.nvidia.com/cudnn) - - [CMake 3.6](https://cmake.org/files/v3.6/cmake-3.6.3-win64-x64.msi) +* Microsoft Windows 10 -* Ubuntu 14.04 - - Makefile generator - - Docker 1.9.1 (for automated testing) + - Microsoft Visual Studio Enterprise/ Community 2015 with Visual C++ 2015 + - [Anaconda 4.1.1 (Python 3.5 64-bit)](https://www.anaconda.com/download/) + - [Git for Windows version 2.9.2.windows.1](https://git-scm.com/download/win) + - [swigwin-3.0.10](http://www.swig.org/download.html) + - [NVidia CUDA Toolkit 9.0](https://developer.nvidia.com/cuda-downloads) + - [NVidia CUDNN 7](https://developer.nvidia.com/cudnn) + - [CMake 3.6](https://cmake.org/files/v3.6/cmake-3.6.3-win64-x64.msi) + +* Ubuntu 14.04 + + - Makefile generator + - Docker 1.9.1 (for automated testing) ### Current known limitations - - The Python package supports **Python 3.5 only**, because that is the only - version for which standard Python binaries exist and those binaries are - compatible with the TensorFlow runtime. (On Windows, the standard Python + +- The Python package supports **Python 3.5/3.6 only**, because these are the + only versions for which standard Python binaries exist and those binaries + are compatible with the TensorFlow runtime. (On Windows, the standard Python binaries for versions earlier than 3.5 were compiled with older compilers that do not have all of the features (e.g. C++11 support) needed to compile - TensorFlow. We welcome patches for making TensorFlow work with Python 2.7 - on Windows, but have not yet committed to supporting that configuration.) - - - The following Python APIs are not currently implemented: - * Loading custom op libraries via `tf.load_op_library()`. In order to use your - custom op, please put the source code under the tensorflow/core/user_ops - directory, and a shape function is required (not optional) for each op. - * Path manipulation functions (such as `tf.gfile.ListDirectory()`) are not - functional. - - - The `tf.contrib` libraries are not currently included in the PIP package. - - - The following operations are not currently implemented: - * `DepthwiseConv2dNative` - * `Digamma` - * `Erf` - * `Erfc` - * `Igamma` - * `Igammac` - * `ImmutableConst` - * `Lgamma` - * `Polygamma` - * `Zeta` - - - Google Cloud Storage support is not currently implemented. The GCS library + TensorFlow. We welcome patches for making TensorFlow work with Python 2.7 on + Windows, but have not yet committed to supporting that configuration.) + +- The following Python APIs are not currently implemented: + + * Loading custom op libraries via `tf.load_op_library()`. In order to use + your custom op, please put the source code under the + tensorflow/core/user_ops directory, and a shape function is required + (not optional) for each op. + * Path manipulation functions (such as `tf.gfile.ListDirectory()`) are not + functional. + +- The `tf.contrib` libraries are not currently included in the PIP package. + +- The following operations are not currently implemented: + + * `DepthwiseConv2dNative` + * `Digamma` + * `Erf` + * `Erfc` + * `Igamma` + * `Igammac` + * `ImmutableConst` + * `Lgamma` + * `Polygamma` + * `Zeta` + +- Google Cloud Storage support is not currently implemented. The GCS library currently depends on `libcurl` and `boringssl`, and the Windows version could use standard Windows APIs for making HTTP requests and cryptography (for OAuth). Contributions are welcome for this feature. @@ -104,9 +127,211 @@ We are actively working on improving CMake and Windows support, and addressing these limitations. We would appreciate pull requests that implement missing ops or APIs. +# CMake GUI build (all platforms) + +Install from CMake GUI would be a convenient way to generate C++ build projects. +The software supports Windows, MacOS and Linux, while the posix platform +provides an extra ccmake binary to run command line GUI. Both working principal +of cmake, ccmake and cmake-gui are the same, the only difference is by providing +suitable interface for project configuration and dependency setting. + +1. Pre-buid checklist: The following binary/libraries should be setted in + system path, otherwise you need to set manualy via cmake. + * Compiler (GCC for Linux, MSVC for Windows) + * Make sure compiler directory has been set to system path + * CUDA 9.0 (GPU build) + * CUDNN (GPU build) + * NCCL (GPU build on Linux) + * SWIG (python binding) + * Perl (required if you need ssl support, optional) + * Go (required if you need ssl support, optional) + * NASM/YASM (required by grpc for ssl support, optional) +2. Start CMake GUI +3. Click on `Browse Source` and direct to the the folder + `/tensorflow/contrib/cmake` +4. Click on `Browse Build` and spectify a location that you want tensorflow to + be build +5. Click on `Configure`, a new window will be prompted out, specify the + generator mode for the project generation. For Windows, choose `Visual + Studio Win64`, for Linux, choose `Unix Makefiles`, then + press `Finish`. Wait for a moment, the default project dependecy would + automatically generate. +6. There are a few options that you can customize your own build. **The setting + here is crucial for a sucessful build, please check all items carefully.** + + * `tensorflow_BUILD_ALL_KERNELS` should alway be `on` + * `tensorflow_BUILD_CC_EXAMPLE` is default to be `on`. This can help you + to test build (optional) + * `tensorflow_BUILD_CONTRIB_KERNELS` is default to be `on`, but it won't + affect tensorflow function, turn it to `off` if you want a slim build. + (optional) + * `tensorflow_BUILD_PYTHON_BINDING` is default to be `on`. Set to `off` if + you don't need python interaface. If SWIG is not in system path, you + need set it manually. (optional) + * `tensorflow_BUILD_SHARED_LIB` is default to be `off`. Set to `on` if you + want the c++ interface. (optional) + * `tensorflow_ENABLE_GPU` is default to be `off`. Set to `on` if you want + GPU support. It will search CUDA and CUDNN dependecies if you have set + them to system path, otherwise CMake would prompt error and request you + to set it manually. (optional) + * `tensorflow_ENABLE_GRPC_SUPPORT` is default to be `on`. For Linux build, + this option must always be `on`. This need to be `on` for a gpu build. + Reminded that Perl, Go and NASM/YASM are required for this option if you + want to build grpc with offical SSL support. + * `tensorflow_ENABLE_POSITION_INDEPENDENT_CODE` should always be `on` + * `tensorflow_ENABLE_SNAPPY_SUPPORT` should always be `on` + * `tensorflow_OPTIMIZE_FOR_NATIVE_ARCH` should always be `on` + * `CMAKE_INSTALL_PREFIX` is the location where the final package will be + installed. You may change it to your own preferred path (optional) + +7. After changing the configuration in step 5, press `Configure` again + +8. If not error is found, press `Generate` + +#### Windows + +1. Open `tensorflow.sln` in the build folder (Windows). Change build type from + `Debug` to `Release`. Choose `Build`->`Build Solution`. This may take more + than hours of compilation. If everything is alright, the output window would + show no error. + + ##### Python + + In solution explorer, right click on `tf_python_build_pip_package` -> + `build`. It will generate the wheel file in + `/tf_python/dist`. Install with following command: + + `pip install --upgrade tensorflow-.whl` + + ***The wheel name varies depends on you config. Change to your own wheel + filename.*** + + Reminded that some pip installation requires administrator right command + prompt. + + ##### C++ + + You can directly use the build folder tree for C++ interface with cmake. If + you want to do installation for api releasing, right click on `Install` -> + `build`. The headers and library will be installed in the directory specify + by `CMAKE_INSTALL_PREFIX` during configuration. + +1. For smaller RAM computer, it is noticed that out of heap space error + appears. Change to command prompt build is an alternative to do step 1. + + Open `VS2015 x64 Native Tools Command Prompt`. You can open it by press + `Start`, then type the binary name. Use `VS2017 x64 Native Tools Command + Prompt` if you are using MSVC 2017. + + ##### Python + + Directly build python wheel package by following command: + + `MSBuild /p:Configuration=Release + ` + + Remember to change `` to the + actual path of the file, it can be found at the root of build directory + + Install the wheel file generated as instructed by step 1. + + ##### C++ interface + + Build from VS native toolchain with following command: `MSBuild + /p:Configuration=Release ` + + Headers are discretely located in the build folders. Tensorflow library can + be found at `/Release`, namely `tensorflow.dll` and + `tensorflow.lib`. + + * Build to install for api release (optional): `MSBuild + /p:Configuration=Release ` + + Remember to change `` and + `` to the actual path of the file, it can be found + at the root of build directory. + +#### Linux/MacOS (command line GNU build) + +1. Open the terminal, change working directory to the one specified in step 3. + +2. Type the following command: + + `make -sj all` + + ##### Python + + **Important Note** CMake generated python wheel for Linux/MacOs is currently + under development. Please use bazel build. + + Follow code is an expected Linux/MacOS python package build after + development work is completed. + + ``` + make -sj tf_python_build_pip_package + cd tf_python + pip install --upgrade tensorflow-.whl + ``` + + ##### C++ interface + + `make -sj install` + + Where `` is the threads used for the compilation, change + to any integer less or equal to your computer's maxiumum thread number. + + Headers are discretely located in the build folders. Tensorflow library can + be found at ``, namely `tensorflow.so` (Linux) or + `tensorflow.dylib` (MacOS). + +#### Start a Tensorflow C++ project with CMake + +Here we assume that you have basic knowledge on gathering dependency with +`CMakeLists.txt`. Here we introduce how the C++ api works with +[official hello world tutorial](https://www.tensorflow.org/api_guides/cc/guide). + +1. Create a new working directory and create a new text file named + `CMakeLists.txt` and the c++ file `main.cxx` +2. Fill in the `main.cxx` with the code provided in + [official c++ api basic](https://www.tensorflow.org/api_guides/cc/guide). +3. Fill in the `CMakeLists.txt` with following code: ``` cmake + cmake_minimum_required (VERSION 2.6) project (tf_hello) + + # Tensorflow + + find_package(Tensorflow REQUIRED) + include_directories(${TENSORFLOW_INCLUDE_DIRS}) + + # compiler setting required by tensorflow, to be tested on all compilers + + # currently only tested on MSVC and GCC + + if (${CMAKE_CXX_COMPILER_ID} STREQUAL MSVC) add_definitions(-DCOMPILER_MSVC) + elseif (${CMAKE_CXX_COMPILER_ID} STREQUAL GNU) if + (${CMAKE_CXX_COMPILER_VERSION} VERSION_LESS "3") + add_definitions(-DCOMPILER_GCC3) else() add_definitions(-D__GNUC__) endif() + else() message(ERROR " compiler ${CMAKE_CXX_COMPILER_ID} not supported by + this CMakeList.txt, under development") endif() + + add_executable(tf_hello main.cxx) target_link_libraries(tf_hello + ${TENSORFLOW_LIBRARIES}) ``` + +4. Configure the folder with cmake-gui, an error should be prompted out, + requesting you to locate the folder containing `TensorflowConfig.cmake`. + This file can be found at `` or `` (for + those have build install in previous steps). + +5. Configure again, generate the project. + +6. Compile the project with `Release` config (Windows). For Linux users, just + compile the project. + +7. Copy the `tensorflow.dll`(Windows)/`tensorflow.so`(Linux) from build + directory to the build folder containing `tf_hello` binary. + +8. Run `tf_hello` binary -Step-by-step Windows build -========================== +# Step-by-step Windows build (command prompt) 1. Install the prerequisites detailed above, and set up your environment. diff --git a/tensorflow/contrib/cmake/TensorflowConfig.cmake.in b/tensorflow/contrib/cmake/TensorflowConfig.cmake.in new file mode 100644 index 0000000000000000000000000000000000000000..cc04db6e952f53b8bb5416dde60b8173e60bf60e --- /dev/null +++ b/tensorflow/contrib/cmake/TensorflowConfig.cmake.in @@ -0,0 +1,16 @@ +# - Config file for the Tensorflow package +# It defines the following variables +# TENSORFLOW_INCLUDE_DIRS - include directories for FooBar +# TENSORFLOW_LIBRARIES - libraries to link against + +# Compute paths +get_filename_component(TENSORFLOW_CMAKE_DIR "${CMAKE_CURRENT_LIST_FILE}" PATH) +set(TENSORFLOW_INCLUDE_DIRS "@CONF_INCLUDE_DIRS@") + +# Our library dependencies (contains definitions for IMPORTED targets) +if(NOT TENSORFLOW_BINARY_DIR) + include("${TENSORFLOW_CMAKE_DIR}/TensorflowTargets.cmake") +endif() + +# These are IMPORTED targets created by TensorflowTargets.cmake +set(TENSORFLOW_LIBRARIES tensorflow) \ No newline at end of file diff --git a/tensorflow/contrib/cmake/TensorflowConfigVersion.cmake.in b/tensorflow/contrib/cmake/TensorflowConfigVersion.cmake.in new file mode 100644 index 0000000000000000000000000000000000000000..2a9609ddb9c4ca864651818bdfae0f8fe290de31 --- /dev/null +++ b/tensorflow/contrib/cmake/TensorflowConfigVersion.cmake.in @@ -0,0 +1,11 @@ +set(PACKAGE_VERSION "@TENSORFLOW_VERSION@") + +# Check whether the requested PACKAGE_FIND_VERSION is compatible +if("${PACKAGE_VERSION}" VERSION_LESS "${PACKAGE_FIND_VERSION}") + set(PACKAGE_VERSION_COMPATIBLE FALSE) +else() + set(PACKAGE_VERSION_COMPATIBLE TRUE) + if ("${PACKAGE_VERSION}" VERSION_EQUAL "${PACKAGE_FIND_VERSION}") + set(PACKAGE_VERSION_EXACT TRUE) + endif() +endif() \ No newline at end of file diff --git a/tensorflow/contrib/cmake/external/abseil_cpp.cmake b/tensorflow/contrib/cmake/external/abseil_cpp.cmake index 4546dbdecc0dbc36f17cc727345e0762718b5165..46a193971c5084523d432065f265fa7a9909f595 100644 --- a/tensorflow/contrib/cmake/external/abseil_cpp.cmake +++ b/tensorflow/contrib/cmake/external/abseil_cpp.cmake @@ -31,27 +31,24 @@ if (systemlib_ABSEIL_CPP) message(STATUS " abseil_cpp includes: ${ABSEIL_CPP_INCLUDE_DIR}") message(STATUS " abseil_cpp libraries: ${ABSEIL_CPP_LIBRARIES}") - add_custom_target(abseil_cpp_build) - list(APPEND tensorflow_EXTERNAL_DEPENDENCIES abseil_cpp_build) + add_custom_target(abseil_cpp) + list(APPEND tensorflow_EXTERNAL_DEPENDENCIES abseil_cpp) else (systemlib_ABSEIL_CPP) include (ExternalProject) - set(abseil_cpp_INCLUDE_DIR ${CMAKE_BINARY_DIR}/abseil_cpp/src/abseil_cpp_build) + set(abseil_cpp_INCLUDE_DIR ${CMAKE_BINARY_DIR}/abseil_cpp/src/abseil_cpp) set(abseil_cpp_URL https://github.com/abseil/abseil-cpp/archive/e01d95528ea2137a4a27a88d1f57c6cb260aafed.tar.gz) set(abseil_cpp_HASH SHA256=84043ed402d2a2a6ba4cdddb7e85118b1158fd81fe4ac3a14adc343d054c1e2e) - set(abseil_cpp_BUILD ${CMAKE_BINARY_DIR}/abseil_cpp/src/abseil_cpp_build) + set(abseil_cpp_BUILD ${CMAKE_BINARY_DIR}/abseil_cpp/src/abseil_cpp-build) if(WIN32) if(${CMAKE_GENERATOR} MATCHES "Visual Studio.*") set(abseil_cpp_STATIC_LIBRARIES ${abseil_cpp_BUILD}/absl/base/Release/absl_base.lib - ${abseil_cpp_BUILD}/absl/base/Release/absl_spinlock_wait.lib ${abseil_cpp_BUILD}/absl/base/Release/absl_dynamic_annotations.lib - ${abseil_cpp_BUILD}/absl/base/Release/absl_malloc_internal.lib - ${abseil_cpp_BUILD}/absl/base/Release/absl_throw_delegate.lib - ${abseil_cpp_BUILD}/absl/numeric/Release/absl_int128.lib + ${abseil_cpp_BUILD}/absl/base/Release/absl_internal_malloc_internal.lib ${abseil_cpp_BUILD}/absl/strings/Release/absl_strings.lib ${abseil_cpp_BUILD}/absl/strings/Release/str_format_internal.lib ${abseil_cpp_BUILD}/absl/types/Release/absl_bad_optional_access.lib) @@ -80,15 +77,12 @@ else (systemlib_ABSEIL_CPP) ${abseil_cpp_BUILD}/absl/types/libabsl_bad_optional_access.a) endif() - ExternalProject_Add(abseil_cpp_build + ExternalProject_Add(abseil_cpp PREFIX abseil_cpp URL ${abseil_cpp_URL} URL_HASH ${abseil_cpp_HASH} DOWNLOAD_DIR "${DOWNLOAD_LOCATION}" - BUILD_IN_SOURCE 1 BUILD_BYPRODUCTS ${abseil_cpp_STATIC_LIBRARIES} - BUILD_COMMAND ${CMAKE_COMMAND} --build . --config Release - COMMAND ${CMAKE_COMMAND} --build . --config Release INSTALL_COMMAND "" CMAKE_CACHE_ARGS -DCMAKE_POSITION_INDEPENDENT_CODE:BOOL=${tensorflow_ENABLE_POSITION_INDEPENDENT_CODE} @@ -99,6 +93,6 @@ else (systemlib_ABSEIL_CPP) include_directories(${abseil_cpp_INCLUDE_DIR}) list(APPEND tensorflow_EXTERNAL_LIBRARIES ${abseil_cpp_STATIC_LIBRARIES}) - list(APPEND tensorflow_EXTERNAL_DEPENDENCIES abseil_cpp_build) + list(APPEND tensorflow_EXTERNAL_DEPENDENCIES abseil_cpp) -endif (systemlib_ABSEIL_CPP) +endif (systemlib_ABSEIL_CPP) \ No newline at end of file diff --git a/tensorflow/contrib/cmake/external/grpc.cmake b/tensorflow/contrib/cmake/external/grpc.cmake index b1e64aa55c80ad59cfdc0f4767c0282b4f73367f..e570c09ecb5e64130ed6f3375a51d74850cc3989 100644 --- a/tensorflow/contrib/cmake/external/grpc.cmake +++ b/tensorflow/contrib/cmake/external/grpc.cmake @@ -17,7 +17,7 @@ include (ExternalProject) set(GRPC_INCLUDE_DIRS ${CMAKE_CURRENT_BINARY_DIR}/grpc/src/grpc/include) set(GRPC_URL https://github.com/grpc/grpc.git) set(GRPC_BUILD ${CMAKE_CURRENT_BINARY_DIR}/grpc/src/grpc) -set(GRPC_TAG d184fa229d75d336aedea0041bd59cb93e7e267f) +set(GRPC_TAG 69b6c047bc767b4d80e7af4d00ccb7c45b683dae) if(WIN32) # We use unsecure gRPC because boringssl does not build on windows @@ -26,9 +26,9 @@ if(WIN32) set(grpc_SSL_PROVIDER NONE) if(${CMAKE_GENERATOR} MATCHES "Visual Studio.*") set(grpc_STATIC_LIBRARIES - ${CMAKE_CURRENT_BINARY_DIR}/grpc/src/grpc/Release/grpc++_unsecure.lib - ${CMAKE_CURRENT_BINARY_DIR}/grpc/src/grpc/Release/grpc_unsecure.lib - ${CMAKE_CURRENT_BINARY_DIR}/grpc/src/grpc/Release/gpr.lib) + ${CMAKE_CURRENT_BINARY_DIR}/grpc/src/grpc/$(Configuration)/grpc++_unsecure.lib + ${CMAKE_CURRENT_BINARY_DIR}/grpc/src/grpc/$(Configuration)/grpc_unsecure.lib + ${CMAKE_CURRENT_BINARY_DIR}/grpc/src/grpc/$(Configuration)/gpr.lib) else() set(grpc_STATIC_LIBRARIES ${CMAKE_CURRENT_BINARY_DIR}/grpc/src/grpc/grpc++_unsecure.lib @@ -43,8 +43,9 @@ else() ${CMAKE_CURRENT_BINARY_DIR}/grpc/src/grpc/libgrpc++.a ${CMAKE_CURRENT_BINARY_DIR}/grpc/src/grpc/libgrpc.a ${CMAKE_CURRENT_BINARY_DIR}/grpc/src/grpc/libaddress_sorting.a + ${CMAKE_CURRENT_BINARY_DIR}/grpc/src/grpc/libgpr.a ${CMAKE_CURRENT_BINARY_DIR}/grpc/src/grpc/third_party/cares/cares/lib/libcares.a - ${CMAKE_CURRENT_BINARY_DIR}/grpc/src/grpc/libgpr.a) + ${CMAKE_CURRENT_BINARY_DIR}/grpc/src/grpc/third_party/zlib/libz.a) endif() add_definitions(-DGRPC_ARES=0) @@ -66,7 +67,7 @@ ExternalProject_Add(grpc -DPROTOBUF_INCLUDE_DIRS:STRING=${PROTOBUF_INCLUDE_DIRS} -DPROTOBUF_LIBRARIES:STRING=${protobuf_STATIC_LIBRARIES} -DZLIB_ROOT:STRING=${ZLIB_INSTALL} - -DgRPC_SSL_PROVIDER:STRING=${grpc_SSL_PROVIDER} + -DgRPC_SSL_PROVIDER:STRING=${grpc_SSL_PROVIDER} ) # grpc/src/core/ext/census/tracing.c depends on the existence of openssl/rand.h. diff --git a/tensorflow/contrib/cmake/external/png.cmake b/tensorflow/contrib/cmake/external/png.cmake index 1a147e9c8e5a9fee17a81e37c9babe3c9ec0290b..32e6d78e508e25f76bd263e9d52b6574ca315f6c 100644 --- a/tensorflow/contrib/cmake/external/png.cmake +++ b/tensorflow/contrib/cmake/external/png.cmake @@ -59,6 +59,7 @@ ExternalProject_Add(png -DCMAKE_VERBOSE_MAKEFILE:BOOL=OFF -DCMAKE_INSTALL_PREFIX:STRING=${png_INSTALL} -DZLIB_ROOT:STRING=${ZLIB_INSTALL} + -DPNG_TESTS:BOOL=OFF ) ## put png includes in the directory where they are expected diff --git a/tensorflow/contrib/cmake/external/protobuf.cmake b/tensorflow/contrib/cmake/external/protobuf.cmake index 56a57a2340ddc7f923c611c222a0399e279ad58a..773c37b309b1dff4ed28d24cd7d6140a63ec5bc6 100644 --- a/tensorflow/contrib/cmake/external/protobuf.cmake +++ b/tensorflow/contrib/cmake/external/protobuf.cmake @@ -16,7 +16,18 @@ include (ExternalProject) set(PROTOBUF_INCLUDE_DIRS ${CMAKE_CURRENT_BINARY_DIR}/protobuf/src/protobuf/src) set(PROTOBUF_URL https://github.com/google/protobuf.git) -set(PROTOBUF_TAG v3.6.1) + +# enable choose protobuf versions +SET(PROTOBUF_VERSION "3.6.1" CACHE STRING "Protobuf version") +SET_PROPERTY(CACHE PROTOBUF_VERSION PROPERTY STRINGS "3.4.0" "3.5.0" "3.6.1") + +if(${PROTOBUF_VERSION} STREQUAL "3.5.1") + set(PROTOBUF_TAG v3.6.1) +elseif(${PROTOBUF_VERSION} STREQUAL "3.5.0") + set(PROTOBUF_TAG 2761122b810fe8861004ae785cc3ab39f384d342) +elseif(${PROTOBUF_VERSION} STREQUAL "3.4.0") + set(PROTOBUF_TAG b04e5cba356212e4e8c66c61bbe0c3a20537c5b9) +endif() if(WIN32) if(${CMAKE_GENERATOR} MATCHES "Visual Studio.*") diff --git a/tensorflow/contrib/cmake/modules/FindAbseilCpp.cmake b/tensorflow/contrib/cmake/modules/FindAbseilCpp.cmake index d4f8bb1bec9ae8eff58dfe78168d8e71319c85e1..944ae3997a9489c13f65f93d9a7e61c21dd975c1 100644 --- a/tensorflow/contrib/cmake/modules/FindAbseilCpp.cmake +++ b/tensorflow/contrib/cmake/modules/FindAbseilCpp.cmake @@ -24,10 +24,10 @@ if(EXISTS "${ABSEIL_CPP_INCLUDE_DIR}" AND NOT "${ABSEIL_CPP_INCLUDE_DIR}" STREQU # search all libraries if no COMPONENTS was requested set(AbseilCpp_FIND_COMPONENTS "absl_algorithm;absl_any;absl_bad_any_cast" - "absl_bad_optional_access;absl_base absl_container;absl_debugging" + "absl_bad_optional_access;absl_base;absl_container;absl_debugging" "absl_dynamic_annotations;absl_examine_stack;absl_failure_signal_handler" - "absl_int128;absl_leak_check;absl_malloc_internal;absl_memory;absl_meta" - "absl_numeric;absl_optional;absl_span;absl_spinlock_wait;absl_stack_consumption" + "absl_int128;absl_leak_check;absl_internal_malloc_internal;absl_memory;absl_meta" + "absl_numeric;absl_optional;absl_span;absl_internal_spinlock_wait;absl_stack_consumption" "absl_stacktrace;absl_str_format;absl_strings;absl_symbolize;absl_synchronization" "absl_throw_delegate;absl_time;absl_utility;str_format_extension_internal" "str_format_internal;test_instance_tracker_lib") diff --git a/tensorflow/contrib/cmake/tf_c.cmake b/tensorflow/contrib/cmake/tf_c.cmake index 7a30eb94f54b18a2a517615a315e23e09e1170d0..a04142bd249ed5e16beba11057d0efc1e191e31b 100644 --- a/tensorflow/contrib/cmake/tf_c.cmake +++ b/tensorflow/contrib/cmake/tf_c.cmake @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== + ######################################################## # tf_c_framework library ######################################################## diff --git a/tensorflow/contrib/cmake/tf_cc_ops.cmake b/tensorflow/contrib/cmake/tf_cc_ops.cmake index 6c90cf398c69c8c1b22ea75e0c407f258e2535f9..6514ae50a4a35b35ba100af6997079294c22f9b8 100644 --- a/tensorflow/contrib/cmake/tf_cc_ops.cmake +++ b/tensorflow/contrib/cmake/tf_cc_ops.cmake @@ -149,11 +149,7 @@ add_library(tf_cc OBJECT ${tf_cc_srcs}) add_dependencies(tf_cc tf_cc_framework tf_cc_ops) if (WIN32) - if(${CMAKE_GENERATOR} MATCHES "Visual Studio.*") - set (pywrap_tensorflow_lib "${CMAKE_CURRENT_BINARY_DIR}/${CMAKE_BUILD_TYPE}/pywrap_tensorflow_internal.lib") - else() - set (pywrap_tensorflow_lib "${CMAKE_CURRENT_BINARY_DIR}/pywrap_tensorflow_internal.lib") - endif() + set (pywrap_tensorflow_lib "${CMAKE_CURRENT_BINARY_DIR}/$(Configuration)/pywrap_tensorflow_internal.lib") else (WIN32) set (pywrap_tensorflow_lib "${CMAKE_CURRENT_BINARY_DIR}/libpywrap_tensorflow_internal${CMAKE_SHARED_LIBRARY_SUFFIX}") endif (WIN32) diff --git a/tensorflow/contrib/cmake/tf_core_cpu.cmake b/tensorflow/contrib/cmake/tf_core_cpu.cmake index a54cbff33b66d63d7229fa2f50b8a4ca962111ed..d8884d464fb5974d77506561a9ed36110a3804c0 100644 --- a/tensorflow/contrib/cmake/tf_core_cpu.cmake +++ b/tensorflow/contrib/cmake/tf_core_cpu.cmake @@ -39,6 +39,8 @@ file(GLOB_RECURSE tf_core_cpu_exclude_srcs "${tensorflow_source_dir}/tensorflow/core/*test*.h" "${tensorflow_source_dir}/tensorflow/core/*test*.cc" "${tensorflow_source_dir}/tensorflow/core/*main.cc" + "${tensorflow_source_dir}/tensorflow/core/common_runtime/eager/*.cc" + "${tensorflow_source_dir}/tensorflow/core/common_runtime/eager/*.h" "${tensorflow_source_dir}/tensorflow/core/common_runtime/gpu/*.cc" "${tensorflow_source_dir}/tensorflow/core/common_runtime/gpu_device_factory.cc" "${tensorflow_source_dir}/tensorflow/core/common_runtime/direct_session.cc" diff --git a/tensorflow/contrib/cmake/tf_core_eager_runtime.cmake b/tensorflow/contrib/cmake/tf_core_eager_runtime.cmake new file mode 100644 index 0000000000000000000000000000000000000000..78e4c0d3035cdaefa1d0950f4270d60152c805af --- /dev/null +++ b/tensorflow/contrib/cmake/tf_core_eager_runtime.cmake @@ -0,0 +1,57 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +######################################################## +# tf_core_eager_runtime library +######################################################## +file(GLOB_RECURSE tf_core_eager_runtime_srcs + "${tensorflow_source_dir}/tensorflow/core/common_runtime/eager/*.cc" + "${tensorflow_source_dir}/tensorflow/core/common_runtime/eager/*.h" +) + +file(GLOB_RECURSE tf_core_eager_runtime_exclude_srcs + "${tensorflow_source_dir}/tensorflow/core/common_runtime/eager/*test*.h" + "${tensorflow_source_dir}/tensorflow/core/common_runtime/eager/*test*.cc" +) + +list(REMOVE_ITEM tf_core_eager_runtime_srcs ${tf_core_eager_runtime_exclude_srcs}) + +add_library(tf_core_eager_runtime OBJECT ${tf_core_eager_runtime_srcs}) +add_dependencies( + tf_core_eager_runtime + tf_c + tf_core_lib) + + +file(GLOB_RECURSE tf_c_eager_srcs + "${tensorflow_source_dir}/tensorflow/c/eager/*.cc" + "${tensorflow_source_dir}/tensorflow/c/eager/*.h" +) + +file(GLOB_RECURSE tf_c_eager_exlclude_srcs + "${tensorflow_source_dir}/tensorflow/c/eager/*test*.h" + "${tensorflow_source_dir}/tensorflow/c/eager/*test*.cc" +) + +list(REMOVE_ITEM tf_c_eager_srcs ${tf_c_eager_exlclude_srcs}) + +add_library(tf_c_eager OBJECT ${tf_c_eager_srcs}) +add_dependencies( + tf_c_eager + tf_core_eager_runtime + tf_c + tf_cc_framework + tf_cc_while_loop + tf_core_lib + tf_protos_cc) \ No newline at end of file diff --git a/tensorflow/contrib/cmake/tf_core_framework.cmake b/tensorflow/contrib/cmake/tf_core_framework.cmake index 7e806685b8448cbd629985cdc00ed1193857abe6..d7b2a1339e047aba0a9424a53a63726805e89721 100644 --- a/tensorflow/contrib/cmake/tf_core_framework.cmake +++ b/tensorflow/contrib/cmake/tf_core_framework.cmake @@ -140,16 +140,19 @@ set(tf_proto_text_srcs "tensorflow/core/example/example.proto" "tensorflow/core/example/feature.proto" "tensorflow/core/framework/allocation_description.proto" + "tensorflow/core/framework/api_def.proto" "tensorflow/core/framework/attr_value.proto" "tensorflow/core/framework/cost_graph.proto" "tensorflow/core/framework/device_attributes.proto" "tensorflow/core/framework/function.proto" "tensorflow/core/framework/graph.proto" "tensorflow/core/framework/graph_transfer_info.proto" + "tensorflow/core/framework/iterator.proto" "tensorflow/core/framework/kernel_def.proto" "tensorflow/core/framework/log_memory.proto" "tensorflow/core/framework/node_def.proto" "tensorflow/core/framework/op_def.proto" + "tensorflow/core/framework/reader_base.proto" "tensorflow/core/framework/remote_fused_graph_execute_info.proto" "tensorflow/core/framework/resource_handle.proto" "tensorflow/core/framework/step_stats.proto" @@ -159,6 +162,7 @@ set(tf_proto_text_srcs "tensorflow/core/framework/tensor_shape.proto" "tensorflow/core/framework/tensor_slice.proto" "tensorflow/core/framework/types.proto" + "tensorflow/core/framework/variable.proto" "tensorflow/core/framework/versions.proto" "tensorflow/core/lib/core/error_codes.proto" "tensorflow/core/protobuf/cluster.proto" @@ -204,10 +208,10 @@ file(GLOB tf_core_platform_srcs "${tensorflow_source_dir}/tensorflow/core/framework/resource_handle.h" "${tensorflow_source_dir}/tensorflow/core/framework/resource_handle.cc") if (NOT tensorflow_ENABLE_GPU) - file(GLOB tf_core_platform_gpu_srcs + file(GLOB tf_core_platform_gpu_srcs_exclude "${tensorflow_source_dir}/tensorflow/core/platform/cuda_libdevice_path.*" "${tensorflow_source_dir}/tensorflow/core/platform/default/cuda_libdevice_path.*") - list(REMOVE_ITEM tf_core_platform_srcs ${tf_core_platform_gpu_srcs}) + list(REMOVE_ITEM tf_core_platform_srcs ${tf_core_platform_gpu_srcs_exclude}) else() file(GLOB tf_core_platform_srcs_exclude "${tensorflow_source_dir}/tensorflow/core/platform/default/device_tracer.cc") diff --git a/tensorflow/contrib/cmake/tf_core_ops.cmake b/tensorflow/contrib/cmake/tf_core_ops.cmake index 9cfa8b90749280b6aa815cc210941c75bd5e16c5..310eed4ecbfdd30a3b3bdd4728c030fe70930797 100644 --- a/tensorflow/contrib/cmake/tf_core_ops.cmake +++ b/tensorflow/contrib/cmake/tf_core_ops.cmake @@ -13,13 +13,14 @@ # limitations under the License. # ============================================================================== set(tf_op_lib_names - "audio_ops" "array_ops" + "audio_ops" "batch_ops" "bitwise_ops" "boosted_trees_ops" "candidate_sampling_ops" "checkpoint_ops" + "collective_ops" "control_flow_ops" "ctc_ops" "cudnn_rnn_ops" @@ -27,13 +28,14 @@ set(tf_op_lib_names "dataset_ops" "decode_proto_ops" "encode_proto_ops" + "function_ops" "functional_ops" "image_ops" "io_ops" "linalg_ops" "list_ops" - "lookup_ops" "logging_ops" + "lookup_ops" "manip_ops" "math_ops" "nn_ops" @@ -43,10 +45,11 @@ set(tf_op_lib_names "remote_fused_graph_ops" "resource_variable_ops" "rpc_ops" + "scoped_allocator_ops" "script_ops" "sdca_ops" - "set_ops" "sendrecv_ops" + "set_ops" "sparse_ops" "spectral_ops" "state_ops" @@ -54,6 +57,7 @@ set(tf_op_lib_names "string_ops" "summary_ops" "training_ops" + "word2vec_ops" ) foreach(tf_op_lib_name ${tf_op_lib_names}) diff --git a/tensorflow/contrib/cmake/tf_python.cmake b/tensorflow/contrib/cmake/tf_python.cmake index df7b854afcca1a0bed660624152f465d4bf3b25f..8faccf8d55902e6701ebb4ce534b84705304fd5f 100755 --- a/tensorflow/contrib/cmake/tf_python.cmake +++ b/tensorflow/contrib/cmake/tf_python.cmake @@ -313,15 +313,14 @@ function(GENERATE_PYTHON_OP_LIB tf_python_op_lib_name) ${GENERATE_PYTHON_OP_LIB_DESTINATION} PARENT_SCOPE) endfunction() -GENERATE_PYTHON_OP_LIB("audio_ops") GENERATE_PYTHON_OP_LIB("array_ops") +GENERATE_PYTHON_OP_LIB("audio_ops") GENERATE_PYTHON_OP_LIB("batch_ops") GENERATE_PYTHON_OP_LIB("bitwise_ops") GENERATE_PYTHON_OP_LIB("boosted_trees_ops") -GENERATE_PYTHON_OP_LIB("math_ops") -GENERATE_PYTHON_OP_LIB("functional_ops") GENERATE_PYTHON_OP_LIB("candidate_sampling_ops") GENERATE_PYTHON_OP_LIB("checkpoint_ops") +GENERATE_PYTHON_OP_LIB("collective_ops") GENERATE_PYTHON_OP_LIB("control_flow_ops" ADDITIONAL_LIBRARIES $) GENERATE_PYTHON_OP_LIB("ctc_ops") @@ -332,14 +331,18 @@ GENERATE_PYTHON_OP_LIB("decode_proto_ops" DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/proto/python/ops/gen_decode_proto_op.py) GENERATE_PYTHON_OP_LIB("encode_proto_ops" DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/proto/python/ops/gen_encode_proto_op.py) +GENERATE_PYTHON_OP_LIB("function_ops") +GENERATE_PYTHON_OP_LIB("functional_ops") GENERATE_PYTHON_OP_LIB("image_ops") GENERATE_PYTHON_OP_LIB("io_ops") GENERATE_PYTHON_OP_LIB("linalg_ops") GENERATE_PYTHON_OP_LIB("list_ops") GENERATE_PYTHON_OP_LIB("logging_ops") GENERATE_PYTHON_OP_LIB("lookup_ops") -GENERATE_PYTHON_OP_LIB("nn_ops") GENERATE_PYTHON_OP_LIB("manip_ops") +GENERATE_PYTHON_OP_LIB("math_ops") +GENERATE_PYTHON_OP_LIB("nn_ops") +GENERATE_PYTHON_OP_LIB("no_op") GENERATE_PYTHON_OP_LIB("parsing_ops") GENERATE_PYTHON_OP_LIB("random_ops") GENERATE_PYTHON_OP_LIB("remote_fused_graph_ops" @@ -347,17 +350,21 @@ GENERATE_PYTHON_OP_LIB("remote_fused_graph_ops" GENERATE_PYTHON_OP_LIB("resource_variable_ops") GENERATE_PYTHON_OP_LIB("rpc_ops" DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/rpc/python/ops/gen_rpc_op.py) +GENERATE_PYTHON_OP_LIB("scoped_allocator_ops") GENERATE_PYTHON_OP_LIB("script_ops") GENERATE_PYTHON_OP_LIB("sdca_ops") +GENERATE_PYTHON_OP_LIB("sendrecv_ops") GENERATE_PYTHON_OP_LIB("set_ops") -GENERATE_PYTHON_OP_LIB("state_ops") GENERATE_PYTHON_OP_LIB("sparse_ops") GENERATE_PYTHON_OP_LIB("spectral_ops") +GENERATE_PYTHON_OP_LIB("state_ops") +GENERATE_PYTHON_OP_LIB("stateless_random_ops") GENERATE_PYTHON_OP_LIB("string_ops") GENERATE_PYTHON_OP_LIB("summary_ops") GENERATE_PYTHON_OP_LIB("user_ops") GENERATE_PYTHON_OP_LIB("training_ops" DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/python/training/gen_training_ops.py) +GENERATE_PYTHON_OP_LIB("word2vec_ops") GENERATE_PYTHON_OP_LIB("contrib_boosted_trees_model_ops" DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/boosted_trees/python/ops/gen_model_ops.py) @@ -391,11 +398,8 @@ GENERATE_PYTHON_OP_LIB("contrib_layers_sparse_feature_cross_ops" DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/layers/ops/gen_sparse_feature_cross_op.py) GENERATE_PYTHON_OP_LIB("contrib_memory_stats_ops" DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/memory_stats/ops/gen_memory_stats_ops.py) -GENERATE_PYTHON_OP_LIB("contrib_nccl_ops" - DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/nccl/ops/gen_nccl_ops.py) GENERATE_PYTHON_OP_LIB("contrib_periodic_resample_ops" DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/periodic_resample/python/ops/gen_periodic_resample_op.py) - GENERATE_PYTHON_OP_LIB("contrib_nearest_neighbor_ops" DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/nearest_neighbor/ops/gen_nearest_neighbor_ops.py) GENERATE_PYTHON_OP_LIB("contrib_resampler_ops" @@ -420,8 +424,6 @@ GENERATE_PYTHON_OP_LIB("contrib_bigquery_reader_ops" DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/cloud/python/ops/gen_bigquery_reader_ops.py) GENERATE_PYTHON_OP_LIB("contrib_gcs_config_ops" DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/cloud/python/ops/gen_gcs_config_ops.py) -GENERATE_PYTHON_OP_LIB("stateless_random_ops" - DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/stateless/gen_stateless_random_ops.py) GENERATE_PYTHON_OP_LIB("debug_ops" DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/python/debug/ops/gen_debug_ops.py) @@ -524,11 +526,13 @@ if(WIN32) add_library(pywrap_tensorflow_internal_static STATIC ${pywrap_tensorflow_internal_src} $ + $ $ $ $ $ $ + $ $ $ $ @@ -581,11 +585,13 @@ endif(WIN32) add_library(pywrap_tensorflow_internal SHARED ${pywrap_tensorflow_internal_src} $ + $ $ $ $ $ $ + $ $ $ $ @@ -615,13 +621,28 @@ target_include_directories(pywrap_tensorflow_internal PUBLIC ${NUMPY_INCLUDE_DIR} ) -target_link_libraries(pywrap_tensorflow_internal PRIVATE +if(CMAKE_COMPILER_IS_GNUCC AND CMAKE_CXX_COMPILER_VERSION VERSION_GREATER 5.0) + # There is a bug in GCC 5 resulting in undefined reference to a __cpu_model function when + # linking to the tensorflow library. Adding the following libraries fixes it. + # See issue on github: https://github.com/tensorflow/tensorflow/issues/9593 + target_link_libraries(pywrap_tensorflow_internal PRIVATE ${tf_core_gpu_kernels_lib} ${tensorflow_EXTERNAL_LIBRARIES} tf_protos_cc tf_python_protos_cc ${PYTHON_LIBRARIES} + gcc_s + gcc ) +else() + target_link_libraries(pywrap_tensorflow_internal PRIVATE + ${tf_core_gpu_kernels_lib} + ${tensorflow_EXTERNAL_LIBRARIES} + tf_protos_cc + tf_python_protos_cc + ${PYTHON_LIBRARIES} +) +endif() if(WIN32) @@ -806,10 +827,10 @@ add_dependencies(tf_python_api tf_python_ops) ######################################################## # Parse tensorflow/python/tools/api/generator/BUILD to get list of generated files. -FILE(READ ${tensorflow_source_dir}/tensorflow/python/tools/api/generator/api_gen.bzl api_generator_BUILD_text) -STRING(REGEX MATCH "# BEGIN GENERATED ESTIMATOR FILES.*# END GENERATED ESTIMATOR FILES" api_init_files_text ${api_generator_BUILD_text}) -string(REPLACE "# BEGIN GENERATED ESTIMATOR FILES" "" api_init_files_text ${api_init_files_text}) -string(REPLACE "# END GENERATED ESTIMATOR FILES" "" api_init_files_text ${api_init_files_text}) +FILE(READ ${tensorflow_source_dir}/tensorflow/python/tools/api/generator/api_init_files.bzl api_generator_BUILD_text) +STRING(REGEX MATCH "# BEGIN GENERATED FILES.*# END GENERATED FILES" api_init_files_text ${api_generator_BUILD_text}) +string(REPLACE "# BEGIN GENERATED FILES" "" api_init_files_text ${api_init_files_text}) +string(REPLACE "# END GENERATED FILES" "" api_init_files_text ${api_init_files_text}) string(REPLACE "," ";" api_init_files_list ${api_init_files_text}) set(api_init_files "") diff --git a/tensorflow/contrib/cmake/tf_shared_lib.cmake b/tensorflow/contrib/cmake/tf_shared_lib.cmake index fdf522f1fd90ffc64acbe82381ef57a389645d61..62005dd113bfb80fbdf23afb6d4aa5f90a1e32de 100644 --- a/tensorflow/contrib/cmake/tf_shared_lib.cmake +++ b/tensorflow/contrib/cmake/tf_shared_lib.cmake @@ -23,6 +23,8 @@ if(WIN32) # we need. # add_library(tensorflow_static STATIC + $ + $ $ $ $ @@ -65,6 +67,8 @@ endif(WIN32) # tensorflow is a shared library containing all of the # TensorFlow runtime and the standard ops and kernels. add_library(tensorflow SHARED + $ + $ $ $ $ @@ -96,6 +100,27 @@ if(CMAKE_COMPILER_IS_GNUCC AND CMAKE_CXX_COMPILER_VERSION VERSION_GREATER 5.0) target_link_libraries(tensorflow PRIVATE gcc_s gcc) endif() +# Offer the user the choice of overriding the installation directories +set(INSTALL_LIB_DIR lib CACHE PATH "Installation directory for libraries") +set(INSTALL_BIN_DIR bin CACHE PATH "Installation directory for executables") +set(INSTALL_INCLUDE_DIR include CACHE PATH + "Installation directory for header files") +if(WIN32 AND NOT CYGWIN) + set(DEF_INSTALL_CMAKE_DIR cmake) +else() + set(DEF_INSTALL_CMAKE_DIR lib/cmake) +endif() +set(INSTALL_CMAKE_DIR ${DEF_INSTALL_CMAKE_DIR} CACHE PATH + "Installation directory for CMake files") + +# Make relative paths absolute (needed later on) +foreach(p LIB BIN INCLUDE CMAKE) + set(var INSTALL_${p}_DIR) + if(NOT IS_ABSOLUTE "${${var}}") + set(${var} "${CMAKE_INSTALL_PREFIX}/${${var}}") + endif() +endforeach() + if(WIN32) add_dependencies(tensorflow tensorflow_static) endif(WIN32) @@ -103,14 +128,57 @@ endif(WIN32) target_include_directories(tensorflow PUBLIC $) -install(TARGETS tensorflow EXPORT tensorflow_export - RUNTIME DESTINATION bin - LIBRARY DESTINATION lib - ARCHIVE DESTINATION lib) +# Add all targets to build-tree export set +export(TARGETS tensorflow + FILE ${PROJECT_BINARY_DIR}/TensorflowTargets.cmake) + +# Export the package for use from the build-tree +export(PACKAGE Tensorflow) + +# Create the TensorflowConfig.cmake and TensorflowConfigVersion files +file(RELATIVE_PATH REL_INCLUDE_DIR "${INSTALL_CMAKE_DIR}" + "${INSTALL_INCLUDE_DIR}") +# for the build tree +set(CONF_INCLUDE_DIRS "${tensorflow_source_dir}" + "${PROJECT_BINARY_DIR}" + "${CMAKE_CURRENT_BINARY_DIR}/protobuf/src/protobuf/src" + "${CMAKE_CURRENT_BINARY_DIR}/nsync/install/include" # Please if there is a better directory + "${CMAKE_CURRENT_BINARY_DIR}/eigen/src/eigen/Eigen/" + "${CMAKE_CURRENT_BINARY_DIR}/external/eigen_archive/" + "${tensorflow_source_dir}/third_party/eigen3/" + "${CMAKE_CURRENT_BINARY_DIR}/eigen/src/eigen/unsupported/Eigen/") +configure_file(TensorflowConfig.cmake.in + "${PROJECT_BINARY_DIR}/TensorflowConfig.cmake" @ONLY) +# for the install tree, yet to be complete +set(CONF_INCLUDE_DIRS "\${TENSORFLOW_CMAKE_DIR}/${REL_INCLUDE_DIR}") +configure_file(TensorflowConfig.cmake.in + "${PROJECT_BINARY_DIR}/${CMAKE_FILES_DIRECTORY}/TensorflowConfig.cmake" @ONLY) +# for both +configure_file(TensorflowConfigVersion.cmake.in + "${PROJECT_BINARY_DIR}/TensorflowConfigVersion.cmake" @ONLY) + +# install(TARGETS tensorflow EXPORT tensorflow_export +# RUNTIME DESTINATION ${INSTALL_BIN_DIR} +# LIBRARY DESTINATION ${INSTALL_LIB_DIR} +# ARCHIVE DESTINATION ${INSTALL_LIB_DIR}) + +# install(EXPORT tensorflow_export +# FILE TensorflowConfig.cmake +# DESTINATION ${INSTALL_CMAKE_DIR}) -install(EXPORT tensorflow_export - FILE TensorflowConfig.cmake - DESTINATION lib/cmake) +install(FILES + "${PROJECT_BINARY_DIR}/${CMAKE_FILES_DIRECTORY}/TensorflowConfig.cmake" + "${PROJECT_BINARY_DIR}/TensorflowConfigVersion.cmake" + DESTINATION "${INSTALL_CMAKE_DIR}" COMPONENT dev) + +# install the export set for use with the install-tree +install(EXPORT TensorflowTargets + DESTINATION ${INSTALL_CMAKE_DIR}) + +install(TARGETS tensorflow EXPORT TensorflowTargets + RUNTIME DESTINATION ${INSTALL_BIN_DIR} + LIBRARY DESTINATION ${INSTALL_LIB_DIR} + ARCHIVE DESTINATION ${INSTALL_LIB_DIR}) # install necessary headers # tensorflow headers @@ -145,6 +213,10 @@ install(DIRECTORY ${tensorflow_source_dir}/third_party/eigen3/ # unsupported Eigen directory install(DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/eigen/src/eigen/unsupported/Eigen/ DESTINATION include/unsupported/Eigen) +# absl directory +install(DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/abseil_cpp/src/abseil_cpp/absl/ + DESTINATION include/absl + FILES_MATCHING PATTERN "*.h") # mkl if (tensorflow_ENABLE_MKL_SUPPORT) install(DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/mkl/src/mkl/include/ diff --git a/tensorflow/contrib/crf/python/ops/crf.py b/tensorflow/contrib/crf/python/ops/crf.py index 656633f0bf21a4d46cb85547241ef0fd42807ed6..40e159b8fcbd1864284e208cb15d9ed96119f840 100644 --- a/tensorflow/contrib/crf/python/ops/crf.py +++ b/tensorflow/contrib/crf/python/ops/crf.py @@ -38,12 +38,12 @@ tf_unary_scores, tf_sequence_lengths, tf_transition_params, _ = session.run( [unary_scores, sequence_lengths, transition_params, train_op]) for tf_unary_scores_, tf_sequence_length_ in zip(tf_unary_scores, tf_sequence_lengths): -# Remove padding. -tf_unary_scores_ = tf_unary_scores_[:tf_sequence_length_] + # Remove padding. + tf_unary_scores_ = tf_unary_scores_[:tf_sequence_length_] -# Compute the highest score and its tag sequence. -tf_viterbi_sequence, tf_viterbi_score = tf.contrib.crf.viterbi_decode( - tf_unary_scores_, tf_transition_params) + # Compute the highest score and its tag sequence. + tf_viterbi_sequence, tf_viterbi_score = tf.contrib.crf.viterbi_decode( + tf_unary_scores_, tf_transition_params) """ from __future__ import absolute_import diff --git a/tensorflow/contrib/data/python/kernel_tests/assert_element_shape_test.py b/tensorflow/contrib/data/python/kernel_tests/assert_element_shape_test.py index 0456463a1928cf226010670b90a5d574579e0411..6c5f8c6b00975b3fba041271309a93cecd9f5057 100644 --- a/tensorflow/contrib/data/python/kernel_tests/assert_element_shape_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/assert_element_shape_test.py @@ -46,7 +46,7 @@ class AssertElementShapeTest(test_base.DatasetTestBase): result = dataset.apply(batching.assert_element_shape(expected_shapes)) self.assertEqual(expected_shapes, result.output_shapes) - iterator = result.make_initializable_iterator() + iterator = dataset_ops.make_initializable_iterator(result) init_op = iterator.initializer get_next = iterator.get_next() with self.cached_session() as sess: @@ -88,7 +88,7 @@ class AssertElementShapeTest(test_base.DatasetTestBase): result = dataset.apply(batching.assert_element_shape(expected_shapes)) self.assertEqual(expected_shapes, result.output_shapes) - iterator = result.make_initializable_iterator() + iterator = dataset_ops.make_initializable_iterator(result) init_op = iterator.initializer get_next = iterator.get_next() with self.cached_session() as sess: @@ -115,9 +115,8 @@ class AssertElementShapeTest(test_base.DatasetTestBase): wrong_shapes = (tensor_shape.TensorShape(2), tensor_shape.TensorShape((3, 10))) - iterator = ( - dataset.apply(batching.assert_element_shape(wrong_shapes)) - .make_initializable_iterator()) + iterator = dataset_ops.make_initializable_iterator( + dataset.apply(batching.assert_element_shape(wrong_shapes))) init_op = iterator.initializer get_next = iterator.get_next() with self.cached_session() as sess: @@ -142,7 +141,7 @@ class AssertElementShapeTest(test_base.DatasetTestBase): tensor_shape.TensorShape((3, 4))) self.assertEqual(actual_shapes, result.output_shapes) - iterator = result.make_initializable_iterator() + iterator = dataset_ops.make_initializable_iterator(result) init_op = iterator.initializer get_next = iterator.get_next() with self.cached_session() as sess: @@ -184,7 +183,7 @@ class AssertElementShapeTest(test_base.DatasetTestBase): result = dataset.apply(batching.assert_element_shape(expected_shapes)) self.assertEqual(expected_shapes, result.output_shapes) - iterator = result.make_initializable_iterator() + iterator = dataset_ops.make_initializable_iterator(result) init_op = iterator.initializer get_next = iterator.get_next() with self.cached_session() as sess: @@ -211,9 +210,8 @@ class AssertElementShapeTest(test_base.DatasetTestBase): wrong_shapes = (tensor_shape.TensorShape(2), tensor_shape.TensorShape((None, 10))) - iterator = ( - dataset.apply(batching.assert_element_shape(wrong_shapes)) - .make_initializable_iterator()) + iterator = dataset_ops.make_initializable_iterator( + dataset.apply(batching.assert_element_shape(wrong_shapes))) init_op = iterator.initializer get_next = iterator.get_next() with self.cached_session() as sess: diff --git a/tensorflow/contrib/data/python/kernel_tests/lmdb_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/lmdb_dataset_op_test.py index d2a72272db159755ac2d741bcdbce9ec646d928e..b9840b1ff1a3df5a05db0e64f436637220f49f80 100644 --- a/tensorflow/contrib/data/python/kernel_tests/lmdb_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/lmdb_dataset_op_test.py @@ -23,6 +23,7 @@ import shutil from tensorflow.contrib.data.python.ops import readers from tensorflow.python.data.kernel_tests import test_base +from tensorflow.python.data.ops import dataset_ops from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors @@ -48,7 +49,7 @@ class LMDBDatasetTest(test_base.DatasetTestBase): num_repeats = 2 dataset = readers.LMDBDataset(filenames).repeat(num_repeats) - iterator = dataset.make_initializable_iterator() + iterator = dataset_ops.make_initializable_iterator(dataset) init_op = iterator.initializer get_next = iterator.get_next() diff --git a/tensorflow/contrib/data/python/kernel_tests/slide_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/slide_dataset_op_test.py index c5a786232252432481566e3cde23e9310df172cc..2527706709fae8e459aca3489324d4db3c784be6 100644 --- a/tensorflow/contrib/data/python/kernel_tests/slide_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/slide_dataset_op_test.py @@ -63,13 +63,13 @@ class SlideDatasetTest(test_base.DatasetTestBase, parameterized.TestCase): # The pipeline is TensorSliceDataset -> MapDataset(square_3) -> # RepeatDataset(count) -> # _SlideDataset(window_size, window_shift, window_stride). - iterator = ( + iterator = dataset_ops.make_initializable_iterator( dataset_ops.Dataset.from_tensor_slices(components).map(_map_fn) .repeat(count).apply( sliding.sliding_window_batch( window_size=window_size_t, window_shift=window_shift_t, - window_stride=window_stride_t)).make_initializable_iterator()) + window_stride=window_stride_t))) init_op = iterator.initializer get_next = iterator.get_next() @@ -127,13 +127,13 @@ class SlideDatasetTest(test_base.DatasetTestBase, parameterized.TestCase): # The pipeline is TensorSliceDataset -> MapDataset(square_3) -> # RepeatDataset(count) -> _SlideDataset(window_size, stride, window_stride). - iterator = ( + iterator = dataset_ops.make_initializable_iterator( dataset_ops.Dataset.from_tensor_slices(components).map(_map_fn) .repeat(count).apply( sliding.sliding_window_batch( window_size=window_size_t, stride=stride_t, - window_stride=window_stride_t)).make_initializable_iterator()) + window_stride=window_stride_t))) init_op = iterator.initializer get_next = iterator.get_next() @@ -173,12 +173,12 @@ class SlideDatasetTest(test_base.DatasetTestBase, parameterized.TestCase): window_shift_t = array_ops.placeholder(dtypes.int64, shape=[]) window_stride_t = array_ops.placeholder(dtypes.int64, shape=[]) - iterator = ( + iterator = dataset_ops.make_initializable_iterator( dataset_ops.Dataset.range(10).map(lambda x: x).repeat(count_t).apply( sliding.sliding_window_batch( window_size=window_size_t, window_shift=window_shift_t, - window_stride=window_stride_t)).make_initializable_iterator()) + window_stride=window_stride_t))) init_op = iterator.initializer with self.cached_session() as sess: @@ -204,9 +204,9 @@ class SlideDatasetTest(test_base.DatasetTestBase, parameterized.TestCase): return sparse_tensor.SparseTensorValue( indices=[[0]], values=(i * [1]), dense_shape=[1]) - iterator = dataset_ops.Dataset.range(10).map(_sparse).apply( - sliding.sliding_window_batch( - window_size=5, window_shift=3)).make_initializable_iterator() + iterator = dataset_ops.make_initializable_iterator( + dataset_ops.Dataset.range(10).map(_sparse).apply( + sliding.sliding_window_batch(window_size=5, window_shift=3))) init_op = iterator.initializer get_next = iterator.get_next() @@ -233,9 +233,9 @@ class SlideDatasetTest(test_base.DatasetTestBase, parameterized.TestCase): values=array_ops.fill([math_ops.to_int32(i)], i), dense_shape=[i]) - iterator = dataset_ops.Dataset.range(10).map(_sparse).apply( - sliding.sliding_window_batch( - window_size=5, window_shift=3)).make_initializable_iterator() + iterator = dataset_ops.make_initializable_iterator( + dataset_ops.Dataset.range(10).map(_sparse).apply( + sliding.sliding_window_batch(window_size=5, window_shift=3))) init_op = iterator.initializer get_next = iterator.get_next() @@ -265,11 +265,10 @@ class SlideDatasetTest(test_base.DatasetTestBase, parameterized.TestCase): return sparse_tensor.SparseTensorValue( indices=[[0]], values=(i * [1]), dense_shape=[1]) - iterator = ( + iterator = dataset_ops.make_initializable_iterator( dataset_ops.Dataset.range(10).map(_sparse).apply( sliding.sliding_window_batch(window_size=4, window_shift=2)).apply( - sliding.sliding_window_batch(window_size=3, window_shift=1)) - .make_initializable_iterator()) + sliding.sliding_window_batch(window_size=3, window_shift=1))) init_op = iterator.initializer get_next = iterator.get_next() @@ -305,11 +304,10 @@ class SlideDatasetTest(test_base.DatasetTestBase, parameterized.TestCase): yield [4.0, 5.0, 6.0] yield [7.0, 8.0, 9.0, 10.0] - iterator = ( + iterator = dataset_ops.make_initializable_iterator( dataset_ops.Dataset.from_generator( generator, dtypes.float32, output_shapes=[None]).apply( - sliding.sliding_window_batch(window_size=3, window_shift=1)) - .make_initializable_iterator()) + sliding.sliding_window_batch(window_size=3, window_shift=1))) next_element = iterator.get_next() with self.cached_session() as sess: diff --git a/tensorflow/contrib/data/python/ops/BUILD b/tensorflow/contrib/data/python/ops/BUILD index 34dc2379d0cb38f8f6962fa42efe21b793bc8d65..0fb406f1167053a128646c5c692986b0ce016f1e 100644 --- a/tensorflow/contrib/data/python/ops/BUILD +++ b/tensorflow/contrib/data/python/ops/BUILD @@ -188,8 +188,7 @@ py_library( "//tensorflow/python:framework_ops", "//tensorflow/python:function", "//tensorflow/python/data/ops:dataset_ops", - "//tensorflow/python/data/util:nest", - "//tensorflow/python/data/util:sparse", + "//tensorflow/python/data/util:structure", ], ) diff --git a/tensorflow/contrib/data/python/ops/readers.py b/tensorflow/contrib/data/python/ops/readers.py index 4601376dff47e161962e92678883039c4b88bab7..c0152156a1ba70297adb7054622b15ca04f859cd 100644 --- a/tensorflow/contrib/data/python/ops/readers.py +++ b/tensorflow/contrib/data/python/ops/readers.py @@ -21,10 +21,9 @@ from tensorflow.python.data.experimental.ops import optimization from tensorflow.python.data.experimental.ops import readers from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.ops import readers as core_readers -from tensorflow.python.data.util import nest +from tensorflow.python.data.util import structure from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops -from tensorflow.python.framework import tensor_shape from tensorflow.python.ops import gen_experimental_dataset_ops from tensorflow.python.util import deprecation @@ -355,7 +354,7 @@ def read_batch_features(file_pattern, shuffle=randomize_input, num_epochs=num_epochs, shuffle_buffer_size=capacity) - iterator = dataset.make_one_shot_iterator() + iterator = dataset_ops.make_one_shot_iterator(dataset) outputs = iterator.get_next() return outputs @@ -379,15 +378,13 @@ class LMDBDataset(dataset_ops.DatasetSource): (key value) pairs sequentially. For example: ```python + tf.enable_eager_execution() + dataset = tf.contrib.lmdb.LMDBDataset("/foo/bar.mdb") - iterator = dataset.make_one_shot_iterator() - next_element = iterator.get_next() + # Prints the (key, value) pairs inside a lmdb file. - while True: - try: - print(sess.run(next_element)) - except tf.errors.OutOfRangeError: - break + for key, value in dataset: + print(key, value) ``` Args: filenames: A `tf.string` tensor containing one or more filenames. @@ -398,18 +395,10 @@ class LMDBDataset(dataset_ops.DatasetSource): def _as_variant_tensor(self): return gen_experimental_dataset_ops.experimental_lmdb_dataset( - self._filenames, - output_types=nest.flatten(self.output_types), - output_shapes=nest.flatten(self.output_shapes)) - - @property - def output_classes(self): - return ops.Tensor, ops.Tensor - - @property - def output_shapes(self): - return (tensor_shape.TensorShape([]), tensor_shape.TensorShape([])) + self._filenames, **dataset_ops.flat_structure(self)) @property - def output_types(self): - return dtypes.string, dtypes.string + def _element_structure(self): + return structure.NestedStructure( + (structure.TensorStructure(dtypes.string, []), + structure.TensorStructure(dtypes.string, []))) diff --git a/tensorflow/contrib/data/python/ops/sliding.py b/tensorflow/contrib/data/python/ops/sliding.py index bcc383587c54bd89502313f9328bc06c49046a87..5c6ee6bfdc7167d14b292f8f763adafca4e3a72c 100644 --- a/tensorflow/contrib/data/python/ops/sliding.py +++ b/tensorflow/contrib/data/python/ops/sliding.py @@ -18,11 +18,10 @@ from __future__ import division from __future__ import print_function from tensorflow.python.data.ops import dataset_ops -from tensorflow.python.data.util import nest +from tensorflow.python.data.util import structure from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops -from tensorflow.python.framework import tensor_shape -from tensorflow.python.ops import gen_dataset_ops +from tensorflow.python.ops import gen_experimental_dataset_ops as ged_ops from tensorflow.python.util import deprecation @@ -40,8 +39,13 @@ class _SlideDataset(dataset_ops.UnaryDataset): self._window_shift = ops.convert_to_tensor( window_shift, dtype=dtypes.int64, name="window_shift") + input_structure = structure.convert_legacy_structure( + input_dataset.output_types, input_dataset.output_shapes, + input_dataset.output_classes) + self._structure = input_structure._batch(None) # pylint: disable=protected-access + def _as_variant_tensor(self): - return gen_dataset_ops.slide_dataset( + return ged_ops.experimental_sliding_window_dataset( self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access window_size=self._window_size, window_shift=self._window_shift, @@ -49,20 +53,8 @@ class _SlideDataset(dataset_ops.UnaryDataset): **dataset_ops.flat_structure(self)) @property - def output_classes(self): - return self._input_dataset.output_classes - - @property - def output_shapes(self): - input_shapes = self._input_dataset.output_shapes - return nest.pack_sequence_as(input_shapes, [ - tensor_shape.vector(None).concatenate(s) - for s in nest.flatten(self._input_dataset.output_shapes) - ]) - - @property - def output_types(self): - return self._input_dataset.output_types + def _element_structure(self): + return self._structure @deprecation.deprecated_args( diff --git a/tensorflow/contrib/distribute/python/BUILD b/tensorflow/contrib/distribute/python/BUILD index 249258def3c4e52604b63764d8a7b5f238b45daa..4c9c35da5a36aa8149d15c8d1c25e4dfaa6a07c1 100644 --- a/tensorflow/contrib/distribute/python/BUILD +++ b/tensorflow/contrib/distribute/python/BUILD @@ -591,6 +591,7 @@ py_library( "//tensorflow/contrib/distribute/python:tpu_strategy", "//tensorflow/python:client_testlib", "//tensorflow/python:training", + "//tensorflow/python/eager:test", "//tensorflow/python/estimator:estimator_py", "//tensorflow/python/keras", "//third_party/py/numpy", diff --git a/tensorflow/contrib/distribute/python/collective_all_reduce_strategy.py b/tensorflow/contrib/distribute/python/collective_all_reduce_strategy.py index 6e9f9facd0a209146d1ad8d101f0b8c41d77752a..346513dc586f208315fd777dc7ddfa500c82f0d7 100644 --- a/tensorflow/contrib/distribute/python/collective_all_reduce_strategy.py +++ b/tensorflow/contrib/distribute/python/collective_all_reduce_strategy.py @@ -67,30 +67,31 @@ class CollectiveAllReduceExtended(mirrored_strategy.MirroredExtended): def __init__(self, container_strategy, num_gpus_per_worker): distribute_lib.DistributionStrategyExtended.__init__( self, container_strategy) + self._cross_device_ops = None self._num_gpus_per_worker = num_gpus_per_worker - self._initialize_local_worker(container_strategy, num_gpus_per_worker) + self._initialize_local_worker(num_gpus_per_worker) + assert isinstance(self._get_cross_device_ops(), + cross_device_ops_lib.CollectiveAllReduce) - def _initialize_local_worker(self, container_strategy, num_gpus_per_worker): + def _initialize_local_worker(self, num_gpus_per_worker): """Initializes the object for local training.""" self._is_chief = True self._num_workers = 1 if num_gpus_per_worker: - local_devices = [ + local_devices = tuple( "/device:GPU:%d" % i for i in range(num_gpus_per_worker) - ] + ) else: - local_devices = ["/device:CPU:0"] + local_devices = ("/device:CPU:0",) self._worker_device = device_util.canonicalize("/device:CPU:0") self._collective_keys = cross_device_utils.CollectiveKeys() - super(CollectiveAllReduceExtended, self).__init__( - container_strategy, - devices=local_devices, - cross_device_ops=cross_device_ops_lib.CollectiveAllReduce( - num_workers=1, - num_gpus_per_worker=num_gpus_per_worker, - collective_keys=self._collective_keys)) + self._initialize_local(local_devices) + self._cross_device_ops = cross_device_ops_lib.CollectiveAllReduce( + num_workers=self._num_workers, + num_gpus_per_worker=num_gpus_per_worker, + collective_keys=self._collective_keys) self._cluster_spec = None self._task_type = None @@ -99,13 +100,13 @@ class CollectiveAllReduceExtended(mirrored_strategy.MirroredExtended): logging.info("CollectiveAllReduceStrategy with local_devices = %r", local_devices) - def _initialize_multi_worker(self, container_strategy, num_gpus_per_worker, - cluster_spec, task_type, task_id): + def _initialize_multi_worker(self, num_gpus_per_worker, cluster_spec, + task_type, task_id): """Initializes the object for multi-worker training.""" if task_type is None or task_id is None: raise ValueError("When `cluster_spec` is given, you must also specify " "`task_type` and `task_id`") - if task_type not in ["chief", "worker"]: + if task_type not in ("chief", "worker"): raise ValueError( "Unrecognized task_type: %r, valid task types are: \"chief\", " "\"worker\"." % task_type) @@ -120,21 +121,19 @@ class CollectiveAllReduceExtended(mirrored_strategy.MirroredExtended): self._worker_device = "/job:%s/task:%d" % (task_type, task_id) if num_gpus_per_worker: - local_devices = [ + local_devices = tuple( "%s/device:GPU:%d" % (self._worker_device, i) for i in range(num_gpus_per_worker) - ] + ) else: - local_devices = [self._worker_device] + local_devices = (self._worker_device,) self._collective_keys = cross_device_utils.CollectiveKeys() - super(CollectiveAllReduceExtended, self).__init__( - container_strategy, - devices=local_devices, - cross_device_ops=cross_device_ops_lib.CollectiveAllReduce( - num_workers=self._num_workers, - num_gpus_per_worker=num_gpus_per_worker, - collective_keys=self._collective_keys)) + self._initialize_local(local_devices) + self._cross_device_ops = cross_device_ops_lib.CollectiveAllReduce( + num_workers=self._num_workers, + num_gpus_per_worker=num_gpus_per_worker, + collective_keys=self._collective_keys) # Add a default device so that ops without specified devices will not end up # on other workers. @@ -268,9 +267,10 @@ class CollectiveAllReduceExtended(mirrored_strategy.MirroredExtended): # If a `cluster_spec` is already passed in, do nothing here. # TODO(yuefengz): check `cluster_spec` is the same if this object has # already been initialized with a `cluster_spec`. - self._initialize_multi_worker( - self._container_strategy(), self._num_gpus_per_worker, cluster_spec, - task_type, task_id) + self._initialize_multi_worker(self._num_gpus_per_worker, cluster_spec, + task_type, task_id) + assert isinstance(self._get_cross_device_ops(), + cross_device_ops_lib.CollectiveAllReduce) if session_config: session_config.CopyFrom(self._update_config_proto(session_config)) diff --git a/tensorflow/contrib/distribute/python/collective_all_reduce_strategy_test.py b/tensorflow/contrib/distribute/python/collective_all_reduce_strategy_test.py index eba3585a55375ee1db561a459e079256c53a85cc..6d7cd14ed5ad8a283e3d0d3405efc58fe670f9cd 100644 --- a/tensorflow/contrib/distribute/python/collective_all_reduce_strategy_test.py +++ b/tensorflow/contrib/distribute/python/collective_all_reduce_strategy_test.py @@ -82,7 +82,8 @@ class CollectiveAllReduceStrategyTestBase( instance_key_with_id_start=num_gpus * 10000 + CollectiveAllReduceStrategyTestBase.collective_key_base) distribution.extended._collective_keys = collective_keys - distribution.extended._cross_device_ops._collective_keys = collective_keys + distribution.extended._cross_device_ops._collective_keys = ( + collective_keys) if task_type and task_id is not None: return distribution, 'grpc://' + self._cluster_spec[task_type][ task_id], session_config diff --git a/tensorflow/contrib/distribute/python/combinations.py b/tensorflow/contrib/distribute/python/combinations.py index c5ce29a43632918be555db865891fdbb5d22e941..365ce5cdec79f1914f0c9ccdf59a7dc59e6f819e 100644 --- a/tensorflow/contrib/distribute/python/combinations.py +++ b/tensorflow/contrib/distribute/python/combinations.py @@ -192,7 +192,7 @@ def _augment_with_special_arguments(test_method): kwargs_to_pass[arg] = kwargs[arg] if mode == "eager": - with ops.Graph().as_default(), context.eager_mode(): + with context.eager_mode(): if distribution: kwargs_to_pass["distribution"] = distribution.strategy test_method(**kwargs_to_pass) diff --git a/tensorflow/contrib/distribute/python/cross_device_ops_test.py b/tensorflow/contrib/distribute/python/cross_device_ops_test.py index 3602cc92094ff607187f19e9e1c0ebde45aa6787..d6e9521c1c1115ffdbdcf375ad4017bacb962832 100644 --- a/tensorflow/contrib/distribute/python/cross_device_ops_test.py +++ b/tensorflow/contrib/distribute/python/cross_device_ops_test.py @@ -392,18 +392,16 @@ class MultiWorkerCrossDeviceOpsTest(multi_worker_test_base.MultiWorkerTestBase, # pylint: disable=g-long-lambda combinations.NamedDistribution( "CoreMirroredCPU", - lambda: mirrored_strategy.CoreMirroredStrategy( - num_gpus_per_worker=0), + lambda: mirrored_strategy.CoreMirroredStrategy(["/device:CPU:0"]), required_gpus=0), combinations.NamedDistribution( "CoreMirrored1GPU", - lambda: mirrored_strategy.CoreMirroredStrategy( - num_gpus_per_worker=1), + lambda: mirrored_strategy.CoreMirroredStrategy(["/device:GPU:0"]), required_gpus=1), combinations.NamedDistribution( "CoreMirrored2GPUs", lambda: mirrored_strategy.CoreMirroredStrategy( - num_gpus_per_worker=2), + ["/device:GPU:0", "/device:GPU:1"]), required_gpus=2), ], mode=["graph"]) diff --git a/tensorflow/contrib/distribute/python/estimator_training_test.py b/tensorflow/contrib/distribute/python/estimator_training_test.py index 0f35657a8099523b6ba5b8f0a1a2f289c06b531a..3f55a8a1c8b88d1b8e4031547fa3fbe519983630 100644 --- a/tensorflow/contrib/distribute/python/estimator_training_test.py +++ b/tensorflow/contrib/distribute/python/estimator_training_test.py @@ -24,7 +24,6 @@ import json import os import sys import tempfile -import threading from absl.testing import parameterized import numpy as np @@ -70,57 +69,19 @@ PS = dc._TaskType.PS original_run_std_server = dc._run_std_server -class MockOsEnv(dict): - - def __init__(self, *args): - self._thread_local = threading.local() - super(MockOsEnv, self).__init__(*args) - - def get(self, key, default): - if not hasattr(self._thread_local, "dict"): - self._thread_local.dict = dict() - if key == "TF_CONFIG": - return dict.get(self._thread_local.dict, key, default) - else: - return dict.get(self, key, default) - - def __getitem__(self, key): - if not hasattr(self._thread_local, "dict"): - self._thread_local.dict = dict() - if key == "TF_CONFIG": - return dict.__getitem__(self._thread_local.dict, key) - else: - return dict.__getitem__(self, key) - - def __setitem__(self, key, val): - if not hasattr(self._thread_local, "dict"): - self._thread_local.dict = dict() - if key == "TF_CONFIG": - return dict.__setitem__(self._thread_local.dict, key, val) - else: - return dict.__setitem__(self, key, val) - - -class DistributeCoordinatorIntegrationTest(test.TestCase, - parameterized.TestCase): +class DistributeCoordinatorIntegrationTest( + multi_worker_test_base.IndependentWorkerTestBase, parameterized.TestCase): @classmethod def setUpClass(cls): """Create a local cluster with 2 workers.""" + super(DistributeCoordinatorIntegrationTest, cls).setUpClass() cls._cluster_spec = multi_worker_test_base.create_in_process_cluster( num_workers=3, num_ps=2, has_eval=True) def setUp(self): self._model_dir = tempfile.mkdtemp() - self._mock_os_env = MockOsEnv() - self._mock_context = test.mock.patch.object(os, "environ", - self._mock_os_env) super(DistributeCoordinatorIntegrationTest, self).setUp() - self._mock_context.__enter__() - - def tearDown(self): - self._mock_context.__exit__(None, None, None) - super(DistributeCoordinatorIntegrationTest, self).tearDown() def dataset_input_fn(self, x, y, batch_size, shuffle): @@ -143,8 +104,8 @@ class DistributeCoordinatorIntegrationTest(test.TestCase, def _extract_loss_and_global_step(self, event_folder): """Returns the loss and global step in last event.""" event_paths = glob.glob(os.path.join(event_folder, "events*")) - self.assertGreater(len(event_paths), 0, - msg="Event file not found in dir %s" % event_folder) + self.assertNotEmpty( + event_paths, msg="Event file not found in dir %s" % event_folder) loss = None global_step_count = None @@ -287,6 +248,12 @@ class DistributeCoordinatorIntegrationTest(test.TestCase, ]) self.assertAllEqual((BATCH_SIZE, LABEL_DIMENSION), predicted_proba.shape) + def _get_strategy_object(self, strategy_cls): + if strategy_cls == mirrored_strategy.CoreMirroredStrategy: + return strategy_cls(mirrored_strategy.all_local_devices()) + else: + return strategy_cls(num_gpus_per_worker=context.num_gpus()) + @combinations.generate( combinations.combine( mode=["graph"], @@ -305,12 +272,10 @@ class DistributeCoordinatorIntegrationTest(test.TestCase, required_gpus=[0, 1])) def test_complete_flow_standalone_client(self, train_distribute_cls, eval_distribute_cls): - train_distribute = train_distribute_cls( - num_gpus_per_worker=context.num_gpus()) + train_distribute = self._get_strategy_object(train_distribute_cls) if eval_distribute_cls: - eval_distribute = eval_distribute_cls( - num_gpus_per_worker=context.num_gpus()) + eval_distribute = self._get_strategy_object(eval_distribute_cls) else: eval_distribute = None @@ -337,12 +302,10 @@ class DistributeCoordinatorIntegrationTest(test.TestCase, required_gpus=[0, 1])) def test_estimator_standalone_client(self, train_distribute_cls, eval_distribute_cls): - train_distribute = train_distribute_cls( - num_gpus_per_worker=context.num_gpus()) + train_distribute = self._get_strategy_object(train_distribute_cls) if eval_distribute_cls: - eval_distribute = eval_distribute_cls( - num_gpus_per_worker=context.num_gpus()) + eval_distribute = self._get_strategy_object(eval_distribute_cls) else: eval_distribute = None @@ -362,47 +325,15 @@ class DistributeCoordinatorIntegrationTest(test.TestCase, self._barrier.wait() return ret - def _task_thread(self, train_distribute, eval_distribute, tf_config): - os.environ["TF_CONFIG"] = json.dumps(tf_config) + def _independent_worker_fn( + self, + train_distribute, + eval_distribute, + ): with test.mock.patch.object(dc, "_run_std_server", self._mock_run_std_server): self._complete_flow(train_distribute, eval_distribute) - def _run_task_in_thread(self, cluster_spec, task_type, task_id, - train_distribute, eval_distribute): - if task_type: - tf_config = { - "cluster": cluster_spec, - "task": { - "type": task_type, - "index": task_id - } - } - else: - tf_config = { - "cluster": cluster_spec, - "task": { - "type": task_type, - "index": task_id - } - } - t = threading.Thread( - target=self._task_thread, - args=(train_distribute, eval_distribute, tf_config)) - t.start() - return t - - def _run_multiple_tasks_in_threads(self, cluster_spec, train_distribute, - eval_distribute): - threads = {} - for task_type in cluster_spec.keys(): - threads[task_type] = [] - for task_id in range(len(cluster_spec[task_type])): - t = self._run_task_in_thread(cluster_spec, task_type, task_id, - train_distribute, eval_distribute) - threads[task_type].append(t) - return threads - @combinations.generate( combinations.combine( mode=["graph"], @@ -418,16 +349,14 @@ class DistributeCoordinatorIntegrationTest(test.TestCase, required_gpus=[0, 1])) def test_complete_flow_indepedent_worker_between_graph( self, train_distribute_cls, eval_distribute_cls): - train_distribute = train_distribute_cls( - num_gpus_per_worker=context.num_gpus()) - if (context.num_gpus() < 2 and eval_distribute_cls == collective_all_reduce_strategy.CollectiveAllReduceStrategy): self.skipTest("`CollectiveAllReduceStrategy` needs at least two towers.") + train_distribute = self._get_strategy_object(train_distribute_cls) + if eval_distribute_cls: - eval_distribute = eval_distribute_cls( - num_gpus_per_worker=context.num_gpus()) + eval_distribute = self._get_strategy_object(eval_distribute_cls) else: eval_distribute = None @@ -443,13 +372,16 @@ class DistributeCoordinatorIntegrationTest(test.TestCase, # 3 workers and 1 evaluator. self._barrier = dc._Barrier(4) - threads = self._run_multiple_tasks_in_threads( - cluster_spec, train_distribute, eval_distribute) + threads = self.run_multiple_tasks_in_threads(self._independent_worker_fn, + cluster_spec, train_distribute, + eval_distribute) + threads_to_join = [] for task_type, ts in threads.items(): if task_type == PS: continue for t in ts: - t.join() + threads_to_join.append(t) + self.join_independent_workers(threads_to_join) estimator = self._get_estimator(train_distribute, eval_distribute) self._inspect_train_and_eval_events(estimator) @@ -469,12 +401,10 @@ class DistributeCoordinatorIntegrationTest(test.TestCase, required_gpus=[0, 1])) def test_complete_flow_indepedent_worker_in_graph(self, train_distribute_cls, eval_distribute_cls): - train_distribute = train_distribute_cls( - num_gpus_per_worker=context.num_gpus()) + train_distribute = self._get_strategy_object(train_distribute_cls) if eval_distribute_cls: - eval_distribute = eval_distribute_cls( - num_gpus_per_worker=context.num_gpus()) + eval_distribute = self._get_strategy_object(eval_distribute_cls) else: eval_distribute = None @@ -482,10 +412,10 @@ class DistributeCoordinatorIntegrationTest(test.TestCase, num_workers=3, num_ps=0, has_eval=True) # 3 workers and 1 evaluator. self._barrier = dc._Barrier(4) - threads = self._run_multiple_tasks_in_threads( - cluster_spec, train_distribute, eval_distribute) - threads[WORKER][0].join() - threads[EVALUATOR][0].join() + threads = self.run_multiple_tasks_in_threads(self._independent_worker_fn, + cluster_spec, train_distribute, + eval_distribute) + self.join_independent_workers([threads[WORKER][0], threads[EVALUATOR][0]]) estimator = self._get_estimator(train_distribute, eval_distribute) self._inspect_train_and_eval_events(estimator) @@ -522,7 +452,7 @@ class RunConfigTest(test.TestCase): run_config_lib.RunConfig( experimental_distribute=DistributeConfig( train_distribute=mirrored_strategy.CoreMirroredStrategy( - num_gpus_per_worker=2))) + ["/device:GPU:0", "/device:GPU:1"]))) def test_should_run_distribute_coordinator(self): """Tests that should_run_distribute_coordinator return a correct value.""" @@ -546,11 +476,11 @@ class RunConfigTest(test.TestCase): config_with_train_distribute = run_config_lib.RunConfig( experimental_distribute=DistributeConfig( train_distribute=mirrored_strategy.CoreMirroredStrategy( - num_gpus_per_worker=2))) + ["/device:GPU:0", "/device:GPU:1"]))) config_with_eval_distribute = run_config_lib.RunConfig( experimental_distribute=DistributeConfig( eval_distribute=mirrored_strategy.CoreMirroredStrategy( - num_gpus_per_worker=2))) + ["/device:GPU:0", "/device:GPU:1"]))) self.assertTrue( dc_training.should_run_distribute_coordinator( config_with_train_distribute)) @@ -564,7 +494,7 @@ class RunConfigTest(test.TestCase): config = run_config_lib.RunConfig( experimental_distribute=DistributeConfig( train_distribute=mirrored_strategy.CoreMirroredStrategy( - num_gpus_per_worker=2))) + ["/device:GPU:0", "/device:GPU:1"]))) self.assertFalse(dc_training.should_run_distribute_coordinator(config)) def test_init_run_config_duplicate_distribute(self): diff --git a/tensorflow/contrib/distribute/python/examples/keras_mnist.py b/tensorflow/contrib/distribute/python/examples/keras_mnist.py index 8b6487252df54dc18cc0763fb1c58a190faad88a..60fda996642464135fe1fb8c314bcf7f04d19362 100644 --- a/tensorflow/contrib/distribute/python/examples/keras_mnist.py +++ b/tensorflow/contrib/distribute/python/examples/keras_mnist.py @@ -20,6 +20,10 @@ from __future__ import print_function import tensorflow as tf +from tensorflow.python.distribute import mirrored_strategy +from tensorflow.python.keras.optimizer_v2 import rmsprop + + NUM_CLASSES = 10 @@ -109,10 +113,10 @@ def main(_): # Instantiate the MirroredStrategy object. If we don't specify `num_gpus` or # the `devices` argument then all the GPUs available on the machine are used. - strategy = tf.contrib.distribute.MirroredStrategy(['/gpu:0', '/cpu:0']) + # TODO(priyag): Use `tf.distribute.MirroredStrategy` once available. + strategy = mirrored_strategy.MirroredStrategy(['/gpu:0', '/cpu:0']) - # TODO(priyag): Use RMSPropOptimizer when it works with eager mode. - optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.001) + optimizer = rmsprop.RMSProp(learning_rate=0.001) # Compile the model by passing the distribution strategy object to the # `distribute` argument. `fit`, `evaluate` and `predict` will be distributed diff --git a/tensorflow/contrib/distribute/python/keras_optimizer_v2_test.py b/tensorflow/contrib/distribute/python/keras_optimizer_v2_test.py index 6dfd85bcc4f3784e2744fd876a7190cc9581d96a..8c596549c4e20754675f69861d4c7f14f7c3c126 100644 --- a/tensorflow/contrib/distribute/python/keras_optimizer_v2_test.py +++ b/tensorflow/contrib/distribute/python/keras_optimizer_v2_test.py @@ -18,24 +18,12 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import shutil -import tempfile from absl.testing import parameterized import numpy as np -import six from tensorflow.contrib.distribute.python import combinations -from tensorflow.core.protobuf import config_pb2 from tensorflow.python import keras -from tensorflow.python.data.ops import dataset_ops from tensorflow.python.distribute import distribution_strategy_context as ds_context -from tensorflow.python.estimator import run_config -from tensorflow.python.estimator import training -from tensorflow.python.estimator.canned import dnn_linear_combined -from tensorflow.python.estimator.canned import prediction_keys -from tensorflow.python.estimator.export import export -from tensorflow.python.estimator.inputs import numpy_io -from tensorflow.python.feature_column import feature_column_lib as feature_column from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops @@ -44,103 +32,7 @@ from tensorflow.python.keras.optimizer_v2 import gradient_descent from tensorflow.python.ops import math_ops from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variables -from tensorflow.python.platform import gfile from tensorflow.python.platform import test -from tensorflow.python.summary.writer import writer_cache - - -class KerasOptimizerV2IntegrationTest(test.TestCase, parameterized.TestCase): - - def setUp(self): - self._model_dir = tempfile.mkdtemp() - - def dataset_input_fn(self, x, y, batch_size): - - def input_fn(): - dataset = dataset_ops.Dataset.from_tensor_slices((x, y)) - dataset = dataset.repeat(1).batch(batch_size) - return dataset - - return input_fn - - @combinations.generate( - combinations.combine( - mode=['graph'], - distribution=[ - combinations.one_device_strategy, - combinations.mirrored_strategy_with_gpu_and_cpu, - combinations.mirrored_strategy_with_two_gpus, - combinations.core_mirrored_strategy_with_gpu_and_cpu, - combinations.core_mirrored_strategy_with_two_gpus - ], - use_train_and_evaluate=[True, False])) - def test_complete_flow_with_mode(self, distribution, use_train_and_evaluate): - label_dimension = 2 - input_dimension = label_dimension - batch_size = 10 - data = np.linspace(0., 2., batch_size * label_dimension, dtype=np.float32) - data = data.reshape(batch_size, label_dimension) - train_input_fn = self.dataset_input_fn( - x={'x': data}, - y=data, - batch_size=batch_size // distribution.num_replicas_in_sync) - eval_input_fn = self.dataset_input_fn( - x={'x': data}, - y=data, - batch_size=batch_size // distribution.num_replicas_in_sync) - predict_input_fn = numpy_io.numpy_input_fn( - x={'x': data}, batch_size=batch_size, shuffle=False) - - linear_feature_columns = [ - feature_column.numeric_column('x', shape=(input_dimension,)) - ] - dnn_feature_columns = [ - feature_column.numeric_column('x', shape=(input_dimension,)) - ] - feature_columns = linear_feature_columns + dnn_feature_columns - session_config = config_pb2.ConfigProto( - log_device_placement=True, allow_soft_placement=True) - estimator = dnn_linear_combined.DNNLinearCombinedRegressor( - linear_feature_columns=linear_feature_columns, - dnn_hidden_units=(2, 2), - dnn_feature_columns=dnn_feature_columns, - label_dimension=label_dimension, - model_dir=self._model_dir, - dnn_optimizer=adam.Adam(0.001), - linear_optimizer=adam.Adam(0.001), - config=run_config.RunConfig( - train_distribute=distribution, - eval_distribute=distribution, - session_config=session_config)) - - num_steps = 2 - if use_train_and_evaluate: - scores, _ = training.train_and_evaluate( - estimator, training.TrainSpec(train_input_fn, max_steps=num_steps), - training.EvalSpec(eval_input_fn)) - else: - estimator.train(train_input_fn, steps=num_steps) - scores = estimator.evaluate(eval_input_fn) - - self.assertIn('loss', six.iterkeys(scores)) - - predictions = np.array([ - x[prediction_keys.PredictionKeys.PREDICTIONS] - for x in estimator.predict(predict_input_fn) - ]) - self.assertAllEqual((batch_size, label_dimension), predictions.shape) - - feature_spec = feature_column.make_parse_example_spec(feature_columns) - serving_input_receiver_fn = export.build_parsing_serving_input_receiver_fn( - feature_spec) - export_dir = estimator.export_savedmodel(tempfile.mkdtemp(), - serving_input_receiver_fn) - self.assertTrue(gfile.Exists(export_dir)) - - def tearDown(self): - if self._model_dir: - writer_cache.FileWriterCache.clear() - shutil.rmtree(self._model_dir) def get_model(): @@ -162,7 +54,9 @@ class MirroredStrategyOptimizerV2Test(test.TestCase, parameterized.TestCase): var = variables.Variable( 2.0, name='var', aggregation=variable_scope.VariableAggregation.SUM) # grad for cpu is 1, grad for gpu is 2, avg grad is 1.5. - loss = math_ops.cast(_replica_id() + 1, dtype=dtypes.float32) * var + def loss(): + return math_ops.cast(_replica_id() + 1, dtype=dtypes.float32) * var + optimizer = adam.Adam(learning_rate=0.01, beta_1=0.2, beta_2=0.2) train_op = optimizer.minimize(loss, var_list=[var]) m = optimizer.get_slot(var, 'm') diff --git a/tensorflow/contrib/distribute/python/keras_test.py b/tensorflow/contrib/distribute/python/keras_test.py index 1d002819745f1959b535ffa534be8f1a6b93b31d..c53e76f922372d8c7937e05fde61772d0b064674 100644 --- a/tensorflow/contrib/distribute/python/keras_test.py +++ b/tensorflow/contrib/distribute/python/keras_test.py @@ -165,7 +165,9 @@ def get_multi_inputs_multi_outputs_data(): return (train_data, test_data) -def batch_wrapper(dataset, batch_size, distribution): +def batch_wrapper(dataset, batch_size, distribution, repeat=None): + if repeat: + dataset = dataset.repeat(repeat) # TPUs currently require fully defined input shapes, drop_remainder ensures # the input will have fully defined shapes. if isinstance(distribution, tpu_strategy.TPUStrategy): @@ -212,9 +214,11 @@ def multi_input_output_model(): return model -def get_correctness_test_inputs(use_numpy, with_distribution, +def get_correctness_test_inputs(use_numpy, use_validation_data, + with_distribution, x_train, y_train, x_predict): """Generates the inputs for correctness check when enable Keras with DS.""" + training_epochs = 2 global_batch_size = 64 batch_size = global_batch_size # TODO(b/118776054): Use global batch size for Keras/DS support. @@ -230,14 +234,19 @@ def get_correctness_test_inputs(use_numpy, with_distribution, 'batch_size': batch_size, 'x': x_train, 'y': y_train, - 'epochs': 1, + 'epochs': training_epochs, 'shuffle': False, } - eval_inputs = { - 'batch_size': batch_size, - 'x': x_train, - 'y': y_train, - } + + if use_validation_data: + eval_inputs = None + training_inputs['validation_data'] = (x_train, y_train) + else: + eval_inputs = { + 'batch_size': batch_size, + 'x': x_train, + 'y': y_train, + } predict_inputs = { 'x': np.array(x_predict, dtype=np.float32), } @@ -246,22 +255,32 @@ def get_correctness_test_inputs(use_numpy, with_distribution, # keras.fit/evaluate/predict. The batch size is part of the dataset. train_dataset = dataset_ops.Dataset.from_tensor_slices( (x_train, y_train)) - x = batch_wrapper(train_dataset, batch_size, with_distribution) + x = batch_wrapper( + train_dataset, batch_size, with_distribution, repeat=training_epochs) training_inputs = { 'batch_size': None, 'x': x, 'y': None, - 'epochs': 1, + 'epochs': training_epochs, 'shuffle': False, 'steps_per_epoch': len(x_train) // global_batch_size, } - eval_inputs = { - 'batch_size': None, - 'x': x, - 'y': None, - 'steps': 20, - } + if use_validation_data: + eval_inputs = None # Remove the eval_inputs + eval_dataset = dataset_ops.Dataset.from_tensor_slices( + (x_train, y_train)) + x = batch_wrapper(eval_dataset, batch_size, with_distribution) + training_inputs['validation_data'] = x + training_inputs['validation_steps'] = 5 + else: + eval_inputs = { + 'batch_size': None, + 'x': x, + 'y': None, + 'steps': 20, + } + predict_batch_size = len(x_predict) if use_per_core_batch_size: predict_batch_size //= with_distribution.num_replicas_in_sync @@ -276,47 +295,66 @@ def get_correctness_test_inputs(use_numpy, with_distribution, return training_inputs, eval_inputs, predict_inputs -strategies = [combinations.default_strategy, - combinations.one_device_strategy, - combinations.mirrored_strategy_with_gpu_and_cpu, - combinations.mirrored_strategy_with_two_gpus, - combinations.core_mirrored_strategy_with_gpu_and_cpu, - combinations.core_mirrored_strategy_with_two_gpus, - combinations.tpu_strategy, # steps_per_run=2 - combinations.tpu_strategy_one_step] +strategies_minus_tpu = [ + combinations.default_strategy, + combinations.one_device_strategy, + combinations.mirrored_strategy_with_gpu_and_cpu, + combinations.mirrored_strategy_with_two_gpus, + combinations.core_mirrored_strategy_with_gpu_and_cpu, + combinations.core_mirrored_strategy_with_two_gpus] + +tpu_strategies = [ + combinations.tpu_strategy, # steps_per_run=2 + combinations.tpu_strategy_one_step] def strategy_minus_tpu_combinations(): return combinations.combine( - distribution=[combinations.default_strategy, - combinations.one_device_strategy, - combinations.mirrored_strategy_with_gpu_and_cpu, - combinations.mirrored_strategy_with_two_gpus, - combinations.core_mirrored_strategy_with_gpu_and_cpu, - combinations.core_mirrored_strategy_with_two_gpus], - mode=['graph']) + distribution=strategies_minus_tpu, + mode=['graph', 'eager']) -def strategy_combinations(): +def tpu_strategy_combinations(): return combinations.combine( - distribution=strategies, + distribution=tpu_strategies, mode=['graph']) -def strategy_and_optimizer_combinations(): - return combinations.combine( - distribution=strategies, - optimizer=[combinations.adagrad_optimizer_v1_fn, - combinations.adam_optimizer_v1_fn, - combinations.gradient_descent_optimizer_v1_fn, - combinations.rmsprop_optimizer_v1_fn], - mode=['graph']) +def all_strategy_combinations(): + return strategy_minus_tpu_combinations() + tpu_strategy_combinations() -def strategy_and_inputs(): +# TODO(priyag): Add v2 optimizers here. +def strategy_and_optimizer_combinations(): + return combinations.times( + all_strategy_combinations(), + combinations.combine( + optimizer=[combinations.adagrad_optimizer_v1_fn, + combinations.adam_optimizer_v1_fn, + combinations.gradient_descent_optimizer_v1_fn, + combinations.rmsprop_optimizer_v1_fn])) + + +def strategy_and_input_combinations(): + return ( + combinations.times( + combinations.combine(distribution=strategies_minus_tpu), + combinations.combine(mode=['graph'], + use_numpy=[True, False], + use_validation_data=[True, False]) + + combinations.combine(mode=['eager'], + use_numpy=[False], + use_validation_data=[False])) + + combinations.times( + combinations.combine(distribution=tpu_strategies), + combinations.combine(mode=['graph'], + use_numpy=[True, False], + use_validation_data=[True, False]))) + + +def strategy_for_numpy_input_combinations(): return combinations.combine( - distribution=strategies, - use_numpy=[True, False], + distribution=strategies_minus_tpu + tpu_strategies, mode=['graph']) @@ -337,7 +375,9 @@ class TestEstimatorDistributionStrategy(test_util.TensorFlowTestCase, @combinations.generate(combinations.combine( distribution=[ + combinations.mirrored_strategy_with_gpu_and_cpu, combinations.mirrored_strategy_with_two_gpus, + combinations.core_mirrored_strategy_with_gpu_and_cpu, combinations.core_mirrored_strategy_with_two_gpus], mode=['graph'])) def test_train_functional_with_distribution_strategy(self, distribution): @@ -365,7 +405,9 @@ class TestEstimatorDistributionStrategy(test_util.TensorFlowTestCase, @combinations.generate(combinations.combine( distribution=[ + combinations.mirrored_strategy_with_gpu_and_cpu, combinations.mirrored_strategy_with_two_gpus, + combinations.core_mirrored_strategy_with_gpu_and_cpu, combinations.core_mirrored_strategy_with_two_gpus], mode=['graph'])) def test_train_sequential_with_distribution_strategy(self, distribution): @@ -392,8 +434,8 @@ class TestEstimatorDistributionStrategy(test_util.TensorFlowTestCase, @combinations.generate(combinations.combine( distribution=[ - combinations.mirrored_strategy_with_two_gpus, - combinations.core_mirrored_strategy_with_two_gpus], + combinations.mirrored_strategy_with_gpu_and_cpu, + combinations.core_mirrored_strategy_with_gpu_and_cpu], mode=['graph'])) def test_multi_inputs_multi_outputs_with_input_fn_as_dict(self, distribution): train_data, test_data = get_multi_inputs_multi_outputs_data() @@ -444,8 +486,8 @@ class TestEstimatorDistributionStrategy(test_util.TensorFlowTestCase, @combinations.generate(combinations.combine( distribution=[ - combinations.mirrored_strategy_with_two_gpus, - combinations.core_mirrored_strategy_with_two_gpus], + combinations.mirrored_strategy_with_gpu_and_cpu, + combinations.core_mirrored_strategy_with_gpu_and_cpu], mode=['graph'])) def test_keras_optimizer_with_distribution_strategy(self, distribution): keras_model = simple_sequential_model() @@ -471,7 +513,7 @@ class TestEstimatorDistributionStrategy(test_util.TensorFlowTestCase, class TestDistributionStrategyWithNumpyArrays(test.TestCase, parameterized.TestCase): - @combinations.generate(strategy_combinations()) + @combinations.generate(strategy_for_numpy_input_combinations()) def test_creating_var_with_numpy_arrays(self, distribution): with self.cached_session(): x = np.asarray(np.random.random((64, 3)), dtype=np.float32) @@ -480,7 +522,7 @@ class TestDistributionStrategyWithNumpyArrays(test.TestCase, # Verify that the numpy value is copied to the variable. self.assertAllEqual(x, val) - @combinations.generate(strategy_combinations()) + @combinations.generate(strategy_for_numpy_input_combinations()) def test_calculating_input_params_no_steps_no_batch_size(self, distribution): # Calculate the per_replica_batch_size scaling factor for strategies # that use per_core_batch_size @@ -511,7 +553,7 @@ class TestDistributionStrategyWithNumpyArrays(test.TestCase, distributed_training_utils.get_input_params( distribution, input_63_samples, steps=None, batch_size=None) - @combinations.generate(strategy_combinations()) + @combinations.generate(strategy_for_numpy_input_combinations()) def test_calculating_input_params_with_steps_no_batch_size(self, distribution): # Calculate the per_replica_batch_size scaling factor for strategies @@ -557,7 +599,7 @@ class TestDistributionStrategyWithNumpyArrays(test.TestCase, distributed_training_utils.get_input_params( distribution, input_63_samples, steps=1, batch_size=None) - @combinations.generate(strategy_combinations()) + @combinations.generate(strategy_for_numpy_input_combinations()) def test_calculating_input_params_no_steps_with_batch_size(self, distribution): # Calculate the per_replica_batch_size scaling factor for strategies @@ -591,7 +633,7 @@ class TestDistributionStrategyWithNumpyArrays(test.TestCase, distributed_training_utils.get_input_params( distribution, input_64_samples, steps=None, batch_size=3) - @combinations.generate(strategy_combinations()) + @combinations.generate(strategy_for_numpy_input_combinations()) def test_calculating_input_params_with_steps_with_batch_size(self, distribution): with self.cached_session(): @@ -608,7 +650,7 @@ class TestDistributionStrategyWithNumpyArrays(test.TestCase, distributed_training_utils.get_input_params( distribution, input_64_samples, steps=10, batch_size=13) - @combinations.generate(strategy_combinations()) + @combinations.generate(strategy_for_numpy_input_combinations()) def test_calling_model_with_numpy_arrays(self, distribution): with self.cached_session(): model = get_model() @@ -639,7 +681,7 @@ class TestDistributionStrategyWithNumpyArrays(test.TestCase, # with batch_size model.predict(inputs, batch_size=8) - @combinations.generate(strategy_combinations()) + @combinations.generate(strategy_for_numpy_input_combinations()) def test_calling_model_with_nested_numpy_arrays(self, distribution): with self.cached_session(): model = multi_input_output_model() @@ -673,7 +715,8 @@ class TestDistributionStrategyWithNumpyArrays(test.TestCase, # with batch_size model.predict(inputs, batch_size=8) - @combinations.generate(strategy_minus_tpu_combinations()) + @combinations.generate(combinations.combine( + distribution=strategies_minus_tpu, mode=['graph'])) def test_numpy_with_sample_weights(self, distribution): model = get_model() optimizer = rmsprop.RMSPropOptimizer(learning_rate=0.001) @@ -687,7 +730,7 @@ class TestDistributionStrategyWithNumpyArrays(test.TestCase, model.fit(inputs, targets, sample_weight=sample_weights, epochs=1, steps_per_epoch=2, verbose=1) - @combinations.generate(strategy_combinations()) + @combinations.generate(strategy_for_numpy_input_combinations()) def test_flatten_predict_outputs(self, distribution): with self.cached_session(): model = multi_input_output_model() @@ -715,7 +758,7 @@ class TestDistributionStrategyWithNumpyArrays(test.TestCase, class TestDistributionStrategyWithDatasets(test.TestCase, parameterized.TestCase): - @combinations.generate(strategy_combinations()) + @combinations.generate(all_strategy_combinations()) def test_calling_model_on_same_dataset(self, distribution): with self.cached_session(): model = get_model() @@ -734,7 +777,7 @@ class TestDistributionStrategyWithDatasets(test.TestCase, validation_data=dataset, validation_steps=2) model.predict(get_predict_dataset(distribution), steps=2) - @combinations.generate(strategy_combinations()) + @combinations.generate(all_strategy_combinations()) def test_model_interleaved_eval_same_as_direct_eval(self, distribution): with self.cached_session(): user_controlled_model = get_model() @@ -782,7 +825,7 @@ class TestDistributionStrategyWithDatasets(test.TestCase, distribution=[ combinations.mirrored_strategy_with_gpu_and_cpu, combinations.core_mirrored_strategy_with_gpu_and_cpu], - mode=['graph'])) + mode=['graph', 'eager'])) def test_fit_with_tuple_and_dict_dataset_inputs(self, distribution): with self.cached_session(): model = multi_input_output_model() @@ -814,7 +857,7 @@ class TestDistributionStrategyWithDatasets(test.TestCase, model.fit(dataset_dict, epochs=1, steps_per_epoch=2, verbose=1) - @combinations.generate(strategy_combinations()) + @combinations.generate(all_strategy_combinations()) def test_fit_eval_and_predict_methods_on_dataset(self, distribution): with self.cached_session(): model = get_model() @@ -865,10 +908,12 @@ class TestDistributionStrategyWithDatasets(test.TestCase, @combinations.generate(combinations.combine( distribution=[ - combinations.mirrored_strategy_with_two_gpus, - combinations.core_mirrored_strategy_with_two_gpus], - mode=['graph'])) - def test_dataset_wrong_input_shape(self, distribution): + combinations.mirrored_strategy_with_gpu_and_cpu, + combinations.core_mirrored_strategy_with_gpu_and_cpu], + mode=['graph', 'eager'])) + # TODO(b/120943676, b/120957836): Re-enable once the validation code is + # restored. + def DISABLED_test_dataset_wrong_input_shape(self, distribution): with self.cached_session(): model = get_model() @@ -888,9 +933,11 @@ class TestDistributionStrategyWithDatasets(test.TestCase, model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=0) @combinations.generate(combinations.combine( - distribution=[combinations.mirrored_strategy_with_two_gpus], - mode=['graph'])) - def test_dataset_no_batch_input_validation(self, distribution): + distribution=[combinations.mirrored_strategy_with_gpu_and_cpu], + mode=['graph', 'eager'])) + # TODO(b/120943676, b/120957836): Re-enable once the validation code is + # restored. + def DISABLED_test_dataset_no_batch_input_validation(self, distribution): with self.cached_session(): model = get_model() @@ -928,9 +975,11 @@ class TestDistributionStrategyWithDatasets(test.TestCase, @combinations.generate(combinations.combine( distribution=[ + combinations.mirrored_strategy_with_gpu_and_cpu, combinations.mirrored_strategy_with_two_gpus, + combinations.core_mirrored_strategy_with_gpu_and_cpu, combinations.core_mirrored_strategy_with_two_gpus], - mode=['graph'])) + mode=['graph', 'eager'])) def test_learning_phase_value(self, distribution): # TODO(anjalisridhar): Modify this test to use Lambdas since we can compare # meaningful values. Currently we don't pass the learning phase if the @@ -1002,7 +1051,7 @@ class TestDistributionStrategyErrorCases(test.TestCase, parameterized.TestCase): distribution=[ combinations.mirrored_strategy_with_gpu_and_cpu, combinations.core_mirrored_strategy_with_gpu_and_cpu], - mode=['graph'])) + mode=['graph', 'eager'])) def test_validating_dataset_input_tensors_with_shape_mismatch(self, distribution): with self.cached_session(): @@ -1025,7 +1074,7 @@ class TestDistributionStrategyErrorCases(test.TestCase, parameterized.TestCase): distribution=[ combinations.mirrored_strategy_with_gpu_and_cpu, combinations.core_mirrored_strategy_with_gpu_and_cpu], - mode=['graph'])) + mode=['graph', 'eager'])) def test_validating_dataset_input_tensors_with_dtype_mismatch(self, distribution): with self.cached_session(): @@ -1046,9 +1095,9 @@ class TestDistributionStrategyErrorCases(test.TestCase, parameterized.TestCase): @combinations.generate(combinations.combine( distribution=[ - combinations.mirrored_strategy_with_two_gpus, - combinations.core_mirrored_strategy_with_two_gpus], - mode=['graph'])) + combinations.mirrored_strategy_with_gpu_and_cpu, + combinations.core_mirrored_strategy_with_gpu_and_cpu], + mode=['graph', 'eager'])) def test_unsupported_features(self, distribution): with self.cached_session(): model = get_model() @@ -1095,9 +1144,9 @@ class TestDistributionStrategyErrorCases(test.TestCase, parameterized.TestCase): @combinations.generate(combinations.combine( distribution=[ - combinations.mirrored_strategy_with_two_gpus, - combinations.core_mirrored_strategy_with_two_gpus], - mode=['graph'])) + combinations.mirrored_strategy_with_gpu_and_cpu, + combinations.core_mirrored_strategy_with_gpu_and_cpu], + mode=['graph', 'eager'])) def test_calling_with_unsupported_predefined_callbacks(self, distribution): with self.cached_session(): model = get_model() @@ -1122,12 +1171,6 @@ class TestDistributionStrategyErrorCases(test.TestCase, parameterized.TestCase): 'using'): model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=0, callbacks=[keras.callbacks.ReduceLROnPlateau()]) - with self.assertRaisesRegexp(ValueError, - 'histogram_freq in the TensorBoard callback ' - 'is not supported when using ' - 'DistributionStrategy.'): - model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=0, - callbacks=[keras.callbacks.TensorBoard(histogram_freq=10)]) class TestDistributionStrategyWithLossMasking(test.TestCase, @@ -1137,9 +1180,9 @@ class TestDistributionStrategyWithLossMasking(test.TestCase, # work for TPU due to some invalid datatype. @combinations.generate(combinations.combine( distribution=[ - combinations.mirrored_strategy_with_two_gpus, - combinations.core_mirrored_strategy_with_two_gpus], - mode=['graph'])) + combinations.mirrored_strategy_with_gpu_and_cpu, + combinations.core_mirrored_strategy_with_gpu_and_cpu], + mode=['graph', 'eager'])) def test_masking(self, distribution): with self.cached_session(): np.random.seed(1337) @@ -1163,7 +1206,7 @@ class TestDistributionStrategyWithLossMasking(test.TestCase, class TestDistributionStrategyWithNormalizationLayer( test.TestCase, parameterized.TestCase): - @combinations.generate(strategy_combinations()) + @combinations.generate(all_strategy_combinations()) def test_batchnorm_correctness(self, distribution): with self.cached_session(): model = keras.models.Sequential() @@ -1195,7 +1238,7 @@ class TestDistributionStrategyWithNormalizationLayer( class TestDistributionStrategyCorrectness(test.TestCase, parameterized.TestCase): - @combinations.generate(strategy_combinations()) + @combinations.generate(all_strategy_combinations()) def test_metric_correctness(self, distribution): with self.cached_session(): keras.backend.set_image_data_format('channels_last') @@ -1224,18 +1267,57 @@ class TestDistributionStrategyCorrectness(test.TestCase, train_dataset = dataset_ops.Dataset.from_tensor_slices((x_train, y_train)) train_dataset = batch_wrapper(train_dataset, batch_size, distribution) - history = model.fit(x=train_dataset, epochs=1, steps_per_epoch=10) - self.assertEqual(history.history['binary_accuracy'], [1.0]) + history = model.fit(x=train_dataset, epochs=2, steps_per_epoch=10) + self.assertEqual(history.history['binary_accuracy'], [1.0, 1.0]) - @combinations.generate(strategy_and_inputs()) - def test_correctness(self, distribution, use_numpy): + @combinations.generate(all_strategy_combinations()) + def test_eval_metrics_correctness(self, distribution): with self.cached_session(): - tolerance = 1e-5 + model = keras.Sequential() + model.add( + keras.layers.Dense( + 3, activation='relu', input_dim=4, kernel_initializer='ones')) + model.add( + keras.layers.Dense( + 1, activation='sigmoid', kernel_initializer='ones')) + model.compile( + loss='mae', + metrics=['accuracy', keras.metrics.BinaryAccuracy()], + optimizer=gradient_descent.GradientDescentOptimizer(0.001), + distribute=distribution) + + # verify correctness of stateful and stateless metrics. + x = np.ones((100, 4)).astype('float32') + y = np.ones((100, 1)).astype('float32') + dataset = dataset_ops.Dataset.from_tensor_slices((x, y)).repeat() + dataset = batch_wrapper(dataset, 4, distribution) + outs = model.evaluate(dataset, steps=10) + self.assertEqual(outs[1], 1.) + self.assertEqual(outs[2], 1.) + + y = np.zeros((100, 1)).astype('float32') + dataset = dataset_ops.Dataset.from_tensor_slices((x, y)).repeat() + dataset = batch_wrapper(dataset, 4, distribution) + outs = model.evaluate(dataset, steps=10) + self.assertEqual(outs[1], 0.) + self.assertEqual(outs[2], 0.) + + @combinations.generate(strategy_and_input_combinations()) + def test_correctness(self, distribution, use_numpy, use_validation_data): + + with self.cached_session(): + default_tolerance = 1e-5 + tol_table = {} if isinstance(distribution, (mirrored_strategy.MirroredStrategy, mirrored_strategy.CoreMirroredStrategy)): - # TODO(b/119257215): use the default one once the flakyness is fixed. - tolerance = 1e-4 + # TODO(b/119257215): Weights are not exactly the same, so use larger + # tolerance for now. Predict should be related to weights. + tol_table = { + 'weights_1': 1e-4, + 'weights_2': 1e-4, + 'predict_result_1': 1e-4, + } keras.backend.set_image_data_format('channels_last') np.random.seed(_RANDOM_SEED) @@ -1256,54 +1338,75 @@ class TestDistributionStrategyCorrectness(test.TestCase, # This is used to initialize the model for both the distribution and # non-distribution run. In addition, we add few non-linear layers to make # it non-trivial. - model = keras.Sequential() - model.add(keras.layers.Dense(10, activation='relu', input_shape=(1,))) - model.add(keras.layers.Dense(10, activation='relu')) - model.add(keras.layers.Dense(10, activation='relu')) - model.add(keras.layers.Dense(1)) + def _create_model(): + model = keras.Sequential() + model.add(keras.layers.Dense(10, activation='relu', input_shape=(1,))) + model.add(keras.layers.Dense(10, activation='relu')) + model.add(keras.layers.Dense(10, activation='relu')) + model.add(keras.layers.Dense(1)) + return model + + model = _create_model() initial_weights = model.get_weights() + del model # avoid accident usage. def fit_eval_and_predict(with_distribution=None): + model = _create_model() # We have initialized the model to the same weight for the distribution # and non-distribution run. model.set_weights(initial_weights) - # TODO(b/120245072): Also use gradient_descent_keras.SGD for - # TPUStrategy. - # pylint: disable=line-too-long - if with_distribution and with_distribution.__class__.__name__ == 'TPUStrategy': - # pylint: enable=line-too-long - optimizer = gradient_descent.GradientDescentOptimizer(0.5) - else: - optimizer = gradient_descent_keras.SGD(0.5) model.compile( loss=keras.losses.mean_squared_error, - optimizer=optimizer, + optimizer=gradient_descent_keras.SGD(0.5), + metrics=['mse'], distribute=with_distribution) training_inputs, eval_inputs, predict_inputs = ( - get_correctness_test_inputs(use_numpy, with_distribution, + get_correctness_test_inputs(use_numpy, use_validation_data, + with_distribution, x_train, y_train, x_predict)) - model.fit(**training_inputs) - eval_result = model.evaluate(**eval_inputs) - weights = model.get_weights() - predict_result = model.predict(**predict_inputs) - - return weights, eval_result, predict_result - - wts_with_ds, eval_with_ds, predict_with_ds = fit_eval_and_predict( - with_distribution=distribution) - wts_without_ds, eval_without_ds, predict_without_ds = ( - fit_eval_and_predict(with_distribution=None)) - - # Verify that the weights, eval results, predict outputs are the same - # within some limits of tolerance. - self.assertAllClose( - wts_with_ds, wts_without_ds, atol=tolerance, rtol=tolerance) - self.assertAllClose( - eval_with_ds, eval_without_ds, atol=tolerance, rtol=tolerance) - self.assertAllClose( - predict_with_ds, predict_without_ds, atol=tolerance, rtol=tolerance) + result = {} + result['training_history_1'] = model.fit(**training_inputs).history + + if eval_inputs is not None: + result['eval_result_1'] = model.evaluate(**eval_inputs) + + result['weights_1'] = model.get_weights() + result['predict_result_1'] = model.predict(**predict_inputs) + + # Train and eval again to mimic user's flow. + + result['training_history_2'] = model.fit(**training_inputs).history + + if eval_inputs is not None: + result['eval_result_2'] = model.evaluate(**eval_inputs) + + result['weights_2'] = model.get_weights() + + return result + + results_with_ds = fit_eval_and_predict(with_distribution=distribution) + results_without_ds = fit_eval_and_predict(with_distribution=None) + + # Verify that the weights, training history, eval results, predict outputs + # are the same within some limits of tolerance. + for key in results_with_ds: + if (key.startswith('training_history') and + isinstance(distribution, tpu_strategy.TPUStrategy) and + distribution.extended.steps_per_run > 1): + # TODO(b/119894254): Enable this test for all cases once the + # underlying bug is fixed. + continue + + tolerance = tol_table.get(key, default_tolerance) + + self.assertAllClose( + results_with_ds[key], + results_without_ds[key], + atol=tolerance, + rtol=tolerance, + msg='Fail to assert {}.'.format(key)) if __name__ == '__main__': diff --git a/tensorflow/contrib/distribute/python/minimize_loss_test.py b/tensorflow/contrib/distribute/python/minimize_loss_test.py index dcc9df4cda51b87e95fb166a726170a8817715fc..f09483cb56b66fd4720ee71085203c14f1ccadc3 100644 --- a/tensorflow/contrib/distribute/python/minimize_loss_test.py +++ b/tensorflow/contrib/distribute/python/minimize_loss_test.py @@ -232,7 +232,7 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase): fetches = distribution.unwrap( distribution.call_for_each_replica(model_fn, args=inputs)) if update_ops_in_cross_replica_mode: - fetches += ops.get_collection(ops.GraphKeys.UPDATE_OPS) + fetches += tuple(ops.get_collection(ops.GraphKeys.UPDATE_OPS)) return control_flow_ops.group(fetches) iterator = self._get_iterator(distribution.distribute_dataset(dataset_fn)) @@ -443,7 +443,7 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase): step_fn, iterator, iterations=2, initial_loop_values=initial_loop_values) - self.assertEqual({key1: [value1]}, ctx.non_tensor_outputs) + self.assertEqual({key1: (value1,)}, ctx.non_tensor_outputs) self._verify_loss_output( initial_loss(), loss_output=ctx.last_step_outputs["replica_loss_reduced"], diff --git a/tensorflow/contrib/distribute/python/mirrored_strategy.py b/tensorflow/contrib/distribute/python/mirrored_strategy.py index 4a594f056e96a2a48563d9902bdeed8458b847e4..24399db6522c325722b95399fd002eed9fd955f2 100644 --- a/tensorflow/contrib/distribute/python/mirrored_strategy.py +++ b/tensorflow/contrib/distribute/python/mirrored_strategy.py @@ -28,8 +28,8 @@ from tensorflow.python.distribute import values # pylint: disable=protected-access,invalid-name _call_for_each_replica = mirrored_strategy._call_for_each_replica -_reduce_non_distributed_value = mirrored_strategy._reduce_non_distributed_value _create_mirrored_variable = mirrored_strategy._create_mirrored_variable +all_local_devices = mirrored_strategy.all_local_devices CoreMirroredStrategy = mirrored_strategy.MirroredStrategy CoreMirroredExtended = mirrored_strategy.MirroredExtended # pylint: enable=protected-access,invalid-name @@ -115,8 +115,13 @@ class MirroredExtended(CoreMirroredExtended): num_gpus_per_worker=None, cross_device_ops=None, auto_shard_dataset=False): - super(MirroredExtended, self).__init__( - container_strategy, devices, num_gpus_per_worker, cross_device_ops) + if devices is None: + devices = mirrored_strategy.all_local_devices(num_gpus_per_worker) + elif num_gpus_per_worker is not None: + raise ValueError( + "Must only specify one of `devices` and `num_gpus_per_worker`.") + super(MirroredExtended, self).__init__(container_strategy, devices, + cross_device_ops) self._auto_shard_dataset = auto_shard_dataset def _make_dataset_iterator(self, dataset): @@ -131,22 +136,22 @@ class MirroredExtended(CoreMirroredExtended): Returns: An `InputIterator` which returns inputs for each step of the computation. """ - if self._cluster_spec: - worker_device_pairs = self._worker_devices - else: + if self._local_mode: worker = device_util.canonicalize("/device:CPU:0") worker_device_pairs = [(worker, self._devices)] + else: + worker_device_pairs = self._worker_devices return values.DatasetIterator(dataset, worker_device_pairs) def _distribute_dataset(self, dataset_fn): - if self._cluster_spec: + if self._local_mode: + return values.PerReplicaDataset( + self._call_dataset_fn(dataset_fn), self._devices) + else: return values.MultiWorkerDataset( functools.partial(self._call_dataset_fn, dataset_fn), self._worker_devices, auto_shard=self._auto_shard_dataset) - else: - return values.PerReplicaDataset( - self._call_dataset_fn(dataset_fn), self._devices) # TODO(priyag): Delete this once all strategies use global batch size. @property diff --git a/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py b/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py index fee37daa424b8ada9f18b2046599a62647d8c33d..337a86b3421fdb90c98cd5097dd880fdbe5871b9 100644 --- a/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py +++ b/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py @@ -38,6 +38,7 @@ from tensorflow.python.eager import context from tensorflow.python.eager import function from tensorflow.python.eager import test from tensorflow.python.framework import constant_op +from tensorflow.python.framework import func_graph from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.keras.engine import training as keras_training @@ -179,9 +180,37 @@ class MirroredStrategyVariableCreatorStackTest( variable_scope.variable_creator_scope(main_thread_creator): result = distribution.extended.call_for_each_replica(model_fn) result = distribution.unwrap(result) - expected = ["main_thread:thread_0", "main_thread:thread_1"] + expected = ("main_thread:thread_0", "main_thread:thread_1") self.assertEqual(expected, result) +@combinations.generate(combinations.combine( + distribution=[ + combinations.mirrored_strategy_with_gpu_and_cpu, + combinations.core_mirrored_strategy_with_gpu_and_cpu], + mode=["graph", "eager"])) +class MirroredStrategyCallForEachReplicaTest(test.TestCase): + + def testExecutingEagerlyOutsideFunction(self, distribution): + """Verify we preserve the value of executing_eagerly_outside_functions().""" + def model_fn(): + return ops.executing_eagerly_outside_functions() + + originally = ops.executing_eagerly_outside_functions() + with distribution.scope(): + in_scope = ops.executing_eagerly_outside_functions() + in_model_fn = distribution.extended.call_for_each_replica(model_fn) + unwrapped = distribution.unwrap(in_model_fn) + self.assertEqual(in_scope, unwrapped[0]) + self.assertEqual(in_scope, originally) + + # Verify this all again, but this time in a FuncGraph. + with func_graph.FuncGraph("fg").as_default(), distribution.scope(): + in_scope = ops.executing_eagerly_outside_functions() + in_model_fn = distribution.extended.call_for_each_replica(model_fn) + unwrapped = distribution.unwrap(in_model_fn) + self.assertEqual(in_scope, unwrapped[0]) + self.assertEqual(in_scope, originally) + @combinations.generate(combinations.combine( distribution=[ @@ -190,6 +219,27 @@ class MirroredStrategyVariableCreatorStackTest( mode=["graph", "eager"])) class MirroredStrategyVariableCreationTest(test.TestCase): + # TODO(priyag): Modify more tests to use this helper and check more + # properties. + def _test_mv_properties(self, var, name): + self.assertIsInstance(var, values.MirroredVariable) + self.assertEqual(name, var.name) + for d in var.devices: + self.assertEqual(d, var.get(d).device) + + def testVariableInFuncGraph(self, distribution): + def model_fn(): + v = variable_scope.variable(2.0, name="bar") + ds_context.get_replica_context().merge_call(lambda _: _) + return v + + with func_graph.FuncGraph("fg").as_default(), distribution.scope(): + v1 = variable_scope.variable(1.0, name="foo") + v2 = distribution.extended.call_for_each_replica(model_fn) + + self._test_mv_properties(v1, "foo:0") + self._test_mv_properties(v2, "bar:0") + def testSingleVariable(self, distribution): def model_fn(): # This variable should be created only once across the threads because of @@ -201,8 +251,7 @@ class MirroredStrategyVariableCreationTest(test.TestCase): with distribution.scope(): result = distribution.extended.call_for_each_replica(model_fn) - self.assertIsInstance(result, values.MirroredVariable) - self.assertEqual("foo:0", result.name) + self._test_mv_properties(result, "foo:0") def testUnnamedVariable(self, distribution): def model_fn(): @@ -212,9 +261,7 @@ class MirroredStrategyVariableCreationTest(test.TestCase): with distribution.scope(): result = distribution.extended.call_for_each_replica(model_fn) - self.assertIsInstance(result, values.MirroredVariable) - # Default name of "Variable" will be used. - self.assertEqual("Variable:0", result.name) + self._test_mv_properties(result, "Variable:0") def testMultipleVariables(self, distribution): def model_fn(): @@ -227,8 +274,7 @@ class MirroredStrategyVariableCreationTest(test.TestCase): with distribution.scope(): result = distribution.extended.call_for_each_replica(model_fn) for i, v in enumerate(result): - self.assertIsInstance(v, values.MirroredVariable) - self.assertEqual("foo" + str(i) + ":0", v.name) + self._test_mv_properties(v, "foo" + str(i) + ":0") def testMultipleVariablesWithSameCanonicalName(self, distribution): def model_fn(): @@ -757,21 +803,23 @@ class MirroredStrategyNameScopeTest(test.TestCase): self.assertEqual("c/replica_1:0", c1.name) -@combinations.generate(combinations.combine( - distribution=[ - combinations.NamedDistribution( - "Mirrored3Devices", - # pylint: disable=g-long-lambda - lambda: mirrored_strategy.MirroredStrategy( - ["/device:GPU:0", "/device:GPU:1", "/device:CPU:0"]), - required_gpus=2), - combinations.NamedDistribution( - "CoreMirrored3Devices", - # pylint: disable=g-long-lambda - lambda: mirrored_strategy.CoreMirroredStrategy( - ["/device:GPU:0", "/device:GPU:1", "/device:CPU:0"]), - required_gpus=2)], - mode=["graph", "eager"])) +@combinations.generate( + combinations.combine( + distribution=[ + combinations.NamedDistribution( + "Mirrored3Devices", + # pylint: disable=g-long-lambda + lambda: mirrored_strategy.MirroredStrategy( + ["/device:GPU:0", "/device:GPU:1", "/device:CPU:0"]), + required_gpus=2), + combinations.NamedDistribution( + "CoreMirrored3Devices", + # pylint: disable=g-long-lambda + lambda: mirrored_strategy.CoreMirroredStrategy( + ["/device:GPU:0", "/device:GPU:1", "/device:CPU:0"]), + required_gpus=2) + ], + mode=["graph", "eager"])) class MirroredThreeDeviceDistributionTest( strategy_test_lib.DistributionTestBase, parameterized.TestCase): @@ -1283,14 +1331,14 @@ class MirroredStrategyDefunTest(test.TestCase): combinations.NamedDistribution( "Mirrored", # pylint: disable=g-long-lambda - lambda: mirrored_strategy.CoreMirroredStrategy( - num_gpus_per_worker=context.num_gpus()), + lambda: mirrored_strategy.MirroredStrategy(num_gpus_per_worker= + context.num_gpus()), required_gpus=1), combinations.NamedDistribution( "CoreMirrored", # pylint: disable=g-long-lambda lambda: mirrored_strategy.CoreMirroredStrategy( - num_gpus_per_worker=context.num_gpus()), + mirrored_strategy.all_local_devices()), required_gpus=1) ], mode=["graph"])) @@ -1374,7 +1422,7 @@ class MultiWorkerMirroredStrategyTestWithChief( def testMinimizeLossGraphCoreMirroredStrategy(self): strategy = mirrored_strategy.CoreMirroredStrategy( - num_gpus_per_worker=context.num_gpus()) + mirrored_strategy.all_local_devices()) strategy.configure(cluster_spec=self._cluster_spec) self._test_minimize_loss_graph(strategy, learning_rate=0.05) diff --git a/tensorflow/contrib/distribute/python/moving_averages_test.py b/tensorflow/contrib/distribute/python/moving_averages_test.py index c492d8bafc9024ed059f05b92e5466f3702726b9..8f13e9153ea7a951dd722c4549882c97e79b57fe 100644 --- a/tensorflow/contrib/distribute/python/moving_averages_test.py +++ b/tensorflow/contrib/distribute/python/moving_averages_test.py @@ -139,6 +139,27 @@ class AssignMovingAveragesTest(test.TestCase, parameterized.TestCase): (2.0 * 0.25 + 0.0) / (1.0 * 0.25 + 1.0)], var.eval()) + @combinations.generate(all_combinations) + def testAssignVariable(self, distribution): + + def replica_fn(): + var = variables.Variable([10.0, 11.0]) + # Here we expect to check the case when input value are variable. + val = variables.Variable([1., 2.]) + decay = 0.25 + assign = moving_averages.assign_moving_average( + var, val, decay, zero_debias=False) + return var, assign + + with distribution.scope(), self.cached_session() as sess: + var, assign = distribution.call_for_each_replica(replica_fn) + variables.global_variables_initializer().run() + self.assertAllClose([10.0, 11.0], var.eval()) + sess.run(distribution.unwrap(assign)) + self.assertAllClose( + [10 * 0.25 + 1. * (1 - 0.25), 11 * 0.25 + 2. * (1 - 0.25)], + var.eval()) + if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/distribute/python/multi_worker_test_base.py b/tensorflow/contrib/distribute/python/multi_worker_test_base.py index 8eec3dc0f6ec0676353c7434d203e017b9aab80d..b05aac431f65b4281d9ed9c2fa95c210d55f4008 100644 --- a/tensorflow/contrib/distribute/python/multi_worker_test_base.py +++ b/tensorflow/contrib/distribute/python/multi_worker_test_base.py @@ -18,8 +18,11 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import collections import contextlib import copy +import json +import os import threading import numpy as np @@ -37,6 +40,7 @@ from tensorflow.python.client import session from tensorflow.python.estimator import run_config from tensorflow.python.platform import test from tensorflow.python.platform import tf_logging as logging +from tensorflow.python.training import coordinator from tensorflow.python.training import server_lib ASSIGNED_PORTS = set() @@ -271,7 +275,6 @@ class MultiWorkerTestBase(test.TestCase): return config - def _run_client(self, client_fn, task_type, task_id, num_gpus, *args, **kwargs): result = client_fn(task_type, task_id, num_gpus, *args, **kwargs) @@ -303,3 +306,106 @@ class MultiWorkerTestBase(test.TestCase): for t in threads: t.join() self.assertEqual(self._result, len(threads)) + + +class MockOsEnv(collections.Mapping): + """A class that allows per-thread TF_CONFIG.""" + + def __init__(self, *args): + self._dict = dict() + self._thread_local = threading.local() + super(MockOsEnv, self).__init__(*args) + + def get(self, key, default=None): + if not hasattr(self._thread_local, 'dict'): + self._thread_local.dict = dict() + if key == 'TF_CONFIG': + return dict.get(self._thread_local.dict, key, default) + else: + return dict.get(self._dict, key, default) + + def __getitem__(self, key): + if not hasattr(self._thread_local, 'dict'): + self._thread_local.dict = dict() + if key == 'TF_CONFIG': + return dict.__getitem__(self._thread_local.dict, key) + else: + return dict.__getitem__(self._dict, key) + + def __setitem__(self, key, val): + if not hasattr(self._thread_local, 'dict'): + self._thread_local.dict = dict() + if key == 'TF_CONFIG': + return dict.__setitem__(self._thread_local.dict, key, val) + else: + return dict.__setitem__(self._dict, key, val) + + def __iter__(self): + if not hasattr(self._thread_local, 'dict'): + self._thread_local.dict = dict() + for x in self._thread_local.dict.items(): + yield x + for x in self._dict.items(): + yield x + + def __len__(self): + if not hasattr(self._thread_local, 'dict'): + self._thread_local.dict = dict() + return self._thread_local.dict.__len__() + self._dict.__len__() + + +class IndependentWorkerTestBase(test.TestCase): + """Testing infra for independent workers.""" + + def setUp(self): + self._mock_os_env = MockOsEnv() + self._mock_context = test.mock.patch.object(os, 'environ', + self._mock_os_env) + self._coord = coordinator.Coordinator() + super(IndependentWorkerTestBase, self).setUp() + self._mock_context.__enter__() + + def tearDown(self): + self._mock_context.__exit__(None, None, None) + super(IndependentWorkerTestBase, self).tearDown() + + def _task_thread(self, task_fn, tf_config, *args, **kwargs): + with self._coord.stop_on_exception(): + os.environ['TF_CONFIG'] = json.dumps(tf_config) + task_fn(*args, **kwargs) + + def _run_task_in_thread(self, task_fn, cluster_spec, task_type, task_id, + *args, **kwargs): + if task_type: + tf_config = { + 'cluster': cluster_spec, + 'task': { + 'type': task_type, + 'index': task_id + } + } + else: + tf_config = { + 'cluster': cluster_spec, + } + t = threading.Thread( + target=self._task_thread, + args=(task_fn, tf_config) + args, + kwargs=kwargs) + t.start() + return t + + def run_multiple_tasks_in_threads(self, task_fn, cluster_spec, *args, + **kwargs): + # The task_fn should create std_server by itself. + threads = {} + for task_type in cluster_spec.keys(): + threads[task_type] = [] + for task_id in range(len(cluster_spec[task_type])): + t = self._run_task_in_thread(task_fn, cluster_spec, task_type, task_id, + *args, **kwargs) + threads[task_type].append(t) + return threads + + def join_independent_workers(self, worker_threads): + self._coord.join(worker_threads) diff --git a/tensorflow/contrib/distribute/python/one_device_strategy.py b/tensorflow/contrib/distribute/python/one_device_strategy.py index e322b6acb84c166a885c9aaa3002f331903a5063..fdbfba4e04358451a46b23ef250dc7c534c855a0 100644 --- a/tensorflow/contrib/distribute/python/one_device_strategy.py +++ b/tensorflow/contrib/distribute/python/one_device_strategy.py @@ -60,7 +60,7 @@ class OneDeviceExtended(distribute_lib.DistributionStrategyExtended): if isinstance(colocate_with, six.string_types): with ops.device(colocate_with): return next_creator(*args, **kwargs) - if (isinstance(colocate_with, list) and len(colocate_with) == 1 and + if (isinstance(colocate_with, (list, tuple)) and len(colocate_with) == 1 and isinstance(colocate_with[0], six.string_types)): with ops.device(colocate_with[0]): return next_creator(*args, **kwargs) @@ -166,7 +166,7 @@ class OneDeviceExtended(distribute_lib.DistributionStrategyExtended): return array_ops.identity(replica_local_var) def _unwrap(self, value): - return [value] + return (value,) def value_container(self, value): return value @@ -177,15 +177,15 @@ class OneDeviceExtended(distribute_lib.DistributionStrategyExtended): @property def worker_devices(self): - return [self._device] + return (self._device,) @property def parameter_devices(self): - return [self._device] + return (self._device,) def non_slot_devices(self, var_list): del var_list - return [self._device] + return (self._device,) @property def experimental_should_init(self): @@ -216,4 +216,4 @@ class _OneDeviceReplicaContext(distribute_lib.ReplicaContext): @property def devices(self): - return [self._distribution_strategy.extended.worker_devices[0]] + return self._distribution_strategy.extended.worker_devices diff --git a/tensorflow/contrib/distribute/python/parameter_server_strategy.py b/tensorflow/contrib/distribute/python/parameter_server_strategy.py index eaeb4d703015fc0762359b24dc23888c01e69111..ca51b07be6601dd615e24137e51c4b34793fdbc0 100644 --- a/tensorflow/contrib/distribute/python/parameter_server_strategy.py +++ b/tensorflow/contrib/distribute/python/parameter_server_strategy.py @@ -145,14 +145,14 @@ class ParameterServerExtended(distribute_lib.DistributionStrategyExtended): # replica. When there are GPUs, replicate operations on these GPUs. # Otherwise, place operations on CPU. if num_gpus_per_worker > 0: - self._compute_devices = [ + self._compute_devices = tuple( "%s/device:GPU:%d" % (self._worker_device, i) for i in range(num_gpus_per_worker) - ] + ) else: - self._compute_devices = [self._worker_device] + self._compute_devices = (self._worker_device,) - self._compute_devices = list( + self._compute_devices = tuple( map(device_util.resolve, self._compute_devices)) self._canonical_compute_device_set = set(self._compute_devices) @@ -176,8 +176,8 @@ class ParameterServerExtended(distribute_lib.DistributionStrategyExtended): # The `_parameter_devices` is needed for the `parameter_devices` property # and is a list of all variable devices. Here parameter devices are all # tasks of the "ps" job. - self._parameter_devices = map("/job:ps/task:{}".format, - range(num_ps_replicas)) + self._parameter_devices = tuple(map("/job:ps/task:{}".format, + range(num_ps_replicas))) # Add a default device so that ops without specified devices will not end up # on other workers. @@ -204,24 +204,24 @@ class ParameterServerExtended(distribute_lib.DistributionStrategyExtended): # replica. When there are GPUs, replicate operations on these GPUs. # Otherwise, place operations on CPU. if num_gpus_per_worker > 0: - self._compute_devices = list( + self._compute_devices = tuple( map("/device:GPU:{}".format, range(num_gpus_per_worker))) else: - self._compute_devices = [_LOCAL_CPU] + self._compute_devices = (_LOCAL_CPU,) - self._compute_devices = list( + self._compute_devices = tuple( map(device_util.resolve, self._compute_devices)) self._canonical_compute_device_set = set(self._compute_devices) # If there is only one GPU, put everything on that GPU. Otherwise, place # variables on CPU. if num_gpus_per_worker == 1: - assert len(list(self._compute_devices)) == 1 + assert len(self._compute_devices) == 1 self._variable_device = _LOCAL_GPU_0 - self._parameter_devices = [_LOCAL_GPU_0] + self._parameter_devices = (_LOCAL_GPU_0,) else: self._variable_device = _LOCAL_CPU - self._parameter_devices = [_LOCAL_CPU] + self._parameter_devices = (_LOCAL_CPU,) self._is_chief = True self._cluster_spec = None @@ -356,7 +356,7 @@ class ParameterServerExtended(distribute_lib.DistributionStrategyExtended): self._verify_destinations_not_different_worker(destinations) if not isinstance(value, values.DistributedValues): # pylint: disable=protected-access - return mirrored_strategy._reduce_non_distributed_value( + return cross_device_ops_lib.reduce_non_distributed_value( self, reduce_op, value, destinations) return self._cross_device_ops.reduce( reduce_op, value, destinations=destinations) @@ -417,9 +417,9 @@ class ParameterServerExtended(distribute_lib.DistributionStrategyExtended): if isinstance(val, values.DistributedValues): # Return in a deterministic order. if set(val.devices) == self._canonical_compute_device_set: - return [val.get(device=d) for d in self._compute_devices] - return [val.get(device=d) for d in sorted(val.devices)] - return [val] + return tuple(val.get(device=d) for d in self._compute_devices) + return tuple(val.get(device=d) for d in sorted(val.devices)) + return (val,) def value_container(self, val): if (hasattr(val, "_aggregating_container") and @@ -497,12 +497,11 @@ class ParameterServerExtended(distribute_lib.DistributionStrategyExtended): @property def worker_devices(self): - # Make a copy to prevent users from accidentally mutating our copy. - return list(self._compute_devices) + return self._compute_devices @property def parameter_devices(self): - return list(self._parameter_devices) + return self._parameter_devices def non_slot_devices(self, var_list): return min(var_list, key=lambda x: x.name) diff --git a/tensorflow/contrib/distribute/python/strategy_test_lib.py b/tensorflow/contrib/distribute/python/strategy_test_lib.py index d50b142c5e9ad36522b11a77219140a7b40d9bf6..d441b5af5f6aa41efde2c75d09d9589516c54992 100644 --- a/tensorflow/contrib/distribute/python/strategy_test_lib.py +++ b/tensorflow/contrib/distribute/python/strategy_test_lib.py @@ -290,4 +290,4 @@ class DistributionTestBase(test.TestCase): self.evaluate(strategy.group(train_ops)) global_step_tensors = strategy.unwrap(value) global_step_values = self.evaluate(global_step_tensors) - self.assertEqual([1] * len(global_step_tensors), global_step_values) + self.assertEqual((1,) * len(global_step_tensors), global_step_values) diff --git a/tensorflow/contrib/distribute/python/tpu_strategy.py b/tensorflow/contrib/distribute/python/tpu_strategy.py index 39ed8f7cf10371c0e8dd70e2bdf53f13e8ce8383..7ea245eb6eb9738bc95e8ac54c1c43de0ddcef7c 100644 --- a/tensorflow/contrib/distribute/python/tpu_strategy.py +++ b/tensorflow/contrib/distribute/python/tpu_strategy.py @@ -28,6 +28,8 @@ from tensorflow.contrib.tpu.python.ops import tpu_ops from tensorflow.contrib.tpu.python.tpu import tpu from tensorflow.contrib.tpu.python.tpu import tpu_system_metadata as tpu_system_metadata_lib from tensorflow.contrib.tpu.python.tpu import training_loop +from tensorflow.core.protobuf import config_pb2 +from tensorflow.python.client import session as session_lib from tensorflow.python.distribute import cross_device_ops as cross_device_ops_lib from tensorflow.python.distribute import device_util from tensorflow.python.distribute import distribute_lib @@ -43,12 +45,10 @@ from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import variable_scope as vs +from tensorflow.python.platform import tf_logging as logging from tensorflow.python.util import nest -_TPU_INITIALIZE_SYSTEM_COLLECTION = "TPU_STRATEGY_INITIALIZE" - - def get_tpu_system_metadata(tpu_cluster_resolver): """Retrieves TPU system metadata given a TPUClusterResolver.""" master = tpu_cluster_resolver.master() @@ -145,6 +145,9 @@ class TPUStrategy(distribute_lib.DistributionStrategy): class TPUExtended(distribute_lib.DistributionStrategyExtended): """Implementation of TPUStrategy.""" + # Track what TPU devices have been initialized. + _initialized_devices = [] + def __init__(self, container_strategy, tpu_cluster_resolver, steps_per_run, num_cores=None): super(TPUExtended, self).__init__(container_strategy) @@ -159,16 +162,41 @@ class TPUExtended(distribute_lib.DistributionStrategyExtended): if "device:TPU:" in d.name} self._device_index = values.PerReplica(device_map) self._host_device = self.get_host_cpu_device(0) - self._tpu_devices = sorted(device_map.keys()) + self._tpu_devices = tuple(sorted(device_map.keys())) # Only create variables for the number of replicas we're running. self._tpu_devices = self._tpu_devices[:self._num_replicas_in_sync] # TODO(sourabhbajaj): Remove this once performance of running one step # at a time is comparable to multiple steps. self.steps_per_run = steps_per_run - self._require_static_shapes = True + # Initialize the TPU devices. + self._initialize_tpu() + + def _initialize_tpu(self): + """Initialize the TPU devices in a separate session and graph. + + We keep track of all the TPU devices that we're initialized as we should + only be running TPU initialize once for the entire process. + """ + master = self._tpu_cluster_resolver.master() + # Verify TPU has not already been initialized in this process. + if master in TPUExtended._initialized_devices: + logging.info("TPU master %s has already been initialized." % master) + return + + logging.info("Initializing the TPU system.") + session_config = config_pb2.ConfigProto(allow_soft_placement=True) + self._configure(session_config) + with ops.Graph().as_default(): + with session_lib.Session(config=session_config, target=master) as sess: + sess.run([tpu.initialize_system()]) + logging.info("Finized initializing TPU system.") + + # Update Strategy state to make sure we can track device initialization. + TPUExtended._initialized_devices.append(master) + def _get_enqueue_op_per_host(self, host_id, multi_worker_iterator, input_shapes, iterations): """Create an enqueue op for a single host identified using host_id. @@ -380,22 +408,14 @@ class TPUExtended(distribute_lib.DistributionStrategyExtended): # TODO(priyag): Add appopriate call here when eager is supported for TPUs. raise NotImplementedError("Eager mode not supported in TPUStrategy.") else: - # TODO(jhseu): We need this hack because DistributionStrategies must be - # pickleable for copy.deepcopy(). Remove when initialize_system goes away. - graph = ops.get_default_graph() - tpu_init = graph.get_collection(_TPU_INITIALIZE_SYSTEM_COLLECTION) - if tpu_init: - return tpu_init - graph.add_to_collection(_TPU_INITIALIZE_SYSTEM_COLLECTION, - tpu.initialize_system()) - return graph.get_collection(_TPU_INITIALIZE_SYSTEM_COLLECTION) + return [] def _finalize(self): if context.executing_eagerly(): # TODO(priyag): Add appopriate call here when eager is supported for TPUs. raise NotImplementedError("Eager mode not supported in TPUStrategy.") else: - return [tpu.shutdown_system()] + return [] def _get_devices_from(self, colocate_with=None): # TODO(jhseu): Change this when we support model parallelism. @@ -445,6 +465,14 @@ class TPUExtended(distribute_lib.DistributionStrategyExtended): "Currently only support sum & mean in TPUStrategy.") return tpu_ops.cross_replica_sum(value) + if not isinstance(value, values.DistributedValues): + # This function handles reducing values that are not PerReplica or + # Mirrored values. For example, the same value could be present on all + # replicas in which case `value` would be a single value or value could + # be 0. + return cross_device_ops_lib.reduce_non_distributed_value( + self, reduce_op, value, destinations) + # Validate that the destination is same as the host device # Note we don't do this when in replicate context as the reduction is # performed on the TPU device itself. @@ -487,13 +515,13 @@ class TPUExtended(distribute_lib.DistributionStrategyExtended): def _unwrap(self, val): if isinstance(val, values.DistributedValues): # Return in a deterministic order. - return [val.get(device=d) for d in sorted(val.devices)] + return tuple(val.get(device=d) for d in sorted(val.devices)) elif isinstance(val, list): # TODO(josh11b): We need to remove this case; per device values should # be represented using a PerReplica wrapper instead of a list with # one entry per device. - return val - return [val] + return tuple(val) + return (val,) def value_container(self, value): return value @@ -599,4 +627,4 @@ class _TPUReplicaContext(distribute_lib.ReplicaContext): distribute_lib.require_replica_context(self) ds = self._distribution_strategy replica_id = tensor_util.constant_value(self._replica_id_in_sync_group) - return [ds.extended.worker_devices[replica_id]] + return (ds.extended.worker_devices[replica_id],) diff --git a/tensorflow/contrib/eager/python/evaluator.py b/tensorflow/contrib/eager/python/evaluator.py index 7949a3f6da293abdd85512209242bae76ab4d816..51443d24829bdc31a41813e0ff50ad7102422112 100644 --- a/tensorflow/contrib/eager/python/evaluator.py +++ b/tensorflow/contrib/eager/python/evaluator.py @@ -22,6 +22,7 @@ import six from tensorflow.contrib.eager.python import datasets from tensorflow.contrib.eager.python import metrics +from tensorflow.python.data.ops import dataset_ops from tensorflow.python.eager import context from tensorflow.python.eager import function from tensorflow.python.framework import errors_impl @@ -164,8 +165,8 @@ class Evaluator(object): self.__call__(example, *args, **kwargs) return self.all_metric_results(summary_logdir) # Graph construction - call_op = self.__call__(dataset.make_one_shot_iterator().get_next(), *args, - **kwargs) + call_op = self.__call__( + dataset_ops.make_one_shot_iterator(dataset).get_next(), *args, **kwargs) init_op = self.init_variables() results_op = self.all_metric_results(summary_logdir) return (init_op, call_op, results_op) diff --git a/tensorflow/contrib/eager/python/examples/densenet/BUILD b/tensorflow/contrib/eager/python/examples/densenet/BUILD index 2dc196f550a10367066730f6f042c4ed69533ec3..e2154fcc5fcf774dcd52285d9442dfd5073a4992 100644 --- a/tensorflow/contrib/eager/python/examples/densenet/BUILD +++ b/tensorflow/contrib/eager/python/examples/densenet/BUILD @@ -3,6 +3,7 @@ licenses(["notice"]) # Apache 2.0 package(default_visibility = ["//tensorflow:internal"]) load("//tensorflow:tensorflow.bzl", "cuda_py_test") +load("//tensorflow:tensorflow.bzl", "py_binary") py_binary( name = "densenet", diff --git a/tensorflow/contrib/eager/python/examples/densenet/densenet_graph_test.py b/tensorflow/contrib/eager/python/examples/densenet/densenet_graph_test.py index 4b3cb624bc947a1d1956eff6accb6d4da3bf3b87..24f6b007b526b29157011f3b1e9abdbd50bacc8e 100644 --- a/tensorflow/contrib/eager/python/examples/densenet/densenet_graph_test.py +++ b/tensorflow/contrib/eager/python/examples/densenet/densenet_graph_test.py @@ -119,7 +119,8 @@ class DensenetBenchmark(tf.test.Benchmark): with tf.Graph().as_default(): np_images, np_labels = random_batch(batch_size) dataset = tf.data.Dataset.from_tensors((np_images, np_labels)).repeat() - (images, labels) = dataset.make_one_shot_iterator().get_next() + (images, labels) = tf.compat.v1.data.make_one_shot_iterator( + dataset).get_next() model = densenet.DenseNet(self.depth, self.growth_rate, self.num_blocks, self.output_classes, diff --git a/tensorflow/contrib/eager/python/examples/gan/mnist_graph_test.py b/tensorflow/contrib/eager/python/examples/gan/mnist_graph_test.py index 12b39b0cde49d4c017acfa74572c725036c54eff..e73841fbf724e05eaa3be90cc8650f795d3e1ccf 100644 --- a/tensorflow/contrib/eager/python/examples/gan/mnist_graph_test.py +++ b/tensorflow/contrib/eager/python/examples/gan/mnist_graph_test.py @@ -42,7 +42,8 @@ class MnistGraphGanBenchmark(tf.test.Benchmark): # Generate some random data. images_data = np.random.randn(batch_size, 784).astype(np.float32) dataset = tf.data.Dataset.from_tensors(images_data) - images = dataset.repeat().make_one_shot_iterator().get_next() + images = tf.compat.v1.data.make_one_shot_iterator( + dataset.repeat()).get_next() # Create the models and optimizers generator = mnist.Generator(data_format()) diff --git a/tensorflow/contrib/eager/python/examples/generative_examples/cvae.ipynb b/tensorflow/contrib/eager/python/examples/generative_examples/cvae.ipynb index ca27a85a229d41a85fa26ecdc982da478fe9e202..1a08cc0fd06516be4af5c2b0b46a3ffcf9101e95 100644 --- a/tensorflow/contrib/eager/python/examples/generative_examples/cvae.ipynb +++ b/tensorflow/contrib/eager/python/examples/generative_examples/cvae.ipynb @@ -470,7 +470,7 @@ "\n", " if epoch % 1 == 0:\n", " loss = tfe.metrics.Mean()\n", - " for test_x in test_dataset.make_one_shot_iterator():\n", + " for test_x in test_dataset:\n", " loss(compute_loss(model, test_x))\n", " elbo = -loss.result()\n", " display.clear_output(wait=False)\n", diff --git a/tensorflow/contrib/eager/python/examples/generative_examples/image_captioning_with_attention.ipynb b/tensorflow/contrib/eager/python/examples/generative_examples/image_captioning_with_attention.ipynb index 3acecd283cda83992bab0c37cf0b8037ed2cf27a..12c5eff2b4aa901bdab52bf545e95b1e4dce7468 100644 --- a/tensorflow/contrib/eager/python/examples/generative_examples/image_captioning_with_attention.ipynb +++ b/tensorflow/contrib/eager/python/examples/generative_examples/image_captioning_with_attention.ipynb @@ -1,1184 +1,1174 @@ { - "nbformat": 4, - "nbformat_minor": 0, - "metadata": { + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "K2s1A9eLRPEj" + }, + "source": [ + "##### Copyright 2018 The TensorFlow Authors.\n", + "\n", + "Licensed under the Apache License, Version 2.0 (the \"License\").\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "Cffg2i257iMS" + }, + "source": [ + "# Image Captioning with Attention\n", + "\n", + "
\n", + "\n", + " Run in Google Colab \n", + "\n", + "View source on GitHub
" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "QASbY_HGo4Lq" + }, + "source": [ + "Image captioning is the task of generating a caption for an image. Given an image like this:\n", + "\n", + "![Man Surfing](https://tensorflow.org/images/surf.jpg) \n", + "\n", + "[Image Source](https://commons.wikimedia.org/wiki/Surfing#/media/File:Surfing_in_Hawaii.jpg), License: Public Domain\n", + "\n", + "Our goal is to generate a caption, such as \"a surfer riding on a wave\". Here, we'll use an attention-based model. This enables us to see which parts of the image the model focuses on as it generates a caption.\n", + "\n", + "![Prediction](https://tensorflow.org/images/imcap_prediction.png)\n", + "\n", + "This model architecture below is similar to [Show, Attend and Tell: Neural Image Caption Generation with Visual Attention](https://arxiv.org/abs/1502.03044). \n", + "\n", + "The code uses [tf.keras](https://www.tensorflow.org/programmers_guide/keras) and [eager execution](https://www.tensorflow.org/programmers_guide/eager), which you can learn more about in the linked guides.\n", + "\n", + "This notebook is an end-to-end example. If you run it, it will download the [MS-COCO](http://cocodataset.org/#home) dataset, preprocess and cache a subset of the images using Inception V3, train an encoder-decoder model, and use it to generate captions on new images.\n", + "\n", + "The code requires TensorFlow version >=1.9. If you're running this in [Colab]()\n", + "\n", + "In this example, we're training on a relatively small amount of data as an example. On a single P100 GPU, this example will take about ~2 hours to train. We train on the first 30,000 captions (corresponding to about ~20,000 images depending on shuffling, as there are multiple captions per image in the dataset)\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { "colab": { - "name": "image_captioning_with_attention.ipynb", - "version": "0.3.2", - "views": {}, - "default_view": {}, - "provenance": [ - { - "file_id": "1HI8OK2sMjcx9CTWVn0122QAHOuXaOaMg", - "timestamp": 1530222436922 - } - ], - "private_outputs": true, - "collapsed_sections": [], - "toc_visible": true + "autoexec": { + "startup": false, + "wait_interval": 0 + } }, - "kernelspec": { - "display_name": "Python 3", - "language": "python", - "name": "python3" - }, - "accelerator": "GPU" + "colab_type": "code", + "id": "U8l4RJ0XRPEm" + }, + "outputs": [], + "source": [ + "# Import TensorFlow and enable eager execution\n", + "# This code requires TensorFlow version >=1.9\n", + "import tensorflow as tf\n", + "tf.enable_eager_execution()\n", + "\n", + "# We'll generate plots of attention in order to see which parts of an image\n", + "# our model focuses on during captioning\n", + "import matplotlib.pyplot as plt\n", + "\n", + "# Scikit-learn includes many helpful utilities\n", + "from sklearn.model_selection import train_test_split\n", + "from sklearn.utils import shuffle\n", + "\n", + "import re\n", + "import numpy as np\n", + "import os\n", + "import time\n", + "import json\n", + "from glob import glob\n", + "from PIL import Image\n", + "import pickle" + ] }, - "cells": [ - { - "metadata": { - "id": "K2s1A9eLRPEj", - "colab_type": "text" - }, - "cell_type": "markdown", - "source": [ - "##### Copyright 2018 The TensorFlow Authors.\n", - "\n", - "Licensed under the Apache License, Version 2.0 (the \"License\").\n" - ] - }, - { - "metadata": { - "id": "Cffg2i257iMS", - "colab_type": "text" - }, - "cell_type": "markdown", - "source": [ - "# Image Captioning with Attention\n", - "\n", - "
\n", - "\n", - " Run in Google Colab \n", - "\n", - "View source on GitHub
" - ] - }, - { - "metadata": { - "id": "QASbY_HGo4Lq", - "colab_type": "text" - }, - "cell_type": "markdown", - "source": [ - "Image captioning is the task of generating a caption for an image. Given an image like this:\n", - "\n", - "![Man Surfing](https://tensorflow.org/images/surf.jpg) \n", - "\n", - "[Image Source](https://commons.wikimedia.org/wiki/Surfing#/media/File:Surfing_in_Hawaii.jpg), License: Public Domain\n", - "\n", - "Our goal is to generate a caption, such as \"a surfer riding on a wave\". Here, we'll use an attention-based model. This enables us to see which parts of the image the model focuses on as it generates a caption.\n", - "\n", - "![Prediction](https://tensorflow.org/images/imcap_prediction.png)\n", - "\n", - "This model architecture below is similar to [Show, Attend and Tell: Neural Image Caption Generation with Visual Attention](https://arxiv.org/abs/1502.03044). \n", - "\n", - "The code uses [tf.keras](https://www.tensorflow.org/programmers_guide/keras) and [eager execution](https://www.tensorflow.org/programmers_guide/eager), which you can learn more about in the linked guides.\n", - "\n", - "This notebook is an end-to-end example. If you run it, it will download the [MS-COCO](http://cocodataset.org/#home) dataset, preprocess and cache a subset of the images using Inception V3, train an encoder-decoder model, and use it to generate captions on new images.\n", - "\n", - "The code requires TensorFlow version >=1.9. If you're running this in [Colab]()\n", - "\n", - "In this example, we're training on a relatively small amount of data as an example. On a single P100 GPU, this example will take about ~2 hours to train. We train on the first 30,000 captions (corresponding to about ~20,000 images depending on shuffling, as there are multiple captions per image in the dataset)\n" - ] - }, - { - "metadata": { - "id": "U8l4RJ0XRPEm", - "colab_type": "code", - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - } - }, - "cell_type": "code", - "source": [ - "# Import TensorFlow and enable eager execution\n", - "# This code requires TensorFlow version >=1.9\n", - "import tensorflow as tf\n", - "tf.enable_eager_execution()\n", - "\n", - "# We'll generate plots of attention in order to see which parts of an image\n", - "# our model focuses on during captioning\n", - "import matplotlib.pyplot as plt\n", - "\n", - "# Scikit-learn includes many helpful utilities\n", - "from sklearn.model_selection import train_test_split\n", - "from sklearn.utils import shuffle\n", - "\n", - "import re\n", - "import numpy as np\n", - "import os\n", - "import time\n", - "import json\n", - "from glob import glob\n", - "from PIL import Image\n", - "import pickle" - ], - "execution_count": 0, - "outputs": [] - }, - { - "metadata": { - "id": "b6qbGw8MRPE5", - "colab_type": "text" - }, - "cell_type": "markdown", - "source": [ - "## Download and prepare the MS-COCO dataset\n", - "\n", - "We will use the [MS-COCO dataset](http://cocodataset.org/#home) to train our model. This dataset contains >82,000 images, each of which has been annotated with at least 5 different captions. The code below will download and extract the dataset automatically. \n", - "\n", - "**Caution: large download ahead**. We'll use the training set, it's a 13GB file." - ] - }, - { - "metadata": { - "id": "krQuPYTtRPE7", - "colab_type": "code", - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - } - }, - "cell_type": "code", - "source": [ - "annotation_zip = tf.keras.utils.get_file('captions.zip', \n", - " cache_subdir=os.path.abspath('.'),\n", - " origin = 'http://images.cocodataset.org/annotations/annotations_trainval2014.zip',\n", - " extract = True)\n", - "annotation_file = os.path.dirname(annotation_zip)+'/annotations/captions_train2014.json'\n", - "\n", - "name_of_zip = 'train2014.zip'\n", - "if not os.path.exists(os.path.abspath('.') + '/' + name_of_zip):\n", - " image_zip = tf.keras.utils.get_file(name_of_zip, \n", - " cache_subdir=os.path.abspath('.'),\n", - " origin = 'http://images.cocodataset.org/zips/train2014.zip',\n", - " extract = True)\n", - " PATH = os.path.dirname(image_zip)+'/train2014/'\n", - "else:\n", - " PATH = os.path.abspath('.')+'/train2014/'" - ], - "execution_count": 0, - "outputs": [] - }, - { - "metadata": { - "id": "aANEzb5WwSzg", - "colab_type": "text" - }, - "cell_type": "markdown", - "source": [ - "## Optionally, limit the size of the training set for faster training\n", - "For this example, we'll select a subset of 30,000 captions and use these and the corresponding images to train our model. As always, captioning quality will improve if you choose to use more data." - ] - }, - { - "metadata": { - "id": "4G3b8x8_RPFD", - "colab_type": "code", - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - } - }, - "cell_type": "code", - "source": [ - "# read the json file\n", - "with open(annotation_file, 'r') as f:\n", - " annotations = json.load(f)\n", - "\n", - "# storing the captions and the image name in vectors\n", - "all_captions = []\n", - "all_img_name_vector = []\n", - "\n", - "for annot in annotations['annotations']:\n", - " caption = ' ' + annot['caption'] + ' '\n", - " image_id = annot['image_id']\n", - " full_coco_image_path = PATH + 'COCO_train2014_' + '%012d.jpg' % (image_id)\n", - " \n", - " all_img_name_vector.append(full_coco_image_path)\n", - " all_captions.append(caption)\n", - "\n", - "# shuffling the captions and image_names together\n", - "# setting a random state\n", - "train_captions, img_name_vector = shuffle(all_captions,\n", - " all_img_name_vector,\n", - " random_state=1)\n", - "\n", - "# selecting the first 30000 captions from the shuffled set\n", - "num_examples = 30000\n", - "train_captions = train_captions[:num_examples]\n", - "img_name_vector = img_name_vector[:num_examples]" - ], - "execution_count": 0, - "outputs": [] - }, - { - "metadata": { - "id": "mPBMgK34RPFL", - "colab_type": "code", - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - } - }, - "cell_type": "code", - "source": [ - "len(train_captions), len(all_captions)" - ], - "execution_count": 0, - "outputs": [] - }, - { - "metadata": { - "id": "8cSW4u-ORPFQ", - "colab_type": "text" - }, - "cell_type": "markdown", - "source": [ - "## Preprocess the images using InceptionV3\n", - "Next, we will use InceptionV3 (pretrained on Imagenet) to classify each image. We will extract features from the last convolutional layer. \n", - "\n", - "First, we will need to convert the images into the format inceptionV3 expects by:\n", - "* Resizing the image to (299, 299)\n", - "* Using the [preprocess_input](https://www.tensorflow.org/api_docs/python/tf/keras/applications/inception_v3/preprocess_input) method to place the pixels in the range of -1 to 1 (to match the format of the images used to train InceptionV3)." - ] - }, - { - "metadata": { - "id": "zXR0217aRPFR", - "colab_type": "code", - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - } - }, - "cell_type": "code", - "source": [ - "def load_image(image_path):\n", - " img = tf.read_file(image_path)\n", - " img = tf.image.decode_jpeg(img, channels=3)\n", - " img = tf.image.resize_images(img, (299, 299))\n", - " img = tf.keras.applications.inception_v3.preprocess_input(img)\n", - " return img, image_path" - ], - "execution_count": 0, - "outputs": [] - }, - { - "metadata": { - "id": "MDvIu4sXRPFV", - "colab_type": "text" - }, - "cell_type": "markdown", - "source": [ - "## Initialize InceptionV3 and load the pretrained Imagenet weights\n", - "\n", - "To do so, we'll create a tf.keras model where the output layer is the last convolutional layer in the InceptionV3 architecture. \n", - "* Each image is forwarded through the network and the vector that we get at the end is stored in a dictionary (image_name --> feature_vector). \n", - "* We use the last convolutional layer because we are using attention in this example. The shape of the output of this layer is ```8x8x2048```. \n", - "* We avoid doing this during training so it does not become a bottleneck. \n", - "* After all the images are passed through the network, we pickle the dictionary and save it to disk." - ] - }, - { - "metadata": { - "id": "RD3vW4SsRPFW", - "colab_type": "code", - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - } - }, - "cell_type": "code", - "source": [ - "image_model = tf.keras.applications.InceptionV3(include_top=False, \n", - " weights='imagenet')\n", - "new_input = image_model.input\n", - "hidden_layer = image_model.layers[-1].output\n", - "\n", - "image_features_extract_model = tf.keras.Model(new_input, hidden_layer)" - ], - "execution_count": 0, - "outputs": [] - }, - { - "metadata": { - "id": "rERqlR3WRPGO", - "colab_type": "text" - }, - "cell_type": "markdown", - "source": [ - "## Caching the features extracted from InceptionV3\n", - "\n", - "We will pre-process each image with InceptionV3 and cache the output to disk. Caching the output in RAM would be faster but memory intensive, requiring 8 \\* 8 \\* 2048 floats per image. At the time of writing, this would exceed the memory limitations of Colab (although these may change, an instance appears to have about 12GB of memory currently). \n", - "\n", - "Performance could be improved with a more sophisticated caching strategy (e.g., by sharding the images to reduce random access disk I/O) at the cost of more code.\n", - "\n", - "This will take about 10 minutes to run in Colab with a GPU. If you'd like to see a progress bar, you could: install [tqdm](https://github.com/tqdm/tqdm) (```!pip install tqdm```), then change this line: \n", - "\n", - "```for img, path in image_dataset:``` \n", - "\n", - "to:\n", - "\n", - "```for img, path in tqdm(image_dataset):```." - ] - }, - { - "metadata": { - "id": "Dx_fvbVgRPGQ", - "colab_type": "code", - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - } - }, - "cell_type": "code", - "source": [ - "# getting the unique images\n", - "encode_train = sorted(set(img_name_vector))\n", - "\n", - "# feel free to change the batch_size according to your system configuration\n", - "image_dataset = tf.data.Dataset.from_tensor_slices(\n", - " encode_train).map(load_image).batch(16)\n", - "\n", - "for img, path in image_dataset:\n", - " batch_features = image_features_extract_model(img)\n", - " batch_features = tf.reshape(batch_features, \n", - " (batch_features.shape[0], -1, batch_features.shape[3]))\n", - "\n", - " for bf, p in zip(batch_features, path):\n", - " path_of_feature = p.numpy().decode(\"utf-8\")\n", - " np.save(path_of_feature, bf.numpy())" - ], - "execution_count": 0, - "outputs": [] - }, - { - "metadata": { - "id": "nyqH3zFwRPFi", - "colab_type": "text" - }, - "cell_type": "markdown", - "source": [ - "## Preprocess and tokenize the captions\n", - "\n", - "* First, we'll tokenize the captions (e.g., by splitting on spaces). This will give us a vocabulary of all the unique words in the data (e.g., \"surfing\", \"football\", etc).\n", - "* Next, we'll limit the vocabulary size to the top 5,000 words to save memory. We'll replace all other words with the token \"UNK\" (for unknown).\n", - "* Finally, we create a word --> index mapping and vice-versa.\n", - "* We will then pad all sequences to the be same length as the longest one. " - ] - }, - { - "metadata": { - "id": "HZfK8RhQRPFj", - "colab_type": "code", - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - } - }, - "cell_type": "code", - "source": [ - "# This will find the maximum length of any caption in our dataset\n", - "def calc_max_length(tensor):\n", - " return max(len(t) for t in tensor)" - ], - "execution_count": 0, - "outputs": [] - }, - { - "metadata": { - "id": "oJGE34aiRPFo", - "colab_type": "code", - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - } - }, - "cell_type": "code", - "source": [ - "# The steps above is a general process of dealing with text processing\n", - "\n", - "# choosing the top 5000 words from the vocabulary\n", - "top_k = 5000\n", - "tokenizer = tf.keras.preprocessing.text.Tokenizer(num_words=top_k, \n", - " oov_token=\"\", \n", - " filters='!\"#$%&()*+.,-/:;=?@[\\]^_`{|}~ ')\n", - "tokenizer.fit_on_texts(train_captions)\n", - "train_seqs = tokenizer.texts_to_sequences(train_captions)" - ], - "execution_count": 0, - "outputs": [] + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "b6qbGw8MRPE5" + }, + "source": [ + "## Download and prepare the MS-COCO dataset\n", + "\n", + "We will use the [MS-COCO dataset](http://cocodataset.org/#home) to train our model. This dataset contains >82,000 images, each of which has been annotated with at least 5 different captions. The code below will download and extract the dataset automatically. \n", + "\n", + "**Caution: large download ahead**. We'll use the training set, it's a 13GB file." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } }, - { - "metadata": { - "id": "8Q44tNQVRPFt", - "colab_type": "code", - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - } - }, - "cell_type": "code", - "source": [ - "tokenizer.word_index = {key:value for key, value in tokenizer.word_index.items() if value <= top_k}\n", - "# putting token in the word2idx dictionary\n", - "tokenizer.word_index[tokenizer.oov_token] = top_k + 1\n", - "tokenizer.word_index[''] = 0" - ], - "execution_count": 0, - "outputs": [] + "colab_type": "code", + "id": "krQuPYTtRPE7" + }, + "outputs": [], + "source": [ + "annotation_zip = tf.keras.utils.get_file('captions.zip', \n", + " cache_subdir=os.path.abspath('.'),\n", + " origin = 'http://images.cocodataset.org/annotations/annotations_trainval2014.zip',\n", + " extract = True)\n", + "annotation_file = os.path.dirname(annotation_zip)+'/annotations/captions_train2014.json'\n", + "\n", + "name_of_zip = 'train2014.zip'\n", + "if not os.path.exists(os.path.abspath('.') + '/' + name_of_zip):\n", + " image_zip = tf.keras.utils.get_file(name_of_zip, \n", + " cache_subdir=os.path.abspath('.'),\n", + " origin = 'http://images.cocodataset.org/zips/train2014.zip',\n", + " extract = True)\n", + " PATH = os.path.dirname(image_zip)+'/train2014/'\n", + "else:\n", + " PATH = os.path.abspath('.')+'/train2014/'" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "aANEzb5WwSzg" + }, + "source": [ + "## Optionally, limit the size of the training set for faster training\n", + "For this example, we'll select a subset of 30,000 captions and use these and the corresponding images to train our model. As always, captioning quality will improve if you choose to use more data." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } }, - { - "metadata": { - "id": "0fpJb5ojRPFv", - "colab_type": "code", - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - } - }, - "cell_type": "code", - "source": [ - "# creating the tokenized vectors\n", - "train_seqs = tokenizer.texts_to_sequences(train_captions)" - ], - "execution_count": 0, - "outputs": [] + "colab_type": "code", + "id": "4G3b8x8_RPFD" + }, + "outputs": [], + "source": [ + "# read the json file\n", + "with open(annotation_file, 'r') as f:\n", + " annotations = json.load(f)\n", + "\n", + "# storing the captions and the image name in vectors\n", + "all_captions = []\n", + "all_img_name_vector = []\n", + "\n", + "for annot in annotations['annotations']:\n", + " caption = ' ' + annot['caption'] + ' '\n", + " image_id = annot['image_id']\n", + " full_coco_image_path = PATH + 'COCO_train2014_' + '%012d.jpg' % (image_id)\n", + " \n", + " all_img_name_vector.append(full_coco_image_path)\n", + " all_captions.append(caption)\n", + "\n", + "# shuffling the captions and image_names together\n", + "# setting a random state\n", + "train_captions, img_name_vector = shuffle(all_captions,\n", + " all_img_name_vector,\n", + " random_state=1)\n", + "\n", + "# selecting the first 30000 captions from the shuffled set\n", + "num_examples = 30000\n", + "train_captions = train_captions[:num_examples]\n", + "img_name_vector = img_name_vector[:num_examples]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } }, - { - "metadata": { - "id": "olQArbgbRPF1", - "colab_type": "code", - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - } - }, - "cell_type": "code", - "source": [ - "# creating a reverse mapping (index -> word)\n", - "index_word = {value:key for key, value in tokenizer.word_index.items()}" - ], - "execution_count": 0, - "outputs": [] + "colab_type": "code", + "id": "mPBMgK34RPFL" + }, + "outputs": [], + "source": [ + "len(train_captions), len(all_captions)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "8cSW4u-ORPFQ" + }, + "source": [ + "## Preprocess the images using InceptionV3\n", + "Next, we will use InceptionV3 (pretrained on Imagenet) to classify each image. We will extract features from the last convolutional layer. \n", + "\n", + "First, we will need to convert the images into the format inceptionV3 expects by:\n", + "* Resizing the image to (299, 299)\n", + "* Using the [preprocess_input](https://www.tensorflow.org/api_docs/python/tf/keras/applications/inception_v3/preprocess_input) method to place the pixels in the range of -1 to 1 (to match the format of the images used to train InceptionV3)." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } }, - { - "metadata": { - "id": "AidglIZVRPF4", - "colab_type": "code", - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - } - }, - "cell_type": "code", - "source": [ - "# padding each vector to the max_length of the captions\n", - "# if the max_length parameter is not provided, pad_sequences calculates that automatically\n", - "cap_vector = tf.keras.preprocessing.sequence.pad_sequences(train_seqs, padding='post')" - ], - "execution_count": 0, - "outputs": [] + "colab_type": "code", + "id": "zXR0217aRPFR" + }, + "outputs": [], + "source": [ + "def load_image(image_path):\n", + " img = tf.read_file(image_path)\n", + " img = tf.image.decode_jpeg(img, channels=3)\n", + " img = tf.image.resize_images(img, (299, 299))\n", + " img = tf.keras.applications.inception_v3.preprocess_input(img)\n", + " return img, image_path" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "MDvIu4sXRPFV" + }, + "source": [ + "## Initialize InceptionV3 and load the pretrained Imagenet weights\n", + "\n", + "To do so, we'll create a tf.keras model where the output layer is the last convolutional layer in the InceptionV3 architecture. \n", + "* Each image is forwarded through the network and the vector that we get at the end is stored in a dictionary (image_name --> feature_vector). \n", + "* We use the last convolutional layer because we are using attention in this example. The shape of the output of this layer is ```8x8x2048```. \n", + "* We avoid doing this during training so it does not become a bottleneck. \n", + "* After all the images are passed through the network, we pickle the dictionary and save it to disk." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } }, - { - "metadata": { - "id": "gL0wkttkRPGA", - "colab_type": "code", - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - } - }, - "cell_type": "code", - "source": [ - "# calculating the max_length \n", - "# used to store the attention weights\n", - "max_length = calc_max_length(train_seqs)" - ], - "execution_count": 0, - "outputs": [] + "colab_type": "code", + "id": "RD3vW4SsRPFW" + }, + "outputs": [], + "source": [ + "image_model = tf.keras.applications.InceptionV3(include_top=False, \n", + " weights='imagenet')\n", + "new_input = image_model.input\n", + "hidden_layer = image_model.layers[-1].output\n", + "\n", + "image_features_extract_model = tf.keras.Model(new_input, hidden_layer)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "rERqlR3WRPGO" + }, + "source": [ + "## Caching the features extracted from InceptionV3\n", + "\n", + "We will pre-process each image with InceptionV3 and cache the output to disk. Caching the output in RAM would be faster but memory intensive, requiring 8 \\* 8 \\* 2048 floats per image. At the time of writing, this would exceed the memory limitations of Colab (although these may change, an instance appears to have about 12GB of memory currently). \n", + "\n", + "Performance could be improved with a more sophisticated caching strategy (e.g., by sharding the images to reduce random access disk I/O) at the cost of more code.\n", + "\n", + "This will take about 10 minutes to run in Colab with a GPU. If you'd like to see a progress bar, you could: install [tqdm](https://github.com/tqdm/tqdm) (```!pip install tqdm```), then change this line: \n", + "\n", + "```for img, path in image_dataset:``` \n", + "\n", + "to:\n", + "\n", + "```for img, path in tqdm(image_dataset):```." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } }, - { - "metadata": { - "id": "M3CD75nDpvTI", - "colab_type": "text" - }, - "cell_type": "markdown", - "source": [ - "## Split the data into training and testing" - ] + "colab_type": "code", + "id": "Dx_fvbVgRPGQ" + }, + "outputs": [], + "source": [ + "# getting the unique images\n", + "encode_train = sorted(set(img_name_vector))\n", + "\n", + "# feel free to change the batch_size according to your system configuration\n", + "image_dataset = tf.data.Dataset.from_tensor_slices(\n", + " encode_train).map(load_image).batch(16)\n", + "\n", + "for img, path in image_dataset:\n", + " batch_features = image_features_extract_model(img)\n", + " batch_features = tf.reshape(batch_features, \n", + " (batch_features.shape[0], -1, batch_features.shape[3]))\n", + "\n", + " for bf, p in zip(batch_features, path):\n", + " path_of_feature = p.numpy().decode(\"utf-8\")\n", + " np.save(path_of_feature, bf.numpy())" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "nyqH3zFwRPFi" + }, + "source": [ + "## Preprocess and tokenize the captions\n", + "\n", + "* First, we'll tokenize the captions (e.g., by splitting on spaces). This will give us a vocabulary of all the unique words in the data (e.g., \"surfing\", \"football\", etc).\n", + "* Next, we'll limit the vocabulary size to the top 5,000 words to save memory. We'll replace all other words with the token \"UNK\" (for unknown).\n", + "* Finally, we create a word --> index mapping and vice-versa.\n", + "* We will then pad all sequences to the be same length as the longest one. " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } }, - { - "metadata": { - "id": "iS7DDMszRPGF", - "colab_type": "code", - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - } - }, - "cell_type": "code", - "source": [ - "# Create training and validation sets using 80-20 split\n", - "img_name_train, img_name_val, cap_train, cap_val = train_test_split(img_name_vector, \n", - " cap_vector, \n", - " test_size=0.2, \n", - " random_state=0)" - ], - "execution_count": 0, - "outputs": [] + "colab_type": "code", + "id": "HZfK8RhQRPFj" + }, + "outputs": [], + "source": [ + "# This will find the maximum length of any caption in our dataset\n", + "def calc_max_length(tensor):\n", + " return max(len(t) for t in tensor)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } }, - { - "metadata": { - "id": "XmViPkRFRPGH", - "colab_type": "code", - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - } - }, - "cell_type": "code", - "source": [ - "len(img_name_train), len(cap_train), len(img_name_val), len(cap_val)" - ], - "execution_count": 0, - "outputs": [] + "colab_type": "code", + "id": "oJGE34aiRPFo" + }, + "outputs": [], + "source": [ + "# The steps above is a general process of dealing with text processing\n", + "\n", + "# choosing the top 5000 words from the vocabulary\n", + "top_k = 5000\n", + "tokenizer = tf.keras.preprocessing.text.Tokenizer(num_words=top_k, \n", + " oov_token=\"\", \n", + " filters='!\"#$%&()*+.,-/:;=?@[\\]^_`{|}~ ')\n", + "tokenizer.fit_on_texts(train_captions)\n", + "train_seqs = tokenizer.texts_to_sequences(train_captions)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } }, - { - "metadata": { - "id": "uEWM9xrYcg45", - "colab_type": "text" - }, - "cell_type": "markdown", - "source": [ - "## Our images and captions are ready! Next, let's create a tf.data dataset to use for training our model.\n", - "\n" - ] + "colab_type": "code", + "id": "8Q44tNQVRPFt" + }, + "outputs": [], + "source": [ + "tokenizer.word_index[''] = 0" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } }, - { - "metadata": { - "id": "Q3TnZ1ToRPGV", - "colab_type": "code", - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - } - }, - "cell_type": "code", - "source": [ - "# feel free to change these parameters according to your system's configuration\n", - "\n", - "BATCH_SIZE = 64\n", - "BUFFER_SIZE = 1000\n", - "embedding_dim = 256\n", - "units = 512\n", - "vocab_size = len(tokenizer.word_index)\n", - "# shape of the vector extracted from InceptionV3 is (64, 2048)\n", - "# these two variables represent that\n", - "features_shape = 2048\n", - "attention_features_shape = 64" - ], - "execution_count": 0, - "outputs": [] + "colab_type": "code", + "id": "0fpJb5ojRPFv" + }, + "outputs": [], + "source": [ + "# creating the tokenized vectors\n", + "train_seqs = tokenizer.texts_to_sequences(train_captions)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } }, - { - "metadata": { - "id": "SmZS2N0bXG3T", - "colab_type": "code", - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - } - }, - "cell_type": "code", - "source": [ - "# loading the numpy files \n", - "def map_func(img_name, cap):\n", - " img_tensor = np.load(img_name.decode('utf-8')+'.npy')\n", - " return img_tensor, cap" - ], - "execution_count": 0, - "outputs": [] + "colab_type": "code", + "id": "AidglIZVRPF4" + }, + "outputs": [], + "source": [ + "# padding each vector to the max_length of the captions\n", + "# if the max_length parameter is not provided, pad_sequences calculates that automatically\n", + "cap_vector = tf.keras.preprocessing.sequence.pad_sequences(train_seqs, padding='post')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } }, - { - "metadata": { - "id": "FDF_Nm3tRPGZ", - "colab_type": "code", - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - } - }, - "cell_type": "code", - "source": [ - "dataset = tf.data.Dataset.from_tensor_slices((img_name_train, cap_train))\n", - "\n", - "# using map to load the numpy files in parallel\n", - "# NOTE: Be sure to set num_parallel_calls to the number of CPU cores you have\n", - "# https://www.tensorflow.org/api_docs/python/tf/py_func\n", - "dataset = dataset.map(lambda item1, item2: tf.py_func(\n", - " map_func, [item1, item2], [tf.float32, tf.int32]), num_parallel_calls=8)\n", - "\n", - "# shuffling and batching\n", - "dataset = dataset.shuffle(BUFFER_SIZE)\n", - "# https://www.tensorflow.org/api_docs/python/tf/contrib/data/batch_and_drop_remainder\n", - "dataset = dataset.batch(BATCH_SIZE)\n", - "dataset = dataset.prefetch(1)" - ], - "execution_count": 0, - "outputs": [] + "colab_type": "code", + "id": "gL0wkttkRPGA" + }, + "outputs": [], + "source": [ + "# calculating the max_length \n", + "# used to store the attention weights\n", + "max_length = calc_max_length(train_seqs)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "M3CD75nDpvTI" + }, + "source": [ + "## Split the data into training and testing" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } }, - { - "metadata": { - "id": "nrvoDphgRPGd", - "colab_type": "text" - }, - "cell_type": "markdown", - "source": [ - "## Model\n", - "\n", - "Fun fact, the decoder below is identical to the one in the example for [Neural Machine Translation with Attention]( https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/eager/python/examples/nmt_with_attention/nmt_with_attention.ipynb).\n", - "\n", - "The model architecture is inspired by the [Show, Attend and Tell](https://arxiv.org/pdf/1502.03044.pdf) paper.\n", - "\n", - "* In this example, we extract the features from the lower convolutional layer of InceptionV3 giving us a vector of shape (8, 8, 2048). \n", - "* We squash that to a shape of (64, 2048).\n", - "* This vector is then passed through the CNN Encoder(which consists of a single Fully connected layer).\n", - "* The RNN(here GRU) attends over the image to predict the next word." - ] + "colab_type": "code", + "id": "iS7DDMszRPGF" + }, + "outputs": [], + "source": [ + "# Create training and validation sets using 80-20 split\n", + "img_name_train, img_name_val, cap_train, cap_val = train_test_split(img_name_vector, \n", + " cap_vector, \n", + " test_size=0.2, \n", + " random_state=0)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } }, - { - "metadata": { - "id": "AAppCGLKRPGd", - "colab_type": "code", - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - } - }, - "cell_type": "code", - "source": [ - "def gru(units):\n", - " # If you have a GPU, we recommend using the CuDNNGRU layer (it provides a \n", - " # significant speedup).\n", - " if tf.test.is_gpu_available():\n", - " return tf.keras.layers.CuDNNGRU(units, \n", - " return_sequences=True, \n", - " return_state=True, \n", - " recurrent_initializer='glorot_uniform')\n", - " else:\n", - " return tf.keras.layers.GRU(units, \n", - " return_sequences=True, \n", - " return_state=True, \n", - " recurrent_activation='sigmoid', \n", - " recurrent_initializer='glorot_uniform')" - ], - "execution_count": 0, - "outputs": [] + "colab_type": "code", + "id": "XmViPkRFRPGH" + }, + "outputs": [], + "source": [ + "len(img_name_train), len(cap_train), len(img_name_val), len(cap_val)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "uEWM9xrYcg45" + }, + "source": [ + "## Our images and captions are ready! Next, let's create a tf.data dataset to use for training our model.\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } }, - { - "metadata": { - "id": "ja2LFTMSdeV3", - "colab_type": "code", - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - } - }, - "cell_type": "code", - "source": [ - "class BahdanauAttention(tf.keras.Model):\n", - " def __init__(self, units):\n", - " super(BahdanauAttention, self).__init__()\n", - " self.W1 = tf.keras.layers.Dense(units)\n", - " self.W2 = tf.keras.layers.Dense(units)\n", - " self.V = tf.keras.layers.Dense(1)\n", - " \n", - " def call(self, features, hidden):\n", - " # features(CNN_encoder output) shape == (batch_size, 64, embedding_dim)\n", - " \n", - " # hidden shape == (batch_size, hidden_size)\n", - " # hidden_with_time_axis shape == (batch_size, 1, hidden_size)\n", - " hidden_with_time_axis = tf.expand_dims(hidden, 1)\n", - " \n", - " # score shape == (batch_size, 64, hidden_size)\n", - " score = tf.nn.tanh(self.W1(features) + self.W2(hidden_with_time_axis))\n", - " \n", - " # attention_weights shape == (batch_size, 64, 1)\n", - " # we get 1 at the last axis because we are applying score to self.V\n", - " attention_weights = tf.nn.softmax(self.V(score), axis=1)\n", - " \n", - " # context_vector shape after sum == (batch_size, hidden_size)\n", - " context_vector = attention_weights * features\n", - " context_vector = tf.reduce_sum(context_vector, axis=1)\n", - " \n", - " return context_vector, attention_weights" - ], - "execution_count": 0, - "outputs": [] + "colab_type": "code", + "id": "Q3TnZ1ToRPGV" + }, + "outputs": [], + "source": [ + "# feel free to change these parameters according to your system's configuration\n", + "\n", + "BATCH_SIZE = 64\n", + "BUFFER_SIZE = 1000\n", + "embedding_dim = 256\n", + "units = 512\n", + "vocab_size = len(tokenizer.word_index)\n", + "# shape of the vector extracted from InceptionV3 is (64, 2048)\n", + "# these two variables represent that\n", + "features_shape = 2048\n", + "attention_features_shape = 64" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } }, - { - "metadata": { - "id": "AZ7R1RxHRPGf", - "colab_type": "code", - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - } - }, - "cell_type": "code", - "source": [ - "class CNN_Encoder(tf.keras.Model):\n", - " # Since we have already extracted the features and dumped it using pickle\n", - " # This encoder passes those features through a Fully connected layer\n", - " def __init__(self, embedding_dim):\n", - " super(CNN_Encoder, self).__init__()\n", - " # shape after fc == (batch_size, 64, embedding_dim)\n", - " self.fc = tf.keras.layers.Dense(embedding_dim)\n", - " \n", - " def call(self, x):\n", - " x = self.fc(x)\n", - " x = tf.nn.relu(x)\n", - " return x" - ], - "execution_count": 0, - "outputs": [] + "colab_type": "code", + "id": "SmZS2N0bXG3T" + }, + "outputs": [], + "source": [ + "# loading the numpy files \n", + "def map_func(img_name, cap):\n", + " img_tensor = np.load(img_name.decode('utf-8')+'.npy')\n", + " return img_tensor, cap" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } }, - { - "metadata": { - "id": "V9UbGQmERPGi", - "colab_type": "code", - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - } - }, - "cell_type": "code", - "source": [ - "class RNN_Decoder(tf.keras.Model):\n", - " def __init__(self, embedding_dim, units, vocab_size):\n", - " super(RNN_Decoder, self).__init__()\n", - " self.units = units\n", - "\n", - " self.embedding = tf.keras.layers.Embedding(vocab_size, embedding_dim)\n", - " self.gru = gru(self.units)\n", - " self.fc1 = tf.keras.layers.Dense(self.units)\n", - " self.fc2 = tf.keras.layers.Dense(vocab_size)\n", - " \n", - " self.attention = BahdanauAttention(self.units)\n", - " \n", - " def call(self, x, features, hidden):\n", - " # defining attention as a separate model\n", - " context_vector, attention_weights = self.attention(features, hidden)\n", - " \n", - " # x shape after passing through embedding == (batch_size, 1, embedding_dim)\n", - " x = self.embedding(x)\n", - " \n", - " # x shape after concatenation == (batch_size, 1, embedding_dim + hidden_size)\n", - " x = tf.concat([tf.expand_dims(context_vector, 1), x], axis=-1)\n", - " \n", - " # passing the concatenated vector to the GRU\n", - " output, state = self.gru(x)\n", - " \n", - " # shape == (batch_size, max_length, hidden_size)\n", - " x = self.fc1(output)\n", - " \n", - " # x shape == (batch_size * max_length, hidden_size)\n", - " x = tf.reshape(x, (-1, x.shape[2]))\n", - " \n", - " # output shape == (batch_size * max_length, vocab)\n", - " x = self.fc2(x)\n", - "\n", - " return x, state, attention_weights\n", - "\n", - " def reset_state(self, batch_size):\n", - " return tf.zeros((batch_size, self.units))" - ], - "execution_count": 0, - "outputs": [] + "colab_type": "code", + "id": "FDF_Nm3tRPGZ" + }, + "outputs": [], + "source": [ + "dataset = tf.data.Dataset.from_tensor_slices((img_name_train, cap_train))\n", + "\n", + "# using map to load the numpy files in parallel\n", + "# NOTE: Be sure to set num_parallel_calls to the number of CPU cores you have\n", + "# https://www.tensorflow.org/api_docs/python/tf/py_func\n", + "dataset = dataset.map(lambda item1, item2: tf.py_func(\n", + " map_func, [item1, item2], [tf.float32, tf.int32]), num_parallel_calls=8)\n", + "\n", + "# shuffling and batching\n", + "dataset = dataset.shuffle(BUFFER_SIZE)\n", + "# https://www.tensorflow.org/api_docs/python/tf/contrib/data/batch_and_drop_remainder\n", + "dataset = dataset.batch(BATCH_SIZE)\n", + "dataset = dataset.prefetch(1)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "nrvoDphgRPGd" + }, + "source": [ + "## Model\n", + "\n", + "Fun fact, the decoder below is identical to the one in the example for [Neural Machine Translation with Attention]( https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/eager/python/examples/nmt_with_attention/nmt_with_attention.ipynb).\n", + "\n", + "The model architecture is inspired by the [Show, Attend and Tell](https://arxiv.org/pdf/1502.03044.pdf) paper.\n", + "\n", + "* In this example, we extract the features from the lower convolutional layer of InceptionV3 giving us a vector of shape (8, 8, 2048). \n", + "* We squash that to a shape of (64, 2048).\n", + "* This vector is then passed through the CNN Encoder(which consists of a single Fully connected layer).\n", + "* The RNN(here GRU) attends over the image to predict the next word." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } }, - { - "metadata": { - "id": "Qs_Sr03wRPGk", - "colab_type": "code", - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - } - }, - "cell_type": "code", - "source": [ - "encoder = CNN_Encoder(embedding_dim)\n", - "decoder = RNN_Decoder(embedding_dim, units, vocab_size)" - ], - "execution_count": 0, - "outputs": [] + "colab_type": "code", + "id": "AAppCGLKRPGd" + }, + "outputs": [], + "source": [ + "def gru(units):\n", + " # If you have a GPU, we recommend using the CuDNNGRU layer (it provides a \n", + " # significant speedup).\n", + " if tf.test.is_gpu_available():\n", + " return tf.keras.layers.CuDNNGRU(units, \n", + " return_sequences=True, \n", + " return_state=True, \n", + " recurrent_initializer='glorot_uniform')\n", + " else:\n", + " return tf.keras.layers.GRU(units, \n", + " return_sequences=True, \n", + " return_state=True, \n", + " recurrent_activation='sigmoid', \n", + " recurrent_initializer='glorot_uniform')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } }, - { - "metadata": { - "id": "-bYN7xA0RPGl", - "colab_type": "code", - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - } - }, - "cell_type": "code", - "source": [ - "optimizer = tf.train.AdamOptimizer()\n", - "\n", - "# We are masking the loss calculated for padding\n", - "def loss_function(real, pred):\n", - " mask = 1 - np.equal(real, 0)\n", - " loss_ = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=real, logits=pred) * mask\n", - " return tf.reduce_mean(loss_)" - ], - "execution_count": 0, - "outputs": [] + "colab_type": "code", + "id": "ja2LFTMSdeV3" + }, + "outputs": [], + "source": [ + "class BahdanauAttention(tf.keras.Model):\n", + " def __init__(self, units):\n", + " super(BahdanauAttention, self).__init__()\n", + " self.W1 = tf.keras.layers.Dense(units)\n", + " self.W2 = tf.keras.layers.Dense(units)\n", + " self.V = tf.keras.layers.Dense(1)\n", + " \n", + " def call(self, features, hidden):\n", + " # features(CNN_encoder output) shape == (batch_size, 64, embedding_dim)\n", + " \n", + " # hidden shape == (batch_size, hidden_size)\n", + " # hidden_with_time_axis shape == (batch_size, 1, hidden_size)\n", + " hidden_with_time_axis = tf.expand_dims(hidden, 1)\n", + " \n", + " # score shape == (batch_size, 64, hidden_size)\n", + " score = tf.nn.tanh(self.W1(features) + self.W2(hidden_with_time_axis))\n", + " \n", + " # attention_weights shape == (batch_size, 64, 1)\n", + " # we get 1 at the last axis because we are applying score to self.V\n", + " attention_weights = tf.nn.softmax(self.V(score), axis=1)\n", + " \n", + " # context_vector shape after sum == (batch_size, hidden_size)\n", + " context_vector = attention_weights * features\n", + " context_vector = tf.reduce_sum(context_vector, axis=1)\n", + " \n", + " return context_vector, attention_weights" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } }, - { - "metadata": { - "id": "PHod7t72RPGn", - "colab_type": "text" - }, - "cell_type": "markdown", - "source": [ - "## Training\n", - "\n", - "* We extract the features stored in the respective `.npy` files and then pass those features through the encoder.\n", - "* The encoder output, hidden state(initialized to 0) and the decoder input (which is the start token) is passed to the decoder.\n", - "* The decoder returns the predictions and the decoder hidden state.\n", - "* The decoder hidden state is then passed back into the model and the predictions are used to calculate the loss.\n", - "* Use teacher forcing to decide the next input to the decoder.\n", - "* Teacher forcing is the technique where the target word is passed as the next input to the decoder.\n", - "* The final step is to calculate the gradients and apply it to the optimizer and backpropagate.\n" - ] + "colab_type": "code", + "id": "AZ7R1RxHRPGf" + }, + "outputs": [], + "source": [ + "class CNN_Encoder(tf.keras.Model):\n", + " # Since we have already extracted the features and dumped it using pickle\n", + " # This encoder passes those features through a Fully connected layer\n", + " def __init__(self, embedding_dim):\n", + " super(CNN_Encoder, self).__init__()\n", + " # shape after fc == (batch_size, 64, embedding_dim)\n", + " self.fc = tf.keras.layers.Dense(embedding_dim)\n", + " \n", + " def call(self, x):\n", + " x = self.fc(x)\n", + " x = tf.nn.relu(x)\n", + " return x" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } }, - { - "metadata": { - "id": "Vt4WZ5mhJE-E", - "colab_type": "code", - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - } - }, - "cell_type": "code", - "source": [ - "# adding this in a separate cell because if you run the training cell \n", - "# many times, the loss_plot array will be reset\n", - "loss_plot = []" - ], - "execution_count": 0, - "outputs": [] + "colab_type": "code", + "id": "V9UbGQmERPGi" + }, + "outputs": [], + "source": [ + "class RNN_Decoder(tf.keras.Model):\n", + " def __init__(self, embedding_dim, units, vocab_size):\n", + " super(RNN_Decoder, self).__init__()\n", + " self.units = units\n", + "\n", + " self.embedding = tf.keras.layers.Embedding(vocab_size, embedding_dim)\n", + " self.gru = gru(self.units)\n", + " self.fc1 = tf.keras.layers.Dense(self.units)\n", + " self.fc2 = tf.keras.layers.Dense(vocab_size)\n", + " \n", + " self.attention = BahdanauAttention(self.units)\n", + " \n", + " def call(self, x, features, hidden):\n", + " # defining attention as a separate model\n", + " context_vector, attention_weights = self.attention(features, hidden)\n", + " \n", + " # x shape after passing through embedding == (batch_size, 1, embedding_dim)\n", + " x = self.embedding(x)\n", + " \n", + " # x shape after concatenation == (batch_size, 1, embedding_dim + hidden_size)\n", + " x = tf.concat([tf.expand_dims(context_vector, 1), x], axis=-1)\n", + " \n", + " # passing the concatenated vector to the GRU\n", + " output, state = self.gru(x)\n", + " \n", + " # shape == (batch_size, max_length, hidden_size)\n", + " x = self.fc1(output)\n", + " \n", + " # x shape == (batch_size * max_length, hidden_size)\n", + " x = tf.reshape(x, (-1, x.shape[2]))\n", + " \n", + " # output shape == (batch_size * max_length, vocab)\n", + " x = self.fc2(x)\n", + "\n", + " return x, state, attention_weights\n", + "\n", + " def reset_state(self, batch_size):\n", + " return tf.zeros((batch_size, self.units))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } }, - { - "metadata": { - "id": "UlA4VIQpRPGo", - "colab_type": "code", - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - } - }, - "cell_type": "code", - "source": [ - "EPOCHS = 20\n", - "\n", - "for epoch in range(EPOCHS):\n", - " start = time.time()\n", - " total_loss = 0\n", - " \n", - " for (batch, (img_tensor, target)) in enumerate(dataset):\n", - " loss = 0\n", - " \n", - " # initializing the hidden state for each batch\n", - " # because the captions are not related from image to image\n", - " hidden = decoder.reset_state(batch_size=target.shape[0])\n", - "\n", - " dec_input = tf.expand_dims([tokenizer.word_index['']] * BATCH_SIZE, 1)\n", - " \n", - " with tf.GradientTape() as tape:\n", - " features = encoder(img_tensor)\n", - " \n", - " for i in range(1, target.shape[1]):\n", - " # passing the features through the decoder\n", - " predictions, hidden, _ = decoder(dec_input, features, hidden)\n", - "\n", - " loss += loss_function(target[:, i], predictions)\n", - " \n", - " # using teacher forcing\n", - " dec_input = tf.expand_dims(target[:, i], 1)\n", - " \n", - " total_loss += (loss / int(target.shape[1]))\n", - " \n", - " variables = encoder.variables + decoder.variables\n", - " \n", - " gradients = tape.gradient(loss, variables) \n", - " \n", - " optimizer.apply_gradients(zip(gradients, variables), tf.train.get_or_create_global_step())\n", - " \n", - " if batch % 100 == 0:\n", - " print ('Epoch {} Batch {} Loss {:.4f}'.format(epoch + 1, \n", - " batch, \n", - " loss.numpy() / int(target.shape[1])))\n", - " # storing the epoch end loss value to plot later\n", - " loss_plot.append(total_loss / len(cap_vector))\n", - " \n", - " print ('Epoch {} Loss {:.6f}'.format(epoch + 1, \n", - " total_loss/len(cap_vector)))\n", - " print ('Time taken for 1 epoch {} sec\\n'.format(time.time() - start))" - ], - "execution_count": 0, - "outputs": [] + "colab_type": "code", + "id": "Qs_Sr03wRPGk" + }, + "outputs": [], + "source": [ + "encoder = CNN_Encoder(embedding_dim)\n", + "decoder = RNN_Decoder(embedding_dim, units, vocab_size)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } }, - { - "metadata": { - "id": "1Wm83G-ZBPcC", - "colab_type": "code", - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - } - }, - "cell_type": "code", - "source": [ - "plt.plot(loss_plot)\n", - "plt.xlabel('Epochs')\n", - "plt.ylabel('Loss')\n", - "plt.title('Loss Plot')\n", - "plt.show()" - ], - "execution_count": 0, - "outputs": [] + "colab_type": "code", + "id": "-bYN7xA0RPGl" + }, + "outputs": [], + "source": [ + "optimizer = tf.train.AdamOptimizer()\n", + "\n", + "# We are masking the loss calculated for padding\n", + "def loss_function(real, pred):\n", + " mask = 1 - np.equal(real, 0)\n", + " loss_ = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=real, logits=pred) * mask\n", + " return tf.reduce_mean(loss_)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "PHod7t72RPGn" + }, + "source": [ + "## Training\n", + "\n", + "* We extract the features stored in the respective `.npy` files and then pass those features through the encoder.\n", + "* The encoder output, hidden state(initialized to 0) and the decoder input (which is the start token) is passed to the decoder.\n", + "* The decoder returns the predictions and the decoder hidden state.\n", + "* The decoder hidden state is then passed back into the model and the predictions are used to calculate the loss.\n", + "* Use teacher forcing to decide the next input to the decoder.\n", + "* Teacher forcing is the technique where the target word is passed as the next input to the decoder.\n", + "* The final step is to calculate the gradients and apply it to the optimizer and backpropagate.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } }, - { - "metadata": { - "id": "xGvOcLQKghXN", - "colab_type": "text" - }, - "cell_type": "markdown", - "source": [ - "## Caption!\n", - "\n", - "* The evaluate function is similar to the training loop, except we don't use teacher forcing here. The input to the decoder at each time step is its previous predictions along with the hidden state and the encoder output.\n", - "* Stop predicting when the model predicts the end token.\n", - "* And store the attention weights for every time step." - ] + "colab_type": "code", + "id": "Vt4WZ5mhJE-E" + }, + "outputs": [], + "source": [ + "# adding this in a separate cell because if you run the training cell \n", + "# many times, the loss_plot array will be reset\n", + "loss_plot = []" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } }, - { - "metadata": { - "id": "RCWpDtyNRPGs", - "colab_type": "code", - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - } - }, - "cell_type": "code", - "source": [ - "def evaluate(image):\n", - " attention_plot = np.zeros((max_length, attention_features_shape))\n", - "\n", - " hidden = decoder.reset_state(batch_size=1)\n", - "\n", - " temp_input = tf.expand_dims(load_image(image)[0], 0)\n", - " img_tensor_val = image_features_extract_model(temp_input)\n", - " img_tensor_val = tf.reshape(img_tensor_val, (img_tensor_val.shape[0], -1, img_tensor_val.shape[3]))\n", - "\n", - " features = encoder(img_tensor_val)\n", - "\n", - " dec_input = tf.expand_dims([tokenizer.word_index['']], 0)\n", - " result = []\n", - "\n", - " for i in range(max_length):\n", - " predictions, hidden, attention_weights = decoder(dec_input, features, hidden)\n", - "\n", - " attention_plot[i] = tf.reshape(attention_weights, (-1, )).numpy()\n", - "\n", - " predicted_id = tf.argmax(predictions[0]).numpy()\n", - " result.append(index_word[predicted_id])\n", - "\n", - " if index_word[predicted_id] == '':\n", - " return result, attention_plot\n", - "\n", - " dec_input = tf.expand_dims([predicted_id], 0)\n", - "\n", - " attention_plot = attention_plot[:len(result), :]\n", - " return result, attention_plot" - ], - "execution_count": 0, - "outputs": [] + "colab_type": "code", + "id": "UlA4VIQpRPGo" + }, + "outputs": [], + "source": [ + "EPOCHS = 20\n", + "\n", + "for epoch in range(EPOCHS):\n", + " start = time.time()\n", + " total_loss = 0\n", + " \n", + " for (batch, (img_tensor, target)) in enumerate(dataset):\n", + " loss = 0\n", + " \n", + " # initializing the hidden state for each batch\n", + " # because the captions are not related from image to image\n", + " hidden = decoder.reset_state(batch_size=target.shape[0])\n", + "\n", + " dec_input = tf.expand_dims([tokenizer.word_index['']] * BATCH_SIZE, 1)\n", + " \n", + " with tf.GradientTape() as tape:\n", + " features = encoder(img_tensor)\n", + " \n", + " for i in range(1, target.shape[1]):\n", + " # passing the features through the decoder\n", + " predictions, hidden, _ = decoder(dec_input, features, hidden)\n", + "\n", + " loss += loss_function(target[:, i], predictions)\n", + " \n", + " # using teacher forcing\n", + " dec_input = tf.expand_dims(target[:, i], 1)\n", + " \n", + " total_loss += (loss / int(target.shape[1]))\n", + " \n", + " variables = encoder.variables + decoder.variables\n", + " \n", + " gradients = tape.gradient(loss, variables) \n", + " \n", + " optimizer.apply_gradients(zip(gradients, variables), tf.train.get_or_create_global_step())\n", + " \n", + " if batch % 100 == 0:\n", + " print ('Epoch {} Batch {} Loss {:.4f}'.format(epoch + 1, \n", + " batch, \n", + " loss.numpy() / int(target.shape[1])))\n", + " # storing the epoch end loss value to plot later\n", + " loss_plot.append(total_loss / len(cap_vector))\n", + " \n", + " print ('Epoch {} Loss {:.6f}'.format(epoch + 1, \n", + " total_loss/len(cap_vector)))\n", + " print ('Time taken for 1 epoch {} sec\\n'.format(time.time() - start))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } }, - { - "metadata": { - "id": "fD_y7PD6RPGt", - "colab_type": "code", - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - } - }, - "cell_type": "code", - "source": [ - "def plot_attention(image, result, attention_plot):\n", - " temp_image = np.array(Image.open(image))\n", - "\n", - " fig = plt.figure(figsize=(10, 10))\n", - " \n", - " len_result = len(result)\n", - " for l in range(len_result):\n", - " temp_att = np.resize(attention_plot[l], (8, 8))\n", - " ax = fig.add_subplot(len_result//2, len_result//2, l+1)\n", - " ax.set_title(result[l])\n", - " img = ax.imshow(temp_image)\n", - " ax.imshow(temp_att, cmap='gray', alpha=0.6, extent=img.get_extent())\n", - "\n", - " plt.tight_layout()\n", - " plt.show()" - ], - "execution_count": 0, - "outputs": [] + "colab_type": "code", + "id": "1Wm83G-ZBPcC" + }, + "outputs": [], + "source": [ + "plt.plot(loss_plot)\n", + "plt.xlabel('Epochs')\n", + "plt.ylabel('Loss')\n", + "plt.title('Loss Plot')\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "xGvOcLQKghXN" + }, + "source": [ + "## Caption!\n", + "\n", + "* The evaluate function is similar to the training loop, except we don't use teacher forcing here. The input to the decoder at each time step is its previous predictions along with the hidden state and the encoder output.\n", + "* Stop predicting when the model predicts the end token.\n", + "* And store the attention weights for every time step." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } }, - { - "metadata": { - "id": "io7ws3ReRPGv", - "colab_type": "code", - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - } - }, - "cell_type": "code", - "source": [ - "# captions on the validation set\n", - "rid = np.random.randint(0, len(img_name_val))\n", - "image = img_name_val[rid]\n", - "real_caption = ' '.join([index_word[i] for i in cap_val[rid] if i not in [0]])\n", - "result, attention_plot = evaluate(image)\n", - "\n", - "print ('Real Caption:', real_caption)\n", - "print ('Prediction Caption:', ' '.join(result))\n", - "plot_attention(image, result, attention_plot)\n", - "# opening the image\n", - "Image.open(img_name_val[rid])" - ], - "execution_count": 0, - "outputs": [] + "colab_type": "code", + "id": "RCWpDtyNRPGs" + }, + "outputs": [], + "source": [ + "def evaluate(image):\n", + " attention_plot = np.zeros((max_length, attention_features_shape))\n", + "\n", + " hidden = decoder.reset_state(batch_size=1)\n", + "\n", + " temp_input = tf.expand_dims(load_image(image)[0], 0)\n", + " img_tensor_val = image_features_extract_model(temp_input)\n", + " img_tensor_val = tf.reshape(img_tensor_val, (img_tensor_val.shape[0], -1, img_tensor_val.shape[3]))\n", + "\n", + " features = encoder(img_tensor_val)\n", + "\n", + " dec_input = tf.expand_dims([tokenizer.word_index['']], 0)\n", + " result = []\n", + "\n", + " for i in range(max_length):\n", + " predictions, hidden, attention_weights = decoder(dec_input, features, hidden)\n", + "\n", + " attention_plot[i] = tf.reshape(attention_weights, (-1, )).numpy()\n", + "\n", + " predicted_id = tf.argmax(predictions[0]).numpy()\n", + " result.append(tokenizer.index_word[predicted_id])\n", + "\n", + " if tokenizer.index_word[predicted_id] == '':\n", + " return result, attention_plot\n", + "\n", + " dec_input = tf.expand_dims([predicted_id], 0)\n", + "\n", + " attention_plot = attention_plot[:len(result), :]\n", + " return result, attention_plot" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } }, - { - "metadata": { - "id": "Rprk3HEvZuxb", - "colab_type": "text" - }, - "cell_type": "markdown", - "source": [ - "## Try it on your own images\n", - "For fun, below we've provided a method you can use to caption your own images with the model we've just trained. Keep in mind, it was trained on a relatively small amount of data, and your images may be different from the training data (so be prepared for weird results!)\n" - ] + "colab_type": "code", + "id": "fD_y7PD6RPGt" + }, + "outputs": [], + "source": [ + "def plot_attention(image, result, attention_plot):\n", + " temp_image = np.array(Image.open(image))\n", + "\n", + " fig = plt.figure(figsize=(10, 10))\n", + " \n", + " len_result = len(result)\n", + " for l in range(len_result):\n", + " temp_att = np.resize(attention_plot[l], (8, 8))\n", + " ax = fig.add_subplot(len_result//2, len_result//2, l+1)\n", + " ax.set_title(result[l])\n", + " img = ax.imshow(temp_image)\n", + " ax.imshow(temp_att, cmap='gray', alpha=0.6, extent=img.get_extent())\n", + "\n", + " plt.tight_layout()\n", + " plt.show()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } }, - { - "metadata": { - "id": "9Psd1quzaAWg", - "colab_type": "code", - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - } - }, - "cell_type": "code", - "source": [ - "image_url = 'https://tensorflow.org/images/surf.jpg'\n", - "image_extension = image_url[-4:]\n", - "image_path = tf.keras.utils.get_file('image'+image_extension, \n", - " origin=image_url)\n", - "\n", - "result, attention_plot = evaluate(image_path)\n", - "print ('Prediction Caption:', ' '.join(result))\n", - "plot_attention(image_path, result, attention_plot)\n", - "# opening the image\n", - "Image.open(image_path)" - ], - "execution_count": 0, - "outputs": [] + "colab_type": "code", + "id": "io7ws3ReRPGv" + }, + "outputs": [], + "source": [ + "# captions on the validation set\n", + "rid = np.random.randint(0, len(img_name_val))\n", + "image = img_name_val[rid]\n", + "real_caption = ' '.join([tokenizer.index_word[i] for i in cap_val[rid] if i not in [0]])\n", + "result, attention_plot = evaluate(image)\n", + "\n", + "print ('Real Caption:', real_caption)\n", + "print ('Prediction Caption:', ' '.join(result))\n", + "plot_attention(image, result, attention_plot)\n", + "# opening the image\n", + "Image.open(img_name_val[rid])" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "Rprk3HEvZuxb" + }, + "source": [ + "## Try it on your own images\n", + "For fun, below we've provided a method you can use to caption your own images with the model we've just trained. Keep in mind, it was trained on a relatively small amount of data, and your images may be different from the training data (so be prepared for weird results!)\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } }, + "colab_type": "code", + "id": "9Psd1quzaAWg" + }, + "outputs": [], + "source": [ + "image_url = 'https://tensorflow.org/images/surf.jpg'\n", + "image_extension = image_url[-4:]\n", + "image_path = tf.keras.utils.get_file('image'+image_extension, \n", + " origin=image_url)\n", + "\n", + "result, attention_plot = evaluate(image_path)\n", + "print ('Prediction Caption:', ' '.join(result))\n", + "plot_attention(image_path, result, attention_plot)\n", + "# opening the image\n", + "Image.open(image_path)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "VJZXyJco6uLO" + }, + "source": [ + "# Next steps\n", + "\n", + "Congrats! You've just trained an image captioning model with attention. Next, we recommend taking a look at this example [Neural Machine Translation with Attention]( https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/eager/python/examples/nmt_with_attention/nmt_with_attention.ipynb). It uses a similar architecture to translate between Spanish and English sentences. You can also experiment with training the code in this notebook on a different dataset." + ] + } + ], + "metadata": { + "accelerator": "GPU", + "colab": { + "collapsed_sections": [], + "default_view": {}, + "name": "image_captioning_with_attention.ipynb", + "private_outputs": true, + "provenance": [ { - "metadata": { - "id": "VJZXyJco6uLO", - "colab_type": "text" - }, - "cell_type": "markdown", - "source": [ - "# Next steps\n", - "\n", - "Congrats! You've just trained an image captioning model with attention. Next, we recommend taking a look at this example [Neural Machine Translation with Attention]( https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/eager/python/examples/nmt_with_attention/nmt_with_attention.ipynb). It uses a similar architecture to translate between Spanish and English sentences. You can also experiment with training the code in this notebook on a different dataset." - ] + "file_id": "1HI8OK2sMjcx9CTWVn0122QAHOuXaOaMg", + "timestamp": 1530222436922 } - ] + ], + "toc_visible": true, + "version": "0.3.2", + "views": {} + }, + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.6.3" + } + }, + "nbformat": 4, + "nbformat_minor": 2 } diff --git a/tensorflow/contrib/eager/python/examples/linear_regression/linear_regression_graph_test.py b/tensorflow/contrib/eager/python/examples/linear_regression/linear_regression_graph_test.py index 557ad42752144243ae3da61b955b31398cba846e..d412b25b368260b81256fd58034330b884261b2b 100644 --- a/tensorflow/contrib/eager/python/examples/linear_regression/linear_regression_graph_test.py +++ b/tensorflow/contrib/eager/python/examples/linear_regression/linear_regression_graph_test.py @@ -36,7 +36,7 @@ class GraphLinearRegressionBenchmark(tf.test.Benchmark): noise_level=0.01, batch_size=batch_size, num_batches=num_batches) - iterator = dataset.make_initializable_iterator() + iterator = tf.compat.v1.data.make_initializable_iterator(dataset) x, y = iterator.get_next() model = linear_regression.LinearModel() diff --git a/tensorflow/contrib/eager/python/examples/resnet50/resnet50_graph_test.py b/tensorflow/contrib/eager/python/examples/resnet50/resnet50_graph_test.py index f3bb978875e226f58d6a00e09154191673a97415..fb7975d8fe867711cff31d627788a2d62a520aa9 100644 --- a/tensorflow/contrib/eager/python/examples/resnet50/resnet50_graph_test.py +++ b/tensorflow/contrib/eager/python/examples/resnet50/resnet50_graph_test.py @@ -142,7 +142,8 @@ class ResNet50Benchmarks(tf.test.Benchmark): with tf.Graph().as_default(): np_images, np_labels = random_batch(batch_size) dataset = tf.data.Dataset.from_tensors((np_images, np_labels)).repeat() - (images, labels) = dataset.make_one_shot_iterator().get_next() + images, labels = tf.compat.v1.data.make_one_shot_iterator( + dataset).get_next() model = resnet50.ResNet50(data_format()) logits = model(images, training=True) diff --git a/tensorflow/contrib/eager/python/examples/revnet/main.py b/tensorflow/contrib/eager/python/examples/revnet/main.py index b702e91f92220c2a9003a1b82411131332012a9e..9585f3565f83af724b6336e466d3671443ba2361 100644 --- a/tensorflow/contrib/eager/python/examples/revnet/main.py +++ b/tensorflow/contrib/eager/python/examples/revnet/main.py @@ -72,14 +72,11 @@ def main(_): train_one_iter(model, x, y, optimizer, global_step=global_step) if global_step.numpy() % config.log_every == 0: - it_test = ds_test.make_one_shot_iterator() - acc_test, loss_test = evaluate(model, it_test) + acc_test, loss_test = evaluate(model, ds_test) if FLAGS.validate: - it_train = ds_train_one_shot.make_one_shot_iterator() - it_validation = ds_validation.make_one_shot_iterator() - acc_train, loss_train = evaluate(model, it_train) - acc_validation, loss_validation = evaluate(model, it_validation) + acc_train, loss_train = evaluate(model, ds_train_one_shot) + acc_validation, loss_validation = evaluate(model, ds_validation) print("Iter {}, " "training set accuracy {:.4f}, loss {:.4f}; " "validation set accuracy {:.4f}, loss {:.4f}; " @@ -218,11 +215,11 @@ def train_one_iter(model, inputs, labels, optimizer, global_step=None): return logits, loss -def evaluate(model, iterator): +def evaluate(model, dataset): """Compute accuracy with the given dataset iterator.""" mean_loss = tfe.metrics.Mean() accuracy = tfe.metrics.Accuracy() - for x, y in iterator: + for x, y in dataset: logits, _ = model(x, training=False) loss = model.compute_loss(logits=logits, labels=y) accuracy( diff --git a/tensorflow/contrib/eager/python/examples/rnn_ptb/rnn_ptb_graph_test.py b/tensorflow/contrib/eager/python/examples/rnn_ptb/rnn_ptb_graph_test.py index 63b5c4c54d13e9c2448ec1f572ca1389f2443bef..770484abed96e540cf75cc5368a1410c31a8d2d0 100644 --- a/tensorflow/contrib/eager/python/examples/rnn_ptb/rnn_ptb_graph_test.py +++ b/tensorflow/contrib/eager/python/examples/rnn_ptb/rnn_ptb_graph_test.py @@ -82,7 +82,7 @@ class PTBBenchmark(tf.test.Benchmark): tf.ones( [PTBBenchmark.SEQ_LEN, PTBBenchmark.BATCH_SIZE], dtype=tf.int64)).repeat(num_iters + num_warmup) - inputs = dataset.make_one_shot_iterator().get_next() + inputs = tf.compat.v1.data.make_one_shot_iterator(dataset).get_next() with tf.device(tf.test.gpu_device_name()): outputs = model(inputs, training=True) @@ -124,7 +124,8 @@ class PTBBenchmark(tf.test.Benchmark): dtype=tf.int64)).repeat(num_iters + num_warmup) # inputs and labels have the same shape dataset = tf.data.Dataset.zip((dataset, dataset)) - (inputs, labels) = dataset.make_one_shot_iterator().get_next() + (inputs, labels) = tf.compat.v1.data.make_one_shot_iterator( + dataset).get_next() with tf.device(tf.test.gpu_device_name()): optimizer = tf.train.GradientDescentOptimizer(learning_rate=1.0) diff --git a/tensorflow/contrib/eager/python/saver.py b/tensorflow/contrib/eager/python/saver.py index f9c716360c5755ee1902b576545d776725f9966f..1d0d6c6c14ce4a8e454206e0be9fea4724f09192 100644 --- a/tensorflow/contrib/eager/python/saver.py +++ b/tensorflow/contrib/eager/python/saver.py @@ -115,6 +115,11 @@ def restore_variables_on_create(save_path, map_func=None): class Saver(object): """A tf.train.Saver adapter for use when eager execution is enabled. + + `Saver`'s name-based checkpointing strategy is fragile. Please switch to + `tf.train.Checkpoint` or `tf.keras.Model.save_weights`, which perform a more + robust object-based saving. These APIs will load checkpoints written by + `Saver`. """ def __init__(self, var_list): diff --git a/tensorflow/contrib/eager/python/tfe.py b/tensorflow/contrib/eager/python/tfe.py index 33c988fd9065e7fbe7b9aeb85cad82eb3c119f76..8882a863c30d8b222c68d6952279c3744345883c 100644 --- a/tensorflow/contrib/eager/python/tfe.py +++ b/tensorflow/contrib/eager/python/tfe.py @@ -41,6 +41,8 @@ To use, at program startup, call `tf.enable_eager_execution()`. @@add_execution_callback @@clear_execution_callbacks +@@errstate +@@ExecutionCallback @@inf_callback @@inf_nan_callback @@nan_callback @@ -119,6 +121,8 @@ from tensorflow.python.eager.context import set_server_def from tensorflow.python.eager.def_function import function from tensorflow.python.eager.execution_callbacks import add_execution_callback from tensorflow.python.eager.execution_callbacks import clear_execution_callbacks +from tensorflow.python.eager.execution_callbacks import errstate +from tensorflow.python.eager.execution_callbacks import ExecutionCallback from tensorflow.python.eager.execution_callbacks import inf_callback from tensorflow.python.eager.execution_callbacks import inf_nan_callback from tensorflow.python.eager.execution_callbacks import nan_callback diff --git a/tensorflow/contrib/feature_column/BUILD b/tensorflow/contrib/feature_column/BUILD index 1cd83bdb5de7c2f6dc91c980750b49aca1a7790b..4c1d1a29f20b5574b63cf87ecf62db95f92902cd 100644 --- a/tensorflow/contrib/feature_column/BUILD +++ b/tensorflow/contrib/feature_column/BUILD @@ -110,8 +110,8 @@ py_test( "//tensorflow/python:parsing_ops", "//tensorflow/python:sparse_tensor", "//tensorflow/python:training", - "//tensorflow/python/feature_column", "//tensorflow/python/feature_column:feature_column_py", + "//tensorflow/python/feature_column:feature_column_v2_test", "//third_party/py/numpy", "@absl_py//absl/testing:parameterized", ], diff --git a/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column_v2.py b/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column_v2.py index 0d34ad161855476b6a4cd9a258521dbe122b4140..83b93ec332044f754f9dcde8d7c5c19b26e53a4a 100644 --- a/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column_v2.py +++ b/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column_v2.py @@ -203,7 +203,8 @@ def sequence_categorical_column_with_identity( columns = [watches_embedding] features = tf.parse_example(..., features=make_parse_example_spec(columns)) - input_layer, sequence_length = sequence_input_layer(features, columns) + sequence_feature_layer = SequenceFeatureLayer(columns) + input_layer, sequence_length = sequence_feature_layer(features) rnn_cell = tf.nn.rnn_cell.BasicRNNCell(hidden_size) outputs, state = tf.nn.dynamic_rnn( @@ -219,15 +220,17 @@ def sequence_categorical_column_with_identity( `[0, num_buckets)`, and will replace out-of-range inputs. Returns: - A `_SequenceCategoricalColumn`. + A `SequenceCategoricalColumn`. Raises: ValueError: if `num_buckets` is less than one. ValueError: if `default_value` is not in range `[0, num_buckets)`. """ - return fc_old._SequenceCategoricalColumn( - fc_old._categorical_column_with_identity( - key=key, num_buckets=num_buckets, default_value=default_value)) + return fc.SequenceCategoricalColumn( + fc.categorical_column_with_identity( + key=key, + num_buckets=num_buckets, + default_value=default_value)) def sequence_categorical_column_with_hash_bucket( @@ -247,7 +250,8 @@ def sequence_categorical_column_with_hash_bucket( columns = [tokens_embedding] features = tf.parse_example(..., features=make_parse_example_spec(columns)) - input_layer, sequence_length = sequence_input_layer(features, columns) + sequence_feature_layer = SequenceFeatureLayer(columns) + input_layer, sequence_length = sequence_feature_layer(features) rnn_cell = tf.nn.rnn_cell.BasicRNNCell(hidden_size) outputs, state = tf.nn.dynamic_rnn( @@ -260,15 +264,17 @@ def sequence_categorical_column_with_hash_bucket( dtype: The type of features. Only string and integer types are supported. Returns: - A `_SequenceCategoricalColumn`. + A `SequenceCategoricalColumn`. Raises: ValueError: `hash_bucket_size` is not greater than 1. ValueError: `dtype` is neither string nor integer. """ - return fc_old._SequenceCategoricalColumn( - fc_old._categorical_column_with_hash_bucket( - key=key, hash_bucket_size=hash_bucket_size, dtype=dtype)) + return fc.SequenceCategoricalColumn( + fc.categorical_column_with_hash_bucket( + key=key, + hash_bucket_size=hash_bucket_size, + dtype=dtype)) def sequence_categorical_column_with_vocabulary_file( @@ -290,7 +296,8 @@ def sequence_categorical_column_with_vocabulary_file( columns = [states_embedding] features = tf.parse_example(..., features=make_parse_example_spec(columns)) - input_layer, sequence_length = sequence_input_layer(features, columns) + sequence_feature_layer = SequenceFeatureLayer(columns) + input_layer, sequence_length = sequence_feature_layer(features) rnn_cell = tf.nn.rnn_cell.BasicRNNCell(hidden_size) outputs, state = tf.nn.dynamic_rnn( @@ -314,7 +321,7 @@ def sequence_categorical_column_with_vocabulary_file( dtype: The type of features. Only string and integer types are supported. Returns: - A `_SequenceCategoricalColumn`. + A `SequenceCategoricalColumn`. Raises: ValueError: `vocabulary_file` is missing or cannot be opened. @@ -323,8 +330,8 @@ def sequence_categorical_column_with_vocabulary_file( ValueError: `num_oov_buckets` and `default_value` are both specified. ValueError: `dtype` is neither string nor integer. """ - return fc_old._SequenceCategoricalColumn( - fc_old._categorical_column_with_vocabulary_file( + return fc.SequenceCategoricalColumn( + fc.categorical_column_with_vocabulary_file( key=key, vocabulary_file=vocabulary_file, vocabulary_size=vocabulary_size, @@ -351,7 +358,8 @@ def sequence_categorical_column_with_vocabulary_list( columns = [colors_embedding] features = tf.parse_example(..., features=make_parse_example_spec(columns)) - input_layer, sequence_length = sequence_input_layer(features, columns) + sequence_feature_layer = SequenceFeatureLayer(columns) + input_layer, sequence_length = sequence_feature_layer(features) rnn_cell = tf.nn.rnn_cell.BasicRNNCell(hidden_size) outputs, state = tf.nn.dynamic_rnn( @@ -375,7 +383,7 @@ def sequence_categorical_column_with_vocabulary_list( with `default_value`. Returns: - A `_SequenceCategoricalColumn`. + A `SequenceCategoricalColumn`. Raises: ValueError: if `vocabulary_list` is empty, or contains duplicate keys. @@ -383,8 +391,8 @@ def sequence_categorical_column_with_vocabulary_list( ValueError: `num_oov_buckets` and `default_value` are both specified. ValueError: if `dtype` is not integer or string. """ - return fc_old._SequenceCategoricalColumn( - fc_old._categorical_column_with_vocabulary_list( + return fc.SequenceCategoricalColumn( + fc.categorical_column_with_vocabulary_list( key=key, vocabulary_list=vocabulary_list, dtype=dtype, diff --git a/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column_v2_test.py b/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column_v2_test.py index ca4398a142065de0be7bee57cd7e54670bbae12e..be012a87690c24c6d9b7808790393e1aa6d01211 100644 --- a/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column_v2_test.py +++ b/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column_v2_test.py @@ -26,7 +26,7 @@ from tensorflow.contrib.feature_column.python.feature_column import sequence_fea from tensorflow.contrib.feature_column.python.feature_column import sequence_feature_column_v2 as sfc from tensorflow.python.feature_column import feature_column as fc_old from tensorflow.python.feature_column import feature_column_lib as fc -from tensorflow.python.feature_column.feature_column import _LazyBuilder +from tensorflow.python.feature_column.feature_column_v2_test import _TestStateManager from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors from tensorflow.python.framework import ops @@ -131,7 +131,7 @@ class SequenceInputLayerTest(test.TestCase, parameterized.TestCase): feature_columns=[embedding_column_b, embedding_column_a]) global_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES) - self.assertItemsEqual( + self.assertCountEqual( ('sequence_input_layer/aaa_embedding/embedding_weights:0', 'sequence_input_layer/bbb_embedding/embedding_weights:0'), tuple([v.name for v in global_vars])) @@ -223,7 +223,7 @@ class SequenceInputLayerTest(test.TestCase, parameterized.TestCase): feature_columns=shared_embedding_columns) global_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES) - self.assertItemsEqual( + self.assertCountEqual( ('sequence_input_layer/aaa_bbb_shared_embedding/embedding_weights:0',), tuple([v.name for v in global_vars])) with monitored_session.MonitoredSession() as sess: @@ -670,6 +670,23 @@ def _assert_sparse_tensor_indices_shape(test_case, expected, actual): test_case.assertAllEqual(expected.dense_shape, actual.dense_shape) +def _get_sequence_dense_tensor(column, features): + return column.get_sequence_dense_tensor( + fc.FeatureTransformationCache(features), None) + + +def _get_sequence_dense_tensor_state(column, features): + state_manager = _TestStateManager() + column.create_state(state_manager) + return column.get_sequence_dense_tensor( + fc.FeatureTransformationCache(features), state_manager) + + +def _get_sparse_tensors(column, features): + return column.get_sparse_tensors( + fc.FeatureTransformationCache(features), None) + + class SequenceCategoricalColumnWithIdentityTest( test.TestCase, parameterized.TestCase): @@ -698,7 +715,7 @@ class SequenceCategoricalColumnWithIdentityTest( expected = sparse_tensor.SparseTensorValue(**expected_args) column = sfc.sequence_categorical_column_with_identity('aaa', num_buckets=9) - id_weight_pair = column._get_sparse_tensors(_LazyBuilder({'aaa': inputs})) + id_weight_pair = _get_sparse_tensors(column, {'aaa': inputs}) self.assertIsNone(id_weight_pair.weight_tensor) with monitored_session.MonitoredSession() as sess: @@ -737,7 +754,7 @@ class SequenceCategoricalColumnWithHashBucketTest( column = sfc.sequence_categorical_column_with_hash_bucket( 'aaa', hash_bucket_size=10) - id_weight_pair = column._get_sparse_tensors(_LazyBuilder({'aaa': inputs})) + id_weight_pair = _get_sparse_tensors(column, {'aaa': inputs}) self.assertIsNone(id_weight_pair.weight_tensor) with monitored_session.MonitoredSession() as sess: @@ -790,7 +807,7 @@ class SequenceCategoricalColumnWithVocabularyFileTest( vocabulary_file=self._wire_vocabulary_file_name, vocabulary_size=self._wire_vocabulary_size) - id_weight_pair = column._get_sparse_tensors(_LazyBuilder({'aaa': inputs})) + id_weight_pair = _get_sparse_tensors(column, {'aaa': inputs}) self.assertIsNone(id_weight_pair.weight_tensor) with monitored_session.MonitoredSession() as sess: @@ -814,8 +831,7 @@ class SequenceCategoricalColumnWithVocabularyFileTest( input_placeholder_shape[1] = None input_placeholder = array_ops.sparse_placeholder( dtypes.string, shape=input_placeholder_shape) - id_weight_pair = column._get_sparse_tensors( - _LazyBuilder({'aaa': input_placeholder})) + id_weight_pair = _get_sparse_tensors(column, {'aaa': input_placeholder}) self.assertIsNone(id_weight_pair.weight_tensor) with monitored_session.MonitoredSession() as sess: @@ -855,7 +871,7 @@ class SequenceCategoricalColumnWithVocabularyListTest( key='aaa', vocabulary_list=('omar', 'stringer', 'marlo')) - id_weight_pair = column._get_sparse_tensors(_LazyBuilder({'aaa': inputs})) + id_weight_pair = _get_sparse_tensors(column, {'aaa': inputs}) self.assertIsNone(id_weight_pair.weight_tensor) with monitored_session.MonitoredSession() as sess: @@ -922,13 +938,12 @@ class SequenceEmbeddingColumnTest( categorical_column = sfc.sequence_categorical_column_with_identity( key='aaa', num_buckets=vocabulary_size) - embedding_column = fc_old._embedding_column( - categorical_column, - dimension=embedding_dimension, + embedding_column = fc.embedding_column( + categorical_column, dimension=embedding_dimension, initializer=_initializer) - embedding_lookup, _ = embedding_column._get_sequence_dense_tensor( - _LazyBuilder({'aaa': inputs})) + embedding_lookup, _ = _get_sequence_dense_tensor_state( + embedding_column, {'aaa': inputs}) global_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES) self.assertItemsEqual( @@ -961,10 +976,11 @@ class SequenceEmbeddingColumnTest( categorical_column = sfc.sequence_categorical_column_with_identity( key='aaa', num_buckets=vocabulary_size) - embedding_column = fc_old._embedding_column(categorical_column, dimension=2) + embedding_column = fc.embedding_column( + categorical_column, dimension=2) - _, sequence_length = embedding_column._get_sequence_dense_tensor( - _LazyBuilder({'aaa': inputs})) + _, sequence_length = _get_sequence_dense_tensor_state( + embedding_column, {'aaa': inputs}) with monitored_session.MonitoredSession() as sess: sequence_length = sess.run(sequence_length) @@ -988,10 +1004,11 @@ class SequenceEmbeddingColumnTest( categorical_column = sfc.sequence_categorical_column_with_identity( key='aaa', num_buckets=vocabulary_size) - embedding_column = fc_old._embedding_column(categorical_column, dimension=2) + embedding_column = fc.embedding_column( + categorical_column, dimension=2) - _, sequence_length = embedding_column._get_sequence_dense_tensor( - _LazyBuilder({'aaa': sparse_input})) + _, sequence_length = _get_sequence_dense_tensor_state( + embedding_column, {'aaa': sparse_input}) with monitored_session.MonitoredSession() as sess: self.assertAllEqual( @@ -1058,22 +1075,18 @@ class SequenceSharedEmbeddingColumnTest(test.TestCase): key='aaa', num_buckets=vocabulary_size) categorical_column_b = sfc.sequence_categorical_column_with_identity( key='bbb', num_buckets=vocabulary_size) - shared_embedding_columns = fc.shared_embedding_columns( + shared_embedding_columns = fc.shared_embedding_columns_v2( [categorical_column_a, categorical_column_b], dimension=embedding_dimension, initializer=_initializer) - embedding_lookup_a = shared_embedding_columns[0]._get_sequence_dense_tensor( - _LazyBuilder({ - 'aaa': sparse_input_a - }))[0] - embedding_lookup_b = shared_embedding_columns[1]._get_sequence_dense_tensor( - _LazyBuilder({ - 'bbb': sparse_input_b - }))[0] + embedding_lookup_a = _get_sequence_dense_tensor( + shared_embedding_columns[0], {'aaa': sparse_input_a})[0] + embedding_lookup_b = _get_sequence_dense_tensor( + shared_embedding_columns[1], {'bbb': sparse_input_b})[0] global_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES) - self.assertItemsEqual(('embedding_weights:0',), + self.assertItemsEqual(('aaa_bbb_shared_embedding:0',), tuple([v.name for v in global_vars])) with monitored_session.MonitoredSession() as sess: self.assertAllEqual(embedding_values, global_vars[0].eval(session=sess)) @@ -1104,17 +1117,13 @@ class SequenceSharedEmbeddingColumnTest(test.TestCase): expected_sequence_length_b = [2, 1] categorical_column_b = sfc.sequence_categorical_column_with_identity( key='bbb', num_buckets=vocabulary_size) - shared_embedding_columns = fc.shared_embedding_columns( + shared_embedding_columns = fc.shared_embedding_columns_v2( [categorical_column_a, categorical_column_b], dimension=2) - sequence_length_a = shared_embedding_columns[0]._get_sequence_dense_tensor( - _LazyBuilder({ - 'aaa': sparse_input_a - }))[1] - sequence_length_b = shared_embedding_columns[1]._get_sequence_dense_tensor( - _LazyBuilder({ - 'bbb': sparse_input_b - }))[1] + sequence_length_a = _get_sequence_dense_tensor( + shared_embedding_columns[0], {'aaa': sparse_input_a})[1] + sequence_length_b = _get_sequence_dense_tensor( + shared_embedding_columns[1], {'bbb': sparse_input_b})[1] with monitored_session.MonitoredSession() as sess: sequence_length_a = sess.run(sequence_length_a) @@ -1155,17 +1164,13 @@ class SequenceSharedEmbeddingColumnTest(test.TestCase): categorical_column_b = sfc.sequence_categorical_column_with_identity( key='bbb', num_buckets=vocabulary_size) - shared_embedding_columns = fc.shared_embedding_columns( + shared_embedding_columns = fc.shared_embedding_columns_v2( [categorical_column_a, categorical_column_b], dimension=2) - sequence_length_a = shared_embedding_columns[0]._get_sequence_dense_tensor( - _LazyBuilder({ - 'aaa': sparse_input_a - }))[1] - sequence_length_b = shared_embedding_columns[1]._get_sequence_dense_tensor( - _LazyBuilder({ - 'bbb': sparse_input_b - }))[1] + sequence_length_a = _get_sequence_dense_tensor( + shared_embedding_columns[0], {'aaa': sparse_input_a})[1] + sequence_length_b = _get_sequence_dense_tensor( + shared_embedding_columns[1], {'bbb': sparse_input_b})[1] with monitored_session.MonitoredSession() as sess: self.assertAllEqual( @@ -1221,10 +1226,10 @@ class SequenceIndicatorColumnTest(test.TestCase, parameterized.TestCase): categorical_column = sfc.sequence_categorical_column_with_identity( key='aaa', num_buckets=vocabulary_size) - indicator_column = fc_old._indicator_column(categorical_column) + indicator_column = fc.indicator_column(categorical_column) - indicator_tensor, _ = indicator_column._get_sequence_dense_tensor( - _LazyBuilder({'aaa': inputs})) + indicator_tensor, _ = _get_sequence_dense_tensor( + indicator_column, {'aaa': inputs}) with monitored_session.MonitoredSession() as sess: self.assertAllEqual(expected, indicator_tensor.eval(session=sess)) @@ -1253,10 +1258,10 @@ class SequenceIndicatorColumnTest(test.TestCase, parameterized.TestCase): categorical_column = sfc.sequence_categorical_column_with_identity( key='aaa', num_buckets=vocabulary_size) - indicator_column = fc_old._indicator_column(categorical_column) + indicator_column = fc.indicator_column(categorical_column) - _, sequence_length = indicator_column._get_sequence_dense_tensor( - _LazyBuilder({'aaa': inputs})) + _, sequence_length = _get_sequence_dense_tensor( + indicator_column, {'aaa': inputs}) with monitored_session.MonitoredSession() as sess: sequence_length = sess.run(sequence_length) @@ -1282,19 +1287,14 @@ class SequenceIndicatorColumnTest(test.TestCase, parameterized.TestCase): key='aaa', num_buckets=vocabulary_size) indicator_column = fc.indicator_column(categorical_column) - _, sequence_length = indicator_column._get_sequence_dense_tensor( - _LazyBuilder({'aaa': sparse_input})) + _, sequence_length = _get_sequence_dense_tensor( + indicator_column, {'aaa': sparse_input}) with monitored_session.MonitoredSession() as sess: self.assertAllEqual( expected_sequence_length, sequence_length.eval(session=sess)) -def _get_sequence_dense_tensor(column, features): - return column.get_sequence_dense_tensor( - fc.FeatureTransformationCache(features), None) - - class SequenceNumericColumnTest(test.TestCase, parameterized.TestCase): def test_defaults(self): diff --git a/tensorflow/contrib/fused_conv/kernels/fused_conv2d_bias_activation_op.cc b/tensorflow/contrib/fused_conv/kernels/fused_conv2d_bias_activation_op.cc index 93b1aaa85e88e00c1b12a388321a4d6fb10f1611..c541c71f996c7a1b36cf28ae9a1783f8dca0a72c 100644 --- a/tensorflow/contrib/fused_conv/kernels/fused_conv2d_bias_activation_op.cc +++ b/tensorflow/contrib/fused_conv/kernels/fused_conv2d_bias_activation_op.cc @@ -522,7 +522,7 @@ void LaunchFusedConv2DBiasActivationOp:: auto bias_ptr = AsDeviceMemory(bias.template flat().data(), bias.template flat().size()); - static int64 ConvolveScratchSize = GetCudnnWorkspaceLimit( + static int64 ConvolveScratchSize = GetDnnWorkspaceLimit( // default value is in bytes despite the name of the environment variable "TF_CUDNN_WORKSPACE_LIMIT_IN_MB", 1LL << 32 // 4GB ); @@ -570,7 +570,7 @@ void LaunchFusedConv2DBiasActivationOp:: for (auto profile_algorithm : algorithms) { // TODO(zhengxq): profile each algorithm multiple times to better // accuracy. - CudnnScratchAllocator scratch_allocator(ConvolveScratchSize, ctx); + DnnScratchAllocator scratch_allocator(ConvolveScratchSize, ctx); dnn::ProfileResult profile_result; bool cudnn_launch_status = stream @@ -609,7 +609,7 @@ void LaunchFusedConv2DBiasActivationOp:: algorithm_config); } - CudnnScratchAllocator scratch_allocator(ConvolveScratchSize, ctx); + DnnScratchAllocator scratch_allocator(ConvolveScratchSize, ctx); bool cudnn_launch_status = stream ->ThenFusedConvolveWithAlgorithm( diff --git a/tensorflow/contrib/gan/python/features/python/conditioning_utils_impl.py b/tensorflow/contrib/gan/python/features/python/conditioning_utils_impl.py index e2594faf85bcf91cbe09f266e4d4211d20bdee17..364fa4eb461c62784803f0c309e3b7c5855df199 100644 --- a/tensorflow/contrib/gan/python/features/python/conditioning_utils_impl.py +++ b/tensorflow/contrib/gan/python/features/python/conditioning_utils_impl.py @@ -64,6 +64,9 @@ def condition_tensor(tensor, conditioning): """ tensor.shape[1:].assert_is_fully_defined() num_features = tensor.shape[1:].num_elements() + if conditioning.shape.ndims < 2: + raise ValueError('conditioning must be at least 2D, but saw shape: %s' + % conditioning.shape) mapped_conditioning = layers.linear( layers.flatten(conditioning), num_features) diff --git a/tensorflow/contrib/gan/python/features/python/conditioning_utils_test.py b/tensorflow/contrib/gan/python/features/python/conditioning_utils_test.py index 0aad769793761be69ee9d1e3416e44c7b3d8cea0..f5c7d53cf2c9aa08ba0074950983ef3ecd90168b 100644 --- a/tensorflow/contrib/gan/python/features/python/conditioning_utils_test.py +++ b/tensorflow/contrib/gan/python/features/python/conditioning_utils_test.py @@ -45,7 +45,7 @@ class ConditioningUtilsTest(test.TestCase): array_ops.placeholder(dtypes.float32, (5, None)), array_ops.placeholder(dtypes.float32, (5, 1))) - with self.assertRaisesRegexp(ValueError, 'expected min_ndim=2'): + with self.assertRaisesRegexp(ValueError, 'at least 2D'): conditioning_utils.condition_tensor( array_ops.placeholder(dtypes.float32, (5, 2)), array_ops.placeholder(dtypes.float32, (5))) diff --git a/tensorflow/contrib/gdr/BUILD b/tensorflow/contrib/gdr/BUILD index e534fdc17749974ebe713c2730682bea6d7a85e4..704be917b3680a1b5712f4f1dc5059b354db8610 100644 --- a/tensorflow/contrib/gdr/BUILD +++ b/tensorflow/contrib/gdr/BUILD @@ -37,7 +37,7 @@ tf_proto_library_cc( ], ) -tf_cuda_library( +cc_library( name = "gdr_memory_manager", srcs = ["gdr_memory_manager.cc"], hdrs = ["gdr_memory_manager.h"], @@ -58,7 +58,7 @@ tf_cuda_library( ], ) -tf_cuda_library( +cc_library( name = "gdr_worker", srcs = ["gdr_worker.cc"], hdrs = ["gdr_worker.h"], diff --git a/tensorflow/contrib/gdr/gdr.proto b/tensorflow/contrib/gdr/gdr.proto index c0b89245b150bfa49cb527d25b6e1f324f353b25..bd438787c3374be6ead4f6233101fd1f548643ea 100644 --- a/tensorflow/contrib/gdr/gdr.proto +++ b/tensorflow/contrib/gdr/gdr.proto @@ -9,5 +9,4 @@ message RemoteMemoryRegion { uint64 addr = 3; uint32 rkey = 4; uint32 tensor_key = 5; - uint64 checksum = 6; } diff --git a/tensorflow/contrib/gdr/gdr_memory_manager.cc b/tensorflow/contrib/gdr/gdr_memory_manager.cc index 53587fcf3050f313c85485f77ce411cba7faccff..ce1875151597f926aeb6392e7fc8307312da123f 100644 --- a/tensorflow/contrib/gdr/gdr_memory_manager.cc +++ b/tensorflow/contrib/gdr/gdr_memory_manager.cc @@ -26,17 +26,14 @@ limitations under the License. #include #include #include -#include #include "tensorflow/contrib/gdr/gdr.pb.h" #include "tensorflow/core/common_runtime/device.h" #include "tensorflow/core/common_runtime/dma_helper.h" -#include "tensorflow/core/common_runtime/process_state.h" -#if GOOGLE_CUDA #include "tensorflow/core/common_runtime/gpu/gpu_process_state.h" -#include "tensorflow/core/common_runtime/gpu/gpu_util.h" -#endif // GOOGLE_CUDA +#include "tensorflow/core/common_runtime/process_state.h" #include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/random/random.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/platform/numa.h" @@ -81,10 +78,6 @@ int TryToReadNumaNode(ibv_device* device) { int32 value; if (strings::safe_strto32(content, &value)) { if (value < 0) { - LOG(INFO) << "Successful NUMA node read from SysFS had negative value (" - << value - << "), but there must be at least one NUMA node" - ", so returning NUMA node zero"; return port::kNUMANoAffinity; } LOG(INFO) << "NUMA node for device: " << device->name << " is " << value; @@ -114,7 +107,7 @@ class GdrMemoryManager : public RemoteMemoryManager { public: GdrMemoryManager(const string& host, const string& port); - virtual ~GdrMemoryManager(); + virtual ~GdrMemoryManager() {} virtual Status Init() override; @@ -140,7 +133,7 @@ class GdrMemoryManager : public RemoteMemoryManager { return ptr < reinterpret_cast(other->addr) + other->length; } - ibv_mr* FindMemoryRegion(void* addr, size_t length); + ibv_mr* FindMemoryRegion(const Tensor* tensor); void InsertMemoryRegion(void* addr, size_t length, const std::string& allocator_name); @@ -152,7 +145,6 @@ class GdrMemoryManager : public RemoteMemoryManager { const string port_; RdmaEndpointPtr listening_; std::atomic stopped_; - int epfd_; int numa_node_; // Server side endpoints @@ -163,15 +155,19 @@ class GdrMemoryManager : public RemoteMemoryManager { std::atomic next_key_; // Server side on-the-fly tensor buffers - mutex server_mu_; - std::map tensor_buffers_ - GUARDED_BY(server_mu_); + mutex buf_mu_; + std::map tensor_buffers_ GUARDED_BY(buf_mu_); // Client side endpoints mutex client_mu_; std::map, RdmaEndpointPtr> clients_ GUARDED_BY(client_mu_); + // Client side callbacks + mutex callback_mu_; + std::map tensor_callbacks_ + GUARDED_BY(callback_mu_); + // Managed memory regions mutex alloc_mu_; std::vector mrs_ GUARDED_BY(alloc_mu_); @@ -184,16 +180,9 @@ GdrMemoryManager::GdrMemoryManager(const string& host, const string& port) port_(port), listening_(nullptr, EndpointDeleter), stopped_(true), - next_key_(0) {} - -GdrMemoryManager::~GdrMemoryManager() { close(epfd_); } + next_key_(static_cast(random::New64())) {} Status GdrMemoryManager::Init() { - epfd_ = epoll_create1(0); - if (epfd_ == -1) { - return errors::Unavailable(strerror(errno), ": ", "epoll_create"); - } - rdma_addrinfo* addrinfo; rdma_addrinfo hints = {}; hints.ai_port_space = RDMA_PS_TCP; @@ -206,7 +195,7 @@ Status GdrMemoryManager::Init() { ibv_qp_init_attr init_attr = {}; init_attr.qp_type = IBV_QPT_RC; - init_attr.cap.max_recv_wr = 32; + init_attr.cap.max_recv_wr = 1024; init_attr.cap.max_send_wr = 1; init_attr.cap.max_recv_sge = 1; init_attr.cap.max_send_sge = 1; @@ -239,14 +228,6 @@ Status GdrMemoryManager::Init() { "cannot set server to non-blocking mode"); } - epoll_event event = {}; - event.events = EPOLLIN | EPOLLPRI; - event.data.ptr = listening_.get(); - if (epoll_ctl(epfd_, EPOLL_CTL_ADD, listening_->channel->fd, &event)) { - return errors::Unavailable(strerror(errno), ": ", - "cannot add server to epoll"); - } - numa_node_ = TryToReadNumaNode(listening_->verbs->device); SubAllocator::Visitor alloc_visitor = [this](void* ptr, int numa_node, @@ -265,121 +246,114 @@ Status GdrMemoryManager::Init() { ProcessState::singleton()->AddCPUFreeVisitor(free_visitor); LOG(INFO) << "Instrumenting CPU allocator(s)"; -#if GOOGLE_CUDA for (int numa_idx = 0; numa_idx < port::NUMANumNodes(); ++numa_idx) { GPUProcessState::singleton()->AddCUDAHostAllocVisitor(numa_idx, alloc_visitor); GPUProcessState::singleton()->AddCUDAHostFreeVisitor(numa_idx, free_visitor); } + if (IsGDRAvailable()) { SubAllocator::Visitor cuda_alloc_visitor = [this](void* ptr, int gpu_id, size_t num_bytes) { VLOG(2) << "Registering RDMA capable memory region on GPU " << gpu_id; InsertMemoryRegion(ptr, num_bytes, strings::StrCat("GPU:", gpu_id)); }; - for (int numa_idx = 0; numa_idx < port::NUMANumNodes(); ++numa_idx) { - GPUProcessState::singleton()->AddGPUAllocVisitor(numa_idx, - cuda_alloc_visitor); - } - VLOG(1) << "Instrumenting GPU allocator(s) for all Numas"; + GPUProcessState::singleton()->AddGPUAllocVisitor(numa_node_, + cuda_alloc_visitor); + LOG(INFO) << "Instrumenting GPU allocator for NUMA " << numa_node_; } -#endif // GOOGLE_CUDA + return Status::OK(); } void GdrMemoryManager::Run() { stopped_ = false; while (!stopped_) { - epoll_event events[32]; - int ret = epoll_wait(epfd_, events, 32, 1); - if (ret == -1) { - LOG(ERROR) << "epoll_wait: " << strerror(errno); - return; - } - for (int i = 0; i < ret; i++) { - rdma_cm_id* id = static_cast(events[i].data.ptr); - if (id == listening_.get()) { - // Accept incoming connections - if (!rdma_get_request(listening_.get(), &id)) { - if (!rdma_accept(id, nullptr)) { - LOG(INFO) << "Accepted new RDMA connection"; - if (ibv_req_notify_cq(id->recv_cq, 0)) { - LOG(ERROR) << strerror(errno) << ": ibv_req_notify_cq failed"; - EndpointDeleter(id); - continue; - } - for (int i = 0; i < 32; i++) { - if (rdma_post_recvv(id, nullptr, nullptr, 0)) { - LOG(ERROR) << strerror(errno) << ": rdma_post_recvv failed"; - EndpointDeleter(id); - continue; - } - } - int flags = fcntl(id->recv_cq_channel->fd, F_GETFL, 0); - if (fcntl(id->recv_cq_channel->fd, F_SETFL, flags | O_NONBLOCK)) { - LOG(ERROR) << strerror(errno) - << ": cannot set server_client to non-blocking mode"; - EndpointDeleter(id); - continue; - } - epoll_event event = {}; - event.events = EPOLLIN | EPOLLPRI; - event.data.ptr = id; - if (epoll_ctl(epfd_, EPOLL_CTL_ADD, id->recv_cq_channel->fd, - &event)) { - LOG(ERROR) << strerror(errno) - << ": cannot add server client to epoll"; - EndpointDeleter(id); - continue; - } - server_clients_.push_back({id, EndpointDeleter}); + rdma_cm_id* id = nullptr; + // Accept incoming connections + if (!rdma_get_request(listening_.get(), &id)) { + if (!rdma_accept(id, nullptr)) { + LOG(INFO) << "Accepted new RDMA connection"; + for (int i = 0; i < 1024; i++) { + if (rdma_post_recvv(id, nullptr, nullptr, 0)) { + LOG(ERROR) << strerror(errno) << ": rdma_post_recvv failed"; + EndpointDeleter(id); + continue; } } - } else { - // Polling work completions - ibv_cq* cq; - void* context; - if (!ibv_get_cq_event(id->recv_cq_channel, &cq, &context)) { - ibv_ack_cq_events(id->recv_cq, 1); - if (ibv_req_notify_cq(id->recv_cq, 0)) { - LOG(ERROR) << strerror(errno) << ": ibv_req_notify_cq failed"; - continue; + server_clients_.push_back({id, EndpointDeleter}); + } + } + // Polling server side work completions + for (const auto& client : server_clients_) { + ibv_wc wc[32]; + int ret = ibv_poll_cq(client->recv_cq, 32, wc); + if (ret < 0) { + LOG(ERROR) << "ibv_poll_cq failed"; + continue; + } + for (int i = 0; i < ret; i++) { + if (wc[i].opcode != IBV_WC_RECV_RDMA_WITH_IMM) { + LOG(ERROR) << "Received unknown operation " << wc[i].opcode; + } + if (wc[i].status != 0) { + LOG(ERROR) << ibv_wc_status_str(wc[i].status); + } + TensorKey tensor_key = ntohl(wc[i].imm_data); + + if (rdma_post_recvv(client.get(), nullptr, nullptr, 0)) { + perror("rdma_post_recvv"); + LOG(ERROR) << "rdma_post_recvv failed"; + } + + mutex_lock l(buf_mu_); + auto iter = tensor_buffers_.find(tensor_key); + if (iter == std::end(tensor_buffers_)) { + LOG(ERROR) << "Cannot find tensor buffer for tensor key " + << tensor_key; + } else { + const TensorBuffer* buffer = iter->second; + buffer->Unref(); + tensor_buffers_.erase(iter); + } + } + } + // Polling client side work completions + if (client_mu_.try_lock()) { + for (const auto& client : clients_) { + ibv_wc wc[32]; + int ret = ibv_poll_cq(client.second->send_cq, 32, wc); + for (int i = 0; i < ret; i++) { + Status s; + if (wc[i].status) { + s = errors::Unavailable(ibv_wc_status_str(wc[i].status)); + } else { + s = Status::OK(); } - ibv_wc wc[32]; - int ret = ibv_poll_cq(id->recv_cq, 32, wc); - if (ret < 0) { - LOG(ERROR) << "ibv_poll_cq failed"; - continue; + TensorKey key = wc[i].wr_id; + + ibv_send_wr wr = {}; + wr.opcode = IBV_WR_RDMA_WRITE_WITH_IMM; + wr.imm_data = htonl(key); + ibv_send_wr* bad_wr; + if (ibv_post_send(client.second->qp, &wr, &bad_wr)) { + LOG(ERROR) << strerror(errno) + << ": ibv_post_send failed for tensor_key " << key; } - for (int i = 0; i < ret; i++) { - if (wc[i].opcode != IBV_WC_RECV_RDMA_WITH_IMM) { - LOG(ERROR) << "Received unknown operation " << wc[i].opcode; - } - if (wc[i].status != 0) { - LOG(ERROR) << ibv_wc_status_str(wc[i].status); - } - TensorKey tensor_key = ntohl(wc[i].imm_data); - { - mutex_lock l(server_mu_); - auto iter = tensor_buffers_.find(tensor_key); - if (iter == std::end(tensor_buffers_)) { - LOG(ERROR) << "Cannot find tensor buffer for tensor key " - << tensor_key; - } else { - const TensorBuffer* buffer = iter->second; - buffer->Unref(); - tensor_buffers_.erase(iter); - } - } - if (rdma_post_recvv(id, nullptr, nullptr, 0)) { - perror("rdma_post_recvv"); - LOG(ERROR) << "rdma_post_recvv failed"; - continue; - } + + mutex_lock l(callback_mu_); + auto iter = tensor_callbacks_.find(key); + if (iter != std::end(tensor_callbacks_)) { + iter->second(s); + tensor_callbacks_.erase(iter); + } else { + LOG(WARNING) << "Cannot find client callback with tensor key " + << key; } } } + client_mu_.unlock(); } } } @@ -390,116 +364,58 @@ void GdrMemoryManager::TransportOptionsFromTensor( ::google::protobuf::Any* mutable_transport_options, const Tensor& tensor, Device* device, DeviceContext* device_context, bool on_host, StatusCallback done) { - auto buffer = DMAHelper::buffer(&tensor); - void* addr = buffer->data(); - size_t length = buffer->size(); - if (length == 0) { - done(errors::Unavailable("Cannot register tensor buffer of size 0")); - return; - } + ibv_mr* mr = FindMemoryRegion(&tensor); + const TensorBuffer* buffer = DMAHelper::buffer(&tensor); - ibv_mr* mr = FindMemoryRegion(addr, length); - -#if GOOGLE_CUDA - if (device->tensorflow_gpu_device_info() && !on_host) { - Allocator* alloc = GPUProcessState::singleton()->GetCUDAHostAllocator(0); - Tensor* host_copy = new Tensor(alloc, tensor.dtype(), tensor.shape()); - GPUUtil::CopyGPUTensorToCPU( - device, device_context, &tensor, host_copy, - [done, host_copy, mutable_transport_options, this](const Status& s) { - if (!s.ok()) { - done(s); - delete host_copy; - return; - } - auto buffer = DMAHelper::buffer(host_copy); - void* addr = buffer->data(); - size_t length = buffer->size(); - ibv_mr* mr = FindMemoryRegion(addr, length); - - if (mr == nullptr) { - done(errors::Unavailable("Cannot find pinned memory region")); - delete host_copy; - return; - } - - buffer->Ref(); - TensorKey tensor_key = next_key_++; - { - mutex_lock l(server_mu_); - tensor_buffers_.insert(std::make_pair(tensor_key, buffer)); - } - - uint64_t checksum = 0; - if (VLOG_IS_ON(2)) { - checksum = GPUUtil::Checksum(*host_copy); - } - - RemoteMemoryRegion remote_mr; - remote_mr.set_host(host_); - remote_mr.set_port(port_); - remote_mr.set_addr(reinterpret_cast(addr)); - remote_mr.set_rkey(mr->rkey); - remote_mr.set_tensor_key(tensor_key); - remote_mr.set_checksum(checksum); - mutable_transport_options->PackFrom(remote_mr); - - done(Status::OK()); - delete host_copy; - }); - return; - } -#endif + Tensor* copy = nullptr; if (mr == nullptr) { - Allocator* alloc = ProcessState::singleton()->GetCPUAllocator(numa_node_); - Tensor host_copy(alloc, tensor.dtype(), tensor.shape()); - - std::memcpy(DMAHelper::buffer(&host_copy)->data(), buffer->data(), length); - VLOG(2) << "Copying " << length << " bytes unpinned tensor buffer"; - - buffer = DMAHelper::buffer(&host_copy); - addr = buffer->data(); - length = buffer->size(); - - mr = FindMemoryRegion(addr, length); + AllocatorAttributes alloc_attrs; + alloc_attrs.set_gpu_compatible(true); + alloc_attrs.set_nic_compatible(true); + alloc_attrs.set_on_host(true); + Allocator* alloc = device->GetAllocator(alloc_attrs); + copy = new Tensor(alloc, tensor.dtype(), tensor.shape()); + + mr = FindMemoryRegion(copy); + buffer = DMAHelper::buffer(copy); if (mr == nullptr) { done(errors::Unavailable("Cannot find pinned memory region")); + delete copy; return; } - - buffer->Ref(); - } else { - buffer->Ref(); } TensorKey tensor_key = next_key_++; + buffer->Ref(); { - mutex_lock l(server_mu_); + mutex_lock l(buf_mu_); tensor_buffers_.insert(std::make_pair(tensor_key, buffer)); } - uint64_t checksum = 0; - if (VLOG_IS_ON(2)) { -#ifdef GOOGLE_CUDA - if (device->tensorflow_gpu_device_info() && !on_host) { - checksum = GPUUtil::Checksum(device, device_context, tensor); - } else { - checksum = GPUUtil::Checksum(tensor); - } -#endif - } - RemoteMemoryRegion remote_mr; remote_mr.set_host(host_); remote_mr.set_port(port_); - remote_mr.set_addr(reinterpret_cast(addr)); + remote_mr.set_addr(reinterpret_cast(buffer->data())); remote_mr.set_rkey(mr->rkey); remote_mr.set_tensor_key(tensor_key); - remote_mr.set_checksum(checksum); mutable_transport_options->PackFrom(remote_mr); - done(Status::OK()); + if (copy && device->tensorflow_gpu_device_info() && !on_host) { + device_context->CopyDeviceTensorToCPU(&tensor, "" /* tensor_name */, device, + copy, [done, copy](const Status& s) { + done(s); + delete copy; + }); + return; + } else if (copy) { + std::memcpy(buffer->data(), DMAHelper::buffer(&tensor)->data(), + buffer->size()); + done(Status::OK()); + delete copy; // OK to delete; we have reffed the underlying TensorBuffer + } else { + done(Status::OK()); + } } void GdrMemoryManager::TensorFromTransportOptions( @@ -512,42 +428,10 @@ void GdrMemoryManager::TensorFromTransportOptions( return; } - auto buffer = DMAHelper::buffer(tensor); - void* addr = buffer->data(); - size_t length = buffer->size(); - ibv_mr* mr = FindMemoryRegion(addr, length); - - Tensor host_copy; -#if GOOGLE_CUDA - if (mr == nullptr && !on_host) { - Allocator* alloc = - GPUProcessState::singleton()->GetCUDAHostAllocator(numa_node_); - host_copy = Tensor(alloc, tensor->dtype(), tensor->shape()); - buffer = DMAHelper::buffer(&host_copy); - addr = buffer->data(); - length = buffer->size(); - mr = FindMemoryRegion(addr, length); - } -#endif // GOOGLE_CUDA - - if (mr == nullptr) { - Allocator* alloc = ProcessState::singleton()->GetCPUAllocator(numa_node_); - host_copy = Tensor(alloc, tensor->dtype(), tensor->shape()); - - buffer = DMAHelper::buffer(&host_copy); - addr = buffer->data(); - length = buffer->size(); - - mr = FindMemoryRegion(addr, length); - if (mr == nullptr) { - done(errors::Unavailable("Cannot find pinned memory region")); - return; - } - } - - decltype(clients_)::iterator iter; - bool success; + rdma_cm_id* id = nullptr; { + decltype(clients_)::iterator iter; + bool success; mutex_lock l(client_mu_); std::tie(iter, success) = clients_.insert( std::make_pair(std::make_pair(remote_mr.host(), remote_mr.port()), @@ -560,93 +444,94 @@ void GdrMemoryManager::TensorFromTransportOptions( return; } } - } - rdma_cm_id* id = iter->second.get(); - - uint64_t start = Env::Default()->NowMicros(); - - if (rdma_post_read(id, nullptr, buffer->data(), buffer->size(), mr, 0, - remote_mr.addr(), remote_mr.rkey())) { - done(errors::Unavailable(strerror(errno), ": ", "rdma_post_read failed")); - return; + id = iter->second.get(); } - ibv_send_wr wr = {}; - wr.opcode = IBV_WR_RDMA_WRITE_WITH_IMM; - wr.imm_data = htonl(remote_mr.tensor_key()); - wr.send_flags = IBV_SEND_SIGNALED; - ibv_send_wr* bad_wr; - if (ibv_post_send(id->qp, &wr, &bad_wr)) { - done(errors::Unavailable(strerror(errno), ": ", "ibv_post_send failed")); - return; - } + ibv_mr* mr = FindMemoryRegion(tensor); + const TensorBuffer* buffer = DMAHelper::buffer(tensor); - ibv_wc wc = {}; - int ret; - while ((ret = ibv_poll_cq(id->send_cq, 1, &wc)) == 0) - ; - if (ret < 0 || wc.status) { - done(errors::Unavailable(ibv_wc_status_str(wc.status))); - return; - } + const Tensor* copy = nullptr; -#if GOOGLE_CUDA - if (device->tensorflow_gpu_device_info() && !on_host && - host_copy.NumElements() > 0) { - uint64_t checksum = 0; - if (VLOG_IS_ON(2)) { - checksum = GPUUtil::Checksum(host_copy); - CHECK(checksum == remote_mr.checksum()) - << "Checksum mismatch: " << checksum << "!=" << remote_mr.checksum(); + if (mr == nullptr) { + AllocatorAttributes alloc_attrs; + alloc_attrs.set_gpu_compatible(true); + alloc_attrs.set_nic_compatible(true); + alloc_attrs.set_on_host(true); + Allocator* alloc = device->GetAllocator(alloc_attrs); + copy = new Tensor(alloc, tensor->dtype(), tensor->shape()); + + mr = FindMemoryRegion(copy); + buffer = DMAHelper::buffer(copy); + if (mr == nullptr) { + done(errors::Unavailable("Cannot find pinned memory region")); + delete copy; + return; } - Tensor* ref = new Tensor; - std::swap(host_copy, *ref); - GPUUtil::CopyCPUTensorToGPU( - ref, device_context, device, tensor, - [ref, done, buffer, remote_mr, start](const Status& s) { - if (!s.ok()) { - done(s); - delete ref; - return; - } - uint64_t end = Env::Default()->NowMicros(); - - VLOG(2) << "RDMA from remote memory region " << remote_mr.rkey() - << " of size " << buffer->size() << " with tensor key " - << remote_mr.tensor_key() << " took " << (end - start) - << " micros"; - done(Status::OK()); - delete ref; - }); - return; } -#endif // GOOGLE_CUDA - if ((on_host || !device->tensorflow_gpu_device_info()) && - host_copy.NumElements() > 0) { - std::memcpy(DMAHelper::buffer(tensor)->data(), addr, length); - VLOG(2) << "Copying " << length << " bytes unpinned tensor buffer"; - } + uint64_t start = Env::Default()->NowMicros(); - uint64_t end = Env::Default()->NowMicros(); + TensorKey tensor_key = remote_mr.tensor_key(); - VLOG(2) << "RDMA from remote memory region " << remote_mr.rkey() - << " of size " << buffer->size() << " with tensor key " - << remote_mr.tensor_key() << " took " << (end - start) << " micros"; + StatusCallback callback = [done, copy, device, device_context, on_host, + tensor, start, tensor_key](const Status& s) { + if (!s.ok()) { + done(s); + if (copy) { + delete copy; + } + return; + } - uint64_t checksum = 0; - if (VLOG_IS_ON(2)) { -#ifdef GOOGLE_CUDA - if (device->tensorflow_gpu_device_info() && !on_host) { - checksum = GPUUtil::Checksum(device, device_context, *tensor); + VLOG(2) << "RDMA of tensor " << tensor_key << " of size " + << DMAHelper::buffer(tensor)->size() << " took " + << (Env::Default()->NowMicros() - start) << " micros"; + + if (copy && device->tensorflow_gpu_device_info() && !on_host) { + device_context->CopyCPUTensorToDevice(copy, device, tensor, + [done, copy](const Status& s) { + done(s); + delete copy; + }); + } else if (copy) { + std::memcpy(DMAHelper::buffer(tensor)->data(), + DMAHelper::buffer(copy)->data(), + DMAHelper::buffer(copy)->size()); + done(s); + delete copy; } else { - checksum = GPUUtil::Checksum(*tensor); + done(s); + } + }; + + { + mutex_lock l(callback_mu_); + if (tensor_callbacks_.find(tensor_key) == std::end(tensor_callbacks_)) { + tensor_callbacks_.insert(std::make_pair(tensor_key, std::move(callback))); + } else { + done(errors::Unavailable("Received duplicated tensor key")); + if (copy) { + delete copy; + } + return; + } + } + + if (rdma_post_read(id, reinterpret_cast(tensor_key), buffer->data(), + buffer->size(), mr, IBV_SEND_SIGNALED, remote_mr.addr(), + remote_mr.rkey())) { + done(errors::Unavailable(strerror(errno), ": ", "rdma_post_read failed")); + { + mutex_lock l(callback_mu_); + auto iter = tensor_callbacks_.find(tensor_key); + if (iter != std::end(tensor_callbacks_)) { + tensor_callbacks_.erase(iter); + } + } + if (copy) { + delete copy; } - CHECK(checksum == remote_mr.checksum()) - << "Checksum mismatch: " << checksum << "!=" << remote_mr.checksum(); -#endif } - done(Status::OK()); } Status GdrMemoryManager::CreateEndpoint(const string& host, const string& port, @@ -663,7 +548,7 @@ Status GdrMemoryManager::CreateEndpoint(const string& host, const string& port, ibv_qp_init_attr init_attr = {}; init_attr.qp_type = IBV_QPT_RC; init_attr.cap.max_recv_wr = 1; - init_attr.cap.max_send_wr = 32; + init_attr.cap.max_send_wr = 1024; init_attr.cap.max_recv_sge = 1; init_attr.cap.max_send_sge = 1; @@ -687,8 +572,8 @@ Status GdrMemoryManager::CreateEndpoint(const string& host, const string& port, return Status::OK(); } -ibv_mr* GdrMemoryManager::FindMemoryRegion(void* addr, size_t length) { - if (length == 0) return nullptr; +ibv_mr* GdrMemoryManager::FindMemoryRegion(const Tensor* tensor) { + const void* addr = DMAHelper::buffer(tensor)->data(); mutex_lock l(alloc_mu_); auto iter = std::upper_bound(mrs_.begin(), mrs_.end(), addr, &Comparator); if (iter == std::end(mrs_) || iter->get()->addr > addr) { diff --git a/tensorflow/contrib/gdr/gdr_rendezvous_mgr.cc b/tensorflow/contrib/gdr/gdr_rendezvous_mgr.cc index fbccbead03fc0d641db40ede661bf3677d44c45d..5f8c300155770ed03ad12a9fa5ac74456edaf024 100644 --- a/tensorflow/contrib/gdr/gdr_rendezvous_mgr.cc +++ b/tensorflow/contrib/gdr/gdr_rendezvous_mgr.cc @@ -58,11 +58,9 @@ class GdrRecvTensorCall : public BaseRecvTensorCall { resp_.InitAlloc(dst_device_, recv_args_.alloc_attrs); StatusCallback cb = [this, recv_done](const Status& s) { bool dma_ok = resp_.metadata().has_transport_options(); - if (s.ok() && tensor().TotalBytes() > 0 && (!is_dead()) && dma_ok) { + if (s.ok() && tensor().TotalBytes() > 1024 && (!is_dead()) && dma_ok) { auto transport_options = resp_.metadata().transport_options(); - const bool on_host = - (dst_device_->tensorflow_gpu_device_info() == nullptr) || - recv_args_.alloc_attrs.on_host(); + const bool on_host = recv_args_.alloc_attrs.on_host(); remote_memory_manager_->TensorFromTransportOptions( const_cast(&tensor()), transport_options, dst_device_, recv_args_.device_context, on_host, @@ -70,9 +68,6 @@ class GdrRecvTensorCall : public BaseRecvTensorCall { if (!s.ok()) { mutex_lock l(mu_); status_.Update(s); - LOG(ERROR) << "Cannot find pinned memory region from allocator " - << dst_device_->GetAllocator(recv_args_.alloc_attrs) - ->Name(); } recv_done(); }); diff --git a/tensorflow/contrib/gdr/gdr_server_lib.cc b/tensorflow/contrib/gdr/gdr_server_lib.cc index b3f48ec1dd9c75055f4e1ea76eb203b6ccf94718..dc0d5d548b80d36409778ef34e63171441f10142 100644 --- a/tensorflow/contrib/gdr/gdr_server_lib.cc +++ b/tensorflow/contrib/gdr/gdr_server_lib.cc @@ -74,9 +74,8 @@ Status GdrServer::Start() { } Status GdrServer::Stop() { - TF_RETURN_IF_ERROR(GrpcServer::Stop()); remote_memory_manager_->Stop(); - return Status::OK(); + return GrpcServer::Stop(); } Status GdrServer::Join() { diff --git a/tensorflow/contrib/gdr/gdr_worker.cc b/tensorflow/contrib/gdr/gdr_worker.cc index 867cb83f42034c8e9061e333ea671457745f92c3..016e5ea27b397830c69b6e1761b5994ebcfa9c3d 100644 --- a/tensorflow/contrib/gdr/gdr_worker.cc +++ b/tensorflow/contrib/gdr/gdr_worker.cc @@ -18,9 +18,6 @@ limitations under the License. #include "tensorflow/core/common_runtime/device.h" #include "tensorflow/core/common_runtime/device_mgr.h" #include "tensorflow/core/common_runtime/dma_helper.h" -#if GOOGLE_CUDA -#include "tensorflow/core/common_runtime/gpu/gpu_util.h" -#endif // GOOGLE_CUDA #include "tensorflow/core/common_runtime/process_util.h" #include "tensorflow/core/common_runtime/step_stats_collector.h" #include "tensorflow/core/distributed_runtime/graph_mgr.h" @@ -78,7 +75,7 @@ void GdrWorker::GrpcRecvTensorAsync(CallOptions* opts, const bool dma_ok = request->dma_ok(); env_->rendezvous_mgr->RecvLocalAsync( step_id, parsed, - [this, opts, response, done, src_dev, dma_ok]( + [this, opts, response, done, src_dev, request, dma_ok]( const Status& status, const Rendezvous::Args& send_args, const Rendezvous::Args&, const Tensor& val, const bool is_dead) { opts->ClearCancelCallback(); @@ -89,10 +86,8 @@ void GdrWorker::GrpcRecvTensorAsync(CallOptions* opts, // 3) the tensor has the on_host allocation attribute, // i.e. it's in CPU RAM *independent of its assigned // device type*. - const bool on_host = - (src_dev->tensorflow_gpu_device_info() == nullptr) || - send_args.alloc_attrs.on_host(); - if (val.TotalBytes() > 0 && (!is_dead) && + const bool on_host = send_args.alloc_attrs.on_host(); + if (val.TotalBytes() > 1024 && (!is_dead) && DMAHelper::CanUseDMA(&val) && dma_ok) { // DMA cases. RecvTensorResponse* proto = new RecvTensorResponse; @@ -117,8 +112,7 @@ void GdrWorker::GrpcRecvTensorAsync(CallOptions* opts, } else { // Non-DMA cases. if (src_dev->tensorflow_gpu_device_info() && (!on_host)) { -#if GOOGLE_CUDA - const DeviceContext* send_dev_context = send_args.device_context; + DeviceContext* send_dev_context = send_args.device_context; AllocatorAttributes alloc_attrs; alloc_attrs.set_gpu_compatible(true); alloc_attrs.set_on_host(true); @@ -127,7 +121,8 @@ void GdrWorker::GrpcRecvTensorAsync(CallOptions* opts, CHECK(send_dev_context) << "send dev name: " << src_dev->name() << " gpu_info: " << src_dev->tensorflow_gpu_device_info(); - // "val" is on a GPU. Uses GPUUtil to fill the response proto. + // "val" is on an accelerator device. Uses the device_context to + // fill the copy on host. StatusCallback copy_ready = [response, done, copy, is_dead](const Status& s) { // The value is now ready to be returned on the wire. @@ -136,11 +131,8 @@ void GdrWorker::GrpcRecvTensorAsync(CallOptions* opts, delete copy; }; - GPUUtil::CopyGPUTensorToCPU(src_dev, send_dev_context, &val, copy, - copy_ready); -#else - done(errors::Internal("No GPU device in process")); -#endif // GOOGLE_CUDA + send_dev_context->CopyDeviceTensorToCPU( + &val, request->rendezvous_key(), src_dev, copy, copy_ready); } else { grpc::EncodeTensorToByteBuffer(is_dead, val, response); done(Status::OK()); diff --git a/tensorflow/contrib/hadoop/python/kernel_tests/hadoop_test.py b/tensorflow/contrib/hadoop/python/kernel_tests/hadoop_test.py index f7f1189bb93c611719186a697c40f371644f63a2..bc941ae9f23eaa5c46fcca95b9aba0ac0d87960a 100644 --- a/tensorflow/contrib/hadoop/python/kernel_tests/hadoop_test.py +++ b/tensorflow/contrib/hadoop/python/kernel_tests/hadoop_test.py @@ -21,6 +21,7 @@ from __future__ import print_function import os from tensorflow.contrib.hadoop.python.ops import hadoop_dataset_ops +from tensorflow.python.data.ops import dataset_ops from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors @@ -47,7 +48,7 @@ class SequenceFileDatasetTest(test.TestCase): dataset = hadoop_dataset_ops.SequenceFileDataset(filenames).repeat( num_repeats) - iterator = dataset.make_initializable_iterator() + iterator = dataset_ops.make_initializable_iterator(dataset) init_op = iterator.initializer get_next = iterator.get_next() diff --git a/tensorflow/contrib/hadoop/python/ops/hadoop_dataset_ops.py b/tensorflow/contrib/hadoop/python/ops/hadoop_dataset_ops.py index bf398b838dfaaff6fdaf33a6cd7086ef13e43a3e..77813519c136665a2fea30d4387f5e7a9776b20b 100644 --- a/tensorflow/contrib/hadoop/python/ops/hadoop_dataset_ops.py +++ b/tensorflow/contrib/hadoop/python/ops/hadoop_dataset_ops.py @@ -20,15 +20,19 @@ from __future__ import print_function from tensorflow.contrib.hadoop.python.ops import gen_dataset_ops from tensorflow.contrib.hadoop.python.ops import hadoop_op_loader # pylint: disable=unused-import from tensorflow.python.data.ops import dataset_ops -from tensorflow.python.data.util import nest +from tensorflow.python.data.util import structure from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops -from tensorflow.python.framework import tensor_shape +from tensorflow.python.util import deprecation class SequenceFileDataset(dataset_ops.DatasetSource): """A Sequence File Dataset that reads the sequence file.""" + @deprecation.deprecated( + None, + "tf.contrib.hadoop will be removed in 2.0, the support for Apache Hadoop " + "will continue to be provided through the tensorflow/io GitHub project.") def __init__(self, filenames): """Create a `SequenceFileDataset`. @@ -40,15 +44,12 @@ class SequenceFileDataset(dataset_ops.DatasetSource): For example: ```python + tf.enable_eager_execution() + dataset = tf.contrib.hadoop.SequenceFileDataset("/foo/bar.seq") - iterator = dataset.make_one_shot_iterator() - next_element = iterator.get_next() # Prints the (key, value) pairs inside a hadoop sequence file. - while True: - try: - print(sess.run(next_element)) - except tf.errors.OutOfRangeError: - break + for key, value in dataset: + print(key, value) ``` Args: @@ -60,16 +61,10 @@ class SequenceFileDataset(dataset_ops.DatasetSource): def _as_variant_tensor(self): return gen_dataset_ops.sequence_file_dataset( - self._filenames, nest.flatten(self.output_types)) - - @property - def output_classes(self): - return ops.Tensor, ops.Tensor - - @property - def output_shapes(self): - return (tensor_shape.TensorShape([]), tensor_shape.TensorShape([])) + self._filenames, self._element_structure._flat_types) # pylint: disable=protected-access @property - def output_types(self): - return dtypes.string, dtypes.string + def _element_structure(self): + return structure.NestedStructure( + (structure.TensorStructure(dtypes.string, []), + structure.TensorStructure(dtypes.string, []))) diff --git a/tensorflow/contrib/ignite/README.md b/tensorflow/contrib/ignite/README.md index c7db0b77e25668fb8a42d204776044420f403e44..5a8c650fb927be0c835aaceffc516c048195c7bf 100644 --- a/tensorflow/contrib/ignite/README.md +++ b/tensorflow/contrib/ignite/README.md @@ -54,14 +54,12 @@ jdbc:ignite:thin://localhost/> INSERT INTO KITTEN_CACHE VALUES (3, 'LITTLE BALL ```python >>> import tensorflow as tf >>> from tensorflow.contrib.ignite import IgniteDataset ->>> +>>> tf.enable_eager_execution() +>>> >>> dataset = IgniteDataset(cache_name="SQL_PUBLIC_KITTEN_CACHE") ->>> iterator = dataset.make_one_shot_iterator() ->>> next_obj = iterator.get_next() >>> ->>> with tf.Session() as sess: ->>> for _ in range(3): ->>> print(sess.run(next_obj)) +>>> for element in dataset: +>>> print(element) {'key': 1, 'val': {'NAME': b'WARM KITTY'}} {'key': 2, 'val': {'NAME': b'SOFT KITTY'}} @@ -74,23 +72,22 @@ jdbc:ignite:thin://localhost/> INSERT INTO KITTEN_CACHE VALUES (3, 'LITTLE BALL ```python >>> import tensorflow as tf >>> from tensorflow.contrib.ignite import IgniteDataset ->>> +>>> tf.enable_eager_execution() +>>> >>> dataset = IgniteDataset(cache_name="IMAGES") ->>> iterator = dataset.make_one_shot_iterator() ->>> next_obj = iterator.get_next() >>> ->>> with tf.Session() as sess: ->>> print(sess.run(next_obj)) +>>> for element in dataset.take(1): +>>> print(element) { - 'key': 'kitten.png', + 'key': 'kitten.png', 'val': { 'metadata': { 'file_name': b'kitten.png', 'label': b'little ball of fur', - width: 800, + width: 800, height: 600 - }, + }, 'pixels': [0, 0, 0, 0, ..., 0] } } @@ -100,13 +97,11 @@ jdbc:ignite:thin://localhost/> INSERT INTO KITTEN_CACHE VALUES (3, 'LITTLE BALL ```python >>> import tensorflow as tf >>> from tensorflow.contrib.ignite import IgniteDataset ->>> +>>> >>> dataset = IgniteDataset(cache_name="IMAGES").map(lambda obj: obj['val']['pixels']) ->>> iterator = dataset.make_one_shot_iterator() ->>> next_obj = iterator.get_next() >>> ->>> with tf.Session() as sess: ->>> print(sess.run(next_obj)) +>>> for element in dataset: +>>> print(element) [0, 0, 0, 0, ..., 0] ``` @@ -126,18 +121,18 @@ Ignite Dataset allows using these two aspects of distributed neural network trai ```python >>> import tensorflow as tf >>> from tensorflow.contrib.ignite import IgniteDataset ->>> +>>> >>> dataset = IgniteDataset("IMAGES") >>> >>> # Compute gradients locally on every worker node. ->>> gradients = [] +>>> gradients = [] >>> for i in range(5): >>> with tf.device("/job:WORKER/task:%d" % i): ->>> device_iterator = dataset.make_one_shot_iterator() +>>> device_iterator = tf.compat.v1.data.make_one_shot_iterator(dataset) >>> device_next_obj = device_iterator.get_next() >>> gradient = compute_gradient(device_next_obj) ->>> gradients.append(gradient) ->>> +>>> gradients.append(gradient) +>>> >>> # Aggregate them on master node. >>> result_gradient = tf.reduce_sum(gradients) >>> @@ -145,7 +140,7 @@ Ignite Dataset allows using these two aspects of distributed neural network trai >>> print(sess.run(result_gradient)) ``` -High-level TensorFlow API for [distributed training](https://www.tensorflow.org/api_docs/python/tf/contrib/distribute/DistributionStrategy) is supported as well. +High-level TensorFlow API for [distributed training](https://www.tensorflow.org/api_docs/python/tf/contrib/distribute/DistributionStrategy) is supported as well. ### Distributed File System diff --git a/tensorflow/contrib/ignite/python/ops/ignite_dataset_ops.py b/tensorflow/contrib/ignite/python/ops/ignite_dataset_ops.py index 936b29a4f50794380d48efed99e267c6b4c44dc6..66e654ca636a5a051c6f9cd35bf9001dfbcbf7f4 100644 --- a/tensorflow/contrib/ignite/python/ops/ignite_dataset_ops.py +++ b/tensorflow/contrib/ignite/python/ops/ignite_dataset_ops.py @@ -27,17 +27,16 @@ import six from tensorflow.contrib.ignite.python.ops import gen_dataset_ops from tensorflow.contrib.ignite.python.ops import ignite_op_loader # pylint: disable=unused-import from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.data.util import structure from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape +from tensorflow.python.util import deprecation @six.add_metaclass(abc.ABCMeta) class Readable(object): - """Readable abstract class that exposes methods to do reading-related - - operations. - """ + """Abstract class that exposes methods to do reading-related operations.""" @abc.abstractmethod def __init__(self): @@ -227,10 +226,7 @@ types = { class TypeTreeNode(object): - """TypeTreeNode class exposes methods to format object tree structure - - data. - """ + """TypeTreeNode class exposes methods to format object tree structure data.""" def __init__(self, name, type_id, fields=None, permutation=None): """Constructs a new instance of TypeTreeNode. @@ -692,18 +688,22 @@ class IgniteClient(TcpClient): class IgniteDataset(dataset_ops.DatasetSource): - """Apache Ignite is a memory-centric distributed database, caching, and - - processing platform for transactional, analytical, and streaming workloads, - delivering in-memory speeds at petabyte scale. This contrib package - contains an integration between Apache Ignite and TensorFlow. The - integration is based on tf.data from TensorFlow side and Binary Client - Protocol from Apache Ignite side. It allows to use Apache Ignite as a - datasource for neural network training, inference and all other + """Apache Ignite is a memory-centric distributed database. + + It acts as a caching and processing platform for transactional, analytical, + and streaming workloads, delivering in-memory speeds at petabyte scale. + This contrib package contains an integration between Apache Ignite and + TensorFlow. The integration is based on tf.data from TensorFlow side and + Binary Client Protocol from Apache Ignite side. It allows to use Apache + Ignite as a datasource for neural network training, inference and all other computations supported by TensorFlow. Ignite Dataset is based on Apache Ignite Binary Client Protocol. """ + @deprecation.deprecated( + None, + "tf.contrib.ignite will be removed in 2.0, the support for Apache Ignite " + "will continue to be provided through the tensorflow/io GitHub project.") def __init__(self, cache_name, host="localhost", @@ -756,6 +756,9 @@ class IgniteDataset(dataset_ops.DatasetSource): self.cache_type.to_permutation(), dtype=dtypes.int32, name="permutation") + self._structure = structure.convert_legacy_structure( + self.cache_type.to_output_types(), self.cache_type.to_output_shapes(), + self.cache_type.to_output_classes()) def _as_variant_tensor(self): return gen_dataset_ops.ignite_dataset(self.cache_name, self.host, self.port, @@ -763,13 +766,5 @@ class IgniteDataset(dataset_ops.DatasetSource): self.schema, self.permutation) @property - def output_classes(self): - return self.cache_type.to_output_classes() - - @property - def output_shapes(self): - return self.cache_type.to_output_shapes() - - @property - def output_types(self): - return self.cache_type.to_output_types() + def _element_structure(self): + return self._structure diff --git a/tensorflow/contrib/ignite/python/tests/ignite_dataset_test.py b/tensorflow/contrib/ignite/python/tests/ignite_dataset_test.py index ef29b5f14a4b2fea2400ec4d56a7ad2cf44cf2cb..ff5d4c458c859fd8e5e3ae65ee41a454d55d6538 100644 --- a/tensorflow/contrib/ignite/python/tests/ignite_dataset_test.py +++ b/tensorflow/contrib/ignite/python/tests/ignite_dataset_test.py @@ -21,6 +21,7 @@ import os from tensorflow.contrib.ignite import IgniteDataset from tensorflow.python.client import session +from tensorflow.python.data.ops import dataset_ops from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors from tensorflow.python.platform import test @@ -65,7 +66,7 @@ class IgniteDatasetTest(test.TestCase): self.assertEqual(dtypes.string, dataset.output_types["val"]["NAME"]) self.assertEqual(dtypes.int64, dataset.output_types["val"]["VAL"]) - it = dataset.make_one_shot_iterator() + it = dataset_ops.make_one_shot_iterator(dataset) ne = it.get_next() with session.Session() as sess: diff --git a/tensorflow/contrib/image/python/kernel_tests/image_ops_test.py b/tensorflow/contrib/image/python/kernel_tests/image_ops_test.py index 4997c31a7fc7f4243d03b22fc9c01fb13a2a25a4..ba5cdfebf92c07e496ed588848d5859ff6a5bff2 100644 --- a/tensorflow/contrib/image/python/kernel_tests/image_ops_test.py +++ b/tensorflow/contrib/image/python/kernel_tests/image_ops_test.py @@ -281,6 +281,13 @@ class ImageOpsTest(test_util.TensorFlowTestCase): value.eval(), np.array([[4, 4], [4, 4]]).astype(dtype.as_numpy_dtype())) + @test_util.run_in_graph_and_eager_modes + def test_transform_eager(self): + image = constant_op.constant([[1., 2.], [3., 4.]]) + value = image_ops.transform(image, [1] * 8) + with self.test_session(use_gpu=True): + self.assertAllEqual(self.evaluate(value), np.array([[4, 4], [4, 4]])) + class BipartiteMatchTest(test_util.TensorFlowTestCase): diff --git a/tensorflow/contrib/image/python/ops/image_ops.py b/tensorflow/contrib/image/python/ops/image_ops.py index d4fb99a017faebe30384d739f22f4ff5fa986bc4..b25a6f7b5742917a032946fe03a0dab20e7dc1ad 100644 --- a/tensorflow/contrib/image/python/ops/image_ops.py +++ b/tensorflow/contrib/image/python/ops/image_ops.py @@ -17,6 +17,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from tensorflow.python.eager import context from tensorflow.contrib.image.ops import gen_image_ops from tensorflow.contrib.util import loader from tensorflow.python.framework import common_shapes @@ -271,8 +272,11 @@ def transform(images, raise TypeError("Images should have rank between 2 and 4.") if output_shape is None: - output_shape = tensor_util.constant_value( - array_ops.shape(images)[1:3]) or array_ops.shape(images)[1:3] + output_shape = array_ops.shape(images)[1:3] + if not context.executing_eagerly(): + output_shape_value = tensor_util.constant_value(output_shape) + if output_shape_value is not None: + output_shape = output_shape_value output_shape = ops.convert_to_tensor( output_shape, dtypes.int32, name="output_shape") diff --git a/tensorflow/contrib/kafka/python/ops/kafka_dataset_ops.py b/tensorflow/contrib/kafka/python/ops/kafka_dataset_ops.py index 7129f09e8b42e48a9c768fd4a66cde3d4da9d31d..b399e1b6c2ac47db205b5d8bbc81875ef5c08a31 100644 --- a/tensorflow/contrib/kafka/python/ops/kafka_dataset_ops.py +++ b/tensorflow/contrib/kafka/python/ops/kafka_dataset_ops.py @@ -20,15 +20,20 @@ from __future__ import print_function from tensorflow.contrib.kafka.python.ops import gen_dataset_ops from tensorflow.contrib.kafka.python.ops import kafka_op_loader # pylint: disable=unused-import from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.data.util import structure from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops -from tensorflow.python.framework import tensor_shape +from tensorflow.python.util import deprecation class KafkaDataset(dataset_ops.DatasetSource): """A Kafka Dataset that consumes the message. """ + @deprecation.deprecated( + None, + "tf.contrib.kafka will be removed in 2.0, the support for Apache Kafka " + "will continue to be provided through the tensorflow/io GitHub project.") def __init__(self, topics, servers="localhost", @@ -63,13 +68,5 @@ class KafkaDataset(dataset_ops.DatasetSource): self._group, self._eof, self._timeout) @property - def output_classes(self): - return ops.Tensor - - @property - def output_shapes(self): - return tensor_shape.scalar() - - @property - def output_types(self): - return dtypes.string + def _element_structure(self): + return structure.TensorStructure(dtypes.string, []) diff --git a/tensorflow/contrib/kinesis/python/ops/kinesis_dataset_ops.py b/tensorflow/contrib/kinesis/python/ops/kinesis_dataset_ops.py index 75806dbbeb1819bb0a6965bbc384e02df9895210..2b1d478a9b0fd12ca25c72da6872acccfd7285fc 100644 --- a/tensorflow/contrib/kinesis/python/ops/kinesis_dataset_ops.py +++ b/tensorflow/contrib/kinesis/python/ops/kinesis_dataset_ops.py @@ -20,9 +20,10 @@ from __future__ import print_function from tensorflow.contrib.kinesis.python.ops import gen_dataset_ops from tensorflow.contrib.kinesis.python.ops import kinesis_op_loader # pylint: disable=unused-import from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.data.util import structure from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops -from tensorflow.python.framework import tensor_shape +from tensorflow.python.util import deprecation class KinesisDataset(dataset_ops.DatasetSource): @@ -34,15 +35,12 @@ class KinesisDataset(dataset_ops.DatasetSource): For example, we can construct and use the KinesisDataset as follows: ```python + tf.enable_eager_execution() + dataset = tf.contrib.kinesis.KinesisDataset( "kinesis_stream_name", read_indefinitely=False) - next = dataset.make_one_shot_iterator().get_next() - with tf.Session() as sess: - while True: - try: - print(sess.run(nxt)) - except tf.errors.OutOfRangeError: - break + for element in dataset: + print(element) ``` Since Kinesis is a data streaming service, data may not be available @@ -53,6 +51,10 @@ class KinesisDataset(dataset_ops.DatasetSource): is returned immediately instead. """ + @deprecation.deprecated( + None, + "tf.contrib.kinesis will be removed in 2.0, the support for Kinesis " + "will continue to be provided through the tensorflow/io GitHub project.") def __init__(self, stream, shard="", @@ -84,13 +86,5 @@ class KinesisDataset(dataset_ops.DatasetSource): self._stream, self._shard, self._read_indefinitely, self._interval) @property - def output_classes(self): - return ops.Tensor - - @property - def output_shapes(self): - return tensor_shape.scalar() - - @property - def output_types(self): - return dtypes.string + def _element_structure(self): + return structure.TensorStructure(dtypes.string, []) diff --git a/tensorflow/contrib/layers/python/layers/layers_test.py b/tensorflow/contrib/layers/python/layers/layers_test.py index 0a4d2c6d4cb5cad7da93cea89478bc0fca2ac4d6..d791418c9d0f887058ceb535092fa8122da1aa75 100644 --- a/tensorflow/contrib/layers/python/layers/layers_test.py +++ b/tensorflow/contrib/layers/python/layers/layers_test.py @@ -1459,13 +1459,6 @@ class DropoutTest(test.TestCase): class FlattenTest(test.TestCase): - def testInvalidRank(self): - with ops.Graph().as_default() as g, self.session(g): - inputs = array_ops.placeholder(dtype=dtypes.float32) - inputs.set_shape(tensor_shape.TensorShape((5,))) - with self.assertRaisesRegexp(ValueError, 'incompatible with the layer'): - _layers.flatten(inputs) - def testUnknownLastDim(self): with ops.Graph().as_default() as g, self.session(g): inputs = array_ops.placeholder(dtype=dtypes.float32) @@ -1502,6 +1495,12 @@ class FlattenTest(test.TestCase): images.get_shape().num_elements()) self.assertEqual(output.get_shape()[0], images.get_shape()[0]) + def testFlatten0D(self): + with self.cached_session(): + scalars = random_ops.random_uniform((5,), seed=1, name='scalars') + output = _layers.flatten(scalars) + self.assertEqual(output.shape, (5, 1)) + def testFlattenBatchSize(self): height, width = 3, 3 with self.cached_session() as sess: diff --git a/tensorflow/contrib/learn/BUILD b/tensorflow/contrib/learn/BUILD index 238504f6d60aeb1a7ff25deab4a86881285e8c03..14065fcee51c014a1af227504eaaca1fa39941e1 100644 --- a/tensorflow/contrib/learn/BUILD +++ b/tensorflow/contrib/learn/BUILD @@ -274,6 +274,7 @@ py_test( name = "estimator_test", size = "medium", srcs = ["python/learn/estimators/estimator_test.py"], + shard_count = 2, srcs_version = "PY2AND3", tags = [ "manual", diff --git a/tensorflow/contrib/linear_optimizer/python/kernel_tests/sdca_ops_test.py b/tensorflow/contrib/linear_optimizer/python/kernel_tests/sdca_ops_test.py index 8466dc36d13e223aed4f1dfe8e39a6f91c99fa55..d49834dc860a8b4341ddd3720fde52281f7474f7 100644 --- a/tensorflow/contrib/linear_optimizer/python/kernel_tests/sdca_ops_test.py +++ b/tensorflow/contrib/linear_optimizer/python/kernel_tests/sdca_ops_test.py @@ -12,7 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Tests for SdcaModel.""" +"""Tests for SdcaModel (deprecated). + +This module and all its submodules are deprecated. To UPDATE or USE linear +optimizers, please check its latest version in core: +tensorflow_estimator/python/estimator/canned/linear_optimizer/. +""" from __future__ import absolute_import from __future__ import division diff --git a/tensorflow/contrib/linear_optimizer/python/ops/sdca_ops.py b/tensorflow/contrib/linear_optimizer/python/ops/sdca_ops.py index f3f1dcd98db5ae24af154d1f0851a0688d2bc611..c056a12fa5307a7e9ac4cf30e1386ddfd5cd7d75 100644 --- a/tensorflow/contrib/linear_optimizer/python/ops/sdca_ops.py +++ b/tensorflow/contrib/linear_optimizer/python/ops/sdca_ops.py @@ -12,7 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Proximal stochastic dual coordinate ascent optimizer for linear models.""" +# pylint: disable=line-too-long +"""Proximal stochastic dual coordinate ascent optimizer for linear models (deprecated). + +This module and all its submodules are deprecated. To UPDATE or USE linear +optimizers, please check its latest version in core: +tensorflow_estimator/python/estimator/canned/linear_optimizer/. +""" +# pylint: enable=line-too-long from __future__ import absolute_import from __future__ import division from __future__ import print_function @@ -40,6 +47,7 @@ from tensorflow.python.ops import variables as var_ops from tensorflow.python.ops.nn import log_poisson_loss from tensorflow.python.ops.nn import sigmoid_cross_entropy_with_logits from tensorflow.python.summary import summary +from tensorflow.python.util import deprecation __all__ = ['SdcaModel'] @@ -48,7 +56,7 @@ __all__ = ['SdcaModel'] class SdcaModel(object): """Stochastic dual coordinate ascent solver for linear models. - Loss functions supported: + Loss functions supported: * Binary logistic loss * Squared loss @@ -109,6 +117,10 @@ class SdcaModel(object): ``` """ + @deprecation.deprecated( + None, 'This class is deprecated. To UPDATE or USE linear optimizers, ' + 'please check its latest version in core: ' + 'tensorflow_estimator/python/estimator/canned/linear_optimizer/.') def __init__(self, examples, variables, options): """Create a new sdca optimizer.""" diff --git a/tensorflow/contrib/linear_optimizer/python/ops/sharded_mutable_dense_hashtable.py b/tensorflow/contrib/linear_optimizer/python/ops/sharded_mutable_dense_hashtable.py index a001555e8f257c88a52fdb40d4181f5cd9c92e84..a28394964a12013c43d85701b5a0ab5c559afd62 100644 --- a/tensorflow/contrib/linear_optimizer/python/ops/sharded_mutable_dense_hashtable.py +++ b/tensorflow/contrib/linear_optimizer/python/ops/sharded_mutable_dense_hashtable.py @@ -12,7 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Sharded mutable dense hash table.""" +"""Sharded mutable dense hash table (deprecated). + +This module and all its submodules are deprecated. To UPDATE or USE linear +optimizers, please check its latest version in core: +tensorflow_estimator/python/estimator/canned/linear_optimizer/. +""" from __future__ import absolute_import from __future__ import division @@ -28,6 +33,7 @@ from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import data_flow_ops from tensorflow.python.ops import math_ops +from tensorflow.python.util import deprecation # TODO(rohanj): This should subclass Checkpointable and implement @@ -45,6 +51,10 @@ class ShardedMutableDenseHashTable(object): # TODO(andreasst): consider moving this to lookup module + @deprecation.deprecated( + None, 'This class is deprecated. To UPDATE or USE linear optimizers, ' + 'please check its latest version in core: ' + 'tensorflow_estimator/python/estimator/canned/linear_optimizer/.') def __init__(self, key_dtype, value_dtype, diff --git a/tensorflow/contrib/linear_optimizer/python/ops/sharded_mutable_dense_hashtable_test.py b/tensorflow/contrib/linear_optimizer/python/ops/sharded_mutable_dense_hashtable_test.py index 2b56d0fa3a8b8564b7c73a62bd99cc900d6f5c54..2d1457f9e4cc576da696be191e718814dd9ff4e5 100644 --- a/tensorflow/contrib/linear_optimizer/python/ops/sharded_mutable_dense_hashtable_test.py +++ b/tensorflow/contrib/linear_optimizer/python/ops/sharded_mutable_dense_hashtable_test.py @@ -12,7 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Tests for sharded_mutable_dense_hashtable.py.""" +"""Tests for sharded_mutable_dense_hashtable.py (deprecated). + +This module and all its submodules are deprecated. To UPDATE or USE linear +optimizers, please check its latest version in core: +tensorflow_estimator/python/estimator/canned/linear_optimizer/. +""" from __future__ import absolute_import from __future__ import division diff --git a/tensorflow/contrib/linear_optimizer/python/ops/sparse_feature_column.py b/tensorflow/contrib/linear_optimizer/python/ops/sparse_feature_column.py index 003795233ff2b28e33fc10388ef25efb63c43bb0..64730f8eed1ff9bfcd4a980dceb28abb98e39f73 100644 --- a/tensorflow/contrib/linear_optimizer/python/ops/sparse_feature_column.py +++ b/tensorflow/contrib/linear_optimizer/python/ops/sparse_feature_column.py @@ -12,7 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Sparse feature column.""" +"""Sparse feature column (deprecated). + +This module and all its submodules are deprecated. To UPDATE or USE linear +optimizers, please check its latest version in core: +tensorflow_estimator/python/estimator/canned/linear_optimizer/. +""" from __future__ import absolute_import from __future__ import division @@ -21,6 +26,7 @@ from __future__ import print_function from tensorflow.python.framework import dtypes from tensorflow.python.framework.ops import internal_convert_to_tensor from tensorflow.python.framework.ops import name_scope +from tensorflow.python.util import deprecation class SparseFeatureColumn(object): @@ -68,6 +74,10 @@ class SparseFeatureColumn(object): @@feature_values """ + @deprecation.deprecated( + None, 'This class is deprecated. To UPDATE or USE linear optimizers, ' + 'please check its latest version in core: ' + 'tensorflow_estimator/python/estimator/canned/linear_optimizer/.') def __init__(self, example_indices, feature_indices, feature_values): """Creates a `SparseFeatureColumn` representation. diff --git a/tensorflow/contrib/linear_optimizer/python/ops/sparse_feature_column_test.py b/tensorflow/contrib/linear_optimizer/python/ops/sparse_feature_column_test.py index 51c4f68543da2f563481cc2d35b556796616cf9d..0ae780e1a100c7dadde7196803f2ae0d4bcb2334 100644 --- a/tensorflow/contrib/linear_optimizer/python/ops/sparse_feature_column_test.py +++ b/tensorflow/contrib/linear_optimizer/python/ops/sparse_feature_column_test.py @@ -12,7 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Tests for sparse_feature_column.py.""" +"""Tests for sparse_feature_column.py (deprecated). + +This module and all its submodules are deprecated. To UPDATE or USE linear +optimizers, please check its latest version in core: +tensorflow_estimator/python/estimator/canned/linear_optimizer/. +""" from __future__ import absolute_import from __future__ import division diff --git a/tensorflow/contrib/lookup/lookup_ops.py b/tensorflow/contrib/lookup/lookup_ops.py index e52fb5ab1431e086f99b4033a6216636a83bad79..229a72a780d5ccce8263444ffeae7700f6ac8613 100644 --- a/tensorflow/contrib/lookup/lookup_ops.py +++ b/tensorflow/contrib/lookup/lookup_ops.py @@ -91,7 +91,7 @@ def index_table_from_tensor(mapping, The bucket ID range is `[mapping size, mapping size + num_oov_buckets - 1]`. The underlying table must be initialized by calling - `tf.tables_initializer.run()` or `table.initializer.run()` once. + `session.run(tf.tables_initializer)` or `session.run(table.init)` once. Elements in `mapping` cannot have duplicates, otherwise when executing the table initializer op, it will throw a `FailedPreconditionError`. @@ -158,7 +158,7 @@ def string_to_index(tensor, mapping, default_value=-1, name=None): will throw a FailedPreconditionError. The underlying table must be initialized by calling - `tf.tables_initializer.run()` once. + `session.run(tf.tables_initializer)` once. For example: @@ -202,7 +202,7 @@ def index_to_string_table_from_tensor(mapping, default_value="UNK", name=None): (an out-of-vocabulary entry) is assigned the `default_value` The underlying table must be initialized by calling - `tf.tables_initializer.run()` or `table.initializer.run()` once. + `session.run(tf.tables_initializer)` or `session.run(table.init)` once. Elements in `mapping` cannot have duplicates, otherwise when executing the table initializer op, it will throw a `FailedPreconditionError`. @@ -257,7 +257,7 @@ def index_to_string(tensor, mapping, default_value="UNK", name=None): (an out-of-vocabulary entry) is assigned the `default_value` The underlying table must be initialized by calling - `tf.tables_initializer.run()` once. + `session.run(tf.tables_initializer)` once. For example: diff --git a/tensorflow/contrib/lookup/lookup_ops_test.py b/tensorflow/contrib/lookup/lookup_ops_test.py index 5e99ef460518fa761b12533e5dc07dc252f1d582..9b2c2dd87cc8a92fbb6b45504939be3788b60839 100644 --- a/tensorflow/contrib/lookup/lookup_ops_test.py +++ b/tensorflow/contrib/lookup/lookup_ops_test.py @@ -25,6 +25,7 @@ import six from tensorflow.contrib import lookup from tensorflow.python.client import session from tensorflow.python.data.experimental.ops import counter +from tensorflow.python.data.ops import dataset_ops from tensorflow.python.eager import context from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes @@ -2737,7 +2738,7 @@ class MutableHashTableBenchmark(test.Benchmark): def benchmark_many_repeated_scalar_insert_scalar(self): table = self._create_table() - c = counter.Counter().make_one_shot_iterator().get_next() + c = dataset_ops.make_one_shot_iterator(counter.Counter()).get_next() value = variables.Variable(1.0) insert = table.insert(c, value) size = table.size() @@ -2758,7 +2759,7 @@ class MutableHashTableBenchmark(test.Benchmark): def benchmark_many_repeated_batch_32_insert_scalar(self): table = self._create_table() - c = counter.Counter().make_one_shot_iterator().get_next() + c = dataset_ops.make_one_shot_iterator(counter.Counter()).get_next() value = variables.Variable([1.0] * 32) insert = table.insert(32 * c + list(range(32)), value) size = table.size() diff --git a/tensorflow/contrib/makefile/download_dependencies.sh b/tensorflow/contrib/makefile/download_dependencies.sh index b396c527673902d61072dc9cf7d2766476be8369..2a5232b476712a96f84be0f4725beb78bc138297 100755 --- a/tensorflow/contrib/makefile/download_dependencies.sh +++ b/tensorflow/contrib/makefile/download_dependencies.sh @@ -30,11 +30,13 @@ EIGEN_URL="$(grep -o 'http.*bitbucket.org/eigen/eigen/get/.*tar\.gz' "${BZL_FILE GEMMLOWP_URL="$(grep -o 'https://mirror.bazel.build/github.com/google/gemmlowp/.*zip' "${BZL_FILE_PATH}" | head -n1)" GOOGLETEST_URL="https://github.com/google/googletest/archive/release-1.8.0.tar.gz" NSYNC_URL="$(grep -o 'https://mirror.bazel.build/github.com/google/nsync/.*tar\.gz' "${BZL_FILE_PATH}" | head -n1)" -# Note: The Protobuf source in `tensorflow/workspace.bzl` in TensorFlow -# 1.10 branch does not work. `make distclean` fails and blocks the build -# process. For now we're hardcoding to the version which is used by -# TensorFlow 1.9. -PROTOBUF_URL="https://mirror.bazel.build/github.com/google/protobuf/archive/v3.6.0.tar.gz" + +# Note: The protobuf repo needs to be cloned due to its submodules. +# These variables contain the GitHub repo and the sha, from `tensorflow/workspace.bzl`, +# from which to clone it from and checkout to. +readonly PROTOBUF_REPO="https://github.com/protocolbuffers/protobuf.git" +readonly PROTOBUF_TAG="$(grep -o 'https://github.com/protocolbuffers/protobuf/archive/.*tar\.gz' "${BZL_FILE_PATH}" | head -n1 | awk '{print substr($0, index($0, "archive") + 8, index($0, "tar") - index($0, "archive") - 9) }')" + # TODO (yongtang): Replace the following with 'https://mirror.bazel.build/github.com/google/re2/.*tar\.gz' once # the archive has been propagated in mirror.bazel.build. RE2_URL="$(grep -o 'https://github.com/google/re2/.*tar\.gz' "${BZL_FILE_PATH}" | head -n1)" @@ -91,11 +93,34 @@ download_and_extract() { find "${dir}" -type f -name '*BUILD' -delete } +function clone_repository() { + local repo_url="${1}" + local destination_directory="${2}" + local commit_sha="${3}" + + if [[ -d "${destination_directory}" ]]; then + rm -rf "${destination_directory}" + fi + + git clone "${repo_url}" "${destination_directory}" + + pushd "$(pwd)" 1>/dev/null + + cd "${destination_directory}" + + if [[ -n "${commit_sha}" ]]; then + git checkout "${PROTOBUF_TAG}" + fi + + git submodule update --init + + popd 1>/dev/null +} + download_and_extract "${EIGEN_URL}" "${DOWNLOADS_DIR}/eigen" download_and_extract "${GEMMLOWP_URL}" "${DOWNLOADS_DIR}/gemmlowp" download_and_extract "${GOOGLETEST_URL}" "${DOWNLOADS_DIR}/googletest" download_and_extract "${NSYNC_URL}" "${DOWNLOADS_DIR}/nsync" -download_and_extract "${PROTOBUF_URL}" "${DOWNLOADS_DIR}/protobuf" download_and_extract "${RE2_URL}" "${DOWNLOADS_DIR}/re2" download_and_extract "${FFT2D_URL}" "${DOWNLOADS_DIR}/fft2d" download_and_extract "${DOUBLE_CONVERSION_URL}" "${DOWNLOADS_DIR}/double_conversion" @@ -106,6 +131,8 @@ download_and_extract "${CUB_URL}" "${DOWNLOADS_DIR}/cub/external/cub_archive" download_and_extract "${FARMHASH_URL}" "${DOWNLOADS_DIR}/farmhash" download_and_extract "${FLATBUFFERS_URL}" "${DOWNLOADS_DIR}/flatbuffers" +clone_repository "${PROTOBUF_REPO}" "${DOWNLOADS_DIR}/protobuf" "${PROTOBUF_TAG}" + replace_by_sed 's#static uint32x4_t p4ui_CONJ_XOR = vld1q_u32( conj_XOR_DATA );#static uint32x4_t p4ui_CONJ_XOR; // = vld1q_u32( conj_XOR_DATA ); - Removed by script#' \ "${DOWNLOADS_DIR}/eigen/Eigen/src/Core/arch/NEON/Complex.h" replace_by_sed 's#static uint32x2_t p2ui_CONJ_XOR = vld1_u32( conj_XOR_DATA );#static uint32x2_t p2ui_CONJ_XOR;// = vld1_u32( conj_XOR_DATA ); - Removed by scripts#' \ diff --git a/tensorflow/contrib/metrics/python/metrics/classification.py b/tensorflow/contrib/metrics/python/metrics/classification.py index 062deb74b165329d8e72efa73b9d81f4174f8831..9aabc4bec3053871e3ff6cd3a88fd76d293f48cc 100644 --- a/tensorflow/contrib/metrics/python/metrics/classification.py +++ b/tensorflow/contrib/metrics/python/metrics/classification.py @@ -18,13 +18,13 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from tensorflow.python.distribute import distribution_strategy_context from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import metrics_impl from tensorflow.python.ops import variable_scope -from tensorflow.python.training import distribution_strategy_context # TODO(nsilberman): move into metrics/python/ops/ diff --git a/tensorflow/contrib/metrics/python/metrics/classification_test.py b/tensorflow/contrib/metrics/python/metrics/classification_test.py index d6a670f97b32a29129cb9ea0cd71c5a2b7597a47..e789d2cb9dfbac7b1e145be48b3f707af3fd4e18 100644 --- a/tensorflow/contrib/metrics/python/metrics/classification_test.py +++ b/tensorflow/contrib/metrics/python/metrics/classification_test.py @@ -291,12 +291,11 @@ class F1ScoreTest(test.TestCase): labels = labels.astype(np.float32) predictions = predictions.astype(np.float32) - tf_predictions, tf_labels = (dataset_ops.Dataset - .from_tensor_slices((predictions, labels)) - .repeat() - .batch(batch_size) - .make_one_shot_iterator() - .get_next()) + tf_predictions, tf_labels = dataset_ops.make_one_shot_iterator( + dataset_ops.Dataset + .from_tensor_slices((predictions, labels)) + .repeat() + .batch(batch_size)).get_next() f1, f1_op = classification.f1_score(tf_labels, tf_predictions, num_thresholds=3) diff --git a/tensorflow/contrib/mixed_precision/python/loss_scale_manager_test.py b/tensorflow/contrib/mixed_precision/python/loss_scale_manager_test.py index 1b0383d24c0c472b4875d15c3650e37dfd2439e1..c922d0cd11fda3c51a51ceccf69798df7ce75f26 100644 --- a/tensorflow/contrib/mixed_precision/python/loss_scale_manager_test.py +++ b/tensorflow/contrib/mixed_precision/python/loss_scale_manager_test.py @@ -29,7 +29,7 @@ from tensorflow.python.platform import test def _GetExampleIter(inputs): dataset = dataset_ops.Dataset.from_tensor_slices(inputs) - return dataset.make_one_shot_iterator() + return dataset_ops.make_one_shot_iterator(dataset) class FixedLossScaleManagerTest(test.TestCase): diff --git a/tensorflow/contrib/mixed_precision/python/loss_scale_optimizer_test.py b/tensorflow/contrib/mixed_precision/python/loss_scale_optimizer_test.py index 9009df0eefec13146090ba5fc2096e71ba6eb89d..33f9a43e803ea845a25bba284e41e5a0e6228dad 100644 --- a/tensorflow/contrib/mixed_precision/python/loss_scale_optimizer_test.py +++ b/tensorflow/contrib/mixed_precision/python/loss_scale_optimizer_test.py @@ -132,7 +132,7 @@ class LossScaleOptimizerTest(test.TestCase): x = variable_scope.get_variable("x", initializer=1., dtype=dtypes.float32) dataset = dataset_ops.Dataset.from_tensor_slices([np.nan, np.inf, 0.1]) - itr = dataset.make_one_shot_iterator() + itr = dataset_ops.make_one_shot_iterator(dataset) lr = 1 opt = gd.GradientDescentOptimizer(lr) @@ -182,7 +182,7 @@ class LossScaleOptimizerTest(test.TestCase): x = variable_scope.get_variable("x", initializer=1., dtype=dtypes.float32) dataset = dataset_ops.Dataset.from_tensor_slices([np.nan, np.inf, 0.1]) - itr = dataset.make_one_shot_iterator() + itr = dataset_ops.make_one_shot_iterator(dataset) lr = 1 init_loss_scale = 8 diff --git a/tensorflow/contrib/model_pruning/python/pruning.py b/tensorflow/contrib/model_pruning/python/pruning.py index f6b4373edd0544555dd16a373802d2feb5d674b1..43ea66ac5a178f6ffe87df99ddced3d0442111c1 100644 --- a/tensorflow/contrib/model_pruning/python/pruning.py +++ b/tensorflow/contrib/model_pruning/python/pruning.py @@ -214,7 +214,7 @@ def get_pruning_hparams(): target_sparsity=0.5, sparsity_function_begin_step=0, sparsity_function_end_step=100, - sparsity_function_exponent=3, + sparsity_function_exponent=3.0, use_tpu=False) diff --git a/tensorflow/contrib/opt/BUILD b/tensorflow/contrib/opt/BUILD index f4ac70eb1a720c2acc3ef942f269228156749cba..0446e823d95f8ecbed6a0c34a83ade009e68448b 100644 --- a/tensorflow/contrib/opt/BUILD +++ b/tensorflow/contrib/opt/BUILD @@ -14,6 +14,7 @@ py_library( name = "opt_py", srcs = [ "__init__.py", + "python/training/adam_gs_optimizer.py", "python/training/adamax.py", "python/training/addsign.py", "python/training/agn_optimizer.py", @@ -22,6 +23,7 @@ py_library( "python/training/external_optimizer.py", "python/training/ggt.py", "python/training/lars_optimizer.py", + "python/training/lazy_adam_gs_optimizer.py", "python/training/lazy_adam_optimizer.py", "python/training/matrix_functions.py", "python/training/model_average_optimizer.py", @@ -60,6 +62,21 @@ py_library( ], ) +py_test( + name = "adam_gs_optimizer_test", + srcs = ["python/training/adam_gs_optimizer_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":opt_py", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:math_ops", + "//tensorflow/python:training", + "//third_party/py/numpy", + ], +) + py_test( name = "adamax_test", srcs = ["python/training/adamax_test.py"], @@ -148,6 +165,25 @@ py_test( ], ) +py_test( + name = "lazy_adam_gs_optimizer_test", + srcs = ["python/training/lazy_adam_gs_optimizer_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":opt_py", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:constant_op", + "//tensorflow/python:dtypes", + "//tensorflow/python:framework_ops", + "//tensorflow/python:math_ops", + "//tensorflow/python:resource_variable_ops", + "//tensorflow/python:variables", + "//third_party/py/numpy", + "@absl_py//absl/testing:parameterized", + ], +) + py_test( name = "lazy_adam_optimizer_test", srcs = ["python/training/lazy_adam_optimizer_test.py"], diff --git a/tensorflow/contrib/opt/__init__.py b/tensorflow/contrib/opt/__init__.py index c7ea68efa9a13a471bba3f41d0600855793b20a2..e8fc52342ceabb47da97ca0f3c8a01e419a221a1 100644 --- a/tensorflow/contrib/opt/__init__.py +++ b/tensorflow/contrib/opt/__init__.py @@ -19,6 +19,7 @@ from __future__ import division from __future__ import print_function # pylint: disable=wildcard-import +from tensorflow.contrib.opt.python.training.adam_gs_optimizer import * from tensorflow.contrib.opt.python.training.adamax import * from tensorflow.contrib.opt.python.training.addsign import * from tensorflow.contrib.opt.python.training.agn_optimizer import * @@ -28,6 +29,7 @@ from tensorflow.contrib.opt.python.training.external_optimizer import * from tensorflow.contrib.opt.python.training.lars_optimizer import * from tensorflow.contrib.opt.python.training.ggt import * from tensorflow.contrib.opt.python.training.lazy_adam_optimizer import * +from tensorflow.contrib.opt.python.training.lazy_adam_gs_optimizer import * from tensorflow.contrib.opt.python.training.model_average_optimizer import * from tensorflow.contrib.opt.python.training.moving_average_optimizer import * from tensorflow.contrib.opt.python.training.multitask_optimizer_wrapper import * @@ -44,12 +46,14 @@ from tensorflow.python.util.all_util import remove_undocumented _allowed_symbols = [ 'AdaMaxOptimizer', + 'AdamGSOptimizer', 'PowerSignOptimizer', 'AddSignOptimizer', 'DelayCompensatedGradientDescentOptimizer', 'DropStaleGradientOptimizer', 'ExternalOptimizerInterface', 'LARSOptimizer', + 'LazyAdamGSOptimizer', 'LazyAdamOptimizer', 'NadamOptimizer', 'MovingAverageOptimizer', diff --git a/tensorflow/contrib/opt/python/training/adam_gs_optimizer.py b/tensorflow/contrib/opt/python/training/adam_gs_optimizer.py new file mode 100644 index 0000000000000000000000000000000000000000..3fb649ea82e79b3bc78a2da6d5c3e9a071adec6d --- /dev/null +++ b/tensorflow/contrib/opt/python/training/adam_gs_optimizer.py @@ -0,0 +1,217 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Adam rewrite to use global step for computing beta1 & beta2 accumulation.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import resource_variable_ops +from tensorflow.python.ops import state_ops +from tensorflow.python.training import optimizer +from tensorflow.python.training import training_ops +from tensorflow.python.util.tf_export import tf_export + + +@tf_export("train.AdamOptimizer") +class AdamGSOptimizer(optimizer.Optimizer): + """Optimizer that implements the Adam algorithm. + + See [Kingma et al., 2014](http://arxiv.org/abs/1412.6980) + ([pdf](http://arxiv.org/pdf/1412.6980.pdf)). + """ + + def __init__(self, global_step=0, learning_rate=0.001, + beta1=0.9, beta2=0.999, epsilon=1e-8, + use_locking=False, name="Adam"): + """Construct a new Adam optimizer. + + Branched from tf.train.AdamOptimizer. The only difference is to pass + global step for computing beta1 and beta2 accumulators, instead of having + optimizer keep its own independent beta1 and beta2 accumulators as non-slot + variables. + + Initialization: + + $$m_0 := 0 \text{(Initialize initial 1st moment vector)}$$ + $$v_0 := 0 \text{(Initialize initial 2nd moment vector)}$$ + $$t := 0 \text{(Initialize timestep)}$$ + + The update rule for `variable` with gradient `g` uses an optimization + described at the end of section2 of the paper: + + $$t := t + 1$$ + $$lr_t := \text{learning\_rate} * \sqrt{1 - beta_2^t} / (1 - beta_1^t)$$ + + $$m_t := beta_1 * m_{t-1} + (1 - beta_1) * g$$ + $$v_t := beta_2 * v_{t-1} + (1 - beta_2) * g * g$$ + $$variable := variable - lr_t * m_t / (\sqrt{v_t} + \epsilon)$$ + + The default value of 1e-8 for epsilon might not be a good default in + general. For example, when training an Inception network on ImageNet a + current good choice is 1.0 or 0.1. Note that since AdamOptimizer uses the + formulation just before Section 2.1 of the Kingma and Ba paper rather than + the formulation in Algorithm 1, the "epsilon" referred to here is "epsilon + hat" in the paper. + + The sparse implementation of this algorithm (used when the gradient is an + IndexedSlices object, typically because of `tf.gather` or an embedding + lookup in the forward pass) does apply momentum to variable slices even if + they were not used in the forward pass (meaning they have a gradient equal + to zero). Momentum decay (beta1) is also applied to the entire momentum + accumulator. This means that the sparse behavior is equivalent to the dense + behavior (in contrast to some momentum implementations which ignore momentum + unless a variable slice was actually used). + + Args: + global_step: tensorflow variable indicating the step. + learning_rate: A Tensor or a floating point value. The learning rate. + beta1: A float value or a constant float tensor. + The exponential decay rate for the 1st moment estimates. + beta2: A float value or a constant float tensor. + The exponential decay rate for the 2nd moment estimates. + epsilon: A small constant for numerical stability. This epsilon is + "epsilon hat" in the Kingma and Ba paper (in the formula just before + Section 2.1), not the epsilon in Algorithm 1 of the paper. + use_locking: If True use locks for update operations. + name: Optional name for the operations created when applying gradients. + Defaults to "Adam". + + @compatibility(eager) + When eager execution is enabled, `learning_rate`, `beta1`, `beta2`, and + `epsilon` can each be a callable that takes no arguments and returns the + actual value to use. This can be useful for changing these values across + different invocations of optimizer functions. + @end_compatibility + """ + super(AdamGSOptimizer, self).__init__(use_locking, name) + self._lr = learning_rate + self._beta1 = beta1 + self._beta2 = beta2 + self._epsilon = epsilon + self._global_step = global_step + self._global_step_on_worker = None + + # Tensor versions of the constructor arguments, created in _prepare(). + self._lr_t = None + self._beta1_t = None + self._beta2_t = None + self._epsilon_t = None + + # Created in SparseApply if needed. + self._updated_lr = None + + def _get_beta_accumulators(self): + return (math_ops.pow(self._beta1_t, self._global_step_on_worker), + math_ops.pow(self._beta2_t, self._global_step_on_worker)) + + def _create_slots(self, var_list): + # Create slots for the first and second moments. + for v in var_list: + self._zeros_slot(v, "m", self._name) + self._zeros_slot(v, "v", self._name) + + def _prepare(self): + lr = self._call_if_callable(self._lr) + beta1 = self._call_if_callable(self._beta1) + beta2 = self._call_if_callable(self._beta2) + epsilon = self._call_if_callable(self._epsilon) + + self._lr_t = ops.convert_to_tensor(lr, name="learning_rate") + self._beta1_t = ops.convert_to_tensor(beta1, name="beta1") + self._beta2_t = ops.convert_to_tensor(beta2, name="beta2") + self._epsilon_t = ops.convert_to_tensor(epsilon, name="epsilon") + + # Performance optimization so that worker creates a copy of the global step + # to avoid overloading the parameter server holding the global step. + self._global_step_on_worker = math_ops.cast( + array_ops.identity(self._global_step) + 1, dtypes.float32) + + def _apply_dense(self, grad, var): + m = self.get_slot(var, "m") + v = self.get_slot(var, "v") + beta1_power, beta2_power = self._get_beta_accumulators() + return training_ops.apply_adam( + var, m, v, + math_ops.cast(beta1_power, var.dtype.base_dtype), + math_ops.cast(beta2_power, var.dtype.base_dtype), + math_ops.cast(self._lr_t, var.dtype.base_dtype), + math_ops.cast(self._beta1_t, var.dtype.base_dtype), + math_ops.cast(self._beta2_t, var.dtype.base_dtype), + math_ops.cast(self._epsilon_t, var.dtype.base_dtype), + grad, use_locking=self._use_locking).op + + def _resource_apply_dense(self, grad, var): + m = self.get_slot(var, "m") + v = self.get_slot(var, "v") + beta1_power, beta2_power = self._get_beta_accumulators() + return training_ops.resource_apply_adam( + var.handle, m.handle, v.handle, + math_ops.cast(beta1_power, grad.dtype.base_dtype), + math_ops.cast(beta2_power, grad.dtype.base_dtype), + math_ops.cast(self._lr_t, grad.dtype.base_dtype), + math_ops.cast(self._beta1_t, grad.dtype.base_dtype), + math_ops.cast(self._beta2_t, grad.dtype.base_dtype), + math_ops.cast(self._epsilon_t, grad.dtype.base_dtype), + grad, use_locking=self._use_locking) + + def _apply_sparse_shared(self, grad, var, indices, scatter_add): + beta1_power, beta2_power = self._get_beta_accumulators() + beta1_power = math_ops.cast(beta1_power, var.dtype.base_dtype) + beta2_power = math_ops.cast(beta2_power, var.dtype.base_dtype) + lr_t = math_ops.cast(self._lr_t, var.dtype.base_dtype) + beta1_t = math_ops.cast(self._beta1_t, var.dtype.base_dtype) + beta2_t = math_ops.cast(self._beta2_t, var.dtype.base_dtype) + epsilon_t = math_ops.cast(self._epsilon_t, var.dtype.base_dtype) + lr = (lr_t * math_ops.sqrt(1 - beta2_power) / (1 - beta1_power)) + # m_t = beta1 * m + (1 - beta1) * g_t + m = self.get_slot(var, "m") + m_scaled_g_values = grad * (1 - beta1_t) + m_t = state_ops.assign(m, m * beta1_t, + use_locking=self._use_locking) + with ops.control_dependencies([m_t]): + m_t = scatter_add(m, indices, m_scaled_g_values) + # v_t = beta2 * v + (1 - beta2) * (g_t * g_t) + v = self.get_slot(var, "v") + v_scaled_g_values = (grad * grad) * (1 - beta2_t) + v_t = state_ops.assign(v, v * beta2_t, use_locking=self._use_locking) + with ops.control_dependencies([v_t]): + v_t = scatter_add(v, indices, v_scaled_g_values) + v_sqrt = math_ops.sqrt(v_t) + var_update = state_ops.assign_sub(var, + lr * m_t / (v_sqrt + epsilon_t), + use_locking=self._use_locking) + return control_flow_ops.group(*[var_update, m_t, v_t]) + + def _apply_sparse(self, grad, var): + return self._apply_sparse_shared( + grad.values, var, grad.indices, + lambda x, i, v: state_ops.scatter_add( # pylint: disable=g-long-lambda + x, i, v, use_locking=self._use_locking)) + + def _resource_scatter_add(self, x, i, v): + with ops.control_dependencies( + [resource_variable_ops.resource_scatter_add( + x.handle, i, v)]): + return x.value() + + def _resource_apply_sparse(self, grad, var, indices): + return self._apply_sparse_shared( + grad, var, indices, self._resource_scatter_add) diff --git a/tensorflow/contrib/opt/python/training/adam_gs_optimizer_test.py b/tensorflow/contrib/opt/python/training/adam_gs_optimizer_test.py new file mode 100644 index 0000000000000000000000000000000000000000..c68c965aef3729bebe7d0e0dd707c344321d9e3f --- /dev/null +++ b/tensorflow/contrib/opt/python/training/adam_gs_optimizer_test.py @@ -0,0 +1,382 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Tests for AdamGS.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.contrib.opt.python.training import adam_gs_optimizer +from tensorflow.python.client import session +from tensorflow.python.eager import context +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.framework import test_util +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import resource_variable_ops +from tensorflow.python.ops import variables +from tensorflow.python.platform import test + + +def adam_update_numpy(param, + g_t, + t, + m, + v, + alpha=0.001, + beta1=0.9, + beta2=0.999, + epsilon=1e-8): + alpha_t = alpha * np.sqrt(1 - beta2**t) / (1 - beta1**t) + + m_t = beta1 * m + (1 - beta1) * g_t + v_t = beta2 * v + (1 - beta2) * g_t * g_t + + param_t = param - alpha_t * m_t / (np.sqrt(v_t) + epsilon) + return param_t, m_t, v_t + + +class AdamGSOptimizerTest(test.TestCase): + + def doTestSparse(self, use_resource=False): + for dtype in [dtypes.half, dtypes.float32, dtypes.float64]: + with self.cached_session(): + # Initialize variables for numpy implementation. + m0, v0, m1, v1 = 0.0, 0.0, 0.0, 0.0 + var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype) + grads0_np = np.array([0.1, 0.1], dtype=dtype.as_numpy_dtype) + var1_np = np.array([3.0, 4.0], dtype=dtype.as_numpy_dtype) + grads1_np = np.array([0.01, 0.01], dtype=dtype.as_numpy_dtype) + + if use_resource: + global_step = resource_variable_ops.ResourceVariable( + array_ops.zeros([], dtypes.int64)) + var0 = resource_variable_ops.ResourceVariable(var0_np) + var1 = resource_variable_ops.ResourceVariable(var1_np) + else: + global_step = variables.Variable(array_ops.zeros([], dtypes.int64)) + var0 = variables.Variable(var0_np) + var1 = variables.Variable(var1_np) + grads0_np_indices = np.array([0, 1], dtype=np.int32) + grads0 = ops.IndexedSlices( + constant_op.constant(grads0_np), + constant_op.constant(grads0_np_indices), constant_op.constant([2])) + grads1_np_indices = np.array([0, 1], dtype=np.int32) + grads1 = ops.IndexedSlices( + constant_op.constant(grads1_np), + constant_op.constant(grads1_np_indices), constant_op.constant([2])) + opt = adam_gs_optimizer.AdamGSOptimizer(global_step=global_step) + update = opt.apply_gradients(zip([grads0, grads1], [var0, var1]), + global_step=global_step) + variables.global_variables_initializer().run() + + # Fetch params to validate initial values + self.assertAllClose([1.0, 2.0], self.evaluate(var0)) + self.assertAllClose([3.0, 4.0], self.evaluate(var1)) + + beta1_power, beta2_power = opt._get_beta_accumulators() + + # Run 3 steps of Adam + for t in range(1, 4): + self.assertAllCloseAccordingToType(0.9**t, self.evaluate(beta1_power)) + self.assertAllCloseAccordingToType(0.999**t, + self.evaluate(beta2_power)) + update.run() + + var0_np, m0, v0 = adam_update_numpy(var0_np, grads0_np, t, m0, v0) + var1_np, m1, v1 = adam_update_numpy(var1_np, grads1_np, t, m1, v1) + + # Validate updated params + self.assertAllCloseAccordingToType(var0_np, self.evaluate(var0)) + self.assertAllCloseAccordingToType(var1_np, self.evaluate(var1)) + + def testSparse(self): + self.doTestSparse(use_resource=False) + + def testResourceSparse(self): + self.doTestSparse(use_resource=True) + + def testSparseDevicePlacement(self): + for index_dtype in [dtypes.int32, dtypes.int64]: + with self.cached_session(force_gpu=test.is_gpu_available()): + # If a GPU is available, tests that all optimizer ops can be placed on + # it (i.e. they have GPU kernels). + var = variables.Variable([[1.0], [2.0]]) + indices = constant_op.constant([0, 1], dtype=index_dtype) + gathered_sum = math_ops.reduce_sum(array_ops.gather(var, indices)) + optimizer = adam_gs_optimizer.AdamGSOptimizer(3.0) + minimize_op = optimizer.minimize(gathered_sum) + variables.global_variables_initializer().run() + minimize_op.run() + + def testSparseRepeatedIndices(self): + for dtype in [dtypes.half, dtypes.float32, dtypes.float64]: + with self.cached_session(): + repeated_index_global_step = variables.Variable( + array_ops.zeros([], dtypes.int64)) + aggregated_global_step = variables.Variable( + array_ops.zeros([], dtypes.int64)) + repeated_index_update_var = variables.Variable( + [[1.0], [2.0]], dtype=dtype) + aggregated_update_var = variables.Variable( + [[1.0], [2.0]], dtype=dtype) + grad_repeated_index = ops.IndexedSlices( + constant_op.constant( + [0.1, 0.1], shape=[2, 1], dtype=dtype), + constant_op.constant([1, 1]), + constant_op.constant([2, 1])) + grad_aggregated = ops.IndexedSlices( + constant_op.constant( + [0.2], shape=[1, 1], dtype=dtype), + constant_op.constant([1]), + constant_op.constant([2, 1])) + repeated_update = adam_gs_optimizer.AdamGSOptimizer( + global_step=repeated_index_global_step).apply_gradients( + [(grad_repeated_index, repeated_index_update_var)], + global_step=repeated_index_global_step) + aggregated_update = adam_gs_optimizer.AdamGSOptimizer( + global_step=aggregated_global_step).apply_gradients( + [(grad_aggregated, aggregated_update_var)], + global_step=aggregated_global_step) + variables.global_variables_initializer().run() + self.assertAllClose(aggregated_update_var.eval(), + self.evaluate(repeated_index_update_var)) + for _ in range(3): + repeated_update.run() + aggregated_update.run() + self.assertAllClose(aggregated_update_var.eval(), + self.evaluate(repeated_index_update_var)) + + def doTestBasic(self, use_resource=False, use_callable_params=False): + for i, dtype in enumerate([dtypes.half, dtypes.float32, dtypes.float64]): + with self.session(graph=ops.Graph()): + # Initialize variables for numpy implementation. + m0, v0, m1, v1 = 0.0, 0.0, 0.0, 0.0 + var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype) + grads0_np = np.array([0.1, 0.1], dtype=dtype.as_numpy_dtype) + var1_np = np.array([3.0, 4.0], dtype=dtype.as_numpy_dtype) + grads1_np = np.array([0.01, 0.01], dtype=dtype.as_numpy_dtype) + + if use_resource: + global_step = resource_variable_ops.ResourceVariable( + array_ops.zeros([], dtypes.int64), name="global_step_%d" % i) + var0 = resource_variable_ops.ResourceVariable( + var0_np, name="var0_%d" % i) + var1 = resource_variable_ops.ResourceVariable( + var1_np, name="var1_%d" % i) + else: + global_step = variables.Variable(array_ops.zeros([], dtypes.int64)) + var0 = variables.Variable(var0_np) + var1 = variables.Variable(var1_np) + grads0 = constant_op.constant(grads0_np) + grads1 = constant_op.constant(grads1_np) + + learning_rate = lambda: 0.001 + beta1 = lambda: 0.9 + beta2 = lambda: 0.999 + epsilon = lambda: 1e-8 + if not use_callable_params: + learning_rate = learning_rate() + beta1 = beta1() + beta2 = beta2() + epsilon = epsilon() + + opt = adam_gs_optimizer.AdamGSOptimizer(global_step=global_step, + learning_rate=learning_rate) + update = opt.apply_gradients(zip([grads0, grads1], [var0, var1]), + global_step=global_step) + opt_variables = opt.variables() + beta1_power, beta2_power = opt._get_beta_accumulators() + self.assertTrue(beta1_power is not None) + self.assertTrue(beta2_power is not None) + self.assertNotIn(beta1_power, opt_variables) + self.assertNotIn(beta2_power, opt_variables) + + if not context.executing_eagerly(): + with ops.Graph().as_default(): + # Shouldn't return non-slot variables from other graphs. + self.assertEqual(0, len(opt.variables())) + self.evaluate(variables.global_variables_initializer()) + # Fetch params to validate initial values + self.assertAllClose([1.0, 2.0], self.evaluate(var0)) + self.assertAllClose([3.0, 4.0], self.evaluate(var1)) + + # Run 3 steps of Adam + for t in range(1, 4): + if not context.executing_eagerly(): + self.evaluate(update) + self.assertAllCloseAccordingToType( + 0.9**(t + 1), self.evaluate(beta1_power)) + self.assertAllCloseAccordingToType( + 0.999**(t + 1), self.evaluate(beta2_power)) + else: + if t > 1: + opt.apply_gradients(zip([grads0, grads1], [var0, var1]), + global_step=global_step) + beta1_power, beta2_power = opt._get_beta_accumulators() + self.assertAllCloseAccordingToType( + 0.9**t, self.evaluate(beta1_power)) + self.assertAllCloseAccordingToType( + 0.999**t, self.evaluate(beta2_power)) + + var0_np, m0, v0 = adam_update_numpy(var0_np, grads0_np, t, m0, v0) + var1_np, m1, v1 = adam_update_numpy(var1_np, grads1_np, t, m1, v1) + + # Validate updated params + self.assertAllCloseAccordingToType(var0_np, self.evaluate(var0)) + self.assertAllCloseAccordingToType(var1_np, self.evaluate(var1)) + if use_resource: + self.assertEqual("var0_%d/Adam:0" % (i,), + opt.get_slot(var=var0, name="m").name) + + def testBasic(self): + with self.cached_session(): + self.doTestBasic(use_resource=False) + + @test_util.run_in_graph_and_eager_modes(reset_test=True) + def testResourceBasic(self): + self.doTestBasic(use_resource=True) + + def testBasicCallableParams(self): + with context.eager_mode(): + self.doTestBasic(use_resource=True, use_callable_params=True) + + def testTensorLearningRate(self): + for dtype in [dtypes.half, dtypes.float32, dtypes.float64]: + with self.cached_session(): + global_step = variables.Variable(array_ops.zeros([], dtypes.int64)) + # Initialize variables for numpy implementation. + m0, v0, m1, v1 = 0.0, 0.0, 0.0, 0.0 + var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype) + grads0_np = np.array([0.1, 0.1], dtype=dtype.as_numpy_dtype) + var1_np = np.array([3.0, 4.0], dtype=dtype.as_numpy_dtype) + grads1_np = np.array([0.01, 0.01], dtype=dtype.as_numpy_dtype) + + var0 = variables.Variable(var0_np) + var1 = variables.Variable(var1_np) + grads0 = constant_op.constant(grads0_np) + grads1 = constant_op.constant(grads1_np) + opt = adam_gs_optimizer.AdamGSOptimizer( + global_step=global_step, learning_rate=constant_op.constant(0.001)) + update = opt.apply_gradients(zip([grads0, grads1], [var0, var1]), + global_step=global_step) + variables.global_variables_initializer().run() + + # Fetch params to validate initial values + self.assertAllClose([1.0, 2.0], self.evaluate(var0)) + self.assertAllClose([3.0, 4.0], self.evaluate(var1)) + + beta1_power, beta2_power = opt._get_beta_accumulators() + + # Run 3 steps of Adam + for t in range(1, 4): + self.assertAllCloseAccordingToType(0.9**t, self.evaluate(beta1_power)) + self.assertAllCloseAccordingToType(0.999**t, + self.evaluate(beta2_power)) + update.run() + + var0_np, m0, v0 = adam_update_numpy(var0_np, grads0_np, t, m0, v0) + var1_np, m1, v1 = adam_update_numpy(var1_np, grads1_np, t, m1, v1) + + # Validate updated params + self.assertAllCloseAccordingToType(var0_np, self.evaluate(var0)) + self.assertAllCloseAccordingToType(var1_np, self.evaluate(var1)) + + def testSharing(self): + for dtype in [dtypes.half, dtypes.float32, dtypes.float64]: + with self.cached_session(): + global_step = variables.Variable(array_ops.zeros([], dtypes.int64)) + # Initialize variables for numpy implementation. + m0, v0, m1, v1 = 0.0, 0.0, 0.0, 0.0 + var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype) + grads0_np = np.array([0.1, 0.1], dtype=dtype.as_numpy_dtype) + var1_np = np.array([3.0, 4.0], dtype=dtype.as_numpy_dtype) + grads1_np = np.array([0.01, 0.01], dtype=dtype.as_numpy_dtype) + + var0 = variables.Variable(var0_np) + var1 = variables.Variable(var1_np) + grads0 = constant_op.constant(grads0_np) + grads1 = constant_op.constant(grads1_np) + opt = adam_gs_optimizer.AdamGSOptimizer(global_step=global_step) + update1 = opt.apply_gradients(zip([grads0, grads1], [var0, var1]), + global_step=global_step) + update2 = opt.apply_gradients(zip([grads0, grads1], [var0, var1]), + global_step=global_step) + variables.global_variables_initializer().run() + + beta1_power, beta2_power = opt._get_beta_accumulators() + + # Fetch params to validate initial values + self.assertAllClose([1.0, 2.0], self.evaluate(var0)) + self.assertAllClose([3.0, 4.0], self.evaluate(var1)) + + # Run 3 steps of intertwined Adam1 and Adam2. + for t in range(1, 4): + self.assertAllCloseAccordingToType(0.9**t, self.evaluate(beta1_power)) + self.assertAllCloseAccordingToType(0.999**t, + self.evaluate(beta2_power)) + if t % 2 == 0: + update1.run() + else: + update2.run() + + var0_np, m0, v0 = adam_update_numpy(var0_np, grads0_np, t, m0, v0) + var1_np, m1, v1 = adam_update_numpy(var1_np, grads1_np, t, m1, v1) + + # Validate updated params + self.assertAllCloseAccordingToType(var0_np, self.evaluate(var0)) + self.assertAllCloseAccordingToType(var1_np, self.evaluate(var1)) + + def testTwoSessions(self): + optimizer = adam_gs_optimizer.AdamGSOptimizer() + + with context.eager_mode(): + var0 = variables.Variable(np.array([1.0, 2.0]), name="v0") + grads0 = constant_op.constant(np.array([0.1, 0.1])) + optimizer.apply_gradients([(grads0, var0)]) + + g = ops.Graph() + with g.as_default(): + with session.Session(): + var0 = variables.Variable(np.array([1.0, 2.0]), name="v0") + grads0 = constant_op.constant(np.array([0.1, 0.1])) + optimizer.apply_gradients([(grads0, var0)]) + + gg = ops.Graph() + with gg.as_default(): + with session.Session(): + var0 = variables.Variable(np.array([1.0, 2.0]), name="v0") + grads0 = constant_op.constant(np.array([0.1, 0.1])) + + # If the optimizer saves any state not keyed by graph the following line + # fails. + optimizer.apply_gradients([(grads0, var0)]) + + def testSlotsUniqueEager(self): + with context.eager_mode(): + v1 = resource_variable_ops.ResourceVariable(1.) + v2 = resource_variable_ops.ResourceVariable(1.) + opt = adam_gs_optimizer.AdamGSOptimizer(1.) + opt.minimize(lambda: v1 + v2) + # There should be two unique slot variables for v1 and v2 respectively. + self.assertEqual(4, len(set(opt.variables()))) + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/opt/python/training/elastic_average_optimizer.py b/tensorflow/contrib/opt/python/training/elastic_average_optimizer.py index 6c203e5519e6a66d20e2509eca3c74eb66bf32c7..fa1a7aaff0aa59a6a64b1f0bf836a273926d785d 100644 --- a/tensorflow/contrib/opt/python/training/elastic_average_optimizer.py +++ b/tensorflow/contrib/opt/python/training/elastic_average_optimizer.py @@ -30,6 +30,7 @@ from tensorflow.python.ops import variables from tensorflow.python.training import optimizer from tensorflow.python.training import saver from tensorflow.python.training import session_run_hook +from tensorflow.python.training.saving import saveable_object_util LOCAL_VARIABLE_NAME = 'local_center_variable' GLOBAL_VARIABLE_NAME = 'global_center_variable' @@ -424,7 +425,7 @@ class ElasticAverageOptimizer(optimizer.Optimizer): if var_list is None: var_list = variables.trainable_variables() if not isinstance(var_list, dict): - var_list = saver.BaseSaverBuilder.OpListToDict(var_list) + var_list = saveable_object_util.op_list_to_dict(var_list) swapped_var_list = {} for key, var in var_list.items(): @@ -464,4 +465,4 @@ class _ElasticAverageOptimizerHook(session_run_hook.SessionRunHook): def after_create_session(self, session, coord): """Run initialization ops""" - session.run(self._variable_init_op) \ No newline at end of file + session.run(self._variable_init_op) diff --git a/tensorflow/contrib/opt/python/training/lazy_adam_gs_optimizer.py b/tensorflow/contrib/opt/python/training/lazy_adam_gs_optimizer.py new file mode 100644 index 0000000000000000000000000000000000000000..8827007e4d7f6722398a8e36bd626377842d92ef --- /dev/null +++ b/tensorflow/contrib/opt/python/training/lazy_adam_gs_optimizer.py @@ -0,0 +1,114 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""LazyAdam rewrite to use global step for computing beta1 & beta2 accumulation. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.opt.python.training import adam_gs_optimizer +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import resource_variable_ops +from tensorflow.python.ops import state_ops + + +class LazyAdamGSOptimizer(adam_gs_optimizer.AdamGSOptimizer): + """Variant of the Adam optimizer that handles sparse updates more efficiently. + + Branched from tf.contrib.opt.LazyAdamGSOptimizer. The only difference is to + pass global step for computing beta1 and beta2 accumulators, instead of having + optimizer keep its own independent beta1 and beta2 accumulators as non-slot + variables. + + The original Adam algorithm maintains two moving-average accumulators for + each trainable variable; the accumulators are updated at every step. + This class provides lazier handling of gradient updates for sparse variables. + It only updates moving-average accumulators for sparse variable indices that + appear in the current batch, rather than updating the accumulators for all + indices. Compared with the original Adam optimizer, it can provide large + improvements in model training throughput for some applications. However, it + provides slightly different semantics than the original Adam algorithm, and + may lead to different empirical results. + """ + + def _apply_sparse(self, grad, var): + beta1_power, beta2_power = self._get_beta_accumulators() + beta1_power = math_ops.cast(beta1_power, var.dtype.base_dtype) + beta2_power = math_ops.cast(beta2_power, var.dtype.base_dtype) + lr_t = math_ops.cast(self._lr_t, var.dtype.base_dtype) + beta1_t = math_ops.cast(self._beta1_t, var.dtype.base_dtype) + beta2_t = math_ops.cast(self._beta2_t, var.dtype.base_dtype) + epsilon_t = math_ops.cast(self._epsilon_t, var.dtype.base_dtype) + lr = (lr_t * math_ops.sqrt(1 - beta2_power) / (1 - beta1_power)) + + # \\(m := beta1 * m + (1 - beta1) * g_t\\) + m = self.get_slot(var, "m") + m_t = state_ops.scatter_update(m, grad.indices, + beta1_t * array_ops.gather(m, grad.indices) + + (1 - beta1_t) * grad.values, + use_locking=self._use_locking) + + # \\(v := beta2 * v + (1 - beta2) * (g_t * g_t)\\) + v = self.get_slot(var, "v") + v_t = state_ops.scatter_update(v, grad.indices, + beta2_t * array_ops.gather(v, grad.indices) + + (1 - beta2_t) * math_ops.square(grad.values), + use_locking=self._use_locking) + + # \\(variable -= learning_rate * m_t / (epsilon_t + sqrt(v_t))\\) + m_t_slice = array_ops.gather(m_t, grad.indices) + v_t_slice = array_ops.gather(v_t, grad.indices) + denominator_slice = math_ops.sqrt(v_t_slice) + epsilon_t + var_update = state_ops.scatter_sub(var, grad.indices, + lr * m_t_slice / denominator_slice, + use_locking=self._use_locking) + return control_flow_ops.group(var_update, m_t, v_t) + + def _resource_apply_sparse(self, grad, var, indices): + beta1_power, beta2_power = self._get_beta_accumulators() + beta1_power = math_ops.cast(beta1_power, var.dtype.base_dtype) + beta2_power = math_ops.cast(beta2_power, var.dtype.base_dtype) + lr_t = math_ops.cast(self._lr_t, var.dtype.base_dtype) + beta1_t = math_ops.cast(self._beta1_t, var.dtype.base_dtype) + beta2_t = math_ops.cast(self._beta2_t, var.dtype.base_dtype) + epsilon_t = math_ops.cast(self._epsilon_t, var.dtype.base_dtype) + lr = (lr_t * math_ops.sqrt(1 - beta2_power) / (1 - beta1_power)) + + # \\(m := beta1 * m + (1 - beta1) * g_t\\) + m = self.get_slot(var, "m") + m_t_slice = beta1_t * array_ops.gather(m, indices) + (1 - beta1_t) * grad + m_update_op = resource_variable_ops.resource_scatter_update(m.handle, + indices, + m_t_slice) + + # \\(v := beta2 * v + (1 - beta2) * (g_t * g_t)\\) + v = self.get_slot(var, "v") + v_t_slice = (beta2_t * array_ops.gather(v, indices) + + (1 - beta2_t) * math_ops.square(grad)) + v_update_op = resource_variable_ops.resource_scatter_update(v.handle, + indices, + v_t_slice) + + # \\(variable -= learning_rate * m_t / (epsilon_t + sqrt(v_t))\\) + var_slice = lr * m_t_slice / (math_ops.sqrt(v_t_slice) + epsilon_t) + var_update_op = resource_variable_ops.resource_scatter_sub(var.handle, + indices, + var_slice) + + return control_flow_ops.group(var_update_op, m_update_op, v_update_op) diff --git a/tensorflow/contrib/opt/python/training/lazy_adam_gs_optimizer_test.py b/tensorflow/contrib/opt/python/training/lazy_adam_gs_optimizer_test.py new file mode 100644 index 0000000000000000000000000000000000000000..bdc9a02a546c8399172d0c5b58941b4d80179955 --- /dev/null +++ b/tensorflow/contrib/opt/python/training/lazy_adam_gs_optimizer_test.py @@ -0,0 +1,402 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Tests for LazyAdamGSOptimizer.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from absl.testing import parameterized +import numpy as np + +from tensorflow.contrib.opt.python.training import lazy_adam_gs_optimizer +from tensorflow.python.eager import context +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.framework import test_util +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import resource_variable_ops +from tensorflow.python.ops import variables +from tensorflow.python.platform import test + + +def adam_update_numpy(param, + g_t, + t, + m, + v, + alpha=0.001, + beta1=0.9, + beta2=0.999, + epsilon=1e-8): + alpha_t = alpha * np.sqrt(1 - beta2**t) / (1 - beta1**t) + + m_t = beta1 * m + (1 - beta1) * g_t + v_t = beta2 * v + (1 - beta2) * g_t * g_t + + param_t = param - alpha_t * m_t / (np.sqrt(v_t) + epsilon) + return param_t, m_t, v_t + + +class LazyAdamGSOptimizerTest(test.TestCase, parameterized.TestCase): + + @parameterized.parameters([False, True]) + def testSparse(self, use_resource): + for dtype in [dtypes.half, dtypes.float32, dtypes.float64]: + with self.cached_session(): + # Initialize variables for numpy implementation. + m0, v0, m1, v1 = 0.0, 0.0, 0.0, 0.0 + var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype) + grads0_np = np.array([0.1, 0.1], dtype=dtype.as_numpy_dtype) + var1_np = np.array([3.0, 4.0], dtype=dtype.as_numpy_dtype) + grads1_np = np.array([0.01, 0.01], dtype=dtype.as_numpy_dtype) + + if use_resource: + global_step = resource_variable_ops.ResourceVariable( + array_ops.zeros([], dtypes.int64)) + var0 = resource_variable_ops.ResourceVariable(var0_np) + var1 = resource_variable_ops.ResourceVariable(var1_np) + else: + global_step = variables.Variable(array_ops.zeros([], dtypes.int64)) + var0 = variables.Variable(var0_np) + var1 = variables.Variable(var1_np) + + grads0_np_indices = np.array([0, 1], dtype=np.int32) + grads0 = ops.IndexedSlices( + constant_op.constant(grads0_np), + constant_op.constant(grads0_np_indices), constant_op.constant([2])) + grads1_np_indices = np.array([0, 1], dtype=np.int32) + grads1 = ops.IndexedSlices( + constant_op.constant(grads1_np), + constant_op.constant(grads1_np_indices), constant_op.constant([2])) + opt = lazy_adam_gs_optimizer.LazyAdamGSOptimizer( + global_step=global_step) + update = opt.apply_gradients(zip([grads0, grads1], [var0, var1]), + global_step=global_step) + variables.global_variables_initializer().run() + + # Fetch params to validate initial values + self.assertAllClose([1.0, 2.0], var0.eval()) + self.assertAllClose([3.0, 4.0], var1.eval()) + + beta1_power, beta2_power = opt._get_beta_accumulators() + + # Run 3 steps of Adam + for t in range(1, 4): + self.assertAllCloseAccordingToType(0.9**t, beta1_power.eval()) + self.assertAllCloseAccordingToType(0.999**t, beta2_power.eval()) + update.run() + + var0_np, m0, v0 = adam_update_numpy(var0_np, grads0_np, t, m0, v0) + var1_np, m1, v1 = adam_update_numpy(var1_np, grads1_np, t, m1, v1) + + # Validate updated params + self.assertAllCloseAccordingToType(var0_np, var0.eval()) + self.assertAllCloseAccordingToType(var1_np, var1.eval()) + + @parameterized.parameters([False, True]) + def testSparseDevicePlacement(self, use_resource): + for index_dtype in [dtypes.int32, dtypes.int64]: + with self.cached_session(force_gpu=test.is_gpu_available()): + # If a GPU is available, tests that all optimizer ops can be placed on + # it (i.e. they have GPU kernels). + if use_resource: + global_step = resource_variable_ops.ResourceVariable( + array_ops.zeros([], dtypes.int64)) + var = resource_variable_ops.ResourceVariable([[1.0], [2.0]]) + else: + global_step = variables.Variable(array_ops.zeros([], dtypes.int64)) + var = variables.Variable([[1.0], [2.0]]) + + indices = constant_op.constant([0, 1], dtype=index_dtype) + gathered_sum = math_ops.reduce_sum(array_ops.gather(var, indices)) + optimizer = lazy_adam_gs_optimizer.LazyAdamGSOptimizer( + global_step=global_step, learning_rate=3.0) + minimize_op = optimizer.minimize(gathered_sum, global_step=global_step) + variables.global_variables_initializer().run() + minimize_op.run() + + @parameterized.parameters([False, True]) + def testSparseRepeatedIndices(self, use_resource): + for dtype in [dtypes.half, dtypes.float32, dtypes.float64]: + with self.cached_session(): + if use_resource: + repeated_index_global_step = resource_variable_ops.ResourceVariable( + array_ops.zeros([], dtypes.int64)) + aggregated_global_step = resource_variable_ops.ResourceVariable( + array_ops.zeros([], dtypes.int64)) + repeated_index_update_var = resource_variable_ops.ResourceVariable( + [[1.0], [2.0]], dtype=dtype) + aggregated_update_var = resource_variable_ops.ResourceVariable( + [[1.0], [2.0]], dtype=dtype) + else: + repeated_index_global_step = variables.Variable( + array_ops.zeros([], dtypes.int64)) + aggregated_global_step = variables.Variable( + array_ops.zeros([], dtypes.int64)) + repeated_index_update_var = variables.Variable( + [[1.0], [2.0]], dtype=dtype) + aggregated_update_var = variables.Variable( + [[1.0], [2.0]], dtype=dtype) + + grad_repeated_index = ops.IndexedSlices( + constant_op.constant( + [0.1, 0.1], shape=[2, 1], dtype=dtype), + constant_op.constant([1, 1]), + constant_op.constant([2, 1])) + grad_aggregated = ops.IndexedSlices( + constant_op.constant( + [0.2], shape=[1, 1], dtype=dtype), + constant_op.constant([1]), + constant_op.constant([2, 1])) + repeated_update_opt = lazy_adam_gs_optimizer.LazyAdamGSOptimizer( + global_step=repeated_index_global_step) + repeated_update = repeated_update_opt.apply_gradients( + [(grad_repeated_index, repeated_index_update_var)], + global_step=repeated_index_global_step) + aggregated_update_opt = lazy_adam_gs_optimizer.LazyAdamGSOptimizer( + global_step=aggregated_global_step) + aggregated_update = aggregated_update_opt.apply_gradients( + [(grad_aggregated, aggregated_update_var)], + global_step=aggregated_global_step) + variables.global_variables_initializer().run() + self.assertAllClose(aggregated_update_var.eval(), + repeated_index_update_var.eval()) + for _ in range(3): + repeated_update.run() + aggregated_update.run() + self.assertAllClose(aggregated_update_var.eval(), + repeated_index_update_var.eval()) + + def doTestBasic(self, use_resource=False, use_callable_params=False): + for i, dtype in enumerate([dtypes.half, dtypes.float32, dtypes.float64]): + with self.session(graph=ops.Graph()): + # Initialize variables for numpy implementation. + m0, v0, m1, v1 = 0.0, 0.0, 0.0, 0.0 + var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype) + grads0_np = np.array([0.1, 0.1], dtype=dtype.as_numpy_dtype) + var1_np = np.array([3.0, 4.0], dtype=dtype.as_numpy_dtype) + grads1_np = np.array([0.01, 0.01], dtype=dtype.as_numpy_dtype) + + if use_resource: + global_step = resource_variable_ops.ResourceVariable( + array_ops.zeros([], dtypes.int64), name="global_step_%d" % i) + var0 = resource_variable_ops.ResourceVariable( + var0_np, name="var0_%d" % i) + var1 = resource_variable_ops.ResourceVariable( + var1_np, name="var1_%d" % i) + else: + global_step = variables.Variable(array_ops.zeros([], dtypes.int64)) + var0 = variables.Variable(var0_np) + var1 = variables.Variable(var1_np) + grads0 = constant_op.constant(grads0_np) + grads1 = constant_op.constant(grads1_np) + + learning_rate = lambda: 0.001 + beta1 = lambda: 0.9 + beta2 = lambda: 0.999 + epsilon = lambda: 1e-8 + if not use_callable_params: + learning_rate = learning_rate() + beta1 = beta1() + beta2 = beta2() + epsilon = epsilon() + + opt = lazy_adam_gs_optimizer.LazyAdamGSOptimizer( + global_step=global_step, learning_rate=learning_rate) + update = opt.apply_gradients(zip([grads0, grads1], [var0, var1]), + global_step=global_step) + opt_variables = opt.variables() + beta1_power, beta2_power = opt._get_beta_accumulators() + self.assertIsNotNone(beta1_power) + self.assertIsNotNone(beta2_power is not None) + self.assertNotIn(beta1_power, opt_variables) + self.assertNotIn(beta2_power, opt_variables) + + if not context.executing_eagerly(): + with ops.Graph().as_default(): + # Shouldn't return non-slot variables from other graphs. + self.assertEqual(0, len(opt.variables())) + self.evaluate(variables.global_variables_initializer()) + # Fetch params to validate initial values + self.assertAllClose([1.0, 2.0], self.evaluate(var0)) + self.assertAllClose([3.0, 4.0], self.evaluate(var1)) + + # Run 3 steps of Adam + for t in range(1, 4): + if not context.executing_eagerly(): + self.evaluate(update) + self.assertAllCloseAccordingToType( + 0.9**(t + 1), self.evaluate(beta1_power)) + self.assertAllCloseAccordingToType( + 0.999**(t + 1), self.evaluate(beta2_power)) + else: + if t > 1: + opt.apply_gradients(zip([grads0, grads1], [var0, var1]), + global_step=global_step) + beta1_power, beta2_power = opt._get_beta_accumulators() + self.assertAllCloseAccordingToType( + 0.9**t, self.evaluate(beta1_power)) + self.assertAllCloseAccordingToType( + 0.999**t, self.evaluate(beta2_power)) + + var0_np, m0, v0 = adam_update_numpy(var0_np, grads0_np, t, m0, v0) + var1_np, m1, v1 = adam_update_numpy(var1_np, grads1_np, t, m1, v1) + + # Validate updated params + self.assertAllCloseAccordingToType(var0_np, self.evaluate(var0)) + self.assertAllCloseAccordingToType(var1_np, self.evaluate(var1)) + if use_resource: + self.assertEqual("var0_%d/Adam:0" % (i,), + opt.get_slot(var=var0, name="m").name) + + def testBasic(self): + with self.cached_session(): + self.doTestBasic(use_resource=False) + + @test_util.run_in_graph_and_eager_modes(reset_test=True) + def testResourceBasic(self): + self.doTestBasic(use_resource=True) + + def testBasicCallableParams(self): + with context.eager_mode(): + self.doTestBasic(use_resource=True, use_callable_params=True) + + def testTensorLearningRate(self): + for dtype in [dtypes.half, dtypes.float32, dtypes.float64]: + with self.cached_session(): + global_step = variables.Variable(array_ops.zeros([], dtypes.int64)) + # Initialize variables for numpy implementation. + m0, v0, m1, v1 = 0.0, 0.0, 0.0, 0.0 + var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype) + grads0_np = np.array([0.1, 0.1], dtype=dtype.as_numpy_dtype) + var1_np = np.array([3.0, 4.0], dtype=dtype.as_numpy_dtype) + grads1_np = np.array([0.01, 0.01], dtype=dtype.as_numpy_dtype) + + var0 = variables.Variable(var0_np) + var1 = variables.Variable(var1_np) + grads0 = constant_op.constant(grads0_np) + grads1 = constant_op.constant(grads1_np) + opt = lazy_adam_gs_optimizer.LazyAdamGSOptimizer( + global_step=global_step, learning_rate=constant_op.constant(0.001)) + update = opt.apply_gradients(zip([grads0, grads1], [var0, var1]), + global_step=global_step) + variables.global_variables_initializer().run() + + # Fetch params to validate initial values + self.assertAllClose([1.0, 2.0], var0.eval()) + self.assertAllClose([3.0, 4.0], var1.eval()) + + beta1_power, beta2_power = opt._get_beta_accumulators() + + # Run 3 steps of Adam + for t in range(1, 4): + self.assertAllCloseAccordingToType(0.9**t, beta1_power.eval()) + self.assertAllCloseAccordingToType(0.999**t, beta2_power.eval()) + update.run() + + var0_np, m0, v0 = adam_update_numpy(var0_np, grads0_np, t, m0, v0) + var1_np, m1, v1 = adam_update_numpy(var1_np, grads1_np, t, m1, v1) + + # Validate updated params + self.assertAllCloseAccordingToType(var0_np, var0.eval()) + self.assertAllCloseAccordingToType(var1_np, var1.eval()) + + def testSharing(self): + for dtype in [dtypes.half, dtypes.float32, dtypes.float64]: + with self.cached_session(): + global_step = variables.Variable(array_ops.zeros([], dtypes.int64)) + # Initialize variables for numpy implementation. + m0, v0, m1, v1 = 0.0, 0.0, 0.0, 0.0 + var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype) + grads0_np = np.array([0.1, 0.1], dtype=dtype.as_numpy_dtype) + var1_np = np.array([3.0, 4.0], dtype=dtype.as_numpy_dtype) + grads1_np = np.array([0.01, 0.01], dtype=dtype.as_numpy_dtype) + + var0 = variables.Variable(var0_np) + var1 = variables.Variable(var1_np) + grads0 = constant_op.constant(grads0_np) + grads1 = constant_op.constant(grads1_np) + opt = lazy_adam_gs_optimizer.LazyAdamGSOptimizer( + global_step=global_step) + update1 = opt.apply_gradients(zip([grads0, grads1], [var0, var1]), + global_step=global_step) + update2 = opt.apply_gradients(zip([grads0, grads1], [var0, var1]), + global_step=global_step) + variables.global_variables_initializer().run() + + beta1_power, beta2_power = opt._get_beta_accumulators() + + # Fetch params to validate initial values + self.assertAllClose([1.0, 2.0], var0.eval()) + self.assertAllClose([3.0, 4.0], var1.eval()) + + # Run 3 steps of intertwined Adam1 and Adam2. + for t in range(1, 4): + self.assertAllCloseAccordingToType(0.9**t, beta1_power.eval()) + self.assertAllCloseAccordingToType(0.999**t, beta2_power.eval()) + if t % 2 == 0: + update1.run() + else: + update2.run() + + var0_np, m0, v0 = adam_update_numpy(var0_np, grads0_np, t, m0, v0) + var1_np, m1, v1 = adam_update_numpy(var1_np, grads1_np, t, m1, v1) + + # Validate updated params + self.assertAllCloseAccordingToType(var0_np, var0.eval()) + self.assertAllCloseAccordingToType(var1_np, var1.eval()) + + def testTwoSessions(self): + optimizer = lazy_adam_gs_optimizer.LazyAdamGSOptimizer() + + with context.eager_mode(): + var0 = variables.Variable(np.array([1.0, 2.0]), name="v0") + grads0 = constant_op.constant(np.array([0.1, 0.1])) + optimizer.apply_gradients([(grads0, var0)]) + + g = ops.Graph() + with g.as_default(): + with self.session(graph=g): + var0 = variables.Variable(np.array([1.0, 2.0]), name="v0") + grads0 = constant_op.constant(np.array([0.1, 0.1])) + optimizer.apply_gradients([(grads0, var0)]) + + gg = ops.Graph() + with gg.as_default(): + with self.session(graph=gg): + var0 = variables.Variable(np.array([1.0, 2.0]), name="v0") + grads0 = constant_op.constant(np.array([0.1, 0.1])) + + # If the optimizer saves any state not keyed by graph the following line + # fails. + optimizer.apply_gradients([(grads0, var0)]) + + def testSlotsUniqueEager(self): + with context.eager_mode(): + v1 = resource_variable_ops.ResourceVariable(1.) + v2 = resource_variable_ops.ResourceVariable(1.) + opt = lazy_adam_gs_optimizer.LazyAdamGSOptimizer(1.) + opt.minimize(lambda: v1 + v2) + # There should be two non-slot variables, and two unique slot variables + # for v1 and v2 respectively. + self.assertLen(set(opt.variables()), 4) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/opt/python/training/moving_average_optimizer.py b/tensorflow/contrib/opt/python/training/moving_average_optimizer.py index b7fd2d2fb9db3eed15eb1cc2934199939790b1c0..bf3e5c51f78cc3ca3c7c77009c9cf428c4988953 100644 --- a/tensorflow/contrib/opt/python/training/moving_average_optimizer.py +++ b/tensorflow/contrib/opt/python/training/moving_average_optimizer.py @@ -26,6 +26,7 @@ from tensorflow.python.ops import variables from tensorflow.python.training import moving_averages from tensorflow.python.training import optimizer from tensorflow.python.training import saver +from tensorflow.python.training.saving import saveable_object_util class MovingAverageOptimizer(optimizer.Optimizer): @@ -165,7 +166,7 @@ class MovingAverageOptimizer(optimizer.Optimizer): if var_list is None: var_list = variables.global_variables() if not isinstance(var_list, dict): - var_list = saver.BaseSaverBuilder.OpListToDict(var_list) + var_list = saveable_object_util.op_list_to_dict(var_list) v_name_to_tensor = {} for k, tensor_or_list in six.iteritems(var_list): diff --git a/tensorflow/contrib/opt/python/training/weight_decay_optimizers.py b/tensorflow/contrib/opt/python/training/weight_decay_optimizers.py index 200b0d200826a6212a236680327f4daf7d07831f..8b8065c678e11e8fc237e71cf1d392ced5c22ada 100644 --- a/tensorflow/contrib/opt/python/training/weight_decay_optimizers.py +++ b/tensorflow/contrib/opt/python/training/weight_decay_optimizers.py @@ -59,6 +59,23 @@ class DecoupledWeightDecayExtension(object): Note that this extension decays weights BEFORE applying the update based on the gradient, i.e. this extension only has the desired behaviour for optimizers which do not depend on the value of'var' in the update step! + + Note: when applying a decay to the learning rate, be sure to manually apply + the decay to the `weight_decay` as well. For example: + + ```python + schedule = tf.train.piecewise_constant(tf.train.get_global_step(), + [10000, 15000], [1e-0, 1e-1, 1e-2]) + lr = 1e-1 * schedule() + wd = lambda: 1e-4 * schedule() + + # ... + + optimizer = tf.contrib.opt.MomentumWOptimizer(learning_rate=lr, + weight_decay=wd, + momentum=0.9, + use_nesterov=True) + ``` """ def __init__(self, weight_decay, **kwargs): diff --git a/tensorflow/contrib/optimizer_v2/optimizer_v2.py b/tensorflow/contrib/optimizer_v2/optimizer_v2.py index 73a556f0b299614b098ceef0fb9d32f148227b03..7fb23abc38d9dc101204ed83808aebe5a8ef1e78 100644 --- a/tensorflow/contrib/optimizer_v2/optimizer_v2.py +++ b/tensorflow/contrib/optimizer_v2/optimizer_v2.py @@ -25,6 +25,7 @@ import abc import six from tensorflow.python.distribute import distribute_lib +from tensorflow.python.distribute import distribution_strategy_context as distribute_ctx from tensorflow.python.distribute import reduce_util as ds_reduce_util from tensorflow.python.eager import backprop from tensorflow.python.eager import context @@ -36,7 +37,6 @@ from tensorflow.python.ops import math_ops from tensorflow.python.ops import resource_variable_ops from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variables -from tensorflow.python.training import distribution_strategy_context as distribute_ctx from tensorflow.python.training import optimizer as optimizer_v1 from tensorflow.python.training import slot_creator from tensorflow.python.training.checkpointable import base as checkpointable @@ -997,10 +997,10 @@ class OptimizerV2(optimizer_v1.Optimizer): with ops.control_dependencies([update_ops]): finish_updates = distribution.extended.update_non_slot( non_slot_devices, finish, group=False) - # We said grouped=False, which means finish_updates is always a list. - # It will be [None] when finish() returns None. - if finish_updates == [None]: - finish_updates = [update_ops] + # We said group=False, which means finish_updates is always a tuple. + # It will be (None,) when finish() returns None. + if finish_updates == (None,): + finish_updates = (update_ops,) # Update `global_step` (if any). if global_step is None: diff --git a/tensorflow/contrib/predictor/BUILD b/tensorflow/contrib/predictor/BUILD index d50b52b8ff1ce8188ab52c6968d716378efd9daa..53a3bc63e1d770b451846c45370fdee9ffa72d70 100644 --- a/tensorflow/contrib/predictor/BUILD +++ b/tensorflow/contrib/predictor/BUILD @@ -42,6 +42,7 @@ py_library( name = "saved_model_predictor", srcs = ["saved_model_predictor.py"], srcs_version = "PY2AND3", + visibility = ["//learning/brain/contrib/learn/tpu:__subpackages__"], deps = [ ":base_predictor", "//tensorflow/contrib/saved_model:saved_model_py", diff --git a/tensorflow/contrib/quantize/python/quantize.py b/tensorflow/contrib/quantize/python/quantize.py index 21d1b1213090273b5abd8e012f8711db98c94347..7c973fe597181b822e617db1f85a08f1b678e26f 100644 --- a/tensorflow/contrib/quantize/python/quantize.py +++ b/tensorflow/contrib/quantize/python/quantize.py @@ -685,7 +685,7 @@ def _InsertQuantOp(context, [1; 2^bits - 1] or wide range [0; 2^bits - 1]. producer_scope: The restriction of producer scope. If not None, the new op will be inserted only when the producer is in this scope. - consumer_scope: The restriction of producer scope. If not None, the new op + consumer_scope: The restriction of consumer scope. If not None, the new op will be inserted only when all the consumers are in this scope. Raises: ValueError: When producer operation is not directly connected to the diff --git a/tensorflow/contrib/rate/BUILD b/tensorflow/contrib/rate/BUILD index c461a7145e27c4238161cec989448be807acd543..76db9aecf615d0a94f65cd7ea799db245828db1c 100644 --- a/tensorflow/contrib/rate/BUILD +++ b/tensorflow/contrib/rate/BUILD @@ -34,6 +34,11 @@ py_test( name = "rate_test", size = "small", srcs = ["rate_test.py"], + tags = [ + "manual", # TODO(b/120555555) + "no_oss", # TODO(b/120555555) + "notap", # TODO(b/120555555) + ], deps = [ ":rate", "//tensorflow/python:array_ops", diff --git a/tensorflow/contrib/receptive_field/README.md b/tensorflow/contrib/receptive_field/README.md index 79b015a9163f5727caa40b54579c71e57621c92f..d1c41e4c0a11028765c9fc0dc345cb29453baa31 100644 --- a/tensorflow/contrib/receptive_field/README.md +++ b/tensorflow/contrib/receptive_field/README.md @@ -185,5 +185,4 @@ Effective padding (vertical) = 1482 ## Authors -André Araujo (github id: andrefaraujo) and Mark Sandler (github id: -marksandler) +André Araujo (@andrefaraujo) and Mark Sandler (@marksandler) diff --git a/tensorflow/contrib/receptive_field/python/util/examples/compute_rf.py b/tensorflow/contrib/receptive_field/python/util/examples/compute_rf.py index d6fdd12bbe37fb0e0cb12f1d0adc3fce29b19e8a..72f98ccc32e945b48b5f1b570bcca323a5b5f48a 100644 --- a/tensorflow/contrib/receptive_field/python/util/examples/compute_rf.py +++ b/tensorflow/contrib/receptive_field/python/util/examples/compute_rf.py @@ -12,10 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Computes Receptive Field (RF) information given a graph protobuf. - -For an example of usage, see accompanying file compute_rf.sh -""" +"""Computes Receptive Field (RF) information given a graph protobuf.""" from __future__ import absolute_import from __future__ import division diff --git a/tensorflow/contrib/receptive_field/python/util/examples/rf_benchmark.py b/tensorflow/contrib/receptive_field/python/util/examples/rf_benchmark.py index a298b4d49038468299b58140758c69675368e855..325929a5937ac60a6134fae064e7633a4c57473d 100644 --- a/tensorflow/contrib/receptive_field/python/util/examples/rf_benchmark.py +++ b/tensorflow/contrib/receptive_field/python/util/examples/rf_benchmark.py @@ -16,8 +16,6 @@ The receptive field (and related parameters) for the different models are printed to stdout, and may also optionally be written to a CSV file. - -For an example of usage, see rf_benchmark.sh """ from __future__ import absolute_import @@ -262,11 +260,11 @@ def _model_rf(graphdef, information will be computed. model_type: Type of model to be used, used only for printing purposes. csv_writer: A CSV writer for RF parameters, which is used if it is not None. - input_resolution: Input resolution to use when computing RF - parameters. This is important for the case where padding can only be - defined if the input resolution is known, which may happen if using SAME - padding. This is assumed the resolution for both height and width. If - None, we consider the resolution is unknown. + input_resolution: Input resolution to use when computing RF parameters. This + is important for the case where padding can only be defined if the input + resolution is known, which may happen if using SAME padding. This is + assumed the resolution for both height and width. If None, we consider the + resolution is unknown. """ for desired_end_point_key in desired_end_point_keys: print('- %s:' % desired_end_point_key) @@ -283,10 +281,10 @@ def _model_rf(graphdef, if (receptive_field_x == receptive_field_y) and ( effective_stride_x == effective_stride_y) and ( effective_padding_x == effective_padding_y): - print('Receptive field size = %5s, effective stride = %5s, effective ' - 'padding = %5s' % (str(receptive_field_x), - str(effective_stride_x), - str(effective_padding_x))) + print( + 'Receptive field size = %5s, effective stride = %5s, effective ' + 'padding = %5s' % (str(receptive_field_x), str(effective_stride_x), + str(effective_padding_x))) else: print('Receptive field size: horizontal = %5s, vertical = %5s. ' 'Effective stride: horizontal = %5s, vertical = %5s. Effective ' @@ -362,9 +360,8 @@ def _process_model_rf(model_type='resnet_v1_50', defined if the input resolution is known, which may happen if using SAME padding. The entries in the list are assumed the resolution for both height and width. If one of the elements in the list is None, we consider - it to mean that the resolution is unknown. If the list itself is None, - we use the default list [None, 224, 321]. - + it to mean that the resolution is unknown. If the list itself is None, we + use the default list [None, 224, 321]. """ # Process default value for this list. if input_resolutions is None: @@ -477,8 +474,8 @@ def _mobilenet_v1_rf(csv_writer=None): csv_writer: A CSV writer for RF parameters, which is used if it is not None. """ for model_type in _SUPPORTED_MOBILENETV1_VARIANTS: - with slim.arg_scope( - [slim.batch_norm, slim.dropout], is_training=False) as arg_sc: + with slim.arg_scope([slim.batch_norm, slim.dropout], + is_training=False) as arg_sc: _process_model_rf(model_type, csv_writer, arg_sc) diff --git a/tensorflow/contrib/receptive_field/python/util/receptive_field.py b/tensorflow/contrib/receptive_field/python/util/receptive_field.py index b9bd2f09761ab10a62d37e8e2580b93b9b8a4453..9127c772c75279d9c8eacc5a17680beba9247d01 100644 --- a/tensorflow/contrib/receptive_field/python/util/receptive_field.py +++ b/tensorflow/contrib/receptive_field/python/util/receptive_field.py @@ -12,12 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Functions to compute receptive field of a fully-convolutional network. - -Please refer to the following g3doc for detailed explanation on how this -computation is performed, and why it is important: -g3doc/photos/vision/features/delf/g3doc/rf_computation.md -""" +"""Functions to compute receptive field of a fully-convolutional network.""" from __future__ import absolute_import from __future__ import division @@ -96,8 +91,8 @@ class ReceptiveField(object): Args: y: An array of feature coordinates with shape `(..., d)`, where `d` is the number of dimensions of the coordinates. - axis: The dimensions for which to compute the input center coordinates. - If `None` (the default), compute the input center coordinates for all + axis: The dimensions for which to compute the input center coordinates. If + `None` (the default), compute the input center coordinates for all dimensions. Returns: @@ -127,8 +122,8 @@ class ReceptiveField(object): Args: x: An array of input center coordinates with shape `(..., d)`, where `d` is the number of dimensions of the coordinates. - axis: The dimensions for which to compute the feature coordinates. - If `None` (the default), compute the feature coordinates for all + axis: The dimensions for which to compute the feature coordinates. If + `None` (the default), compute the feature coordinates for all dimensions. Returns: @@ -274,14 +269,15 @@ def compute_receptive_field_from_graph_def(graph_def, continue # Get params for this layer. - (kernel_size_x, kernel_size_y, stride_x, stride_y, padding_x, - padding_y, _, _) = parse_layer_parameters.get_layer_params( + (kernel_size_x, kernel_size_y, stride_x, stride_y, padding_x, padding_y, + _, _) = parse_layer_parameters.get_layer_params( node, name_to_node, node_info[node.name].input_size) - logging.vlog(3, "kernel_size_x = %s, kernel_size_y = %s, " - "stride_x = %s, stride_y = %s, " - "padding_x = %s, padding_y = %s, input size = %s" % - (kernel_size_x, kernel_size_y, stride_x, stride_y, padding_x, - padding_y, node_info[node.name].input_size)) + logging.vlog( + 3, "kernel_size_x = %s, kernel_size_y = %s, " + "stride_x = %s, stride_y = %s, " + "padding_x = %s, padding_y = %s, input size = %s" % + (kernel_size_x, kernel_size_y, stride_x, stride_y, padding_x, + padding_y, node_info[node.name].input_size)) if padding_x is None or padding_y is None: undefined_padding = True @@ -352,15 +348,15 @@ def compute_receptive_field_from_graph_def(graph_def, raise ValueError( "Graph is not aligned since effective stride from different " "paths is different in vertical direction") - if (rf_sizes_x[inp_name] - 1 - ) / 2 - effective_paddings_x[inp_name] != ( - rf_size_input_x - 1) / 2 - effective_padding_input_x: + if (rf_sizes_x[inp_name] - + 1) / 2 - effective_paddings_x[inp_name] != ( + rf_size_input_x - 1) / 2 - effective_padding_input_x: raise ValueError( "Graph is not aligned since center shift from different " "paths is different in horizontal direction") - if (rf_sizes_y[inp_name] - 1 - ) / 2 - effective_paddings_y[inp_name] != ( - rf_size_input_y - 1) / 2 - effective_padding_input_y: + if (rf_sizes_y[inp_name] - + 1) / 2 - effective_paddings_y[inp_name] != ( + rf_size_input_y - 1) / 2 - effective_padding_input_y: raise ValueError( "Graph is not aligned since center shift from different " "paths is different in vertical direction") diff --git a/tensorflow/contrib/resampler/xla/resampler_ops_xla_test.py b/tensorflow/contrib/resampler/xla/resampler_ops_xla_test.py index d8ca0eab276b39f025d018edebb78eed7a8433bb..cec4c3c23305034d167a248a637425507750064e 100644 --- a/tensorflow/contrib/resampler/xla/resampler_ops_xla_test.py +++ b/tensorflow/contrib/resampler/xla/resampler_ops_xla_test.py @@ -164,6 +164,15 @@ class ResamplerOpsTest(xla_test.XLATestCase): expected = [[[0.0], [27.62]]] self._assertForwardOpMatchesExpected(input_np, warp_np, expected) + expected_grad_data = [[[[0.12], [0.27999997]], [[0.18000001], + [0.42000002]]]] + expected_grad_warp = [[[0., 0.], [22.60000038, 35.20000076]]] + + grad_output = np.ones([1, 2, 1], dtype=dtype) + self._assertBackwardOpMatchesExpected(input_np, warp_np, grad_output, + expected_grad_data, + expected_grad_warp) + # One of (x, y) is less than 0. for dtype in self.float_types: input_shape = [1, 2, 2, 1] @@ -171,11 +180,21 @@ class ResamplerOpsTest(xla_test.XLATestCase): input_np = np.array(input_data, dtype=dtype).reshape(input_shape) warp_shape = [1, 2, 2] + # -1 is out of bound for grad_warp. warp_data = [-1, 0.1, 0.7, 0.6] warp_np = np.array(warp_data, dtype=dtype).reshape(warp_shape) expected = [[[0.0], [27.62]]] self._assertForwardOpMatchesExpected(input_np, warp_np, expected) + expected_grad_data = [[[[0.12], [0.27999997]], [[0.18000001], + [0.42000002]]]] + expected_grad_warp = [[[0., 0.], [22.60000038, 35.20000076]]] + + grad_output = np.ones([1, 2, 1], dtype=dtype) + self._assertBackwardOpMatchesExpected(input_np, warp_np, grad_output, + expected_grad_data, + expected_grad_warp) + # Both of (x, y) are greater than image size. for dtype in self.float_types: input_shape = [1, 2, 2, 1] @@ -183,11 +202,20 @@ class ResamplerOpsTest(xla_test.XLATestCase): input_np = np.array(input_data, dtype=dtype).reshape(input_shape) warp_shape = [1, 2, 2] + # -0.1 is *inbound* for grad_warp and grad_data, 2.1 is out of bound. warp_data = [-0.1, 0.1, 1.2, 2.1] warp_np = np.array(warp_data, dtype=dtype).reshape(warp_shape) expected = [[[0.0], [0.0]]] self._assertForwardOpMatchesExpected(input_np, warp_np, expected) + expected_grad_data = [[[[0.81], [0.0]], [[0.09], [0.0]]]] + expected_grad_warp = [[[10.30, 2.7], [0.0, 0.0]]] + + grad_output = np.ones([1, 2, 1], dtype=dtype) + self._assertBackwardOpMatchesExpected(input_np, warp_np, grad_output, + expected_grad_data, + expected_grad_warp) + # One of (x, y) is greater than image size. for dtype in self.float_types: input_shape = [1, 2, 2, 1] @@ -200,6 +228,14 @@ class ResamplerOpsTest(xla_test.XLATestCase): expected = [[[0.0], [0.0]]] self._assertForwardOpMatchesExpected(input_np, warp_np, expected) + expected_grad_data = [[[[0.81], [0.81]], [[0.0], [0.08]]]] + expected_grad_warp = [[[-4.5, 9.5], [-9.9, 39.20]]] + + grad_output = np.ones([1, 2, 1], dtype=dtype) + self._assertBackwardOpMatchesExpected(input_np, warp_np, grad_output, + expected_grad_data, + expected_grad_warp) + if __name__ == '__main__': test.main() diff --git a/tensorflow/contrib/saved_model/python/saved_model/keras_saved_model.py b/tensorflow/contrib/saved_model/python/saved_model/keras_saved_model.py index ffba514bb96f5ce8d963cb0a0482738eafe88355..2a4b6eae367fe617e9a19d80f16eb3fda9ade1c0 100644 --- a/tensorflow/contrib/saved_model/python/saved_model/keras_saved_model.py +++ b/tensorflow/contrib/saved_model/python/saved_model/keras_saved_model.py @@ -22,53 +22,57 @@ import os import six from tensorflow.python.client import session -from tensorflow.python.estimator import keras as estimator_keras_util -from tensorflow.python.estimator import model_fn as model_fn_lib -from tensorflow.python.estimator.export import export as export_helpers from tensorflow.python.framework import ops from tensorflow.python.keras import backend as K from tensorflow.python.keras import models as models_lib from tensorflow.python.keras import optimizers from tensorflow.python.keras.engine import sequential +from tensorflow.python.keras.engine import training_utils from tensorflow.python.keras.metrics import Metric from tensorflow.python.keras.models import model_from_json from tensorflow.python.lib.io import file_io from tensorflow.python.ops import variables -from tensorflow.python.platform import gfile from tensorflow.python.platform import tf_logging as logging from tensorflow.python.saved_model import builder as saved_model_builder from tensorflow.python.saved_model import constants +from tensorflow.python.saved_model import save as save_lib from tensorflow.python.saved_model import utils_impl as saved_model_utils from tensorflow.python.training import saver as saver_lib from tensorflow.python.training.checkpointable import util as checkpointable_utils from tensorflow.python.util import compat +from tensorflow.python.util import nest +from tensorflow_estimator.python.estimator import keras as estimator_keras_util +from tensorflow_estimator.python.estimator import model_fn as model_fn_lib +from tensorflow_estimator.python.estimator.export import export as export_helpers def save_keras_model( - model, saved_model_path, custom_objects=None, as_text=None): - """Save a `tf.keras.Model` into Tensorflow SavedModel format. + model, saved_model_path, custom_objects=None, as_text=None, + input_signature=None, serving_only=False): + """Saves a `tf.keras.Model` into Tensorflow SavedModel format. `save_model` generates new files/folders under the `saved_model_path` folder: - 1) an asset folder containing the json string of the model's - configuration (topology). - 2) a checkpoint containing the model weights. - 3) a saved_model.pb file containing the model's MetaGraphs. The prediction + 1) a checkpoint containing the model weights. + 2) a saved_model.pb file containing the model's MetaGraphs. The prediction graph is always exported. The evaluaton and training graphs are exported if the following conditions are met: - Evaluation: model loss is defined. - Training: model is compiled with an optimizer defined under `tf.train`. This is because `tf.keras.optimizers.Optimizer` instances cannot be saved to checkpoints. - - Model Requirements: - - Model must be a sequential model or functional model. Subclassed models can - not be saved via this function, unless you provide an implementation for - get_config() and from_config(). - - All variables must be saveable by the model. In general, this condition is - met through the use of layers defined in the keras library. However, - there is currently a bug with variables created in Lambda layer functions - not being saved correctly (see - https://github.com/keras-team/keras/issues/9740). + 3) Model's json configuration, if model.get_config() has been implemented. + This file can be used to reload the model using + tf.keras.models.model_from_json(). Note that if any custom objects were + used, they should be passed to the `custom_object` argument when loading + the model. + + Model limitations: + - Sequential and functional models can always be saved. + - Subclassed models can only be saved when `serving_only=True`. This is due to + the current implementation copying the model in order to export the training + and evaluation graphs. Because the topology of subclassed models cannot be + determined, the subclassed models cannot be cloned. Subclassed models will + be entirely exportable in the future. Note that each mode is exported in separate graphs, so different modes do not share variables. To use the train graph with evaluation or prediction graphs, @@ -94,38 +98,88 @@ def save_keras_model( ``` Args: - model: A `tf.keras.Model` to be saved. + model: A `tf.keras.Model` to be saved. If the model is subclassed, the flag + `serving_only` must be set to True. saved_model_path: a string specifying the path to the SavedModel directory. The SavedModel will be saved to a timestamped folder created within this directory. custom_objects: Optional dictionary mapping string names to custom classes or functions (e.g. custom loss functions). - as_text: whether to write the `SavedModel` proto in text format. + as_text: whether to write the `SavedModel` proto in text format. Currently + unavailable in serving-only mode. + input_signature: A possibly nested sequence of `tf.TensorSpec` objects, used + to specify the expected model inputs. `input_signature`'s nested structure + should match the expected nested structure of the inputs to the model. If + this is not set, this function will attempt to infer the input shapes and + dtypes from the model. Note that if the model is subclassed, the tensor + inputs to the call function should be nested in the first argument (this + is a general requirement for using subclassed models with Keras functions + .fit(), .predict(), etc.). + serving_only: Export only the outputs produced from calling the model in + predict mode. The losses, optimizer, and other training configurations are + not saved. If the SavedModel will only be used for serving (rather than + retraining), or if the model is subclassed, this can be set to True. Returns: String path to the SavedModel folder, a subdirectory of `saved_model_path`. Raises: - NotImplementedError: If the model is a subclassed model. - ValueError: If a Sequential model does not have input shapes defined by the - user, and is not built. + NotImplementedError: If the model is a subclassed model, and serving_only is + False. + ValueError: If the input signature cannot be inferred from the model. """ + export_dir = export_helpers.get_timestamped_export_dir(saved_model_path) + + if serving_only: + save_lib.save( + model, export_dir, + signatures=training_utils.trace_model_call(model, input_signature)) + else: + _save_v1_format(model, export_dir, custom_objects, as_text, input_signature) + + try: + _export_model_json(model, export_dir) + except NotImplementedError: + logging.warning('Skipped saving model JSON, subclassed model does not have ' + 'get_config() defined.') + + return export_dir + + +def _export_model_json(model, saved_model_path): + """Saves model configuration as a json string under assets folder.""" + model_json = model.to_json() + model_json_filepath = os.path.join( + saved_model_utils.get_or_create_assets_dir(saved_model_path), + compat.as_text(constants.SAVED_MODEL_FILENAME_JSON)) + file_io.write_string_to_file(model_json_filepath, model_json) + + +def _export_model_variables(model, saved_model_path): + """Saves model weights in checkpoint format under variables folder.""" + saved_model_utils.get_or_create_variables_dir(saved_model_path) + checkpoint_prefix = saved_model_utils.get_variables_path(saved_model_path) + model.save_weights(checkpoint_prefix, save_format='tf', overwrite=True) + return checkpoint_prefix + + +def _save_v1_format(model, path, custom_objects, as_text, input_signature): + """Exports model to v1 SavedModel format.""" if not model._is_graph_network: if isinstance(model, sequential.Sequential): # If input shape is not directly set in the model, the exported model - # will assume that the inputs have the same shape as the shape the model - # was built model with. - if not model.built: + # will infer the expected shapes of the input from the model. + if not model.built and input_signature is None: raise ValueError( - 'Sequential model must be built before it can be exported.') + 'Sequential model\'s input shape is unknown. Please build the ' + 'model, or use the input_signature argument to specify the ' + 'model inputs.') else: raise NotImplementedError( - 'Exporting subclassed models is not yet supported.') + 'Subclassed models can only be exported for serving. Please set ' + 'argument serving_only=True.') - export_dir = export_helpers.get_timestamped_export_dir(saved_model_path) - temp_export_dir = export_helpers.get_temp_export_dir(export_dir) - - builder = saved_model_builder._SavedModelBuilder(temp_export_dir) + builder = saved_model_builder._SavedModelBuilder(path) # Manually save variables to export them in an object-based checkpoint. This # skips the `builder.add_meta_graph_and_variables()` step, which saves a @@ -133,7 +187,7 @@ def save_keras_model( # TODO(b/113134168): Add fn to Builder to save with object-based saver. # TODO(b/113178242): This should only export the model json structure. Only # one save is needed once the weights can be copied from the model to clone. - checkpoint_path = _export_model_json_and_variables(model, temp_export_dir) + checkpoint_path = _export_model_variables(model, path) # Export each mode. Use ModeKeys enums defined for `Estimator` to ensure that # Keras models and `Estimator`s are exported with the same format. @@ -143,10 +197,12 @@ def save_keras_model( export_args = {'builder': builder, 'model': model, 'custom_objects': custom_objects, - 'checkpoint_path': checkpoint_path} + 'checkpoint_path': checkpoint_path, + 'input_signature': input_signature} has_saved_vars = False if model.optimizer: + # TODO(kathywu): Verify this works with v2 optimizer. if isinstance(model.optimizer, optimizers.TFOptimizer): _export_mode(model_fn_lib.ModeKeys.TRAIN, has_saved_vars, **export_args) has_saved_vars = True @@ -161,34 +217,20 @@ def save_keras_model( builder.save(as_text) - gfile.Rename(temp_export_dir, export_dir) - return export_dir - - -def _export_model_json_and_variables(model, saved_model_path): - """Save model variables and json structure into SavedModel subdirectories.""" - # Save model configuration as a json string under assets folder. - model_json = model.to_json() - model_json_filepath = os.path.join( - saved_model_utils.get_or_create_assets_dir(saved_model_path), - compat.as_text(constants.SAVED_MODEL_FILENAME_JSON)) - file_io.write_string_to_file(model_json_filepath, model_json) - - # Save model weights in checkpoint format under variables folder. - saved_model_utils.get_or_create_variables_dir(saved_model_path) - checkpoint_prefix = saved_model_utils.get_variables_path(saved_model_path) - model.save_weights(checkpoint_prefix, save_format='tf', overwrite=True) - return checkpoint_prefix - def _get_var_list(model): - """Return list of all checkpointed saveable objects in the model.""" + """Returns list of all checkpointed saveable objects in the model.""" return checkpointable_utils.named_saveables(model) +def create_placeholder(spec): + return K.placeholder(shape=spec.shape, dtype=spec.dtype, name=spec.name) + + def _export_mode( - mode, has_saved_vars, builder, model, custom_objects, checkpoint_path): - """Export a model, and optionally save new vars from the clone model. + mode, has_saved_vars, builder, model, custom_objects, checkpoint_path, + input_signature): + """Exports a model, and optionally saves new vars from the clone model. Args: mode: A `tf.estimator.ModeKeys` string. @@ -199,6 +241,8 @@ def _export_mode( custom_objects: A dictionary mapping string names to custom classes or functions. checkpoint_path: String path to checkpoint. + input_signature: Nested TensorSpec containing the expected inputs. Can be + `None`, in which case the signature will be inferred from the model. Raises: ValueError: If the train/eval mode is being exported, but the model does @@ -214,10 +258,16 @@ def _export_mode( K.set_learning_phase(mode == model_fn_lib.ModeKeys.TRAIN) + if input_signature is None: + input_tensors = None + else: + input_tensors = nest.map_structure(create_placeholder, input_signature) + # Clone the model into blank graph. This will create placeholders for inputs # and targets. clone = models_lib.clone_and_build_model( - model, custom_objects=custom_objects, compile_clone=compile_clone) + model, input_tensors=input_tensors, custom_objects=custom_objects, + compile_clone=compile_clone) # Make sure that iterations variable is added to the global step collection, # to ensure that, when the SavedModel graph is loaded, the iterations @@ -271,7 +321,7 @@ def _export_mode( def _create_signature_def_map(model, mode): - """Create a SignatureDef map from a Keras model.""" + """Creates a SignatureDef map from a Keras model.""" inputs_dict = {name: x for name, x in zip(model.input_names, model.inputs)} if model.optimizer: targets_dict = {x.name.split(':')[0]: x @@ -309,14 +359,14 @@ def _create_signature_def_map(model, mode): def _assert_same_non_optimizer_objects(model, model_graph, clone, clone_graph): # pylint: disable=unused-argument - """Assert model and clone contain the same checkpointable objects.""" + """Asserts model and clone contain the same checkpointable objects.""" # TODO(fchollet, kathywu): make sure this works in eager mode. return True def load_keras_model(saved_model_path): - """Load a keras.Model from SavedModel. + """Loads a keras.Model from SavedModel. load_model reinstantiates model state by: 1) loading model topology from json (this will eventually come diff --git a/tensorflow/contrib/saved_model/python/saved_model/keras_saved_model_test.py b/tensorflow/contrib/saved_model/python/saved_model/keras_saved_model_test.py index 93d73e1b484ed810fb347b13e95022dfca3584c2..fbf8138493362d4a3c8a75e1ee1bb2fbe8096499 100644 --- a/tensorflow/contrib/saved_model/python/saved_model/keras_saved_model_test.py +++ b/tensorflow/contrib/saved_model/python/saved_model/keras_saved_model_test.py @@ -29,7 +29,9 @@ from tensorflow.python import keras from tensorflow.python.client import session from tensorflow.python.eager import context from tensorflow.python.estimator import model_fn as model_fn_lib +from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_spec from tensorflow.python.framework import test_util from tensorflow.python.keras.engine import training from tensorflow.python.keras.utils import tf_utils @@ -215,7 +217,7 @@ class LayerWithLearningPhase(keras.engine.base_layer.Layer): return input_shape -def functional_model(uses_learning_phase): +def functional_model(uses_learning_phase=True): inputs = keras.layers.Input(shape=(3,)) x = keras.layers.Dense(2)(inputs) x = keras.layers.Dense(3)(x) @@ -224,7 +226,7 @@ def functional_model(uses_learning_phase): return keras.models.Model(inputs, x) -def sequential_model(uses_learning_phase): +def sequential_model(uses_learning_phase=True): model = keras.models.Sequential() model.add(keras.layers.Dense(2, input_shape=(3,))) model.add(keras.layers.Dense(3)) @@ -233,7 +235,7 @@ def sequential_model(uses_learning_phase): return model -def sequential_model_without_input_shape(uses_learning_phase): +def sequential_model_without_input_shape(uses_learning_phase=True): model = keras.models.Sequential() model.add(keras.layers.Dense(2)) model.add(keras.layers.Dense(3)) @@ -242,10 +244,30 @@ def sequential_model_without_input_shape(uses_learning_phase): return model +class Subclassed(keras.models.Model): + + def __init__(self): + super(Subclassed, self).__init__() + self.dense1 = keras.layers.Dense(2) + self.dense2 = keras.layers.Dense(3) + + def call(self, inputs): + x = self.dense1(inputs) + x = self.dense2(x) + return x + + +def subclassed_model(): + return Subclassed() + + def load_model(sess, path, mode): tags = model_fn_lib.EXPORT_TAG_MAP[mode] - sig_def_key = (signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY - if mode == model_fn_lib.ModeKeys.PREDICT else mode) + if mode == model_fn_lib.ModeKeys.PREDICT: + sig_def_key = signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY + else: + sig_def_key = mode + meta_graph_def = loader_impl.load(sess, tags, path) inputs = { k: sess.graph.get_tensor_by_name(v.name) @@ -463,13 +485,54 @@ class TestModelSavedModelExport(test.TestCase, parameterized.TestCase): clone.compile(loss='mse', optimizer=keras.optimizers.RMSprop(lr=0.0001)) clone.train_on_batch(input_arr, target_arr) - def testSaveSeqModelWithoutInputShapesRaisesError(self): - """A Sequential model that hasn't been built should raise an error.""" + def testSaveSequentialModelWithoutInputShapes(self): model = sequential_model_without_input_shape(True) - with self.assertRaisesRegexp( - ValueError, 'must be built'): + # A Sequential model that hasn't been built should raise an error. + with self.assertRaisesRegexp(ValueError, 'Please build the model'): keras_saved_model.save_keras_model(model, '') + saved_model_path = self._save_model_dir() + output_path = keras_saved_model.save_keras_model( + model, saved_model_path, + input_signature=tensor_spec.TensorSpec(shape=(10, 11, 12, 13, 14), + dtype=dtypes.float32, + name='spec_input')) + + with session.Session(graph=ops.Graph()) as sess: + inputs, outputs, _ = load_model(sess, output_path, + model_fn_lib.ModeKeys.PREDICT) + self.assertEqual(5, inputs[next(iter(inputs.keys()))].shape.ndims) + self.assertEqual(5, outputs[next(iter(outputs.keys()))].shape.ndims) + self.assertEqual(3, outputs[next(iter(outputs.keys()))].shape[-1]) + + @test_util.run_v2_only + @parameterized.parameters( + { + 'model_builder': sequential_model_without_input_shape, + 'input_signature': [tensor_spec.TensorSpec(shape=[None, 3], + dtype=dtypes.float32)]}, + { + 'model_builder': subclassed_model, + 'input_signature': [tensor_spec.TensorSpec(shape=[None, 3], + dtype=dtypes.float32)]}) + def testServingOnly(self, model_builder, input_signature): + saved_model_path = self._save_model_dir() + input_arr = np.random.random((5, 3)).astype(np.float32) + model = model_builder() + ref_predict = model.predict(input_arr) + + output_path = keras_saved_model.save_keras_model( + model, saved_model_path, serving_only=True, + input_signature=input_signature) + + # Load predict graph, and test predictions + with session.Session(graph=ops.Graph()) as sess: + inputs, outputs, _ = load_model(sess, output_path, + model_fn_lib.ModeKeys.PREDICT) + predictions = sess.run(outputs[next(iter(outputs.keys()))], + {inputs[next(iter(inputs.keys()))]: input_arr}) + self.assertAllClose(ref_predict, predictions, atol=1e-05) + if __name__ == '__main__': test.main() diff --git a/tensorflow/contrib/seq2seq/python/kernel_tests/attention_wrapper_test.py b/tensorflow/contrib/seq2seq/python/kernel_tests/attention_wrapper_test.py index 922f21b98b35dfff19c8c605a25e89c5d2da8d98..d815f81f847ad79ddcc6c6ecf5c050598e185d8d 100644 --- a/tensorflow/contrib/seq2seq/python/kernel_tests/attention_wrapper_test.py +++ b/tensorflow/contrib/seq2seq/python/kernel_tests/attention_wrapper_test.py @@ -35,6 +35,7 @@ from tensorflow.python.ops import array_ops from tensorflow.python.ops import init_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import random_ops +from tensorflow.python.ops import rnn from tensorflow.python.ops import rnn_cell from tensorflow.python.ops import variables from tensorflow.python.ops import variable_scope as vs @@ -992,5 +993,67 @@ class AttentionWrapperTest(test.TestCase): expected_final_alignment_history=expected_final_alignment_history, name='testMultiAttention') + def testCustomizedAttention(self): + batch_size = 2 + max_time = 3 + num_units = 2 + memory = constant_op.constant([[[1., 1.], [2., 2.], [3., 3.]], + [[4., 4.], [5., 5.], [6., 6.]]]) + memory_sequence_length = constant_op.constant([3, 2]) + attention_mechanism = wrapper.BahdanauAttention(num_units, memory, + memory_sequence_length) + + # Sets all returned values to be all ones. + def _customized_attention(unused_attention_mechanism, unused_cell_output, + unused_attention_state, unused_attention_layer): + """Customized attention. + + Returns: + attention: `Tensor` of shape [batch_size, num_units], attention output. + alignments: `Tensor` of shape [batch_size, max_time], sigma value for + each input memory (prob. function of input keys). + next_attention_state: A `Tensor` representing the next state for the + attention. + """ + attention = array_ops.ones([batch_size, num_units]) + alignments = array_ops.ones([batch_size, max_time]) + next_attention_state = alignments + return attention, alignments, next_attention_state + + attention_cell = wrapper.AttentionWrapper( + rnn_cell.LSTMCell(2), + attention_mechanism, + attention_layer_size=None, # don't use attention layer. + output_attention=False, + alignment_history=(), + attention_fn=_customized_attention, + name='attention') + self.assertEqual(num_units, attention_cell.output_size) + + initial_state = attention_cell.zero_state( + batch_size=2, dtype=dtypes.float32) + source_input_emb = array_ops.ones([2, 3, 2]) + source_input_length = constant_op.constant([3, 2]) + + # 'state' is a tuple of + # (cell_state, h, attention, alignments, alignment_history, attention_state) + output, state = rnn.dynamic_rnn( + attention_cell, + inputs=source_input_emb, + sequence_length=source_input_length, + initial_state=initial_state, + dtype=dtypes.float32) + + with self.session() as sess: + sess.run(variables.global_variables_initializer()) + output_value, state_value = sess.run([output, state], feed_dict={}) + self.assertAllEqual(np.array([2, 3, 2]), output_value.shape) + self.assertAllClose(np.array([[1., 1.], [1., 1.]]), state_value.attention) + self.assertAllClose( + np.array([[1., 1., 1.], [1., 1., 1.]]), state_value.alignments) + self.assertAllClose( + np.array([[1., 1., 1.], [1., 1., 1.]]), state_value.attention_state) + + if __name__ == '__main__': test.main() diff --git a/tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py b/tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py index 77e9f848b137911b53e1b4df5dd740fe38af55bb..60ec3efffe771a3a6d6f36ed4b51a34ef9509612 100644 --- a/tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py +++ b/tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py @@ -1088,7 +1088,8 @@ class AttentionWrapper(rnn_cell_impl.RNNCell): output_attention=True, initial_cell_state=None, name=None, - attention_layer=None): + attention_layer=None, + attention_fn=None): """Construct the `AttentionWrapper`. **NOTE** If you are using the `BeamSearchDecoder` with a cell wrapped in @@ -1132,7 +1133,9 @@ class AttentionWrapper(rnn_cell_impl.RNNCell): feed the context and cell output into the attention layer to generate attention at each time step. If attention_mechanism is a list, attention_layer_size must be a list of the same length. If - attention_layer is set, this must be None. + attention_layer is set, this must be None. If attention_fn is set, + it must guaranteed that the outputs of attention_fn also meet the + above requirements. alignment_history: Python boolean, whether to store alignment history from all time steps in the final output state (currently stored as a time major `TensorArray` on which you must call `stack()`). @@ -1158,6 +1161,12 @@ class AttentionWrapper(rnn_cell_impl.RNNCell): the context as attention at each time step. If attention_mechanism is a list, attention_layer must be a list of the same length. If attention_layers_size is set, this must be None. + attention_fn: An optional callable function that allows users to provide + their own customized attention function, which takes input + (attention_mechanism, cell_output, attention_state, attention_layer) and + outputs (attention, alignments, next_attention_state). If provided, + the attention_layer_size should be the size of the outputs of + attention_fn. Raises: TypeError: `attention_layer_size` is not None and (`attention_mechanism` @@ -1240,6 +1249,10 @@ class AttentionWrapper(rnn_cell_impl.RNNCell): tensor_shape.dimension_value(attention_mechanism.values.shape[-1]) for attention_mechanism in attention_mechanisms) + if attention_fn is None: + attention_fn = _compute_attention + self._attention_fn = attention_fn + self._cell = cell self._attention_mechanisms = attention_mechanisms self._cell_input_fn = cell_input_fn @@ -1443,7 +1456,7 @@ class AttentionWrapper(rnn_cell_impl.RNNCell): all_attention_states = [] maybe_all_histories = [] for i, attention_mechanism in enumerate(self._attention_mechanisms): - attention, alignments, next_attention_state = _compute_attention( + attention, alignments, next_attention_state = self._attention_fn( attention_mechanism, cell_output, previous_attention_state[i], self._attention_layers[i] if self._attention_layers else None) alignment_history = previous_alignment_history[i].write( diff --git a/tensorflow/contrib/seq2seq/python/ops/beam_search_decoder.py b/tensorflow/contrib/seq2seq/python/ops/beam_search_decoder.py index ab36848f13ab3078cd232c18f140188e12db703b..8f8f057702951094758b277ce060955f3dc6e99d 100644 --- a/tensorflow/contrib/seq2seq/python/ops/beam_search_decoder.py +++ b/tensorflow/contrib/seq2seq/python/ops/beam_search_decoder.py @@ -921,6 +921,7 @@ def _get_scores(log_probs, sequence_lengths, length_penalty_weight, """ length_penalty_ = _length_penalty( sequence_lengths=sequence_lengths, penalty_factor=length_penalty_weight) + length_penalty_ = math_ops.cast(length_penalty_, dtype=log_probs.dtype) scores = log_probs / length_penalty_ coverage_penalty_weight = ops.convert_to_tensor( diff --git a/tensorflow/contrib/solvers/python/kernel_tests/lanczos_test.py b/tensorflow/contrib/solvers/python/kernel_tests/lanczos_test.py index 8fcd7aeef6a6964902666a4f3c17e05b0c7b52ee..f31bdbd399c9de4f2f5d557b75b1ece6d64a765e 100644 --- a/tensorflow/contrib/solvers/python/kernel_tests/lanczos_test.py +++ b/tensorflow/contrib/solvers/python/kernel_tests/lanczos_test.py @@ -19,6 +19,7 @@ from __future__ import print_function import numpy as np +from tensorflow.python import tf2 from tensorflow.contrib.solvers.python.ops import lanczos from tensorflow.contrib.solvers.python.ops import util from tensorflow.python.framework import constant_op @@ -80,7 +81,8 @@ if __name__ == "__main__": for shape in [[4, 4], [7, 4], [5, 8]]: for orthogonalize in True, False: for steps in range(1, min(shape) + 1): - for use_static_shape in True, False: + # TF2 does not support placeholders so we skip it + for use_static_shape in set([True, tf2.enabled()]): arg_string = "%s_%s_%s_%s_staticshape_%s" % ( dtype.__name__, "_".join(map(str, shape)), orthogonalize, steps, use_static_shape) diff --git a/tensorflow/contrib/solvers/python/kernel_tests/least_squares_test.py b/tensorflow/contrib/solvers/python/kernel_tests/least_squares_test.py index 2a9100903aae5689919a6b25fcb18ff192f250b3..841a41a2339824ab8ca15f4bdd74be697cd6fe9f 100644 --- a/tensorflow/contrib/solvers/python/kernel_tests/least_squares_test.py +++ b/tensorflow/contrib/solvers/python/kernel_tests/least_squares_test.py @@ -19,6 +19,7 @@ from __future__ import print_function import numpy as np +from tensorflow.python import tf2 from tensorflow.contrib.solvers.python.ops import least_squares from tensorflow.contrib.solvers.python.ops import util from tensorflow.python.framework import constant_op @@ -76,7 +77,8 @@ def _get_least_squares_tests(dtype_, use_static_shape_, shape_): if __name__ == "__main__": for dtype in np.float32, np.float64: for shape in [[4, 4], [8, 5], [3, 7]]: - for use_static_shape in True, False: + # TF2 does not support placeholders under eager so we skip it + for use_static_shape in set([True, tf2.enabled()]): arg_string = "%s_%s_staticshape_%s" % (dtype.__name__, "_".join(map(str, shape)), use_static_shape) diff --git a/tensorflow/contrib/solvers/python/kernel_tests/linear_equations_test.py b/tensorflow/contrib/solvers/python/kernel_tests/linear_equations_test.py index a0e6eb87bc06fb1303a7eb86fa6760458f20a9b9..10807f7a80617e56abeb6d13ce419a49a2269aac 100644 --- a/tensorflow/contrib/solvers/python/kernel_tests/linear_equations_test.py +++ b/tensorflow/contrib/solvers/python/kernel_tests/linear_equations_test.py @@ -19,6 +19,7 @@ from __future__ import print_function import numpy as np +from tensorflow.python import tf2 from tensorflow.contrib.solvers.python.ops import linear_equations from tensorflow.contrib.solvers.python.ops import util from tensorflow.python.framework import constant_op @@ -113,7 +114,8 @@ def _get_linear_equations_tests(dtype_, use_static_shape_, shape_): if __name__ == "__main__": for dtype in np.float32, np.float64: for size in 1, 4, 10: - for use_static_shape in True, False: + # TF2 does not support placeholders under eager so we skip it + for use_static_shape in set([True, tf2.enabled()]): shape = [size, size] arg_string = "%s_%s_staticshape_%s" % (dtype.__name__, size, use_static_shape) diff --git a/tensorflow/contrib/tensorrt/README.md b/tensorflow/contrib/tensorrt/README.md index caf8b6db0dc0a220d593f9c0afc9464ca51a1e05..a9c2ad78a3db409e6e8669c48c4df37c8db19c4b 100644 --- a/tensorflow/contrib/tensorrt/README.md +++ b/tensorflow/contrib/tensorrt/README.md @@ -1,8 +1,46 @@ -# Using TensorRT in TensorFlow +# Using TensorRT in TensorFlow (TF-TRT) -This module provides necessary bindings and introduces TRT_engine_op operator -that wraps a subgraph in TensorRT. This is still a work in progress but should -be useable with most common graphs. +This module provides necessary bindings and introduces `TRTEngineOp` operator +that wraps a subgraph in TensorRT. This module is under active development. + +## Installing TF-TRT + +Currently TensorFlow nightly builds include TF-TRT by default, which means you +don't need to install TF-TRT separately. You can pull the latest TF containers +from docker hub or install the latest TF pip package to get access to the latest +TF-TRT. + +If you want to use TF-TRT on NVIDIA Jetson platform, you can find the download +links for the relevant TensorFlow pip packages here: +https://docs.nvidia.com/deeplearning/dgx/index.html#installing-frameworks-for-jetson + +## Installing TensorRT + +In order to make use of TF-TRT, you will need a local installation of TensorRT. +Installation instructions for compatibility with TensorFlow are provided on the +[TensorFlow GPU support](https://www.tensorflow.org/install/gpu) guide. + +## Examples + +You can find example scripts for running inference on deep learning models in +this repository: https://github.com/tensorflow/tensorrt + +We have used these examples to verify the accuracy and performance of TF-TRT. +For more information see +[Verified Models](https://docs.nvidia.com/deeplearning/dgx/integrate-tf-trt/index.html#verified-models). + +## Documentation + +[TF-TRT documentation](https://docs.nvidia.com/deeplearning/dgx/integrate-tf-trt/index.html) +gives an overview of the supported functionalities, provides tutorials and +verified models, explains best practices with troubleshooting guides. + +## Tests + +TF-TRT includes both Python tests and C++ unit tests. Most of Python tests are +located in the test directory and they can be executed using `bazel test` or +directly with the Python command. Most of the C++ unit tests are used to test +the conversion functions that convert each TF op to a number of TensorRT layers. ## Compilation @@ -18,12 +56,3 @@ bazel build --config=cuda --config=opt //tensorflow/tools/pip_package:build_pip_ bazel-bin/tensorflow/tools/pip_package/build_pip_package /tmp/ ``` -After the installation of tensorflow package, TensorRT transformation will be -available. An example use can be found in test/test_tftrt.py script - -## Installing TensorRT 3.0.4 - -In order to make use of TensorRT integration, you will need a local installation -of TensorRT 3.0.4 from the [NVIDIA Developer website](https://developer.nvidia.com/tensorrt). -Installation instructions for compatibility with TensorFlow are provided on the -[TensorFlow GPU support](https://www.tensorflow.org/install/gpu) guide. diff --git a/tensorflow/contrib/tensorrt/convert/convert_graph.cc b/tensorflow/contrib/tensorrt/convert/convert_graph.cc index ecbd4ea802083cd742b496a65a13b72eb9eda9d9..746514b930c6c4c602c727a51313a8c5da271fa6 100644 --- a/tensorflow/contrib/tensorrt/convert/convert_graph.cc +++ b/tensorflow/contrib/tensorrt/convert/convert_graph.cc @@ -59,7 +59,7 @@ limitations under the License. #if GOOGLE_CUDA #if GOOGLE_TENSORRT #include "cuda/include/cuda_runtime_api.h" -#include "tensorrt/NvInfer.h" +#include "tensorrt/include/NvInfer.h" namespace tensorflow { namespace tensorrt { namespace convert { @@ -89,49 +89,52 @@ Status TrtCandidateSelector::IsTensorRTCandidate(const tensorflow::Node* node) { // TODO(laigd): move this set to TrtNodeValidator where it should belong. // LINT.IfChange static const std::set candidate_ops = { - "Identity", - "Snapshot", - "Const", - "Conv2D", - "MaxPool", - "BiasAdd", - "Relu", - "Sigmoid", - "Tanh", + "Abs", "Add", - "Mul", - "Sub", - "Rsqrt", - "Pad", - "Mean", "AvgPool", + "BatchMatMul", + "BiasAdd", "ConcatV2", + "Const", + "Conv2D", "DepthwiseConv2dNative", - "FusedBatchNorm", - "FusedBatchNormV2", "Div", - "RealDiv", - "Rsqrt", - "Reciprocal", "Exp", + "ExpandDims", + "FusedBatchNorm", + "FusedBatchNormV2", + "Identity", "Log", - "Sqrt", - "Abs", - "Neg", - "Transpose", - "Reshape", "MatMul", - "BatchMatMul", - "Softmax", - "Minimum", - "Maximum", - "TopKV2", - "Sum", - "Prod", "Max", + "MaxPool", + "Maximum", + "Mean", "Min", + "Minimum", + "Mul", + "Neg", + "Pad", + "Prod", + "RealDiv", + "Reciprocal", + "Relu", "Relu6", + "Reshape", + "Rsqrt", + "Rsqrt", + "Sigmoid", + "Snapshot", + "Softmax", + "Sqrt", "Square", + "Squeeze", + "StridedSlice", + "Sub", + "Sum", + "Tanh", + "TopKV2", + "Transpose", }; bool is_supported_op_type = (candidate_ops.count(node->type_string()) || @@ -320,6 +323,13 @@ tensorflow::Status ConvertGraphDefToTensorRT( return Status::OK(); } +struct EdgePtrCompare { + bool operator()(const tensorflow::Edge* lhs, + const tensorflow::Edge* rhs) const { + return lhs->id() < rhs->id(); + } +}; + // Function to get subsegment information structure. tensorflow::Status GetEngineInfo( const tensorflow::Graph* g, @@ -358,8 +368,12 @@ tensorflow::Status GetEngineInfo( } const int node_id = node->id(); subgraph_node_ids.push_back(node_id); - // Create input connections. - for (const auto edge : node->in_edges()) { + // Create input connections. Sort edges first to make determnistic since + // in_edges is a set of pointers. + std::vector in_edges(node->in_edges().begin(), + node->in_edges().end()); + std::sort(in_edges.begin(), in_edges.end(), EdgePtrCompare()); + for (const auto edge : in_edges) { auto input_node = edge->src(); if (input_node->IsSource() || segment_nodes.count(input_node->name())) { continue; @@ -407,8 +421,12 @@ tensorflow::Status GetEngineInfo( node_id, edge->dst_input(), /*input_edge=*/true, port); } } - // Create output connections. - for (const auto edge : node->out_edges()) { + // Create output connections. Sort edges first to make determnistic since + // out_edges is a set of pointers. + std::vector out_edges(node->out_edges().begin(), + node->out_edges().end()); + std::sort(out_edges.begin(), out_edges.end(), EdgePtrCompare()); + for (const auto edge : out_edges) { auto output_node = edge->dst(); if (output_node->IsSink() || segment_nodes.count(output_node->name())) { continue; @@ -585,6 +603,14 @@ tensorflow::Status CreateTRTNode(const std::vector& infos, int pos, } } } + // We don't support segments with no inputs. Fall back to native TF here to + // avoid crash later. Constant folding should've folded the ops that make up + // these segments. + if (inputs.empty()) { + return tensorflow::errors::Internal( + "Segment has no inputs (possible " + "constfold failure)"); + } const bool calibrate_int8 = (info.precision_mode == INT8MODE && info.use_calibration); diff --git a/tensorflow/contrib/tensorrt/convert/convert_nodes.cc b/tensorflow/contrib/tensorrt/convert/convert_nodes.cc index 938cadc7c44270cc48eba73fe33f6559193ac4b3..adf8831b960172fc29b5d631e5b0533318d4764d 100644 --- a/tensorflow/contrib/tensorrt/convert/convert_nodes.cc +++ b/tensorflow/contrib/tensorrt/convert/convert_nodes.cc @@ -48,7 +48,7 @@ limitations under the License. #if GOOGLE_CUDA #if GOOGLE_TENSORRT -#include "tensorrt/NvInfer.h" +#include "tensorrt/include/NvInfer.h" // Check if the types are equal. Cast to int first so that failure log message // would work! @@ -120,6 +120,15 @@ inline nvinfer1::Dims TensorShapeToTrtDims(const TensorShapeType& shape, return trt_dims; } +Status TensorShapeArrayToTrtDims(const std::vector& shape, + nvinfer1::Dims* out, + bool ignore_first_dim = false) { + PartialTensorShape tensor_shape; + TF_RETURN_IF_ERROR(TensorShapeUtils::MakeShape(shape, &tensor_shape)); + *out = TensorShapeToTrtDims(tensor_shape, ignore_first_dim); + return tensorflow::Status::OK(); +} + void GetOutputProperties(const grappler::GraphProperties& graph_properties, const Node* node, const int out_port, PartialTensorShape* shape, @@ -623,6 +632,11 @@ bool TFAttrs::get(const string& key) const { return this->at(key)->b(); } +template <> +int TFAttrs::get(const string& key) const { + return this->at(key)->i(); +} + // TODO(jie): reorder4 & reorder2 should be merged? // TODO(aaroey): fix the order of parameters. template @@ -1524,6 +1538,24 @@ enum class ConvolutionType { DEFAULT, DEPTHWISE_CONV }; tensorflow::Status ConvertConv2DHelper(OpConverterParams* params, int group) { const auto& inputs = params->inputs; const auto& node_def = params->node_def; + if (inputs.at(0).is_weights()) { + return tensorflow::errors::Unimplemented( + node_def.op(), " is only implemented for tensors, not weights, at ", + node_def.name()); + } + if (inputs.at(1).is_tensor()) { + return tensorflow::errors::Unimplemented("Kernel for ", node_def.op(), + " must be constant weights, at ", + node_def.name()); + } + TRT_ShapedWeights weights_rsck = inputs.at(1).weights(); + VLOG(2) << "weight shape: " << weights_rsck.DebugString(); + if (weights_rsck.shape_.nbDims != 4) { + return tensorflow::errors::Internal( + "Conv2D expects kernel of dimension 4, at: " + node_def.name()); + } + if (params->validation_only) return tensorflow::Status::OK(); + const nvinfer1::ITensor* tensor = inputs.at(0).tensor(); TFAttrs attrs(node_def); @@ -1545,12 +1577,6 @@ tensorflow::Status ConvertConv2DHelper(OpConverterParams* params, int group) { if (num_groups == 0) num_groups = tensor_dim.d[0]; // depthwise convolution VLOG(2) << "groups count: " << num_groups; - TRT_ShapedWeights weights_rsck = inputs.at(1).weights(); - VLOG(2) << "weight shape: " << weights_rsck.DebugString(); - if (weights_rsck.shape_.nbDims != 4) { - return tensorflow::errors::Internal( - "Conv2D expects kernel of dimension 4, at: " + node_def.name()); - } if (params->converter->precision_mode() == FP16MODE) { weights_rsck = ConvertFP32ToFP16(params->weight_store, inputs.at(1).weights()); @@ -1637,7 +1663,7 @@ tensorflow::Status ConvertConv2DHelper(OpConverterParams* params, case ConvolutionType::DEPTHWISE_CONV: return ConvertConv2DHelper(params, 0); } - return tensorflow::errors::Unimplemented("unsupported convolution type at, " + + return tensorflow::errors::Unimplemented("Unsupported convolution type, at ", params->node_def.name()); } @@ -1880,6 +1906,372 @@ tensorflow::Status ConvertReshape(OpConverterParams* params) { return tensorflow::Status::OK(); } +tensorflow::Status ConvertExpandDims(OpConverterParams* params) { + const auto& inputs = params->inputs; + const auto& node_def = params->node_def; + if (inputs.size() != 2) { + return tensorflow::errors::InvalidArgument( + "Two inputs expected for ExpandDims, at ", node_def.name()); + } + if (inputs.at(0).is_weights()) { + return tensorflow::errors::Unimplemented( + "ExpandDims expects tensor for input, at ", node_def.name()); + } + if (!inputs.at(1).is_weights()) { + return tensorflow::errors::InvalidArgument( + "ExpandDims expects weights for axis, at ", node_def.name()); + } + // Get input shape as vector. + TRT_TensorOrWeights input_tensor = inputs.at(0); + const nvinfer1::Dims dims = input_tensor.GetTrtDims(); + std::vector input_dims(dims.d, dims.d + dims.nbDims); + // Add batch dim back. + input_dims.insert(input_dims.begin(), -1); + const int input_rank = input_dims.size(); + // Get axis to expand on. + TRT_ShapedWeights weights = inputs.at(1).weights(); + if (weights.count() != 1) { + return tensorflow::errors::InvalidArgument( + "ExpandDims axis must be a scalar, at ", node_def.name()); + } + const int* weights_ptr = + static_cast(const_cast(weights.GetValues())); + int axis = weights_ptr[0]; + // Make sure axis is valid. + if ((axis < (-input_rank - 1)) || (axis > input_rank)) { + return tensorflow::errors::InvalidArgument( + "Axis for ExpandDims is invalid, must be in the range " + "[-rank(input) - 1, rank(input)], at ", + node_def.name()); + } + // Convert negative axis to corresponding positive axis. + if (axis < 0) axis += input_rank + 1; + if (axis == 0) { + return tensorflow::errors::Unimplemented( + "Modifying batch dimension is not supported for ExpandDims, at ", + node_def.name()); + } + if (params->validation_only) return Status::OK(); + + // ExpandDims: Insert new dim of size 1. + input_dims.insert(input_dims.begin() + axis, 1); + // Reshape tensor. + nvinfer1::Dims new_dims; + TF_RETURN_IF_ERROR(TensorShapeArrayToTrtDims(input_dims, &new_dims, + /*ignore_first_dim=*/true)); + const nvinfer1::ITensor* output_tensor = nullptr; + TF_RETURN_IF_ERROR(params->converter->PrepareTensorForShape( + input_tensor, new_dims, &output_tensor)); + params->outputs->push_back( + TRT_TensorOrWeights(const_cast(output_tensor))); + return tensorflow::Status::OK(); +} + +tensorflow::Status ConvertSqueeze(OpConverterParams* params) { + const auto& inputs = params->inputs; + const auto& node_def = params->node_def; + if (inputs.size() != 1) { + return tensorflow::errors::InvalidArgument( + "One input expected for Squeeze, at ", node_def.name()); + } + if (inputs.at(0).is_weights()) { + return tensorflow::errors::Unimplemented( + "Squeeze expects tensor for input, at ", node_def.name()); + } + // Get input shape. + TRT_TensorOrWeights input_tensor = inputs.at(0); + const nvinfer1::Dims dims = input_tensor.GetTrtDims(); + std::vector input_dims(dims.d, dims.d + dims.nbDims); + // Add batch dim back. + input_dims.insert(input_dims.begin(), -1); + const int input_rank = input_dims.size(); + // Mark axes to remove by setting them to 0. + TFAttrs attrs(node_def); + auto squeeze_dims = attrs.get>("squeeze_dims"); + if (squeeze_dims.size() == 0) { + return tensorflow::errors::Unimplemented( + "Squeeze is only implemented for explicit dims, at ", node_def.name()); + } + for (int axis : squeeze_dims) { + // Make sure axis is valid. + if ((axis < -input_rank) || (axis >= input_rank)) { + return tensorflow::errors::InvalidArgument( + "Axis for Squeeze is invalid, must be in the range " + "[-rank(input), rank(input)), at ", + node_def.name()); + } + // Convert negative axis to corresponding positive axis. + if (axis < 0) axis += input_rank; + // Don't squeeze batch dim. + if (axis == 0) { + return tensorflow::errors::Unimplemented( + "Cannot squeeze batch dimension, at ", node_def.name()); + } + // Make sure target dimension is size 1. + if (input_dims[axis] != 1) { + return tensorflow::errors::InvalidArgument( + "Cannot squeeze a dimension which isn't size 1, at ", + node_def.name()); + } + // Mark dim for removal by setting to 0. + input_dims[axis] = 0; + } + if (params->validation_only) return Status::OK(); + + // Remove all dims which are equal to 0. + input_dims.erase(std::remove(input_dims.begin(), input_dims.end(), 0), + input_dims.end()); + // Reshape tensor. + nvinfer1::Dims new_dims; + TF_RETURN_IF_ERROR(TensorShapeArrayToTrtDims(input_dims, &new_dims, + /*ignore_first_dim=*/true)); + const nvinfer1::ITensor* output_tensor = nullptr; + TF_RETURN_IF_ERROR(params->converter->PrepareTensorForShape( + input_tensor, new_dims, &output_tensor)); + params->outputs->push_back( + TRT_TensorOrWeights(const_cast(output_tensor))); + return tensorflow::Status::OK(); +} + +// Gets the bounds (start or end) from the weights of a StridedSlice op. +tensorflow::Status GetStridedSliceBound(const std::vector& input_dims, + const TRT_ShapedWeights& bound_weights, + int mask, bool begin, string node_name, + std::vector* output_bound) { + const string bound_name = (begin) ? "begin" : "end"; + const int* weights_ptr = static_cast(bound_weights.GetValues()); + *output_bound = + std::vector(weights_ptr, weights_ptr + bound_weights.count()); + if (output_bound->size() != input_dims.size()) { + return tensorflow::errors::InvalidArgument( + "StridedSlice \"", bound_name, "\" specified ", + std::to_string(output_bound->size()), " dimensions, but input rank is ", + std::to_string(input_dims.size()), ", at ", node_name); + } + for (int i = 0; i < output_bound->size(); i++) { + if ((1 << i) & mask) { + // Apply mask. + (*output_bound)[i] = (begin) ? 0 : input_dims[i]; + // Masked bound will always result in a valid, non-negative bound, so we + // don't need the following checks. For the common case of using masks on + // a undefined batch dim (-1), we specifically don't want to do the + // following checks because they will erroneously detect an out of range + // bound or try to correct the negative value. + continue; + } + // Make sure bound is valid. + if (((*output_bound)[i] < -input_dims[i]) || + ((*output_bound)[i] > input_dims[i])) { + return tensorflow::errors::InvalidArgument( + bound_name, " value of ", std::to_string((*output_bound)[i]), + " for StridedSlice is invalid, must be in the range " + "[-dim_size(i), dim_size(i)], at ", + node_name); + } + // Convert negative values to their positive equivalent. + if ((*output_bound)[i] < 0) { + (*output_bound)[i] += input_dims[i]; + } + } + return tensorflow::Status::OK(); +} + +tensorflow::Status ConvertStridedSlice(OpConverterParams* params) { + const auto& inputs = params->inputs; + const auto& node_def = params->node_def; + if (inputs.size() != 4) { + return tensorflow::errors::InvalidArgument( + "StridedSlice expects 4 inputs, at ", node_def.name()); + } + if (!inputs.at(1).is_weights() || !inputs.at(2).is_weights() || + !inputs.at(3).is_weights()) { + return tensorflow::errors::InvalidArgument( + "StridedSlice expects weights for begin, end, and strides, at ", + node_def.name()); + } + if (!inputs.at(0).is_tensor()) { + return tensorflow::errors::Unimplemented( + "StridedSlice is only implemented for tensors, at ", node_def.name()); + } + // Get input dims. + nvinfer1::Dims dims = inputs.at(0).GetTrtDims(); + std::vector input_dims(dims.d, dims.d + dims.nbDims); + if (inputs.at(0).is_tensor()) { + // Temporarily add batch dimension so that indexes line up properly. + input_dims.insert(input_dims.begin(), inputs.at(0).batch_size()); + } + if (input_dims.size() > 4) { + return tensorflow::errors::Unimplemented( + "StridedSlice is not implemented for tensors with rank > 4, at ", + node_def.name()); + } + TFAttrs attrs(node_def); + // Get begin and end bounds per axis. + std::vector begin, end; + TF_RETURN_IF_ERROR(GetStridedSliceBound(input_dims, inputs.at(1).weights(), + attrs.get("begin_mask"), true, + node_def.name(), &begin)); + TF_RETURN_IF_ERROR(GetStridedSliceBound(input_dims, inputs.at(2).weights(), + attrs.get("end_mask"), false, + node_def.name(), &end)); + // Get strides per axis (must all be 1). + TRT_ShapedWeights stride_weights = inputs.at(3).weights(); + const int* stride_weights_ptr = static_cast(stride_weights.GetValues()); + std::vector strides(stride_weights_ptr, + stride_weights_ptr + stride_weights.count()); + for (int x : strides) { + if (x != 1) { + return tensorflow::errors::Unimplemented( + "StridedSlice is only implemented for stride of 1, at ", + node_def.name()); + } + } + // Unsupported mask options. + for (const string& attr : + {"ellipsis_mask", "new_axis_mask", "shrink_axis_mask"}) { + int attr_val = attrs.get(attr); + if (attr_val != 0) { + return tensorflow::errors::Unimplemented( + attr, " is not supported for StridedSlice, at ", node_def.name()); + } + } + + nvinfer1::ITensor* tensor = + const_cast(inputs.at(0).tensor()); + // Reshape if necessary to 4-D, since IPaddingLayer requires a 4-D input. + const bool need_reshape = (input_dims.size() != 4); + int reshape_dims_added = 0; + nvinfer1::Dims reshape_dims; + if (need_reshape) { + // Add new dims after batch dim until tensor is 4D. + while (input_dims.size() < 4) { + input_dims.insert(input_dims.begin() + 1, 1); + begin.insert(begin.begin() + 1, 0); + end.insert(end.begin() + 1, 1); + reshape_dims_added++; + } + TF_RETURN_IF_ERROR(TensorShapeArrayToTrtDims(input_dims, &reshape_dims, + /*ignore_first_dim=*/true)); + } + // Find dimensions which need to be sliced. + std::vector pad_dims; + for (int i = 0; i < input_dims.size(); i++) { + if ((begin[i] != 0) || (end[i] != input_dims[i])) { + if (i == 0) { + return tensorflow::errors::Unimplemented( + "StridedSlice can't modify batch dim, at ", node_def.name()); + } else if ((end[i] - begin[i]) < 0) { + return tensorflow::errors::InvalidArgument( + "New size of sliced dimension is negative, at ", node_def.name()); + } + pad_dims.push_back(i); + } + } + if (pad_dims.size() == 0) { + // No dimensions are changed. We could create a padding layer anyway with + // values of 0. + if (params->validation_only) return Status::OK(); + params->outputs->push_back(inputs.at(0)); + return tensorflow::Status::OK(); + } else if (pad_dims.size() == 1) { + // Only one dim is modified but we have to have 2, mark a second dim which + // will have padding of 0. The dim we add is chosen to avoid an unecessary + // transpose. + if (pad_dims[0] != 2) { + pad_dims.push_back(2); + } else { + pad_dims.push_back(3); + } + } else if (pad_dims.size() > 2) { + return tensorflow::errors::Unimplemented( + "StridedSlice can only modify 2 dimensions, at ", node_def.name()); + } + std::sort(pad_dims.begin(), pad_dims.end()); + // Convert to pre/post padding values. Since TRT does not have a StridedSlice + // or Slice layer, we instead create an IPaddingLayer with negative padding. + nvinfer1::DimsHW pre_padding, post_padding; + for (int i = 0; i < pad_dims.size(); i++) { + const int axis = pad_dims[i]; + pre_padding.d[i] = -begin[axis]; + post_padding.d[i] = end[axis] - input_dims[axis]; + } + + // IPaddingLayer will always apply the padding to dims 2,3 (input format is + // NCHW). + const bool need_transpose = !(pad_dims[0] == 2 && pad_dims[1] == 3); + std::vector transpose_order(input_dims.size()); + std::vector inv_transpose_order(input_dims.size()); + if (need_transpose) { + if (pad_dims[0] == 1 && pad_dims[1] == 3) { + transpose_order = {0, 2, 1, 3}; + inv_transpose_order = {0, 2, 1, 3}; + } else if (pad_dims[0] == 1 && pad_dims[1] == 2) { + transpose_order = {0, 3, 1, 2}; + inv_transpose_order = {0, 2, 3, 1}; + } + } + if (params->validation_only) return Status::OK(); + + // Start conversion. + if (need_reshape) { + const nvinfer1::ITensor* output_tensor = nullptr; + TF_RETURN_IF_ERROR(params->converter->PrepareTensorForShape( + inputs.at(0), reshape_dims, &output_tensor)); + tensor = const_cast(output_tensor); + } + if (need_transpose) { + const nvinfer1::ITensor* output_tensor = nullptr; + TF_RETURN_IF_ERROR(params->converter->TransposeTensor( + tensor, transpose_order, &output_tensor)); + tensor = const_cast(output_tensor); + } + + // Add padding layer + nvinfer1::IPaddingLayer* layer = params->converter->network()->addPadding( + *const_cast(tensor), pre_padding, post_padding); + TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name()); + params->converter->MarkQuantizationRangesAsInferrable(tensor, + layer->getOutput(0)); + tensor = layer->getOutput(0); + + // Restore transpose + if (need_transpose) { + const nvinfer1::ITensor* output_tensor = nullptr; + TF_RETURN_IF_ERROR(params->converter->TransposeTensor( + tensor, inv_transpose_order, &output_tensor)); + tensor = const_cast(output_tensor); + } + // Restore reshape + if (need_reshape) { + // Calculate output dimensions + for (int i = 0; i < pad_dims.size(); i++) { + const int axis = pad_dims[i]; + input_dims[axis] = end[axis] - begin[axis]; + } + // Remove added 1 dimensions + for (int i = 0; i < reshape_dims_added; i++) { + int value = input_dims[1]; + if (value != 1) { + return tensorflow::errors::Internal( + "StridedSlice error when reshaping, at ", node_def.name()); + } + input_dims.erase(input_dims.begin() + 1); + } + + nvinfer1::Dims new_dims; + TF_RETURN_IF_ERROR(TensorShapeArrayToTrtDims(input_dims, &new_dims, + /*ignore_first_dim=*/true)); + const nvinfer1::ITensor* output_tensor = nullptr; + TF_RETURN_IF_ERROR(params->converter->PrepareTensorForShape( + TRT_TensorOrWeights(tensor), new_dims, &output_tensor)); + tensor = const_cast(output_tensor); + } + + params->outputs->push_back( + TRT_TensorOrWeights(const_cast(tensor))); + return tensorflow::Status::OK(); +} + tensorflow::Status ConvertConv2D(OpConverterParams* params) { return ConvertConv2DHelper(params, ConvolutionType::DEFAULT); } @@ -1891,9 +2283,29 @@ tensorflow::Status ConvertConv2DDepthwise(OpConverterParams* params) { tensorflow::Status ConvertPool(OpConverterParams* params) { const auto& inputs = params->inputs; const auto& node_def = params->node_def; - const nvinfer1::ITensor* tensor = inputs.at(0).tensor(); + if (inputs.at(0).is_weights()) { + return tensorflow::errors::Unimplemented( + node_def.op(), " is only implemented for tensors, not weights, at ", + node_def.name()); + } + nvinfer1::PoolingType type; + if (node_def.op() == "MaxPool") { + type = nvinfer1::PoolingType::kMAX; + } else if (node_def.op() == "AvgPool") { + type = nvinfer1::PoolingType::kAVERAGE; + } else { + return tensorflow::errors::Unimplemented( + "Unsupported pooling type: ", node_def.op(), ", at ", node_def.name()); + } TFAttrs attrs(node_def); + const string padding_type = attrs.get("padding"); + if ((padding_type != "SAME") && (padding_type != "VALID")) { + return tensorflow::errors::Unimplemented( + "Unsupported padding type: ", padding_type, ", at ", node_def.name()); + } + if (params->validation_only) return Status::OK(); + const nvinfer1::ITensor* tensor = inputs.at(0).tensor(); int h_index = 2; int w_index = 3; const auto data_format = attrs.get("data_format"); @@ -1904,16 +2316,6 @@ tensorflow::Status ConvertPool(OpConverterParams* params) { const_cast(tensor), {0, 3, 1, 2}, &tensor)); } - nvinfer1::PoolingType type; - if (node_def.op() == "MaxPool") { - type = nvinfer1::PoolingType::kMAX; - } else if (node_def.op() == "AvgPool") { - type = nvinfer1::PoolingType::kAVERAGE; - } else { - return tensorflow::errors::Unimplemented("Unsupported pool type: ", - node_def.op()); - } - const auto tf_stride = attrs.get>("strides"); const nvinfer1::DimsHW stride(tf_stride[h_index], tf_stride[w_index]); @@ -1922,7 +2324,6 @@ tensorflow::Status ConvertPool(OpConverterParams* params) { auto tensor_dim = tensor->getDimensions(); std::vector> padding; - const string padding_type = attrs.get("padding"); if (padding_type == "SAME") { // This is NCHW tensor with no batch dimension. // 1 -> h @@ -1932,9 +2333,6 @@ tensorflow::Status ConvertPool(OpConverterParams* params) { {static_cast(tensor_dim.d[1]), static_cast(tensor_dim.d[2])}); } else if (padding_type == "VALID") { padding = {{0, 0}, {0, 0}}; - } else { - return tensorflow::errors::Unimplemented("Unsupported padding type: ", - padding_type); } if (padding[0].first != padding[0].second || @@ -2701,6 +3099,7 @@ tensorflow::Status ConvertPad(OpConverterParams* params) { return tensorflow::errors::Unimplemented( "Padding layer does not support padding on dimension 1 and 3 yet"); } + if (params->validation_only) return Status::OK(); bool legit_pad = true; nvinfer1::DimsHW pre_padding(0, 0); @@ -2804,6 +3203,7 @@ tensorflow::Status ConvertConcat(OpConverterParams* params) { inputs_vec.push_back(tensor_i); } + if (params->validation_only) return tensorflow::Status::OK(); // nvinfer1::ITensor const* tensor = inputs.at(0).tensor(); nvinfer1::IConcatenationLayer* layer = @@ -2825,12 +3225,35 @@ tensorflow::Status ConvertFusedBatchNorm(OpConverterParams* params) { auto data_format = attrs.get("data_format"); if (data_format != "NCHW") { return tensorflow::errors::Unimplemented( - "only data_format=NCHW is supported, at " + node_def.name()); + node_def.op(), " only supports data_format=NCHW, at ", node_def.name()); } bool is_training = attrs.get("is_training"); if (is_training) { + // Trying to use batchnorm in training mode is a very common problem. + // Because the error message will only be printed in VLOG(1) by the + // segmenter, we issue a special warning so that users will actually see it. + LOG(WARNING) << node_def.op() << " only supports is_training=false. If you " + << "are using Keras, please call " + << "keras.backend.set_learning_phase(0) before constructing " + << "your model. At " << node_def.name(); return tensorflow::errors::Unimplemented( - "only is_training=false is supported, at " + node_def.name()); + node_def.op(), " only supports is_training=false, at ", + node_def.name()); + } + if (inputs.at(0).is_weights()) { + return tensorflow::errors::Unimplemented( + node_def.op(), + " is only implemented for tensor inputs, not weights, at ", + node_def.name()); + } + for (int i = 1; i < 5; i++) { + if (inputs.at(i).is_tensor()) { + return tensorflow::errors::Unimplemented( + node_def.op(), + " must have constant inputs for scale, offset, mean and variance, " + "at ", + node_def.name()); + } } nvinfer1::ITensor const* tensor = inputs.at(0).tensor(); @@ -2845,7 +3268,7 @@ tensorflow::Status ConvertFusedBatchNorm(OpConverterParams* params) { for (int i = 1; i < 5; i++) { if (inputs.at(i).weights().type_ != parameter_type) { return tensorflow::errors::Unimplemented( - "Inconsistent parameter type for batchnormis not supported, at: " + + "Inconsistent parameter type for batchnorm is not supported, at: " + node_def.name()); } } @@ -2865,6 +3288,8 @@ tensorflow::Status ConvertFusedBatchNorm(OpConverterParams* params) { "Inconsistent batchnorm parameter count, at: " + node_def.name()); } } + if (params->validation_only) return Status::OK(); + // We could technically have two weights with different shape. // that requires two addScale op, arguably less performant TRT_ShapedWeights combined_scale_weights = @@ -3150,12 +3575,19 @@ static void RegisterValidatableOpConverters( std::unordered_map* registration) { // TODO(laigd): support all op types. (*registration)["BiasAdd"] = ConvertBiasAdd; + (*registration)["ConcatV2"] = ConvertConcat; (*registration)["Const"] = ConvertConst; - (*registration)["Transpose"] = ConvertTranspose; - (*registration)["Reshape"] = ConvertReshape; + (*registration)["Conv2D"] = ConvertConv2D; + (*registration)["DepthwiseConv2dNative"] = ConvertConv2DDepthwise; + (*registration)["ExpandDims"] = ConvertExpandDims; (*registration)["MatMul"] = ConvertMatMul; + (*registration)["Pad"] = ConvertPad; (*registration)["Relu6"] = ConvertRelu6; + (*registration)["Reshape"] = ConvertReshape; (*registration)["Square"] = ConvertSquare; + (*registration)["Squeeze"] = ConvertSqueeze; + (*registration)["StridedSlice"] = ConvertStridedSlice; + (*registration)["Transpose"] = ConvertTranspose; for (auto quantization_op_type : {"QuantizeAndDequantizeV2", "QuantizeAndDequantizeV3", @@ -3169,6 +3601,12 @@ static void RegisterValidatableOpConverters( for (auto activation_op_type : {"Relu", "Sigmoid", "Tanh"}) { (*registration)[activation_op_type] = ConvertActivation; } + for (auto pool_op_type : {"AvgPool", "MaxPool"}) { + (*registration)[pool_op_type] = ConvertPool; + } + for (auto normalization_op_type : {"FusedBatchNorm", "FusedBatchNormV2"}) { + (*registration)[normalization_op_type] = ConvertFusedBatchNorm; + } } void TrtNodeValidator::RegisterOpValidators() { @@ -3177,21 +3615,10 @@ void TrtNodeValidator::RegisterOpValidators() { void Converter::RegisterOpConverters() { RegisterValidatableOpConverters(&op_registry_); - - op_registry_["Conv2D"] = ConvertConv2D; - op_registry_["DepthwiseConv2dNative"] = ConvertConv2DDepthwise; - op_registry_["MaxPool"] = ConvertPool; - op_registry_["AvgPool"] = ConvertPool; // TODO(ben,jie): this is a temp hack. op_registry_["Identity"] = ConvertIdentity; // Identity should be removed op_registry_["Snapshot"] = ConvertIdentity; // Snapshot should be removed - op_registry_["Pad"] = ConvertPad; - - op_registry_["ConcatV2"] = ConvertConcat; - op_registry_["FusedBatchNorm"] = ConvertFusedBatchNorm; - op_registry_["FusedBatchNormV2"] = ConvertFusedBatchNorm; - op_registry_["Rsqrt"] = ConvertUnary; op_registry_["Reciprocal"] = ConvertUnary; op_registry_["Exp"] = ConvertUnary; diff --git a/tensorflow/contrib/tensorrt/convert/convert_nodes.h b/tensorflow/contrib/tensorrt/convert/convert_nodes.h index daa311119282221a5eccf4f166f67b479d0d3776..54e19b73957bccdae2b23bd3556de9ad00b864e5 100644 --- a/tensorflow/contrib/tensorrt/convert/convert_nodes.h +++ b/tensorflow/contrib/tensorrt/convert/convert_nodes.h @@ -34,7 +34,7 @@ limitations under the License. #if GOOGLE_CUDA #if GOOGLE_TENSORRT -#include "tensorrt/NvInfer.h" +#include "tensorrt/include/NvInfer.h" namespace tensorflow { namespace tensorrt { diff --git a/tensorflow/contrib/tensorrt/convert/convert_nodes_test.cc b/tensorflow/contrib/tensorrt/convert/convert_nodes_test.cc index 4790622e83ee1f77be2754a3655e6f8881609d26..a2ddfbffa5b0d8c421bcfe054097a9e42b79fe8f 100644 --- a/tensorflow/contrib/tensorrt/convert/convert_nodes_test.cc +++ b/tensorflow/contrib/tensorrt/convert/convert_nodes_test.cc @@ -44,7 +44,7 @@ limitations under the License. #if GOOGLE_TENSORRT #include "cuda/include/cuda.h" #include "cuda/include/cuda_runtime_api.h" -#include "tensorrt/NvInfer.h" +#include "tensorrt/include/NvInfer.h" namespace tensorflow { namespace tensorrt { @@ -2113,6 +2113,512 @@ TEST_F(OpConverterTest, ConvertActivation) { } } +TEST_F(OpConverterTest, ConvertExpandDims) { + { + // Input list is empty, should fail. + NodeDef node_def = MakeNodeDef("my_expanddims", "ExpandDims", {}); + RunValidationAndConversion( + node_def, error::INVALID_ARGUMENT, + "Two inputs expected for ExpandDims, at my_expanddims"); + } + + // Get the NodeDef for ExpandDims. + Scope s = Scope::NewRootScope(); + auto input = ops::Placeholder(s.WithOpName("input"), DT_FLOAT); + auto weights = ops::Placeholder(s.WithOpName("weights"), DT_INT32); + auto expanddims = + ops::ExpandDims(s.WithOpName("my_expanddims"), input, weights); + const NodeDef& node_def = expanddims.operation.node()->def(); + { + // Input is weights, should fail. + Reset(); + AddTestWeights("input", {1, 2, 3}, {1, 2, 3, 4, 5, 6}); + AddTestWeights("weights", {1}, {1}); + RunValidationAndConversion( + node_def, error::UNIMPLEMENTED, + "ExpandDims expects tensor for input, at my_expanddims"); + } + { + // Axis is a tensor, should fail. + Reset(); + AddTestTensor("input", {1, 2, 3}); + AddTestTensor("weights", {3}); + RunValidationAndConversion( + node_def, error::INVALID_ARGUMENT, + "ExpandDims expects weights for axis, at my_expanddims"); + } + { + // Add dim at batch dimension, should fail. + Reset(); + AddTestTensor("input", {1, 2, 3}); + AddTestWeights("weights", {1}, {0}); + RunValidationAndConversion( + node_def, error::UNIMPLEMENTED, + "Modifying batch dimension is not supported for ExpandDims, at " + "my_expanddims"); + } + { + // Add dim at batch dimension via negative axis, should fail. + Reset(); + AddTestTensor("input", {1, 2, 3}); + // Input is rank 4 (batch dim included) + AddTestWeights("weights", {1}, {-5}); + RunValidationAndConversion( + node_def, error::UNIMPLEMENTED, + "Modifying batch dimension is not supported for ExpandDims, at " + "my_expanddims"); + } + { + // Axis > rank(input), should fail. + Reset(); + AddTestTensor("input", {1, 2, 3}); + // Input is rank 4 (batch dim included) + AddTestWeights("weights", {1}, {5}); + RunValidationAndConversion( + node_def, error::INVALID_ARGUMENT, + "Axis for ExpandDims is invalid, must be in the range " + "[-rank(input) - 1, rank(input)], at my_expanddims"); + } + { + // Axis < -rank(input)-1, should fail. + Reset(); + AddTestTensor("input", {1, 2, 3}); + // Input is rank 4 (batch dim included) + AddTestWeights("weights", {1}, {-6}); + RunValidationAndConversion( + node_def, error::INVALID_ARGUMENT, + "Axis for ExpandDims is invalid, must be in the range " + "[-rank(input) - 1, rank(input)], at my_expanddims"); + } + + struct TestParams { + TestParams(const std::vector& input_dims, int axis, + const std::vector& expected_output_dims) + : input_dims(input_dims), + axis(axis), + expected_output_dims(expected_output_dims) {} + std::vector input_dims; + int axis; + std::vector expected_output_dims; + }; + + // Ok. + const int kExpandDimsOKCases = 8; + TestParams ok_params[kExpandDimsOKCases] = { + TestParams{{2, 3}, 1, {1, 2, 3}}, TestParams{{2, 3}, -3, {1, 2, 3}}, + TestParams{{2, 3}, 3, {2, 3, 1}}, TestParams{{2, 3}, -1, {2, 3, 1}}, + TestParams{{2, 3}, 2, {2, 1, 3}}, TestParams{{2, 3}, -2, {2, 1, 3}}, + TestParams{{6}, 1, {1, 6}}, TestParams{{6}, -1, {6, 1}}, + }; + for (int i = 0; i < kExpandDimsOKCases; ++i) { + Reset(); + AddTestTensor("input", ok_params[i].input_dims); + AddTestWeights("weights", {1}, {ok_params[i].axis}); + RunValidationAndConversion(node_def); + TRT_TensorOrWeights output; + TF_EXPECT_OK(GetTensorOrWeights("my_expanddims", &output)); + EXPECT_TRUE(output.is_tensor()); + ExpectTrtDimsEqualsArray(ok_params[i].expected_output_dims, + output.tensor()->getDimensions()); + + std::vector output_data(6); + BuildAndRun({{"input", {1, 2, 3, 4, 5, 6}}}, "my_expanddims", + &output_data); + EXPECT_THAT(output_data, ElementsAre(1, 2, 3, 4, 5, 6)); + } +} + +TEST_F(OpConverterTest, ConvertSqueeze) { + { + // Input list is empty, should fail. + NodeDef node_def = MakeNodeDef("my_squeeze", "Squeeze", {}); + RunValidationAndConversion(node_def, error::INVALID_ARGUMENT, + "One input expected for Squeeze, at my_squeeze"); + } + { + // No attrs, should fail. + Reset(); + Scope s = Scope::NewRootScope(); + auto input = ops::Placeholder(s.WithOpName("input"), DT_FLOAT); + auto squeeze = ops::Squeeze(s.WithOpName("my_squeeze"), input); + const NodeDef& node_def = squeeze.operation.node()->def(); + AddTestTensor("input", {1, 2, 3}); + RunValidationAndConversion( + node_def, error::UNIMPLEMENTED, + "Squeeze is only implemented for explicit dims, at my_squeeze"); + } + + // Get the NodeDef for Squeeze. + auto get_squeeze_nodedef = [](std::vector axis) -> NodeDef { + Scope s = Scope::NewRootScope(); + auto input = ops::Placeholder(s.WithOpName("input"), DT_FLOAT); + ops::Squeeze::Attrs squeeze_attrs; + squeeze_attrs.axis_ = gtl::ArraySlice(axis); + auto squeeze = + ops::Squeeze(s.WithOpName("my_squeeze"), input, squeeze_attrs); + return squeeze.operation.node()->def(); + }; + + { + // Input is weights, should fail. + Reset(); + NodeDef node_def = get_squeeze_nodedef({0}); + AddTestWeights("input", {1, 2, 3}, {1, 2, 3, 4, 5, 6}); + RunValidationAndConversion( + node_def, error::UNIMPLEMENTED, + "Squeeze expects tensor for input, at my_squeeze"); + } + { + // Squeeze batch dim, should fail. + Reset(); + NodeDef node_def = get_squeeze_nodedef({0}); + AddTestTensor("input", {1, 2, 3}); + RunValidationAndConversion(node_def, error::UNIMPLEMENTED, + "Cannot squeeze batch dimension, at my_squeeze"); + } + { + // Squeeze batch dim via negative axis, should fail. + Reset(); + NodeDef node_def = get_squeeze_nodedef({-4}); + AddTestTensor("input", {1, 2, 3}); + RunValidationAndConversion(node_def, error::UNIMPLEMENTED, + "Cannot squeeze batch dimension, at my_squeeze"); + } + { + // Squeeze >= rank(input), should fail. + Reset(); + NodeDef node_def = get_squeeze_nodedef({4}); + AddTestTensor("input", {1, 2, 3}); + RunValidationAndConversion( + node_def, error::INVALID_ARGUMENT, + "Axis for Squeeze is invalid, must be in the range " + "[-rank(input), rank(input)), at my_squeeze"); + } + { + // Squeeze < -rank(input), should fail. + Reset(); + NodeDef node_def = get_squeeze_nodedef({-5}); + AddTestTensor("input", {1, 2, 3}); + RunValidationAndConversion( + node_def, error::INVALID_ARGUMENT, + "Axis for Squeeze is invalid, must be in the range " + "[-rank(input), rank(input)), at my_squeeze"); + } + + struct TestParams { + TestParams(const std::vector& input_dims, const std::vector& axis, + const std::vector& expected_output_dims) + : input_dims(input_dims), + axis(axis), + expected_output_dims(expected_output_dims) {} + std::vector input_dims; + std::vector axis; + std::vector expected_output_dims; + }; + + // Ok. + const int kSqueezeOKCases = 10; + TestParams ok_params[kSqueezeOKCases] = { + TestParams{{1, 2, 3}, {1}, {2, 3}}, + TestParams{{1, 2, 3}, {-3}, {2, 3}}, + TestParams{{2, 3, 1}, {3}, {2, 3}}, + TestParams{{2, 3, 1}, {-1}, {2, 3}}, + TestParams{{1, 2, 1, 3, 1}, {1, 3, 5}, {2, 3}}, + TestParams{{1, 2, 1, 3, 1}, {3, 1, 5}, {2, 3}}, + TestParams{{1, 2, 1, 3, 1}, {-1, -3, -5}, {2, 3}}, + TestParams{{1, 2, 1, 3, 1}, {1, -3, 5}, {2, 3}}, + TestParams{{1, 6}, {1}, {6}}, + TestParams{{6, 1}, {2}, {6}}, + }; + for (int i = 0; i < kSqueezeOKCases; ++i) { + Reset(); + NodeDef node_def = get_squeeze_nodedef(ok_params[i].axis); + AddTestTensor("input", ok_params[i].input_dims); + RunValidationAndConversion(node_def); + TRT_TensorOrWeights output; + TF_EXPECT_OK(GetTensorOrWeights("my_squeeze", &output)); + EXPECT_TRUE(output.is_tensor()); + ExpectTrtDimsEqualsArray(ok_params[i].expected_output_dims, + output.tensor()->getDimensions()); + + std::vector output_data(6); + BuildAndRun({{"input", {1, 2, 3, 4, 5, 6}}}, "my_squeeze", + &output_data); + EXPECT_THAT(output_data, ElementsAre(1, 2, 3, 4, 5, 6)); + } +} + +TEST_F(OpConverterTest, ConvertStridedSlice) { + { + // Input list is empty, should fail. + NodeDef node_def = MakeNodeDef("my_strided_slice", "StridedSlice", {}); + RunValidationAndConversion( + node_def, error::INVALID_ARGUMENT, + "StridedSlice expects 4 inputs, at my_strided_slice"); + } + + // Get nodedef for StridedSlice layer. + auto get_strided_slice_nodedef = + [](int begin_mask = 0, int end_mask = 0, int ellipsis_mask = 0, + int new_axis_mask = 0, int shrink_axis_mask = 0) -> NodeDef { + Scope s = Scope::NewRootScope(); + auto input = ops::Placeholder(s.WithOpName("input"), DT_FLOAT); + auto begin = ops::Placeholder(s.WithOpName("begin"), DT_INT32); + auto end = ops::Placeholder(s.WithOpName("end"), DT_INT32); + auto strides = ops::Placeholder(s.WithOpName("strides"), DT_INT32); + ops::StridedSlice::Attrs attrs = ops::StridedSlice::Attrs() + .BeginMask(begin_mask) + .EndMask(end_mask) + .EllipsisMask(ellipsis_mask) + .NewAxisMask(new_axis_mask) + .ShrinkAxisMask(shrink_axis_mask); + auto strided_slice = ops::StridedSlice(s.WithOpName("my_strided_slice"), + input, begin, end, strides, attrs); + return strided_slice.operation.node()->def(); + }; + + { + NodeDef node_def = get_strided_slice_nodedef(); + AddTestWeights("input", {1, 2, 3}, {1, 2, 3, 4, 5, 6}); + AddTestWeights("begin", {4}, {0, 0, 0, 0}); + AddTestWeights("end", {4}, {1, 1, 2, 3}); + AddTestWeights("strides", {4}, {1, 1, 1, 1}); + RunValidationAndConversion( + node_def, error::UNIMPLEMENTED, + "StridedSlice is only implemented for tensors, at my_strided_slice"); + } + { + // Begin, end, strides are tensors, should fail. + Reset(); + NodeDef node_def = get_strided_slice_nodedef(); + AddTestTensor("input", {1, 2, 3}); + AddTestTensor("begin", {4}); + AddTestTensor("end", {4}); + AddTestTensor("strides", {4}); + RunValidationAndConversion( + node_def, error::INVALID_ARGUMENT, + "StridedSlice expects weights for begin, end, and strides, at " + "my_strided_slice"); + } + { + // Non-zero ellipsis_mask, should fail. + Reset(); + NodeDef node_def = get_strided_slice_nodedef( + /*begin_mask=*/0, /*end_mask=*/0, /*ellipsis_mask=*/2, + /*new_axis_mask=*/0, /*shrink_axis_mask=*/0); + AddTestTensor("input", {1, 2, 3}); + AddTestWeights("begin", {4}, {0, 0, 0, 0}); + AddTestWeights("end", {4}, {1, 1, 2, 3}); + AddTestWeights("strides", {4}, {1, 1, 1, 1}); + RunValidationAndConversion( + node_def, error::UNIMPLEMENTED, + "ellipsis_mask is not supported for StridedSlice, at " + "my_strided_slice"); + } + { + // Modify batch dim, should fail. + Reset(); + NodeDef node_def = get_strided_slice_nodedef(); + AddTestTensor("input", {1, 2, 3}); + AddTestWeights("begin", {4}, {0, 0, 0, 0}); + AddTestWeights("end", {4}, {0, 1, 2, 3}); + AddTestWeights("strides", {4}, {1, 1, 1, 1}); + RunValidationAndConversion( + node_def, error::UNIMPLEMENTED, + "StridedSlice can't modify batch dim, at my_strided_slice"); + } + { + // Stride is not 1, should fail. + Reset(); + NodeDef node_def = get_strided_slice_nodedef(); + AddTestTensor("input", {1, 2, 3}); + AddTestWeights("begin", {4}, {0, 0, 0, 0}); + AddTestWeights("end", {4}, {1, 1, 2, 3}); + AddTestWeights("strides", {4}, {1, 2, -1, 3}); + RunValidationAndConversion(node_def, error::UNIMPLEMENTED, + "StridedSlice is only implemented for stride of " + "1, at my_strided_slice"); + } + { + // Begin out of bounds, should fail. + Reset(); + NodeDef node_def = get_strided_slice_nodedef(); + AddTestTensor("input", {1, 2, 3}); + AddTestWeights("begin", {4}, {1, 2, 3, 4}); + AddTestWeights("end", {4}, {0, 1, 2, 3}); + AddTestWeights("strides", {4}, {1, 1, 1, 1}); + RunValidationAndConversion( + node_def, error::INVALID_ARGUMENT, + "begin value of 2 for StridedSlice is invalid, must be in the range " + "[-dim_size(i), dim_size(i)], at my_strided_slice"); + } + { + // End out of bounds, should fail. + Reset(); + NodeDef node_def = get_strided_slice_nodedef(); + AddTestTensor("input", {1, 2, 3}); + AddTestWeights("begin", {4}, {0, 0, 0, 0}); + AddTestWeights("end", {4}, {1, 2, 3, 4}); + AddTestWeights("strides", {4}, {1, 1, 1, 1}); + RunValidationAndConversion( + node_def, error::INVALID_ARGUMENT, + "end value of 2 for StridedSlice is invalid, must be in the range " + "[-dim_size(i), dim_size(i)], at my_strided_slice"); + } + { + // Size of sliced dim is negative, should fail. + Reset(); + NodeDef node_def = get_strided_slice_nodedef(); + AddTestTensor("input", {1, 2, 3}); + AddTestWeights("begin", {4}, {0, 0, 2, 0}); + AddTestWeights("end", {4}, {1, 1, 0, 3}); + AddTestWeights("strides", {4}, {1, 1, 1, 1}); + RunValidationAndConversion( + node_def, error::INVALID_ARGUMENT, + "New size of sliced dimension is negative, at my_strided_slice"); + } + + struct TestParams { + TestParams(const std::vector& input_dims, + const std::vector& expected_output_dims, + const std::vector& begin, const std::vector& end, + const std::vector& begin_mask, + const std::vector& end_mask, + const std::vector& expected_output) + : input_dims(input_dims), + expected_output_dims(expected_output_dims), + begin(begin), + end(end), + expected_output(expected_output) { + // Masks are provided in terms of vectors for readability. Convert them to + // binary here. + this->begin_mask = 0; + for (int i = 0; i < begin_mask.size(); i++) { + if (begin_mask[i]) this->begin_mask |= (1 << i); + } + this->end_mask = 0; + for (int i = 0; i < end_mask.size(); i++) { + if (end_mask[i]) this->end_mask |= (1 << i); + } + } + + std::vector input_dims; + std::vector expected_output_dims; + std::vector begin; + std::vector end; + int begin_mask; + int end_mask; + std::vector expected_output; + }; + + // Ok. + const int kStridedSliceOKCases = 18; + TestParams ok_params[kStridedSliceOKCases] = { + // 2D Crop. + TestParams{/*input_dims=*/{1, 2, 3}, /*expected_output_dims=*/{1, 1, 2}, + /*begin=*/{0, 0, 0, 0}, /*end=*/{0, 0, 1, 2}, + /*begin_mask=*/{0, 0, 0, 0}, /*end_mask=*/{1, 1, 0, 0}, + /*expected_output=*/{1, 2}}, + TestParams{/*input_dims=*/{1, 2, 3}, /*expected_output_dims=*/{1, 1, 2}, + /*begin=*/{0, 0, 1, 1}, /*end=*/{0, 0, 0, 0}, + /*begin_mask=*/{0, 0, 0, 0}, /*end_mask=*/{1, 1, 1, 1}, + /*expected_output=*/{5, 6}}, + TestParams{/*input_dims=*/{1, 2, 3}, /*expected_output_dims=*/{1, 1, 2}, + /*begin=*/{0, 0, 1, 1}, /*end=*/{0, 1, 2, 3}, + /*begin_mask=*/{0, 0, 0, 0}, /*end_mask=*/{1, 1, 0, 0}, + /*expected_output=*/{5, 6}}, + // 2D Crop, with transpose. + TestParams{/*input_dims=*/{2, 3, 1}, /*expected_output_dims=*/{1, 2, 1}, + /*begin=*/{0, 0, 0, 0}, /*end=*/{0, 1, 2, 1}, + /*begin_mask=*/{0, 0, 0, 0}, /*end_mask=*/{1, 0, 0, 0}, + /*expected_output=*/{1, 2}}, + TestParams{/*input_dims=*/{2, 3, 1}, /*expected_output_dims=*/{1, 2, 1}, + /*begin=*/{0, 1, 1, 0}, /*end=*/{0, 2, 3, 1}, + /*begin_mask=*/{0, 0, 0, 0}, /*end_mask=*/{1, 0, 0, 0}, + /*expected_output=*/{5, 6}}, + TestParams{/*input_dims=*/{2, 1, 3}, /*expected_output_dims=*/{1, 1, 2}, + /*begin=*/{0, 0, 0, 0}, /*end=*/{0, 1, 1, 2}, + /*begin_mask=*/{0, 0, 0, 0}, /*end_mask=*/{1, 0, 0, 0}, + /*expected_output=*/{1, 2}}, + TestParams{/*input_dims=*/{2, 1, 3}, /*expected_output_dims=*/{1, 1, 2}, + /*begin=*/{0, 1, 0, 1}, /*end=*/{0, 2, 1, 3}, + /*begin_mask=*/{0, 0, 0, 0}, /*end_mask=*/{1, 0, 0, 0}, + /*expected_output=*/{5, 6}}, + // 2D Crop, with reshape. + TestParams{/*input_dims=*/{2, 3}, /*expected_output_dims=*/{1, 2}, + /*begin=*/{0, 0, 0}, /*end=*/{0, 1, 2}, + /*begin_mask=*/{0, 0, 0}, /*end_mask=*/{1, 0, 0}, + /*expected_output=*/{1, 2}}, + TestParams{/*input_dims=*/{2, 3}, /*expected_output_dims=*/{1, 2}, + /*begin=*/{0, 1, 1}, /*end=*/{0, 0, 0}, + /*begin_mask=*/{0, 0, 0}, /*end_mask=*/{1, 1, 1}, + /*expected_output=*/{5, 6}}, + // 1D Crop. + TestParams{/*input_dims=*/{1, 2, 3}, /*expected_output_dims=*/{1, 2, 2}, + /*begin=*/{0, 0, 0, 0}, /*end=*/{0, 0, 0, 2}, + /*begin_mask=*/{0, 0, 0, 0}, /*end_mask=*/{1, 1, 1, 0}, + /*expected_output=*/{1, 2, 4, 5}}, + TestParams{/*input_dims=*/{1, 2, 3}, /*expected_output_dims=*/{1, 1, 3}, + /*begin=*/{0, 0, 1, 0}, /*end=*/{0, 0, 0, 0}, + /*begin_mask=*/{0, 0, 0, 0}, /*end_mask=*/{1, 1, 1, 1}, + /*expected_output=*/{4, 5, 6}}, + // 1D Crop, with transpose. + TestParams{/*input_dims=*/{2, 3, 1}, /*expected_output_dims=*/{1, 3, 1}, + /*begin=*/{0, 0, 0, 0}, /*end=*/{0, 1, 0, 0}, + /*begin_mask=*/{0, 0, 0, 0}, /*end_mask=*/{1, 0, 1, 1}, + /*expected_output=*/{1, 2, 3}}, + TestParams{/*input_dims=*/{2, 3, 1}, /*expected_output_dims=*/{1, 3, 1}, + /*begin=*/{0, 1, 0, 0}, /*end=*/{0, 0, 0, 0}, + /*begin_mask=*/{0, 0, 0, 0}, /*end_mask=*/{1, 1, 1, 1}, + /*expected_output=*/{4, 5, 6}}, + // 1D Crop, with reshape. + TestParams{/*input_dims=*/{6}, /*expected_output_dims=*/{3}, + /*begin=*/{0, 0}, /*end=*/{0, 3}, + /*begin_mask=*/{0, 0}, /*end_mask=*/{1, 0}, + /*expected_output=*/{1, 2, 3}}, + TestParams{/*input_dims=*/{1, 6}, /*expected_output_dims=*/{1, 3}, + /*begin=*/{0, 0, 2}, /*end=*/{0, 0, 5}, + /*begin_mask=*/{0, 0, 0}, /*end_mask=*/{1, 1, 0}, + /*expected_output=*/{3, 4, 5}}, + TestParams{/*input_dims=*/{6, 1}, /*expected_output_dims=*/{3, 1}, + /*begin=*/{0, 2, 0}, /*end=*/{0, 5, 0}, + /*begin_mask=*/{0, 0, 0}, /*end_mask=*/{1, 0, 1}, + /*expected_output=*/{3, 4, 5}}, + // Negative axis. + TestParams{/*input_dims=*/{6, 1}, /*expected_output_dims=*/{3, 1}, + /*begin=*/{0, -6, 0}, /*end=*/{0, -3, 0}, + /*begin_mask=*/{0, 0, 0}, /*end_mask=*/{1, 0, 1}, + /*expected_output=*/{1, 2, 3}}, + TestParams{/*input_dims=*/{6, 1}, /*expected_output_dims=*/{5, 1}, + /*begin=*/{0, 0, 0}, /*end=*/{0, -1, 0}, + /*begin_mask=*/{0, 0, 0}, /*end_mask=*/{1, 0, 1}, + /*expected_output=*/{1, 2, 3, 4, 5}}, + }; + + for (int i = 0; i < kStridedSliceOKCases; i++) { + Reset(); + NodeDef node_def = get_strided_slice_nodedef(ok_params[i].begin_mask, + ok_params[i].end_mask); + AddTestTensor("input", ok_params[i].input_dims); + AddTestWeights("begin", + {static_cast(ok_params[i].begin.size())}, + ok_params[i].begin); + AddTestWeights("end", {static_cast(ok_params[i].end.size())}, + ok_params[i].end); + std::vector strides(ok_params[i].input_dims.size(), 1); + AddTestWeights("strides", {static_cast(strides.size())}, + strides); + RunValidationAndConversion(node_def); + + TRT_TensorOrWeights output; + TF_EXPECT_OK(GetTensorOrWeights("my_strided_slice", &output)); + std::vector output_data(ok_params[i].expected_output.size()); + BuildAndRun({{"input", {1, 2, 3, 4, 5, 6}}}, "my_strided_slice", + &output_data); + EXPECT_THAT(output_data, ElementsAreArray(ok_params[i].expected_output)); + } +} + } // namespace convert } // namespace tensorrt } // namespace tensorflow diff --git a/tensorflow/contrib/tensorrt/convert/trt_optimization_pass.cc b/tensorflow/contrib/tensorrt/convert/trt_optimization_pass.cc index c1688d4db88a270dcd202989f89a677ed10576d9..d57f2300f8e6e6ce79c538133da6bc5cf5ead2f5 100644 --- a/tensorflow/contrib/tensorrt/convert/trt_optimization_pass.cc +++ b/tensorflow/contrib/tensorrt/convert/trt_optimization_pass.cc @@ -226,8 +226,9 @@ tensorflow::Status TRTOptimizationPass::Optimize( tensorflow::tensorrt::convert::ConversionParams cp; if (use_calibration_ && precision_mode_ != INT8MODE) { - LOG(ERROR) << "Calibration with FP32 or FP16 is not implemented. " - << "Falling back to use_calibration = False."; + VLOG(1) << "Calibration with FP32 or FP16 is not implemented. " + << "Falling back to use_calibration = False." + << "Note that the default value of use_calibration is True."; use_calibration_ = false; } diff --git a/tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_plugin.h b/tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_plugin.h index f658e45569bbef73faa751634b0163f5687ad164..189e9c939b9ffd4450f7ba95fe1abdbbc049b430 100644 --- a/tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_plugin.h +++ b/tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_plugin.h @@ -23,7 +23,7 @@ limitations under the License. #if GOOGLE_CUDA #if GOOGLE_TENSORRT -#include "tensorrt/NvInfer.h" +#include "tensorrt/include/NvInfer.h" namespace tensorflow { namespace tensorrt { diff --git a/tensorflow/contrib/tensorrt/kernels/trt_engine_op.h b/tensorflow/contrib/tensorrt/kernels/trt_engine_op.h index b801480a30552113d0e9572d173871d71b7cacd8..b545f497f32d5a1a6960b748467ca189b7debf6c 100644 --- a/tensorflow/contrib/tensorrt/kernels/trt_engine_op.h +++ b/tensorflow/contrib/tensorrt/kernels/trt_engine_op.h @@ -31,7 +31,7 @@ limitations under the License. #if GOOGLE_CUDA #if GOOGLE_TENSORRT #include "cuda/include/cuda_runtime_api.h" -#include "tensorrt/NvInfer.h" +#include "tensorrt/include/NvInfer.h" namespace tensorflow { namespace tensorrt { diff --git a/tensorflow/contrib/tensorrt/log/trt_logger.h b/tensorflow/contrib/tensorrt/log/trt_logger.h index 58d9d05d01960d7c5222b3a8be881afdba2f79e6..96ccacb791e40143c5c4d9d691bb353702f9a28b 100644 --- a/tensorflow/contrib/tensorrt/log/trt_logger.h +++ b/tensorflow/contrib/tensorrt/log/trt_logger.h @@ -20,7 +20,7 @@ limitations under the License. #if GOOGLE_CUDA #if GOOGLE_TENSORRT -#include "tensorrt/NvInfer.h" +#include "tensorrt/include/NvInfer.h" namespace tensorflow { namespace tensorrt { diff --git a/tensorflow/contrib/tensorrt/plugin/trt_plugin.h b/tensorflow/contrib/tensorrt/plugin/trt_plugin.h index 167e8197a70ecab0b068777d92948a92cabe6d2b..754920b60ca7439513a91ad0354833a2482b29c1 100644 --- a/tensorflow/contrib/tensorrt/plugin/trt_plugin.h +++ b/tensorflow/contrib/tensorrt/plugin/trt_plugin.h @@ -24,7 +24,7 @@ limitations under the License. #if GOOGLE_CUDA #if GOOGLE_TENSORRT -#include "tensorrt/NvInfer.h" +#include "tensorrt/include/NvInfer.h" namespace tensorflow { namespace tensorrt { diff --git a/tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.h b/tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.h index 51393d2092c04323932273f6655d4579269e34aa..bbae9fb65c22cf69d2e7954436fd04dd16f7f6c8 100644 --- a/tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.h +++ b/tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.h @@ -27,7 +27,7 @@ limitations under the License. #if GOOGLE_CUDA #if GOOGLE_TENSORRT -#include "tensorrt/NvInfer.h" +#include "tensorrt/include/NvInfer.h" namespace tensorflow { namespace tensorrt { diff --git a/tensorflow/contrib/tensorrt/plugin/trt_plugin_factory_test.cc b/tensorflow/contrib/tensorrt/plugin/trt_plugin_factory_test.cc index 2346cb9ba03ea033363fe5336e5fbab058a8ac6c..129bdcdbc2f8d9d5215f45f381bcadf35e4fa75e 100644 --- a/tensorflow/contrib/tensorrt/plugin/trt_plugin_factory_test.cc +++ b/tensorflow/contrib/tensorrt/plugin/trt_plugin_factory_test.cc @@ -23,7 +23,7 @@ limitations under the License. #if GOOGLE_CUDA #if GOOGLE_TENSORRT -#include "tensorrt/NvInfer.h" +#include "tensorrt/include/NvInfer.h" namespace tensorflow { namespace tensorrt { diff --git a/tensorflow/contrib/tensorrt/plugin/trt_plugin_utils.h b/tensorflow/contrib/tensorrt/plugin/trt_plugin_utils.h index 5ded702c4189f9cced2aa1ed3c33f7d1ccf7efd1..274ce42fec9283c643004d45fba461879fc5f2dc 100644 --- a/tensorflow/contrib/tensorrt/plugin/trt_plugin_utils.h +++ b/tensorflow/contrib/tensorrt/plugin/trt_plugin_utils.h @@ -23,7 +23,7 @@ limitations under the License. #if GOOGLE_CUDA #if GOOGLE_TENSORRT -#include "tensorrt/NvInfer.h" +#include "tensorrt/include/NvInfer.h" namespace tensorflow { namespace tensorrt { diff --git a/tensorflow/contrib/tensorrt/resources/trt_allocator.h b/tensorflow/contrib/tensorrt/resources/trt_allocator.h index b03fe7b8b59b01d7f5e947efffee1b7a7c45b86d..f857a9de055ee7668f0bf9bc97e030354505081b 100644 --- a/tensorflow/contrib/tensorrt/resources/trt_allocator.h +++ b/tensorflow/contrib/tensorrt/resources/trt_allocator.h @@ -22,7 +22,7 @@ limitations under the License. #if GOOGLE_CUDA #if GOOGLE_TENSORRT -#include "tensorrt/NvInfer.h" +#include "tensorrt/include/NvInfer.h" #endif // GOOGLE_TENSORRT #endif // GOOGLE_CUDA diff --git a/tensorflow/contrib/tensorrt/resources/trt_allocator_test.cc b/tensorflow/contrib/tensorrt/resources/trt_allocator_test.cc index ad6b1d7d4c57d696d3dee3b479733e152e669211..beb1284208e4c10ffe1d36ef411cf08f11dbcb78 100644 --- a/tensorflow/contrib/tensorrt/resources/trt_allocator_test.cc +++ b/tensorflow/contrib/tensorrt/resources/trt_allocator_test.cc @@ -48,11 +48,14 @@ TEST(TRTAllocatorTest, Align) { 513ul, 700ul, 12345ul, 1ul << 32}) { for (uint64_t alignment = 1; alignment <= space * 4; alignment *= 2) { for (const uintptr_t ptr_val : - {1ul, alignment == 1 ? 1ul : alignment - 1, alignment, alignment + 1, - alignment + (alignment / 2)}) { + {static_cast(1), + alignment == 1 ? static_cast(1) : alignment - 1, + alignment, alignment + 1, alignment + (alignment / 2)}) { if (ptr_val % alignment == 0) { for (const uint64_t size : - {1ul, space == 1 ? 1ul : space - 1, space, space + 1}) { + {static_cast(1), + space == 1 ? static_cast(1) : space - 1, space, + space + 1}) { EXPECT_EQ(space >= size, RunTest(alignment, size, ptr_val, space)); } } else { @@ -62,8 +65,10 @@ TEST(TRTAllocatorTest, Align) { EXPECT_TRUE( RunTest(alignment, space - diff, ptr_val + diff, space - diff)); for (const uint64_t size : - {1ul, space - diff > 1 ? space - diff - 1 : 1ul, space - diff, - space - diff + 1, space - 1}) { + {static_cast(1), + space - diff > 1 ? space - diff - 1 + : static_cast(1), + space - diff, space - diff + 1, space - 1}) { EXPECT_EQ(space - diff >= size, RunTest(alignment, size, ptr_val, space)); } diff --git a/tensorflow/contrib/tensorrt/resources/trt_int8_calibrator.h b/tensorflow/contrib/tensorrt/resources/trt_int8_calibrator.h index e8f08ad9f48453da4ef51d94fcbc3d98f6e04b3b..65466c9741989fda5f82fc27d813d026f35fe386 100644 --- a/tensorflow/contrib/tensorrt/resources/trt_int8_calibrator.h +++ b/tensorflow/contrib/tensorrt/resources/trt_int8_calibrator.h @@ -26,7 +26,7 @@ limitations under the License. #if GOOGLE_TENSORRT #include "cuda/include/cuda_runtime_api.h" -#include "tensorrt/NvInfer.h" +#include "tensorrt/include/NvInfer.h" namespace tensorflow { namespace tensorrt { diff --git a/tensorflow/contrib/tensorrt/resources/trt_resources.h b/tensorflow/contrib/tensorrt/resources/trt_resources.h index 0be5d44f7a36276bdaabe9a63337844d4011cf32..aac9e5c7bd725fc10bcaa04536ebc7be071b4d4c 100644 --- a/tensorflow/contrib/tensorrt/resources/trt_resources.h +++ b/tensorflow/contrib/tensorrt/resources/trt_resources.h @@ -31,7 +31,7 @@ limitations under the License. #if GOOGLE_CUDA #if GOOGLE_TENSORRT -#include "tensorrt/NvInfer.h" +#include "tensorrt/include/NvInfer.h" namespace tensorflow { namespace tensorrt { diff --git a/tensorflow/contrib/tensorrt/segment/segment.cc b/tensorflow/contrib/tensorrt/segment/segment.cc index 6abc5226ccf96e472df77269bee6186726e5768d..084a96e0fa5c97edc58adf2590ed94e5ef0e4d85 100644 --- a/tensorflow/contrib/tensorrt/segment/segment.cc +++ b/tensorflow/contrib/tensorrt/segment/segment.cc @@ -225,6 +225,24 @@ SimpleGraph::~SimpleGraph() { for (auto x : edges_) delete x; } +// Define comparison functions for std::set with pointer keys so that behavior +// is deterministic. When using std::set with pointer key types, the items are +// sorted by pointer address which is non-deterministic. This can cause issues +// for INT8 mode because the graph is converted twice and non-determinism may +// cause a mismatch between the calibration tables of the conversions. +struct SimpleEdgePtrCompare { + bool operator()(const SimpleEdge* lhs, const SimpleEdge* rhs) const { + return lhs->id() < rhs->id(); + } +}; + +struct NodePtrCompare { + bool operator()(const tensorflow::Node* lhs, + const tensorflow::Node* rhs) const { + return lhs->name() < rhs->name(); + } +}; + namespace { // Copied from TF ReverseDFS, which only works for tensorflow::Graph. @@ -476,7 +494,7 @@ tensorflow::Status SegmentGraph( // nodes. Iterate since combining two nodes may unblock other // combining. while (true) { - std::set contract_edges; + std::set contract_edges; for (const SimpleEdge* out_edge : node->out_edges()) { VLOG(3) << "... out node " << out_edge->dst()->name() << " ( " << out_edge->dst()->id() << " <- " << node->id() << " )"; @@ -530,7 +548,7 @@ tensorflow::Status SegmentGraph( // A map from the segment identifier (currently the name of the root node of // the segment tree) to the segment nodes set. - std::map> sg_map; + std::map> sg_map; // A map from the segment identifier (currently the name of the root node of // the segment tree) to the device names that the nodes in the segment are @@ -566,7 +584,8 @@ tensorflow::Status SegmentGraph( // --------------------------------- Step 2 --------------------------------- // Remove ineligible input/output nodes. for (auto& itr : sg_map) { - std::set& segment_nodes = itr.second; + std::set& segment_nodes = + itr.second; VLOG(1) << "Segment original size: " << segment_nodes.size(); while (true) { std::deque in_nodes_que, out_nodes_que; @@ -618,8 +637,9 @@ tensorflow::Status SegmentGraph( bool is_input_nodes, std::deque* que) { // Run a BFS on the queue to find all the input/output nodes. - std::set visited; - std::set logged(que->begin(), que->end()); + std::set visited; + std::set logged(que->begin(), + que->end()); while (!que->empty()) { auto node = que->front(); que->pop_front(); @@ -653,7 +673,8 @@ tensorflow::Status SegmentGraph( // --------------------------------- Step 3 --------------------------------- // Convert the segments into the expected return format for (const auto& itr : sg_map) { - const std::set& segment_nodes = itr.second; + const std::set& segment_nodes = + itr.second; if (VLOG_IS_ON(1)) { string s = "parent=" + itr.first + ":"; for (auto node : segment_nodes) s += " " + node->name(); diff --git a/tensorflow/contrib/tensorrt/shape_fn/trt_shfn.cc b/tensorflow/contrib/tensorrt/shape_fn/trt_shfn.cc index ad9703325f5b0e41e7cea28eafdf91e5a1681245..f30dba59ad55317d7ad7730e4dc66c9aba4e6a6b 100644 --- a/tensorflow/contrib/tensorrt/shape_fn/trt_shfn.cc +++ b/tensorflow/contrib/tensorrt/shape_fn/trt_shfn.cc @@ -23,7 +23,7 @@ limitations under the License. #if GOOGLE_TENSORRT #include "tensorflow/contrib/tensorrt/log/trt_logger.h" #include "tensorflow/core/lib/core/errors.h" -#include "tensorrt/NvInfer.h" +#include "tensorrt/include/NvInfer.h" namespace tensorflow { namespace shape_inference { diff --git a/tensorflow/contrib/tensorrt/tensorrt_test.cc b/tensorflow/contrib/tensorrt/tensorrt_test.cc index 102a1d38919fde9cdb0890a0073c916a68b85601..769982c6456f76663e50fe3ec59651127e3720ac 100644 --- a/tensorflow/contrib/tensorrt/tensorrt_test.cc +++ b/tensorflow/contrib/tensorrt/tensorrt_test.cc @@ -22,7 +22,7 @@ limitations under the License. #if GOOGLE_TENSORRT #include "cuda/include/cuda.h" #include "cuda/include/cuda_runtime_api.h" -#include "tensorrt/NvInfer.h" +#include "tensorrt/include/NvInfer.h" namespace tensorflow { namespace { diff --git a/tensorflow/contrib/tensorrt/test/rank_two_test.py b/tensorflow/contrib/tensorrt/test/rank_two_test.py index 0cd733dca13462ac8f4478544005ae4000f711f1..563232fc12675d9e1b32b7ab461591af57beadb9 100644 --- a/tensorflow/contrib/tensorrt/test/rank_two_test.py +++ b/tensorflow/contrib/tensorrt/test/rank_two_test.py @@ -51,8 +51,10 @@ class RankTwoTest(trt_test.TfTrtIntegrationTestBase): c = constant_op.constant(3.0, name="c%d_3" % i) q = math_ops.add(q, c, name="add%d_3" % i) if i == 0: + axis = constant_op.constant(-1, dtype=dtypes.int32, name="axis") for j in range(2): - q = array_ops.expand_dims(q, -1, name="expand%d_%d" % (i, j)) + q = array_ops.expand_dims(q, axis, name="expand%d_%d" % (i, j)) + q = self.trt_incompatible_op(q) q = gen_math_ops.reciprocal(q, name="reciprocal%d" % i) outputs.append(q) # Combine both paths @@ -70,7 +72,7 @@ class RankTwoTest(trt_test.TfTrtIntegrationTestBase): return { "TRTEngineOp_0": [ "add0_1", "add0_2", "add0_3", "c0_1", "c0_2", "c0_3", "abs0_1", - "abs0_2" + "abs0_2", "expand0_0", "expand0_1", "axis" ], "TRTEngineOp_1": [ "add", "add1_1", "add1_2", "add1_3", "c1_1", "c1_2", "c1_3", diff --git a/tensorflow/contrib/tensorrt/test/unary_test.py b/tensorflow/contrib/tensorrt/test/unary_test.py index 9fc50e05952abd335e196dce8fc8a81056d7007d..b6e5e32db1236684a06c2d44298b9a3d39667152 100644 --- a/tensorflow/contrib/tensorrt/test/unary_test.py +++ b/tensorflow/contrib/tensorrt/test/unary_test.py @@ -106,10 +106,7 @@ class UnaryTest(trt_test.TfTrtIntegrationTestBase): def ExpectedEnginesToBuild(self, run_params): """Return the expected engines to build.""" - return [ - "TRTEngineOp_0", "TRTEngineOp_1", "TRTEngineOp_2", "TRTEngineOp_3", - "TRTEngineOp_4" - ] + return ["TRTEngineOp_0"] if __name__ == "__main__": diff --git a/tensorflow/contrib/tfprof/README.md b/tensorflow/contrib/tfprof/README.md index b29d1acacf17b57549558be45c853566817c1729..f40e76f554e8815aac96344d8cb0b911bafdd712 100644 --- a/tensorflow/contrib/tfprof/README.md +++ b/tensorflow/contrib/tfprof/README.md @@ -1,7 +1,5 @@ # tfprof: TensorFlow Profiler and Beyond -

Please use `tf.profiler.xxx` instead of `tf.contrib.tfprof.xxx`

-

Full Document in tensorflow/core/profiler/README.md

diff --git a/tensorflow/contrib/timeseries/python/timeseries/BUILD b/tensorflow/contrib/timeseries/python/timeseries/BUILD index ae7db35b47b326272dd2c7bc76e18047cec59865..4b90b596b28efec83aa349782c4874d79b6817c7 100644 --- a/tensorflow/contrib/timeseries/python/timeseries/BUILD +++ b/tensorflow/contrib/timeseries/python/timeseries/BUILD @@ -104,6 +104,7 @@ py_test( srcs = [ "estimators_test.py", ], + shard_count = 3, srcs_version = "PY2AND3", tags = [ "no_mac", diff --git a/tensorflow/contrib/tpu/BUILD b/tensorflow/contrib/tpu/BUILD index 05d2ebd2e8a3292a95df0e2f976df0e2871063f8..007aeaec15d6db7ea4581ab9825da2dbe8b37163 100644 --- a/tensorflow/contrib/tpu/BUILD +++ b/tensorflow/contrib/tpu/BUILD @@ -79,6 +79,7 @@ py_library( "//tensorflow/python:init_ops", "//tensorflow/python:math_ops", "//tensorflow/python:platform", + "//tensorflow/python:session", "//tensorflow/python:state_ops", "//tensorflow/python:summary", "//tensorflow/python:summary_ops_v2", @@ -101,6 +102,7 @@ tf_gen_op_libs( "replication_ops", "tpu_configuration_ops", "tpu_embedding_ops", + "tpu_ordinal_selector_op", ], deps = [ "//tensorflow/contrib/tpu/proto:tpu_embedding_configuration_proto_cc", @@ -152,6 +154,13 @@ tf_gen_op_wrapper_py( ], ) +tf_gen_op_wrapper_py( + name = "tpu_ordinal_selector_op", + deps = [ + ":tpu_ordinal_selector_op_op_lib", + ], +) + py_library( name = "profiler", srcs = ["python/profiler/__init__.py"], diff --git a/tensorflow/contrib/tpu/ops/tpu_ordinal_selector_op.cc b/tensorflow/contrib/tpu/ops/tpu_ordinal_selector_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..54e6b20f7f388b67a96ac8acfe814a4202b56a18 --- /dev/null +++ b/tensorflow/contrib/tpu/ops/tpu_ordinal_selector_op.cc @@ -0,0 +1,39 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/shape_inference.h" + +namespace tensorflow { + +REGISTER_OP("TPUOrdinalSelector") + .Output("device_ordinals: int32") + .SetIsStateful() + .SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) { + c->set_output(0, + c->Vector(shape_inference::InferenceContext::kUnknownDim)); + return Status::OK(); + }) + .Doc(R"doc( +A TPU core selector Op. + +This Op produces a set of TPU cores (for warm-up) or a single TPU core +(for regular inference) to execute the TPU program on. The output is +consumed by TPUPartitionedCall. + +device_ordinals: A vector 1 or more TPU cores. +)doc"); + +} // namespace tensorflow diff --git a/tensorflow/contrib/tpu/profiler/pip_package/cloud_tpu_profiler/main.py b/tensorflow/contrib/tpu/profiler/pip_package/cloud_tpu_profiler/main.py index 63641e00c5dbf4b4e635ecfea8bef98c7d0b7075..a081c4354a779d37140338793e66844c3fcf7a12 100644 --- a/tensorflow/contrib/tpu/profiler/pip_package/cloud_tpu_profiler/main.py +++ b/tensorflow/contrib/tpu/profiler/pip_package/cloud_tpu_profiler/main.py @@ -90,12 +90,12 @@ def main(unused_argv=None): tf_version = tf.__version__ print('TensorFlow version %s detected' % tf_version) - if FLAGS.service_addr is None and FLAGS.tpu is None: + if not FLAGS.service_addr and not FLAGS.tpu: sys.exit('You must specify either --service_addr or --tpu.') tpu_cluster_resolver = None - if FLAGS.service_addr is not None: - if FLAGS.tpu is not None: + if FLAGS.service_addr: + if FLAGS.tpu: tf.logging.warn('Both --service_addr and --tpu are set. Ignoring ' '--tpu and using --service_addr.') service_addr = FLAGS.service_addr diff --git a/tensorflow/contrib/tpu/python/tpu/datasets.py b/tensorflow/contrib/tpu/python/tpu/datasets.py index d61c824eab5337a7cd08cfa52a7e8f8b8d73b455..8d6245390fc3fa005c92d01bc9b64ddb47583582 100644 --- a/tensorflow/contrib/tpu/python/tpu/datasets.py +++ b/tensorflow/contrib/tpu/python/tpu/datasets.py @@ -156,7 +156,7 @@ def StreamingFilesDataset(files, source_dataset = source_dataset.prefetch(1) - source_iterator = source_dataset.make_one_shot_iterator() + source_iterator = dataset_ops.make_one_shot_iterator(source_dataset) source_handle = source_iterator.string_handle() @function.Defun(dtypes.string) diff --git a/tensorflow/contrib/tpu/python/tpu/datasets_test.py b/tensorflow/contrib/tpu/python/tpu/datasets_test.py index b58d05eac56f3586e183333f7c1a3867ee57456c..52d87b800401c3e584da9843916cfc7a767c082a 100644 --- a/tensorflow/contrib/tpu/python/tpu/datasets_test.py +++ b/tensorflow/contrib/tpu/python/tpu/datasets_test.py @@ -70,7 +70,7 @@ class DatasetsTest(test.TestCase): dataset = datasets.StreamingFilesDataset( os.path.join(self.get_temp_dir(), 'text_line.*.txt'), filetype='text') - iterator = dataset.make_initializable_iterator() + iterator = dataset_ops.make_initializable_iterator(dataset) self._sess.run(iterator.initializer) get_next = iterator.get_next() @@ -94,7 +94,7 @@ class DatasetsTest(test.TestCase): dataset = datasets.StreamingFilesDataset( os.path.join(self.get_temp_dir(), 'tf_record*'), filetype='tfrecord') - iterator = dataset.make_initializable_iterator() + iterator = dataset_ops.make_initializable_iterator(dataset) self._sess.run(iterator.initializer) get_next = iterator.get_next() @@ -121,7 +121,7 @@ class DatasetsTest(test.TestCase): dataset = datasets.StreamingFilesDataset(filenames, filetype='tfrecord') - iterator = dataset.make_initializable_iterator() + iterator = dataset_ops.make_initializable_iterator(dataset) self._sess.run(iterator.initializer) get_next = iterator.get_next() @@ -154,7 +154,7 @@ class DatasetsTest(test.TestCase): os.path.join(self.get_temp_dir(), 'fixed_length*'), filetype=FixedLengthFile) - iterator = dataset.make_initializable_iterator() + iterator = dataset_ops.make_initializable_iterator(dataset) self._sess.run(iterator.initializer) get_next = iterator.get_next() @@ -177,7 +177,7 @@ class DatasetsTest(test.TestCase): dataset = datasets.StreamingFilesDataset( dataset_ops.Dataset.range(10), filetype=gen_dataset) - iterator = dataset.make_initializable_iterator() + iterator = dataset_ops.make_initializable_iterator(dataset) self._sess.run(iterator.initializer) get_next = iterator.get_next() diff --git a/tensorflow/contrib/tpu/python/tpu/keras_support.py b/tensorflow/contrib/tpu/python/tpu/keras_support.py index cf3b2e68e940652220983c98e3a0acb68cf88d89..cf9672f8d867f4ad5cb0281abe710f6e3bcdf1f2 100644 --- a/tensorflow/contrib/tpu/python/tpu/keras_support.py +++ b/tensorflow/contrib/tpu/python/tpu/keras_support.py @@ -133,7 +133,7 @@ def _tpu_session_context(): An error occurred connecting or initializing your TPU. The session has been reset. re-run keras_to_tpu_model to create a new session. -""" + e) +""" + str(e)) def setup_tpu_session(cluster_resolver): @@ -729,7 +729,7 @@ class TPUDatasetInfeedManager(TPUInfeedManager): dummy_x_shape[0] *= tpu_assignment.num_towers dummy_y_shape = dataset.output_shapes[1].as_list() dummy_y_shape[0] *= tpu_assignment.num_towers - self._iterator = dataset.make_initializable_iterator() + self._iterator = dataset_ops.make_initializable_iterator(dataset) K.get_session().run(self._iterator.initializer) self._get_next_ops = [] @@ -1676,14 +1676,10 @@ class KerasTPUModel(models.Model): callbacks, self, do_validation=do_validation, - val_inputs=val_inputs, - val_targets=val_targets, - val_sample_weights=val_sample_weights, batch_size=batch_size, epochs=epochs, steps_per_epoch=steps_per_epoch, samples=num_training_samples, - validation_steps=validation_steps, verbose=verbose, count_mode=count_mode) @@ -2073,6 +2069,8 @@ class KerasTPUModel(models.Model): # tpu_model may not be compiled, e.g., loading weights and then predict. return for k, v in six.iteritems(cpu_optimizer_config): + if k == 'name': + continue opt_var = getattr(self._tpu_model.optimizer, k) if isinstance(opt_var, variables.Variable): logging.info('CPU -> TPU %s: %s {%s}', k, v, K.get_value(opt_var)) @@ -2101,6 +2099,8 @@ class KerasTPUModel(models.Model): self._cpu_model.set_weights(tpu_weights) for k, v in six.iteritems(tpu_optimizer_config): logging.info('TPU -> CPU %s: %s', k, v) + if k == 'name': + continue opt_var = getattr(self.cpu_optimizer, k) if isinstance(opt_var, variables.Variable): K.get_session().run(opt_var.assign(v)) diff --git a/tensorflow/contrib/tpu/python/tpu/keras_tpu_variables.py b/tensorflow/contrib/tpu/python/tpu/keras_tpu_variables.py index 8b0b240dc7302c203a22349d583323327fc4480b..de425626c813784ef657d17eac0c7bb77599a155 100644 --- a/tensorflow/contrib/tpu/python/tpu/keras_tpu_variables.py +++ b/tensorflow/contrib/tpu/python/tpu/keras_tpu_variables.py @@ -69,6 +69,7 @@ class ReplicatedVariable(object): def __init__(self, name, variables): self._name = name self._primary_var = variables[0] + self._common_name = self._primary_var.name.split(":")[0] self._vars = variables self._cached_value = None self._dtype = variables[0].dtype diff --git a/tensorflow/contrib/tpu/python/tpu/session_support.py b/tensorflow/contrib/tpu/python/tpu/session_support.py index a95275487899c4770ef99b620a7671eec2bb81eb..3e463823c820a3ef8628324f77e1a9caf8d385d5 100644 --- a/tensorflow/contrib/tpu/python/tpu/session_support.py +++ b/tensorflow/contrib/tpu/python/tpu/session_support.py @@ -43,12 +43,19 @@ class CoordinatorShutdownException(Exception): pass +def _clone_session(session, graph=None): + return session_lib.Session( + target=session.sess_str, + config=session._config, # pylint: disable=protected-access + graph=graph if graph else session.graph) + + def _make_heartbeat_op(session, device, request_ph): """Return a heartbeat op or None if heartbeats are not supported by device.""" try: # Test if we can connect in a isolated graph + session with ops.Graph().as_default(): - with session_lib.Session(target=session.sess_str) as temp_session: + with _clone_session(session) as temp_session: with ops.device(device): heartbeat_op = tpu_ops.worker_heartbeat('') options = config_pb2.RunOptions(timeout_in_ms=5000) @@ -220,6 +227,7 @@ class WatchdogManager(threading.Thread): self.ping_interval = ping_interval self.shutdown_timeout = shutdown_timeout self.daemon = True + self._config = session._config # pylint: disable=protected-access self._target = session.sess_str self._running = False self._devices = devices @@ -234,6 +242,7 @@ class WatchdogManager(threading.Thread): self._session = session_lib.Session( target=self._target, graph=self._graph, + config=self._config, ) if self._devices is None: @@ -334,8 +343,7 @@ class GracefulShutdownHook(session_run_hook.SessionRunHook): with self._graph.as_default(): logging.info('Installing graceful shutdown hook.') - self._session = session_lib.Session( - target=training_session.sess_str, graph=self._graph) + self._session = _clone_session(training_session, self._graph) self._workers = WorkerHeartbeatManager.from_devices( self._session, all_worker_devices(self._session)) self._heartbeat_supported = self._workers.num_workers() > 0 diff --git a/tensorflow/contrib/tpu/python/tpu/tensor_tracer.py b/tensorflow/contrib/tpu/python/tpu/tensor_tracer.py index 70baea203cc6174bebc7d90646045efae5f2391d..a1494e3660bc09e3af45e81097151a35990810fb 100644 --- a/tensorflow/contrib/tpu/python/tpu/tensor_tracer.py +++ b/tensorflow/contrib/tpu/python/tpu/tensor_tracer.py @@ -21,44 +21,56 @@ from __future__ import print_function import os import os.path import re +import sys from tensorflow.contrib.tpu.python.ops import tpu_ops from tensorflow.contrib.tpu.python.tpu import tpu from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_util from tensorflow.python.ops import gen_math_ops +from tensorflow.python.ops import linalg_ops from tensorflow.python.ops import logging_ops from tensorflow.python.ops import math_ops +from tensorflow.python.platform import gfile from tensorflow.python.platform import tf_logging as logging _TRACER_LOG_PREFIX = ' [>>>TT>>>]' _DEVICE_TYPE_TPU = 'tpu' _DEVICE_TYPE_CPU = 'cpu' -_GLOBAL_STEP_OP_NAME = 'GLOBAL-STEP' _TRACE_MODE_NAN_INF = 'nan-inf' _TRACE_MODE_PART_TENSOR = 'part-tensor' _TRACE_MODE_PART_TENSOR_SIZE = 3 _TRACE_MODE_FULL_TENSOR = 'full-tensor' -_RECORD_OUTSIDE_OP_RANGE = 'not-traced-outside-op-range' -_RECORD_SHOULD_NOT_TRACE = 'not-traced-should-not-trace' -_RECORD_FILTERED_OUT = 'not-traced-filtered-out' -_RECORD_SCALAR = 'not-traced-scalar' -_RECORD_DYNAMIC_SHAPE = 'not-traced-dynamic-shape' -_RECORD_GET_TRACED = 'get-traced' +_TRACE_MODE_NORM = 'norm' +_TRACE_MODE_MAX_ABS = 'max-abs' +_REASON_OUTSIDE_OP_RANGE = 'not-traced-outside-op-range' +_REASON_UNSAFE_OP = 'not-traced-unsafe-op' +_REASON_UNSAFE_SCALAR = 'not-traced-unsafe-scalar' +_REASON_LESS_INTERESTING_OP = 'not-traced-less-interesting-op' +_REASON_DEVICE_MISMATCH = 'not-traced-device-mismatch' +_REASON_DYNAMIC_SHAPE = 'not-traced-dynamic-shape' +_REASON_SCALAR_GET_TRACED = 'traced-scalar' +_REASON_TENSOR_GET_TRACED = 'traced-tensor' +_REASON_USER_INCLUDED = 'traced-user-included' +_REASON_USER_EXCLUDED = 'not-traced-user-excluded' +_REASON_NON_NUMERIC_TENSOR = 'not-traced-non-numeric-tensor' _MARKER_SECTION_BEGIN = '!!!!!!! section-begin:' _MARKER_SECTION_END = '!!!!!!! section-end:' _SECTION_NAME_CONFIG = 'configuration' _SECTION_NAME_REASON = 'reason' _SECTION_NAME_OP_LIST = 'op-list' +_SECTION_NAME_TENSOR_LIST = 'tensor-list' _SECTION_NAME_GRAPH = 'graph' _FIELD_NAME_VERSION = 'version:' _FIELD_NAME_DEVICE = 'device:' _FIELD_NAME_TRACE_MODE = 'trace-mode:' _FIELD_NAME_NUM_REPLICAS = 'num-replicas:' _FIELD_NAME_NUM_OPS = 'number-of-ops:' +_FIELD_NAME_NUM_TENSORS = 'number-of-tensors:' _FIELD_NAME_TOPOLOGICAL_SORT_SUCCEED = 'topological-sort-succeed:' _FLAGS_ENV_VAR = 'TENSOR_TRACER_FLAGS' _FLAG_SINGLE_QUOTE_PAT = re.compile(r"\s*--([^=]+)='([^']*)'") @@ -66,13 +78,72 @@ _FLAG_DOUBLE_QUOTE_PAT = re.compile(r'\s*--([^=]+)="([^"]*)"') _FLAG_NO_QUOTE_PAT = re.compile(r'\s*--([^=]+)=(\S*)') _FLAG_NAME_ENABLE = 'enable' _FLAG_NAME_TRACE_MODE = 'trace_mode' -_FLAG_NAME_INTERESTING_OPS = 'interesting_ops' +_FLAG_NAME_INCLUDE_LESS_INTERESTING_OPS = 'include_less_interesting_ops' +_FLAG_NAME_EXCLUDED_OPNAMES = 'excluded_opnames' +_FLAG_NAME_EXCLUDED_OPTYPES = 'excluded_optypes' +_FLAG_NAME_INCLUDED_OPNAMES = 'included_opnames' +_FLAG_NAME_INCLUDED_OPTYPES = 'included_optypes' _FLAG_NAME_TRACE_FILE = 'trace_file_path' +_FLAG_NAME_REPORT_FILE = 'report_file_path' _FLAG_NAME_USE_TEST_UNDECLARED_OUTPUTS_DIR = 'use_test_undeclared_outputs_dir' _FLAG_NAME_OP_RANGE = 'op_range' _OP_RANGE_PAT = re.compile(r'(\d+):(\d+)') _OUTPUT_STREAM_ESCAPE = 'file://' _TEST_UNDECLARED_OUTPUTS_DIR_ENV_VAR = 'TEST_UNDECLARED_OUTPUTS_DIR' +_TENSOR_TRACER_COLLECTION = 'tensor_tracer_variables' +_TENSOR_TRACER_CHECKPOINT = 'tensor_tracer_checkpoint' + + +def tensor_checkpoint(tensor, checkpoint_name): + """Adds a checkpoint with the given checkpoint name for the given tensor. + + The tensor will be added to the list of tensors that will be traced by the + tensor tracer. + + Args: + tensor: the tensor object for which the tracing is requested. + checkpoint_name: a string name for the checkpoint. This name has to be a + unique name if used within model comparison. The tensors that have the same + checkpoint identifier is compared in model comparison. + Returns: + The provided tensor. + """ + + tensor.graph.get_collection(_TENSOR_TRACER_COLLECTION) + tensor.graph.add_to_collection(_TENSOR_TRACER_COLLECTION, + (tensor, checkpoint_name)) + return tensor + + +def keras_layer_checkpoint(layer, checkpoint_name): + """An interface for adding the tensor outputs of a keras layer. + + Encapsulates tensor_checkpoint. + + Args: + layer: A keras layer. + checkpoint_name: a string name for the checkpoint. This name has to be a + unique name if used within model comparison. The tensors that have the same + checkpoint identifier is compared in model comparison. + + Returns: + The provided layer. + """ + try: + outputs = layer.output + if tensor_util.is_tensor(outputs): + tensor_checkpoint(outputs, '%s' % (checkpoint_name)) + else: + idx = 0 + for output_tensor in outputs: + if tensor_util.is_tensor(outputs): + tensor_checkpoint(output_tensor, '%s_%d' % (checkpoint_name, idx)) + idx += 1 + except AttributeError: + pass + except RuntimeError: + pass + return layer class TensorTracer(object): @@ -105,6 +176,34 @@ class TensorTracer(object): match = _FLAG_NO_QUOTE_PAT.match(flags, pos) return match + @staticmethod + def validate_flag_names(): + """Validates if the TensorTrace flags passed are valid.""" + valid_flag_names = [_FLAG_NAME_ENABLE, _FLAG_NAME_TRACE_MODE, + _FLAG_NAME_EXCLUDED_OPNAMES, + _FLAG_NAME_EXCLUDED_OPTYPES, + _FLAG_NAME_INCLUDED_OPNAMES, + _FLAG_NAME_INCLUDED_OPTYPES, + _FLAG_NAME_TRACE_FILE, _FLAG_NAME_REPORT_FILE, + _FLAG_NAME_USE_TEST_UNDECLARED_OUTPUTS_DIR, + _FLAG_NAME_INCLUDE_LESS_INTERESTING_OPS, + _FLAG_NAME_OP_RANGE] + tensor_tracer_flags = os.environ.get(_FLAGS_ENV_VAR) + if not tensor_tracer_flags: + return + pos = 0 + while True: + match = TensorTracer._match_next_flag(tensor_tracer_flags, pos) + if not match: + break + flag_name = match.group(1) + if flag_name not in valid_flag_names: + raise ValueError( + 'The flag name "%s" passed via the environment variable "%s" ' + 'is invalid. Valid flag names are:' + '\n%s'%(flag_name, _FLAGS_ENV_VAR, valid_flag_names)) + pos = match.end() + @staticmethod def print_flag_values(): """Prints all TensorTracer flags passed via environment variables.""" @@ -146,6 +245,20 @@ class TensorTracer(object): pos = match.end() return '' + @staticmethod + def flag_value_to_re_list(flag_name): + """Converts list of strings to compiled RE.""" + + re_list = [] + flag_value = TensorTracer.get_flag_value(flag_name) + if not flag_value: + return re_list + list_of_values = flag_value.split() + for v in list_of_values: + r = re.compile(v) + re_list.append(r) + return re_list + @staticmethod def is_enabled(): """Returns True if TensorTracer is enabled.""" @@ -186,29 +299,67 @@ class TensorTracer(object): """Checks if the given trace mode is valid.""" valid_trace_modes = [_TRACE_MODE_NAN_INF, _TRACE_MODE_PART_TENSOR, - _TRACE_MODE_FULL_TENSOR] + _TRACE_MODE_FULL_TENSOR, _TRACE_MODE_NORM, + _TRACE_MODE_MAX_ABS] if trace_mode not in valid_trace_modes: raise ValueError('Invalid trace mode "%s" given to the Tensor_Tracer.' 'Valid trace modes are: %s'%(trace_mode, valid_trace_modes)) @staticmethod - def should_trace(device_type, op): - """Returns True if the given Op should be traced.""" + def unsafe_op(op): + """Returns True if this op is not safe to be traced.""" - if device_type != _DEVICE_TYPE_TPU: - raise ValueError('Non TPU device type is not supported') if control_flow_util.IsInCond(op): + return True + # Reasons for not including following op types: + # Assign: cause incorrect result with CPU tracing. + # others: compilation problems. + if op.type in ['Assign', 'Pack', 'Shape', 'Reshape', 'ArgMin', 'ArgMax']: + return True + return False + + @staticmethod + def device_mismatch(device_type, op): + if device_type == _DEVICE_TYPE_TPU: + # pylint: disable=protected-access + return tpu._TPU_REPLICATE_ATTR not in op.node_def.attr + # pylint: enable=protected-access + return False + + @staticmethod + def unsafe_scalar_trace(op): + """Return true if scalar output tensor from Op is not safe to be traced.""" + + # Tracing the following causes cycle in the graph on TPU. + if op.type in ['LoopCond', 'Enter', 'Merge', 'Const', + 'Switch', 'Less', 'ReadVariableOp']: + return True + # Tracing the following will cause casting-issue + # with the norm tracing mode or other compilation issues on CPU. + if op.type in ['VarHandleOp', 'IteratorToStringHandle', + 'IteratorGetNext', 'OneShotIterator', + 'IteratorV2', 'MakeIterator', + 'BatchDatasetV2', 'MapDataset', + 'FixedLengthRecordDataset', 'TakeDataset', 'ZipDataset', + 'Placeholder', 'PlaceholderWithDefault', 'StridedSlice']: + return True + return False + + @staticmethod + def less_interesting_op(op): + """Returns True if the given Op is not an interesting one to be traced.""" + + include_less_interesting = TensorTracer.get_flag_value( + _FLAG_NAME_INCLUDE_LESS_INTERESTING_OPS) + if include_less_interesting: return False - if op.type in ['Reshape', 'ArgMin', 'ArgMax']: - return False - # pylint: disable=protected-access - return tpu._TPU_REPLICATE_ATTR in op.node_def.attr - # pylint: enable=protected-access + return op.type in ['Const', 'Identity', 'Cast', 'Shape'] @staticmethod def reason(op_idx, details): - """Returns why the Op at op_idx is traced or not.""" + """Returns reason why the Op at op_idx is traced or not.""" + return '%d %s'%(op_idx, details) @staticmethod @@ -274,6 +425,33 @@ class TensorTracer(object): assert len(unsorted_ops) == len(sorted_ops) return (True, sorted_ops) + @staticmethod + def _make_op_and_tensor_maps(op_list): + """Creates various maps and lists from op_list. + + Args: + op_list: a list of Ops + + Returns: + opname_idx_map: a map from Op's name to its index in op_list. + tensor_list: a list of output tensors of the Ops in op_list. + tensorname_idx_map: a map from output tensor name to its index + in tensor_list. + """ + + opname_idx_map = {} + tensor_list = [] + tensorname_idx_map = {} + for op_id, op in enumerate(op_list): + if op.name in opname_idx_map: + raise ValueError('Duplicated Op name: %s'%op.name) + opname_idx_map[op.name] = op_id + for output_tensor in op.outputs: + if output_tensor.name not in tensorname_idx_map: + tensor_list.append(output_tensor) + tensorname_idx_map[output_tensor.name] = len(tensor_list)-1 + return (opname_idx_map, tensor_list, tensorname_idx_map) + def __init__(self): """Initializes a TensorTracer. @@ -281,16 +459,20 @@ class TensorTracer(object): """ self._version = 'use-outside-compilation' self._device_type = None + TensorTracer.validate_flag_names() self._trace_mode = TensorTracer.get_flag_value(_FLAG_NAME_TRACE_MODE) if not self._trace_mode: self._trace_mode = _TRACE_MODE_NAN_INF TensorTracer.check_trace_mode(self._trace_mode) self._part_tensor_size = _TRACE_MODE_PART_TENSOR_SIZE self._instrument_records = {} - interesting_ops = TensorTracer.get_flag_value(_FLAG_NAME_INTERESTING_OPS) - self._selected_ops = interesting_ops.split() self._set_trace_file_path() + self._set_report_file() self._set_op_range() + self._set_excluded_opnames() + self._set_excluded_optypes() + self._set_included_opnames() + self._set_included_optypes() self._num_replicas = None self._replica_id = None @@ -318,10 +500,7 @@ class TensorTracer(object): """Sets the path of the output trace file.""" self._trace_file_path = TensorTracer.get_flag_value(_FLAG_NAME_TRACE_FILE) - if not self._trace_file_path: - raise ValueError('--%s is not set in the environment variable %s' - %(_FLAG_NAME_TRACE_FILE, _FLAGS_ENV_VAR)) - elif TensorTracer.use_test_undeclared_outputs_dir(): + if self._trace_file_path and TensorTracer.use_test_undeclared_outputs_dir(): if os.path.isabs(self._trace_file_path): raise ValueError('If use_test_undeclared_outputs_dir is set,' 'trace_file_path cannot be an absolute path (%s)' @@ -330,6 +509,22 @@ class TensorTracer(object): self._trace_file_path = os.path.join(outputs_dir, self._trace_file_path) + def _set_report_file(self): + """Sets the path of the output report file.""" + + self._report_file_path = TensorTracer.get_flag_value(_FLAG_NAME_REPORT_FILE) + if not self._report_file_path: + self._report_file = None + return + try: + self._report_file = gfile.Open(self._report_file_path, 'w') + except IOError as e: + raise e + + def _close_report_file(self): + if self._report_file: + self._report_file.close() + def _set_op_range(self): """Sets the index range of the Ops that we will consider tracing.""" @@ -350,19 +545,48 @@ class TensorTracer(object): return False return self._op_range[1] < 0 or idx <= self._op_range[1] - def _write_report(self, content): - """Writes the given content to the report.""" + def _set_excluded_opnames(self): + self._excluded_opname_re_list = TensorTracer.flag_value_to_re_list( + _FLAG_NAME_EXCLUDED_OPNAMES) + + def _set_excluded_optypes(self): + self._excluded_optype_re_list = TensorTracer.flag_value_to_re_list( + _FLAG_NAME_EXCLUDED_OPTYPES) + + def _set_included_opnames(self): + self._included_opname_re_list = TensorTracer.flag_value_to_re_list( + _FLAG_NAME_INCLUDED_OPNAMES) + + def _set_included_optypes(self): + self._included_optype_re_list = TensorTracer.flag_value_to_re_list( + _FLAG_NAME_INCLUDED_OPTYPES) + + def _is_user_included_op(self, op): + for opname_re in self._included_opname_re_list: + if opname_re.match(op.name): + return True + for optype_re in self._included_optype_re_list: + if optype_re.match(op.type): + return True + return False - logging.info('%s %s'%(_TRACER_LOG_PREFIX, content)) + def _is_user_excluded_op(self, op): + for opname_re in self._excluded_opname_re_list: + if opname_re.match(op.name): + return True + for optype_re in self._excluded_optype_re_list: + if optype_re.match(op.type): + return True + return False - def _is_selected_op(self, op_name): - """Returns True if the Op with op_name is selected to be traced.""" + def _write_report(self, content): + """Writes the given content to the report.""" - if not self._selected_ops: - return True - if op_name in self._selected_ops: - return True - return False + line = '%s %s'%(_TRACER_LOG_PREFIX, content) + if self._report_file: + self._report_file.write(line) + else: + logging.info(line) def _write_config_section(self): """Writes the config section of the report.""" @@ -382,15 +606,42 @@ class TensorTracer(object): self._write_report('"%s" %s\n'%(key, self._instrument_records[key])) self._write_report('%s %s\n'%(_MARKER_SECTION_END, _SECTION_NAME_REASON)) - def _write_op_list_section(self, op_list): + def _write_op_list_section(self, op_list, tensorname_idx_map): """Writes the Op-list section of the report.""" self._write_report('%s %s\n'%(_MARKER_SECTION_BEGIN, _SECTION_NAME_OP_LIST)) self._write_report('%s %d\n'%(_FIELD_NAME_NUM_OPS, len(op_list))) for i in range(0, len(op_list)): - self._write_report('%d "%s" %s\n'%(i, op_list[i].name, op_list[i].type)) + op = op_list[i] + line = '%d "%s" %s'%(i, op.name, op.type) + for out_tensor in op.outputs: + if out_tensor.name not in tensorname_idx_map: + raise ValueError( + 'out_tensor %s is not in tensorname_idx_map'%out_tensor.name) + line += ' %d'%tensorname_idx_map[out_tensor.name] + line += '\n' + self._write_report(line) self._write_report('%s %s\n'%(_MARKER_SECTION_END, _SECTION_NAME_OP_LIST)) + def _write_tensor_list_section(self, tensor_list, opname_idx_map): + """Writes the tensor-list section of the report.""" + + self._write_report('%s %s\n'%(_MARKER_SECTION_BEGIN, + _SECTION_NAME_TENSOR_LIST)) + self._write_report('%s %d\n'%(_FIELD_NAME_NUM_TENSORS, len(tensor_list))) + for i in range(0, len(tensor_list)): + tensor = tensor_list[i] + line = '%d "%s"'%(i, tensor.name) + for consumer_op in tensor.consumers(): + if consumer_op.name not in opname_idx_map: + raise ValueError( + 'consumer_op %s is not in opname_idx_map'%consumer_op.name) + line += ' %d'%opname_idx_map[consumer_op.name] + line += '\n' + self._write_report(line) + self._write_report('%s %s\n'%(_MARKER_SECTION_END, + _SECTION_NAME_TENSOR_LIST)) + def _write_graph_section(self, succeed, sorted_or_cycle): """Writes the graph section of the report.""" @@ -422,7 +673,7 @@ class TensorTracer(object): Args: op_name: the name of the Op that outputs the tensor to be printed. output_idx: which output of the Op it is (0 means the first output). - num_elements: number of elements to print. + num_elements: number of elements to print (-1 means print all). tensor: the tensor needs to be returned. output_tensor: the tensor needs to be printed. @@ -430,10 +681,13 @@ class TensorTracer(object): The same tensor passed via the "tensor" argument. """ msg = '"%s:%d" '%(op_name, output_idx) - output_stream = _OUTPUT_STREAM_ESCAPE + self._trace_file_path + if self._trace_file_path: + output_stream = _OUTPUT_STREAM_ESCAPE + self._trace_file_path + else: + output_stream = sys.stderr print_op = logging_ops.print_v2(msg, array_ops.shape(output_tensor), ' @', self._replica_id, - '\n', output_tensor, + '\n', output_tensor, '\n', summarize=num_elements, output_stream=output_stream) with ops.control_dependencies([print_op]): @@ -442,7 +696,8 @@ class TensorTracer(object): def _detect_nan_inf(tensor): """Trace function for detecting any NaN/Inf in the tensor.""" - if tensor.dtype.is_floating: + if tensor.dtype.__eq__(dtypes.bfloat16) or tensor.dtype.__eq__( + dtypes.float16): # Since host can't handle bf16, always convert tensor to f32. tensor = math_ops.cast(tensor, dtypes.float32) output_tensor = math_ops.reduce_any( @@ -450,12 +705,19 @@ class TensorTracer(object): gen_math_ops.is_inf(tensor))) else: output_tensor = constant_op.constant(0) - return _print_tensor(op_name, output_idx, 1, tensor, output_tensor) + return _print_tensor(op_name, output_idx, -1, tensor, output_tensor) - def _show_global_step(tensor): - """Trace function for printing the global step count.""" + def _show_norm(tensor): + tensor = math_ops.cast(tensor, dtypes.float64) + output_tensor = linalg_ops.norm(tensor) + return _print_tensor(op_name, output_idx, -1, tensor, output_tensor) - return _print_tensor(op_name, output_idx, 1, tensor, tensor) + def _show_max_abs(tensor): + output_tensor = math_ops.cast(math_ops.reduce_max(math_ops.abs(tensor)), + dtypes.float64) + zero = constant_op.constant(0, dtypes.float64) + output_tensor = gen_math_ops.maximum(zero, output_tensor) + return _print_tensor(op_name, output_idx, -1, tensor, output_tensor) def _show_part_tensor(tensor): """Trace function for printing part of the tensor.""" @@ -468,23 +730,139 @@ class TensorTracer(object): return _print_tensor(op_name, output_idx, -1, tensor, tensor) - if op_name == _GLOBAL_STEP_OP_NAME: - return _show_global_step if self._trace_mode == _TRACE_MODE_NAN_INF: return _detect_nan_inf if self._trace_mode == _TRACE_MODE_PART_TENSOR: return _show_part_tensor if self._trace_mode == _TRACE_MODE_FULL_TENSOR: return _show_full_tensor + if self._trace_mode == _TRACE_MODE_NORM: + return _show_norm + if self._trace_mode == _TRACE_MODE_MAX_ABS: + return _show_max_abs raise RuntimeError('Tensor trace fun for %s is not yet implemented' %self._trace_mode) + def _skip_op(self, op_id, op, user_included, user_excluded): + """Returns True if we should not trace Op.""" + + if user_included: + self._instrument_records[op.name] = TensorTracer.reason( + op_id, _REASON_USER_INCLUDED) + return False + if user_excluded: + self._instrument_records[op.name] = TensorTracer.reason( + op_id, _REASON_USER_EXCLUDED) + return True + if not self._inside_op_range(op_id): + self._instrument_records[op.name] = TensorTracer.reason( + op_id, _REASON_OUTSIDE_OP_RANGE) + return True + if TensorTracer.unsafe_op(op): + self._instrument_records[op.name] = TensorTracer.reason( + op_id, _REASON_UNSAFE_OP) + return True + if TensorTracer.device_mismatch(self._device_type, op): + self._instrument_records[op.name] = TensorTracer.reason( + op_id, _REASON_DEVICE_MISMATCH) + return True + if TensorTracer.less_interesting_op(op): + self._instrument_records[op.name] = TensorTracer.reason( + op_id, _REASON_LESS_INTERESTING_OP) + return True + return False + + def _skip_tensor(self, op_id, out_tensor, user_included, + user_excluded): + """Returns True if we should not trace out_tensor.""" + + # Skips a tensor if the tensor has a non-numeric type. + # Note: we cannot use check_ops.is_numeric_tensor(out_tensor) + # because it also excludes tensors with dtypes, bool, and + # float32_ref, which we actually want to trace. + non_numeric_tensor_types = set([dtypes.variant, dtypes.resource, + dtypes.string]) + if out_tensor.dtype in non_numeric_tensor_types: + self._instrument_records[out_tensor.name] = TensorTracer.reason( + op_id, _REASON_NON_NUMERIC_TENSOR) + return True + + if user_included: + self._instrument_records[out_tensor.name] = TensorTracer.reason( + op_id, _REASON_USER_INCLUDED) + return False + if user_excluded: + self._instrument_records[out_tensor.name] = TensorTracer.reason( + op_id, _REASON_USER_EXCLUDED) + return True + if not out_tensor.get_shape().is_fully_defined(): + self._instrument_records[out_tensor.name] = TensorTracer.reason( + op_id, _REASON_DYNAMIC_SHAPE) + return True + rank = len(out_tensor.shape) + if rank < 1: + # scalar + if TensorTracer.unsafe_scalar_trace(out_tensor.op): + self._instrument_records[out_tensor.name] = TensorTracer.reason( + op_id, _REASON_UNSAFE_SCALAR) + return True + else: + self._instrument_records[out_tensor.name] = TensorTracer.reason( + op_id, _REASON_SCALAR_GET_TRACED) + return False + else: + # tensor + self._instrument_records[out_tensor.name] = TensorTracer.reason( + op_id, _REASON_TENSOR_GET_TRACED) + return False + + def _pre_tracing(self, graph): + """Work needs to be done prior to TPU or CPU tracing.""" + + operations = graph.get_operations() + (opname_idx_map, tensor_list, tensorname_idx_map) = ( + TensorTracer._make_op_and_tensor_maps(operations)) + self._write_config_section() + self._write_op_list_section(operations, tensorname_idx_map) + self._write_tensor_list_section(tensor_list, opname_idx_map) + # Does the topological sort before adding any nodes to the graph. + (succeed, sorted_or_cycle) = TensorTracer.topological_sort(graph) + return (operations, succeed, sorted_or_cycle) + + def _post_tracing(self, succeed, sorted_or_cycle): + """Work needs to be done after TPU or CPU tracing.""" + + self._write_reason_section() + self._write_graph_section(succeed, sorted_or_cycle) + self._close_report_file() + + def _get_checkpoints(self, graph): + """Returns the list of Ops that produce the tensors traced with API. + + Args: + graph: the graph of Ops. + + Returns: + A set of operation names which should be traced. + """ + + self._write_report('%s %s\n'%(_MARKER_SECTION_BEGIN, + _TENSOR_TRACER_CHECKPOINT)) + checkpoint_operations = set() + tensor_tracer_variables = graph.get_collection(_TENSOR_TRACER_COLLECTION) + for (tensor, checkpoint_name) in tensor_tracer_variables: + self._write_report('%s %s\n'%(tensor.name, checkpoint_name)) + checkpoint_operations.add(tensor.op.name) + self._write_report('%s %s\n'%(_MARKER_SECTION_END, + _TENSOR_TRACER_CHECKPOINT)) + return checkpoint_operations + def trace_tpu(self, graph, result_tensor, num_replicas=None): """Traces the tensors generated by TPU Ops in a TF graph. Args: - graph: the graph of Ops. + graph: the graph of Ops executed on the TPU. result_tensor: a result tensor of evaluating the graph. num_replicas: number of replicas used on the TPU. @@ -502,38 +880,22 @@ class TensorTracer(object): TensorTracer.check_device_type(self._device_type) result_tensor_copy = self._add_replica_id_to_graph(num_replicas, result_tensor) - self._write_config_section() + (operations, succeed, sorted_or_cycle) = self._pre_tracing(graph) tracing_ops = [] - operations = graph.get_operations() - self._write_op_list_section(operations) - # Does the topological sort before adding any nodes to the graph. - (succeed, sorted_or_cycle) = TensorTracer.topological_sort(graph) + checkpoint_operations = self._get_checkpoints(graph) + for op_id, op in enumerate(operations): - if not self._inside_op_range(op_id): - self._instrument_records[op.name] = TensorTracer.reason( - op_id, _RECORD_OUTSIDE_OP_RANGE) + if checkpoint_operations and op.name not in checkpoint_operations: continue - if not TensorTracer.should_trace(self._device_type, op): - self._instrument_records[op.name] = TensorTracer.reason( - op_id, _RECORD_SHOULD_NOT_TRACE) - continue - if not self._is_selected_op(op.name): - self._instrument_records[op.name] = TensorTracer.reason( - op_id, _RECORD_FILTERED_OUT) + user_included = self._is_user_included_op(op) + user_excluded = self._is_user_excluded_op(op) + if self._skip_op(op_id, op, user_included, user_excluded): continue for i in range(len(op.outputs)): out_tensor = op.outputs[i] - if not out_tensor.get_shape().is_fully_defined(): - self._instrument_records[out_tensor.name] = TensorTracer.reason( - op_id, _RECORD_DYNAMIC_SHAPE) - continue # cannot trace tensors with dynamic shape. - rank = len(out_tensor.shape) - if rank < 1: - self._instrument_records[out_tensor.name] = TensorTracer.reason( - op_id, _RECORD_SCALAR) - continue # cannot trace scalar. - self._instrument_records[out_tensor.name] = TensorTracer.reason( - op_id, _RECORD_GET_TRACED) + if self._skip_tensor(op_id, out_tensor, user_included, + user_excluded): + continue consumers = out_tensor.consumers() trace_op = tpu.outside_compilation( self._make_tensor_trace_fun(op.name, i), out_tensor) @@ -546,8 +908,45 @@ class TensorTracer(object): # if there is no consumer, we will add the control dependence later # when we add the control dependency to the output operations. tracing_ops.append(trace_op) + self._post_tracing(succeed, sorted_or_cycle) + return (result_tensor_copy, tracing_ops) - self._write_reason_section() - self._write_graph_section(succeed, sorted_or_cycle) + def trace_cpu(self, graph): + """Traces the tensors generated by CPU Ops in a TF graph. - return (result_tensor_copy, tracing_ops) + Args: + graph: the graph of Ops executed on the CPU. + + Returns: + tracing_calls: a map from keys to trace calls. + A key is constructed from an Op's name. + A trace call consists of a function and a tensor ( + the function will be invoked with the tensor). + """ + + self._device_type = _DEVICE_TYPE_CPU + TensorTracer.check_device_type(self._device_type) + self._num_replicas = 1 + self._replica_id = 0 + (operations, succeed, sorted_or_cycle) = self._pre_tracing(graph) + tracing_calls = {} + checkpoint_operations = self._get_checkpoints(graph) + + for op_id, op in enumerate(operations): + if checkpoint_operations and op.name not in checkpoint_operations: + continue + user_included = self._is_user_included_op(op) + user_excluded = self._is_user_excluded_op(op) + if self._skip_op(op_id, op, user_included, user_excluded): + continue + for i in range(len(op.outputs)): + out_tensor = op.outputs[i] + if self._skip_tensor(op_id, out_tensor, user_included, + user_excluded): + continue + trace_fun = self._make_tensor_trace_fun(op.name, i) + trace_call = (trace_fun, [out_tensor]) + trace_call_key = 'tensor_tracing_cpu-%s:%d'%(op.name, i) + tracing_calls[trace_call_key] = trace_call + self._post_tracing(succeed, sorted_or_cycle) + return tracing_calls diff --git a/tensorflow/contrib/tpu/python/tpu/tpu.py b/tensorflow/contrib/tpu/python/tpu/tpu.py index def57da20d6018dcf27ccb7a9d04592f38ce2f7c..9266d81cf5fc035790062f0e307a5da0b01a9fc1 100644 --- a/tensorflow/contrib/tpu/python/tpu/tpu.py +++ b/tensorflow/contrib/tpu/python/tpu/tpu.py @@ -646,6 +646,10 @@ def split_compile_and_replicate(computation, array_ops.identity(x, name="replicated_input_{}".format(i)) for i, x in enumerate(computation_inputs) ] + for i in computation_inputs: + # pylint: disable=protected-access + i.op._set_attr("_tpu_input_identity", attr_value_pb2.AttrValue(b=True)) + # pylint: enable=protected-access # If there is an infeed queue, adds the dequeued values to the # computation's inputs. @@ -726,7 +730,11 @@ def split_compile_and_replicate(computation, new_output_tensors = [] for t in output_tensors: with ops.device(t.device if t.device else core(0)): - new_output_tensors.append(array_ops.identity(t)) + o = array_ops.identity(t) + # pylint: disable=protected-access + o.op._set_attr("_tpu_output_identity", attr_value_pb2.AttrValue(b=True)) + # pylint: enable=protected-access + new_output_tensors.append(o) output_tensors = new_output_tensors context.ExitResult(output_tensors) finally: @@ -777,15 +785,15 @@ def split_compile_and_replicate(computation, ] -def shard(computation, - inputs=None, - num_shards=1, - input_shard_axes=None, - outputs_from_all_shards=True, - output_shard_axes=None, - infeed_queue=None, - device_assignment=None, - name=None): +def split_compile_and_shard(computation, + inputs=None, + num_shards=1, + input_shard_axes=None, + outputs_from_all_shards=True, + output_shard_axes=None, + infeed_queue=None, + device_assignment=None, + name=None): """Shards `computation` for parallel execution. `inputs` must be a list of Tensors or None (equivalent to an empty list), each @@ -839,7 +847,7 @@ def shard(computation, is equal to the number of cores in the TPU system. name: (Deprecated) Does nothing. Returns: - A list of output tensors. + A tuple of (compile op, [output tensors]). Raises: ValueError: If num_shards <= 0 ValueError: If len(input_shard_axes) != len(inputs) @@ -874,7 +882,7 @@ def shard(computation, else: transposed_inputs = [[]] * num_shards - outputs = replicate( + compile_op, outputs = split_compile_and_replicate( computation, transposed_inputs, infeed_queue=infeed_queue, @@ -891,7 +899,7 @@ def shard(computation, # one so it can be used as a control dependency or fetch node. # TODO(b/36647078) remove disable when pylint bug is fixed. # pylint: disable=indexing-exception - return [outputs[0]] + return compile_op, [outputs[0]] # pylint: enable=indexing-exception # TODO(b/36647078) remove disable when pylint bug is fixed. @@ -925,7 +933,87 @@ def shard(computation, # TODO(phawkins): use a smarter policy, e.g., round-robin across shards. results.append(x[0]) - return results + return compile_op, results + + +def shard(computation, + inputs=None, + num_shards=1, + input_shard_axes=None, + outputs_from_all_shards=True, + output_shard_axes=None, + infeed_queue=None, + device_assignment=None, + name=None): + """Shards `computation` for parallel execution. + + `inputs` must be a list of Tensors or None (equivalent to an empty list), each + of which has a corresponding split axis (from `input_shard_axes`). Each input + is split into `num_shards` pieces along the corresponding axis, and + computation is applied to each shard in parallel. + + Tensors are broadcast to all shards if they are lexically captured by + `computation`. e.g., + + x = tf.constant(7) + def computation(): + return x + 3 + ... = shard(computation, ...) + + TODO(phawkins): consider adding support for broadcasting Tensors passed + as inputs. + + If `outputs_from_all_shards` is true, the outputs from all shards of + `computation` are concatenated back together along their `output_shards_axes`. + Otherwise, each output is taken from an arbitrary shard. + + Inputs and outputs of the computation must be at least rank-1 Tensors. + + Args: + computation: A Python function that builds a computation to apply to each + shard of the input. + inputs: A list of input tensors or None (equivalent to an empty list). Each + input tensor has a corresponding shard axes, given by `input_shard_axes`, + which must have size divisible by `num_shards`. + num_shards: The number of shards. + input_shard_axes: A list of dimensions along which to shard `inputs`, or + `None`. `None` means "shard all inputs along dimension 0". If not `None`, + there must be one dimension per input. + outputs_from_all_shards: Boolean or list of boolean. For each output, if + `True`, outputs from all shards are concatenated along the corresponding + `output_shard_axes` entry. Otherwise, each output is taken + from an arbitrary shard. If the argument is a boolean, the argument's + value is used for each output. + output_shard_axes: A list of dimensions along which to concatenate the + outputs of `computation`, or `None`. `None` means "concatenate all outputs + along dimension 0". If not `None`, there must be one dimension per output. + Ignored if `outputs_from_all_shards` is False. + infeed_queue: If not `None`, the `InfeedQueue` to use to augment the inputs + of `computation`. + device_assignment: If not `None`, a `DeviceAssignment` describing the + mapping between logical cores in the computation with physical cores in + the TPU topology. Uses a default device assignment if `None`. The + `DeviceAssignment` may be omitted if each shard of the computation uses + only one core, and there is either only one shard, or the number of shards + is equal to the number of cores in the TPU system. + name: (Deprecated) Does nothing. + Returns: + A list of output tensors. + Raises: + ValueError: If num_shards <= 0 + ValueError: If len(input_shard_axes) != len(inputs) + ValueError: If len(output_shard_axes) != len(outputs from `computation`) + """ + return split_compile_and_shard( + computation, + inputs=inputs, + num_shards=num_shards, + input_shard_axes=input_shard_axes, + outputs_from_all_shards=outputs_from_all_shards, + output_shard_axes=output_shard_axes, + infeed_queue=infeed_queue, + device_assignment=device_assignment, + name=name)[1] def batch_parallel(computation, diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py b/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py index 7171587ff7298982423a5046d85d1970a4d6b1cb..44a8f7ce0e5794ec95b5d0c25adca14b194a25d1 100644 --- a/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py +++ b/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py @@ -31,6 +31,7 @@ import six from six.moves import queue as Queue # pylint: disable=redefined-builtin from six.moves import xrange # pylint: disable=redefined-builtin +from tensorflow.contrib.tpu.proto import compilation_result_pb2 as tpu_compilation_result from tensorflow.contrib.tpu.python.tpu import tensor_tracer from tensorflow.contrib.tpu.python.ops import tpu_ops from tensorflow.contrib.tpu.python.tpu import error_handling @@ -45,6 +46,7 @@ from tensorflow.contrib.training.python.training import hparam from tensorflow.core.framework import variable_pb2 from tensorflow.core.framework.summary_pb2 import Summary from tensorflow.core.protobuf import config_pb2 +from tensorflow.python.client import session as tf_session from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.util import nest as data_nest from tensorflow.python.estimator import estimator as estimator_lib @@ -335,6 +337,16 @@ class TPUEstimatorSpec(model_fn_lib._TPUEstimatorSpec): # pylint: disable=prote hooks = None if self.host_call is not None: hooks = [_OutfeedHostCallHook(host_call_ret['host_call'])] + if tensor_tracer.TensorTracer.is_enabled(): + tt = tensor_tracer.TensorTracer() + tracing_calls = tt.trace_cpu(ops.get_default_graph()) + tracing_call_ret = _OutfeedHostCall.create_cpu_hostcall(tracing_calls) + tracing_functions = tracing_call_ret.values() + if tracing_functions: + if hooks: + hooks.extend([_OutfeedHostCallHook(tracing_functions)]) + else: + hooks = [_OutfeedHostCallHook(tracing_functions)] hooks = tuple(hooks or []) scaffold = self.scaffold_fn() if self.scaffold_fn else None return model_fn_lib.EstimatorSpec( @@ -411,13 +423,17 @@ class TPUInfeedOutfeedSessionHook(session_run_hook.SessionRunHook): ctx, enqueue_ops, dequeue_ops, + tpu_compile_op, run_infeed_loop_on_coordinator=True, - rendezvous=None): + rendezvous=None, + master=None, + session_config=None): self._master_job = ctx.master_job self._enqueue_ops = enqueue_ops self._dequeue_ops = dequeue_ops self._rendezvous = rendezvous - + self._master = master + self._session_config = session_config self._run_infeed_loop_on_coordinator = run_infeed_loop_on_coordinator self._initial_infeed_sleep_secs = ( ctx.config.tpu_config.initial_infeed_sleep_secs) @@ -425,15 +441,15 @@ class TPUInfeedOutfeedSessionHook(session_run_hook.SessionRunHook): self._feed_error = None self._finished = False self._should_initialize_tpu = True + self._tpu_compile_op = tpu_compile_op def begin(self): logging.info('TPU job name %s', self._master_job) self._iterations_per_loop_var = _create_or_get_iterations_per_loop() + self._init_ops = [] if self._should_initialize_tpu: - self._init_ops = [tpu.initialize_system(job=self._master_job)] self._finalize_ops = [tpu.shutdown_system(job=self._master_job)] else: - self._init_ops = [] self._finalize_ops = [] summary_writer_init_ops = contrib_summary.summary_writer_initializer_op() @@ -474,12 +490,31 @@ class TPUInfeedOutfeedSessionHook(session_run_hook.SessionRunHook): def _create_infeed_controller(self, name, target, args): return _OpQueueContext(name=name, target=target, args=args) + def _assertCompilationSucceeded(self, result, coord): + proto = tpu_compilation_result.CompilationResultProto() + proto.ParseFromString(result) + if proto.status_error_message: + logging.error('Compilation failed: {}'.format(proto.status_error_message)) + coord.request_stop() + else: + logging.info('Compilation succeeded') + def after_create_session(self, session, coord): - logging.info('Init TPU system') - start = time.time() + if self._should_initialize_tpu: + logging.info('Init TPU system') + start = time.time() + with ops.Graph().as_default(): + with tf_session.Session( + self._master, config=self._session_config) as sess: + sess.run(tpu.initialize_system(job=self._master_job)) + logging.info('Initialized TPU in %d seconds', time.time() - start) + session.run(self._init_ops, options=config_pb2.RunOptions(timeout_in_ms=5 * 60 * 1000)) - logging.info('Initialized TPU in %d seconds', time.time() - start) + + if os.environ.get('TPU_SPLIT_COMPILE_AND_EXECUTE', '') == '1': + logging.info('Compiling user program: this may take a while...') + self._assertCompilationSucceeded(session.run(self._tpu_compile_op), coord) self._infeed_controller = self._create_infeed_controller( name='InfeedController', target=self._run_infeed, args=(session,)) @@ -521,13 +556,17 @@ class TPUInfeedOutfeedSessionHook(session_run_hook.SessionRunHook): class TPUInfeedOutfeedSessionHookForPrediction(TPUInfeedOutfeedSessionHook): - def __init__(self, ctx, enqueue_ops, dequeue_ops, rendezvous=None): + def __init__(self, ctx, enqueue_ops, dequeue_ops, tpu_compile_op, + rendezvous=None, master=None, session_config=None): super(TPUInfeedOutfeedSessionHookForPrediction, self).__init__( ctx, enqueue_ops, dequeue_ops, + tpu_compile_op=tpu_compile_op, run_infeed_loop_on_coordinator=False, - rendezvous=rendezvous) + rendezvous=rendezvous, + master=master, + session_config=session_config) def _create_infeed_controller(self, name, target, args): return _OpSignalOnceQueueContext(name=name, target=target, args=args) @@ -2241,7 +2280,7 @@ class TPUEstimator(estimator_lib.Estimator): (k, _export_output_to_tensors(v)) for k, v in six.iteritems(estimator_spec.export_outputs)) tensors = nest.flatten(tensors_dict) - tpu_tensors = [t for t in tensors if _is_tpu_tensor(t)] + tpu_tensors = [t for t in tensors if t is not None] # We cannot return anything other than `tpu_tensors` here so we capture # the rest for later use. @@ -2255,18 +2294,10 @@ class TPUEstimator(estimator_lib.Estimator): # `tpu_tensors_on_cpu`. new_tensors = [] for t in tensors: - if _is_tpu_tensor(t): - new_tensors.append(tpu_tensors_on_cpu.pop(0)) - elif t is None: + if t is None: new_tensors.append(None) else: - # Only fetching `tpu_tensors_on_cpu` does not trigger - # TPU computation and blocks, so we add the control dependency here. - control_inputs = ( - tpu_tensors_on_cpu if _is_iterable(tpu_tensors_on_cpu) else - (tpu_tensors_on_cpu,)) - with ops.control_dependencies(control_inputs): - new_tensors.append(array_ops.identity(t)) + new_tensors.append(tpu_tensors_on_cpu.pop(0)) # Reconstruct `tensors_dict`. new_tensors_dict = nest.pack_sequence_as(tensors_dict, new_tensors) @@ -2523,7 +2554,7 @@ class TPUEstimator(estimator_lib.Estimator): graph.add_to_collection(_TPU_ENQUEUE_OPS, enqueue_op) if mode == model_fn_lib.ModeKeys.TRAIN: - loss, host_call, scaffold, training_hooks = ( + compile_op, loss, host_call, scaffold, training_hooks = ( _train_on_tpu_system(ctx, model_fn_wrapper, dequeue_fn)) host_ops = host_call.create_tpu_hostcall() if host_ops is None: @@ -2558,9 +2589,12 @@ class TPUEstimator(estimator_lib.Estimator): ctx, enqueue_ops, host_ops, + tpu_compile_op=compile_op, run_infeed_loop_on_coordinator=( run_infeed_loop_on_coordinator), rendezvous=self._rendezvous[mode], + master=self._config.master, + session_config=self._session_config, ), InstallSignalHandlerHook() ]) @@ -2613,8 +2647,8 @@ class TPUEstimator(estimator_lib.Estimator): scaffold=scaffold) if mode == model_fn_lib.ModeKeys.EVAL: - total_loss, host_calls, scaffold, eval_hooks = _eval_on_tpu_system( - ctx, model_fn_wrapper, dequeue_fn) + compile_op, total_loss, host_calls, scaffold, eval_hooks = ( + _eval_on_tpu_system(ctx, model_fn_wrapper, dequeue_fn)) iterations_per_loop_var = _create_or_get_iterations_per_loop() mean_loss = math_ops.div( total_loss, @@ -2661,10 +2695,13 @@ class TPUEstimator(estimator_lib.Estimator): ctx, enqueue_ops, eval_update_ops + host_ops, + tpu_compile_op=compile_op, run_infeed_loop_on_coordinator=( run_infeed_loop_on_coordinator), - rendezvous=self._rendezvous[mode]), - ] + input_hooks + rendezvous=self._rendezvous[mode], + master=self._config.evaluation_master, + session_config=self._session_config, + )] + input_hooks if eval_hooks: hooks.extend(eval_hooks) @@ -2679,7 +2716,7 @@ class TPUEstimator(estimator_lib.Estimator): # Predict assert mode == model_fn_lib.ModeKeys.PREDICT - (dummy_predict_op, host_calls, + (compile_op, dummy_predict_op, host_calls, scaffold, prediction_hooks) = _predict_on_tpu_system( ctx, model_fn_wrapper, dequeue_fn) with ops.control_dependencies([dummy_predict_op]): @@ -2735,7 +2772,10 @@ class TPUEstimator(estimator_lib.Estimator): hooks = [ _StoppingPredictHook(scalar_stopping_signal), TPUInfeedOutfeedSessionHookForPrediction( - ctx, enqueue_ops, host_ops, rendezvous=self._rendezvous[mode]), + ctx, enqueue_ops, host_ops, rendezvous=self._rendezvous[mode], + tpu_compile_op=compile_op, + master=self._config.master, + session_config=self._session_config), ] + input_hooks if prediction_hooks: @@ -2750,17 +2790,6 @@ class TPUEstimator(estimator_lib.Estimator): return _model_fn -def _is_tpu_tensor(tensor): - if not isinstance(tensor, ops.Tensor): - return False - try: - tensor.op.get_attr(tpu._OUTSIDE_COMPILATION_ATTR) # pylint: disable=protected-access - except ValueError: - return True - else: - return False - - def _export_output_to_tensors(export_output): """Get a list of `Tensors` used in `export_output`. @@ -2832,15 +2861,16 @@ def _eval_on_tpu_system(ctx, model_fn_wrapper, dequeue_fn): return training_loop.repeat(iterations_per_loop_var, single_tpu_eval_step, [_ZERO_LOSS]) - (loss,) = tpu.shard( + (compile_op, loss,) = tpu.split_compile_and_shard( multi_tpu_eval_steps_on_single_shard, inputs=[], num_shards=ctx.num_replicas, outputs_from_all_shards=False, device_assignment=ctx.device_assignment) + loss = loss[0] scaffold = _get_scaffold(captured_scaffold_fn) - return loss, host_calls, scaffold, captured_eval_hooks.get() + return compile_op, loss, host_calls, scaffold, captured_eval_hooks.get() def _train_on_tpu_system(ctx, model_fn_wrapper, dequeue_fn): @@ -2855,15 +2885,16 @@ def _train_on_tpu_system(ctx, model_fn_wrapper, dequeue_fn): return training_loop.repeat(iterations_per_loop_var, single_tpu_train_step, [_INITIAL_LOSS]) - (loss,) = tpu.shard( + (compile_op, loss,) = tpu.split_compile_and_shard( multi_tpu_train_steps_on_single_shard, inputs=[], num_shards=ctx.num_replicas, outputs_from_all_shards=False, device_assignment=ctx.device_assignment) + loss = loss[0] scaffold = _get_scaffold(captured_scaffold_fn) - return loss, host_call, scaffold, captured_training_hooks.get() + return compile_op, loss, host_call, scaffold, captured_training_hooks.get() def _predict_on_tpu_system(ctx, model_fn_wrapper, dequeue_fn): @@ -2883,15 +2914,17 @@ def _predict_on_tpu_system(ctx, model_fn_wrapper, dequeue_fn): cond, single_tpu_predict_step, inputs=inputs, name=b'loop') return outputs - (dummy_predict_op,) = tpu.shard( + (compile_op, dummy_predict_op,) = tpu.split_compile_and_shard( multi_tpu_predict_steps_on_single_shard, inputs=[], num_shards=ctx.num_replicas, outputs_from_all_shards=False, device_assignment=ctx.device_assignment) + dummy_predict_op = dummy_predict_op[0] scaffold = _get_scaffold(captured_scaffold_fn) - return dummy_predict_op, host_calls, scaffold, captured_predict_hooks.get() + return (compile_op, dummy_predict_op, host_calls, scaffold, + captured_predict_hooks.get()) def _wrap_computation_in_while_loop(device, op_fn): @@ -3081,7 +3114,7 @@ class _Inputs(object): The initializer must be run before calling `features_and_labels`. """ - self._iterator = self._dataset.make_initializable_iterator() + self._iterator = dataset_ops.make_initializable_iterator(self._dataset) return self._iterator.initializer def features_and_labels(self): diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_estimator_signals_test.py b/tensorflow/contrib/tpu/python/tpu/tpu_estimator_signals_test.py index 3786e52b949dfac8c1587d1ea3041b625f00183f..e3ea983abfd24d03c964fbc647b56262e15e0a96 100644 --- a/tensorflow/contrib/tpu/python/tpu/tpu_estimator_signals_test.py +++ b/tensorflow/contrib/tpu/python/tpu/tpu_estimator_signals_test.py @@ -21,8 +21,8 @@ from __future__ import print_function import numpy as np from tensorflow.contrib.tpu.python.tpu import tpu_estimator -from tensorflow.python import data as dataset_lib from tensorflow.python.client import session +from tensorflow.python.data.ops import dataset_ops from tensorflow.python.framework import errors from tensorflow.python.framework import ops from tensorflow.python.platform import test @@ -34,10 +34,10 @@ def make_input_fn(num_samples): def input_fn(params): batch_size = params['batch_size'] - da1 = dataset_lib.Dataset.from_tensor_slices(a) - da2 = dataset_lib.Dataset.from_tensor_slices(b) + da1 = dataset_ops.Dataset.from_tensor_slices(a) + da2 = dataset_ops.Dataset.from_tensor_slices(b) - dataset = dataset_lib.Dataset.zip((da1, da2)) + dataset = dataset_ops.Dataset.zip((da1, da2)) dataset = dataset.map(lambda fa, fb: {'a': fa, 'b': fb}) dataset = dataset.batch(batch_size) return dataset @@ -50,10 +50,10 @@ def make_input_fn_with_labels(num_samples): def input_fn(params): batch_size = params['batch_size'] - da1 = dataset_lib.Dataset.from_tensor_slices(a) - da2 = dataset_lib.Dataset.from_tensor_slices(b) + da1 = dataset_ops.Dataset.from_tensor_slices(a) + da2 = dataset_ops.Dataset.from_tensor_slices(b) - dataset = dataset_lib.Dataset.zip((da1, da2)) + dataset = dataset_ops.Dataset.zip((da1, da2)) dataset = dataset.map(lambda fa, fb: ({'a': fa}, fb)) dataset = dataset.batch(batch_size) return dataset @@ -71,7 +71,7 @@ class TPUEstimatorStoppingSignalsTest(test.TestCase): with ops.Graph().as_default(): dataset = input_fn(params) - features = dataset.make_one_shot_iterator().get_next() + features = dataset_ops.make_one_shot_iterator(dataset).get_next() # With tf.data.Dataset.batch, the batch is None, i.e., dynamic shape. self.assertIsNone(features['a'].shape.as_list()[0]) diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_system_metadata.py b/tensorflow/contrib/tpu/python/tpu/tpu_system_metadata.py index ec682e5829c4df536a043334b74200f0b6259df3..d66ecfcf4a56b8da1c2d2f518bebe4baa76b315e 100644 --- a/tensorflow/contrib/tpu/python/tpu/tpu_system_metadata.py +++ b/tensorflow/contrib/tpu/python/tpu/tpu_system_metadata.py @@ -52,6 +52,7 @@ def _query_tpu_system_metadata(master_address, cluster_def=None, devices = [] device_dict = collections.defaultdict(list) + # TODO(b/120564445): Replace with standard library for retries. retry_count = 1 while True: logging.info('Querying Tensorflow master (%s) for TPU system metadata.', diff --git a/tensorflow/contrib/tpu/tpu_estimator.md b/tensorflow/contrib/tpu/tpu_estimator.md index b6514e19dc92fe4c7cdcdb6582a7c0ad5ad573d5..552febd80bd35b37a95cdaaf8d5923278311ac8e 100644 --- a/tensorflow/contrib/tpu/tpu_estimator.md +++ b/tensorflow/contrib/tpu/tpu_estimator.md @@ -89,12 +89,9 @@ handle training: dataset = tf.data.TFRecordDataset( filename, buffer_size=FLAGS.dataset_reader_buffer_size) - dataset = dataset.map(parser).cache().repeat().batch(batch_size) - images, labels = dataset.make_one_shot_iterator().get_next() - # set_shape to give inputs statically known shapes. - images.set_shape([batch_size, 28 * 28]) - labels.set_shape([batch_size]) - return images, labels + dataset = dataset.map(parser).cache().repeat().batch( + batch_size, drop_remainder=True) + return dataset return input_fn diff --git a/tensorflow/contrib/training/BUILD b/tensorflow/contrib/training/BUILD index 00295f57f60858db5234ce28cc643ea9eee44daa..f6427ae05a20f253edf030eff0f860361616042b 100644 --- a/tensorflow/contrib/training/BUILD +++ b/tensorflow/contrib/training/BUILD @@ -26,7 +26,6 @@ py_library( "python/training/resample.py", "python/training/sampling_ops.py", "python/training/sequence_queueing_state_saver.py", - "python/training/tensor_queue_dataset.py", "python/training/training.py", "python/training/tuner.py", ], @@ -287,28 +286,6 @@ py_test( ], ) -py_test( - name = "tensor_queue_dataset_test", - size = "large", - srcs = ["python/training/tensor_queue_dataset_test.py"], - srcs_version = "PY2AND3", - tags = ["notsan"], - deps = [ - ":training_py", - "//tensorflow/python:client_testlib", - "//tensorflow/python:framework_for_generated_wrappers", - "//tensorflow/python:gradients", - "//tensorflow/python:math_ops", - "//tensorflow/python:platform", - "//tensorflow/python:random_seed", - "//tensorflow/python:training", - "//tensorflow/python:variables", - "//tensorflow/python/data", - "//tensorflow/python/data/experimental/kernel_tests/serialization:dataset_serialization_test_base", - "//third_party/py/numpy", - ], -) - tf_proto_library( name = "protos_all", srcs = glob(["**/*.proto"]), diff --git a/tensorflow/contrib/training/__init__.py b/tensorflow/contrib/training/__init__.py index 3547e71184ec2b99163ea4247c01d24487811b47..87ce57ef060a0eb9383248255713421c14988416 100644 --- a/tensorflow/contrib/training/__init__.py +++ b/tensorflow/contrib/training/__init__.py @@ -59,8 +59,6 @@ from tensorflow.contrib.training.python.training.hparam import * from tensorflow.contrib.training.python.training.resample import * from tensorflow.contrib.training.python.training.sampling_ops import * from tensorflow.contrib.training.python.training.sequence_queueing_state_saver import * -from tensorflow.contrib.training.python.training.tensor_queue_dataset import enqueue_in_queue_dataset -from tensorflow.contrib.training.python.training.tensor_queue_dataset import prepend_from_queue_and_padded_batch_dataset from tensorflow.contrib.training.python.training.training import add_gradients_summaries from tensorflow.contrib.training.python.training.training import clip_gradient_norms from tensorflow.contrib.training.python.training.training import clip_gradient_norms_fn @@ -79,7 +77,6 @@ _allowed_symbols = [ 'FeedingQueueRunner', 'get_or_create_eval_step', 'StopAfterNEvalsHook', 'SummaryAtEndHook', 'wait_for_new_checkpoint', 'add_gradients_summaries', 'clip_gradient_norms', 'clip_gradient_norms_fn', 'create_train_op', - 'multiply_gradients', 'enqueue_in_queue_dataset', - 'prepend_from_queue_and_padded_batch_dataset', 'train'] + 'multiply_gradients', 'train'] remove_undocumented(__name__, _allowed_symbols) diff --git a/tensorflow/contrib/training/python/training/hparam.py b/tensorflow/contrib/training/python/training/hparam.py index 3beb7bfe3048a8f0294f7e9149b5a07b5fcc7d17..bcc177601b95172b05d327247bd370c2f8b65d59 100644 --- a/tensorflow/contrib/training/python/training/hparam.py +++ b/tensorflow/contrib/training/python/training/hparam.py @@ -187,7 +187,7 @@ def _cast_to_type_if_compatible(name, param_type, value): return param_type(value) -def parse_values(values, type_map): +def parse_values(values, type_map, ignore_unknown=False): """Parses hyperparameter values from a string into a python map. `values` is a string containing comma-separated `name=value` pairs. @@ -233,6 +233,9 @@ def parse_values(values, type_map): type T if either V has type T, or V is a list of elements of type T. Hence, for a multidimensional parameter 'x' taking float values, 'x=[0.1,0.2]' will parse successfully if type_map['x'] = float. + ignore_unknown: Bool. Whether values that are missing a type in type_map + should be ignored. If set to True, a ValueError will not be raised for + unknown hyperparameter type. Returns: A python map mapping each name to either: @@ -260,6 +263,8 @@ def parse_values(values, type_map): m_dict = m.groupdict() name = m_dict['name'] if name not in type_map: + if ignore_unknown: + continue raise ValueError('Unknown hyperparameter type for %s' % name) type_ = type_map[name] diff --git a/tensorflow/contrib/training/python/training/hparam_test.py b/tensorflow/contrib/training/python/training/hparam_test.py index 660c97f25e8458c345c8914bcaf98f37d047e50e..a990e04711ce68bd928a508484f0d6f657dd2f8c 100644 --- a/tensorflow/contrib/training/python/training/hparam_test.py +++ b/tensorflow/contrib/training/python/training/hparam_test.py @@ -216,6 +216,14 @@ class HParamsTest(test.TestCase): self.assertTrue(isinstance(parse_dict['arr'], dict)) self.assertDictEqual(parse_dict['arr'], {1: 10}) + def testParseValuesWithIndexAssigment1_IgnoreUnknown(self): + """Assignment to an index position.""" + parse_dict = hparam.parse_values( + 'arr[1]=10,b=5', {'arr': int}, ignore_unknown=True) + self.assertEqual(len(parse_dict), 1) + self.assertTrue(isinstance(parse_dict['arr'], dict)) + self.assertDictEqual(parse_dict['arr'], {1: 10}) + def testParseValuesWithIndexAssigment2(self): """Assignment to multiple index positions.""" parse_dict = hparam.parse_values('arr[0]=10,arr[5]=20', {'arr': int}) @@ -223,6 +231,14 @@ class HParamsTest(test.TestCase): self.assertTrue(isinstance(parse_dict['arr'], dict)) self.assertDictEqual(parse_dict['arr'], {0: 10, 5: 20}) + def testParseValuesWithIndexAssigment2_IgnoreUnknown(self): + """Assignment to multiple index positions.""" + parse_dict = hparam.parse_values( + 'arr[0]=10,arr[5]=20,foo=bar', {'arr': int}, ignore_unknown=True) + self.assertEqual(len(parse_dict), 1) + self.assertTrue(isinstance(parse_dict['arr'], dict)) + self.assertDictEqual(parse_dict['arr'], {0: 10, 5: 20}) + def testParseValuesWithIndexAssigment3(self): """Assignment to index positions in multiple names.""" parse_dict = hparam.parse_values('arr[0]=10,arr[1]=20,L[5]=100,L[10]=200', @@ -234,6 +250,17 @@ class HParamsTest(test.TestCase): self.assertTrue(isinstance(parse_dict['L'], dict)) self.assertDictEqual(parse_dict['L'], {5: 100, 10: 200}) + def testParseValuesWithIndexAssigment3_IgnoreUnknown(self): + """Assignment to index positions in multiple names.""" + parse_dict = hparam.parse_values( + 'arr[0]=10,C=5,arr[1]=20,B[0]=kkk,L[5]=100,L[10]=200', + {'arr': int, 'L': int}, ignore_unknown=True) + self.assertEqual(len(parse_dict), 2) + self.assertTrue(isinstance(parse_dict['arr'], dict)) + self.assertDictEqual(parse_dict['arr'], {0: 10, 1: 20}) + self.assertTrue(isinstance(parse_dict['L'], dict)) + self.assertDictEqual(parse_dict['L'], {5: 100, 10: 200}) + def testParseValuesWithIndexAssigment4(self): """Assignment of index positions and scalars.""" parse_dict = hparam.parse_values('x=10,arr[1]=20,y=30', @@ -246,6 +273,17 @@ class HParamsTest(test.TestCase): self.assertEqual(parse_dict['x'], 10) self.assertEqual(parse_dict['y'], 30) + def testParseValuesWithIndexAssigment4_IgnoreUnknown(self): + """Assignment of index positions and scalars.""" + parse_dict = hparam.parse_values( + 'x=10,foo[0]=bar,arr[1]=20,zzz=78,y=30', + {'x': int, 'y': int, 'arr': int}, ignore_unknown=True) + self.assertEqual(len(parse_dict), 3) + self.assertTrue(isinstance(parse_dict['arr'], dict)) + self.assertDictEqual(parse_dict['arr'], {1: 20}) + self.assertEqual(parse_dict['x'], 10) + self.assertEqual(parse_dict['y'], 30) + def testParseValuesWithIndexAssigment5(self): """Different variable types.""" parse_dict = hparam.parse_values('a[0]=5,b[1]=true,c[2]=abc,d[3]=3.14', { @@ -264,24 +302,55 @@ class HParamsTest(test.TestCase): self.assertTrue(isinstance(parse_dict['d'], dict)) self.assertDictEqual(parse_dict['d'], {3: 3.14}) + def testParseValuesWithIndexAssigment5_IgnoreUnknown(self): + """Different variable types.""" + parse_dict = hparam.parse_values( + 'a[0]=5,cc=4,b[1]=true,c[2]=abc,mm=2,d[3]=3.14', + {'a': int, 'b': bool, 'c': str, 'd': float}, + ignore_unknown=True) + self.assertEqual(set(parse_dict.keys()), {'a', 'b', 'c', 'd'}) + self.assertTrue(isinstance(parse_dict['a'], dict)) + self.assertDictEqual(parse_dict['a'], {0: 5}) + self.assertTrue(isinstance(parse_dict['b'], dict)) + self.assertDictEqual(parse_dict['b'], {1: True}) + self.assertTrue(isinstance(parse_dict['c'], dict)) + self.assertDictEqual(parse_dict['c'], {2: 'abc'}) + self.assertTrue(isinstance(parse_dict['d'], dict)) + self.assertDictEqual(parse_dict['d'], {3: 3.14}) + def testParseValuesWithBadIndexAssigment1(self): """Reject assignment of list to variable type.""" with self.assertRaisesRegexp(ValueError, r'Assignment of a list to a list index.'): hparam.parse_values('arr[1]=[1,2,3]', {'arr': int}) + def testParseValuesWithBadIndexAssigment1_IgnoreUnknown(self): + """Reject assignment of list to variable type.""" + with self.assertRaisesRegexp(ValueError, + r'Assignment of a list to a list index.'): + hparam.parse_values( + 'arr[1]=[1,2,3],c=8', {'arr': int}, ignore_unknown=True) + def testParseValuesWithBadIndexAssigment2(self): """Reject if type missing.""" with self.assertRaisesRegexp(ValueError, r'Unknown hyperparameter type for arr'): hparam.parse_values('arr[1]=5', {}) + def testParseValuesWithBadIndexAssigment2_IgnoreUnknown(self): + """Ignore missing type.""" + hparam.parse_values('arr[1]=5', {}, ignore_unknown=True) + def testParseValuesWithBadIndexAssigment3(self): """Reject type of the form name[index].""" with self.assertRaisesRegexp(ValueError, 'Unknown hyperparameter type for arr'): hparam.parse_values('arr[1]=1', {'arr[1]': int}) + def testParseValuesWithBadIndexAssigment3_IgnoreUnknown(self): + """Ignore type of the form name[index].""" + hparam.parse_values('arr[1]=1', {'arr[1]': int}, ignore_unknown=True) + def testWithReusedVariables(self): with self.assertRaisesRegexp(ValueError, 'Multiple assignments to variable \'x\''): diff --git a/tensorflow/contrib/training/python/training/tensor_queue_dataset.py b/tensorflow/contrib/training/python/training/tensor_queue_dataset.py deleted file mode 100644 index 8896a95327a4cb609a9a78412afa68b316a3131e..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/training/python/training/tensor_queue_dataset.py +++ /dev/null @@ -1,201 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Python wrappers for Datasets and Iterators.""" -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -from tensorflow.python.data.ops import dataset_ops -from tensorflow.python.data.util import convert -from tensorflow.python.data.util import nest -from tensorflow.python.data.util import sparse -from tensorflow.python.framework import dtypes -from tensorflow.python.framework import ops -from tensorflow.python.framework import tensor_shape -from tensorflow.python.framework import tensor_util -from tensorflow.python.ops import gen_dataset_ops -from tensorflow.python.util import nest as tf_nest - - -class _PrependFromQueueAndPaddedBatchDataset(dataset_ops.UnaryDataset): - """A `Dataset` that prepends a queue to another `Dataset`. - - A vector of handles to the queue is returned as the first component of - the associated iterator. This vector can be passed to - `enqueue_in_queue_dataset` to add new elements to the queue. - """ - - def __init__(self, input_dataset, batch_size, padded_shapes, padding_values): - """Initialize `PrependFromQueueAndPaddedBatchDataset`.""" - super(_PrependFromQueueAndPaddedBatchDataset, self).__init__(input_dataset) - if sparse.any_sparse(input_dataset.output_classes): - raise TypeError( - "Batching of padded sparse tensors is not currently supported") - self._input_dataset = input_dataset - self._batch_size = ops.convert_to_tensor( - batch_size, dtype=dtypes.int64, name="batch_size") - if padded_shapes is None: - self._padded_shapes = nest.map_structure( - convert.partial_shape_to_tensor, input_dataset.output_shapes) - else: - self._padded_shapes = nest.map_structure_up_to( - input_dataset.output_shapes, convert.partial_shape_to_tensor, - padded_shapes) - # pylint: disable=protected-access - padding_values = ( - padding_values if padding_values is not None else - dataset_ops._default_padding(input_dataset)) - self._padding_values = nest.map_structure_up_to( - input_dataset.output_shapes, dataset_ops._padding_value_to_tensor, - padding_values, input_dataset.output_types) - # pylint: enable=protected-access - - def _as_variant_tensor(self): - # pylint: disable=protected-access - return gen_dataset_ops.prepend_from_queue_and_padded_batch_dataset( - self._input_dataset._as_variant_tensor(), - batch_size=self._batch_size, - padded_shapes=[ - ops.convert_to_tensor(s, dtype=dtypes.int64) - for s in nest.flatten(self._padded_shapes) - ], - padding_values=nest.flatten(self._padding_values), - output_shapes=nest.flatten( - sparse.as_dense_shapes(self.output_shapes, self.output_classes))) - # pylint: enable=protected-access - - @property - def output_classes(self): - return (ops.Tensor, self._input_dataset.output_classes) - - def _as_batch_shape(self, shape_like): - return tensor_shape.vector(None).concatenate( - tensor_util.constant_value_as_shape(shape_like)) - - @property - def output_shapes(self): - # First output is a variant representing the Queue - return (tensor_shape.vector(None), - nest.map_structure(self._as_batch_shape, self._padded_shapes)) - - @property - def output_types(self): - # First output is a variant representing the Queue - return (dtypes.variant, self._input_dataset.output_types) - - -def prepend_from_queue_and_padded_batch_dataset(batch_size, - padding_values=None, - padded_shapes=None): - """A transformation that prepends a queue to a `Dataset` and batches results. - - A vector of handles to the queue is returned as the first component of the - associated iterator. This vector can be passed to `enqueue_in_queue_dataset` - to add new elements to the queue. - - Below is an example of how this dataset might be used to split incoming - variable-length sequences into "head" and "rest" parts, where "rest" parts - are re-enqueued back into the dataset. A more realistic example would - perform some calculation on the "head" and modify some components of "rest" - with the result (before re-enqueueing). - - ```python - dataset = tf.data.Dataset.from_tensor_slices([2*x for x in range(10)]) - # Make a dataset of variable-length vectors and their lengths. - dataset = dataset.map(lambda count: (count, tf.ones((count,)))) - # Emit a queue we can prepend to, and counts/values as padded batch. - dataset = dataset.apply( - tf.contrib.training.prepend_from_queue_and_padded_batch_dataset( - batch_size=10)) - dataset = dataset.prefetch(1) - - iterator = dataset.make_one_shot_iterator() - queue, (count, padded_value) = iterator.get_next() - - # Split the padded_value into two pieces: head and rest - rest_indices = tf.squeeze(tf.where(count > 3), axis=1) - bound = tf.minimum(3, tf.reduce_max(count)) - value_head = padded_value[:, :bound] - count_rest = tf.gather(count - 3, rest_indices) - value_rest = tf.gather(padded_value[:, bound:], rest_indices) - queue_rest = tf.gather(queue, rest_indices) - enqueue_rest_op = tf.contrib.training.enqueue_in_queue_dataset( - queue_rest, (count_rest, value_rest)) - with tf.control_dependencies([enqueue_rest_op]): - calculation = fn(value_head) - - while True: # Will raise OutOfRange when finished with all pieces. - session.run(calculation) - ``` - - Args: - batch_size: `int64` scalar tensor. The batch size to use when performing - padded batching. - padding_values: (optional) Nested tuple of scalar tensors. If provided, - the structure and dtypes of padding_values should match that of - incoming dataset's `output_types`. - padded_shapes: (optional) Nested tuple of `int64` vector tensors. - If provided, the structure must match that of the incoming dataset's - `output_types`. If not provided, the incoming dataset's `output_shapes` - is used. Any unknown (`None` or `-1`) dimensions in the shapes are - treated as being unique per-batch: for each batch time, an unknown - dimension is replaced with the maximum given value of this dimension - across all tensors for the given component in the batch. - - Returns: - A `Dataset` transformation function, which can be passed to - `tf.data.Dataset.apply`. - """ - - def _apply_fn(dataset): - return _PrependFromQueueAndPaddedBatchDataset( - dataset, - batch_size=batch_size, - padding_values=padding_values, - padded_shapes=padded_shapes) - - return _apply_fn - - -def enqueue_in_queue_dataset(queue, components): - """Enqueue components into queue from `PrependFromQueueAndPaddedBatchDataset`. - - The components' dtypes and shapes must be compatible with the `output_shapes` - attribute of the `dataset` created by - `prepend_from_queue_and_padded_batch_dataset`. This operation supports both - non-batched and batched modes. - - For more details, see the example in the docstring for - `prepend_from_queue_and_padded_batch_dataset`. - - Args: - queue: `variant` scalar or vector tensor. - The tensor emitted by the first component of the iterator associated with - `prepend_from_queue_and_padded_batch_dataset`. If this is a scalar, - then the `components` input tensors should not have a prepended batch - dimension. - components: Nested tuple of tensors, each with a leading batch dimension - if `queue` is a vector. The structure, dtypes, and shapes - (excluding batch dimension) must match the nested tuples - `dataset.output_types[1]` and `dataset.output_shapes[1]` (the non-queue - output types and shapes) of the `dataset` emitted by - the original `prepend_from_queue_and_padded_batch_dataset` call. - - Returns: - An `Operation` that enqueues `components` into the dataset(s) associated - with entries of `queue`. - """ - return gen_dataset_ops.enqueue_in_queue_dataset( - queue=queue, components=tf_nest.flatten(components)) diff --git a/tensorflow/contrib/training/python/training/tensor_queue_dataset_test.py b/tensorflow/contrib/training/python/training/tensor_queue_dataset_test.py deleted file mode 100644 index c1657fec7bbe4a3227c3ea273b72176ac4066c50..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/training/python/training/tensor_queue_dataset_test.py +++ /dev/null @@ -1,355 +0,0 @@ -# Copyright 2016 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Tests for TensorQueueDataset.""" -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import numpy as np - -from tensorflow.contrib.training.python.training import tensor_queue_dataset as tqd -from tensorflow.python.data.experimental.kernel_tests.serialization import dataset_serialization_test_base -from tensorflow.python.data.ops import dataset_ops -from tensorflow.python.framework import dtypes -from tensorflow.python.framework import errors -from tensorflow.python.framework import ops -from tensorflow.python.ops import array_ops -from tensorflow.python.ops import math_ops -from tensorflow.python.ops import string_ops -from tensorflow.python.platform import test - - -class PrependFromQueueAndPaddedBatchDatasetTest(test.TestCase): - - def testNoEnqueue(self): - dataset = dataset_ops.Dataset.from_tensor_slices([0, 1, 2]) - dataset = dataset.apply( - tqd.prepend_from_queue_and_padded_batch_dataset(batch_size=1)) - self.assertEqual((dtypes.variant, dtypes.int32), dataset.output_types) - self.assertAllEqual(([None],) * 2, - [x.as_list() for x in dataset.output_shapes]) - iterator = dataset.make_one_shot_iterator() - _, value = iterator.get_next() - self.assertEqual([0], self.evaluate(value)) - self.assertEqual([1], self.evaluate(value)) - self.assertEqual([2], self.evaluate(value)) - with self.assertRaisesOpError("End of sequence"): - self.evaluate(value) - - def testBatchedNoEnqueue(self): - dataset = dataset_ops.Dataset.from_tensor_slices([0, 1, 2]) - dataset = dataset.apply( - tqd.prepend_from_queue_and_padded_batch_dataset(batch_size=2)) - iterator = dataset.make_one_shot_iterator() - _, value = iterator.get_next() - self.assertAllEqual([0, 1], self.evaluate(value)) - self.assertAllEqual([2], self.evaluate(value)) - with self.assertRaisesOpError("End of sequence"): - self.evaluate(value) - - def testBatchedWithBiggerPaddingNoEnqueue(self): - dataset = dataset_ops.Dataset.from_tensor_slices([[0], [1], [2]]) - dataset = dataset.apply( - tqd.prepend_from_queue_and_padded_batch_dataset( - batch_size=2, padded_shapes=[3])) - iterator = dataset.make_one_shot_iterator() - _, value = iterator.get_next() - self.assertAllEqual([[0, 0, 0], [1, 0, 0]], self.evaluate(value)) - self.assertAllEqual([[2, 0, 0]], self.evaluate(value)) - with self.assertRaisesOpError("End of sequence"): - self.evaluate(value) - - def testBatchedWithBiggerPaddingOneEnqueue(self): - dataset = dataset_ops.Dataset.from_tensor_slices([[0], [1], [2]]) - dataset = dataset.apply( - tqd.prepend_from_queue_and_padded_batch_dataset( - batch_size=1, padded_shapes=[3])) - iterator = dataset.make_one_shot_iterator() - queue_handle, value = iterator.get_next() - enqueue_negative = tqd.enqueue_in_queue_dataset(queue_handle, -value) - with self.cached_session() as sess: - self.assertAllEqual([[0, 0, 0]], sess.run(value)) - value_1, _ = sess.run([value, enqueue_negative]) - self.assertAllEqual([[1, 0, 0]], value_1) - value_2, _ = sess.run([value, enqueue_negative]) - self.assertAllEqual([[-1, 0, 0]], value_2) - value_3 = sess.run(value) - self.assertAllEqual([[1, 0, 0]], value_3) - value_4, _ = sess.run([value, enqueue_negative]) - self.assertAllEqual([[2, 0, 0]], value_4) - value_5 = sess.run(value) - self.assertAllEqual([[-2, 0, 0]], value_5) - with self.assertRaisesOpError("End of sequence"): - sess.run(value) - - def testOneEnqueue(self): - dataset = dataset_ops.Dataset.from_tensor_slices([0, 1, 2]) - dataset = dataset.apply( - tqd.prepend_from_queue_and_padded_batch_dataset(batch_size=1)) - iterator = dataset.make_one_shot_iterator() - queue_handle, value = iterator.get_next() - enqueue_negative = tqd.enqueue_in_queue_dataset(queue_handle, -value) - with self.cached_session() as sess: - self.assertEqual([0], sess.run(value)) - value_1, _ = sess.run([value, enqueue_negative]) - self.assertEqual([1], value_1) - value_2, _ = sess.run([value, enqueue_negative]) - self.assertEqual([-1], value_2) - value_3 = sess.run(value) - self.assertEqual([1], value_3) - value_4, _ = sess.run([value, enqueue_negative]) - self.assertEqual([2], value_4) - value_5 = sess.run(value) - self.assertEqual([-2], value_5) - with self.assertRaisesOpError("End of sequence"): - sess.run(value) - - def testBatchedOneEnqueue(self): - dataset = dataset_ops.Dataset.from_tensor_slices([0, 1, 2]) - dataset = dataset.apply( - tqd.prepend_from_queue_and_padded_batch_dataset(batch_size=2)) - iterator = dataset.make_one_shot_iterator() - queue_handle, value = iterator.get_next() - enqueue_negative = tqd.enqueue_in_queue_dataset(queue_handle, -value) - enqueue_zeroth = tqd.enqueue_in_queue_dataset([queue_handle[0]], - array_ops.expand_dims( - value[0], axis=0)) - with self.cached_session() as sess: - value_0, _ = sess.run([value, enqueue_negative]) - self.assertAllEqual([0, 1], value_0) - value_1, _ = sess.run([value, enqueue_zeroth]) - self.assertAllEqual([0, -1], value_1) - value_2, _ = sess.run([value, enqueue_negative]) - self.assertAllEqual([0, 2], value_2) - self.assertAllEqual([0, -2], sess.run(value)) - with self.assertRaisesOpError("End of sequence"): - sess.run(value) - - def testManyEnqueue(self): - dataset = dataset_ops.Dataset.from_tensor_slices([0, 1]) - dataset = dataset.apply( - tqd.prepend_from_queue_and_padded_batch_dataset(batch_size=1)) - iterator = dataset.make_one_shot_iterator() - queue_handle, value = iterator.get_next() - enqueue_many_more = [ - tqd.enqueue_in_queue_dataset(queue_handle, value + 100 + i) - for i in range(1000) - ] - with self.cached_session() as sess: - value_0, _ = sess.run((value, enqueue_many_more)) - self.assertEqual([0], value_0) - rest = [] - for _ in range(1000): - rest.append(sess.run(value)) - self.assertEquals([[100 + i] for i in range(1000)], sorted(rest)) - # Going back to the original input. - value_1, _ = sess.run((value, enqueue_many_more)) - self.assertEqual(1, value_1) - rest = [] - for _ in range(1000): - rest.append(sess.run(value)) - self.assertEquals([[100 + i + 1] for i in range(1000)], sorted(rest)) - with self.assertRaisesOpError("End of sequence"): - sess.run(value) - - def testEnqueueWithPrefetch(self): - dataset = dataset_ops.Dataset.from_tensor_slices([0]) - dataset = dataset.apply( - tqd.prepend_from_queue_and_padded_batch_dataset(batch_size=1)) - # Prefetching will request additional values before they are - # available to the queue. - dataset = dataset.prefetch(buffer_size=3) - iterator = dataset.make_one_shot_iterator() - queue_handle, value = iterator.get_next() - enqueue = tqd.enqueue_in_queue_dataset(queue_handle, value + 1) - with self.cached_session() as sess: - i = 0 - while i < 4: - received, _ = sess.run((value, enqueue)) - if received.size > 0: - self.assertAllEqual([i], received) - i += 1 - received_last = False - while True: - try: - received = sess.run(value) - if received.size > 0: - self.assertAllEqual([4], received) - received_last = True - except errors.OutOfRangeError: - break - self.assertTrue(received_last) - - def testDatasetWithPaddedShapeSmallerThanInputFails(self): - dataset = dataset_ops.Dataset.from_tensor_slices([[0, 0, 0]]).repeat(None) - dataset = dataset.apply( - tqd.prepend_from_queue_and_padded_batch_dataset( - batch_size=1, padded_shapes=[2])) - iterator = dataset.make_one_shot_iterator() - _, value = iterator.get_next() - with self.cached_session() as sess: - with self.assertRaisesOpError( - r"Incompatible input shapes at component 0 between " - r"input dataset this dataset: \[3\] vs. \[2\]"): - sess.run(value) - - def testEnqueueWithIncompatibleInputsFailsWithInformativeError(self): - dataset = dataset_ops.Dataset.from_tensor_slices([0]).repeat(None) - dataset = dataset.apply( - tqd.prepend_from_queue_and_padded_batch_dataset(batch_size=1)) - iterator = dataset.make_one_shot_iterator() - queue_handle, value = iterator.get_next() - - enqueue_bad_structure = tqd.enqueue_in_queue_dataset( - queue_handle, (value, value)) - enqueue_bad_dtype = tqd.enqueue_in_queue_dataset(queue_handle, - np.array( - [1.0], - dtype=np.float32)) - enqueue_bad_shape_no_batch_dim = tqd.enqueue_in_queue_dataset( - queue_handle, ([1],)) - enqueue_bad_shape = tqd.enqueue_in_queue_dataset(queue_handle, - np.array( - [[1]], dtype=np.int32)) - - with self.cached_session() as sess: - with self.assertRaisesOpError( - "mismatched number of tensors. Queue expects 1 tensors but " - "tried to insert 2"): - sess.run(enqueue_bad_structure) - with self.assertRaisesOpError(r"Expected component 0 to have batched " - r"shape \[1,...\], but saw shape: \[\]"): - sess.run(enqueue_bad_shape_no_batch_dim) - with self.assertRaisesOpError( - r"mismatched shapes at component 0. Attempted to insert tensor " - r"with shape \[1\] but queue expected shape: \[\]"): - sess.run(enqueue_bad_shape) - with self.assertRaisesOpError( - r"mismatched dtypes at component 0. Attempted to insert tensor " - r"of type float but queue expected type: int32"): - sess.run(enqueue_bad_dtype) - - def testEnqueueWithPaddedBatchFailsWithInformativeError(self): - dataset = dataset_ops.Dataset.from_tensor_slices([0, 1, 2]) - dataset = dataset.apply( - tqd.prepend_from_queue_and_padded_batch_dataset(batch_size=1)) - with self.assertRaisesRegexp( - TypeError, r"Unable to create padding for field of type 'variant'"): - dataset.padded_batch(batch_size=10, padded_shapes=[1]) - - def testOneEnqueueWithPadding(self): - dataset = dataset_ops.Dataset.from_tensor_slices([0, 2, 4, 6]) - # Make a dataset of variable-length vectors and their lengths. - dataset = dataset.map( - lambda c: (c, c * array_ops.ones((c,), dtype=c.dtype))) - # Emit a queue we can prepend to, and counts/values as padded - # batch. - dataset = dataset.apply( - tqd.prepend_from_queue_and_padded_batch_dataset(batch_size=3)) - - iterator = dataset.make_one_shot_iterator() - queue, (count, padded_value) = iterator.get_next() - - # Split the padded_value into two pieces: head and rest - rest_indices = array_ops.squeeze(array_ops.where(count > 2), axis=1) - bound = math_ops.minimum(2, math_ops.reduce_max(count)) - value_head = padded_value[:, :bound] - count_rest = array_ops.gather(count - 2, rest_indices) - value_rest = array_ops.gather(padded_value, rest_indices)[:, bound:] - queue_rest = array_ops.gather(queue, rest_indices) - enqueue_rest_op = tqd.enqueue_in_queue_dataset(queue_rest, - (count_rest, value_rest)) - with ops.control_dependencies([enqueue_rest_op]): - calc = array_ops.identity(value_head) - - with self.cached_session() as sess: - self.assertAllEqual([[0, 0], [2, 2], [4, 4]], sess.run(calc)) - self.assertAllEqual([[4, 4], [6, 6]], sess.run(calc)) - self.assertAllEqual([[6, 6]], sess.run(calc)) - self.assertAllEqual([[6, 6]], sess.run(calc)) - # Get some final batches due to prefetching. - for _ in range(3): - try: - self.assertAllEqual( - np.empty(shape=(0, 0), dtype=np.int32), sess.run(calc)) - except errors.OutOfRangeError as e: - self.assertTrue(str(e).startswith("End of sequence")) - - def testNonstandardPadding(self): - dataset = dataset_ops.Dataset.from_tensor_slices([0, 2, 4, 6]) - # Make a dataset of variable-length vectors and their lengths. - dataset = dataset.map( - lambda c: (c, c * array_ops.ones((c,), dtype=c.dtype))) - # Emit a queue we can prepend to, and counts/values as padded - # batch. - dataset = dataset.apply( - tqd.prepend_from_queue_and_padded_batch_dataset( - batch_size=3, padding_values=( - 0, - -1, - ))) - - iterator = dataset.make_one_shot_iterator() - _, (unused_count, padded_value) = iterator.get_next() - - with self.cached_session() as sess: - self.assertAllEqual([[-1, -1, -1, -1], [2, 2, -1, -1], [4, 4, 4, 4]], - sess.run(padded_value)) - self.assertAllEqual([[6] * 6], sess.run(padded_value)) - with self.assertRaisesOpError("End of sequence"): - sess.run(padded_value) - - -# TODO(ebrevdo): Figure out how to use run_core_tests to test state -# saving of an iterator that's had some tensors enqueued into its queue. -class PrependFromQueueAndPaddedBatchDatasetSerializationTest( - dataset_serialization_test_base.DatasetSerializationTestBase): - - def testPrependFromQueueAndPaddedBatch(self): - - def build_dataset(seq_lens): - return dataset_ops.Dataset.from_tensor_slices(seq_lens).map( - lambda x: array_ops.fill([x], x)).apply( - tqd.prepend_from_queue_and_padded_batch_dataset(batch_size=4)) - - seq_lens1 = np.random.randint(1, 20, size=(32,)).astype(np.int32) - seq_lens2 = np.random.randint(21, 40, size=(32,)).astype(np.int32) - self.run_core_tests(lambda: build_dataset(seq_lens1), - lambda: build_dataset(seq_lens2), 8) - - def testPrependFromQueueAndPaddedBatchNonDefaultPadding(self): - - def build_dataset(seq_lens): - - def fill_tuple(x): - filled = array_ops.fill([x], x) - return (filled, string_ops.as_string(filled)) - - padded_shape = [-1] - return dataset_ops.Dataset.from_tensor_slices(seq_lens).map( - fill_tuple).apply( - tqd.prepend_from_queue_and_padded_batch_dataset( - batch_size=4, - padded_shapes=(padded_shape, padded_shape), - padding_values=(-1, ""))) - - seq_lens1 = np.random.randint(1, 20, size=(32,)).astype(np.int32) - seq_lens2 = np.random.randint(21, 40, size=(32,)).astype(np.int32) - self.run_core_tests(lambda: build_dataset(seq_lens1), - lambda: build_dataset(seq_lens2), 8) - - -if __name__ == "__main__": - test.main() diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD index c268605711fb73f37773ce7b4181bf17f2a3a4fa..8bf1480d33b2d2117fb5c7ddf046262cfeb8a8ab 100644 --- a/tensorflow/core/BUILD +++ b/tensorflow/core/BUILD @@ -49,7 +49,7 @@ # filegroup ":android_proto_srcs" - Protos # filegroup ":android_srcs" - Core sources # cc_library ":android_tensorflow_lib" - Native library -# cc_library ":android_tensorflow_lib_selective_registration" - Native library +# cc_library ":android_tensorflow_lib_lite" - Native library, without ops, # supporting SELECTIVE_REGISTRATION feature. # portable_proto_library ":android_proto_lib" (Google-internal) # @@ -113,7 +113,6 @@ load( "tf_additional_device_tracer_test_flags", "tf_additional_gdr_lib_defines", "tf_additional_human_readable_json_deps", - "tf_additional_logger_deps", "tf_additional_lib_defines", "tf_additional_lib_deps", "tf_additional_lib_hdrs", @@ -446,15 +445,31 @@ cc_library( ) cc_library( - name = "logger", - srcs = tf_platform_srcs(["logger.cc"]), - hdrs = ["platform/logger.h"] + tf_platform_hdrs(["logger.h"]), + name = "logger_interface", + hdrs = ["platform/logger.h"], copts = tf_copts(), visibility = ["//visibility:public"], deps = [ - ":lib", - ":lib_internal", - ] + tf_additional_logger_deps(), + ":lib_proto_parsing", + "@protobuf_archive//:protobuf", + ], +) + +cc_library( + name = "default_logger", + srcs = ["platform/default/logger.cc"], + hdrs = ["platform/logger.h"], + deps = [ + "//tensorflow/core:lib_proto_parsing", + "//tensorflow/core:logger_interface", + ], +) + +cc_library( + name = "logger", + hdrs = ["platform/logger.h"], + visibility = ["//visibility:public"], + deps = ["//tensorflow/core/platform/default/build_config:logger"], ) filegroup( @@ -492,7 +507,10 @@ cc_library( ":platform_env_internal_hdrs", ], copts = tf_copts(), - visibility = ["//tensorflow/core:__subpackages__"], + visibility = [ + "//tensorflow/c:__subpackages__", + "//tensorflow/core:__subpackages__", + ], deps = [ ":error_codes_proto_cc", ":lib", @@ -1608,6 +1626,9 @@ filegroup( "**/*main.cc", "debug/**/*", "framework/op_gen_*", + "framework/node_def_util.*", + "framework/op_kernel.*", + "framework/dataset.*", "lib/jpeg/**/*", "lib/png/**/*", "lib/gif/**/*", @@ -1616,7 +1637,6 @@ filegroup( "util/reporter.*", "platform/**/cuda_libdevice_path.*", "platform/**/logger.cc", - "platform/**/logger.h", "platform/default/test_benchmark.*", "platform/cuda.h", "platform/google/**/*", @@ -1651,6 +1671,9 @@ filegroup( "common_runtime/**/*.cc", "graph/**/*.h", "graph/**/*.cc", + "framework/node_def_util.*", + "framework/op_kernel.*", + "framework/dataset.*", ], exclude = [ "**/*test.*", @@ -1679,6 +1702,9 @@ filegroup( # operators, use :android_tensorflow_lib if you want full operator # support. # +# If you just need TensorFlow types, e.g. Tensors, use +# :android_tensorflow_lib_lite_no_runtime. +# # Compiles to a trivial library on non-Android to prevent irrelevant # build errors. If not building this as part of an android_binary, # a command such as the following must be used: @@ -1689,7 +1715,33 @@ filegroup( cc_library( name = "android_tensorflow_lib_lite", srcs = if_android(["//tensorflow/core:android_srcs"]), - copts = tf_copts(android_optimization_level_override = None), + copts = tf_copts(android_optimization_level_override = None) + [ + "-DSUPPORT_SELECTIVE_REGISTRATION", + ], + linkopts = ["-lz"], + tags = [ + "manual", + "notap", + ], + visibility = ["//visibility:public"], + deps = [ + ":mobile_additional_lib_deps", + ":protos_all_cc_impl", + ":stats_calculator_portable", + "//third_party/eigen3", + "@double_conversion//:double-conversion", + "@nsync//:nsync_cpp", + "@protobuf_archive//:protobuf", + ], + alwayslink = 1, +) + +cc_library( + name = "android_tensorflow_lib_lite_nortti", + srcs = if_android(["//tensorflow/core:android_srcs"]), + copts = tf_copts(android_optimization_level_override = None) + [ + "-DSUPPORT_SELECTIVE_REGISTRATION", + ] + tf_opts_nortti_if_android(), linkopts = ["-lz"], tags = [ "manual", @@ -1797,58 +1849,6 @@ cc_library( alwayslink = 1, ) -# Android library for use with the SELECTIVE_REGISTRATION feature. -# Does not contain operators. In contrast to android_tensorflow_lib_lite, -# this links in framework support for all types, relying on selective -# registration of ops to prune code size. -cc_library( - name = "android_tensorflow_lib_selective_registration", - srcs = if_android(["//tensorflow/core:android_srcs"]), - copts = tf_copts(android_optimization_level_override = None) + [ - "-DSUPPORT_SELECTIVE_REGISTRATION", - ], - linkopts = if_android(["-lz"]), - tags = [ - "manual", - "notap", - ], - visibility = ["//visibility:public"], - deps = [ - ":protos_all_cc_impl", - "//third_party/eigen3", - "@com_google_absl//absl/container:flat_hash_set", - "@double_conversion//:double-conversion", - "@nsync//:nsync_cpp", - "@protobuf_archive//:protobuf", - ], - alwayslink = 1, -) - -# Android library for use with the SELECTIVE_REGISTRATION feature with -# no proto_rtti. -cc_library( - name = "android_tensorflow_lib_selective_registration_nortti", - srcs = if_android(["//tensorflow/core:android_srcs"]), - copts = tf_copts(android_optimization_level_override = None) + tf_opts_nortti_if_android() + [ - "-DSUPPORT_SELECTIVE_REGISTRATION", - ], - linkopts = if_android(["-lz"]), - tags = [ - "manual", - "notap", - ], - visibility = ["//visibility:public"], - deps = [ - ":protos_all_cc_impl", - "//third_party/eigen3", - "@com_google_absl//absl/container:flat_hash_set", - "@double_conversion//:double-conversion", - "@nsync//:nsync_cpp", - "@protobuf_archive//:protobuf", - ], - alwayslink = 1, -) - filegroup( name = "android_op_registrations_and_gradients", srcs = glob( @@ -2087,9 +2087,7 @@ tf_proto_library_cc( srcs = ["protobuf/master.proto"], cc_api_version = 2, protodeps = tf_additional_all_protos(), - visibility = [ - "//tensorflow:internal", - ], + visibility = ["//tensorflow:internal"], ) tf_proto_library_cc( @@ -4060,20 +4058,6 @@ tf_cuda_cc_test( ], ) -tf_cc_test_gpu( - name = "cuda_libdevice_path_test", - size = "small", - srcs = ["platform/cuda_libdevice_path_test.cc"], - linkstatic = tf_kernel_tests_linkstatic(), - tags = tf_cuda_tests_tags(), - deps = [ - ":cuda_libdevice_path", - ":lib", - ":test", - ":test_main", - ], -) - tf_cuda_only_cc_test( name = "util_cuda_kernel_helper_test", srcs = [ @@ -4929,7 +4913,7 @@ filegroup( cc_library( name = "cuda_libdevice_path", - srcs = ["platform/cuda_libdevice_path.cc"] + tf_additional_libdevice_srcs(), + srcs = tf_additional_libdevice_srcs(), hdrs = ["platform/cuda_libdevice_path.h"], copts = tf_copts(), data = tf_additional_libdevice_data(), diff --git a/tensorflow/core/api_def/api_test.cc b/tensorflow/core/api_def/api_test.cc index d38a8424eb13009fbf84d7511fb1325085d8b809..7405e2ace72d1c08cf87cc0040e617379e18149b 100644 --- a/tensorflow/core/api_def/api_test.cc +++ b/tensorflow/core/api_def/api_test.cc @@ -35,7 +35,6 @@ limitations under the License. #include "tensorflow/core/platform/protobuf.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/types.h" -#include "tensorflow/core/util/command_line_flags.h" namespace tensorflow { namespace { diff --git a/tensorflow/core/api_def/base_api/api_def_BatchDataset.pbtxt b/tensorflow/core/api_def/base_api/api_def_BatchDataset.pbtxt index 639d962874d083472e6df13550e107026fd2d0a1..32def912f83e420eab58a3071f573ae81139a298 100644 --- a/tensorflow/core/api_def/base_api/api_def_BatchDataset.pbtxt +++ b/tensorflow/core/api_def/base_api/api_def_BatchDataset.pbtxt @@ -1,5 +1,6 @@ op { graph_op_name: "BatchDataset" + visibility: HIDDEN in_arg { name: "batch_size" description: <