diff --git a/tools/bazel.rc b/.bazelrc similarity index 93% rename from tools/bazel.rc rename to .bazelrc index 1fdf51f53e29c7111cf89c016400b710051cf9c6..ceba7bfdbac74d1e44aadc3010e5e84bd36ce3ee 100644 --- a/tools/bazel.rc +++ b/.bazelrc @@ -76,7 +76,6 @@ build:nonccl --define=no_nccl_support=true build --define=use_fast_cpp_protos=true build --define=allow_oversize_protos=true -build --define=grpc_no_ares=true build --spawn_strategy=standalone build --genrule_strategy=standalone @@ -93,3 +92,14 @@ build:dynamic_kernels --copt=-DAUTOLOAD_DYNAMIC_KERNELS build --define=PREFIX=/usr build --define=LIBDIR=$(PREFIX)/lib build --define=INCLUDEDIR=$(PREFIX)/include + +# Disable MKL-DNN contraction kernels by default. +build --define=tensorflow_mkldnn_contraction_kernel=0 + +# Default options should come above this line + +# Options from ./configure +try-import %workspace%/.tf_configure.bazelrc + +# Put user-specific options in .bazelrc.user +try-import %workspace%/.bazelrc.user diff --git a/.gitignore b/.gitignore index 90324058600bee46af56e49028977971848a80de..e1d352c238a1b2d4febe0f5d4a30cfa0c942f7e7 100644 --- a/.gitignore +++ b/.gitignore @@ -1,7 +1,7 @@ .DS_Store .ipynb_checkpoints node_modules -/.bazelrc +/.bazelrc.user /.tf_configure.bazelrc /bazel-* /bazel_pip diff --git a/CODEOWNERS b/CODEOWNERS index bfcdc2a23f4753336e357a45afd6259b531f36ec..cb3fa2312405ce44d5dfc30ea4164740f436e07e 100644 --- a/CODEOWNERS +++ b/CODEOWNERS @@ -1,7 +1,7 @@ # Where component owners are known, add them here. /tenosrflow/core/debug @caisq -/tensorflow/core/nccl/ @azaks @csigg +/tensorflow/core/nccl/ @azaks2 @chsigg /tensorflow/core/platform/windows/ @mrry /tensorflow/core/platform/s3 @yongtang /tensorflow/go @asimshankar diff --git a/README.md b/README.md index 044174947a094d43a51f7140dd40ec0f17801d40..519815d006cc33be10132909baf414a4bd843435 100644 --- a/README.md +++ b/README.md @@ -113,11 +113,12 @@ The TensorFlow project strives to abide by generally accepted best practices in Build Type | Status | Artifacts ----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | --------- **IBM s390x** | [![Build Status](http://ibmz-ci.osuosl.org/job/TensorFlow_IBMZ_CI/badge/icon)](http://ibmz-ci.osuosl.org/job/TensorFlow_IBMZ_CI/) | TBA -**IBM ppc64le CPU** | [![Build Status](http://powerci.osuosl.org/job/TensorFlow_PPC64LE_CPU_Build/badge/icon)](http://powerci.osuosl.org/job/TensorFlow_PPC64LE_CPU_Build/) | TBA -**IBM ppc64le GPU** Nightly | [![Build Status](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Nightly_Artifact/badge/icon)](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Nightly_Artifact/) | [Nightly](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Nightly_Artifact/) -**IBM ppc64le GPU** Stable Release | [![Build Status](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Release_Build/badge/icon)](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Release_Build/) | [Release](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Release_Build/) +**Linux ppc64le CPU** Nightly | [![Build Status](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_CPU_Build/badge/icon)](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_CPU_Build/) | [Nightly](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_CPU_Nightly_Artifact/) +**Linux ppc64le CPU** Stable Release | [![Build Status](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_CPU_Release_Build/badge/icon)](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_CPU_Release_Build/) | [Release](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_CPU_Release_Build/) +**Linux ppc64le GPU** Nightly | [![Build Status](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Build/badge/icon)](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Build/) | [Nightly](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Nightly_Artifact/) +**Linux ppc64le GPU** Stable Release | [![Build Status](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Release_Build/badge/icon)](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Release_Build/) | [Release](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Release_Build/) **Linux CPU with Intel® MKL-DNN** Nightly | [![Build Status](https://tensorflow-ci.intel.com/job/tensorflow-mkl-linux-cpu/badge/icon)](https://tensorflow-ci.intel.com/job/tensorflow-mkl-linux-cpu/) | [Nightly](https://tensorflow-ci.intel.com/job/tensorflow-mkl-build-whl-nightly/) -**Linux CPU with Intel® MKL-DNN** Python 2.7
**Linux CPU with Intel® MKL-DNN** Python 3.4
**Linux CPU with Intel® MKL-DNN** Python 3.5
**Linux CPU with Intel® MKL-DNN** Python 3.6 | [![Build Status](https://tensorflow-ci.intel.com/job/tensorflow-mkl-build-release-whl/badge/icon)](https://tensorflow-ci.intel.com/job/tensorflow-mkl-build-release-whl/lastStableBuild) | [1.11.0 py2.7](https://storage.googleapis.com/intel-optimized-tensorflow/tensorflow-1.11.0-cp27-cp27mu-linux_x86_64.whl)
[1.11.0 py3.4](https://storage.googleapis.com/intel-optimized-tensorflow/tensorflow-1.11.0-cp34-cp34m-linux_x86_64.whl)
[1.11.0 py3.5](https://storage.googleapis.com/intel-optimized-tensorflow/tensorflow-1.11.0-cp35-cp35m-linux_x86_64.whl)
[1.11.0 py3.6](https://storage.googleapis.com/intel-optimized-tensorflow/tensorflow-1.11.0-cp36-cp36m-linux_x86_64.whl) +**Linux CPU with Intel® MKL-DNN** Python 2.7
**Linux CPU with Intel® MKL-DNN** Python 3.4
**Linux CPU with Intel® MKL-DNN** Python 3.5
**Linux CPU with Intel® MKL-DNN** Python 3.6 | [![Build Status](https://tensorflow-ci.intel.com/job/tensorflow-mkl-build-release-whl/badge/icon)](https://tensorflow-ci.intel.com/job/tensorflow-mkl-build-release-whl/lastStableBuild) | [1.12.0 py2.7](https://storage.googleapis.com/intel-optimized-tensorflow/tensorflow-1.12.0-cp27-cp27mu-linux_x86_64.whl)
[1.12.0 py3.4](https://storage.googleapis.com/intel-optimized-tensorflow/tensorflow-1.12.0-cp34-cp34m-linux_x86_64.whl)
[1.12.0 py3.5](https://storage.googleapis.com/intel-optimized-tensorflow/tensorflow-1.12.0-cp35-cp35m-linux_x86_64.whl)
[1.12.0 py3.6](https://storage.googleapis.com/intel-optimized-tensorflow/tensorflow-1.12.0-cp36-cp36m-linux_x86_64.whl) ## For more information diff --git a/RELEASE.md b/RELEASE.md index b13b071bd6cf4d3a260c8e248a67d23e1a688498..282430d12303bde980e19e3c3602eb91b1a54d63 100644 --- a/RELEASE.md +++ b/RELEASE.md @@ -7,6 +7,8 @@ Serving. * Keras models now support evaluating with a `tf.data.Dataset`. * TensorFlow binaries are built with XLA support linked in by default. +* Ignite Dataset added to contrib/ignite that allows to work with Apache + Ignite. ## Bug Fixes and Other Changes @@ -280,50 +282,76 @@ Ag Ramesh, Alex Wiltschko, Alexander Pantyukhin, Amogh Mannekote, An Jiaoyang, A ## Bug Fixes and Other Changes -* `tfe.Network` is deprecated. Please inherit from `tf.keras.Model`. -* Layered variable names have changed in the following conditions: - * Using `tf.keras.layers` with custom variable scopes. - * Using `tf.layers` in a subclassed `tf.keras.Model` class. See - [here](https://www.tensorflow.org/versions/r1.9/api_docs/python/tf/layers) for more details -* `tf.data`: - * `Dataset.from_generator()` now accepts an `args` list, in order to create nested generators. - * `Dataset.list_files()` now produces determinstic results when `shuffle=False` or a `seed` is passed. - * `tf.contrib.data.sample_from_datasets()` and `tf.contrib.data.choose_from_datasets()` make it easier to sample or deterministically choose elements from multiple datasets. - * `tf.contrib.data.make_csv_dataset()` now supports line breaks in quoted strings, and two infrequently used arguments removed. - * (C++) `DatasetBase::DebugString()` is now `const`. - * (C++) `DatasetBase::MakeIterator()` has been renamed to `DatasetBase::MakeIteratorInternal()`. - * (C++) `IteratorBase::Initialize()` method was added to support raising errors during iterator construction. -* Eager Execution: - * Added the ability to pause recording operations for gradient computation via `tf.GradientTape.stop_recording`. - * Updated documentation, introductory notebooks. -* `tf.keras`: - * Move Keras code out of _impl folder and remove API files. - * `tf.keras.Model.save_weights` now saves in TensorFlow format by default. - * Enable dataset iterators to be passed to `tf.keras.Model` training/eval methods. -* TensorFlow Debugger (tfdbg) CLI: fix an issue in which the TensorBoard Debugger Plugin could not handle total source file size exceeding gRPC message size limit (4 MB). -* `tf.contrib`: - * `tf.contrib.framework.zero_initializer` supports ResourceVariable. - * Adding "constrained_optimization" to tensorflow/contrib. -* Other: - * Add GCS Configuration Ops. - * Changing signature of `MakeIterator` to enable propagating error status. - * KL divergence for two Dirichlet distributions. - * More consistent GcsFileSystem behavior for certain reads past EOF. - * Update benchmark for tf.scan to match ranges across eager and graph modes. - * Fixed bug in `tf.reduce_prod gradient` for complex dtypes. - * Allow the use of '.' in variables (e.g. "hparams.parse('a.b=1.0')"), which would previously raise an error. This will correspond to an attribute name with an embedded '.' symbol (e.g. 'a.b'), which can only be accessed indirectly (e.g. through getattr and setattr). To set this up the user will first need to explicitly add the variable to the hparam object (e.g. "hparams.add_hparam(name='a.b', value=0.0)"). - * Benchmark for tf.scan in graph and eager modes. - * Added complex128 support to FFT, FFT2D, FFT3D, IFFT, IFFT2D, and IFFT3D. - * Making ids unique in `nn.embedding_lookup_sparse`. This helps to reduce RPC calls for looking up the embeddings when there are repeated ids in the batch. - * Support indicator column in boosted trees. - * Prevent `tf.gradients()` from backpropagating through integer tensors. - * LinearOperator[1D,2D,3D]Circulant added to `tensorflow.linalg`. - * Conv3D, Conv3DBackpropInput, Conv3DBackpropFilter now supports arbitrary. - * Added `tf.train.Checkpoint` for reading/writing object-based checkpoints. - * Added LinearOperatorKronecker, a dense-free implementation of the Kronecker Product. - * Allow LinearOperator to broadcast. - * SavedModelBuilder will now deduplicate asset names that point to files with the same basename and the same contents. Note that this may result in new asset files included in SavedModels in cases where assets with the same name but different contents were previously overwriting each other. - +* `tfe.Network` is deprecated. Please inherit from `tf.keras.Model`. +* Layered variable names have changed in the following conditions: + * Using `tf.keras.layers` with custom variable scopes. + * Using `tf.layers` in a subclassed `tf.keras.Model` class. See + [here](https://www.tensorflow.org/versions/r1.9/api_docs/python/tf/layers) + for more details +* `tf.data`: + * `Dataset.from_generator()` now accepts an `args` list, in order to + create nested generators. + * `Dataset.list_files()` now produces deterministic results when + `shuffle=False` or a `seed` is passed. + * `tf.contrib.data.sample_from_datasets()` and + `tf.contrib.data.choose_from_datasets()` make it easier to sample or + deterministically choose elements from multiple datasets. + * `tf.contrib.data.make_csv_dataset()` now supports line breaks in quoted + strings, and two infrequently used arguments removed. + * (C++) `DatasetBase::DebugString()` is now `const`. + * (C++) `DatasetBase::MakeIterator()` has been renamed to + `DatasetBase::MakeIteratorInternal()`. + * (C++) `IteratorBase::Initialize()` method was added to support raising + errors during iterator construction. +* Eager Execution: + * Added the ability to pause recording operations for gradient computation + via `tf.GradientTape.stop_recording`. + * Updated documentation, introductory notebooks. +* `tf.keras`: + * Move Keras code out of _impl folder and remove API files. + * `tf.keras.Model.save_weights` now saves in TensorFlow format by default. + * Enable dataset iterators to be passed to `tf.keras.Model` training/eval + methods. +* TensorFlow Debugger (tfdbg) CLI: fix an issue in which the TensorBoard + Debugger Plugin could not handle total source file size exceeding gRPC + message size limit (4 MB). +* `tf.contrib`: + * `tf.contrib.framework.zero_initializer` supports ResourceVariable. + * Adding "constrained_optimization" to tensorflow/contrib. +* Other: + * Add GCS Configuration Ops. + * Changing signature of `MakeIterator` to enable propagating error status. + * KL divergence for two Dirichlet distributions. + * More consistent GcsFileSystem behavior for certain reads past EOF. + * Update benchmark for tf.scan to match ranges across eager and graph + modes. + * Fixed bug in `tf.reduce_prod gradient` for complex dtypes. + * Allow the use of '.' in variables (e.g. "hparams.parse('a.b=1.0')"), + which would previously raise an error. This will correspond to an + attribute name with an embedded '.' symbol (e.g. 'a.b'), which can only + be accessed indirectly (e.g. through getattr and setattr). To set this + up the user will first need to explicitly add the variable to the hparam + object (e.g. "hparams.add_hparam(name='a.b', value=0.0)"). + * Benchmark for tf.scan in graph and eager modes. + * Added complex128 support to FFT, FFT2D, FFT3D, IFFT, IFFT2D, and IFFT3D. + * Making ids unique in `nn.embedding_lookup_sparse`. This helps to reduce + RPC calls for looking up the embeddings when there are repeated ids in + the batch. + * Support indicator column in boosted trees. + * Prevent `tf.gradients()` from backpropagating through integer tensors. + * LinearOperator[1D,2D,3D]Circulant added to `tensorflow.linalg`. + * Conv3D, Conv3DBackpropInput, Conv3DBackpropFilter now supports + arbitrary. + * Added `tf.train.Checkpoint` for reading/writing object-based + checkpoints. + * Added LinearOperatorKronecker, a dense-free implementation of the + Kronecker Product. + * Allow LinearOperator to broadcast. + * SavedModelBuilder will now deduplicate asset names that point to files + with the same basename and the same contents. Note that this may result + in new asset files included in SavedModels in cases where assets with + the same name but different contents were previously overwriting each + other. ## Thanks to our Contributors diff --git a/WORKSPACE b/WORKSPACE index 7cc08e0164a202581ad7ebbe107a9e19410e70e4..2277e83a3f67b62cf4ee1311767ee06c0549c697 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -1,6 +1,6 @@ workspace(name = "org_tensorflow") -load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive") +load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive", "http_file") http_archive( name = "io_bazel_rules_closure", @@ -16,38 +16,64 @@ load("@io_bazel_rules_closure//closure:defs.bzl", "closure_repositories") closure_repositories() -http_archive( - name = "base_images_docker", - sha256 = "e2b1b7254270bb7605e814a9dbf6d1e4ae04a11136ff1714fbfdabe3f87f7cf9", - strip_prefix = "base-images-docker-12801524f867e657fbb5d1a74f31618aff181ac6", - urls = ["https://github.com/GoogleCloudPlatform/base-images-docker/archive/12801524f867e657fbb5d1a74f31618aff181ac6.tar.gz"], -) +load("//third_party/toolchains/preconfig/generate:archives.bzl", + "bazel_toolchains_archive") -http_archive( - name = "bazel_toolchains", - sha256 = "15b5858b1b5541ec44df31b94c3b8672815b31d71215a98398761ea9f4c4eedb", - strip_prefix = "bazel-toolchains-6200b238c9c2d137c0d9a7262c80cc71d98e692b", - urls = [ - "https://github.com/bazelbuild/bazel-toolchains/archive/6200b238c9c2d137c0d9a7262c80cc71d98e692b.tar.gz", - ], +bazel_toolchains_archive() + +load( + "@bazel_toolchains//repositories:repositories.bzl", + bazel_toolchains_repositories = "repositories", ) -http_archive( - name = "io_bazel_rules_docker", - sha256 = "29d109605e0d6f9c892584f07275b8c9260803bf0c6fcb7de2623b2bedc910bd", - strip_prefix = "rules_docker-0.5.1", - urls = ["https://github.com/bazelbuild/rules_docker/archive/v0.5.1.tar.gz"], +bazel_toolchains_repositories() + +load( + "@io_bazel_rules_docker//container:container.bzl", + container_repositories = "repositories", ) -load("//third_party/toolchains/preconfig/generate:workspace.bzl", "remote_config_workspace") +container_repositories() + +load("//third_party/toolchains/preconfig/generate:workspace.bzl", + "remote_config_workspace") remote_config_workspace() +# Apple and Swift rules. +http_archive( + name = "build_bazel_rules_apple", + sha256 = "4fe4ee824200b48821730f89ff260984332dc3551db587c24691235d1d96a8a7", + strip_prefix = "rules_apple-0.10.0", + urls = ["https://github.com/bazelbuild/rules_apple/archive/0.10.0.tar.gz"], +) +http_archive( + name = "build_bazel_rules_swift", + sha256 = "6544ff5615febec0342de1127144d2f3e43ea80fb7f9b1ade65e6a184e39e618", + strip_prefix = "rules_swift-0.5.0", + urls = ["https://github.com/bazelbuild/rules_swift/archive/0.5.0.tar.gz"], +) +http_archive( + name = "bazel_skylib", + sha256 = "eb5c57e4c12e68c0c20bc774bfbc60a568e800d025557bc4ea022c6479acc867", + strip_prefix = "bazel-skylib-0.6.0", + urls = ["https://github.com/bazelbuild/bazel-skylib/archive/0.6.0.tar.gz"], +) +http_file( + name = "xctestrunner", + executable = 1, + urls = ["https://github.com/google/xctestrunner/releases/download/0.2.5/ios_test_runner.par"], +) +load("@build_bazel_rules_apple//apple:repositories.bzl", "apple_rules_dependencies") +apple_rules_dependencies(ignore_version_differences = True) +load("@build_bazel_rules_swift//swift:repositories.bzl", "swift_rules_dependencies") +swift_rules_dependencies() + # We must check the bazel version before trying to parse any other BUILD # files, in case the parsing of those build files depends on the bazel # version we require here. load("//tensorflow:version_check.bzl", "check_bazel_version_at_least") -check_bazel_version_at_least("0.15.0") +check_bazel_version_at_least("0.18.0") load("//tensorflow:workspace.bzl", "tf_workspace") diff --git a/tensorflow/opensource_only/arm_compiler.BUILD b/arm_compiler.BUILD similarity index 100% rename from tensorflow/opensource_only/arm_compiler.BUILD rename to arm_compiler.BUILD diff --git a/configure.py b/configure.py index 6c905a0be3d685b5921dfbc5bddfbe6471a82625..1e732db26404906901a9eeab97a5e75137ee8388 100644 --- a/configure.py +++ b/configure.py @@ -255,18 +255,6 @@ def setup_python(environ_cp): def reset_tf_configure_bazelrc(): """Reset file that contains customized config settings.""" open(_TF_BAZELRC, 'w').close() - bazelrc_path = os.path.join(_TF_WORKSPACE_ROOT, '.bazelrc') - - data = [] - if os.path.exists(bazelrc_path): - with open(bazelrc_path, 'r') as f: - data = f.read().splitlines() - with open(bazelrc_path, 'w') as f: - for l in data: - if _TF_BAZELRC_FILENAME in l: - continue - f.write('%s\n' % l) - f.write('import %%workspace%%/%s\n' % _TF_BAZELRC_FILENAME) def cleanup_makefile(): """Delete any leftover BUILD files from the Makefile build. @@ -488,11 +476,12 @@ def check_bazel_version(min_version, max_version): if curr_version_int < min_version_int: print('Please upgrade your bazel installation to version %s or higher to ' 'build TensorFlow!' % min_version) - sys.exit(0) - if curr_version_int > max_version_int: + sys.exit(1) + if (curr_version_int > max_version_int and + 'TF_IGNORE_MAX_BAZEL_VERSION' not in os.environ): print('Please downgrade your bazel installation to version %s or lower to ' 'build TensorFlow!' % max_version) - sys.exit(0) + sys.exit(1) return curr_version @@ -1565,11 +1554,9 @@ def main(): # environment variables. environ_cp = dict(os.environ) - check_bazel_version('0.15.0', '0.20.0') + check_bazel_version('0.19.0', '0.20.0') reset_tf_configure_bazelrc() - # Explicitly import tools/bazel.rc, this is needed for Bazel 0.19.0 or later - write_to_bazelrc('import %workspace%/tools/bazel.rc') cleanup_makefile() setup_python(environ_cp) diff --git a/tensorflow/BUILD b/tensorflow/BUILD index fd4b94202aad24a82abef8abd16431f61a8326f0..f07e7365d3482cde5b7bb76ebf22890150e98651 100644 --- a/tensorflow/BUILD +++ b/tensorflow/BUILD @@ -202,6 +202,12 @@ config_setting( visibility = ["//visibility:public"], ) +config_setting( + name = "arm", + values = {"cpu": "arm"}, + visibility = ["//visibility:public"], +) + config_setting( name = "freebsd", values = {"cpu": "freebsd"}, @@ -267,6 +273,15 @@ config_setting( visibility = ["//visibility:public"], ) +# By default, XLA GPU is compiled into tensorflow when building with +# --config=cuda even when `with_xla_support` is false. The config setting +# here allows us to override the behavior if needed. +config_setting( + name = "no_xla_deps_in_cuda", + define_values = {"no_xla_deps_in_cuda": "true"}, + visibility = ["//visibility:public"], +) + config_setting( name = "with_gdr_support", define_values = {"with_gdr_support": "true"}, @@ -359,7 +374,9 @@ package_group( name = "internal", packages = [ "-//third_party/tensorflow/python/estimator", + "//learning/deepmind/...", "//learning/meta_rank/...", + "//learning/pathways/...", # While dataset C++ api requires internals "//tensorflow/...", "//tensorflow_estimator/contrib/...", "//tensorflow_fold/llgtm/...", @@ -606,9 +623,11 @@ py_library( name = "tensorflow_py", srcs_version = "PY2AND3", visibility = ["//visibility:public"], - deps = [ + deps = select({ + "api_version_2": [], + "//conditions:default": ["//tensorflow/contrib:contrib_py"], + }) + [ ":tensorflow_py_no_contrib", - "//tensorflow/contrib:contrib_py", "//tensorflow/python/estimator:estimator_py", ], ) @@ -618,7 +637,11 @@ py_library( srcs = select({ "api_version_2": [":tf_python_api_gen_v2"], "//conditions:default": [":tf_python_api_gen_v1"], - }) + [":root_init_gen"], + }) + [":root_init_gen"] + [ + "//tensorflow/python/keras/api:keras_python_api_gen", + "//tensorflow/python/keras/api:keras_python_api_gen_compat_v1", + "//tensorflow/python/keras/api:keras_python_api_gen_compat_v2", + ], srcs_version = "PY2AND3", visibility = ["//visibility:public"], deps = ["//tensorflow/python:no_contrib"], diff --git a/tensorflow/api_template.__init__.py b/tensorflow/api_template.__init__.py index d81cf067eb07e88e2b8a86cf5643674235eb3f3b..2c0a7452692e5cdb184f7f0a77eb1b646a1772d4 100644 --- a/tensorflow/api_template.__init__.py +++ b/tensorflow/api_template.__init__.py @@ -18,27 +18,78 @@ from __future__ import absolute_import as _absolute_import from __future__ import division as _division from __future__ import print_function as _print_function +import distutils as _distutils +import inspect as _inspect import os as _os +import site as _site +import sys as _sys + +# API IMPORTS PLACEHOLDER # pylint: disable=g-bad-import-order from tensorflow.python.tools import component_api_helper as _component_api_helper _component_api_helper.package_hook( parent_package_str=__name__, - child_package_str=('tensorflow_estimator.python.estimator.api.estimator')) - -# API IMPORTS PLACEHOLDER + child_package_str=( + 'tensorflow_estimator.python.estimator.api._v2.estimator')) +_current_module = _sys.modules[__name__] +if not hasattr(_current_module, 'estimator'): + _component_api_helper.package_hook( + parent_package_str=__name__, + child_package_str=( + 'tensorflow_estimator.python.estimator.api.estimator')) +_component_api_helper.package_hook( + parent_package_str=__name__, + child_package_str=('tensorflow.python.keras.api._v2.keras')) # Make sure directory containing top level submodules is in # the __path__ so that "from tensorflow.foo import bar" works. # We're using bitwise, but there's nothing special about that. _tf_api_dir = _os.path.dirname(_os.path.dirname(bitwise.__file__)) # pylint: disable=undefined-variable -if _tf_api_dir not in __path__: +if not hasattr(_current_module, '__path__'): + __path__ = [_tf_api_dir] +elif _tf_api_dir not in __path__: __path__.append(_tf_api_dir) # Enable TF2 behaviors from tensorflow.python.compat import compat as _compat # pylint: disable=g-import-not-at-top _compat.enable_v2_behavior() + +# Load all plugin libraries from site-packages/tensorflow-plugins if we are +# running under pip. +# TODO(gunan): Enable setting an environment variable to define arbitrary plugin +# directories. +# TODO(gunan): Find a better location for this code snippet. +from tensorflow.python.framework import load_library as _ll +from tensorflow.python.lib.io import file_io as _fi + +# Get sitepackages directories for the python installation. +_site_packages_dirs = [] +_site_packages_dirs += [_site.USER_SITE] +_site_packages_dirs += [_p for _p in _sys.path if 'site-packages' in _p] +if 'getsitepackages' in dir(_site): + _site_packages_dirs += _site.getsitepackages() + +if 'sysconfig' in dir(_distutils): + _site_packages_dirs += [_distutils.sysconfig.get_python_lib()] + +_site_packages_dirs = list(set(_site_packages_dirs)) + +# Find the location of this exact file. +_current_file_location = _inspect.getfile(_inspect.currentframe()) + +def _running_from_pip_package(): + return any( + _current_file_location.startswith(dir_) for dir_ in _site_packages_dirs) + +if _running_from_pip_package(): + for s in _site_packages_dirs: + # TODO(gunan): Add sanity checks to loaded modules here. + plugin_dir = _os.path.join(s, 'tensorflow-plugins') + if _fi.file_exists(plugin_dir): + _ll.load_library(plugin_dir) + # These symbols appear because we import the python package which # in turn imports from tensorflow.core and tensorflow.python. They # must come from this module. So python adds these symbols for the @@ -59,4 +110,6 @@ try: del compiler except NameError: pass + + # pylint: enable=undefined-variable diff --git a/tensorflow/api_template_v1.__init__.py b/tensorflow/api_template_v1.__init__.py index 65bdb6cb1b5e6fb0656a12b932d767aeacfccd29..514aba1b59631f882523396aab0f4d3d5e88a893 100644 --- a/tensorflow/api_template_v1.__init__.py +++ b/tensorflow/api_template_v1.__init__.py @@ -18,20 +18,42 @@ from __future__ import absolute_import as _absolute_import from __future__ import division as _division from __future__ import print_function as _print_function +import distutils as _distutils +import inspect as _inspect import os as _os +import site as _site +import sys as _sys # pylint: disable=g-bad-import-order from tensorflow.python import pywrap_tensorflow # pylint: disable=unused-import +# API IMPORTS PLACEHOLDER + from tensorflow.python.tools import component_api_helper as _component_api_helper _component_api_helper.package_hook( parent_package_str=__name__, - child_package_str=('tensorflow_estimator.python.estimator.api.estimator')) - -# API IMPORTS PLACEHOLDER + child_package_str=( + 'tensorflow_estimator.python.estimator.api._v1.estimator')) +_current_module = _sys.modules[__name__] +if not hasattr(_current_module, 'estimator'): + _component_api_helper.package_hook( + parent_package_str=__name__, + child_package_str=( + 'tensorflow_estimator.python.estimator.api.estimator')) +_component_api_helper.package_hook( + parent_package_str=__name__, + child_package_str=('tensorflow.python.keras.api._v1.keras')) from tensorflow.python.util.lazy_loader import LazyLoader # pylint: disable=g-import-not-at-top -contrib = LazyLoader('contrib', globals(), 'tensorflow.contrib') +_CONTRIB_WARNING = """ +WARNING: The TensorFlow contrib module will not be included in TensorFlow 2.0. +For more information, please see: + * https://github.com/tensorflow/community/blob/master/rfcs/20180907-contrib-sunset.md + * https://github.com/tensorflow/addons +If you depend on functionality not listed there, please file an issue. +""" +contrib = LazyLoader('contrib', globals(), 'tensorflow.contrib', + _CONTRIB_WARNING) del LazyLoader # The templated code that replaces the placeholder above sometimes # sets the __all__ variable. If it does, we have to be sure to add @@ -45,9 +67,44 @@ app.flags = flags # pylint: disable=undefined-variable # Make sure directory containing top level submodules is in # the __path__ so that "from tensorflow.foo import bar" works. _tf_api_dir = _os.path.dirname(_os.path.dirname(app.__file__)) # pylint: disable=undefined-variable -if _tf_api_dir not in __path__: +if not hasattr(_current_module, '__path__'): + __path__ = [_tf_api_dir] +elif _tf_api_dir not in __path__: __path__.append(_tf_api_dir) +# Load all plugin libraries from site-packages/tensorflow-plugins if we are +# running under pip. +# TODO(gunan): Enable setting an environment variable to define arbitrary plugin +# directories. +# TODO(gunan): Find a better location for this code snippet. +from tensorflow.python.framework import load_library as _ll +from tensorflow.python.lib.io import file_io as _fi + +# Get sitepackages directories for the python installation. +_site_packages_dirs = [] +_site_packages_dirs += [_site.USER_SITE] +_site_packages_dirs += [_p for _p in _sys.path if 'site-packages' in _p] +if 'getsitepackages' in dir(_site): + _site_packages_dirs += _site.getsitepackages() + +if 'sysconfig' in dir(_distutils): + _site_packages_dirs += [_distutils.sysconfig.get_python_lib()] + +_site_packages_dirs = list(set(_site_packages_dirs)) + +# Find the location of this exact file. +_current_file_location = _inspect.getfile(_inspect.currentframe()) + +def _running_from_pip_package(): + return any( + _current_file_location.startswith(dir_) for dir_ in _site_packages_dirs) + +if _running_from_pip_package(): + for s in _site_packages_dirs: + # TODO(gunan): Add sanity checks to loaded modules here. + plugin_dir = _os.path.join(s, 'tensorflow-plugins') + if _fi.file_exists(plugin_dir): + _ll.load_library(plugin_dir) # These symbols appear because we import the python package which # in turn imports from tensorflow.core and tensorflow.python. They diff --git a/tensorflow/c/BUILD b/tensorflow/c/BUILD index ad2ae08a37b628b7343e58088a5340d6525675d1..3e1f220db233001ba652120657631f8c1a296b35 100644 --- a/tensorflow/c/BUILD +++ b/tensorflow/c/BUILD @@ -123,7 +123,6 @@ tf_cuda_library( "//tensorflow/c/eager:c_api", "//tensorflow/c/eager:c_api_internal", "//tensorflow/compiler/jit:flags", - "//tensorflow/contrib/tpu:all_ops", "//tensorflow/core:core_cpu", "//tensorflow/core:framework", "//tensorflow/core:lib", @@ -175,6 +174,32 @@ tf_cuda_library( ], ) +tf_cuda_library( + name = "env", + srcs = [ + "env.cc", + ], + hdrs = [ + "env.h", + ], + copts = tf_copts(), + visibility = ["//visibility:public"], + deps = select({ + "//tensorflow:android": [ + ":c_api", + ":tf_status_helper", + "//tensorflow/core:android_tensorflow_lib_lite", + "//tensorflow/core:lib", + ], + "//conditions:default": [ + ":c_api", + ":tf_status_helper", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + ], + }) + [":c_api_internal"], +) + tf_cuda_library( name = "kernels", srcs = [ @@ -223,6 +248,24 @@ tf_cuda_library( ], ) +tf_cc_test( + name = "c_test", + srcs = ["c_test.c"], + extra_copts = ["-std=c11"], + tags = [ + # TODO(b/121223209): Re-enable after fixing asan memory leaks and MacOS + # build errors. + "noasan", + "no_mac", + ], + deps = [ + ":c_api", + ":c_api_experimental", + ":env", + ":kernels", + ], +) + tf_cuda_cc_test( name = "c_api_test", size = "small", @@ -334,6 +377,27 @@ tf_kernel_library( alwayslink = 1, ) +tf_cuda_cc_test( + name = "env_test", + size = "small", + srcs = ["env_test.cc"], + linkopts = select({ + "//tensorflow:darwin": ["-headerpad_max_install_names"], + "//conditions:default": [], + }), + tags = ["noasan"], + # We must ensure that the dependencies can be dynamically linked since + # the shared library must be able to use core:framework. + # linkstatic = tf_kernel_tests_linkstatic(), + deps = [ + ":c_api", + ":env", + "//tensorflow/core:lib", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + ], +) + tf_cuda_cc_test( name = "kernels_test", size = "small", diff --git a/tensorflow/c/c_api.cc b/tensorflow/c/c_api.cc index 94d18eb8b04e3534be547aca5cfbb32da40ffbf6..9580215a317b1a6b1cdacbd430a1764af61be990 100644 --- a/tensorflow/c/c_api.cc +++ b/tensorflow/c/c_api.cc @@ -488,6 +488,7 @@ static TF_Tensor* EmptyTensor(TF_DataType dtype, const TensorShape& shape) { // Non-static for testing. TF_Tensor* TF_TensorFromTensor(const tensorflow::Tensor& src, TF_Status* status) { + TF_SetStatus(status, TF_OK, ""); if (!src.IsInitialized()) { status->status = FailedPrecondition( "attempt to use a tensor with an uninitialized value"); diff --git a/tensorflow/c/c_api.h b/tensorflow/c/c_api.h index 3d56268110edbe96616201d15a69cc8c84d3115a..c7abba85521fccec07983cd5ab4f94a8368d6181 100644 --- a/tensorflow/c/c_api.h +++ b/tensorflow/c/c_api.h @@ -91,7 +91,7 @@ extern "C" { // -------------------------------------------------------------------------- // TF_Version returns a string describing version information of the // TensorFlow library. TensorFlow using semantic versioning. -TF_CAPI_EXPORT extern const char* TF_Version(); +TF_CAPI_EXPORT extern const char* TF_Version(void); // -------------------------------------------------------------------------- // TF_DataType holds the type for a scalar value. E.g., one slot in a tensor. @@ -157,7 +157,7 @@ typedef enum TF_Code { typedef struct TF_Status TF_Status; // Return a new status object. -TF_CAPI_EXPORT extern TF_Status* TF_NewStatus(); +TF_CAPI_EXPORT extern TF_Status* TF_NewStatus(void); // Delete a previously created status object. TF_CAPI_EXPORT extern void TF_DeleteStatus(TF_Status*); @@ -196,7 +196,7 @@ TF_CAPI_EXPORT extern TF_Buffer* TF_NewBufferFromString(const void* proto, size_t proto_len); // Useful for passing *out* a protobuf. -TF_CAPI_EXPORT extern TF_Buffer* TF_NewBuffer(); +TF_CAPI_EXPORT extern TF_Buffer* TF_NewBuffer(void); TF_CAPI_EXPORT extern void TF_DeleteBuffer(TF_Buffer*); @@ -305,7 +305,7 @@ TF_CAPI_EXPORT extern size_t TF_StringEncodedSize(size_t len); typedef struct TF_SessionOptions TF_SessionOptions; // Return a new options object. -TF_CAPI_EXPORT extern TF_SessionOptions* TF_NewSessionOptions(); +TF_CAPI_EXPORT extern TF_SessionOptions* TF_NewSessionOptions(void); // Set the target in TF_SessionOptions.options. // target can be empty, a single entry, or a comma separated list of entries. @@ -338,7 +338,7 @@ TF_CAPI_EXPORT extern void TF_DeleteSessionOptions(TF_SessionOptions*); typedef struct TF_Graph TF_Graph; // Return a new graph object. -TF_CAPI_EXPORT extern TF_Graph* TF_NewGraph(); +TF_CAPI_EXPORT extern TF_Graph* TF_NewGraph(void); // Destroy an options object. Graph will be deleted once no more // TFSession's are referencing it. @@ -890,7 +890,8 @@ TF_CAPI_EXPORT extern void TF_GraphVersions(TF_Graph* graph, // TF_GraphImportGraphDef. typedef struct TF_ImportGraphDefOptions TF_ImportGraphDefOptions; -TF_CAPI_EXPORT extern TF_ImportGraphDefOptions* TF_NewImportGraphDefOptions(); +TF_CAPI_EXPORT extern TF_ImportGraphDefOptions* TF_NewImportGraphDefOptions( + void); TF_CAPI_EXPORT extern void TF_DeleteImportGraphDefOptions( TF_ImportGraphDefOptions* opts); @@ -1611,7 +1612,7 @@ TF_CAPI_EXPORT extern void TF_DeleteLibraryHandle(TF_Library* lib_handle); // // The data in the buffer will be the serialized OpList proto for ops registered // in this address space. -TF_CAPI_EXPORT extern TF_Buffer* TF_GetAllOpList(); +TF_CAPI_EXPORT extern TF_Buffer* TF_GetAllOpList(void); // TF_ApiDefMap encapsulates a collection of API definitions for an operation. // diff --git a/tensorflow/c/c_api_experimental.cc b/tensorflow/c/c_api_experimental.cc index 38e29aa74a90f4e85d1369b6928a5a58c531b2da..f04b285037dff403428ed74fe90eac60339fe36b 100644 --- a/tensorflow/c/c_api_experimental.cc +++ b/tensorflow/c/c_api_experimental.cc @@ -66,7 +66,8 @@ void TF_EnableXLACompilation(TF_SessionOptions* options, unsigned char enable) { } TF_Buffer* TF_CreateConfig(unsigned char enable_xla_compilation, - unsigned char gpu_memory_allow_growth) { + unsigned char gpu_memory_allow_growth, + unsigned int num_cpu_devices) { tensorflow::ConfigProto config; auto* optimizer_options = config.mutable_graph_options()->mutable_optimizer_options(); @@ -87,6 +88,8 @@ TF_Buffer* TF_CreateConfig(unsigned char enable_xla_compilation, auto* gpu_options = config.mutable_gpu_options(); gpu_options->set_allow_growth(gpu_memory_allow_growth); + (*config.mutable_device_count())["CPU"] = num_cpu_devices; + // TODO(b/113217601): This is needed for EagerContext::runner_ to use a // threadpool, so that we avoid the possibility of running the runner_ in the // threadpool of GPU event mgr, as that can trigger more callbacks to be @@ -8535,8 +8538,9 @@ TFE_Context* TFE_CreateContextFromSession(TF_Session* session, // Reduce GPU memory allocation, and set appropriate config options for TFE // context. - auto* config = - TF_CreateConfig(/*xla*/ false, /* gpu_memory_allow_growth */ true); + auto* config = TF_CreateConfig( + /*xla*/ false, /* gpu_memory_allow_growth */ true, /* num_cpu_devices */ + 10); TFE_ContextOptionsSetConfig(opts, config->data, config->length, status); if (!status->status.ok()) { CHECK(!config); @@ -8886,3 +8890,54 @@ TFE_TensorHandle* TFE_NewTensorHandleFromScalar(TF_DataType dtype_arg, std::memcpy(tensorflow::TensorCApi::Buffer(tensor)->data(), data, len); return new TFE_TensorHandle(tensor, nullptr, nullptr); } + +namespace { +tensorflow::Status EnableCollectiveOps(const tensorflow::ServerDef& server_def, + TFE_Context* ctx) { + // We don't use the TF_RETURN_IF_ERROR macro directly since that destroys the + // server object (which currently CHECK-fails) and we miss the error, instead, + // we log the error, and then return to allow the user to see the error + // message. +#define LOG_AND_RETURN_IF_ERROR(...) \ + do { \ + const ::tensorflow::Status _status = (__VA_ARGS__); \ + if (TF_PREDICT_FALSE(!_status.ok())) { \ + LOG(ERROR) << _status.error_message(); \ + return _status; \ + } \ + } while (0); + + std::unique_ptr server; + LOG_AND_RETURN_IF_ERROR(tensorflow::NewServer(server_def, &server)); + + tensorflow::GrpcServer* grpc_server = + dynamic_cast(server.get()); + if (grpc_server == nullptr) { + LOG_AND_RETURN_IF_ERROR(tensorflow::errors::Internal( + "Currently, TFE_NewContext only supports tensorflow::GrpcServer.")); + } + + LOG_AND_RETURN_IF_ERROR(grpc_server->Start()); + + LOG_AND_RETURN_IF_ERROR(ctx->context.StoreCollectiveOpsServer( + std::move(server), grpc_server->worker_env()->device_mgr, + grpc_server->worker_env()->collective_executor_mgr)); + + return tensorflow::Status::OK(); +#undef LOG_AND_RETURN_IF_ERROR +} +} // namespace + +// Set server_def on the context, possibly updating it. +TF_CAPI_EXPORT extern void TFE_EnableCollectiveOps(TFE_Context* ctx, + const void* proto, + size_t proto_len, + TF_Status* status) { + tensorflow::ServerDef server_def; + if (!server_def.ParseFromArray(proto, proto_len)) { + status->status = tensorflow::errors::InvalidArgument( + "Invalid tensorflow.ServerDef protocol buffer"); + return; + } + status->status = EnableCollectiveOps(server_def, ctx); +} diff --git a/tensorflow/c/c_api_experimental.h b/tensorflow/c/c_api_experimental.h index 80c8bfe594c4c89606efd01bec7f50e7a86b5bda..e6d04d0c2b25a3f7b1ebf50c58268f003595a520 100644 --- a/tensorflow/c/c_api_experimental.h +++ b/tensorflow/c/c_api_experimental.h @@ -67,9 +67,10 @@ TF_CAPI_EXPORT extern void TF_EnableXLACompilation(TF_SessionOptions* options, // a) ConfigProto.optimizer_options.global_jit_level is set to to ON_1 if // `enable_xla_compilation` is non-zero, and OFF otherwise. // b) ConfigProto.gpu_options.allow_growth is set to `gpu_memory_allow_growth`. +// c) ConfigProto.device_count is set to `num_cpu_devices`. TF_CAPI_EXPORT extern TF_Buffer* TF_CreateConfig( - unsigned char enable_xla_compilation, - unsigned char gpu_memory_allow_growth); + unsigned char enable_xla_compilation, unsigned char gpu_memory_allow_growth, + unsigned int num_cpu_devices); // Create a serialized tensorflow.RunOptions proto, where RunOptions.trace_level // is set to FULL_TRACE if `enable_full_trace` is non-zero, and NO_TRACE @@ -239,13 +240,21 @@ TF_CAPI_EXPORT void TF_InitMain(const char* usage, int* argc, char*** argv); // Platform-specific implementation to return an unused port. (This should used // in tests only.) -TF_CAPI_EXPORT int TF_PickUnusedPortOrDie(); +TF_CAPI_EXPORT int TF_PickUnusedPortOrDie(void); // Fast path method that makes constructing a single scalar tensor require less // overhead and copies. TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_NewTensorHandleFromScalar( TF_DataType dtype, void* scalar, size_t len); +// Specify the server_def that enables collective ops. +// This is different to the above function in that it doesn't create remote +// contexts, and remotely executing ops is not possible. It just enables +// communication for collective ops. +TF_CAPI_EXPORT extern void TFE_EnableCollectiveOps(TFE_Context* ctx, + const void* proto, + size_t proto_len, + TF_Status* status); #ifdef __cplusplus } /* end extern "C" */ #endif diff --git a/tensorflow/c/c_test.c b/tensorflow/c/c_test.c new file mode 100644 index 0000000000000000000000000000000000000000..c0ed5ccd15d9524e2c14630d8ef92f6b3ef9b059 --- /dev/null +++ b/tensorflow/c/c_test.c @@ -0,0 +1,79 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include +#include +#include +#include + +#include "tensorflow/c/c_api.h" +#include "tensorflow/c/c_api_experimental.h" +#include "tensorflow/c/env.h" +#include "tensorflow/c/kernels.h" + +// A compute function. This will never actually get called in this test, it's +// just nice to know that it compiles. +void compute(void* kernel, TF_OpKernelContext* ctx) { + TF_Tensor* input; + TF_Status* s = TF_NewStatus(); + TF_GetInput(ctx, 0, &input, s); + TF_DeleteTensor(input); + TF_DeleteStatus(s); +} + +// Exercises tensorflow's C API. +int main(int argc, char** argv) { + TF_InitMain(argv[0], &argc, &argv); + + struct TF_StringStream* s = TF_GetLocalTempDirectories(); + const char* path; + + if (!TF_StringStreamNext(s, &path)) { + fprintf(stderr, "TF_GetLocalTempDirectories returned no results\n"); + return 1; + } + + char file_name[100]; + struct timeval t; + if (gettimeofday(&t, NULL)) { + perror("gettimeofday failed"); + return 1; + } + snprintf(file_name, sizeof(file_name), "test-%d-%ld.txt", getpid(), t.tv_sec); + + size_t length = 2 + strlen(path) + strlen(file_name); + char* full_path = malloc(length); + snprintf(full_path, length, "%s/%s", path, file_name); + + TF_WritableFileHandle* h; + TF_Status* status = TF_NewStatus(); + TF_NewWritableFile(full_path, &h, status); + if (TF_GetCode(status) != TF_OK) { + fprintf(stderr, "TF_NewWritableFile failed: %s\n", TF_Message(status)); + return 1; + } + fprintf(stderr, "wrote %s\n", full_path); + free(full_path); + TF_StringStreamDone(s); + + TF_KernelBuilder* b = + TF_NewKernelBuilder("SomeOp", "SomeDevice", NULL, &compute, NULL); + TF_RegisterKernelBuilder("someKernel", b, status); + + TF_DeleteStatus(status); + return 0; +} diff --git a/tensorflow/c/eager/c_api.h b/tensorflow/c/eager/c_api.h index 8d6c8d958d5961fce817156a14eb2b2940c1f2f0..120748ab763a3358b6e38e64bb3b6fd2ea32f7c3 100755 --- a/tensorflow/c/eager/c_api.h +++ b/tensorflow/c/eager/c_api.h @@ -48,7 +48,7 @@ extern "C" { typedef struct TFE_ContextOptions TFE_ContextOptions; // Return a new options object. -TF_CAPI_EXPORT extern TFE_ContextOptions* TFE_NewContextOptions(); +TF_CAPI_EXPORT extern TFE_ContextOptions* TFE_NewContextOptions(void); // Set the config in TF_ContextOptions.options. // config should be a serialized tensorflow.ConfigProto proto. @@ -170,23 +170,11 @@ TF_CAPI_EXPORT extern int64_t TFE_TensorHandleDim(TFE_TensorHandle* h, int dim_index, TF_Status* status); -// Returns the device of the operation that produced `h`. -// If `h` was produced by a copy, returns the destination device of -// the copy. Note that returned device name is not always the device -// holding the tensor handle's memory. If you want the latter, use -// TFE_TensorHandleBackingDeviceName. -// This function will block till the operation that produces `h` has completed. -// -// Device on which the kernel of the operation that produced `h` ran. -// -// If `h` was produced by a copy, returns the destination device of -// the copy. -// -// Note that returned device name is not always the device that owns the memory -// that backs the tensor handle. For the latter see -// TFE_TensorHandleBackingDeviceName. -// -// This function will block till the operation that produces `h` has completed. +// Returns the device of the operation that produced `h`. If `h` was produced by +// a copy, returns the destination device of the copy. Note that the returned +// device name is not always the device holding the tensor handle's memory. If +// you want the latter, use TFE_TensorHandleBackingDeviceName. This function +// will block till the operation that produces `h` has completed. TF_CAPI_EXPORT extern const char* TFE_TensorHandleDeviceName( TFE_TensorHandle* h, TF_Status* status); diff --git a/tensorflow/c/env.cc b/tensorflow/c/env.cc new file mode 100644 index 0000000000000000000000000000000000000000..1c35ff9001d0ee1ab0fbae9e1bcc07116fab1065 --- /dev/null +++ b/tensorflow/c/env.cc @@ -0,0 +1,183 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/c/env.h" + +#include "tensorflow/c/c_api_internal.h" +#include "tensorflow/c/tf_status_helper.h" +#include "tensorflow/core/platform/env.h" +#include "tensorflow/core/platform/types.h" + +struct TF_StringStream { + std::vector<::tensorflow::string>* list; + size_t position; +}; + +void TF_CreateDir(const char* dirname, TF_Status* status) { + TF_SetStatus(status, TF_OK, ""); + ::tensorflow::Set_TF_Status_from_Status( + status, ::tensorflow::Env::Default()->CreateDir(dirname)); +} + +void TF_DeleteDir(const char* dirname, TF_Status* status) { + TF_SetStatus(status, TF_OK, ""); + ::tensorflow::Set_TF_Status_from_Status( + status, ::tensorflow::Env::Default()->DeleteDir(dirname)); +} + +void TF_DeleteRecursively(const char* dirname, int64_t* undeleted_file_count, + int64_t* undeleted_dir_count, TF_Status* status) { + ::tensorflow::int64 f, d; + + TF_SetStatus(status, TF_OK, ""); + ::tensorflow::Set_TF_Status_from_Status( + status, ::tensorflow::Env::Default()->DeleteRecursively(dirname, &f, &d)); + *undeleted_file_count = f; + *undeleted_dir_count = d; +} + +void TF_FileStat(const char* filename, TF_FileStatistics* stats, + TF_Status* status) { + ::tensorflow::FileStatistics cc_stats; + TF_SetStatus(status, TF_OK, ""); + ::tensorflow::Status s = + ::tensorflow::Env::Default()->Stat(filename, &cc_stats); + ::tensorflow::Set_TF_Status_from_Status(status, s); + if (s.ok()) { + stats->length = cc_stats.length; + stats->mtime_nsec = cc_stats.mtime_nsec; + stats->is_directory = cc_stats.is_directory; + } +} + +void TF_NewWritableFile(const char* filename, TF_WritableFileHandle** handle, + TF_Status* status) { + std::unique_ptr<::tensorflow::WritableFile> f; + TF_SetStatus(status, TF_OK, ""); + ::tensorflow::Status s = + ::tensorflow::Env::Default()->NewWritableFile(filename, &f); + ::tensorflow::Set_TF_Status_from_Status(status, s); + + if (s.ok()) { + *handle = reinterpret_cast(f.release()); + } +} + +void TF_CloseWritableFile(TF_WritableFileHandle* handle, TF_Status* status) { + auto* cc_file = reinterpret_cast<::tensorflow::WritableFile*>(handle); + TF_SetStatus(status, TF_OK, ""); + ::tensorflow::Set_TF_Status_from_Status(status, cc_file->Close()); + delete cc_file; +} + +void TF_SyncWritableFile(TF_WritableFileHandle* handle, TF_Status* status) { + auto* cc_file = reinterpret_cast<::tensorflow::WritableFile*>(handle); + TF_SetStatus(status, TF_OK, ""); + ::tensorflow::Set_TF_Status_from_Status(status, cc_file->Sync()); +} + +void TF_FlushWritableFile(TF_WritableFileHandle* handle, TF_Status* status) { + auto* cc_file = reinterpret_cast<::tensorflow::WritableFile*>(handle); + TF_SetStatus(status, TF_OK, ""); + ::tensorflow::Set_TF_Status_from_Status(status, cc_file->Flush()); +} + +void TF_AppendWritableFile(TF_WritableFileHandle* handle, const char* data, + size_t length, TF_Status* status) { + auto* cc_file = reinterpret_cast<::tensorflow::WritableFile*>(handle); + TF_SetStatus(status, TF_OK, ""); + ::tensorflow::Set_TF_Status_from_Status( + status, cc_file->Append(::tensorflow::StringPiece{data, length})); +} + +void TF_DeleteFile(const char* filename, TF_Status* status) { + TF_SetStatus(status, TF_OK, ""); + ::tensorflow::Set_TF_Status_from_Status( + status, ::tensorflow::Env::Default()->DeleteFile(filename)); +} + +bool TF_StringStreamNext(TF_StringStream* list, const char** result) { + if (list->position >= list->list->size()) { + *result = nullptr; + return false; + } + + *result = list->list->at(list->position++).c_str(); + return true; +} + +void TF_StringStreamDone(TF_StringStream* list) { + delete list->list; + delete list; +} +TF_StringStream* TF_GetChildren(const char* dirname, TF_Status* status) { + auto* children = new std::vector<::tensorflow::string>; + + TF_SetStatus(status, TF_OK, ""); + ::tensorflow::Set_TF_Status_from_Status( + status, ::tensorflow::Env::Default()->GetChildren(dirname, children)); + + auto* list = new TF_StringStream; + list->list = children; + list->position = 0; + return list; +} + +TF_StringStream* TF_GetLocalTempDirectories() { + auto* tmpdirs = new std::vector<::tensorflow::string>; + + ::tensorflow::Env::Default()->GetLocalTempDirectories(tmpdirs); + + auto* list = new TF_StringStream; + list->list = tmpdirs; + list->position = 0; + return list; +} + +TF_CAPI_EXPORT extern uint64_t TF_NowNanos(void) { + return ::tensorflow::Env::Default()->NowNanos(); +} + +// Returns the number of microseconds since the Unix epoch. +TF_CAPI_EXPORT extern uint64_t TF_NowMicros(void) { + return ::tensorflow::Env::Default()->NowMicros(); +} + +// Returns the number of seconds since the Unix epoch. +TF_CAPI_EXPORT extern uint64_t TF_NowSeconds(void) { + return ::tensorflow::Env::Default()->NowSeconds(); +} + +void TF_DefaultThreadOptions(TF_ThreadOptions* options) { + options->stack_size = 0; + options->guard_size = 0; + options->numa_node = -1; +} + +TF_Thread* TF_StartThread(const TF_ThreadOptions* options, + const char* thread_name, void (*work_func)(void*), + void* param) { + ::tensorflow::ThreadOptions cc_options; + cc_options.stack_size = options->stack_size; + cc_options.guard_size = options->guard_size; + cc_options.numa_node = options->numa_node; + return reinterpret_cast(::tensorflow::Env::Default()->StartThread( + cc_options, thread_name, [=]() { (*work_func)(param); })); +} + +void TF_JoinThread(TF_Thread* thread) { + // ::tensorflow::Thread joins on destruction + delete reinterpret_cast<::tensorflow::Thread*>(thread); +} diff --git a/tensorflow/c/env.h b/tensorflow/c/env.h new file mode 100644 index 0000000000000000000000000000000000000000..73078fcbbc5ae4c042f4a992655072a838e42915 --- /dev/null +++ b/tensorflow/c/env.h @@ -0,0 +1,195 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include + +#ifndef TENSORFLOW_C_ENV_H_ +#define TENSORFLOW_C_ENV_H_ + +#include "tensorflow/c/c_api.h" + +// -------------------------------------------------------------------------- +// C API for tensorflow::Env. + +#ifdef __cplusplus +extern "C" { +#endif + +typedef struct TF_WritableFileHandle TF_WritableFileHandle; +typedef struct TF_StringStream TF_StringStream; +typedef struct TF_Thread TF_Thread; + +typedef struct TF_FileStatistics { + // The length of the file in bytes. + int64_t length; + // The last modified time in nanoseconds. + int64_t mtime_nsec; + // Whether the name refers to a directory. + bool is_directory; +} TF_FileStatistics; + +typedef struct TF_ThreadOptions { + // Thread stack size to use (in bytes), zero implies that the system default + // will be used. + size_t stack_size; + + // Guard area size to use near thread stacks to use (in bytes), zero implies + // that the system default will be used. + size_t guard_size; + + // The NUMA node to use, -1 implies that there should be no NUMA affinity for + // this thread. + int numa_node; +} TF_ThreadOptions; + +// Creates the specified directory. Typical status code are: +// * TF_OK - successfully created the directory +// * TF_ALREADY_EXISTS - directory already exists +// * TF_PERMISSION_DENIED - dirname is not writable +TF_CAPI_EXPORT extern void TF_CreateDir(const char* dirname, TF_Status* status); + +// Deletes the specified directory. Typical status codes are: +// * TF_OK - successfully deleted the directory +// * TF_FAILED_PRECONDITION - the directory is not empty +TF_CAPI_EXPORT extern void TF_DeleteDir(const char* dirname, TF_Status* status); + +// Deletes the specified directory and all subdirectories and files underneath +// it. This is accomplished by traversing the directory tree rooted at dirname +// and deleting entries as they are encountered. +// +// If dirname itself is not readable or does not exist, *undeleted_dir_count is +// set to 1, *undeleted_file_count is set to 0 and an appropriate status (e.g. +// TF_NOT_FOUND) is returned. +// +// If dirname and all its descendants were successfully deleted, TF_OK is +// returned and both error counters are set to zero. +// +// Otherwise, while traversing the tree, undeleted_file_count and +// undeleted_dir_count are updated if an entry of the corresponding type could +// not be deleted. The returned error status represents the reason that any one +// of these entries could not be deleted. +// +// Typical status codes: +// * TF_OK - dirname exists and we were able to delete everything underneath +// * TF_NOT_FOUND - dirname doesn't exist +// * TF_PERMISSION_DENIED - dirname or some descendant is not writable +// * TF_UNIMPLEMENTED - some underlying functions (like Delete) are not +// implemented +TF_CAPI_EXPORT extern void TF_DeleteRecursively(const char* dirname, + int64_t* undeleted_file_count, + int64_t* undeleted_dir_count, + TF_Status* status); + +// Obtains statistics for the given path. If status is TF_OK, *stats is +// updated, otherwise it is not touched. +TF_CAPI_EXPORT extern void TF_FileStat(const char* filename, + TF_FileStatistics* stats, + TF_Status* status); + +// Creates or truncates the given filename and returns a handle to be used for +// appending data to the file. If status is TF_OK, *handle is updated and the +// caller is responsible for freeing it (see TF_CloseWritableFile). +TF_CAPI_EXPORT extern void TF_NewWritableFile(const char* filename, + TF_WritableFileHandle** handle, + TF_Status* status); + +// Closes the given handle and frees its memory. If there was a problem closing +// the file, it is indicated by status. Memory is freed in any case. +TF_CAPI_EXPORT extern void TF_CloseWritableFile(TF_WritableFileHandle* handle, + TF_Status* status); + +// Syncs content of the handle to the filesystem. Blocks waiting for the +// filesystem to indicate that the content has been persisted. +TF_CAPI_EXPORT extern void TF_SyncWritableFile(TF_WritableFileHandle* handle, + TF_Status* status); + +// Flush local buffers to the filesystem. If the process terminates after a +// successful flush, the contents may still be persisted, since the underlying +// filesystem may eventually flush the contents. If the OS or machine crashes +// after a successful flush, the contents may or may not be persisted, depending +// on the implementation. +TF_CAPI_EXPORT extern void TF_FlushWritableFile(TF_WritableFileHandle* handle, + TF_Status* status); + +// Appends the given bytes to the file. Any failure to do so is indicated in +// status. +TF_CAPI_EXPORT extern void TF_AppendWritableFile(TF_WritableFileHandle* handle, + const char* data, + size_t length, + TF_Status* status); + +// Deletes the named file and indicates whether successful in *status. +TF_CAPI_EXPORT extern void TF_DeleteFile(const char* filename, + TF_Status* status); + +// Retrieves the next item from the given TF_StringStream and places a pointer +// to it in *result. If no more items are in the list, *result is set to NULL +// and false is returned. +// +// Ownership of the items retrieved with this function remains with the library. +// Item points are invalidated after a call to TF_StringStreamDone. +TF_CAPI_EXPORT extern bool TF_StringStreamNext(TF_StringStream* list, + const char** result); + +// Frees the resources associated with given string list. All pointers returned +// by TF_StringStreamNext are invalid after this call. +TF_CAPI_EXPORT extern void TF_StringStreamDone(TF_StringStream* list); + +// Retrieves the list of children of the given directory. You can iterate +// through the list with TF_StringStreamNext. The caller is responsible for +// freeing the list (see TF_StringStreamDone). +TF_CAPI_EXPORT extern TF_StringStream* TF_GetChildren(const char* filename, + TF_Status* status); + +// Retrieves a list of directory names on the local machine that may be used for +// temporary storage. You can iterate through the list with TF_StringStreamNext. +// The caller is responsible for freeing the list (see TF_StringStreamDone). +TF_CAPI_EXPORT extern TF_StringStream* TF_GetLocalTempDirectories(void); + +// Returns the number of nanoseconds since the Unix epoch. +TF_CAPI_EXPORT extern uint64_t TF_NowNanos(void); + +// Returns the number of microseconds since the Unix epoch. +TF_CAPI_EXPORT extern uint64_t TF_NowMicros(void); + +// Returns the number of seconds since the Unix epoch. +TF_CAPI_EXPORT extern uint64_t TF_NowSeconds(void); + +// Populates a TF_ThreadOptions struct with system-default values. +TF_CAPI_EXPORT extern void TF_DefaultThreadOptions(TF_ThreadOptions* options); + +// Returns a new thread that is running work_func and is identified +// (for debugging/performance-analysis) by thread_name. +// +// The given param (which may be null) is passed to work_func when the thread +// starts. In this way, data may be passed from the thread back to the caller. +// +// Caller takes ownership of the result and must call TF_JoinThread on it +// eventually. +TF_CAPI_EXPORT extern TF_Thread* TF_StartThread(const TF_ThreadOptions* options, + const char* thread_name, + void (*work_func)(void*), + void* param); + +// Waits for the given thread to finish execution, then deletes it. +TF_CAPI_EXPORT extern void TF_JoinThread(TF_Thread* thread); + +#ifdef __cplusplus +} +#endif + +#endif // TENSORFLOW_C_ENV_H_ diff --git a/tensorflow/c/env_test.cc b/tensorflow/c/env_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..687ad024137352662759ec1f43df87e89faca353 --- /dev/null +++ b/tensorflow/c/env_test.cc @@ -0,0 +1,127 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/c/env.h" + +#include "tensorflow/c/c_api.h" +#include "tensorflow/core/lib/io/path.h" +#include "tensorflow/core/platform/mutex.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/platform/types.h" + +#define ASSERT_TF_OK(x) ASSERT_EQ(TF_OK, TF_GetCode(x)) + +TEST(TestEnv, TestDirHandling) { + TF_StringStream* tempdirs = TF_GetLocalTempDirectories(); + const char* tempdir; + bool found = false; + while (TF_StringStreamNext(tempdirs, &tempdir)) { + found = true; + + TF_Status* s = TF_NewStatus(); + + ::tensorflow::string dirpath = + ::tensorflow::io::JoinPath(tempdir, "somedir"); + TF_CreateDir(dirpath.c_str(), s); + ASSERT_TF_OK(s) << "TF_CreateDir failed for " << dirpath << ": " + << TF_Message(s); + + ::tensorflow::string filepath = + ::tensorflow::io::JoinPath(dirpath, "somefile.txt"); + TF_WritableFileHandle* handle; + TF_NewWritableFile(filepath.c_str(), &handle, s); + ASSERT_TF_OK(s) << "NewWritableFile failed for " << filepath << ": " + << TF_Message(s); + + const char* data = "Hello, world!\n"; + TF_AppendWritableFile(handle, data, strlen(data), s); + ASSERT_TF_OK(s) << "TF_AppendWritableFile failed to append data to file at " + << filepath << ": " << TF_Message(s); + + TF_CloseWritableFile(handle, s); + ASSERT_TF_OK(s) << "TF_CloseWritableFile failed to close handle to " + << filepath << ": " << TF_Message(s); + + TF_StringStream* children = TF_GetChildren(dirpath.c_str(), s); + ASSERT_TF_OK(s) << "TF_GetChildren failed for " << dirpath; + const char* childpath; + ASSERT_TRUE(TF_StringStreamNext(children, &childpath)); + ASSERT_EQ(::tensorflow::string(childpath), "somefile.txt"); + // There should only be one file in this directory. + ASSERT_FALSE(TF_StringStreamNext(children, &childpath)); + ASSERT_EQ(childpath, nullptr); + TF_StringStreamDone(children); + + TF_FileStatistics stats; + TF_FileStat(filepath.c_str(), &stats, s); + ASSERT_EQ(stats.length, strlen(data)); + ASSERT_FALSE(stats.is_directory); + ASSERT_GT(stats.mtime_nsec, 0); + + // Trying to delete a non-empty directory should fail. + TF_DeleteDir(dirpath.c_str(), s); + ASSERT_NE(TF_OK, TF_GetCode(s)) + << "TF_DeleteDir unexpectedly succeeded with a non-empty directory " + << dirpath; + + TF_DeleteFile(filepath.c_str(), s); + ASSERT_TF_OK(s) << "TF_DeleteFile failed for " << filepath << ": " + << TF_Message(s); + + // Now deleting the directory should work. + TF_DeleteDir(dirpath.c_str(), s); + ASSERT_TF_OK(s) << "TF_DeleteDir failed for " << dirpath << ": " + << TF_Message(s); + + TF_DeleteStatus(s); + break; + } + + ASSERT_TRUE(found) << "expected at least one temp dir"; + + TF_StringStreamDone(tempdirs); +} + +TEST(TestEnv, TestTimeFunctions) { + ASSERT_GE(TF_NowSeconds(), 946684800); // Midnight Jan 1, 2000 + ASSERT_GE(TF_NowMicros(), 946684800 * 1e6); + ASSERT_GE(TF_NowNanos(), 946684800 * 1e9); +} + +namespace { + +struct SomeThreadData { + ::tensorflow::mutex mu; + bool did_work = false; +}; + +void SomeThreadFunc(void* data) { + auto* real_data = static_cast(data); + ::tensorflow::mutex_lock l(real_data->mu); + real_data->did_work = true; +} + +} // namespace + +TEST(TestEnv, TestThreads) { + TF_ThreadOptions options; + TF_DefaultThreadOptions(&options); + SomeThreadData data; + TF_Thread* thread = + TF_StartThread(&options, "SomeThreadName", &SomeThreadFunc, &data); + TF_JoinThread(thread); + ::tensorflow::mutex_lock l(data.mu); + ASSERT_TRUE(data.did_work); +} diff --git a/tensorflow/c/kernels.h b/tensorflow/c/kernels.h index 1a91aa184f11ac8e45b38a1d106c7b445747a7c1..cefc30bcdf89bdc14a4406299cc29f74153e77ac 100644 --- a/tensorflow/c/kernels.h +++ b/tensorflow/c/kernels.h @@ -35,9 +35,9 @@ extern "C" { // `TF_RegisterKernelBuilder`, which will allow TF to construct user-provided // kernels when necessary. -struct TF_KernelBuilder; -struct TF_OpKernelConstruction; -struct TF_OpKernelContext; +typedef struct TF_KernelBuilder TF_KernelBuilder; +typedef struct TF_OpKernelConstruction TF_OpKernelConstruction; +typedef struct TF_OpKernelContext TF_OpKernelContext; // Allocates a new kernel builder and returns a pointer to it. // diff --git a/tensorflow/cc/gradients/image_grad.cc b/tensorflow/cc/gradients/image_grad.cc index 882709e1e2817431a32c453fe0f35f2b2e6c69b0..05c287bdc62cdb8be7208ce3975f280aaa816766 100644 --- a/tensorflow/cc/gradients/image_grad.cc +++ b/tensorflow/cc/gradients/image_grad.cc @@ -69,6 +69,23 @@ Status ResizeBicubicGradHelper(const Scope& scope, const Operation& op, } REGISTER_GRADIENT_OP("ResizeBicubic", ResizeBicubicGradHelper); +Status ScaleAndTranslateGradHelper(const Scope& scope, const Operation& op, + const std::vector& grad_inputs, + std::vector* grad_outputs) { + string kernel_type; + TF_RETURN_IF_ERROR( + GetNodeAttr(op.node()->attrs(), "kernel_type", &kernel_type)); + grad_outputs->push_back(internal::ScaleAndTranslateGrad( + scope, grad_inputs[0], op.input(0), op.input(2), op.input(3), + internal::ScaleAndTranslateGrad::KernelType(kernel_type))); + + grad_outputs->push_back(NoGradient()); + grad_outputs->push_back(NoGradient()); + grad_outputs->push_back(NoGradient()); + return scope.status(); +} +REGISTER_GRADIENT_OP("ScaleAndTranslate", ScaleAndTranslateGradHelper); + } // anonymous namespace } // namespace ops } // namespace tensorflow diff --git a/tensorflow/cc/gradients/image_grad_test.cc b/tensorflow/cc/gradients/image_grad_test.cc index 2e55c7561b030c50bd67bd53fd0d55710085c5d2..1d150226538093467e092e02f38090a327f9c9b6 100644 --- a/tensorflow/cc/gradients/image_grad_test.cc +++ b/tensorflow/cc/gradients/image_grad_test.cc @@ -30,6 +30,7 @@ using ops::Const; using ops::ResizeBicubic; using ops::ResizeBilinear; using ops::ResizeNearestNeighbor; +using ops::ScaleAndTranslate; class ImageGradTest : public ::testing::Test { protected: @@ -153,5 +154,45 @@ TEST_F(ImageGradTest, TestBicubic) { TestResize(RESIZE_BICUBIC); } +class ScaleAndTranslateGradTest : public ::testing::Test { + protected: + ScaleAndTranslateGradTest() : scope_(Scope::NewRootScope()) {} + + template + Tensor MakeData(const TensorShape& data_shape) { + DataType data_type = DataTypeToEnum::v(); + Tensor data(data_type, data_shape); + auto data_flat = data.flat(); + for (int i = 0; i < data_flat.size(); ++i) { + data_flat(i) = T(i); + } + return data; + } + + template + void MakeOp(const Tensor& x_data, const Input& y_shape, Output* x, + Output* y) { + *x = Const(scope_, x_data); + *y = ScaleAndTranslate(scope_, *x, y_shape, {1.8f, 2.1f}, {0.5f, 0.7f}); + TF_ASSERT_OK(scope_.status()); + } + + template + void TestResize() { + TensorShape x_shape({1, 2, 3, 1}); + Tensor x_data = MakeData(x_shape); + Output x, y; + MakeOp(x_data, {4, 6}, &x, &y); + JAC_T max_error; + TF_ASSERT_OK((ComputeGradientError( + scope_, x, x_data, y, {1, 4, 6, 1}, &max_error))); + EXPECT_LT(max_error, 1e-3); + } + + Scope scope_; +}; + +TEST_F(ScaleAndTranslateGradTest, Works) { TestResize(); } + } // namespace } // namespace tensorflow diff --git a/tensorflow/compat_template_v1.__init__.py b/tensorflow/compat_template_v1.__init__.py index 7df80ec01245a7fe820c79d5879458c4cd0a93cb..b966c22b2319aef3b87ef54a283911718d37cf84 100644 --- a/tensorflow/compat_template_v1.__init__.py +++ b/tensorflow/compat_template_v1.__init__.py @@ -23,12 +23,14 @@ import os as _os # pylint: disable=g-bad-import-order from tensorflow.python import pywrap_tensorflow # pylint: disable=unused-import +# API IMPORTS PLACEHOLDER + from tensorflow.python.tools import component_api_helper as _component_api_helper _component_api_helper.package_hook( parent_package_str=__name__, child_package_str=('tensorflow_estimator.python.estimator.api.estimator')) - -# API IMPORTS PLACEHOLDER - +_component_api_helper.package_hook( + parent_package_str=__name__, + child_package_str=('tensorflow.python.keras.api._v1.keras')) from tensorflow.python.platform import flags # pylint: disable=g-import-not-at-top app.flags = flags # pylint: disable=undefined-variable diff --git a/tensorflow/compiler/aot/tfcompile.bzl b/tensorflow/compiler/aot/tfcompile.bzl index 2dc3e8c9113b37bf9d575ad66783f4ab49478af4..4051664c24cacad4a2d151ad3ac9009015900609 100644 --- a/tensorflow/compiler/aot/tfcompile.bzl +++ b/tensorflow/compiler/aot/tfcompile.bzl @@ -283,7 +283,7 @@ def tf_library( ) # Variables used for gen_test and gen_benchmark. - cpp_class_split = cpp_class.rsplit("::", maxsplit = 2) + cpp_class_split = cpp_class.rsplit("::", 2) if len(cpp_class_split) == 1: no_ns_name = cpp_class_split[0] else: diff --git a/tensorflow/compiler/jit/BUILD b/tensorflow/compiler/jit/BUILD index 7ebcd120f6bc26a1b03f388ec03964cd042c127a..b9a87ba296abfc6b9d9aaeff3b3e26678e4e1b94 100644 --- a/tensorflow/compiler/jit/BUILD +++ b/tensorflow/compiler/jit/BUILD @@ -76,6 +76,7 @@ cc_library( srcs = ["xla_cpu_device.cc"], visibility = [":friends"], deps = [ + ":create_xla_launch_op", # buildcleaner: keep ":flags", ":jit_compilation_passes", ":xla_device", @@ -95,6 +96,7 @@ cc_library( srcs = ["xla_gpu_device.cc"], visibility = [":friends"], deps = [ + ":create_xla_launch_op", # buildcleaner: keep ":jit_compilation_passes", ":xla_device", "//tensorflow/compiler/jit/kernels:xla_ops", @@ -199,6 +201,7 @@ cc_library( "//tensorflow/core/kernels/data:prefetch_dataset_op", "@com_google_absl//absl/memory", "@com_google_absl//absl/synchronization", + "@com_google_absl//absl/types:optional", ], ) @@ -513,6 +516,7 @@ cc_library( "//tensorflow/compiler/jit/ops:xla_ops", "//tensorflow/compiler/tf2xla:dump_graph", "//tensorflow/compiler/tf2xla:resource_operation_table", + "//tensorflow/compiler/tf2xla:side_effect_util", "//tensorflow/compiler/tf2xla:tf2xla_util", "//tensorflow/compiler/tf2xla:xla_compiler", "//tensorflow/compiler/tf2xla/cc:xla_jit_ops", @@ -611,6 +615,7 @@ tf_cc_test( "//tensorflow/cc:cc_ops", "//tensorflow/cc:cc_ops_internal", "//tensorflow/cc:function_ops", + "//tensorflow/cc:functional_ops", "//tensorflow/cc:ops", "//tensorflow/cc:resource_variable_ops", "//tensorflow/cc:scope", @@ -623,6 +628,7 @@ tf_cc_test( "//tensorflow/compiler/tf2xla/cc:xla_ops", "//tensorflow/compiler/tf2xla/kernels:xla_dummy_ops", "//tensorflow/compiler/tf2xla/kernels:xla_ops", + "//tensorflow/compiler/xla:test", "//tensorflow/core:core_cpu", "//tensorflow/core:framework", "//tensorflow/core:framework_internal", diff --git a/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc b/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc index f478832781cb1dc045d9163d4a6f5e5f64a8a705..03aba97bbe81a11f6366d118ee5bc573d0c6b31b 100644 --- a/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc +++ b/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc @@ -779,7 +779,8 @@ Status Encapsulator::Subgraph::RecordArg( if (inserted) { NodeDef arg_def; NodeDefBuilder builder( - absl::StrCat(src_node->name(), "_", src_slot, "_arg"), kArgOp); + absl::StrCat(src_node->name(), "_", src_slot, "_arg"), kArgOp, + NodeDebugInfo(src_node->def())); DataType dtype = edge->dst()->input_type(edge->dst_input()); builder.Attr("T", dtype); builder.Attr("index", arg_index); @@ -814,7 +815,8 @@ Status Encapsulator::Subgraph::RecordResult( if (inserted) { NodeDef ret_def; NodeDefBuilder builder( - absl::StrCat(src_node->name(), "_", src_slot, "_retval"), kRetValOp); + absl::StrCat(src_node->name(), "_", src_slot, "_retval"), kRetValOp, + NodeDebugInfo(src_node->def())); DataType dtype = src_node->output_type(src_slot); builder.Attr("T", dtype); builder.Attr("index", ret_index); @@ -974,6 +976,7 @@ Status Encapsulator::Subgraph::AddHostComputes( } NodeDef host_compute_def; + // TODO(shikharagarwal): What source node should we use for errors? NodeDefBuilder builder(absl::StrCat("outside_compilation_", oc_subgraph_name, "_host_compute"), kHostComputeOp); @@ -1040,6 +1043,7 @@ Status Encapsulator::Subgraph::MakeSequencingNode(const string& subgraph_name, Graph* graph_out) { if (sequencer_ == nullptr) { NodeDef seq_def; + // TODO(shikharagarwal): What source node should we use for errors? NodeDefBuilder builder(absl::StrCat(subgraph_name, "_sequencer"), "NoOp"); builder.Attr(kXlaHostTransferSequencerAttr, subgraph_name); builder.Device(device_); @@ -1214,7 +1218,8 @@ Status Encapsulator::Subgraph::AddHostComputeKeyPlaceholder( GraphDefBuilder::Options options(graph_out, /*status=*/nullptr); NodeDef key_def; NodeDefBuilder builder( - absl::StrCat(call_node_def_.name(), "_key_placeholder"), "Placeholder"); + absl::StrCat(call_node_def_.name(), "_key_placeholder"), "Placeholder", + NodeDebugInfo(call_node_def_)); builder.Attr("dtype", DT_STRING); builder.Attr("shape", shape_proto); builder.Attr("_host_compute_call_node", call_node_def_.name()); @@ -1248,6 +1253,7 @@ Status Encapsulator::Subgraph::AddRecvAtHostNode( } NodeDef recv_def; + // TODO(shikharagarwal): What source node should we use for errors? NodeDefBuilder builder(absl::StrCat("outside_compilation_", subgraph_name, "_", oc_subgraph_name, "_recv"), kRecvAtHostOp); @@ -1303,6 +1309,7 @@ Status Encapsulator::Subgraph::AddSendFromHostNode( } NodeDef send_def; + // TODO(shikharagarwal): What source node should we use for errors? NodeDefBuilder builder(absl::StrCat("outside_compilation_", subgraph_name, "_", oc_subgraph_name, "_send"), kSendFromHostOp); @@ -1833,8 +1840,9 @@ Node* AddDummyShapedNode(const Node* src_node, int src_port, // Add any Enter nodes required to bring the constant to the correct control // flow frame. while (!control_flow_info[src_node->id()].frame_name.empty()) { + NodeDebugInfo debug_info(*src_node); NodeBuilder enter_builder(options.GetNameForOp("Enter"), "Enter", - options.op_registry()); + options.op_registry(), &debug_info); enter_builder.Attr("frame_name", control_flow_info[src_node->id()].frame_name); enter_builder.Attr("is_constant", true); @@ -2018,7 +2026,8 @@ Status Encapsulator::DoStaticShapeInferenceForOutsideCompilationSend( return errors::InvalidArgument( "Shape inference is not possible for outside_compilation " "SendFromHost node ", - send_node->name(), " because shape of node ", n->name(), + send_node->name(), " because shape of node ", + FormatNodeForError(*n), " will not be known at compilation time."); } } @@ -2047,8 +2056,7 @@ Status Encapsulator::DoStaticShapeInferenceForOutsideCompilationSend( return errors::Internal( "Internal assumption failed while rewriting an outside_compilation " "cluster that contains a while loop. Logic assumes back-edge is to " - "port 1 of a 2-input " - "Merge node."); + "port 1 of a 2-input Merge node."); } // Connect the existing edge to both inputs of the Merge node so that the // graph will be well-formed. diff --git a/tensorflow/compiler/jit/encapsulate_subgraphs_pass_test.cc b/tensorflow/compiler/jit/encapsulate_subgraphs_pass_test.cc index de89be9a3555960dabe7bacd17226c15ae888ae6..8617beec004d0fe912155f054442c5b6249bb6b5 100644 --- a/tensorflow/compiler/jit/encapsulate_subgraphs_pass_test.cc +++ b/tensorflow/compiler/jit/encapsulate_subgraphs_pass_test.cc @@ -299,7 +299,7 @@ REGISTER_OP("XlaHostCompute") .Attr("Toutputs: list(type) >= 0") .Attr("ancestors: list(string) >= 0") .Attr("key: string") - .Attr("shape_inference_graph: string = ''") + .Attr("shape_inference_graph: func") .Attr("shapes: list(shape) >= 0") .SetShapeFn(::tensorflow::shape_inference::UnknownShape); @@ -510,11 +510,7 @@ Status Encapsulate(GraphDef* graphdef, FunctionDefLibrary* library, s = ConvertGraphDefToGraph(options, *graphdef, graph.get()); if (!s.ok()) return s; - s = PerformStaticShapeInferenceBeforeEncapsulation( - graph.get(), "_encapsulate", "_outside"); - if (!s.ok()) return s; - - s = PreprocessForEncapsulation(graph.get(), "_encapsulate", "_outside"); + s = PerformStaticShapeInferenceBeforeEncapsulation(graph.get()); if (!s.ok()) return s; std::unique_ptr graph_out; @@ -550,6 +546,14 @@ Status Encapsulate(GraphDef* graphdef, FunctionDefLibrary* library, graphdef->Swap(&graphdef_out); *library = lib_def->ToProto(); + // Remove "_xla_inferred_shapes" attr. They are added by + // `PerformStaticShapeInferenceBeforeEncapsulation`. + for (FunctionDef& fdef : *library->mutable_function()) { + for (NodeDef& node_def : *fdef.mutable_node_def()) { + node_def.mutable_attr()->erase("_xla_inferred_shapes"); + } + } + return s; } @@ -901,18 +905,22 @@ TEST(EncapsulateSubgraphsTest, OneFunctionOneOutside) { { GraphDefBuilder shape(GraphDefBuilder::kFailImmediately); Node* key_constant = KeyPlaceholder("F1", shape.opts()); - Node* recv = RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O1", - {DT_FLOAT, DT_FLOAT}, shape.opts()); + Node* recv = RecvAtHost( + ops::NodeOut(key_constant, 0), "F1", "O1", {DT_FLOAT, DT_FLOAT}, + shape.opts().WithAttr(kXlaHasHostTransferAttrName, true)); Node* e = Binary(ops::NodeOut(recv, 0), ops::NodeOut(recv, 1), shape.opts() .WithName("E") .WithAttr("_encapsulate", "F1") .WithAttr("_outside", "O1")); - SendFromHost(ops::NodeOut(key_constant, 0), "F1", "O1", {e}, shape.opts()); + SendFromHost(ops::NodeOut(key_constant, 0), "F1", "O1", {e}, + shape.opts().WithAttr(kXlaHasHostTransferAttrName, true)); TF_EXPECT_OK( AddGraphDefToFunctionLibrary(shape, "F1_O1", &library_expected)); } + NameAttrList shape_inference_graph; + shape_inference_graph.set_name("_outside_compilation_shape_inference_F1_O1"); *library_expected.add_function() = test::function::XTimesTwo(); *library_expected.add_function() = FunctionDefHelper::Create( "F1", {"a_0_arg:float", "b_0_arg:float"}, {"f_0_retval_retval:float"}, {}, @@ -931,8 +939,7 @@ TEST(EncapsulateSubgraphsTest, OneFunctionOneOutside) { {"Toutputs", absl::Span({DT_FLOAT})}, {"ancestors", absl::Span({})}, {"key", "host_compute_channel_F1_O1"}, - {"shape_inference_graph", - "_outside_compilation_shape_inference_F1_O1"}, + {"shape_inference_graph", shape_inference_graph}, {"shapes", absl::Span({})}, {"_outside_compilation_subgraph", "O1"}}, {"c"}}, @@ -948,16 +955,18 @@ TEST(EncapsulateSubgraphsTest, OneFunctionOneOutside) { Node* key_constant = KeyPlaceholder("F1", b2.opts().WithName("F1_key_placeholder")); - Node* recv = RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O1", - {DT_FLOAT, DT_FLOAT}, b2.opts()); + Node* recv = RecvAtHost( + ops::NodeOut(key_constant, 0), "F1", "O1", {DT_FLOAT, DT_FLOAT}, + b2.opts().WithAttr(kXlaHasHostTransferAttrName, true)); Node* e = Binary(ops::NodeOut(recv, 0), ops::NodeOut(recv, 1), b2.opts() .WithName("E") - .WithControlInputs({recv, b}) + .WithControlInputs({recv}) .WithAttr("_encapsulate", "F1") .WithAttr("_outside", "O1")); Node* send = SendFromHost(ops::NodeOut(key_constant, 0), "F1", "O1", {e}, - b2.opts().WithControlInput(e)); + b2.opts().WithControlInput(e).WithAttr( + kXlaHasHostTransferAttrName, true)); Node* s = Sequencer( b2.opts().WithName("F1_sequencer").WithControlInputs({recv, send}), @@ -966,9 +975,9 @@ TEST(EncapsulateSubgraphsTest, OneFunctionOneOutside) { NodeBuilder node_builder("F1", "F1", lib_def.get()); node_builder.Input(a).Input(b); Node* call = - b2.opts().WithControlInputs({s}).FinalizeBuilder(&node_builder); + b2.opts().WithControlInputs({s, b}).FinalizeBuilder(&node_builder); - Binary(a, call, b2.opts().WithName("G").WithControlInputs({e})); + Binary(a, call, b2.opts().WithName("G").WithControlInputs({call})); TF_EXPECT_OK(b2.ToGraphDef(&graphdef_expected)); } @@ -1022,14 +1031,16 @@ TEST(EncapsulateSubgraphsTest, OneFunctionTwoOutside) { { GraphDefBuilder shape1(GraphDefBuilder::kFailImmediately); Node* key_constant = KeyPlaceholder("F1", shape1.opts()); - Node* recv = RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O1", - {DT_FLOAT, DT_FLOAT}, shape1.opts()); + Node* recv = RecvAtHost( + ops::NodeOut(key_constant, 0), "F1", "O1", {DT_FLOAT, DT_FLOAT}, + shape1.opts().WithAttr(kXlaHasHostTransferAttrName, true)); Node* e = Binary(ops::NodeOut(recv, 0), ops::NodeOut(recv, 1), shape1.opts() .WithName("E") .WithAttr("_encapsulate", "F1") .WithAttr("_outside", "O1")); - SendFromHost(ops::NodeOut(key_constant, 0), "F1", "O1", {e}, shape1.opts()); + SendFromHost(ops::NodeOut(key_constant, 0), "F1", "O1", {e}, + shape1.opts().WithAttr(kXlaHasHostTransferAttrName, true)); TF_EXPECT_OK( AddGraphDefToFunctionLibrary(shape1, "F1_O1", &library_expected)); } @@ -1037,33 +1048,45 @@ TEST(EncapsulateSubgraphsTest, OneFunctionTwoOutside) { { GraphDefBuilder shape2(GraphDefBuilder::kFailImmediately); Node* key_constant = KeyPlaceholder("F1", shape2.opts()); - Node* recv1 = RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O1", - {DT_FLOAT, DT_FLOAT}, shape2.opts()); + Node* recv1 = RecvAtHost( + ops::NodeOut(key_constant, 0), "F1", "O1", {DT_FLOAT, DT_FLOAT}, + shape2.opts().WithAttr(kXlaHasHostTransferAttrName, true)); Node* e = Binary(ops::NodeOut(recv1, 0), ops::NodeOut(recv1, 1), shape2.opts() .WithName("E") .WithAttr("_encapsulate", "F1") .WithAttr("_outside", "O1")); - Node* recv2 = RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O2", - {DT_FLOAT, DT_FLOAT}, shape2.opts()); + Node* recv2 = RecvAtHost( + ops::NodeOut(key_constant, 0), "F1", "O2", {DT_FLOAT, DT_FLOAT}, + shape2.opts().WithAttr(kXlaHasHostTransferAttrName, true)); + Node* g = Binary(e, ops::NodeOut(recv2, 0), + shape2.opts() + .WithName("G") + .WithAttr("_encapsulate", "F1") + .WithAttr("_outside", "O2")); Node* h = Binary(ops::NodeOut(recv2, 1), e, shape2.opts() .WithName("H") .WithAttr("_encapsulate", "F1") .WithAttr("_outside", "O2")); - SendFromHost(ops::NodeOut(key_constant, 0), "F1", "O2", {h}, shape2.opts()); + SendFromHost(ops::NodeOut(key_constant, 0), "F1", "O2", {g, h}, + shape2.opts().WithAttr(kXlaHasHostTransferAttrName, true)); TF_EXPECT_OK( AddGraphDefToFunctionLibrary(shape2, "F1_O2", &library_expected)); } + NameAttrList shape_inference_graph1, shape_inference_graph2; + shape_inference_graph1.set_name("_outside_compilation_shape_inference_F1_O1"); + shape_inference_graph2.set_name("_outside_compilation_shape_inference_F1_O2"); *library_expected.add_function() = FunctionDefHelper::Create( - "F1", {"a_0_arg:float", "b_0_arg:float"}, {"i_0_retval_retval:float"}, {}, + "F1", {"a_0_arg:float", "b_0_arg:float"}, + {"g_0_retval_retval:float", "i_0_retval_retval:float"}, {}, { {{"C"}, "UnaryTest", {"a_0_arg"}}, {{"D"}, "BinaryTest", {"b_0_arg", "C:o:0"}, {}}, {{"I"}, "UnaryTest", - {"outside_compilation_O2_host_compute:outputs:0"}}, + {"outside_compilation_O2_host_compute:outputs:1"}}, {{"F"}, "BinaryTest", {"C:o:0", "outside_compilation_O1_host_compute:outputs:0"}, @@ -1073,11 +1096,10 @@ TEST(EncapsulateSubgraphsTest, OneFunctionTwoOutside) { "XlaHostCompute", {"F:o:0", "D:o:0"}, {{"Tinputs", absl::Span({DT_FLOAT, DT_FLOAT})}, - {"Toutputs", absl::Span({DT_FLOAT})}, + {"Toutputs", absl::Span({DT_FLOAT, DT_FLOAT})}, {"ancestors", absl::Span({})}, {"key", "host_compute_channel_F1_O2"}, - {"shape_inference_graph", - "_outside_compilation_shape_inference_F1_O2"}, + {"shape_inference_graph", shape_inference_graph2}, {"shapes", absl::Span({})}, {"_outside_compilation_subgraph", "O2"}}, {"F"}}, @@ -1088,13 +1110,13 @@ TEST(EncapsulateSubgraphsTest, OneFunctionTwoOutside) { {"Toutputs", absl::Span({DT_FLOAT})}, {"ancestors", absl::Span({})}, {"key", "host_compute_channel_F1_O1"}, - {"shape_inference_graph", - "_outside_compilation_shape_inference_F1_O1"}, + {"shape_inference_graph", shape_inference_graph1}, {"shapes", absl::Span({})}, {"_outside_compilation_subgraph", "O1"}}, {"D"}}, }, - {{"i_0_retval_retval", "I:o:0"}}); + {{"g_0_retval_retval", "outside_compilation_O2_host_compute:outputs:0"}, + {"i_0_retval_retval", "I:o:0"}}); { std::unique_ptr lib_def( @@ -1105,19 +1127,22 @@ TEST(EncapsulateSubgraphsTest, OneFunctionTwoOutside) { Node* key_constant = KeyPlaceholder("F1", b2.opts().WithName("F1_key_placeholder")); - Node* recv1 = RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O1", - {DT_FLOAT, DT_FLOAT}, b2.opts()); + Node* recv1 = RecvAtHost( + ops::NodeOut(key_constant, 0), "F1", "O1", {DT_FLOAT, DT_FLOAT}, + b2.opts().WithAttr(kXlaHasHostTransferAttrName, true)); Node* e = Binary(ops::NodeOut(recv1, 0), ops::NodeOut(recv1, 1), b2.opts() .WithName("E") - .WithControlInputs({recv1, b}) + .WithControlInputs({recv1}) .WithAttr("_encapsulate", "F1") .WithAttr("_outside", "O1")); Node* send1 = SendFromHost(ops::NodeOut(key_constant, 0), "F1", "O1", {e}, - b2.opts().WithControlInput(e)); + b2.opts().WithControlInput(e).WithAttr( + kXlaHasHostTransferAttrName, true)); - Node* recv2 = RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O2", - {DT_FLOAT, DT_FLOAT}, b2.opts()); + Node* recv2 = RecvAtHost( + ops::NodeOut(key_constant, 0), "F1", "O2", {DT_FLOAT, DT_FLOAT}, + b2.opts().WithAttr(kXlaHasHostTransferAttrName, true)); Node* g = Binary(e, ops::NodeOut(recv2, 0), b2.opts() .WithName("G") @@ -1130,7 +1155,8 @@ TEST(EncapsulateSubgraphsTest, OneFunctionTwoOutside) { .WithAttr("_encapsulate", "F1") .WithAttr("_outside", "O2")); Node* send2 = - SendFromHost(ops::NodeOut(key_constant, 0), "F1", "O2", {h}, b2.opts()); + SendFromHost(ops::NodeOut(key_constant, 0), "F1", "O2", {g, h}, + b2.opts().WithAttr(kXlaHasHostTransferAttrName, true)); Node* s = Sequencer(b2.opts() .WithName("F1_sequencer") @@ -1139,12 +1165,13 @@ TEST(EncapsulateSubgraphsTest, OneFunctionTwoOutside) { NodeBuilder node_builder("F1", "F1", lib_def.get()); node_builder.Input(a).Input(b); - Node* call = b2.opts().WithControlInput(s).FinalizeBuilder(&node_builder); + Node* call = + b2.opts().WithControlInputs({s, b}).FinalizeBuilder(&node_builder); - Binary(g, call, b2.opts().WithName("J")); + Binary(ops::NodeOut(call, 0), ops::NodeOut(call, 1), + b2.opts().WithName("J")); TF_EXPECT_OK(b2.ToGraphDef(&graphdef_expected)); } - TF_EXPECT_GRAPH_EQ(graphdef_expected, graphdef); TF_EXPECT_FUNCTIONDEFLIBRARY_EQ(library_expected, library); } @@ -1196,7 +1223,9 @@ TEST(EncapsulateSubgraphsTest, TwoFunctionsTwoOutside) { *library_expected.add_function() = FunctionDefHelper::Create( "F1", {"a_0_arg:float", "b_0_arg:float"}, - {"f_0_retval_retval:float", "d_0_retval_retval:float"}, {}, + {"e_0_retval_retval:float", "f_0_retval_retval:float", + "d_0_retval_retval:float"}, + {}, { {{"C"}, "UnaryTest", {"a_0_arg"}}, {{"D"}, "BinaryTest", {"b_0_arg", "C:o:0"}}, @@ -1212,35 +1241,37 @@ TEST(EncapsulateSubgraphsTest, TwoFunctionsTwoOutside) { {"Toutputs", absl::Span({DT_FLOAT})}, {"ancestors", absl::Span({})}, {"key", "host_compute_channel_F1_O1"}, - {"shape_inference_graph", ""}, + {"shape_inference_graph", NameAttrList()}, {"shapes", absl::Span({shape_proto_expected})}, {"_outside_compilation_subgraph", "O1"}}, {"D"}}, }, - {{"d_0_retval_retval", "D:o:0"}, {"f_0_retval_retval", "F:o:0"}}); + {{"e_0_retval_retval", "outside_compilation_O1_host_compute:outputs:0"}, + {"d_0_retval_retval", "D:o:0"}, + {"f_0_retval_retval", "F:o:0"}}); *library_expected.add_function() = FunctionDefHelper::Create( - "F2", {"f_0_arg:float", "bridge_e_g_0_arg:float"}, - {"i_0_retval_retval:float", "g_0_retval_retval:float"}, {}, + "F2", {"e_0_arg:float", "f_0_arg:float", "d_0_arg:float"}, + {"g_0_retval_retval:float", "i_0_retval_retval:float"}, {}, { - {{"G"}, "BinaryTest", {"bridge_e_g_0_arg", "f_0_arg"}}, + {{"G"}, "BinaryTest", {"e_0_arg", "f_0_arg"}}, {{"I"}, "BinaryTest", {"f_0_arg", "outside_compilation_O1_host_compute:outputs:0"}}, {{"outside_compilation_O1_host_compute"}, "XlaHostCompute", - {"G:o:0"}, - {{"Tinputs", absl::Span({DT_FLOAT})}, + {"d_0_arg", "G:o:0"}, + {{"Tinputs", absl::Span({DT_FLOAT, DT_FLOAT})}, {"Toutputs", absl::Span({DT_FLOAT})}, {"ancestors", absl::Span({})}, {"key", "host_compute_channel_F2_O1"}, - {"shape_inference_graph", ""}, + {"shape_inference_graph", NameAttrList()}, {"shapes", absl::Span({shape_proto_expected})}, {"_outside_compilation_subgraph", "O1"}}}, }, - {{"i_0_retval_retval", "I:o:0"}, {"g_0_retval_retval", "G:o:0"}}); + {{"g_0_retval_retval", "G:o:0"}, {"i_0_retval_retval", "I:o:0"}}); { std::unique_ptr lib_def( @@ -1251,16 +1282,18 @@ TEST(EncapsulateSubgraphsTest, TwoFunctionsTwoOutside) { Node* key_constant1 = KeyPlaceholder("F1", b2.opts().WithName("F1_key_placeholder")); - Node* recv1 = RecvAtHost(ops::NodeOut(key_constant1, 0), "F1", "O1", - {DT_FLOAT, DT_FLOAT}, b2.opts()); + Node* recv1 = RecvAtHost( + ops::NodeOut(key_constant1, 0), "F1", "O1", {DT_FLOAT, DT_FLOAT}, + b2.opts().WithAttr(kXlaHasHostTransferAttrName, true)); Node* e = Binary(ops::NodeOut(recv1, 0), ops::NodeOut(recv1, 1), b2.opts() .WithName("E") - .WithControlInputs({recv1, b}) + .WithControlInputs({recv1}) .WithAttr("_encapsulate", "F1") .WithAttr("_outside", "O1")); Node* send1 = SendFromHost(ops::NodeOut(key_constant1, 0), "F1", "O1", {e}, - b2.opts().WithControlInput(e)); + b2.opts().WithControlInput(e).WithAttr( + kXlaHasHostTransferAttrName, true)); Node* s1 = Sequencer( b2.opts().WithName("F1_sequencer").WithControlInputs({recv1, send1}), "F1"); @@ -1268,29 +1301,33 @@ TEST(EncapsulateSubgraphsTest, TwoFunctionsTwoOutside) { NodeBuilder node_builder1("F1", "F1", lib_def.get()); node_builder1.Input(a).Input(b); Node* call1 = - b2.opts().WithControlInput(s1).FinalizeBuilder(&node_builder1); + b2.opts().WithControlInputs({s1, b}).FinalizeBuilder(&node_builder1); Node* key_constant2 = KeyPlaceholder("F2", b2.opts().WithName("F2_key_placeholder")); - Node* recv2 = RecvAtHost(ops::NodeOut(key_constant2, 0), "F2", "O1", - {DT_FLOAT}, b2.opts()); - Node* h = Binary(ops::NodeOut(call1, 1), recv2, + Node* recv2 = RecvAtHost( + ops::NodeOut(key_constant2, 0), "F2", "O1", {DT_FLOAT, DT_FLOAT}, + b2.opts().WithAttr(kXlaHasHostTransferAttrName, true)); + Node* h = Binary(recv2, ops::NodeOut(recv2, 1), b2.opts() .WithName("H") .WithAttr("_encapsulate", "F2") .WithAttr("_outside", "O1")); - Node* send2 = SendFromHost(ops::NodeOut(key_constant2, 0), "F2", "O1", {h}, - b2.opts()); + Node* send2 = + SendFromHost(ops::NodeOut(key_constant2, 0), "F2", "O1", {h}, + b2.opts().WithAttr(kXlaHasHostTransferAttrName, true)); Node* s2 = Sequencer( b2.opts().WithName("F2_sequencer").WithControlInputs({recv2, send2}), "F2"); NodeBuilder node_builder2("F2", "F2", lib_def.get()); - node_builder2.Input(call1).Input(e); + node_builder2.Input(call1) + .Input(ops::NodeOut(call1, 1)) + .Input(ops::NodeOut(call1, 2)); Node* call2 = b2.opts() - .WithControlInputs({s2, e, call1}) + .WithControlInputs({s2, call1}) .FinalizeBuilder(&node_builder2); - Binary(ops::NodeOut(call2, 1), call2, b2.opts().WithName("J")); + Binary(call2, ops::NodeOut(call2, 1), b2.opts().WithName("J")); TF_EXPECT_OK(b2.ToGraphDef(&graphdef_expected)); } @@ -1326,8 +1363,7 @@ TEST(EncapsulateSubgraphsTest, TwoFunctionsTwoOutsideDependencyFromOutside) { Node* h = Unary(g, b1.opts() .WithName("H") .WithAttr("_encapsulate", "F2") - .WithAttr("_outside", "O1") - .WithControlInput(e)); + .WithAttr("_outside", "O1")); Node* i = Unary(h, b1.opts().WithName("I").WithAttr("_encapsulate", "F2")); Binary(f, i, b1.opts().WithName("J")); TF_EXPECT_OK(b1.ToGraphDef(&graphdef)); @@ -1358,7 +1394,7 @@ TEST(EncapsulateSubgraphsTest, TwoFunctionsTwoOutsideDependencyFromOutside) { {"Toutputs", absl::Span({DT_FLOAT})}, {"ancestors", absl::Span({})}, {"key", "host_compute_channel_F1_O1"}, - {"shape_inference_graph", ""}, + {"shape_inference_graph", NameAttrList()}, {"shapes", absl::Span({shape_proto_expected})}, {"_outside_compilation_subgraph", "O1"}}, @@ -1380,7 +1416,7 @@ TEST(EncapsulateSubgraphsTest, TwoFunctionsTwoOutsideDependencyFromOutside) { {"Toutputs", absl::Span({DT_FLOAT})}, {"ancestors", absl::Span({})}, {"key", "host_compute_channel_F2_O1"}, - {"shape_inference_graph", ""}, + {"shape_inference_graph", NameAttrList()}, {"shapes", absl::Span({shape_proto_expected})}, {"_outside_compilation_subgraph", "O1"}}}, @@ -1401,7 +1437,7 @@ TEST(EncapsulateSubgraphsTest, TwoFunctionsTwoOutsideDependencyFromOutside) { Node* e = Binary(ops::NodeOut(recv1, 0), ops::NodeOut(recv1, 1), b2.opts() .WithName("E") - .WithControlInputs({recv1, b}) + .WithControlInputs({recv1}) .WithAttr("_encapsulate", "F1") .WithAttr("_outside", "O1")); Node* send1 = SendFromHost(ops::NodeOut(key_constant1, 0), "F1", "O1", {e}, @@ -1413,7 +1449,7 @@ TEST(EncapsulateSubgraphsTest, TwoFunctionsTwoOutsideDependencyFromOutside) { NodeBuilder node_builder1("F1", "F1", lib_def.get()); node_builder1.Input(a).Input(b); Node* call1 = - b2.opts().WithControlInput(s1).FinalizeBuilder(&node_builder1); + b2.opts().WithControlInputs({s1, b}).FinalizeBuilder(&node_builder1); Node* key_constant2 = KeyPlaceholder("F2", b2.opts().WithName("F2_key_placeholder")); @@ -1422,8 +1458,7 @@ TEST(EncapsulateSubgraphsTest, TwoFunctionsTwoOutsideDependencyFromOutside) { Node* h = Unary(recv2, b2.opts() .WithName("H") .WithAttr("_encapsulate", "F2") - .WithAttr("_outside", "O1") - .WithControlInput(e)); + .WithAttr("_outside", "O1")); Node* send2 = SendFromHost(ops::NodeOut(key_constant2, 0), "F2", "O1", {h}, b2.opts()); @@ -1484,12 +1519,12 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationNoInputs) { {"D:o:0", "outside_compilation_O1_host_compute:outputs:0"}}, {{"outside_compilation_O1_host_compute"}, "XlaHostCompute", - {}, - {{"Tinputs", absl::Span({})}, + {"a_0_arg"}, + {{"Tinputs", absl::Span({DT_FLOAT})}, {"Toutputs", absl::Span({DT_FLOAT})}, {"ancestors", absl::Span({})}, {"key", "host_compute_channel_F1_O1"}, - {"shape_inference_graph", ""}, + {"shape_inference_graph", NameAttrList()}, {"shapes", absl::Span({shape_proto_expected})}, {"_outside_compilation_subgraph", "O1"}}}, @@ -1503,16 +1538,19 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationNoInputs) { Node* a = InputShaped(b2.opts().WithName("A")); Node* b = Input(b2.opts().WithName("B")); - Node* e = Unary(a, b2.opts() - .WithName("E") - .WithAttr("_encapsulate", "F1") - .WithAttr("_outside", "O1")); Node* key_constant = KeyPlaceholder("F1", b2.opts().WithName("F1_key_placeholder")); + Node* recv1 = RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O1", + {DT_FLOAT}, b2.opts()); + Node* e = Unary(recv1, b2.opts() + .WithName("E") + .WithAttr("_encapsulate", "F1") + .WithAttr("_outside", "O1")); Node* send1 = SendFromHost(ops::NodeOut(key_constant, 0), "F1", "O1", {e}, b2.opts()); Node* s1 = Sequencer( - b2.opts().WithName("F1_sequencer").WithControlInput(send1), "F1"); + b2.opts().WithName("F1_sequencer").WithControlInputs({send1, recv1}), + "F1"); NodeBuilder node_builder1("F1", "F1", lib_def.get()); node_builder1.Input(a).Input(b); Node* call1 = @@ -1569,12 +1607,12 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationControlInput) { {"D:o:0", "outside_compilation_O1_host_compute:outputs:0"}}, {{"outside_compilation_O1_host_compute"}, "XlaHostCompute", - {}, - {{"Tinputs", absl::Span({})}, + {"a_0_arg"}, + {{"Tinputs", absl::Span({DT_FLOAT})}, {"Toutputs", absl::Span({DT_FLOAT})}, {"ancestors", absl::Span({})}, {"key", "host_compute_channel_F1_O1"}, - {"shape_inference_graph", ""}, + {"shape_inference_graph", NameAttrList()}, {"shapes", absl::Span({shape_proto_expected})}, {"_outside_compilation_subgraph", "O1"}}, @@ -1591,13 +1629,13 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationControlInput) { Node* key_constant = KeyPlaceholder("F1", b2.opts().WithName("F1_key_placeholder")); - Node* recv1 = - RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O1", {}, b2.opts()); - Node* e = Unary(a, b2.opts() - .WithName("E") - .WithControlInput(recv1) - .WithAttr("_encapsulate", "F1") - .WithAttr("_outside", "O1")); + Node* recv1 = RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O1", + {DT_FLOAT}, b2.opts()); + Node* e = Unary(recv1, b2.opts() + .WithName("E") + .WithControlInput(recv1) + .WithAttr("_encapsulate", "F1") + .WithAttr("_outside", "O1")); Node* send1 = SendFromHost(ops::NodeOut(key_constant, 0), "F1", "O1", {e}, b2.opts()); Node* s1 = Sequencer( @@ -1644,8 +1682,27 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationNoOutputs) { FunctionDefLibrary library_expected; GraphDef graphdef_expected; + { + GraphDefBuilder shape1(GraphDefBuilder::kFailImmediately); + Node* key_constant = KeyPlaceholder("F1", shape1.opts()); + Node* recv1 = + RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O1", {DT_FLOAT}, + shape1.opts().WithAttr(kXlaHasHostTransferAttrName, true)); + Node* e = Unary(ops::NodeOut(recv1, 0), shape1.opts() + .WithName("E") + .WithAttr("_encapsulate", "F1") + .WithAttr("_outside", "O1")); + SendFromHost(ops::NodeOut(key_constant, 0), "F1", "O1", {e}, + shape1.opts().WithAttr(kXlaHasHostTransferAttrName, true)); + TF_EXPECT_OK( + AddGraphDefToFunctionLibrary(shape1, "F1_O1", &library_expected)); + } + + NameAttrList shape_inference_graph; + shape_inference_graph.set_name("_outside_compilation_shape_inference_F1_O1"); *library_expected.add_function() = FunctionDefHelper::Create( - "F1", {"a_0_arg:float", "b_0_arg:float"}, {"f_0_retval_retval:float"}, {}, + "F1", {"a_0_arg:float", "b_0_arg:float"}, + {"e_0_retval_retval:float", "f_0_retval_retval:float"}, {}, { {{"C"}, "UnaryTest", {"a_0_arg"}}, {{"D"}, "BinaryTest", {"b_0_arg", "C:o:0"}}, @@ -1654,14 +1711,15 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationNoOutputs) { "XlaHostCompute", {"D:o:0"}, {{"Tinputs", absl::Span({DT_FLOAT})}, - {"Toutputs", absl::Span({})}, + {"Toutputs", absl::Span({DT_FLOAT})}, {"ancestors", absl::Span({})}, {"key", "host_compute_channel_F1_O1"}, - {"shape_inference_graph", ""}, + {"shape_inference_graph", shape_inference_graph}, {"shapes", absl::Span({})}, {"_outside_compilation_subgraph", "O1"}}}, }, - {{"f_0_retval_retval", "F:o:0"}}); + {{"e_0_retval_retval", "outside_compilation_O1_host_compute:outputs:0"}, + {"f_0_retval_retval", "F:o:0"}}); { std::unique_ptr lib_def( @@ -1678,14 +1736,17 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationNoOutputs) { .WithName("E") .WithAttr("_encapsulate", "F1") .WithAttr("_outside", "O1")); + Node* send1 = + SendFromHost(ops::NodeOut(key_constant, 0), "F1", "O1", {e}, b2.opts()); Node* s1 = Sequencer( - b2.opts().WithName("F1_sequencer").WithControlInput(recv1), "F1"); + b2.opts().WithName("F1_sequencer").WithControlInputs({recv1, send1}), + "F1"); NodeBuilder node_builder1("F1", "F1", lib_def.get()); node_builder1.Input(a).Input(b); Node* call1 = b2.opts().WithControlInput(s1).FinalizeBuilder(&node_builder1); - Binary(e, call1, b2.opts().WithName("G")); + Binary(call1, ops::NodeOut(call1, 1), b2.opts().WithName("G")); TF_EXPECT_OK(b2.ToGraphDef(&graphdef_expected)); } @@ -1722,8 +1783,27 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationControlOutput) { FunctionDefLibrary library_expected; GraphDef graphdef_expected; + { + GraphDefBuilder shape1(GraphDefBuilder::kFailImmediately); + Node* key_constant = KeyPlaceholder("F1", shape1.opts()); + Node* recv1 = + RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O1", {DT_FLOAT}, + shape1.opts().WithAttr(kXlaHasHostTransferAttrName, true)); + Node* e = Unary(ops::NodeOut(recv1, 0), shape1.opts() + .WithName("E") + .WithAttr("_encapsulate", "F1") + .WithAttr("_outside", "O1")); + SendFromHost(ops::NodeOut(key_constant, 0), "F1", "O1", {e}, + shape1.opts().WithAttr(kXlaHasHostTransferAttrName, true)); + TF_EXPECT_OK( + AddGraphDefToFunctionLibrary(shape1, "F1_O1", &library_expected)); + } + + NameAttrList shape_inference_graph; + shape_inference_graph.set_name("_outside_compilation_shape_inference_F1_O1"); *library_expected.add_function() = FunctionDefHelper::Create( - "F1", {"a_0_arg:float", "b_0_arg:float"}, {"f_0_retval_retval:float"}, {}, + "F1", {"a_0_arg:float", "b_0_arg:float"}, + {"e_0_retval_retval:float", "f_0_retval_retval:float"}, {}, { {{"C"}, "UnaryTest", {"a_0_arg"}}, {{"D"}, "BinaryTest", {"b_0_arg", "C:o:0"}}, @@ -1736,14 +1816,15 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationControlOutput) { "XlaHostCompute", {"D:o:0"}, {{"Tinputs", absl::Span({DT_FLOAT})}, - {"Toutputs", absl::Span({})}, + {"Toutputs", absl::Span({DT_FLOAT})}, {"ancestors", absl::Span({})}, {"key", "host_compute_channel_F1_O1"}, - {"shape_inference_graph", ""}, + {"shape_inference_graph", shape_inference_graph}, {"shapes", absl::Span({})}, {"_outside_compilation_subgraph", "O1"}}}, }, - {{"f_0_retval_retval", "F:o:0"}}); + {{"e_0_retval_retval", "outside_compilation_O1_host_compute:outputs:0"}, + {"f_0_retval_retval", "F:o:0"}}); { std::unique_ptr lib_def( @@ -1760,7 +1841,7 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationControlOutput) { .WithName("E") .WithAttr("_encapsulate", "F1") .WithAttr("_outside", "O1")); - Node* send1 = SendFromHost(ops::NodeOut(key_constant, 0), "F1", "O1", {}, + Node* send1 = SendFromHost(ops::NodeOut(key_constant, 0), "F1", "O1", {e}, b2.opts().WithControlInput(e)); Node* s1 = Sequencer( b2.opts().WithName("F1_sequencer").WithControlInputs({recv1, send1}), @@ -1770,7 +1851,7 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationControlOutput) { Node* call1 = b2.opts().WithControlInput(s1).FinalizeBuilder(&node_builder1); - Binary(e, call1, b2.opts().WithName("G")); + Binary(call1, ops::NodeOut(call1, 1), b2.opts().WithName("G")); TF_EXPECT_OK(b2.ToGraphDef(&graphdef_expected)); } @@ -1813,22 +1894,45 @@ TEST(EncapsulateSubgraphsTest, FunctionDefLibrary library_expected; GraphDef graphdef_expected; + { + GraphDefBuilder shape1(GraphDefBuilder::kFailImmediately); + Node* key_constant = KeyPlaceholder("F1", shape1.opts()); + Node* recv1 = + RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O1", {DT_FLOAT}, + shape1.opts().WithAttr(kXlaHasHostTransferAttrName, true)); + Node* e = Unary(ops::NodeOut(recv1, 0), shape1.opts() + .WithName("E") + .WithAttr("_encapsulate", "F1") + .WithAttr("_outside", "O1")); + SendFromHost(ops::NodeOut(key_constant, 0), "F1", "O1", {e}, + shape1.opts().WithAttr(kXlaHasHostTransferAttrName, true)); + TF_EXPECT_OK( + AddGraphDefToFunctionLibrary(shape1, "F1_O1", &library_expected)); + } + { GraphDefBuilder shape2(GraphDefBuilder::kFailImmediately); Node* key_constant = KeyPlaceholder("F1", shape2.opts()); - Node* recv2 = RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O2", - {DT_FLOAT}, shape2.opts()); + Node* recv2 = + RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O2", {DT_FLOAT}, + shape2.opts().WithAttr(kXlaHasHostTransferAttrName, true)); Node* g = Unary(ops::NodeOut(recv2, 0), shape2.opts() .WithName("G") .WithAttr("_encapsulate", "F1") .WithAttr("_outside", "O2")); - SendFromHost(ops::NodeOut(key_constant, 0), "F1", "O2", {g}, shape2.opts()); + SendFromHost(ops::NodeOut(key_constant, 0), "F1", "O2", {g}, + shape2.opts().WithAttr(kXlaHasHostTransferAttrName, true)); TF_EXPECT_OK( AddGraphDefToFunctionLibrary(shape2, "F1_O2", &library_expected)); } + NameAttrList shape_inference_graph1; + shape_inference_graph1.set_name("_outside_compilation_shape_inference_F1_O1"); + NameAttrList shape_inference_graph2; + shape_inference_graph2.set_name("_outside_compilation_shape_inference_F1_O2"); *library_expected.add_function() = FunctionDefHelper::Create( - "F1", {"a_0_arg:float", "b_0_arg:float"}, {"h_0_retval_retval:float"}, {}, + "F1", {"a_0_arg:float", "b_0_arg:float"}, + {"e_0_retval_retval:float", "h_0_retval_retval:float"}, {}, { {{"C"}, "UnaryTest", {"a_0_arg"}}, {{"D"}, "BinaryTest", {"b_0_arg", "C:o:0"}}, @@ -1836,6 +1940,16 @@ TEST(EncapsulateSubgraphsTest, {{"H"}, "UnaryTest", {"outside_compilation_O2_host_compute:outputs:0"}}, + {{"outside_compilation_O1_host_compute"}, + "XlaHostCompute", + {"a_0_arg"}, + {{"Tinputs", absl::Span({DT_FLOAT})}, + {"Toutputs", absl::Span({DT_FLOAT})}, + {"ancestors", absl::Span({})}, + {"key", "host_compute_channel_F1_O1"}, + {"shape_inference_graph", shape_inference_graph1}, + {"shapes", absl::Span({})}, + {"_outside_compilation_subgraph", "O1"}}}, {{"outside_compilation_O2_host_compute"}, "XlaHostCompute", {"F:o:0"}, @@ -1843,12 +1957,12 @@ TEST(EncapsulateSubgraphsTest, {"Toutputs", absl::Span({DT_FLOAT})}, {"ancestors", absl::Span({})}, {"key", "host_compute_channel_F1_O2"}, - {"shape_inference_graph", - "_outside_compilation_shape_inference_F1_O2"}, + {"shape_inference_graph", shape_inference_graph2}, {"shapes", absl::Span({})}, {"_outside_compilation_subgraph", "O2"}}}, }, - {{"h_0_retval_retval", "H:o:0"}}); + {{"e_0_retval_retval", "outside_compilation_O1_host_compute:outputs:0"}, + {"h_0_retval_retval", "H:o:0"}}); { std::unique_ptr lib_def( @@ -1856,30 +1970,39 @@ TEST(EncapsulateSubgraphsTest, GraphDefBuilder b2(GraphDefBuilder::kFailImmediately, lib_def.get()); Node* a = Input(b2.opts().WithName("A")); Node* b = Input(b2.opts().WithName("B")); - - Node* e = Unary(a, b2.opts() - .WithName("E") - .WithAttr("_encapsulate", "F1") - .WithAttr("_outside", "O1")); Node* key_constant = KeyPlaceholder("F1", b2.opts().WithName("F1_key_placeholder")); - Node* recv = RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O2", - {DT_FLOAT}, b2.opts()); - Node* g = Unary(recv, b2.opts() - .WithName("G") - .WithAttr("_encapsulate", "F1") - .WithAttr("_outside", "O2") - .WithControlInput(e)); - Node* send = - SendFromHost(ops::NodeOut(key_constant, 0), "F1", "O2", {g}, b2.opts()); - Node* s1 = Sequencer( - b2.opts().WithName("F1_sequencer").WithControlInputs({recv, send}), - "F1"); + Node* recv1 = + RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O1", {DT_FLOAT}, + b2.opts().WithAttr(kXlaHasHostTransferAttrName, true)); + + Node* e = Unary(recv1, b2.opts() + .WithName("E") + .WithAttr("_encapsulate", "F1") + .WithAttr("_outside", "O1")); + Node* send1 = + SendFromHost(ops::NodeOut(key_constant, 0), "F1", "O1", {e}, + b2.opts().WithAttr(kXlaHasHostTransferAttrName, true)); + Node* recv2 = + RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O2", {DT_FLOAT}, + b2.opts().WithAttr(kXlaHasHostTransferAttrName, true)); + Node* g = Unary(recv2, b2.opts() + .WithName("G") + .WithAttr("_encapsulate", "F1") + .WithAttr("_outside", "O2") + .WithControlInput(e)); + Node* send2 = + SendFromHost(ops::NodeOut(key_constant, 0), "F1", "O2", {g}, + b2.opts().WithAttr(kXlaHasHostTransferAttrName, true)); + Node* s1 = Sequencer(b2.opts() + .WithName("F1_sequencer") + .WithControlInputs({recv1, send1, recv2, send2}), + "F1"); NodeBuilder node_builder1("F1", "F1", lib_def.get()); node_builder1.Input(a).Input(b).ControlInput(s1); Node* call1 = b2.opts().FinalizeBuilder(&node_builder1); - Binary(e, call1, b2.opts().WithName("I")); + Binary(call1, ops::NodeOut(call1, 1), b2.opts().WithName("I")); TF_EXPECT_OK(b2.ToGraphDef(&graphdef_expected)); } @@ -1925,19 +2048,24 @@ TEST(EncapsulateSubgraphsTest, { GraphDefBuilder shape1(GraphDefBuilder::kFailImmediately); Node* key_constant = KeyPlaceholder("F1", shape1.opts()); - Node* recv2 = RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O1", - {DT_FLOAT}, shape1.opts()); + Node* recv2 = + RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O1", {DT_FLOAT}, + shape1.opts().WithAttr(kXlaHasHostTransferAttrName, true)); Node* e = Unary(ops::NodeOut(recv2, 0), shape1.opts() .WithName("E") .WithAttr("_encapsulate", "F1") .WithAttr("_outside", "O1")); - SendFromHost(ops::NodeOut(key_constant, 0), "F1", "O1", {e}, shape1.opts()); + SendFromHost(ops::NodeOut(key_constant, 0), "F1", "O1", {e}, + shape1.opts().WithAttr(kXlaHasHostTransferAttrName, true)); TF_EXPECT_OK( AddGraphDefToFunctionLibrary(shape1, "F1_O1", &library_expected)); } + NameAttrList shape_inference_graph; + shape_inference_graph.set_name("_outside_compilation_shape_inference_F1_O1"); *library_expected.add_function() = FunctionDefHelper::Create( - "F1", {"a_0_arg:float", "b_0_arg:float"}, {"h_0_retval_retval:float"}, {}, + "F1", {"a_0_arg:float", "b_0_arg:float"}, + {"e_0_retval_retval:float", "h_0_retval_retval:float"}, {}, { {{"C"}, "UnaryTest", {"a_0_arg"}}, {{"D"}, "BinaryTest", {"b_0_arg", "C:o:0"}}, @@ -1945,6 +2073,16 @@ TEST(EncapsulateSubgraphsTest, "UnaryTest", {"outside_compilation_O1_host_compute:outputs:0"}}, {{"H"}, "UnaryTest", {"F:o:0"}}, + {{"outside_compilation_O2_host_compute"}, + "XlaHostCompute", + {"a_0_arg"}, + {{"Tinputs", absl::Span({DT_FLOAT})}, + {"Toutputs", absl::Span({})}, + {"ancestors", absl::Span({})}, + {"key", "host_compute_channel_F1_O2"}, + {"shape_inference_graph", NameAttrList()}, + {"shapes", absl::Span({})}, + {"_outside_compilation_subgraph", "O2"}}}, {{"outside_compilation_O1_host_compute"}, "XlaHostCompute", {"D:o:0"}, @@ -1952,12 +2090,12 @@ TEST(EncapsulateSubgraphsTest, {"Toutputs", absl::Span({DT_FLOAT})}, {"ancestors", absl::Span({})}, {"key", "host_compute_channel_F1_O1"}, - {"shape_inference_graph", - "_outside_compilation_shape_inference_F1_O1"}, + {"shape_inference_graph", shape_inference_graph}, {"shapes", absl::Span({})}, {"_outside_compilation_subgraph", "O1"}}}, }, - {{"h_0_retval_retval", "H:o:0"}}); + {{"e_0_retval_retval", "outside_compilation_O1_host_compute:outputs:0"}, + {"h_0_retval_retval", "H:o:0"}}); { std::unique_ptr lib_def( @@ -1968,27 +2106,33 @@ TEST(EncapsulateSubgraphsTest, Node* key_constant = KeyPlaceholder("F1", b2.opts().WithName("F1_key_placeholder")); - Node* recv = RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O1", - {DT_FLOAT}, b2.opts()); - Node* e = Unary(recv, b2.opts() - .WithName("E") - .WithAttr("_encapsulate", "F1") - .WithAttr("_outside", "O1")); + Node* recv1 = + RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O1", {DT_FLOAT}, + b2.opts().WithAttr(kXlaHasHostTransferAttrName, true)); + Node* e = Unary(recv1, b2.opts() + .WithName("E") + .WithAttr("_encapsulate", "F1") + .WithAttr("_outside", "O1")); Node* send = - SendFromHost(ops::NodeOut(key_constant, 0), "F1", "O1", {e}, b2.opts()); - /*Node* g =*/Unary(a, b2.opts() - .WithName("G") - .WithAttr("_encapsulate", "F1") - .WithAttr("_outside", "O2") - .WithControlInput(e)); - Node* s1 = Sequencer( - b2.opts().WithName("F1_sequencer").WithControlInputs({recv, send}), - "F1"); + SendFromHost(ops::NodeOut(key_constant, 0), "F1", "O1", {e}, + b2.opts().WithAttr(kXlaHasHostTransferAttrName, true)); + Node* recv2 = + RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O2", {DT_FLOAT}, + b2.opts().WithAttr(kXlaHasHostTransferAttrName, true)); + /*Node* g =*/Unary(recv2, b2.opts() + .WithName("G") + .WithAttr("_encapsulate", "F1") + .WithAttr("_outside", "O2") + .WithControlInput(e)); + Node* s1 = Sequencer(b2.opts() + .WithName("F1_sequencer") + .WithControlInputs({recv1, recv2, send}), + "F1"); NodeBuilder node_builder1("F1", "F1", lib_def.get()); node_builder1.Input(a).Input(b).ControlInput(s1); Node* call1 = b2.opts().FinalizeBuilder(&node_builder1); - Binary(e, call1, b2.opts().WithName("I")); + Binary(call1, ops::NodeOut(call1, 1), b2.opts().WithName("I")); TF_EXPECT_OK(b2.ToGraphDef(&graphdef_expected)); } @@ -2039,19 +2183,24 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationClusterDependency) { { GraphDefBuilder shape1(GraphDefBuilder::kFailImmediately); Node* key_constant = KeyPlaceholder("F1", shape1.opts()); - Node* recv2 = RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O1", - {DT_FLOAT}, shape1.opts()); + Node* recv2 = + RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O1", {DT_FLOAT}, + shape1.opts().WithAttr(kXlaHasHostTransferAttrName, true)); Node* e = Unary(ops::NodeOut(recv2, 0), shape1.opts() .WithName("E") .WithAttr("_encapsulate", "F1") .WithAttr("_outside", "O1")); - SendFromHost(ops::NodeOut(key_constant, 0), "F1", "O1", {e}, shape1.opts()); + SendFromHost(ops::NodeOut(key_constant, 0), "F1", "O1", {e}, + shape1.opts().WithAttr(kXlaHasHostTransferAttrName, true)); TF_EXPECT_OK( AddGraphDefToFunctionLibrary(shape1, "F1_O1", &library_expected)); } + NameAttrList shape_inference_graph; + shape_inference_graph.set_name("_outside_compilation_shape_inference_F1_O1"); *library_expected.add_function() = FunctionDefHelper::Create( - "F1", {"a_0_arg:float", "b_0_arg:float"}, {"h_0_retval_retval:float"}, {}, + "F1", {"a_0_arg:float", "b_0_arg:float"}, + {"e_0_retval_retval:float", "h_0_retval_retval:float"}, {}, {{{"C"}, "UnaryTest", {"a_0_arg"}}, {{"D"}, "BinaryTest", {"b_0_arg", "C:o:0"}}, {{"F"}, "UnaryTest", {"outside_compilation_O1_host_compute:outputs:0"}}, @@ -2063,8 +2212,7 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationClusterDependency) { {"Toutputs", absl::Span({DT_FLOAT})}, {"ancestors", absl::Span({})}, {"key", "host_compute_channel_F1_O1"}, - {"shape_inference_graph", - "_outside_compilation_shape_inference_F1_O1"}, + {"shape_inference_graph", shape_inference_graph}, {"shapes", absl::Span({})}, {"_outside_compilation_subgraph", "O1"}}}, {{"outside_compilation_O2_host_compute"}, @@ -2074,7 +2222,7 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationClusterDependency) { {"Toutputs", absl::Span({})}, {"ancestors", absl::Span({})}, {"key", "host_compute_channel_F1_O2"}, - {"shape_inference_graph", ""}, + {"shape_inference_graph", NameAttrList()}, {"shapes", absl::Span({})}, {"_outside_compilation_subgraph", "O2"}}, {}}, @@ -2085,11 +2233,12 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationClusterDependency) { {"Toutputs", absl::Span({})}, {"ancestors", absl::Span({})}, {"key", "host_compute_channel_F1_O3"}, - {"shape_inference_graph", ""}, + {"shape_inference_graph", NameAttrList()}, {"shapes", absl::Span({})}, {"_outside_compilation_subgraph", "O3"}}, {}}}, - {{"h_0_retval_retval", "H:o:0"}}); + {{"e_0_retval_retval", "outside_compilation_O1_host_compute:outputs:0"}, + {"h_0_retval_retval", "H:o:0"}}); { std::unique_ptr lib_def( @@ -2100,23 +2249,27 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationClusterDependency) { Node* key_constant = KeyPlaceholder("F1", b2.opts().WithName("F1_key_placeholder")); - Node* recv1 = RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O1", - {DT_FLOAT}, b2.opts()); + Node* recv1 = + RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O1", {DT_FLOAT}, + b2.opts().WithAttr(kXlaHasHostTransferAttrName, true)); Node* e = Unary(recv1, b2.opts() .WithName("E") .WithAttr("_encapsulate", "F1") .WithAttr("_outside", "O1")); Node* send = - SendFromHost(ops::NodeOut(key_constant, 0), "F1", "O1", {e}, b2.opts()); - Node* recv2 = RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O2", - {DT_FLOAT}, b2.opts()); + SendFromHost(ops::NodeOut(key_constant, 0), "F1", "O1", {e}, + b2.opts().WithAttr(kXlaHasHostTransferAttrName, true)); + Node* recv2 = + RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O2", {DT_FLOAT}, + b2.opts().WithAttr(kXlaHasHostTransferAttrName, true)); Node* g = Unary(recv2, b2.opts() .WithName("G") .WithAttr("_encapsulate", "F1") .WithAttr("_outside", "O2") .WithControlInput(e)); - Node* recv3 = RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O3", - {DT_FLOAT}, b2.opts()); + Node* recv3 = + RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O3", {DT_FLOAT}, + b2.opts().WithAttr(kXlaHasHostTransferAttrName, true)); /*Node* i =*/Binary(recv3, e, b2.opts() .WithName("I") @@ -2131,7 +2284,7 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationClusterDependency) { node_builder1.Input(a).Input(b).ControlInput(s1); Node* call1 = b2.opts().FinalizeBuilder(&node_builder1); - Binary(e, call1, b2.opts().WithName("J")); + Binary(call1, ops::NodeOut(call1, 1), b2.opts().WithName("J")); TF_EXPECT_OK(b2.ToGraphDef(&graphdef_expected)); } @@ -2167,14 +2320,44 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationNoInputsOrOutputs) { FunctionDefLibrary library_expected; GraphDef graphdef_expected; + { + GraphDefBuilder shape1(GraphDefBuilder::kFailImmediately); + Node* key_constant = KeyPlaceholder("F1", shape1.opts()); + Node* recv2 = + RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O1", {DT_FLOAT}, + shape1.opts().WithAttr(kXlaHasHostTransferAttrName, true)); + Node* e = Unary(ops::NodeOut(recv2, 0), shape1.opts() + .WithName("E") + .WithAttr("_encapsulate", "F1") + .WithAttr("_outside", "O1")); + SendFromHost(ops::NodeOut(key_constant, 0), "F1", "O1", {e}, + shape1.opts().WithAttr(kXlaHasHostTransferAttrName, true)); + TF_EXPECT_OK( + AddGraphDefToFunctionLibrary(shape1, "F1_O1", &library_expected)); + } + + NameAttrList shape_inference_graph; + shape_inference_graph.set_name("_outside_compilation_shape_inference_F1_O1"); *library_expected.add_function() = FunctionDefHelper::Create( - "F1", {"a_0_arg:float", "b_0_arg:float"}, {"f_0_retval_retval:float"}, {}, + "F1", {"a_0_arg:float", "b_0_arg:float"}, + {"e_0_retval_retval:float", "f_0_retval_retval:float"}, {}, { {{"C"}, "UnaryTest", {"a_0_arg"}}, {{"D"}, "BinaryTest", {"b_0_arg", "C:o:0"}}, {{"F"}, "UnaryTest", {"D:o:0"}}, + {{"outside_compilation_O1_host_compute"}, + "XlaHostCompute", + {"a_0_arg"}, + {{"Tinputs", absl::Span({DT_FLOAT})}, + {"Toutputs", absl::Span({DT_FLOAT})}, + {"ancestors", absl::Span({})}, + {"key", "host_compute_channel_F1_O1"}, + {"shape_inference_graph", shape_inference_graph}, + {"shapes", absl::Span({})}, + {"_outside_compilation_subgraph", "O1"}}}, }, - {{"f_0_retval_retval", "F:o:0"}}); + {{"e_0_retval_retval", "outside_compilation_O1_host_compute:outputs:0"}, + {"f_0_retval_retval", "F:o:0"}}); { std::unique_ptr lib_def( @@ -2183,15 +2366,26 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationNoInputsOrOutputs) { Node* a = Input(b2.opts().WithName("A")); Node* b = Input(b2.opts().WithName("B")); - Node* e = Unary(a, b2.opts() - .WithName("E") - .WithAttr("_encapsulate", "F1") - .WithAttr("_outside", "O1")); + Node* key_constant = + KeyPlaceholder("F1", b2.opts().WithName("F1_key_placeholder")); + Node* recv = + RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O1", {DT_FLOAT}, + b2.opts().WithAttr(kXlaHasHostTransferAttrName, true)); + Node* e = Unary(recv, b2.opts() + .WithName("E") + .WithAttr("_encapsulate", "F1") + .WithAttr("_outside", "O1")); + Node* send = + SendFromHost(ops::NodeOut(key_constant, 0), "F1", "O1", {e}, + b2.opts().WithAttr(kXlaHasHostTransferAttrName, true)); + Node* s = Sequencer( + b2.opts().WithName("F1_sequencer").WithControlInputs({recv, send}), + "F1"); NodeBuilder node_builder1("F1", "F1", lib_def.get()); - node_builder1.Input(a).Input(b); + node_builder1.Input(a).Input(b).ControlInput(s); Node* call1 = b2.opts().FinalizeBuilder(&node_builder1); - Binary(e, call1, b2.opts().WithName("G")); + Binary(call1, ops::NodeOut(call1, 1), b2.opts().WithName("G")); TF_EXPECT_OK(b2.ToGraphDef(&graphdef_expected)); } @@ -2236,20 +2430,22 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationShapeInference) { { GraphDefBuilder shape(GraphDefBuilder::kFailImmediately); Node* key_constant = KeyPlaceholder("F1", shape.opts()); - Node* recv = RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O1", - {DT_FLOAT}, shape.opts()); - Node* a = InputShaped(shape.opts().WithName("A")); - Node* c = Unary(a, shape.opts().WithName("C")); - Node* e = BinaryUnknownShape(c, recv, + Node* recv = RecvAtHost( + ops::NodeOut(key_constant, 0), "F1", "O1", {DT_FLOAT, DT_FLOAT}, + shape.opts().WithAttr(kXlaHasHostTransferAttrName, true)); + Node* e = BinaryUnknownShape(recv, ops::NodeOut(recv, 1), shape.opts() .WithName("E") .WithAttr("_encapsulate", "F1") .WithAttr("_outside", "O1")); - SendFromHost(ops::NodeOut(key_constant, 0), "F1", "O1", {e}, shape.opts()); + SendFromHost(ops::NodeOut(key_constant, 0), "F1", "O1", {e}, + shape.opts().WithAttr(kXlaHasHostTransferAttrName, true)); TF_EXPECT_OK( AddGraphDefToFunctionLibrary(shape, "F1_O1", &library_expected)); } + NameAttrList shape_inference_graph; + shape_inference_graph.set_name("_outside_compilation_shape_inference_F1_O1"); *library_expected.add_function() = test::function::XTimesTwo(); *library_expected.add_function() = FunctionDefHelper::Create( "F1", {"b_0_arg:float", "c_0_arg:float"}, {"f_0_retval_retval:float"}, {}, @@ -2262,13 +2458,12 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationShapeInference) { {"outside_compilation_O1_host_compute"}}, {{"outside_compilation_O1_host_compute"}, "XlaHostCompute", - {"c:o:0"}, - {{"Tinputs", absl::Span({DT_FLOAT})}, + {"c_0_arg", "c:o:0"}, + {{"Tinputs", absl::Span({DT_FLOAT, DT_FLOAT})}, {"Toutputs", absl::Span({DT_FLOAT})}, {"ancestors", absl::Span({})}, {"key", "host_compute_channel_F1_O1"}, - {"shape_inference_graph", - "_outside_compilation_shape_inference_F1_O1"}, + {"shape_inference_graph", shape_inference_graph}, {"shapes", absl::Span({})}, {"_outside_compilation_subgraph", "O1"}}, {"c"}}, @@ -2285,16 +2480,18 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationShapeInference) { Node* key_constant = KeyPlaceholder("F1", b2.opts().WithName("F1_key_placeholder")); - Node* recv = RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O1", - {DT_FLOAT}, b2.opts()); - Node* e = BinaryUnknownShape(c, ops::NodeOut(recv, 0), + Node* recv = RecvAtHost( + ops::NodeOut(key_constant, 0), "F1", "O1", {DT_FLOAT, DT_FLOAT}, + b2.opts().WithAttr(kXlaHasHostTransferAttrName, true)); + Node* e = BinaryUnknownShape(recv, ops::NodeOut(recv, 1), b2.opts() .WithName("E") - .WithControlInputs({recv, b}) + .WithControlInputs({recv}) .WithAttr("_encapsulate", "F1") .WithAttr("_outside", "O1")); Node* send = SendFromHost(ops::NodeOut(key_constant, 0), "F1", "O1", {e}, - b2.opts().WithControlInput(e)); + b2.opts().WithControlInput(e).WithAttr( + kXlaHasHostTransferAttrName, true)); Node* s = Sequencer( b2.opts().WithName("F1_sequencer").WithControlInputs({recv, send}), @@ -2303,9 +2500,9 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationShapeInference) { NodeBuilder node_builder("F1", "F1", lib_def.get()); node_builder.Input(b).Input(c); Node* call = - b2.opts().WithControlInputs({s, c}).FinalizeBuilder(&node_builder); + b2.opts().WithControlInputs({s, b, c}).FinalizeBuilder(&node_builder); - Binary(a, call, b2.opts().WithName("G").WithControlInputs({e})); + Binary(a, call, b2.opts().WithName("G").WithControlInputs({call})); TF_EXPECT_OK(b2.ToGraphDef(&graphdef_expected)); } diff --git a/tensorflow/compiler/jit/encapsulate_util.cc b/tensorflow/compiler/jit/encapsulate_util.cc index 1f4b9c90a4ff0b1166cdb7b5942771b350740ef3..2264806d6bdabd9f26d9f83b681524399f996317 100644 --- a/tensorflow/compiler/jit/encapsulate_util.cc +++ b/tensorflow/compiler/jit/encapsulate_util.cc @@ -62,517 +62,6 @@ void ReplaceAttr(Node* n, const string& attr_name, const T& value) { n->AddAttr(attr_name, value); } -// Step 1a ~ 1d for PreprocessForEncapsulation(). See comments of -// PreprocessForEncapsulation() for details. -Status ProcessControlEdges(Graph* g, const string& xla_computation_attr_name, - const string& outside_compilation_attr_name) { - // Gather edges to remove. We should not remove the edge while iterating. - std::vector edges_to_remove; - for (const Edge* e : g->edges()) { - if (!e->IsControlEdge()) { - continue; - } - - auto src_xla_computation = - GetStringAttr(*e->src(), xla_computation_attr_name); - auto dst_xla_computation = - GetStringAttr(*e->dst(), xla_computation_attr_name); - auto src_outside_compilation = - GetStringAttr(*e->src(), outside_compilation_attr_name); - auto dst_outside_compilation = - GetStringAttr(*e->dst(), outside_compilation_attr_name); - - if (!src_xla_computation && !dst_xla_computation) { - continue; - } else if (src_xla_computation && !dst_xla_computation) { - if (src_outside_compilation) { - // Case 1c: outside compilation to host computation control edge. - edges_to_remove.push_back(e); - - TF_RETURN_IF_ERROR(AppendToListAttr( - e->dst(), kXlaControlDependenciesAttrName, e->src()->name())); - } - } else if (!src_xla_computation && dst_xla_computation) { - if (dst_outside_compilation) { - // Case 1c: host computation control to outside compilation edge. - edges_to_remove.push_back(e); - - TF_RETURN_IF_ERROR(AppendToListAttr( - e->dst(), kXlaControlDependenciesAttrName, e->src()->name())); - } - } else { // src_xla_computation && dst_xla_computation - if (*src_xla_computation != *dst_xla_computation) { - if (src_outside_compilation && dst_outside_compilation) { - // Case 1b: outside compilation to outside compilation control edge. - edges_to_remove.push_back(e); - - TF_RETURN_IF_ERROR(AppendToListAttr( - e->dst(), kXlaControlDependenciesAttrName, e->src()->name())); - } else if (src_outside_compilation && !dst_outside_compilation) { - // Case 1a: outside compilation to another XLA computaition control - // edge. - TF_RETURN_IF_ERROR(AppendToListAttr( - e->src(), kXlaConnectedToOtherXlaComputationAttrName, - *dst_xla_computation)); - } else if (!src_outside_compilation && dst_outside_compilation) { - // Case 1a: another XLA computaition to outside compilation control - // edge. - TF_RETURN_IF_ERROR(AppendToListAttr( - e->dst(), kXlaConnectedFromOtherXlaComputationAttrName, - *src_xla_computation)); - } - } - } - } - - for (auto e : edges_to_remove) { - g->RemoveEdge(e); - } - return Status::OK(); -} - -// Step 2 for PreprocessForEncapsulation(). See comments of -// PreprocessForEncapsulation() for details. -Status ProcessXlaToXlaDataEdges(Graph* g, - const string& xla_computation_attr_name, - const string& outside_compilation_attr_name) { - // Gather edges between XLA computations. Notice that we do not store `Edge*` - // directly because we remove some nodes while adding Identity nodes, and - // those Edge pointers might be invalidated. - struct EdgeInfo { - int dst_input, dst_node_id; - }; - std::vector edges; - for (const Edge* e : g->edges()) { - if (e->IsControlEdge()) { - continue; - } - - auto src_xla_computation = - GetStringAttr(*e->src(), xla_computation_attr_name); - auto dst_xla_computation = - GetStringAttr(*e->dst(), xla_computation_attr_name); - auto src_outside_compilation = - GetStringAttr(*e->src(), outside_compilation_attr_name); - auto dst_outside_compilation = - GetStringAttr(*e->dst(), outside_compilation_attr_name); - if (!src_xla_computation || !dst_xla_computation) { - continue; - } - - if (*src_xla_computation != *dst_xla_computation) { - if (src_outside_compilation || dst_outside_compilation) { - edges.push_back(EdgeInfo{e->dst_input(), e->dst()->id()}); - VLOG(4) << "XLA -> XLA edge: " << e->DebugString(); - } - } - } - - // For each XLA -> XLA edge, add an Identity node between src and dst. - for (int i = 0; i < edges.size(); i++) { - Node* dst = g->FindNodeId(edges[i].dst_node_id); - const Edge* e; - TF_RETURN_IF_ERROR(dst->input_edge(edges[i].dst_input, &e)); - Node* src = e->src(); - int src_output = e->src_output(), dst_input = e->dst_input(); - g->RemoveEdge(e); - - // Create Identity node, and connect it between `src` and `dst`. - string identity_node_name = - absl::StrCat("bridge_", src->name(), "_", dst->name()); - DataType dtype = src->output_type(src_output); - TF_ASSIGN_OR_RETURN(Node * identity_node, - BuildIdentityNode(g, identity_node_name, dtype, src, - /*requested_device=*/absl::nullopt)); - identity_node->AddAttr(kBridgeSourceNodeAttrName, src->name()); - g->AddEdge(src, src_output, identity_node, 0); - g->AddEdge(identity_node, 0, dst, dst_input); - - // Replace `e->dst()` because its input node changed. - NodeDef new_def = dst->def(); - *new_def.mutable_input(dst_input) = identity_node->name(); - TF_ASSIGN_OR_RETURN(Node * dst_replace_node, ReplaceNode(g, dst, new_def)); - - // Other edge in `edges` might have `e->dst()` as src or dst - // node. Before removing `e->dst()`, replace those edges with corresponding - // edges for `dst_replace_node`. - for (int j = i + 1; j < edges.size(); j++) { - if (edges[j].dst_node_id == edges[i].dst_node_id) { - edges[j].dst_node_id = dst_replace_node->id(); - } - } - } - return Status::OK(); -} - -// Step 3 for PreprocessForEncapsulation(). See comments of -// PreprocessForEncapsulation() for details. -Status ProcessDataEdgeBetweenOutsideCompilationAndHostComputation( - Graph* g, const string& xla_computation_attr_name, - const string& outside_compilation_attr_name) { - // Gather edges between outside compilation and host computation. Notice that - // we do not store `Edge*` directly because we remove some nodes while adding - // Identity nodes, and those Edge pointers might be invalidated. - struct EdgeInfo { - int dst_input, dst_node_id; - bool is_host_to_outside_compilation; - }; - std::vector edges; - for (const Edge* e : g->edges()) { - if (e->IsControlEdge()) { - continue; - } - - if (e->src()->attrs().Find(xla_computation_attr_name) == nullptr && - e->dst()->attrs().Find(xla_computation_attr_name) != nullptr && - e->dst()->attrs().Find(outside_compilation_attr_name) != nullptr) { - edges.push_back(EdgeInfo{e->dst_input(), e->dst()->id(), - /*is_host_to_outside_compilation=*/true}); - VLOG(4) << "Host -> oc edge: " << e->DebugString(); - } else if (e->dst()->attrs().Find(xla_computation_attr_name) == nullptr && - e->src()->attrs().Find(xla_computation_attr_name) != nullptr && - e->src()->attrs().Find(outside_compilation_attr_name) != - nullptr) { - edges.push_back(EdgeInfo{e->dst_input(), e->dst()->id(), - /*is_host_to_outside_compilation=*/false}); - VLOG(4) << "Oc -> host edge: " << e->DebugString(); - } - } - - // Remove the edge from host to outside compilation. Add a placeholder as - // outside compilation node input. - std::map, Node*> placeholders; - for (int i = 0; i < edges.size(); i++) { - Node* dst = g->FindNodeId(edges[i].dst_node_id); - const Edge* e; - TF_RETURN_IF_ERROR(dst->input_edge(edges[i].dst_input, &e)); - Node* src = e->src(); - int src_output = e->src_output(), dst_input = e->dst_input(); - g->RemoveEdge(e); - - // Find or create placeholder node. - string new_name = - edges[i].is_host_to_outside_compilation - ? absl::StrCat(src->name(), "_host_to_oc_placeholder_", src_output) - : absl::StrCat(src->name(), "_oc_to_host_placeholder_", src_output); - auto placeholder_index = std::make_pair(src->name(), src_output); - auto iter = placeholders.find(placeholder_index); - Node* placeholder_node; - if (iter == placeholders.end()) { - NodeDefBuilder placeholder_builder(new_name, "Placeholder"); - placeholder_builder.Attr("dtype", src->output_type(src_output)); - if (edges[i].is_host_to_outside_compilation) { - placeholder_builder.Attr(kHostToOutsideCompilationOriginalNodeAttrName, - src->name()); - placeholder_builder.Attr(kHostToOutsideCompilationSrcOutputAttrName, - src_output); - // If this placeholder node is in outside compilation, we need to set - // `xla_computation_attr_name` and `outside_compilation_attr_name`. - string xla_computation_attr, outside_compilation_attr; - TF_RETURN_IF_ERROR(GetNodeAttr(dst->attrs(), xla_computation_attr_name, - &xla_computation_attr)); - TF_RETURN_IF_ERROR(GetNodeAttr(dst->attrs(), - outside_compilation_attr_name, - &outside_compilation_attr)); - placeholder_builder.Attr(xla_computation_attr_name, - xla_computation_attr); - placeholder_builder.Attr(outside_compilation_attr_name, - outside_compilation_attr); - } else { - placeholder_builder.Attr(kOutsideCompilationToHostOriginalNodeAttrName, - src->name()); - placeholder_builder.Attr(kOutsideCompilationToHostSrcOutputAttrName, - src_output); - } - NodeDef placeholder_def; - TF_RETURN_IF_ERROR(placeholder_builder.Finalize(&placeholder_def)); - Status s; - placeholder_node = g->AddNode(placeholder_def, &s); - TF_RETURN_IF_ERROR(s); - placeholders[placeholder_index] = placeholder_node; - } else { - placeholder_node = iter->second; - } - g->AddEdge(placeholder_node, 0, dst, dst_input); - - // Replace `e->dst()` because its input node changed. - NodeDef new_def = dst->def(); - *new_def.mutable_input(dst_input) = placeholder_node->name(); - TF_ASSIGN_OR_RETURN(Node * dst_replace_node, ReplaceNode(g, dst, new_def)); - - // Other edge in `edges` might have `e->dst()` as src or dst - // node. Before removing `e->dst()`, replace those edges with corresponding - // edges for `dst_replace_node`. - for (int j = i + 1; j < edges.size(); j++) { - if (edges[j].dst_node_id == edges[i].dst_node_id) { - edges[j].dst_node_id = dst_replace_node->id(); - } - } - } - return Status::OK(); -} - -// Step 1 for `PostprocessForEncapsulation`. See comments of -// `PostprocessForEncapsulation` for details. -Status RemovePlaceholderBetweenOutsideCompilationAndHostComputation(Graph* g) { - // Gather all outside compilation to host computation nodes. - struct PlaceHolderNodeInfo { - Node* n; - bool is_host_to_oc; - }; - std::vector placeholder_nodes; - for (Node* n : g->nodes()) { - if (n->type_string() == "Placeholder") { - if (HasNodeAttr(n->def(), - kOutsideCompilationToHostOriginalNodeAttrName)) { - placeholder_nodes.push_back({n, false}); - } else if (HasNodeAttr(n->def(), - kHostToOutsideCompilationOriginalNodeAttrName)) { - placeholder_nodes.push_back({n, true}); - } - } - } - - // Remove the placeholder nodes, and reconnect original edge. - auto node_name_index = g->BuildNodeNameIndex(); - for (auto placeholder_iter : placeholder_nodes) { - Node* n = placeholder_iter.n; - - string node_name; - int node_src_output; - if (placeholder_iter.is_host_to_oc) { - TF_RETURN_IF_ERROR( - GetNodeAttr(n->attrs(), kHostToOutsideCompilationOriginalNodeAttrName, - &node_name)); - TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), - kHostToOutsideCompilationSrcOutputAttrName, - &node_src_output)); - } else { - TF_RETURN_IF_ERROR( - GetNodeAttr(n->attrs(), kOutsideCompilationToHostOriginalNodeAttrName, - &node_name)); - TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), - kOutsideCompilationToHostSrcOutputAttrName, - &node_src_output)); - } - auto iter = node_name_index.find(node_name); - if (iter == node_name_index.end()) { - return errors::Internal( - "Cannot find original node for oc -> host placeholder node ", - node_name); - } - - // Change all usage node to use the original node instead. - Node* original_node = iter->second; - std::vector control_edges; - std::vector data_edges; - for (auto e : n->out_edges()) { - if (e->IsControlEdge()) { - control_edges.push_back(e); - } else { - data_edges.push_back({e->dst(), e->src_output(), e->dst_input()}); - } - } - for (const Edge* e : control_edges) { - g->AddControlEdge(original_node, e->dst()); - g->RemoveEdge(e); - } - for (int i = 0; i < data_edges.size(); i++) { - Node* dst = data_edges[i].dst; - NodeDef new_def = dst->def(); - int dst_input = data_edges[i].dst_input; - *new_def.mutable_input(dst_input) = - absl::StrCat(original_node->name(), ":", node_src_output); - TF_ASSIGN_OR_RETURN(Node * replace_node, ReplaceNode(g, dst, new_def)); - - const Edge* edge_to_replace = nullptr; - TF_RETURN_IF_ERROR(replace_node->input_edge(dst_input, &edge_to_replace)); - g->RemoveEdge(edge_to_replace); - g->AddEdge(original_node, node_src_output, replace_node, dst_input); - - // Other edges might have `dst` as dst node. Update those edges with - // `replace_node`. - for (int j = i + 1; j < data_edges.size(); j++) { - if (data_edges[j].dst == dst) { - data_edges[j].dst = replace_node; - } - } - - // Other placeholder node might have `dst` as original node. Update - // `node_name_index` with `replace_node`. - node_name_index[replace_node->name()] = replace_node; - } - - // Remove placeholder node. - g->RemoveNode(n); - } - return Status::OK(); -} - -// Step 2 for `PostprocessForEncapsulation`. See comments of -// `PostprocessForEncapsulation` for details. -Status RemoveIdentityBetweenDifferentXlaComputation(Graph* g) { - // Gather Identity nodes to remove. - std::vector bridge_nodes; - for (Node* n : g->nodes()) { - if (n->type_string() == "Identity" && - HasNodeAttr(n->def(), kBridgeSourceNodeAttrName)) { - bridge_nodes.push_back(n); - } - } - - // Remove the identity nodes, and reconnect the original edge. - for (int i = 0; i < bridge_nodes.size(); i++) { - Node* n = bridge_nodes[i]; - const Edge* src_edge = nullptr; - TF_RETURN_IF_ERROR(n->input_edge(0, &src_edge)); - - // Change all usage node to use the original node instead. - std::vector control_edges; - std::vector data_edges; - for (auto e : n->out_edges()) { - if (e->IsControlEdge()) { - control_edges.push_back(e); - } else { - data_edges.push_back({e->dst(), e->src_output(), e->dst_input()}); - } - } - for (const Edge* e : control_edges) { - g->AddControlEdge(src_edge->src(), e->dst()); - g->RemoveEdge(e); - } - for (int j = 0; j < data_edges.size(); j++) { - Node* dst = data_edges[j].dst; - NodeDef new_def = dst->def(); - int dst_input = data_edges[j].dst_input; - *new_def.mutable_input(dst_input) = - absl::StrCat(src_edge->src()->name(), ":", src_edge->src_output()); - TF_ASSIGN_OR_RETURN(Node * replace_node, ReplaceNode(g, dst, new_def)); - - const Edge* edge_to_replace = nullptr; - TF_RETURN_IF_ERROR(replace_node->input_edge(dst_input, &edge_to_replace)); - g->RemoveEdge(edge_to_replace); - g->AddEdge(src_edge->src(), src_edge->src_output(), replace_node, - dst_input); - - // Other edges might have `dst` as dst node. Update those edges with - // `replace_node`. - for (int k = j + 1; k < data_edges.size(); k++) { - if (data_edges[k].dst == dst) { - data_edges[k].dst = replace_node; - } - } - - // The node we replaced might be in `bridge_nodes`. If so, update - // `bridge_nodes` to use the replaced node. - for (int k = i + 1; k < bridge_nodes.size(); k++) { - if (bridge_nodes[k] == dst) { - bridge_nodes[k] = replace_node; - } - } - } - - // Remove Identity node. - g->RemoveNode(n); - } - return Status::OK(); -} - -// Step 3 for `PostprocessForEncapsulation`. See comments of -// `PostprocessForEncapsulation` for details. -// We do not need to worry about removed nodes in step 1 and 2; -// `PreprocessForEncapsulation` will not record control dependencies for those -// remvoed nodes in the first place. -Status AddControlDependencies( - Graph* g, const std::unordered_map& cluster_node_names) { - auto node_name_index = g->BuildNodeNameIndex(); - - // Reconnect outside compilation to outside compilation control edge. - for (Node* n : g->nodes()) { - std::vector control_deps; - Status s = - GetNodeAttr(n->attrs(), kXlaControlDependenciesAttrName, &control_deps); - if (!s.ok()) { - if (s.code() != error::NOT_FOUND) { - return s; - } else { - continue; - } - } else { - n->ClearAttr(kXlaControlDependenciesAttrName); - for (const string& control_input : control_deps) { - auto iter = node_name_index.find(control_input); - if (iter == node_name_index.end()) { - return errors::Internal("Cannot find original node for ", - control_input); - } - g->AddControlEdge(iter->second, n); - } - } - } - - // Reconnect outside compilation to XLA computation control edge. - for (Node* n : g->nodes()) { - std::vector control_deps; - Status s = GetNodeAttr( - n->attrs(), kXlaConnectedToOtherXlaComputationAttrName, &control_deps); - if (!s.ok()) { - if (s.code() != error::NOT_FOUND) { - return s; - } else { - continue; - } - } else { - n->ClearAttr(kXlaConnectedToOtherXlaComputationAttrName); - for (const string& control_input : control_deps) { - auto iter = cluster_node_names.find(control_input); - if (iter == cluster_node_names.end()) { - return errors::Internal("Cannot find cluster node for ", - control_input); - } - auto iter2 = node_name_index.find(iter->second); - if (iter2 == node_name_index.end()) { - return errors::Internal("Cannot find cluster node for ", - iter->second); - } - g->AddControlEdge(n, iter2->second); - } - } - } - - // Reconnect XLA computation to outside compilation control edge. - for (Node* n : g->nodes()) { - std::vector control_deps; - Status s = - GetNodeAttr(n->attrs(), kXlaConnectedFromOtherXlaComputationAttrName, - &control_deps); - if (!s.ok()) { - if (s.code() != error::NOT_FOUND) { - return s; - } else { - continue; - } - } else { - n->ClearAttr(kXlaConnectedFromOtherXlaComputationAttrName); - for (const string& control_input : control_deps) { - auto iter = cluster_node_names.find(control_input); - if (iter == cluster_node_names.end()) { - return errors::Internal("Cannot find cluster node for ", - control_input); - } - auto iter2 = node_name_index.find(iter->second); - if (iter2 == node_name_index.end()) { - return errors::Internal("Cannot find cluster node for ", - iter->second); - } - g->AddControlEdge(iter2->second, n); - } - } - } - - return Status::OK(); -} - // Step 1 for `PreprocessEdgesBetweenOutsideCompilations`. See comments of // `PreprocessEdgesBetweenOutsideCompilations` for details. Status PreprocessControlEdgesBetweenOutsideCompilations( @@ -811,20 +300,6 @@ Status PostprocessControlEdgesBetweenOutsideCompilations( const char kXlaInferredShapesAttrName[] = "_xla_inferred_shapes"; -const char kXlaConnectedToOtherXlaComputationAttrName[] = - "_xla_connected_to_other_xla_computation"; -const char kXlaConnectedFromOtherXlaComputationAttrName[] = - "_xla_connected_from_other_xla_computation"; -const char kXlaControlDependenciesAttrName[] = "_xla_control_dependencies"; -const char kBridgeSourceNodeAttrName[] = "_xla_bridge_src"; -const char kOutsideCompilationToHostOriginalNodeAttrName[] = - "_xla_oc_to_host_node_name"; -const char kOutsideCompilationToHostSrcOutputAttrName[] = - "_xla_oc_to_host_src_output"; -const char kHostToOutsideCompilationOriginalNodeAttrName[] = - "_xla_host_to_oc_node_name"; -const char kHostToOutsideCompilationSrcOutputAttrName[] = - "_xla_host_to_oc_src_output"; const char kXlaConnectedToXlaComputationAttrName[] = "_xla_connected_to_xla_computation"; const char kXlaConnectedFromXlaComputationAttrName[] = @@ -835,32 +310,7 @@ const char kOutsideCompilationSrcOutputAttrName[] = "_xla_oc_to_oc_src_output"; const char kXlaControlDependenciesWithinXlaClusterAttrName[] = "_xla_control_dependencies_within_xla_cluster"; -Status PerformStaticShapeInferenceBeforeEncapsulation( - Graph* g, const string& xla_computation_attr_name, - const string& outside_compilation_attr_name) { - // Find all outside compilation to XLA computation data edges. - std::unordered_set outside_compilation_send_nodes; - for (auto e : g->edges()) { - if (e->IsControlEdge()) { - continue; - } - - auto src_computation = GetStringAttr(*e->src(), xla_computation_attr_name); - auto dst_computation = GetStringAttr(*e->dst(), xla_computation_attr_name); - if (!src_computation || !dst_computation || - *src_computation != *dst_computation) { - continue; - } - - auto src_outside_compilation = - GetStringAttr(*e->src(), outside_compilation_attr_name); - auto dst_outside_compilation = - GetStringAttr(*e->dst(), outside_compilation_attr_name); - if (src_outside_compilation && !dst_outside_compilation) { - outside_compilation_send_nodes.insert(e->src()); - } - } - +Status PerformStaticShapeInferenceBeforeEncapsulation(Graph* g) { // Perform shape inference. std::map arg_shapes; GraphShapeInfo shape_info; @@ -868,55 +318,21 @@ Status PerformStaticShapeInferenceBeforeEncapsulation( InferShapes(g, arg_shapes, /*fnlib_def=*/nullptr, &shape_info)); // Add attribute for output shapes. - for (Node* n : outside_compilation_send_nodes) { - auto iter = shape_info.find(n->name()); - if (iter == shape_info.end()) { - continue; - } - + auto node_name_index = g->BuildNodeNameIndex(); + for (auto iter : shape_info) { std::vector output_shapes; - std::transform(iter->second.begin(), iter->second.end(), + std::transform(iter.second.begin(), iter.second.end(), std::back_inserter(output_shapes), [](const InferredShape& inferred_shape) { return inferred_shape.shape; }); + Node* n = node_name_index[iter.first]; n->AddAttr(kXlaInferredShapesAttrName, output_shapes); } return Status::OK(); } -Status PreprocessForEncapsulation(Graph* g, - const string& xla_computation_attr_name, - const string& outside_compilation_attr_name) { - TF_RETURN_IF_ERROR(ProcessControlEdges(g, xla_computation_attr_name, - outside_compilation_attr_name)); - TF_RETURN_IF_ERROR(ProcessXlaToXlaDataEdges(g, xla_computation_attr_name, - outside_compilation_attr_name)); - TF_RETURN_IF_ERROR(ProcessDataEdgeBetweenOutsideCompilationAndHostComputation( - g, xla_computation_attr_name, outside_compilation_attr_name)); - return Status::OK(); -} - -Status PostprocessForEncapsulation( - Graph* g, const string& xla_computation_attr_name, - const string& outside_compilation_attr_name, - const std::unordered_map& clusters) { - // The `node` pointer in `XlaClusterInfo` might be invalidated in step 1/2, - // but the node name won't change. Record cluster node name for - // `AddControlDependencies`. - std::unordered_map cluster_node_names; - for (const auto& iter : clusters) { - cluster_node_names[iter.first] = iter.second.node->name(); - } - - TF_RETURN_IF_ERROR( - RemovePlaceholderBetweenOutsideCompilationAndHostComputation(g)); - TF_RETURN_IF_ERROR(RemoveIdentityBetweenDifferentXlaComputation(g)); - TF_RETURN_IF_ERROR(AddControlDependencies(g, cluster_node_names)); - return Status::OK(); -} - Status PreprocessEdgesBetweenOutsideCompilations( Graph* g, const string& outside_compilation_attr_name) { // Remove edges from source node to outside compilation nodes, and edges diff --git a/tensorflow/compiler/jit/encapsulate_util.h b/tensorflow/compiler/jit/encapsulate_util.h index e363bc5754ac395bae262dc67a780a0173efaf5e..c9f16d14168163e11bb19092f566f1de8724aca3 100644 --- a/tensorflow/compiler/jit/encapsulate_util.h +++ b/tensorflow/compiler/jit/encapsulate_util.h @@ -27,51 +27,13 @@ namespace tensorflow { // a list of PartialTensorShape objects. extern const char kXlaInferredShapesAttrName[]; -// Infer output shapes for outside compilation nodes which have output data -// edges to XLA computation nodes. These shapes will be used later by XLA -// compiler as output shapes of the outside compilation's XlaHostCompute op. -// XLA computation nodes will be mark by attr `xla_computation_attr_name`; -// outside compilation nodes will be marked by both attr -// `xla_computation_attr_name` and `outside_compilation_attr_name`. -// -// Those outside compilation nodes will be marked with attribute -// `kXlaInferredShapesAttrName`. +// Infers output shapes for all nodes in graph `g`. The output shapes will be +// stored in node attribute `kXlaInferredShapesAttrName`. // // We have to perform shape inference before encapsulation because after // encapsulation, some nodes will be encapsulated into function call, and shape // inference does not handle function call at the moment. -Status PerformStaticShapeInferenceBeforeEncapsulation( - Graph* g, const string& xla_computation_attr_name, - const string& outside_compilation_attr_name); - -// Attribute indicating that some ops in other XLA computation has control -// dependency on this node. Attribute value will be a list of string (XLA -// computation names). -extern const char kXlaConnectedToOtherXlaComputationAttrName[]; - -// Attribute indicating that this node has control dependency on some ops in -// other XLA computation. Attribute value will be a list of string (XLA -// computation names). -extern const char kXlaConnectedFromOtherXlaComputationAttrName[]; - -// Attribute indicating that this node has control dependencies on some other -// nodes. Attribute value will be a list of string (node names). -extern const char kXlaControlDependenciesAttrName[]; - -// Attribute indicating that this is an Identity node added to act as a bridge -// between different XLA computations. Attribute value will be string (source -// node name). -extern const char kBridgeSourceNodeAttrName[]; - -// Attribute indicating that this is an Placeholder node added to act as a -// temporary input node for an outside compilation node. Attribute value will be -// string (original input node name). -extern const char kOutsideCompilationToHostOriginalNodeAttrName[]; - -// Attribute indicating that this is an Placeholder node added to act as a -// temporary input node for an outside compilation node. Attribute value will be -// int (src_output for original edge). -extern const char kOutsideCompilationToHostSrcOutputAttrName[]; +Status PerformStaticShapeInferenceBeforeEncapsulation(Graph* g); // Attribute indicating that some ops in this node's XLA computation has control // dependency on this node. Attribute value will always be "true". @@ -81,16 +43,6 @@ extern const char kXlaConnectedToXlaComputationAttrName[]; // this node's XLA computation. Attribute value will always be "true". extern const char kXlaConnectedFromXlaComputationAttrName[]; -// Attribute indicating that this is an Placeholder node added to act as a -// temporary input node for an host node. Attribute value will be string -// (original input node name). -extern const char kHostToOutsideCompilationOriginalNodeAttrName[]; - -// Attribute indicating that this is an Placeholder node added to act as a -// temporary input node for a host node. Attribute value will be int (src_output -// for original edge). -extern const char kHostToOutsideCompilationSrcOutputAttrName[]; - // Attribute indicating that this is an Placeholder node added to act as a // temporary input node for an outside compilation node. Attribute value will be // string (original input node name). @@ -106,27 +58,6 @@ extern const char kOutsideCompilationSrcOutputAttrName[]; // (node names). extern const char kXlaControlDependenciesWithinXlaClusterAttrName[]; -// Preprocesses edges between different XLA clusters for encapsulation. It will -// perform the following operations in order: -// -// 1a. For control edges between outside compilation and another XLA -// computation, add attr "kXlaConnected{From, To}OtherXlaComputationAttrName -// = XLA computation node name" to the outside compilation node. -// 1b. For control edges between different outside compilations (in different -// XLA computations), remove the edge and add attr -// "kXlaControlDependenciesAttrName = src node name" to dst node. -// 1c. For control edges between outside compilation and host computation, -// remove the edge and add attr "kXlaControlDependenciesAttrName = src node -// name" to dst node. -// 2. For data edges between different XLA computations, if either src or dst -// is outside compilation, add an Identity node in between the edge. The -// identity node will have attr kBridgeSourceNodeAttrName. -// 3. For data edges between outside compilation and host computation, remove -// the edge and create a Placeholder node as dst node's input. -Status PreprocessForEncapsulation(Graph* g, - const string& xla_computation_attr_name, - const string& outside_compilation_attr_name); - // Information for XLA computation. struct XlaClusterInfo { // Add an explicitly-defined default constructor for this class. @@ -158,24 +89,6 @@ struct XlaClusterInfo { const std::map host_compute_core; }; -// Postprocesses edges between different XLA clusters for encapsulation. This -// function reverts what `PreprocessForEncapsulation` did. It will perform the -// following operations in order: -// -// 1. Remove Placeholder nodes between outside compilation and host computation -// (created in `PreprocessForEncapsulation` step 3). -// 2. Remove Identity nodes created in `PreprocessForEncapsulation` step 2. -// 3a. Reconnect control edges between outside compilation and another XLA -// computation (marked by `PreprocessForEncapsulation` step 1a). -// 3b. Reconnect control edges between different outside compilations (marked by -// `PreprocessForEncapsulation` step 1b). -// 3c. Reconnect control edges between outside compilation and host computation -// (marked by `PreprocessForEncapsulation` step 1c). -Status PostprocessForEncapsulation( - Graph* g, const string& xla_computation_attr_name, - const string& outside_compilation_attr_name, - const std::unordered_map& clusters); - // Preprocesses edges within the same XLA cluster. It will perform the following // operations in order: // diff --git a/tensorflow/compiler/jit/encapsulate_util_test.cc b/tensorflow/compiler/jit/encapsulate_util_test.cc index 3b8b49cb92f3e453883a8e64e12ce3748a5173f6..3bb979e0698d2d6be42ed5bae66c25267928192c 100644 --- a/tensorflow/compiler/jit/encapsulate_util_test.cc +++ b/tensorflow/compiler/jit/encapsulate_util_test.cc @@ -38,24 +38,11 @@ TEST(PerformStaticShapeInferenceBeforeEncapsulationTest, Basic) { Graph g(OpRegistry::Global()); TF_CHECK_OK(s.ToGraph(&g)); - // "add" node is outside compilation node, "identity" node is XLA node. - auto node_index = g.BuildNodeNameIndex(); - Node *add_node = node_index["add"], *identity_node = node_index["identity"]; - add_node->AddAttr("_xla", "cluster"); - add_node->AddAttr("_oc", "cluster"); - identity_node->AddAttr("_xla", "cluster"); - TF_CHECK_OK( - PerformStaticShapeInferenceBeforeEncapsulation(&g, "_xla", "_oc")); + TF_CHECK_OK(PerformStaticShapeInferenceBeforeEncapsulation(&g)); - // Check that only "add" node now has _xla_inferred_shapes attr. - std::vector nodes_with_inferred_shape; - for (Node *n : g.nodes()) { - if (HasNodeAttr(n->def(), kXlaInferredShapesAttrName)) { - nodes_with_inferred_shape.push_back(n); - } - } - EXPECT_EQ(nodes_with_inferred_shape.size(), 1); - EXPECT_EQ(nodes_with_inferred_shape[0], add_node); + // Check that "add" node now has _xla_inferred_shapes attr. + auto node_index = g.BuildNodeNameIndex(); + Node *add_node = node_index["add"]; std::vector output_shapes; TF_CHECK_OK(GetNodeAttr(add_node->attrs(), kXlaInferredShapesAttrName, &output_shapes)); @@ -66,329 +53,4 @@ TEST(PerformStaticShapeInferenceBeforeEncapsulationTest, Basic) { EXPECT_EQ(shape_proto.dim(0).size(), 2); } -TEST(PreprocessForEncapsulationTest, ControlEdges) { - // Build the graph: - // "const_0" and "const_1" in host computation - // "add" = "const_0" + "const_1" in XLA computation 0 - // "identity0" = "add" in XLA computation 0 & outside compilation 0 - // "identity1" = "identity0" in XLA computation 0 - // "identity2" = "identity1" in host computation - // "identity3" = "identity2" in XLA computation 1 - // "identity4" = "identity3" in XLA computation 1 & outside compilation 1 - // "identity5" = "identity4" in XLA computation 1 - // "identity6" = "identity5" in host computation - tensorflow::Scope s = tensorflow::Scope::NewRootScope(); - Output const_0 = ops::Const(s.WithOpName("const_0"), 1, {}); - Output const_1 = ops::Const(s.WithOpName("const_1"), 2, {}); - Output add = ops::Add(s.WithOpName("add"), const_0, const_1); - Output identity0 = ops::Identity(s.WithOpName("identity0"), add); - Output identity1 = ops::Identity(s.WithOpName("identity1"), identity0); - Output identity2 = ops::Identity(s.WithOpName("identity2"), identity1); - Output identity3 = ops::Identity(s.WithOpName("identity3"), identity2); - Output identity4 = ops::Identity(s.WithOpName("identity4"), identity3); - Output identity5 = ops::Identity(s.WithOpName("identity5"), identity4); - Graph g(OpRegistry::Global()); - TF_CHECK_OK(s.ToGraph(&g)); - auto node_index = g.BuildNodeNameIndex(); - - // Set XLA computation/outside compilation attr, and add control edges. - Node *const0_node = node_index["const_0"], *add_node = node_index["add"], - *identity0_node = node_index["identity0"], - *identity1_node = node_index["identity1"], - *identity2_node = node_index["identity2"], - *identity3_node = node_index["identity3"], - *identity4_node = node_index["identity4"], - *identity5_node = node_index["identity5"]; - add_node->AddAttr("_xla", "0"); - identity0_node->AddAttr("_xla", "0"); - identity0_node->AddAttr("_oc", "0"); - identity1_node->AddAttr("_xla", "0"); - identity3_node->AddAttr("_xla", "1"); - identity4_node->AddAttr("_xla", "1"); - identity4_node->AddAttr("_oc", "0"); - identity5_node->AddAttr("_xla", "1"); - // Case 1a: control edges between outside compilation and another XLA - // computation. - g.AddControlEdge(identity0_node, identity3_node); - g.AddControlEdge(identity1_node, identity4_node); - // Case 1b: control edges between different outside compilations. - g.AddControlEdge(identity0_node, identity4_node); - // Case 1c: control edges between outside compilation and host computation. - g.AddControlEdge(const0_node, identity0_node); - g.AddControlEdge(identity0_node, identity2_node); - - TF_CHECK_OK(PreprocessForEncapsulation(&g, "_xla", "_oc")); - - // Case 1a: add attr "_xla_control_deps_{from/to} = XLA computation node name" - // to the outside compilation node. - std::vector attr; - TF_CHECK_OK(GetNodeAttr(identity0_node->def(), - kXlaConnectedToOtherXlaComputationAttrName, &attr)); - EXPECT_EQ(attr.size(), 1); - EXPECT_EQ(attr[0], "1"); - attr.clear(); - TF_CHECK_OK(GetNodeAttr(identity4_node->def(), - kXlaConnectedFromOtherXlaComputationAttrName, &attr)); - EXPECT_EQ(attr.size(), 1); - EXPECT_EQ(attr[0], "0"); - // Case 1b: add attr "_xla_control_deps = src node name" to dst node. - attr.clear(); - TF_CHECK_OK(GetNodeAttr(identity4_node->def(), - kXlaControlDependenciesAttrName, &attr)); - EXPECT_EQ(attr.size(), 1); - EXPECT_EQ(attr[0], "identity0"); - // Case 1c: add attr "_xla_control_deps = src node name" to dst node. - attr.clear(); - TF_CHECK_OK(GetNodeAttr(identity0_node->def(), - kXlaControlDependenciesAttrName, &attr)); - EXPECT_EQ(attr.size(), 1); - EXPECT_EQ(attr[0], "const_0"); - attr.clear(); - TF_CHECK_OK(GetNodeAttr(identity2_node->def(), - kXlaControlDependenciesAttrName, &attr)); - EXPECT_EQ(attr.size(), 1); - EXPECT_EQ(attr[0], "identity0"); -} - -TEST(PreprocessForEncapsulationTest, DataEdges) { - // Build the graph: - // "const_0" and "const_1" in host computation - // "identityn0" = ("const_0", "const_1") in host computation 0 - // "add0" = "const_0" + "const_1" in XLA computation 0 - // "add1" = "add0" + "const_0" in XLA computation 0 & outside compilation 0 - // "identity0" = "add1" in XLA computation 0 - // "add2" = "add1" + "identity0" in host computation - // "add3" = "add1" + "add2" in XLA computation 1 - // "add4" = "identity0" + "add2" in XLA computation 1 & outside compilation 0 - // "add5" = "identityn0"[0] + "identityn0"[1] in XLA computation 1 & - // outside compilation 0 - // "identityn1" = ("identityn0"[0], "identityn0"[1]) in XLA computation 1 & - // outside compilation 0 - // "identity1" = "add4" in XLA computation 1 - // "identity2" = "identity1" in host computation - tensorflow::Scope s = tensorflow::Scope::NewRootScope(); - Output const_0 = ops::Const(s.WithOpName("const_0"), 1, {}); - Output const_1 = ops::Const(s.WithOpName("const_1"), 2, {}); - auto identityn0 = - ops::IdentityN(s.WithOpName("identityn_0"), {const_0, const_1}); - Output add0 = ops::Add(s.WithOpName("add0"), const_0, const_1); - Output add1 = ops::Add(s.WithOpName("add1"), add0, const_0); - Output identity0 = ops::Identity(s.WithOpName("identity0"), add1); - Output add2 = ops::Add(s.WithOpName("add2"), add1, identity0); - Output add3 = ops::Add(s.WithOpName("add3"), add1, add2); - Output add4 = ops::Add(s.WithOpName("add4"), identity0, add2); - Output add5 = ops::Add(s.WithOpName("add5"), identityn0[0], identityn0[1]); - auto identityn1 = ops::IdentityN(s.WithOpName("identityn_1"), - {identityn0[0], identityn0[1]}); - Output identity1 = ops::Identity(s.WithOpName("identity1"), add4); - Output identity2 = ops::Identity(s.WithOpName("identity2"), add4); - Graph g(OpRegistry::Global()); - TF_CHECK_OK(s.ToGraph(&g)); - auto node_index = g.BuildNodeNameIndex(); - - // Set XLA computation/outside compilation attr. - Node *add0_node = node_index["add0"], *add1_node = node_index["add1"], - *identity0_node = node_index["identity0"], - *add3_node = node_index["add3"], *add4_node = node_index["add4"], - *add5_node = node_index["add5"], - *identityn1_node = node_index["identityn_1"], - *identity1_node = node_index["identity1"]; - add0_node->AddAttr("_xla", "0"); - add1_node->AddAttr("_xla", "0"); - add1_node->AddAttr("_oc", "0"); - identity0_node->AddAttr("_xla", "0"); - add3_node->AddAttr("_xla", "1"); - add4_node->AddAttr("_xla", "1"); - add4_node->AddAttr("_oc", "0"); - add5_node->AddAttr("_xla", "1"); - add5_node->AddAttr("_oc", "0"); - identityn1_node->AddAttr("_xla", "1"); - identityn1_node->AddAttr("_oc", "0"); - identity1_node->AddAttr("_xla", "1"); - - TF_CHECK_OK(PreprocessForEncapsulation(&g, "_xla", "_oc")); - - // Check input nodes for related data edges. - node_index = g.BuildNodeNameIndex(); - // Step 2: add an Identity node between different XLA computations. - Node *bridge_add1_add3 = node_index["bridge_add1_add3"]; - EXPECT_NE(bridge_add1_add3, nullptr); - string str; - TF_CHECK_OK( - GetNodeAttr(bridge_add1_add3->attrs(), kBridgeSourceNodeAttrName, &str)); - EXPECT_EQ(str, "add1"); - Node *bridge_identity0_add4 = node_index["bridge_identity0_add4"]; - EXPECT_NE(bridge_identity0_add4, nullptr); - // Step 3: add placeholder for edges between host computation and outside - // compilation. - EXPECT_EQ(bridge_add1_add3->def().input(0), "add1_oc_to_host_placeholder_0"); - Node *add1_oc_to_host_placeholder = - node_index["add1_oc_to_host_placeholder_0"]; - TF_CHECK_OK(GetNodeAttr(add1_oc_to_host_placeholder->attrs(), - kOutsideCompilationToHostOriginalNodeAttrName, &str)); - EXPECT_EQ(str, "add1"); - int i; - TF_CHECK_OK(GetNodeAttr(add1_oc_to_host_placeholder->attrs(), - kOutsideCompilationToHostSrcOutputAttrName, &i)); - EXPECT_EQ(i, 0); - add4_node = node_index["add4"]; - ASSERT_NE(add4_node, nullptr); - EXPECT_EQ(add4_node->def().input(0), - "bridge_identity0_add4_host_to_oc_placeholder_0"); - Node *identity0_host_to_oc_placeholder = - node_index["bridge_identity0_add4_host_to_oc_placeholder_0"]; - TF_CHECK_OK(GetNodeAttr(identity0_host_to_oc_placeholder->attrs(), - kHostToOutsideCompilationOriginalNodeAttrName, &str)); - EXPECT_EQ(str, "bridge_identity0_add4"); - TF_CHECK_OK(GetNodeAttr(identity0_host_to_oc_placeholder->attrs(), - kHostToOutsideCompilationSrcOutputAttrName, &i)); - EXPECT_EQ(i, 0); - - // Check different placeholder nodes are created for different src_output. - Node *placeholder0 = node_index["identityn_0_host_to_oc_placeholder_0"], - *placeholder1 = node_index["identityn_0_host_to_oc_placeholder_1"]; - EXPECT_NE(placeholder0, nullptr); - EXPECT_NE(placeholder1, nullptr); - // Check we only have 2 placeholder nodes created for "identityn_0". - int placeholder_count = 0; - for (Node *n : g.nodes()) { - if (HasNodeAttr(n->def(), kHostToOutsideCompilationOriginalNodeAttrName)) { - string attr; - TF_CHECK_OK(GetNodeAttr( - n->attrs(), kHostToOutsideCompilationOriginalNodeAttrName, &attr)); - if (attr == "identityn_0") { - ++placeholder_count; - } - } - } - EXPECT_EQ(placeholder_count, 2); -} - -TEST(PostprocessForEncapsulationTest, ControlEdges) { - // Build the graph: - // "const0" - // "identity0" = "const0" (XLA computation 0) - // "identity1" = "identity0" - // "identity2" = "identity1" (XLA computation 1) - // "identity3" = "identity2" - tensorflow::Scope s = tensorflow::Scope::NewRootScope(); - Output const0 = ops::Const(s.WithOpName("const0"), 1, {}); - Output identity0 = ops::Identity(s.WithOpName("identity0"), const0); - Output identity1 = ops::Identity(s.WithOpName("identity1"), identity0); - Output identity2 = ops::Identity(s.WithOpName("identity2"), identity1); - Output identity3 = ops::Identity(s.WithOpName("identity3"), identity2); - Graph g(OpRegistry::Global()); - TF_CHECK_OK(s.ToGraph(&g)); - auto node_index = g.BuildNodeNameIndex(); - - // Set XLA computation/outside compilation attr, and add control edges. - Node *const0_node = node_index["const0"], - *identity0_node = node_index["identity0"], - *identity1_node = node_index["identity1"], - *identity2_node = node_index["identity2"], - *identity3_node = node_index["identity3"]; - identity1_node->AddAttr(kXlaConnectedFromOtherXlaComputationAttrName, - std::vector{"0"}); - identity1_node->AddAttr(kXlaConnectedToOtherXlaComputationAttrName, - std::vector{"1"}); - identity3_node->AddAttr(kXlaControlDependenciesAttrName, - std::vector{"const0", "identity1"}); - - std::unordered_map clusters; - clusters["0"].node = identity0_node; - clusters["1"].node = identity2_node; - TF_CHECK_OK(PostprocessForEncapsulation(&g, "_xla", "_oc", clusters)); - - // Case 3a: we have control edge identity0 -> identity1, and identity1 -> - // identity2. - bool edge_identity0_identity1 = false, edge_identity1_identity2 = false; - for (const Edge *e : g.edges()) { - if (!e->IsControlEdge()) { - continue; - } - if (e->src() == identity0_node && e->dst() == identity1_node) { - edge_identity0_identity1 = true; - } else if (e->src() == identity1_node && e->dst() == identity2_node) { - edge_identity1_identity2 = true; - } - } - EXPECT_TRUE(edge_identity0_identity1); - EXPECT_TRUE(edge_identity1_identity2); - // Case 3b: we have control edge const0 -> identity3, and identity1 -> - // identity3. - bool edge_const0_identity3 = false, edge_identity1_identity3 = false; - for (const Edge *e : g.edges()) { - if (!e->IsControlEdge()) { - continue; - } - if (e->src() == const0_node && e->dst() == identity3_node) { - edge_const0_identity3 = true; - } else if (e->src() == identity1_node && e->dst() == identity3_node) { - edge_identity1_identity3 = true; - } - } - EXPECT_TRUE(edge_const0_identity3); - EXPECT_TRUE(edge_identity1_identity3); -} - -TEST(PostprocessForEncapsulationTest, DataEdges) { - // Build the graph: - // "const0" in outside compilation "0" - // "placeholder0" (for "const0") in host computation - // "add0" = "placeholder0" + "placeholder0" in host computation - // "placeholder1" (for "add0") in outside compilation 1 - // "add1" = "placeholder1" + "placeholder1" in outside compilation 1 - // - // "bridge" = "placeholder0" in host computation - // "placeholder2" (for "bridge") in outside compilation 1 - // "add2" = "placeholder2" + "placeholder2" in outside compilation 1 - tensorflow::Scope s = tensorflow::Scope::NewRootScope(); - Output const0 = ops::Const(s.WithOpName("const0"), 1, {}); - Output placeholder0 = - ops::Placeholder(s.WithOpName("placeholder0"), DT_INT32); - Output add0 = ops::Add(s.WithOpName("add0"), placeholder0, placeholder0); - Output placeholder1 = - ops::Placeholder(s.WithOpName("placeholder1"), DT_INT32); - Output add1 = ops::Add(s.WithOpName("add1"), placeholder1, placeholder1); - Output bridge = ops::Identity(s.WithOpName("bridge"), placeholder0); - Output placeholder2 = - ops::Placeholder(s.WithOpName("placeholder2"), DT_INT32); - Output add2 = ops::Add(s.WithOpName("add2"), placeholder2, placeholder2); - Graph g(OpRegistry::Global()); - TF_CHECK_OK(s.ToGraph(&g)); - auto node_index = g.BuildNodeNameIndex(); - - // Set related attributes. - Node *placeholder0_node = node_index["placeholder0"]; - placeholder0_node->AddAttr(kOutsideCompilationToHostOriginalNodeAttrName, - "const0"); - placeholder0_node->AddAttr(kOutsideCompilationToHostSrcOutputAttrName, 0); - Node *placeholder1_node = node_index["placeholder1"]; - placeholder1_node->AddAttr(kHostToOutsideCompilationOriginalNodeAttrName, - "add0"); - placeholder1_node->AddAttr(kHostToOutsideCompilationSrcOutputAttrName, 0); - Node *bridge_node = node_index["bridge"]; - bridge_node->AddAttr(kBridgeSourceNodeAttrName, "const0"); - Node *placeholder2_node = node_index["placeholder2"]; - placeholder2_node->AddAttr(kHostToOutsideCompilationOriginalNodeAttrName, - "bridge"); - placeholder2_node->AddAttr(kHostToOutsideCompilationSrcOutputAttrName, 0); - - std::unordered_map clusters; - TF_CHECK_OK(PostprocessForEncapsulation(&g, "_xla", "_oc", clusters)); - - // Result graph should be: - // "add0" = "const0" + "const0" - // "add1" = "add0" + "add0" - // "add2" = "const0" + "const0" - node_index = g.BuildNodeNameIndex(); - EXPECT_EQ(node_index.size(), 6); - EXPECT_EQ(node_index["add0"]->def().input(0), "const0:0"); - EXPECT_EQ(node_index["add0"]->def().input(1), "const0:0"); - EXPECT_EQ(node_index["add1"]->def().input(0), "add0:0"); - EXPECT_EQ(node_index["add1"]->def().input(1), "add0:0"); - EXPECT_EQ(node_index["add2"]->def().input(0), "const0:0"); - EXPECT_EQ(node_index["add2"]->def().input(1), "const0:0"); -} - } // namespace tensorflow diff --git a/tensorflow/compiler/jit/encapsulate_xla_computations_pass.cc b/tensorflow/compiler/jit/encapsulate_xla_computations_pass.cc index d334100aa4a915a87fb05d371e0e3379a7ee05f2..ec745cdbb7e237f8b4935dd41e9791fc75f5355d 100644 --- a/tensorflow/compiler/jit/encapsulate_xla_computations_pass.cc +++ b/tensorflow/compiler/jit/encapsulate_xla_computations_pass.cc @@ -297,6 +297,7 @@ Status RewriteSubgraph(const std::vector& arg_source_tensors, NodeDef def; def.set_name(launch->name()); + MergeDebugInfo(NodeDebugInfo(launch->def()), &def); // Target the XLA CPU/GPU backends. VLOG(2) << "Replacing with XlaLaunch"; diff --git a/tensorflow/compiler/jit/extract_outside_compilation_pass.cc b/tensorflow/compiler/jit/extract_outside_compilation_pass.cc index e3c7e2f89be9b37b51a633dabb099969c181013f..8b01768c49422b331b52a8ba31bade000c95722e 100644 --- a/tensorflow/compiler/jit/extract_outside_compilation_pass.cc +++ b/tensorflow/compiler/jit/extract_outside_compilation_pass.cc @@ -20,8 +20,10 @@ limitations under the License. #include "tensorflow/compiler/jit/encapsulate_subgraphs_pass.h" #include "tensorflow/compiler/jit/encapsulate_util.h" #include "tensorflow/compiler/tf2xla/dump_graph.h" +#include "tensorflow/compiler/tf2xla/side_effect_util.h" #include "tensorflow/compiler/tf2xla/tf2xla_util.h" #include "tensorflow/core/common_runtime/function.h" +#include "tensorflow/core/framework/function.h" #include "tensorflow/core/framework/graph_to_functiondef.h" #include "tensorflow/core/framework/node_def_builder.h" #include "tensorflow/core/framework/node_def_util.h" @@ -98,9 +100,12 @@ xla::StatusOr BuildRecvAtHostNode( recv_at_host_builder.Attr("Toutputs", recv_at_host_dtypes); // The correct device_ordinal will be inserted during replication in a // subsequent rewrite. - recv_at_host_builder.Attr("device_ordinal", 0); + AttrValue device_ordinal_value; + device_ordinal_value.set_placeholder("device_ordinal"); + recv_at_host_builder.Attr("device_ordinal", device_ordinal_value); recv_at_host_builder.Attr( "key", absl::StrCat("host_compute_channel_", oc_cluster_name)); + recv_at_host_builder.Attr(kXlaHasHostTransferAttrName, true); recv_at_host_builder.Input(key_placeholder->name(), 0, DT_STRING); TF_RETURN_IF_ERROR(recv_at_host_builder.Finalize(&recv_at_host_def)); Status s; @@ -197,9 +202,12 @@ xla::StatusOr BuildSendFromHostNode( send_from_host_builder.Attr("Tinputs", send_from_host_dtypes); // The correct device_ordinal will be inserted during replication in a // subsequent rewrite. - send_from_host_builder.Attr("device_ordinal", 0); + AttrValue device_ordinal_value; + device_ordinal_value.set_placeholder("device_ordinal"); + send_from_host_builder.Attr("device_ordinal", device_ordinal_value); send_from_host_builder.Attr( "key", absl::StrCat("host_compute_channel_", oc_cluster_name)); + send_from_host_builder.Attr(kXlaHasHostTransferAttrName, true); std::vector inputs(send_from_host_dtypes.size()); for (auto* n : ret_nodes) { int index; @@ -322,6 +330,38 @@ xla::StatusOr BuildXlaHostComputeNodeDef( return new_def; } +Status ValidateOutsideCompilationCallNode(Node* call_node) { + // DT_INT64 as input/output for outside compilation is not supported yet: + // b/120809951. + for (const Edge* e : call_node->in_edges()) { + if (e->IsControlEdge()) { + continue; + } + DataType dtype = e->src()->output_type(e->src_output()); + if (dtype == DT_INT64) { + return errors::Unimplemented( + "int64 input for outside compilation is not supported yet: " + "b/120809951. Please cast output of node ", + e->src()->DebugString(), + " to int32 before feeding it into outside compilation."); + } + } + for (const Edge* e : call_node->out_edges()) { + if (e->IsControlEdge()) { + continue; + } + DataType dtype = e->dst()->input_type(e->dst_input()); + if (dtype == DT_INT64) { + return errors::Unimplemented( + "int64 output for outside compilation is not supported yet: " + "b/120809951. Please cast input of node ", + e->dst()->DebugString(), + " to int32 before returning it from outside compilation."); + } + } + return Status::OK(); +} + // Replace outside compilation function call node with XlaHostCompute node. // If the function call node has no input/output edges, we will just remove it // and not create a XlaHostCompute node. @@ -357,6 +397,47 @@ Status ReplaceOrRemoveOutsideCompilationCallNode( return Status::OK(); } +// Resets "device_ordinal" attr to placeholder value for related nodes +// (XlaRecvAtHost nodes; XlaSendFromHost nodes; If nodes containing +// XlaRecvAtHost/XlaSendFromHost). +Status ResetDeviceOrdinalToPlaceholderValue(Graph* g) { + AttrValue device_ordinal_value; + device_ordinal_value.set_placeholder("device_ordinal"); + for (Node* n : g->nodes()) { + if (!HasNodeAttr(n->def(), kXlaHasHostTransferAttrName)) { + continue; + } + + if (n->type_string() == "_XlaRecvAtHost" || + n->type_string() == "_XlaSendFromHost") { + n->ClearAttr("device_ordinal"); + n->AddAttr("device_ordinal", device_ordinal_value); + } else if (n->type_string() == "If") { + for (const string& attr_name : + std::vector{"then_branch", "else_branch"}) { + NameAttrList branch_func; + TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), attr_name, &branch_func)); + (*branch_func.mutable_attr())["device_ordinal"] = device_ordinal_value; + n->ClearAttr(attr_name); + n->AddAttr(attr_name, branch_func); + } + } else if (n->type_string() == "While") { + for (const string& attr_name : std::vector{"cond", "body"}) { + NameAttrList branch_func; + TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), attr_name, &branch_func)); + (*branch_func.mutable_attr())["device_ordinal"] = device_ordinal_value; + n->ClearAttr(attr_name); + n->AddAttr(attr_name, branch_func); + } + } else { + return errors::Internal("Unknown node marked with ", + kXlaHasHostTransferAttrName, ": ", + n->DebugString()); + } + } + return Status::OK(); +} + // For an XLA computation, builds host side graph given all outside compilation // graphs inside it. The host side graph contains: // 1) a "sequencer" node (we will add control edge between XlaRecvAtHost and @@ -368,8 +449,8 @@ Status ReplaceOrRemoveOutsideCompilationCallNode( Status ConstructHostGraph( const string& xla_cluster_name, const string& outside_compilation_attr_name, const std::vector& outside_compilation_host_graphs, - FunctionLibraryDefinition* fld, std::unique_ptr* host_graph) { - host_graph->reset(new Graph(fld)); + FunctionLibraryDefinition* fld, const string& host_graph_func_name) { + Graph host_graph(fld); // Create sequencer node in host graph. NodeDefBuilder sequencer_builder(absl::StrCat(xla_cluster_name, "_sequencer"), @@ -378,24 +459,34 @@ Status ConstructHostGraph( NodeDef sequencer_def; TF_RETURN_IF_ERROR(sequencer_builder.Finalize(&sequencer_def)); Status s; - Node* sequencer = (*host_graph)->AddNode(sequencer_def, &s); + Node* sequencer = host_graph.AddNode(sequencer_def, &s); TF_RETURN_IF_ERROR(s); // Create key placeholder in host graph. TF_ASSIGN_OR_RETURN( Node * key_placeholder, - AddHostComputeKeyPlaceholder(xla_cluster_name, host_graph->get())); + AddHostComputeKeyPlaceholder(xla_cluster_name, &host_graph)); // For each outside compilation graph, copy them to host graph with the // following changes: // a) Use key_placeholder in host graph instead of its own. - // b) Add control edge from RecvAtHost/SendFromHost to sequencer. + // b) Add control edge from host transfer nodes (XlaRecvAtHost, + // XlaSendFromHost, If/While nodes containing + // XlaRecvAtHost/XlaSendFromHost) to sequencer node. // c) Clear node_def.device(), so device placer won't get confused. for (const string& host_func : outside_compilation_host_graphs) { VLOG(4) << "Expanding host graph " << host_func; + // Temporarily use "0" as "device_ordinal". It will be reset to placeholder + // value after we expanded all host graphs. We cannot just use placeholder + // value here because FunctionDef instantiation does not allow placeholder + // value for attributes. + AttrValue device_ordinal_attr; + device_ordinal_attr.set_i(0); + protobuf::Map attrs; + attrs["device_ordinal"] = device_ordinal_attr; FunctionBody* host_fbody = nullptr; TF_RETURN_IF_ERROR(FunctionDefToBodyHelper( - *fld->Find(host_func), AttrSlice(), fld, + *fld->Find(host_func), AttrSlice(&attrs), fld, [&](const string& op, const OpDef** sig) { return fld->LookUpOpDef(op, sig); }, @@ -408,8 +499,8 @@ Status ConstructHostGraph( FixupSourceAndSinkEdges(host_fbody->graph); std::map node_map; - node_map[host_fbody->graph->source_node()] = (*host_graph)->source_node(); - node_map[host_fbody->graph->sink_node()] = (*host_graph)->sink_node(); + node_map[host_fbody->graph->source_node()] = host_graph.source_node(); + node_map[host_fbody->graph->sink_node()] = host_graph.sink_node(); Status s; ReverseDFS( *host_fbody->graph, /*enter=*/nullptr, @@ -431,7 +522,7 @@ Status ConstructHostGraph( NodeDef copy_def = n->def(); // Change c). copy_def.clear_device(); - copy = (*host_graph)->AddNode(copy_def, &s); + copy = host_graph.AddNode(copy_def, &s); if (!s.ok()) { return; } @@ -446,22 +537,23 @@ Status ConstructHostGraph( e->src()->DebugString()); return; } - (*host_graph) - ->AddEdge(node_map[e->src()], e->src_output(), copy, - e->dst_input()); + host_graph.AddEdge(node_map[e->src()], e->src_output(), copy, + e->dst_input()); } // Change b). - if (copy->type_string() == "_XlaRecvAtHost" || - copy->type_string() == "_XlaSendFromHost") { - (*host_graph)->AddControlEdge(copy, sequencer); + if (HasNodeAttr(copy->def(), kXlaHasHostTransferAttrName)) { + host_graph.AddControlEdge(copy, sequencer); } }, NodeComparatorID()); + if (!s.ok()) { return s; } } + // Reset "device_ordinal" to placeholder value. + TF_RETURN_IF_ERROR(ResetDeviceOrdinalToPlaceholderValue(&host_graph)); // sequencer and key_placeholder might be dead nodes. Prune them if necessary. // - sequencer should be pruned iff it has no input control edges from @@ -470,21 +562,30 @@ Status ConstructHostGraph( // - key_placeholder should be pruned iff there's no RecvAtHost/SendFromHost. // We don't need to do anything special. if (!sequencer->in_edges().empty()) { - (*host_graph)->AddControlEdge(sequencer, (*host_graph)->sink_node()); + host_graph.AddControlEdge(sequencer, host_graph.sink_node()); } PruneForReverseReachability( - host_graph->get(), - std::unordered_set{(*host_graph)->sink_node()}); + &host_graph, std::unordered_set{host_graph.sink_node()}); // Postprocess edges between different outside compilations. TF_RETURN_IF_ERROR(PostprocessEdgesBetweenOutsideCompilations( - host_graph->get(), outside_compilation_attr_name)); + &host_graph, outside_compilation_attr_name)); if (VLOG_IS_ON(4)) { dump_graph::DumpGraphToFile( absl::StrCat("extract_outside_compilation_host_graph_for_", xla_cluster_name), - **host_graph, fld); + host_graph, fld); + } + + FunctionDef host_graph_fdef; + TF_RETURN_IF_ERROR( + GraphToFunctionDef(host_graph, host_graph_func_name, &host_graph_fdef)); + if (fld->Find(host_graph_func_name)) { + TF_RETURN_IF_ERROR( + fld->ReplaceFunction(host_graph_func_name, host_graph_fdef)); + } else { + TF_RETURN_IF_ERROR(fld->AddFunctionDef(host_graph_fdef)); } return Status::OK(); @@ -492,8 +593,28 @@ Status ConstructHostGraph( // Expand XLA computation's outside compilation host side graph into main graph. // Add a control edge between sequencer node and the XLA computation node. -Status ExpandHostGraphIntoMainGraph(Graph* main_graph, Graph* host_graph, +Status ExpandHostGraphIntoMainGraph(Graph* main_graph, + FunctionLibraryDefinition* fld, + const string& host_graph_func_name, Node* xla_computation_node) { + // Temporarily use "0" as "device_ordinal". It will be rewritten with the + // correct value in a later pass. We cannot just use placeholder value here + // because FunctionDef instantiation does not allow placeholder value for + // attributes. + AttrValue device_ordinal_attr; + device_ordinal_attr.set_i(0); + protobuf::Map attrs; + attrs["device_ordinal"] = device_ordinal_attr; + FunctionBody* fbody = nullptr; + TF_RETURN_IF_ERROR(FunctionDefToBodyHelper( + *fld->Find(host_graph_func_name), AttrSlice(&attrs), fld, + [&](const string& op, const OpDef** sig) { + return fld->LookUpOpDef(op, sig); + }, + &fbody)); + std::unique_ptr fbody_deleter(fbody); + Graph* host_graph = fbody->graph; + // We use ReverseDFS() to copy nodes. Make sure all nodes are reverse // reachable from sink node so all nodes will be copied. // TODO(b/77601805): consolidate copy graph functions. @@ -545,23 +666,25 @@ Status ExpandHostGraphIntoMainGraph(Graph* main_graph, Graph* host_graph, return s; } -// Rewrites shape inference graph for outside compilation. -// 1. If the outside compilation is a "top-level" one (not in a function of any -// If/While/etc.), this shape inference graph might have host computation to -// outside compilation placeholder nodes, which will cause shape inference to -// fail. However, those nodes are not in `host_graph` any more (because we -// have executed `PostprocessForEncapsultion`). In this case, we clear the -// graph, and copy SendFromHost with all its predecessors from `host_graph`. -// This case is detected by whether the SendFromHost node exists in -// `host_graph` as well. -// 2. Remove control edges, and prune nodes that are not useful for shape -// inference. +// Rewrites shape inference graph for outside compilation: +// 1) If XlaSendFromHost also exists in `host_graph`, copy nodes from +// `host_graph`. Because we might still have outside compilation to outside +// compilation placeholder nodes in shape inference graph, which will prevent +// us from inferring XlaSendFromHost shape. But in `host_graph`, we already +// removed those placeholder nodes. +// 2) Remove control edges. +// 3) Prune nodes that are not useful for shape inference. Status RewriteShapeInferenceGraph(const string& shape_inference_graph_name, Graph* host_graph, FunctionLibraryDefinition* fld) { + // Use "0" as "device_ordinal". It does not matter for shape inference. + AttrValue device_ordinal_attr; + device_ordinal_attr.set_i(0); + protobuf::Map attrs; + attrs["device_ordinal"] = device_ordinal_attr; FunctionBody* fbody = nullptr; TF_RETURN_IF_ERROR(FunctionDefToBodyHelper( - *fld->Find(shape_inference_graph_name), AttrSlice(), fld, + *fld->Find(shape_inference_graph_name), AttrSlice(&attrs), fld, [&](const string& op, const OpDef** sig) { return fld->LookUpOpDef(op, sig); }, @@ -650,6 +773,7 @@ Status RewriteShapeInferenceGraph(const string& shape_inference_graph_name, g->RemoveEdge(e); } } + // Nodes that are not reverse reachable from SendFromHost are not useful for // shape inference. Prune them. PruneForReverseReachability(g, @@ -669,6 +793,572 @@ Status RewriteShapeInferenceGraph(const string& shape_inference_graph_name, return Status::OK(); } +// Builds XlaSendToHost node which sends cond predicate to host. +xla::StatusOr BuildSendIfPredNode(const string& name, + const string& host_transfer_key, + Node* pred_node, Graph* g) { + NodeDefBuilder send_pred_builder(name, "XlaSendToHost"); + send_pred_builder.Attr("Tinput", DT_BOOL); + send_pred_builder.Attr("key", absl::StrCat(host_transfer_key, "_dtoh_0")); + send_pred_builder.Attr(kXlaTokenInputNodesAttrName, + std::vector{kXlaTokenArgNodeName}); + send_pred_builder.Input(pred_node->name(), 0, DT_BOOL); + NodeDef send_pred_def; + TF_RETURN_IF_ERROR(send_pred_builder.Finalize(&send_pred_def)); + Status s; + Node* send_pred_node = g->AddNode(send_pred_def, &s); + TF_RETURN_IF_ERROR(s); + g->AddEdge(pred_node, 0, send_pred_node, 0); + return send_pred_node; +} + +// Replaces key placeholder node with an _Arg node. +Status ReplaceKeyPlaceholderWithArgNode(const string& xla_cluster_name, + const string& func_name, + FunctionLibraryDefinition* fld) { + // Temporarily use "0" as "device_ordinal". It will be reset to placeholder + // value after rewriting. + AttrValue device_ordinal_attr; + device_ordinal_attr.set_i(0); + protobuf::Map attrs; + attrs["device_ordinal"] = device_ordinal_attr; + FunctionBody* fbody = nullptr; + TF_RETURN_IF_ERROR(FunctionDefToBodyHelper( + *fld->Find(func_name), AttrSlice(&attrs), fld, + [&](const string& op, const OpDef** sig) { + return fld->LookUpOpDef(op, sig); + }, + &fbody)); + std::unique_ptr fbody_deleter(fbody); + Graph* g = fbody->graph; + + // Find or create the key placeholder node. + Node* key_placeholder = nullptr; + for (Node* n : g->nodes()) { + if (IsKeyPlaceholderNode(*n)) { + key_placeholder = n; + break; + } + } + if (!key_placeholder) { + TF_ASSIGN_OR_RETURN(key_placeholder, + AddHostComputeKeyPlaceholder(xla_cluster_name, g)); + } + + // Build the _Arg node, and replace key placeholder node with it. + NodeDefBuilder arg_builder("key_arg", FunctionLibraryDefinition::kArgOp); + arg_builder.Attr("T", DT_STRING); + arg_builder.Attr("index", 0); + NodeDef arg_def; + TF_RETURN_IF_ERROR(arg_builder.Finalize(&arg_def)); + TF_RETURN_IF_ERROR(ReplaceNode(g, key_placeholder, arg_def).status()); + + // Reset "device_ordinal" to placeholder value. + TF_RETURN_IF_ERROR(ResetDeviceOrdinalToPlaceholderValue(g)); + + FunctionDef replace_fdef; + TF_RETURN_IF_ERROR(GraphToFunctionDef(*g, func_name, &replace_fdef)); + TF_RETURN_IF_ERROR(fld->ReplaceFunction(func_name, replace_fdef)); + return Status::OK(); +} + +// Builds host side graph for If node. +Status BuildHostGraphForIfNode(const string& xla_cluster_attr_name, + const string& outside_compilation_attr_name, + const string& xla_cluster_name, + const string& if_node_name, + const string& host_transfer_key, + const string& host_graph_func_name, + FunctionLibraryDefinition* fld, + const string& then_branch_host_func_name, + const string& else_branch_host_func_name) { + Graph host_graph(fld); + string outside_compilation_name = absl::StrCat("oc_if_", if_node_name); + AttrValue device_ordinal_value; + device_ordinal_value.set_placeholder("device_ordinal"); + + // Step 1: add key placeholder node. + TF_ASSIGN_OR_RETURN( + Node * key_placeholder, + AddHostComputeKeyPlaceholder(xla_cluster_name, &host_graph)); + + // Step 2: build XlaRecvAtHost node to recv predicate. + NodeDefBuilder recv_pred_builder( + absl::StrCat("recv_oc_if_pred_", if_node_name), "_XlaRecvAtHost"); + recv_pred_builder.Attr("Toutputs", std::vector{DT_BOOL}); + recv_pred_builder.Attr("key", host_transfer_key); + recv_pred_builder.Attr("device_ordinal", device_ordinal_value); + recv_pred_builder.Attr(xla_cluster_attr_name, xla_cluster_name); + recv_pred_builder.Attr(outside_compilation_attr_name, + outside_compilation_name); + recv_pred_builder.Attr(kXlaHasHostTransferAttrName, true); + recv_pred_builder.Input(key_placeholder->name(), 0, DT_STRING); + NodeDef recv_pred_def; + TF_RETURN_IF_ERROR(recv_pred_builder.Finalize(&recv_pred_def)); + Status s; + Node* recv_pred_node = host_graph.AddNode(recv_pred_def, &s); + TF_RETURN_IF_ERROR(s); + host_graph.AddEdge(key_placeholder, 0, recv_pred_node, 0); + + // Step 3: rewrite `{then, else}_branch_host_func_name`, replace key + // placeholder with an _Arg node. + TF_RETURN_IF_ERROR(ReplaceKeyPlaceholderWithArgNode( + xla_cluster_name, then_branch_host_func_name, fld)); + TF_RETURN_IF_ERROR(ReplaceKeyPlaceholderWithArgNode( + xla_cluster_name, else_branch_host_func_name, fld)); + + // Step 4: build If node to choose between `{then, else}_branch_host_graph`. + NodeDefBuilder if_builder(absl::StrCat("oc_if_", if_node_name), "If"); + if_builder.Attr("Tcond", DT_BOOL); + if_builder.Attr("Tin", std::vector{DT_STRING}); + if_builder.Attr("Tout", std::vector{}); + NameAttrList host_then_branch, host_else_branch; + host_then_branch.set_name(then_branch_host_func_name); + (*host_then_branch.mutable_attr())["device_ordinal"] = device_ordinal_value; + host_else_branch.set_name(else_branch_host_func_name); + (*host_else_branch.mutable_attr())["device_ordinal"] = device_ordinal_value; + if_builder.Attr("then_branch", host_then_branch); + if_builder.Attr("else_branch", host_else_branch); + if_builder.Attr(kXlaHasHostTransferAttrName, true); + if_builder.Attr(xla_cluster_attr_name, xla_cluster_name); + if_builder.Attr(outside_compilation_attr_name, outside_compilation_name); + if_builder.Input(recv_pred_node->name(), 0, DT_BOOL); + std::vector if_inputs{ + {key_placeholder->name(), 0, DT_STRING}}; + if_builder.Input(if_inputs); + NodeDef if_def; + TF_RETURN_IF_ERROR(if_builder.Finalize(&if_def)); + Node* if_node = host_graph.AddNode(if_def, &s); + TF_RETURN_IF_ERROR(s); + host_graph.AddEdge(recv_pred_node, 0, if_node, 0); + host_graph.AddEdge(key_placeholder, 0, if_node, 1); + + // Convert `host_graph` to function, and add a "device_ordinal" attr. + FunctionDef oc_host_graph_fdef; + TF_RETURN_IF_ERROR(GraphToFunctionDef(host_graph, host_graph_func_name, + &oc_host_graph_fdef)); + if (fld->Find(host_graph_func_name)) { + TF_RETURN_IF_ERROR( + fld->ReplaceFunction(host_graph_func_name, oc_host_graph_fdef)); + } else { + TF_RETURN_IF_ERROR(fld->AddFunctionDef(oc_host_graph_fdef)); + } + + return Status::OK(); +} + +// Rewrites loop cond to add a node which sends loop cond to host. +Status AddSendLoopPredToLoopCond(FunctionLibraryDefinition* fld, + const NameAttrList& loop_cond_func, + const string& while_node_name, + const string& host_transfer_key) { + // Instantiate the loop cond function. + FunctionBody* fbody = nullptr; + TF_RETURN_IF_ERROR(FunctionDefToBodyHelper( + *fld->Find(loop_cond_func.name()), AttrSlice(&loop_cond_func.attr()), fld, + [&](const string& op, const OpDef** sig) { + return fld->LookUpOpDef(op, sig); + }, + &fbody)); + std::unique_ptr fbody_deleter(fbody); + Graph* g = fbody->graph; + + // Find the _Retval node and the loop cond node. + Node* ret_node = nullptr; + for (Node* n : g->nodes()) { + if (n->type_string() == "_Retval") { + if (ret_node) { + return errors::Internal("Multiple return node for loop cond function ", + loop_cond_func.name(), ": ", + ret_node->DebugString(), " and ", + n->DebugString()); + } else { + ret_node = n; + } + } + } + if (!ret_node) { + return errors::Internal("No _Retval node for loop cond function ", + loop_cond_func.name()); + } + Node* loop_cond; + TF_RETURN_IF_ERROR(ret_node->input_node(0, &loop_cond)); + + // Build the XlaSendToHost node. + NodeDefBuilder send_loop_cond_builder( + absl::StrCat("send_oc_while_cond_", while_node_name), "XlaSendToHost"); + send_loop_cond_builder.Attr("Tinput", DT_BOOL); + send_loop_cond_builder.Attr("key", + absl::StrCat(host_transfer_key, "_dtoh_0")); + send_loop_cond_builder.Attr(kXlaTokenInputNodesAttrName, + std::vector{kXlaTokenArgNodeName}); + send_loop_cond_builder.Input(loop_cond->name(), 0, DT_BOOL); + NodeDef send_loop_cond_def; + TF_RETURN_IF_ERROR(send_loop_cond_builder.Finalize(&send_loop_cond_def)); + Status s; + Node* send_loop_cond_node = g->AddNode(send_loop_cond_def, &s); + TF_RETURN_IF_ERROR(s); + g->AddEdge(loop_cond, 0, send_loop_cond_node, 0); + + // Replace original function. + FunctionDef replace_fdef; + TF_RETURN_IF_ERROR( + GraphToFunctionDef(*g, loop_cond_func.name(), &replace_fdef)); + TF_RETURN_IF_ERROR(fld->ReplaceFunction(loop_cond_func.name(), replace_fdef)); + + return Status::OK(); +} + +// Rewrites while loop cond function for host. +Status RewriteHostWhileLoopCond( + const string& cond_host_func_name, const string& while_node_name, + const string& host_transfer_key, const string& xla_cluster_attr_name, + const string& xla_cluster_name, const string& outside_compilation_attr_name, + const string& outside_compilation_name, FunctionLibraryDefinition* fld) { + // Replace key placeholder node with _Arg node. + TF_RETURN_IF_ERROR(ReplaceKeyPlaceholderWithArgNode( + xla_cluster_name, cond_host_func_name, fld)); + + // Instantiate cond function. + AttrValue device_ordinal_temp_value; + device_ordinal_temp_value.set_i(0); + protobuf::Map attrs; + attrs["device_ordinal"] = device_ordinal_temp_value; + FunctionBody* cond_fbody = nullptr; + TF_RETURN_IF_ERROR(FunctionDefToBodyHelper( + *fld->Find(cond_host_func_name), AttrSlice(&attrs), fld, + [&](const string& op, const OpDef** sig) { + return fld->LookUpOpDef(op, sig); + }, + &cond_fbody)); + std::unique_ptr cond_fbody_deleter(cond_fbody); + Graph* cond_graph = cond_fbody->graph; + Node* key_arg = nullptr; + for (Node* n : cond_graph->nodes()) { + if (n->type_string() == "_Arg") { + key_arg = n; + } + } + if (!key_arg) { + return errors::Internal( + "No _Arg node found for host compute key in function ", + cond_host_func_name); + } + + // Add an XlaRecvAtHost node to use as cond function return value. + // We don't need to set kXlaHasHostTransferAttrName for this node, because + // it's already added for the "While" node on the host. + NodeDefBuilder recv_pred_builder( + absl::StrCat("recv_oc_while_cond_", while_node_name), "_XlaRecvAtHost"); + recv_pred_builder.Attr("Toutputs", std::vector{DT_BOOL}); + recv_pred_builder.Attr("key", host_transfer_key); + AttrValue device_ordinal_value; + device_ordinal_value.set_placeholder("device_ordinal"); + recv_pred_builder.Attr("device_ordinal", device_ordinal_value); + recv_pred_builder.Attr(xla_cluster_attr_name, xla_cluster_name); + recv_pred_builder.Attr(outside_compilation_attr_name, + outside_compilation_name); + recv_pred_builder.Input(key_arg->name(), 0, DT_STRING); + NodeDef recv_pred_def; + TF_RETURN_IF_ERROR(recv_pred_builder.Finalize(&recv_pred_def)); + Status s; + Node* recv_pred_node = cond_graph->AddNode(recv_pred_def, &s); + TF_RETURN_IF_ERROR(s); + cond_graph->AddEdge(key_arg, 0, recv_pred_node, 0); + NodeDefBuilder ret_builder( + absl::StrCat("recv_oc_while_cond_ret_", while_node_name), "_Retval"); + ret_builder.Attr("T", DT_BOOL); + ret_builder.Attr("index", 0); + ret_builder.Input(recv_pred_node->name(), 0, DT_BOOL); + NodeDef ret_def; + TF_RETURN_IF_ERROR(ret_builder.Finalize(&ret_def)); + Node* ret_node = cond_graph->AddNode(ret_def, &s); + TF_RETURN_IF_ERROR(s); + cond_graph->AddEdge(recv_pred_node, 0, ret_node, 0); + + // Reset device_ordinal to placeholder value. + TF_RETURN_IF_ERROR(ResetDeviceOrdinalToPlaceholderValue(cond_graph)); + + // Replace original function. + FunctionDef cond_replace_fdef; + TF_RETURN_IF_ERROR( + GraphToFunctionDef(*cond_graph, cond_host_func_name, &cond_replace_fdef)); + TF_RETURN_IF_ERROR( + fld->ReplaceFunction(cond_host_func_name, cond_replace_fdef)); + + return Status::OK(); +} + +// Rewrites while loop body function for host. +Status RewriteHostWhileLoopBody( + const string& body_host_func_name, const string& while_node_name, + const string& host_transfer_key, const string& xla_cluster_attr_name, + const string& xla_cluster_name, const string& outside_compilation_attr_name, + const string& outside_compilation_name, FunctionLibraryDefinition* fld) { + // Replace key placeholder node with _Arg node. + TF_RETURN_IF_ERROR(ReplaceKeyPlaceholderWithArgNode( + xla_cluster_name, body_host_func_name, fld)); + + // Instantiate body function. + AttrValue device_ordinal_temp_value; + device_ordinal_temp_value.set_i(0); + protobuf::Map attrs; + attrs["device_ordinal"] = device_ordinal_temp_value; + FunctionBody* body_fbody = nullptr; + TF_RETURN_IF_ERROR(FunctionDefToBodyHelper( + *fld->Find(body_host_func_name), AttrSlice(&attrs), fld, + [&](const string& op, const OpDef** sig) { + return fld->LookUpOpDef(op, sig); + }, + &body_fbody)); + std::unique_ptr body_fbody_deleter(body_fbody); + Graph* body_graph = body_fbody->graph; + Node* key_arg = nullptr; + for (Node* n : body_graph->nodes()) { + if (n->type_string() == "_Arg") { + key_arg = n; + } + } + if (!key_arg) { + return errors::Internal( + "No _Arg node found for host compute key in function ", + body_host_func_name); + } + + // Add a _Retval node to loop body. + NodeDefBuilder ret_builder( + absl::StrCat("recv_oc_while_body_ret_", while_node_name), "_Retval"); + ret_builder.Attr("T", DT_STRING); + ret_builder.Attr("index", 0); + ret_builder.Input(key_arg->name(), 0, DT_STRING); + NodeDef ret_def; + TF_RETURN_IF_ERROR(ret_builder.Finalize(&ret_def)); + Status s; + Node* ret_node = body_graph->AddNode(ret_def, &s); + TF_RETURN_IF_ERROR(s); + body_graph->AddEdge(key_arg, 0, ret_node, 0); + + // Reset device_ordinal to placeholder value. + TF_RETURN_IF_ERROR(ResetDeviceOrdinalToPlaceholderValue(body_graph)); + + // Replace original function. + FunctionDef body_replace_fdef; + TF_RETURN_IF_ERROR( + GraphToFunctionDef(*body_graph, body_host_func_name, &body_replace_fdef)); + TF_RETURN_IF_ERROR( + fld->ReplaceFunction(body_host_func_name, body_replace_fdef)); + + return Status::OK(); +} + +// Builds host side graph for while node. +Status BuildHostGraphForWhileNode( + const string& xla_cluster_attr_name, + const string& outside_compilation_attr_name, const string& xla_cluster_name, + const string& while_node_name, const string& host_transfer_key, + const string& host_graph_func_name, FunctionLibraryDefinition* fld, + const string& cond_host_func_name, const string& body_host_func_name) { + Graph host_graph(fld); + string outside_compilation_name = absl::StrCat("oc_while_", while_node_name); + + // Step 1: add key placeholder node. + TF_ASSIGN_OR_RETURN( + Node * key_placeholder, + AddHostComputeKeyPlaceholder(xla_cluster_name, &host_graph)); + + // Step 2: rewrite cond function. + TF_RETURN_IF_ERROR(RewriteHostWhileLoopCond( + cond_host_func_name, while_node_name, host_transfer_key, + xla_cluster_attr_name, xla_cluster_name, outside_compilation_attr_name, + outside_compilation_name, fld)); + + // Step 3: rewrite body function. + TF_RETURN_IF_ERROR(RewriteHostWhileLoopBody( + body_host_func_name, while_node_name, host_transfer_key, + xla_cluster_attr_name, xla_cluster_name, outside_compilation_attr_name, + outside_compilation_name, fld)); + + // Step 4: build While node. + NodeDefBuilder while_builder(absl::StrCat("oc_while_", while_node_name), + "While"); + while_builder.Attr("T", std::vector{DT_STRING}); + NameAttrList func; + AttrValue device_ordinal_value; + device_ordinal_value.set_placeholder("device_ordinal"); + (*func.mutable_attr())["device_ordinal"] = device_ordinal_value; + func.set_name(cond_host_func_name); + while_builder.Attr("cond", func); + func.set_name(body_host_func_name); + while_builder.Attr("body", func); + while_builder.Attr(kXlaHasHostTransferAttrName, true); + while_builder.Attr(xla_cluster_attr_name, xla_cluster_name); + while_builder.Attr(outside_compilation_attr_name, outside_compilation_name); + std::vector while_inputs{ + {key_placeholder->name(), 0, DT_STRING}}; + while_builder.Input(while_inputs); + NodeDef while_def; + TF_RETURN_IF_ERROR(while_builder.Finalize(&while_def)); + Status s; + Node* while_node = host_graph.AddNode(while_def, &s); + TF_RETURN_IF_ERROR(s); + host_graph.AddEdge(key_placeholder, 0, while_node, 0); + + // Convert `host_graph` to function. + FunctionDef oc_host_graph_fdef; + TF_RETURN_IF_ERROR(GraphToFunctionDef(host_graph, host_graph_func_name, + &oc_host_graph_fdef)); + if (fld->Find(host_graph_func_name)) { + TF_RETURN_IF_ERROR( + fld->ReplaceFunction(host_graph_func_name, oc_host_graph_fdef)); + } else { + TF_RETURN_IF_ERROR(fld->AddFunctionDef(oc_host_graph_fdef)); + } + + return Status::OK(); +} + +Status ExtractOutsideCompilationForNodesWithAssociatedFunctions( + Graph* g, const string& xla_cluster_attr_name, + const string& outside_compilation_attr_name, const string& xla_cluster_name, + const std::map& host_compute_core, + FunctionLibraryDefinition* fld, std::vector* host_graphs, + std::vector* shape_inference_graphs, + bool* has_outside_compilation) { + std::vector if_nodes, while_nodes; + for (Node* n : g->nodes()) { + if (n->type_string() == "If") { + if_nodes.push_back(n); + } else if (n->type_string() == "While") { + while_nodes.push_back(n); + } + } + + for (Node* n : if_nodes) { + // Instantiate "then_branch" and "else_branch". + NameAttrList then_branch, else_branch; + TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "then_branch", &then_branch)); + TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "else_branch", &else_branch)); + + // Extract outside compilation for then_branch and else_branch. + bool then_branch_has_outside_compilation = false; + bool else_branch_has_outside_compilation = false; + string then_branch_host_func_name = + absl::StrCat("oc_then_branch_host_if_", n->name()), + else_branch_host_func_name = + absl::StrCat("oc_else_branch_host_if_", n->name()); + string then_branch_xla_func_name = absl::StrCat(then_branch.name(), "_oc"), + else_branch_xla_func_name = absl::StrCat(else_branch.name(), "_oc"); + TF_RETURN_IF_ERROR(ExtractOutsideCompilationForFunction( + xla_cluster_attr_name, outside_compilation_attr_name, xla_cluster_name, + then_branch, then_branch_xla_func_name, then_branch_host_func_name, + host_compute_core, fld, shape_inference_graphs, + &then_branch_has_outside_compilation)); + TF_RETURN_IF_ERROR(ExtractOutsideCompilationForFunction( + xla_cluster_attr_name, outside_compilation_attr_name, xla_cluster_name, + else_branch, else_branch_xla_func_name, else_branch_host_func_name, + host_compute_core, fld, shape_inference_graphs, + &else_branch_has_outside_compilation)); + + // If then/else branch do not have outside compilation, nothing to do. + if (!then_branch_has_outside_compilation && + !else_branch_has_outside_compilation) { + continue; + } + + *has_outside_compilation = true; + + // Change If node to call the new functions. + then_branch.set_name(then_branch_xla_func_name); + n->ClearAttr("then_branch"); + n->AddAttr("then_branch", then_branch); + else_branch.set_name(else_branch_xla_func_name); + n->ClearAttr("else_branch"); + n->AddAttr("else_branch", else_branch); + + string host_transfer_key = absl::StrCat("oc_if_pred_", n->name()); + + // XLA computation: add a SendToHost node to send cond predicate. + Node* pred_node; + TF_RETURN_IF_ERROR(n->input_node(0, &pred_node)); + TF_ASSIGN_OR_RETURN( + Node * send_pred_node, + BuildSendIfPredNode(absl::StrCat("send_oc_if_pred_", n->name()), + host_transfer_key, pred_node, g)); + n->AddAttr(kXlaTokenInputNodesAttrName, + std::vector{send_pred_node->name()}); + + // Add a control edge from `send_pred_node` to If node, so XlaCompiler will + // visit If node after `send_pred_node`, thus the token output for + // `send_pred_node` has been generated. + g->AddControlEdge(send_pred_node, n); + + // Build host side graph for the "If" node. + string oc_host_graph_name = absl::StrCat("oc_if_host_graph_", n->name()); + TF_RETURN_IF_ERROR(BuildHostGraphForIfNode( + xla_cluster_attr_name, outside_compilation_attr_name, xla_cluster_name, + n->name(), host_transfer_key, oc_host_graph_name, fld, + then_branch_host_func_name, else_branch_host_func_name)); + host_graphs->push_back(oc_host_graph_name); + } + + for (Node* n : while_nodes) { + // Instantiate "cond" and "body". + NameAttrList cond, body; + TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "cond", &cond)); + TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "body", &body)); + + // Extract outside compilation for cond and body. + bool cond_has_outside_compilation = false; + bool body_has_outside_compilation = false; + string cond_host_func_name = absl::StrCat("oc_cond_host_while_", n->name()), + body_host_func_name = absl::StrCat("oc_body_host_while_", n->name()); + string cond_xla_func_name = absl::StrCat(cond.name(), "_oc"), + body_xla_func_name = absl::StrCat(body.name(), "_oc"); + TF_RETURN_IF_ERROR(ExtractOutsideCompilationForFunction( + xla_cluster_attr_name, outside_compilation_attr_name, xla_cluster_name, + cond, cond_xla_func_name, cond_host_func_name, host_compute_core, fld, + shape_inference_graphs, &cond_has_outside_compilation)); + TF_RETURN_IF_ERROR(ExtractOutsideCompilationForFunction( + xla_cluster_attr_name, outside_compilation_attr_name, xla_cluster_name, + body, body_xla_func_name, body_host_func_name, host_compute_core, fld, + shape_inference_graphs, &body_has_outside_compilation)); + + // If cond/body do not have outside compilation, nothing to do. + if (!cond_has_outside_compilation && !body_has_outside_compilation) { + continue; + } + + *has_outside_compilation = true; + + // Change While node to call the new functions. + cond.set_name(cond_xla_func_name); + n->ClearAttr("cond"); + n->AddAttr("cond", cond); + body.set_name(body_xla_func_name); + n->ClearAttr("body"); + n->AddAttr("body", body); + + string host_transfer_key = absl::StrCat("oc_while_pred_", n->name()); + + // XLA computation: rewrite cond function to add a SendToHost node to send + // loop predicate. + TF_RETURN_IF_ERROR( + AddSendLoopPredToLoopCond(fld, cond, n->name(), host_transfer_key)); + n->AddAttr(kXlaTokenInputNodesAttrName, + std::vector{kXlaTokenArgNodeName}); + + // Build host side graph for the "While" node. + string oc_host_graph_name = absl::StrCat("oc_while_host_graph_", n->name()); + TF_RETURN_IF_ERROR(BuildHostGraphForWhileNode( + xla_cluster_attr_name, outside_compilation_attr_name, xla_cluster_name, + n->name(), host_transfer_key, oc_host_graph_name, fld, + cond_host_func_name, body_host_func_name)); + host_graphs->push_back(oc_host_graph_name); + } + + return Status::OK(); +} + } // namespace Status RewriteOutsideCompilationSubgraphFn::operator()( @@ -755,12 +1445,15 @@ Status RewriteOutsideCompilationSubgraphFn::operator()( // it with HostCompute node later. AddNodeAttr("_outside_compilation_subgraph", old_name, node_def); if (shapes) { - AddNodeAttr("shape_inference_graph", "", node_def); + NameAttrList shape_inference_graph; + AddNodeAttr("shape_inference_graph", shape_inference_graph, node_def); AddNodeAttr("shapes", *shapes, node_def); } else { string shape_inference_func_name = absl::StrCat("_outside_compilation_shape_inference_", new_name); - AddNodeAttr("shape_inference_graph", shape_inference_func_name, node_def); + NameAttrList shape_inference_graph; + shape_inference_graph.set_name(shape_inference_func_name); + AddNodeAttr("shape_inference_graph", shape_inference_graph, node_def); AddNodeAttr("shapes", std::vector{}, node_def); } AddNodeAttr("ancestors", std::vector{}, node_def); @@ -775,11 +1468,10 @@ Status ExtractOutsideCompilationForFunction( const string& xla_cluster_attr_name, const string& outside_compilation_attr_name, const string& xla_cluster_name, const NameAttrList& func_name_attrs, const string& new_func_name, + const string& host_graph_func_name, const std::map& host_compute_core, - FunctionLibraryDefinition* fld, std::unique_ptr* host_graph, - std::vector* shape_inference_graphs, + FunctionLibraryDefinition* fld, std::vector* shape_inference_graphs, bool* has_outside_compilation) { - // Early return if function does not have any outside compilation nodes. const string& func_name = func_name_attrs.name(); const FunctionDef* fdef = fld->Find(func_name); if (!fdef) { @@ -792,9 +1484,8 @@ Status ExtractOutsideCompilationForFunction( break; } } - if (!has_outside_compilation) { - return Status::OK(); - } + // We cannot early return here, because we might have outside compilation in + // If/While function body. // Convert the function to graph. FunctionBody* fbody = nullptr; @@ -835,11 +1526,11 @@ Status ExtractOutsideCompilationForFunction( // If we could not infer shapes for XlaSendFromHost inputs statically, we // will set the "shape_inference_graph" attribute. In that case, copy // outside compilation subgraph as shape inference graph in `fld`. - string shape_inference_graph; + NameAttrList shape_inference_graph; TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "shape_inference_graph", &shape_inference_graph)); - if (!shape_inference_graph.empty()) { - shape_inference_graphs->push_back(shape_inference_graph); + if (!shape_inference_graph.name().empty()) { + shape_inference_graphs->push_back(shape_inference_graph.name()); const FunctionDef* xla_fdef = fld->Find(n->name()); if (!xla_fdef) { @@ -847,9 +1538,9 @@ Status ExtractOutsideCompilationForFunction( } FunctionDef shape_inference_fdef = *xla_fdef; shape_inference_fdef.mutable_signature()->set_name( - shape_inference_graph); - if (fld->Find(shape_inference_graph)) { - TF_RETURN_IF_ERROR(fld->ReplaceFunction(shape_inference_graph, + shape_inference_graph.name()); + if (fld->Find(shape_inference_graph.name())) { + TF_RETURN_IF_ERROR(fld->ReplaceFunction(shape_inference_graph.name(), shape_inference_fdef)); } else { TF_RETURN_IF_ERROR(fld->AddFunctionDef(shape_inference_fdef)); @@ -858,6 +1549,7 @@ Status ExtractOutsideCompilationForFunction( } } for (Node* n : outside_compilation_nodes) { + TF_RETURN_IF_ERROR(ValidateOutsideCompilationCallNode(n)); TF_RETURN_IF_ERROR(ReplaceOrRemoveOutsideCompilationCallNode( graph_out.get(), n, host_compute_core)); } @@ -867,12 +1559,17 @@ Status ExtractOutsideCompilationForFunction( *graph_out, fld); } + // Handle nodes with associated functions. + TF_RETURN_IF_ERROR(ExtractOutsideCompilationForNodesWithAssociatedFunctions( + graph_out.get(), xla_cluster_attr_name, outside_compilation_attr_name, + xla_cluster_name, host_compute_core, fld, + &outside_compilation_host_graphs, shape_inference_graphs, + has_outside_compilation)); + // Construct host graph. - if (!outside_compilation_host_graphs.empty()) { - TF_RETURN_IF_ERROR( - ConstructHostGraph(xla_cluster_name, outside_compilation_attr_name, - outside_compilation_host_graphs, fld, host_graph)); - } + TF_RETURN_IF_ERROR(ConstructHostGraph( + xla_cluster_name, outside_compilation_attr_name, + outside_compilation_host_graphs, fld, host_graph_func_name)); // Remove the outside compilation graphs from function library. for (const string& func : outside_compilation_host_graphs) { @@ -909,24 +1606,17 @@ Status ExtractOutsideCompilation( auto const& host_compute_core = iter.second.host_compute_core; bool has_outside_compilation; - std::unique_ptr host_graph; + string host_graph_func_name = absl::StrCat("oc_host_graph_", n->name()); TF_RETURN_IF_ERROR(ExtractOutsideCompilationForFunction( xla_cluster_attr_name, outside_compilation_attr_name, xla_cluster_name, - func_name_attrs, func_name_attrs.name(), host_compute_core, fld, - &host_graph, &shape_inference_graphs, &has_outside_compilation)); - if (host_graph) { - TF_RETURN_IF_ERROR(ExpandHostGraphIntoMainGraph(g, host_graph.get(), n)); - } - } - - if (VLOG_IS_ON(4)) { - dump_graph::DumpGraphToFile("extract_outside_compilation_expanded", *g, - fld); + func_name_attrs, func_name_attrs.name(), host_graph_func_name, + host_compute_core, fld, &shape_inference_graphs, + &has_outside_compilation)); + TF_RETURN_IF_ERROR( + ExpandHostGraphIntoMainGraph(g, fld, host_graph_func_name, n)); + TF_RETURN_IF_ERROR(fld->RemoveFunction(host_graph_func_name)); } - TF_RETURN_IF_ERROR(PostprocessForEncapsulation( - g, xla_cluster_attr_name, outside_compilation_attr_name, clusters)); - for (auto shape_inference_graph_name : shape_inference_graphs) { TF_RETURN_IF_ERROR( RewriteShapeInferenceGraph(shape_inference_graph_name, g, fld)); diff --git a/tensorflow/compiler/jit/extract_outside_compilation_pass.h b/tensorflow/compiler/jit/extract_outside_compilation_pass.h index 2a4f07cca213d999202024294f5d8f94527059c3..e07e7c5dd0cd42ddd4d643d8b36583c82056bbb2 100644 --- a/tensorflow/compiler/jit/extract_outside_compilation_pass.h +++ b/tensorflow/compiler/jit/extract_outside_compilation_pass.h @@ -88,9 +88,10 @@ Status ExtractOutsideCompilationForFunction( const string& xla_cluster_attr_name, const string& outside_compilation_attr_name, const string& xla_cluster_name, const NameAttrList& func_name_attrs, const string& new_func_name, + const string& host_graph_func_name, const std::map& host_compute_core, - FunctionLibraryDefinition* fld, std::unique_ptr* host_graph, - std::vector* shape_inference_graphs, bool* has_outside_compilation); + FunctionLibraryDefinition* fld, std::vector* shape_inference_graphs, + bool* has_outside_compilation); // Rewrites XLA computation in `clusters` to replace outside compilation nodes // with XlaHostCompute, and moves those outside compilations into `g`. If shapes diff --git a/tensorflow/compiler/jit/extract_outside_compilation_pass_test.cc b/tensorflow/compiler/jit/extract_outside_compilation_pass_test.cc index bff956100da661b679b4557fce53671e6cef88c5..e9a89e34e0c7b04b4be34e367b2d0bf627c0061a 100644 --- a/tensorflow/compiler/jit/extract_outside_compilation_pass_test.cc +++ b/tensorflow/compiler/jit/extract_outside_compilation_pass_test.cc @@ -19,8 +19,10 @@ limitations under the License. #include "tensorflow/cc/framework/scope.h" #include "tensorflow/cc/ops/array_ops.h" #include "tensorflow/cc/ops/function_ops.h" +#include "tensorflow/cc/ops/functional_ops.h" #include "tensorflow/cc/ops/standard_ops.h" #include "tensorflow/compiler/jit/encapsulate_util.h" +#include "tensorflow/compiler/xla/test.h" #include "tensorflow/core/common_runtime/function.h" #include "tensorflow/core/framework/common_shape_fns.h" #include "tensorflow/core/framework/function.h" @@ -109,10 +111,10 @@ TEST(RewriteOutsideCompilationSubgraphFnTest, Basic) { } EXPECT_TRUE(has_control_edge_to_send_from_host); // Verify step 7: necessary attrs added to call_node_def. - string shape_inference_graph; + NameAttrList shape_inference_graph; TF_CHECK_OK(GetNodeAttr(AttrSlice(&call_node_def.attr()), "shape_inference_graph", &shape_inference_graph)); - EXPECT_EQ(shape_inference_graph, + EXPECT_EQ(shape_inference_graph.name(), "_outside_compilation_shape_inference_cluster_0"); } @@ -249,27 +251,26 @@ TEST(ExtractOutsideCompilationForFunctionTest, Basic) { protobuf::Map attrs; std::map host_compute_core = {{"0", 1}, {"1", 0}}; - std::unique_ptr host_graph; std::vector shape_inference_graphs; bool has_outside_compilation; NameAttrList name_attrs; name_attrs.set_name("cluster"); *name_attrs.mutable_attr() = attrs; TF_CHECK_OK(ExtractOutsideCompilationForFunction( - "_xla", "_oc", "cluster", name_attrs, "cluster_rewritten", - host_compute_core, &fld, &host_graph, &shape_inference_graphs, + "_xla", "_oc", "cluster", name_attrs, "cluster_rewritten", "host_graph", + host_compute_core, &fld, &shape_inference_graphs, &has_outside_compilation)); // Get rewritten XLA computation function. - FunctionBody *fbody = nullptr; - TF_CHECK_OK(FunctionDefToBodyHelper(*fld.Find("cluster_rewritten"), - AttrSlice(), &fld, - [&](const string &op, const OpDef **sig) { - return fld.LookUpOpDef(op, sig); - }, - &fbody)); - std::unique_ptr fbody_deleter(fbody); - auto node_name_index = fbody->graph->BuildNodeNameIndex(); + FunctionBody *xla_fbody = nullptr; + TF_CHECK_OK(FunctionDefToBodyHelper( + *fld.Find("cluster_rewritten"), AttrSlice(), &fld, + [&](const string &op, const OpDef **sig) { + return fld.LookUpOpDef(op, sig); + }, + &xla_fbody)); + std::unique_ptr xla_fbody_deleter(xla_fbody); + auto node_name_index = xla_fbody->graph->BuildNodeNameIndex(); // Check XlaHostCompute nodes. Node *host_compute_0 = node_name_index["outside_compilation_0_host_compute"]; @@ -292,18 +293,31 @@ TEST(ExtractOutsideCompilationForFunctionTest, Basic) { EXPECT_EQ(shapes[0].dim_size(), 1); // Check XlaHostCompute nodes' "shape_inference_graph" attr. Both should have // empty values. - string shape_inference_graph; + NameAttrList shape_inference_graph; TF_CHECK_OK(GetNodeAttr(host_compute_0->attrs(), "shape_inference_graph", &shape_inference_graph)); - EXPECT_EQ(shape_inference_graph, ""); + EXPECT_EQ(shape_inference_graph.name(), ""); TF_CHECK_OK(GetNodeAttr(host_compute_1->attrs(), "shape_inference_graph", &shape_inference_graph)); - EXPECT_EQ(shape_inference_graph, ""); + EXPECT_EQ(shape_inference_graph.name(), ""); // Check `shape_inference_graphs`. EXPECT_EQ(shape_inference_graphs.size(), 0); - // Check `host_graph`: verify we have key placeholder and sequencer. + // Check host graph: verify we have key placeholder and sequencer. + FunctionBody *host_fbody = nullptr; + AttrValue device_ordinal_temp_value; + device_ordinal_temp_value.set_i(0); + protobuf::Map host_func_attrs; + host_func_attrs["device_ordinal"] = device_ordinal_temp_value; + TF_CHECK_OK(FunctionDefToBodyHelper( + *fld.Find("host_graph"), AttrSlice(&host_func_attrs), &fld, + [&](const string &op, const OpDef **sig) { + return fld.LookUpOpDef(op, sig); + }, + &host_fbody)); + std::unique_ptr host_fbody_deleter(host_fbody); + Graph *host_graph = host_fbody->graph; Node *key_placeholder = nullptr, *sequencer = nullptr; for (Node *n : host_graph->nodes()) { if (n->type_string() == "Placeholder" && @@ -365,25 +379,37 @@ TEST(ExtractOutsideCompilationForFunctionTest, NoHostGraph) { protobuf::Map attrs; std::map host_compute_core = {{"0", 1}, {"1", 0}}; - std::unique_ptr host_graph; std::vector shape_inference_graphs; bool has_outside_compilation; NameAttrList name_attrs; name_attrs.set_name("cluster"); *name_attrs.mutable_attr() = attrs; TF_CHECK_OK(ExtractOutsideCompilationForFunction( - "_xla", "_oc", "cluster", name_attrs, "cluster_rewritten", - host_compute_core, &fld, &host_graph, &shape_inference_graphs, + "_xla", "_oc", "cluster", name_attrs, "cluster_rewritten", "host_graph", + host_compute_core, &fld, &shape_inference_graphs, &has_outside_compilation)); - // Check `host_graph` is empty. - EXPECT_FALSE(host_graph); + // Check host graph is empty. + FunctionBody *host_fbody = nullptr; + AttrValue device_ordinal_temp_value; + device_ordinal_temp_value.set_i(0); + protobuf::Map host_func_attrs; + host_func_attrs["device_ordinal"] = device_ordinal_temp_value; + TF_CHECK_OK(FunctionDefToBodyHelper( + *fld.Find("host_graph"), AttrSlice(&host_func_attrs), &fld, + [&](const string &op, const OpDef **sig) { + return fld.LookUpOpDef(op, sig); + }, + &host_fbody)); + std::unique_ptr host_fbody_deleter(host_fbody); + Graph *host_graph = host_fbody->graph; + EXPECT_EQ(host_graph->num_nodes(), 2); } TEST(ExtractOutsideCompilationForFunctionTest, XlaHostComputeRemoved) { // Build the XLA computation func. // "const0" - // "const1" (outside compilation clsuter "0") + // "const1" (outside compilation cluster "0") FunctionDefLibrary fdl; { tensorflow::Scope s = tensorflow::Scope::NewRootScope(); @@ -401,31 +427,43 @@ TEST(ExtractOutsideCompilationForFunctionTest, XlaHostComputeRemoved) { protobuf::Map attrs; std::map host_compute_core = {{"0", 1}, {"1", 0}}; - std::unique_ptr host_graph; std::vector shape_inference_graphs; bool has_outside_compilation; NameAttrList name_attrs; name_attrs.set_name("cluster"); *name_attrs.mutable_attr() = attrs; TF_CHECK_OK(ExtractOutsideCompilationForFunction( - "_xla", "_oc", "cluster", name_attrs, "cluster_rewritten", - host_compute_core, &fld, &host_graph, &shape_inference_graphs, + "_xla", "_oc", "cluster", name_attrs, "cluster_rewritten", "host_graph", + host_compute_core, &fld, &shape_inference_graphs, &has_outside_compilation)); // Check rewritten XLA graph: verify that we have no XlaHostCompute. - FunctionBody *fbody = nullptr; - TF_CHECK_OK(FunctionDefToBodyHelper(*fld.Find("cluster_rewritten"), - AttrSlice(), &fld, - [&](const string &op, const OpDef **sig) { - return fld.LookUpOpDef(op, sig); - }, - &fbody)); - std::unique_ptr fbody_deleter(fbody); - for (Node *n : fbody->graph->nodes()) { + FunctionBody *xla_fbody = nullptr; + TF_CHECK_OK(FunctionDefToBodyHelper( + *fld.Find("cluster_rewritten"), AttrSlice(), &fld, + [&](const string &op, const OpDef **sig) { + return fld.LookUpOpDef(op, sig); + }, + &xla_fbody)); + std::unique_ptr xla_fbody_deleter(xla_fbody); + for (Node *n : xla_fbody->graph->nodes()) { EXPECT_NE(n->type_string(), "XlaHostCompute"); } - // Check `host_graph`: verify we have no placeholder, but we have "const1". + // Check host graph: verify we have no placeholder, but we have "const1". + FunctionBody *host_fbody = nullptr; + AttrValue device_ordinal_temp_value; + device_ordinal_temp_value.set_i(0); + protobuf::Map host_func_attrs; + host_func_attrs["device_ordinal"] = device_ordinal_temp_value; + TF_CHECK_OK(FunctionDefToBodyHelper( + *fld.Find("host_graph"), AttrSlice(&host_func_attrs), &fld, + [&](const string &op, const OpDef **sig) { + return fld.LookUpOpDef(op, sig); + }, + &host_fbody)); + std::unique_ptr host_fbody_deleter(host_fbody); + Graph *host_graph = host_fbody->graph; int num_key_placeholders = 0; for (Node *n : host_graph->nodes()) { if (n->type_string() == "Placeholder" && @@ -438,4 +476,310 @@ TEST(ExtractOutsideCompilationForFunctionTest, XlaHostComputeRemoved) { EXPECT_NE(node_name_index.find("const1"), node_name_index.end()); } +REGISTER_OP("XlaSendToHost") + .Input("input: Tinput") + .Attr("Tinput: type") + .Attr("key: string") + .SetIsStateful(); + +REGISTER_OP("XlaRecvFromHost") + .Output("output: Toutput") + .Attr("Toutput: type") + .Attr("shape: shape") + .Attr("key: string") + .SetIsStateful(); + +TEST(ExtractOutsideCompilationForFunctionTest, OutsideCompilationInIf) { + // Build the XLA computation func. + // "const0" (bool) + // "const1" (int32) + // "if0" (pred = "const0", input = "const1", then_branch = "true_fn", + // else_branch = "false_fn") + FunctionDefLibrary fdl; + { + tensorflow::Scope s = tensorflow::Scope::NewRootScope(); + Output arg = ops::_Arg(s.WithOpName("arg"), DT_INT32, 0); + Output identity = ops::Identity(s.WithOpName("identity_true_fn"), arg); + ops::_Retval retval(s.WithOpName("retval"), identity, 0); + std::unique_ptr g(new Graph(OpRegistry::Global())); + TF_CHECK_OK(s.ToGraph(g.get())); + auto node_name_image = g->BuildNodeNameIndex(); + node_name_image["identity_true_fn"]->AddAttr("_oc", "0"); + PartialTensorShape shape({2}); + node_name_image["identity_true_fn"]->AddAttr( + kXlaInferredShapesAttrName, std::vector{shape}); + + FunctionDef *true_fn_fdef = fdl.add_function(); + TF_CHECK_OK(GraphToFunctionDef(*g, "true_fn", true_fn_fdef)); + } + { + tensorflow::Scope s = tensorflow::Scope::NewRootScope(); + Output arg = ops::_Arg(s.WithOpName("arg"), DT_INT32, 0); + Output identity = ops::Identity(s.WithOpName("identity_false_fn"), arg); + ops::_Retval retval(s.WithOpName("retval"), identity, 0); + std::unique_ptr g(new Graph(OpRegistry::Global())); + TF_CHECK_OK(s.ToGraph(g.get())); + auto node_name_image = g->BuildNodeNameIndex(); + node_name_image["identity_false_fn"]->AddAttr("_oc", "0"); + PartialTensorShape shape({2}); + node_name_image["identity_false_fn"]->AddAttr( + kXlaInferredShapesAttrName, std::vector{shape}); + + FunctionDef *false_fn_fdef = fdl.add_function(); + TF_CHECK_OK(GraphToFunctionDef(*g, "false_fn", false_fn_fdef)); + } + { + tensorflow::Scope s = tensorflow::Scope::NewRootScope(); + Output cond = ops::Const(s.WithOpName("const0"), true, {2}); + Output input = ops::Const(s.WithOpName("const1"), 1, {2}); + NameAttrList true_fn; + true_fn.set_name("true_fn"); + NameAttrList false_fn; + false_fn.set_name("false_fn"); + auto if_op = ops::If(s.WithOpName("if"), cond, + std::initializer_list{cond, input}, {DT_INT32}, + true_fn, false_fn); + ops::_Retval retval(s.WithOpName("retval"), if_op.output[0], 0); + std::unique_ptr g(new Graph(OpRegistry::Global())); + TF_CHECK_OK(s.ToGraph(g.get())); + + FunctionDef *xla_fdef = fdl.add_function(); + TF_CHECK_OK(GraphToFunctionDef(*g, "cluster", xla_fdef)); + } + FunctionLibraryDefinition fld(OpRegistry::Global(), fdl); + + protobuf::Map attrs; + std::map host_compute_core; + std::vector shape_inference_graphs; + bool has_outside_compilation; + NameAttrList name_attrs; + name_attrs.set_name("cluster"); + *name_attrs.mutable_attr() = attrs; + TF_CHECK_OK(ExtractOutsideCompilationForFunction( + "_xla", "_oc", "cluster", name_attrs, "cluster_rewritten", "host_graph", + host_compute_core, &fld, &shape_inference_graphs, + &has_outside_compilation)); + + // Check host graph. + { + FunctionBody *host_fbody = nullptr; + AttrValue device_ordinal_temp_value; + device_ordinal_temp_value.set_i(0); + protobuf::Map host_func_attrs; + host_func_attrs["device_ordinal"] = device_ordinal_temp_value; + TF_CHECK_OK(FunctionDefToBodyHelper( + *fld.Find("host_graph"), AttrSlice(&host_func_attrs), &fld, + [&](const string &op, const OpDef **sig) { + return fld.LookUpOpDef(op, sig); + }, + &host_fbody)); + std::unique_ptr host_fbody_deleter(host_fbody); + Graph *host_graph = host_fbody->graph; + auto node_name_index = host_graph->BuildNodeNameIndex(); + + // Verify we have XlaRecvAtHost to receive "If" predicate. + Node *recv_if_pred_node = node_name_index["recv_oc_if_pred_if"]; + EXPECT_NE(recv_if_pred_node, nullptr); + + // Verify we have an "If" to choose outside compilation between then_branch + // and else_branch, and it has `recv_if_pred_node` as cond input. + Node *if_oc_node = node_name_index["oc_if_if"]; + EXPECT_NE(if_oc_node, nullptr); + Node *if_oc_node_cond_input; + TF_CHECK_OK(if_oc_node->input_node(0, &if_oc_node_cond_input)); + EXPECT_EQ(if_oc_node_cond_input, recv_if_pred_node); + + // Check that then_branch outside compilation has node "identity_true_fn". + const FunctionDef *true_def = fld.Find("oc_then_branch_host_if_if"); + EXPECT_NE(true_def, nullptr); + bool has_identity_true_fn_node = false; + for (const auto &node_def : true_def->node_def()) { + if (node_def.name() == "identity_true_fn") { + has_identity_true_fn_node = true; + break; + } + } + EXPECT_TRUE(has_identity_true_fn_node); + + // Check that else_branch outside compilation has node "identity_false_fn". + const FunctionDef *false_def = fld.Find("oc_else_branch_host_if_if"); + EXPECT_NE(false_def, nullptr); + bool has_identity_false_fn_node = false; + for (const auto &node_def : false_def->node_def()) { + if (node_def.name() == "identity_false_fn") { + has_identity_false_fn_node = true; + break; + } + } + EXPECT_TRUE(has_identity_false_fn_node); + } + + // Check XLA graph. + { + FunctionBody *xla_fbody = nullptr; + TF_CHECK_OK(FunctionDefToBodyHelper( + *fld.Find("cluster_rewritten"), AttrSlice(), &fld, + [&](const string &op, const OpDef **sig) { + return fld.LookUpOpDef(op, sig); + }, + &xla_fbody)); + std::unique_ptr xla_fbody_deleter(xla_fbody); + Graph *xla_graph = xla_fbody->graph; + auto node_name_index = xla_graph->BuildNodeNameIndex(); + + // Check that we have XlaSendToHost to send cond predicate to host, and + // there is a control edge to If node. + Node *send_if_pred_node = node_name_index["send_oc_if_pred_if"]; + EXPECT_NE(send_if_pred_node, nullptr); + bool has_control_edge_to_if = false; + for (const Edge *e : send_if_pred_node->out_edges()) { + if (e->IsControlEdge() && e->dst()->name() == "if") { + has_control_edge_to_if = true; + break; + } + } + EXPECT_TRUE(has_control_edge_to_if); + + // Check that the "If" node now has `send_if_pred_node` as attribute + // _xla_token_input_nodes. + Node *if_node = node_name_index["if"]; + EXPECT_NE(if_node, nullptr); + std::vector token_inputs; + TF_CHECK_OK( + GetNodeAttr(if_node->def(), "_xla_token_input_nodes", &token_inputs)); + EXPECT_THAT(token_inputs, ::testing::ElementsAre("send_oc_if_pred_if")); + } +} + +TEST(ExtractOutsideCompilationForFunctionTest, OutsideCompilationInWhile) { + // Build the XLA computation func. + // "const0" (bool) + // "while0" (input = "const0", cond = "cond_fn", body = "body_fn") + FunctionDefLibrary fdl; + { + tensorflow::Scope s = tensorflow::Scope::NewRootScope(); + Output arg = ops::_Arg(s.WithOpName("arg"), DT_BOOL, 0); + Output identity = ops::Identity(s.WithOpName("identity_cond_fn"), arg); + ops::_Retval retval(s.WithOpName("retval"), identity, 0); + std::unique_ptr g(new Graph(OpRegistry::Global())); + TF_CHECK_OK(s.ToGraph(g.get())); + auto node_name_image = g->BuildNodeNameIndex(); + node_name_image["identity_cond_fn"]->AddAttr("_oc", "0"); + PartialTensorShape shape({2}); + node_name_image["identity_cond_fn"]->AddAttr( + kXlaInferredShapesAttrName, std::vector{shape}); + + FunctionDef *cond_fn_fdef = fdl.add_function(); + TF_CHECK_OK(GraphToFunctionDef(*g, "cond_fn", cond_fn_fdef)); + } + { + tensorflow::Scope s = tensorflow::Scope::NewRootScope(); + Output arg = ops::_Arg(s.WithOpName("arg"), DT_BOOL, 0); + Output identity = ops::Identity(s.WithOpName("identity_body_fn"), arg); + ops::_Retval retval(s.WithOpName("retval"), identity, 0); + std::unique_ptr g(new Graph(OpRegistry::Global())); + TF_CHECK_OK(s.ToGraph(g.get())); + auto node_name_image = g->BuildNodeNameIndex(); + node_name_image["identity_body_fn"]->AddAttr("_oc", "0"); + PartialTensorShape shape({2}); + node_name_image["identity_body_fn"]->AddAttr( + kXlaInferredShapesAttrName, std::vector{shape}); + + FunctionDef *body_fn_fdef = fdl.add_function(); + TF_CHECK_OK(GraphToFunctionDef(*g, "body_fn", body_fn_fdef)); + } + { + tensorflow::Scope s = tensorflow::Scope::NewRootScope(); + Output input = ops::Const(s.WithOpName("const0"), true, {2}); + NameAttrList cond_fn; + cond_fn.set_name("cond_fn"); + NameAttrList body_fn; + body_fn.set_name("body_fn"); + auto while_op = + ops::While(s.WithOpName("while"), std::initializer_list{input}, + cond_fn, body_fn); + ops::_Retval retval(s.WithOpName("retval"), while_op.output[0], 0); + std::unique_ptr g(new Graph(OpRegistry::Global())); + TF_CHECK_OK(s.ToGraph(g.get())); + + FunctionDef *xla_fdef = fdl.add_function(); + TF_CHECK_OK(GraphToFunctionDef(*g, "cluster", xla_fdef)); + } + FunctionLibraryDefinition fld(OpRegistry::Global(), fdl); + + protobuf::Map attrs; + std::map host_compute_core; + std::vector shape_inference_graphs; + bool has_outside_compilation; + NameAttrList name_attrs; + name_attrs.set_name("cluster"); + *name_attrs.mutable_attr() = attrs; + TF_CHECK_OK(ExtractOutsideCompilationForFunction( + "_xla", "_oc", "cluster", name_attrs, "cluster_rewritten", "host_graph", + host_compute_core, &fld, &shape_inference_graphs, + &has_outside_compilation)); + + // Check host graph. + { + FunctionBody *host_fbody = nullptr; + AttrValue device_ordinal_temp_value; + device_ordinal_temp_value.set_i(0); + protobuf::Map host_func_attrs; + host_func_attrs["device_ordinal"] = device_ordinal_temp_value; + TF_CHECK_OK(FunctionDefToBodyHelper( + *fld.Find("host_graph"), AttrSlice(&host_func_attrs), &fld, + [&](const string &op, const OpDef **sig) { + return fld.LookUpOpDef(op, sig); + }, + &host_fbody)); + std::unique_ptr host_fbody_deleter(host_fbody); + Graph *host_graph = host_fbody->graph; + auto node_name_index = host_graph->BuildNodeNameIndex(); + + // Verify we have an "While" to execute outside compilation. + Node *while_oc_node = node_name_index["oc_while_while"]; + EXPECT_NE(while_oc_node, nullptr); + + // Check that cond outside compilation has node "identity_cond_fn". + const FunctionDef *cond_def = fld.Find("oc_cond_host_while_while"); + EXPECT_NE(cond_def, nullptr); + bool has_identity_cond_fn_node = false; + for (const auto &node_def : cond_def->node_def()) { + if (node_def.name() == "identity_cond_fn") { + has_identity_cond_fn_node = true; + break; + } + } + EXPECT_TRUE(has_identity_cond_fn_node); + + // Check that body outside compilation has node "identity_body_fn". + const FunctionDef *body_def = fld.Find("oc_body_host_while_while"); + EXPECT_NE(body_def, nullptr); + bool has_identity_body_fn_node = false; + for (const auto &node_def : body_def->node_def()) { + if (node_def.name() == "identity_body_fn") { + has_identity_body_fn_node = true; + break; + } + } + EXPECT_TRUE(has_identity_body_fn_node); + } + + // Check XLA graph. + { + // Verify that rewritten cond fn has XlaSendToHost to send loop predicate to + // host. + const FunctionDef *cond_def = fld.Find("cond_fn_oc"); + EXPECT_NE(cond_def, nullptr); + bool has_send_oc_while_cond_node = false; + for (const auto &node_def : cond_def->node_def()) { + if (node_def.name() == "send_oc_while_cond_while") { + has_send_oc_while_cond_node = true; + break; + } + } + EXPECT_TRUE(has_send_oc_while_cond_node); + } +} + } // namespace tensorflow diff --git a/tensorflow/compiler/jit/partially_decluster_pass.cc b/tensorflow/compiler/jit/partially_decluster_pass.cc index 42ea3926e16ae791dbe1bede3b8742383db7667c..e1fd2aaee2822daeffb415d053c9c4f56002a856 100644 --- a/tensorflow/compiler/jit/partially_decluster_pass.cc +++ b/tensorflow/compiler/jit/partially_decluster_pass.cc @@ -120,6 +120,7 @@ Status PartiallyDeclusterNode(Graph* graph, Node* n) { NodeDef ndef = n->def(); ndef.set_name(absl::StrCat(n->name(), "/declustered")); + MergeDebugInfo(NodeDebugInfo(n->def()), &ndef); RemoveFromXlaCluster(&ndef); Status s; Node* cloned_node = graph->AddNode(ndef, &s); diff --git a/tensorflow/compiler/jit/shape_inference.cc b/tensorflow/compiler/jit/shape_inference.cc index 80c691fe490c1092315708a2da754d367d585300..a27e0d9f2a6ecddfdbdb29be673084d77a178d8a 100644 --- a/tensorflow/compiler/jit/shape_inference.cc +++ b/tensorflow/compiler/jit/shape_inference.cc @@ -53,7 +53,15 @@ Status PropagateShapes(const Graph& graph, // shapes, even if no shape function is registered for a node. Status status = shape_refiner->AddNode(n); if (!status.ok()) { - VLOG(1) << "Shape inference failed for node: " << status; + VLOG(1) << "Shape inference failed for node " << n->name() << ": " + << status; + } else { + shape_inference::InferenceContext* context = shape_refiner->GetContext(n); + for (int i = 0; i < n->num_outputs(); i++) { + shape_inference::ShapeHandle handle = context->output(i); + VLOG(4) << "Output " << i << " for node " << n->name() << ": " + << context->DebugString(handle); + } } if (n->type_string() == "_Arg") { diff --git a/tensorflow/compiler/jit/xla_compile_on_demand_op.cc b/tensorflow/compiler/jit/xla_compile_on_demand_op.cc index 1fe612d43d10030675cf307b109e4dcc89cb2d79..c7e8d61d280a33a83c3386d8ef801018634d31ec 100644 --- a/tensorflow/compiler/jit/xla_compile_on_demand_op.cc +++ b/tensorflow/compiler/jit/xla_compile_on_demand_op.cc @@ -142,11 +142,22 @@ Status XlaCompileOnDemandOp::Compile( TF_RETURN_IF_ERROR(ctx->allocate_temp( device_tensor.dtype(), device_tensor.shape(), &host_tensor, attrs)); Notification n; + Status status; ctx->op_device_context()->CopyDeviceTensorToCPU( &device_tensor, "ConstantArgument", reinterpret_cast(ctx->device()), &host_tensor, - [&](Status status) { n.Notify(); }); + [&](Status s) { + status = s; + n.Notify(); + }); n.WaitForNotification(); + if (!status.ok()) { + LOG(ERROR) << "Copying tensor of shape " + << device_tensor.shape().DebugString() << " from " + << ctx->device()->name() << "to CPU failed with " + << status.ToString(); + return status; + } constant_arguments[i] = host_tensor; } } @@ -189,6 +200,7 @@ Status XlaCompileOnDemandOp::Compile( std::map variable_args = GetVariables(ctx); std::vector args; + TF_RETURN_IF_ERROR(XlaComputationLaunchContext::BuildXlaCompilerArguments( constant_arguments, variable_args, ctx, &args)); diff --git a/tensorflow/compiler/jit/xla_cpu_device.cc b/tensorflow/compiler/jit/xla_cpu_device.cc index 7df898ad12a15345f45fc96e0ec3d42b6e51731b..e9770647e7ba96cc1db026d12d5f11f52ce98d35 100644 --- a/tensorflow/compiler/jit/xla_cpu_device.cc +++ b/tensorflow/compiler/jit/xla_cpu_device.cc @@ -63,7 +63,19 @@ Status XlaCpuDeviceFactory::CreateDevices( options.device_ordinal = 0; options.compilation_device_name = DEVICE_CPU_XLA_JIT; options.use_multiple_streams = false; - devices->push_back(absl::make_unique(session_options, options)); + auto device = absl::make_unique(session_options, options); + + // Setting GpuDeviceInfo because eager runtime relies on the device + // context in tensorflow_gpu_device_info(). Also, + // tensorflow_gpu_device_info() == nullptr is used as an IsCPU test. + // We need XlaCpuDevice to be treated not as CPU because it allocates + // XlaTensors, not regular Tensors. + Status status = device->UseGpuDeviceInfo(); + if (!status.ok()) { + errors::AppendToMessage(&status, "while setting up ", DEVICE_GPU_XLA_JIT); + return status; + } + devices->push_back(std::move(device)); return Status::OK(); } diff --git a/tensorflow/compiler/jit/xla_device.cc b/tensorflow/compiler/jit/xla_device.cc index 4201ff91a89b1bee370e6a43337c51abe3bf974a..77cd2f44628677942da9e576070d1d295194cead 100644 --- a/tensorflow/compiler/jit/xla_device.cc +++ b/tensorflow/compiler/jit/xla_device.cc @@ -201,7 +201,8 @@ XlaDevice::XlaDevice(const SessionOptions& session_options, jit_device_name_(options.compilation_device_name), platform_(options.platform), use_multiple_streams_(options.use_multiple_streams), - shape_representation_fn_(options.shape_representation_fn) { + shape_representation_fn_(options.shape_representation_fn), + allowed_devices_(options.allowed_devices) { VLOG(1) << "Created XLA device " << options.compilation_device_name << " " << this; thread_pool_.reset(new thread::ThreadPool(session_options.env, "xla_device", @@ -234,7 +235,8 @@ xla::LocalClient* XlaDevice::client() const { // TODO(b/78468222): This can fail, at least when the backend is GPU and // there is no GPU on the host. - return xla::ClientLibrary::GetOrCreateLocalClient(platform_).ValueOrDie(); + return xla::ClientLibrary::GetOrCreateLocalClient(platform_, allowed_devices_) + .ValueOrDie(); } Allocator* XlaDevice::GetAllocator(AllocatorAttributes attr) { diff --git a/tensorflow/compiler/jit/xla_device.h b/tensorflow/compiler/jit/xla_device.h index c8bb276cdb9673fdcba4cc15a9f33ecd3ae96dbb..45f18ac9ee6d403c192bd421d7823f2d408d994b 100644 --- a/tensorflow/compiler/jit/xla_device.h +++ b/tensorflow/compiler/jit/xla_device.h @@ -24,7 +24,9 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_JIT_XLA_DEVICE_H_ #define TENSORFLOW_COMPILER_JIT_XLA_DEVICE_H_ +#include +#include "absl/types/optional.h" #include "tensorflow/compiler/jit/xla_device_context.h" #include "tensorflow/compiler/jit/xla_tensor.h" #include "tensorflow/compiler/tf2xla/xla_compiler.h" @@ -123,6 +125,11 @@ class XlaDevice : public LocalDevice { // If padded_shape_fn is empty, a default implementation that returns // the logical on-device shape without padding is used. PaddedShapeFn padded_shape_fn; + + // Set of devices to use. This controls which of the devices on the given + // platform will have resources allocated. For GPUs this will be + // filled from visible_gpu_devices list from session configuration. + absl::optional> allowed_devices; }; // Creates a new XLA Device. @@ -256,6 +263,11 @@ class XlaDevice : public LocalDevice { // completion. int64 outstanding_asynchronous_operations_ GUARDED_BY(mu_) = 0; condition_variable outstanding_asynchronous_operations_cv_; + + // Set of devices to use. This controls which of the devices on the given + // platform will have resources allocated. For GPUs this will be + // filled from visible_gpu_devices list from session configuration. + absl::optional> allowed_devices_; }; // Builds OpKernel registrations on 'device' for the JIT operators diff --git a/tensorflow/compiler/jit/xla_device_context.cc b/tensorflow/compiler/jit/xla_device_context.cc index 6e6532731e64bd42ee56aa719748988f321e0f17..1f3afe8822d441a5ce37617fe18d7767e9bc72e4 100644 --- a/tensorflow/compiler/jit/xla_device_context.cc +++ b/tensorflow/compiler/jit/xla_device_context.cc @@ -79,6 +79,13 @@ XlaDeviceContext::XlaDeviceContext( } } +void XlaDeviceContext::CopyTensorInSameDevice(const Tensor* input_tensor, + Device* device, + Tensor* output_tensor, + StatusCallback done) const { + done(errors::Unimplemented("XLA->XLA same-device copies not implemented.")); +} + void XlaDeviceContext::CopyCPUTensorToDevice(const Tensor* cpu_tensor, Device* device, Tensor* device_tensor, diff --git a/tensorflow/compiler/jit/xla_device_context.h b/tensorflow/compiler/jit/xla_device_context.h index 1e18df197a2dd65590c5181b4dae4481dca36641..e45db989fac720df6c3458c93a6b8dbb0919f930 100644 --- a/tensorflow/compiler/jit/xla_device_context.h +++ b/tensorflow/compiler/jit/xla_device_context.h @@ -62,6 +62,9 @@ class XlaDeviceContext : public DeviceContext { void CopyDeviceTensorToCPU(const Tensor* device_tensor, absl::string_view tensor_name, Device* device, Tensor* cpu_tensor, StatusCallback done) override; + void CopyTensorInSameDevice(const Tensor* input_tensor, Device* device, + Tensor* output_tensor, + StatusCallback done) const override; xla::LocalClient* client() const { return client_; } se::Stream* stream() const { return stream_.get(); } diff --git a/tensorflow/compiler/jit/xla_device_ops.h b/tensorflow/compiler/jit/xla_device_ops.h index adf0f994b84d9fbf918a5b2478aa7d106853e038..927f983ba9ef23c8509523f42366c0c89c29db9f 100644 --- a/tensorflow/compiler/jit/xla_device_ops.h +++ b/tensorflow/compiler/jit/xla_device_ops.h @@ -203,6 +203,8 @@ class XlaAssignVariableOp : public OpKernel { .HostMemory("output") \ .TypeConstraint("T"), \ ArgOp); \ + REGISTER_KERNEL_BUILDER( \ + Name(kArgOp).Device(DEVICE).TypeConstraint("T"), ArgOp); \ \ REGISTER_KERNEL_BUILDER(Name(kRetOp) \ .Device(DEVICE) \ diff --git a/tensorflow/compiler/jit/xla_gpu_device.cc b/tensorflow/compiler/jit/xla_gpu_device.cc index 0191315a66f4d331e54fadc9dc6a073a05fd67ef..b29f6a009b9e9fdba76ac55386a4bec2f339cc0e 100644 --- a/tensorflow/compiler/jit/xla_gpu_device.cc +++ b/tensorflow/compiler/jit/xla_gpu_device.cc @@ -29,6 +29,30 @@ limitations under the License. namespace tensorflow { +// Returns a set containing the device ids contained in visible_device_list or +// nullopt if it is empty. It returns error in case of malformed configuration +// string. +static xla::StatusOr>> ParseVisibleDeviceList( + const string& visible_device_list) { + std::set gpu_ids; + if (visible_device_list.empty()) { + return {{absl::nullopt}}; + } + const std::vector visible_devices = + absl::StrSplit(visible_device_list, ','); + for (const string& platform_gpu_id_str : visible_devices) { + int32 platform_gpu_id; + if (!absl::SimpleAtoi(platform_gpu_id_str, &platform_gpu_id)) { + return errors::InvalidArgument( + "Could not parse entry in 'visible_device_list': '", + platform_gpu_id_str, + "'. visible_device_list = ", visible_device_list); + } + gpu_ids.insert(platform_gpu_id); + } + return {{gpu_ids}}; +} + class XlaGpuDeviceFactory : public DeviceFactory { public: Status CreateDevices(const SessionOptions& options, const string& name_prefix, @@ -57,33 +81,16 @@ Status XlaGpuDeviceFactory::CreateDevices( } string allowed_gpus = session_options.config.gpu_options().visible_device_list(); - std::set gpu_ids; - int num_visible_devices = platform.ValueOrDie()->VisibleDeviceCount(); - if (allowed_gpus.empty()) { - for (int i = 0; i < num_visible_devices; ++i) { - gpu_ids.insert(i); - } - } else { - // For loop below is copied from gpu/gpu_device.cc. It validates - // the visible_device_list and populates gpu_ids set. - const std::vector visible_devices = - absl::StrSplit(allowed_gpus, ','); - for (const string& platform_gpu_id_str : visible_devices) { - int32 platform_gpu_id; - if (!absl::SimpleAtoi(platform_gpu_id_str, &platform_gpu_id)) { - return errors::InvalidArgument( - "Could not parse entry in 'visible_device_list': '", - platform_gpu_id_str, "'. visible_device_list = ", allowed_gpus); - } - if (platform_gpu_id < 0 || platform_gpu_id >= num_visible_devices) { - return errors::InvalidArgument( - "'visible_device_list' listed an invalid GPU id '", platform_gpu_id, - "' but visible device count is ", num_visible_devices); - } - gpu_ids.insert(platform_gpu_id); + absl::optional> gpu_ids = + ParseVisibleDeviceList(allowed_gpus).ValueOrDie(); + if (!gpu_ids) { + gpu_ids.emplace(); + // Fill the gpu_ids set with all devices if config string is empty. + for (int i = 0; i < platform.ValueOrDie()->VisibleDeviceCount(); ++i) { + gpu_ids->insert(i); } } - for (int i : gpu_ids) { + for (int i : *gpu_ids) { XlaDevice::Options options; options.platform = platform.ValueOrDie(); options.device_name_prefix = name_prefix; @@ -91,6 +98,7 @@ Status XlaGpuDeviceFactory::CreateDevices( options.device_ordinal = i; options.compilation_device_name = DEVICE_GPU_XLA_JIT; options.use_multiple_streams = true; + options.allowed_devices = gpu_ids; auto device = absl::make_unique(session_options, options); Status status = device->UseGpuDeviceInfo(); diff --git a/tensorflow/compiler/tests/BUILD b/tensorflow/compiler/tests/BUILD index bc3d60b90e58b4018f1c52b09941dedba7ef348a..fa02cf9cbef45188a6dc2f861ff036649ea92b03 100644 --- a/tensorflow/compiler/tests/BUILD +++ b/tensorflow/compiler/tests/BUILD @@ -408,13 +408,6 @@ tf_xla_py_test( name = "eager_test", size = "large", srcs = ["eager_test.py"], - disabled_backends = [ - # TODO(b/78199195) Support XLA CPU devices in eager runtime - "cpu", - "cpu_ondemand", - # TODO(b/78468222) Enable GPU backend - "gpu", - ], deps = [ ":xla_test", "//tensorflow/python:array_ops", @@ -1195,11 +1188,18 @@ tf_xla_py_test( tf_xla_py_test( name = "quantized_ops_test", - size = "small", + size = "medium", srcs = ["quantized_ops_test.py"], + disabled_backends = [ + "cpu", + "cpu_ondemand", + ], deps = [ ":xla_test", + "//tensorflow/compiler/tf2xla/python:xla", "//tensorflow/python:array_ops", + "//tensorflow/python:bitwise_ops", + "//tensorflow/python:constant_op", "//tensorflow/python:dtypes", "//tensorflow/python:math_ops", "//tensorflow/python:platform_test", diff --git a/tensorflow/compiler/tests/depthwise_conv_op_test.py b/tensorflow/compiler/tests/depthwise_conv_op_test.py index 174bfa9efbcd7dcb4f895237eb01c17bc4a3a6b4..90146e6b27ca31304a2549ec247412341efe390c 100644 --- a/tensorflow/compiler/tests/depthwise_conv_op_test.py +++ b/tensorflow/compiler/tests/depthwise_conv_op_test.py @@ -350,8 +350,13 @@ class DepthwiseConv2DTest(xla_test.XLATestCase): self._CompareBackpropInput(input_size, filter_size, output_size, stride, padding) - def _CompareBackpropFilter(self, input_sizes, filter_sizes, output_sizes, - stride, padding): + def _CompareBackpropFilter(self, + input_sizes, + filter_sizes, + output_sizes, + stride, + padding, + data_format="NHWC"): x0 = np.random.rand(*input_sizes).astype(np.float32) x2 = np.random.rand(*output_sizes).astype(np.float32) @@ -360,13 +365,30 @@ class DepthwiseConv2DTest(xla_test.XLATestCase): t0 = array_ops.placeholder(np.float32, shape=input_sizes) t1 = constant_op.constant(filter_sizes, shape=[len(filter_sizes)]) t2 = array_ops.placeholder(np.float32, shape=output_sizes) + native_t0 = t0 + native_t2 = t2 + strides = [1, stride, stride, 1] + if use_xla: + if data_format == "NCHW": + # Transpose from NWHC input to NCHW + # Ex. [4, 5, 5, 48] to [4, 48, 5, 5] + native_t0 = array_ops.transpose(t0, [0, 3, 1, 2]) + native_t2 = array_ops.transpose(t2, [0, 3, 1, 2]) + strides = [1, 1, stride, stride] with self.test_scope(): backprop = nn_ops.depthwise_conv2d_native_backprop_filter( - t0, t1, t2, strides=[1, stride, stride, 1], padding=padding) + native_t0, + t1, + native_t2, + strides=strides, + padding=padding, + data_format=data_format) else: + # For CPU, the format NCHW is not supported. Therefore we always use + # NHWC here. backprop = nn_ops.depthwise_conv2d_native_backprop_filter( - t0, t1, t2, strides=[1, stride, stride, 1], padding=padding) + native_t0, t1, native_t2, strides=strides, padding=padding) ret = backprop.eval({t0: x0, t2: x2}) self.assertShapeEqual(ret, backprop) return ret @@ -379,11 +401,24 @@ class DepthwiseConv2DTest(xla_test.XLATestCase): for index, (input_size, filter_size, output_size, stride, padding) in enumerate(ConfigsToTest()): print("Testing DepthwiseConv2DFilterGradCompare,", index, "th config:", - input_size, "*", filter_size, "stride:", stride, "padding:", - padding) + input_size, "*", filter_size, "producing output", output_size, + "stride:", stride, "padding:", padding) self._CompareBackpropFilter(input_size, filter_size, output_size, stride, padding) + def testDepthwiseConv2DFilterGradFormatNCHWCompare(self): + for index, (input_size, filter_size, output_size, stride, + padding) in enumerate(ConfigsToTest()): + print("Testing DepthwiseConv2DFilterGradFormatNCHWCompare,", index, + "th config:", input_size, "*", filter_size, "producing output", + output_size, "stride:", stride, "padding:", padding) + self._CompareBackpropFilter( + input_size, + filter_size, + output_size, + stride, + padding, + data_format="NCHW") if __name__ == "__main__": test.main() diff --git a/tensorflow/compiler/tests/quantized_ops_test.py b/tensorflow/compiler/tests/quantized_ops_test.py index 80c338513bc9ff6b8e56c5ad6b904af9e06a3715..cd9b728ab314d29e4eb585e00a9131024ea3a207 100644 --- a/tensorflow/compiler/tests/quantized_ops_test.py +++ b/tensorflow/compiler/tests/quantized_ops_test.py @@ -18,11 +18,16 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import math import numpy as np from tensorflow.compiler.tests import xla_test +from tensorflow.compiler.tf2xla.python import xla +from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops +from tensorflow.python.ops import bitwise_ops from tensorflow.python.ops import math_ops from tensorflow.python.platform import googletest @@ -44,5 +49,55 @@ class QuantizedOpsTest(xla_test.XLATestCase): self.assertAllEqual(value, expected) +class DeuantizedOpsTest(xla_test.XLATestCase): + + def pack_uint8_r2_to_uint32(self, test_input): + num_rows, num_columns = test_input.get_shape().as_list() + num_output_columns = int(math.ceil(num_columns / 4.0)) + padding_input = array_ops.pad( + math_ops.cast(test_input, dtype=dtypes.uint8), + constant_op.constant([[ + 0, + 0, + ], [0, num_output_columns * 4 - num_columns]])) + output = array_ops.zeros([num_rows, num_output_columns], + dtype=dtypes.uint32) + num_elements_per_pack = 4 + shift_bits = 8 + + iota_r1 = math_ops.range(num_output_columns * num_elements_per_pack) + + for p in range(num_elements_per_pack): + selected_index = math_ops.equal( + math_ops.mod(iota_r1, num_elements_per_pack), p) + gather_index = array_ops.boolean_mask(iota_r1, selected_index) + gathered_input = array_ops.gather(padding_input, gather_index, axis=1) + total_shift_bits = shift_bits * (num_elements_per_pack - p - 1) + left_shift_input = bitwise_ops.left_shift( + math_ops.cast(gathered_input, dtype=dtypes.uint32), total_shift_bits) + output = bitwise_ops.bitwise_or(output, left_shift_input) + return output + + def testDequantizeQuint8(self): + num_rows = 100 + num_columns = 3547 + random_input = np.random.normal(128.0, 10.0, [num_rows, num_columns]) + with self.cached_session() as session: + with ops.device("CPU"): + test_input = ops.convert_to_tensor(random_input, dtype=dtypes.float32) + transposed_input = array_ops.transpose(test_input, [1, 0]) + quantized_input = array_ops.quantize(transposed_input, 0.0, 255.0, + dtypes.quint8) + packed_input = self.pack_uint8_r2_to_uint32(quantized_input.output) + with self.test_scope(): + transposed_quantized_output = xla.dequantize(packed_input, 0.0, 255.0, + "MIN_COMBINED", True) + quantized_output = array_ops.slice(transposed_quantized_output, [0, 0], + [num_rows, num_columns]) + + value = session.run(quantized_output) + self.assertAllClose(value, random_input, 1.0) + + if __name__ == "__main__": googletest.main() diff --git a/tensorflow/compiler/tests/xla_ops_test.py b/tensorflow/compiler/tests/xla_ops_test.py index 4cf88fc523735cc2d22e085afb83790c7ebb48e4..28274ff799de2c85e1e80512cadbe0206cb640a4 100644 --- a/tensorflow/compiler/tests/xla_ops_test.py +++ b/tensorflow/compiler/tests/xla_ops_test.py @@ -319,7 +319,7 @@ class XlaOpsTest(xla_test.XLATestCase, parameterized.TestCase): session.run(output) self.assertRegexpMatches( invalid_arg_error.exception.message, - (r'^start_indices must be a vector with length equal to input rank, ' + (r'start_indices must be a vector with length equal to input rank, ' r'but input rank is 3 and start_indices has shape \[2\].*')) def testDynamicSliceWithIncorrectSizeIndicesShape(self): @@ -332,7 +332,7 @@ class XlaOpsTest(xla_test.XLATestCase, parameterized.TestCase): session.run(output) self.assertRegexpMatches( invalid_arg_error.exception.message, - (r'^size_indices must be a vector with length equal to input rank, ' + (r'size_indices must be a vector with length equal to input rank, ' r'but input rank is 3 and size_indices has shape \[2\].*')) diff --git a/tensorflow/compiler/tf2xla/BUILD b/tensorflow/compiler/tf2xla/BUILD index 5a0d9b9af9d55a8dee809d3cf909bce39c3b8b6c..d8123e956fac04912b4fed5bf75cc9cb55c5baf9 100644 --- a/tensorflow/compiler/tf2xla/BUILD +++ b/tensorflow/compiler/tf2xla/BUILD @@ -244,6 +244,7 @@ cc_library( deps = [ "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:core_cpu_internal", "//tensorflow/core:framework", diff --git a/tensorflow/compiler/tf2xla/functionalize_cond.cc b/tensorflow/compiler/tf2xla/functionalize_cond.cc index c693e42d26712d55852f45c806215fc1f1b9a030..7ae96e1d484900e28e8c23c3bb2232401144ad82 100644 --- a/tensorflow/compiler/tf2xla/functionalize_cond.cc +++ b/tensorflow/compiler/tf2xla/functionalize_cond.cc @@ -640,7 +640,8 @@ Status Conditional::ExtractBodies(Graph* graph) { Status Conditional::BuildIfNode(Graph* graph, FunctionLibraryDefinition* library) { VLOG(2) << "Build cond function for " << name(); - NodeDefBuilder builder(name(), "If", library); + NodeDebugInfo debug_info((*merges_.begin())->def()); + NodeDefBuilder builder(name(), "If", library, &debug_info); const string branch_name[] = {"else_branch", "then_branch"}; for (auto branch : {BranchType::kElseBranch, BranchType::kThenBranch}) { int branch_index = static_cast(branch); diff --git a/tensorflow/compiler/tf2xla/kernels/BUILD b/tensorflow/compiler/tf2xla/kernels/BUILD index 8bc329229648c5aced8d06c99b170803bb3a90f8..47209d285f1a077fd80f779a406e6980892f1646 100644 --- a/tensorflow/compiler/tf2xla/kernels/BUILD +++ b/tensorflow/compiler/tf2xla/kernels/BUILD @@ -1,16 +1,11 @@ +load("//tensorflow:tensorflow.bzl", "tf_copts", "tf_kernel_library") + licenses(["notice"]) # Apache 2.0 package( default_visibility = ["//tensorflow/compiler/tf2xla:internal"], ) -load("//tensorflow:tensorflow.bzl", "tf_copts") -load("//tensorflow:tensorflow.bzl", "tf_kernel_library") -load( - "//third_party/mkl:build_defs.bzl", - "if_mkl", -) - tf_kernel_library( name = "xla_ops", srcs = [ @@ -106,6 +101,7 @@ tf_kernel_library( "variable_ops.cc", "xla_broadcast_helper_op.cc", "xla_conv_op.cc", + "xla_dequantize_op.cc", "xla_dot_op.cc", "xla_pad_op.cc", "xla_reduce_op.cc", @@ -122,12 +118,9 @@ tf_kernel_library( "//tensorflow/compiler/tf2xla:common", "//tensorflow/compiler/tf2xla:xla_compiler", "//tensorflow/compiler/tf2xla/lib:broadcast", - "//tensorflow/compiler/tf2xla/lib:cholesky", - "//tensorflow/compiler/tf2xla/lib:qr", "//tensorflow/compiler/tf2xla/lib:random", "//tensorflow/compiler/tf2xla/lib:scatter", "//tensorflow/compiler/tf2xla/lib:util", - "//tensorflow/compiler/tf2xla/lib:while_loop", "//tensorflow/compiler/tf2xla/ops:xla_ops", "//tensorflow/compiler/xla:array4d", "//tensorflow/compiler/xla:literal", @@ -140,11 +133,15 @@ tf_kernel_library( "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/client/lib:arithmetic", + "//tensorflow/compiler/xla/client/lib:cholesky", "//tensorflow/compiler/xla/client/lib:constants", + "//tensorflow/compiler/xla/client/lib:loops", "//tensorflow/compiler/xla/client/lib:math", "//tensorflow/compiler/xla/client/lib:matrix", "//tensorflow/compiler/xla/client/lib:pooling", "//tensorflow/compiler/xla/client/lib:prng", + "//tensorflow/compiler/xla/client/lib:qr", + "//tensorflow/compiler/xla/client/lib:quantize", "//tensorflow/compiler/xla/client/lib:sorting", "//tensorflow/compiler/xla/client/lib:triangular_solve", "//tensorflow/core:framework", diff --git a/tensorflow/compiler/tf2xla/kernels/cholesky_op.cc b/tensorflow/compiler/tf2xla/kernels/cholesky_op.cc index 9fcbc86adc0967cbb7fb73da8bdabc58b60953da..0ed3044efa5b1060d2b0ad2d5563b0e02ebf66ec 100644 --- a/tensorflow/compiler/tf2xla/kernels/cholesky_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/cholesky_op.cc @@ -13,9 +13,9 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/tf2xla/lib/cholesky.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/lib/cholesky.h" namespace tensorflow { namespace { @@ -24,7 +24,7 @@ class CholeskyOp : public XlaOpKernel { public: explicit CholeskyOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} void Compile(XlaOpKernelContext* ctx) override { - ctx->SetOutput(0, Cholesky(ctx->Input(0))); + ctx->SetOutput(0, xla::Cholesky(ctx->Input(0))); } }; diff --git a/tensorflow/compiler/tf2xla/kernels/conv_op_helpers.cc b/tensorflow/compiler/tf2xla/kernels/conv_op_helpers.cc index 641fefafb357f6ad10483c454600f3dadd4f8cb7..b0bc7640307149459a29e6b0b2e8e8132e4141c9 100644 --- a/tensorflow/compiler/tf2xla/kernels/conv_op_helpers.cc +++ b/tensorflow/compiler/tf2xla/kernels/conv_op_helpers.cc @@ -392,23 +392,31 @@ xla::StatusOr MakeXlaBackpropFilterConvOp( builder->GetShape(activations)); TF_ASSIGN_OR_RETURN(xla::Shape out_backprop_shape, builder->GetShape(gradients)); + xla::XlaOp filter_backprop; + + xla::Shape input_shape = activations_shape; + xla::Shape output_shape = out_backprop_shape; + + TensorShape input_tensor_shape, filter_tensor_shape, output_tensor_shape; + TF_RETURN_IF_ERROR(XLAShapeToTensorShape(filter_shape, &filter_tensor_shape)); + TF_RETURN_IF_ERROR(XLAShapeToTensorShape(input_shape, &input_tensor_shape)); + TF_RETURN_IF_ERROR(XLAShapeToTensorShape(output_shape, &output_tensor_shape)); + const xla::Shape expanded_filter_shape = attrs.depthwise ? ExpandedFilterShapeForDepthwiseConvolution(filter_shape) : filter_shape; - // Reuse dimension computation logic from conv_grad_ops.cc. ConvBackpropDimensions dims; - TF_RETURN_IF_ERROR(ConvBackpropComputeDimensionsV2XlaShapes( - type_string, attrs.num_spatial_dims, activations_shape, - expanded_filter_shape, out_backprop_shape, attrs.dilations, attrs.strides, - attrs.padding, attrs.data_format, &dims)); - // The filter gradients are computed by a convolution of the input // activations and the output gradients, with some appropriate padding. // See the comment at the top of conv_grad_ops.h for details. - xla::ConvolutionDimensionNumbers dnums; + TF_RETURN_IF_ERROR(ConvBackpropComputeDimensionsV2XlaShapes( + type_string, attrs.num_spatial_dims, activations_shape, + expanded_filter_shape, out_backprop_shape, attrs.dilations, attrs.strides, + attrs.padding, attrs.data_format, &dims)); + // The activations (inputs) form the LHS of the convolution. // Activations have shape: [batch, in_rows, in_cols, ..., in_depth] // For the gradient computation, we flip the roles of the batch and @@ -420,6 +428,23 @@ xla::StatusOr MakeXlaBackpropFilterConvOp( int n_dim = GetTensorBatchDimIndex(num_dims, attrs.data_format); int c_dim = GetTensorFeatureDimIndex(num_dims, attrs.data_format); + // The conversion logic below assumes that the data format is NHWC, so we also + // check that here. + bool use_batch_group_count = + filter_tensor_shape.dim_size(num_dims - 1) == 1 && attrs.depthwise && + attrs.data_format == FORMAT_NHWC; + + std::vector> padding(attrs.num_spatial_dims); + std::vector rhs_dilation(attrs.num_spatial_dims); + std::vector window_strides(attrs.num_spatial_dims); + std::vector ones(attrs.num_spatial_dims, 1); + + // The activations (inputs) form the LHS of the convolution. + // Activations have shape: [batch, in_rows, in_cols, ..., in_depth] + // For the gradient computation, we flip the roles of the batch and + // feature dimensions. + // Each spatial entry has size in_depth * batch + // Swap n_dim and c_dim in the activations. dnums.set_input_batch_dimension(c_dim); dnums.set_input_feature_dimension(n_dim); @@ -430,19 +455,21 @@ xla::StatusOr MakeXlaBackpropFilterConvOp( dnums.set_kernel_input_feature_dimension(n_dim); dnums.set_kernel_output_feature_dimension(c_dim); - std::vector> padding(attrs.num_spatial_dims); - std::vector rhs_dilation(attrs.num_spatial_dims); - std::vector window_strides(attrs.num_spatial_dims); - std::vector ones(attrs.num_spatial_dims, 1); + // The dimension swap below is needed because filter shape is KH,KW,F,DM. + if (use_batch_group_count) { + dnums.set_output_batch_dimension(attrs.num_spatial_dims + 1); + dnums.set_output_feature_dimension(attrs.num_spatial_dims); + } else { + dnums.set_output_batch_dimension(attrs.num_spatial_dims); + dnums.set_output_feature_dimension(attrs.num_spatial_dims + 1); + } // Tensorflow filter shape is [ H, W, ..., inC, outC ]. for (int i = 0; i < attrs.num_spatial_dims; ++i) { dnums.add_output_spatial_dimensions(i); } - dnums.set_output_batch_dimension(attrs.num_spatial_dims); - dnums.set_output_feature_dimension(attrs.num_spatial_dims + 1); - for (int i = 0; i < attrs.num_spatial_dims; ++i) { + for (int64 i = 0; i < attrs.num_spatial_dims; ++i) { int64 dim = GetTensorSpatialDimIndex(num_dims, attrs.data_format, i); dnums.add_input_spatial_dimensions(dim); dnums.add_kernel_spatial_dimensions(dim); @@ -496,11 +523,14 @@ xla::StatusOr MakeXlaBackpropFilterConvOp( // // This is done by specifying the window dilation factors in the // convolution HLO below. - auto filter_backprop = - xla::ConvGeneralDilated(activations, gradients, window_strides, padding, - /*lhs_dilation=*/ones, rhs_dilation, dnums); - if (attrs.depthwise) { + filter_backprop = xla::ConvGeneralDilated( + activations, gradients, window_strides, padding, /*lhs_dilation=*/ones, + rhs_dilation, dnums, + /*feature_group_count=*/1, + /*batch_group_count=*/use_batch_group_count ? dims.in_depth : 1); + + if (!use_batch_group_count && attrs.depthwise) { filter_backprop = ContractFilterForDepthwiseBackprop( filter_shape, filter_backprop, activations.builder()); } diff --git a/tensorflow/compiler/tf2xla/kernels/gather_op.cc b/tensorflow/compiler/tf2xla/kernels/gather_op.cc index 20b0de193dc060197f3062d3be0b8d45f7dcb9b1..41c31d0ed58fe9bc9bbde0bd58993c975f04fd60 100644 --- a/tensorflow/compiler/tf2xla/kernels/gather_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/gather_op.cc @@ -14,7 +14,6 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/tf2xla/kernels/gather_op_helpers.h" -#include "tensorflow/compiler/tf2xla/lib/while_loop.h" #include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/tf2xla/type_util.h" #include "tensorflow/compiler/tf2xla/xla_context.h" diff --git a/tensorflow/compiler/tf2xla/kernels/if_op.cc b/tensorflow/compiler/tf2xla/kernels/if_op.cc index b5e083912555c865b5eadc7697075c9ca4451ca9..4f0f0fd9aefecc3d31f8bd9c8ca40ebb0860c82d 100644 --- a/tensorflow/compiler/tf2xla/kernels/if_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/if_op.cc @@ -56,6 +56,7 @@ void XlaIfOp::Compile(XlaOpKernelContext* ctx) { VLOG(1) << "Building If: " << input_types_.size() << " inputs"; std::vector arguments(input_types_.size()); + int num_resource_args = 0; for (int i = 0; i < input_types_.size(); ++i) { XlaCompiler::Argument& arg = arguments[i]; DataType type = ctx->input_type(i + 1); @@ -81,6 +82,8 @@ void XlaIfOp::Compile(XlaOpKernelContext* ctx) { << " type: " << DataTypeString(arg.type) << " shape: " << arg.shape.DebugString() << " initialized: " << arg.initialized; + + num_resource_args++; } else { arg.kind = XlaCompiler::Argument::kParameter; arg.type = input_types_[i]; @@ -236,9 +239,13 @@ void XlaIfOp::Compile(XlaOpKernelContext* ctx) { ctx->SetOutput(i, output_handle); } if (has_token_input_output_) { - // Set token output for this "if" op. + // Set token output for this "If" op. Token output is the last output of + // XLA computation, which comes after all "normal" TF outputs and resource + // updates. For "If" node, num of resource updates equals to number of + // resource args because we set `return_updated_values_for_all_resources` + // to true in XlaCompiler option. xla::XlaOp token_output = - xla::GetTupleElement(outputs, output_types_.size()); + xla::GetTupleElement(outputs, output_types_.size() + num_resource_args); auto shape_or = b->GetShape(token_output); OP_REQUIRES_OK(ctx, shape_or.status()); OP_REQUIRES(ctx, xla::ShapeUtil::IsToken(shape_or.ValueOrDie()), diff --git a/tensorflow/compiler/tf2xla/kernels/image_ops.cc b/tensorflow/compiler/tf2xla/kernels/image_ops.cc index e9bb0a77e99d144863b027bd214081316d61c314..96ddd42e2ae04d454e4fb85628d139e17a543d2e 100644 --- a/tensorflow/compiler/tf2xla/kernels/image_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/image_ops.cc @@ -15,12 +15,12 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/kernels/gather_op_helpers.h" #include "tensorflow/compiler/tf2xla/lib/util.h" -#include "tensorflow/compiler/tf2xla/lib/while_loop.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/lib/arithmetic.h" #include "tensorflow/compiler/xla/client/lib/constants.h" +#include "tensorflow/compiler/xla/client/lib/loops.h" #include "tensorflow/compiler/xla/client/lib/sorting.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/shape_util.h" @@ -505,9 +505,9 @@ class NonMaxSuppressionOp : public XlaOpKernel { init_values.push_back(included_iou); auto suppress_loop_result = - XlaWhileLoop(WhileCondFn(num_boxes, output_size), - SuppressBodyFn(num_boxes), init_values, "suppress_loop", - builder) + xla::WhileLoopHelper(WhileCondFn(num_boxes, output_size), + SuppressBodyFn(num_boxes), init_values, + "suppress_loop", builder) .ValueOrDie(); xla::XlaOp included_score = diff --git a/tensorflow/compiler/tf2xla/kernels/qr_op.cc b/tensorflow/compiler/tf2xla/kernels/qr_op.cc index 7ea0afc1f53cbe4cfcc3f6121a4ecd55864c1b52..66ec40a946b8a063d84acd33daf81f52ea2c35ed 100644 --- a/tensorflow/compiler/tf2xla/kernels/qr_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/qr_op.cc @@ -13,9 +13,9 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/tf2xla/lib/qr.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/lib/qr.h" namespace tensorflow { namespace { @@ -26,7 +26,7 @@ class QROp : public XlaOpKernel { OP_REQUIRES_OK(ctx, ctx->GetAttr("full_matrices", &full_matrices_)); } void Compile(XlaOpKernelContext* ctx) override { - auto result = QRDecomposition(ctx->Input(0), full_matrices_); + auto result = xla::QRDecomposition(ctx->Input(0), full_matrices_); if (!result.ok()) { ctx->SetStatus(result.status()); return; diff --git a/tensorflow/compiler/tf2xla/kernels/random_ops.cc b/tensorflow/compiler/tf2xla/kernels/random_ops.cc index 8822e29f7e77b1cbc6fa6ca61d0062d9b1b0c36e..2d92056e4f522f6206e7d632f0fa1e8b793fd6e3 100644 --- a/tensorflow/compiler/tf2xla/kernels/random_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/random_ops.cc @@ -20,12 +20,12 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/kernels/gather_op_helpers.h" #include "tensorflow/compiler/tf2xla/lib/random.h" #include "tensorflow/compiler/tf2xla/lib/util.h" -#include "tensorflow/compiler/tf2xla/lib/while_loop.h" #include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/lib/arithmetic.h" +#include "tensorflow/compiler/xla/client/lib/loops.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/tensor.h" @@ -175,8 +175,8 @@ class RandomShuffleOp : public XlaOpKernel { }; // for i in range(n): auto swap_loop_result = - XlaForEachIndex(n, xla::S32, swap_body_fn, {swaps, indices}, - "indices_swap_loop", builder) + xla::ForEachIndex(n, xla::S32, swap_body_fn, {swaps, indices}, + "indices_swap_loop", builder) .ValueOrDie(); auto swapped_indices = swap_loop_result[1]; diff --git a/tensorflow/compiler/tf2xla/kernels/resampler_ops.cc b/tensorflow/compiler/tf2xla/kernels/resampler_ops.cc index 54d34a38abc4948a1a08197d72e3e7f763649093..f9985d526033ca675c701a508a3d1576e46bc5f7 100644 --- a/tensorflow/compiler/tf2xla/kernels/resampler_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/resampler_ops.cc @@ -125,7 +125,7 @@ XlaOp ConcatenateIota(xla::XlaBuilder* b, XlaOp indices, dimensions.back() = 1; auto batch_indices = - xla::Iota(b, xla::ShapeUtil::MakeShape(xla::U32, dimensions), + xla::Iota(b, xla::ShapeUtil::MakeShape(xla::S32, dimensions), /*iota_dimension=*/0); return xla::ConcatInDim(b, {batch_indices, indices}, dimensions.size() - 1); @@ -189,11 +189,53 @@ XlaOp ScatterToGradData(XlaOpKernelContext* ctx, XlaOp grad_data, XlaOp indices, scatter_dim_numbers); } +// Bounds samples to 0 if the warp image indices are out of the (-1, image_size) +// bound. +// The resulting dimension is given by 'result_dims'. +XlaOp BoundSamples(XlaOpKernelContext* ctx, XlaOp warp, + xla::PrimitiveType warp_type, TensorShape warp_shape, + std::vector result_dims, + std::vector broadcasted_dims, int64 last_warp_dim, + xla::Shape data_shape, XlaOp sample) { + auto is_gt_minus_one = + xla::Gt(warp, + xla::ConvertElementType( + xla::ConstantR1(ctx->builder(), {-1, -1}), warp_type), + /*broadcast_dimensions=*/{warp_shape.dims() - 1}); + auto is_lt_image_size = xla::Lt( + warp, + xla::ConvertElementType( + xla::ConstantR1( + ctx->builder(), + {/*width=*/static_cast(data_shape.dimensions(2)), + /*height=*/static_cast(data_shape.dimensions(1))}), + warp_type), + /*broadcast_dimensions=*/{warp_shape.dims() - 1}); + + auto is_in_bound_padded_x_y = xla::And(is_gt_minus_one, is_lt_image_size); + // Reduce along last dimension. The resulting dimension is: + // [batch, dim_0, ...dim_n]. + auto is_in_bound = xla::Reduce( + is_in_bound_padded_x_y, xla::ConstantR0(ctx->builder(), true), + xla::CreateScalarAndComputation(xla::PrimitiveType::PRED, ctx->builder()), + {last_warp_dim}); + + // Broadcast 'is_in_bound' to the same dimension as 'result_dims'. + auto broadcasted_is_in_bound = + xla::BroadcastInDim(is_in_bound, result_dims, broadcasted_dims); + + // Set out of bound samples to zero. + auto zeros = + xla::Broadcast(xla::Zero(ctx->builder(), warp_type), result_dims); + return xla::Select(broadcasted_is_in_bound, sample, zeros); +} + // Build computation the backprop into input 'data'. // Where input: // grad_output is of dimension [batch, dim_0, ...dim_n, channel] // ratio is of dimension [batch, dim_0, ...dim_n, 2] // gather_indices is of dimension [batch, dim_0, ...dim_n, 3] +// data_shape is of dimension [batch, x(width), y(height), channel] // // Output: // scatter-add to each 2x2 grad_data neighbor: @@ -201,10 +243,12 @@ XlaOp ScatterToGradData(XlaOpKernelContext* ctx, XlaOp grad_data, XlaOp indices, // grad_data[cx, fy, chan] += output_grad * (1 - dx) * dy // grad_data[fx, cy, chan] += output_grad * dx * (1 - dy) // grad_data[cx, cy, chan] += output_grad * (1 - dx) * (1 - dy) -// where (dx, dy) is (1 - ratio). +// where (dx, dy) is (1 - ratio). If (dx, dy) is out of bound, then the their +// contribution is 0 to 'grad_data'. XlaOp CalculateGradData(XlaOpKernelContext* ctx, XlaOp grad_output, XlaOp ratio, - XlaOp gather_indices, xla::PrimitiveType warp_type, - TensorShape warp_shape, int64 data_channels, + XlaOp gather_indices, XlaOp warp, + xla::PrimitiveType warp_type, TensorShape warp_shape, + int64 last_warp_dim, int64 data_channels, xla::Shape data_shape) { // Weights tensor has dimension [batch, dim_0, ... dim_n, 4]. auto weights = BilinearWeights(ctx, ratio, warp_shape, warp_type); @@ -229,6 +273,18 @@ XlaOp CalculateGradData(XlaOpKernelContext* ctx, XlaOp grad_output, XlaOp ratio, std::iota(reshaped_weights_indices.begin(), reshaped_weights_indices.end(), 0); + // Set out of bound weights to 0. + // The dimension of the reshaped_weight: [batch, dim_0, ...dim_n, 2, 2]. + std::vector reshaped_result_dims(warp_dims.begin(), + warp_dims.end() - 1); + reshaped_result_dims.push_back(2); + reshaped_result_dims.push_back(2); + std::vector broadcasted_dims(warp_dims.size() - 1); + std::iota(broadcasted_dims.begin(), broadcasted_dims.end(), 0); + reshaped_weights = BoundSamples(ctx, warp, warp_type, warp_shape, + reshaped_result_dims, broadcasted_dims, + last_warp_dim, data_shape, reshaped_weights); + // The dimension is [batch, dim_0, ..., dim_n, 2, 2, data_channel]. auto broadcast_reshaped_weights = xla::BroadcastInDim( reshaped_weights, weights_with_channels_dims, reshaped_weights_indices); @@ -245,18 +301,41 @@ XlaOp CalculateGradData(XlaOpKernelContext* ctx, XlaOp grad_output, XlaOp ratio, auto grad_data = xla::ConstantLiteral( ctx->builder(), xla::Literal::CreateFromShape(data_shape)); - return ScatterToGradData(ctx, grad_data, gather_indices, - grad_output_multiply_weights, warp_shape.dims(), - warp_type); + // Pad grad data then slice it back. + // + // After left and right column 0-padding, the new dimension of padded data + // will be [batch, x+2, y+2, channel]. + auto padded_grad_data = + xla::Pad(grad_data, xla::Zero(ctx->builder(), warp_type), + xla::MakeEdgePaddingConfig({{0, 0}, {1, 1}, {1, 1}, {0, 0}})); + + auto shifting_value = xla::ConstantR1( + ctx->builder(), {/*batch=*/0, /*x(width)=*/1, /*y(height)=*/1}); + auto shifted_gather_indices = + xla::Add(gather_indices, shifting_value, {last_warp_dim}); + + auto updated_grad_data = ScatterToGradData( + ctx, padded_grad_data, shifted_gather_indices, + grad_output_multiply_weights, warp_shape.dims(), warp_type); + + const int64 batch_size = data_shape.dimensions(0); + const int64 width = data_shape.dimensions(1); + const int64 height = data_shape.dimensions(2); + // Slice out the result accounting for the padding. + return xla::Slice( + updated_grad_data, /*start_indices=*/{0, 1, 1, 0}, + /*limit_indices=*/{batch_size, width + 1, height + 1, data_channels}, + /*strides=*/{1, 1, 1, 1}); } // Build computation for the backprop into input 'warp'. // Where input: -// warp is of dimension [batch, dim_0, ...dim_n, 2] -// grad_output is of dimension [batch, dim_0, ...dim_n, channel] -// ratio is of dimension [batch, dim_0, ...dim_n, 2] -// gather_indices is of dimension [batch, dim_0, ...dim_n, 3] -// data is of dimension [batch, x, y, channel] +// warp is of dimension [batch, dim_0, ...dim_n, 2] +// grad_output is of dimension [batch, dim_0, ...dim_n, channel] +// ratio is of dimension [batch, dim_0, ...dim_n, 2] +// gather_indices is of dimension [batch, dim_0, ...dim_n, 3] where the last +// dimension of size 3 is for {batch, x(width), y(height)}. +// data is of dimension [batch, x, y, channel] // // Output (simplified by ignoring the batch dimensions): // Since the forward path has: @@ -275,12 +354,12 @@ XlaOp CalculateGradData(XlaOpKernelContext* ctx, XlaOp grad_output, XlaOp ratio, // grad_warp_x = py * (img_cxcy - img_fxcy) + (1-py) * (img_cxfy-img_fxfy) // grad_warp_y = px * (img_cxcy - img_cxfy) + (1-px) * (img_fxcy-img_fxfy) // -// where (px, py) is warp, (fx, fy) is the left top corner and (cx, cy) is the +// where (px, py) is warp, (fx, fy) is the top left corner and (cx, cy) is the // bottom right corner in a 2x2 neighborhood. XlaOp CalculateGradWarp(XlaOpKernelContext* ctx, XlaOp grad_output, XlaOp ratio, XlaOp gather_indices, XlaOp data, TensorShape warp_shape, int64 data_channels, - xla::PrimitiveType data_type) { + xla::PrimitiveType data_type, xla::Shape data_shape) { auto warp_dims = warp_shape.dim_sizes(); std::vector warp_dims_without_last_dims(warp_dims.begin(), warp_dims.end() - 1); @@ -289,12 +368,30 @@ XlaOp CalculateGradWarp(XlaOpKernelContext* ctx, XlaOp grad_output, XlaOp ratio, std::vector neighbor_broadcast_dims = warp_dims_without_last_dims; neighbor_broadcast_dims.push_back(4); - // The dimension is [batch, dim_0, ... dim_n, 4, data_channels] - auto neighbors_data = Gather2by2Neighbors( - ctx->builder(), data, gather_indices, data_channels, warp_shape.dims()); + // With dimension [batch, dim_0, ...dim_n, 4] + auto neighbor_broadcast_shape = + xla::ShapeUtil::MakeShape(data_type, neighbor_broadcast_dims); const int64 last_warp_dim = warp_shape.dims() - 1; + // Pad data with 0, before gathering such that 0 will be returned for samples + // in the range of (-1, 0) or (image_dimension-1, image_dimension). + // After left and right column 0-padding, the new dimension of padded data + // will be [batch, x+2, y+2, channel]. + auto padded_data = + xla::Pad(data, xla::Zero(ctx->builder(), data_type), + xla::MakeEdgePaddingConfig({{0, 0}, {1, 1}, {1, 1}, {0, 0}})); + + auto shifting_value = xla::ConstantR1( + ctx->builder(), {/*batch=*/0, /*x(width)=*/1, /*y(height)=*/1}); + auto shifted_gather_indices = + xla::Add(gather_indices, shifting_value, {last_warp_dim}); + + // The dimension is [batch, dim_0, ... dim_n, 4, data_channels] + auto neighbors_data = + Gather2by2Neighbors(ctx->builder(), padded_data, shifted_gather_indices, + data_channels, warp_shape.dims()); + // Since we will be creating the dot product of: // lhs: [batch, dim_0, ...dim_n, 4] // and @@ -417,7 +514,7 @@ class ResamplerOp : public XlaOpKernel { // Find the coordinates of the top left corner for the 2x2 region to be // sampled from. The dimensions are [batch, dim_0, ... dim_n, 2] where the // last dimension of size 2 in turn is [x, y]. - XlaOp top_left = xla::ConvertElementType(warp, xla::U32); + XlaOp top_left = xla::ConvertElementType(warp, xla::S32); auto gather_indices = ConcatenateIota(ctx->builder(), top_left, warp_shape); @@ -526,7 +623,8 @@ class ResamplerGradOp : public XlaOpKernel { size, "]")); } // Last dimension of warp shape must be of size 2. - OP_REQUIRES(ctx, warp_shape.dim_size(warp_shape.dims() - 1) == 2, + const int64 last_warp_dim = warp_shape.dims() - 1; + OP_REQUIRES(ctx, warp_shape.dim_size(last_warp_dim) == 2, errors::InvalidArgument( "the last dimension of warp must be exactly size 2.")); xla::PrimitiveType warp_type = ctx->input_xla_type(1); @@ -549,24 +647,32 @@ class ResamplerGradOp : public XlaOpKernel { // Find the top left corner coordinate for the region to be sampled from. // The dimensions are [batch, dim_0, ... dim_n, 2] where the last dimension // of size 2 in turn is [x, y]. - XlaOp top_left = xla::ConvertElementType(warp, xla::U32); + XlaOp top_left = xla::ConvertElementType(xla::Floor(warp), xla::S32); - // Dimensions are [batch, dim_0, ... dim_n, 2] + // Dimensions are [batch, dim_0, ... dim_n, 2]. XlaOp ratio = warp - xla::ConvertElementType(top_left, warp_type); // Indices for gathering neighboring pixels. auto gather_indices = ConcatenateIota(ctx->builder(), top_left, warp_shape); - auto grad_data = - CalculateGradData(ctx, grad_output, ratio, gather_indices, warp_type, - warp_shape, data_channels, data_shape); + auto grad_data = CalculateGradData( + ctx, grad_output, ratio, gather_indices, warp, warp_type, warp_shape, + last_warp_dim, data_channels, data_shape); auto grad_warp = CalculateGradWarp(ctx, grad_output, ratio, gather_indices, data, - warp_shape, data_channels, data_type); + warp_shape, data_channels, data_type, data_shape); + auto warp_dims = warp_shape.dim_sizes(); + std::vector result_dims(warp_dims.begin(), warp_dims.end() - 1); + result_dims.push_back(2); + std::vector broadcasted_dims(warp_dims.size() - 1); + std::iota(broadcasted_dims.begin(), broadcasted_dims.end(), 0); + auto grad_warp_bounded = + BoundSamples(ctx, warp, warp_type, warp_shape, result_dims, + broadcasted_dims, last_warp_dim, data_shape, grad_warp); ctx->SetOutput(0, grad_data); - ctx->SetOutput(1, grad_warp); + ctx->SetOutput(1, grad_warp_bounded); } }; diff --git a/tensorflow/compiler/tf2xla/kernels/training_ops.cc b/tensorflow/compiler/tf2xla/kernels/training_ops.cc index 960c1462ceb8c00a2d6c96564f6c985fd1caef0f..26d4214099d1d07c1b2e275d783654d9cd948e28 100644 --- a/tensorflow/compiler/tf2xla/kernels/training_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/training_ops.cc @@ -172,6 +172,65 @@ class ResourceApplyMomentum : public XlaOpKernel { REGISTER_XLA_OP(Name("ResourceApplyMomentum").TypeConstraint("T", kFloatTypes), ResourceApplyMomentum); +class ResourceApplyKerasMomentum : public XlaOpKernel { + public: + explicit ResourceApplyKerasMomentum(OpKernelConstruction* ctx) + : XlaOpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("use_nesterov", &use_nesterov_)); + } + + void Compile(XlaOpKernelContext* ctx) override { + DataType type = ctx->input_type(2); + + TensorShape var_shape, accum_shape; + xla::XlaOp var, accum; + OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, type, &var_shape, &var)); + OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(1, type, &accum_shape, &accum)); + + OP_REQUIRES(ctx, var_shape.IsSameSize(accum_shape), + errors::InvalidArgument( + "var and accum do not have the same shape", + var_shape.DebugString(), " ", accum_shape.DebugString())); + + TensorShape lr_shape = ctx->InputShape(2); + OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(lr_shape), + errors::InvalidArgument("lr is not a scalar: ", + lr_shape.DebugString())); + + TensorShape grad_shape = ctx->InputShape(3); + OP_REQUIRES(ctx, var_shape.IsSameSize(grad_shape), + errors::InvalidArgument( + "var and grad do not have the same shape", + var_shape.DebugString(), " ", grad_shape.DebugString())); + + TensorShape momentum_shape = ctx->InputShape(4); + OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(momentum_shape), + errors::InvalidArgument("momentum is not a scalar: ", + momentum_shape.DebugString())); + + xla::XlaOp lr = ctx->Input(2); + xla::XlaOp grad = ctx->Input(3); + xla::XlaOp momentum = ctx->Input(4); + + accum = accum * momentum - grad * lr; + if (use_nesterov_) { + // See https://github.com/tensorflow/tensorflow/pull/2798 for an + // explanation of the reparameterization used here. + var = var + accum * momentum - grad * lr; + } else { + var = var + accum; + } + OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, type, var)); + OP_REQUIRES_OK(ctx, ctx->AssignVariable(1, type, accum)); + } + + private: + bool use_nesterov_; +}; +REGISTER_XLA_OP( + Name("ResourceApplyKerasMomentum").TypeConstraint("T", kFloatTypes), + ResourceApplyKerasMomentum); + class ResourceApplyAdagrad : public XlaOpKernel { public: explicit ResourceApplyAdagrad(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} diff --git a/tensorflow/compiler/tf2xla/kernels/while_op.cc b/tensorflow/compiler/tf2xla/kernels/while_op.cc index ce007fc04a818869686b9936a1607cee42665e87..ff5255028bd012ea4d839faa59ef5930a17c5767 100644 --- a/tensorflow/compiler/tf2xla/kernels/while_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/while_op.cc @@ -41,8 +41,7 @@ Status MakeXlaCompilerArgumentsFromInputs( *has_uninitialized_vars = false; *has_tensor_arrays = false; for (int i = 0; i < ctx->num_inputs(); ++i) { - VLOG(2) << " Input " << i - << " type: " << DataTypeString(ctx->input_type(i)) + VLOG(2) << " Input " << i << " type: " << DataTypeString(ctx->input_type(i)) << " shape: " << ctx->InputShape(i).DebugString(); XlaCompiler::Argument& arg = (*args)[i]; DataType type = ctx->input_type(i); @@ -233,13 +232,22 @@ void XlaWhileOp::Compile(XlaOpKernelContext* ctx) { xla::ShapeUtil::HumanString(body_input_shape), " vs. ", xla::ShapeUtil::HumanString(body.xla_output_shape))); - xla::Shape expected_cond_output_shape = xla::ShapeUtil::MakeTupleShape( - {xla::ShapeUtil::MakeShape(xla::PRED, {})}); + xla::Shape expected_cond_output_shape_without_side_effect = + xla::ShapeUtil::MakeTupleShape( + {xla::ShapeUtil::MakeShape(xla::PRED, {})}); + xla::Shape expected_cond_output_shape_with_side_effect = + xla::ShapeUtil::MakeTupleShape({xla::ShapeUtil::MakeShape(xla::PRED, {}), + xla::ShapeUtil::MakeTokenShape()}); OP_REQUIRES(ctx, - xla::ShapeUtil::Compatible(cond.xla_output_shape, - expected_cond_output_shape), + xla::ShapeUtil::Compatible( + cond.xla_output_shape, + expected_cond_output_shape_without_side_effect) || + xla::ShapeUtil::Compatible( + cond.xla_output_shape, + expected_cond_output_shape_with_side_effect), errors::InvalidArgument( - "Output shape of loop condition should be (pred[]), got: ", + "Output shape of loop condition should be (pred[]) or " + "(pred[], token[]), got: ", xla::ShapeUtil::HumanString(cond.xla_output_shape))); int num_inputs = body.input_mapping.size(); @@ -283,6 +291,15 @@ void XlaWhileOp::Compile(XlaOpKernelContext* ctx) { xla::XlaOp while_result = xla::While(cond_wrapper, *body.computation, init); + auto while_shape_or = builder->GetShape(while_result); + OP_REQUIRES_OK(ctx, while_shape_or.status()); + auto count = xla::ShapeUtil::TupleElementCount(while_shape_or.ValueOrDie()); + int max_index = body.outputs.size() + body.resource_updates.size() - 1; + OP_REQUIRES( + ctx, max_index < count, + errors::Internal("Max tuple element requested (", max_index, + ") needs to be less than tuple size (", count, ")")); + // Sets non-variable outputs. for (int i = 0; i < ctx->num_outputs(); ++i) { if (ctx->input_type(i) != DT_RESOURCE) { diff --git a/tensorflow/compiler/tf2xla/kernels/xla_conv_op.cc b/tensorflow/compiler/tf2xla/kernels/xla_conv_op.cc index 4612f19971a3ce6994aef303f751748b77ccda9a..b20adc592a0d3d2129c897218ddbfc891b4cd40a 100644 --- a/tensorflow/compiler/tf2xla/kernels/xla_conv_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/xla_conv_op.cc @@ -78,7 +78,7 @@ class XlaConvOp : public XlaOpKernel { xla::XlaOp output = xla::ConvGeneralDilated( context->Input(0), context->Input(1), window_strides, padding, lhs_dilation, rhs_dilation, dnums_, feature_group_count, - &precision_config_); + /*batch_group_count=*/1, &precision_config_); context->SetOutput(0, output); } diff --git a/tensorflow/compiler/tf2xla/kernels/xla_dequantize_op.cc b/tensorflow/compiler/tf2xla/kernels/xla_dequantize_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..a30b4861f6b3a964c0c874a3affab7d6198264d7 --- /dev/null +++ b/tensorflow/compiler/tf2xla/kernels/xla_dequantize_op.cc @@ -0,0 +1,60 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/tf2xla/shape_util.h" +#include "tensorflow/compiler/tf2xla/xla_compiler.h" +#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/lib/quantize.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/framework/op_kernel.h" + +namespace tensorflow { +namespace { + +class XlaDequantizeOp : public XlaOpKernel { + public: + explicit XlaDequantizeOp(OpKernelConstruction* context) + : XlaOpKernel(context) { + OP_REQUIRES_OK(context, context->GetAttr("min_range", &min_range_)); + OP_REQUIRES_OK(context, context->GetAttr("max_range", &max_range_)); + OP_REQUIRES_OK(context, context->GetAttr("mode", &mode_)); + OP_REQUIRES_OK(context, + context->GetAttr("transpose_output", &transpose_output_)); + } + + void Compile(XlaOpKernelContext* context) override { + const xla::XlaOp& input = context->Input(0); + + xla::QuantizedRange range(min_range_, max_range_); + + xla::XlaOp output = + xla::Dequantize(input, range, mode_, transpose_output_); + context->SetOutput(0, output); + } + + private: + float min_range_; + float max_range_; + bool transpose_output_; + string mode_; + TF_DISALLOW_COPY_AND_ASSIGN(XlaDequantizeOp); +}; + +REGISTER_XLA_OP(Name("XlaDequantize"), XlaDequantizeOp); + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/lib/BUILD b/tensorflow/compiler/tf2xla/lib/BUILD index 3e7a761120317ff85947559b7b2e52be9232afb7..3d7b0bc959f9dbf3c1b9749379e2ea0d285b302b 100644 --- a/tensorflow/compiler/tf2xla/lib/BUILD +++ b/tensorflow/compiler/tf2xla/lib/BUILD @@ -15,8 +15,6 @@ filegroup( ]), ) -load("//tensorflow/compiler/xla/tests:build_defs.bzl", "xla_test") - cc_library( name = "broadcast", srcs = ["broadcast.cc"], @@ -33,27 +31,6 @@ cc_library( ], ) -cc_library( - name = "cholesky", - srcs = ["cholesky.cc"], - hdrs = ["cholesky.h"], - deps = [ - ":util", - ":while_loop", - "//tensorflow/compiler/xla:literal", - "//tensorflow/compiler/xla:shape_util", - "//tensorflow/compiler/xla:status_macros", - "//tensorflow/compiler/xla:statusor", - "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/client:xla_builder", - "//tensorflow/compiler/xla/client/lib:constants", - "//tensorflow/compiler/xla/client/lib:matrix", - "//tensorflow/compiler/xla/client/lib:slicing", - "//tensorflow/compiler/xla/client/lib:triangular_solve", - "//tensorflow/core:lib", - ], -) - cc_library( name = "random", srcs = ["random.cc"], @@ -69,35 +46,12 @@ cc_library( ], ) -cc_library( - name = "qr", - srcs = ["qr.cc"], - hdrs = ["qr.h"], - deps = [ - ":util", - ":while_loop", - "//tensorflow/compiler/xla:literal_util", - "//tensorflow/compiler/xla:shape_util", - "//tensorflow/compiler/xla:status_macros", - "//tensorflow/compiler/xla:statusor", - "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/client:xla_builder", - "//tensorflow/compiler/xla/client/lib:arithmetic", - "//tensorflow/compiler/xla/client/lib:constants", - "//tensorflow/compiler/xla/client/lib:math", - "//tensorflow/compiler/xla/client/lib:matrix", - "//tensorflow/compiler/xla/client/lib:slicing", - "//tensorflow/core:lib", - ], -) - cc_library( name = "scatter", srcs = ["scatter.cc"], hdrs = ["scatter.h"], deps = [ ":util", - ":while_loop", "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", @@ -128,19 +82,3 @@ cc_library( "@com_google_absl//absl/types:span", ], ) - -cc_library( - name = "while_loop", - srcs = ["while_loop.cc"], - hdrs = ["while_loop.h"], - deps = [ - ":util", - "//tensorflow/compiler/xla:shape_util", - "//tensorflow/compiler/xla:status_macros", - "//tensorflow/compiler/xla:statusor", - "//tensorflow/compiler/xla/client:xla_builder", - "//tensorflow/compiler/xla/client:xla_computation", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/types:span", - ], -) diff --git a/tensorflow/compiler/tf2xla/lib/scatter.cc b/tensorflow/compiler/tf2xla/lib/scatter.cc index 2b1c2ced925d9fee7392986015a6e716a94d356f..688056791f9750e6b22df4b2cd4643de0b780651 100644 --- a/tensorflow/compiler/tf2xla/lib/scatter.cc +++ b/tensorflow/compiler/tf2xla/lib/scatter.cc @@ -20,7 +20,6 @@ limitations under the License. #include "absl/types/span.h" #include "tensorflow/compiler/tf2xla/lib/util.h" -#include "tensorflow/compiler/tf2xla/lib/while_loop.h" #include "tensorflow/compiler/xla/client/lib/arithmetic.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/literal.h" diff --git a/tensorflow/compiler/tf2xla/ops/xla_ops.cc b/tensorflow/compiler/tf2xla/ops/xla_ops.cc index bd2c0a5ee88869ba60701c0a7ace05857452eed9..ab77984684db4525f4d3f42b2c9c0f093c82ec45 100644 --- a/tensorflow/compiler/tf2xla/ops/xla_ops.cc +++ b/tensorflow/compiler/tf2xla/ops/xla_ops.cc @@ -409,5 +409,29 @@ body: A function that takes a list of tensors and returns another list of tensors. Both lists have the same types as specified by T. )doc"); +REGISTER_OP("XlaDequantize") + .Input("input: uint32") + .Output("output: bfloat16") + .Attr("min_range: float") + .Attr("max_range: float") + .Attr("mode: string") + .Attr("transpose_output: bool") + .SetIsStateful() + .SetShapeFn(shape_inference::UnknownShape) + .Doc(R"doc( +Takes the packed uint32 input and unpacks the input to uint8 to do +Dequantization on deivce. + +input: Input tensors whose types is uint32, shape is [d0, ..., dn]. +output: Output tensors whose types is bloat16. If transpose_output is true, + output shape is [dn * 4, dn-1, ..., d1, d0]. If transpose_output + is false, output shape is [d0,..., dn * 4]. +min_range: The minimum scalar value possibly produced for the input. +max_range: The maximum scalar value possibly produced for the input. +mode: String to determine the dequantize mode in {"MIN_COMBINED", "MIN_FIRST", "SCALED"}. +transpose_output: Boolean to determine if output is transposed. transpose_output + is faster when input is large and rank of input is higher than 1. +)doc"); + } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/python/xla.py b/tensorflow/compiler/tf2xla/python/xla.py index 147e562658bbfc445f99268812e2c3ae1ee61e30..345193c936a885e5a9e468979c4b73b5b0c9e5c2 100644 --- a/tensorflow/compiler/tf2xla/python/xla.py +++ b/tensorflow/compiler/tf2xla/python/xla.py @@ -386,3 +386,4 @@ def slice(x, start_dims, limit_dims, strides): sort = gen_xla_ops.xla_sort key_value_sort = gen_xla_ops.xla_key_value_sort while_loop = gen_xla_ops.xla_while +dequantize = gen_xla_ops.xla_dequantize diff --git a/tensorflow/compiler/tf2xla/resource_operation_table.cc b/tensorflow/compiler/tf2xla/resource_operation_table.cc index 72b240996fb4d9dcb5f5dfd919da618cbae08c16..ff9f1b9ccba2c4f3307890d5aac4ddb6cfaafcd9 100644 --- a/tensorflow/compiler/tf2xla/resource_operation_table.cc +++ b/tensorflow/compiler/tf2xla/resource_operation_table.cc @@ -65,6 +65,7 @@ CreateResourceOpInfoMap() { add("ResourceApplyFtrlV2" , kReadWrite, kVariable); add("ResourceApplyGradientDescent" , kReadWrite, kVariable); add("ResourceApplyMomentum" , kReadWrite, kVariable); + add("ResourceApplyKerasMomentum" , kReadWrite, kVariable); add("ResourceApplyPowerSign" , kReadWrite, kVariable); add("ResourceApplyProximalAdagrad" , kReadWrite, kVariable); add("ResourceApplyProximalGradientDescent" , kReadWrite, kVariable); diff --git a/tensorflow/compiler/tf2xla/shape_util.cc b/tensorflow/compiler/tf2xla/shape_util.cc index b589512dcdfa32050281120aba6a5ae89a980c2f..ec604af13867171d558cd7324919fb9531caf460 100644 --- a/tensorflow/compiler/tf2xla/shape_util.cc +++ b/tensorflow/compiler/tf2xla/shape_util.cc @@ -18,10 +18,33 @@ limitations under the License. #include #include "tensorflow/compiler/tf2xla/type_util.h" +#include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/core/lib/core/status.h" namespace tensorflow { +namespace { + +Status PopulateInfeedLayoutVector(const xla::Shape& shape, + std::vector* layouts) { + if (xla::ShapeUtil::IsTuple(shape)) { + int64 tuple_elements = xla::ShapeUtil::TupleElementCount(shape); + for (int64 i = 0; i < tuple_elements; ++i) { + const xla::Shape& subshape = + xla::ShapeUtil::GetTupleElementShape(shape, i); + TF_RETURN_IF_ERROR(PopulateInfeedLayoutVector(subshape, layouts)); + } + } else if (xla::LayoutUtil::HasLayout(shape)) { + for (auto dim : xla::LayoutUtil::MinorToMajor(shape)) { + layouts->push_back(dim); + } + } else { + layouts->insert(layouts->end(), xla::ShapeUtil::Rank(shape), -1); + } + return Status::OK(); +} + +} // namespace // Convert an XLA Shape into the equivalent TensorFlow shape. Status XLAShapeToTensorShape(const xla::Shape& shape, @@ -61,4 +84,10 @@ xla::Shape TensorShapeToXLAShape(xla::PrimitiveType type, return xla::ShapeUtil::MakeShapeWithLayout(type, dimensions, layout); } +xla::StatusOr> GetInfeedLayoutVector(const xla::Shape& shape) { + std::vector layouts; + TF_RETURN_IF_ERROR(PopulateInfeedLayoutVector(shape, &layouts)); + return layouts; +} + } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/shape_util.h b/tensorflow/compiler/tf2xla/shape_util.h index 0b231ea8e7a2d8e303e91911e2e0a36fc83e78b4..cf52bf46e7c2a237d57f4c87e7d6efbf3fa9b1c2 100644 --- a/tensorflow/compiler/tf2xla/shape_util.h +++ b/tensorflow/compiler/tf2xla/shape_util.h @@ -18,7 +18,10 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_TF2XLA_SHAPE_UTIL_H_ #define TENSORFLOW_COMPILER_TF2XLA_SHAPE_UTIL_H_ +#include + #include "tensorflow/compiler/xla/shape.h" +#include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/framework/types.pb.h" @@ -41,6 +44,14 @@ Status TensorShapeToXLAShape(DataType dtype, const TensorShape& tensor_shape, xla::Shape TensorShapeToXLAShape(xla::PrimitiveType type, const TensorShape& tensor_shape); +// Given an XLA shape with layouts, builds a layout vector in the form able to +// be fed to an InfeedEnqueue/InfeedEnqueueTuple ops. +// THe returned vector is a linearized sequence of the minor-to-major values of +// the layouts held within the input shape. +// In case the input shape is a tuple, the minor-to-major values will be in the +// order of the tuple elements within the tuple shape. +xla::StatusOr> GetInfeedLayoutVector(const xla::Shape& shape); + } // namespace tensorflow #endif // TENSORFLOW_COMPILER_TF2XLA_SHAPE_UTIL_H_ diff --git a/tensorflow/compiler/tf2xla/side_effect_util.cc b/tensorflow/compiler/tf2xla/side_effect_util.cc index b233e6b2c28e1968bb74901fc684e808ae45ab60..b62f8e9115229ac35c657d374c68336f1168ff77 100644 --- a/tensorflow/compiler/tf2xla/side_effect_util.cc +++ b/tensorflow/compiler/tf2xla/side_effect_util.cc @@ -24,6 +24,8 @@ const char kXlaTokenInputNodesAttrName[] = "_xla_token_input_nodes"; const char kXlaTokenArgNodeName[] = "_xla_token_arg_node"; +const char kXlaHasHostTransferAttrName[] = "_xla_has_host_transfer"; + std::set CalculateTokenInputsForOutputToken(const Graph& g) { std::set results; Node* first_side_effecting_node_on_path = nullptr; diff --git a/tensorflow/compiler/tf2xla/side_effect_util.h b/tensorflow/compiler/tf2xla/side_effect_util.h index f22ddb2f58e1fa5c10ca0fdb956d9136942388b7..7081b362c36c4785164b29003a5f89cd73bcf3af 100644 --- a/tensorflow/compiler/tf2xla/side_effect_util.h +++ b/tensorflow/compiler/tf2xla/side_effect_util.h @@ -35,6 +35,9 @@ extern const char kXlaTokenInputNodesAttrName[]; // node has side-effect dependency on current graph's token input. extern const char kXlaTokenArgNodeName[]; +// This node have XlaRecvAtHost/XlaSendFromHost in its associated functions. +extern const char kXlaHasHostTransferAttrName[]; + // Calculates side-effect dependencies for the graph's token output. // Returns a set of node names representing these dependencies. std::set CalculateTokenInputsForOutputToken(const Graph& g); diff --git a/tensorflow/compiler/tf2xla/tf2xla_test.cc b/tensorflow/compiler/tf2xla/tf2xla_test.cc index ab26d939ccba75ce58609ffd71c7ccadbe90cfa8..24afe595b18b823818bd8fe65bc599af8bce040a 100644 --- a/tensorflow/compiler/tf2xla/tf2xla_test.cc +++ b/tensorflow/compiler/tf2xla/tf2xla_test.cc @@ -91,7 +91,7 @@ TEST(ConvertGraphDefToXla, Sum) { client->ExecuteAndTransfer(computation, {x_global.get(), y_global.get()}); TF_EXPECT_OK(result_or.status()); xla::Literal result = std::move(result_or.ValueOrDie()); - EXPECT_EQ("(s32[]) (\n42\n)", result.ToString()); + EXPECT_EQ("(\ns32[] 42\n)", result.ToString()); config.mutable_feed(0)->mutable_id()->set_output_index( 123); /* invalid output_index */ diff --git a/tensorflow/compiler/tf2xla/tf2xla_util.cc b/tensorflow/compiler/tf2xla/tf2xla_util.cc index cc81772e8c5da710bc733f7e4f5fe820b2c2d110..18d87727c500619bf386be7d8c7085724f44aba3 100644 --- a/tensorflow/compiler/tf2xla/tf2xla_util.cc +++ b/tensorflow/compiler/tf2xla/tf2xla_util.cc @@ -364,6 +364,7 @@ Status AddPlaceholdersForFeeds( GraphDef gd; *gd.mutable_versions() = graph_def->versions(); *gd.add_node() = *existing; + MergeDebugInfo(NodeDebugInfo(*existing), gd.mutable_node(0)); TF_RETURN_IF_ERROR( AddDefaultAttrsToGraphDef(&gd, *op_registry, 0 /*node_offset*/)); @@ -390,6 +391,7 @@ Status AddPlaceholdersForFeeds( // in this code. for (auto it = placeholder_info.begin(); it != placeholder_info.end(); ++it) { const PlaceholderInfo& info = it->second; + // TODO(shikharagarwal): Add original node information. NodeDef* d = graph_def->add_node(); d->set_name(info.placeholder_name); d->set_op("PlaceholderV2"); @@ -557,6 +559,12 @@ bool HasAssociatedFunction(const NodeDef& node_def, return true; } + if (node_def.op() == "XlaHostCompute") { + // XlaHostCompute has "shape_inference_graph" func attr, but that's not + // related to graph execution. + return false; + } + for (const auto& iter : node_def.attr()) { if (iter.second.has_func()) { return true; @@ -578,6 +586,9 @@ std::vector GetAssociatedFunctions( // This is a SymbolicGradient op. AttrValueMap attrs(node.attrs().begin(), node.attrs().end()); results.emplace_back(AssociatedFunctionInfo::SymbolicGradient(op, attrs)); + } else if (node.type_string() == "XlaHostCompute") { + // XlaHostCompute has "shape_inference_graph" func attr, but that's not + // related to graph execution. } else { // Collect all function attrs for the node. for (auto& iter : node.attrs()) { @@ -599,7 +610,9 @@ Status RewriteAssociatedFunction( switch (associated_function.type()) { case AssociatedFunctionInfo::kFunctionCallNode: { // Change this node to call the new function. - NodeDefBuilder builder(node->name(), rewritten_function_name, fld); + NodeDebugInfo debug_info(*node); + NodeDefBuilder builder(node->name(), rewritten_function_name, fld, + &debug_info); for (auto attr : node->attrs()) { builder.Attr(attr.first, attr.second); } diff --git a/tensorflow/compiler/xla/BUILD b/tensorflow/compiler/xla/BUILD index 4360e0857964b0ac63fc887e269b04a4b00d854a..722d1376687efa1c04158e3fd9ce539aac9d0122 100644 --- a/tensorflow/compiler/xla/BUILD +++ b/tensorflow/compiler/xla/BUILD @@ -109,7 +109,7 @@ cc_library( name = "status_macros", srcs = ["status_macros.cc"], hdrs = ["status_macros.h"], - visibility = [":friends"], + visibility = ["//visibility:public"], deps = [ ":statusor", ":types", @@ -224,6 +224,7 @@ cc_library( name = "shape_util", srcs = [ "index_util.cc", + "layout.cc", "layout_util.cc", "primitive_util.cc", "shape.cc", @@ -231,6 +232,7 @@ cc_library( ], hdrs = [ "index_util.h", + "layout.h", "layout_util.h", "primitive_util.h", "shape.h", @@ -290,6 +292,22 @@ tf_cc_test( ], ) +tf_cc_test( + name = "primitive_util_test", + srcs = ["primitive_util_test.cc"], + deps = [ + ":shape_util", + ":status_macros", + ":test", + ":test_helpers", + ":types", + ":util", + ":xla_data_proto", + "//tensorflow/core:lib", + "//tensorflow/core:test_main", + ], +) + tf_cc_test( name = "layout_util_test", srcs = ["layout_util_test.cc"], @@ -301,6 +319,22 @@ tf_cc_test( ], ) +tf_cc_test( + name = "layout_test", + srcs = ["layout_test.cc"], + deps = [ + ":shape_util", + ":status_macros", + ":test", + ":test_helpers", + ":types", + ":util", + ":xla_data_proto", + "//tensorflow/core:test_main", + "@com_google_absl//absl/strings", + ], +) + tf_cc_test( name = "index_util_test", srcs = ["index_util_test.cc"], @@ -575,6 +609,7 @@ cc_library( ":types", ":util", ":xla_data_proto", + "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", "@com_google_absl//absl/memory", @@ -705,7 +740,6 @@ cc_library( "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service:hlo_evaluator", "//tensorflow/compiler/xla/service:shape_inference", - "//tensorflow/compiler/xla/service/cpu:runtime_single_threaded_matmul", "//tensorflow/core:lib", "@com_google_absl//absl/memory", "@com_google_absl//absl/types:span", diff --git a/tensorflow/compiler/xla/client/BUILD b/tensorflow/compiler/xla/client/BUILD index fe99564d3c671cd7890e1fa26fcd2e3384972983..27c075e8f13f6777af4e837501d97a33034313f5 100644 --- a/tensorflow/compiler/xla/client/BUILD +++ b/tensorflow/compiler/xla/client/BUILD @@ -3,7 +3,7 @@ licenses(["notice"]) # Apache 2.0 -package(default_visibility = [":friends"]) +package(default_visibility = ["//visibility:public"]) package_group( name = "friends", @@ -170,6 +170,7 @@ cc_library( "//tensorflow/core:lib", "//tensorflow/core:stream_executor_no_cuda", "@com_google_absl//absl/memory", + "@com_google_absl//absl/types:optional", ], ) diff --git a/tensorflow/compiler/xla/client/client.cc b/tensorflow/compiler/xla/client/client.cc index 74b76f929949d3300a5d0ff45d5fa4cd9f162642..43127cae1e5d81521003a28288e27d291e33c9b9 100644 --- a/tensorflow/compiler/xla/client/client.cc +++ b/tensorflow/compiler/xla/client/client.cc @@ -186,7 +186,7 @@ StatusOr Client::ComputeConstant(const XlaComputation& computation, ComputeConstantGraphRequest request; *request.mutable_computation() = computation.proto(); if (output_layout != nullptr) { - *request.mutable_output_layout() = *output_layout; + *request.mutable_output_layout() = output_layout->ToProto(); } ComputeConstantResponse response; diff --git a/tensorflow/compiler/xla/client/client_library.cc b/tensorflow/compiler/xla/client/client_library.cc index 27b7fa7b29206affa9f9c2e4becd9e4ea66484ab..42aae026229a49fd801cc90562fa51f604336148 100644 --- a/tensorflow/compiler/xla/client/client_library.cc +++ b/tensorflow/compiler/xla/client/client_library.cc @@ -24,12 +24,14 @@ limitations under the License. namespace xla { -LocalClientOptions::LocalClientOptions(se::Platform* platform, - int number_of_replicas, - int intra_op_parallelism_threads) +LocalClientOptions::LocalClientOptions( + se::Platform* platform, int number_of_replicas, + int intra_op_parallelism_threads, + const absl::optional>& allowed_devices) : platform_(platform), number_of_replicas_(number_of_replicas), - intra_op_parallelism_threads_(intra_op_parallelism_threads) {} + intra_op_parallelism_threads_(intra_op_parallelism_threads), + allowed_devices_(allowed_devices) {} LocalClientOptions& LocalClientOptions::set_platform(se::Platform* platform) { platform_ = platform; @@ -58,6 +60,17 @@ int LocalClientOptions::intra_op_parallelism_threads() const { return intra_op_parallelism_threads_; } +LocalClientOptions& LocalClientOptions::set_allowed_devices( + const absl::optional>& allowed_devices) { + allowed_devices_ = allowed_devices; + return *this; +} + +const absl::optional>& LocalClientOptions::allowed_devices() + const { + return allowed_devices_; +} + /* static */ ClientLibrary& ClientLibrary::Singleton() { static ClientLibrary* c = new ClientLibrary; return *c; @@ -67,9 +80,10 @@ ClientLibrary::ClientLibrary() = default; ClientLibrary::~ClientLibrary() = default; /* static */ StatusOr ClientLibrary::GetOrCreateLocalClient( - se::Platform* platform) { + se::Platform* platform, const absl::optional>& device_set) { LocalClientOptions default_options; default_options.set_platform(platform); + default_options.set_allowed_devices(device_set); return GetOrCreateLocalClient(default_options); } @@ -94,7 +108,7 @@ ClientLibrary::~ClientLibrary() = default; service_options.set_number_of_replicas(replica_count); service_options.set_intra_op_parallelism_threads( options.intra_op_parallelism_threads()); - + service_options.set_allowed_devices(options.allowed_devices()); auto instance = absl::make_unique(); TF_ASSIGN_OR_RETURN(instance->service, LocalService::NewService(service_options)); diff --git a/tensorflow/compiler/xla/client/client_library.h b/tensorflow/compiler/xla/client/client_library.h index 3ad558fa532931937fab898f7b855f0a3370eaec..62d225c6c298b26bbbd248fc1f4be64fc8efcf6b 100644 --- a/tensorflow/compiler/xla/client/client_library.h +++ b/tensorflow/compiler/xla/client/client_library.h @@ -23,9 +23,11 @@ limitations under the License. #include #include +#include #include #include +#include "absl/types/optional.h" #include "tensorflow/compiler/xla/client/compile_only_client.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/service/compile_only_service.h" @@ -43,9 +45,10 @@ namespace xla { // Options to configure the local client when it is created. class LocalClientOptions { public: - LocalClientOptions(se::Platform* platform = nullptr, - int number_of_replicas = 1, - int intra_op_parallelism_threads = -1); + LocalClientOptions( + se::Platform* platform = nullptr, int number_of_replicas = 1, + int intra_op_parallelism_threads = -1, + const absl::optional>& allowed_devices = absl::nullopt); // Set the platform backing the service, or nullptr for the default platform. LocalClientOptions& set_platform(se::Platform* platform); @@ -60,10 +63,17 @@ class LocalClientOptions { LocalClientOptions& set_intra_op_parallelism_threads(int num_threads); int intra_op_parallelism_threads() const; + // Sets the allowed_devices set for selectively constructing stream executors + // on the platform. + LocalClientOptions& set_allowed_devices( + const absl::optional>& allowed_devices); + const absl::optional>& allowed_devices() const; + private: se::Platform* platform_; int number_of_replicas_; int intra_op_parallelism_threads_; + absl::optional> allowed_devices_; }; class ClientLibrary { @@ -73,8 +83,11 @@ class ClientLibrary { // // platform : The platform the underlying XLA service should target. If // null then default platform is used. + // device_set: Set of device IDs for which the stream executor will be + // created, for the given platform. static StatusOr GetOrCreateLocalClient( - se::Platform* platform = nullptr); + se::Platform* platform = nullptr, + const absl::optional>& allowed_devices = absl::nullopt); static StatusOr GetOrCreateLocalClient( const LocalClientOptions& options); diff --git a/tensorflow/compiler/xla/client/lib/BUILD b/tensorflow/compiler/xla/client/lib/BUILD index 41db8de29ff0085a30847ff41db4ffbfc774e2a1..6192b89b4abf24d2f21daa0f4a3faf9c405b9fa5 100644 --- a/tensorflow/compiler/xla/client/lib/BUILD +++ b/tensorflow/compiler/xla/client/lib/BUILD @@ -1,5 +1,7 @@ # Common computation builders for XLA. +load("//tensorflow/compiler/xla/tests:build_defs.bzl", "generate_backend_suites", "xla_test") + licenses(["notice"]) # Apache 2.0 package(default_visibility = ["//tensorflow/compiler/xla/client:friends"]) @@ -13,9 +15,6 @@ filegroup( ]), ) -load("//tensorflow/compiler/xla/tests:build_defs.bzl", "xla_test") -load("//tensorflow/compiler/xla/tests:build_defs.bzl", "generate_backend_suites") - # Generate test_suites for all backends, named "${backend}_tests". generate_backend_suites() @@ -35,6 +34,48 @@ cc_library( ], ) +cc_library( + name = "cholesky", + srcs = ["cholesky.cc"], + hdrs = ["cholesky.h"], + deps = [ + ":math", + "//tensorflow/compiler/xla:literal", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/client/lib:constants", + "//tensorflow/compiler/xla/client/lib:loops", + "//tensorflow/compiler/xla/client/lib:matrix", + "//tensorflow/compiler/xla/client/lib:slicing", + "//tensorflow/compiler/xla/client/lib:triangular_solve", + "//tensorflow/core:lib", + ], +) + +xla_test( + name = "cholesky_test", + srcs = ["cholesky_test.cc"], + tags = ["optonly"], + deps = [ + ":arithmetic", + ":cholesky", + ":matrix", + "//tensorflow/compiler/xla:array2d", + "//tensorflow/compiler/xla:literal", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/tests:client_library_test_base", + "//tensorflow/compiler/xla/tests:literal_test_util", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/core:test", + ], +) + cc_library( name = "constants", srcs = ["constants.cc"], @@ -75,6 +116,22 @@ cc_library( ], ) +cc_library( + name = "loops", + srcs = ["loops.cc"], + hdrs = ["loops.h"], + deps = [ + ":constants", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/client:xla_computation", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", + ], +) + cc_library( name = "math", srcs = ["math.cc"], @@ -177,6 +234,48 @@ cc_library( ], ) +cc_library( + name = "qr", + srcs = ["qr.cc"], + hdrs = ["qr.h"], + deps = [ + ":arithmetic", + ":constants", + ":loops", + ":math", + ":matrix", + ":slicing", + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/core:lib", + ], +) + +xla_test( + name = "qr_test", + srcs = ["qr_test.cc"], + tags = ["optonly"], + deps = [ + ":matrix", + ":qr", + "//tensorflow/compiler/xla:array2d", + "//tensorflow/compiler/xla:array3d", + "//tensorflow/compiler/xla:literal", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/tests:client_library_test_base", + "//tensorflow/compiler/xla/tests:literal_test_util", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/core:test", + ], +) + cc_library( name = "slicing", srcs = ["slicing.cc"], @@ -237,6 +336,34 @@ xla_test( ], ) +cc_library( + name = "quantize", + hdrs = ["quantize.h"], + deps = [ + ":constants", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/core:lib", + ], +) + +xla_test( + name = "quantize_test", + srcs = ["quantize_test.cc"], + tags = ["enable_for_xla_interpreter"], + deps = [ + ":quantize", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/tests:client_library_test_base", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + ], +) + cc_library( name = "testing", srcs = ["testing.cc"], @@ -285,6 +412,8 @@ xla_test( srcs = ["triangular_solve_test.cc"], tags = ["noasan"], # sometimes times out, http://b/78650012 deps = [ + ":math", + ":matrix", ":triangular_solve", "//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:literal", diff --git a/tensorflow/compiler/tf2xla/lib/cholesky.cc b/tensorflow/compiler/xla/client/lib/cholesky.cc similarity index 68% rename from tensorflow/compiler/tf2xla/lib/cholesky.cc rename to tensorflow/compiler/xla/client/lib/cholesky.cc index 550ab5b05693b79e60e49577309328ac6846d3f9..fd98049968491d80b9717a2de1f34997bd9d18c1 100644 --- a/tensorflow/compiler/tf2xla/lib/cholesky.cc +++ b/tensorflow/compiler/xla/client/lib/cholesky.cc @@ -13,14 +13,14 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/tf2xla/lib/cholesky.h" +#include "tensorflow/compiler/xla/client/lib/cholesky.h" #include #include -#include "tensorflow/compiler/tf2xla/lib/util.h" -#include "tensorflow/compiler/tf2xla/lib/while_loop.h" #include "tensorflow/compiler/xla/client/lib/constants.h" +#include "tensorflow/compiler/xla/client/lib/loops.h" +#include "tensorflow/compiler/xla/client/lib/math.h" #include "tensorflow/compiler/xla/client/lib/matrix.h" #include "tensorflow/compiler/xla/client/lib/slicing.h" #include "tensorflow/compiler/xla/client/lib/triangular_solve.h" @@ -31,7 +31,7 @@ limitations under the License. #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/core/lib/core/errors.h" -namespace tensorflow { +namespace xla { namespace { @@ -50,26 +50,25 @@ namespace { // l[..., j+1:, j] = (a[..., j+1:, j] - np.dot(l[..., j+1:, :j], row_t)) / // l[..., j, j] // return l -xla::XlaOp CholeskyUnblocked(xla::XlaOp a, - xla::PrecisionConfig::Precision precision) { - xla::XlaBuilder* builder = a.builder(); - return builder->ReportErrorOrReturn([&]() -> xla::StatusOr { - TF_ASSIGN_OR_RETURN(xla::Shape a_shape, builder->GetShape(a)); - const int n_dims = xla::ShapeUtil::Rank(a_shape); - const int64 n = xla::ShapeUtil::GetDimension(a_shape, -1); - auto major_dims = xla::AsInt64Slice(a_shape.dimensions()) +XlaOp CholeskyUnblocked(XlaOp a, PrecisionConfig::Precision precision) { + XlaBuilder* builder = a.builder(); + return builder->ReportErrorOrReturn([&]() -> StatusOr { + TF_ASSIGN_OR_RETURN(Shape a_shape, builder->GetShape(a)); + const int n_dims = ShapeUtil::Rank(a_shape); + const int64 n = ShapeUtil::GetDimension(a_shape, -1); + auto major_dims = AsInt64Slice(a_shape.dimensions()) .subspan( /*pos=*/0, /*len=*/n_dims - 2); - xla::XlaOp l = xla::ZerosLike(a); + XlaOp l = ZerosLike(a); // Construct the for loop body to iterate over rows. - auto body_fn = [&](xla::XlaOp i, absl::Span loop_vars, - xla::XlaBuilder* body_builder) - -> xla::StatusOr> { - xla::Shape col_shape; - xla::Shape row_shape; + auto body_fn = + [&](XlaOp i, absl::Span loop_vars, + XlaBuilder* body_builder) -> StatusOr> { + Shape col_shape; + Shape row_shape; for (int64 d : major_dims) { row_shape.add_dimensions(d); col_shape.add_dimensions(d); @@ -77,43 +76,40 @@ xla::XlaOp CholeskyUnblocked(xla::XlaOp a, row_shape.add_dimensions(1); row_shape.add_dimensions(n); row_shape.set_element_type(a_shape.element_type()); - auto mask_zeros_row = xla::Zeros(body_builder, row_shape); + auto mask_zeros_row = Zeros(body_builder, row_shape); col_shape.add_dimensions(n); col_shape.add_dimensions(1); col_shape.set_element_type(a_shape.element_type()); - auto mask_zeros_col = xla::Zeros(body_builder, col_shape); + auto mask_zeros_col = Zeros(body_builder, col_shape); std::vector mask_vector(n); std::iota(mask_vector.begin(), mask_vector.end(), 0); - auto mask_range = xla::ConstantR1(body_builder, mask_vector); + auto mask_range = ConstantR1(body_builder, mask_vector); auto mask_range_row = - xla::Broadcast(xla::Reshape(mask_range, {0}, {1, n}), major_dims); + Broadcast(Reshape(mask_range, {0}, {1, n}), major_dims); auto mask_range_col = - xla::Broadcast(xla::Reshape(mask_range, {0}, {n, 1}), major_dims); + Broadcast(Reshape(mask_range, {0}, {n, 1}), major_dims); auto body_a = loop_vars[0]; auto body_l = loop_vars[1]; // row = l[..., i, :i] // select the whole i-th row, then mask out all columns past i-1 - auto zero = xla::ConstantR0(body_builder, 0); + auto zero = ConstantR0(body_builder, 0); auto l_i = DynamicSliceInMinorDims(body_l, {i, zero}, {1, n}); - auto row = xla::Select(xla::Ge(mask_range_row, i), mask_zeros_row, l_i); + auto row = Select(Ge(mask_range_row, i), mask_zeros_row, l_i); // a[..., i, i] auto a_ii = DynamicSliceInMinorDims(body_a, {i, i}, {1, 1}); // np.dot(row, np.swapaxes(row, -1, -2)) auto diag_dot = BatchDot(row, TransposeInMinorDims(row), precision); // l[..., i, i] = np.sqrt(a[..., i, i] - np.dot(row, // np.swapaxes(row, -1, -2))) - auto l_ii = - xla::Pow(a_ii - diag_dot, - FloatLiteral(body_builder, a_shape.element_type(), 0.5)); + auto l_ii = Sqrt(a_ii - diag_dot); // a[..., i+1:, i] // select the whole i-th column, then mask out all rows above i+1 auto a_0i = DynamicSliceInMinorDims(body_a, {i}, {1}); - auto a_ip1i = - xla::Select(xla::Le(mask_range_col, i), mask_zeros_col, a_0i); + auto a_ip1i = Select(Le(mask_range_col, i), mask_zeros_col, a_0i); // l[..., i+1:, i] = (a[..., i+1:, i] - np.dot(l[..., i+1:, :i], r.T)) / // l[..., i, i] @@ -122,8 +118,7 @@ xla::XlaOp CholeskyUnblocked(xla::XlaOp a, // r.T) auto dot = BatchDot(body_l, TransposeInMinorDims(row), precision); // np.dot(l[..., i+1:, :i], r.T) - auto dot_ip1 = - xla::Select(xla::Le(mask_range_col, i), mask_zeros_col, dot); + auto dot_ip1 = Select(Le(mask_range_col, i), mask_zeros_col, dot); body_l = DynamicUpdateSliceInMinorDims(body_l, (a_ip1i - dot_ip1) / l_ii, {i}); @@ -131,12 +126,12 @@ xla::XlaOp CholeskyUnblocked(xla::XlaOp a, // column assign will wrap around and overwrite the diagonal assign. body_l = DynamicUpdateSliceInMinorDims(body_l, l_ii, {i, i}); - return std::vector{body_a, body_l}; + return std::vector{body_a, body_l}; }; TF_ASSIGN_OR_RETURN( auto cholesky_while, - XlaForEachIndex(n, xla::S32, body_fn, {a, l}, "unblocked", builder)); + ForEachIndex(n, S32, body_fn, {a, l}, "unblocked", builder)); return cholesky_while[1]; }); @@ -144,34 +139,35 @@ xla::XlaOp CholeskyUnblocked(xla::XlaOp a, } // namespace -xla::XlaOp Cholesky(xla::XlaOp a, int64 block_size, - xla::PrecisionConfig::Precision precision) { - xla::XlaBuilder* builder = a.builder(); - return builder->ReportErrorOrReturn([&]() -> xla::StatusOr { - TF_ASSIGN_OR_RETURN(xla::Shape a_shape, builder->GetShape(a)); - const int ndims = xla::ShapeUtil::Rank(a_shape); +XlaOp Cholesky(XlaOp a, int64 block_size, + PrecisionConfig::Precision precision) { + XlaBuilder* builder = a.builder(); + return builder->ReportErrorOrReturn([&]() -> StatusOr { + TF_ASSIGN_OR_RETURN(Shape a_shape, builder->GetShape(a)); + const int ndims = ShapeUtil::Rank(a_shape); if (ndims < 2) { - return errors::InvalidArgument( - "Arguments to Cholesky must have rank >= 2: ", ndims); + return InvalidArgument( + "Argument to Cholesky must have rank >= 2; shape was %s", + a_shape.ToString()); } - const int64 n = xla::ShapeUtil::GetDimension(a_shape, -1); - if (n != xla::ShapeUtil::GetDimension(a_shape, -2)) { - return errors::InvalidArgument( - "Arguments to Cholesky must be square matrices: ", - xla::ShapeUtil::HumanString(a_shape)); + const int64 n = ShapeUtil::GetDimension(a_shape, -1); + if (n != ShapeUtil::GetDimension(a_shape, -2)) { + return InvalidArgument( + "Argument to Cholesky must be batched square matrices; got shape %s", + ShapeUtil::HumanString(a_shape)); } if (block_size < 1) { - return errors::InvalidArgument( - "block_size argument to Cholesky must be >= 1; got ", block_size); + return InvalidArgument( + "block_size argument to Cholesky must be >= 1; got %d", block_size); } // Blocked left-looking Cholesky factorization. // Algorithm 1 from // Haidar, Azzam, et al. "High-performance Cholesky factorization for // GPU-only execution." Proceedings of General Purpose GPUs. ACM, 2017. - xla::XlaOp l = xla::ZerosLike(a); + XlaOp l = ZerosLike(a); for (int64 i = 0; i < n; i += block_size) { int64 k = std::min(block_size, n - i); if (i > 0) { @@ -207,4 +203,4 @@ xla::XlaOp Cholesky(xla::XlaOp a, int64 block_size, }); } -} // namespace tensorflow +} // namespace xla diff --git a/tensorflow/compiler/tf2xla/lib/cholesky.h b/tensorflow/compiler/xla/client/lib/cholesky.h similarity index 87% rename from tensorflow/compiler/tf2xla/lib/cholesky.h rename to tensorflow/compiler/xla/client/lib/cholesky.h index 9a561c34b92ee45059f2a05336e682838f8e36e2..0bae26837c0f14dd0cfab82cf426becc787ec11c 100644 --- a/tensorflow/compiler/tf2xla/lib/cholesky.h +++ b/tensorflow/compiler/xla/client/lib/cholesky.h @@ -13,13 +13,13 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_COMPILER_TF2XLA_LIB_CHOLESKY_H_ -#define TENSORFLOW_COMPILER_TF2XLA_LIB_CHOLESKY_H_ +#ifndef TENSORFLOW_COMPILER_XLA_CLIENT_LIB_CHOLESKY_H_ +#define TENSORFLOW_COMPILER_XLA_CLIENT_LIB_CHOLESKY_H_ #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -namespace tensorflow { +namespace xla { // Computes the Cholesky decompositions of a batch of symmetric positive // definite matrices. @@ -34,6 +34,6 @@ xla::XlaOp Cholesky( xla::XlaOp a, int64 block_size = 256, xla::PrecisionConfig::Precision precision = xla::PrecisionConfig::HIGHEST); -} // namespace tensorflow +} // namespace xla -#endif // TENSORFLOW_COMPILER_TF2XLA_LIB_CHOLESKY_H_ +#endif // TENSORFLOW_COMPILER_XLA_CLIENT_LIB_CHOLESKY_H_ diff --git a/tensorflow/compiler/xla/client/lib/cholesky_test.cc b/tensorflow/compiler/xla/client/lib/cholesky_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..ba9580a3d32225625acc1447344b7d2c16c5d8a5 --- /dev/null +++ b/tensorflow/compiler/xla/client/lib/cholesky_test.cc @@ -0,0 +1,166 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/client/lib/cholesky.h" + +#include +#include +#include + +#include "tensorflow/compiler/xla/array2d.h" +#include "tensorflow/compiler/xla/client/lib/arithmetic.h" +#include "tensorflow/compiler/xla/client/lib/matrix.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/literal.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/tests/client_library_test_base.h" +#include "tensorflow/compiler/xla/tests/literal_test_util.h" +#include "tensorflow/compiler/xla/tests/test_macros.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/core/lib/core/status_test_util.h" + +namespace { + +using xla::int64; + +using CholeskyTest = xla::ClientLibraryTestBase; + +XLA_TEST_F(CholeskyTest, Simple) { + xla::XlaBuilder builder(TestName()); + + xla::Array2D a_vals({ + {4, 6, 8, 10}, + {6, 45, 54, 63}, + {8, 54, 146, 166}, + {10, 63, 166, 310}, + }); + + xla::XlaOp a; + auto a_data = CreateR2Parameter(a_vals, 0, "a", &builder, &a); + xla::Cholesky(a, /*block_size=*/2); + + xla::Array2D expected({ + {2, 0, 0, 0}, + {3, 6, 0, 0}, + {4, 7, 9, 0}, + {5, 8, 10, 11}, + }); + + ComputeAndCompareR2(&builder, expected, {a_data.get()}, + xla::ErrorSpec(1e-4, 1e-4)); +} + +XLA_TEST_F(CholeskyTest, Simple2) { + xla::XlaBuilder builder(TestName()); + + xla::Array2D a_vals({ + {16, 24, 8, 12}, + {24, 61, 82, 48}, + {8, 82, 456, 106}, + {12, 48, 106, 62}, + }); + + xla::XlaOp a; + auto a_data = CreateR2Parameter(a_vals, 0, "a", &builder, &a); + xla::Cholesky(a); + + xla::Array2D expected( + {{4, 0, 0, 0}, {6, 5, 0, 0}, {2, 14, 16, 0}, {3, 6, 1, 4}}); + + ComputeAndCompareR2(&builder, expected, {a_data.get()}, + xla::ErrorSpec(1e-4, 1e-4)); +} + +XLA_TEST_F(CholeskyTest, SimpleBatched) { + xla::XlaBuilder builder(TestName()); + + xla::Array3D a_vals({ + { + {4, 6, 8, 10}, + {6, 45, 54, 63}, + {8, 54, 146, 166}, + {10, 63, 166, 310}, + }, + { + {16, 24, 8, 12}, + {24, 61, 82, 48}, + {8, 82, 456, 106}, + {12, 48, 106, 62}, + }, + }); + + xla::XlaOp a; + auto a_data = CreateR3Parameter(a_vals, 0, "a", &builder, &a); + xla::Cholesky(a); + + xla::Array3D expected({ + { + {2, 0, 0, 0}, + {3, 6, 0, 0}, + {4, 7, 9, 0}, + {5, 8, 10, 11}, + }, + {{4, 0, 0, 0}, {6, 5, 0, 0}, {2, 14, 16, 0}, {3, 6, 1, 4}}, + }); + + ComputeAndCompareR3(&builder, expected, {a_data.get()}, + xla::ErrorSpec(1e-4, 1e-4)); +} + +using CholeskyTestCase = std::tuple; + +class RandomCholeskyTest + : public xla::ClientLibraryTestBase, + public ::testing::WithParamInterface {}; + +XLA_TEST_P(RandomCholeskyTest, Random) { + xla::XlaBuilder builder(TestName()); + + auto test_params = GetParam(); + std::vector dimensions = {std::get<0>(test_params), + std::get<1>(test_params), + std::get<1>(test_params)}; + xla::Shape shape = xla::ShapeUtil::MakeShape(xla::F32, dimensions); + TF_ASSERT_OK_AND_ASSIGN( + auto literal, + xla::LiteralUtil::CreateRandomLiteral(shape, 0.0, 1.0)); + + auto input = xla::Parameter(&builder, 0, shape, "input"); + // Form a random positive definite matrix. + auto matrix = xla::BatchDot(input, TransposeInMinorDims(input), + xla::PrecisionConfig::HIGHEST); + + auto cholesky = xla::Cholesky(matrix, /*block_size=*/4); + + // Verify that ||matrix - cholesky * cholesky_t||_2 ~= 0 + auto verification = xla::BatchDot(cholesky, TransposeInMinorDims(cholesky), + xla::PrecisionConfig::HIGHEST); + auto delta = matrix - verification; + xla::Reduce(delta * delta, xla::ConstantR0(&builder, 0.0), + CreateScalarAddComputation(xla::F32, &builder), {0, 1, 2}); + + TF_ASSERT_OK_AND_ASSIGN(auto input_data, client_->TransferToServer(literal)); + ComputeAndCompareR0(&builder, 0.0, {input_data.get()}, + xla::ErrorSpec(1e-4, 1e-4)); +} + +INSTANTIATE_TEST_CASE_P(RandomCholeskyTestInstance, RandomCholeskyTest, + ::testing::Values(CholeskyTestCase{1, 1}, + CholeskyTestCase{1, 2}, + CholeskyTestCase{10, 5}, + CholeskyTestCase{2, 20})); + +} // namespace diff --git a/tensorflow/compiler/tf2xla/lib/while_loop.cc b/tensorflow/compiler/xla/client/lib/loops.cc similarity index 50% rename from tensorflow/compiler/tf2xla/lib/while_loop.cc rename to tensorflow/compiler/xla/client/lib/loops.cc index 594ab1dfd0700f47501712183f6efe62d17e15e7..721f987628a8ac7da3f3f872939c3f0457d6bbe2 100644 --- a/tensorflow/compiler/tf2xla/lib/while_loop.cc +++ b/tensorflow/compiler/xla/client/lib/loops.cc @@ -13,44 +13,43 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/tf2xla/lib/while_loop.h" -#include "tensorflow/compiler/tf2xla/lib/util.h" +#include "tensorflow/compiler/xla/client/lib/loops.h" + +#include "tensorflow/compiler/xla/client/lib/constants.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" -namespace tensorflow { +namespace xla { -xla::StatusOr> XlaWhileLoop( - const LoopConditionFunction& condition_function, - const LoopBodyFunction& body_function, - absl::Span initial_values, absl::string_view name, - xla::XlaBuilder* builder) { +StatusOr> WhileLoopHelper( + const WhileLoopHelperConditionFunction& condition_function, + const WhileLoopHelperBodyFunction& body_function, + absl::Span initial_values, absl::string_view name, + XlaBuilder* builder) { int arity = initial_values.size(); - std::vector var_shapes; + std::vector var_shapes; var_shapes.reserve(arity); - for (const xla::XlaOp& input : initial_values) { + for (const XlaOp& input : initial_values) { TF_ASSIGN_OR_RETURN(auto shape, builder->GetShape(input)); var_shapes.push_back(std::move(shape)); } - xla::Shape tuple_shape = xla::ShapeUtil::MakeTupleShape(var_shapes); + Shape tuple_shape = ShapeUtil::MakeTupleShape(var_shapes); // Unpacks a tuple into its component parts. - auto unpack_tuple = [](xla::XlaOp tuple, int arity, - xla::XlaBuilder* builder) { - std::vector elements(arity); + auto unpack_tuple = [](XlaOp tuple, int arity, XlaBuilder* builder) { + std::vector elements(arity); for (int i = 0; i < arity; ++i) { - elements[i] = xla::GetTupleElement(tuple, i); + elements[i] = GetTupleElement(tuple, i); } return elements; }; // Build the condition. - std::unique_ptr cond_builder = + std::unique_ptr cond_builder = builder->CreateSubBuilder(absl::StrCat(name, "_condition")); { - auto parameter = - xla::Parameter(cond_builder.get(), 0, tuple_shape, "parameter"); + auto parameter = Parameter(cond_builder.get(), 0, tuple_shape, "parameter"); TF_RETURN_IF_ERROR( condition_function(unpack_tuple(parameter, arity, cond_builder.get()), @@ -60,11 +59,10 @@ xla::StatusOr> XlaWhileLoop( TF_ASSIGN_OR_RETURN(auto cond, cond_builder->Build()); // Build the body. - std::unique_ptr body_builder = + std::unique_ptr body_builder = builder->CreateSubBuilder(absl::StrCat(name, "_body")); { - auto parameter = - xla::Parameter(body_builder.get(), 0, tuple_shape, "parameter"); + auto parameter = Parameter(body_builder.get(), 0, tuple_shape, "parameter"); TF_ASSIGN_OR_RETURN( auto result, @@ -72,56 +70,54 @@ xla::StatusOr> XlaWhileLoop( body_builder.get())); TF_RET_CHECK(result.size() == initial_values.size()); - xla::Tuple(body_builder.get(), result); + Tuple(body_builder.get(), result); } TF_ASSIGN_OR_RETURN(auto body, body_builder->Build()); - auto outputs = xla::While(cond, body, xla::Tuple(builder, initial_values)); + auto outputs = While(cond, body, Tuple(builder, initial_values)); return unpack_tuple(outputs, arity, builder); } -xla::StatusOr> XlaForEachIndex( - int64 num_iterations, xla::PrimitiveType num_iterations_type, +StatusOr> ForEachIndex( + int64 num_iterations, PrimitiveType num_iterations_type, const ForEachIndexBodyFunction& body_function, - absl::Span initial_values, absl::string_view name, - xla::XlaBuilder* builder) { - auto while_cond_fn = - [&](absl::Span values, - xla::XlaBuilder* cond_builder) -> xla::StatusOr { - return xla::Lt(values[0], IntegerLiteral(cond_builder, num_iterations_type, - num_iterations)); + absl::Span initial_values, absl::string_view name, + XlaBuilder* builder) { + auto while_cond_fn = [&](absl::Span values, + XlaBuilder* cond_builder) -> StatusOr { + return Lt(values[0], ConstantR0WithType(cond_builder, num_iterations_type, + num_iterations)); }; - auto while_body_fn = [&](absl::Span values, - xla::XlaBuilder* body_builder) - -> xla::StatusOr> { - xla::XlaOp iteration = values[0]; + auto while_body_fn = + [&](absl::Span values, + XlaBuilder* body_builder) -> StatusOr> { + XlaOp iteration = values[0]; - std::vector updated_values; + std::vector updated_values; updated_values.reserve(values.size()); - updated_values.push_back(xla::Add( + updated_values.push_back(Add( iteration, - xla::ConstantLiteral(body_builder, - xla::LiteralUtil::One(num_iterations_type)))); + ConstantLiteral(body_builder, LiteralUtil::One(num_iterations_type)))); values.remove_prefix(1); - TF_ASSIGN_OR_RETURN(std::vector body_outputs, + TF_ASSIGN_OR_RETURN(std::vector body_outputs, body_function(iteration, values, body_builder)); updated_values.insert(updated_values.end(), body_outputs.begin(), body_outputs.end()); return updated_values; }; - std::vector values; + std::vector values; values.reserve(initial_values.size() + 1); - values.push_back(xla::ConstantLiteral( - builder, xla::LiteralUtil::Zero(num_iterations_type))); + values.push_back( + ConstantLiteral(builder, LiteralUtil::Zero(num_iterations_type))); values.insert(values.end(), initial_values.begin(), initial_values.end()); - TF_ASSIGN_OR_RETURN(values, XlaWhileLoop(while_cond_fn, while_body_fn, values, - name, builder)); + TF_ASSIGN_OR_RETURN(values, WhileLoopHelper(while_cond_fn, while_body_fn, + values, name, builder)); values.erase(values.begin(), values.begin() + 1); return values; } -} // namespace tensorflow +} // namespace xla diff --git a/tensorflow/compiler/tf2xla/lib/while_loop.h b/tensorflow/compiler/xla/client/lib/loops.h similarity index 62% rename from tensorflow/compiler/tf2xla/lib/while_loop.h rename to tensorflow/compiler/xla/client/lib/loops.h index f2134bb4495a12b8342961d96f70e7737f816c7d..e11de59493e9c1de51fbdb6c45dab6d82b85a62a 100644 --- a/tensorflow/compiler/tf2xla/lib/while_loop.h +++ b/tensorflow/compiler/xla/client/lib/loops.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_COMPILER_TF2XLA_LIB_WHILE_LOOP_H_ -#define TENSORFLOW_COMPILER_TF2XLA_LIB_WHILE_LOOP_H_ +#ifndef TENSORFLOW_COMPILER_XLA_CLIENT_LIB_LOOPS_H_ +#define TENSORFLOW_COMPILER_XLA_CLIENT_LIB_LOOPS_H_ #include #include @@ -25,19 +25,18 @@ limitations under the License. #include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/statusor.h" -namespace tensorflow { +namespace xla { // Function that builds a loop condition. Takes as input a sequence of input // values, and returns a boolean value representing if the condition succeeds. -typedef std::function(absl::Span, - xla::XlaBuilder*)> - LoopConditionFunction; +typedef std::function(absl::Span, XlaBuilder*)> + WhileLoopHelperConditionFunction; // Function that builds a loop body. Takes as input a sequence of input values // and returns a sequence of output values. -typedef std::function>( - absl::Span, xla::XlaBuilder*)> - LoopBodyFunction; +typedef std::function>(absl::Span, + XlaBuilder*)> + WhileLoopHelperBodyFunction; // Helper function for building an XLA while loop, where the values carried by // the loop are a tuple of values, e.g., (a, b, c): @@ -47,27 +46,27 @@ typedef std::function>( // init: (a, b, c) // ) // 'name' is a descriptive name for the loop. -xla::StatusOr> XlaWhileLoop( - const LoopConditionFunction& condition_function, - const LoopBodyFunction& body_function, - absl::Span initial_values, absl::string_view name, - xla::XlaBuilder* builder); +StatusOr> WhileLoopHelper( + const WhileLoopHelperConditionFunction& condition_function, + const WhileLoopHelperBodyFunction& body_function, + absl::Span initial_values, absl::string_view name, + XlaBuilder* builder); // Builds an XLA loop that repeats a computation `num_iterations` times. // // The body function (ForEachIndexBodyFunction) takes as input a pair of // (current iteration number, loop-carried values), and returns an updated // vector of the loop-carried values. -typedef std::function>( - xla::XlaOp, absl::Span, xla::XlaBuilder*)> +typedef std::function>( + XlaOp, absl::Span, XlaBuilder*)> ForEachIndexBodyFunction; -xla::StatusOr> XlaForEachIndex( - int64 num_iterations, xla::PrimitiveType num_iterations_type, +StatusOr> ForEachIndex( + int64 num_iterations, PrimitiveType num_iterations_type, const ForEachIndexBodyFunction& body_function, - absl::Span initial_values, absl::string_view name, - xla::XlaBuilder* builder); + absl::Span initial_values, absl::string_view name, + XlaBuilder* builder); -} // namespace tensorflow +} // namespace xla -#endif // TENSORFLOW_COMPILER_TF2XLA_LIB_WHILE_LOOP_H_ +#endif // TENSORFLOW_COMPILER_XLA_CLIENT_LIB_LOOPS_H_ diff --git a/tensorflow/compiler/xla/client/lib/matrix.cc b/tensorflow/compiler/xla/client/lib/matrix.cc index ffd744d190885b8e3f4149a48a706498b3787618..16c177b4e2219adf079070a52b08e5884023908f 100644 --- a/tensorflow/compiler/xla/client/lib/matrix.cc +++ b/tensorflow/compiler/xla/client/lib/matrix.cc @@ -64,7 +64,7 @@ XlaOp GetMatrixDiagonal(XlaOp x) { }); } -XlaOp Triangle(XlaOp x, bool lower) { +XlaOp TriangleMask(XlaOp x, int diagonal) { XlaBuilder* builder = x.builder(); return builder->ReportErrorOrReturn([&]() -> StatusOr { TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(x)); @@ -74,20 +74,19 @@ XlaOp Triangle(XlaOp x, bool lower) { const int64 n = shape.dimensions(n_dims - 1); absl::Span major_dims = AsInt64Slice(shape.dimensions()).subspan(/*pos=*/0, /*len=*/n_dims - 2); - auto a = Iota(builder, U32, n); - auto b = Iota(builder, U32, m); + auto a = Iota(builder, S32, n); + auto b = Iota(builder, S32, m) + ConstantR0(builder, diagonal); XlaOp indicator; - if (lower) { - indicator = Ge(b, Broadcast(a, {m}), /*broadcast_dimensions=*/{0}); - } else { - indicator = Le(b, Broadcast(a, {m}), /*broadcast_dimensions=*/{0}); - } - auto mask = Broadcast(indicator, major_dims); - - return Select(mask, x, Zeros(builder, shape)); + indicator = Ge(b, Broadcast(a, {m}), /*broadcast_dimensions=*/{0}); + return Broadcast(indicator, major_dims); }); } +XlaOp Triangle(XlaOp x, bool lower) { + return lower ? Select(TriangleMask(x, 0), x, ZerosLike(x)) + : Select(TriangleMask(x, -1), ZerosLike(x), x); +} + XlaOp UpperTriangle(XlaOp x) { return Triangle(x, false); } XlaOp LowerTriangle(XlaOp x) { return Triangle(x, true); } diff --git a/tensorflow/compiler/xla/client/lib/matrix.h b/tensorflow/compiler/xla/client/lib/matrix.h index 8856f99c7a0fee8f315aac11fab392cf5536f57b..916cd83748e7028c474065b86bf02d85166d2c9c 100644 --- a/tensorflow/compiler/xla/client/lib/matrix.h +++ b/tensorflow/compiler/xla/client/lib/matrix.h @@ -31,6 +31,10 @@ XlaOp IdentityMatrix(XlaBuilder* builder, PrimitiveType type, int64 m, int64 n); // diagonal elements (i.e., with indices [..., i, i]). XlaOp GetMatrixDiagonal(XlaOp x); +// Returns a lower-triangular mask, i.e., true below the `diagonal`-th diagonal +// and false above that diagonal. +XlaOp TriangleMask(XlaOp x, int diagonal); + // Get the upper or lower triangle part of the last two dimensions XlaOp Triangle(XlaOp x, bool lower); diff --git a/tensorflow/compiler/tf2xla/lib/qr.cc b/tensorflow/compiler/xla/client/lib/qr.cc similarity index 62% rename from tensorflow/compiler/tf2xla/lib/qr.cc rename to tensorflow/compiler/xla/client/lib/qr.cc index d6007748609fdd161cb89692a167eb7ed12fe00c..72ca653173b78d9338f632c41779f2a30db1e978 100644 --- a/tensorflow/compiler/tf2xla/lib/qr.cc +++ b/tensorflow/compiler/xla/client/lib/qr.cc @@ -13,15 +13,14 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/tf2xla/lib/qr.h" +#include "tensorflow/compiler/xla/client/lib/qr.h" #include #include -#include "tensorflow/compiler/tf2xla/lib/util.h" -#include "tensorflow/compiler/tf2xla/lib/while_loop.h" #include "tensorflow/compiler/xla/client/lib/arithmetic.h" #include "tensorflow/compiler/xla/client/lib/constants.h" +#include "tensorflow/compiler/xla/client/lib/loops.h" #include "tensorflow/compiler/xla/client/lib/math.h" #include "tensorflow/compiler/xla/client/lib/matrix.h" #include "tensorflow/compiler/xla/client/lib/slicing.h" @@ -32,10 +31,18 @@ limitations under the License. #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/core/lib/core/errors.h" -namespace tensorflow { +namespace xla { namespace { +std::vector ConcatVectors(absl::Span xs, + absl::Span ys) { + std::vector output(xs.size() + ys.size()); + std::copy(xs.begin(), xs.end(), output.begin()); + std::copy(ys.begin(), ys.end(), output.begin() + xs.size()); + return output; +} + // Computes a Householder reflection of the form: // H = I - tau v v.T. // such that @@ -65,52 +72,47 @@ namespace { // return (v, tau, beta) // TODO(phawkins): LAPACK's xLARFG implementation has code for handling // overflows in the norm/beta calculations. Perhaps do the same here. -xla::Status House(xla::XlaOp x, xla::XlaOp k, - absl::Span batch_dims, const int64 m, - xla::XlaOp* v, xla::XlaOp* tau, xla::XlaOp* beta) { - xla::XlaBuilder* const builder = x.builder(); - TF_ASSIGN_OR_RETURN(xla::Shape x_shape, builder->GetShape(x)); - const xla::PrimitiveType type = x_shape.element_type(); +Status House(XlaOp x, XlaOp k, absl::Span batch_dims, + const int64 m, XlaOp* v, XlaOp* tau, XlaOp* beta) { + XlaBuilder* const builder = x.builder(); + TF_ASSIGN_OR_RETURN(Shape x_shape, builder->GetShape(x)); + const PrimitiveType type = x_shape.element_type(); std::vector batch_dim_ids(batch_dims.size()); std::iota(batch_dim_ids.begin(), batch_dim_ids.end(), 0); const int64 minor_dim = batch_dims.size(); - xla::XlaOp zero = xla::ScalarLike(x, 0.0); - xla::XlaOp one = xla::ScalarLike(x, 1.0); + XlaOp zero = ScalarLike(x, 0.0); + XlaOp one = ScalarLike(x, 1.0); // alpha = x[k] - xla::XlaOp alpha = - xla::Reshape(DynamicSliceInMinorDims(x, {k}, {1}), batch_dims); + XlaOp alpha = Reshape(DynamicSliceInMinorDims(x, {k}, {1}), batch_dims); // Compute x[k+1:] (padded with zeros in elements 0..k) - xla::XlaOp iota = xla::Iota(builder, xla::S32, m); - xla::XlaOp x_after_k = - xla::Mul(x, xla::ConvertElementType(xla::Gt(iota, k), type), - /*broadcast_dimensions=*/{minor_dim}); + XlaOp iota = Iota(builder, S32, m); + XlaOp x_after_k = Mul(x, ConvertElementType(Gt(iota, k), type), + /*broadcast_dimensions=*/{minor_dim}); // sigma = np.dot(x[k+1:], x[k+1:]) - auto sigma = - xla::Reduce(x_after_k * x_after_k, zero, - xla::CreateScalarAddComputation(type, builder), {minor_dim}); + auto sigma = Reduce(x_after_k * x_after_k, zero, + CreateScalarAddComputation(type, builder), {minor_dim}); // mu = np.sqrt(x[k]*x[k] + sigma) - auto mu = xla::Sqrt(xla::Square(alpha) + sigma); + auto mu = Sqrt(Square(alpha) + sigma); - auto sigma_is_zero = xla::Eq(sigma, zero); + auto sigma_is_zero = Eq(sigma, zero); - *beta = xla::Select(sigma_is_zero, alpha, -xla::Sign(alpha) * mu); - *tau = xla::Select(sigma_is_zero, xla::Broadcast(zero, batch_dims), - (*beta - alpha) / *beta); - auto divisor = xla::Select(sigma_is_zero, xla::Broadcast(one, batch_dims), - alpha - *beta); + *beta = Select(sigma_is_zero, alpha, -Sign(alpha) * mu); + *tau = Select(sigma_is_zero, Broadcast(zero, batch_dims), + (*beta - alpha) / *beta); + auto divisor = + Select(sigma_is_zero, Broadcast(one, batch_dims), alpha - *beta); - auto e_k = xla::Broadcast(xla::ConvertElementType(xla::Eq(iota, k), type), - std::vector(batch_dims.size(), 1)); + auto e_k = Broadcast(ConvertElementType(Eq(iota, k), type), + std::vector(batch_dims.size(), 1)); // Form v as [0, 0, ..., 1] ++ x[k+1:] / divisor // If sigma is zero, x[k+1:] is zero, so use any non-zero divisor. - *v = e_k + - xla::Div(x_after_k, divisor, /*broadcast_dimensions=*/batch_dim_ids); + *v = e_k + Div(x_after_k, divisor, /*broadcast_dimensions=*/batch_dim_ids); return Status::OK(); } @@ -143,90 +145,86 @@ xla::Status House(xla::XlaOp x, xla::XlaOp k, // return (q, vs, taus) struct QRBlockResult { // The factored R value - xla::XlaOp r; + XlaOp r; // Representation of the Householder matrices I - beta v v.T - xla::XlaOp taus; // Shape: [..., n] - xla::XlaOp vs; // Shape: [..., m, n] + XlaOp taus; // Shape: [..., n] + XlaOp vs; // Shape: [..., m, n] }; -xla::StatusOr QRBlock( - xla::XlaOp a, xla::PrecisionConfig::Precision precision) { - xla::XlaBuilder* builder = a.builder(); - TF_ASSIGN_OR_RETURN(xla::Shape a_shape, builder->GetShape(a)); - const int num_dims = xla::ShapeUtil::Rank(a_shape); +StatusOr QRBlock(XlaOp a, PrecisionConfig::Precision precision) { + XlaBuilder* builder = a.builder(); + TF_ASSIGN_OR_RETURN(Shape a_shape, builder->GetShape(a)); + const int num_dims = ShapeUtil::Rank(a_shape); if (num_dims < 2) { - return errors::InvalidArgument("Arguments to QR must have rank >= 2: ", - num_dims); + return InvalidArgument("Argument to QR must have rank >= 2; got shape %s", + a_shape.ToString()); } - xla::PrimitiveType type = a_shape.element_type(); + PrimitiveType type = a_shape.element_type(); - const int64 m = xla::ShapeUtil::GetDimension(a_shape, -2); - const int64 n = xla::ShapeUtil::GetDimension(a_shape, -1); + const int64 m = ShapeUtil::GetDimension(a_shape, -2); + const int64 n = ShapeUtil::GetDimension(a_shape, -1); const int64 num_batch_dims = num_dims - 2; std::vector batch_dims(num_batch_dims); for (int i = 0; i < num_batch_dims; ++i) { - batch_dims[i] = xla::ShapeUtil::GetDimension(a_shape, i); + batch_dims[i] = ShapeUtil::GetDimension(a_shape, i); } std::vector batch_dim_indices(num_batch_dims); std::iota(batch_dim_indices.begin(), batch_dim_indices.end(), 0); - auto qr_body_fn = - [&](xla::XlaOp j, absl::Span values, - xla::XlaBuilder* builder) -> xla::StatusOr> { + auto qr_body_fn = [&](XlaOp j, absl::Span values, + XlaBuilder* builder) -> StatusOr> { auto a = values[0]; auto vs = values[1]; auto taus = values[2]; // v, beta = house(a[:, j], j) auto x = DynamicSliceInMinorDims(a, {j}, {1}); - xla::XlaOp v, tau, beta; - TF_RETURN_IF_ERROR(House(xla::Collapse(x, {num_dims - 2, num_dims - 1}), j, + XlaOp v, tau, beta; + TF_RETURN_IF_ERROR(House(Collapse(x, {num_dims - 2, num_dims - 1}), j, batch_dims, m, &v, &tau, &beta)); std::vector shape = batch_dims; shape.push_back(1); shape.push_back(m); - auto v_broadcast = xla::Reshape(v, shape); + auto v_broadcast = Reshape(v, shape); // a[:, :] -= tau * np.dot(v[:, np.newaxis], // np.dot(v[np.newaxis, :], a[:, :])) auto vva = BatchDot(v_broadcast, a, precision); vva = BatchDot(TransposeInMinorDims(v_broadcast), vva, precision); - a = a - xla::Mul(tau, vva, - /*broadcast_dimensions=*/batch_dim_indices); + a = a - Mul(tau, vva, + /*broadcast_dimensions=*/batch_dim_indices); // It is more precise to populate column 'k' explicitly, rather than // computing it implicitly by applying the Householder transformation. // a[k,k] = beta // a[k+1:,k] = np.zeros([m-k-1], dtype=a.dtype) - auto iota = xla::Reshape(xla::Iota(a.builder(), xla::S32, m), {m, 1}); - auto predecessor_mask = xla::ConvertElementType(xla::Lt(iota, j), type); - auto mask = xla::Broadcast(xla::ConvertElementType(xla::Eq(iota, j), type), - std::vector(batch_dims.size(), 1)); - auto new_x = - xla::Mul(x, predecessor_mask, - /*broadcast_dimensions=*/{num_dims - 2, num_dims - 1}) + - xla::Mul(beta, mask, /*broadcast_dimensions=*/batch_dim_indices); + auto iota = Reshape(Iota(a.builder(), S32, m), {m, 1}); + auto predecessor_mask = ConvertElementType(Lt(iota, j), type); + auto mask = Broadcast(ConvertElementType(Eq(iota, j), type), + std::vector(batch_dims.size(), 1)); + auto new_x = Mul(x, predecessor_mask, + /*broadcast_dimensions=*/{num_dims - 2, num_dims - 1}) + + Mul(beta, mask, /*broadcast_dimensions=*/batch_dim_indices); a = DynamicUpdateSliceInMinorDims(a, new_x, {j}); // vs[:, j] = v vs = DynamicUpdateSliceInMinorDims( - vs, xla::Reshape(v, ConcatVectors(batch_dims, {m, 1})), {j}); + vs, Reshape(v, ConcatVectors(batch_dims, {m, 1})), {j}); // taus[j] = tau taus = DynamicUpdateSliceInMinorDims( - taus, xla::Reshape(tau, ConcatVectors(batch_dims, {1})), {j}); - return std::vector{a, vs, taus}; + taus, Reshape(tau, ConcatVectors(batch_dims, {1})), {j}); + return std::vector{a, vs, taus}; }; - auto vs = xla::Zeros(builder, xla::ShapeUtil::MakeShape( - type, ConcatVectors(batch_dims, {m, n}))); - auto taus = xla::Zeros( - builder, xla::ShapeUtil::MakeShape(type, ConcatVectors(batch_dims, {n}))); + auto vs = Zeros( + builder, ShapeUtil::MakeShape(type, ConcatVectors(batch_dims, {m, n}))); + auto taus = Zeros(builder, + ShapeUtil::MakeShape(type, ConcatVectors(batch_dims, {n}))); - TF_ASSIGN_OR_RETURN(auto values, - XlaForEachIndex(std::min(m, n), xla::S32, qr_body_fn, - {a, vs, taus}, "qr", builder)); + TF_ASSIGN_OR_RETURN(auto values, ForEachIndex(std::min(m, n), S32, qr_body_fn, + {a, vs, taus}, "qr", builder)); QRBlockResult result; result.r = values[0]; @@ -250,24 +248,23 @@ xla::StatusOr QRBlock( // return W // There is no need to return Y since at termination of the loop it is equal to // vs. -xla::StatusOr ComputeWYRepresentation( - xla::PrimitiveType type, absl::Span batch_dims, xla::XlaOp vs, - xla::XlaOp taus, int64 m, int64 n, - xla::PrecisionConfig::Precision precision) { +StatusOr ComputeWYRepresentation(PrimitiveType type, + absl::Span batch_dims, + XlaOp vs, XlaOp taus, int64 m, int64 n, + PrecisionConfig::Precision precision) { std::vector batch_dim_indices(batch_dims.size()); std::iota(batch_dim_indices.begin(), batch_dim_indices.end(), 0); int64 n_index = batch_dims.size() + 1; - auto body_fn = - [&](xla::XlaOp j, absl::Span values, - xla::XlaBuilder* builder) -> xla::StatusOr> { + auto body_fn = [&](XlaOp j, absl::Span values, + XlaBuilder* builder) -> StatusOr> { auto w = values[0]; auto y = values[1]; const auto vs = values[2]; const auto taus = values[3]; // Want j values in range [1, ... n). - j = j + xla::ConstantR0(builder, 1); + j = j + ConstantR0(builder, 1); // vs has shape [..., m, 1] auto v = DynamicSliceInMinorDims(vs, {j}, {1}); // beta has shape [..., 1] @@ -278,31 +275,31 @@ xla::StatusOr ComputeWYRepresentation( // wyv has shape [..., m, 1] auto wyv = BatchDot(w, yv, precision); - auto z = xla::Mul( + auto z = Mul( -beta, v + wyv, /*broadcast_dimensions=*/ConcatVectors(batch_dim_indices, {n_index})); w = DynamicUpdateSliceInMinorDims(w, z, {j}); y = DynamicUpdateSliceInMinorDims(y, v, {j}); - return std::vector{w, y, vs, taus}; + return std::vector{w, y, vs, taus}; }; - xla::XlaBuilder* builder = vs.builder(); - auto w = xla::Zeros(builder, xla::ShapeUtil::MakeShape( - type, ConcatVectors(batch_dims, {m, n}))); + XlaBuilder* builder = vs.builder(); + auto w = Zeros(builder, + ShapeUtil::MakeShape(type, ConcatVectors(batch_dims, {m, n}))); auto y = w; auto v = SliceInMinorDims(vs, {0}, {1}); auto beta = SliceInMinorDims(taus, {0}, {1}); y = UpdateSliceInMinorDims(y, v, {0}); - auto bv = xla::Mul( - -beta, v, - /*broadcast_dimensions=*/ConcatVectors(batch_dim_indices, {n_index})); + auto bv = + Mul(-beta, v, + /*broadcast_dimensions=*/ConcatVectors(batch_dim_indices, {n_index})); w = UpdateSliceInMinorDims(w, bv, {0}); TF_ASSIGN_OR_RETURN( - auto values, XlaForEachIndex(n - 1, xla::S32, body_fn, {w, y, vs, taus}, - "wy", builder)); + auto values, + ForEachIndex(n - 1, S32, body_fn, {w, y, vs, taus}, "wy", builder)); return values[0]; } @@ -323,34 +320,34 @@ xla::StatusOr ComputeWYRepresentation( // return (q, a) // TODO(phawkins): consider using UT transformations (in the form I - V U V') // rather than WY transformations. -xla::StatusOr QRDecomposition( - xla::XlaOp a, bool full_matrices, int64 block_size, - xla::PrecisionConfig::Precision precision) { - xla::XlaBuilder* builder = a.builder(); - TF_ASSIGN_OR_RETURN(xla::Shape a_shape, builder->GetShape(a)); - const int num_dims = xla::ShapeUtil::Rank(a_shape); +StatusOr QRDecomposition( + XlaOp a, bool full_matrices, int64 block_size, + PrecisionConfig::Precision precision) { + XlaBuilder* builder = a.builder(); + TF_ASSIGN_OR_RETURN(Shape a_shape, builder->GetShape(a)); + const int num_dims = ShapeUtil::Rank(a_shape); if (num_dims < 2) { - return errors::InvalidArgument("Arguments to QR must have rank >= 2: ", - num_dims); + return InvalidArgument("Arguments to QR must have rank >= 2: got shape %s", + a_shape.ToString()); } - xla::PrimitiveType type = a_shape.element_type(); + PrimitiveType type = a_shape.element_type(); - const int64 m = xla::ShapeUtil::GetDimension(a_shape, -2); - const int64 n = xla::ShapeUtil::GetDimension(a_shape, -1); + const int64 m = ShapeUtil::GetDimension(a_shape, -2); + const int64 n = ShapeUtil::GetDimension(a_shape, -1); const int64 p = std::min(m, n); if (block_size < 1) { - return errors::InvalidArgument( - "block_size argument to QR must be >= 1; got ", block_size); + return InvalidArgument("block_size argument to QR must be >= 1; got %d", + block_size); } const int64 num_batch_dims = num_dims - 2; std::vector batch_dims(num_batch_dims); for (int i = 0; i < num_batch_dims; ++i) { - batch_dims[i] = xla::ShapeUtil::GetDimension(a_shape, i); + batch_dims[i] = ShapeUtil::GetDimension(a_shape, i); } - auto q = xla::Broadcast(xla::IdentityMatrix(builder, type, m, m), batch_dims); + auto q = Broadcast(IdentityMatrix(builder, type, m, m), batch_dims); for (int64 i = 0; i < p; i += block_size) { int64 k = std::min(block_size, p - i); @@ -393,4 +390,4 @@ xla::StatusOr QRDecomposition( return result; } -} // namespace tensorflow +} // namespace xla diff --git a/tensorflow/compiler/tf2xla/lib/qr.h b/tensorflow/compiler/xla/client/lib/qr.h similarity index 74% rename from tensorflow/compiler/tf2xla/lib/qr.h rename to tensorflow/compiler/xla/client/lib/qr.h index 24b537ac8b63b93e734c3d0e335ea455f7d51a54..827c8eeca05ef09a0d77363eb3c40961b95813d8 100644 --- a/tensorflow/compiler/tf2xla/lib/qr.h +++ b/tensorflow/compiler/xla/client/lib/qr.h @@ -13,13 +13,13 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_COMPILER_TF2XLA_LIB_QR_H_ -#define TENSORFLOW_COMPILER_TF2XLA_LIB_QR_H_ +#ifndef TENSORFLOW_COMPILER_XLA_CLIENT_LIB_QR_H_ +#define TENSORFLOW_COMPILER_XLA_CLIENT_LIB_QR_H_ #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -namespace tensorflow { +namespace xla { // Computes the QR decompositions of a batch of matrices. That is, // given a (batched) matrix a, computes an orthonormal matrix Q and an @@ -29,14 +29,14 @@ namespace tensorflow { // the block size to use. // TODO(phawkins): handle the complex case. struct QRDecompositionResult { - xla::XlaOp q; - xla::XlaOp r; + XlaOp q; + XlaOp r; }; -xla::StatusOr QRDecomposition( - xla::XlaOp a, bool full_matrices, int64 block_size = 128, - xla::PrecisionConfig::Precision precision = xla::PrecisionConfig::HIGHEST); +StatusOr QRDecomposition( + XlaOp a, bool full_matrices, int64 block_size = 128, + PrecisionConfig::Precision precision = PrecisionConfig::HIGHEST); -} // namespace tensorflow +} // namespace xla -#endif // TENSORFLOW_COMPILER_TF2XLA_LIB_QR_H_ +#endif // TENSORFLOW_COMPILER_XLA_CLIENT_LIB_QR_H_ diff --git a/tensorflow/compiler/xla/client/lib/qr_test.cc b/tensorflow/compiler/xla/client/lib/qr_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..b27d364b62444d6d5fb1278b6e6461affc15b2e6 --- /dev/null +++ b/tensorflow/compiler/xla/client/lib/qr_test.cc @@ -0,0 +1,93 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/client/lib/qr.h" + +#include "tensorflow/compiler/xla/array2d.h" +#include "tensorflow/compiler/xla/array3d.h" +#include "tensorflow/compiler/xla/client/lib/matrix.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/literal.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/tests/client_library_test_base.h" +#include "tensorflow/compiler/xla/tests/literal_test_util.h" +#include "tensorflow/compiler/xla/tests/test_macros.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/core/status_test_util.h" + +namespace { + +using QrTest = xla::ClientLibraryTestBase; + +XLA_TEST_F(QrTest, Simple) { + xla::XlaBuilder builder(TestName()); + + xla::Array2D a_vals({ + {4, 6, 8, 10}, + {6, 45, 54, 63}, + {8, 54, 146, 166}, + {10, 63, 166, 310}, + }); + + xla::XlaOp a; + auto a_data = CreateR2Parameter(a_vals, 0, "a", &builder, &a); + TF_ASSERT_OK_AND_ASSIGN( + auto result, + xla::QRDecomposition(a, /*full_matrices=*/true, /*block_size=*/2)); + + // Verifies that the decomposition composes back to the original matrix. + // + // This isn't a terribly demanding test, (e.g., we should verify that Q is + // orthonormal and R is upper-triangular) but it's awkward to write such tests + // without more linear algebra libraries. It's easier to test the numerics + // from Python, anyway, where we have access to numpy and scipy. + xla::BatchDot(result.q, result.r, xla::PrecisionConfig::HIGHEST); + + ComputeAndCompareR2(&builder, a_vals, {a_data.get()}, + xla::ErrorSpec(1e-4, 1e-4)); +} + +XLA_TEST_F(QrTest, SimpleBatched) { + xla::XlaBuilder builder(TestName()); + + xla::Array3D a_vals({ + { + {4, 6, 8, 10}, + {6, 45, 54, 63}, + {8, 54, 146, 166}, + {10, 63, 166, 310}, + }, + { + {16, 24, 8, 12}, + {24, 61, 82, 48}, + {8, 82, 456, 106}, + {12, 48, 106, 62}, + }, + }); + + xla::XlaOp a; + auto a_data = CreateR3Parameter(a_vals, 0, "a", &builder, &a); + TF_ASSERT_OK_AND_ASSIGN( + auto result, + xla::QRDecomposition(a, /*full_matrices=*/true, /*block_size=*/2)); + + xla::BatchDot(result.q, result.r, xla::PrecisionConfig::HIGHEST); + + ComputeAndCompareR3(&builder, a_vals, {a_data.get()}, + xla::ErrorSpec(1e-4, 1e-4)); +} + +} // namespace diff --git a/tensorflow/compiler/xla/client/lib/quantize.h b/tensorflow/compiler/xla/client/lib/quantize.h new file mode 100644 index 0000000000000000000000000000000000000000..26dbbd5b00bd1a29f4047c9a4294fcac7340cf6c --- /dev/null +++ b/tensorflow/compiler/xla/client/lib/quantize.h @@ -0,0 +1,186 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_CLIENT_LIB_QUANTIZE_H_ +#define TENSORFLOW_COMPILER_XLA_CLIENT_LIB_QUANTIZE_H_ + +#include +#include +#include + +#include "tensorflow/compiler/xla/client/lib/constants.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/bfloat16/bfloat16.h" + +namespace xla { + +constexpr int64 kBitsOfByte = 8; + +// Represents the range used for quantization +struct QuantizedRange { + QuantizedRange() = default; + QuantizedRange(float min_in, float max_in) : min(min_in), max(max_in) {} + + bool operator==(const QuantizedRange& rhs) const { + return this->min == rhs.min && this->max == rhs.max; + } + + bool operator!=(const QuantizedRange& rhs) const { return !(*this == rhs); } + + tensorflow::bfloat16 min = tensorflow::bfloat16(0.0f); + tensorflow::bfloat16 max = tensorflow::bfloat16(0.0f); +}; + +template +inline std::vector PackToUint32(absl::Span input) { + const int64 kElementsPerPack = sizeof(uint32) / sizeof(T); + const int64 input_size = input.size(); + const int64 output_size = CeilOfRatio(input_size, kElementsPerPack); + + std::vector output_vec; + constexpr int64 kShiftBits = sizeof(T) / sizeof(uint8) * kBitsOfByte; + + for (int64 i = 0; i < output_size; i++) { + uint32 result = 0; + for (int64 p = 0; p < kElementsPerPack; p++) { + int64 index = i * kElementsPerPack + p; + if (index < input_size) { + int64 total_shift_bits = kShiftBits * (kElementsPerPack - p - 1); + result |= (input[index] << total_shift_bits); + } + } + output_vec.push_back(result); + } + + return output_vec; +} + +// Dequantize the quantized input of packed uint32 to bfloat16. +// Only uint8 or uint16 is supported for the original unpacked input. +// Returns a tensor of shape [d0,..., dn * unpack_size] if +// input shape is [d0, ..., dn], where unpack_size = sizeof(unit32) / sizeof(T). +// If transpose_output is true, will return a tensor of shape +// [dn * unpack_size, dn-1, ..., d1, d0]. transpose_output is faster when +// input's rank higher than 1. The input needs to be transposed to use +// transpose_output feature. +template +inline XlaOp Dequantize(XlaOp input, const QuantizedRange& range, + absl::string_view mode_string = "MIN_COMBINED", + bool transpose_output = false) { + XlaBuilder* const builder = input.builder(); + return builder->ReportErrorOrReturn([&]() -> StatusOr { + float half_range = + !std::is_signed::value + ? 0.0f + : (static_cast(std::numeric_limits::max()) - + std::numeric_limits::min() + 1) / + 2.0f; + const int64 unpack_size = sizeof(uint32) / sizeof(T); + TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(input)); + + auto element_type = shape.element_type(); + if (element_type != U32) { + return InvalidArgument( + "Only U32 is supported for input type of xla::Dequantize Op."); + } + + // Broadcast the input to [unpack_size, d0, ..., dn] if input size is + // [d0, ..., dn]. + auto broadcast_input = Broadcast(input, {unpack_size}); + + XlaOp iota_r1 = Iota(builder, U32, unpack_size); + // Highest significant bytes needs to shift more bytes than lower + // significant bytes. + XlaOp shift_bytes = + xla::ConstantR0(builder, unpack_size - 1) - iota_r1; + + const int bytes_of_type = sizeof(T) / sizeof(uint8); + std::vector shift_vec(unpack_size, kBitsOfByte * bytes_of_type); + XlaOp shift_bits = + shift_bytes * xla::ConstantR1(builder, shift_vec); + + // Make bit_mask for different data type T. + uint32 bit_mask = 0x00000000; + for (int i = 0; i < bytes_of_type; i++) { + bit_mask <<= kBitsOfByte; + bit_mask |= 0x000000ff; + } + + std::vector shift_transpose_dimensions(shape.dimensions_size()); + std::iota(shift_transpose_dimensions.begin(), + shift_transpose_dimensions.end(), 0); + shift_transpose_dimensions.insert(shift_transpose_dimensions.begin(), 1, + shape.dimensions_size()); + + // Shift the input by sizeof(T) bytes and apply bit_mask to unpack. + XlaOp shifted_input = ShiftRightLogical( + broadcast_input, Transpose(Broadcast(shift_bits, shape.dimensions()), + shift_transpose_dimensions)); + XlaOp unpack_input = + And(shifted_input, xla::ConstantR0(builder, bit_mask)); + + XlaOp result; + + if (mode_string == "MIN_COMBINED") { + const tensorflow::bfloat16 scale_factor = + (range.max - range.min) / + (static_cast(std::numeric_limits::max() - + std::numeric_limits::min())); + // result = bfloat16(input + half_range) * scale_factor + range.min + XlaOp unpack_input_bf16 = ConvertElementType(unpack_input, BF16); + XlaOp half_range_bf16 = xla::ConstantR0( + builder, static_cast(half_range)); + XlaOp sum = unpack_input_bf16 + half_range_bf16; + + result = + sum * xla::ConstantR0(builder, scale_factor) + + xla::ConstantR0(builder, range.min); + } else { + // TODO(wangtao): support other modes. + return InvalidArgument( + "Only MIN_COMBINED mode is supported in xla::Dequantize Op."); + } + + std::vector transpose_dimensions(shape.dimensions_size()); + std::iota(transpose_dimensions.begin(), transpose_dimensions.end(), 1); + std::reverse(transpose_dimensions.begin(), transpose_dimensions.end()); + transpose_dimensions.insert(transpose_dimensions.begin() + 1, 1, 0); + + // Transpose the result to be [dn, unpack_size, dn-1, ..., d1, d0]. + XlaOp transposed_result = Transpose(result, transpose_dimensions); + + // Reshape to be [dn * unpack_size, dn-1, ..., d1, d0]. + XlaOp reshaped_result = Collapse(transposed_result, {0, 1}); + + // Return the transpose result if transpose_output is true. + if (transpose_output) { + return reshaped_result; + } + + // Transpose the result to be [d0, d1, ..., dn-1, dn * unpack_size]. + std::vector result_dimensions(shape.dimensions_size()); + std::iota(result_dimensions.begin(), result_dimensions.end(), 0); + std::reverse(result_dimensions.begin(), result_dimensions.end()); + + return Transpose(reshaped_result, result_dimensions); + }); +} + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_CLIENT_LIB_QUANTIZE_H_ diff --git a/tensorflow/compiler/xla/client/lib/quantize_test.cc b/tensorflow/compiler/xla/client/lib/quantize_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..be3603d9e11670913c21a834d2216a999306d582 --- /dev/null +++ b/tensorflow/compiler/xla/client/lib/quantize_test.cc @@ -0,0 +1,337 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/client/lib/quantize.h" + +#include + +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/tests/client_library_test_base.h" +#include "tensorflow/compiler/xla/tests/test_macros.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/util.h" + +namespace xla { +namespace { + +using bfloat16 = tensorflow::bfloat16; + +template +std::vector GenerateInput() { + std::vector input; + + for (int64 i = std::numeric_limits::min(); + i < std::numeric_limits::max(); ++i) { + input.push_back(static_cast(i)); + } + + return input; +} + +template +Array2D GenerateLargeSizeInput(int num_columns, int num_rows) { + Array2D input(num_columns, num_rows); + + input.FillRandom(6, 128); + + return input; +} + +template +Array2D PackLargeInput(Array2D &input) { + const int64 size_per_pack = sizeof(uint32) / sizeof(NativeT); + int64 width = input.width(); + + int64 padded_output_width = CeilOfRatio(width, size_per_pack); + + Array2D pack_input(input.height(), padded_output_width); + + for (int h = 0; h < input.height(); h++) { + std::vector input_row; + for (int w = 0; w < width; w++) { + input_row.push_back(input({h, w})); + } + + auto pack_input_vec = PackToUint32(input_row); + + for (int w = 0; w < padded_output_width; w++) { + pack_input(h, w) = pack_input_vec[w]; + } + } + + return pack_input; +} + +template +Array2D GenerateLargeSizeMinCombinedOutput( + Array2D &input, const QuantizedRange &range, + bool transpose_output = false) { + const int64 size_per_pack = sizeof(uint32) / sizeof(NativeT); + int64 width = input.width(); + + int64 padded_output_width = CeilOfRatio(width, size_per_pack) * size_per_pack; + + int64 output_height; + int64 output_width; + + if (transpose_output) { + output_height = padded_output_width; + output_width = input.height(); + } else { + output_height = input.height(); + output_width = padded_output_width; + } + + Array2D output(output_height, output_width, bfloat16(0.0)); + + float half_range = + !std::is_signed::value + ? 0.0f + : (static_cast(std::numeric_limits::max() - + std::numeric_limits::min() + 1)) / + 2.0f; + const bfloat16 scale_factor = + (range.max - range.min) / + (static_cast(std::numeric_limits::max() - + std::numeric_limits::min())); + + for (int h = 0; h < input.height(); h++) { + std::vector input_row; + for (int w = 0; w < width; w++) { + bfloat16 result = + static_cast(input(h, w) + half_range) * scale_factor + + range.min; + if (transpose_output) { + output(w, h) = result; + } else { + output(h, w) = result; + } + } + } + + return output; +} + +template +std::vector GenerateMinCombinedOutput(const QuantizedRange &range) { + float half_range = + !std::is_signed::value + ? 0.0f + : (static_cast(std::numeric_limits::max() - + std::numeric_limits::min() + 1)) / + 2.0f; + const bfloat16 scale_factor = + (range.max - range.min) / + (static_cast(std::numeric_limits::max() - + std::numeric_limits::min())); + std::vector output; + for (int64 i = std::numeric_limits::min(); + i < std::numeric_limits::max(); ++i) { + bfloat16 result = + static_cast(i + half_range) * scale_factor + range.min; + output.push_back(result); + } + + const int64 pack_size = sizeof(uint32) / sizeof(NativeT); + const int64 output_size = output.size(); + + int64 num_tailing_zeros = + CeilOfRatio(output_size, pack_size) * pack_size - output_size; + + output.insert(output.end(), num_tailing_zeros, bfloat16(0.0)); + return output; +} + +// TODO(wangtao): add a test to make sure this op is the inverse of the existing +// TF quantize op defined in: third_party/tensorflow/core/kernels/quantize_op.cc + +using DequantizeTest = ClientLibraryTestBase; + +TEST(PackTest, PackUint8ToUint32) { + std::vector input = {0xAB, 0x0B, 0x00, 0xF0, 0x01}; + auto output = PackToUint32(input); + EXPECT_THAT(output, ::testing::ElementsAre(0xAB0B00F0, 0x01000000)); +} + +TEST(PackTest, PackInt8ToUint32) { + std::vector input = {static_cast(0x81), 0x0B, 0x00, 0x20, + 0x01}; + auto output = PackToUint32(input); + EXPECT_THAT(output, ::testing::ElementsAre(0x810B0020, 0x01000000)); +} + +TEST(PackTest, PackUint8ToUint32PerfectSize) { + std::vector input = {3, 2, 1, 0}; + auto output = PackToUint32(input); + EXPECT_THAT(output, ::testing::ElementsAre(0x03020100)); +} + +XLA_TEST_F(DequantizeTest, MinCombinedUint16R1) { + XlaBuilder builder(TestName()); + auto input = GenerateInput(); + auto x = ConstantR1(&builder, PackToUint32(input)); + QuantizedRange range(0, 255.0f); + xla::Dequantize(x, range, "MIN_COMBINED"); + auto expected = GenerateMinCombinedOutput(range); + ComputeAndCompareR1(&builder, expected, {}); +} + +XLA_TEST_F(DequantizeTest, MinCombinedUint8R1) { + XlaBuilder builder(TestName()); + auto input = GenerateInput(); + auto x = ConstantR1(&builder, PackToUint32(input)); + QuantizedRange range(0, 127.0f); + xla::Dequantize(x, range, "MIN_COMBINED"); + auto expected = GenerateMinCombinedOutput(range); + ComputeAndCompareR1(&builder, expected, {}); +} + +XLA_TEST_F(DequantizeTest, MinCombinedUint8R2) { + XlaBuilder builder(TestName()); + std::vector> input = { + {0, 1, 2, 3}, + {4, 5, 6, 7}, + {8, 9, 10, 11}, + {12, 13, 16, 15}, + }; + auto x = ConstantR2(&builder, {{PackToUint32(input[0])[0]}, + {PackToUint32(input[1])[0]}, + {PackToUint32(input[2])[0]}, + {PackToUint32(input[3])[0]}}); + QuantizedRange range(0, 255.0f); + xla::Dequantize(x, range, "MIN_COMBINED"); + const Array2D expected = { + {bfloat16(0.0), bfloat16(1.0), bfloat16(2.0), bfloat16(3.0)}, + {bfloat16(4.0), bfloat16(5.0), bfloat16(6.0), bfloat16(7.0)}, + {bfloat16(8.0), bfloat16(9.0), bfloat16(10.0), bfloat16(11.0)}, + {bfloat16(12.0), bfloat16(13.0), bfloat16(16.0), bfloat16(15.0)}, + }; + ComputeAndCompareR2(&builder, expected, {}); +} + +XLA_TEST_F(DequantizeTest, MinCombinedUint8R2TransposeOutput) { + XlaBuilder builder(TestName()); + std::vector> input = { + {0, 1, 2, 3}, + {4, 5, 6, 7}, + {8, 9, 10, 11}, + {12, 13, 16, 15}, + }; + auto x = ConstantR2(&builder, {{PackToUint32(input[0])[0]}, + {PackToUint32(input[1])[0]}, + {PackToUint32(input[2])[0]}, + {PackToUint32(input[3])[0]}}); + QuantizedRange range(0, 255.0f); + xla::Dequantize(x, range, "MIN_COMBINED", /*transpose_output=*/true); + const Array2D expected = { + {bfloat16(0.0), bfloat16(4.0), bfloat16(8.0), bfloat16(12.0)}, + {bfloat16(1.0), bfloat16(5.0), bfloat16(9.0), bfloat16(13.0)}, + {bfloat16(2.0), bfloat16(6.0), bfloat16(10.0), bfloat16(16.0)}, + {bfloat16(3.0), bfloat16(7.0), bfloat16(11.0), bfloat16(15.0)}, + }; + ComputeAndCompareR2(&builder, expected, {}); +} + +XLA_TEST_F(DequantizeTest, MinCombinedUint8R2TailingZero) { + XlaBuilder builder(TestName()); + std::vector> input = { + {0, 1, 2, 3, 16}, + {4, 5, 6, 7, 17}, + {8, 9, 10, 11, 18}, + {12, 13, 16, 15, 19}, + }; + auto x = ConstantR2( + &builder, + {{PackToUint32(input[0])[0], PackToUint32(input[0])[1]}, + {PackToUint32(input[1])[0], PackToUint32(input[1])[1]}, + {PackToUint32(input[2])[0], PackToUint32(input[2])[1]}, + {PackToUint32(input[3])[0], PackToUint32(input[3])[1]}}); + QuantizedRange range(0, 255.0f); + xla::Dequantize(x, range, "MIN_COMBINED"); + + const Array2D expected = { + {bfloat16(0.0), bfloat16(1.0), bfloat16(2.0), bfloat16(3.0), + bfloat16(16.0), bfloat16(0.0), bfloat16(0.0), bfloat16(0.0)}, + {bfloat16(4.0), bfloat16(5.0), bfloat16(6.0), bfloat16(7.0), + bfloat16(17.0), bfloat16(0.0), bfloat16(0.0), bfloat16(0.0)}, + {bfloat16(8.0), bfloat16(9.0), bfloat16(10.0), bfloat16(11.0), + bfloat16(18.0), bfloat16(0.0), bfloat16(0.0), bfloat16(0.0)}, + {bfloat16(12.0), bfloat16(13.0), bfloat16(16.0), bfloat16(15.0), + bfloat16(19.0), bfloat16(0.0), bfloat16(0.0), bfloat16(0.0)}, + }; + ComputeAndCompareR2(&builder, expected, {}); +} + +XLA_TEST_F(DequantizeTest, MinCombinedUint8R2TailingZeroTransposeOutput) { + XlaBuilder builder(TestName()); + std::vector> input = { + {0, 1, 2, 3, 16}, + {4, 5, 6, 7, 17}, + {8, 9, 10, 11, 18}, + {12, 13, 16, 15, 19}, + }; + auto x = ConstantR2( + &builder, + {{PackToUint32(input[0])[0], PackToUint32(input[0])[1]}, + {PackToUint32(input[1])[0], PackToUint32(input[1])[1]}, + {PackToUint32(input[2])[0], PackToUint32(input[2])[1]}, + {PackToUint32(input[3])[0], PackToUint32(input[3])[1]}}); + QuantizedRange range(0, 255.0f); + xla::Dequantize(x, range, "MIN_COMBINED", /*transpose_output=*/true); + + const Array2D expected = { + {bfloat16(0.0), bfloat16(4.0), bfloat16(8.0), bfloat16(12.0)}, + {bfloat16(1.0), bfloat16(5.0), bfloat16(9.0), bfloat16(13.0)}, + {bfloat16(2.0), bfloat16(6.0), bfloat16(10.0), bfloat16(16.0)}, + {bfloat16(3.0), bfloat16(7.0), bfloat16(11.0), bfloat16(15.0)}, + {bfloat16(16.0), bfloat16(17.0), bfloat16(18.0), bfloat16(19.0)}, + {bfloat16(0.0), bfloat16(0.0), bfloat16(0.0), bfloat16(0.0)}, + {bfloat16(0.0), bfloat16(0.0), bfloat16(0.0), bfloat16(0.0)}, + {bfloat16(0.0), bfloat16(0.0), bfloat16(0.0), bfloat16(0.0)}, + }; + ComputeAndCompareR2(&builder, expected, {}); +} + +XLA_TEST_F(DequantizeTest, MinCombinedUint8LargeSizeTest) { + XlaBuilder builder(TestName()); + Array2D input = GenerateLargeSizeInput(500, 3547); + Array2D input_packed = PackLargeInput(input); + + auto x = ConstantR2FromArray2D(&builder, input_packed); + QuantizedRange range(0, 255.0f); + xla::Dequantize(x, range, "MIN_COMBINED"); + + const Array2D expected = + GenerateLargeSizeMinCombinedOutput(input, range); + ComputeAndCompareR2(&builder, expected, {}); +} + +XLA_TEST_F(DequantizeTest, MinCombinedUint8LargeSizeTestTransposeOutput) { + XlaBuilder builder(TestName()); + Array2D input = GenerateLargeSizeInput(500, 3547); + Array2D input_packed = PackLargeInput(input); + + auto x = ConstantR2FromArray2D(&builder, input_packed); + QuantizedRange range(0, 255.0f); + xla::Dequantize(x, range, "MIN_COMBINED", /*transpose_output=*/true); + + const Array2D expected = GenerateLargeSizeMinCombinedOutput( + input, range, /*transpose_output=*/true); + ComputeAndCompareR2(&builder, expected, {}); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/client/lib/testing.cc b/tensorflow/compiler/xla/client/lib/testing.cc index a95bbf2c8c860914877d3195b97342097dafc725..5db9d10dff4c50d71cde934b3f3c345bee571f29 100644 --- a/tensorflow/compiler/xla/client/lib/testing.cc +++ b/tensorflow/compiler/xla/client/lib/testing.cc @@ -59,22 +59,25 @@ XlaOp BuildFakeDataOpOnDevice(const Shape& shape, XlaBuilder* builder) { return Tuple(builder, parts); } -std::unique_ptr MakeFakeDataViaDeviceOrDie(const Shape& shape, - Client* client) { +std::unique_ptr MakeFakeDataViaDeviceOrDie( + const Shape& shape, Client* client, DebugOptions* debug_opts) { XlaBuilder b(absl::StrCat("make_fake_", ShapeUtil::HumanString(shape))); BuildFakeDataOpOnDevice(shape, &b); XlaComputation computation = b.Build().ConsumeValueOrDie(); auto execution_options = CreateDefaultExecutionOptions(); *execution_options.mutable_shape_with_output_layout() = shape.ToProto(); + if (debug_opts) { + *execution_options.mutable_debug_options() = *debug_opts; + } return client->Execute(computation, /*arguments=*/{}, &execution_options) .ConsumeValueOrDie(); } } // namespace -std::unique_ptr MakeFakeDataOrDie(const Shape& shape, - Client* client) { +std::unique_ptr MakeFakeDataOrDie( + const Shape& shape, Client* client, DebugOptions* debug_opts /*=nullptr*/) { if (DataSizeOfShape(shape) < (1LL << 20)) { StatusOr literal_status = MakeFakeLiteral(shape); if (!literal_status.ok()) { @@ -82,24 +85,25 @@ std::unique_ptr MakeFakeDataOrDie(const Shape& shape, // an on-device computation. CHECK_EQ(literal_status.status().code(), tensorflow::error::UNIMPLEMENTED); - return MakeFakeDataViaDeviceOrDie(shape, client); + return MakeFakeDataViaDeviceOrDie(shape, client, debug_opts); } return client->TransferToServer(literal_status.ValueOrDie()).ValueOrDie(); } // If the data is large, generate it on-device. - return MakeFakeDataViaDeviceOrDie(shape, client); + return MakeFakeDataViaDeviceOrDie(shape, client, debug_opts); } std::vector> MakeFakeArgumentsOrDie( - const XlaComputation& computation, Client* client) { + const XlaComputation& computation, Client* client, + DebugOptions* debug_opts /*=nullptr*/) { CHECK(computation.proto().has_host_program_shape()) << "Computation should have progran shape."; auto program_shape = computation.proto().host_program_shape(); std::vector> results; for (const ShapeProto& shape : program_shape.parameters()) { - results.push_back(MakeFakeDataOrDie(Shape(shape), client)); + results.push_back(MakeFakeDataOrDie(Shape(shape), client, debug_opts)); } return results; } diff --git a/tensorflow/compiler/xla/client/lib/testing.h b/tensorflow/compiler/xla/client/lib/testing.h index 03695ce2a339735e3e49522f4fe1bbf2d83a3834..428fa3e93d1b46983aae60176e7c2242d2552fdb 100644 --- a/tensorflow/compiler/xla/client/lib/testing.h +++ b/tensorflow/compiler/xla/client/lib/testing.h @@ -29,14 +29,19 @@ namespace xla { // Generates fake data of the given shape on the device or dies. The fake data // is created by performing a computation on the device rather than transferring // data from the host to the device. -std::unique_ptr MakeFakeDataOrDie(const Shape& shape, - Client* client); +// +// The optional DebugOptions are used when generating fake data on the device. +std::unique_ptr MakeFakeDataOrDie( + const Shape& shape, Client* client, DebugOptions* debug_opts = nullptr); // Returns vector of GlobalData handles of fake data (created using // MakeFakeDataOrDie) that are correctly shaped arguments for the given // xla computation. +// +// The optional DebugOptions are used when generating fake data on the device. std::vector> MakeFakeArgumentsOrDie( - const XlaComputation& computation, Client* client); + const XlaComputation& computation, Client* client, + DebugOptions* debug_opts = nullptr); } // namespace xla diff --git a/tensorflow/compiler/xla/client/lib/triangular_solve.cc b/tensorflow/compiler/xla/client/lib/triangular_solve.cc index c5a1d34cc66e6f8c1a832f8a8437163b846a5431..4bc2f3d121884541c497361695e3ddb9423e6238 100644 --- a/tensorflow/compiler/xla/client/lib/triangular_solve.cc +++ b/tensorflow/compiler/xla/client/lib/triangular_solve.cc @@ -62,15 +62,26 @@ XlaOp DiagonalBlocks(XlaOp a, int64 block_size) { /*broadcast_sizes=*/{2}), /*permutation=*/{1, 0}); + PaddingConfig padding_config = + MakeEdgePaddingConfig({{0, 0}, {ndims - 2, 0}}); + start_indices = + Pad(start_indices, ConstantR0(builder, 0), padding_config); + // Gather the diagonal blocks + std::vector slice_sizes(ndims); GatherDimensionNumbers dim_numbers; + for (int i = 0; i < ndims - 2; ++i) { + dim_numbers.add_offset_dims(i); + dim_numbers.add_start_index_map(i); + slice_sizes[i] = ShapeUtil::GetDimension(shape, i); + } + slice_sizes[ndims - 2] = slice_sizes[ndims - 1] = block_size; dim_numbers.add_offset_dims(ndims - 1); dim_numbers.add_offset_dims(ndims); dim_numbers.add_start_index_map(ndims - 2); dim_numbers.add_start_index_map(ndims - 1); dim_numbers.set_index_vector_dim(1); - diag_blocks = Gather(a, start_indices, dim_numbers, - /*slice_sizes=*/{block_size, block_size}); + diag_blocks = Gather(a, start_indices, dim_numbers, slice_sizes); } // The last block might be smaller than the block size, @@ -393,6 +404,12 @@ XlaOp TriangularSolve(XlaOp a, XlaOp b, bool left_side, bool lower, block_size); } + if (ShapeUtil::IsZeroElementArray(b_shape)) { + // The output has the same shape as 'b', and since the output has zero + // elements, any such array will do. + return b; + } + // We find the diagonal blocks of the coefficient matrix auto diag_blocks = DiagonalBlocks(a, block_size); @@ -400,6 +417,11 @@ XlaOp TriangularSolve(XlaOp a, XlaOp b, bool left_side, bool lower, auto inv_diag_blocks = InvertDiagonalBlocks(diag_blocks, lower, transpose_a, conjugate_a, precision); + // Mask off the ignored elements of the triangular matrix a. + // TODO(phawkins): it would probably be preferable to perform this masking + // block by block inside SolveWithInvertedDiagonalBlocks. + a = Triangle(a, lower); + // We now find the solution using GEMMs auto x = SolveWithInvertedDiagonalBlocks(a, b, inv_diag_blocks, left_side, lower, diff --git a/tensorflow/compiler/xla/client/lib/triangular_solve_test.cc b/tensorflow/compiler/xla/client/lib/triangular_solve_test.cc index f6a70d64a788d95a456774ccbbcf67f2e5cac98b..703227c94944feb6858de9464758e024c55b323d 100644 --- a/tensorflow/compiler/xla/client/lib/triangular_solve_test.cc +++ b/tensorflow/compiler/xla/client/lib/triangular_solve_test.cc @@ -20,6 +20,8 @@ limitations under the License. #include #include "tensorflow/compiler/xla/array2d.h" +#include "tensorflow/compiler/xla/client/lib/math.h" +#include "tensorflow/compiler/xla/client/lib/matrix.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/statusor.h" @@ -33,56 +35,78 @@ limitations under the License. namespace xla { namespace { -using TriangularSolveTest = xla::ClientLibraryTestBase; -using TriangularSolveLeftLookingTest = xla::ClientLibraryTestBase; -using complex64 = xla::complex64; +using TriangularSolveTest = ClientLibraryTestBase; +using TriangularSolveLeftLookingTest = ClientLibraryTestBase; -xla::Array2D AValsLower() { - return {{2, 0, 0, 0}, {3, 6, 0, 0}, {4, 7, 9, 0}, {5, 8, 10, 11}}; +static constexpr float kNan = std::numeric_limits::quiet_NaN(); + +Array2D AValsLower() { + return {{2, kNan, kNan, kNan}, + {3, 6, kNan, kNan}, + {4, 7, 9, kNan}, + {5, 8, 10, 11}}; } -xla::Array2D AValsUpper() { - return {{2, 3, 4, 5}, {0, 6, 7, 8}, {0, 0, 9, 10}, {0, 0, 0, 11}}; +Array2D AValsUpper() { + return {{2, 3, 4, 5}, + {kNan, 6, 7, 8}, + {kNan, kNan, 9, 10}, + {kNan, kNan, kNan, 11}}; } -xla::Array2D BValsRight() { +Array2D BValsRight() { return {{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}}; } -xla::Array2D BValsLeft() { +Array2D BValsLeft() { return {{1, 2, 3}, {4, 5, 6}, {7, 8, 9}, {10, 11, 12}}; } -xla::Array2D AValsLowerComplex() { - return {{2, 0, 0, 0}, - {complex64(3, 1), 6, 0, 0}, - {4, complex64(7, 2), 9, 0}, +static constexpr complex64 kNanC64 = complex64(kNan, kNan); + +Array2D AValsLowerComplex() { + return {{2, kNanC64, kNanC64, kNanC64}, + {complex64(3, 1), 6, kNanC64, kNanC64}, + {4, complex64(7, 2), 9, kNanC64}, {5, 8, complex64(10, 3), 11}}; } -xla::Array2D AValsUpperComplex() { +Array2D AValsUpperComplex() { return {{2, 3, complex64(4, 3), 5}, - {0, 6, complex64(7, 2), 8}, - {0, 0, complex64(9, 1), 10}, - {0, 0, 0, 11}}; + {kNanC64, 6, complex64(7, 2), 8}, + {kNanC64, kNanC64, complex64(9, 1), 10}, + {kNanC64, kNanC64, kNanC64, 11}}; } -xla::Array2D BValsRightComplex() { +Array2D BValsRightComplex() { return {{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}}; } -xla::Array2D BValsLeftComplex() { +Array2D BValsLeftComplex() { return {{1, 2, 3}, {4, 5, 6}, {7, 8, 9}, {10, 11, 12}}; } -xla::Array2D AValsFull() { - return {{2, 0, 1, 2}, {3, 6, 0, 1}, {4, 7, 9, 0}, {5, 8, 10, 11}}; +XLA_TEST_F(TriangularSolveTest, EmptyArrays) { + XlaBuilder builder(TestName()); + + XlaOp a, b; + auto a_data = + CreateR2Parameter(Array2D(0, 0), 0, "a", &builder, &a); + auto b_data = + CreateR2Parameter(Array2D(0, 10), 1, "b", &builder, &b); + TriangularSolve(a, b, + /*left_side=*/true, /*lower=*/true, + /*transpose_a=*/true, /*conjugate_a=*/false, + /*block_size=*/2); + + ComputeAndCompareR2(&builder, Array2D(0, 10), + {a_data.get(), b_data.get()}); } XLA_TEST_F(TriangularSolveTest, SimpleRightLowerTranspose) { - xla::XlaBuilder builder(TestName()); + XlaBuilder builder(TestName()); - xla::XlaOp a, b; + XlaOp a, b; auto a_data = CreateR2Parameter(AValsLower(), 0, "a", &builder, &a); auto b_data = CreateR2Parameter(BValsRight(), 1, "b", &builder, &b); TriangularSolve(a, b, @@ -90,20 +114,20 @@ XLA_TEST_F(TriangularSolveTest, SimpleRightLowerTranspose) { /*transpose_a=*/true, /*conjugate_a=*/false, /*block_size=*/2); - xla::Array2D expected({ + Array2D expected({ {0.5, 0.08333334, 0.04629629, 0.03367003}, {2.5, -0.25, -0.1388889, -0.1010101}, {4.5, -0.58333331, -0.32407406, -0.23569024}, }); ComputeAndCompareR2(&builder, expected, {a_data.get(), b_data.get()}, - xla::ErrorSpec(1e-2, 1e-2)); + ErrorSpec(1e-2, 1e-2)); } XLA_TEST_F(TriangularSolveTest, SimpleRightLowerNotranspose) { - xla::XlaBuilder builder(TestName()); + XlaBuilder builder(TestName()); - xla::XlaOp a, b; + XlaOp a, b; auto a_data = CreateR2Parameter(AValsLower(), 0, "a", &builder, &a); auto b_data = CreateR2Parameter(BValsRight(), 1, "b", &builder, &b); TriangularSolve(a, b, @@ -111,20 +135,20 @@ XLA_TEST_F(TriangularSolveTest, SimpleRightLowerNotranspose) { /*transpose_a=*/false, /*conjugate_a=*/false, /*block_size=*/2); - xla::Array2D expected({ + Array2D expected({ {-0.16414141, -0.06902357, -0.07070707, 0.36363636}, {0.64393939, 0.06565657, -0.03030303, 0.72727273}, {1.4520202, 0.2003367, 0.01010101, 1.09090909}, }); ComputeAndCompareR2(&builder, expected, {a_data.get(), b_data.get()}, - xla::ErrorSpec(1e-2, 1e-2)); + ErrorSpec(1e-2, 1e-2)); } XLA_TEST_F(TriangularSolveTest, SimpleRightUpperTranspose) { - xla::XlaBuilder builder(TestName()); + XlaBuilder builder(TestName()); - xla::XlaOp a, b; + XlaOp a, b; auto a_data = CreateR2Parameter(AValsUpper(), 0, "a", &builder, &a); auto b_data = CreateR2Parameter(BValsRight(), 1, "b", &builder, &b); TriangularSolve(a, b, @@ -132,20 +156,20 @@ XLA_TEST_F(TriangularSolveTest, SimpleRightUpperTranspose) { /*transpose_a=*/true, /*conjugate_a=*/false, /*block_size=*/2); - xla::Array2D expected({ + Array2D expected({ {-0.16414141, -0.06902357, -0.07070707, 0.36363636}, {0.64393939, 0.06565657, -0.03030303, 0.72727273}, {1.4520202, 0.2003367, 0.01010101, 1.09090909}, }); ComputeAndCompareR2(&builder, expected, {a_data.get(), b_data.get()}, - xla::ErrorSpec(1e-2, 1e-2)); + ErrorSpec(1e-2, 1e-2)); } XLA_TEST_F(TriangularSolveTest, SimpleRightUpperNotranspose) { - xla::XlaBuilder builder(TestName()); + XlaBuilder builder(TestName()); - xla::XlaOp a, b; + XlaOp a, b; auto a_data = CreateR2Parameter(AValsUpper(), 0, "a", &builder, &a); auto b_data = CreateR2Parameter(BValsRight(), 1, "b", &builder, &b); TriangularSolve(a, b, @@ -153,20 +177,20 @@ XLA_TEST_F(TriangularSolveTest, SimpleRightUpperNotranspose) { /*transpose_a=*/false, /*conjugate_a=*/false, /*block_size=*/2); - xla::Array2D expected({ + Array2D expected({ {0.5, 0.08333334, 0.04629629, 0.03367003}, {2.5, -0.25, -0.1388889, -0.1010101}, {4.5, -0.58333331, -0.32407406, -0.23569024}, }); ComputeAndCompareR2(&builder, expected, {a_data.get(), b_data.get()}, - xla::ErrorSpec(1e-2, 1e-2)); + ErrorSpec(1e-2, 1e-2)); } XLA_TEST_F(TriangularSolveTest, SimpleLeftLowerTranspose) { - xla::XlaBuilder builder(TestName()); + XlaBuilder builder(TestName()); - xla::XlaOp a, b; + XlaOp a, b; auto a_data = CreateR2Parameter(AValsLower(), 0, "a", &builder, &a); auto b_data = CreateR2Parameter(BValsLeft(), 1, "b", &builder, &b); TriangularSolve(a, b, @@ -174,7 +198,7 @@ XLA_TEST_F(TriangularSolveTest, SimpleLeftLowerTranspose) { /*transpose_a=*/true, /*conjugate_a=*/false, /*block_size=*/2); - xla::Array2D expected({ + Array2D expected({ {-0.89646465, -0.69444444, -0.49242424}, {-0.27441077, -0.24074074, -0.20707071}, {-0.23232323, -0.22222222, -0.21212121}, @@ -182,13 +206,13 @@ XLA_TEST_F(TriangularSolveTest, SimpleLeftLowerTranspose) { }); ComputeAndCompareR2(&builder, expected, {a_data.get(), b_data.get()}, - xla::ErrorSpec(1e-2, 1e-2)); + ErrorSpec(1e-2, 1e-2)); } XLA_TEST_F(TriangularSolveTest, SimpleLeftLowerNotranspose) { - xla::XlaBuilder builder(TestName()); + XlaBuilder builder(TestName()); - xla::XlaOp a, b; + XlaOp a, b; auto a_data = CreateR2Parameter(AValsLower(), 0, "a", &builder, &a); auto b_data = CreateR2Parameter(BValsLeft(), 1, "b", &builder, &b); TriangularSolve(a, b, @@ -196,7 +220,7 @@ XLA_TEST_F(TriangularSolveTest, SimpleLeftLowerNotranspose) { /*transpose_a=*/false, /*conjugate_a=*/false, /*block_size=*/2); - xla::Array2D expected({ + Array2D expected({ {0.5, 1.0, 1.5}, {0.41666667, 0.33333333, 0.25}, {0.23148148, 0.18518519, 0.13888889}, @@ -204,13 +228,13 @@ XLA_TEST_F(TriangularSolveTest, SimpleLeftLowerNotranspose) { }); ComputeAndCompareR2(&builder, expected, {a_data.get(), b_data.get()}, - xla::ErrorSpec(1e-2, 1e-2)); + ErrorSpec(1e-2, 1e-2)); } XLA_TEST_F(TriangularSolveTest, SimpleLeftLowerNotransposeIrregularblock) { - xla::XlaBuilder builder(TestName()); + XlaBuilder builder(TestName()); - xla::XlaOp a, b; + XlaOp a, b; auto a_data = CreateR2Parameter(AValsLower(), 0, "a", &builder, &a); auto b_data = CreateR2Parameter(BValsLeft(), 1, "b", &builder, &b); TriangularSolve(a, b, @@ -218,7 +242,7 @@ XLA_TEST_F(TriangularSolveTest, SimpleLeftLowerNotransposeIrregularblock) { /*transpose_a=*/false, /*conjugate_a=*/false, /*block_size=*/3); - xla::Array2D expected({ + Array2D expected({ {0.5, 1.0, 1.5}, {0.41666667, 0.33333333, 0.25}, {0.23148148, 0.18518519, 0.13888889}, @@ -226,13 +250,13 @@ XLA_TEST_F(TriangularSolveTest, SimpleLeftLowerNotransposeIrregularblock) { }); ComputeAndCompareR2(&builder, expected, {a_data.get(), b_data.get()}, - xla::ErrorSpec(1e-2, 1e-2)); + ErrorSpec(1e-2, 1e-2)); } XLA_TEST_F(TriangularSolveTest, SimpleLeftUpperTranspose) { - xla::XlaBuilder builder(TestName()); + XlaBuilder builder(TestName()); - xla::XlaOp a, b; + XlaOp a, b; auto a_data = CreateR2Parameter(AValsUpper(), 0, "a", &builder, &a); auto b_data = CreateR2Parameter(BValsLeft(), 1, "b", &builder, &b); TriangularSolve(a, b, @@ -240,7 +264,7 @@ XLA_TEST_F(TriangularSolveTest, SimpleLeftUpperTranspose) { /*transpose_a=*/true, /*conjugate_a=*/false, /*block_size=*/2); - xla::Array2D expected({ + Array2D expected({ {0.5, 1.0, 1.5}, {0.41666667, 0.33333333, 0.25}, {0.23148148, 0.18518519, 0.13888889}, @@ -248,13 +272,13 @@ XLA_TEST_F(TriangularSolveTest, SimpleLeftUpperTranspose) { }); ComputeAndCompareR2(&builder, expected, {a_data.get(), b_data.get()}, - xla::ErrorSpec(1e-2, 1e-2)); + ErrorSpec(1e-2, 1e-2)); } XLA_TEST_F(TriangularSolveTest, SimpleLeftUpperNotranspose) { - xla::XlaBuilder builder(TestName()); + XlaBuilder builder(TestName()); - xla::XlaOp a, b; + XlaOp a, b; auto a_data = CreateR2Parameter(AValsUpper(), 0, "a", &builder, &a); auto b_data = CreateR2Parameter(BValsLeft(), 1, "b", &builder, &b); TriangularSolve(a, b, @@ -262,7 +286,7 @@ XLA_TEST_F(TriangularSolveTest, SimpleLeftUpperNotranspose) { /*transpose_a=*/false, /*conjugate_a=*/false, /*block_size=*/2); - xla::Array2D expected({ + Array2D expected({ {-0.89646465, -0.69444444, -0.49242424}, {-0.27441077, -0.24074074, -0.20707071}, {-0.23232323, -0.22222222, -0.21212121}, @@ -270,13 +294,13 @@ XLA_TEST_F(TriangularSolveTest, SimpleLeftUpperNotranspose) { }); ComputeAndCompareR2(&builder, expected, {a_data.get(), b_data.get()}, - xla::ErrorSpec(1e-2, 1e-2)); + ErrorSpec(1e-2, 1e-2)); } XLA_TEST_F(TriangularSolveTest, SimpleRightLowerTransposeConjugate) { - xla::XlaBuilder builder(TestName()); + XlaBuilder builder(TestName()); - xla::XlaOp a, b; + XlaOp a, b; auto a_data = CreateR2Parameter(AValsLowerComplex(), 0, "a", &builder, &a); auto b_data = @@ -286,7 +310,7 @@ XLA_TEST_F(TriangularSolveTest, SimpleRightLowerTransposeConjugate) { /*transpose_a=*/true, /*conjugate_a=*/true, /*block_size=*/2); - xla::Array2D expected({ + Array2D expected({ {0.5, complex64(0.08333333, 0.08333333), complex64(0.02777778, -0.0462963), complex64(0.06313131, -0.01094276)}, {2.5, complex64(-0.25, 0.41666667), complex64(-0.23148148, -0.37962963), @@ -295,15 +319,14 @@ XLA_TEST_F(TriangularSolveTest, SimpleRightLowerTransposeConjugate) { complex64(0.11026936, -0.03114478)}, }); - ComputeAndCompareR2(&builder, expected, - {a_data.get(), b_data.get()}, - xla::ErrorSpec(1e-2, 1e-2)); + ComputeAndCompareR2( + &builder, expected, {a_data.get(), b_data.get()}, ErrorSpec(1e-2, 1e-2)); } XLA_TEST_F(TriangularSolveTest, SimpleLeftUpperTransposeNoconjugate) { - xla::XlaBuilder builder(TestName()); + XlaBuilder builder(TestName()); - xla::XlaOp a, b; + XlaOp a, b; auto a_data = CreateR2Parameter(AValsUpperComplex(), 0, "a", &builder, &a); auto b_data = @@ -313,7 +336,7 @@ XLA_TEST_F(TriangularSolveTest, SimpleLeftUpperTransposeNoconjugate) { /*transpose_a=*/true, /*conjugate_a=*/false, /*block_size=*/2); - xla::Array2D expected({ + Array2D expected({ {0.5, 1., 1.5}, {0.41666667, 0.33333333, 0.25}, {complex64(0.20020325, -2.81504065e-01), @@ -324,10 +347,101 @@ XLA_TEST_F(TriangularSolveTest, SimpleLeftUpperTransposeNoconjugate) { complex64(0.15798226, 5.12749446e-01)}, }); - ComputeAndCompareR2(&builder, expected, - {a_data.get(), b_data.get()}, - xla::ErrorSpec(1e-2, 1e-2)); + ComputeAndCompareR2( + &builder, expected, {a_data.get(), b_data.get()}, ErrorSpec(1e-2, 1e-2)); } +XLA_TEST_F(TriangularSolveTest, BatchedLeftUpper) { + XlaBuilder builder(TestName()); + + Array3D bvals(7, 5, 5); + bvals.FillIota(1.); + + // Set avals to the upper triangle of bvals. + Array3D avals = bvals; + avals.Each([](absl::Span indices, float* value) { + if (indices[1] > indices[2]) { + *value = 0; + } + }); + + XlaOp a, b; + auto a_data = CreateR3Parameter(avals, 0, "a", &builder, &a); + auto b_data = CreateR3Parameter(bvals, 1, "b", &builder, &b); + BatchDot(ConstantR3FromArray3D(&builder, avals), + TriangularSolve(a, b, + /*left_side=*/true, /*lower=*/false, + /*transpose_a=*/false, /*conjugate_a=*/false, + /*block_size=*/2)); + + ComputeAndCompareR3(&builder, bvals, {a_data.get(), b_data.get()}, + ErrorSpec(1e-2, 1e-2)); +} + +struct TriangularSolveTestSpec { + int m, n; // A is mxm, B is mxn + bool left_side; + bool lower; + bool transpose_a; +}; + +class TriangularSolveParametricTest + : public ClientLibraryTestBase, + public ::testing::WithParamInterface {}; + +XLA_TEST_P(TriangularSolveParametricTest, Random) { + TriangularSolveTestSpec spec = GetParam(); + + XlaBuilder builder(TestName()); + + Array2D avals(spec.m, spec.m); + avals.FillRandom(1.0); + for (int i = 0; i < spec.m; ++i) { + avals(i, i) += 10; + } + + std::pair bdims = spec.left_side ? std::make_pair(spec.m, spec.n) + : std::make_pair(spec.n, spec.m); + Array2D bvals(bdims.first, bdims.second); + bvals.FillRandom(1.0); + + XlaOp a, b; + auto a_data = CreateR2Parameter(avals, 0, "a", &builder, &a); + auto b_data = CreateR2Parameter(bvals, 1, "b", &builder, &b); + auto x = TriangularSolve(a, b, spec.left_side, spec.lower, spec.transpose_a, + /*conjugate_a=*/false, + /*block_size=*/3); + auto a_tri = Triangle(a, spec.lower); + a_tri = MaybeTransposeInMinorDims(a_tri, spec.transpose_a); + if (spec.left_side) { + BatchDot(a_tri, x); + } else { + BatchDot(x, a_tri); + } + + ComputeAndCompareR2(&builder, bvals, {a_data.get(), b_data.get()}, + ErrorSpec(1e-2, 1e-2)); +} + +std::vector TriangularSolveTests() { + std::vector specs; + for (int m : {5, 10}) { + for (int n : {5, 10}) { + for (bool left_side : {false, true}) { + for (bool lower : {false, true}) { + for (bool transpose_a : {false, true}) { + specs.push_back({m, n, left_side, lower, transpose_a}); + } + } + } + } + } + return specs; +} + +INSTANTIATE_TEST_CASE_P(TriangularSolveParametricTestInstantiation, + TriangularSolveParametricTest, + ::testing::ValuesIn(TriangularSolveTests())); + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/client/xla_builder.cc b/tensorflow/compiler/xla/client/xla_builder.cc index 60df2ec3959216b0564846ad47c21c5bcc01ea57..622fc158e11161b5b1167ccb432f51775767e3a1 100644 --- a/tensorflow/compiler/xla/client/xla_builder.cc +++ b/tensorflow/compiler/xla/client/xla_builder.cc @@ -211,7 +211,7 @@ void XlaBuilder::IsConstantVisitor(const int64 op_handle, // Non functional ops. case HloOpcode::kRng: - case HloOpcode::kCrossReplicaSum: + case HloOpcode::kAllReduce: // TODO(b/33009255): Implmement constant folding for cross replica sum. case HloOpcode::kInfeed: case HloOpcode::kOutfeed: @@ -959,27 +959,29 @@ Status XlaBuilder::VerifyConvolution( XlaOp XlaBuilder::Conv(const XlaOp& lhs, const XlaOp& rhs, absl::Span window_strides, Padding padding, - int64 feature_group_count, + int64 feature_group_count, int64 batch_group_count, const PrecisionConfig* precision_config) { return ConvWithGeneralDimensions( lhs, rhs, window_strides, padding, CreateDefaultConvDimensionNumbers(window_strides.size()), - feature_group_count, precision_config); + feature_group_count, batch_group_count, precision_config); } XlaOp XlaBuilder::ConvWithGeneralPadding( const XlaOp& lhs, const XlaOp& rhs, absl::Span window_strides, absl::Span> padding, - int64 feature_group_count, const PrecisionConfig* precision_config) { + int64 feature_group_count, int64 batch_group_count, + const PrecisionConfig* precision_config) { return ConvGeneral(lhs, rhs, window_strides, padding, CreateDefaultConvDimensionNumbers(window_strides.size()), - feature_group_count, precision_config); + feature_group_count, batch_group_count, precision_config); } XlaOp XlaBuilder::ConvWithGeneralDimensions( const XlaOp& lhs, const XlaOp& rhs, absl::Span window_strides, Padding padding, const ConvolutionDimensionNumbers& dimension_numbers, - int64 feature_group_count, const PrecisionConfig* precision_config) { + int64 feature_group_count, int64 batch_group_count, + const PrecisionConfig* precision_config) { return ReportErrorOrReturn([&]() -> StatusOr { TF_ASSIGN_OR_RETURN(const Shape& lhs_shape, GetShape(lhs)); TF_ASSIGN_OR_RETURN(const Shape& rhs_shape, GetShape(rhs)); @@ -1007,7 +1009,7 @@ XlaOp XlaBuilder::ConvWithGeneralDimensions( MakePadding(base_area_dimensions, window_dimensions, window_strides, padding), dimension_numbers, feature_group_count, - precision_config); + batch_group_count, precision_config); }); } @@ -1015,10 +1017,11 @@ XlaOp XlaBuilder::ConvGeneral( const XlaOp& lhs, const XlaOp& rhs, absl::Span window_strides, absl::Span> padding, const ConvolutionDimensionNumbers& dimension_numbers, - int64 feature_group_count, const PrecisionConfig* precision_config) { + int64 feature_group_count, int64 batch_group_count, + const PrecisionConfig* precision_config) { return ConvGeneralDilated(lhs, rhs, window_strides, padding, {}, {}, dimension_numbers, feature_group_count, - precision_config); + batch_group_count, precision_config); } XlaOp XlaBuilder::ConvGeneralDilated( @@ -1026,7 +1029,8 @@ XlaOp XlaBuilder::ConvGeneralDilated( absl::Span> padding, absl::Span lhs_dilation, absl::Span rhs_dilation, const ConvolutionDimensionNumbers& dimension_numbers, - int64 feature_group_count, const PrecisionConfig* precision_config) { + int64 feature_group_count, int64 batch_group_count, + const PrecisionConfig* precision_config) { return ReportErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; TF_ASSIGN_OR_RETURN(const Shape& lhs_shape, GetShape(lhs)); @@ -1045,14 +1049,15 @@ XlaOp XlaBuilder::ConvGeneralDilated( MakeWindow(window_dimensions, window_strides, padding, lhs_dilation, rhs_dilation)); - TF_ASSIGN_OR_RETURN(Shape shape, - ShapeInference::InferConvolveShape( - lhs_shape, rhs_shape, feature_group_count, - instr.window(), dimension_numbers)); + TF_ASSIGN_OR_RETURN( + Shape shape, ShapeInference::InferConvolveShape( + lhs_shape, rhs_shape, feature_group_count, + batch_group_count, instr.window(), dimension_numbers)); *instr.mutable_shape() = shape.ToProto(); *instr.mutable_convolution_dimension_numbers() = dimension_numbers; instr.set_feature_group_count(feature_group_count); + instr.set_batch_group_count(batch_group_count); if (precision_config != nullptr) { *instr.mutable_precision_config() = *precision_config; @@ -2015,8 +2020,8 @@ XlaOp XlaBuilder::CrossReplicaSum( return ReportErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand)); - TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferCrossReplicaSumShape( - {&operand_shape})); + TF_ASSIGN_OR_RETURN(Shape shape, + ShapeInference::InferAllReduceShape({&operand_shape})); *instr.mutable_shape() = shape.ToProto(); for (const ReplicaGroup& group : replica_groups) { @@ -2029,8 +2034,7 @@ XlaOp XlaBuilder::CrossReplicaSum( AddCalledComputation(computation, &instr); - return AddInstruction(std::move(instr), HloOpcode::kCrossReplicaSum, - {operand}); + return AddInstruction(std::move(instr), HloOpcode::kAllReduce, {operand}); }); } @@ -2786,38 +2790,42 @@ XlaOp DotGeneral(const XlaOp& lhs, const XlaOp& rhs, XlaOp Conv(const XlaOp& lhs, const XlaOp& rhs, absl::Span window_strides, Padding padding, - int64 feature_group_count, const PrecisionConfig* precision_config) { + int64 feature_group_count, int64 batch_group_count, + const PrecisionConfig* precision_config) { return lhs.builder()->Conv(lhs, rhs, window_strides, padding, - feature_group_count, precision_config); + feature_group_count, batch_group_count, + precision_config); } XlaOp ConvWithGeneralPadding(const XlaOp& lhs, const XlaOp& rhs, absl::Span window_strides, absl::Span> padding, - int64 feature_group_count, + int64 feature_group_count, int64 batch_group_count, const PrecisionConfig* precision_config) { return lhs.builder()->ConvWithGeneralPadding( - lhs, rhs, window_strides, padding, feature_group_count, precision_config); + lhs, rhs, window_strides, padding, feature_group_count, batch_group_count, + precision_config); } XlaOp ConvWithGeneralDimensions( const XlaOp& lhs, const XlaOp& rhs, absl::Span window_strides, Padding padding, const ConvolutionDimensionNumbers& dimension_numbers, - int64 feature_group_count, const PrecisionConfig* precision_config) { + int64 feature_group_count, int64 batch_group_count, + const PrecisionConfig* precision_config) { return lhs.builder()->ConvWithGeneralDimensions( lhs, rhs, window_strides, padding, dimension_numbers, feature_group_count, - precision_config); + batch_group_count, precision_config); } XlaOp ConvGeneral(const XlaOp& lhs, const XlaOp& rhs, absl::Span window_strides, absl::Span> padding, const ConvolutionDimensionNumbers& dimension_numbers, - int64 feature_group_count, + int64 feature_group_count, int64 batch_group_count, const PrecisionConfig* precision_config) { return lhs.builder()->ConvGeneral(lhs, rhs, window_strides, padding, dimension_numbers, feature_group_count, - precision_config); + batch_group_count, precision_config); } XlaOp ConvGeneralDilated(const XlaOp& lhs, const XlaOp& rhs, @@ -2826,11 +2834,12 @@ XlaOp ConvGeneralDilated(const XlaOp& lhs, const XlaOp& rhs, absl::Span lhs_dilation, absl::Span rhs_dilation, const ConvolutionDimensionNumbers& dimension_numbers, - int64 feature_group_count, + int64 feature_group_count, int64 batch_group_count, const PrecisionConfig* precision_config) { return lhs.builder()->ConvGeneralDilated( lhs, rhs, window_strides, padding, lhs_dilation, rhs_dilation, - dimension_numbers, feature_group_count, precision_config); + dimension_numbers, feature_group_count, batch_group_count, + precision_config); } XlaOp Fft(const XlaOp& operand, FftType fft_type, diff --git a/tensorflow/compiler/xla/client/xla_builder.h b/tensorflow/compiler/xla/client/xla_builder.h index 098efb60f9bdca8306ff771a505f4a225dea9f7d..6e9b025e5d70c03e9f4c7e7fbc89976f314d48d7 100644 --- a/tensorflow/compiler/xla/client/xla_builder.h +++ b/tensorflow/compiler/xla/client/xla_builder.h @@ -387,28 +387,28 @@ class XlaBuilder { XlaOp Conv(const XlaOp& lhs, const XlaOp& rhs, absl::Span window_strides, Padding padding, - int64 feature_group_count = 1, + int64 feature_group_count = 1, int64 batch_group_count = 1, const PrecisionConfig* precision_config = nullptr); XlaOp ConvWithGeneralPadding( const XlaOp& lhs, const XlaOp& rhs, absl::Span window_strides, absl::Span> padding, - int64 feature_group_count = 1, + int64 feature_group_count = 1, int64 batch_group_count = 1, const PrecisionConfig* precision_config = nullptr); XlaOp ConvWithGeneralDimensions( const XlaOp& lhs, const XlaOp& rhs, absl::Span window_strides, Padding padding, const ConvolutionDimensionNumbers& dimension_numbers, - int64 feature_group_count = 1, + int64 feature_group_count = 1, int64 batch_group_count = 1, const PrecisionConfig* precision_config = nullptr); XlaOp ConvGeneral(const XlaOp& lhs, const XlaOp& rhs, absl::Span window_strides, absl::Span> padding, const ConvolutionDimensionNumbers& dimension_numbers, - int64 feature_group_count = 1, + int64 feature_group_count = 1, int64 batch_group_count = 1, const PrecisionConfig* precision_config = nullptr); XlaOp ConvGeneralDilated(const XlaOp& lhs, const XlaOp& rhs, @@ -418,6 +418,7 @@ class XlaBuilder { absl::Span rhs_dilation, const ConvolutionDimensionNumbers& dimension_numbers, int64 feature_group_count = 1, + int64 batch_group_count = 1, const PrecisionConfig* precision_config = nullptr); XlaOp Fft(const XlaOp& operand, FftType fft_type, @@ -881,23 +882,25 @@ class XlaBuilder { const PrecisionConfig* precision_config); friend XlaOp Conv(const XlaOp& lhs, const XlaOp& rhs, absl::Span window_strides, Padding padding, - int64 feature_group_count, + int64 feature_group_count, int64 batch_group_count, const PrecisionConfig* precision_config); friend XlaOp ConvWithGeneralPadding( const XlaOp& lhs, const XlaOp& rhs, absl::Span window_strides, absl::Span> padding, - int64 feature_group_count, const PrecisionConfig* precision_config); + int64 feature_group_count, int64 batch_group_count, + const PrecisionConfig* precision_config); friend XlaOp ConvWithGeneralDimensions( const XlaOp& lhs, const XlaOp& rhs, absl::Span window_strides, Padding padding, const ConvolutionDimensionNumbers& dimension_numbers, - int64 feature_group_count, const PrecisionConfig* precision_config); + int64 feature_group_count, int64 batch_group_count, + const PrecisionConfig* precision_config); friend XlaOp ConvGeneral(const XlaOp& lhs, const XlaOp& rhs, absl::Span window_strides, absl::Span> padding, const ConvolutionDimensionNumbers& dimension_numbers, - int64 feature_group_count, + int64 feature_group_count, int64 batch_group_count, const PrecisionConfig* precision_config); friend XlaOp ConvGeneralDilated( const XlaOp& lhs, const XlaOp& rhs, @@ -906,7 +909,8 @@ class XlaBuilder { absl::Span lhs_dilation, absl::Span rhs_dilation, const ConvolutionDimensionNumbers& dimension_numbers, - int64 feature_group_count, const PrecisionConfig* precision_config); + int64 feature_group_count, int64 batch_group_count, + const PrecisionConfig* precision_config); friend XlaOp Fft(const XlaOp& operand, FftType fft_type, absl::Span fft_length); friend XlaOp Infeed(XlaBuilder* builder, const Shape& shape, @@ -1372,7 +1376,7 @@ XlaOp DotGeneral(const XlaOp& lhs, const XlaOp& rhs, // default convolution dimension numbers. XlaOp Conv(const XlaOp& lhs, const XlaOp& rhs, absl::Span window_strides, Padding padding, - int64 feature_group_count = 1, + int64 feature_group_count = 1, int64 batch_group_count = 1, const PrecisionConfig* precision_config = nullptr); // Enqueues a convolution instruction onto the computation, with the caller @@ -1381,6 +1385,7 @@ XlaOp ConvWithGeneralPadding(const XlaOp& lhs, const XlaOp& rhs, absl::Span window_strides, absl::Span> padding, int64 feature_group_count = 1, + int64 batch_group_count = 1, const PrecisionConfig* precision_config = nullptr); // Enqueues a convolution instruction onto the computation, with the caller @@ -1388,7 +1393,7 @@ XlaOp ConvWithGeneralPadding(const XlaOp& lhs, const XlaOp& rhs, XlaOp ConvWithGeneralDimensions( const XlaOp& lhs, const XlaOp& rhs, absl::Span window_strides, Padding padding, const ConvolutionDimensionNumbers& dimension_numbers, - int64 feature_group_count = 1, + int64 feature_group_count = 1, int64 batch_group_count = 1, const PrecisionConfig* precision_config = nullptr); // Enqueues a convolution instruction onto the computation, with the caller @@ -1397,7 +1402,7 @@ XlaOp ConvGeneral(const XlaOp& lhs, const XlaOp& rhs, absl::Span window_strides, absl::Span> padding, const ConvolutionDimensionNumbers& dimension_numbers, - int64 feature_group_count = 1, + int64 feature_group_count = 1, int64 batch_group_count = 1, const PrecisionConfig* precision_config = nullptr); // Enqueues a convolution instruction onto the computation, with the caller @@ -1409,6 +1414,7 @@ XlaOp ConvGeneralDilated(const XlaOp& lhs, const XlaOp& rhs, absl::Span rhs_dilation, const ConvolutionDimensionNumbers& dimension_numbers, int64 feature_group_count = 1, + int64 batch_group_count = 1, const PrecisionConfig* precision_config = nullptr); // Enqueues an FFT instruction onto the computation, of the given type and diff --git a/tensorflow/compiler/xla/debug_options_flags.cc b/tensorflow/compiler/xla/debug_options_flags.cc index d7e7b9e621894f1c363734d6415a38d2e8165463..a9a91648ac377987e7f226116e11c9c697ace103 100644 --- a/tensorflow/compiler/xla/debug_options_flags.cc +++ b/tensorflow/compiler/xla/debug_options_flags.cc @@ -22,49 +22,49 @@ limitations under the License. #include "tensorflow/compiler/xla/parse_flags_from_env.h" namespace xla { -namespace { -DebugOptions* flag_values; -std::vector* flag_objects; -std::once_flag flags_init; - -void SetDebugOptionsDefaults(DebugOptions* flags) { - flags->set_xla_llvm_enable_alias_scope_metadata(true); - flags->set_xla_llvm_enable_noalias_metadata(true); - flags->set_xla_llvm_enable_invariant_load_metadata(true); - flags->set_xla_llvm_disable_expensive_passes(false); - flags->set_xla_backend_optimization_level(3); - flags->set_xla_cpu_multi_thread_eigen(true); - flags->set_xla_gpu_cuda_data_dir("./cuda_sdk_lib"); - flags->set_xla_eliminate_hlo_implicit_broadcast(true); +DebugOptions DefaultDebugOptionsIgnoringFlags() { + DebugOptions opts; + opts.set_xla_llvm_enable_alias_scope_metadata(true); + opts.set_xla_llvm_enable_noalias_metadata(true); + opts.set_xla_llvm_enable_invariant_load_metadata(true); + opts.set_xla_llvm_disable_expensive_passes(false); + opts.set_xla_backend_optimization_level(3); + opts.set_xla_cpu_multi_thread_eigen(true); + opts.set_xla_gpu_cuda_data_dir("./cuda_sdk_lib"); + opts.set_xla_eliminate_hlo_implicit_broadcast(true); + opts.set_xla_hlo_dump_as_html(false); #ifdef INTEL_MKL - flags->set_xla_cpu_use_mkl_dnn(true); + opts.set_xla_cpu_use_mkl_dnn(true); #endif // INTEL_MKL - flags->set_xla_gpu_max_kernel_unroll_factor(4); + opts.set_xla_gpu_max_kernel_unroll_factor(4); // Set cudnn batchnorm off by default; it does not provide a performance win // on average. - flags->set_xla_gpu_use_cudnn_batchnorm(false); + opts.set_xla_gpu_use_cudnn_batchnorm(false); // Run all GPU work on one stream by default. Using multiple streams // increases memory usage and we lack strong motivating benchmarks for tuning // the heuristics needed to decide when to run on multiple streams. See // b/77879207. - flags->set_xla_gpu_disable_multi_streaming(true); + opts.set_xla_gpu_disable_multi_streaming(true); // TODO(jlebar): Disable fastmath once doing so is not a performance // regression. - flags->set_xla_cpu_enable_fast_math(true); - flags->set_xla_gpu_enable_fast_min_max(true); + opts.set_xla_cpu_enable_fast_math(true); + opts.set_xla_gpu_enable_fast_min_max(true); - flags->set_xla_force_host_platform_device_count(1); + opts.set_xla_force_host_platform_device_count(1); + return opts; } +static DebugOptions* flag_values; +static std::vector* flag_objects; +static std::once_flag flags_init; + // Allocates flag_values and flag_objects; this function must not be called more // than once - its call done via call_once. -void AllocateFlags() { - flag_values = new DebugOptions; - - SetDebugOptionsDefaults(flag_values); +static void AllocateFlags() { + flag_values = new DebugOptions(DefaultDebugOptionsIgnoringFlags()); // Returns a lambda that calls "member_setter" on "flag_values" with the // argument passed in to the lambda. @@ -133,6 +133,11 @@ void AllocateFlags() { bool_setter_for(&DebugOptions::set_xla_hlo_dump_as_graphdef), flag_values->xla_hlo_dump_as_graphdef(), "Dump HLO graphs as TensorFlow GraphDefs."), + tensorflow::Flag("xla_hlo_dump_as_html", + bool_setter_for(&DebugOptions::set_xla_hlo_dump_as_html), + flag_values->xla_hlo_dump_as_html(), + "Dump HLO graphs as an HTML (DOT rendered into SVG " + "inlined in HTML)."), tensorflow::Flag( "xla_hlo_graph_sharding_color", bool_setter_for(&DebugOptions::set_xla_hlo_graph_sharding_color), @@ -202,6 +207,16 @@ void AllocateFlags() { "Comma-separated list of hlo passes to be disabled. These names " "must exactly match the passes' names; no whitespace around " "commas."), + tensorflow::Flag( + "xla_disable_all_hlo_passes", + bool_setter_for(&DebugOptions::set_xla_disable_all_hlo_passes), false, + "Disables all HLO passes. Notes that some passes are necessary for " + "correctness and the invariants that must be satisfied by 'fully " + "optimized' HLO are different for different devices and may change " + "over time. The only 'guarantee', such as it is, is that if you " + "compile XLA and dump the optimized HLO for some graph, you should " + "be able to run it again on the same device with the same build of " + "XLA."), tensorflow::Flag( "xla_embed_ir_in_executable", bool_setter_for(&DebugOptions::set_xla_embed_ir_in_executable), @@ -334,12 +349,16 @@ void AllocateFlags() { "overhead from context switching but we let the user override this " "behavior to help run tests on the host that run models in parallel " "across multiple devices."), + tensorflow::Flag( + "xla_gpu_disable_ptxas_optimizations", + bool_setter_for( + &DebugOptions::set_xla_gpu_disable_ptxas_optimizations), + flag_values->xla_gpu_disable_ptxas_optimizations(), + "In XLA:GPU run ptxas in -O0 (default is -O3)."), }); ParseFlagsFromEnvAndDieIfUnknown("XLA_FLAGS", *flag_objects); } -} // namespace - void AppendDebugOptionsFlags(std::vector* flag_list) { std::call_once(flags_init, &AllocateFlags); flag_list->insert(flag_list->end(), flag_objects->begin(), diff --git a/tensorflow/compiler/xla/debug_options_flags.h b/tensorflow/compiler/xla/debug_options_flags.h index 60e59abc2a2e0f1cce3de1afc928f9fe36f75b33..dbf86a40f052af09c61da0e1abb3116ef5214357 100644 --- a/tensorflow/compiler/xla/debug_options_flags.h +++ b/tensorflow/compiler/xla/debug_options_flags.h @@ -29,7 +29,10 @@ void AppendDebugOptionsFlags(std::vector* flag_list); // Fetches a DebugOptions proto message from flags provided to the program. // Flags must be registered with the flags parser using AppendDebugOptionsFlags // first. -xla::DebugOptions GetDebugOptionsFromFlags(); +DebugOptions GetDebugOptionsFromFlags(); + +// Gets a DebugOptions proto that reflects the defaults as if no flags were set. +DebugOptions DefaultDebugOptionsIgnoringFlags(); } // namespace xla diff --git a/tensorflow/compiler/xla/experimental/xla_sharding/xla_sharding.py b/tensorflow/compiler/xla/experimental/xla_sharding/xla_sharding.py index 1fea816a803bfb75b9721393cef8c4dfc249268d..c34e84efc80ba970624d80802841d6ec534b6fd0 100644 --- a/tensorflow/compiler/xla/experimental/xla_sharding/xla_sharding.py +++ b/tensorflow/compiler/xla/experimental/xla_sharding/xla_sharding.py @@ -104,9 +104,9 @@ class Sharding(object): ValueError: The tensor to split was smaller in the split dimension than the number of devices to split over. """ - tensor.shape.assert_is_fully_defined() shape = tensor.shape.as_list() - if shape[split_dimension] < num_devices: + if (shape[split_dimension] is not None and + shape[split_dimension] < num_devices): raise ValueError('Split dimension was smaller than the required number ' 'of splits: shape=%r, dimension=%r, num_devices=%r' % (shape, split_dimension, num_devices)) diff --git a/tensorflow/compiler/xla/g3doc/operation_semantics.md b/tensorflow/compiler/xla/g3doc/operation_semantics.md index d888b1f23f36f33ef94ef0e22374e0c796e47a89..9a9cd08c301502cbda8858225182d95fca4bf7ae 100644 --- a/tensorflow/compiler/xla/g3doc/operation_semantics.md +++ b/tensorflow/compiler/xla/g3doc/operation_semantics.md @@ -38,25 +38,25 @@ Alltoall is a collective operation that sends data from all cores to all cores. It has two phases: 1. the scatter phase. On each core, the operand is split into `split_count` - number of blocks along the `split_dimensions`, and the blocks are scattered - to all cores, e.g., the ith block is send to the ith core. +number of blocks along the `split_dimensions`, and the blocks are scattered +to all cores, e.g., the ith block is send to the ith core. 2. the gather phase. Each core concatenates the received blocks along the - `concat_dimension`. +`concat_dimension`. The participating cores can be configured by: - `replica_groups`: each ReplicaGroup contains a list of replica id. If empty, - all replicas belong to one group in the order of 0 - (n-1). Alltoall will be - applied within subgroups in the specified order. For example, replica - groups = {{1,2,3},{4,5,0}} means, an Alltoall will be applied within replica - 1, 2, 3, and in the gather phase, the received blocks will be concatenated - in the order of 1, 2, 3; another Alltoall will be applied within replica 4, - 5, 0, and the concatenation order is 4, 5, 0. +all replicas belong to one group in the order of 0 - (n-1). Alltoall will be +applied within subgroups in the specified order. For example, replica +groups = {{1,2,3},{4,5,0}} means, an Alltoall will be applied within replica +1, 2, 3, and in the gather phase, the received blocks will be concatenated +in the order of 1, 2, 3; another Alltoall will be applied within replica 4, +5, 0, and the concatenation order is 4, 5, 0. Prerequisites: - The dimension size of the operand on the split_dimension is divisible by - split_count. +split_count. - The operand's shape is not tuple. `AllToAll(operand, split_dimension, concat_dimension, split_count, @@ -93,7 +93,7 @@ AllToAll(x, /*split_dimension=*/1, /*concat_dimension=*/0, /*split_count=*/4); ```
- +
In this example, there are 4 cores participating the Alltoall. On each core, the @@ -387,34 +387,34 @@ For example, let v be an array of 24 elements: ``` let v = f32[4x2x3] {{{10, 11, 12}, {15, 16, 17}}, - {{20, 21, 22}, {25, 26, 27}}, - {{30, 31, 32}, {35, 36, 37}}, - {{40, 41, 42}, {45, 46, 47}}}; +{{20, 21, 22}, {25, 26, 27}}, +{{30, 31, 32}, {35, 36, 37}}, +{{40, 41, 42}, {45, 46, 47}}}; // Collapse to a single dimension, leaving one dimension. let v012 = Collapse(v, {0,1,2}); then v012 == f32[24] {10, 11, 12, 15, 16, 17, - 20, 21, 22, 25, 26, 27, - 30, 31, 32, 35, 36, 37, - 40, 41, 42, 45, 46, 47}; +20, 21, 22, 25, 26, 27, +30, 31, 32, 35, 36, 37, +40, 41, 42, 45, 46, 47}; // Collapse the two lower dimensions, leaving two dimensions. let v01 = Collapse(v, {0,1}); then v01 == f32[4x6] {{10, 11, 12, 15, 16, 17}, - {20, 21, 22, 25, 26, 27}, - {30, 31, 32, 35, 36, 37}, - {40, 41, 42, 45, 46, 47}}; +{20, 21, 22, 25, 26, 27}, +{30, 31, 32, 35, 36, 37}, +{40, 41, 42, 45, 46, 47}}; // Collapse the two higher dimensions, leaving two dimensions. let v12 = Collapse(v, {1,2}); then v12 == f32[8x3] {{10, 11, 12}, - {15, 16, 17}, - {20, 21, 22}, - {25, 26, 27}, - {30, 31, 32}, - {35, 36, 37}, - {40, 41, 42}, - {45, 46, 47}}; +{15, 16, 17}, +{20, 21, 22}, +{25, 26, 27}, +{30, 31, 32}, +{35, 36, 37}, +{40, 41, 42}, +{45, 46, 47}}; ``` @@ -441,9 +441,9 @@ replicas. Note that there are the following restrictions on the `source_target_pair`: - Any two pairs should not have the same target replica id, and they should - not have the same source replica id. +not have the same source replica id. - If a replica id is not a target in any pair, then the output on that replica - is a tensor consists of 0(s) with the same shape as the input. +is a tensor consists of 0(s) with the same shape as the input. ## Concatenate @@ -480,25 +480,25 @@ Concat({{2, 3}, {4, 5}, {6, 7}}, 0) ``` let a = { - {1, 2}, - {3, 4}, - {5, 6}, +{1, 2}, +{3, 4}, +{5, 6}, }; let b = { - {7, 8}, +{7, 8}, }; Concat({a, b}, 0) >>> { - {1, 2}, - {3, 4}, - {5, 6}, - {7, 8}, +{1, 2}, +{3, 4}, +{5, 6}, +{7, 8}, } ``` Diagram:
- +
## Conditional @@ -548,17 +548,23 @@ Computes a convolution of the kind used in neural networks. Here, a convolution can be thought of as a n-dimensional window moving across a n-dimensional base area and a computation is performed for each possible position of the window. -| Arguments | Type | Semantics | -| --------------------- | -------------------- | ----------------------------- | -| `lhs` | `XlaOp` | rank n+2 array of inputs | -| `rhs` | `XlaOp` | rank n+2 array of kernel | -: : : weights : -| `window_strides` | `ArraySlice` | n-d array of kernel strides | -| `padding` | `ArraySlice< | n-d array of (low, high) | -: : pair>` : padding : -| `lhs_dilation` | `ArraySlice` | n-d lhs dilation factor array | -| `rhs_dilation` | `ArraySlice` | n-d rhs dilation factor array | -| `feature_group_count` | int64 | the number of feature groups | +| Arguments | Type | Semantics | +| --------------------- | ------------------------ | ------------------------ | +| `lhs` | `XlaOp` | rank n+2 array of inputs | +| `rhs` | `XlaOp` | rank n+2 array of kernel | +: : : weights : +| `window_strides` | `ArraySlice` | n-d array of kernel | +: : : strides : +| `padding` | `ArraySlice< pair>` : padding : +| `lhs_dilation` | `ArraySlice` | n-d lhs dilation factor | +: : : array : +| `rhs_dilation` | `ArraySlice` | n-d rhs dilation factor | +: : : array : +| `feature_group_count` | int64 | the number of feature | +: : : groups : +| `batch_group_count` | int64 | the number of batch | +: : : groups : Let n be the number of spatial dimensions. The `lhs` argument is a rank n+2 array describing the base area. This is called the input, even though of course @@ -566,20 +572,20 @@ the rhs is also an input. In a neural network, these are the input activations. The n+2 dimensions are, in this order: * `batch`: Each coordinate in this dimension represents an independent input - for which convolution is carried out. +for which convolution is carried out. * `z/depth/features`: Each (y,x) position in the base area has a vector - associated to it, which goes into this dimension. +associated to it, which goes into this dimension. * `spatial_dims`: Describes the `n` spatial dimensions that define the base - area that the window moves across. +area that the window moves across. The `rhs` argument is a rank n+2 array describing the convolutional filter/kernel/window. The dimensions are, in this order: * `output-z`: The `z` dimension of the output. * `input-z`: The size of this dimension times `feature_group_count` should - equal the size of the `z` dimension in lhs. +equal the size of the `z` dimension in lhs. * `spatial_dims`: Describes the `n` spatial dimensions that define the n-d - window that moves across the base area. +window that moves across the base area. The `window_strides` argument specifies the stride of the convolutional window in the spatial dimensions. For example, if the stride in the first spatial @@ -628,9 +634,18 @@ input feature dimension, and the filter would be reshaped from `[filter_height, filter_width, 1, in_channels * channel_multiplier]`. For more details, see `tf.nn.depthwise_conv2d`. +The `batch_group_count` (default value 1) argument can be used for depthwise +filters during backpropagation. `batch_group_count` needs to be a divisor of the +size of the `lhs` batch dimension. If `batch_group_count` is greater than 1, it +means that conceptually the output batch dimension is split evenely in +`batch_group_count` groups, such that each group consists of a consecutive +subsequence of batches. Each output batch element is the reduced value of the +batch group size. + The output shape has these dimensions, in this order: -* `batch`: Same size as `batch` on the input (`lhs`). +* `batch`: The size of this dimension times `batch_group_count` should equal + the size of the `batch` dimension in lhs. * `z`: Same size as `output-z` on the kernel (`rhs`). * `spatial_dims`: One value for each valid placement of the convolutional window. @@ -658,15 +673,15 @@ Here is pseudo-code for a 2d convolution with padding and striding: ``` for (b, oz, oy, ox) { // output coordinates - value = 0; - for (iz, ky, kx) { // kernel coordinates and input z - iy = oy*stride_y + ky - pad_low_y; - ix = ox*stride_x + kx - pad_low_x; - if ((iy, ix) inside the base area considered without padding) { - value += input(b, iz, iy, ix) * kernel(oz, iz, ky, kx); - } - } - output(b, oz, oy, ox) = value; +value = 0; +for (iz, ky, kx) { // kernel coordinates and input z +iy = oy*stride_y + ky - pad_low_y; +ix = ox*stride_x + kx - pad_low_x; +if ((iy, ix) inside the base area considered without padding) { +value += input(b, iz, iy, ix) * kernel(oz, iz, ky, kx); +} +} +output(b, oz, oy, ox) = value; } ``` @@ -777,19 +792,19 @@ Here is an example of an implementation of `myfunc`: ``` extern "C" void myfunc(void* out, void** in) { - float (&x)[2] = *static_cast(in[0]); - float (&y)[2][3] = *static_cast(in[1]); - EXPECT_EQ(1, x[0]); - EXPECT_EQ(2, x[1]); - EXPECT_EQ(10, y[0][0]); - EXPECT_EQ(20, y[0][1]); - EXPECT_EQ(30, y[0][2]); - EXPECT_EQ(40, y[1][0]); - EXPECT_EQ(50, y[1][1]); - EXPECT_EQ(60, y[1][2]); - float (&z)[3][3] = *static_cast(out); - z[0][0] = x[1] + y[1][0]; - // ... +float (&x)[2] = *static_cast(in[0]); +float (&y)[2][3] = *static_cast(in[1]); +EXPECT_EQ(1, x[0]); +EXPECT_EQ(2, x[1]); +EXPECT_EQ(10, y[0][0]); +EXPECT_EQ(20, y[0][1]); +EXPECT_EQ(30, y[0][2]); +EXPECT_EQ(40, y[1][0]); +EXPECT_EQ(50, y[1][1]); +EXPECT_EQ(60, y[1][2]); +float (&z)[3][3] = *static_cast(out); +z[0][0] = x[1] + y[1][0]; +// ... } ``` @@ -864,17 +879,17 @@ Example with contracting dimension numbers: ``` lhs = { {1.0, 2.0, 3.0}, - {4.0, 5.0, 6.0} } +{4.0, 5.0, 6.0} } rhs = { {1.0, 1.0, 1.0}, - {2.0, 2.0, 2.0} } +{2.0, 2.0, 2.0} } DotDimensionNumbers dnums; dnums.add_lhs_contracting_dimensions(1); dnums.add_rhs_contracting_dimensions(1); DotGeneral(lhs, rhs, dnums) -> { {6.0, 12.0}, - {15.0, 30.0} } +{15.0, 30.0} } ``` Associated batch dimension numbers from the 'lhs' and 'rhs' must have the same @@ -886,14 +901,14 @@ Example with batch dimension numbers (batch size 2, 2x2 matrices): ``` lhs = { { {1.0, 2.0}, - {3.0, 4.0} }, - { {5.0, 6.0}, - {7.0, 8.0} } } +{3.0, 4.0} }, +{ {5.0, 6.0}, +{7.0, 8.0} } } rhs = { { {1.0, 0.0}, - {0.0, 1.0} }, - { {1.0, 0.0}, - {0.0, 1.0} } } +{0.0, 1.0} }, +{ {1.0, 0.0}, +{0.0, 1.0} } } DotDimensionNumbers dnums; dnums.add_lhs_contracting_dimensions(2); @@ -902,9 +917,9 @@ dnums.add_lhs_batch_dimensions(0); dnums.add_rhs_batch_dimensions(0); DotGeneral(lhs, rhs, dnums) -> { { {1.0, 2.0}, - {3.0, 4.0} }, - { {5.0, 6.0}, - {7.0, 8.0} } } +{3.0, 4.0} }, +{ {5.0, 6.0}, +{7.0, 8.0} } } ``` | Input | Output | Semantics | @@ -963,22 +978,22 @@ let a = {0.0, 1.0, 2.0, 3.0, 4.0} let s = {2} DynamicSlice(a, s, {2}) produces: - {2.0, 3.0} +{2.0, 3.0} ``` 2-dimensional example: ``` let b = - { {0.0, 1.0, 2.0}, - {3.0, 4.0, 5.0}, - {6.0, 7.0, 8.0}, - {9.0, 10.0, 11.0} } +{ {0.0, 1.0, 2.0}, +{3.0, 4.0, 5.0}, +{6.0, 7.0, 8.0}, +{9.0, 10.0, 11.0} } let s = {2, 1} DynamicSlice(b, s, {2, 2}) produces: - { { 7.0, 8.0}, - {10.0, 11.0} } +{ { 7.0, 8.0}, +{10.0, 11.0} } ``` ## DynamicUpdateSlice @@ -1027,29 +1042,29 @@ let u = {5.0, 6.0} let s = {2} DynamicUpdateSlice(a, u, s) produces: - {0.0, 1.0, 5.0, 6.0, 4.0} +{0.0, 1.0, 5.0, 6.0, 4.0} ``` 2-dimensional example: ``` let b = - { {0.0, 1.0, 2.0}, - {3.0, 4.0, 5.0}, - {6.0, 7.0, 8.0}, - {9.0, 10.0, 11.0} } +{ {0.0, 1.0, 2.0}, +{3.0, 4.0, 5.0}, +{6.0, 7.0, 8.0}, +{9.0, 10.0, 11.0} } let u = - { {12.0, 13.0}, - {14.0, 15.0}, - {16.0, 17.0} } +{ {12.0, 13.0}, +{14.0, 15.0}, +{16.0, 17.0} } let s = {1, 1} DynamicUpdateSlice(b, u, s) produces: - { {0.0, 1.0, 2.0}, - {3.0, 12.0, 13.0}, - {6.0, 14.0, 15.0}, - {9.0, 16.0, 17.0} } +{ {0.0, 1.0, 2.0}, +{3.0, 12.0, 13.0}, +{6.0, 14.0, 15.0}, +{9.0, 16.0, 17.0} } ``` ## Element-wise binary arithmetic operations @@ -1235,42 +1250,42 @@ shape of `start_indices` to be `[6,7,1]`). The bounds for the output array along dimension `i` is computed as follows: - 1. If `i` is present in `batch_dims` (i.e. is equal to `batch_dims[k]` for - some `k`) then we pick the corresponding dimension bounds out of - `start_indices.shape`, skipping `index_vector_dim` (i.e. pick - `start_indices.shape.dims`[`k`] if `k` < `index_vector_dim` and - `start_indices.shape.dims`[`k`+`1`] otherwise). +1. If `i` is present in `batch_dims` (i.e. is equal to `batch_dims[k]` for +some `k`) then we pick the corresponding dimension bounds out of +`start_indices.shape`, skipping `index_vector_dim` (i.e. pick +`start_indices.shape.dims`[`k`] if `k` < `index_vector_dim` and +`start_indices.shape.dims`[`k`+`1`] otherwise). - 2. If `i` is present in `offset_dims` (i.e. equal to `offset_dims`[`k`] for - some `k`) then we pick the corresponding bound out of `slice_sizes` after - accounting for `collapsed_slice_dims` (i.e. we pick - `adjusted_slice_sizes`[`k`] where `adjusted_slice_sizes` is `slice_sizes` - with the bounds at indices `collapsed_slice_dims` removed). +2. If `i` is present in `offset_dims` (i.e. equal to `offset_dims`[`k`] for +some `k`) then we pick the corresponding bound out of `slice_sizes` after +accounting for `collapsed_slice_dims` (i.e. we pick +`adjusted_slice_sizes`[`k`] where `adjusted_slice_sizes` is `slice_sizes` +with the bounds at indices `collapsed_slice_dims` removed). Formally, the operand index `In` corresponding to an output index `Out` is computed as follows: - 1. Let `G` = { `Out`[`k`] for `k` in `batch_dims` }. Use `G` to slice out - vector `S` such that `S`[`i`] = `start_indices`[Combine(`G`, `i`)] where - Combine(A, b) inserts b at position `index_vector_dim` into A. Note that - this is well defined even if `G` is empty -- if `G` is empty then `S` = - `start_indices`. - - 2. Create a starting index, `S``in`, into `operand` using `S` by - scattering `S` using `start_index_map`. More precisely: - 1. `S``in`[`start_index_map`[`k`]] = `S`[`k`] if `k` < - `start_index_map.size`. - 2. `S``in`[`_`] = `0` otherwise. - - 3. Create an index `O``in` into `operand` by scattering the indices - at the offset dimensions in `Out` according to the `collapsed_slice_dims` - set. More precisely: - 1. `O``in`[`expand_offset_dims`(`k`)] = - `Out`[`offset_dims`[`k`]] if `k` < `offset_dims.size` - (`expand_offset_dims` is defined below). - 2. `O``in`[`_`] = `0` otherwise. - 4. `In` is `O``in` + `S``in` where + is element-wise - addition. +1. Let `G` = { `Out`[`k`] for `k` in `batch_dims` }. Use `G` to slice out +vector `S` such that `S`[`i`] = `start_indices`[Combine(`G`, `i`)] where +Combine(A, b) inserts b at position `index_vector_dim` into A. Note that +this is well defined even if `G` is empty -- if `G` is empty then `S` = +`start_indices`. + +2. Create a starting index, `S``in`, into `operand` using `S` by +scattering `S` using `start_index_map`. More precisely: +1. `S``in`[`start_index_map`[`k`]] = `S`[`k`] if `k` < +`start_index_map.size`. +2. `S``in`[`_`] = `0` otherwise. + +3. Create an index `O``in` into `operand` by scattering the indices +at the offset dimensions in `Out` according to the `collapsed_slice_dims` +set. More precisely: +1. `O``in`[`expand_offset_dims`(`k`)] = +`Out`[`offset_dims`[`k`]] if `k` < `offset_dims.size` +(`expand_offset_dims` is defined below). +2. `O``in`[`_`] = `0` otherwise. +4. `In` is `O``in` + `S``in` where + is element-wise +addition. `expand_offset_dims` is the monotonic function with domain [`0`, `offset.size`) and range [`0`, `operand.rank`) \ `collapsed_slice_dims`. So if, e.g., @@ -1282,21 +1297,21 @@ and range [`0`, `operand.rank`) \ `collapsed_slice_dims`. So if, e.g., Informally, every index `Out` in the output array corresponds to an element `E` in the operand array, computed as follows: - - We use the batch dimensions in `Out` to look up a starting index from - `start_indices`. +- We use the batch dimensions in `Out` to look up a starting index from +`start_indices`. - - We use `start_index_map` to map the starting index (which may have size less - than operand.rank) to a "full" starting index into operand. +- We use `start_index_map` to map the starting index (which may have size less +than operand.rank) to a "full" starting index into operand. - - We dynamic-slice out a slice with size `slice_sizes` using the full starting - index. +- We dynamic-slice out a slice with size `slice_sizes` using the full starting +index. - - We reshape the slice by collapsing the `collapsed_slice_dims` dimensions. - Since all collapsed slice dimensions have to have bound 1 this reshape is - always legal. +- We reshape the slice by collapsing the `collapsed_slice_dims` dimensions. +Since all collapsed slice dimensions have to have bound 1 this reshape is +always legal. - - We use the offset dimensions in `Out` to index into this slice to get the - input element, `E`, corresponding to output index `Out`. +- We use the offset dimensions in `Out` to index into this slice to get the +input element, `E`, corresponding to output index `Out`. `index_vector_dim` is set to `start_indices.rank` - `1` in all of the examples that follow. More interesting values for `index_vector_dim` does not @@ -1315,7 +1330,7 @@ the output shape, and maps it to an element in the input array in the following way:
- +
We first select an (`X`,`Y`) vector from the gather indices array using `G`. @@ -1334,7 +1349,7 @@ version of the example above using a "gather indices" array of shape `[4,5,2]` would translate indices like this:
- +
Again, this acts as a batch dynamic slice `G``0` and @@ -1343,27 +1358,27 @@ Again, this acts as a batch dynamic slice `G``0` and The gather operation in XLA generalizes the informal semantics outlined above in the following ways: - 1. We can configure which dimensions in the output shape are the offset - dimensions (dimensions containing `O``0`, `O``1` in - the last example). The output batch dimensions (dimensions containing - `G``0`, `G``1` in the last example) are defined to be - the output dimensions that are not offset dimensions. +1. We can configure which dimensions in the output shape are the offset +dimensions (dimensions containing `O``0`, `O``1` in +the last example). The output batch dimensions (dimensions containing +`G``0`, `G``1` in the last example) are defined to be +the output dimensions that are not offset dimensions. - 2. The number of output offset dimensions explicitly present in the output - shape may be smaller than the input rank. These "missing" dimensions, which - are listed explicitly as `collapsed_slice_dims`, must have a slice size of - `1`. Since they have a slice size of `1` the only valid index for them is - `0` and eliding them does not introduce ambiguity. +2. The number of output offset dimensions explicitly present in the output +shape may be smaller than the input rank. These "missing" dimensions, which +are listed explicitly as `collapsed_slice_dims`, must have a slice size of +`1`. Since they have a slice size of `1` the only valid index for them is +`0` and eliding them does not introduce ambiguity. - 3. The slice extracted from the "Gather Indices" array ((`X`, `Y`) in the last - example) may have fewer elements than the input array rank, and an explicit - mapping dictates how the index should be expanded to have the same rank as - the input. +3. The slice extracted from the "Gather Indices" array ((`X`, `Y`) in the last +example) may have fewer elements than the input array rank, and an explicit +mapping dictates how the index should be expanded to have the same rank as +the input. As a final example, we use (2) and (3) to implement `tf.gather_nd`:
- +
`G``0` and `G``1` are used to slice out a starting index @@ -1442,11 +1457,11 @@ dependency between the while loops. ``` result1 = while (condition, init = init_value) { - Infeed(shape) +Infeed(shape) } result2 = while (condition, init = result1) { - Infeed(shape) +Infeed(shape) } ``` @@ -1464,7 +1479,9 @@ Infeed of the device. Builds a constant literal on device rather than a potentially large host transfer. Creates a rank 1 array of values starting at zero and incrementing by -one. +one. For floating-point types, the produced array is equivalent to +`ConvertElementType(Iota(...))` where the `Iota` is of integral type and the +conversion is to the floating-point type. Arguments | Type | Semantics ---------------- | --------------- | ------------------------------------ diff --git a/tensorflow/compiler/xla/layout.cc b/tensorflow/compiler/xla/layout.cc new file mode 100644 index 0000000000000000000000000000000000000000..e3b5fcd5274881cec31ecf906e3461685f82a1f4 --- /dev/null +++ b/tensorflow/compiler/xla/layout.cc @@ -0,0 +1,96 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/layout.h" + +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" +#include "tensorflow/compiler/xla/layout_util.h" + +namespace xla { + +TileProto Tile::ToProto() const { + TileProto tile_proto; + for (int64 i : dimensions()) { + tile_proto.add_dimensions(i); + } + return tile_proto; +} + +string Tile::ToString() const { + return absl::StrCat("(", absl::StrJoin(dimensions(), ","), ")"); +} + +/* static */ Layout Layout::CreateFromProto(const LayoutProto& proto) { + Layout layout; + layout.set_format(proto.format()); + layout.minor_to_major_.reserve(proto.minor_to_major_size()); + for (const int64 dimension : proto.minor_to_major()) { + layout.add_minor_to_major(dimension); + } + layout.set_max_sparse_elements(proto.max_sparse_elements()); + for (const TileProto& tile_proto : proto.tiles()) { + *layout.add_tiles() = Tile::CreateFromProto(tile_proto); + } + layout.set_element_size_in_bits(proto.element_size_in_bits()); + return layout; +} + +LayoutProto Layout::ToProto() const { + LayoutProto proto; + proto.set_format(format_); + proto.mutable_minor_to_major()->Reserve(minor_to_major_size()); + for (const int64 dimension : minor_to_major()) { + proto.add_minor_to_major(dimension); + } + proto.set_max_sparse_elements(max_sparse_elements_); + for (const Tile& tile : tiles()) { + *proto.add_tiles() = tile.ToProto(); + } + proto.set_element_size_in_bits(element_size_in_bits()); + return proto; +} + +string Layout::ToString() const { + // TODO(b/119839262): Emit tiles in string. + if (format() == SPARSE) { + return absl::StrCat("sparse{", max_sparse_elements(), "}"); + } else if (format() == DENSE) { + return absl::StrCat("{", absl::StrJoin(minor_to_major(), ","), "}"); + } else { + CHECK_EQ(format(), INVALID_FORMAT); + return "invalid{}"; + } +} + +bool Layout::operator==(const Layout& other) const { + return (other.format() == format() && + other.minor_to_major() == minor_to_major() && + other.element_size_in_bits() == element_size_in_bits() && + other.max_sparse_elements() == max_sparse_elements() && + other.tiles() == tiles()); +} + +std::ostream& operator<<(std::ostream& out, const Tile& tile) { + out << tile.ToString(); + return out; +} + +std::ostream& operator<<(std::ostream& out, const Layout& layout) { + out << layout.ToString(); + return out; +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/layout.h b/tensorflow/compiler/xla/layout.h new file mode 100644 index 0000000000000000000000000000000000000000..313368c39e4c976fc481941eb17325101f2ba69a --- /dev/null +++ b/tensorflow/compiler/xla/layout.h @@ -0,0 +1,187 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_LAYOUT_H_ +#define TENSORFLOW_COMPILER_XLA_LAYOUT_H_ + +#include + +#include "absl/types/span.h" + +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/platform/types.h" + +namespace xla { + +// Describes a tile used in tiling-based layout. Refer to +// g3doc/third_party/tensorflow/compiler/xla/g3doc/layout_with_tiling.md for +// details. +class Tile { + public: + Tile() = default; + explicit Tile(absl::Span dimensions) + : dimensions_(dimensions.begin(), dimensions.end()) {} + + // De/Serialize a Tile to and from a TileProto. + static Tile CreateFromProto(const TileProto& tile_proto) { + return Tile(AsInt64Slice(tile_proto.dimensions())); + } + TileProto ToProto() const; + + bool operator==(const Tile& other) const { + return dimensions() == other.dimensions(); + } + bool operator!=(const Tile& other) const { return !(*this == other); } + + string ToString() const; + + // Returns the bound of the tile in the given dimension index. + int64 dimension(int i) const { return dimensions_.at(i); } + + // Returns the dimensions of the tile. + const std::vector& dimensions() const { return dimensions_; } + + private: + // The bounds of the tile. + std::vector dimensions_; +}; + +class Layout { + public: + Layout() = default; + + // Constructs a dense layout with the given minor-to-major order. + explicit Layout(absl::Span minor_to_major) + : format_(DENSE), + minor_to_major_(minor_to_major.begin(), minor_to_major.end()) {} + + // Constructs a dense tiled layout with the given minor-to-major order and + // tiles. + Layout(absl::Span minor_to_major, absl::Span tiles) + : format_(DENSE), + minor_to_major_(minor_to_major.begin(), minor_to_major.end()), + tiles_(tiles.begin(), tiles.end()) {} + + // Construct a shape from a LayoutProto. + static Layout CreateFromProto(const LayoutProto& proto); + + // Returns a LayoutProto representation of the Layout. + LayoutProto ToProto() const; + + // Returns a human-readable string that represents this layout. + string ToString() const; + + bool operator==(const Layout& other) const; + bool operator!=(const Layout& other) const { return !(*this == other); } + + // The following methods mirror the protobuf generated code interface for the + // message LayoutProto. This enabled easy migration of this data structure + // from a proto to a proper C++ class. + // + // TODO(b/29771030): Replace or augment these methods with a more ergonomic + // interface. + + // Methods for accessing the format. + Format format() const { return format_; } + Layout& set_format(Format value) { + format_ = value; + return *this; + } + + // Methods for accessing the minor-to-major array. + int minor_to_major_size() const { return minor_to_major_.size(); } + int64 minor_to_major(int index) const { return minor_to_major_.at(index); } + Layout& set_minor_to_major(int index, int64 value) { + minor_to_major_.at(index) = value; + return *this; + } + Layout& add_minor_to_major(int64 value) { + minor_to_major_.push_back(value); + return *this; + } + Layout& clear_minor_to_major() { + minor_to_major_.clear(); + return *this; + } + const std::vector& minor_to_major() const { return minor_to_major_; } + std::vector* mutable_minor_to_major() { return &minor_to_major_; } + + // Methods for accessing the tile field. + int tiles_size() const { return tiles_.size(); } + const Tile& tiles(int index) const { return tiles_.at(index); } + Tile* mutable_tiles(int index) { return &tiles_.at(index); } + Tile* add_tiles() { + tiles_.push_back(Tile()); + return &tiles_.back(); + } + Layout& clear_tiles() { + tiles_.clear(); + return *this; + } + const std::vector& tiles() const { return tiles_; } + std::vector* mutable_tiles() { return &tiles_; } + + // Methods for accessing the int64 fields. + int64 max_sparse_elements() const { return max_sparse_elements_; } + Layout& set_max_sparse_elements(int64 value) { + max_sparse_elements_ = value; + return *this; + } + int64 element_size_in_bits() const { return element_size_in_bits_; } + Layout& set_element_size_in_bits(int64 value) { + element_size_in_bits_ = value; + return *this; + } + + void Swap(Layout* other) { + using std::swap; + swap(*this, *other); + } + + void Clear() { + format_ = INVALID_FORMAT; + minor_to_major_.clear(); + max_sparse_elements_ = 0; + element_size_in_bits_ = 0; + } + + public: + // The format of this layout. + Format format_ = INVALID_FORMAT; + + // Sequence of dimension numbers, from minor (fastest varying index) to major + // (slowest varying index). + std::vector minor_to_major_; + + // The maximum number of elements that can be stored for SPARSE formats. This + // can be used to determine the maximum size in bytes of arrays stored in + // memory. This field must be zero unless the format is SPARSE. + int64 max_sparse_elements_ = 0; + + // The number of bits used to store an individual array element. + int64 element_size_in_bits_ = 0; + + // The tiles used in tiling-based layout. + std::vector tiles_; +}; + +std::ostream& operator<<(std::ostream& out, const Tile& Tile); +std::ostream& operator<<(std::ostream& out, const Layout& layout); + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_LAYOUT_H_ diff --git a/tensorflow/compiler/xla/layout_test.cc b/tensorflow/compiler/xla/layout_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..fb6abd3f6523b978e72b21ec082ae06973e86243 --- /dev/null +++ b/tensorflow/compiler/xla/layout_test.cc @@ -0,0 +1,104 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/layout.h" + +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" +#include "tensorflow/compiler/xla/layout_util.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/test_helpers.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" + +namespace xla { +namespace { + +class LayoutTest : public ::testing::Test {}; + +TEST_F(LayoutTest, ToString) { + EXPECT_EQ(Layout().ToString(), "invalid{}"); + EXPECT_EQ(Layout({4, 5, 6}).ToString(), "{4,5,6}"); + EXPECT_EQ(Layout().set_format(SPARSE).set_max_sparse_elements(123).ToString(), + "sparse{123}"); + EXPECT_EQ(Layout({4, 5, 6}).ToString(), "{4,5,6}"); + EXPECT_EQ(Layout({3, 2, 1, 0}, {Tile({42, 123}), Tile({4, 5})}).ToString(), + "{3,2,1,0}"); + EXPECT_EQ( + Layout({1, 0}, {Tile({2, 55})}).set_element_size_in_bits(42).ToString(), + "{1,0}"); +} + +TEST_F(LayoutTest, StreamOut) { + { + std::ostringstream oss; + oss << Tile({7, 8}); + EXPECT_EQ(oss.str(), "(7,8)"); + } + + { + std::ostringstream oss; + oss << Layout({0, 1, 2}); + EXPECT_EQ(oss.str(), "{0,1,2}"); + } +} + +TEST_F(LayoutTest, SparseLayoutMaxElements) { + EXPECT_EQ(LayoutUtil::MaxSparseElements(LayoutUtil::MakeSparseLayout(101)), + 101); +} + +TEST_F(LayoutTest, Equality) { + EXPECT_EQ(Layout(), Layout()); + const std::vector empty_dims; + EXPECT_EQ(Layout(empty_dims), Layout(empty_dims)); + EXPECT_NE(Layout(), Layout(empty_dims)); + EXPECT_EQ(Layout({0, 1, 2, 3}), Layout({0, 1, 2, 3})); + EXPECT_NE(Layout({0, 1, 2, 3}), Layout({0, 1, 2})); + EXPECT_EQ(Layout({0, 1, 2}, {Tile({42, 44})}), + Layout({0, 1, 2}, {Tile({42, 44})})); + EXPECT_NE(Layout({0, 1, 2}, {Tile({42, 44})}), + Layout({0, 1, 2}, {Tile({42, 45})})); + EXPECT_NE(Layout({0, 1, 2}, {Tile({42, 44})}), Layout({0, 1, 2, 3})); + EXPECT_EQ(Layout({0, 1, 2}).set_element_size_in_bits(33), + Layout({0, 1, 2}).set_element_size_in_bits(33)); + EXPECT_NE(Layout({0, 1, 2}).set_element_size_in_bits(33), + Layout({0, 1, 2}).set_element_size_in_bits(7)); + EXPECT_EQ(Layout().set_format(SPARSE), Layout().set_format(SPARSE)); + EXPECT_EQ(Layout().set_format(SPARSE).set_max_sparse_elements(42), + Layout().set_format(SPARSE).set_max_sparse_elements(42)); + EXPECT_NE(Layout().set_format(SPARSE).set_max_sparse_elements(42), + Layout().set_format(SPARSE).set_max_sparse_elements(24)); +} + +TEST_F(LayoutTest, LayoutToFromProto) { + // Round-trips a Layout through proto de/serialization. + auto expect_unchanged = [](const Layout& layout) { + EXPECT_EQ(layout, Layout::CreateFromProto(layout.ToProto())); + }; + + expect_unchanged(Layout()); + expect_unchanged(Layout({1, 3, 2, 0})); + expect_unchanged(Layout().set_format(SPARSE)); + expect_unchanged(Layout().set_format(SPARSE).set_max_sparse_elements(123)); + expect_unchanged(Layout({0, 1}).set_element_size_in_bits(42)); + expect_unchanged(Layout({3, 2, 1, 0}, {Tile({42, 123}), Tile({4, 5})})); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/layout_util.cc b/tensorflow/compiler/xla/layout_util.cc index dbb81381acde645f08639737b6e7b6f6ad971f9b..ddccd8c798df5b926d2e5aea8975cb6cb6640824 100644 --- a/tensorflow/compiler/xla/layout_util.cc +++ b/tensorflow/compiler/xla/layout_util.cc @@ -41,15 +41,13 @@ namespace { // Internal helper for GetDefaultLayoutForShape and SetToDefaultLayout. Sets // minor_to_major to the value that represents the default layout. -void SetDefaultLayoutToContainer( - tensorflow::protobuf::RepeatedField* - minor_to_major) { +void SetDefaultLayoutToContainer(std::vector* minor_to_major) { // The default XLA layout is major-to-minor (dim 0 is major). // For more information on XLA layouts, see: // https://www.tensorflow.org/performance/xla/shapes const int64 size = minor_to_major->size(); for (int64 i = 0; i < size; ++i) { - minor_to_major->Set(i, size - 1 - i); + (*minor_to_major)[i] = size - 1 - i; } } @@ -94,9 +92,8 @@ namespace { Layout CreateDefaultLayoutForRank(int64 rank) { Layout layout; layout.set_format(DENSE); - tensorflow::protobuf::RepeatedField* - minor_to_major = layout.mutable_minor_to_major(); - minor_to_major->Resize(rank, 0); + std::vector* minor_to_major = layout.mutable_minor_to_major(); + minor_to_major->resize(rank, 0); SetDefaultLayoutToContainer(minor_to_major); return layout; } @@ -139,9 +136,8 @@ Layout CreateDefaultLayoutForRank(int64 rank) { shape->clear_layout(); } else if (ShapeUtil::IsArray(*shape)) { shape->mutable_layout()->set_format(DENSE); - tensorflow::protobuf::RepeatedField* - minor_to_major = shape->mutable_layout()->mutable_minor_to_major(); - minor_to_major->Resize(shape->dimensions_size(), 0); + auto* minor_to_major = shape->mutable_layout()->mutable_minor_to_major(); + minor_to_major->resize(shape->dimensions_size(), 0); SetDefaultLayoutToContainer(minor_to_major); } else { // Opaque, token types etc. have no layout. @@ -210,9 +206,8 @@ Layout CreateDefaultLayoutForRank(int64 rank) { } if (layout.format() == INVALID_FORMAT || !Format_IsValid(layout.format())) { - return InvalidArgument( - "Layout has an invalid format (%d) in layout {%s}, shape {%s}", - layout.format(), layout.ShortDebugString(), shape.ShortDebugString()); + return InvalidArgument("Layout has an invalid format (%d)", + layout.format()); } if (layout.format() == DENSE) { @@ -316,7 +311,7 @@ Layout CreateDefaultLayoutForRank(int64 rank) { } /* static */ bool LayoutUtil::Equal(const Layout& lhs, const Layout& rhs) { - return protobuf_util::ProtobufEquals(lhs, rhs); + return lhs == rhs; } /* static */ absl::Span LayoutUtil::MinorToMajor( @@ -358,11 +353,7 @@ Layout CreateDefaultLayoutForRank(int64 rank) { } /* static */ string LayoutUtil::HumanString(const Layout& layout) { - if (IsSparse(layout)) { - return absl::StrCat("sparse{", layout.max_sparse_elements(), "}"); - } - CHECK(IsDense(layout)); - return absl::StrCat("{", absl::StrJoin(layout.minor_to_major(), ","), "}"); + return layout.ToString(); } namespace { @@ -444,11 +435,6 @@ Status LayoutUtil::CopyLayoutBetweenShapes(const Shape& src, Shape* dst) { return true; } -std::ostream& operator<<(std::ostream& out, const Layout& layout) { - out << LayoutUtil::HumanString(layout); - return out; -} - /*static*/ size_t LayoutUtil::Hash(const Layout& layout) { using tensorflow::hash; using tensorflow::Hash64Combine; diff --git a/tensorflow/compiler/xla/layout_util.h b/tensorflow/compiler/xla/layout_util.h index 6c298e57252449ce3f1f9055436e918f2d9f17f1..609dba67bcdbcb11be0906b7d87a52a17ba0dfbd 100644 --- a/tensorflow/compiler/xla/layout_util.h +++ b/tensorflow/compiler/xla/layout_util.h @@ -21,6 +21,7 @@ limitations under the License. #include #include "absl/types/span.h" +#include "tensorflow/compiler/xla/layout.h" #include "tensorflow/compiler/xla/shape.h" #include "tensorflow/compiler/xla/status.h" #include "tensorflow/compiler/xla/types.h" @@ -195,8 +196,6 @@ class LayoutUtil { TF_DISALLOW_COPY_AND_ASSIGN(LayoutUtil); }; -std::ostream& operator<<(std::ostream& out, const Layout& layout); - } // namespace xla #endif // TENSORFLOW_COMPILER_XLA_LAYOUT_UTIL_H_ diff --git a/tensorflow/compiler/xla/layout_util_test.cc b/tensorflow/compiler/xla/layout_util_test.cc index 12ce2d2d7c6fa8c590035f9ff2af50001ccf80d8..4cc94c270cd64eb19761cc1044861c7d185b7888 100644 --- a/tensorflow/compiler/xla/layout_util_test.cc +++ b/tensorflow/compiler/xla/layout_util_test.cc @@ -317,17 +317,6 @@ TEST_F(LayoutUtilTest, DefaultLayoutGettersMajorToMinor) { ShapeUtil::MakeShape(F32, {10, 20, 30, 15, 25})))); } -TEST_F(LayoutUtilTest, SparseLayoutMaxElements) { - EXPECT_EQ(LayoutUtil::MaxSparseElements(LayoutUtil::MakeSparseLayout(101)), - 101); -} - -TEST_F(LayoutUtilTest, StreamOut) { - std::ostringstream oss; - oss << LayoutUtil::MakeLayout({0, 1, 2}); - EXPECT_EQ(oss.str(), "{0,1,2}"); -} - TEST_F(LayoutUtilTest, ValidateLayout_ValidArrayLayout) { Shape shape = ShapeUtil::MakeShapeWithLayout(F32, {2, 3}, {0, 1}); auto status = diff --git a/tensorflow/compiler/xla/literal.cc b/tensorflow/compiler/xla/literal.cc index 8f480c1f1079b4e1a5be53958ebdf6e004ad9ebe..277c98721e59ac12965392500fdfdc3d91e59a8b 100644 --- a/tensorflow/compiler/xla/literal.cc +++ b/tensorflow/compiler/xla/literal.cc @@ -1028,20 +1028,21 @@ string ShapeToString(bool print_layout, const Shape& shape) { } void ToStringHelper(const LiteralBase& literal, const ShapeIndex& shape_index, - bool print_layout, std::vector* pieces); + bool print_shape, bool print_layout, + std::vector* pieces); void TupleToStringHelper(const LiteralBase& literal, - const ShapeIndex& shape_index, bool print_layout, - std::vector* pieces) { + const ShapeIndex& shape_index, bool print_shape, + bool print_layout, std::vector* pieces) { const Shape& subshape = ShapeUtil::GetSubshape(literal.shape(), shape_index); - pieces->push_back(ShapeToString(print_layout, subshape)); - pieces->push_back(" (\n"); + pieces->push_back("(\n"); std::vector tuple_pieces; for (int i = 0; i < ShapeUtil::TupleElementCount(subshape); ++i) { ShapeIndex element_index = shape_index; element_index.push_back(i); std::vector element_pieces; - ToStringHelper(literal, element_index, print_layout, &element_pieces); + ToStringHelper(literal, element_index, print_shape, print_layout, + &element_pieces); tuple_pieces.push_back(absl::StrJoin(element_pieces, "")); } pieces->push_back(absl::StrJoin(tuple_pieces, ",\n")); @@ -1049,9 +1050,11 @@ void TupleToStringHelper(const LiteralBase& literal, } void SparseArrayToStringHelper(const LiteralBase& literal, - const Shape& subshape, bool print_layout, - std::vector* pieces) { - pieces->push_back(ShapeToString(print_layout, subshape)); + const Shape& subshape, bool print_shape, + bool print_layout, std::vector* pieces) { + if (print_shape) { + pieces->push_back(ShapeToString(print_layout, subshape)); + } pieces->push_back("{"); int64 rank = ShapeUtil::Rank(subshape); int64 num_elements = literal.sparse_element_count(); @@ -1073,8 +1076,8 @@ void SparseArrayToStringHelper(const LiteralBase& literal, } void DenseArrayToStringHelper(const LiteralBase& literal, - const ShapeIndex& shape_index, bool print_layout, - std::vector* pieces) { + const ShapeIndex& shape_index, bool print_shape, + bool print_layout, std::vector* pieces) { const Shape& subshape = ShapeUtil::GetSubshape(literal.shape(), shape_index); int64 rank = ShapeUtil::Rank(subshape); @@ -1135,7 +1138,7 @@ void DenseArrayToStringHelper(const LiteralBase& literal, } }; - if (rank > 1) { + if (print_shape) { pieces->push_back(ShapeToString(print_layout, subshape)); pieces->push_back(" "); } @@ -1146,19 +1149,23 @@ void DenseArrayToStringHelper(const LiteralBase& literal, } void ToStringHelper(const LiteralBase& literal, const ShapeIndex& shape_index, - bool print_layout, std::vector* pieces) { + bool print_shape, bool print_layout, + std::vector* pieces) { const Shape& subshape = ShapeUtil::GetSubshape(literal.shape(), shape_index); CHECK(LayoutUtil::HasLayout(literal.shape())); CHECK(LayoutUtil::HasLayout(subshape)); if (ShapeUtil::IsTuple(subshape)) { - TupleToStringHelper(literal, shape_index, print_layout, pieces); + TupleToStringHelper(literal, shape_index, print_shape, print_layout, + pieces); } else if (ShapeUtil::IsToken(subshape)) { pieces->push_back("token"); } else if (LayoutUtil::IsSparseArray(subshape)) { - SparseArrayToStringHelper(literal, subshape, print_layout, pieces); + SparseArrayToStringHelper(literal, subshape, print_shape, print_layout, + pieces); } else { CHECK(LayoutUtil::IsDenseArray(subshape)); - DenseArrayToStringHelper(literal, shape_index, print_layout, pieces); + DenseArrayToStringHelper(literal, shape_index, print_shape, print_layout, + pieces); } } @@ -1169,10 +1176,27 @@ int64 LiteralBase::sparse_element_count() const { return sparse_indices()->index_count(); } -string LiteralBase::ToString(bool print_layout) const { +string LiteralBase::ToString() const { + std::vector pieces; + CHECK(LayoutUtil::HasLayout(this->shape())); + ToStringHelper(*this, {}, /*print_shape=*/true, + /*print_layout=*/false, &pieces); + return absl::StrJoin(pieces, ""); +} + +string LiteralBase::ToStringWithoutShape() const { + std::vector pieces; + CHECK(LayoutUtil::HasLayout(this->shape())); + ToStringHelper(*this, {}, /*print_shape=*/false, + /*print_layout=*/false, &pieces); + return absl::StrJoin(pieces, ""); +} + +string LiteralBase::ToStringWithLayout() const { std::vector pieces; CHECK(LayoutUtil::HasLayout(this->shape())); - ToStringHelper(*this, {}, print_layout, &pieces); + ToStringHelper(*this, {}, /*print_shape=*/true, + /*print_layout=*/true, &pieces); return absl::StrJoin(pieces, ""); } diff --git a/tensorflow/compiler/xla/literal.h b/tensorflow/compiler/xla/literal.h index fa9a71af4ceb998a7a289443cbef70eb52cb1a11..67e908e7ec4d4346f4e26a99a42aac26928ec0c2 100644 --- a/tensorflow/compiler/xla/literal.h +++ b/tensorflow/compiler/xla/literal.h @@ -92,9 +92,20 @@ class LiteralBase { // array. string GetR1U8AsString() const; - // Returns a string representation of the literal value. - // Warning: this function can take minutes for multi-million element Literals. - string ToString(bool print_layout = false) const; + // Returns a string representation of the literal value. The Shape of the + // literal is a prefix of the literal value in the string. + + // Warning: this function can take minutes for multi-million + // element Literals. + string ToString() const; + + // Returns a string representation of the literal value which does *not* + // include the shape string. + string ToStringWithoutShape() const; + + // Returns a string representation of the literal value which includes the + // shape string with its layout.does *not* include the shape string. + string ToStringWithLayout() const; // Gets an element in the literal at the given index. The multi_index is // CHECKed against the dimension sizes. diff --git a/tensorflow/compiler/xla/literal_comparison.cc b/tensorflow/compiler/xla/literal_comparison.cc index b044f0ad73f13a0599e77f1f43888bc974e31f73..1ac9a48e805daa86f0dc65b54626195c89241020 100644 --- a/tensorflow/compiler/xla/literal_comparison.cc +++ b/tensorflow/compiler/xla/literal_comparison.cc @@ -46,68 +46,102 @@ uint16 GetRawValue(Eigen::half val) { return val.x; } // between the left-hand-side and right-hand-side, by bit-casting to UnsignedT // -- on miscompare, a nice error message is given in the AssertionFailure. template -Status CompareFloatsBitwiseEqual(FloatT lhs, FloatT rhs, - absl::Span multi_index) { +bool CompareFloatsBitwiseEqual(FloatT lhs, FloatT rhs, + absl::Span multi_index) { + auto ulhs = absl::bit_cast(GetRawValue(lhs)); + auto urhs = absl::bit_cast(GetRawValue(rhs)); + return ulhs == urhs; +} + +// Templated comparator that specializes for float equality comparison with the +// bitwise helper above (this is the un-specialized fallback, to just use the +// default gunit implementation). +template +bool CompareEqual(NativeT lhs, NativeT rhs, + absl::Span multi_index) { + return lhs == rhs; +} + +// Specializations for floating types that do bitwise comparisons when equality +// comparison is requested. +template <> +bool CompareEqual(bfloat16 lhs, bfloat16 rhs, + absl::Span multi_index) { + return CompareFloatsBitwiseEqual(lhs, rhs, multi_index); +} +template <> +bool CompareEqual(Eigen::half lhs, Eigen::half rhs, + absl::Span multi_index) { + return CompareFloatsBitwiseEqual(lhs, rhs, multi_index); +} +template <> +bool CompareEqual(float lhs, float rhs, + absl::Span multi_index) { + return CompareFloatsBitwiseEqual(lhs, rhs, multi_index); +} +template <> +bool CompareEqual(double lhs, double rhs, + absl::Span multi_index) { + return CompareFloatsBitwiseEqual(lhs, rhs, multi_index); +} +template <> +bool CompareEqual(complex64 lhs, complex64 rhs, + absl::Span multi_index) { + return CompareEqual(lhs.real(), rhs.real(), multi_index) && + CompareEqual(lhs.imag(), rhs.imag(), multi_index); +} + +template +Status MakeBitwiseErrorStatus(NativeT lhs, NativeT rhs, + absl::Span multi_index) { auto ulhs = absl::bit_cast(GetRawValue(lhs)); auto urhs = absl::bit_cast(GetRawValue(rhs)); auto lhs_double = static_cast(lhs); auto rhs_double = static_cast(rhs); - if (ulhs != urhs) { return InvalidArgument( "floating values are not bitwise-equal; and equality testing " "was requested: %s=%g=%a vs %s=%g=%a at array index %s", StrCat(absl::Hex(ulhs)), lhs_double, lhs_double, StrCat(absl::Hex(urhs)), rhs_double, rhs_double, LiteralUtil::MultiIndexAsString(multi_index)); - } - return Status::OK(); } -// Templated comparator that specializes for float equality comparison with the -// bitwise helper above (this is the un-specialized fallback, to just use the -// default gunit implementation). template -Status CompareEqual(NativeT lhs, NativeT rhs, - absl::Span multi_index) { - if (lhs == rhs) { - return Status::OK(); - } +Status MakeErrorStatus(NativeT lhs, NativeT rhs, + absl::Span multi_index) { return InvalidArgument( "first mismatch at array index %s:\n expected value: %s\n actual " "value: %s", LiteralUtil::MultiIndexAsString(multi_index), StrCat(lhs), StrCat(rhs)); } -// Specializations for floating types that do bitwise comparisons when equality -// comparison is requested. template <> -Status CompareEqual(bfloat16 lhs, bfloat16 rhs, - absl::Span multi_index) { - return CompareFloatsBitwiseEqual(lhs, rhs, multi_index); +Status MakeErrorStatus(bfloat16 lhs, bfloat16 rhs, + absl::Span multi_index) { + return MakeBitwiseErrorStatus(lhs, rhs, multi_index); } template <> -Status CompareEqual(Eigen::half lhs, Eigen::half rhs, - absl::Span multi_index) { - return CompareFloatsBitwiseEqual(lhs, rhs, multi_index); +Status MakeErrorStatus(Eigen::half lhs, Eigen::half rhs, + absl::Span multi_index) { + return MakeBitwiseErrorStatus(lhs, rhs, multi_index); } template <> -Status CompareEqual(float lhs, float rhs, - absl::Span multi_index) { - return CompareFloatsBitwiseEqual(lhs, rhs, multi_index); +Status MakeErrorStatus(float lhs, float rhs, + absl::Span multi_index) { + return MakeBitwiseErrorStatus(lhs, rhs, multi_index); } template <> -Status CompareEqual(double lhs, double rhs, - absl::Span multi_index) { - return CompareFloatsBitwiseEqual(lhs, rhs, multi_index); +Status MakeErrorStatus(double lhs, double rhs, + absl::Span multi_index) { + return MakeBitwiseErrorStatus(lhs, rhs, multi_index); } template <> -Status CompareEqual(complex64 lhs, complex64 rhs, - absl::Span multi_index) { - auto res = CompareEqual(lhs.real(), rhs.real(), multi_index); - if (!res.ok()) { - return res; +Status MakeErrorStatus(complex64 lhs, complex64 rhs, + absl::Span multi_index) { + if (!CompareEqual(lhs.real(), rhs.real(), multi_index)) { + return MakeErrorStatus(lhs.real(), rhs.real(), multi_index); } - return CompareEqual(lhs.imag(), rhs.imag(), multi_index); + return MakeErrorStatus(lhs.imag(), rhs.imag(), multi_index); } // A recursive function which iterates through every index of expected and @@ -119,7 +153,11 @@ Status Equal(LiteralSlice expected, LiteralSlice actual, if (dimension == expected.shape().dimensions_size()) { NativeT expected_value = expected.Get(multi_index); NativeT actual_value = actual.Get(multi_index); - return CompareEqual(expected_value, actual_value, multi_index); + bool result = + CompareEqual(expected_value, actual_value, multi_index); + return result ? Status::OK() + : MakeErrorStatus(expected_value, actual_value, + multi_index); } Status result; @@ -330,7 +368,7 @@ class NearComparator { NanMismatch(expected, actual, error_.relaxed_nans); float abs_error; float rel_error; - if (CompareEqual(expected, actual, {linear_index}).ok()) { + if (CompareEqual(expected, actual, {linear_index})) { abs_error = 0; rel_error = 0; } else if (is_nan_mismatch) { @@ -344,7 +382,7 @@ class NearComparator { } else if (IsInf(expected) || IsInf(actual)) { // If either the expected or actual value is infinity but not both, // then both absolute and relative error are regarded as inifity. - CHECK(!CompareEqual(expected, actual, {linear_index}).ok()); + CHECK(!CompareEqual(expected, actual, {linear_index})); abs_error = std::numeric_limits::infinity(); rel_error = std::numeric_limits::infinity(); } else { diff --git a/tensorflow/compiler/xla/literal_test.cc b/tensorflow/compiler/xla/literal_test.cc index 49363ad802ddb9520f89b53257216bc7ddaf8ff5..d8c7141cacb8f60cb4ce56d07ac5827a8dbf9b20 100644 --- a/tensorflow/compiler/xla/literal_test.cc +++ b/tensorflow/compiler/xla/literal_test.cc @@ -98,42 +98,42 @@ class LiteralUtilTest : public ::testing::Test { TEST_F(LiteralUtilTest, LiteralScalarToString) { auto true_lit = LiteralUtil::CreateR0(true); - EXPECT_EQ("true", true_lit.ToString()); + EXPECT_EQ("pred[] true", true_lit.ToString()); auto false_lit = LiteralUtil::CreateR0(false); - EXPECT_EQ("false", false_lit.ToString()); + EXPECT_EQ("pred[] false", false_lit.ToString()); auto u32_lit = LiteralUtil::CreateR0(42); - EXPECT_EQ("42", u32_lit.ToString()); + EXPECT_EQ("u32[] 42", u32_lit.ToString()); auto s32_lit = LiteralUtil::CreateR0(-999); - EXPECT_EQ("-999", s32_lit.ToString()); + EXPECT_EQ("s32[] -999", s32_lit.ToString()); auto f32_lit = LiteralUtil::CreateR0(3.14f); - EXPECT_EQ("3.14", f32_lit.ToString()); + EXPECT_EQ("f32[] 3.14", f32_lit.ToString()); auto f16_lit = LiteralUtil::CreateR0(static_cast(0.5f)); - EXPECT_EQ("0.5", f16_lit.ToString()); + EXPECT_EQ("f16[] 0.5", f16_lit.ToString()); auto c64_lit = LiteralUtil::CreateR0({3.14f, 2.78f}); - EXPECT_EQ("(3.14, 2.78)", c64_lit.ToString()); + EXPECT_EQ("c64[] (3.14, 2.78)", c64_lit.ToString()); auto bf16_lit = LiteralUtil::CreateR0(static_cast(0.5f)); - EXPECT_EQ("0.5", bf16_lit.ToString()); + EXPECT_EQ("bf16[] 0.5", bf16_lit.ToString()); // 3.14 will be rounded to 3.14062 in bfloat16 format. auto bf16_lit_truncated = LiteralUtil::CreateR0(static_cast(3.14f)); - ASSERT_EQ("3.14062", bf16_lit_truncated.ToString()); + ASSERT_EQ("bf16[] 3.14062", bf16_lit_truncated.ToString()); auto bf16_lit_truncated2 = LiteralUtil::CreateR0(static_cast(9.001f)); - EXPECT_EQ("9", bf16_lit_truncated2.ToString()); + EXPECT_EQ("bf16[] 9", bf16_lit_truncated2.ToString()); } TEST_F(LiteralUtilTest, LiteralVectorToString) { auto pred_vec = LiteralUtil::CreateR1({true, false, true}); - EXPECT_EQ("{1, 0, 1}", pred_vec.ToString()); + EXPECT_EQ("pred[3] {1, 0, 1}", pred_vec.ToString()); } TEST_F(LiteralUtilTest, R2ToString) { @@ -210,8 +210,8 @@ TEST_F(LiteralUtilTest, TupleToString) { auto scalar = LiteralUtil::CreateR0(1.0); auto matrix = LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}}); auto tuple = LiteralUtil::MakeTuple({&scalar, &matrix}); - const string expected = R"((f32[], f32[2,2]) ( -1, + const string expected = R"(( +f32[] 1, f32[2,2] { { 1, 2 }, { 3, 4 } @@ -1890,7 +1890,7 @@ TEST_F(LiteralUtilTest, SortSparseElements) { literal.AppendSparseElement({3, 4, 5}, 3.0); literal.AppendSparseElement({1, 2, 3}, 1.0); literal.SortSparseElements(); - EXPECT_EQ(literal.ToString(false), + EXPECT_EQ(literal.ToString(), "f32[10,10,10]{[1, 2, 3]: 1, [2, 3, 4]: 2, [3, 4, 5]: 3}"); } diff --git a/tensorflow/compiler/xla/packed_literal_reader.cc b/tensorflow/compiler/xla/packed_literal_reader.cc index 0f86f9f35e105713aa3072a9ebf572d33d35d66d..339660cf44fd64fc5859e72255d63762fcf20efe 100644 --- a/tensorflow/compiler/xla/packed_literal_reader.cc +++ b/tensorflow/compiler/xla/packed_literal_reader.cc @@ -42,8 +42,7 @@ PackedLiteralReader::~PackedLiteralReader() { delete file_; } StatusOr PackedLiteralReader::Read(const Shape& shape, const Layout* layout) { VLOG(3) << "reading shape from file: " << ShapeUtil::HumanString(shape) - << " layout: " - << (layout == nullptr ? "" : layout->ShortDebugString()); + << " layout: " << (layout == nullptr ? "" : layout->ToString()); Shape literal_shape = shape; if (layout != nullptr) { TF_RETURN_IF_ERROR( diff --git a/tensorflow/compiler/xla/primitive_util.cc b/tensorflow/compiler/xla/primitive_util.cc index b16147e3be71771269d8b7a18528bef3a8c72d99..00ad01fc407017624a9183d69e61cb0d382e3f11 100644 --- a/tensorflow/compiler/xla/primitive_util.cc +++ b/tensorflow/compiler/xla/primitive_util.cc @@ -15,6 +15,9 @@ limitations under the License. #include "tensorflow/compiler/xla/primitive_util.h" +#include "absl/strings/ascii.h" +#include "absl/strings/numbers.h" +#include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/platform/logging.h" namespace xla { @@ -90,5 +93,65 @@ bool IsArrayType(PrimitiveType primitive_type) { primitive_type != OPAQUE && primitive_type != TOKEN; } +// Class to memoize the computation of +// absl::AsciiStrToLower(PrimitiveType_Name(p)) +// for all PrimitiveType values "p" +class PrimitiveTypeNameGenerator { + public: + PrimitiveTypeNameGenerator() { + for (int i = 0; i < PrimitiveType_ARRAYSIZE; i++) { + if (PrimitiveType_IsValid(i)) { + lowercase_name_[i] = absl::AsciiStrToLower( + PrimitiveType_Name(static_cast(i))); + } + } + } + const string& LowercaseName(PrimitiveType t) { + return lowercase_name_[static_cast(t)]; + } + + private: + string lowercase_name_[PrimitiveType_ARRAYSIZE]; +}; + +const string& LowercasePrimitiveTypeName(PrimitiveType s) { + static auto* gen = new PrimitiveTypeNameGenerator(); + return gen->LowercaseName(s); +} + +namespace { + +// Returns a map from lower-case primitive type name to primitive type. +const std::unordered_map& GetPrimitiveTypeStringMap() { + static std::unordered_map* name_to_type = [] { + static auto* map = new std::unordered_map; + for (int i = 0; i < PrimitiveType_ARRAYSIZE; i++) { + if (PrimitiveType_IsValid(i) && i != PRIMITIVE_TYPE_INVALID) { + auto value = static_cast(i); + (*map)[LowercasePrimitiveTypeName(value)] = value; + } + } + return map; + }(); + return *name_to_type; +} + +} // namespace + +StatusOr StringToPrimitiveType(absl::string_view name) { + const auto& map = GetPrimitiveTypeStringMap(); + auto found = map.find(string(name)); + if (found == map.end()) { + return InvalidArgument("Invalid element type string: \"%s\".", name); + } + return found->second; +} + +bool IsPrimitiveTypeName(absl::string_view name) { + const auto& map = GetPrimitiveTypeStringMap(); + auto found = map.find(string(name)); + return found != map.end(); +} + } // namespace primitive_util } // namespace xla diff --git a/tensorflow/compiler/xla/primitive_util.h b/tensorflow/compiler/xla/primitive_util.h index 889e9a1ceca675689406d255d348c82c398563aa..70603b6fed1be50c427799e6dce7b8bf9631a6f4 100644 --- a/tensorflow/compiler/xla/primitive_util.h +++ b/tensorflow/compiler/xla/primitive_util.h @@ -20,6 +20,9 @@ limitations under the License. #include +#include "absl/strings/string_view.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" @@ -221,6 +224,17 @@ template <> struct PrimitiveTypeToNative { using type = complex64; }; + +// Returns the lower-case name of the given primitive type. +const string& LowercasePrimitiveTypeName(PrimitiveType s); + +// Returns the PrimitiveType matching the given name. The given name is expected +// to be lower-case. +StatusOr StringToPrimitiveType(absl::string_view name); + +// Returns true if the given name is a primitive type string (lower-case). +bool IsPrimitiveTypeName(absl::string_view name); + } // namespace primitive_util } // namespace xla diff --git a/tensorflow/compiler/xla/primitive_util_test.cc b/tensorflow/compiler/xla/primitive_util_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..1f765d6da9ef65849fe8ede56ced7597d623cb59 --- /dev/null +++ b/tensorflow/compiler/xla/primitive_util_test.cc @@ -0,0 +1,46 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/primitive_util.h" + +#include +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/test_helpers.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" + +namespace xla { +namespace { + +TEST(PrimitiveUtilTest, StringToPrimitiveType) { + auto expect_ok_and_equal = [](const string& str, PrimitiveType expected) { + TF_ASSERT_OK_AND_ASSIGN(PrimitiveType actual, + primitive_util::StringToPrimitiveType(str)); + EXPECT_EQ(expected, actual); + }; + expect_ok_and_equal("f32", F32); + expect_ok_and_equal("tuple", TUPLE); + expect_ok_and_equal("pred", PRED); + expect_ok_and_equal("s32", S32); + + EXPECT_IS_NOT_OK(primitive_util::StringToPrimitiveType("F32").status()); + EXPECT_IS_NOT_OK(primitive_util::StringToPrimitiveType("Pred").status()); + EXPECT_IS_NOT_OK(primitive_util::StringToPrimitiveType("preD").status()); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/python/BUILD b/tensorflow/compiler/xla/python/BUILD index 63ac1c6649210cbae9e238a74e0a45fb8ee4da63..ddffafa9017a565f01c3214360a958e6840e9148 100644 --- a/tensorflow/compiler/xla/python/BUILD +++ b/tensorflow/compiler/xla/python/BUILD @@ -4,6 +4,7 @@ package(default_visibility = ["//tensorflow:internal"]) load("//tensorflow:tensorflow.bzl", "tf_py_wrap_cc") load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda_is_configured") +load("//tensorflow/core:platform/default/build_config.bzl", "pyx_library") py_library( name = "xla_client", @@ -17,6 +18,12 @@ py_library( ], ) +pyx_library( + name = "custom_call_for_test", + testonly = True, + srcs = ["custom_call_for_test.pyx"], +) + py_test( name = "xla_client_test", srcs = ["xla_client_test.py"], @@ -24,6 +31,7 @@ py_test( srcs_version = "PY2AND3", tags = ["no_oss"], deps = [ + ":custom_call_for_test", ":xla_client", "//tensorflow/python:platform_test", ], @@ -66,13 +74,18 @@ cc_library( "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/client:xla_computation", + "//tensorflow/compiler/xla/client/lib:cholesky", "//tensorflow/compiler/xla/client/lib:math", + "//tensorflow/compiler/xla/client/lib:qr", + "//tensorflow/compiler/xla/client/lib:triangular_solve", "//tensorflow/compiler/xla/service:platform_util", "//tensorflow/compiler/xla/service:shaped_buffer", + "//tensorflow/compiler/xla/service/cpu:custom_call_target_registry", "//tensorflow/compiler/xrt:xrt_proto", "//tensorflow/compiler/xrt/cc:xrt_ops", "//tensorflow/core:framework", "//tensorflow/core:lib", + "//third_party/python_runtime:headers", "@com_google_absl//absl/memory", "@com_google_absl//absl/types:span", ], diff --git a/tensorflow/compiler/xla/python/custom_call_for_test.pyx b/tensorflow/compiler/xla/python/custom_call_for_test.pyx new file mode 100644 index 0000000000000000000000000000000000000000..530dffd1755d8438f52569c223525000c97df6ea --- /dev/null +++ b/tensorflow/compiler/xla/python/custom_call_for_test.pyx @@ -0,0 +1,21 @@ +# distutils: language = c++ + +# Test case for defining a XLA custom call target in Cython, and registering +# it via the xla_client SWIG API. + +from cpython.pycapsule cimport PyCapsule_New + +cdef void test_subtract_f32(void* out_ptr, void** data_ptr) nogil: + cdef float a = ((data_ptr[0]))[0] + cdef float b = ((data_ptr[1]))[0] + cdef float* out = (out_ptr) + out[0] = a - b + + +cpu_custom_call_targets = {} + +cdef register_custom_call_target(fn_name, void* fn): + cdef const char* name = "xla._CPU_CUSTOM_CALL_TARGET" + cpu_custom_call_targets[fn_name] = PyCapsule_New(fn, name, NULL) + +register_custom_call_target(b"test_subtract_f32", (test_subtract_f32)) diff --git a/tensorflow/compiler/xla/python/local_computation_builder.cc b/tensorflow/compiler/xla/python/local_computation_builder.cc index 6e2ee866321a070d55a7221c7c68024ceaa93448..657a09f92ad14d959416c768b09c392ff17f96eb 100644 --- a/tensorflow/compiler/xla/python/local_computation_builder.cc +++ b/tensorflow/compiler/xla/python/local_computation_builder.cc @@ -24,12 +24,16 @@ limitations under the License. #include "tensorflow/cc/framework/ops.h" #include "tensorflow/cc/framework/scope.h" #include "tensorflow/cc/ops/standard_ops.h" +#include "tensorflow/compiler/xla/client/lib/cholesky.h" #include "tensorflow/compiler/xla/client/lib/math.h" +#include "tensorflow/compiler/xla/client/lib/qr.h" +#include "tensorflow/compiler/xla/client/lib/triangular_solve.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/executable_run_options.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/service/cpu/custom_call_target_registry.h" #include "tensorflow/compiler/xla/service/platform_util.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/util.h" @@ -112,6 +116,20 @@ LocalClient* GetOrCreateLocalClient() { return g_local_client; } +Status RegisterCpuCustomCallTarget(const string& fn_name, PyObject* capsule) { + const char* name = "xla._CPU_CUSTOM_CALL_TARGET"; + if (!PyCapsule_IsValid(capsule, name)) { + return InvalidArgument( + "Argument to RegisterCpuCustomCallTargetRegistry was not a " + "xla._CPU_CUSTOM_CALL_TARGET capsule."); + } + void* fn_ptr = PyCapsule_GetPointer(capsule, name); + CHECK(fn_ptr != nullptr); + cpu::CustomCallTargetRegistry::Global()->Register( + std::string(fn_name.begin(), fn_name.end()), fn_ptr); + return Status::OK(); +} + Status TransferToInfeedLocal(const Literal& literal) { VLOG(1) << "Infeeding literal without replica number; shape: " << literal.shape(); @@ -242,7 +260,6 @@ XrtAllocation::~XrtAllocation() { StatusOr XrtAllocation::FromLiteral( const Literal& argument, const string& session_target) { xrt::XLAAllocation alloc; - alloc.set_device_ordinal(0); *alloc.mutable_value() = argument.ToProto(); tensorflow::Scope root = tensorflow::Scope::NewRootScope(); @@ -644,6 +661,15 @@ LocalOp LocalComputationBuilder::ConstantLiteral(const Literal& literal) { return xla::ConstantLiteral(&builder_, literal); } +LocalOp LocalComputationBuilder::Iota(PrimitiveType element_type, int64 size) { + return xla::Iota(&builder_, element_type, size); +} + +LocalOp LocalComputationBuilder::BroadcastedIota(const Shape& shape, + int64 dimension) { + return xla::Iota(&builder_, shape, dimension); +} + LocalOp LocalComputationBuilder::Broadcast( const LocalOp& operand, absl::Span broadcast_sizes) { return xla::Broadcast(operand.op(), broadcast_sizes); @@ -780,6 +806,21 @@ LocalOp LocalComputationBuilder::Call(const LocalComputation& local_computation, return xla::Call(&builder_, local_computation.computation(), xla_ops); } +LocalOp LocalComputationBuilder::CustomCall( + const string& call_target_name, absl::Span operands, + const Shape& shape_with_layout, + const std::vector& operand_shapes_with_layout, + const string& opaque) { + std::vector xla_ops; + xla_ops.reserve(operands.size()); + for (const auto& op : operands) { + xla_ops.push_back(op.op()); + } + return xla::CustomCallWithLayout(&builder_, call_target_name, xla_ops, + shape_with_layout, + operand_shapes_with_layout, opaque); +} + LocalOp LocalComputationBuilder::Transpose( const LocalOp& operand, absl::Span permutation) { return xla::Transpose(operand.op(), permutation); @@ -865,6 +906,27 @@ LocalOp LocalComputationBuilder::SortKeyVal(const LocalOp& keys, return xla::Sort(keys.op(), {values.op()}, dimension); } +LocalOp LocalComputationBuilder::Cholesky(const LocalOp& a) { + return xla::Cholesky(a.op()); +} + +LocalOp LocalComputationBuilder::QR(const LocalOp& a, bool full_matrices) { + XlaBuilder* builder = a.op().builder(); + return builder->ReportErrorOrReturn([&]() -> StatusOr { + TF_ASSIGN_OR_RETURN(auto qr, xla::QRDecomposition(a.op(), full_matrices)); + return xla::Tuple(builder, {qr.q, qr.r}); + }); +} + +LocalOp LocalComputationBuilder::TriangularSolve(const LocalOp& a, + const LocalOp& b, + bool left_side, bool lower, + bool transpose_a, + bool conjugate_a) { + return xla::TriangularSolve(a.op(), b.op(), left_side, lower, transpose_a, + conjugate_a); +} + StatusOr LocalComputationBuilder::BuildConstantSubGraph( const LocalOp& operand) { TF_ASSIGN_OR_RETURN(XlaComputation computation, diff --git a/tensorflow/compiler/xla/python/local_computation_builder.h b/tensorflow/compiler/xla/python/local_computation_builder.h index 149e44570df5c6a3df88bbe2ffa779be47842d82..5e8341592100bc1eba4d1c17b0c2dd0e0888fdb1 100644 --- a/tensorflow/compiler/xla/python/local_computation_builder.h +++ b/tensorflow/compiler/xla/python/local_computation_builder.h @@ -19,6 +19,8 @@ limitations under the License. #include #include +#include + #include "absl/types/span.h" #include "tensorflow/cc/framework/ops.h" #include "tensorflow/cc/framework/scope.h" @@ -49,6 +51,11 @@ Status InitializePlatformName(const string& platform_name); // local XLA service has been instantiated yet or not. int GetReplicaCount(); +// Registers a 'fn_capsule' as a CPU custom call target. +// 'fn_capsule' is a void* pointer encapsulated in a PyCapsule object, with name +// "xla._CPU_CUSTOM_CALL_TARGET". +Status RegisterCpuCustomCallTarget(const string& name, PyObject* fn_capsule); + // Wraps the local client's infeed-transfer function. // // The default device ordinal (0) is used. @@ -286,6 +293,10 @@ class LocalComputationBuilder { LocalOp ConstantLiteral(const Literal& literal); + LocalOp Iota(PrimitiveType element_type, int64 size); + + LocalOp BroadcastedIota(const Shape& shape, int64 dimension); + LocalOp Broadcast(const LocalOp& operand, absl::Span broadcast_sizes); @@ -352,6 +363,12 @@ class LocalComputationBuilder { LocalOp Call(const LocalComputation& local_computation, absl::Span operands); + LocalOp CustomCall(const string& call_target_name, + absl::Span operands, + const Shape& shape_with_layout, + const std::vector& operand_shapes_with_layout, + const string& opaque); + LocalOp Transpose(const LocalOp& operand, absl::Span permutation); @@ -394,6 +411,13 @@ class LocalComputationBuilder { LocalOp SortKeyVal(const LocalOp& keys, const LocalOp& values, int64 dimension); + LocalOp QR(const LocalOp& a, bool full_matrices); + + LocalOp Cholesky(const LocalOp& a); + + LocalOp TriangularSolve(const LocalOp& a, const LocalOp& b, bool left_side, + bool lower, bool transpose_a, bool conjugate_a); + StatusOr BuildConstantSubGraph(const LocalOp& operand); #define _FORWARD(method_name, return_sig, args_sig) \ diff --git a/tensorflow/compiler/xla/python/local_computation_builder.i b/tensorflow/compiler/xla/python/local_computation_builder.i index d23d693c1e5bde43b52959e4397aa311268411bb..bf5d667c6a12972845735983a74264ea05675971 100644 --- a/tensorflow/compiler/xla/python/local_computation_builder.i +++ b/tensorflow/compiler/xla/python/local_computation_builder.i @@ -1010,6 +1010,7 @@ tensorflow::ImportNumpy(); %unignore xla::swig::InitializeReplicaCount; %unignore xla::swig::InitializePlatformName; %unignore xla::swig::GetReplicaCount; +%unignore xla::swig::RegisterCpuCustomCallTarget; %unignore xla::swig::TransferToInfeedLocal; %unignore xla::swig::TransferToInfeedLocalReplica; %unignore xla::swig::TransferFromOutfeedLocalReplica; @@ -1051,6 +1052,8 @@ tensorflow::ImportNumpy(); %unignore xla::swig::LocalComputationBuilder::Outfeed; %unignore xla::swig::LocalComputationBuilder::ConstantLiteral; %unignore xla::swig::LocalComputationBuilder::ConstantR0; +%unignore xla::swig::LocalComputationBuilder::Iota; +%unignore xla::swig::LocalComputationBuilder::BroadcastedIota; %unignore xla::swig::LocalComputationBuilder::Broadcast; %unignore xla::swig::LocalComputationBuilder::BroadcastInDim; %unignore xla::swig::LocalComputationBuilder::Pad; @@ -1144,6 +1147,10 @@ tensorflow::ImportNumpy(); %unignore xla::swig::LocalComputationBuilder::Imag; %unignore xla::swig::LocalComputationBuilder::Conj; %unignore xla::swig::LocalComputationBuilder::Complex; +%unignore xla::swig::LocalComputationBuilder::Cholesky; +%unignore xla::swig::LocalComputationBuilder::QR; +%unignore xla::swig::LocalComputationBuilder::TriangularSolve; +%unignore xla::swig::LocalComputationBuilder::CustomCall; %unignore xla::swig::DeleteLocalComputation; %unignore xla::swig::DestructureLocalShapedBufferTuple; %unignore xla::swig::DestructureXrtAllocationTuple; diff --git a/tensorflow/compiler/xla/python/xla_client.py b/tensorflow/compiler/xla/python/xla_client.py index c91a2aaf56dfe2127168628c78e0c4b868a28055..378bbdcb175f10d73da87f5286cf5129477a124c 100644 --- a/tensorflow/compiler/xla/python/xla_client.py +++ b/tensorflow/compiler/xla/python/xla_client.py @@ -415,7 +415,7 @@ class Shape(object): assert mtm is None, self if mtm is not None: assert self.rank() == len(mtm), self - assert sorted(mtm) == range(len(mtm)), self + assert sorted(mtm) == list(range(len(mtm))), self def update_minor_to_major(self, minor_to_major): if not self.is_array(): @@ -831,6 +831,33 @@ class ComputationBuilder(object): return self.ParameterWithShape( Shape.from_pyval(value), name=name, parameter_num=parameter_num) + def Iota(self, dtype, size): + """Enqueues an iota constant onto the computation. + + Args: + dtype: expected numpy dtype of the output. + size: integer, the number of elements in the array. + + Returns: + A LocalOp representing the added iota constant. + """ + element_type = DTYPE_TO_XLA_ELEMENT_TYPE[str(np.dtype(dtype))] + return self._client.Iota(element_type, size) + + def BroadcastedIota(self, dtype, shape, dimension): + """Enqueues a broadcasted iota constant onto the computation. + + Args: + dtype: expected numpy dtype of the output. + shape: tuple of integers, the expected output shape (dimensions). + dimension: positive integer, dimension along which to increment values. + + Returns: + A LocalOp representing the added broadcasted iota constant. + """ + xla_shape = Shape.array_shape(dtype, shape) + return self._client.BroadcastedIota(xla_shape, dimension) + def Broadcast(self, operand, sizes): """Enqueues a broadcast operation onto the computation. @@ -1102,6 +1129,31 @@ class ComputationBuilder(object): """ return self._client.Call(computation_to_apply.computation, operands) + def CustomCall(self, + call_target_name, + operands, + shape_with_layout, + operand_shapes_with_layout, + opaque=None): + """Enqueues a custom call operation onto the computation. + + Args: + call_target_name: the name of the function to call. + operands: an iterable of LocalOp. The number and types of operands must + match the arity of `operand_shapes_with_layout`. + shape_with_layout: the shape of the operator's output, with layout. + operand_shapes_with_layout: the shapes of `operands`, including the + expected layouts. + opaque: an opaque string passed to the backend. + + Returns: + A LocalOp representing the added custom call op. + """ + opaque = opaque or b'' + return self._client.CustomCall(call_target_name, operands, + shape_with_layout, + operand_shapes_with_layout, opaque) + def Map(self, operands, computation_to_apply, dimensions): """Enqueues a map operation onto the computation. @@ -1411,6 +1463,20 @@ class ComputationBuilder(object): """Enqueues a key-value sort operation onto the computation.""" return self._client.SortKeyVal(keys, values, dimension) + def Cholesky(self, a): + """Enqueues a Cholesky decomposition onto the computation.""" + return self._client.Cholesky(a) + + def QR(self, a, full_matrices=True): + """Enqueues a QR decomposition onto the computation.""" + return self._client.QR(a, full_matrices) + + def TriangularSolve(self, a, b, left_side=False, lower=False, + transpose_a=False, conjugate_a=False): + """Enqueues a triangular-solve operation onto the computation.""" + return self._client.TriangularSolve( + a, b, left_side, lower, transpose_a, conjugate_a) + def _forward_methods_to_local_builder(): """Forward remaining ComputationBuilder methods to the C API. @@ -1486,6 +1552,16 @@ def get_replica_count(): return c_api.GetReplicaCount() +def register_cpu_custom_call_target(name, fn): + """Registers a CPU custom call target. + + Args: + name: bytes containing the name of the function. + fn: a PyCapsule object containing the function pointer. + """ + c_api.RegisterCpuCustomCallTarget(name, fn) + + def GetPaddingConfigFromTriples(triples): """Create PaddingConfig proto from list of triples of integers.""" padding_config = xla_data_pb2.PaddingConfig() diff --git a/tensorflow/compiler/xla/python/xla_client_test.py b/tensorflow/compiler/xla/python/xla_client_test.py index 21b5c93b615ec429a5da0b4ffe89e8f75f59ef1b..002a20e60a9fbe117af991731a555e60eef9397a 100644 --- a/tensorflow/compiler/xla/python/xla_client_test.py +++ b/tensorflow/compiler/xla/python/xla_client_test.py @@ -18,11 +18,13 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import functools import itertools import threading import numpy as np +from tensorflow.compiler.xla.python import custom_call_for_test from tensorflow.compiler.xla.python import xla_client import unittest @@ -51,9 +53,11 @@ class LocalComputationTest(unittest.TestCase): def _ExecuteAndCompareExact(self, c, arguments=(), expected=None): self._ExecuteAndAssertWith(np.testing.assert_equal, c, arguments, expected) - def _ExecuteAndCompareClose(self, c, arguments=(), expected=None): - self._ExecuteAndAssertWith(np.testing.assert_allclose, c, arguments, - expected) + def _ExecuteAndCompareClose(self, c, arguments=(), expected=None, rtol=1e-7, + atol=0): + self._ExecuteAndAssertWith( + functools.partial(np.testing.assert_allclose, rtol=rtol, atol=atol), + c, arguments, expected) def NumpyArrayF32(*args, **kwargs): @@ -143,6 +147,17 @@ class ComputationsWithConstantsTest(LocalComputationTest): c.Pow(c.Constant(NumpyArrayF64([1.5, 2.5, 3.0])), c.ConstantF64Scalar(2.)) self._ExecuteAndCompareClose(c, expected=[2.25, 6.25, 9.]) + def testIota(self): + c = self._NewComputation() + c.Iota(np.float32, 10) + self._ExecuteAndCompareExact(c, expected=np.arange(10, dtype=np.float32)) + + def testBroadcastedIota(self): + c = self._NewComputation() + c.BroadcastedIota(np.int64, (2, 3), 1) + expected = np.array([[0, 1, 2], [0, 1, 2]], dtype=np.int64) + self._ExecuteAndCompareExact(c, expected=expected) + def testBooleanAnd(self): c = self._NewComputation() c.And( @@ -268,6 +283,20 @@ class ComputationsWithConstantsTest(LocalComputationTest): c.Constant(NumpyArrayF64([100, -100, 200, -200]))) self._ExecuteAndCompareClose(c, expected=[104.4, -93.4, 208.8, -189]) + def testCustomCall(self): + c = self._NewComputation() + for name, fn in custom_call_for_test.cpu_custom_call_targets.items(): + xla_client.register_cpu_custom_call_target(name, fn) + c.CustomCall( + b"test_subtract_f32", + operands=(c.ConstantF32Scalar(1.25), c.ConstantF32Scalar(0.5)), + shape_with_layout=xla_client.Shape.array_shape(np.float32, (), ()), + operand_shapes_with_layout=( + xla_client.Shape.array_shape(np.float32, (), ()), + xla_client.Shape.array_shape(np.float32, (), ()), + )) + self._ExecuteAndCompareClose(c, expected=0.75) + class ParametersTest(LocalComputationTest): """Tests focusing on Parameter ops and argument-passing.""" @@ -1057,6 +1086,38 @@ class SingleOpTest(LocalComputationTest): self.assertTrue(np.all(lo <= result)) self.assertTrue(np.all(result < hi)) + def testCholesky(self): + l = np.array([[4, 0, 0, 0], [6, 5, 0, 0], [2, 14, 16, 0], [3, 6, 1, 4]], + dtype=np.float32) + c = self._NewComputation() + c.Cholesky(c.Constant(np.dot(l, l.T))) + self._ExecuteAndCompareClose(c, expected=l, rtol=1e-4) + + def testQR(self): + a = np.array( + [[4, 6, 8, 10], [6, 45, 54, 63], [8, 54, 146, 166], [10, 63, 166, 310]], + dtype=np.float32) + c = self._NewComputation() + c.QR(c.Constant(a), full_matrices=True) + q, r = self._Execute(c, ()) + np.testing.assert_allclose(np.dot(q, r), a, rtol=1e-4) + + def testTriangularSolve(self): + a_vals = np.array( + [[2, 0, 0, 0], [3, 6, 0, 0], [4, 7, 9, 0], [5, 8, 10, 11]], + dtype=np.float32) + b_vals = np.array([[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]], + dtype=np.float32) + + c = self._NewComputation() + c.TriangularSolve(c.Constant(a_vals), c.Constant(b_vals), left_side=False, + lower=True, transpose_a=True) + self._ExecuteAndCompareClose(c, expected=np.array([ + [0.5, 0.08333334, 0.04629629, 0.03367003], + [2.5, -0.25, -0.1388889, -0.1010101], + [4.5, -0.58333331, -0.32407406, -0.23569024], + ], dtype=np.float32), rtol=1e-4) + def testIsConstant(self): c = self._NewComputation() a = c.ConstantS32Scalar(3) diff --git a/tensorflow/compiler/xla/python_api/xla_shape.py b/tensorflow/compiler/xla/python_api/xla_shape.py index 95b2bf300ec67e9f034f77450416544cb088ae55..bdcd4abd6cc708795416b15412f37dde10d7fe97 100644 --- a/tensorflow/compiler/xla/python_api/xla_shape.py +++ b/tensorflow/compiler/xla/python_api/xla_shape.py @@ -20,6 +20,8 @@ from __future__ import print_function import numpy as _np # Avoids becoming a part of public Tensorflow API. +from six.moves import xrange + from tensorflow.compiler.xla import xla_data_pb2 from tensorflow.compiler.xla.python_api import types diff --git a/tensorflow/compiler/xla/reference_util.cc b/tensorflow/compiler/xla/reference_util.cc index ceb5e74db7c3b9305e9d77068df9ae0a3690af8a..92f28a9f8aaa3106b9a58ae1ee93ef8841ab58ef 100644 --- a/tensorflow/compiler/xla/reference_util.cc +++ b/tensorflow/compiler/xla/reference_util.cc @@ -21,7 +21,6 @@ limitations under the License. #include "absl/memory/memory.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/literal_util.h" -#include "tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul.h" #include "tensorflow/compiler/xla/service/hlo_evaluator.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/shape_inference.h" @@ -32,48 +31,19 @@ limitations under the License. namespace xla { -namespace { - -template -std::unique_ptr> MatmulArray2DImpl( - const Array2D& lhs, const Array2D& rhs, - const std::function& impl_fn) { - CHECK_EQ(lhs.width(), rhs.height()); - int m = lhs.height(); - int n = rhs.width(); - int k = lhs.width(); - auto result = absl::make_unique>(m, n); - // Because Eigen is a header-oriented library, make sure that the Eigen code - // is the same as the code used by the CPU backend (otherwise the linker will - // randomly pick *some* definition). - impl_fn( - /*run_options_ptr=*/nullptr, result->data(), rhs.data(), lhs.data(), n, m, - k, - /*transpose_lhs=*/0, - /*transpose_rhs=*/0); - return result; -} - -} // namespace - /* static */ std::unique_ptr> ReferenceUtil::MatmulArray2D( const Array2D& lhs, const Array2D& rhs) { - return MatmulArray2DImpl( - lhs, rhs, __xla_cpu_runtime_EigenSingleThreadedMatMulF16); + return HloEvaluator::MatmulArray2D(lhs, rhs); } /* static */ std::unique_ptr> ReferenceUtil::MatmulArray2D( const Array2D& lhs, const Array2D& rhs) { - return MatmulArray2DImpl( - lhs, rhs, __xla_cpu_runtime_EigenSingleThreadedMatMulF32); + return HloEvaluator::MatmulArray2D(lhs, rhs); } /* static */ std::unique_ptr> ReferenceUtil::MatmulArray2D( const Array2D& lhs, const Array2D& rhs) { - return MatmulArray2DImpl( - lhs, rhs, __xla_cpu_runtime_EigenSingleThreadedMatMulF64); + return HloEvaluator::MatmulArray2D(lhs, rhs); } /* static */ std::unique_ptr> ReferenceUtil::Array2DF32ToF64( @@ -557,10 +527,11 @@ ReferenceUtil::ConvArray4DGeneralDimensionsDilated( dim2.set_base_dilation(lhs_dilation.second); *window.add_dimensions() = dim2; - const Shape& shape = ShapeInference::InferConvolveShape( - lhs_literal.shape(), rhs_literal.shape(), - /*feature_group_count=*/1, window, dnums) - .ConsumeValueOrDie(); + const Shape& shape = + ShapeInference::InferConvolveShape( + lhs_literal.shape(), rhs_literal.shape(), + /*feature_group_count=*/1, /*batch_group_count=*/1, window, dnums) + .ConsumeValueOrDie(); HloInstruction* lhs_instruction = b.AddInstruction(HloInstruction::CreateConstant(std::move(lhs_literal))); @@ -572,7 +543,7 @@ ReferenceUtil::ConvArray4DGeneralDimensionsDilated( /*new_size=*/2, PrecisionConfig::DEFAULT); b.AddInstruction(HloInstruction::CreateConvolve( shape, lhs_instruction, rhs_instruction, /*feature_group_count=*/1, - window, dnums, precision_config)); + /*batch_group_count=*/1, window, dnums, precision_config)); HloModuleConfig config; HloModule module("ReferenceUtil", config); auto computation = module.AddEntryComputation(b.Build()); diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD index 4c21ae2a427477caa86fb4130616c38eb3bcf006..d8736c819687482a9dead57bdeacff8e75dce105 100644 --- a/tensorflow/compiler/xla/service/BUILD +++ b/tensorflow/compiler/xla/service/BUILD @@ -241,6 +241,7 @@ cc_library( ":hlo_casting_utils", ":hlo_query", ":shape_inference", + "//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_util", @@ -249,6 +250,7 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:window_util", "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/service/cpu:runtime_single_threaded_matmul", "//tensorflow/core:lib", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/base", @@ -1012,6 +1014,7 @@ cc_library( srcs = ["name_uniquer.cc"], hdrs = ["name_uniquer.h"], deps = [ + "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:types", "//tensorflow/core:lib", "@com_google_absl//absl/container:flat_hash_map", @@ -1412,6 +1415,7 @@ cc_library( "//tensorflow/core:lib", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/memory", ], ) @@ -1576,6 +1580,7 @@ cc_library( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:lib", "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:optional", @@ -1695,9 +1700,9 @@ tf_cc_test( ) cc_library( - name = "convolution_feature_group_converter", - srcs = ["convolution_feature_group_converter.cc"], - hdrs = ["convolution_feature_group_converter.h"], + name = "convolution_group_converter", + srcs = ["convolution_group_converter.cc"], + hdrs = ["convolution_group_converter.h"], deps = [ ":hlo", ":hlo_pass", @@ -1719,7 +1724,7 @@ tf_cc_test( size = "small", srcs = ["convolution_feature_group_converter_test.cc"], deps = [ - ":convolution_feature_group_converter", + ":convolution_group_converter", ":hlo", ":hlo_matchers", ":hlo_parser", @@ -1782,6 +1787,7 @@ tf_cc_test( ":hlo_cse", ":hlo_dce", ":hlo_matchers", + ":hlo_parser", ":hlo_pass", ":hlo_pass_pipeline", ":tuple_simplifier", @@ -3163,6 +3169,7 @@ cc_library( name = "hlo_graph_dumper", srcs = [ "hlo_graph_dumper.cc", + "hlo_graph_html_renderer.cc", ], hdrs = ["hlo_graph_dumper.h"], deps = [ @@ -3624,7 +3631,6 @@ cc_library( srcs = ["hlo_lexer.cc"], hdrs = [ "hlo_lexer.h", - "hlo_token.h", ], deps = [ "//tensorflow/compiler/xla:shape_util", diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.cc b/tensorflow/compiler/xla/service/algebraic_simplifier.cc index 985c5af1c4d89425dd6693585e42e22510fe21f8..9e453203ce17cceb606cac06d0ebfaccbf912126 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier.cc @@ -17,6 +17,7 @@ limitations under the License. #include #include +#include #include #include #include @@ -25,6 +26,7 @@ limitations under the License. #include #include "absl/algorithm/container.h" +#include "absl/container/inlined_vector.h" #include "absl/memory/memory.h" #include "absl/strings/str_cat.h" #include "absl/types/optional.h" @@ -41,6 +43,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/hlo_query.h" #include "tensorflow/compiler/xla/service/pattern_matcher.h" +#include "tensorflow/compiler/xla/shape.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/types.h" @@ -239,6 +242,13 @@ class AlgebraicSimplifierVisitor : public DfsHloVisitorWithDefault { // more fusion than leaving the nodes as Dot operations. StatusOr HandleDotStrengthReduction(HloInstruction* dot); + // Removes dimension dim from hlo. + HloInstruction* StripDim(HloInstruction* hlo, int64 dim) { + CHECK_EQ(hlo->shape().dimensions(dim), 1); + return computation_->AddInstruction(HloInstruction::CreateReshape( + ShapeUtil::DeleteDimension(dim, hlo->shape()), hlo)); + } + // Reshapes an instruction to rank 1 if it is not already rank 1. HloInstruction* Flatten(HloInstruction* hlo) { if (ShapeUtil::Rank(hlo->shape()) == 1) { @@ -908,21 +918,51 @@ StatusOr AlgebraicSimplifierVisitor::HandleDotStrengthReduction( HloInstruction* dot) { HloInstruction *lhs, *rhs; CHECK(Match(dot, m::Dot(m::Op(&lhs), m::Op(&rhs)))); - int64 lhs_collapsing_dim = - dot->dot_dimension_numbers().lhs_contracting_dimensions(0); + + const auto kept_dim = [](int64 rank, int64 contracting_dimension, + absl::Span batch_dimensions) -> int64 { + for (int64 i = 0; i < rank; ++i) { + if (i != contracting_dimension && + !absl::c_linear_search(batch_dimensions, i)) { + return i; + } + } + return -1; + }; + + const int64 dot_rank = ShapeUtil::Rank(dot->shape()); + const int64 rhs_rank = ShapeUtil::Rank(rhs->shape()); + const int64 lhs_rank = ShapeUtil::Rank(lhs->shape()); + const auto& dnums = dot->dot_dimension_numbers(); + if (dnums.rhs_contracting_dimensions_size() > 1) { + return false; + } + if (dot_rank > 2 && (lhs_rank != rhs_rank || lhs_rank != dot_rank)) { + return false; + } + int64 lhs_collapsing_dim = dnums.lhs_contracting_dimensions(0); + int64 lhs_kept_dim = kept_dim(lhs_rank, lhs_collapsing_dim, + AsInt64Slice(dnums.lhs_batch_dimensions())); + // If there is no non-contracting dimension in rank 2, do not strength reduce. + if (lhs_kept_dim == -1 && lhs_rank > 1) { + return false; + } if (lhs->IsRank2Transpose()) { lhs = lhs->mutable_operand(0); - lhs_collapsing_dim = 1 - lhs_collapsing_dim; + std::swap(lhs_collapsing_dim, lhs_kept_dim); } - const int64 lhs_kept_dim = 1 - lhs_collapsing_dim; - int64 rhs_collapsing_dim = - dot->dot_dimension_numbers().rhs_contracting_dimensions(0); + int64 rhs_collapsing_dim = dnums.rhs_contracting_dimensions(0); + int64 rhs_kept_dim = kept_dim(rhs_rank, rhs_collapsing_dim, + AsInt64Slice(dnums.rhs_batch_dimensions())); + // If there is no non-contracting dimension in rank 2, do not strength reduce. + if (rhs_kept_dim == -1 && rhs_rank > 1) { + return false; + } if (rhs->IsRank2Transpose()) { rhs = rhs->mutable_operand(0); - rhs_collapsing_dim = 1 - rhs_collapsing_dim; + std::swap(rhs_collapsing_dim, rhs_kept_dim); } - const int64 rhs_kept_dim = 1 - rhs_collapsing_dim; auto as_type = [&](HloInstruction* hlo, const PrimitiveType element_type) { if (hlo->shape().element_type() == element_type) { @@ -945,10 +985,15 @@ StatusOr AlgebraicSimplifierVisitor::HandleDotStrengthReduction( return AddReduce(as_type(hlo, F32), dim); }; + auto broadcast = [&](HloInstruction* hlo, const Shape& shape, + absl::Span dims) { + return computation_->AddInstruction( + HloInstruction::CreateBroadcast(shape, hlo, dims)); + }; + auto broadcast_to_dim = [&](HloInstruction* hlo, const Shape& shape, int64 dim) { - return computation_->AddInstruction( - HloInstruction::CreateBroadcast(shape, hlo, {dim})); + return broadcast(hlo, shape, {dim}); }; auto multiply = [&](HloInstruction* local_lhs, HloInstruction* local_rhs) { @@ -959,11 +1004,9 @@ StatusOr AlgebraicSimplifierVisitor::HandleDotStrengthReduction( // Strength reduce dot(a[K] , b[K]) = // reshape(result.shape, // reduce_sum(multiply(a, b), {0})) - if (ShapeUtil::Rank(rhs->shape()) == 1 && - ShapeUtil::Rank(lhs->shape()) == 1) { - TF_RETURN_IF_ERROR( - ReplaceInstruction(dot, reshape_if_necessary(add_reduce_in_f32( - multiply(Flatten(lhs), Flatten(rhs)), 0)))); + if (rhs_rank == 1 && lhs_rank == 1) { + TF_RETURN_IF_ERROR(ReplaceInstruction( + dot, reshape_if_necessary(add_reduce_in_f32(multiply(lhs, rhs), 0)))); return true; } @@ -977,8 +1020,7 @@ StatusOr AlgebraicSimplifierVisitor::HandleDotStrengthReduction( // Simplify outer product into multiply with implicit broadcasting. // // A dot(a[M, 1], b[1, N]) = multiply(a [M,1], b [1, N]) - if (ShapeUtil::Rank(rhs->shape()) == 2 && - rhs->shape().dimensions(rhs_collapsing_dim) == 1) { + if (rhs_rank == 2 && rhs->shape().dimensions(rhs_collapsing_dim) == 1) { TF_RETURN_IF_ERROR(ReplaceInstruction( dot, multiply(broadcast_to_dim(Flatten(lhs), dot->shape(), 0), broadcast_to_dim(Flatten(rhs), dot->shape(), 1)))); @@ -992,9 +1034,8 @@ StatusOr AlgebraicSimplifierVisitor::HandleDotStrengthReduction( // {0}) // ) // ) - if (ShapeUtil::Rank(lhs->shape()) == 1 || - (ShapeUtil::Rank(lhs->shape()) == 2 && - lhs->shape().dimensions(lhs_kept_dim) == 1)) { + if (lhs_rank == 1 || + (lhs_rank == 2 && lhs->shape().dimensions(lhs_kept_dim) == 1)) { if (ShapeUtil::Rank(rhs->shape()) == 1) { TF_RETURN_IF_ERROR( ReplaceInstruction(dot, reshape_if_necessary(add_reduce_in_f32( @@ -1014,9 +1055,8 @@ StatusOr AlgebraicSimplifierVisitor::HandleDotStrengthReduction( // reshape(result.shape, // reduce_sum(multiply(a, broadcast(reshape([K],b), {1})), {0}) // ) - if (ShapeUtil::Rank(rhs->shape()) == 1 || - (ShapeUtil::Rank(rhs->shape()) == 2 && - rhs->shape().dimensions(rhs_kept_dim) == 1)) { + if (rhs_rank == 1 || + (rhs_rank == 2 && rhs->shape().dimensions(rhs_kept_dim) == 1)) { TF_RETURN_IF_ERROR(ReplaceInstruction( dot, reshape_if_necessary(add_reduce_in_f32( multiply(lhs, broadcast_to_dim(Flatten(rhs), lhs->shape(), @@ -1024,6 +1064,97 @@ StatusOr AlgebraicSimplifierVisitor::HandleDotStrengthReduction( lhs_collapsing_dim)))); return true; } + + // Only consider kDot with batch dimension. + if (dot_rank <= 2) { + return false; + } + + CHECK_EQ(rhs_rank, lhs_rank); + CHECK_EQ(dot_rank, lhs_rank); + // If there is more than one non-contracting dimension or the batch dimensions + // are not equal, bail out since transposes may be required to do a strength + // reduction. + if (dnums.rhs_batch_dimensions_size() + 2 != dot_rank || + !absl::c_equal(dnums.lhs_batch_dimensions(), + dnums.rhs_batch_dimensions())) { + return false; + } + + auto broadcast_dims = [](int64 rank, int64 non_broadcast_dim) { + absl::InlinedVector dims; + for (int64 i = 0; i < rank; ++i) { + if (i != non_broadcast_dim) { + dims.push_back(i); + } + } + return dims; + }; + + // If the contracting dimension is 1, remove the degnerate dimnesions from the + // lhs and rhs, broadcast each to the result shape and multiply. + if (lhs->shape().dimensions(lhs_collapsing_dim) == 1 && + (rhs_kept_dim == rhs_rank - 1 || + (rhs_collapsing_dim == rhs_rank - 1 && rhs_kept_dim == rhs_rank - 2))) { + CHECK_EQ(rhs->shape().dimensions(rhs_collapsing_dim), 1); + const int64 lhs_kept_dim_in_output = + lhs_kept_dim > lhs_collapsing_dim ? (lhs_kept_dim - 1) : lhs_kept_dim; + absl::InlinedVector lhs_broadcast_dims; + for (const int64 dim : dnums.lhs_batch_dimensions()) { + lhs_broadcast_dims.push_back(dim > lhs_collapsing_dim ? (dim - 1) : dim); + } + absl::InlinedVector rhs_broadcast_dims = lhs_broadcast_dims; + lhs_broadcast_dims.push_back(lhs_kept_dim_in_output); + absl::c_sort(lhs_broadcast_dims); + rhs_broadcast_dims.push_back(dot_rank - 1); + absl::c_sort(rhs_broadcast_dims); + TF_RETURN_IF_ERROR(ReplaceInstruction( + dot, reshape_if_necessary( + multiply(broadcast(StripDim(lhs, lhs_collapsing_dim), + dot->shape(), lhs_broadcast_dims), + broadcast(StripDim(rhs, rhs_collapsing_dim), + dot->shape(), rhs_broadcast_dims))))); + return true; + } + + // If the lhs and rhs non-contracting dimensions are both one, strip each one, + // multiply and then reduce the collapsing dimension + if (lhs->shape().dimensions(lhs_kept_dim) == 1 && + rhs->shape().dimensions(rhs_kept_dim) == 1 && + lhs_kept_dim == rhs_kept_dim) { + auto new_lhs = StripDim(lhs, lhs_kept_dim); + auto new_rhs = StripDim(rhs, rhs_kept_dim); + const int64 reduce_dim = rhs_kept_dim < rhs_collapsing_dim + ? (rhs_collapsing_dim - 1) + : rhs_collapsing_dim; + TF_RETURN_IF_ERROR( + ReplaceInstruction(dot, reshape_if_necessary(add_reduce_in_f32( + multiply(new_lhs, new_rhs), reduce_dim)))); + return true; + } + + // If the lhs non-contracting dimensions is one, strip the one, brodcast to + // the rhs shape, multiply and then reduce the collapsing dimension + if (lhs->shape().dimensions(lhs_kept_dim) == 1) { + auto new_lhs = broadcast(StripDim(lhs, lhs_kept_dim), rhs->shape(), + broadcast_dims(rhs_rank, rhs_kept_dim)); + TF_RETURN_IF_ERROR(ReplaceInstruction( + dot, reshape_if_necessary(add_reduce_in_f32(multiply(new_lhs, rhs), + rhs_collapsing_dim)))); + return true; + } + + // If the rhs non-contracting dimensions is one, strip the one, brodcast to + // the lhs shape, multiply and then reduce the collapsing dimension + if (rhs->shape().dimensions(rhs_kept_dim) == 1) { + auto new_rhs = broadcast(StripDim(rhs, rhs_kept_dim), lhs->shape(), + broadcast_dims(lhs_rank, lhs_kept_dim)); + TF_RETURN_IF_ERROR(ReplaceInstruction( + dot, reshape_if_necessary(add_reduce_in_f32(multiply(lhs, new_rhs), + lhs_collapsing_dim)))); + return true; + } + return false; } @@ -1302,25 +1433,31 @@ Status AlgebraicSimplifierVisitor::HandleDot(HloInstruction* dot) { HloInstruction *lhs, *rhs; CHECK(Match(dot, m::Dot(m::Op(&lhs), m::Op(&rhs)))); - // Only optimize F32 or BF16 dot operations where the dot, rhs and lhs are - // rank 2 or below. - if ((dot->shape().element_type() != F32 && - dot->shape().element_type() != BF16) || - ShapeUtil::Rank(lhs->shape()) > 2 || ShapeUtil::Rank(rhs->shape()) > 2 || - ShapeUtil::Rank(dot->shape()) > 2) { - return Status::OK(); - } - // Replace a zero element dot with a broadcast of the constant 0. if (ShapeUtil::IsZeroElementArray(dot->shape()) || ShapeUtil::IsZeroElementArray(lhs->shape()) || ShapeUtil::IsZeroElementArray(rhs->shape())) { - auto zero = computation_->AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.0f))); + auto zero = computation_->AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::Zero(dot->shape().element_type()))); return ReplaceWithNewInstruction( dot, HloInstruction::CreateBroadcast(dot->shape(), zero, {})); } + // Only optimize F32 or BF16 dot operations where the dot, rhs and lhs are + // rank 2 or below. + if (dot->shape().element_type() != F32 && + dot->shape().element_type() != BF16) { + return Status::OK(); + } + if (ShapeUtil::Rank(lhs->shape()) > 2 || ShapeUtil::Rank(rhs->shape()) > 2 || + ShapeUtil::Rank(dot->shape()) > 2) { + if (options_.enable_dot_strength_reduction() && + !options_.is_layout_sensitive()) { + TF_RETURN_IF_ERROR(HandleDotStrengthReduction(dot).status()); + } + return Status::OK(); + } + TF_ASSIGN_OR_RETURN(HloInstruction * dot_of_concat_optimized, OptimizeDotOfConcat(dot)); if (dot_of_concat_optimized) { @@ -2026,6 +2163,7 @@ Status AlgebraicSimplifierVisitor::HandleReshape(HloInstruction* reshape) { reshape, HloInstruction::CreateReshape(reshape->shape(), operand->mutable_operand(0))); } + if (operand->opcode() == HloOpcode::kRng && operand->user_count() == 1) { *operand->mutable_shape() = reshape->shape(); return ReplaceInstruction(reshape, operand); @@ -2459,6 +2597,53 @@ Status AlgebraicSimplifierVisitor::HandleReduceWindow( function)); } + // A reduce window can be expressed as a reduce and a reshape if all + // dimensions either have a window size of one or the entire dimension. If + // there is no stride, dilation, or padding, this is as easy as checking the + // size of the output shape and window dimension. + // + // The reshape is a bitcast since it adds one-sized dimensions. Often these + // ones are immediately removed as well with another reshape. The + // implementation of reduce tends to be slightly more efficient at reducing + // entire dimensions compared to reduce window. + auto effective_reduce_dims = [&] { + if (window_util::HasStride(window) || window_util::HasDilation(window) || + window_util::HasPadding(window)) { + return absl::InlinedVector{}; + } + absl::InlinedVector reduce_dims; + for (int64 i = 0; i < window.dimensions_size(); ++i) { + if (window.dimensions(i).size() == 1) { + continue; + } else if (reduce_window->shape().dimensions(i) == 1) { + reduce_dims.push_back(i); + } else { + return absl::InlinedVector{}; + } + } + return reduce_dims; + }(); + + // If a reduce window can be expressed as a reduce, do so and reshape the + // output. + if (!effective_reduce_dims.empty()) { + Shape reduce_shape = ShapeUtil::FilterDimensions( + [&](int64 dim) { + return !absl::c_linear_search(effective_reduce_dims, dim); + }, + reduce_window->shape()); + HloInstruction* reduce = + computation_->AddInstruction(HloInstruction::CreateReduce( + /*shape=*/reduce_shape, + /*operand=*/operand, + /*init_value=*/reduce_window->mutable_operand(1), + /*dimensions_to_reduce=*/effective_reduce_dims, + /*reduce_computation=*/function)); + return ReplaceWithNewInstruction( + reduce_window, + HloInstruction::CreateReshape(reduce_window->shape(), reduce)); + } + // This optimization folds a pad op into reduce_window. HloInstruction* pad; const HloInstruction* convert = nullptr; @@ -2748,6 +2933,22 @@ Status AlgebraicSimplifierVisitor::HandleSort(HloInstruction* sort) { return Status::OK(); } +namespace { +bool OnlyPermutesMoreThanOneDegenerateDim(const Shape& shape, + absl::Span perm) { + std::vector new_permutation; + int64 degenerate_count = 0; + for (int64 i = 0; i < perm.size(); ++i) { + if (shape.dimensions(i) != 1) { + new_permutation.push_back(perm[i]); + } else { + ++degenerate_count; + } + } + return degenerate_count > 1 && absl::c_is_sorted(new_permutation); +} +} // namespace + Status AlgebraicSimplifierVisitor::HandleTranspose(HloInstruction* transpose) { auto operand = transpose->mutable_operand(0); if (std::is_sorted(transpose->dimensions().begin(), @@ -2764,6 +2965,15 @@ Status AlgebraicSimplifierVisitor::HandleTranspose(HloInstruction* transpose) { transpose->dimensions()))); } + // Replace transpose with a reshape if more than one degenerate method is + // permuted. + if (OnlyPermutesMoreThanOneDegenerateDim(transpose->shape(), + transpose->dimensions())) { + return ReplaceWithNewInstruction( + transpose, HloInstruction::CreateReshape( + transpose->shape(), transpose->mutable_operand(0))); + } + if (operand->opcode() == HloOpcode::kRng && operand->user_count() == 1) { *operand->mutable_shape() = transpose->shape(); return ReplaceInstruction(transpose, operand); diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc index 14ce519b6a0fd221070006d336d23bddeb6cd621..a9d617cbf6dcd02283d5d66655c0fa6ddf6dc27f 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc @@ -1273,7 +1273,7 @@ TEST_F(AlgebraicSimplifierTest, ZeroSizedConvolution) { // Create add computation. builder.AddInstruction(HloInstruction::CreateConvolve( ShapeUtil::MakeShape(F32, {3, 3, 3}), lhs, rhs, /*feature_group_count=*/1, - window, dnums, DefaultPrecisionConfig(2))); + /*batch_group_count=*/1, window, dnums, DefaultPrecisionConfig(2))); m->AddEntryComputation(builder.Build()); HloPassFix simplifier(default_options_); EXPECT_THAT(m->entry_computation()->root_instruction(), @@ -1283,6 +1283,51 @@ TEST_F(AlgebraicSimplifierTest, ZeroSizedConvolution) { GmockMatch(m::Broadcast(m::Constant()))); } +TEST_F(AlgebraicSimplifierTest, ReduceWindowIsReduceAndReshape) { + auto m = CreateNewVerifiedModule(); + auto builder = HloComputation::Builder(TestName()); + HloInstruction* param = + builder.AddInstruction(HloInstruction::CreateParameter( + 0, ShapeUtil::MakeShape(F32, {1, 2, 3, 4}), "param")); + Window window; + for (int64 i = 0; i < 4; ++i) { + WindowDimension* dim = window.add_dimensions(); + // Makes 1x2x3x1 window. + dim->set_size((i % 3) + 1); + dim->set_stride(1); + dim->set_padding_low(0); + dim->set_padding_high(0); + dim->set_window_dilation(1); + dim->set_base_dilation(1); + } + // Create add computation. + HloComputation* add_computation = nullptr; + { + HloComputation::Builder builder(TestName() + ".add"); + const Shape scalar_shape = ShapeUtil::MakeShape(F32, {}); + HloInstruction* p0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, scalar_shape, "p0")); + HloInstruction* p1 = builder.AddInstruction( + HloInstruction::CreateParameter(1, scalar_shape, "p1")); + builder.AddInstruction( + HloInstruction::CreateBinary(scalar_shape, HloOpcode::kAdd, p0, p1)); + add_computation = m->AddEmbeddedComputation(builder.Build()); + } + builder.AddInstruction(HloInstruction::CreateReduceWindow( + ShapeUtil::MakeShape(F32, {1, 1, 1, 4}), param, + builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.0f))), + window, add_computation)); + m->AddEntryComputation(builder.Build()); + HloPassFix simplifier(default_options_); + EXPECT_THAT(m->entry_computation()->root_instruction(), + GmockMatch(m::ReduceWindow(m::Parameter(0), m::Constant()))); + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); + EXPECT_THAT( + m->entry_computation()->root_instruction(), + GmockMatch(m::Reshape(m::Reduce(m::Parameter(0), m::Constant())))); +} + TEST_F(AlgebraicSimplifierTest, ZeroSizedReduceWindow) { auto m = CreateNewVerifiedModule(); auto builder = HloComputation::Builder(TestName()); @@ -2047,6 +2092,27 @@ TEST_F(AlgebraicSimplifierTest, TransposesMerged) { computation->root_instruction()->dimensions()); } +TEST_F(AlgebraicSimplifierTest, TransposeIsReshape) { + const char* hlo_string = R"( + HloModule module + + ENTRY test { + param = f32[10] parameter(0) + reshaped = f32[1,1,10] reshape(f32[10] param) + transposed = f32[10,1,1] transpose(f32[1,1,10] reshaped), dimensions={2,1,0} + ROOT reshaped_again = f32[10] reshape(f32[10,1,1] transposed) + } + )"; + TF_ASSERT_OK_AND_ASSIGN( + auto module, + HloRunner::CreateModuleFromString(hlo_string, GetDebugOptionsForTest())); + + HloPassFix simplifier(default_options_); + EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + auto root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, GmockMatch(m::Parameter())); +} + // Test merging reshape and broadcast. TEST_F(AlgebraicSimplifierTest, ReshapeAndBroadcastMerged) { auto m = CreateNewVerifiedModule(); @@ -2950,11 +3016,11 @@ TEST_P(ConvInputPaddingTest, DoTest) { .ValueOrDie(); builder.AddInstruction(HloInstruction::CreateConvolve( ShapeInference::InferConvolveShape(lhs_pad->shape(), filter->shape(), - /*feature_group_count=*/1, window, - dnums) + /*feature_group_count=*/1, + /*batch_group_count=*/1, window, dnums) .ValueOrDie(), - lhs_pad, filter, /*feature_group_count=*/1, window, dnums, - DefaultPrecisionConfig(2))); + lhs_pad, filter, /*feature_group_count=*/1, /*batch_group_count=*/1, + window, dnums, DefaultPrecisionConfig(2))); auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build()); @@ -3067,11 +3133,11 @@ TEST_P(ConvFilterPaddingTest, DoIt) { builder.AddInstruction(HloInstruction::CreateConvolve( ShapeInference::InferConvolveShape(input->shape(), rhs_pad->shape(), - /*feature_group_count=*/1, window, - dnums) + /*feature_group_count=*/1, + /*batch_group_count=*/1, window, dnums) .ValueOrDie(), - input, rhs_pad, /*feature_group_count=*/1, window, dnums, - precision_config)); + input, rhs_pad, /*feature_group_count=*/1, /*batch_group_count=*/1, + window, dnums, precision_config)); auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build()); @@ -3219,7 +3285,8 @@ TEST_F(AlgebraicSimplifierTest, ConvertConvToMatmul) { b.AddInstruction(HloInstruction::CreateConvolve( out_shape, input, filter, - /*feature_group_count=*/1, window, dnums, DefaultPrecisionConfig(2))); + /*feature_group_count=*/1, /*batch_group_count=*/1, window, dnums, + DefaultPrecisionConfig(2))); // TODO(b/80488902): verify this module. auto module = CreateNewUnverifiedModule(); @@ -4065,9 +4132,6 @@ PadReduceWindowEffectiveBroadcastCases() { {/*input_spatials=*/{2, 2}, /*symmetric_pad_amount=*/{6, 6}, /*reduce_window_spatials=*/{7, 7}, /*prepend_a=*/true, /*should_become_broadcast=*/false}, // - {/*input_spatials=*/{1, 1}, /*symmetric_pad_amount=*/{2, 2}, - /*reduce_window_spatials=*/{5, 5}, /*prepend_a=*/true, - /*should_become_broadcast=*/true}, // {/*input_spatials=*/{1, 1}, /*symmetric_pad_amount=*/{2, 2}, /*reduce_window_spatials=*/{1, 1}, /*prepend_a=*/true, /*should_become_broadcast=*/false}, // @@ -4083,6 +4147,57 @@ INSTANTIATE_TEST_CASE_P( PadReduceWindowEffectiveBroadcastTest, ::testing::ValuesIn(PadReduceWindowEffectiveBroadcastCases())); +class BatchDotStrengthReductionTest + : public AlgebraicSimplifierTest, + public ::testing::WithParamInterface< + ::testing::tuple> {}; +TEST_P(BatchDotStrengthReductionTest, BatchDotStrengthReduction) { + auto module = CreateNewVerifiedModule(); + int m, k, n; + PrimitiveType element_type; + std::tie(m, k, n, element_type) = GetParam(); + + Shape dot_shape = ShapeUtil::MakeShape(element_type, {1, 3, 5, m, n}); + Shape lhs_shape = ShapeUtil::MakeShape(element_type, {1, 3, 5, m, k}); + Shape rhs_shape = ShapeUtil::MakeShape(element_type, {1, 3, 5, k, n}); + HloComputation::Builder builder(TestName()); + + auto lhs = builder.AddInstruction( + HloInstruction::CreateParameter(0, lhs_shape, "lhs")); + auto rhs = builder.AddInstruction( + HloInstruction::CreateParameter(1, rhs_shape, "rhs")); + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_batch_dimensions(0); + dot_dnums.add_lhs_batch_dimensions(1); + dot_dnums.add_lhs_batch_dimensions(2); + dot_dnums.add_rhs_batch_dimensions(0); + dot_dnums.add_rhs_batch_dimensions(1); + dot_dnums.add_rhs_batch_dimensions(2); + dot_dnums.add_lhs_contracting_dimensions(4); + dot_dnums.add_rhs_contracting_dimensions(3); + builder.AddInstruction(HloInstruction::CreateDot( + dot_shape, lhs, rhs, dot_dnums, DefaultPrecisionConfig(2))); + auto computation = module->AddEntryComputation(builder.Build()); + AlgebraicSimplifier simplifier(default_options_); + TF_ASSERT_OK_AND_ASSIGN(bool changed, simplifier.Run(module.get())); + const bool dot_should_be_transformed = m == 1 || k == 1 || n == 1; + const bool computation_should_be_modified = dot_should_be_transformed; + EXPECT_EQ(changed, computation_should_be_modified); + bool has_no_dot = true; + for (const auto& hlo : computation->instructions()) { + if (hlo->opcode() == HloOpcode::kDot) { + has_no_dot = false; + break; + } + } + EXPECT_EQ(has_no_dot, dot_should_be_transformed); +} + +INSTANTIATE_TEST_CASE_P( + BatchDotStrengthReductionTestInstantiation, BatchDotStrengthReductionTest, + ::testing::Combine(::testing::Values(1, 2), ::testing::Values(1, 2), + ::testing::Values(1, 2), ::testing::Values(F32, BF16))); + class DotStrengthReductionTest : public AlgebraicSimplifierTest, public ::testing::WithParamInterface< diff --git a/tensorflow/compiler/xla/service/ar_crs_combiner.cc b/tensorflow/compiler/xla/service/ar_crs_combiner.cc index 24de69382262cabd30c34eea95e77aa0df2947cb..47d2c7e35705698d49950c2fa042af1c6327d521 100644 --- a/tensorflow/compiler/xla/service/ar_crs_combiner.cc +++ b/tensorflow/compiler/xla/service/ar_crs_combiner.cc @@ -36,31 +36,47 @@ namespace { namespace m = match; -// If the argument instruction is a CRS in the sequence -// AR -> Convert -> Add -> CRS -// then return the AR in the sequence. -// TODO(b/117554291): Rewrite this to recognize more general patterns, -// not just the specific one of AR -> Add -> Convert -> CRS. -absl::optional MatchesArCrsPattern( - HloInstruction* instruction) { - HloInstruction *ar, *convert, *add, *crs; - if (Match(instruction, - m::CrossReplicaSum( - &crs, m::Add(&add, m::Op(), - m::Convert(&convert, - m::CrossReplicaSum(&ar, m::Op()))))) && - ar->users().size() == 1 && ar->shape().element_type() == BF16 && - convert->shape().element_type() == F32 && !crs->all_reduce_id()) { - return ar; +// Returns true iff the argument instruction is an AllReduce, followed by a +// certain sequence of instructions and then a CRS. It must be possible to move +// the AR past each instruction in the sequence. +bool MatchesArCrsPattern(HloInstruction* instruction) { + auto can_ar_move_past_instruction = [](HloInstruction* instruction) -> bool { + if (instruction->user_count() != 1) { + return false; + } + auto opcode = instruction->opcode(); + return opcode == HloOpcode::kBitcast || opcode == HloOpcode::kTranspose || + opcode == HloOpcode::kReshape || opcode == HloOpcode::kConvert || + opcode == HloOpcode::kAdd || opcode == HloOpcode::kSubtract || + opcode == HloOpcode::kMultiply; + }; + + auto computation_is_addition = [](HloComputation* c) { + return c->instruction_count() == 3 && + Match(c->root_instruction(), m::Add(m::Parameter(), m::Parameter())); + }; + + if (!instruction->IsCrossModuleAllReduce() || + !computation_is_addition(instruction->called_computations()[0]) || + instruction->user_count() != 1) { + return false; } - return absl::optional(); + auto next = instruction->users()[0]; + while (!next->IsCrossReplicaAllReduce()) { + if (can_ar_move_past_instruction(next)) { + next = next->users()[0]; + } else { + return false; + } + } + return computation_is_addition(next->called_computations()[0]); } } // namespace absl::optional ArCrsCombiner::WhileFromBodyParameter( HloInstruction* instruction) { - CHECK(HloOpcode::kParameter == instruction->opcode()); + CHECK_EQ(HloOpcode::kParameter, instruction->opcode()); HloComputation* computation = instruction->parent(); auto caller_instructions = call_graph_->GetComputationCallers(computation); if (caller_instructions.size() == 1) { @@ -120,7 +136,7 @@ bool ArCrsCombiner::TupleElementsComputeSameValue( return false; } for (auto tuple : tuples) { - CHECK(tuple->opcode() == HloOpcode::kTuple); + CHECK_EQ(tuple->opcode(), HloOpcode::kTuple); if (!InstructionsComputeSameValue(tuple->mutable_operand(i1), tuple->mutable_operand(i2), visited_pairs)) { @@ -133,7 +149,7 @@ bool ArCrsCombiner::TupleElementsComputeSameValue( /* static */ bool ArCrsCombiner::TestInstructionsComputeSameValue(HloInstruction* i1, HloInstruction* i2) { - ArCrsCombiner combiner(/*num_spatial_partitions=*/2, /*num_replicas=*/1); + ArCrsCombiner combiner(/*num_spatial_partitions=*/2); auto module = i1->parent()->parent(); CHECK_EQ(module, i2->parent()->parent()); combiner.call_graph_ = CallGraph::Build(module); @@ -160,13 +176,6 @@ bool ArCrsCombiner::InstructionsComputeSameValue( if (opcode1 != i2->opcode() || operands1.size() != i2->operands().size()) { return false; } - if (opcode1 == HloOpcode::kConstant || i1->IsCrossModuleAllReduce()) { - return i1->Identical( - *i2, - /*eq_operands=*/std::equal_to(), - /*eq_computations=*/std::equal_to(), - /*layout_sensitive=*/false); - } visited_pairs->emplace(min_uid, max_uid); for (int i = 0; i < operands1.size(); ++i) { auto operand1 = operands1[i]; @@ -175,22 +184,35 @@ bool ArCrsCombiner::InstructionsComputeSameValue( return false; } } + if (opcode1 == HloOpcode::kParameter) { + // In the general case, we don't try to prove equality of parameters. + // We only try in the context of get-tuple-element + // (see TupleElementsComputeSameValue). + return false; + } if (opcode1 == HloOpcode::kGetTupleElement) { - if (i1->tuple_index() == i2->tuple_index()) { - return true; - } - return TupleElementsComputeSameValue(operands1[0], i1->tuple_index(), + return i1->tuple_index() == i2->tuple_index() || + TupleElementsComputeSameValue(operands1[0], i1->tuple_index(), i2->tuple_index(), visited_pairs); } - return true; + // Don't check that the operands are identical, because Identical can + // return false for instructions that compute the same value but are not + // identical, which we don't want. We have checked the arguments with + // InstructionsComputeSameValue earlier. + auto eq_instructions = [](const HloInstruction* i1, + const HloInstruction* i2) -> bool { return true; }; + auto eq_computations = [](const HloComputation* a, const HloComputation* b) { + return *a == *b; + }; + return i1->Identical(*i2, eq_instructions, eq_computations, + /*layout_sensitive=*/false); } void ArCrsCombiner::GroupAllReducesById(HloModule* module) { for (HloComputation* computation : module->MakeNonfusionComputations()) { for (HloInstruction* instruction : computation->instructions()) { - auto ar = MatchesArCrsPattern(instruction); - if (ar) { - all_reduce_map_[*((*ar)->all_reduce_id())].push_back(*ar); + if (MatchesArCrsPattern(instruction)) { + all_reduce_map_[*(instruction->all_reduce_id())].push_back(instruction); } } } @@ -198,21 +220,23 @@ void ArCrsCombiner::GroupAllReducesById(HloModule* module) { void ArCrsCombiner::KeepProvablyEqualInstructionGroups() { for (auto it : all_reduce_map_) { + auto all_reduce_id = it.first; auto instruction_vec = it.second; CHECK_EQ(instruction_vec.size(), num_spatial_partitions_); - auto instr_0 = instruction_vec[0]; - auto add_0 = instr_0->users()[0]->users()[0]; - CHECK(HloOpcode::kAdd == add_0->opcode()); - for (int i = 1; i < instruction_vec.size(); ++i) { auto instr_i = instruction_vec[i]; - auto add_i = instr_i->users()[0]->users()[0]; - CHECK(HloOpcode::kAdd == add_i->opcode()); + auto next_0 = instr_0->users()[0]; + auto next_i = instr_i->users()[0]; absl::flat_hash_map visited_pairs; - if (!InstructionsComputeSameValue(add_0, add_i, &visited_pairs)) { - all_reduce_map_.erase(it.first); - } + do { + if (!InstructionsComputeSameValue(next_0, next_i, &visited_pairs)) { + all_reduce_map_.erase(all_reduce_id); + break; + } + next_0 = next_0->users()[0]; + next_i = next_i->users()[0]; + } while (!next_0->IsCrossReplicaAllReduce()); } } } @@ -221,51 +245,51 @@ StatusOr ArCrsCombiner::RewriteGraph() { if (all_reduce_map_.empty()) { return false; } - - auto computation_is_addition = [](HloComputation* c) { - return c->instruction_count() == 3 && - Match(c->root_instruction(), m::Add(m::Parameter(), m::Parameter())); - }; - for (auto it : all_reduce_map_) { auto instruction_vec = it.second; for (auto all_reduce : instruction_vec) { auto parent_computation = all_reduce->parent(); - auto convert = all_reduce->users()[0]; - auto add = convert->users()[0]; - auto crs = add->users()[0]; - - if (!computation_is_addition(all_reduce->called_computations()[0]) || - !computation_is_addition(crs->called_computations()[0])) { - continue; - } - HloInstruction* other_summand = (add->operands()[0] == convert) - ? add->operands()[1] - : add->operands()[0]; - // Remove the AllReduce and replace the CRS with an all-core AllReduce, - // then subtract: - // other_summand * num_replicas_ * (num_spatial_partitions_ - 1) - TF_CHECK_OK( - all_reduce->ReplaceAllUsesWith(all_reduce->mutable_operand(0))); - crs->set_all_reduce_id(all_reduce->all_reduce_id()); - auto new_shape = crs->shape(); - Literal lit(new_shape); - lit.PopulateWithValue(num_replicas_ * - (num_spatial_partitions_ - 1)); - auto partitions_minus_1_const = parent_computation->AddInstruction( - HloInstruction::CreateConstant(lit.Clone())); - auto to_subtract = - parent_computation->AddInstruction(HloInstruction::CreateBinary( - new_shape, HloOpcode::kMultiply, other_summand, - partitions_minus_1_const)); - auto sub = - parent_computation->AddInstruction(HloInstruction::CreateBinary( - new_shape, HloOpcode::kSubtract, crs, to_subtract)); - TF_CHECK_OK(crs->ReplaceAllUsesWith(sub)); + auto all_reduce_id = all_reduce->all_reduce_id(); + auto prev = all_reduce->mutable_operand(0); + auto next = all_reduce->users()[0]; + TF_CHECK_OK(all_reduce->ReplaceUseWith(next, prev)); TF_CHECK_OK(parent_computation->RemoveInstruction(all_reduce)); + while (!next->IsCrossReplicaAllReduce()) { + switch (next->opcode()) { + case HloOpcode::kBitcast: + case HloOpcode::kTranspose: + case HloOpcode::kReshape: + case HloOpcode::kConvert: + case HloOpcode::kMultiply: + break; + case HloOpcode::kAdd: + case HloOpcode::kSubtract: { + auto other_operand = (next->operands()[0] == prev) + ? next->operands()[1] + : next->operands()[0]; + // To move the AR past the addition/subtraction, we need to divide + // other_operand by the number of spatial partitions. + auto shape = other_operand->shape(); + Literal lit(shape); + lit.PopulateWithValue(num_spatial_partitions_); + auto divisor = parent_computation->AddInstruction( + HloInstruction::CreateConstant(lit.Clone())); + auto division = + parent_computation->AddInstruction(HloInstruction::CreateBinary( + shape, HloOpcode::kDivide, other_operand, divisor)); + TF_CHECK_OK(other_operand->ReplaceUseWith(next, division)); + break; + } + default: + LOG(FATAL) << "Unexpected instruction: " << next->ToShortString(); + } + prev = next; + next = next->users()[0]; + } + // The AllReduce and the CRS are combined to an all-core AllReduce. + next->set_all_reduce_id(all_reduce_id); } } - return true; } diff --git a/tensorflow/compiler/xla/service/ar_crs_combiner.h b/tensorflow/compiler/xla/service/ar_crs_combiner.h index 4abdb1f57d835ff0faa6f371df3a170a4a0b22f0..6f54b97615b270bc6b180dd47d9aff6473752b47 100644 --- a/tensorflow/compiler/xla/service/ar_crs_combiner.h +++ b/tensorflow/compiler/xla/service/ar_crs_combiner.h @@ -25,14 +25,16 @@ limitations under the License. namespace xla { -// Combine an AllReduce and a CrossReplicaSum when they are close to each other -// in the graph, to use an efficient CrossReplicaSum implementation that -// fully utilizes the interconnect bandwidth. +// When the HLO graph contains a cross-module AllReduce, followed by some simple +// linear operations, followed by a cross-replica AllReduce, we can combine the +// CMAR and the CRAR, to use an efficient AllReduce implementation that fully +// utilizes the interconnect bandwidth. +// Such sequences appear in spatially partitioned models. +// This pass must run right after spatial partitioning. class ArCrsCombiner : public HloModulePass { public: - ArCrsCombiner(int num_spatial_partitions, int num_replicas) - : num_spatial_partitions_(num_spatial_partitions), - num_replicas_(num_replicas) {} + ArCrsCombiner(int num_spatial_partitions) + : num_spatial_partitions_(num_spatial_partitions) {} absl::string_view name() const override { return "ar-crs-combiner"; } StatusOr Run(HloModule* module) override; @@ -77,7 +79,6 @@ class ArCrsCombiner : public HloModulePass { StatusOr RewriteGraph(); int num_spatial_partitions_; - int num_replicas_; // Map from all-reduce ids to the all reduce instructions. absl::flat_hash_map> all_reduce_map_; diff --git a/tensorflow/compiler/xla/service/ar_crs_combiner_test.cc b/tensorflow/compiler/xla/service/ar_crs_combiner_test.cc index 23d9aa9eb343054dbe3c6afba161161072195451..caa57296f465698eb70d7cb8327d4678f394b323 100644 --- a/tensorflow/compiler/xla/service/ar_crs_combiner_test.cc +++ b/tensorflow/compiler/xla/service/ar_crs_combiner_test.cc @@ -32,8 +32,8 @@ HloModule foobar ENTRY %entrycomp (p: f32[2,2]) -> (f32[2,2], f32[2,2]) { %p = f32[2,2] parameter(0) - %constant.f32.1 = f32[2,2] constant(f32[2,2] {{1, 2}, {3, 4}}) - %constant.f32.2 = f32[2,2] constant(f32[2,2] {{1, 2}, {3, 4}}) + %constant.f32.1 = f32[2,2] constant({{1, 2}, {3, 4}}) + %constant.f32.2 = f32[2,2] constant({{1, 2}, {3, 4}}) ROOT %tuple = (f32[2,2], f32[2,2]) tuple(%constant.f32.1, %constant.f32.2) } )"; @@ -48,13 +48,50 @@ ENTRY %entrycomp (p: f32[2,2]) -> (f32[2,2], f32[2,2]) { EXPECT_TRUE(ArCrsCombiner::TestInstructionsComputeSameValue(i1, i2)); } +TEST_F(ArCrsCombinerTest, SameValueTestBasecase2) { + const char* module_str = R"( +HloModule foobar + +ENTRY %entrycomp (x: f32[]) -> (f32[], f32[]) { + %x = f32[] parameter(0) + ROOT %tuple = (f32[], f32[]) tuple(%x, %x) +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(module_str)); + auto root_tuple = module->entry_computation()->root_instruction(); + auto i1 = root_tuple->operands()[0]; + auto i2 = root_tuple->operands()[1]; + EXPECT_TRUE(ArCrsCombiner::TestInstructionsComputeSameValue(i1, i2)); +} + +TEST_F(ArCrsCombinerTest, SameValueTestBasecase3) { + const char* module_str = R"( +HloModule foobar + +ENTRY %entrycomp (x: f32[], y: f32[]) -> (f32[], f32[]) { + %x = f32[] parameter(0) + %y = f32[] parameter(1) + ROOT %tuple = (f32[], f32[]) tuple(%x, %y) +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(module_str)); + auto root_tuple = module->entry_computation()->root_instruction(); + auto i1 = root_tuple->operands()[0]; + auto i2 = root_tuple->operands()[1]; + EXPECT_FALSE(ArCrsCombiner::TestInstructionsComputeSameValue(i1, i2)); +} + TEST_F(ArCrsCombinerTest, SameValueTestNumOperands) { const char* module_str = R"( HloModule foobar ENTRY %entrycomp (p: f32[2,2]) -> ((f32[2,2]), (f32[2,2], f32[2,2])) { %p = f32[2,2] parameter(0) - %constant.f32 = f32[2,2] constant(f32[2,2] {{1, 2}, {3, 4}}) + %constant.f32 = f32[2,2] constant({{1, 2}, {3, 4}}) %tuple1 = (f32[2,2]) tuple(%constant.f32) %tuple2 = (f32[2,2], f32[2,2]) tuple(%constant.f32, %constant.f32) ROOT %tuple = ((f32[2,2]), (f32[2,2], f32[2,2])) tuple(%tuple1, %tuple2) @@ -69,13 +106,53 @@ ENTRY %entrycomp (p: f32[2,2]) -> ((f32[2,2]), (f32[2,2], f32[2,2])) { EXPECT_FALSE(ArCrsCombiner::TestInstructionsComputeSameValue(i1, i2)); } +TEST_F(ArCrsCombinerTest, SameValueTestSliceIndicesMatch) { + const char* module_str = R"( +HloModule foobar + +ENTRY %entrycomp (p: f32[2]) -> (f32[1], f32[1]) { + %p = f32[2] parameter(0) + %slice.1 = f32[1] slice(f32[2] %p), slice={[0:1]} + %slice.2 = f32[1] slice(f32[2] %p), slice={[0:1]} + ROOT %tuple = (f32[1], f32[1]) tuple(%slice.1, %slice.2) +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(module_str)); + auto root_tuple = module->entry_computation()->root_instruction(); + auto i1 = root_tuple->operands()[0]; + auto i2 = root_tuple->operands()[1]; + EXPECT_TRUE(ArCrsCombiner::TestInstructionsComputeSameValue(i1, i2)); +} + +TEST_F(ArCrsCombinerTest, SameValueTestSliceIndicesDontMatch) { + const char* module_str = R"( +HloModule foobar + +ENTRY %entrycomp (p: f32[2]) -> (f32[1], f32[1]) { + %p = f32[2] parameter(0) + %slice.1 = f32[1] slice(f32[2] %p), slice={[0:1]} + %slice.2 = f32[1] slice(f32[2] %p), slice={[1:2]} + ROOT %tuple = (f32[1], f32[1]) tuple(%slice.1, %slice.2) +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(module_str)); + auto root_tuple = module->entry_computation()->root_instruction(); + auto i1 = root_tuple->operands()[0]; + auto i2 = root_tuple->operands()[1]; + EXPECT_FALSE(ArCrsCombiner::TestInstructionsComputeSameValue(i1, i2)); +} + TEST_F(ArCrsCombinerTest, SameValueTestTupleElementSameIndex) { const char* module_str = R"( HloModule foobar ENTRY %entrycomp (p: f32[2,2]) -> (f32[2,2], f32[2,2]) { %p = f32[2,2] parameter(0) - %constant.f32 = f32[2,2] constant(f32[2,2] {{1, 2}, {3, 4}}) + %constant.f32 = f32[2,2] constant({{1, 2}, {3, 4}}) %tuple.1 = (f32[2,2], f32[2,2]) tuple(%constant.f32, %constant.f32) %get-tuple-element.1 = f32[2,2] get-tuple-element(%tuple.1), index=0 %get-tuple-element.2 = f32[2,2] get-tuple-element(%tuple.1), index=0 @@ -97,7 +174,7 @@ HloModule foobar ENTRY %entrycomp (p: f32[2,2]) -> (f32[2,2], f32[2,2]) { %p = f32[2,2] parameter(0) - %constant.f32 = f32[2,2] constant(f32[2,2] {{1, 2}, {3, 4}}) + %constant.f32 = f32[2,2] constant({{1, 2}, {3, 4}}) %tuple.1 = (f32[2,2], f32[2,2]) tuple(%constant.f32, %constant.f32) %get-tuple-element.1 = f32[2,2] get-tuple-element(%tuple.1), index=0 %get-tuple-element.2 = f32[2,2] get-tuple-element(%tuple.1), index=1 @@ -119,8 +196,8 @@ HloModule foobar ENTRY %entrycomp (p: f32[2,2]) -> (f32[2,2], f32[2,2]) { %p = f32[2,2] parameter(0) - %constant.f32.1 = f32[2,2] constant(f32[2,2] {{1, 2}, {3, 4}}) - %constant.f32.2 = f32[2,2] constant(f32[2,2] {{2, 3}, {4, 5}}) + %constant.f32.1 = f32[2,2] constant({{1, 2}, {3, 4}}) + %constant.f32.2 = f32[2,2] constant({{2, 3}, {4, 5}}) %tuple.1 = (f32[2,2], f32[2,2]) tuple(%constant.f32.1, %constant.f32.2) %get-tuple-element.1 = f32[2,2] get-tuple-element(%tuple.1), index=0 %get-tuple-element.2 = f32[2,2] get-tuple-element(%tuple.1), index=1 @@ -149,7 +226,7 @@ HloModule foobar %body (x: (f32[2,2], f32[2,2])) -> (f32[2,2], f32[2,2]) { %x = (f32[2,2], f32[2,2]) parameter(0) - %constant.f32 = f32[2,2] constant(f32[2,2] {{1, 2}, {3, 4}}) + %constant.f32 = f32[2,2] constant({{1, 2}, {3, 4}}) %get-tuple-element.1 = f32[2,2] get-tuple-element(%x), index=0 %get-tuple-element.2 = f32[2,2] get-tuple-element(%x), index=1 %add.1 = f32[2,2] add(%get-tuple-element.1, %constant.f32) @@ -158,7 +235,7 @@ HloModule foobar } ENTRY %WhileLoop () -> (f32[2,2], f32[2,2]) { - %constant.f32 = f32[2,2] constant(f32[2,2] {{3, 4}, {5, 6}}) + %constant.f32 = f32[2,2] constant({{3, 4}, {5, 6}}) %init.tuple = (f32[2,2], f32[2,2]) tuple(%constant.f32, %constant.f32) ROOT %while = (f32[2,2], f32[2,2]) while(%init.tuple), condition=%condition, body=%body } @@ -186,7 +263,7 @@ HloModule foobar %body (x: (f32[2,2], f32[2,2])) -> (f32[2,2], f32[2,2]) { %x = (f32[2,2], f32[2,2]) parameter(0) - %constant.f32 = f32[2,2] constant(f32[2,2] {{1, 2}, {3, 4}}) + %constant.f32 = f32[2,2] constant({{1, 2}, {3, 4}}) %get-tuple-element.1 = f32[2,2] get-tuple-element(%x), index=0 %get-tuple-element.2 = f32[2,2] get-tuple-element(%x), index=1 %add.1 = f32[2,2] add(%get-tuple-element.1, %constant.f32) @@ -195,8 +272,8 @@ HloModule foobar } ENTRY %WhileLoop () -> (f32[2,2], f32[2,2]) { - %constant.f32.1 = f32[2,2] constant(f32[2,2] {{3, 4}, {5, 6}}) - %constant.f32.2 = f32[2,2] constant(f32[2,2] {{3, 4}, {7, 8}}) + %constant.f32.1 = f32[2,2] constant({{3, 4}, {5, 6}}) + %constant.f32.2 = f32[2,2] constant({{3, 4}, {7, 8}}) %init.tuple = (f32[2,2], f32[2,2]) tuple(%constant.f32.1, %constant.f32.2) ROOT %while = (f32[2,2], f32[2,2]) while(%init.tuple), condition=%condition, body=%body } @@ -224,8 +301,8 @@ HloModule foobar %body (x: (f32[2,2], f32[2,2])) -> (f32[2,2], f32[2,2]) { %x = (f32[2,2], f32[2,2]) parameter(0) - %constant.f32.1 = f32[2,2] constant(f32[2,2] {{1, 2}, {3, 4}}) - %constant.f32.2 = f32[2,2] constant(f32[2,2] {{3, 4}, {1, 2}}) + %constant.f32.1 = f32[2,2] constant({{1, 2}, {3, 4}}) + %constant.f32.2 = f32[2,2] constant({{3, 4}, {1, 2}}) %get-tuple-element.1 = f32[2,2] get-tuple-element(%x), index=0 %get-tuple-element.2 = f32[2,2] get-tuple-element(%x), index=1 %add.1 = f32[2,2] add(%get-tuple-element.1, %constant.f32.1) @@ -234,7 +311,7 @@ HloModule foobar } ENTRY %WhileLoop () -> (f32[2,2], f32[2,2]) { - %constant.f32 = f32[2,2] constant(f32[2,2] {{3, 4}, {5, 6}}) + %constant.f32 = f32[2,2] constant({{3, 4}, {5, 6}}) %init.tuple = (f32[2,2], f32[2,2]) tuple(%constant.f32, %constant.f32) ROOT %while = (f32[2,2], f32[2,2]) while(%init.tuple), condition=%condition, body=%body } @@ -249,11 +326,27 @@ ENTRY %WhileLoop () -> (f32[2,2], f32[2,2]) { EXPECT_FALSE(ArCrsCombiner::TestInstructionsComputeSameValue(i1, i2)); } -TEST_F(ArCrsCombinerTest, RewritePatternArConvertAddCrs) { +void CompareReplicaGroups(const std::vector& groups_before, + const std::vector& groups_after) { + ASSERT_EQ(groups_before.size(), groups_after.size()); + for (int i = 0; i < groups_before.size(); ++i) { + // Somewhat verbose way to compare the replica_ids, because EqualsProto + // is not available in the open-source build. + auto group_before = groups_before[i]; + std::vector ids_before(group_before.replica_ids().begin(), + group_before.replica_ids().end()); + auto group_after = groups_after[i]; + std::vector ids_after(group_after.replica_ids().begin(), + group_after.replica_ids().end()); + EXPECT_EQ(ids_before, ids_after); + } +} + +TEST_F(ArCrsCombinerTest, RewriteArConvertCrs) { const char* module_str = R"( HloModule foobar -%binary_add (a: bf16[], b: bf16[]) -> bf16[] { +%sum.bf16 (a: bf16[], b: bf16[]) -> bf16[] { %a = bf16[] parameter(0) %b = bf16[] parameter(1) ROOT %add = bf16[] add(%a, %b) @@ -265,49 +358,257 @@ HloModule foobar ROOT %add = f32[] add(%x, %y) } -ENTRY %entrycomp (p: f32[2,2]) -> (f32[2,2], f32[2,2]) { - %p = f32[2,2] parameter(0) - %constant.bf16 = bf16[2,2] constant(bf16[2,2] {{1, 2}, {3, 4}}) - %constant.f32 = f32[2,2] constant(f32[2,2] {{1, 2}, {3, 4}}) +ENTRY %entrycomp (p: bf16[]) -> (f32[], f32[]) { + %p = bf16[] parameter(0) + + %all-reduce.ar.1 = bf16[] + all-reduce(%p), + replica_groups={{0},{1}}, + all_reduce_id=1, + to_apply=%sum.bf16, + sharding={maximal device=0} + %convert.1 = f32[] + convert(%all-reduce.ar.1), + sharding={maximal device=0} + %all-reduce.1 = f32[] + all-reduce(%convert.1), + replica_groups={{0,1}}, + to_apply=%sum.f32, + sharding={maximal device=0} + + %all-reduce.ar.2 = bf16[] + all-reduce(%p), + replica_groups={{0},{1}}, + all_reduce_id=1, + to_apply=%sum.bf16, + sharding={maximal device=1} + %convert.2 = f32[] + convert(%all-reduce.ar.2), + sharding={maximal device=1} + %all-reduce.2 = f32[] + all-reduce(%convert.2), + replica_groups={{0,1}}, + to_apply=%sum.f32, + sharding={maximal device=1} + + ROOT %tuple = (f32[], f32[]) + tuple(%all-reduce.1, %all-reduce.2), + sharding={{maximal device=0}, {maximal device=1}} +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(module_str)); + auto crs_before = + module->entry_computation()->root_instruction()->operands()[0]; + auto replica_groups_before = crs_before->replica_groups(); + ArCrsCombiner combiner(2); + auto changed = combiner.Run(module.get()).ValueOrDie(); + EXPECT_TRUE(changed); + EXPECT_THAT(module->entry_computation()->root_instruction(), + op::Tuple(op::AllReduce(op::Convert(op::Parameter())), + op::AllReduce(op::Convert(op::Parameter())))); + auto crs_after = + module->entry_computation()->root_instruction()->operands()[0]; + auto replica_groups_after = crs_after->replica_groups(); + CompareReplicaGroups(replica_groups_before, replica_groups_after); +} + +TEST_F(ArCrsCombinerTest, RewriteArBitcastCrs) { + const char* module_str = R"( +HloModule foobar + +%sum.1 (a: f32[2,1], b: f32[2,1]) -> f32[2,1] { + %a = f32[2,1] parameter(0) + %b = f32[2,1] parameter(1) + ROOT %add = f32[2,1] add(%a, %b) +} + +%sum.2 (x: f32[2], y: f32[2]) -> f32[2] { + %x = f32[2] parameter(0) + %y = f32[2] parameter(1) + ROOT %add = f32[2] add(%x, %y) +} + +ENTRY %entrycomp (p: f32[2,1]) -> (f32[2], f32[2]) { + %p = f32[2,1] parameter(0) + + %all-reduce.ar.1 = f32[2,1] + all-reduce(%p), + replica_groups={{0},{1}}, + all_reduce_id=1, + to_apply=%sum.1, + sharding={maximal device=0} + %bitcast.1 = f32[2]{0} bitcast(f32[2,1]{1,0} %all-reduce.ar.1) + %all-reduce.1 = f32[2] + all-reduce(%bitcast.1), + replica_groups={{0,1}}, + to_apply=%sum.2, + sharding={maximal device=0} + + %all-reduce.ar.2 = f32[2,1] + all-reduce(%p), + replica_groups={{0},{1}}, + all_reduce_id=1, + to_apply=%sum.1, + sharding={maximal device=1} + %bitcast.2 = f32[2]{0} bitcast(f32[2,1]{1,0} %all-reduce.ar.2) + %all-reduce.2 = f32[2] + all-reduce(%bitcast.2), + replica_groups={{0,1}}, + to_apply=%sum.2, + sharding={maximal device=1} + + ROOT %tuple = (f32[], f32[]) + tuple(%all-reduce.1, %all-reduce.2), + sharding={{maximal device=0}, {maximal device=1}} +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(module_str)); + auto crs_before = + module->entry_computation()->root_instruction()->operands()[0]; + auto replica_groups_before = crs_before->replica_groups(); + ArCrsCombiner combiner(2); + auto changed = combiner.Run(module.get()).ValueOrDie(); + EXPECT_TRUE(changed); + EXPECT_THAT(module->entry_computation()->root_instruction(), + op::Tuple(op::AllReduce(op::Bitcast(op::Parameter())), + op::AllReduce(op::Bitcast(op::Parameter())))); + auto crs_after = + module->entry_computation()->root_instruction()->operands()[0]; + auto replica_groups_after = crs_after->replica_groups(); + CompareReplicaGroups(replica_groups_before, replica_groups_after); +} + +TEST_F(ArCrsCombinerTest, RewriteArMultiplyCrs) { + const char* module_str = R"( +HloModule foobar + +%sum.f32 (x: f32[], y: f32[]) -> f32[] { + %x = f32[] parameter(0) + %y = f32[] parameter(1) + ROOT %add = f32[] add(%x, %y) +} + +ENTRY %entrycomp (p: f32[]) -> (f32[], f32[]) { + %p = f32[] parameter(0) + %constant.f32 = f32[] constant(123) + + %all-reduce.ar.1 = f32[] + all-reduce(%p), + replica_groups={{0},{1}}, + all_reduce_id=1, + to_apply=%sum.f32, + sharding={maximal device=0} + %multiply.1 = f32[] + multiply(%all-reduce.ar.1, %constant.f32), + sharding={maximal device=0} + %all-reduce.1 = f32[] + all-reduce(%multiply.1), + replica_groups={{0,1}}, + to_apply=%sum.f32, + sharding={maximal device=0} - %cross-replica-sum.ar.1 = bf16[2,2] - cross-replica-sum(%constant.bf16), + %all-reduce.ar.2 = f32[] + all-reduce(%p), replica_groups={{0},{1}}, all_reduce_id=1, - to_apply=%binary_add, + to_apply=%sum.f32, + sharding={maximal device=1} + %multiply.2 = f32[] + multiply(%all-reduce.ar.2, %constant.f32), + sharding={maximal device=1} + %all-reduce.2 = f32[] + all-reduce(%multiply.2), + replica_groups={{0,1}}, + to_apply=%sum.f32, + sharding={maximal device=1} + + ROOT %tuple = (f32[], f32[]) + tuple(%all-reduce.1, %all-reduce.2), + sharding={{maximal device=0}, {maximal device=1}} +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(module_str)); + auto crs_before = + module->entry_computation()->root_instruction()->operands()[0]; + auto replica_groups_before = crs_before->replica_groups(); + ArCrsCombiner combiner(2); + auto changed = combiner.Run(module.get()).ValueOrDie(); + EXPECT_TRUE(changed); + EXPECT_THAT( + module->entry_computation()->root_instruction(), + op::Tuple(op::AllReduce(op::Multiply(op::Parameter(), op::Constant())), + op::AllReduce(op::Multiply(op::Parameter(), op::Constant())))); + auto crs_after = + module->entry_computation()->root_instruction()->operands()[0]; + auto replica_groups_after = crs_after->replica_groups(); + CompareReplicaGroups(replica_groups_before, replica_groups_after); +} + +TEST_F(ArCrsCombinerTest, RewriteArConvertAddCrs) { + const char* module_str = R"( +HloModule foobar + +%sum.bf16 (a: bf16[], b: bf16[]) -> bf16[] { + %a = bf16[] parameter(0) + %b = bf16[] parameter(1) + ROOT %add = bf16[] add(%a, %b) +} + +%sum.f32 (x: f32[], y: f32[]) -> f32[] { + %x = f32[] parameter(0) + %y = f32[] parameter(1) + ROOT %add = f32[] add(%x, %y) +} + +ENTRY %entrycomp (p: f32[]) -> (f32[], f32[]) { + %p = f32[] parameter(0) + %constant.bf16 = bf16[] constant(1) + %constant.f32 = f32[] constant(2) + + %all-reduce.ar.1 = bf16[] + all-reduce(%constant.bf16), + replica_groups={{0},{1}}, + all_reduce_id=1, + to_apply=%sum.bf16, sharding={maximal device=0} - %convert.1 = f32[2,2] - convert(%cross-replica-sum.ar.1), + %convert.1 = f32[] + convert(%all-reduce.ar.1), sharding={maximal device=0} - %add.1 = f32[2,2] + %add.1 = f32[] add(%constant.f32, %convert.1), sharding={maximal device=0} - %cross-replica-sum.1 = f32[2,2] - cross-replica-sum(%add.1), + %all-reduce.1 = f32[] + all-reduce(%add.1), replica_groups={{0,1}}, to_apply=%sum.f32, sharding={maximal device=0} - %cross-replica-sum.ar.2 = bf16[2,2] - cross-replica-sum(%constant.bf16), + %all-reduce.ar.2 = bf16[] + all-reduce(%constant.bf16), replica_groups={{0},{1}}, all_reduce_id=1, - to_apply=%binary_add, + to_apply=%sum.bf16, sharding={maximal device=1} - %convert.2 = f32[2,2] - convert(%cross-replica-sum.ar.2), + %convert.2 = f32[] + convert(%all-reduce.ar.2), sharding={maximal device=1} - %add.2 = f32[2,2] + %add.2 = f32[] add(%constant.f32, %convert.2), sharding={maximal device=1} - %cross-replica-sum.2 = f32[2,2] - cross-replica-sum(%add.2), + %all-reduce.2 = f32[] + all-reduce(%add.2), replica_groups={{0,1}}, to_apply=%sum.f32, sharding={maximal device=1} - ROOT %tuple = (f32[2,2], f32[2,2]) - tuple(%cross-replica-sum.1, %cross-replica-sum.2), + ROOT %tuple = (f32[], f32[]) + tuple(%all-reduce.1, %all-reduce.2), sharding={{maximal device=0}, {maximal device=1}} } )"; @@ -317,37 +618,27 @@ ENTRY %entrycomp (p: f32[2,2]) -> (f32[2,2], f32[2,2]) { auto crs_before = module->entry_computation()->root_instruction()->operands()[0]; auto replica_groups_before = crs_before->replica_groups(); - ArCrsCombiner combiner(2, 1); + ArCrsCombiner combiner(2); auto changed = combiner.Run(module.get()).ValueOrDie(); EXPECT_TRUE(changed); EXPECT_THAT( module->entry_computation()->root_instruction(), - op::Tuple(op::Subtract(op::CrossReplicaSum(), - op::Multiply(op::Constant(), op::Constant())), - op::Subtract(op::CrossReplicaSum(), - op::Multiply(op::Constant(), op::Constant())))); - auto sub = module->entry_computation()->root_instruction()->operands()[0]; - auto crs_after = sub->operands()[0]; + op::Tuple( + op::AllReduce(op::Add(op::Divide(op::Constant(), op::Constant()), + op::Convert())), + op::AllReduce(op::Add(op::Divide(op::Constant(), op::Constant()), + op::Convert())))); + auto crs_after = + module->entry_computation()->root_instruction()->operands()[0]; auto replica_groups_after = crs_after->replica_groups(); - ASSERT_EQ(replica_groups_before.size(), replica_groups_after.size()); - for (int i = 0; i < replica_groups_before.size(); ++i) { - // Somewhat verbose way to compare the replica_ids, because EqualsProto - // is not available in the open-source build. - auto group_before = replica_groups_before[i]; - std::vector ids_before(group_before.replica_ids().begin(), - group_before.replica_ids().end()); - auto group_after = replica_groups_after[i]; - std::vector ids_after(group_after.replica_ids().begin(), - group_after.replica_ids().end()); - EXPECT_EQ(ids_before, ids_after); - } + CompareReplicaGroups(replica_groups_before, replica_groups_after); } TEST_F(ArCrsCombinerTest, OtherSummandNotTheSameDontRewrite) { const char* module_str = R"( HloModule foobar -%binary_add (a: bf16[], b: bf16[]) -> bf16[] { +%sum.bf16 (a: bf16[], b: bf16[]) -> bf16[] { %a = bf16[] parameter(0) %b = bf16[] parameter(1) ROOT %add = bf16[] add(%a, %b) @@ -359,57 +650,57 @@ HloModule foobar ROOT %add = f32[] add(%x, %y) } -ENTRY %entrycomp (p: f32[2,2]) -> (f32[2,2], f32[2,2]) { - %p = f32[2,2] parameter(0) - %constant.bf16 = bf16[2,2] constant(bf16[2,2] {{1, 2}, {3, 4}}) - %constant.f32.1 = f32[2,2] constant(f32[2,2] {{1, 2}, {3, 4}}) - %constant.f32.2 = f32[2,2] constant(f32[2,2] {{3, 4}, {5, 6}}) +ENTRY %entrycomp (p: f32[]) -> (f32[], f32[]) { + %p = f32[] parameter(0) + %constant.bf16 = bf16[] constant(1) + %constant.f32.1 = f32[] constant(2) + %constant.f32.2 = f32[] constant(3) - %cross-replica-sum.ar.1 = bf16[2,2] - cross-replica-sum(%constant.bf16), + %all-reduce.ar.1 = bf16[] + all-reduce(%constant.bf16), replica_groups={{0},{1}}, all_reduce_id=1, - to_apply=%binary_add, + to_apply=%sum.bf16, sharding={maximal device=0} - %convert.1 = f32[2,2] - convert(%cross-replica-sum.ar.1), + %convert.1 = f32[] + convert(%all-reduce.ar.1), sharding={maximal device=0} - %add.1 = f32[2,2] + %add.1 = f32[] add(%constant.f32.1, %convert.1), sharding={maximal device=0} - %cross-replica-sum.1 = f32[2,2] - cross-replica-sum(%add.1), + %all-reduce.1 = f32[] + all-reduce(%add.1), replica_groups={{0,1}}, to_apply=%sum.f32, sharding={maximal device=0} - %cross-replica-sum.ar.2 = bf16[2,2] - cross-replica-sum(%constant.bf16), + %all-reduce.ar.2 = bf16[] + all-reduce(%constant.bf16), replica_groups={{0},{1}}, all_reduce_id=1, - to_apply=%binary_add, + to_apply=%sum.bf16, sharding={maximal device=1} - %convert.2 = f32[2,2] - convert(%cross-replica-sum.ar.2), + %convert.2 = f32[] + convert(%all-reduce.ar.2), sharding={maximal device=1} - %add.2 = f32[2,2] + %add.2 = f32[] add(%constant.f32.2, %convert.2), sharding={maximal device=1} - %cross-replica-sum.2 = f32[2,2] - cross-replica-sum(%add.2), + %all-reduce.2 = f32[] + all-reduce(%add.2), replica_groups={{0,1}}, to_apply=%sum.f32, sharding={maximal device=1} - ROOT %tuple = (f32[2,2], f32[2,2]) - tuple(%cross-replica-sum.1, %cross-replica-sum.2), + ROOT %tuple = (f32[], f32[]) + tuple(%all-reduce.1, %all-reduce.2), sharding={{maximal device=0}, {maximal device=1}} } )"; TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, ParseAndReturnVerifiedModule(module_str)); - ArCrsCombiner combiner(2, 1); + ArCrsCombiner combiner(2); auto changed = combiner.Run(module.get()).ValueOrDie(); EXPECT_FALSE(changed); } diff --git a/tensorflow/compiler/xla/service/backend.cc b/tensorflow/compiler/xla/service/backend.cc index 5c180cbdd492031e133b81149f0f4698619b7788..2cf24a9dd5fa18abe9dde4eb49b03c6586bfef03 100644 --- a/tensorflow/compiler/xla/service/backend.cc +++ b/tensorflow/compiler/xla/service/backend.cc @@ -57,6 +57,16 @@ int BackendOptions::intra_op_parallelism_threads() const { return intra_op_parallelism_threads_; } +BackendOptions& BackendOptions::set_allowed_devices( + const absl::optional>& allowed_devices) { + allowed_devices_ = allowed_devices; + return *this; +} + +const absl::optional>& BackendOptions::allowed_devices() const { + return allowed_devices_; +} + // Define this in .cc file to avoid having to include eigen or forward declare // these types in the header. struct Backend::EigenThreadPoolWrapper { @@ -76,8 +86,9 @@ struct Backend::EigenThreadPoolWrapper { const BackendOptions& options) { se::Platform* platform = options.platform(); TF_ASSIGN_OR_RETURN(auto compiler, Compiler::GetForPlatform(platform)); - TF_ASSIGN_OR_RETURN(auto stream_executors, - PlatformUtil::GetStreamExecutors(platform)); + TF_ASSIGN_OR_RETURN( + auto stream_executors, + PlatformUtil::GetStreamExecutors(platform, options.allowed_devices())); TF_ASSIGN_OR_RETURN(auto transfer_manager, TransferManager::GetForPlatform(platform)); TF_ASSIGN_OR_RETURN(auto computation_placer, diff --git a/tensorflow/compiler/xla/service/backend.h b/tensorflow/compiler/xla/service/backend.h index a2dafbe803f8bd5f23e4e9f3f6d3e6f744c9fab9..7ca993fb2656037951d98d9c4459a3c3e4c64c61 100644 --- a/tensorflow/compiler/xla/service/backend.h +++ b/tensorflow/compiler/xla/service/backend.h @@ -18,6 +18,7 @@ limitations under the License. #include #include +#include #include #include @@ -53,9 +54,16 @@ class BackendOptions { BackendOptions& set_intra_op_parallelism_threads(int num_threads); int intra_op_parallelism_threads() const; + // Sets the allowed_devices for selectively constructing stream executors + // on the platform. + BackendOptions& set_allowed_devices( + const absl::optional>& allowed_devices); + const absl::optional>& allowed_devices() const; + private: se::Platform* platform_ = nullptr; int intra_op_parallelism_threads_ = -1; + absl::optional> allowed_devices_; }; // Class which encapsulates an XLA backend. It includes everything necessary diff --git a/tensorflow/compiler/xla/service/bfloat16_conversion_folding.cc b/tensorflow/compiler/xla/service/bfloat16_conversion_folding.cc index e9d30fc03c1c3194de577e6683b36a95641694d9..6caef77ed00909040a54e65651cc6fb7ca74eb90 100644 --- a/tensorflow/compiler/xla/service/bfloat16_conversion_folding.cc +++ b/tensorflow/compiler/xla/service/bfloat16_conversion_folding.cc @@ -34,8 +34,8 @@ class BFloat16ConversionFoldingVisitor : public DfsHloVisitorWithDefault { Status DefaultAction(HloInstruction* hlo) override; - // Special handling for cross-replica-sum which can have a tuple output. - Status HandleCrossReplicaSum(HloInstruction* crs) override; + // Special handling for all-reduce which can have a tuple output. + Status HandleAllReduce(HloInstruction* crs) override; static bool Run(HloComputation* computation, const BFloat16Support* bfloat16_support) { @@ -176,8 +176,7 @@ Status BFloat16ConversionFoldingVisitor::DefaultAction(HloInstruction* hlo) { return TryFoldBF16Conversions(hlo); } -Status BFloat16ConversionFoldingVisitor::HandleCrossReplicaSum( - HloInstruction* crs) { +Status BFloat16ConversionFoldingVisitor::HandleAllReduce(HloInstruction* crs) { if (crs->IsCrossModuleAllReduce()) { // Cross-module all-reduce has side effect. return Status::OK(); diff --git a/tensorflow/compiler/xla/service/bfloat16_conversion_folding_test.cc b/tensorflow/compiler/xla/service/bfloat16_conversion_folding_test.cc index 4ce351acc2c359773e618da70360c96faf5ca379..2232a2cbdfe0cf64dc4fb10d4598c0ad8b51ee5e 100644 --- a/tensorflow/compiler/xla/service/bfloat16_conversion_folding_test.cc +++ b/tensorflow/compiler/xla/service/bfloat16_conversion_folding_test.cc @@ -38,7 +38,7 @@ class TestBFloat16Support : public BFloat16Support { hlo.opcode() == HloOpcode::kSubtract || hlo.opcode() == HloOpcode::kTuple || hlo.opcode() == HloOpcode::kGetTupleElement || - hlo.opcode() == HloOpcode::kCrossReplicaSum) { + hlo.opcode() == HloOpcode::kAllReduce) { return true; } return false; @@ -49,7 +49,7 @@ class TestBFloat16Support : public BFloat16Support { hlo.opcode() == HloOpcode::kSubtract || hlo.opcode() == HloOpcode::kTuple || hlo.opcode() == HloOpcode::kGetTupleElement || - hlo.opcode() == HloOpcode::kCrossReplicaSum) { + hlo.opcode() == HloOpcode::kAllReduce) { return true; } return false; @@ -58,7 +58,7 @@ class TestBFloat16Support : public BFloat16Support { bool SupportsMixedPrecisions(const HloInstruction& hlo) const override { if (hlo.opcode() == HloOpcode::kAdd || hlo.opcode() == HloOpcode::kTuple || hlo.opcode() == HloOpcode::kGetTupleElement || - hlo.opcode() == HloOpcode::kCrossReplicaSum) { + hlo.opcode() == HloOpcode::kAllReduce) { return true; } return false; @@ -213,7 +213,7 @@ TEST_F(BFloat16ConversionFoldingTest, DoNotFoldTuple) { EXPECT_EQ(tuple->operand(1), convert0); } -TEST_F(BFloat16ConversionFoldingTest, FoldCrossReplicaSumTupleOutput) { +TEST_F(BFloat16ConversionFoldingTest, FoldAllReduceTupleOutput) { auto builder = HloComputation::Builder(TestName()); auto module = CreateNewVerifiedModule(); @@ -236,11 +236,10 @@ TEST_F(BFloat16ConversionFoldingTest, FoldCrossReplicaSumTupleOutput) { HloInstruction* b = builder.AddInstruction( HloInstruction::CreateParameter(1, f32_shape, "b")); - HloInstruction* crs = - builder.AddInstruction(HloInstruction::CreateCrossReplicaSum( - ShapeUtil::MakeTupleShape({f32_shape, f32_shape}), {convert_a, b}, - sum, /*replica_groups=*/{}, /*barrier=*/"", - /*all_reduce_id=*/absl::nullopt)); + HloInstruction* crs = builder.AddInstruction(HloInstruction::CreateAllReduce( + ShapeUtil::MakeTupleShape({f32_shape, f32_shape}), {convert_a, b}, sum, + /*replica_groups=*/{}, /*barrier=*/"", + /*all_reduce_id=*/absl::nullopt)); HloInstruction* gte_a = builder.AddInstruction( HloInstruction::CreateGetTupleElement(f32_shape, crs, 0)); HloInstruction* gte_b = builder.AddInstruction( diff --git a/tensorflow/compiler/xla/service/bfloat16_normalization.cc b/tensorflow/compiler/xla/service/bfloat16_normalization.cc index b8a8f844eff17a95d4073f53495e0027c481f558..e3aefe906739b74e887f33d2ffc3ad7a60510b5b 100644 --- a/tensorflow/compiler/xla/service/bfloat16_normalization.cc +++ b/tensorflow/compiler/xla/service/bfloat16_normalization.cc @@ -362,7 +362,7 @@ Status BFloat16NormalizationVisitor::DefaultAction(HloInstruction* hlo) { } // TODO(b/112040122): Correctly normalize variadic reduce. if ((hlo->opcode() == HloOpcode::kSort || - hlo->opcode() == HloOpcode::kCrossReplicaSum) && + hlo->opcode() == HloOpcode::kAllReduce) && ShapeUtil::IsTuple(hlo->shape())) { return HandleMultipleOutputs(hlo); } diff --git a/tensorflow/compiler/xla/service/bfloat16_normalization_test.cc b/tensorflow/compiler/xla/service/bfloat16_normalization_test.cc index 9f97d18c565c7915b9f9346f0c6330cdc3c707e9..551ac4be73a7630d213a53ca3606aa7f890cd794 100644 --- a/tensorflow/compiler/xla/service/bfloat16_normalization_test.cc +++ b/tensorflow/compiler/xla/service/bfloat16_normalization_test.cc @@ -232,7 +232,7 @@ TEST_F(BFloat16NormalizationTest, ResolveUnsupportedMixedPrecisionReduce) { EXPECT_EQ(reduce->operand(1)->shape().element_type(), F32); } -TEST_F(BFloat16NormalizationTest, ResolveMixedPrecisionTupleCrossReplicaSum) { +TEST_F(BFloat16NormalizationTest, ResolveMixedPrecisionTupleAllReduce) { auto module = CreateNewVerifiedModule(); HloComputation::Builder sum_builder("sum"); auto x = sum_builder.AddInstruction(HloInstruction::CreateParameter( @@ -253,11 +253,10 @@ TEST_F(BFloat16NormalizationTest, ResolveMixedPrecisionTupleCrossReplicaSum) { HloInstruction* b = builder.AddInstruction( HloInstruction::CreateParameter(1, bf16_shape, "b")); - HloInstruction* crs = - builder.AddInstruction(HloInstruction::CreateCrossReplicaSum( - ShapeUtil::MakeTupleShape({f32_shape, bf16_shape}), {a, b}, reduction, - /*replica_groups=*/{}, /*barrier=*/"", - /*all_reduce_id=*/absl::nullopt)); + HloInstruction* crs = builder.AddInstruction(HloInstruction::CreateAllReduce( + ShapeUtil::MakeTupleShape({f32_shape, bf16_shape}), {a, b}, reduction, + /*replica_groups=*/{}, /*barrier=*/"", + /*all_reduce_id=*/absl::nullopt)); HloInstruction* gte = builder.AddInstruction( HloInstruction::CreateGetTupleElement(bf16_shape, crs, 1)); diff --git a/tensorflow/compiler/xla/service/bfloat16_propagation.cc b/tensorflow/compiler/xla/service/bfloat16_propagation.cc index 63d4572f2028c462df1cac9d5e4ee616e407f37b..05dd4b3e914f5563a33d534829ffb01668279064 100644 --- a/tensorflow/compiler/xla/service/bfloat16_propagation.cc +++ b/tensorflow/compiler/xla/service/bfloat16_propagation.cc @@ -276,7 +276,7 @@ bool BFloat16Propagation::AllUsersConsumeBF16(const HloInstruction& hlo, if (bfloat16_support_->EffectiveOperandPrecisionIsOutputPrecision( *use.instruction, use.operand_number)) { if (use.instruction->opcode() == HloOpcode::kTuple || - (use.instruction->opcode() == HloOpcode::kCrossReplicaSum && + (use.instruction->opcode() == HloOpcode::kAllReduce && ShapeUtil::IsTuple(use.instruction->shape()))) { ShapeIndex use_output_index{use.operand_number}; for (int64 i : use.operand_index) { diff --git a/tensorflow/compiler/xla/service/bfloat16_propagation_test.cc b/tensorflow/compiler/xla/service/bfloat16_propagation_test.cc index 5be7141aae423adb4fe2f39262e463ff25ae8234..a9b5d9916e400b39039248098c22a715e44ccfd2 100644 --- a/tensorflow/compiler/xla/service/bfloat16_propagation_test.cc +++ b/tensorflow/compiler/xla/service/bfloat16_propagation_test.cc @@ -209,7 +209,7 @@ TEST_F(BFloat16PropagationTest, DoNotChangeAllReduce) { rb.AddInstruction(HloInstruction::CreateParameter(1, shape, "p1")))); auto reduction = module->AddEmbeddedComputation(rb.Build()); HloInstruction* all_reduce = - builder.AddInstruction(HloInstruction::CreateCrossReplicaSum( + builder.AddInstruction(HloInstruction::CreateAllReduce( ShapeUtil::MakeTupleShape({shape, shape}), {a, b}, reduction, /*replica_groups=*/{}, /*barrier=*/"", /*all_reduce_id=*/1)); HloInstruction* gte0 = builder.AddInstruction( diff --git a/tensorflow/compiler/xla/service/buffer_assignment.cc b/tensorflow/compiler/xla/service/buffer_assignment.cc index 8d7c62447852fd946440c41389300a92377c471f..202e45e181d13621f79e3bf95e33091b54e8b779 100644 --- a/tensorflow/compiler/xla/service/buffer_assignment.cc +++ b/tensorflow/compiler/xla/service/buffer_assignment.cc @@ -186,7 +186,7 @@ Status GatherComputationsByAllocationType( worklist.push_back(std::make_pair(subcomputation, false)); // Not thread local. break; - case HloOpcode::kCrossReplicaSum: + case HloOpcode::kAllReduce: case HloOpcode::kMap: case HloOpcode::kReduce: case HloOpcode::kReduceWindow: diff --git a/tensorflow/compiler/xla/service/call_graph.cc b/tensorflow/compiler/xla/service/call_graph.cc index 7987343bfaf1069fd550909d127e4b11f2124701..173b3fc05f53d523fb07ef9b14be884fd5f8aeb1 100644 --- a/tensorflow/compiler/xla/service/call_graph.cc +++ b/tensorflow/compiler/xla/service/call_graph.cc @@ -58,7 +58,7 @@ CallContext GetInstructionCallContext(HloOpcode opcode) { case HloOpcode::kConditional: case HloOpcode::kWhile: return CallContext::kSequential; - case HloOpcode::kCrossReplicaSum: + case HloOpcode::kAllReduce: case HloOpcode::kMap: case HloOpcode::kReduce: case HloOpcode::kReduceWindow: diff --git a/tensorflow/compiler/xla/service/convolution_feature_group_converter_test.cc b/tensorflow/compiler/xla/service/convolution_feature_group_converter_test.cc index e6bf2143a21bd5001d3530fe8727c88504be1d43..d58f157242f5fb9690f7fda3e7d8f71ca6c8db84 100644 --- a/tensorflow/compiler/xla/service/convolution_feature_group_converter_test.cc +++ b/tensorflow/compiler/xla/service/convolution_feature_group_converter_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/xla/service/convolution_feature_group_converter.h" +#include "tensorflow/compiler/xla/service/convolution_group_converter.h" #include #include @@ -30,10 +30,10 @@ limitations under the License. namespace xla { namespace { -using ConvolutionFeatureGroupConverterTest = HloTestBase; +using ConvolutionGroupConverterTest = HloTestBase; namespace op = testing::opcode_matchers; -TEST_F(ConvolutionFeatureGroupConverterTest, +TEST_F(ConvolutionGroupConverterTest, ConvertFeatureGroupCountEqualToInputFeatureDim) { string hlo_string = R"(HloModule Convolve1D1Window_0_module @@ -49,7 +49,8 @@ ENTRY %Convolve1D1Window_0.v3 (input: f32[1,2,2], filter: f32[1,1,2]) -> f32[1,2 auto computation = module->entry_computation(); HloInstruction* root = computation->root_instruction(); EXPECT_EQ(root->opcode(), HloOpcode::kConvolution); - ConvolutionFeatureGroupConverter converter; + ConvolutionGroupConverter converter(nullptr, /*convert_batch_groups_only=*/ + false); ASSERT_TRUE(converter.Run(module.get()).ValueOrDie()); root = computation->root_instruction(); // Make sure the convolution is converted to one with feature_group_count = 1. @@ -63,7 +64,7 @@ ENTRY %Convolve1D1Window_0.v3 (input: f32[1,2,2], filter: f32[1,1,2]) -> f32[1,2 op::Broadcast(op::Constant()))); } -TEST_F(ConvolutionFeatureGroupConverterTest, +TEST_F(ConvolutionGroupConverterTest, ConvertFeatureGroupCountDivisorOfInputFeatureDim) { string hlo_string = R"(HloModule Convolve1D1Window_0_module @@ -79,7 +80,8 @@ ENTRY %Convolve1D1Window_0.v3 (input: f32[1,2,4], filter: f32[1,2,2]) -> f32[1,2 auto computation = module->entry_computation(); HloInstruction* root = computation->root_instruction(); EXPECT_EQ(root->opcode(), HloOpcode::kConvolution); - ConvolutionFeatureGroupConverter converter; + ConvolutionGroupConverter converter(nullptr, /*convert_batch_groups_only=*/ + false); ASSERT_TRUE(converter.Run(module.get()).ValueOrDie()); root = computation->root_instruction(); // Make sure the convolution is replaced with a concatenate. diff --git a/tensorflow/compiler/xla/service/convolution_feature_group_converter.cc b/tensorflow/compiler/xla/service/convolution_group_converter.cc similarity index 61% rename from tensorflow/compiler/xla/service/convolution_feature_group_converter.cc rename to tensorflow/compiler/xla/service/convolution_group_converter.cc index 95c7724c3c93507ae61a984301ecfc0111bef192..7a24faec17f0c4f0a57406328b1c21cd73506d82 100644 --- a/tensorflow/compiler/xla/service/convolution_feature_group_converter.cc +++ b/tensorflow/compiler/xla/service/convolution_group_converter.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/xla/service/convolution_feature_group_converter.h" +#include "tensorflow/compiler/xla/service/convolution_group_converter.h" #include #include @@ -50,8 +50,12 @@ class ConvolutionVisitor : public DfsHloVisitorWithDefault { Status HandleConvolution(HloInstruction* convolution) override; + Status HandleBatchGroupCount(HloInstruction* convolution); + // Runs the visitor on a computation. static bool Run(HloComputation* computation, + std::function is_cost_viable, + bool convert_batch_groups_only, bool canonicalize_depthwise_filter); // Returns whether any convolution ops were rewritten. @@ -60,10 +64,15 @@ class ConvolutionVisitor : public DfsHloVisitorWithDefault { ~ConvolutionVisitor() override = default; private: - explicit ConvolutionVisitor(HloComputation* computation, - bool canonicalize_depthwise_filter = false) + explicit ConvolutionVisitor( + HloComputation* computation, + std::function is_cost_viable, + bool convert_batch_groups_only, + bool canonicalize_depthwise_filter = false) : computation_(computation), - filter_expansion_(!canonicalize_depthwise_filter) {} + filter_expansion_(!canonicalize_depthwise_filter), + convert_batch_groups_only_(convert_batch_groups_only), + is_cost_viable_(is_cost_viable) {} // Current HloComputation instance the ConvolutionVisitor is traversing. HloComputation* computation_; @@ -73,11 +82,21 @@ class ConvolutionVisitor : public DfsHloVisitorWithDefault { // Whether filter expansion is required. bool filter_expansion_; + + // Decides whether to convert batch groups or feature groups. + bool convert_batch_groups_only_; + + // std::function(int64, int64)> chunk_fetcher + std::function is_cost_viable_; }; -bool ConvolutionVisitor::Run(HloComputation* computation, - bool canonicalize_depthwise_filter) { - ConvolutionVisitor visitor(computation, canonicalize_depthwise_filter); +bool ConvolutionVisitor::Run( + HloComputation* computation, + std::function is_cost_viable, + bool convert_batch_groups_only, bool canonicalize_depthwise_filter) { + ConvolutionVisitor visitor(computation, is_cost_viable, + convert_batch_groups_only, + canonicalize_depthwise_filter); TF_CHECK_OK(computation->Accept(&visitor)); return visitor.changed_; } @@ -176,18 +195,206 @@ HloInstruction* GetExpandedFilterMask( predicate_shape, HloOpcode::kEq, broadcasted_mask1, broadcasted_mask2)); } -Status ConvolutionVisitor::HandleConvolution(HloInstruction* convolution) { - int64 group_count = convolution->feature_group_count(); - if (group_count == 1) { +// This function handles batch_group_counts which are relevant only for +// depthwise backprop filter convolutions. +Status ConvolutionVisitor::HandleBatchGroupCount(HloInstruction* convolution) { + auto dim_numbers = convolution->convolution_dimension_numbers(); + auto activation = convolution->mutable_operand(0); + auto filter = convolution->mutable_operand(1); + int64 batch_group_count = convolution->batch_group_count(); + + if (batch_group_count == 1) { return Status::OK(); } - auto filter = convolution->mutable_operand(1); - changed_ = true; + + VLOG(2) << "Dealing with batch_group_count " << batch_group_count << "\n"; + auto add = [&](std::unique_ptr inst) { return computation_->AddInstruction(std::move(inst)); }; + int64 input_batch_dimension = dim_numbers.input_batch_dimension(); + int64 input_feature_dimension = dim_numbers.input_feature_dimension(); + int64 output_batch_dimension = dim_numbers.output_batch_dimension(); + int64 output_feature_dimension = dim_numbers.output_feature_dimension(); + int64 kernel_input_feature_dimension = + dim_numbers.kernel_input_feature_dimension(); + + int64 input_batch = activation->shape().dimensions(input_batch_dimension); + + // We are not yet supporting batch_group of sizes greater than 1. + TF_RET_CHECK(input_batch == batch_group_count); + + if (is_cost_viable_(convolution)) { + // Add a dimension to the activation, and reshape. + Shape reshaped_activation_shape = activation->shape(); + ShapeUtil::AppendMajorDimension(1, &reshaped_activation_shape); + + activation = add( + HloInstruction::CreateReshape(reshaped_activation_shape, activation)); + + // Add a dimension to the filter, and reshape. + Shape reshaped_filter_shape = filter->shape(); + ShapeUtil::AppendMajorDimension(1, &reshaped_filter_shape); + + filter = add(HloInstruction::CreateReshape(reshaped_filter_shape, filter)); + + int64 new_spatial_dim = reshaped_activation_shape.dimensions().size() - 1; + + Shape new_output_shape = convolution->shape(); + ShapeUtil::AppendMajorDimension(1, &new_output_shape); + + int64 input_feature = + activation->shape().dimensions(input_feature_dimension); + + // The code below edits convolution dimension numbers. Please refer to + // conv_op_helpers.cc to find how the dimensions were set up originally. + + // Effectively, the new input batch becomes 1, and so does the kernel + // input feature. The original input batch now becomes a spatial dimension. + // The output batch (remember that the output is the new kernel for in + // backprop) becomes a spatial dimension too. + + dim_numbers.set_input_batch_dimension(new_spatial_dim); + dim_numbers.set_input_feature_dimension(input_batch_dimension); + dim_numbers.set_kernel_input_feature_dimension(new_spatial_dim); + + dim_numbers.add_input_spatial_dimensions(input_feature_dimension); + dim_numbers.add_kernel_spatial_dimensions(kernel_input_feature_dimension); + + dim_numbers.add_output_spatial_dimensions(output_batch_dimension); + dim_numbers.set_output_batch_dimension(new_spatial_dim); + + // Add window for the new spatial dimension. + Window new_window = convolution->window(); + auto* dim = new_window.add_dimensions(); + dim->set_window_dilation(1); + dim->set_base_dilation(1); + dim->set_stride(1); + dim->set_size(input_feature); + + auto new_convolution = add(HloInstruction::CreateConvolve( + new_output_shape, activation, filter, + /*feature_group_count=*/batch_group_count, /*batch_group_count=*/1, + new_window, dim_numbers, convolution->precision_config())); + + // Delete the extra spatial dimension, and reshape. + Shape reshaped_convolution_shape = ShapeUtil::DeleteDimension( + new_spatial_dim - 1, new_convolution->shape()); + auto reshaped_convolution = HloInstruction::CreateReshape( + reshaped_convolution_shape, new_convolution); + + TF_RETURN_IF_ERROR(computation_->ReplaceWithNewInstruction( + convolution, std::move(reshaped_convolution))); + + changed_ = true; + } else { + // We first obtain the expanded the filter (which is the convolution + // output). The batch dimension is the expanded one (which originally + // represents kernel input feature dimension). We mask the filter to zero + // out the expanded regions. Next we reduce the filter in the batch + // dimension to obtain the original filter size. + + HloInstruction* filter_mask = + GetExpandedFilterMask(convolution->shape(), output_batch_dimension, + output_feature_dimension, batch_group_count, add); + auto expanded_filter_shape = ExpandedFilterShape( + convolution->shape(), batch_group_count, output_batch_dimension); + + auto new_convolution = add(HloInstruction::CreateConvolve( + expanded_filter_shape, activation, filter, + /*feature_group_count=*/1, /*batch_group_count=*/1, + convolution->window(), dim_numbers, convolution->precision_config())); + + auto zero = add(HloInstruction::CreateConstant( + LiteralUtil::Zero(expanded_filter_shape.element_type()))); + auto zero_filter = + add(HloInstruction::CreateBroadcast(expanded_filter_shape, zero, {})); + + auto new_filter = add(HloInstruction::CreateTernary( + expanded_filter_shape, HloOpcode::kSelect, filter_mask, new_convolution, + zero_filter)); + + auto zero_literal = LiteralUtil::CreateR0(0.0f); + TF_ASSIGN_OR_RETURN(zero_literal, zero_literal.Convert(F32)); + auto zero_scalar = + add(HloInstruction::CreateConstant(std::move(zero_literal))); + + auto reduce_function = [&]() -> HloComputation* { + HloComputation::Builder b("add_computation"); + Shape shape = ShapeUtil::MakeShape(F32, {}); + auto lhs = + b.AddInstruction(HloInstruction::CreateParameter(0, shape, "lhs")); + auto rhs = + b.AddInstruction(HloInstruction::CreateParameter(1, shape, "rhs")); + auto scalar_op = b.AddInstruction( + HloInstruction::CreateBinary(shape, HloOpcode::kAdd, lhs, rhs)); + return computation_->parent()->AddEmbeddedComputation(b.Build(scalar_op)); + }; + + // Ensure that data input to reduce window is of type F32. + if (primitive_util::BitWidth(new_filter->shape().element_type()) < + primitive_util::BitWidth(F32)) { + Shape convert_shape = new_filter->shape(); + convert_shape.set_element_type(F32); + new_filter = + add(HloInstruction::CreateBitcastConvert(convert_shape, new_filter)); + } + + auto reduce_window_shape = new_convolution->shape(); + reduce_window_shape.set_dimensions(output_batch_dimension, 1); + + // Create the reduce window. + Window window; + for (int64 i = 0; i < new_convolution->shape().dimensions_size(); ++i) { + auto* dim = window.add_dimensions(); + dim->set_padding_low(0); + dim->set_padding_high(0); + dim->set_window_dilation(1); + dim->set_base_dilation(1); + if (i == output_batch_dimension) { + dim->set_stride(batch_group_count); + dim->set_size(batch_group_count); + } else { + dim->set_stride(1); + dim->set_size(1); + } + } + auto reduce_window = add(HloInstruction::CreateReduceWindow( + reduce_window_shape, new_filter, zero_scalar, window, + reduce_function())); + + Shape convert_back_shape = reduce_window->shape(); + convert_back_shape.set_element_type(activation->shape().element_type()); + + // Convert reduced data back to the original data type. + auto reduce_window_converted = + HloInstruction::CreateBitcastConvert(convert_back_shape, reduce_window); + + TF_RETURN_IF_ERROR(computation_->ReplaceWithNewInstruction( + convolution, std::move(reduce_window_converted))); + } + + return Status::OK(); +} + +Status ConvolutionVisitor::HandleConvolution(HloInstruction* convolution) { + if (convert_batch_groups_only_) { + return HandleBatchGroupCount(convolution); + } + + auto add = [&](std::unique_ptr inst) { + return computation_->AddInstruction(std::move(inst)); + }; + + int64 group_count = convolution->feature_group_count(); + if (group_count == 1) { + return Status::OK(); + } + + changed_ = true; auto dim_numbers = convolution->convolution_dimension_numbers(); + auto filter = convolution->mutable_operand(1); int64 kernel_input_feature_dim = dim_numbers.kernel_input_feature_dimension(); int64 group_size = filter->shape().dimensions(kernel_input_feature_dim); int64 kernel_output_feature_dim = @@ -205,6 +412,7 @@ Status ConvolutionVisitor::HandleConvolution(HloInstruction* convolution) { // If the code generator handles depthwise separable convolutions // inherently, then no filter expansion is needed. if (!filter_expansion_ && depthwise_separable) { + changed_ = false; return Status::OK(); } // We want to repeat 'filter' in the 'input_feature_dim' dimension @@ -233,8 +441,8 @@ Status ConvolutionVisitor::HandleConvolution(HloInstruction* convolution) { auto new_convolution = HloInstruction::CreateConvolve( convolution->shape(), convolution->mutable_operand(0), new_filter, - /*feature_group_count=*/1, convolution->window(), dim_numbers, - convolution->precision_config()); + /*feature_group_count=*/1, /*batch_group_count=*/1, + convolution->window(), dim_numbers, convolution->precision_config()); TF_RETURN_IF_ERROR(computation_->ReplaceWithNewInstruction( convolution, std::move(new_convolution))); } else { @@ -294,8 +502,9 @@ Status ConvolutionVisitor::HandleConvolution(HloInstruction* convolution) { dim->set_size(group_size); auto new_convolution = add(HloInstruction::CreateConvolve( - new_output_shape, activation, filter, group_count, new_window, - dim_numbers, convolution->precision_config())); + new_output_shape, activation, filter, group_count, + /*batch_group_count=*/1, new_window, dim_numbers, + convolution->precision_config())); // Delete the extra spatial dimension, and reshape. Shape reshaped_convolution_shape = @@ -372,7 +581,8 @@ Status ConvolutionVisitor::HandleConvolution(HloInstruction* convolution) { auto new_convolution = add(HloInstruction::CreateConvolve( conv_slice_shape, activation_slice, filter_slice, - /*feature_group_count=*/1, convolution->window(), dim_numbers, + /*feature_group_count=*/1, /*batch_group_count=*/1, + convolution->window(), dim_numbers, convolution->precision_config())); sliced_convolutions.push_back(new_convolution); @@ -390,17 +600,19 @@ Status ConvolutionVisitor::HandleConvolution(HloInstruction* convolution) { } // namespace -StatusOr ConvolutionFeatureGroupConverter::Run(HloModule* module) { - XLA_VLOG_LINES(2, "ConvolutionFeatureGroupConverter::Run(), before:\n" + - module->ToString()); +StatusOr ConvolutionGroupConverter::Run(HloModule* module) { + XLA_VLOG_LINES( + 2, "ConvolutionGroupConverter::Run(), before:\n" + module->ToString()); bool changed = false; for (auto* comp : module->MakeNonfusionComputations()) { - if (ConvolutionVisitor::Run(comp, filter_expansion_)) { + if (ConvolutionVisitor::Run(comp, is_cost_viable_, + convert_batch_groups_only_, + filter_expansion_)) { changed = true; } } - XLA_VLOG_LINES(2, "ConvolutionFeatureGroupConverter::Run(), after:\n" + - module->ToString()); + XLA_VLOG_LINES( + 2, "ConvolutionGroupConverter::Run(), after:\n" + module->ToString()); return changed; } diff --git a/tensorflow/compiler/xla/service/convolution_feature_group_converter.h b/tensorflow/compiler/xla/service/convolution_group_converter.h similarity index 58% rename from tensorflow/compiler/xla/service/convolution_feature_group_converter.h rename to tensorflow/compiler/xla/service/convolution_group_converter.h index cb6bc04c00a2ff10f970da2a07fb540a561dad5a..1caf1841119a965044502435fe0f5b38ca94f6a5 100644 --- a/tensorflow/compiler/xla/service/convolution_feature_group_converter.h +++ b/tensorflow/compiler/xla/service/convolution_group_converter.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CONVOLUTION_FEATURE_GROUP_CONVERTER_H_ -#define TENSORFLOW_COMPILER_XLA_SERVICE_CONVOLUTION_FEATURE_GROUP_CONVERTER_H_ +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CONVOLUTION_GROUP_CONVERTER_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_CONVOLUTION_GROUP_CONVERTER_H_ #include "absl/strings/string_view.h" #include "tensorflow/compiler/xla/service/hlo_module.h" @@ -25,23 +25,34 @@ namespace xla { // A pass which rewrites convolutions with feature_group_count > 1 into // convolutions with feature_group_count = 1. -class ConvolutionFeatureGroupConverter : public HloModulePass { +class ConvolutionGroupConverter : public HloModulePass { public: - ConvolutionFeatureGroupConverter(bool canonicalize_depthwise_filter = false) - : filter_expansion_(canonicalize_depthwise_filter) {} + ConvolutionGroupConverter(std::function is_cost_viable, + bool convert_batch_groups_only, + bool canonicalize_depthwise_filter = false) + : is_cost_viable_(is_cost_viable), + convert_batch_groups_only_(convert_batch_groups_only), + filter_expansion_(canonicalize_depthwise_filter) {} absl::string_view name() const override { - return "convolution-feature-group-converter"; + return "convolution-group-converter"; } // Run convolution rewriting on the given computation. Returns whether the // computation was changed. StatusOr Run(HloModule* module) override; + // Lambda containing cost model that decides whether to expand + // batch_group_count. + std::function is_cost_viable_; + + // Decides whether to convert batch groups or feature groups. + bool convert_batch_groups_only_; + // Tells whether filter expansion is required. bool filter_expansion_; }; } // namespace xla -#endif // TENSORFLOW_COMPILER_XLA_SERVICE_CONVOLUTION_FEATURE_GROUP_CONVERTER_H_ +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_CONVOLUTION_GROUP_CONVERTER_H_ diff --git a/tensorflow/compiler/xla/service/cpu/BUILD b/tensorflow/compiler/xla/service/cpu/BUILD index ce4c2a9cc69240b9565b35a3f2504d7fc9373917..f49b5110be5c4bab63b423e5ed2e67bc1828f6e3 100644 --- a/tensorflow/compiler/xla/service/cpu/BUILD +++ b/tensorflow/compiler/xla/service/cpu/BUILD @@ -112,7 +112,7 @@ cc_library( "//tensorflow/compiler/xla/service:buffer_liveness", "//tensorflow/compiler/xla/service:call_inliner", "//tensorflow/compiler/xla/service:conditional_simplifier", - "//tensorflow/compiler/xla/service:convolution_feature_group_converter", + "//tensorflow/compiler/xla/service:convolution_group_converter", "//tensorflow/compiler/xla/service:dot_decomposer", "//tensorflow/compiler/xla/service:executable", "//tensorflow/compiler/xla/service:flatten_call_graph", @@ -572,6 +572,7 @@ cc_library( ":runtime_matvec", "//tensorflow/compiler/xla:executable_run_options", "//tensorflow/core:framework_lite", + "//tensorflow/core/kernels:eigen_contraction_kernel", "//third_party/eigen3", ], ) @@ -766,6 +767,8 @@ cc_library( ":target_machine_features", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla/service:computation_layout", + "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service:hlo_casting_utils", "//tensorflow/compiler/xla/service:layout_assignment", "//tensorflow/core:lib", "@com_google_absl//absl/container:flat_hash_map", diff --git a/tensorflow/compiler/xla/service/cpu/conv_canonicalization.cc b/tensorflow/compiler/xla/service/cpu/conv_canonicalization.cc index 2d9978404cc9ec1e40fc61aaf794a8f1f06050bb..8e55267a67d330e7e721f9b5fb25451357a49a9d 100644 --- a/tensorflow/compiler/xla/service/cpu/conv_canonicalization.cc +++ b/tensorflow/compiler/xla/service/cpu/conv_canonicalization.cc @@ -132,7 +132,8 @@ StatusOr ConvCanonicalization::Run(HloModule* module) { HloInstruction* new_conv = module->entry_computation()->AddInstruction( HloInstruction::CreateConvolve( new_conv_shape, new_input, new_kernel, hlo->feature_group_count(), - hlo->window(), new_dnums, hlo->precision_config())); + hlo->batch_group_count(), hlo->window(), new_dnums, + hlo->precision_config())); // Reshape the output back to the shape of the original convolution. TF_RETURN_IF_ERROR(module->entry_computation()->ReplaceWithNewInstruction( diff --git a/tensorflow/compiler/xla/service/cpu/conv_canonicalization_test.cc b/tensorflow/compiler/xla/service/cpu/conv_canonicalization_test.cc index c58175428fea6a2d38253c35de598b99a4281bf1..02085108a081358cd4f8aed6dc12557cbd8eea85 100644 --- a/tensorflow/compiler/xla/service/cpu/conv_canonicalization_test.cc +++ b/tensorflow/compiler/xla/service/cpu/conv_canonicalization_test.cc @@ -84,8 +84,8 @@ TEST_F(ConvCanonicalizationTest, NonCanonicalToCanonical) { builder.AddInstruction(HloInstruction::CreateConvolve( ShapeUtil::MakeShape( F32, {kOutputFeatureCount, kBatchSize, output_size, output_size}), - input, kernel, /*feature_group_count=*/1, conv_window_, dnums, - DefaultPrecisionConfig(2))); + input, kernel, /*feature_group_count=*/1, /*batch_group_count=*/1, + conv_window_, dnums, DefaultPrecisionConfig(2))); auto module = CreateNewVerifiedModule(); HloComputation* entry_computation = @@ -147,8 +147,8 @@ TEST_F(ConvCanonicalizationTest, CanonicalStaysTheSame) { builder.AddInstruction(HloInstruction::CreateConvolve( ShapeUtil::MakeShape( F32, {kBatchSize, output_size, output_size, kOutputFeatureCount}), - input, kernel, /*feature_group_count=*/1, conv_window_, dnums, - DefaultPrecisionConfig(2))); + input, kernel, /*feature_group_count=*/1, /*batch_group_count=*/1, + conv_window_, dnums, DefaultPrecisionConfig(2))); auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build()); diff --git a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc index 6374822c81bf42fd12829f57cf93c19457128219..ba7dcde5c3d7e0406f46d642632f780d6d7db54f 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc @@ -51,7 +51,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/buffer_liveness.h" #include "tensorflow/compiler/xla/service/call_inliner.h" #include "tensorflow/compiler/xla/service/conditional_simplifier.h" -#include "tensorflow/compiler/xla/service/convolution_feature_group_converter.h" +#include "tensorflow/compiler/xla/service/convolution_group_converter.h" #include "tensorflow/compiler/xla/service/cpu/buffer_info_util.h" #include "tensorflow/compiler/xla/service/cpu/compiler_functor.h" #include "tensorflow/compiler/xla/service/cpu/conv_canonicalization.h" @@ -257,7 +257,16 @@ Status CpuCompiler::RunHloPassesThroughLayoutAssn( pipeline.AddPass(); pipeline.AddPass(); pipeline.AddPass(); - pipeline.AddPass(); + auto cost_model = [](HloInstruction* conv) { + // We need a cost model for CPUs. Currently, do nothing. + return false; + }; + pipeline.AddPass( + cost_model, + /*convert_batch_groups_only=*/true); + pipeline.AddPass( + cost_model, + /*convert_batch_groups_only=*/false); pipeline.AddPass(target_machine_features); { auto& pass = @@ -635,18 +644,17 @@ StatusOr> CpuCompiler::RunBackend( .EmitComputation( embedded_computation, embedded_computation->name(), /*is_top_level_computation=*/false, - &schedule.sequence(embedded_computation).instructions()) + schedule.sequence(embedded_computation).instructions()) .status()); } string function_name_prefix = entry_computation->name().empty() ? "__compute" : entry_computation->name(); - TF_ASSIGN_OR_RETURN( - llvm::Function * entry_function, - ir_emitter.EmitComputation( - entry_computation, function_name_prefix, - /*is_top_level_computation=*/true, - &schedule.sequence(entry_computation).instructions())); + TF_ASSIGN_OR_RETURN(llvm::Function * entry_function, + ir_emitter.EmitComputation( + entry_computation, function_name_prefix, + /*is_top_level_computation=*/true, + schedule.sequence(entry_computation).instructions())); string function_name = [&]() { llvm::SmallVector function_name_vector; @@ -835,7 +843,7 @@ CpuCompiler::CompileAheadOfTime(std::unique_ptr module_group, .EmitComputation( embedded_computation, embedded_computation->name(), /*is_top_level_computation=*/false, - &schedule.sequence(embedded_computation).instructions()) + schedule.sequence(embedded_computation).instructions()) .status()); } const string& entry_point_name = options.entry_point_name(); @@ -843,7 +851,7 @@ CpuCompiler::CompileAheadOfTime(std::unique_ptr module_group, ir_emitter.EmitComputation( computation, entry_point_name, /*is_top_level_computation=*/true, - &schedule.sequence(computation).instructions())); + schedule.sequence(computation).instructions())); CHECK(entry_function->getName() == llvm_ir::AsStringRef(entry_point_name)); diff --git a/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc b/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc index 97f9b85a606e140fd7f3b1e3ecfb0dd5ba289f03..37cefcb2e827ffd15aa489b1b3199ba9f27d9dd6 100644 --- a/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc +++ b/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc @@ -323,11 +323,11 @@ void ColumnMajorMatrixVectorProductEmitter::Emit() { int64 column_remainder = k() % tile_cols(); int64 column_limit = k() - column_remainder; - ksl_.ForReturnVoid("dot.outer.tiled", - /*start=*/0, /*end=*/column_limit, /*step=*/tile_cols(), - [&](llvm::Value* column, bool is_first_column) { - EmitOuterLoopBody(column, tile_cols(), is_first_column); - }); + ksl_.For("dot.outer.tiled", + /*start=*/0, /*end=*/column_limit, /*step=*/tile_cols(), + [&](llvm::Value* column, bool is_first_column) { + EmitOuterLoopBody(column, tile_cols(), is_first_column); + }); if (column_remainder != 0) { EmitOuterLoopBody(b_->getInt64(column_limit), column_remainder, @@ -340,7 +340,7 @@ void ColumnMajorMatrixVectorProductEmitter::EmitInnerLoopTiled( int64 columns, bool is_first_column) { int64 row_limit = m() - (m() % tile_rows()); - ksl_.ForReturnVoid( + ksl_.For( "dot.inner.tiled", /*start=*/0, /*end=*/row_limit, /*step=*/tile_rows(), [&](llvm::Value* row) { std::vector lhs_tile = @@ -372,7 +372,7 @@ void ColumnMajorMatrixVectorProductEmitter::EmitInnerLoopEpilogue( // // initialized. // } - ksl_.ForReturnVoid( + ksl_.For( "dot.inner.epilg.outer", /*start=*/current_tile_col, /*end=*/b_->CreateAdd(columns_llvm, current_tile_col), /*step=*/1, /*peel_first_iteration=*/false, @@ -381,14 +381,14 @@ void ColumnMajorMatrixVectorProductEmitter::EmitInnerLoopEpilogue( llvm::Value* total_offset = b_->CreateMul(col, b_->getInt64(m())); llvm::Value* lhs_base_pointer = vsl_.ComputeOffsetPointer(lhs_, total_offset); - ksl_.ForReturnVoid( + ksl_.For( "dot.inner.epilg.inner", /*start=*/row_start, /*end=*/m(), /*step=*/1, [&](llvm::Value* scalar_row) { llvm::Value* product = vsl_.Mul( vsl_.LoadScalar(lhs_base_pointer, scalar_row), rhs_element); llvm::Value* setting_result_first_time = b_->CreateAnd( is_first_scalar_col, b_->getInt1(is_first_tiled_column)); - ksl_.IfReturnVoid( + ksl_.If( setting_result_first_time, /*true_block_generator=*/ [&]() { @@ -568,10 +568,9 @@ void RowMajorMatrixVectorProductEmitter::Emit() { int64 row_remainder = m() % tile_rows(); int64 row_limit = m() - row_remainder; - ksl_.ForReturnVoid( - "dot.outer.tiled", - /*start=*/0, /*end=*/row_limit, /*step=*/tile_rows(), - [&](llvm::Value* row) { EmitOuterLoopBody(row, tile_rows()); }); + ksl_.For("dot.outer.tiled", + /*start=*/0, /*end=*/row_limit, /*step=*/tile_rows(), + [&](llvm::Value* row) { EmitOuterLoopBody(row, tile_rows()); }); if (row_remainder != 0) { EmitOuterLoopBody(b_->getInt64(row_limit), row_remainder); @@ -583,17 +582,17 @@ void RowMajorMatrixVectorProductEmitter::EmitInnerLoopTiled( std::vector* vector_accumulators) { int64 column_limit = k() - (k() % tile_cols()); - ksl_.ForReturnVoid("dot.inner.tiled", /*start=*/0, /*end=*/column_limit, - /*step=*/tile_cols(), [&](llvm::Value* col) { - std::vector lhs_tile = - lhs_memory_tile->LoadTile(/*minor_dim_offset=*/col); - llvm::Value* rhs_value = vsl_.LoadVector(rhs_, col); - for (int i = 0; i < rows; i++) { - llvm::Value* old_sum = (*vector_accumulators)[i].Get(); - (*vector_accumulators)[i].Set(vsl_.Add( - old_sum, vsl_.Mul(rhs_value, lhs_tile[i]))); - } - }); + ksl_.For("dot.inner.tiled", /*start=*/0, /*end=*/column_limit, + /*step=*/tile_cols(), [&](llvm::Value* col) { + std::vector lhs_tile = + lhs_memory_tile->LoadTile(/*minor_dim_offset=*/col); + llvm::Value* rhs_value = vsl_.LoadVector(rhs_, col); + for (int i = 0; i < rows; i++) { + llvm::Value* old_sum = (*vector_accumulators)[i].Get(); + (*vector_accumulators)[i].Set( + vsl_.Add(old_sum, vsl_.Mul(rhs_value, lhs_tile[i]))); + } + }); } void RowMajorMatrixVectorProductEmitter::EmitInnerLoopEpilogue( @@ -609,7 +608,7 @@ void RowMajorMatrixVectorProductEmitter::EmitInnerLoopEpilogue( b_->CreateAdd(b_->getInt64(r), current_tile_row), b_->getInt64(k())); llvm::Value* lhs_base_pointer = vsl_.ComputeOffsetPointer(lhs_, total_offset); - ksl_.ForReturnVoid( + ksl_.For( "dot.inner.epilg.inner", /*start=*/column_start, /*end=*/k(), /*step=*/1, [&](llvm::Value* scalar_col) { llvm::Value* product = @@ -813,7 +812,7 @@ void TiledSmallGemmEmitter::HandleResiduesOnN() { if (n_start != dims().n()) { VectorSupportLibrary vsl(scalar_type(), 1, b_, "gemm"); - ksl_.ForReturnVoid("epi.n", n_start, dims().n(), 1, [&](llvm::Value* n_i) { + ksl_.For("epi.n", n_start, dims().n(), 1, [&](llvm::Value* n_i) { llvm::Value* n_i_next = b_->CreateAdd(n_i, b_->getInt64(1)); HandleResiduesOnK(&vsl, n_i, n_i_next); }); @@ -924,7 +923,7 @@ void TiledSmallGemmEmitter::EmitTiledGemm( VectorSupportLibrary* vsl, int64 tile_size_k, llvm::Value* k_start, llvm::Value* k_end, llvm::Value* n_start, llvm::Value* n_end, int64 tile_size_m, llvm::Value* m_start, llvm::Value* m_end) { - ksl_.ForReturnVoid( + ksl_.For( "dot.m", m_start, m_end, tile_size_m, [&](llvm::Value* m_i) { MemoryTile result_memory_tile( vsl, b_, /*matrix=*/result_, @@ -935,11 +934,11 @@ void TiledSmallGemmEmitter::EmitTiledGemm( /*matrix_size_along_minor_dim=*/dims().k(), /*major_dim_offset=*/m_i, /*tile_size_along_major_dim=*/tile_size_m); - ksl_.ForReturnVoid( + ksl_.For( "dot.n", n_start, n_end, vsl->vector_size(), [&](llvm::Value* n_i) { TileVariable result_tile_var(vsl, result_memory_tile.LoadTile(n_i)); - ksl_.ForReturnVoid( + ksl_.For( "dot.k", k_start, k_end, tile_size_k, [&](llvm::Value* k_i) { MemoryTile rhs_memory_tile(vsl, b_, rhs_, dims().n(), k_i, tile_size_k); @@ -1406,16 +1405,20 @@ Status DotOpEmitter::EmitScalarDot() { llvm::Value* rhs_value = rhs_array_.EmitReadArrayElement(/*index=*/element_index, b_); if (ShapeUtil::ElementIsComplex(lhs_array_.GetShape())) { -#define REAL(x) b_->CreateExtractValue(x, {0}) -#define IMAG(x) b_->CreateExtractValue(x, {1}) - llvm::Value* real = - b_->CreateFSub(b_->CreateFMul(REAL(lhs_value), REAL(rhs_value)), - b_->CreateFMul(IMAG(lhs_value), IMAG(rhs_value))); - llvm::Value* imag = - b_->CreateFAdd(b_->CreateFMul(REAL(lhs_value), IMAG(rhs_value)), - b_->CreateFMul(IMAG(lhs_value), REAL(rhs_value))); -#undef IMAG -#undef REAL + auto get_real = [&](llvm::Value* x) { + return b_->CreateExtractValue(x, {0}); + }; + + auto get_imag = [&](llvm::Value* x) { + return b_->CreateExtractValue(x, {1}); + }; + + llvm::Value* real = b_->CreateFSub( + b_->CreateFMul(get_real(lhs_value), get_real(rhs_value)), + b_->CreateFMul(get_imag(lhs_value), get_imag(rhs_value))); + llvm::Value* imag = b_->CreateFAdd( + b_->CreateFMul(get_real(lhs_value), get_imag(rhs_value)), + b_->CreateFMul(get_imag(lhs_value), get_real(rhs_value))); result = llvm::ConstantAggregateZero::get(lhs_array_.GetElementLlvmType()); result = b_->CreateInsertValue(result, real, {0}); result = b_->CreateInsertValue(result, imag, {1}); diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc index 4032c2da2f33ee61da8771ae6225a14172cbe6e8..ed7fe59c80ed68420cea8b51e1732489ac2a874e 100644 --- a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc +++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc @@ -111,10 +111,9 @@ IrEmitter::IrEmitter( StatusOr IrEmitter::EmitComputation( HloComputation* computation, const string& function_name_prefix, bool is_top_level_computation, - const std::vector* instruction_order) { + absl::Span instruction_order) { string function_name = name_uniquer_.GetUniqueName(function_name_prefix); - VLOG(2) << "Emitting IR for CPU function [" << function_name_prefix - << "]; ordered? " << (instruction_order != nullptr); + VLOG(2) << "Emitting IR for CPU function [" << function_name_prefix << "]"; is_top_level_computation_ = is_top_level_computation; num_dynamic_loop_bounds_ = 0; if (!computation->root_instruction()->outer_dimension_partitions().empty()) { @@ -141,11 +140,7 @@ StatusOr IrEmitter::EmitComputation( bool use_rdtscp = arch_type_ == llvm::Triple::ArchType::x86 || arch_type_ == llvm::Triple::ArchType::x86_64; profiling_state_ = ProfilingState(use_rdtscp); - if (instruction_order == nullptr) { - TF_RETURN_IF_ERROR(computation->Accept(this)); - } else { - TF_RETURN_IF_ERROR(computation->AcceptOrdered(this, *instruction_order)); - } + TF_RETURN_IF_ERROR(computation->AcceptOrdered(this, instruction_order)); llvm::Function* ir_function = compute_function_->function(); InsertOrDie(&emitted_functions_, computation, ir_function); // Delete 'compute_function', finalizing 'ir_function' and restoring caller @@ -1338,11 +1333,11 @@ Status IrEmitter::HandleFft(HloInstruction* fft) { return Status::OK(); } -Status IrEmitter::HandleCrossReplicaSum(HloInstruction* crs) { +Status IrEmitter::HandleAllReduce(HloInstruction* crs) { if (hlo_module_config_.replica_count() != 1) { // TODO(b/33011107): Support nontrivial cross replica sum on CPU. return Unimplemented( - "CrossReplicaSum with >1 replica is not implemented on CPU."); + "AllReduce with >1 replica is not implemented on CPU."); } // When there is a single replica, a cross replica sum is the identity @@ -1368,7 +1363,7 @@ Status IrEmitter::HandleCrossReplicaSum(HloInstruction* crs) { const Shape& operand_shape = crs->operand(i)->shape(); CHECK(ShapeUtil::IsArray(operand_shape)) - << "Operands to cross-replica-sum must be arrays: " << crs->ToString(); + << "Operands to all-reduce must be arrays: " << crs->ToString(); operand_ptrs.push_back(EmitBufferPointer(out_slice, operand_shape)); // TODO(b/63762267): Be more aggressive about specifying alignment. @@ -2271,6 +2266,22 @@ Status IrEmitter::HandleCustomCall(HloInstruction* custom_call) { /*isVarArg=*/false))); TF_RETURN_IF_ERROR(EmitTargetAddressForOp(custom_call)); + // Write the tuple table if the output is a tuple. + if (ShapeUtil::IsTuple(custom_call->shape())) { + std::vector base_ptrs; + for (int i = 0; i < ShapeUtil::TupleElementCount(custom_call->shape()); + ++i) { + const Shape& elem_shape = + ShapeUtil::GetTupleElementShape(custom_call->shape(), i); + TF_RET_CHECK(!ShapeUtil::IsTuple(elem_shape)) + << "Nested tuples not implemented"; + TF_ASSIGN_OR_RETURN(const BufferAllocation::Slice slice, + assignment_.GetUniqueSlice(custom_call, {i})); + llvm::Value* addr = EmitBufferPointer(slice, elem_shape); + base_ptrs.push_back(addr); + } + llvm_ir::EmitTuple(GetIrArrayFor(custom_call), base_ptrs, &b_, module_); + } auto* output_address_arg = PointerCast(GetEmittedValueFor(custom_call), i8_ptr_type); @@ -2851,7 +2862,9 @@ llvm::Value* IrEmitter::EmitBufferPointer(const BufferAllocation::Slice& slice, if (slice.allocation()->is_thread_local()) { return EmitThreadLocalBufferPointer(slice, target_shape); } else if (slice.allocation()->is_constant()) { - return FindOrDie(constant_buffer_to_global_, slice.allocation()->index()); + return BitCast( + FindOrDie(constant_buffer_to_global_, slice.allocation()->index()), + IrShapeType(target_shape)->getPointerTo()); } else { return EmitGlobalBufferPointer(slice, target_shape); } diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.h b/tensorflow/compiler/xla/service/cpu/ir_emitter.h index 559a8162a2d53f28ea6817653503c216af90a610..db76de4bb2b8ed568bf2557a30fa216d0cbe518d 100644 --- a/tensorflow/compiler/xla/service/cpu/ir_emitter.h +++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.h @@ -101,7 +101,7 @@ class IrEmitter : public DfsHloVisitorWithDefault, StatusOr EmitComputation( HloComputation* computation, const string& function_name_prefix, bool is_top_level_computation, - const std::vector* instruction_order); + absl::Span instruction_order); llvm::IRBuilder<>* b() { return &b_; } @@ -134,7 +134,7 @@ class IrEmitter : public DfsHloVisitorWithDefault, Status HandleDot(HloInstruction* dot) override; Status HandleConvolution(HloInstruction* convolution) override; Status HandleFft(HloInstruction* fft) override; - Status HandleCrossReplicaSum(HloInstruction* crs) override; + Status HandleAllReduce(HloInstruction* crs) override; Status HandleInfeed(HloInstruction* infeed) override; Status HandleOutfeed(HloInstruction* outfeed) override; Status HandleSort(HloInstruction* sort) override; diff --git a/tensorflow/compiler/xla/service/cpu/parallel_task_assignment_test.cc b/tensorflow/compiler/xla/service/cpu/parallel_task_assignment_test.cc index f0b65046c14ccec5336abf7c4d05d1d755f783bd..35ae62b42dfa768c6abd0508097d6b235b2ebf54 100644 --- a/tensorflow/compiler/xla/service/cpu/parallel_task_assignment_test.cc +++ b/tensorflow/compiler/xla/service/cpu/parallel_task_assignment_test.cc @@ -112,10 +112,10 @@ TEST_F(ParallelTaskAssignmentTest, InfeedOutfeedOperationNotParallelized) { const string hlo_string = R"( HloModule TestTaskParallel_infeed_outfeed ENTRY InfeedOutfeed { - token = token[] after-all() - infeed0 = (u32[12345678,2]{1,0}, token[]) infeed(token) + token0 = token[] after-all() + infeed0 = (u32[12345678,2]{1,0}, token[]) infeed(token0) infeed0.data = u32[12345678,2]{1,0} get-tuple-element((u32[12345678,2]{1,0}, token[]) infeed0), index=0 - ROOT outfeed0 = token[] outfeed(infeed0.data, token) + ROOT outfeed0 = token[] outfeed(infeed0.data, token0) } )"; diff --git a/tensorflow/compiler/xla/service/cpu/runtime_matmul.cc b/tensorflow/compiler/xla/service/cpu/runtime_matmul.cc index a71a85913cfef271bc2a226cb0cf2dd4204499a4..fe7e87a197b6cf571195537eaea2898659cd5e2e 100644 --- a/tensorflow/compiler/xla/service/cpu/runtime_matmul.cc +++ b/tensorflow/compiler/xla/service/cpu/runtime_matmul.cc @@ -23,12 +23,20 @@ limitations under the License. #include "tensorflow/core/platform/dynamic_annotations.h" #include "tensorflow/core/platform/types.h" +#if defined(TENSORFLOW_USE_CUSTOM_CONTRACTION_KERNEL) +#include "tensorflow/core/kernels/eigen_contraction_kernel.h" +#endif + using tensorflow::int32; using tensorflow::int64; namespace { -template +bool Is16BytesAligned(void* ptr) { + return reinterpret_cast(ptr) % 16 == 0; +} + +template void MatMul(const void* run_options_ptr, T* out, T* lhs, T* rhs, int64 m, int64 n, int64 k, int32 transpose_lhs, int32 transpose_rhs) { const xla::ExecutableRunOptions* run_options = @@ -46,11 +54,11 @@ void MatMul(const void* run_options_ptr, T* out, T* lhs, T* rhs, int64 m, std::swap(rhs_rows, rhs_cols); } - const Eigen::TensorMap, Eigen::Aligned> A( - lhs, lhs_rows, lhs_cols); - const Eigen::TensorMap, Eigen::Aligned> B( - rhs, rhs_rows, rhs_cols); - Eigen::TensorMap, Eigen::Aligned> C(out, m, n); + const Eigen::TensorMap, Alignment> A(lhs, lhs_rows, + lhs_cols); + const Eigen::TensorMap, Alignment> B(rhs, rhs_rows, + rhs_cols); + Eigen::TensorMap, Alignment> C(out, m, n); typedef typename Eigen::Tensor::DimensionPair DimPair; int lhs_contract_dim = transpose_lhs ? 0 : 1; @@ -65,14 +73,24 @@ void MatMul(const void* run_options_ptr, T* out, T* lhs, T* rhs, int64 m, } template -void MatMulImpl(const void* run_options_ptr, T* out, T* lhs, T* rhs, int64 m, - int64 n, int64 k, int32 transpose_lhs, int32 transpose_rhs) { +void MatMulDispatch(const void* run_options_ptr, T* out, T* lhs, T* rhs, + int64 m, int64 n, int64 k, int32 transpose_lhs, + int32 transpose_rhs) { + bool all_buffers_16b_aligned = + Is16BytesAligned(out) && Is16BytesAligned(lhs) && Is16BytesAligned(rhs); + + if (!all_buffers_16b_aligned) { + MatMul(run_options_ptr, out, lhs, rhs, m, n, k, + transpose_lhs, transpose_rhs); + return; + } + if (m == 1 || n == 1) { // Despite being single threaded, this version of matrix * vector is faster. xla::EigenMatVec(out, lhs, rhs, m, n, k, transpose_lhs, transpose_rhs); } else { - MatMul(run_options_ptr, out, lhs, rhs, m, n, k, transpose_lhs, - transpose_rhs); + MatMul(run_options_ptr, out, lhs, rhs, m, n, k, + transpose_lhs, transpose_rhs); } } @@ -82,20 +100,20 @@ TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_EigenMatMulF16( const void* run_options_ptr, Eigen::half* out, Eigen::half* lhs, Eigen::half* rhs, int64 m, int64 n, int64 k, int32 transpose_lhs, int32 transpose_rhs) { - MatMulImpl(run_options_ptr, out, lhs, rhs, m, n, k, - transpose_lhs, transpose_rhs); + MatMulDispatch(run_options_ptr, out, lhs, rhs, m, n, k, + transpose_lhs, transpose_rhs); } TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_EigenMatMulF32( const void* run_options_ptr, float* out, float* lhs, float* rhs, int64 m, int64 n, int64 k, int32 transpose_lhs, int32 transpose_rhs) { - MatMulImpl(run_options_ptr, out, lhs, rhs, m, n, k, transpose_lhs, - transpose_rhs); + MatMulDispatch(run_options_ptr, out, lhs, rhs, m, n, k, transpose_lhs, + transpose_rhs); } TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_EigenMatMulF64( const void* run_options_ptr, double* out, double* lhs, double* rhs, int64 m, int64 n, int64 k, int32 transpose_lhs, int32 transpose_rhs) { - MatMulImpl(run_options_ptr, out, lhs, rhs, m, n, k, transpose_lhs, - transpose_rhs); + MatMulDispatch(run_options_ptr, out, lhs, rhs, m, n, k, transpose_lhs, + transpose_rhs); } diff --git a/tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul.cc b/tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul.cc index 16692e7f2e6145b2649b67987eef47916e958be2..1ed743afc30af7c7ff38c7d2a738f2e376270952 100644 --- a/tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul.cc +++ b/tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul.cc @@ -25,7 +25,11 @@ using tensorflow::int64; namespace { -template +bool Is16BytesAligned(void* ptr) { + return reinterpret_cast(ptr) % 16 == 0; +} + +template void MatMul(const void* run_options_ptr, T* out, T* lhs, T* rhs, int64 m, int64 n, int64 k, int32 transpose_lhs, int32 transpose_rhs) { int64 lhs_rows = m; @@ -40,11 +44,11 @@ void MatMul(const void* run_options_ptr, T* out, T* lhs, T* rhs, int64 m, std::swap(rhs_rows, rhs_cols); } - const Eigen::TensorMap, Eigen::Aligned> A( - lhs, lhs_rows, lhs_cols); - const Eigen::TensorMap, Eigen::Aligned> B( - rhs, rhs_rows, rhs_cols); - Eigen::TensorMap, Eigen::Aligned> C(out, m, n); + const Eigen::TensorMap, Alignment> A(lhs, lhs_rows, + lhs_cols); + const Eigen::TensorMap, Alignment> B(rhs, rhs_rows, + rhs_cols); + Eigen::TensorMap, Alignment> C(out, m, n); typedef typename Eigen::Tensor::DimensionPair DimPair; int lhs_contract_dim = transpose_lhs ? 0 : 1; @@ -59,14 +63,22 @@ void MatMul(const void* run_options_ptr, T* out, T* lhs, T* rhs, int64 m, } template -void SingleThreadedMatMul(const void* run_options_ptr, T* out, T* lhs, T* rhs, - int64 m, int64 n, int64 k, int32 transpose_lhs, - int32 transpose_rhs) { +void SingleThreadedMatMulDispatch(const void* run_options_ptr, T* out, T* lhs, + T* rhs, int64 m, int64 n, int64 k, + int32 transpose_lhs, int32 transpose_rhs) { + bool all_buffers_16b_aligned = + Is16BytesAligned(out) && Is16BytesAligned(lhs) && Is16BytesAligned(rhs); + + if (!all_buffers_16b_aligned) { + MatMul(run_options_ptr, out, lhs, rhs, m, n, k, + transpose_lhs, transpose_rhs); + } + if (m == 1 || n == 1) { xla::EigenMatVec(out, lhs, rhs, m, n, k, transpose_lhs, transpose_rhs); } else { - MatMul(run_options_ptr, out, lhs, rhs, m, n, k, transpose_lhs, - transpose_rhs); + MatMul(run_options_ptr, out, lhs, rhs, m, n, k, + transpose_lhs, transpose_rhs); } } @@ -77,8 +89,8 @@ __xla_cpu_runtime_EigenSingleThreadedMatMulF16( const void* run_options_ptr, Eigen::half* out, Eigen::half* lhs, Eigen::half* rhs, int64 m, int64 n, int64 k, int32 transpose_lhs, int32 transpose_rhs) { - SingleThreadedMatMul(run_options_ptr, out, lhs, rhs, m, n, k, - transpose_lhs, transpose_rhs); + SingleThreadedMatMulDispatch(run_options_ptr, out, lhs, rhs, m, + n, k, transpose_lhs, transpose_rhs); } TF_ATTRIBUTE_NO_SANITIZE_MEMORY void @@ -87,8 +99,8 @@ __xla_cpu_runtime_EigenSingleThreadedMatMulF32(const void* run_options_ptr, float* rhs, int64 m, int64 n, int64 k, int32 transpose_lhs, int32 transpose_rhs) { - SingleThreadedMatMul(run_options_ptr, out, lhs, rhs, m, n, k, - transpose_lhs, transpose_rhs); + SingleThreadedMatMulDispatch(run_options_ptr, out, lhs, rhs, m, n, k, + transpose_lhs, transpose_rhs); } TF_ATTRIBUTE_NO_SANITIZE_MEMORY void @@ -97,6 +109,6 @@ __xla_cpu_runtime_EigenSingleThreadedMatMulF64(const void* run_options_ptr, double* rhs, int64 m, int64 n, int64 k, int32 transpose_lhs, int32 transpose_rhs) { - SingleThreadedMatMul(run_options_ptr, out, lhs, rhs, m, n, k, - transpose_lhs, transpose_rhs); + SingleThreadedMatMulDispatch(run_options_ptr, out, lhs, rhs, m, n, k, + transpose_lhs, transpose_rhs); } diff --git a/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc b/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc index efccadedf27181a4cddf4f1dc3610f7c6db1d821..296f39a4853f2d3f7030209a921001e92c39d609 100644 --- a/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc +++ b/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc @@ -139,7 +139,7 @@ llvm::JITSymbol SimpleOrcJIT::ResolveRuntimeSymbol(const std::string& name) { } if (func_addr == nullptr) { - VLOG(2) << "Unable to resolve runtime symbol: " << name; + LOG(ERROR) << "Unable to resolve runtime symbol: " << name; return nullptr; } llvm::JITEvaluatedSymbol symbol_info(reinterpret_cast(func_addr), @@ -296,6 +296,9 @@ bool RegisterKnownJITSymbols() { REGISTER_LIBM_SYMBOL(sin, double (*)(double)); #ifdef __APPLE__ REGISTER_LIBM_SYMBOL(__sincos, void (*)(double, double*, double*)); + registry->Register("__sincosf_stret", + reinterpret_cast(__sincosf_stret)); + registry->Register("__sincos_stret", reinterpret_cast(__sincos_stret)); #else REGISTER_LIBM_SYMBOL(sincos, void (*)(double, double*, double*)); #endif @@ -311,6 +314,13 @@ bool RegisterKnownJITSymbols() { registry->Register("memcpy", reinterpret_cast(memcpy)); registry->Register("memmove", reinterpret_cast(memmove)); registry->Register("memset", reinterpret_cast(memset)); + +#ifdef __APPLE__ + registry->Register("__bzero", reinterpret_cast(bzero)); + registry->Register("memset_pattern16", + reinterpret_cast(memset_pattern16)); +#endif + return true; } diff --git a/tensorflow/compiler/xla/service/cpu/tests/cpu_literal_caching_test.cc b/tensorflow/compiler/xla/service/cpu/tests/cpu_literal_caching_test.cc index fa0e09ff6b5694c0e97963b83c6e541b858a1376..0584c0484f810a03ccccd522163f54535440ef8b 100644 --- a/tensorflow/compiler/xla/service/cpu/tests/cpu_literal_caching_test.cc +++ b/tensorflow/compiler/xla/service/cpu/tests/cpu_literal_caching_test.cc @@ -31,29 +31,27 @@ HloModule RepeatedConstants while_body { arg_body = f32[2,3,2] parameter(0) ROOT const = f32[2,3,2] constant( - f32[2,3,2] {{{1, 2}, {1001, 1002}, {2001, 2002}}, {{2, 1}, {2001, 3002}, {2001, 2002}}}) } while_cond { arg_cond = f32[2,3,2] parameter(0) - token = token[] after-all() - infeed = (pred[], token[]) infeed(token) + token0 = token[] after-all() + infeed = (pred[], token[]) infeed(token0) ROOT unknown = pred[] get-tuple-element((pred[], token[]) infeed), index=0 } ENTRY main { param = f32[2,3,2] parameter(0) const_a = f32[2,3,2] constant( - f32[2,3,2] {{{1, 2}, {1001, 1002}, {2001, 2002}}, {{2, 1}, {2001, 3002}, {2001, 2002}}}) const_b = f32[2,3,2] while(f32[2,3,2] const_a), condition=while_cond, body=while_body - token = token[] after-all() - out0 = token[] outfeed(f32[2,3,2] const_a, token[] token) - ROOT out1 = token[] outfeed(f32[2,3,2] const_b, token[] token) + token0 = token[] after-all() + out0 = token[] outfeed(f32[2,3,2] const_a, token[] token0) + ROOT out1 = token[] outfeed(f32[2,3,2] const_b, token[] token0) } )"; @@ -82,24 +80,24 @@ HloModule RepeatedConstants while_body { arg_body = (f32[2,1]{1,0}, f32[1]{0}) parameter(0) - ROOT const = (f32[2,1]{1,0}, f32[1]{0}) constant((f32[2,1], f32[1]) ( f32[2,1] { { 1 }, { 2 } }, {2} )) + ROOT const = (f32[2,1]{1,0}, f32[1]{0}) constant(({ { 1 }, { 2 } }, {2} )) } while_cond { arg_cond = (f32[2,1]{1,0}, f32[1]{0}) parameter(0) - token = token[] after-all() - infeed = (pred[], token[]) infeed(token) + token0 = token[] after-all() + infeed = (pred[], token[]) infeed(token0) ROOT unknown = pred[] get-tuple-element((pred[], token[]) infeed), index=0 } ENTRY main { param = f32[2,3,2] parameter(0) - const_a = (f32[2,1]{1,0}, f32[1]{0}) constant((f32[2,1], f32[1]) ( f32[2,1] { { 1 }, { 2 } }, {2} )) + const_a = (f32[2,1]{1,0}, f32[1]{0}) constant(( { { 1 }, { 2 } }, {2} )) const_b = (f32[2,1]{1,0}, f32[1]{0}) while((f32[2,1]{1,0}, f32[1]{0}) const_a), condition=while_cond, body=while_body - token = token[] after-all() - out0 = () outfeed((f32[2,1]{1,0}, f32[1]{0}) const_a, token[] token) - ROOT out1 = () outfeed((f32[2,1]{1,0}, f32[1]{0}) const_b, token[] token) + token0 = token[] after-all() + out0 = () outfeed((f32[2,1]{1,0}, f32[1]{0}) const_a, token[] token0) + ROOT out1 = () outfeed((f32[2,1]{1,0}, f32[1]{0}) const_b, token[] token0) } )"; diff --git a/tensorflow/compiler/xla/service/cpu/tests/cpu_outfeed_test.cc b/tensorflow/compiler/xla/service/cpu/tests/cpu_outfeed_test.cc index e2c7af541eede5265f274c72f55305549f059839..aab7f0b393881642437f1891256bd138823a3b87 100644 --- a/tensorflow/compiler/xla/service/cpu/tests/cpu_outfeed_test.cc +++ b/tensorflow/compiler/xla/service/cpu/tests/cpu_outfeed_test.cc @@ -28,12 +28,11 @@ HloModule Outfeed ENTRY main { const_a = f32[2,3,2] constant( - f32[2,3,2] {{{1, 2}, {1001, 1002}, {2001, 2002}}, {{2, 1}, {2001, 3002}, {2001, 2002}}}) - token = token[] after-all() - outfeed = token[] outfeed(f32[2,3,2] const_a, token) + token0 = token[] after-all() + outfeed = token[] outfeed(f32[2,3,2] const_a, token0) ROOT root = () tuple() } )"; diff --git a/tensorflow/compiler/xla/service/dfs_hlo_visitor.h b/tensorflow/compiler/xla/service/dfs_hlo_visitor.h index e84bf00153aa28df29d8df486b92654feab4afbf..2132468b9067ad4d5644d6cf3908a488a20ced05 100644 --- a/tensorflow/compiler/xla/service/dfs_hlo_visitor.h +++ b/tensorflow/compiler/xla/service/dfs_hlo_visitor.h @@ -105,7 +105,7 @@ class DfsHloVisitorBase { } virtual Status HandleConvolution(HloInstructionPtr hlo) = 0; virtual Status HandleFft(HloInstructionPtr fft) = 0; - virtual Status HandleCrossReplicaSum(HloInstructionPtr hlo) = 0; + virtual Status HandleAllReduce(HloInstructionPtr hlo) = 0; virtual Status HandleAllToAll(HloInstructionPtr hlo) = 0; virtual Status HandleCollectivePermute(HloInstructionPtr hlo) = 0; virtual Status HandleGetDimensionSize(HloInstructionPtr hlo) = 0; diff --git a/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h b/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h index 80ea5be298aea44a0f424398da74c4e478f10346..680dd256bb15bd3a9eaff7241174c1d2833002c6 100644 --- a/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h +++ b/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h @@ -91,7 +91,7 @@ class DfsHloVisitorWithDefaultBase Status HandleFft(HloInstructionPtr fft) override { return DefaultAction(fft); } - Status HandleCrossReplicaSum(HloInstructionPtr crs) override { + Status HandleAllReduce(HloInstructionPtr crs) override { return DefaultAction(crs); } Status HandleAllToAll(HloInstructionPtr hlo) override { diff --git a/tensorflow/compiler/xla/service/dynamic_dimension_inference_test.cc b/tensorflow/compiler/xla/service/dynamic_dimension_inference_test.cc index ea9ebed45d99797ce4f80376ec3d0b758da3ca17..1dd196821c05cc820e2a3bf53a04d96b15484cd4 100644 --- a/tensorflow/compiler/xla/service/dynamic_dimension_inference_test.cc +++ b/tensorflow/compiler/xla/service/dynamic_dimension_inference_test.cc @@ -292,7 +292,8 @@ TEST_F(DynamicDimensionInferenceTest, ConvolutionTest) { Window window; auto* conv = builder.AddInstruction(HloInstruction::CreateConvolve( - zx_shape, a_param, b_param, /*feature_group_count=*/1, window, dnums, + zx_shape, a_param, b_param, /*feature_group_count=*/1, + /*batch_group_count=*/1, window, dnums, HloTestBase::DefaultPrecisionConfig(2))); module_->AddEntryComputation(builder.Build()); diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD index bfd1b6cb1492f5cb709e2ecefe73782094e26f5e..6c23f921f40cac0dc5df08494dc1b63e6d1d5e93 100644 --- a/tensorflow/compiler/xla/service/gpu/BUILD +++ b/tensorflow/compiler/xla/service/gpu/BUILD @@ -694,6 +694,7 @@ cc_library( "//tensorflow/compiler/xla/service:buffer_liveness", "//tensorflow/compiler/xla/service:call_inliner", "//tensorflow/compiler/xla/service:conditional_simplifier", + "//tensorflow/compiler/xla/service:convolution_group_converter", "//tensorflow/compiler/xla/service:executable", "//tensorflow/compiler/xla/service:flatten_call_graph", "//tensorflow/compiler/xla/service:hlo", diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_conv_rewriter_test.cc b/tensorflow/compiler/xla/service/gpu/cudnn_conv_rewriter_test.cc index 443883a89f66a747def1049bc5afb53fec3c2409..dbcdc2b075bc72f3194af8e555faabb1511376e0 100644 --- a/tensorflow/compiler/xla/service/gpu/cudnn_conv_rewriter_test.cc +++ b/tensorflow/compiler/xla/service/gpu/cudnn_conv_rewriter_test.cc @@ -109,9 +109,11 @@ TEST_F(CudnnConvRewriterTest, BackwardFilterConvolve) { auto* conv = builder.AddInstruction(HloInstruction::CreateConvolve( ShapeInference::InferConvolveShape( activations->shape(), gradients->shape(), /*feature_group_count=*/1, - conv_window, tf_default_dnums_for_backward_filter_) + /*batch_group_count=*/1, conv_window, + tf_default_dnums_for_backward_filter_) .ConsumeValueOrDie(), - activations, gradients, /*feature_group_count=*/1, conv_window, + activations, gradients, /*feature_group_count=*/1, + /*batch_group_count=*/1, conv_window, tf_default_dnums_for_backward_filter_, DefaultPrecisionConfig(2))); OpMetadata metadata; @@ -147,9 +149,11 @@ TEST_F(CudnnConvRewriterTest, builder.AddInstruction(HloInstruction::CreateConvolve( ShapeInference::InferConvolveShape( activations->shape(), gradients->shape(), /*feature_group_count=*/1, - conv_window, tf_default_dnums_for_backward_filter_) + /*batch_group_count=*/1, conv_window, + tf_default_dnums_for_backward_filter_) .ConsumeValueOrDie(), - activations, gradients, /*feature_group_count=*/1, conv_window, + activations, gradients, /*feature_group_count=*/1, + /*batch_group_count=*/1, conv_window, tf_default_dnums_for_backward_filter_, DefaultPrecisionConfig(2))); auto module = CreateNewVerifiedModule(); @@ -179,7 +183,7 @@ TEST_F(CudnnConvRewriterTest, BackwardFilterConvolveWithPaddedActivations) { } builder.AddInstruction(HloInstruction::CreateConvolve( ShapeUtil::MakeShape(F32, {32, 3, 3, 32}), activations, gradients, - /*feature_group_count=*/1, conv_window, + /*feature_group_count=*/1, /*batch_group_count=*/1, conv_window, tf_default_dnums_for_backward_filter_, DefaultPrecisionConfig(2))); auto module = CreateNewVerifiedModule(); @@ -209,7 +213,7 @@ TEST_F(CudnnConvRewriterTest, BackwardFilterConvolveWithPaddedGradients) { } builder.AddInstruction(HloInstruction::CreateConvolve( ShapeUtil::MakeShape(F32, {320, 3, 3, 192}), activations, gradients, - /*feature_group_count=*/1, conv_window, + /*feature_group_count=*/1, /*batch_group_count=*/1, conv_window, tf_default_dnums_for_backward_filter_, DefaultPrecisionConfig(2))); auto module = CreateNewVerifiedModule(); @@ -238,7 +242,7 @@ TEST_F(CudnnConvRewriterTest, BackwardFilterConvolveWithUnevenPadding) { } builder.AddInstruction(HloInstruction::CreateConvolve( ShapeUtil::MakeShape(F32, {32, 2, 2, 32}), activations, gradients, - /*feature_group_count=*/1, conv_window, + /*feature_group_count=*/1, /*batch_group_count=*/1, conv_window, tf_default_dnums_for_backward_filter_, DefaultPrecisionConfig(2))); auto module = CreateNewVerifiedModule(); @@ -283,13 +287,15 @@ TEST_F(CudnnConvRewriterTest, BackwardInputConvolveEvenPadding) { HloInstruction* conv = builder.AddInstruction(HloInstruction::CreateConvolve( ShapeUtil::MakeShape(F32, {4, 3, 16, 16}), /*lhs=*/output, - /*rhs=*/reverse_kernel, /*feature_group_count=*/1, conv_window, - conv_dnums, DefaultPrecisionConfig(2))); + /*rhs=*/reverse_kernel, /*feature_group_count=*/1, + /*batch_group_count=*/1, conv_window, conv_dnums, + DefaultPrecisionConfig(2))); // Verify the convolution's shape is consistent with ShapeInference. CHECK(ShapeUtil::Compatible( conv->shape(), ShapeInference::InferConvolveShape( output->shape(), reverse_kernel->shape(), - /*feature_group_count=*/1, conv_window, conv_dnums) + /*feature_group_count=*/1, /*batch_group_count=*/1, + conv_window, conv_dnums) .ValueOrDie())); auto module = CreateNewVerifiedModule(); @@ -332,10 +338,12 @@ TEST_F(CudnnConvRewriterTest, BackwardInputConvolve1x1Filter) { builder.AddInstruction(HloInstruction::CreateConvolve( ShapeInference::InferConvolveShape(output->shape(), kernel->shape(), - /*feature_group_count=*/1, conv_window, + /*feature_group_count=*/1, + /*batch_group_count=*/1, conv_window, tf_default_dnums_for_backward_input_) .ConsumeValueOrDie(), - /*lhs=*/output, /*rhs=*/kernel, /*feature_group_count=*/1, conv_window, + /*lhs=*/output, /*rhs=*/kernel, /*feature_group_count=*/1, + /*batch_group_count=*/1, conv_window, tf_default_dnums_for_backward_input_, DefaultPrecisionConfig(2))); auto module = CreateNewVerifiedModule(); @@ -365,11 +373,12 @@ TEST_F(CudnnConvRewriterTest, builder.AddInstruction(HloInstruction::CreateConvolve( ShapeInference::InferConvolveShape( output->shape(), kernel->shape(), /*feature_group_count=*/1, - default_conv_window_, tf_default_dnums_for_backward_input_) + /*batch_group_count=*/1, default_conv_window_, + tf_default_dnums_for_backward_input_) .ConsumeValueOrDie(), /*lhs=*/output, /*rhs=*/kernel, /*feature_group_count=*/1, - default_conv_window_, tf_default_dnums_for_backward_input_, - DefaultPrecisionConfig(2))); + /*batch_group_count=*/1, default_conv_window_, + tf_default_dnums_for_backward_input_, DefaultPrecisionConfig(2))); auto module = CreateNewVerifiedModule(); HloComputation* entry_computation = @@ -415,15 +424,15 @@ TEST_F(CudnnConvRewriterTest, BackwardInputConvolveUnevenPaddingOnGradients) { } HloInstruction* conv = builder.AddInstruction(HloInstruction::CreateConvolve( ShapeUtil::MakeShape(F32, {20, 10, 10, 192}), output, reverse_kernel, - /*feature_group_count=*/1, conv_window, + /*feature_group_count=*/1, /*batch_group_count=*/1, conv_window, tf_default_dnums_for_backward_input_, DefaultPrecisionConfig(2))); // Verify the convolution's shape is consistent with ShapeInference. CHECK(ShapeUtil::Compatible( - conv->shape(), - ShapeInference::InferConvolveShape( - output->shape(), reverse_kernel->shape(), /*feature_group_count=*/1, - conv_window, tf_default_dnums_for_backward_input_) - .ValueOrDie())); + conv->shape(), ShapeInference::InferConvolveShape( + output->shape(), reverse_kernel->shape(), + /*feature_group_count=*/1, /*batch_group_count=*/1, + conv_window, tf_default_dnums_for_backward_input_) + .ValueOrDie())); auto module = CreateNewVerifiedModule(); HloComputation* entry_computation = @@ -465,15 +474,15 @@ TEST_F(CudnnConvRewriterTest, BackwardInputConvolveLowPaddingTooLarge) { } HloInstruction* conv = builder.AddInstruction(HloInstruction::CreateConvolve( ShapeUtil::MakeShape(F32, {20, 10, 10, 192}), output, reverse_kernel, - /*feature_group_count=*/1, conv_window, + /*feature_group_count=*/1, /*batch_group_count=*/1, conv_window, tf_default_dnums_for_backward_input_, DefaultPrecisionConfig(2))); // Verify the convolution's shape is consistent with ShapeInference. CHECK(ShapeUtil::Compatible( - conv->shape(), - ShapeInference::InferConvolveShape( - output->shape(), reverse_kernel->shape(), /*feature_group_count=*/1, - conv_window, tf_default_dnums_for_backward_input_) - .ValueOrDie())); + conv->shape(), ShapeInference::InferConvolveShape( + output->shape(), reverse_kernel->shape(), + /*feature_group_count=*/1, /*batch_group_count=*/1, + conv_window, tf_default_dnums_for_backward_input_) + .ValueOrDie())); auto module = CreateNewVerifiedModule(); HloComputation* entry_computation = @@ -519,15 +528,15 @@ TEST_F(CudnnConvRewriterTest, BackwardInputConvolveUnevenPaddingOnActivations) { forward_conv_col_dim->set_base_dilation(2); HloInstruction* conv = builder.AddInstruction(HloInstruction::CreateConvolve( ShapeUtil::MakeShape(F32, {1, 1, 14, 1}), output, reverse_kernel, - /*feature_group_count=*/1, conv_window, + /*feature_group_count=*/1, /*batch_group_count=*/1, conv_window, tf_default_dnums_for_backward_input_, DefaultPrecisionConfig(2))); // Verify the convolution's shape is consistent with ShapeInference. CHECK(ShapeUtil::Compatible( - conv->shape(), - ShapeInference::InferConvolveShape( - output->shape(), reverse_kernel->shape(), /*feature_group_count=*/1, - conv_window, tf_default_dnums_for_backward_input_) - .ValueOrDie())); + conv->shape(), ShapeInference::InferConvolveShape( + output->shape(), reverse_kernel->shape(), + /*feature_group_count=*/1, /*batch_group_count=*/1, + conv_window, tf_default_dnums_for_backward_input_) + .ValueOrDie())); auto module = CreateNewVerifiedModule(); const HloComputation* entry_computation = @@ -574,15 +583,15 @@ TEST_F(CudnnConvRewriterTest, forward_conv_col_dim->set_padding_high(2); HloInstruction* conv = builder.AddInstruction(HloInstruction::CreateConvolve( ShapeUtil::MakeShape(F32, {1, 1, 4, 1}), output, reverse_kernel, - /*feature_group_count=*/1, conv_window, + /*feature_group_count=*/1, /*batch_group_count=*/1, conv_window, tf_default_dnums_for_backward_input_, DefaultPrecisionConfig(2))); // Verify the convolution's shape is consistent with ShapeInference. CHECK(ShapeUtil::Compatible( - conv->shape(), - ShapeInference::InferConvolveShape( - output->shape(), reverse_kernel->shape(), /*feature_group_count=*/1, - conv_window, tf_default_dnums_for_backward_input_) - .ValueOrDie())); + conv->shape(), ShapeInference::InferConvolveShape( + output->shape(), reverse_kernel->shape(), + /*feature_group_count=*/1, /*batch_group_count=*/1, + conv_window, tf_default_dnums_for_backward_input_) + .ValueOrDie())); auto module = CreateNewVerifiedModule(); HloComputation* entry_computation = @@ -599,7 +608,7 @@ TEST_F(CudnnConvRewriterTest, BackwardInputConvolveConstantFilter) { Array4D constant_arr(4, 4, 2, 2); constant_arr.FillIota(0); string constant_str = - LiteralUtil::CreateR4FromArray4D(constant_arr).ToString(); + LiteralUtil::CreateR4FromArray4D(constant_arr).ToStringWithoutShape(); const string module_str = absl::StrFormat(R"( HloModule test diff --git a/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment_test.cc b/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment_test.cc index 2ffc8bfb49b205dced0d540ba72426e72d95e596..29756d27260b0f41b2dd4b649ea9b1610ff90268 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment_test.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment_test.cc @@ -369,7 +369,7 @@ TEST_F(LayoutAssignmentTest, SortLayout) { const char* hlo_text = R"( HloModule SortLayout ENTRY sort { - keys = f32[3,2]{0,1} constant(f32[3,2]{0,1}{{0,1},{0,1},{0,1}}) + keys = f32[3,2]{0,1} constant({{0,1},{0,1},{0,1}}) values = f32[2,3]{1,0} parameter(0) transpose = f32[3,2]{1,0} transpose(values), dimensions={1,0} ROOT sort = (f32[3,2]{1,0}, f32[3,2]{1,0}) sort(keys, transpose), diff --git a/tensorflow/compiler/xla/service/gpu/infeed_thunk.cc b/tensorflow/compiler/xla/service/gpu/infeed_thunk.cc index 8c3a026740851767855beae59d6a3c92f7a0d6bd..8a96b5fabc990ecd2e3d5a5cc5eb2f7b4b938c80 100644 --- a/tensorflow/compiler/xla/service/gpu/infeed_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/infeed_thunk.cc @@ -36,6 +36,21 @@ Status InfeedThunk::ExecuteOnStream(const BufferAllocations& buffer_allocations, ShapeTree infeed_buffers = GetOrCreateInfeedManager()->BlockingGetNextDestination(); + // infeed_slices_'s shape should be a tuple of shape (buffers, token). + const auto& infeed_shape = infeed_slices_.shape(); + TF_RET_CHECK(ShapeUtil::IsTuple(infeed_shape)) + << ShapeUtil::HumanStringWithLayout(infeed_shape); + TF_RET_CHECK(infeed_shape.tuple_shapes().size() == 2) + << ShapeUtil::HumanStringWithLayout(infeed_shape); + TF_RET_CHECK(ShapeUtil::IsToken(infeed_shape.tuple_shapes(1))) + << ShapeUtil::HumanStringWithLayout(infeed_shape); + TF_RET_CHECK( + ShapeUtil::Equal(infeed_buffers.shape(), infeed_shape.tuple_shapes(0))) + << "Expected infeed of shape " + << ShapeUtil::HumanStringWithLayout(infeed_shape.tuple_shapes(0)) + << " but was " + << ShapeUtil::HumanStringWithLayout(infeed_buffers.shape()); + { // The infeed buffer has an extra outer tuple with a token. Adjust the index // accordingly. diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter.cc index 6693f66d62d8b04d1b78e001fdb515b34539c67f..22db38ee03b9990cc2f21a01b6c0f2249d0991ea 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter.cc @@ -637,9 +637,9 @@ Status IrEmitter::HandleFft(HloInstruction* fft) { return Unimplemented("Hit a case for fft that is not implemented on GPU."); } -Status IrEmitter::HandleCrossReplicaSum(HloInstruction* crs) { +Status IrEmitter::HandleAllReduce(HloInstruction* crs) { // TODO(b/33011107): Support cross replica sum on GPU. - return Unimplemented("CrossReplicaSum is not implemented on GPU."); + return Unimplemented("AllReduce is not implemented on GPU."); } Status IrEmitter::HandleParameter(HloInstruction* parameter) { diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter.h b/tensorflow/compiler/xla/service/gpu/ir_emitter.h index 2da46c016935d0e927879bbfb0d05cfc4899d818..f380aee9d3c06a29b503c81c7bd3846dbccf6ce5 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter.h +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter.h @@ -81,7 +81,7 @@ class IrEmitter : public DfsHloVisitorWithDefault, Status HandleDot(HloInstruction* dot) override; Status HandleConvolution(HloInstruction* convolution) override; Status HandleFft(HloInstruction* fft) override; - Status HandleCrossReplicaSum(HloInstruction* crs) override; + Status HandleAllReduce(HloInstruction* crs) override; Status HandleInfeed(HloInstruction* infeed) override; Status HandleOutfeed(HloInstruction* outfeed) override; Status HandleSend(HloInstruction* send) override; diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc index fb040aff30d48bf5817946ce53d37bc6685941e4..1472853dc443f0190c3bbed7f96c91ec65ae6dda 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc @@ -22,7 +22,6 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h" #include "absl/algorithm/container.h" -#include "absl/container/inlined_vector.h" #include "absl/memory/memory.h" #include "absl/strings/str_cat.h" #include "absl/types/optional.h" @@ -548,91 +547,7 @@ Status IrEmitterUnnested::HandleFusion(HloInstruction* fusion) { // TODO(b/112040122): Support variadic reduce. return Unimplemented("Variadic reduce is not supported on GPU"); } - VLOG(3) << "Emitting fused reduction to vector: " << fusion->ToString(); - std::vector> thunks; - absl::Span output_instructions = - root->opcode() == HloOpcode::kTuple - ? root->operands() - : absl::Span(&root, 1); - - // For multi-output fusion emit an initializer for each tuple element. - // Otherwise it's sufficient to just initialize the single output. - HloInstruction* first_reduce = nullptr; - for (int i = 0, e = output_instructions.size(); i != e; ++i) { - if (output_instructions[i]->opcode() == HloOpcode::kReduce) { - TF_ASSIGN_OR_RETURN( - std::unique_ptr initializer_thunk, - BuildInitializerThunk(fusion, output_instructions[i] == root - ? ShapeIndex() - : ShapeIndex({i}))); - thunks.push_back(std::move(initializer_thunk)); - first_reduce = - first_reduce == nullptr ? output_instructions[i] : first_reduce; - } - } - CHECK(first_reduce != nullptr); - std::unique_ptr kernel_thunk = - BuildKernelThunk(fusion, /*implements_whole_instruction=*/false); - GpuElementalIrEmitter elemental_emitter( - hlo_module_config_, ir_emitter_context_->llvm_module(), &b_, - GetNestedComputer()); - FusedIrEmitter fused_emitter(GetGeneratorForOperandIrArrays(fusion), - &elemental_emitter); - TF_RETURN_IF_ERROR(root->Accept(&fused_emitter)); - - // For multi-output fusion CHECK the constraints and feed all the - // reduces into a single loop code generator. Single-output reduce - // fusion is a special case of that. - InlinedVector input_gens; - InlinedVector init_value_gens; - std::vector> - extra_output_gens; - InlinedVector reducers; - InlinedVector reduce_output_shapes; - for (int i = 0, e = output_instructions.size(); i != e; ++i) { - const HloInstruction* inst = output_instructions[i]; - ShapeIndex output_shape_index; - if (root->opcode() == HloOpcode::kTuple) { - output_shape_index = {i}; - } - if (inst->opcode() == HloOpcode::kReduce) { - CHECK(IsReductionToVector(*inst)) - << "Only reductions to vector are supported"; - // Shapes, layouts and dimensions must be the same for all reduces - // inside of this fusion. - CHECK(ShapeUtil::Equal(first_reduce->shape(), inst->shape())); - CHECK(ShapeUtil::Equal(first_reduce->operand(0)->shape(), - inst->operand(0)->shape())); - CHECK(ShapeUtil::Equal(first_reduce->operand(1)->shape(), - inst->operand(1)->shape())); - CHECK(first_reduce->dimensions() == inst->dimensions()); - input_gens.push_back(fused_emitter.GetGenerator(inst->operand(0))); - init_value_gens.push_back( - fused_emitter.GetGenerator(inst->operand(1))); - reducers.push_back(inst->to_apply()); - reduce_output_shapes.push_back(std::move(output_shape_index)); - } else { - // For extra outputs we can relax shape equality to allow different - // types (with the same number of elements). Layouts still have to - // match. - CHECK(ShapeUtil::CompatibleIgnoringElementType( - first_reduce->operand(0)->shape(), inst->shape())); - CHECK(LayoutUtil::Equal(first_reduce->operand(0)->shape().layout(), - inst->shape().layout())); - extra_output_gens.emplace_back(fused_emitter.GetGenerator(inst), - std::move(output_shape_index)); - } - } - const Shape& input_shape = first_reduce->operand(0)->shape(); - TF_CHECK_OK(EmitReductionToVector( - kernel_thunk.get(), first_reduce, input_shape, input_gens, - init_value_gens, first_reduce->dimensions(), reducers, - reduce_output_shapes, extra_output_gens)); - thunks.push_back(std::move(kernel_thunk)); - std::unique_ptr sequential_thunk = - absl::make_unique(std::move(thunks), fusion); - AddThunkToThunkSequence(std::move(sequential_thunk)); - return Status::OK(); + return EmitReductionToVector(fusion); } default: LOG(FATAL) << "Bad opcode for input fusion: " @@ -702,13 +617,12 @@ Status IrEmitterUnnested::HandleCopy(HloInstruction* copy) { } Status IrEmitterUnnested::EmitExtraOutputsForReduce( - const HloInstruction* reduce, const IrArray::Index& index, + const HloInstruction* unnested_hlo, const IrArray::Index& index, absl::Span> extra_output_gens) { for (int i = 0; i != extra_output_gens.size(); ++i) { - const HloInstruction* output = reduce->parent()->FusionInstruction(); llvm::Value* extra_output_address = - GetIrArray(*output, *output, extra_output_gens[i].second) + GetIrArray(*unnested_hlo, *unnested_hlo, extra_output_gens[i].second) .EmitArrayElementAddress(index, &b_, "extra_output_element_address"); TF_ASSIGN_OR_RETURN(llvm::Value* const extra_output_ir_value, @@ -718,984 +632,13 @@ Status IrEmitterUnnested::EmitExtraOutputsForReduce( return Status::OK(); } -Status IrEmitterUnnested::EmitReductionToScalar( - KernelThunk* kernel_thunk, HloInstruction* reduce, const Shape& input_shape, - absl::Span input_gens, - absl::Span init_value_gens, - absl::Span reducers, - absl::Span reduce_output_shapes, - absl::Span> - extra_output_gens) { - // Number of elements processed by a single thread. - constexpr int64 kTileSize = 16; - int64 num_elems = ShapeUtil::ElementsIn(input_shape); - - // Round up the number of tiles to a multiple of the warp size. This is - // necessary for correctness. We launch one thread per tile, and if the - // number of threads isn't a multiple of the number of the warp size, our - // shuffles will read from inactive threads, producing undefined values. - int64 num_tiles = - RoundUpToNearest(CeilOfRatio(num_elems, kTileSize), kWarpSize); - - Shape tiled_input_shape = ShapeUtil::MakeShapeWithLayout( - reduce->shape().element_type(), {num_tiles}, {0}); - LaunchDimensions launch_dimensions = CalculateLaunchDimensions( - tiled_input_shape, ir_emitter_context_->device_description()); - - llvm::Type* index_ty = - GetIndexTypeForKernel(reduce, launch_dimensions.launch_bound(), &b_); - - auto index_typed_constant = [&](uint64 c) -> llvm::Constant* { - return llvm::ConstantInt::get(index_ty, c); - }; - - // Check whether every thread will process a full tile's worth of elements - // without reading outside the bounds of the input. If this is true, we can - // skip some bounds checks in the final algorithm. - bool all_threads_in_bounds = num_tiles * kTileSize == num_elems; - - // __global__ void full_reduce_kernel() { - // x_in_tiles = threadIdx.x + blockIdx.x * blockDim.x; - // x = x_in_tiles * kTileSize; - // - // partial_result = init_value; - // if (all_threads_in_bounds || x + kTileSize <= num_elems) { - // for (i = 0; i < kTileSize; ++i) { - // partial_result = Reducer(partial_result, input[x + i]); - // } - // } else { - // for (i = 0; i < kTileSize; ++i) { - // if (x + i < num_elems) { - // partial_result = Reducer(partial_result, input[x + i]); - // } - // } - // } - // for (i = warpSize / 2; i > 0; i /= 2) { - // partial_result = Reducer(partial_result, - // __shfl_down(partial_result, i)); - // } - // if (lane_id == 0) { - // AtomicReducer(&output[y], partial_result); - // } - // } - // - // // Choose num_blocks and threads_per_block such that: - // // - // // num_blocks * threads_per_block = - // // RoundUpToNextMultipleOf(Ceil(num_elems / kTileSize), warpSize), - // // - // // and threads_per_block is a multiple of warpSize. - // reduce_kernel // - auto loop_body_emitter = [=](const IrArray::Index& tile_index) -> Status { - const int num_reduces = reducers.size(); - llvm::Type* element_ir_type = - llvm_ir::PrimitiveTypeToIrType(input_shape.element_type(), module_); - std::vector partial_reduction_result_addresses; - for (int i = 0; i != num_reduces; ++i) { - llvm::Value* partial_reduction_result_address = - Alloca(element_ir_type, /*ArraySize=*/nullptr, - "partial_reduction_result." + llvm::Twine(i)); - TF_ASSIGN_OR_RETURN(llvm::Value* const init_ir_value, - init_value_gens[i](IrArray::Index(index_ty))); - Store(init_ir_value, partial_reduction_result_address); - partial_reduction_result_addresses.push_back( - partial_reduction_result_address); - } - - llvm::Value* x_in_tiles = tile_index[0]; - x_in_tiles = ZExtOrTrunc(x_in_tiles, index_ty); - - // Emit an inner for-loop that reduces the elements in the tile. - auto emit_tile_element_loop = [=](bool tile_in_bounds) -> Status { - std::unique_ptr tile_element_loop = - llvm_ir::ForLoop::EmitForLoop( - "element_id_in_tile", index_typed_constant(0), - index_typed_constant(kTileSize), index_typed_constant(1), &b_); - - // Emit the body of the partial reduction loop. - llvm_ir::SetToFirstInsertPoint(tile_element_loop->GetBodyBasicBlock(), - &b_); - llvm::Value* x = - NSWAdd(NSWMul(x_in_tiles, index_typed_constant(kTileSize)), - tile_element_loop->GetIndVarValue()); - // Unless we know the tile is entirely in bounds, we have to emit a - // x-in-bounds check before reading from the input. - if (!tile_in_bounds) { - llvm_ir::LlvmIfData if_data = llvm_ir::EmitIfThenElse( - ICmpULT(x, index_typed_constant(num_elems)), "x_in_bounds", &b_); - - // Emit code that reads the input element and accumulates it to - // the partial reduction result. - llvm_ir::SetToFirstInsertPoint(if_data.true_block, &b_); - } - - IrArray::Index input_index( - /*linear=*/x, input_shape, &b_); - llvm::Value* input_address = Alloca(element_ir_type); - for (int i = 0; i != num_reduces; ++i) { - TF_ASSIGN_OR_RETURN(llvm::Value* const input_ir_value, - input_gens[i](input_index)); - Store(input_ir_value, input_address); - TF_RETURN_IF_ERROR(EmitCallToNestedComputation( - *reducers[i], - {partial_reduction_result_addresses[i], input_address}, - partial_reduction_result_addresses[i])); - } - return EmitExtraOutputsForReduce(reduce, input_index, extra_output_gens); - }; - - // x_end = kTileSize + x_in_tiles * kTileSize, i.e., the location that's - // immediately beyond the tile. - llvm::Value* x_end = - NSWAdd(index_typed_constant(kTileSize), - NSWMul(x_in_tiles, index_typed_constant(kTileSize))); - // The tile is entirely in bound if all_threads_in_bounds or - // x_end <= num_elems. - llvm::Value* tile_in_bounds = - Or(ICmpULE(x_end, index_typed_constant(num_elems)), - b_.getInt1(all_threads_in_bounds)); - llvm_ir::LlvmIfData if_tile_in_bounds_data = - llvm_ir::EmitIfThenElse(tile_in_bounds, "tile_in_bounds", &b_); - llvm_ir::SetToFirstInsertPoint(if_tile_in_bounds_data.true_block, &b_); - TF_RETURN_IF_ERROR(emit_tile_element_loop(/*tile_in_bounds=*/true)); - llvm_ir::SetToFirstInsertPoint(if_tile_in_bounds_data.false_block, &b_); - TF_RETURN_IF_ERROR(emit_tile_element_loop(/*tile_in_bounds=*/false)); - - // After the if-then-else statement on tile_in_bounds, emit calls to - // shfl_down that accumulate the partial reduction results of all threads - // from the warp. - llvm_ir::SetToFirstInsertPoint(if_tile_in_bounds_data.after_block, &b_); - int bit_width = llvm_ir::GetSizeInBits(element_ir_type); - // bitcast cannot be applied to aggregate types (even packed ones), so we - // instead bitcast addresses of load/store to intN* of the same bit-width. - llvm::Type* shuffle_ir_type = element_ir_type->isStructTy() - ? b_.getIntNTy(bit_width) - : element_ir_type; - for (int shuffle_distance = kWarpSize / 2; shuffle_distance >= 1; - shuffle_distance /= 2) { - llvm::Value* result_from_other_lane = - Alloca(element_ir_type, nullptr, "result_from_other_lane"); - for (int i = 0; i != num_reduces; ++i) { - llvm::Value* partial_reduction_result = - Load(BitCast(partial_reduction_result_addresses[i], - shuffle_ir_type->getPointerTo()), - "partial_reduction_result"); - CHECK_EQ(launch_dimensions.threads_per_block() % kWarpSize, 0) - << "Requires block size a multiple of the warp size, otherwise we " - "will read undefined elements."; - Store(EmitFullWarpShuffleDown(partial_reduction_result, - b_.getInt32(shuffle_distance), &b_), - BitCast(result_from_other_lane, shuffle_ir_type->getPointerTo())); - TF_RETURN_IF_ERROR(EmitCallToNestedComputation( - *reducers[i], - {partial_reduction_result_addresses[i], result_from_other_lane}, - partial_reduction_result_addresses[i])); - } - } - - const HloInstruction* output = - reduce->IsFused() ? reduce->parent()->FusionInstruction() : reduce; - - // Emit an atomic operation that accumulates the partial reduction result of - // lane 0 (which holds the partially accumulated result for its warp) to the - // output element. - llvm::Value* lane_id = - URem(x_in_tiles, index_typed_constant(kWarpSize), "lane_id"); - llvm_ir::LlvmIfData if_lane_id_is_zero_data = llvm_ir::EmitIfThenElse( - ICmpEQ(lane_id, index_typed_constant(0)), "lane_id_is_zero", &b_); - llvm_ir::SetToFirstInsertPoint(if_lane_id_is_zero_data.true_block, &b_); - - for (int i = 0; i != num_reduces; ++i) { - llvm::Value* output_address = - GetIrArray(*output, *output, reduce_output_shapes[i]) - .EmitArrayElementAddress( - IrArray::Index( - /*linear=*/b_.getInt64(0), - ShapeUtil::GetSubshape(output->shape(), - reduce_output_shapes[i]), - &b_), - &b_, "output_element_address"); - TF_RETURN_IF_ERROR(EmitAtomicOperationForNestedComputation( - *reducers[i], output_address, partial_reduction_result_addresses[i])); - } - return Status::OK(); - }; - - // Emit a parallel loop that iterates through all input tiles, one per thread. - UpdateLaunchDimensions(launch_dimensions, kernel_thunk, - ir_emitter_context_->llvm_module()); - return ParallelLoopEmitter(loop_body_emitter, tiled_input_shape, - launch_dimensions, &b_) - .EmitLoop(IrName(reduce), index_ty); -} - -Status IrEmitterUnnested::EmitColumnReduction( - KernelThunk* kernel_thunk, int64 height, int64 width, - HloInstruction* reduce, const Shape& input_shape, - absl::Span input_gens, - absl::Span init_value_gens, - absl::Span reducers, - absl::Span reduce_output_shapes, - absl::Span> - extra_output_gens) { - // Divide the input matrix into tiles of size KxL. For example, when the - // input matrix is 4x4, K=2, and L=1 the tiled matrix looks like - // - // 0123 - // 0123 - // 4567 - // 4567 // Numbers indicate tile IDs. - // - // Each tile is first partially reduced to a scalar by a thread, and then the - // scalar is accumulated to the output vector using atomic operations. - // - // We choose 128 as the tile size based on empirical evidence. It's big enough - // to reduce the amount of atomic adds in the end, maximizing the memory - // bandwidth. A tile width of 2 allows for high memory bandwidth utilization - // on 16b input data. - constexpr int64 kTileHeight = 128; - constexpr int64 kTileWidth = 2; - - // If the height is not a multiple of kTileHeight, we pad the bottom of the - // input matrix. - const int64 height_in_tiles = CeilOfRatio(height, kTileHeight); - // If width is not a multiple of kTileWidth the rightmost thread will process - // fewer input elements. - const int64 width_in_tiles = CeilOfRatio(width, kTileWidth); - Shape tiled_input_shape = - ShapeUtil::MakeShapeWithLayout(reduce->shape().element_type(), - {height_in_tiles, width_in_tiles}, {1, 0}); - LaunchDimensions launch_dimensions = CalculateLaunchDimensions( - tiled_input_shape, ir_emitter_context_->device_description()); - - // TODO(b/110211620): Convert to use i32 index_type when it is possible. - llvm::Type* index_ty = b_.getInt64Ty(); - - auto index_typed_constant = [&](uint64 c) -> llvm::Constant* { - return llvm::ConstantInt::get(index_ty, c); - }; - - // for (linear_index = threadIdx.x + blockIdx.x * blockDim.x; - // linear_index < height_in_tiles * width_in_tiles; - // linear_index += blockDim.x * gridDim.x) { - // y_in_tiles = linear_index / width_in_tiles; - // x_in_tiles = linear_index % width_in_tiles; - // - // partial_results[kTileWidth] = init_values; - // tile_in_y_bounds = height % kTileHeight == 0 || - // y_in_tiles * kTileHeight + kTileHeight <= height; - // tile_in_x_bounds = width % kTileWidth == 0 || - // x_in_tiles * kTileWidth + kTileWidth <= width; - // // The implementation handles y and x bound checks separately. - // if (tile_in_y_bounds && tile_in_x_bounds) { - // for (y_offset : range(kTileHeight)) { - // y = y_in_tiles * kTileHeight + y_offset; - // for (x_offset : range(kTileWidth)) { - // x = x_in_tiles * kTileWidth + x_offset; - // partial_result = Reducer(partial_result[x_offset], input[y][x]); - // } - // } - // } else { - // for (y_offset : range(kTileHeight)) { - // y = y_in_tiles * kTileHeight + y_offset; - // for (y_offset : range(kTileHeight)) { - // x = x_in_tiles * kTileWidth + x_offset; - // if (y < height && x < width) { - // partial_result = Reducer(partial_result, input[y][x]); - // } - // } - // } - // } - // for (x_offset : range(kTileWidth)) { - // AtomicReducer(&output[x + x_offset], partial_result[x_offset]); - // } - // } - auto loop_body_emitter = [=](const IrArray::Index& tile_index) -> Status { - const int num_reduces = reducers.size(); - // Emit the loop body that reduces one tile. - llvm::Type* element_ir_type = - llvm_ir::PrimitiveTypeToIrType(input_shape.element_type(), module_); - std::vector partial_reduction_result_addresses; - for (int i = 0; i != num_reduces; ++i) { - for (int x_offset = 0; x_offset < kTileWidth; ++x_offset) { - llvm::Value* partial_reduction_result_address = - Alloca(element_ir_type, /*ArraySize=*/nullptr, - "partial_reduction_result." + - llvm::Twine(i * kTileWidth + x_offset)); - TF_ASSIGN_OR_RETURN(llvm::Value* const init_ir_value, - init_value_gens[i](IrArray::Index(index_ty))); - Store(init_ir_value, partial_reduction_result_address); - partial_reduction_result_addresses.push_back( - partial_reduction_result_address); - } - } - - // Emit an inner for-loop that partially reduces the elements in the given - // tile. - llvm::Value* y_in_tiles = tile_index[0]; - llvm::Value* x_in_tiles = tile_index[1]; - - y_in_tiles = ZExtOrTrunc(y_in_tiles, index_ty); - x_in_tiles = ZExtOrTrunc(x_in_tiles, index_ty); - - auto emit_tile_element_loop = [=](bool tile_in_y_bounds, - bool tile_in_x_bounds) -> Status { - std::unique_ptr tile_element_loop = - llvm_ir::ForLoop::EmitForLoop( - "element_id_in_tile", index_typed_constant(0), - index_typed_constant(kTileHeight), index_typed_constant(1), &b_); - - // Emit the body of the partial reduction loop. - llvm_ir::SetToFirstInsertPoint(tile_element_loop->GetBodyBasicBlock(), - &b_); - llvm::Value* y = - NSWAdd(NSWMul(y_in_tiles, index_typed_constant(kTileHeight)), - tile_element_loop->GetIndVarValue()); - - // Unless we know that y is in bounds, we have to emit a check before - // reading from the input. - if (!tile_in_y_bounds) { - llvm_ir::LlvmIfData if_data = llvm_ir::EmitIfThenElse( - ICmpULT(y, index_typed_constant(height)), "y_in_bounds", &b_); - - // Emit code that reads the input element and accumulates it to - // the partial reduction result. - llvm_ir::SetToFirstInsertPoint(if_data.true_block, &b_); - } - for (int x_offset = 0; x_offset < kTileWidth; ++x_offset) { - llvm::Value* x = - NSWAdd(NSWMul(x_in_tiles, index_typed_constant(kTileWidth)), - index_typed_constant(x_offset)); - // Unless we know that x is in bounds, we have to emit a check before - // reading from the input. - if (!tile_in_x_bounds) { - llvm_ir::LlvmIfData if_data = llvm_ir::EmitIfThenElse( - ICmpULT(x, index_typed_constant(width)), "x_in_bounds", &b_); - llvm_ir::SetToFirstInsertPoint(if_data.true_block, &b_); - } - llvm::Value* input_address = Alloca(element_ir_type); - // {y,x} is an index to input_matrix_shape [height,width]. We need to - // convert that to an index to input_shape (the shape of the operand of - // "reduce"). This conversion is composed of a transposition from - // input_shape to normalized_input_shape and a reshape from - // normalized_input_shape to input_matrix_shape. - const Shape normalized_input_shape = - ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout( - input_shape); - auto input_shape_min2maj = LayoutUtil::MinorToMajor(input_shape); - const std::vector transpose_dimension_mapping( - input_shape_min2maj.rbegin(), input_shape_min2maj.rend()); - - const Shape input_matrix_shape = - ShapeUtil::MakeShapeWithDescendingLayout(input_shape.element_type(), - {height, width}); - const IrArray::Index input_matrix_index({y, x}, input_matrix_shape, - &b_); - const IrArray::Index input_index = - input_matrix_index - .SourceIndexOfReshape(input_matrix_shape, - normalized_input_shape, &b_) - .SourceIndexOfTranspose(normalized_input_shape, input_shape, - transpose_dimension_mapping, &b_); - for (int i = 0; i != num_reduces; ++i) { - TF_ASSIGN_OR_RETURN(llvm::Value* const input_ir_value, - input_gens[i](input_index)); - Store(input_ir_value, input_address); - TF_RETURN_IF_ERROR(EmitCallToNestedComputation( - *reducers[i], - {partial_reduction_result_addresses[i * kTileWidth + x_offset], - input_address}, - partial_reduction_result_addresses[i * kTileWidth + x_offset])); - TF_RETURN_IF_ERROR(EmitExtraOutputsForReduce(reduce, input_index, - extra_output_gens)); - } - } - return Status::OK(); - }; - - // y_end = kTileHeight + y_in_tiles * kTileHeight, i.e., the y location - // that's immediately beyond the tile. - llvm::Value* y_end = - NSWAdd(index_typed_constant(kTileHeight), - NSWMul(y_in_tiles, index_typed_constant(kTileHeight))); - // x_end = kTileWidth + x_in_tiles * kTileWidth, i.e., the x location - // that's immediately beyond the tile. - llvm::Value* x_end = - NSWAdd(index_typed_constant(kTileWidth), - NSWMul(x_in_tiles, index_typed_constant(kTileWidth))); - llvm::Value* tile_in_y_bounds = - Or(ICmpULE(y_end, index_typed_constant(height)), - b_.getInt1(height % kTileHeight == 0)); - llvm::Value* tile_in_x_bounds = - Or(ICmpULE(x_end, index_typed_constant(width)), - b_.getInt1(width % kTileWidth == 0)); - // The tile is in y bounds if "height" is a multiple of kTileHeight or - // y_end <= height. - llvm_ir::LlvmIfData if_tile_in_y_bounds_data = - llvm_ir::EmitIfThenElse(tile_in_y_bounds, "tile_in_y_bounds", &b_); - llvm_ir::SetToFirstInsertPoint(if_tile_in_y_bounds_data.true_block, &b_); - // The tile is in x bounds if "width" is a multiple of kTileWidth or - // x_end <= width. - llvm_ir::LlvmIfData if_tile_in_x_bounds_data = - llvm_ir::EmitIfThenElse(tile_in_x_bounds, "tile_in_x_bounds", &b_); - llvm_ir::SetToFirstInsertPoint(if_tile_in_x_bounds_data.true_block, &b_); - TF_RETURN_IF_ERROR(emit_tile_element_loop(/*tile_in_y_bounds=*/true, - /*tile_in_x_bounds=*/true)); - llvm_ir::SetToFirstInsertPoint(if_tile_in_x_bounds_data.false_block, &b_); - TF_RETURN_IF_ERROR(emit_tile_element_loop(/*tile_in_y_bounds=*/true, - /*tile_in_x_bounds=*/false)); - llvm_ir::SetToFirstInsertPoint(if_tile_in_y_bounds_data.false_block, &b_); - if_tile_in_x_bounds_data = - llvm_ir::EmitIfThenElse(tile_in_x_bounds, "tile_in_x_bounds", &b_); - llvm_ir::SetToFirstInsertPoint(if_tile_in_x_bounds_data.true_block, &b_); - TF_RETURN_IF_ERROR(emit_tile_element_loop(/*tile_in_y_bounds=*/false, - /*tile_in_x_bounds=*/true)); - llvm_ir::SetToFirstInsertPoint(if_tile_in_x_bounds_data.false_block, &b_); - TF_RETURN_IF_ERROR(emit_tile_element_loop(/*tile_in_y_bounds=*/false, - /*tile_in_x_bounds=*/false)); - - // After the nested if-then-else statement on tile_in_y_bounds and - // tile_in_x_bounds, emit atomic operations to accumulate the partial - // reduction result to the output element. - llvm_ir::SetToFirstInsertPoint(if_tile_in_y_bounds_data.after_block, &b_); - const HloInstruction* output = - reduce->IsFused() ? reduce->parent()->FusionInstruction() : reduce; - for (int i = 0; i != num_reduces; ++i) { - for (int x_offset = 0; x_offset < kTileWidth; ++x_offset) { - llvm::Value* x = - NSWAdd(NSWMul(x_in_tiles, index_typed_constant(kTileWidth)), - index_typed_constant(x_offset)); - llvm::Value* output_address = - GetIrArray(*output, *output, reduce_output_shapes[i]) - .EmitArrayElementAddress( - IrArray::Index( - x, - ShapeUtil::GetSubshape(output->shape(), - reduce_output_shapes[i]), - &b_), - &b_, "output_element_address"); - TF_RETURN_IF_ERROR(EmitAtomicOperationForNestedComputation( - *reducers[i], output_address, - partial_reduction_result_addresses[i * kTileWidth + x_offset])); - } - } - return Status::OK(); - }; - - // Emit a parallel loop that iterate through all input tiles. - UpdateLaunchDimensions(launch_dimensions, kernel_thunk, - ir_emitter_context_->llvm_module()); - return ParallelLoopEmitter(loop_body_emitter, tiled_input_shape, - launch_dimensions, &b_) - .EmitLoop(IrName(reduce), index_ty); -} - -static std::pair ComputeKernelMappingSchemeForReduction( - int64 depth, int64 width, int64 kWarpSize) { - constexpr int64 kTargetNumElementsPerThread = 64; - int64 x_tile_size = kTargetNumElementsPerThread; - int64 z_tile_size = 1; - - // Only tile along the x dimension with tile size kTargetNumElementsPerThread - // if doing so doesn't require a slow version of loop with bound check on each - // dimension. A more sophisticated heuristics is to enable tile along the - // x dimension with tile size kTargetNumElementsPerThread when either width is - // a factor of (kWarpSize * kTargetNumElementsPerThread) or width is big - // enough so that only a small fraction of the threads execute the slow - // version of loop with bound check. - if (width % (kWarpSize * kTargetNumElementsPerThread) != 0) { - x_tile_size = 8; - z_tile_size = 8; - while (depth % z_tile_size != 0) { - z_tile_size -= 1; - } - } - - return std::pair(x_tile_size, z_tile_size); -} - -Status IrEmitterUnnested::EmitRowReduction( - KernelThunk* kernel_thunk, int64 depth, int64 height, int64 width, - HloInstruction* reduce, const Shape& input_shape, - absl::Span input_gens, - absl::Span init_value_gens, - absl::Span reducers, - absl::Span reduce_output_shapes, - absl::Span> - extra_output_gens) { - // A naive algorithm is: - // 1. Divide the x dimension of the input tensor into tiles of size 1x1xX. - // 2. Partially reduces each tile to a scalar using one thread. - // 3. Accumulates that scalar to the output vector using atomic operations. - // - // for (linear_index = threadIdx.x + blockIdx.x * blockDim.x; - // linear_index < depth * height * width_in_tiles; - // linear_index += blockDim.x * gridDim.x) { - // int x_in_tiles = linear_index % width_in_tiles; - // int y = linear_index / width_in_tiles % height; - // int z = linear_index / (height * width_in_tiles); - // float partial_result = 0; - // for (element_id_in_tile : range(x_tile_size)) { - // int x = x_in_tiles * x_tile_size + element_id_in_tile; - // if (x < width) - // partial_result = reducer(partial_result, input[z][y][x]); - // } - // AtomicReducer(&output[y], partial_result); - // } - // - // Four optimizations are performed. - // - // 1. To coalesce global memory accesses, dilate the tile with a factor of 32 - // (i.e. the warp size). For example, suppose the width is 8x32=256. Instead - // of making each tile consecutive, we let make tile 0 column - // [0,32,64,...,224], tile 1 column [1,33,65,...,225], and so on. This ensures - // that threads in a warp access consecutive memory in one iteration (i.e. - // coalesced). In the above example, the warp that contains thread 0-31 - // accesses column 0-31 in the first iteration, and 32-63 in the second - // iteration, and so on. - // - // 2. Partially accumulate partial reduced results computed by threads in the - // same warp using shfl_down. Using shfl_down is faster than directly using - // atomic operations because shfl_down transfers the data between threads - // using shared memory and threads in the same warp run in lock step (thus no - // extra synchronization needed). See - // https://devblogs.nvidia.com/parallelforall/faster-parallel-reductions-kepler/ - // for details. The downside is, to produce correct results when using - // shfl_down, we need to guarantee threads in the same warp work on input - // elements with the same y, so the number of tiles in each row must be a - // multiple of 32. - // - // 3. Specialize the case that the entire tile is in bounds. When that is - // true, we don't need to emit "if(x 0; shuffle_distance /= 2) - // partial_result = Reducer( - // partial_result, - // __shfl_down_sync(CUDA_WARP_ALL, partial_result, shuffle_distance)); - // if (lane_id == 0) - // AtomicReducer(&output[y], partial_result); - // } - // - - int64 x_tile_size; - int64 z_tile_size; - std::tie(x_tile_size, z_tile_size) = - ComputeKernelMappingSchemeForReduction(depth, width, kWarpSize); - - // Round the width in tiles up to the nearest multiple of kWarpSize, so that - // the use of shfl_down is valid. - const int64 width_in_tiles = - RoundUpToNearest(CeilOfRatio(width, x_tile_size), kWarpSize); - Shape tiled_input_shape = ShapeUtil::MakeShapeWithLayout( - reduce->shape().element_type(), - {depth / z_tile_size, height, width_in_tiles}, {2, 1, 0}); - LaunchDimensions launch_dimensions = CalculateLaunchDimensions( - tiled_input_shape, ir_emitter_context_->device_description()); - llvm::Type* index_ty = - GetIndexTypeForKernel(reduce, launch_dimensions.launch_bound(), &b_); - - auto index_typed_constant = [&](uint64 c) -> llvm::Constant* { - return llvm::ConstantInt::get(index_ty, c); - }; - - auto loop_body_emitter = [=](const IrArray::Index& tile_index) { - const int num_reduces = reducers.size(); - llvm::Type* element_ir_type = llvm_ir::PrimitiveTypeToIrType( - input_shape.element_type(), ir_emitter_context_->llvm_module()); - std::vector partial_reduction_result_addresses; - for (int i = 0; i != num_reduces; ++i) { - llvm::Value* partial_reduction_result_address = - Alloca(element_ir_type, /*ArraySize=*/nullptr, - "partial_reduction_result." + llvm::Twine(i)); - TF_ASSIGN_OR_RETURN(llvm::Value* const init_ir_value, - init_value_gens[i](IrArray::Index(index_ty))); - Store(init_ir_value, partial_reduction_result_address); - partial_reduction_result_addresses.push_back( - partial_reduction_result_address); - } - - llvm::Value* z_tile = tile_index[0]; - llvm::Value* y = tile_index[1]; - llvm::Value* x_tile = tile_index[2]; - - x_tile = ZExtOrTrunc(x_tile, index_ty); - - llvm::Value* warp_id = - UDiv(x_tile, index_typed_constant(kWarpSize), "warp_id"); - llvm::Value* lane_id = - URem(x_tile, index_typed_constant(kWarpSize), "lane_id"); - - // The x-location of the last element in this z-x-tile. - // last_x = lane_id + warpSize * (x_tile_size - 1 + warp_id * x_tile_size); - llvm::Value* last_x = NSWAdd( - lane_id, - NSWMul(index_typed_constant(kWarpSize), - NSWAdd(index_typed_constant(x_tile_size - 1), - NSWMul(warp_id, index_typed_constant(x_tile_size))))); - - KernelSupportLibrary ksl( - &b_, - /*unroll_mode=*/xla::llvm_ir::UnrollMode::kFullyUnroll, - /*prevent_vectorization=*/false); - - // Emit a for-loop that partially reduces the elements in the given - // z-x-tile. - auto emit_z_x_tile_element_loop = [&](bool x_tile_in_bounds, - int64 x_tile_loop_bound) -> Status { - auto emit_z_tile_element_loop = [&](llvm::Value* z_indvar) -> Status { - llvm::Value* z = - NSWAdd(z_indvar, NSWMul(index_typed_constant(z_tile_size), z_tile)); - TF_RETURN_IF_ERROR(ksl.For( - "x_tile", - /*start=*/index_typed_constant(0), - /*end=*/index_typed_constant(x_tile_loop_bound), - /*step=*/1, [&](llvm::Value* x_indvar) -> Status { - // x = lane_id + - // warpSize * (element_id_in_x_tile + warp_id * x_tile_size); - llvm::Value* x = NSWAdd( - lane_id, - NSWMul(index_typed_constant(kWarpSize), - NSWAdd(x_indvar, - NSWMul(warp_id, llvm::ConstantInt::get( - index_ty, x_tile_size))))); - - // Unless we know the x-tile is entirely in bounds, we have to - // emit a x-in-bounds check before reading from the input. - if (!x_tile_in_bounds) { - llvm_ir::LlvmIfData if_x_in_bounds_data = - llvm_ir::EmitIfThenElse( - ICmpULT(x, index_typed_constant(width)), "x_in_bounds", - &b_); - // Points b_ to the then-block. - llvm_ir::SetToFirstInsertPoint(if_x_in_bounds_data.true_block, - &b_); - } - - // Emit code that reads the input element and accumulates it - // to the partial reduction result. - llvm::Value* input_address = Alloca(element_ir_type); - { - // {z,y,x} is an index to input_3d_tensor_shape - // [depth,height,width]. We need to convert that to an index - // to input_shape (the shape of the operand of "reduce"). - // This conversion is composed of a transposition from - // input_shape to normalized_input_shape and a reshape from - // normalized_input_shape to input_3d_tensor_shape. - const Shape normalized_input_shape = ShapeUtil:: - MakeShapeWithDescendingLayoutAndSamePhysicalLayout( - input_shape); - auto input_shape_min2maj = - LayoutUtil::MinorToMajor(input_shape); - const std::vector transpose_dimension_mapping( - input_shape_min2maj.rbegin(), input_shape_min2maj.rend()); - const Shape input_3d_tensor_shape = - ShapeUtil::MakeShapeWithDescendingLayout( - input_shape.element_type(), {depth, height, width}); - const IrArray::Index input_3d_tensor_index( - {z, y, x}, input_3d_tensor_shape, &b_); - const IrArray::Index input_index = - input_3d_tensor_index - .SourceIndexOfReshape(input_3d_tensor_shape, - normalized_input_shape, &b_) - .SourceIndexOfTranspose( - normalized_input_shape, input_shape, - transpose_dimension_mapping, &b_); - - for (int i = 0; i != num_reduces; ++i) { - TF_ASSIGN_OR_RETURN(llvm::Value* const input_ir_value, - input_gens[i](input_index)); - Store(input_ir_value, input_address); - TF_RETURN_IF_ERROR(EmitCallToNestedComputation( - *reducers[i], - {partial_reduction_result_addresses[i], input_address}, - partial_reduction_result_addresses[i])); - } - return EmitExtraOutputsForReduce(reduce, input_index, - extra_output_gens); - } - })); - return Status::OK(); - }; - - return ksl.For("z_tile", - /*start=*/index_typed_constant(0), - /*end=*/index_typed_constant(z_tile_size), - /*step=*/1, emit_z_tile_element_loop); - }; - - llvm::Value* tile_in_bounds = - Or(b_.getInt1(width % (x_tile_size * kWarpSize) == 0), - ICmpULT(last_x, index_typed_constant(width))); - - TF_RETURN_IF_ERROR( - ksl.If(tile_in_bounds, - /*true_block_generator=*/ - [&]() -> Status { - return emit_z_x_tile_element_loop(/*x_tile_in_bounds=*/true, - x_tile_size); - }, - /*false_block_generator=*/ - [&]() -> Status { - return emit_z_x_tile_element_loop( - /*x_tile_in_bounds=*/false, - CeilOfRatio(width % (x_tile_size * kWarpSize), kWarpSize)); - })); - - // After accumulating the elements of the z_x_tile, emit calls to - // shfl_down that accumulate the partial reduction results of all - // threads in a warp. - int bit_width = llvm_ir::GetSizeInBits(element_ir_type); - // bitcast cannot be applied to aggregate types (even packed ones), so we - // instead bitcast addresses of load/store to intN* of the same bit-width. - llvm::Type* shuffle_ir_type = element_ir_type->isStructTy() - ? b_.getIntNTy(bit_width) - : element_ir_type; - for (int shuffle_distance = 16; shuffle_distance >= 1; - shuffle_distance /= 2) { - llvm::Value* result_from_other_lane = - Alloca(element_ir_type, nullptr, "result_from_other_lane"); - for (int i = 0; i != num_reduces; ++i) { - llvm::Value* partial_reduction_result = - Load(BitCast(partial_reduction_result_addresses[i], - shuffle_ir_type->getPointerTo()), - "partial_reduction_result"); - CHECK_EQ(launch_dimensions.threads_per_block() % kWarpSize, 0) - << "Requires block size a multiple of the warp size, otherwise we " - "will read undefined elements."; - Store(EmitFullWarpShuffleDown(partial_reduction_result, - b_.getInt32(shuffle_distance), &b_), - BitCast(result_from_other_lane, shuffle_ir_type->getPointerTo())); - TF_RETURN_IF_ERROR(EmitCallToNestedComputation( - *reducers[i], - {partial_reduction_result_addresses[i], result_from_other_lane}, - partial_reduction_result_addresses[i])); - } - } - - const HloInstruction* output = - reduce->IsFused() ? reduce->parent()->FusionInstruction() : reduce; - - // Emit an atomic operation that accumulates the partial reduction result of - // lane 0 (which holds the partially accumulated result for its warp) to the - // output element. - llvm_ir::LlvmIfData if_lane_id_is_zero_data = llvm_ir::EmitIfThenElse( - ICmpEQ(lane_id, index_typed_constant(0)), "lane_id_is_zero", &b_); - llvm_ir::SetToFirstInsertPoint(if_lane_id_is_zero_data.true_block, &b_); - for (int i = 0; i != num_reduces; ++i) { - llvm::Value* output_address = - GetIrArray(*output, *output, reduce_output_shapes[i]) - .EmitArrayElementAddress( - IrArray::Index(y, - ShapeUtil::GetSubshape( - output->shape(), reduce_output_shapes[i]), - &b_), - &b_, "output_element_address"); - // We don't need to emit atomic operations if there is only one tile of - // results. 'depth' is the z dimension, 'width' is the x dimension. - if (z_tile_size >= depth && x_tile_size >= width) { - TF_RETURN_IF_ERROR(EmitCallToNestedComputation( - *reducers[i], - {output_address, partial_reduction_result_addresses[i]}, - output_address)); - } else { - TF_RETURN_IF_ERROR(EmitAtomicOperationForNestedComputation( - *reducers[i], output_address, - partial_reduction_result_addresses[i])); - } - } - return Status::OK(); - }; - - // Emit a parallel loop that iterates through every input tiles. - UpdateLaunchDimensions(launch_dimensions, kernel_thunk, - ir_emitter_context_->llvm_module()); - return ParallelLoopEmitter(loop_body_emitter, tiled_input_shape, - launch_dimensions, &b_) - .EmitLoop(IrName(reduce), index_ty); -} - -// Figures out whether `reduce` is a row or column reduction, and which -// dimensions to reduce, and calls either `EmitRowReduction` or -// `EmitColumnReduction` as appropriate. -// Prerequisite: all the dimensions to keep are contiguous in the input layout -// and, if `reduce` is fused, the fused subgraph is pure -// elementwise. -Status IrEmitterUnnested::EmitReductionToVector( - KernelThunk* kernel_thunk, HloInstruction* reduce, const Shape& input_shape, - absl::Span input_gens, - absl::Span init_value_gens, - absl::Span dimensions_to_reduce, - absl::Span reducers, - absl::Span reduce_output_shapes, - absl::Span> - extra_output_gens) { - // This emission requires "reduce" to have an input layout. It is either set - // by LayoutAssignment (for a top-level kReduce) or by InstructionFusion (for - // a fused kReduce). - CHECK(input_shape.has_layout()) << "LayoutAssignment or InstructionFusion " - "doesn't set the input layout of " - << reduce->ToString(); - - // Specialize multi-dimensional-array-to-vector reduction. - std::vector input_dims_to_keep; - for (int64 input_dim = 0; input_dim < ShapeUtil::Rank(input_shape); - ++input_dim) { - if (std::find(dimensions_to_reduce.begin(), dimensions_to_reduce.end(), - input_dim) == dimensions_to_reduce.end()) { - input_dims_to_keep.push_back(input_dim); - } - } - - // Sort the dimensions to keep from minor to major, to facilitate checking - // whether another dimension is major or minor of them. - std::sort(input_dims_to_keep.begin(), input_dims_to_keep.end(), - [&input_shape](int64 dim_a, int64 dim_b) { - return PositionInContainer(LayoutUtil::MinorToMajor(input_shape), - dim_a) < - PositionInContainer(LayoutUtil::MinorToMajor(input_shape), - dim_b); - }); - // Now, if output rank is at least 1, `input_dims_to_keep.front()` is - // minormost and `input_dims_to_keep.back()` is majormost. - - // If the dimensions to keep are minormost, emit a column reduction. As all - // the dimensions to keep are contiguous, by prerequisite of - // `EmitReductionToVector`, we only need to check whether the minormost - // dimension of the input is to keep. - if (ShapeUtil::IsEffectiveScalar(reduce->shape())) { - return EmitReductionToScalar(kernel_thunk, reduce, input_shape, input_gens, - init_value_gens, reducers, - reduce_output_shapes, extra_output_gens); - } else if (input_dims_to_keep.front() == - LayoutUtil::Minor(input_shape.layout(), 0)) { - // Column reduction. Treat the result of "input" as a matrix whose width - // is the most minor dimension and height the product of other dimensions, - // and treat "reduce" as a column reduction of the input matrix. - const int64 width = ShapeUtil::ElementsIn(reduce->shape()); - // "width" can be zero, so don't do - // height = ShapeUtil::ElementsIn(input_shape) / width; - int64 height = 1; - for (int64 input_dim = 0; input_dim < ShapeUtil::Rank(input_shape); - ++input_dim) { - if (!std::count(input_dims_to_keep.begin(), input_dims_to_keep.end(), - input_dim)) { - height *= input_shape.dimensions(input_dim); - } - } - return EmitColumnReduction(kernel_thunk, height, width, reduce, input_shape, - input_gens, init_value_gens, reducers, - reduce_output_shapes, extra_output_gens); - } else { - // Reduce the row dimension of a matrix or reduce dimension 0 and 2 in a - // 3D tensor. The size of dimension 1 (the height) is the size of the - // dimension to keep, the size of dimension 0 (the depth) is the product - // of dimensions that are more major than the dimension to keep, and the - // size of dimension 2 (the width) is the product of more minor - // dimensions. - int64 depth = 1; - int64 width = 1; - for (int64 input_dim = 0; input_dim < ShapeUtil::Rank(input_shape); - ++input_dim) { - if (PositionInContainer(LayoutUtil::MinorToMajor(input_shape), - input_dim) > - PositionInContainer(LayoutUtil::MinorToMajor(input_shape), - input_dims_to_keep.back())) { - depth *= input_shape.dimensions(input_dim); - } else if (PositionInContainer(LayoutUtil::MinorToMajor(input_shape), - input_dim) < - PositionInContainer(LayoutUtil::MinorToMajor(input_shape), - input_dims_to_keep.front())) { - width *= input_shape.dimensions(input_dim); - } - } - const int64 height = ShapeUtil::ElementsIn(reduce->shape()); - return EmitRowReduction(kernel_thunk, depth, height, width, reduce, - input_shape, input_gens, init_value_gens, reducers, - reduce_output_shapes, extra_output_gens); - } -} - Status IrEmitterUnnested::HandleReduce(HloInstruction* reduce) { // TODO(b/112040122): Support multi-output reduce. if (!ShapeUtil::IsArray(reduce->shape())) { return Unimplemented("Multi-output reduce is not supported on GPU"); } - auto input = reduce->operand(0); - auto init_value = reduce->operand(1); - absl::Span dimensions_to_reduce(reduce->dimensions()); - HloComputation* reducer = reduce->to_apply(); - // HandleReduce specializes reduction from a multi-dimensional array to a 1D - // array. The specialized version requires an initializer thunk that - // initializes the output array to the initial value of the reduce. if (IsReductionToVector(*reduce)) { - TF_ASSIGN_OR_RETURN(std::unique_ptr initializer_thunk, - BuildInitializerThunk(reduce)); - std::vector> thunks; - thunks.push_back(std::move(initializer_thunk)); - std::unique_ptr kernel_thunk = - BuildKernelThunk(reduce, /*implements_whole_instruction=*/false); - - TF_CHECK_OK(EmitReductionToVector( - kernel_thunk.get(), reduce, input->shape(), - {[&](const IrArray::Index& index) { - return GetIrArray(*input, *reduce).EmitReadArrayElement(index, &b_); - }}, - {[&](const IrArray::Index& index) { - return GetIrArray(*init_value, *reduce) - .EmitReadArrayElement(index, &b_); - }}, - dimensions_to_reduce, {reducer}, {{}}, {})); - - thunks.push_back(std::move(kernel_thunk)); - - std::unique_ptr sequential_thunk = - absl::make_unique(std::move(thunks), reduce); - AddThunkToThunkSequence(std::move(sequential_thunk)); - return Status::OK(); + return EmitReductionToVector(reduce); } return IrEmitter::HandleReduce(reduce); @@ -1820,7 +763,7 @@ Status IrEmitterUnnested::HandleSelectAndScatter( // Create the inner loop to iterate over the window. llvm_ir::ForLoopNest window_loops(IrName(select_and_scatter, "inner"), &b_, index_type); - std::vector window_size; + DimensionVector window_size; for (const auto& dim : window.dimensions()) { window_size.push_back(dim.size()); CHECK_GT(dim.size(), 0); @@ -2352,11 +1295,11 @@ Status IrEmitterUnnested::HandleTupleSelect(HloInstruction* tuple_select) { return IrEmitter::HandleTupleSelect(tuple_select); } -Status IrEmitterUnnested::HandleCrossReplicaSum(HloInstruction* crs) { +Status IrEmitterUnnested::HandleAllReduce(HloInstruction* crs) { if (hlo_module_config_.replica_count() != 1) { // TODO(b/33011107): Support nontrivial cross replica sum on GPU. return Unimplemented( - "CrossReplicaSum with >1 replica is not implemented on GPU."); + "AllReduce with >1 replica is not implemented on GPU."); } // CRS with one operand and one replica is simply the identity function. @@ -2368,7 +1311,7 @@ Status IrEmitterUnnested::HandleCrossReplicaSum(HloInstruction* crs) { // and when it's run. if (crs->operand_count() == 1) { CHECK(ShapeUtil::IsArray(crs->operand(0)->shape())) - << "Operands to cross-replica-sum must be arrays: " << crs->ToString(); + << "Operands to all-reduce must be arrays: " << crs->ToString(); AddThunkToThunkSequence(absl::make_unique( /*source_address=*/GetAllocationSlice(*crs->operand(0)), /*destination_buffer=*/GetAllocationSlice(*crs), @@ -3121,11 +2064,9 @@ Status IrEmitterUnnested::EmitTargetElementLoopInThunk( // pressure, since we touch threadIdx.x and blockIdx.x at the beginning of the // kernel *anyway*. std::vector output_arrays = ConstructIrArrayForOutputs(hlo); - TF_RETURN_IF_ERROR( - KernelSupportLibrary(&b_).If("emit_mof_tuple", IsBlock0Thread0(&b_), [&] { - llvm_ir::EmitTuple(GetIrArray(hlo, hlo), output_arrays, &b_, module_); - return Status::OK(); - })); + KernelSupportLibrary{&b_}.If("emit_mof_tuple", IsBlock0Thread0(&b_), [&] { + llvm_ir::EmitTuple(GetIrArray(hlo, hlo), output_arrays, &b_, module_); + }); // For multioutput fusion, we need to emit each operand and the root. TF_RETURN_IF_ERROR( @@ -3195,34 +2136,36 @@ int IrEmitterUnnested::ConstructInputReducedShapeAndCastInputIrArrayToShape( namespace { -void EmitFullTile(const KernelMappingScheme* mapping_scheme, - const IrArray::Index& tile_origin_index, - llvm::IRBuilder<>* builder, llvm::Value* y, llvm::Value* x, - llvm::Type* index_ty, - const std::function& emit_elem_function) { +void EmitFullElementalTile( + const KernelMappingScheme* mapping_scheme, + const IrArray::Index& tile_origin_index, const string& loop_name, + KernelSupportLibrary* ksl, llvm::IRBuilder<>* builder, llvm::Value* y, + llvm::Value* x, llvm::Type* index_ty, + const std::function& emit_elem_function) { int64 num_threads_x = mapping_scheme->GetNumberOfThreadsForDimensionX(); int64 num_threads_y = mapping_scheme->GetNumberOfThreadsForDimensionY(); int64 tile_size_x = mapping_scheme->GetTileSizeForDimensionX(); int64 tile_size_y = mapping_scheme->GetTileSizeForDimensionY(); - for (int64 i = 0; i < tile_size_y; i += num_threads_y) { - IrArray::Index source_idx_y = - tile_origin_index.AddOffsetToDim(llvm::ConstantInt::get(index_ty, i), - KernelMappingScheme::DimY, builder); - llvm::Value* y_loc = - builder->CreateAdd(llvm::ConstantInt::get(index_ty, i), y); - for (int64 j = 0; j < tile_size_x; j += num_threads_x) { - IrArray::Index source_idx = - source_idx_y.AddOffsetToDim(llvm::ConstantInt::get(index_ty, j), - KernelMappingScheme::DimX, builder); - llvm::Value* x_loc = - builder->CreateAdd(llvm::ConstantInt::get(index_ty, j), x); - emit_elem_function(source_idx, y_loc, x_loc); - } - } -} - -void EmitPartialTile( + ksl->For(loop_name + "_y", /*start=*/llvm::ConstantInt::get(index_ty, 0), + /*end=*/llvm::ConstantInt::get(index_ty, tile_size_y), + /*step=*/llvm::ConstantInt::get(index_ty, num_threads_y), + [&](llvm::Value* y_indvar) { + IrArray::Index source_idx_y = tile_origin_index.AddOffsetToDim( + y_indvar, KernelMappingScheme::DimY, builder); + llvm::Value* y_loc = builder->CreateAdd(y_indvar, y); + for (int64 j = 0; j < tile_size_x; j += num_threads_x) { + IrArray::Index source_idx = source_idx_y.AddOffsetToDim( + llvm::ConstantInt::get(index_ty, j), + KernelMappingScheme::DimX, builder); + llvm::Value* x_loc = + builder->CreateAdd(llvm::ConstantInt::get(index_ty, j), x); + emit_elem_function(source_idx, y_loc, x_loc); + } + }); +} + +void EmitPartialElementalTile( const KernelMappingScheme* mapping_scheme, const IrArray::Index& tile_origin_index, const string& loop_name, KernelSupportLibrary* ksl, llvm::IRBuilder<>* builder, llvm::Value* y, @@ -3241,8 +2184,9 @@ void EmitPartialTile( llvm::Value* x_loc = builder->CreateAdd(llvm::ConstantInt::get(index_ty, j), x); - ksl->IfReturnVoid( - "x_in_tile", builder->CreateICmpULT(x_loc, tile_width), [&] { + ksl->If( + loop_name + "_x_in_tile", builder->CreateICmpULT(x_loc, tile_width), + [&] { // tile_height_bound = // ceil(tile_height / num_threads_y) * num_threads_y llvm::Value* ceiling_of_ratio = builder->CreateUDiv( @@ -3252,15 +2196,15 @@ void EmitPartialTile( llvm::Value* tile_height_bound = builder->CreateMul( ceiling_of_ratio, llvm::ConstantInt::get(index_ty, num_threads_y)); - ksl->ForReturnVoid( + ksl->For( loop_name, /*start=*/llvm::ConstantInt::get(index_ty, 0), /*end=*/tile_height_bound, /*step=*/llvm::ConstantInt::get(index_ty, num_threads_y), [&](llvm::Value* y_indvar) { llvm::Value* y_loc = builder->CreateAdd(y_indvar, y); - ksl->IfReturnVoid( - "y_in_tile", builder->CreateICmpULT(y_loc, tile_height), - [&] { + ksl->If( + loop_name + "_y_in_tile", + builder->CreateICmpULT(y_loc, tile_height), [&] { emit_elem_function( source_idx.AddOffsetToDim( y_indvar, KernelMappingScheme::DimY, builder), @@ -3290,21 +2234,21 @@ void EmitTiledElementalCodeWithBoundsCheck( int64 tile_size_y = mapping_scheme->GetTileSizeForDimensionY(); llvm::Type* index_ty = tile_width->getType(); - ksl->IfReturnVoid( - "full_tile", + ksl->If( + loop_name + "_full_tile", builder->CreateAnd( builder->CreateICmpEQ(llvm::ConstantInt::get(index_ty, tile_size_x), tile_width), builder->CreateICmpEQ(llvm::ConstantInt::get(index_ty, tile_size_y), tile_height)), [&] { - EmitFullTile(mapping_scheme, tile_origin_index, builder, y, x, index_ty, - emit_elem_function); + EmitFullElementalTile(mapping_scheme, tile_origin_index, loop_name, ksl, + builder, y, x, index_ty, emit_elem_function); }, [&] { - EmitPartialTile(mapping_scheme, tile_origin_index, loop_name, ksl, - builder, y, x, tile_height, tile_width, index_ty, - emit_elem_function); + EmitPartialElementalTile(mapping_scheme, tile_origin_index, loop_name, + ksl, builder, y, x, tile_height, tile_width, + index_ty, emit_elem_function); }); } } // namespace @@ -3382,7 +2326,395 @@ void IrEmitterUnnested::EmitTileElementForFusion( } } -// Emits a block of tiles, given a function object to emit one tile. +// Information to support the code generation for a tiled reduction kernel. +using AddressVector = InlinedVector; +class ReductionCodegenInfo : public IrEmitterUnnested::KernelCodegenInfo { + public: + explicit ReductionCodegenInfo(llvm_ir::KernelMappingScheme* mapping_scheme, + bool is_row_reduction) + : KernelCodegenInfo(mapping_scheme), + current_output_linear_index_address_(nullptr), + current_output_inbound_address_(nullptr), + is_row_reduction_(is_row_reduction) {} + + void SetCurrentOutputLinearIndexAddress(llvm::AllocaInst* a) { + current_output_linear_index_address_ = a; + } + // Returns the address of the memory that stores the linear index of the + // current output. Since we are processing reduction to contiguous physical + // dimensions, this linear index is the linear index of the 1D output array. + llvm::AllocaInst* GetCurrentOutputLinearIndexAddress() const { + return current_output_linear_index_address_; + } + + void SetCurrentOutputInboundAddress(llvm::AllocaInst* a) { + current_output_inbound_address_ = a; + } + + llvm::AllocaInst* GetCurrentOutputInboundAddress() const { + return current_output_inbound_address_; + } + + AddressVector* GetMutablePartialResultAddresses() { + return &partial_result_addresses_; + } + absl::Span GetPartialResultAddresses() const { + return partial_result_addresses_; + } + + AddressVector* GetMutableReductionInputAddresses() { + return &reduction_input_addresses_; + } + absl::Span GetReductionInputAddresses() const { + return reduction_input_addresses_; + } + + InlinedVector* GetMutableReducers() { return &reducers_; } + const InlinedVector& GetReducers() const { + return reducers_; + } + int GetNumberOfReduces() const { return reducers_.size(); } + + InlinedVector* GetMutableReductionOutputShapeIndices() { + return &reduction_output_shape_indices_; + } + absl::Span GetReductionOutputShapeIndices() const { + return reduction_output_shape_indices_; + } + + bool IsRowReduction() const { return is_row_reduction_; } + + // Return the dimension that is being reduced between DimX and DimY. + int GetReducedDimensionEnum() const { + return IsRowReduction() ? llvm_ir::KernelMappingScheme::DimX + : llvm_ir::KernelMappingScheme::DimY; + } + + // Return the dimension that is being ketp between DimX and DimY. + int GetKeptDimensionEnum() const { + return IsRowReduction() ? llvm_ir::KernelMappingScheme::DimY + : llvm_ir::KernelMappingScheme::DimX; + } + + private: + AddressVector partial_result_addresses_; + AddressVector reduction_input_addresses_; + InlinedVector reducers_; + InlinedVector reduction_output_shape_indices_; + llvm::AllocaInst* current_output_linear_index_address_; + llvm::AllocaInst* current_output_inbound_address_; + bool is_row_reduction_; +}; + +namespace { +// Returns a group of instructions that generate the output for the kernel +// containing the given HLO instruction. The result may be an unnested kReduce +// HLO, a nested kReduce HLO of a kInput fusion, or the operands of the tuple +// for a multiple output fusion. +absl::Span GetOutputInstructions( + HloInstruction* const* reduce_or_tuple_pointer) { + HloOpcode opcode = (*reduce_or_tuple_pointer)->opcode(); + CHECK(opcode == HloOpcode::kReduce || opcode == HloOpcode::kTuple); + return opcode == HloOpcode::kTuple + ? (*reduce_or_tuple_pointer)->operands() + : absl::Span(reduce_or_tuple_pointer, 1); +} + +const HloInstruction* GetFirstReduceInstruction( + absl::Span instructions) { + auto first_reduce_iter = + absl::c_find_if(instructions, [](const HloInstruction* inst) { + return inst->opcode() == HloOpcode::kReduce; + }); + CHECK_NE(first_reduce_iter, instructions.end()); + return *first_reduce_iter; +} + +}; // namespace + +void IrEmitterUnnested::EmitPrologueForOneReduction( + HloInstruction* unnested_hlo, HloInstruction* reduce_inst, int reduce_idx, + KernelCodegenInfo* kernel_info, GpuElementalIrEmitter* elemental_emitter, + ShapeIndex output_shape_index) { + ReductionCodegenInfo* reduction_info = + static_cast(kernel_info); + + InlinedVector* reducers = + reduction_info->GetMutableReducers(); + CHECK(IsReductionToVector(*reduce_inst)); + reducers->push_back(reduce_inst->to_apply()); + + InlinedVector* reduction_output_shape_indices = + reduction_info->GetMutableReductionOutputShapeIndices(); + reduction_output_shape_indices->push_back(std::move(output_shape_index)); + + AddressVector* reduction_input_addresses = + reduction_info->GetMutableReductionInputAddresses(); + llvm::Type* element_type = llvm_ir::PrimitiveTypeToIrType( + reduce_inst->shape().element_type(), ir_emitter_context_->llvm_module()); + llvm::AllocaInst* reduction_input_address = Alloca(element_type); + reduction_input_addresses->push_back(reduction_input_address); + + AddressVector* partial_result_addresses = + reduction_info->GetMutablePartialResultAddresses(); + llvm::AllocaInst* partial_result_address = + Alloca(element_type, /*ArraySize=*/nullptr, + "partial_reduction_result." + llvm::Twine(reduce_idx)); + partial_result_addresses->push_back(partial_result_address); + + // Initialize the partial result with the initial value of the reduction. + llvm::Value* init_ir_value; + if (unnested_hlo->opcode() == HloOpcode::kFusion) { + HloInstruction* init_value_operand = reduce_inst->mutable_operand(1); + FusedIrEmitter fused_emitter(GetGeneratorForOperandIrArrays(unnested_hlo), + elemental_emitter); + + TF_CHECK_OK(init_value_operand->Accept(&fused_emitter)); + init_ir_value = + fused_emitter + .GetGenerator(init_value_operand)(IrArray::Index(b_.getInt32Ty())) + .ValueOrDie(); + } else { + const HloInstruction* init_value = unnested_hlo->operand(1); + init_ir_value = + GetIrArray(*init_value, *unnested_hlo) + .EmitReadArrayElement(IrArray::Index(b_.getInt32Ty()), &b_); + } + + Store(init_ir_value, partial_result_address); +} + +void IrEmitterUnnested::EmitPrologueForReduction( + HloInstruction* unnested_hlo, KernelCodegenInfo* kernel_info) { + VLOG(10) << "Emit prologue for reduction " << unnested_hlo->ToString(); + // Find the unnested kReduce or the tuple that contains a list of kReduce. + HloInstruction* reduce_or_tuple = unnested_hlo->opcode() == HloOpcode::kFusion + ? unnested_hlo->fused_expression_root() + : unnested_hlo; + absl::Span output_instructions = + GetOutputInstructions(&reduce_or_tuple); + ReductionCodegenInfo* reduction_info = + static_cast(kernel_info); + GpuElementalIrEmitter elemental_emitter(hlo_module_config_, + ir_emitter_context_->llvm_module(), + &b_, GetNestedComputer()); + const HloInstruction* first_reduce = nullptr; + for (int i = 0, e = output_instructions.size(); i != e; ++i) { + if (output_instructions[i]->opcode() != HloOpcode::kReduce) { + continue; + } + HloInstruction* reduce_inst = output_instructions[i]; + if (first_reduce == nullptr) { + first_reduce = reduce_inst; + } else { + CHECK(first_reduce->dimensions() == reduce_inst->dimensions()); + } + ShapeIndex output_shape_index; + if (reduce_or_tuple->opcode() == HloOpcode::kTuple) { + output_shape_index = {i}; + } + + EmitPrologueForOneReduction(unnested_hlo, reduce_inst, i, kernel_info, + &elemental_emitter, + std::move(output_shape_index)); + } + + // Allocate stack storage to store the current output linear index and record + // the address of the storage. + reduction_info->SetCurrentOutputLinearIndexAddress( + Alloca(reduction_info->GetIndexType())); + + if (!reduction_info->IsRowReduction()) { + llvm::Type* bool_ty = b_.getInt1Ty(); + llvm::AllocaInst* output_inbound_addr = Alloca(bool_ty); + Store(llvm::ConstantInt::get(bool_ty, 0), output_inbound_addr); + reduction_info->SetCurrentOutputInboundAddress(output_inbound_addr); + } +} + +void IrEmitterUnnested::EmitFullWarpShuffleDownLoopForAllReduces( + absl::Span reducers, + absl::Span partial_result_addresses) { + for (int distance = 16; distance >= 1; distance /= 2) { + for (int i = 0; i != reducers.size(); ++i) { + llvm::Type* element_type = + partial_result_addresses[i]->getType()->getElementType(); + int bit_width = llvm_ir::GetSizeInBits(element_type); + llvm::Value* result_from_other_lane = Alloca( + element_type, nullptr, "result_from_other_lane" + llvm::Twine(i)); + // Bitcast cannot be applied to aggregate types (even packed ones), so + // we bitcast addresses of load/store to intN* of the same bit-width. + llvm::Type* shuffled_value_type = + element_type->isStructTy() ? b_.getIntNTy(bit_width) : element_type; + auto convert_pointer_for_shuffle = [&](llvm::Value* ptr) { + return BitCast(ptr, shuffled_value_type->getPointerTo()); + }; + llvm::Value* partial_result = + Load(convert_pointer_for_shuffle(partial_result_addresses[i]), + "partial_reduction_result"); + Store(EmitFullWarpShuffleDown(partial_result, b_.getInt32(distance), &b_), + convert_pointer_for_shuffle(result_from_other_lane)); + TF_CHECK_OK(EmitCallToNestedComputation( + *reducers[i], {partial_result_addresses[i], result_from_other_lane}, + partial_result_addresses[i])); + } + } +} + +void IrEmitterUnnested::EmitEpilogueForReduction( + HloInstruction* unnested_hlo, KernelCodegenInfo* kernel_info) { + ReductionCodegenInfo* reduction_info = + static_cast(kernel_info); + int num_reduces = reduction_info->GetNumberOfReduces(); + absl::Span partial_result_addresses = + reduction_info->GetPartialResultAddresses(); + const InlinedVector& reducers = + reduction_info->GetReducers(); + absl::Span reduction_output_shape_indices = + reduction_info->GetReductionOutputShapeIndices(); + + if (reduction_info->IsRowReduction()) { + EmitFullWarpShuffleDownLoopForAllReduces(reducers, + partial_result_addresses); + llvm::Value* lane_id = reduction_info->GetLaneId(); + llvm_ir::LlvmIfData if_lane_id_is_zero_data = llvm_ir::EmitIfThenElse( + ICmpEQ(lane_id, llvm::ConstantInt::get(lane_id->getType(), 0)), + "lane_id_is_zero", &b_); + llvm_ir::SetToFirstInsertPoint(if_lane_id_is_zero_data.true_block, &b_); + } else { + llvm::Value* output_inbound_addr = + reduction_info->GetCurrentOutputInboundAddress(); + llvm::Value* output_inbound = Load(output_inbound_addr); + llvm_ir::LlvmIfData if_output_inbound_data = llvm_ir::EmitIfThenElse( + ICmpEQ(output_inbound, + llvm::ConstantInt::get(output_inbound->getType(), 1)), + "output_inbound", &b_); + llvm_ir::SetToFirstInsertPoint(if_output_inbound_data.true_block, &b_); + } + + // Emit an atomic operation that accumulates the partial reduction to the + // output element. For row reduction, this is only for lane 0 due to the + // if-statement emitted above. + for (int i = 0; i != num_reduces; ++i) { + IrArray::Index element_index( + /*linear=*/Load(reduction_info->GetCurrentOutputLinearIndexAddress(), + "output_linear_addr"), + ShapeUtil::GetSubshape(unnested_hlo->shape(), + reduction_output_shape_indices[i]), + &b_); + llvm::Value* output_address = + GetIrArray(*unnested_hlo, *unnested_hlo, + reduction_output_shape_indices[i]) + .EmitArrayElementAddress(element_index, &b_, + "output_element_address"); + // Do not emit atomic operations if each element in the reduction result is + // computed by one block, that is the dimension being reduced has only one + // block. + const llvm_ir::KernelMappingScheme* mapping_scheme = + reduction_info->GetKernelMappingScheme(); + if (mapping_scheme->GetTileBlockSizeForDimension( + llvm_ir::KernelMappingScheme::DimZ) == 1 && + mapping_scheme->GetTileBlockSizeForDimension( + reduction_info->GetReducedDimensionEnum()) == 1) { + TF_CHECK_OK(EmitCallToNestedComputation( + *reducers[i], {output_address, partial_result_addresses[i]}, + output_address)); + } else { + TF_CHECK_OK(EmitAtomicOperationForNestedComputation( + *reducers[i], output_address, partial_result_addresses[i])); + } + } +} + +void IrEmitterUnnested::EmitTileElementForReduction( + HloInstruction* unnested_hlo, const llvm_ir::IrArray::Index& index, + const KernelCodegenInfo* kernel_info, llvm::Value* y_loc, + llvm::Value* x_loc) { + VLOG(10) << "Emit tile element for reduce " << unnested_hlo->ToString(); + HloInstruction* reduce_or_tuple = unnested_hlo->opcode() == HloOpcode::kFusion + ? unnested_hlo->fused_expression_root() + : unnested_hlo; + llvm_ir::TiledParameterInfo* tiled_param_info = + kernel_info->GetTiledParameterInfo(); + tiled_param_info->set_y(y_loc); + tiled_param_info->set_x(x_loc); + + // Record the linear address for the current reduction. + const ReductionCodegenInfo* reduction_info = + dynamic_cast(kernel_info); + Store(index[reduction_info->GetKeptDimensionEnum()], + reduction_info->GetCurrentOutputLinearIndexAddress()); + if (!reduction_info->IsRowReduction()) { + llvm::Type* bool_ty = b_.getInt1Ty(); + llvm::AllocaInst* output_inbound_addr = + reduction_info->GetCurrentOutputInboundAddress(); + Store(llvm::ConstantInt::get(bool_ty, 1), output_inbound_addr); + } + + InlinedVector input_gens; + std::vector> + extra_output_gens; + GpuElementalIrEmitter elem_emitter(hlo_module_config_, module_, &b_, + GetNestedComputer()); + FusedIrEmitter fused_emitter(GetGeneratorForOperandIrArrays(unnested_hlo), + &elem_emitter); + absl::Span output_instructions = + GetOutputInstructions(&reduce_or_tuple); + // Construct the ElementGenerator for each reduction and extra output in the + // the group of output instructions. + if (unnested_hlo->opcode() == HloOpcode::kFusion) { + fused_emitter.SetTiledParameterInfo(tiled_param_info); + TF_CHECK_OK(unnested_hlo->fused_expression_root()->Accept(&fused_emitter)); + + for (int i = 0, e = output_instructions.size(); i != e; ++i) { + const HloInstruction* inst = output_instructions[i]; + ShapeIndex output_shape_index; + if (reduce_or_tuple->opcode() == HloOpcode::kTuple) { + output_shape_index = {i}; + } + if (inst->opcode() == HloOpcode::kReduce) { + input_gens.push_back(fused_emitter.GetGenerator(inst->operand(0))); + } else { + extra_output_gens.emplace_back(fused_emitter.GetGenerator(inst), + std::move(output_shape_index)); + } + } + } else { + input_gens.push_back([&](const IrArray::Index& index) { + return GetIrArray(*unnested_hlo->operand(0), *unnested_hlo) + .EmitReadArrayElement(index, &b_); + }); + } + + IrArray::Index input_index = + reduction_info->GetKernelMappingScheme()->GetUnnormalizedIndex( + index, + GetFirstReduceInstruction(output_instructions)->operand(0)->shape()); + absl::Span partial_reduction_result_addresses = + reduction_info->GetPartialResultAddresses(); + absl::Span reduction_input_addresses = + reduction_info->GetReductionInputAddresses(); + const InlinedVector& reducers = + reduction_info->GetReducers(); + + // Emit code to generate the input and perform the reduction computation for + // each reduction instruction. + for (int i = 0; i != reducers.size(); ++i) { + llvm::Value* const input_ir_value = input_gens[i](input_index).ValueOrDie(); + Store(input_ir_value, reduction_input_addresses[i]); + TF_CHECK_OK(EmitCallToNestedComputation( + *reducers[i], + {partial_reduction_result_addresses[i], reduction_input_addresses[i]}, + partial_reduction_result_addresses[i])); + } + + // Emit code to generate the output for the non-reduction instructions in the + // fusion, if any. + TF_CHECK_OK( + EmitExtraOutputsForReduce(unnested_hlo, input_index, extra_output_gens)); +} + +// Emits a kernel for the hlo instruction using the given tiling scheme. void IrEmitterUnnested::EmitBlock(const TileGenerator& emit_one_tile, const KernelCodegenInfo* kernel_info, KernelSupportLibrary& ksl, @@ -3419,15 +2751,14 @@ void IrEmitterUnnested::EmitBlock(const TileGenerator& emit_one_tile, Select(ICmpEQ(last_block_for_dim, block_id_for_dim), last_block_size_for_dim, block_size_for_dim); - ksl.ForReturnVoid( - loop_name, - /*start=*/index_typed_constant(0), - /*end=*/num_tiles_in_block, - /*step=*/1, [&](llvm::Value* block_dim_induction_var) { - IrArray::Index tile_index = starting_tile.AddOffsetToDim( - block_dim_induction_var, dim_id, &b_); - emit_next_block_dim(tile_index); - }); + ksl.For(loop_name, + /*start=*/index_typed_constant(0), + /*end=*/num_tiles_in_block, + /*step=*/1, [&](llvm::Value* block_dim_induction_var) { + IrArray::Index tile_index = starting_tile.AddOffsetToDim( + block_dim_induction_var, dim_id, &b_); + emit_next_block_dim(tile_index); + }); } }; @@ -3509,11 +2840,22 @@ LaunchDimensions IrEmitterUnnested::EmitKernel( << llvm_ir::DumpToString(*param_shmem_buffers[id]); } - CHECK_EQ(mapping_scheme->GetThreadsPerTile() % kWarpSize, 0); - LaunchDimensions launch_dimensions = LaunchDimensions( - mapping_scheme->GetNumberOfBlocks(), mapping_scheme->GetThreadsPerTile()); - llvm::Type* index_ty = GetIndexTypeForKernel( - unnested_hlo, launch_dimensions.launch_bound(), &b_); + const ReductionCodegenInfo* reduction_info = + dynamic_cast(kernel_info); + bool is_column_reduction = + (reduction_info && !reduction_info->IsRowReduction()); + + LaunchDimensions launch_dimensions = + LaunchDimensions(mapping_scheme->GetNumberOfBlocks(), + mapping_scheme->GetThreadsPerBlock()); + + // TODO(b/110211620): Enable int32 index type for column reduction. + llvm::Type* index_ty = + is_column_reduction + ? b_.getInt64Ty() + : GetIndexTypeForKernel(unnested_hlo, + launch_dimensions.launch_bound(), &b_); + auto index_typed_constant = [&](uint64 c) -> llvm::Constant* { return llvm::ConstantInt::get(index_ty, c); }; @@ -3523,14 +2865,13 @@ LaunchDimensions IrEmitterUnnested::EmitKernel( // but we do it at the beginning in the hopes of reducing register pressure, // since we touch threadIdx.x and blockIdx.x at the beginning of the kernel // *anyway*. - if (unnested_hlo->IsMultiOutputFusion()) { - TF_CHECK_OK(KernelSupportLibrary(&b_).If( + if (!reduction_info && unnested_hlo->IsMultiOutputFusion()) { + KernelSupportLibrary{&b_}.If( "emit_mof_tuple", IsBlock0Thread0(&b_), [&] { llvm_ir::EmitTuple(GetIrArray(*unnested_hlo, *unnested_hlo), ConstructIrArrayForOutputs(*unnested_hlo), &b_, module_); - return Status::OK(); - })); + }); } // For each tiled parameter, cast its input IrArray to the corresponding @@ -3553,6 +2894,7 @@ LaunchDimensions IrEmitterUnnested::EmitKernel( kernel_info->SetLaneId( mapping_scheme->GetNumberOfThreadsForDimensionX() == kWarpSize ? x : nullptr); + kernel_info->SetIndexType(index_ty); KernelSupportLibrary ksl(&b_, llvm_ir::UnrollMode::kDefaultUnroll); // Curry a few parameters to EmitTiledElementalCodeWithBoundsCheck. @@ -3577,29 +2919,31 @@ LaunchDimensions IrEmitterUnnested::EmitKernel( input_tile_origin.AddOffsetToDim(x, KernelMappingScheme::DimX, &b_) .AddOffsetToDim(y, KernelMappingScheme::DimY, &b_); - // Copy input parameter values to shared memory buffers: - // tile[y, x] = input[index] - // Note that tile_width and tile_height are flipped here because we are - // reading a transposed tile. - emit_tiled_elemental_code_with_bounds_check( - input_index, "input", output_tile_bounds[2], output_tile_bounds[1], - [&](const IrArray::Index& index, llvm::Value* y_loc, - llvm::Value* x_loc) { - for (int64 id : tiled_param_ids) { - IrArray& input_in_logical_shape = param_in_reduced_shape_arrays[id]; - llvm::Value* shmem_buffer = param_shmem_buffers[id]; - // TODO(jlebar): Add AA metadata to this store. Tile buffers are - // global variables, so LLVM can't infer much about it. - Store(input_in_logical_shape.EmitReadArrayElement(index, &b_, - "input_element"), - GEP(shmem_buffer, {index_typed_constant(0), y_loc, x_loc})); - } - }); - // If shared memory transpose is needed, wait for all threads to reach this // point, lest we copy a value from tile to output before the other thread // copies it from input to tile. This is `__syncthreads` in CUDA. if (!tiled_param_ids.empty()) { + // Copy input parameter values to shared memory buffers: + // tile[y, x] = input[index] + // Note that tile_width and tile_height are flipped here because we are + // reading a transposed tile. + emit_tiled_elemental_code_with_bounds_check( + input_index, "input", output_tile_bounds[2], output_tile_bounds[1], + [&](const IrArray::Index& index, llvm::Value* y_loc, + llvm::Value* x_loc) { + for (int64 id : tiled_param_ids) { + IrArray& input_in_logical_shape = + param_in_reduced_shape_arrays[id]; + llvm::Value* shmem_buffer = param_shmem_buffers[id]; + // TODO(jlebar): Add AA metadata to this store. Tile buffers are + // global variables, so LLVM can't infer much about it. + Store(input_in_logical_shape.EmitReadArrayElement( + index, &b_, "input_element"), + GEP(shmem_buffer, {index_typed_constant(0), y_loc, x_loc})); + } + }); + + // Wait for all threads to reach this point using `__syncthreads` in CUDA. llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::nvvm_barrier0, {}, {}, &b_); } @@ -3619,6 +2963,7 @@ LaunchDimensions IrEmitterUnnested::EmitKernel( kernel_generator.GetTileElementGenerator()(unnested_hlo, index, kernel_info, y_loc, x_loc); }); + // If a tile block contains multiple tiles and shared memory buffers are // used, we need to wait for all threads to finish using the shared memory // buffer for the current tile before we move on to process the next tile @@ -3814,6 +3159,249 @@ bool IrEmitterUnnested::CheckAndEmitHloWithTile021(HloInstruction* hlo) { return true; } +namespace { +// Checks that the outputs of a fusion with reduction are consistent. +Status AreFusedReductionOutputsConsistent( + absl::Span output_instructions, + const HloInstruction* first_reduce) { + for (const HloInstruction* inst : output_instructions) { + if (inst->opcode() == HloOpcode::kReduce) { + // Shapes, layouts and dimensions must be the same for all reduces + // inside of this fusion. + TF_RET_CHECK(ShapeUtil::Equal(first_reduce->shape(), inst->shape())); + TF_RET_CHECK(ShapeUtil::Equal(first_reduce->operand(0)->shape(), + inst->operand(0)->shape())); + TF_RET_CHECK(ShapeUtil::Equal(first_reduce->operand(1)->shape(), + inst->operand(1)->shape())); + TF_RET_CHECK(first_reduce->dimensions() == inst->dimensions()); + } else { + // For extra outputs we can relax shape equality to allow different + // types (with the same number of elements). Layouts still have to + // match. + TF_RET_CHECK(ShapeUtil::CompatibleIgnoringElementType( + first_reduce->operand(0)->shape(), inst->shape())); + TF_RET_CHECK(LayoutUtil::Equal(first_reduce->operand(0)->shape().layout(), + inst->shape().layout())); + } + } + return Status::OK(); +} + +// Finds the dimensions to keep for the reduction, sorts and returns the +// dimensions from minor to major. +DimensionVector GetDimensionsToKeepMinorToMajor( + const Shape& input_shape, absl::Span dims_to_reduce) { + DimensionVector input_dims(ShapeUtil::Rank(input_shape), 0); + absl::c_iota(input_dims, 0); + DimensionVector input_dims_to_keep; + for (int input_dim : input_dims) { + auto it = absl::c_find_if(dims_to_reduce, [&](int64 dim_to_reduce) { + return dim_to_reduce == input_dim; + }); + if (it == dims_to_reduce.end()) { + input_dims_to_keep.push_back(input_dim); + } + } + + // Sort the dimensions to keep from minor to major. + absl::c_sort(input_dims_to_keep, [&input_shape](int64 dim_a, int64 dim_b) { + return PositionInContainer(LayoutUtil::MinorToMajor(input_shape), dim_a) < + PositionInContainer(LayoutUtil::MinorToMajor(input_shape), dim_b); + }); + + VLOG(10) << "dims to keep minor to major" + << absl::StrJoin(input_dims_to_keep, ","); + return input_dims_to_keep; +} + +// Given the input shape and dimensions to reduce for the reduction to vector, +// returns : +// num_kept: the number of elements in the contiguous dimensions to keep. +// num_reduced_major: the number of elements in the dimensions to reduce that +// are more major than the dimensions to keep. +// num_reduced_minor: the number of elements in the dimensions to reduce that +// are more minor than the dimensions to kept. +std::tuple GetReductionToVectorDimensions( + const Shape& input_shape, absl::Span dims_to_reduce) { + DimensionVector input_dims_to_keep_minor_to_major = + GetDimensionsToKeepMinorToMajor(input_shape, dims_to_reduce); + CHECK(LayoutUtil::AreDimensionsConsecutive( + input_shape.layout(), input_dims_to_keep_minor_to_major)); + int num_reduced_major = 1, num_kept = 1, num_reduced_minor = 1; + if (input_dims_to_keep_minor_to_major.empty()) { + return std::make_tuple(num_reduced_major, num_kept, num_reduced_minor); + } + DimensionVector input_dims(ShapeUtil::Rank(input_shape), 0); + absl::c_iota(input_dims, 0); + absl::Span minor_to_major = + LayoutUtil::MinorToMajor(input_shape); + for (int input_dim : input_dims) { + int64 curr_dim_size = input_shape.dimensions(input_dim); + if (PositionInContainer(minor_to_major, input_dim) > + PositionInContainer(minor_to_major, + input_dims_to_keep_minor_to_major.back())) { + num_reduced_major *= curr_dim_size; + } else if (PositionInContainer(minor_to_major, input_dim) < + PositionInContainer(minor_to_major, + input_dims_to_keep_minor_to_major.front())) { + num_reduced_minor *= curr_dim_size; + } else { + num_kept *= curr_dim_size; + } + } + + return std::make_tuple(num_reduced_major, num_kept, num_reduced_minor); +} + +} // namespace + +std::tuple +IrEmitterUnnested::ComputeMappingSchemeAndReductionKind( + const HloInstruction* first_reduce) { + int64 depth = 1; + int64 height = 1; + int64 width = 1; + bool is_row_reduction = true; + int64 tile_size_x = 1; + int64 tile_size_y = 1; + int64 block_size_z = 1; + int64 num_threads_x = 1; + int64 num_threads_y = 1; + const Shape& input_shape = first_reduce->operand(0)->shape(); + int64 num_input_elems = ShapeUtil::ElementsIn(input_shape); + int64 num_output_elems = ShapeUtil::ElementsIn(first_reduce->shape()); + int64 num_reduced_major, num_kept, num_reduced_minor; + std::tie(num_reduced_major, num_kept, num_reduced_minor) = + GetReductionToVectorDimensions(input_shape, first_reduce->dimensions()); + CHECK_EQ(num_output_elems, num_kept); + + if (num_kept == 1) { + // Scalar reduction is a special row reduction with depth = height = 1. + width = num_input_elems; + tile_size_x = kWarpSize * 16; + num_threads_x = kWarpSize; + } else if (num_reduced_minor == 1) { + // Column reduction reduces inputs with dimension [height, width], where + // width is the minor dimension, to dimension [width]. + height = num_reduced_major; + width = num_kept; + is_row_reduction = false; + // Column reduction without transpose doesn't require communication among + // threads processing elements in the same tile. The current implementation + // only support the use of on hardware thread block to process one block of + // tiles in the KernelMappingScheme. We try to maximize the values of + // num_threads_x and tile_size_x to allow a bigger hardware thread block. + int64 hw_threads_per_block_limit = + ThreadsPerBlockLimit(ir_emitter_context_->device_description()); + tile_size_x = std::min(hw_threads_per_block_limit, num_kept); + num_threads_x = tile_size_x; + int64 kNumElementsPerPartialSum = 128; + tile_size_y = kNumElementsPerPartialSum; + } else { + // Row reduction reduces inputs with dimension [depth, height, width], + // where width is the most minor dimension, to dimension [height] . + depth = num_reduced_major; + height = num_kept; + width = num_reduced_minor; + num_threads_x = kWarpSize; + if (width % (kWarpSize * 64) == 0) { + tile_size_x = kWarpSize * 64; + } else { + tile_size_x = kWarpSize * 8; + block_size_z = 8; + while (depth % block_size_z != 0) { + block_size_z -= 1; + } + } + } + DCHECK_EQ(depth * height * width, num_input_elems); + VLOG(10) << "is_row_reduction " << is_row_reduction << depth << " " << height + << " " << width; + + DimensionVector dims_in_elem{depth, height, width}; + DimensionVector req_block_sizes{block_size_z, 1, 1}; + llvm_ir::KernelMappingScheme mapping_scheme( + dims_in_elem, tile_size_y, tile_size_x, req_block_sizes, num_threads_y, + num_threads_x, &b_); + return std::make_tuple(mapping_scheme, is_row_reduction); +} + +Status IrEmitterUnnested::EmitReductionToVector(HloInstruction* unnested_hlo) { + VLOG(10) << "Emitting reduction to vector " << unnested_hlo->ToString(); + + HloInstruction* reduce_or_tuple = unnested_hlo->opcode() == HloOpcode::kFusion + ? unnested_hlo->fused_expression_root() + : unnested_hlo; + absl::Span output_instructions = + GetOutputInstructions(&reduce_or_tuple); + const HloInstruction* first_reduce = + GetFirstReduceInstruction(output_instructions); + + if (output_instructions.size() > 1) { + TF_RETURN_IF_ERROR( + AreFusedReductionOutputsConsistent(output_instructions, first_reduce)); + } + + // Build an initializer thunk to initialize each reduction output. + std::vector> thunks; + for (int i = 0, e = output_instructions.size(); i != e; ++i) { + if (output_instructions[i]->opcode() != HloOpcode::kReduce) { + continue; + } + TF_ASSIGN_OR_RETURN( + std::unique_ptr initializer_thunk, + BuildInitializerThunk(unnested_hlo, + (output_instructions[i] == reduce_or_tuple) + ? ShapeIndex() + : ShapeIndex({i}))); + thunks.push_back(std::move(initializer_thunk)); + } + + // Build a kernel thunk to compute all the outputs. + std::unique_ptr kernel_thunk = + BuildKernelThunk(unnested_hlo, /*implements_whole_instruction=*/false); + + const Shape& input_shape = first_reduce->operand(0)->shape(); + // The layout of a reduction input is either set by LayoutAssignment for + // unnested kReduce or by InstructionFusion for fused kReduce. + CHECK(input_shape.has_layout()) << "LayoutAssignment or InstructionFusion " + "doesn't set the input layout of " + << first_reduce->ToString(); + + bool is_row_reduction; + llvm_ir::KernelMappingScheme mapping_scheme; + std::tie(mapping_scheme, is_row_reduction) = + ComputeMappingSchemeAndReductionKind(first_reduce); + ReductionCodegenInfo reduction_info(&mapping_scheme, is_row_reduction); + KernelCodeGenerator kernel_generator( + /*tile_element_generator=*/ + [&](HloInstruction* hlo, const llvm_ir::IrArray::Index& index, + const KernelCodegenInfo* kernel_info, llvm::Value* y_loc, + llvm::Value* x_loc) { + EmitTileElementForReduction(hlo, index, kernel_info, y_loc, x_loc); + }, + /*block_prologue_generator=*/ + [&](HloInstruction* hlo, KernelCodegenInfo* kernel_info) { + EmitPrologueForReduction(hlo, kernel_info); + }, + /*block_epilogue_generator*/ + [&](HloInstruction* hlo, KernelCodegenInfo* kernel_info) { + EmitEpilogueForReduction(hlo, kernel_info); + }); + + LaunchDimensions launch_dimensions = + EmitKernel(unnested_hlo, {}, kernel_generator, &reduction_info); + UpdateLaunchDimensions(launch_dimensions, kernel_thunk.get(), + ir_emitter_context_->llvm_module()); + + thunks.push_back(std::move(kernel_thunk)); + std::unique_ptr sequential_thunk = + absl::make_unique(std::move(thunks), unnested_hlo); + AddThunkToThunkSequence(std::move(sequential_thunk)); + + return Status::OK(); +} + Status IrEmitterUnnested::EmitConstantGlobals() { for (const BufferAllocation& allocation : ir_emitter_context_->buffer_assignment().Allocations()) { diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h index e09ed657a812be6ab4859a0e365a51c45a37bfed..d217ee36cf6e9b5278024a2f78513232328e7538 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h @@ -16,6 +16,7 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_IR_EMITTER_UNNESTED_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_IR_EMITTER_UNNESTED_H_ +#include "absl/container/inlined_vector.h" #include "tensorflow/compiler/xla/service/gpu/ir_emitter.h" #include "tensorflow/compiler/xla/service/gpu/sequential_thunk.h" #include "tensorflow/compiler/xla/service/gpu/thunk.h" @@ -68,9 +69,12 @@ class IrEmitterUnnested : public IrEmitter { explicit KernelCodegenInfo(llvm_ir::KernelMappingScheme* mapping_scheme) : mapping_scheme_(mapping_scheme), tiled_param_info_(nullptr), - lane_id_(nullptr) {} + lane_id_(nullptr), + index_ty_(nullptr) {} + virtual ~KernelCodegenInfo() {} void SetLaneId(llvm::Value* v) { lane_id_ = v; } + void SetIndexType(llvm::Type* t) { index_ty_ = t; } void SetTiledParamInfo(llvm_ir::TiledParameterInfo* tiled_param_info) { CHECK_EQ(tiled_param_info_, nullptr); tiled_param_info_ = tiled_param_info; @@ -83,11 +87,13 @@ class IrEmitterUnnested : public IrEmitter { llvm_ir::TiledParameterInfo* GetTiledParameterInfo() const { return tiled_param_info_; } + llvm::Type* GetIndexType() const { return index_ty_; } private: llvm_ir::KernelMappingScheme* mapping_scheme_; llvm_ir::TiledParameterInfo* tiled_param_info_; llvm::Value* lane_id_; + llvm::Type* index_ty_; }; // A function object to prepare for the code generation for a tile block. @@ -170,7 +176,7 @@ class IrEmitterUnnested : public IrEmitter { Status HandleSelect(HloInstruction* select) override; Status HandleSort(HloInstruction* sort) override; Status HandleTupleSelect(HloInstruction* tuple_select) override; - Status HandleCrossReplicaSum(HloInstruction* crs) override; + Status HandleAllReduce(HloInstruction* crs) override; Status HandleAfterAll(HloInstruction* after_all) override; Status EmitTargetElementLoop( @@ -200,82 +206,19 @@ class IrEmitterUnnested : public IrEmitter { // Helper for writing extra outputs from inside a reduce kernel. Status EmitExtraOutputsForReduce( - const HloInstruction* reduce, const llvm_ir::IrArray::Index& index, + const HloInstruction* unnested_hlo, const llvm_ir::IrArray::Index& index, absl::Span> extra_output_gens); - // EmitColumnReduction and EmitRowReduction emit code for column and row - // reduction of a matrix and/or 3D tensor. Row and column reduction have - // different memory access pattern, so for performance their implementations - // are significantly different. + // Generates code for reduction to contiguous dimensions. // - // Emits code that reduces a matrix of shape [height x width] to a vector of - // [width]. Other parameters have the same meaning as those of - // `EmitReductionToVector`. Note that input shape might not be - // [height x width], but can be bitcast to [height x width] with "height" - // being the major dimension. - Status EmitColumnReduction( - KernelThunk* kernel_thunk, int64 height, int64 width, - HloInstruction* reduce, const Shape& input_shape, - absl::Span input_gens, - absl::Span init_value_gens, - absl::Span reducers, - absl::Span reduce_output_shapes, - absl::Span> - extra_output_gens); - - // Emits code that reduces a 3D tensor of shape [depth x height x width] to a - // vector of shape [height]. Other parameters have the same meaning as those - // of `EmitReductionToVector`. Note that input shape might not be - // [depth x height x width], but can be bitcast to [depth x height x width] - // with "depth" being the most major dimension. - Status EmitRowReduction( - KernelThunk* kernel_thunk, int64 depth, int64 height, int64 width, - HloInstruction* reduce, const Shape& input_shape, - absl::Span input_gens, - absl::Span init_value_gens, - absl::Span reducers, - absl::Span reduce_output_shapes, - absl::Span> - extra_output_gens); - - // Emits code that reduces a tensor of arbitrary rank to a scalar. - Status EmitReductionToScalar( - KernelThunk* kernel_thunk, HloInstruction* reduce, - const Shape& input_shape, - absl::Span input_gens, - absl::Span init_value_gens, - absl::Span reducers, - absl::Span reduce_output_shapes, - absl::Span> - extra_output_gens); + // Prerequisite: `IsReductionToVector(*unnested_hlo)` + Status EmitReductionToVector(HloInstruction* unnested_hlo); - // Figures out whether `reduce` is a row or column reduction, and which - // dimensions to reduce, and calls either `EmitRowReduction` or - // `EmitColumnReduction` as appropriate. `input_shape` is the shape of the - // input array, which is the operand of the Reduce instruction if unfused or - // of the Fusion instruction if fused. `input_gen` and `init_value_gen` - // generate elements of the input and the initial value. Other parameters mean - // the same as for `HandleReduce`. - // - // Multiple reduces can be emitted in the same loop, assuming they have the - // same input and output shapes, and the same reduce dimensions. - // - // extra_output_gens can contain extra generators for intermediate outputs. - // These must have the same shape as the reduce input as they are computed - // when the reduce inputs are being read. - // - // Prerequisite: `IsReductionToVector(*reduce)` - Status EmitReductionToVector( - KernelThunk* kernel_thunk, HloInstruction* reduce, - const Shape& input_shape, - absl::Span input_gens, - absl::Span init_value_gens, - absl::Span dimensions_to_reduce, - absl::Span reducers, - absl::Span reduce_output_shapes, - absl::Span> - extra_output_gens); + // Computes the KernelMappingScheme for the reduce HLO and indicates whether + // the reduction is a row reduction. + std::tuple + ComputeMappingSchemeAndReductionKind(const HloInstruction* first_reduce); // Emits code for an in-place scatter, modifying `thunk`s launch dimensions in // the process. `scatter` may be fused, scatter indices are taken from @@ -314,6 +257,28 @@ class IrEmitterUnnested : public IrEmitter { const llvm_ir::IrArray::Index& index, const KernelCodegenInfo* kernel_info, llvm::Value* y_loc, llvm::Value* x_loc); + // Emits code to process a tensor element in a tile for the given input hlo + // that is either a unnested kReduce or a kInput fusion. + void EmitTileElementForReduction(HloInstruction* unnested_hlo, + const llvm_ir::IrArray::Index& index, + const KernelCodegenInfo* kernel_info, + llvm::Value* y_loc, llvm::Value* x_loc); + // Prepares for the code generation for a tile block of a reduction kernel. + void EmitPrologueForReduction(HloInstruction* unnested_hlo, + KernelCodegenInfo* kernel_info); + void EmitPrologueForOneReduction(HloInstruction* unnested_hlo, + HloInstruction* reduce_inst, int reduce_idx, + KernelCodegenInfo* kernel_info, + GpuElementalIrEmitter* elemental_emitter, + ShapeIndex output_shape_index); + // Wraps up the code generation for a tile block of a reduction kernel. + void EmitEpilogueForReduction(HloInstruction* unnested_hlo, + KernelCodegenInfo* kernel_info); + // For each reducer, emits the shuffle-down loop to accumulate the partial + // result to the global result. + void EmitFullWarpShuffleDownLoopForAllReduces( + absl::Span reducers, + absl::Span partial_result_addresses); // Generates the IrArray for each input of an hlo and returns a vector that // constains such IrArrays. diff --git a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/nvptx_backend_lib.cc b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/nvptx_backend_lib.cc index abf9e7b6d62d931e8a937b243bda09f21f604467..bd53b90b42d8e657a3ee58e7ca03fb60522aae28 100644 --- a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/nvptx_backend_lib.cc +++ b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/nvptx_backend_lib.cc @@ -199,8 +199,7 @@ std::unique_ptr GetTargetMachine( } return absl::WrapUnique(target->createTargetMachine( triple.str(), llvm_ir::AsStringRef(cpu_name), "+ptx60", target_options, - Optional(RelocModel), Optional(CMModel), - codegen_opt_level)); + getRelocModel(), getCodeModel(), codegen_opt_level)); } // Adds the standard LLVM optimization passes, based on the speed optimization diff --git a/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc b/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc index e934cbda1765cb10b4ff2ac14c3ff2f7a5f5cc41..cd369d55987b96eed2efb64ae0df6b3a76acb672 100644 --- a/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc +++ b/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc @@ -36,6 +36,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/buffer_liveness.h" #include "tensorflow/compiler/xla/service/call_inliner.h" #include "tensorflow/compiler/xla/service/conditional_simplifier.h" +#include "tensorflow/compiler/xla/service/convolution_group_converter.h" #include "tensorflow/compiler/xla/service/flatten_call_graph.h" #include "tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_rewriter.h" #include "tensorflow/compiler/xla/service/gpu/cudnn_conv_algorithm_picker.h" @@ -108,27 +109,33 @@ namespace { namespace tracing = tensorflow::tracing; -// Returns the directory containing nvvm libdevice files. config_cuda_data_dir -// should be equal to config().debug_options().xla_gpu_cuda_data_dir() of the -// HloModule being compiled. -string GetLibdeviceDir(const string& config_cuda_data_dir) { - std::vector potential_libdevice_dirs; - if (!config_cuda_data_dir.empty()) { - potential_libdevice_dirs.push_back(config_cuda_data_dir); - } - potential_libdevice_dirs.push_back(tensorflow::LibdeviceRoot()); - - // Tries all potential libdevice directories in the order they are inserted. - // Returns the first directory that exists in the file system. - for (const string& potential_libdevice_dir : potential_libdevice_dirs) { - if (tensorflow::Env::Default()->IsDirectory(potential_libdevice_dir).ok()) { - VLOG(2) << "Found libdevice dir " << potential_libdevice_dir; - return potential_libdevice_dir; - } - VLOG(2) << "Unable to find potential libdevice dir " - << potential_libdevice_dir; +// Returns a vector of potential locations of the CUDA root directory. +std::vector GetCudaRootCandidates( + const HloModuleConfig& hlo_module_config) { + std::vector potential_cuda_roots = tensorflow::CandidateCudaRoots(); + + // CUDA location explicitly specified by user via --xla_gpu_cuda_data_dir has + // highest priority. + string xla_gpu_cuda_data_dir = + hlo_module_config.debug_options().xla_gpu_cuda_data_dir(); + if (!xla_gpu_cuda_data_dir.empty()) { + potential_cuda_roots.insert(potential_cuda_roots.begin(), + xla_gpu_cuda_data_dir); } + return potential_cuda_roots; +} +// Returns the directory containing nvvm libdevice files. +string GetLibdeviceDir(const HloModuleConfig& hlo_module_config) { + for (const string& cuda_root : GetCudaRootCandidates(hlo_module_config)) { + string libdevice_dir = + tensorflow::io::JoinPath(cuda_root, "nvvm", "libdevice"); + VLOG(2) << "Looking for libdevice at " << libdevice_dir; + if (tensorflow::Env::Default()->IsDirectory(libdevice_dir).ok()) { + VLOG(2) << "Found libdevice dir " << libdevice_dir; + return libdevice_dir; + } + } LOG(WARNING) << "Unable to find libdevice dir. Using '.'"; // Last resort: maybe in the current folder. return "."; @@ -152,6 +159,13 @@ Status OptimizeHloModule(HloModule* hlo_module, se::StreamExecutor* stream_exec, // TODO(b/64094172): make Call work on GPU instead of inlining. pipeline.AddPass(); + auto cost_model = [](HloInstruction* conv) { + // We need a cost model for GPUs. Currently, do nothing. + return false; + }; + pipeline.AddPass( + cost_model, + /*convert_batch_groups_only=*/true); // Convert BF16 operations to F32 operations so that the GPU backend can // support BF16 operations without directly implementing a BF16 lowering for // most ops. @@ -478,13 +492,19 @@ void WarnIfBadDriverJITVersion() { // Compiles the given PTX string using ptxas and returns the resulting machine // code (i.e. a cubin) as a byte array. -StatusOr> CompilePtx(const string& ptx, int cc_major, - int cc_minor) { +StatusOr> CompilePtx( + const string& ptx, int cc_major, int cc_minor, + const HloModuleConfig& hlo_module_config) { tracing::ScopedActivity activity("Compile PTX", /*is_expensive=*/true); - const string ptxas_path = - tensorflow::io::JoinPath(tensorflow::CudaRoot(), "bin", "ptxas"); - VLOG(2) << "Checking ptxas at " << ptxas_path; auto env = tensorflow::Env::Default(); + string ptxas_path; + for (const string& cuda_root : GetCudaRootCandidates(hlo_module_config)) { + ptxas_path = tensorflow::io::JoinPath(cuda_root, "bin", "ptxas"); + VLOG(2) << "Looking for ptxas at " << ptxas_path; + if (env->FileExists(ptxas_path).ok()) { + break; + } + } TF_RETURN_IF_ERROR(env->FileExists(ptxas_path)); VLOG(2) << "Using ptxas at " << ptxas_path; @@ -519,6 +539,9 @@ StatusOr> CompilePtx(const string& ptx, int cc_major, if (VLOG_IS_ON(2)) { ptxas_args.push_back("-v"); } + if (hlo_module_config.debug_options().xla_gpu_disable_ptxas_optimizations()) { + ptxas_args.push_back("-O0"); + } ptxas_info_dumper.SetProgram(ptxas_path, ptxas_args); ptxas_info_dumper.SetChannelAction(tensorflow::CHAN_STDERR, tensorflow::ACTION_PIPE); @@ -681,12 +704,8 @@ StatusOr> NVPTXCompiler::RunBackend( // Find the directory containing libdevice. To avoid searching for it every // time, we have a one-element cache, keyed on the module's config's // cuda_data_dir. - const auto& config_cuda_data_dir = - module->config().debug_options().xla_gpu_cuda_data_dir(); - if (cached_libdevice_dir_.empty() || - cached_cuda_data_dir_ != config_cuda_data_dir) { - cached_cuda_data_dir_ = config_cuda_data_dir; - cached_libdevice_dir_ = GetLibdeviceDir(config_cuda_data_dir); + if (cached_libdevice_dir_.empty()) { + cached_libdevice_dir_ = GetLibdeviceDir(module->config()); } libdevice_dir = cached_libdevice_dir_; } @@ -740,7 +759,7 @@ StatusOr> NVPTXCompiler::RunBackend( } const std::vector cubin = - CompilePtxOrGetCachedResult(ptx, cc_major, cc_minor); + CompilePtxOrGetCachedResult(ptx, cc_major, cc_minor, module->config()); auto thunk_schedule = absl::make_unique( ir_emitter.ConsumeThunkSequence(), std::move(stream_assignment), @@ -772,9 +791,9 @@ StatusOr> NVPTXCompiler::RunBackend( return std::unique_ptr(gpu_executable); } -std::vector NVPTXCompiler::CompilePtxOrGetCachedResult(const string& ptx, - int cc_major, - int cc_minor) { +std::vector NVPTXCompiler::CompilePtxOrGetCachedResult( + const string& ptx, int cc_major, int cc_minor, + const HloModuleConfig& hlo_module_config) { XLA_SCOPED_LOGGING_TIMER("NVPTXCompiler::CompilePtxOrGetCachedResult"); tracing::ScopedActivity activity("PTX->CUBIN", /*is_expensive=*/true); bool inserted; @@ -803,7 +822,7 @@ std::vector NVPTXCompiler::CompilePtxOrGetCachedResult(const string& ptx, CHECK(!cache_value->compilation_done); if (!ptx.empty()) { StatusOr> maybe_cubin = - CompilePtx(*cache_ptx, cc_major, cc_minor); + CompilePtx(*cache_ptx, cc_major, cc_minor, hlo_module_config); if (maybe_cubin.ok()) { cache_value->cubin_data = std::move(maybe_cubin).ValueOrDie(); VLOG(2) << "Compiled PTX size:" << ptx.size() diff --git a/tensorflow/compiler/xla/service/gpu/nvptx_compiler.h b/tensorflow/compiler/xla/service/gpu/nvptx_compiler.h index f79ae2990ae7d6e6985b15727a72358289121aa9..b2077f42fd097330703fde063d80a20704fa48e2 100644 --- a/tensorflow/compiler/xla/service/gpu/nvptx_compiler.h +++ b/tensorflow/compiler/xla/service/gpu/nvptx_compiler.h @@ -97,8 +97,9 @@ class NVPTXCompiler : public LLVMCompiler { // Tries to compile the given ptx string to cubin. Returns a vector with the // compiled cubin. If compilation was unsuccessful, returns an empty vector. - std::vector CompilePtxOrGetCachedResult(const string& ptx, - int cc_major, int cc_minor); + std::vector CompilePtxOrGetCachedResult( + const string& ptx, int cc_major, int cc_minor, + const HloModuleConfig& hlo_module_config); // The compilation_cache_ map is a cache from {ptx string, cc_major, cc_minor} // -> cubin so we don't recompile the same ptx twice. This is important for diff --git a/tensorflow/compiler/xla/service/gpu/partition_assignment.cc b/tensorflow/compiler/xla/service/gpu/partition_assignment.cc index 375f68a15957936151aee068582a714b62694af2..bfed4f5230dfe37bca48560ce83a2dd82c8950a4 100644 --- a/tensorflow/compiler/xla/service/gpu/partition_assignment.cc +++ b/tensorflow/compiler/xla/service/gpu/partition_assignment.cc @@ -39,6 +39,25 @@ std::ostream& operator<<(std::ostream& out, return out; } +int64 ThreadsPerBlockLimit(const se::DeviceDescription& device_desc) { + int64 threads_per_block = device_desc.threads_per_block_limit(); + if (threads_per_block == 0) { + static std::atomic log_count{0}; + if (log_count.fetch_add(1) < 8) { + LOG(WARNING) << "Attempting to calculate launch dimensions for GPU " + "without full information about its capabilities. " + "StreamExecutor's PopulateDeviceDescription should be " + "updated for this device."; + } + threads_per_block = device_desc.threads_per_warp(); + if (threads_per_block == 0) { + // Fall back to *something* if we can't even get num threads per warp. + threads_per_block = 32; + } + } + return threads_per_block; +} + // Calculates the launch dimensions used to invoke `hlo`. LaunchDimensions CalculateLaunchDimensions( const Shape& shape, const se::DeviceDescription& device_desc, @@ -62,21 +81,7 @@ LaunchDimensions CalculateLaunchDimensions( // // * = - int64 threads_per_block = device_desc.threads_per_block_limit(); - if (threads_per_block == 0) { - static std::atomic log_count{0}; - if (log_count.fetch_add(1) < 8) { - LOG(WARNING) << "Attempting to calculate launch dimensions for GPU " - "without full information about its capabilities. " - "StreamExecutor's PopulateDeviceDescription should be " - "updated for this device."; - } - threads_per_block = device_desc.threads_per_warp(); - if (threads_per_block == 0) { - // Fall back to *something* if we can't even get num threads per warp. - threads_per_block = 32; - } - } + int64 threads_per_block = ThreadsPerBlockLimit(device_desc); if (num_elements < threads_per_block) { threads_per_block = num_elements; diff --git a/tensorflow/compiler/xla/service/gpu/partition_assignment.h b/tensorflow/compiler/xla/service/gpu/partition_assignment.h index 02471129e004b4876ce20a62cade34060c65b478..eb41dcccb938ccc088c2371def96ca73276771ab 100644 --- a/tensorflow/compiler/xla/service/gpu/partition_assignment.h +++ b/tensorflow/compiler/xla/service/gpu/partition_assignment.h @@ -57,6 +57,9 @@ class LaunchDimensions { std::ostream& operator<<(std::ostream& out, const LaunchDimensions& launch_dims); +// Returns the maximum number of threads per block allowed by the device. +int64 ThreadsPerBlockLimit(const se::DeviceDescription& device_desc); + LaunchDimensions CalculateLaunchDimensions( const Shape& shape, const se::DeviceDescription& device_desc, int unroll_factor = 1); diff --git a/tensorflow/compiler/xla/service/gpu/stream_executor_util.h b/tensorflow/compiler/xla/service/gpu/stream_executor_util.h index 1fc46bafa10e7ba6c896f081d5c836bd400886c9..92e4d6dbbc1bd564657f8a5de09d23d5ae81a93e 100644 --- a/tensorflow/compiler/xla/service/gpu/stream_executor_util.h +++ b/tensorflow/compiler/xla/service/gpu/stream_executor_util.h @@ -16,6 +16,7 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_STREAM_EXECUTOR_UTIL_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_STREAM_EXECUTOR_UTIL_H_ +#include "tensorflow/compiler/xla/layout.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" diff --git a/tensorflow/compiler/xla/service/hlo.proto b/tensorflow/compiler/xla/service/hlo.proto index 414c63271245315f037d04924c9291a9cd5b7a77..9b50f1ca5b5365463f32106fc005ef2c63f2e37a 100644 --- a/tensorflow/compiler/xla/service/hlo.proto +++ b/tensorflow/compiler/xla/service/hlo.proto @@ -34,7 +34,7 @@ import "tensorflow/compiler/xla/xla_data.proto"; option cc_enable_arenas = true; // Serialization of HloInstruction. -// Next ID: 58 +// Next ID: 59 message HloInstructionProto { reserved 10; reserved "parameter_name"; @@ -82,6 +82,8 @@ message HloInstructionProto { // it will use a default value of 1. int64 feature_group_count = 50; + int64 batch_group_count = 58; + // Describes the [begin, end) index range and stride for slices. message SliceDimensions { int64 start = 1; @@ -166,7 +168,7 @@ message HloInstructionProto { // Cross replica op fields. repeated ReplicaGroup replica_groups = 49; int64 all_reduce_id = 45; - string cross_replica_sum_barrier = 46; + string all_reduce_barrier = 46; // Whether this Send/Recv instruction transfers data to/from the host. Only // present for Send and Recv instructions and their SendDone and RecvDone diff --git a/tensorflow/compiler/xla/service/hlo_computation.cc b/tensorflow/compiler/xla/service/hlo_computation.cc index ff122b529bdcdcc69d2245136e19101902dbf957..75630307186ba42f711a85454d73722533e59358 100644 --- a/tensorflow/compiler/xla/service/hlo_computation.cc +++ b/tensorflow/compiler/xla/service/hlo_computation.cc @@ -15,8 +15,8 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_computation.h" -#include #include +#include #include #include #include @@ -332,7 +332,7 @@ void HloComputation::ComputeInstructionPostOrder( dfs_stack.emplace_back(op); } - // Add inputs for send->recv_done dependencies and cross-replica-sum + // Add inputs for send->recv_done dependencies and all-reduce // dependencies. switch (current->opcode()) { case HloOpcode::kRecvDone: { @@ -344,7 +344,7 @@ void HloComputation::ComputeInstructionPostOrder( } break; } - case HloOpcode::kCrossReplicaSum: { + case HloOpcode::kAllReduce: { auto all_reduce_id = current->all_reduce_id(); if (all_reduce_id) { auto it = channel_dependency_map.find(all_reduce_id.value()); @@ -372,7 +372,7 @@ HloComputation::ComputeChannelDependencies() const { instruction.get()); break; } - case HloOpcode::kCrossReplicaSum: { + case HloOpcode::kAllReduce: { auto all_reduce_id = instruction->all_reduce_id(); if (all_reduce_id) { auto& dependencies = channel_dependency_map[all_reduce_id.value()]; @@ -396,6 +396,7 @@ std::vector HloComputation::MakeInstructionPostOrder() const { post_order.reserve(instruction_count()); std::vector trace_instructions; absl::flat_hash_map visited; + visited.reserve(instruction_count()); for (auto& instruction : instructions_) { if (instruction->opcode() == HloOpcode::kTrace) { // Trace instructions aren't handled by the DFS visitor. Add trace @@ -711,8 +712,6 @@ bool HloComputation::operator==(const HloComputation& other) const { return eq(root_instruction(), other.root_instruction()); } -uint64 HloComputation::Hash() const { return root_instruction()->Hash(); } - Status HloComputation::ReplaceWithNewInstruction( HloInstruction* old_instruction, std::unique_ptr new_instruction) { @@ -797,7 +796,7 @@ Status HloComputation::AcceptWithOperandOrder( template Status HloComputation::AcceptOrdered( DfsHloVisitorBase* visitor, - const std::vector& order) const { + absl::Span order) const { VLOG(3) << "Accepting visitor with order."; for (HloInstruction* root : CollectUnreachableRoots()) { TF_RET_CHECK(std::find(order.begin(), order.end(), root) != order.end()) @@ -827,9 +826,9 @@ Status HloComputation::AcceptOrdered( // Explicit instantiations. template Status HloComputation::AcceptOrdered( - DfsHloVisitor*, const std::vector&) const; + DfsHloVisitor*, absl::Span) const; template Status HloComputation::AcceptOrdered( - ConstDfsHloVisitor*, const std::vector&) const; + ConstDfsHloVisitor*, absl::Span) const; Status HloComputation::Accept( const std::function& visitor_func) { diff --git a/tensorflow/compiler/xla/service/hlo_computation.h b/tensorflow/compiler/xla/service/hlo_computation.h index c584e4c7ca5770533f28352b0df9dadd9dbe1860..a0ccbc583f8c409f29d31756fcc1fa1b4af7dc35 100644 --- a/tensorflow/compiler/xla/service/hlo_computation.h +++ b/tensorflow/compiler/xla/service/hlo_computation.h @@ -264,12 +264,6 @@ class HloComputation { // Return whether `*this` and `other` are functionally equivalent. bool operator==(const HloComputation& other) const; - // Generates a hash value of an HLO computation. Hash considers - // information on opcode, shape, operands, and typically a root instruction. - // This function returns the same hash value for equivalent HLO computations, - // with respect to HloInstruction::Identical() method. - uint64 Hash() const; - // Replaces old instruction with newly created instruction. Removes old // instruction from computation. Updates uses and root instruction. Status ReplaceWithNewInstruction( @@ -307,7 +301,7 @@ class HloComputation { // be a topological sort of all instructions in the computation. template Status AcceptOrdered(DfsHloVisitorBase* visitor, - const std::vector& order) const; + absl::Span order) const; // Same as Accept() above, but the visitor is given as a function. Status Accept(const std::function& visitor_func); @@ -373,7 +367,7 @@ class HloComputation { // Returns a map from channel-id to directed dependencies of the channel // instructions. For send&recv pairs it means the send instruction and for - // cross-replica-sum the union of the dependencies for all participating + // all-reduce the union of the dependencies for all participating // instructions. using ChannelDependencyMap = absl::flat_hash_map>; diff --git a/tensorflow/compiler/xla/service/hlo_constant_folding_test.cc b/tensorflow/compiler/xla/service/hlo_constant_folding_test.cc index 4f81dc94e577a63c09ae4019e5e8158252c712ce..92b748d813c3efef83ef0155f1d5d3c637ce2c57 100644 --- a/tensorflow/compiler/xla/service/hlo_constant_folding_test.cc +++ b/tensorflow/compiler/xla/service/hlo_constant_folding_test.cc @@ -252,7 +252,7 @@ const char* const kConstantFoldLargePad = R"( HloModule ConstantFoldLargePad ENTRY r { - a = f32[1,1,1] constant(f32[1,1,1]{{{7}}}) + a = f32[1,1,1] constant({{{7}}}) b = f32[] constant(42) ROOT pad = f32[2048,2048,128] pad(a, b), padding=1024_1023x1024_1023x64_63 })"; diff --git a/tensorflow/compiler/xla/service/hlo_cost_analysis.cc b/tensorflow/compiler/xla/service/hlo_cost_analysis.cc index df7d3826dbad1f264a5dc53312c062900155b0f6..cb431aed47f0a751a697305986a8a0c194ac966c 100644 --- a/tensorflow/compiler/xla/service/hlo_cost_analysis.cc +++ b/tensorflow/compiler/xla/service/hlo_cost_analysis.cc @@ -552,7 +552,7 @@ Status HloCostAnalysis::HandleFft(const HloInstruction* fft) { return Status::OK(); } -Status HloCostAnalysis::HandleCrossReplicaSum(const HloInstruction* crs) { +Status HloCostAnalysis::HandleAllReduce(const HloInstruction* crs) { // We assume 2 replicas, so that each output element is the sum of two input // elements. // diff --git a/tensorflow/compiler/xla/service/hlo_cost_analysis.h b/tensorflow/compiler/xla/service/hlo_cost_analysis.h index 33983119c9b00a248c0e8dcc5815c6367192dca3..b52305626dd67336eb31098d086ad357f12d96c7 100644 --- a/tensorflow/compiler/xla/service/hlo_cost_analysis.h +++ b/tensorflow/compiler/xla/service/hlo_cost_analysis.h @@ -71,7 +71,7 @@ class HloCostAnalysis : public ConstDfsHloVisitor { Status HandleDot(const HloInstruction* dot) override; Status HandleConvolution(const HloInstruction* convolution) override; Status HandleFft(const HloInstruction* fft) override; - Status HandleCrossReplicaSum(const HloInstruction* crs) override; + Status HandleAllReduce(const HloInstruction* crs) override; Status HandleAllToAll(const HloInstruction* hlo) override; Status HandleCollectivePermute(const HloInstruction* hlo) override; Status HandleInfeed(const HloInstruction* infeed) override; diff --git a/tensorflow/compiler/xla/service/hlo_creation_utils.cc b/tensorflow/compiler/xla/service/hlo_creation_utils.cc index b2005d3c210d4ae7e3702cb9624c3ad98056984c..e41aeab19e49ddd4f2363746f0ff8ba1740139b3 100644 --- a/tensorflow/compiler/xla/service/hlo_creation_utils.cc +++ b/tensorflow/compiler/xla/service/hlo_creation_utils.cc @@ -69,11 +69,11 @@ StatusOr MakeConvolveHlo( CHECK_EQ(computation, rhs->parent()); TF_ASSIGN_OR_RETURN(Shape convolve_shape, ShapeInference::InferConvolveShape( - lhs->shape(), rhs->shape(), feature_group_count, + lhs->shape(), rhs->shape(), feature_group_count, 1, window, dimension_numbers)); return computation->AddInstruction(HloInstruction::CreateConvolve( - convolve_shape, lhs, rhs, feature_group_count, window, dimension_numbers, - precision_config)); + convolve_shape, lhs, rhs, feature_group_count, 1, window, + dimension_numbers, precision_config)); } StatusOr MakeTransposeHlo(HloInstruction* operand, diff --git a/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc b/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc index f7a1f19a6f52befd58a405d0e406d7d0d37a8e57..94de7c55dd2402e55ec344b79c24af2d8283fe73 100644 --- a/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc +++ b/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc @@ -1882,8 +1882,8 @@ TEST_P(HloDataflowAnalysisTest, AddDependency) { HloModule AddDependency ENTRY %AddDependency (p: f32[3]) -> f32[3] { %p = f32[3] parameter(0) - %token = token[] after-all() - ROOT %add_dep = f32[3] add-dependency(f32[3] %p, token[] %token) + %token0 = token[] after-all() + ROOT %add_dep = f32[3] add-dependency(f32[3] %p, token[] %token0) } )"; TF_ASSERT_OK_AND_ASSIGN( diff --git a/tensorflow/compiler/xla/service/hlo_domain_test.cc b/tensorflow/compiler/xla/service/hlo_domain_test.cc index acdb42128e3d9a1fb912a466c9c2c3cbbe3d3f83..fd4fb0246d8d42ab7329c05dc23e386303cdce3c 100644 --- a/tensorflow/compiler/xla/service/hlo_domain_test.cc +++ b/tensorflow/compiler/xla/service/hlo_domain_test.cc @@ -195,10 +195,10 @@ HloModule Module ENTRY entry { p0 = (f32[4]) parameter(0) a = f32[4] get-tuple-element(p0), index=0 - token = token[] after-all() - b = (f32[4], u32[], token[]) send(a, token), channel_id=1, sharding={maximal device=0} + token0 = token[] after-all() + b = (f32[4], u32[], token[]) send(a, token0), channel_id=1, sharding={maximal device=0} c = token[] send-done(b), channel_id=1, sharding={maximal device=0} - d = (f32[4], u32[], token[]) recv(token), channel_id=2, sharding={maximal device=0} + d = (f32[4], u32[], token[]) recv(token0), channel_id=2, sharding={maximal device=0} e = (f32[4], token[]) recv-done(d), channel_id=2, sharding={maximal device=0} e_element = f32[4] get-tuple-element(e), index=0, sharding={maximal device=0} f = f32[4] add(a, e_element) @@ -235,12 +235,12 @@ TEST_F(HloDomainTest, CheckNoDomainAddedOnPureIOComputation) { HloModule Module ENTRY entry { - token = token[] after-all(), sharding={maximal device=-1} - a = (f32[4], u32[], token[]) recv(token), channel_id=1, sharding={maximal device=-1} + token0 = token[] after-all(), sharding={maximal device=-1} + a = (f32[4], u32[], token[]) recv(token0), channel_id=1, sharding={maximal device=-1} b = (f32[4], token[]) recv-done(a), channel_id=1, sharding={maximal device=-1} b_element = f32[4] get-tuple-element(b), index=0, sharding={maximal device=-1} c = f32[4] add(b_element, b_element), sharding={maximal device=-1} - d = (f32[4], u32[], token[]) send(c, token), channel_id=2, sharding={maximal device=-1} + d = (f32[4], u32[], token[]) send(c, token0), channel_id=2, sharding={maximal device=-1} ROOT e = token[] send-done(d), channel_id=2, sharding={maximal device=-1} } )"; @@ -259,12 +259,12 @@ TEST_F(HloDomainTest, CheckNormalizationOnPureIOComputation) { HloModule Module ENTRY entry { - token = token[] after-all(), sharding={maximal device=0} - a = (f32[4], u32[], token[]) recv(token), channel_id=1, sharding={maximal device=0} + token0 = token[] after-all(), sharding={maximal device=0} + a = (f32[4], u32[], token[]) recv(token0), channel_id=1, sharding={maximal device=0} b = (f32[4], token[]) recv-done(a), channel_id=1, sharding={maximal device=0} b_element = f32[4] get-tuple-element(b), index=0, sharding={maximal device=0} c = f32[4] add(b_element, b_element) - d = (f32[4], u32[], token[]) send(c, token), channel_id=2, sharding={maximal device=0} + d = (f32[4], u32[], token[]) send(c, token0), channel_id=2, sharding={maximal device=0} ROOT e = token[] send-done(d), channel_id=2, sharding={maximal device=0} } )"; @@ -344,8 +344,8 @@ TEST_F(HloDomainTest, CheckNormalizationOnInfeedTuple) { HloModule Module ENTRY entry { - token = token[] after-all() - infeed = ((f32[4], f32[4]), token[]) infeed(token), + token0 = token[] after-all() + infeed = ((f32[4], f32[4]), token[]) infeed(token0), sharding={{maximal device=1}, {maximal device=0}, {maximal device=0}} infeed.data = (f32[4], f32[4]) get-tuple-element(infeed), index=0, sharding={{maximal device=1}, {maximal device=0}} diff --git a/tensorflow/compiler/xla/service/hlo_element_type_converter.cc b/tensorflow/compiler/xla/service/hlo_element_type_converter.cc index 72006e17e7e7ec09b62e88d05b695ec9f4c49647..a40b6d888c548bf0909f413c092fc32cfc0a4892 100644 --- a/tensorflow/compiler/xla/service/hlo_element_type_converter.cc +++ b/tensorflow/compiler/xla/service/hlo_element_type_converter.cc @@ -141,10 +141,9 @@ StatusOr HloElementTypeConverter::Run(HloModule* module) { // These are ops with embedded computations where it suffices to convert // the embedded computations instead of converting the ops themselves. if (opcode == HloOpcode::kWhile || opcode == HloOpcode::kCall || - opcode == HloOpcode::kCrossReplicaSum || - opcode == HloOpcode::kFusion || opcode == HloOpcode::kMap || - opcode == HloOpcode::kReduce || opcode == HloOpcode::kReduceWindow || - opcode == HloOpcode::kScatter || + opcode == HloOpcode::kAllReduce || opcode == HloOpcode::kFusion || + opcode == HloOpcode::kMap || opcode == HloOpcode::kReduce || + opcode == HloOpcode::kReduceWindow || opcode == HloOpcode::kScatter || opcode == HloOpcode::kSelectAndScatter || opcode == HloOpcode::kConditional) { continue; diff --git a/tensorflow/compiler/xla/service/hlo_element_type_converter_test.cc b/tensorflow/compiler/xla/service/hlo_element_type_converter_test.cc index c170e36c73ad2bef830e528de3ec72d38683d888..a3b56a44a0b02923585c1dcb69571479236188a3 100644 --- a/tensorflow/compiler/xla/service/hlo_element_type_converter_test.cc +++ b/tensorflow/compiler/xla/service/hlo_element_type_converter_test.cc @@ -57,10 +57,10 @@ TEST_F(HloElementTypeConverterTest, InfeedsOutfeedsNotConverted) { const string& hlo_string = R"( HloModule InfeedOutfeed ENTRY RoundTrip16MiBR1.v2 { - token = token[] after-all() - infeed = (bf16[4]{0}, token[]) infeed(token) + token0 = token[] after-all() + infeed = (bf16[4]{0}, token[]) infeed(token0) ROOT infeed.data = bf16[4]{0} get-tuple-element(infeed), index=0 - outfeed = token[] outfeed(infeed.data, token) + outfeed = token[] outfeed(infeed.data, token0) } )"; auto module = CreateModuleFromHloString(hlo_string); @@ -96,13 +96,13 @@ TEST_F(HloElementTypeConverterTest, BatchNormGradBF16Converted) { const string& hlo_string = R"( HloModule BatchNormGrad ENTRY BatchNormGrad.v6 { - constant.4 = bf16[2,2,2,1]{3,2,1,0} constant(bf16[2,2,2,1] { { /*i0=0*/ + constant.4 = bf16[2,2,2,1]{3,2,1,0} constant({ { /*i0=0*/ { /*i1=0*/ {0}, {0} }, { /*i1=1*/ {0}, {0} } }, { /*i0=1*/ { /*i1=0*/ {0}, {0} }, { /*i1=1*/ {0}, {0} } } }) constant.5 = bf16[2]{0} constant({1, 1}) constant.6 = bf16[2]{0} constant({0, 0}) constant.7 = bf16[2]{0} constant({1, 1}) - constant.8 = bf16[2,2,2,1]{3,2,1,0} constant(bf16[2,2,2,1] { { /*i0=0*/ + constant.8 = bf16[2,2,2,1]{3,2,1,0} constant({ { /*i0=0*/ { /*i1=0*/ {1}, {2} }, { /*i1=1*/ {3}, {4} } }, { /*i0=1*/ { /*i1=0*/ {5}, {6} }, { /*i1=1*/ {7}, {8} } } }) ROOT batch-norm-grad = (bf16[2,2,2,1]{3,2,1,0}, bf16[2]{0}, bf16[2]{0}) diff --git a/tensorflow/compiler/xla/service/hlo_evaluator.cc b/tensorflow/compiler/xla/service/hlo_evaluator.cc index 3a7652a8dc856b23c8988c4676916c8199e78860..934c082bb9f003b1d2d80835f09a8f4109c7e7fd 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator.cc +++ b/tensorflow/compiler/xla/service/hlo_evaluator.cc @@ -33,6 +33,7 @@ limitations under the License. #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/map_util.h" #include "tensorflow/compiler/xla/primitive_util.h" +#include "tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul.h" #include "tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" @@ -629,8 +630,11 @@ Status HloEvaluator::HandleCompare(HloInstruction* compare) { evaluated_[compare], Compare(compare->shape(), opcode, lhs_literal, rhs_literal)); } break; - case F16: - return Unimplemented("unhandled primitive type: F16."); + case F16: { + TF_ASSIGN_OR_RETURN( + evaluated_[compare], + Compare(compare->shape(), opcode, lhs_literal, rhs_literal)); + } break; case BF16: { TF_ASSIGN_OR_RETURN(evaluated_[compare], Compare(compare->shape(), opcode, @@ -1449,4 +1453,46 @@ template StatusOr HloEvaluator::Evaluate( template StatusOr HloEvaluator::Evaluate( HloInstruction* instruction, absl::Span arg_literals); +namespace { +template +std::unique_ptr> MatmulArray2DImpl( + const Array2D& lhs, const Array2D& rhs, + const std::function& impl_fn) { + CHECK_EQ(lhs.width(), rhs.height()); + int m = lhs.height(); + int n = rhs.width(); + int k = lhs.width(); + auto result = absl::make_unique>(m, n); + // Because Eigen is a header-oriented library, make sure that the Eigen code + // is the same as the code used by the CPU backend (otherwise the linker will + // randomly pick *some* definition). + impl_fn( + /*run_options_ptr=*/nullptr, result->data(), rhs.data(), lhs.data(), n, m, + k, + /*transpose_lhs=*/0, + /*transpose_rhs=*/0); + return result; +} +} // namespace + +std::unique_ptr> HloEvaluator::MatmulArray2D( + const Array2D& lhs, const Array2D& rhs) { + return MatmulArray2DImpl( + lhs, rhs, __xla_cpu_runtime_EigenSingleThreadedMatMulF16); +} + +std::unique_ptr> HloEvaluator::MatmulArray2D( + const Array2D& lhs, const Array2D& rhs) { + return MatmulArray2DImpl( + lhs, rhs, __xla_cpu_runtime_EigenSingleThreadedMatMulF32); +} + +std::unique_ptr> HloEvaluator::MatmulArray2D( + const Array2D& lhs, const Array2D& rhs) { + return MatmulArray2DImpl( + lhs, rhs, __xla_cpu_runtime_EigenSingleThreadedMatMulF64); +} + } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_evaluator.h b/tensorflow/compiler/xla/service/hlo_evaluator.h index 45ed8131dc6b71f706fce45d65b206363dd79ac3..d363a51c63de6fd4246c4970f580b68f4a627df8 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator.h +++ b/tensorflow/compiler/xla/service/hlo_evaluator.h @@ -21,6 +21,7 @@ limitations under the License. #include "absl/container/node_hash_map.h" #include "absl/memory/memory.h" #include "absl/types/span.h" +#include "tensorflow/compiler/xla/array2d.h" #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" @@ -119,6 +120,17 @@ class HloEvaluator : public DfsHloVisitorWithDefault { const PrecisionConfig& precision_config, const Literal& lhs, const Literal& rhs); + // Enable the fast path for certain operations like dot or convolution. + void set_use_fast_path(bool value) { use_fast_path_ = value; } + + // Returns the result of a matrix multiply `lhs x rhs`. + static std::unique_ptr> MatmulArray2D( + const Array2D& lhs, const Array2D& rhs); + static std::unique_ptr> MatmulArray2D( + const Array2D& lhs, const Array2D& rhs); + static std::unique_ptr> MatmulArray2D( + const Array2D& lhs, const Array2D& rhs); + protected: // Make HloEvaluatorTypedVisitor a friend because it is logically part of this // class. @@ -217,6 +229,9 @@ class HloEvaluator : public DfsHloVisitorWithDefault { // we cannot use flat_hash_map any more. absl::node_hash_map evaluated_; + // Use fast path that uses eigen in the evaluator. + bool use_fast_path_ = false; + private: template static StatusOr ElementWiseUnaryOpImpl( @@ -250,6 +265,8 @@ class HloEvaluator : public DfsHloVisitorWithDefault { TF_DISALLOW_COPY_AND_ASSIGN(HloEvaluator); }; +std::unique_ptr> MatmulArray2D(const Array2D& lhs, + const Array2D& rhs); } // namespace xla #endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_EVALUATOR_H_ diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_test.cc b/tensorflow/compiler/xla/service/hlo_evaluator_test.cc index 4eaaab20ea0add17d9b49b1b2b97991af0438dcc..8fa493a8732662d5357a68937bfad7ac2b3b8c5d 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator_test.cc +++ b/tensorflow/compiler/xla/service/hlo_evaluator_test.cc @@ -804,7 +804,7 @@ TEST_P(HloEvaluatorTest, SimpleConv1D) { Shape shape = ShapeUtil::MakeShape(F32, {1, 1, 3}); b.AddInstruction(HloInstruction::CreateConvolve( shape, lhs_instruction, rhs_instruction, /*feature_group_count=*/1, - window, dnums, DefaultPrecisionConfig(2))); + /*batch_group_count=*/1, window, dnums, DefaultPrecisionConfig(2))); m_->AddEntryComputation(b.Build()); Literal result = Evaluate(); @@ -859,7 +859,7 @@ TEST_P(HloEvaluatorTest, Simple4x4Conv2DWith2x2Kernel) { Shape shape = ShapeUtil::MakeShape(F32, {1, 1, 4, 4}); b.AddInstruction(HloInstruction::CreateConvolve( shape, lhs_instruction, rhs_instruction, /*feature_group_count=*/1, - window, dnums, DefaultPrecisionConfig(2))); + /*batch_group_count=*/1, window, dnums, DefaultPrecisionConfig(2))); m_->AddEntryComputation(b.Build()); Literal result = Evaluate(); @@ -943,7 +943,7 @@ TEST_P(HloEvaluatorTest, Conv2DGeneralDimensionsReversed) { Shape shape = ShapeUtil::MakeShape(F32, {1, 1, 1, 2}); b.AddInstruction(HloInstruction::CreateConvolve( shape, lhs_instruction, rhs_instruction, /*feature_group_count=*/1, - window, dnums, DefaultPrecisionConfig(2))); + /*batch_group_count=*/1, window, dnums, DefaultPrecisionConfig(2))); m_->AddEntryComputation(b.Build()); Literal result = Evaluate(); @@ -1021,7 +1021,7 @@ TEST_P(HloEvaluatorTest, Conv2DGeneralDimensions) { Shape shape = ShapeUtil::MakeShape(F32, {1, 1, 1, 2}); b.AddInstruction(HloInstruction::CreateConvolve( shape, lhs_instruction, rhs_instruction, /*feature_group_count=*/1, - window, dnums, DefaultPrecisionConfig(2))); + /*batch_group_count=*/1, window, dnums, DefaultPrecisionConfig(2))); m_->AddEntryComputation(b.Build()); Literal result = Evaluate(); @@ -1081,7 +1081,7 @@ TEST_P(HloEvaluatorTest, DilatedBaseConv2DWithHighPadding) { Shape shape = ShapeUtil::MakeShape(F32, {1, 1, 7, 7}); b.AddInstruction(HloInstruction::CreateConvolve( shape, lhs_instruction, rhs_instruction, /*feature_group_count=*/1, - window, dnums, DefaultPrecisionConfig(2))); + /*batch_group_count=*/1, window, dnums, DefaultPrecisionConfig(2))); m_->AddEntryComputation(b.Build()); Literal result = Evaluate(); @@ -1145,7 +1145,7 @@ TEST_P(HloEvaluatorTest, DilatedBaseConv2DWithLowAndHighPadding) { Shape shape = ShapeUtil::MakeShape(F32, {1, 1, 8, 8}); b.AddInstruction(HloInstruction::CreateConvolve( shape, lhs_instruction, rhs_instruction, /*feature_group_count=*/1, - window, dnums, DefaultPrecisionConfig(2))); + /*batch_group_count=*/1, window, dnums, DefaultPrecisionConfig(2))); m_->AddEntryComputation(b.Build()); Literal result = Evaluate(); @@ -1217,7 +1217,7 @@ TEST_P(HloEvaluatorTest, Shape shape = ShapeUtil::MakeShape(F32, {1, 1, 9, 3}); b.AddInstruction(HloInstruction::CreateConvolve( shape, lhs_instruction, rhs_instruction, /*feature_group_count=*/1, - window, dnums, DefaultPrecisionConfig(2))); + /*batch_group_count=*/1, window, dnums, DefaultPrecisionConfig(2))); m_->AddEntryComputation(b.Build()); Literal result = Evaluate(); @@ -1288,7 +1288,8 @@ TEST_P(HloEvaluatorTest, Conv2DGroupedConvolution) { Shape shape = ShapeUtil::MakeShape(F32, {1, 1, 1, 8}); b.AddInstruction(HloInstruction::CreateConvolve( shape, lhs_instruction, rhs_instruction, - /*feature_group_count=*/2, window, dnums, DefaultPrecisionConfig(2))); + /*feature_group_count=*/2, /*batch_group_count=*/1, window, dnums, + DefaultPrecisionConfig(2))); m_->AddEntryComputation(b.Build()); Literal result = Evaluate(); diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h index b87fc3e34012e75ee07bff6c1e113dce404f83cb..3ace2f544329253d217e1891ce387a8a55fe2339 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h +++ b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h @@ -23,6 +23,7 @@ limitations under the License. #include "absl/container/inlined_vector.h" #include "absl/memory/memory.h" #include "absl/types/optional.h" +#include "tensorflow/compiler/xla/array2d.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/service/hlo_casting_utils.h" #include "tensorflow/compiler/xla/service/hlo_evaluator.h" @@ -105,6 +106,12 @@ bool SafeLess(const NativeT& a, const NativeT& b) { template class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { private: + Status UnsupportedTypeError(HloInstruction* instruction) { + return InvalidArgument( + "Unsupported type for %s: %s", HloOpcodeString(instruction->opcode()), + PrimitiveType_Name(instruction->shape().element_type())); + } + // Get the value in the given literal static_cast as a double. template < typename NativeT, @@ -224,7 +231,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { typename NativeT, typename std::enable_if::value>::type* = nullptr> Status HandleRound(HloInstruction* round) { - return InvalidArgument("Unsupported type for Round"); + return UnsupportedTypeError(round); } Status HandleRound(HloInstruction* round) override { @@ -246,7 +253,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { typename NativeT, typename std::enable_if::value>::type* = nullptr> Status HandleCeil(HloInstruction* ceil) { - return InvalidArgument("Unsupported type for Ceil"); + return UnsupportedTypeError(ceil); } Status HandleCeil(HloInstruction* ceil) override { @@ -297,8 +304,8 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { template < typename NativeT, typename std::enable_if::value>::type* = nullptr> - Status HandleExpm1(HloInstruction* floor) { - return InvalidArgument("Unsupported type for Expm1"); + Status HandleExpm1(HloInstruction* expm1) { + return UnsupportedTypeError(expm1); } Status HandleExpm1(HloInstruction* floor) override { @@ -321,7 +328,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { typename NativeT, typename std::enable_if::value>::type* = nullptr> Status HandleFloor(HloInstruction* floor) { - return InvalidArgument("Unsupported type for Floor"); + return UnsupportedTypeError(floor); } Status HandleFloor(HloInstruction* floor) override { @@ -351,12 +358,12 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { template < typename NativeT, typename std::enable_if::value>::type* = nullptr> - Status HandleLog1p(HloInstruction* floor) { - return InvalidArgument("Unsupported type for Log1p"); + Status HandleLog1p(HloInstruction* log1p) { + return UnsupportedTypeError(log1p); } - Status HandleLog1p(HloInstruction* floor) override { - return HandleLog1p(floor); + Status HandleLog1p(HloInstruction* log1p) override { + return HandleLog1p(log1p); } template ::value>::type* = nullptr> Status HandleNot(HloInstruction* not_) { - return InvalidArgument("Unsupported type for Not"); + return UnsupportedTypeError(not_); } Status HandleNot(HloInstruction* not_) override { @@ -476,7 +483,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { template ::value>::type* = nullptr> Status HandleAtan2(HloInstruction* atan2) { - return InvalidArgument("Unsupported type for Atan2"); + return UnsupportedTypeError(atan2); } Status HandleAtan2(HloInstruction* atan2) override { @@ -624,7 +631,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { typename NativeT, typename std::enable_if::value>::type* = nullptr> Status HandleMaximum(HloInstruction* maximum) { - return InvalidArgument("Unsupported type for Maximum"); + return UnsupportedTypeError(maximum); } Status HandleMaximum(HloInstruction* maximum) override { @@ -659,7 +666,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { typename NativeT, typename std::enable_if::value>::type* = nullptr> Status HandleMinimum(HloInstruction* minimum) { - return InvalidArgument("Unsupported type for Minimum"); + return UnsupportedTypeError(minimum); } Status HandleMinimum(HloInstruction* minimum) override { @@ -724,7 +731,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { typename NativeT, typename std::enable_if::value>::type* = nullptr> Status HandleRemainder(HloInstruction* remainder) { - return InvalidArgument("Unsupported type for Remainder"); + return UnsupportedTypeError(remainder); } Status HandleRemainder(HloInstruction* remainder) override { @@ -746,14 +753,14 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { template ::value>::type* = nullptr> Status HandleAnd(HloInstruction* and_) { - return InvalidArgument("Unsupported type for And"); + return UnsupportedTypeError(and_); } template < typename NativeT, typename std::enable_if::value>::type* = nullptr> Status HandleAnd(HloInstruction* and_) { - return InvalidArgument("Unsupported type for And"); + return UnsupportedTypeError(and_); } Status HandleAnd(HloInstruction* and_) override { @@ -775,7 +782,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { template ::value>::type* = nullptr> Status HandleOr(HloInstruction* or_) { - return InvalidArgument("Unsupported type for Or"); + return UnsupportedTypeError(or_); } template < @@ -804,14 +811,14 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { template ::value>::type* = nullptr> Status HandleXor(HloInstruction* xor_) { - return InvalidArgument("Unsupported type for Xor"); + return UnsupportedTypeError(xor_); } template < typename NativeT, typename std::enable_if::value>::type* = nullptr> Status HandleXor(HloInstruction* xor_) { - return InvalidArgument("Unsupported type for Xor"); + return UnsupportedTypeError(xor_); } Status HandleXor(HloInstruction* xor_) override { @@ -836,8 +843,8 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { typename std::enable_if::value || std::is_same::value>::type* = nullptr> - Status HandleShiftLeft(HloInstruction*) { - return InvalidArgument("Unsupported type for ShiftLeft"); + Status HandleShiftLeft(HloInstruction* shift) { + return UnsupportedTypeError(shift); } Status HandleShiftLeft(HloInstruction* shl) override { @@ -866,8 +873,8 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { typename std::enable_if::value || std::is_same::value>::type* = nullptr> - Status HandleShiftRightArithmetic(HloInstruction*) { - return InvalidArgument("Unsupported type for ShiftRightArithmetic"); + Status HandleShiftRightArithmetic(HloInstruction* shift) { + return UnsupportedTypeError(shift); } Status HandleShiftRightArithmetic(HloInstruction* shra) override { @@ -897,8 +904,8 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { typename std::enable_if::value || std::is_same::value>::type* = nullptr> - Status HandleShiftRightLogical(HloInstruction*) { - return InvalidArgument("Unsupported type for ShiftRightLogical"); + Status HandleShiftRightLogical(HloInstruction* shift) { + return UnsupportedTypeError(shift); } Status HandleShiftRightLogical(HloInstruction* shrl) override { @@ -923,8 +930,8 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { template < typename NativeT, typename std::enable_if::value>::type* = nullptr> - Status HandleClamp(HloInstruction*) { - return InvalidArgument("Unsupported type for Clamp"); + Status HandleClamp(HloInstruction* clamp) { + return UnsupportedTypeError(clamp); } Status HandleClamp(HloInstruction* clamp) override { @@ -1004,10 +1011,10 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { CHECK_EQ(num_spatial_dims + 2, lhs_rank); CHECK_EQ(num_spatial_dims + 2, rhs_rank); - TF_ASSIGN_OR_RETURN( - auto inferred_return_shape, - ShapeInference::InferConvolveShape( - lhs_shape, rhs_shape, conv->feature_group_count(), window, dnums)); + TF_ASSIGN_OR_RETURN(auto inferred_return_shape, + ShapeInference::InferConvolveShape( + lhs_shape, rhs_shape, conv->feature_group_count(), + conv->batch_group_count(), window, dnums)); CHECK(ShapeUtil::Compatible(result_shape, inferred_return_shape)) << "return shape set to: " << ShapeUtil::HumanString(result_shape) << " but is inferred to be: " @@ -1031,7 +1038,10 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { auto rhs_literal_data = rhs_literal.data(); int64 feature_group_count = conv->feature_group_count(); + int64 batch_group_count = conv->batch_group_count(); + // The batch count > 1 case is unimplemented in the HLO evaluator so far. + TF_RET_CHECK(batch_group_count == 1); auto func = [&window_shape, &dnums, &lhs_shape, &rhs_shape, &window, &lhs_dim_multipliers, &rhs_dim_multipliers, lhs_literal_data, rhs_literal_data, @@ -1148,6 +1158,78 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { } Status HandleDot(HloInstruction* dot) override { + if (parent_->use_fast_path_) { + return HandleDot(dot); + } + return HandleDotSlowPath(dot); + } + + template ::value>::type* = nullptr> + Status HandleDot(HloInstruction* dot) { + const HloInstruction* lhs = dot->operand(0); + const HloInstruction* rhs = dot->operand(1); + CHECK(ShapeUtil::IsArray(dot->shape())); + CHECK(ShapeUtil::IsArray(lhs->shape())); + CHECK(ShapeUtil::IsArray(rhs->shape())); + + const auto& dnums = dot->dot_dimension_numbers(); + + const int64 lhs_rank = ShapeUtil::Rank(lhs->shape()); + const int64 rhs_rank = ShapeUtil::Rank(rhs->shape()); + + CHECK(ShapeUtil::SameElementType(lhs->shape(), rhs->shape())); + CHECK(ShapeUtil::SameElementType(lhs->shape(), dot->shape())); + + // There must be 1 and only 1 Contracting dimension for lhs and rhs. + CHECK_EQ(dnums.lhs_contracting_dimensions_size(), 1); + CHECK_EQ(dnums.rhs_contracting_dimensions_size(), 1); + const int64 lhs_contracting_dimension = dnums.lhs_contracting_dimensions(0); + const int64 rhs_contracting_dimension = dnums.rhs_contracting_dimensions(0); + // Contracted dimension sizes must be the same. + CHECK_EQ(lhs->shape().dimensions(lhs_contracting_dimension), + rhs->shape().dimensions(rhs_contracting_dimension)) + << "lhs contracted dimension: " + << lhs->shape().dimensions(lhs_contracting_dimension) + << " rhs contracted dimension: " + << rhs->shape().dimensions(rhs_contracting_dimension); + + // The fast path is for a simple rank 2 dot with default layout operands. + if (lhs_rank == 2 && rhs_rank == 2 && lhs_contracting_dimension == 1 && + rhs_contracting_dimension == 0 && + LayoutUtil::Equal(lhs->shape().layout(), + LayoutUtil::GetDefaultLayoutForR2()) && + LayoutUtil::Equal(rhs->shape().layout(), + LayoutUtil::GetDefaultLayoutForR2()) && + LayoutUtil::Equal(dot->shape().layout(), + LayoutUtil::GetDefaultLayoutForR2())) { + const Literal& lhs_literal = parent_->GetEvaluatedLiteralFor(lhs); + const Literal& rhs_literal = parent_->GetEvaluatedLiteralFor(rhs); + const int64 contracted_dimension_size = + lhs->shape().dimensions(lhs_contracting_dimension); + Array2D lhs_array(lhs->shape().dimensions(0), + contracted_dimension_size); + lhs_array.SetValues(lhs_literal.data()); + Array2D rhs_array(contracted_dimension_size, + rhs->shape().dimensions(1)); + rhs_array.SetValues(rhs_literal.data()); + std::unique_ptr> result_array = + HloEvaluator::MatmulArray2D(lhs_array, rhs_array); + Literal result(dot->shape()); + result.PopulateR2FromArray2D(*result_array); + parent_->evaluated_[dot] = std::move(result); + return Status::OK(); + } + return HandleDotSlowPath(dot); + } + + template ::value>::type* = nullptr> + Status HandleDot(HloInstruction* dot) { + return HandleDotSlowPath(dot); + } + + Status HandleDotSlowPath(HloInstruction* dot) { auto lhs = dot->operand(0); auto rhs = dot->operand(1); CHECK(ShapeUtil::IsArray(dot->shape())); @@ -1578,7 +1660,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { std::is_same::value>::type* = nullptr> Status HandleSort(HloInstruction* sort) { - return InvalidArgument("Unsupported type for Sort"); + return UnsupportedTypeError(sort); } Status HandleSort(HloInstruction* sort) override { @@ -2357,7 +2439,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { std::is_same::value || std::is_same::value)>::type* = nullptr> Status HandleClz(HloInstruction* clz) { - return InvalidArgument("Unsupported type for Clz"); + return UnsupportedTypeError(clz); } template ::value || is_complex_t::value>::type* = nullptr> Status HandleSin(HloInstruction* sin) { - return InvalidArgument("Unsupported type for Sin"); + return UnsupportedTypeError(sin); } Status HandleSin(HloInstruction* sin) override { @@ -2425,7 +2507,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { typename std::enable_if::value || is_complex_t::value>::type* = nullptr> Status HandleCos(HloInstruction* cos) { - return InvalidArgument("Unsupported type for Cos"); + return UnsupportedTypeError(cos); } Status HandleCos(HloInstruction* cos) override { @@ -2534,7 +2616,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { typename std::enable_if::value || is_complex_t::value>::type* = nullptr> Status HandleReducePrecision(HloInstruction* reduce_precision) { - return InvalidArgument("Unsupported type for reduce precision"); + return UnsupportedTypeError(reduce_precision); } Status HandleReducePrecision(HloInstruction* reduce_precision) override { @@ -2543,15 +2625,27 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { template ::value || + std::is_same::value || std::is_integral::value || std::is_floating_point::value>::type* = nullptr> Status HandleIota(HloInstruction* instruction) { auto* iota = Cast(instruction); + const int64 iota_size = iota->shape().dimensions(iota->iota_dimension()); // Avoid using std::vector since std::vector does not convert to // absl::Span. - absl::InlinedVector data( - iota->shape().dimensions(iota->iota_dimension())); - std::iota(data.begin(), data.end(), 0); + absl::InlinedVector data(iota_size); + // We don't use std::iota for two reasons: + // + // (1) std:iota does not support bfloat16 and float16. + // + // (2) std::iota saturates for floating point types when the value is not + // representable, but the definition of HLO iota is the value as a + // 64-bit integer cast to the native type. + for (int64 i = 0; i < iota_size; ++i) { + // static_cast is required for Eigen::half (F16). + data[i] = static_cast(i); + } auto result = LiteralUtil::CreateR1(data); if (ShapeUtil::Rank(iota->shape()) > 1) { @@ -2567,10 +2661,12 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { } template ::value || + !(std::is_same::value || + std::is_same::value || + std::is_integral::value || std::is_floating_point::value)>::type* = nullptr> Status HandleIota(HloInstruction* iota) { - return InvalidArgument("Unsupported type for iota"); + return UnsupportedTypeError(iota); } Status HandleIota(HloInstruction* iota) override { return HandleIota(iota); diff --git a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc index 302eca656be53a3cec86ddbf05a7fa3925c5185b..dbf0d2c113bf670da3617967d913da819ccf2663 100644 --- a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc +++ b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc @@ -1030,7 +1030,7 @@ ColorScheme HloDotDumper::GetInstructionColor(const HloInstruction* instr) { case HloOpcode::kMap: case HloOpcode::kGetDimensionSize: return kGray; - case HloOpcode::kCrossReplicaSum: + case HloOpcode::kAllReduce: case HloOpcode::kAllToAll: case HloOpcode::kCollectivePermute: case HloOpcode::kInfeed: @@ -1474,14 +1474,15 @@ string ExportGraph(const string& graph, GraphRendererInterface::GraphKind graph_kind, const DebugOptions& debug_options) { string path = debug_options.xla_hlo_graph_path(); - if (!path.empty()) { + if (!path.empty() && !debug_options.xla_hlo_dump_as_html()) { return SaveGraph(graph, graph_kind, path); } else { auto graph_renderer = GraphRendererRegistry::Default()->GetDefaultRenderer(); CHECK(graph_renderer != nullptr) << "No registered renderer for the HLO graph. " - "Use --xla_hlo_graph_path=PATH to export to local file system"; + "Use --xla_hlo_graph_path=PATH --xla_hlo_dump_as_html=false to " + "export to local file system"; return graph_renderer->RenderGraph(graph, graph_kind, debug_options); } } @@ -1589,5 +1590,145 @@ string MaybeDumpHloModule(const HloModule& module, const string& label, return graph_url; } +string WrapDotInHTML(const string& dot) { + static const char html_prefix[] = R"html( + + + + + + + + + + + +
+ + + +)html"; + + return html_prefix + dot + html_suffix; +} + +string RenderDotAsHTMLFile(const string& dot, + const DebugOptions& debug_options) { + string html = WrapDotInHTML(dot); + + auto env = tensorflow::Env::Default(); + std::vector dirs; + string output_dir = debug_options.xla_hlo_graph_path(); + if (output_dir.empty()) { + env->GetLocalTempDirectories(&dirs); + } else { + dirs.push_back(output_dir); + } + // Try each directory, as they might be full, have inappropriate + // permissions or have different problems at times. + string output; + for (const string& dir : dirs) { + string filename = tensorflow::io::JoinPath(dir, "graph-"); + if (env->CreateUniqueFileName(&filename, ".html")) { + output = filename; + break; + } + } + if (output.empty()) { + LOG(FATAL) << "Failed to create unique output file name."; + } + TF_CHECK_OK(tensorflow::WriteStringToFile(env, output, html)); + return "file://" + output; +} + } // namespace hlo_graph_dumper } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_graph_dumper.h b/tensorflow/compiler/xla/service/hlo_graph_dumper.h index de1eefab776f9c3d2c73959a5cd267e938a78a32..8e51454ef1cf992386cc7325e32705c08bf7712f 100644 --- a/tensorflow/compiler/xla/service/hlo_graph_dumper.h +++ b/tensorflow/compiler/xla/service/hlo_graph_dumper.h @@ -81,6 +81,12 @@ string DumpAllPathsFromTo(const HloInstruction& from, const HloInstruction& to, void DumpText(const HloModule& module, const string& label, const string& directory_path, bool do_prefix = true); +// Renders DOT graph as inline SVG and saves it in an HTML file in a temprary +// directory or directory specified via --xla_hlo_graph_path. Returns the file +// URI pointing to the file. +string RenderDotAsHTMLFile(const string& dot, + const DebugOptions& debug_options); + // Graph renderers may be added using a registration mechanism, e.g.: // XLA_REGISTER_GRAPH_RENDERER(AGraphRendererClass, 100) // The renderer with the highest numeric priority value is used. diff --git a/tensorflow/compiler/xla/service/hlo_graph_html_renderer.cc b/tensorflow/compiler/xla/service/hlo_graph_html_renderer.cc new file mode 100644 index 0000000000000000000000000000000000000000..84c4cf18df69816c611f4eb159ba247320ebc20e --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_graph_html_renderer.cc @@ -0,0 +1,43 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Implementation of an DOT graph renderer that uses Javascript to render DOT to +// SVG in a browser. + +#include "tensorflow/compiler/xla/service/hlo_graph_dumper.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" + +namespace xla { +namespace hlo_graph_dumper { +namespace { + +class GraphHtmlRenderer : public GraphRendererInterface { + public: + string RenderGraph(const string& graph, GraphKind graph_kind, + const DebugOptions& debug_options) override { + switch (graph_kind) { + case DOT_GRAPH: + return RenderDotAsHTMLFile(graph, debug_options); + default: + LOG(FATAL) << "Only DOT graphs can be rendered"; + } + } +}; + +XLA_REGISTER_GRAPH_RENDERER(GraphHtmlRenderer); + +} // namespace +} // namespace hlo_graph_dumper +} // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc index 21b1dbc1676cccd2fe5b331a1f9d6ff5e3a73fcd..3e8903c95376ae1238b68280bbbb00b0db5a23a2 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction.cc @@ -333,20 +333,20 @@ StatusOr> HloInstruction::CreateFromProto( proto.outfeed_config()); break; } - case HloOpcode::kCrossReplicaSum: { + case HloOpcode::kAllReduce: { TF_RET_CHECK(proto.called_computation_ids_size() == 1) - << "CrossReplicaSum should have 1 called computation but sees " + << "AllReduce should have 1 called computation but sees " << proto.called_computation_ids_size(); absl::optional all_reduce_id; if (proto.all_reduce_id() > 0) { all_reduce_id = proto.all_reduce_id(); } - instruction = CreateCrossReplicaSum( + instruction = CreateAllReduce( shape, all_operands(), computations(0), /*replica_groups=*/ std::vector(proto.replica_groups().begin(), proto.replica_groups().end()), - /*barrier=*/proto.cross_replica_sum_barrier(), + /*barrier=*/proto.all_reduce_barrier(), /*all_reduce_id=*/all_reduce_id); break; } @@ -383,7 +383,8 @@ StatusOr> HloInstruction::CreateFromProto( proto.operand_ids_size(), PrecisionConfig::DEFAULT); instruction = CreateConvolve( shape, operands(0), operands(1), - std::max(proto.feature_group_count(), 1), proto.window(), + std::max(proto.feature_group_count(), 1), + std::max(proto.batch_group_count(), 1), proto.window(), proto.convolution_dimension_numbers(), precision_config); break; } @@ -438,6 +439,9 @@ StatusOr> HloInstruction::CreateFromProto( static_cast(instruction.get()) ->set_feature_group_count( std::max(static_cast(proto.feature_group_count()), 1LL)); + static_cast(instruction.get()) + ->set_batch_group_count( + std::max(static_cast(proto.batch_group_count()), 1LL)); break; case HloOpcode::kPad: TF_RET_CHECK(proto.operand_ids_size() == 2) @@ -569,6 +573,11 @@ StatusOr> HloInstruction::CreateFromProto( instruction->SetAndSanitizeName(proto.name()); instruction->metadata_ = proto.metadata(); instruction->backend_config_ = proto.backend_config(); + + TF_RET_CHECK(proto.id() >= 0) + << "Instruction with negative id: " << proto.id(); + TF_RET_CHECK(proto.id() <= INT_MAX) + << "Instruction with id > INT_MAX: " << proto.id(); instruction->unique_id_ = proto.id(); if (proto.has_sharding()) { @@ -729,12 +738,12 @@ HloInstruction::CreateGetTupleElement(const Shape& shape, /* static */ std::unique_ptr HloInstruction::CreateConvolve( const Shape& shape, HloInstruction* lhs, HloInstruction* rhs, - int64 feature_group_count, const Window& window, + int64 feature_group_count, int64 batch_group_count, const Window& window, const ConvolutionDimensionNumbers& dimension_numbers, const PrecisionConfig& precision_config) { return absl::make_unique( - shape, lhs, rhs, feature_group_count, window, dimension_numbers, - precision_config); + shape, lhs, rhs, feature_group_count, batch_group_count, window, + dimension_numbers, precision_config); } /* static */ std::unique_ptr HloInstruction::CreateFft( @@ -761,8 +770,7 @@ HloInstruction::CreateReducePrecision(const Shape& shape, shape, operand, exponent_bits, mantissa_bits); } -/* static */ std::unique_ptr -HloInstruction::CreateCrossReplicaSum( +/* static */ std::unique_ptr HloInstruction::CreateAllReduce( const Shape& shape, absl::Span operands, HloComputation* reduce_computation, const std::vector& replica_groups, absl::string_view barrier, @@ -914,12 +922,8 @@ HloInstruction::CreateDynamicUpdateSlice(const Shape& shape, HloInstruction* operand, HloInstruction* update, HloInstruction* start_indices) { - auto instruction = absl::WrapUnique( - new HloInstruction(HloOpcode::kDynamicUpdateSlice, shape)); - instruction->AppendOperand(operand); - instruction->AppendOperand(update); - instruction->AppendOperand(start_indices); - return instruction; + return absl::make_unique( + shape, operand, update, start_indices); } /* static */ std::unique_ptr HloInstruction::CreateConcatenate( @@ -1160,7 +1164,7 @@ bool HloInstruction::HasSideEffectNoRecurse() const { case HloOpcode::kOutfeed: case HloOpcode::kTrace: return true; - case HloOpcode::kCrossReplicaSum: + case HloOpcode::kAllReduce: return all_reduce_id().has_value(); default: return false; @@ -1283,7 +1287,7 @@ std::unique_ptr HloInstruction::CloneWithNewOperands( case HloOpcode::kParameter: case HloOpcode::kGetTupleElement: case HloOpcode::kReducePrecision: - case HloOpcode::kCrossReplicaSum: + case HloOpcode::kAllReduce: case HloOpcode::kAllToAll: case HloOpcode::kCollectivePermute: case HloOpcode::kInfeed: @@ -1740,7 +1744,7 @@ bool HloInstruction::IdenticalSlowPath( case HloOpcode::kReducePrecision: case HloOpcode::kInfeed: case HloOpcode::kOutfeed: - case HloOpcode::kCrossReplicaSum: + case HloOpcode::kAllReduce: case HloOpcode::kAllToAll: case HloOpcode::kCollectivePermute: case HloOpcode::kConvolution: @@ -1760,7 +1764,12 @@ bool HloInstruction::IdenticalSlowPath( return false; } -uint64 HloInstruction::Hash() const { +static uint64 HashOperand(const HloInstruction* hlo) { + return ShapeUtil::Hash(hlo->shape()); +} + +uint64 HloInstruction::Hash( + const std::function& hash_operand) const { using tensorflow::Hash64Combine; uint64 hash_value = Hash64Combine(0, static_cast(opcode())); @@ -1769,7 +1778,7 @@ uint64 HloInstruction::Hash() const { if (!IsCrossModuleAllReduce()) { if (!operands().empty()) { for (size_t i = 0; i < operands().size(); ++i) { - hash_value = Hash64Combine(hash_value, operand(i)->Hash()); + hash_value = Hash64Combine(hash_value, hash_operand(operand(i))); } } } @@ -1778,6 +1787,11 @@ uint64 HloInstruction::Hash() const { return hash_value; } +uint64 HloInstruction::Hash() const { + // Use HashOperand as an argument to prevent non-termination. + return Hash(HashOperand); +} + uint64 HloInstruction::InnerHash() const { return 13; } void HloInstruction::RemoveUser(HloInstruction* user) { @@ -1879,7 +1893,7 @@ HloComputation* HloInstruction::to_apply() const { case HloOpcode::kMap: case HloOpcode::kReduceWindow: case HloOpcode::kReduce: - case HloOpcode::kCrossReplicaSum: + case HloOpcode::kAllReduce: case HloOpcode::kScatter: CHECK_EQ(called_computations_.size(), 1); return called_computations_[0]; @@ -1898,7 +1912,7 @@ void HloInstruction::set_to_apply(HloComputation* computation) { case HloOpcode::kMap: case HloOpcode::kReduceWindow: case HloOpcode::kReduce: - case HloOpcode::kCrossReplicaSum: + case HloOpcode::kAllReduce: case HloOpcode::kScatter: CHECK_EQ(called_computations_.size(), 1); called_computations_[0] = computation; @@ -2056,7 +2070,11 @@ bool HloInstruction::IsElementwiseImpl( } bool HloInstruction::IsCrossModuleAllReduce() const { - return opcode() == HloOpcode::kCrossReplicaSum && all_reduce_id(); + return opcode() == HloOpcode::kAllReduce && all_reduce_id(); +} + +bool HloInstruction::IsCrossReplicaAllReduce() const { + return opcode() == HloOpcode::kAllReduce && !all_reduce_id(); } string HloInstruction::ToStringWithCanonicalNameMap( @@ -2167,7 +2185,7 @@ std::vector HloInstruction::ExtraAttributesToString( } else if (opcode() == HloOpcode::kCall || opcode() == HloOpcode::kMap || opcode() == HloOpcode::kReduceWindow || opcode() == HloOpcode::kReduce || - opcode() == HloOpcode::kCrossReplicaSum || + opcode() == HloOpcode::kAllReduce || opcode() == HloOpcode::kScatter) { extra.push_back( StrCat("to_apply=", PrintName(to_apply()->name(), options))); @@ -2203,7 +2221,7 @@ std::vector HloInstruction::ExtraAttributesToString( case HloOpcode::kMap: case HloOpcode::kReduceWindow: case HloOpcode::kReduce: - case HloOpcode::kCrossReplicaSum: + case HloOpcode::kAllReduce: case HloOpcode::kScatter: extra.push_back( StrCat("to_apply=\n", to_apply()->ToString(new_options))); @@ -2400,8 +2418,8 @@ Status HloInstruction::Visit(DfsHloVisitorBase* visitor) { return visitor->HandleConvolution(this); case HloOpcode::kFft: return visitor->HandleFft(this); - case HloOpcode::kCrossReplicaSum: - return visitor->HandleCrossReplicaSum(this); + case HloOpcode::kAllReduce: + return visitor->HandleAllReduce(this); case HloOpcode::kAllToAll: return visitor->HandleAllToAll(this); case HloOpcode::kCollectivePermute: @@ -3256,13 +3274,12 @@ HloInstruction::source_target_pairs() const { return Cast(this)->source_target_pairs(); } -string HloInstruction::cross_replica_sum_barrier() const { - return Cast(this)->cross_replica_sum_barrier(); +string HloInstruction::all_reduce_barrier() const { + return Cast(this)->all_reduce_barrier(); } -void HloInstruction::set_cross_replica_sum_barrier(const string& barrier) { - return Cast(this)->set_cross_replica_sum_barrier( - barrier); +void HloInstruction::set_all_reduce_barrier(const string& barrier) { + return Cast(this)->set_all_reduce_barrier(barrier); } absl::optional HloInstruction::all_reduce_id() const { @@ -3308,6 +3325,18 @@ void HloInstruction::set_feature_group_count(int64 feature_group_count) { feature_group_count); } +int64 HloInstruction::batch_group_count() const { + if (auto convolution = DynCast(this)) { + return convolution->batch_group_count(); + } + return Cast(this)->batch_group_count(); +} + +void HloInstruction::set_batch_group_count(int64 batch_group_count) { + Cast(this)->set_batch_group_count( + batch_group_count); +} + HloComputation* HloInstruction::select() const { return Cast(this)->select(); } diff --git a/tensorflow/compiler/xla/service/hlo_instruction.h b/tensorflow/compiler/xla/service/hlo_instruction.h index a54716217d6bbc5c0601f5d9ff7bf4072a6b30f5..36e1ab49319a3e28143ef4d08888c68c86fbcf62 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.h +++ b/tensorflow/compiler/xla/service/hlo_instruction.h @@ -426,7 +426,7 @@ class HloInstruction { // and window describes how the filter is applied to lhs. static std::unique_ptr CreateConvolve( const Shape& shape, HloInstruction* lhs, HloInstruction* rhs, - int64 feature_group_count, const Window& window, + int64 feature_group_count, int64 batch_group_count, const Window& window, const ConvolutionDimensionNumbers& dimension_numbers, const PrecisionConfig& precision_config); @@ -462,9 +462,7 @@ class HloInstruction { // `all_reduce_id`: for Allreduce nodes from different modules, if they have // the same all_reduce_id, they will be 'Allreduce'd. If empty, Allreduce will // not be applied cross modules. - // - // TODO(b/117564385): Rename this to AllReduce. - static std::unique_ptr CreateCrossReplicaSum( + static std::unique_ptr CreateAllReduce( const Shape& shape, absl::Span operands, HloComputation* reduce_computation, const std::vector& replica_groups, @@ -909,6 +907,14 @@ class HloInstruction { // information on opcode, shape, operands, and typically a root instruction. // This function returns the same hash value for equivalent HLO instructions, // with respect to HloInstruction::Identical() method. + // + // Uses hash_operand function to compute hash values of its operands. + // At the very top level, hash_operand should be non-recursive to prevent + // non-termination. + uint64 Hash( + const std::function& hash_operand) const; + + // Calls the above method with non-recursive hash_operand function. uint64 Hash() const; // Returns whether the instruction has a constant operand. @@ -1174,9 +1180,12 @@ class HloInstruction { // Returns true if this instruction is elementwise on all its operands. bool IsElementwise() const; - // Returns true if this is an cross module all-reduce instrucion. + // Returns true if this is a cross module all-reduce instruction. bool IsCrossModuleAllReduce() const; + // Returns true if this is a cross-replica all-reduce instruction. + bool IsCrossReplicaAllReduce() const; + // Returns true if this elementwise instruction implicitly broadcasts operand // `operand_idx`. // @@ -1448,9 +1457,9 @@ class HloInstruction { // Delegates to HloCollectivePermuteInstruction::source_target_pairs. const std::vector>& source_target_pairs() const; - // Delegates to HloAllReduceInstruction::cross_replica_sum_barrier. - string cross_replica_sum_barrier() const; - void set_cross_replica_sum_barrier(const string& barrier); + // Delegates to HloAllReduceInstruction::all_reduce_barrier. + string all_reduce_barrier() const; + void set_all_reduce_barrier(const string& barrier); // Delegates to HloAllReduceInstruction::all_reduce_id. absl::optional all_reduce_id() const; @@ -1484,6 +1493,11 @@ class HloInstruction { void set_feature_group_count(int64 feature_group_count); + // The number of batch groups. Must be a divisor of the input batch dimension + int64 batch_group_count() const; + + void set_batch_group_count(int64 batch_group_count); + // Delegates to HloSelectAndScatterInstruction::select. HloComputation* select() const; diff --git a/tensorflow/compiler/xla/service/hlo_instructions.cc b/tensorflow/compiler/xla/service/hlo_instructions.cc index 1ea02cf9c03866a598bec0e5356f0eb31ad27755..756e260b60dcda660e89c211862c8c5800439f2c 100644 --- a/tensorflow/compiler/xla/service/hlo_instructions.cc +++ b/tensorflow/compiler/xla/service/hlo_instructions.cc @@ -363,9 +363,9 @@ HloAllReduceInstruction::HloAllReduceInstruction( HloComputation* reduce_computation, const std::vector& replica_groups, absl::string_view barrier, const absl::optional& all_reduce_id) - : HloCollectiveInstruction(HloOpcode::kCrossReplicaSum, shape, operands, + : HloCollectiveInstruction(HloOpcode::kAllReduce, shape, operands, replica_groups), - cross_replica_sum_barrier_(barrier), + all_reduce_barrier_(barrier), all_reduce_id_(all_reduce_id) { AppendComputation(reduce_computation); } @@ -381,7 +381,7 @@ HloInstructionProto HloAllReduceInstruction::ToProto() const { if (all_reduce_id_) { proto.set_all_reduce_id(*all_reduce_id_); } - proto.set_cross_replica_sum_barrier(cross_replica_sum_barrier_); + proto.set_all_reduce_barrier(all_reduce_barrier_); return proto; } @@ -389,8 +389,8 @@ std::vector HloAllReduceInstruction::ExtraAttributesToStringImpl( const HloPrintOptions& options) const { std::vector result = HloCollectiveInstruction::ExtraAttributesToStringImpl(options); - if (!cross_replica_sum_barrier().empty()) { - result.push_back(StrCat("barrier=\"", cross_replica_sum_barrier(), "\"")); + if (!all_reduce_barrier().empty()) { + result.push_back(StrCat("barrier=\"", all_reduce_barrier(), "\"")); } if (all_reduce_id_) { result.push_back(StrCat("all_reduce_id=", *all_reduce_id_)); @@ -405,8 +405,7 @@ bool HloAllReduceInstruction::IdenticalSlowPath( const auto& casted_other = static_cast(other); return HloCollectiveInstruction::IdenticalSlowPath(other, eq_computations) && eq_computations(to_apply(), casted_other.to_apply()) && - cross_replica_sum_barrier() == - casted_other.cross_replica_sum_barrier() && + all_reduce_barrier() == casted_other.all_reduce_barrier() && all_reduce_id() == casted_other.all_reduce_id(); } @@ -415,8 +414,8 @@ HloAllReduceInstruction::CloneWithNewOperandsImpl( const Shape& shape, absl::Span new_operands, HloCloneContext* /*context*/) const { return absl::make_unique( - shape, new_operands, to_apply(), replica_groups(), - cross_replica_sum_barrier(), all_reduce_id()); + shape, new_operands, to_apply(), replica_groups(), all_reduce_barrier(), + all_reduce_id()); } HloAllToAllInstruction::HloAllToAllInstruction( @@ -905,7 +904,7 @@ string HloConstantInstruction::OperandsToStringWithCanonicalNameMap( options.print_large_constants())) { // Literal::ToString emits multidimensional arrays over multiple // lines. Compact this into one line by stripping out white space. - string tmp = literal().ToString(); + string tmp = literal().ToStringWithoutShape(); std::replace(tmp.begin(), tmp.end(), '\n', ' '); std::vector v = absl::StrSplit(tmp, ' '); bool first = true; @@ -1372,8 +1371,14 @@ bool HloFusionInstruction::IdenticalSlowPath( other.fused_instructions_computation()); } +static uint64 HashOperandRecursive(const HloInstruction* hlo) { + return hlo->Hash(HashOperandRecursive); +} + uint64 HloFusionInstruction::InnerHash() const { - return fused_instructions_computation()->Hash(); + // Use HashOperandRecursive to recursively compute hash on inner operands. + return fused_instructions_computation()->root_instruction()->Hash( + HashOperandRecursive); } std::unique_ptr HloFusionInstruction::CloneWithNewOperandsImpl( @@ -1649,11 +1654,12 @@ std::unique_ptr HloOutfeedInstruction::CloneWithNewOperandsImpl( HloConvolutionInstruction::HloConvolutionInstruction( const Shape& shape, HloInstruction* lhs, HloInstruction* rhs, - int64 feature_group_count, const Window& window, + int64 feature_group_count, int64 batch_group_count, const Window& window, const ConvolutionDimensionNumbers& dimension_numbers, const PrecisionConfig& precision_config) : HloInstruction(HloOpcode::kConvolution, shape), feature_group_count_(feature_group_count), + batch_group_count_(batch_group_count), window_(window), convolution_dimension_numbers_(dimension_numbers), precision_config_(precision_config) { @@ -1731,8 +1737,9 @@ HloConvolutionInstruction::CloneWithNewOperandsImpl( HloCloneContext* context) const { CHECK_EQ(new_operands.size(), 2); return absl::make_unique( - shape, new_operands[0], new_operands[1], feature_group_count_, window(), - convolution_dimension_numbers_, precision_config_); + shape, new_operands[0], new_operands[1], feature_group_count_, + batch_group_count_, window(), convolution_dimension_numbers_, + precision_config_); } HloReduceWindowInstruction::HloReduceWindowInstruction( @@ -1994,12 +2001,21 @@ std::unique_ptr HloPadInstruction::CloneWithNewOperandsImpl( HloDynamicSliceInstruction::HloDynamicSliceInstruction( const Shape& shape, HloInstruction* operand, HloInstruction* start_indices, absl::Span slice_sizes) - : HloInstruction(HloOpcode::kDynamicSlice, shape), + : HloDynamicIndexInstruction(HloOpcode::kDynamicSlice, shape), dynamic_slice_sizes_(slice_sizes.begin(), slice_sizes.end()) { AppendOperand(operand); AppendOperand(start_indices); } +HloDynamicUpdateSliceInstruction::HloDynamicUpdateSliceInstruction( + const Shape& shape, HloInstruction* operand, HloInstruction* update, + HloInstruction* start_indices) + : HloDynamicIndexInstruction(HloOpcode::kDynamicUpdateSlice, shape) { + AppendOperand(operand); + AppendOperand(update); + AppendOperand(start_indices); +} + HloInstructionProto HloDynamicSliceInstruction::ToProto() const { HloInstructionProto proto = HloInstruction::ToProto(); for (int64 slice_size : dynamic_slice_sizes_) { diff --git a/tensorflow/compiler/xla/service/hlo_instructions.h b/tensorflow/compiler/xla/service/hlo_instructions.h index b5c28137a145667a977d39c9d3c40c6d36a8436e..ca212c7f2c98f75ceefc14b7fbc2a1f530c06cf7 100644 --- a/tensorflow/compiler/xla/service/hlo_instructions.h +++ b/tensorflow/compiler/xla/service/hlo_instructions.h @@ -242,14 +242,10 @@ class HloAllReduceInstruction : public HloCollectiveInstruction { const std::vector& replica_groups, absl::string_view barrier, const absl::optional& all_reduce_id); - // Returns the barrier config used for the CrossReplicaSum implementation of + // Returns the barrier config used for the AllReduce implementation of // each backend. - string cross_replica_sum_barrier() const { - return cross_replica_sum_barrier_; - } - void set_cross_replica_sum_barrier(string barrier) { - cross_replica_sum_barrier_ = barrier; - } + string all_reduce_barrier() const { return all_reduce_barrier_; } + void set_all_reduce_barrier(string barrier) { all_reduce_barrier_ = barrier; } absl::optional all_reduce_id() const { return all_reduce_id_; } void set_all_reduce_id(const absl::optional& all_reduce_id); @@ -270,8 +266,8 @@ class HloAllReduceInstruction : public HloCollectiveInstruction { const Shape& shape, absl::Span new_operands, HloCloneContext* context) const override; - // The string representation of the barrier config used for CrossReplicaSum. - string cross_replica_sum_barrier_; + // The string representation of the barrier config used for AllReduce. + string all_reduce_barrier_; // For Allreduce nodes from different modules, if they have the same // all_reduce_id, they will be 'Allreduce'd. If empty, Allreduce will not be @@ -933,7 +929,7 @@ class HloConvolutionInstruction : public HloInstruction { public: explicit HloConvolutionInstruction( const Shape& shape, HloInstruction* lhs, HloInstruction* rhs, - int64 feature_group_count, const Window& window, + int64 feature_group_count, int64 batch_group_count, const Window& window, const ConvolutionDimensionNumbers& dimension_numbers, const PrecisionConfig& precision_config); const Window& window() const override { return window_; } @@ -949,6 +945,10 @@ class HloConvolutionInstruction : public HloInstruction { // dimension and output feature dimension. int64 feature_group_count() const { return feature_group_count_; } + // The number of feature groups. Must be a divisor of the input batch + // dimension. + int64 batch_group_count() const { return batch_group_count_; } + // Returns the information used to tell the implementation information about // what sort of precision is requested. The meaning of the field is backend // specific. At the moment, it is only supported for kConvolution and kDot. @@ -977,6 +977,9 @@ class HloConvolutionInstruction : public HloInstruction { // The number of feature groups. Must be a divisor of the input feature // dimension and output feature dimension. int64 feature_group_count_; + // The number of feature groups. Must be a divisor of the input batch + // dimension. + int64 batch_group_count_; // Describes the window used for a convolution. Window window_; // Describes the dimension numbers used for a convolution. @@ -1099,7 +1102,11 @@ class HloCustomCallInstruction : public HloInstruction { void set_feature_group_count(int64 feature_group_count) { feature_group_count_ = feature_group_count; } + void set_batch_group_count(int64 batch_group_count) { + batch_group_count_ = batch_group_count; + } int64 feature_group_count() const { return feature_group_count_; } + int64 batch_group_count() const { return batch_group_count_; } // Returns a serialized representation of this instruction. HloInstructionProto ToProto() const override; @@ -1134,6 +1141,7 @@ class HloCustomCallInstruction : public HloInstruction { std::unique_ptr convolution_dimension_numbers_; // The number of feature groups. This is used for grouped convolutions. int64 feature_group_count_; + int64 batch_group_count_; // Whether the result and operand layouts are constrained. bool layout_constrained_; // For layout-constrained custom calls, this vector holds the shape with @@ -1171,7 +1179,14 @@ class HloPadInstruction : public HloInstruction { PaddingConfig padding_config_; }; -class HloDynamicSliceInstruction : public HloInstruction { +class HloDynamicIndexInstruction : public HloInstruction { + public: + explicit HloDynamicIndexInstruction(HloOpcode opcode, const Shape& shape) + : HloInstruction(opcode, shape) {} + virtual int64 index_operand_number() const = 0; +}; + +class HloDynamicSliceInstruction : public HloDynamicIndexInstruction { public: explicit HloDynamicSliceInstruction(const Shape& shape, HloInstruction* operand, @@ -1189,6 +1204,8 @@ class HloDynamicSliceInstruction : public HloInstruction { // Returns a serialized representation of this instruction. HloInstructionProto ToProto() const override; + int64 index_operand_number() const override { return 1; } + private: std::vector ExtraAttributesToStringImpl( const HloPrintOptions& options) const override; @@ -1206,6 +1223,16 @@ class HloDynamicSliceInstruction : public HloInstruction { std::vector dynamic_slice_sizes_; }; +class HloDynamicUpdateSliceInstruction : public HloDynamicIndexInstruction { + public: + explicit HloDynamicUpdateSliceInstruction(const Shape& shape, + HloInstruction* operand, + HloInstruction* update, + HloInstruction* start_indices); + + int64 index_operand_number() const override { return 2; } +}; + class HloGatherInstruction : public HloInstruction { public: explicit HloGatherInstruction( diff --git a/tensorflow/compiler/xla/service/hlo_lexer.cc b/tensorflow/compiler/xla/service/hlo_lexer.cc index 1390537101e95a08e4ba4eef7ae8d6059a40e916..dc712e5e42c449737bf4415f3a5e3eb9d81d9be4 100644 --- a/tensorflow/compiler/xla/service/hlo_lexer.cc +++ b/tensorflow/compiler/xla/service/hlo_lexer.cc @@ -19,6 +19,7 @@ limitations under the License. #include "absl/strings/escaping.h" #include "absl/strings/numbers.h" +#include "absl/strings/str_split.h" #include "absl/types/optional.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/statusor.h" @@ -82,9 +83,23 @@ tensorflow::RegexpStringPiece HloLexer::RegexpStringPieceFromPointers( return tensorflow::RegexpStringPiece(begin, end - begin); } +TokKind HloLexer::LookAhead() { + if (GetKind() == TokKind::kEof || GetKind() == TokKind::kError) { + return GetKind(); + } + + const char* old_current_ptr = current_ptr_; + TokenState old_token_state = token_state_; + Lex(); + TokKind kind = GetKind(); + token_state_ = old_token_state; + current_ptr_ = old_current_ptr; + return kind; +} + TokKind HloLexer::LexToken() { while (true) { - token_start_ = current_ptr_; + token_state_.token_start = current_ptr_; int current_char = GetNextChar(); switch (current_char) { @@ -206,43 +221,37 @@ TokKind HloLexer::LexToken() { // dim_labels_pattern ::= [0-9bf]{2,}_[0-9io]{2,}->[0-9bf]{2,} // identifiers ::= other cases that match [a-zA-Z_][a-zA-Z0-9_.-]* TokKind HloLexer::LexIdentifier() { - { - auto consumable = RegexpStringPieceFromPointers(token_start_, buf_.end()); - // 'consumable' will be advanced iff its prefix matches the pattern. - static LazyRE2 shape_pattern = { - R"(^(\w*\d*)\[([\d,\s]*)\](?:(dense|sparse)?{([\d,\s]+)})?)"}; - if (RE2::Consume(&consumable, *shape_pattern)) { - auto status_or_shape = ShapeUtil::ParseShapeString( - StringPieceFromPointers(token_start_, consumable.begin())); - if (status_or_shape.ok()) { - // This is a shape string. - shape_val_ = status_or_shape.ValueOrDie(); - current_ptr_ = consumable.begin(); - return TokKind::kShape; - } - } - } - while (IsIdentifierChar(PeekCurrentChar())) { current_ptr_++; } // If followed by ':', it's a name. if (PeekCurrentChar() == ':') { - str_val_.assign(token_start_, current_ptr_); + token_state_.str_val.assign(token_state_.token_start, current_ptr_); current_ptr_++; // skip ':' return TokKind::kName; } // If followed by '=', it's a attribute name. if (PeekCurrentChar() == '=') { - str_val_.assign(token_start_, current_ptr_); + token_state_.str_val.assign(token_state_.token_start, current_ptr_); current_ptr_++; // skip '=' return TokKind::kAttributeName; } absl::string_view identifier = - StringPieceFromPointers(token_start_, current_ptr_); + StringPieceFromPointers(token_state_.token_start, current_ptr_); + + // Primitive type strings are reserved words. The exception is 'tuple' whose + // type is represented using nested parentheses without the string 'tuple'. + if (primitive_util::IsPrimitiveTypeName(identifier)) { + PrimitiveType primitive_type = + primitive_util::StringToPrimitiveType(identifier).ValueOrDie(); + if (primitive_type != TUPLE) { + token_state_.primitive_type_val = primitive_type; + return TokKind::kPrimitiveType; + } + } // See if this is a keyword. #define KEYWORD(STR) \ @@ -261,21 +270,23 @@ TokKind HloLexer::LexIdentifier() { KEYWORD(ROOT); KEYWORD(maximal); KEYWORD(replicated); + KEYWORD(sparse); #undef KEYWORD { - auto consumable = RegexpStringPieceFromPointers(token_start_, buf_.end()); + auto consumable = + RegexpStringPieceFromPointers(token_state_.token_start, buf_.end()); static LazyRE2 dim_labels_pattern = { R"([0-9bf]{2,}_[0-9io]{2,}->[0-9bf]{2,})"}; if (RE2::Consume(&consumable, *dim_labels_pattern)) { current_ptr_ = consumable.begin(); - str_val_.assign(token_start_, current_ptr_); + token_state_.str_val.assign(token_state_.token_start, current_ptr_); return TokKind::kDimLabels; } } - str_val_ = string(identifier); + token_state_.str_val = string(identifier); return TokKind::kIdent; } @@ -289,7 +300,7 @@ TokKind HloLexer::LexPercent() { while (IsIdentifierChar(PeekCurrentChar())) { current_ptr_++; } - str_val_.assign(name_start, current_ptr_); + token_state_.str_val.assign(name_start, current_ptr_); return TokKind::kName; } return TokKind::kError; @@ -307,12 +318,14 @@ TokKind HloLexer::LexPercent() { // int ::= [-]?[0-9]+ // negative inf ::= '-inf' TokKind HloLexer::LexNumberOrPattern() { - auto consumable = RegexpStringPieceFromPointers(token_start_, buf_.end()); + auto consumable = + RegexpStringPieceFromPointers(token_state_.token_start, buf_.end()); static LazyRE2 float_pattern = { R"([-]?((\d+|\d+[.]\d*|\d*[.]\d+)([eE][+-]?\d+))|[-]?(\d+[.]\d*|\d*[.]\d+))"}; if (RE2::Consume(&consumable, *float_pattern)) { current_ptr_ = consumable.begin(); - CHECK(absl::SimpleAtod(string(token_start_, current_ptr_), &decimal_val_)); + CHECK(absl::SimpleAtod(string(token_state_.token_start, current_ptr_), + &token_state_.decimal_val)); return TokKind::kDecimal; } @@ -324,27 +337,28 @@ TokKind HloLexer::LexNumberOrPattern() { if (RE2::Consume(&consumable, *dim_labels_pattern)) { current_ptr_ = consumable.begin(); - str_val_.assign(token_start_, current_ptr_); + token_state_.str_val.assign(token_state_.token_start, current_ptr_); return TokKind::kDimLabels; } if (RE2::Consume(&consumable, *dxd_pattern)) { current_ptr_ = consumable.begin(); - str_val_.assign(token_start_, current_ptr_); + token_state_.str_val.assign(token_state_.token_start, current_ptr_); return TokKind::kDxD; } if (RE2::Consume(&consumable, *pad_pattern)) { current_ptr_ = consumable.begin(); - str_val_.assign(token_start_, current_ptr_); + token_state_.str_val.assign(token_state_.token_start, current_ptr_); return TokKind::kPad; } static LazyRE2 int_pattern = {R"([-]?\d+)"}; if (RE2::Consume(&consumable, *int_pattern)) { current_ptr_ = consumable.begin(); - auto slice = StringPieceFromPointers(token_start_, current_ptr_); - if (absl::SimpleAtoi(slice, &int64_val_)) { + auto slice = + StringPieceFromPointers(token_state_.token_start, current_ptr_); + if (absl::SimpleAtoi(slice, &token_state_.int64_val)) { return TokKind::kInt; } LOG(ERROR) << "Failed to parse int literal: " << slice; @@ -403,16 +417,17 @@ absl::string_view HloLexer::GetLine(LocTy loc) const { } // Lexes quoted string with escaping characters. If matched, the quoted string -// will be unescaped and stored to str_val_. +// will be unescaped and stored to token_state_.str_val. TokKind HloLexer::LexString() { - auto consumable = RegexpStringPieceFromPointers(token_start_, buf_.end()); + auto consumable = + RegexpStringPieceFromPointers(token_state_.token_start, buf_.end()); static LazyRE2 escaping_pattern = {R"("([^"\\]|\\.)*")"}; if (RE2::Consume(&consumable, *escaping_pattern)) { current_ptr_ = consumable.begin(); absl::string_view raw = - StringPieceFromPointers(token_start_ + 1, current_ptr_ - 1); + StringPieceFromPointers(token_state_.token_start + 1, current_ptr_ - 1); string error; - if (!absl::CUnescape(raw, &str_val_, &error)) { + if (!absl::CUnescape(raw, &token_state_.str_val, &error)) { LOG(ERROR) << "Failed unescaping string: " << raw << ". error: " << error; return TokKind::kError; } @@ -467,6 +482,10 @@ string TokKindToString(TokKind kind) { return "kw_inf"; case TokKind::kNegInf: return "kNegInf"; + case TokKind::kw_sparse: + return "kw_sparse"; + case TokKind::kPrimitiveType: + return "kPrimitiveType"; case TokKind::kName: return "kName"; case TokKind::kAttributeName: @@ -481,8 +500,6 @@ string TokKindToString(TokKind kind) { return "kIdent"; case TokKind::kString: return "kString"; - case TokKind::kShape: - return "kShape"; case TokKind::kInt: return "kInt"; case TokKind::kDecimal: diff --git a/tensorflow/compiler/xla/service/hlo_lexer.h b/tensorflow/compiler/xla/service/hlo_lexer.h index d6a2b292a3916b2ff85f278cf5cb9f1567df88fa..41f5043904a2622814154693679a0e27cb92f642 100644 --- a/tensorflow/compiler/xla/service/hlo_lexer.h +++ b/tensorflow/compiler/xla/service/hlo_lexer.h @@ -19,7 +19,6 @@ limitations under the License. #include #include "absl/strings/string_view.h" -#include "tensorflow/compiler/xla/service/hlo_token.h" #include "tensorflow/compiler/xla/shape.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" @@ -29,6 +28,57 @@ limitations under the License. namespace xla { +// Defines different kinds of tokens used by the HLO lexer. +// +// You shouldn't need to use this directly unless you're using HloLexer +// directly, and you probably don't need to do that. Use hlo_parser instead. +enum class TokKind { + // Markers + kEof, + kError, + + // Tokens with no info. + kEqual, // = + kComma, // , + kColon, // : + kLsquare, + kRsquare, // [ ] + kLbrace, + kRbrace, // { } + kLparen, + kRparen, // ( ) + + kArrow, // -> + + // Keywords + kw_HloModule, + kw_ENTRY, + kw_ROOT, + kw_true, + kw_false, + kw_maximal, + kw_replicated, + kw_nan, + kw_inf, + kw_sparse, + + kNegInf, // -inf + + // Typed tokens. + kPrimitiveType, // F32, PRED, etc. + kName, // %foo + kAttributeName, // dimensions= + kDimLabels, // [0-9bf]{2,}_[0-9io]{2,}->[0-9bf]{2,} + kDxD, // [0-9]+(x[0-9]+)+ + kPad, // [0-9]+_[0-9]+(_[0-9]+)?(x[0-9]+_[0-9]+(_[0-9]+)?)* + kIdent, // other identifiers + kString, // "abcd\"\n" + kInt, // 42 + kDecimal, // 4.2 +}; + +string TokKindToString(TokKind kind); + // Lexer for the HloModule::ToString() format text. // // This class is meant to be used by hlo_parser.cc. You shouldn't need to use @@ -39,9 +89,9 @@ class HloLexer { current_ptr_ = buf_.begin(); } - TokKind Lex() { return current_kind_ = LexToken(); } + TokKind Lex() { return token_state_.current_kind = LexToken(); } - TokKind GetKind() const { return current_kind_; } + TokKind GetKind() const { return token_state_.current_kind; } string GetStrVal() const { switch (GetKind()) { case TokKind::kName: @@ -51,28 +101,28 @@ class HloLexer { case TokKind::kPad: case TokKind::kString: case TokKind::kIdent: - return str_val_; + return token_state_.str_val; default: LOG(FATAL) << "This token does not have string value"; } } - Shape GetShapeVal() const { - CHECK(GetKind() == TokKind::kShape); - return shape_val_; - } tensorflow::int64 GetInt64Val() const { CHECK(GetKind() == TokKind::kInt); - return int64_val_; + return token_state_.int64_val; } double GetDecimalVal() const { CHECK(GetKind() == TokKind::kDecimal); - return decimal_val_; + return token_state_.decimal_val; + } + PrimitiveType GetPrimitiveTypeVal() const { + CHECK(GetKind() == TokKind::kPrimitiveType); + return token_state_.primitive_type_val; } typedef const char* LocTy; // Returns the location of the current token. - LocTy GetLoc() const { return token_start_; } + LocTy GetLoc() const { return token_state_.token_start; } // Returns the line and column of a location in the buffer. std::pair GetLineAndColumn(LocTy location) const; @@ -80,6 +130,9 @@ class HloLexer { // Returns the whole line given the location. absl::string_view GetLine(LocTy loc) const; + // Looks ahead one token and returns it. Lexer state is unchanged. + TokKind LookAhead(); + private: // Returns the current character. If it's neither the end of input buffer nor // an invalid character, moves the pointer forward. @@ -112,12 +165,15 @@ class HloLexer { const char* current_ptr_; // Information about the current token. - const char* token_start_ = nullptr; - TokKind current_kind_; - string str_val_; - Shape shape_val_; - tensorflow::int64 int64_val_; - double decimal_val_; + struct TokenState { + const char* token_start = nullptr; + TokKind current_kind; + string str_val; + tensorflow::int64 int64_val; + double decimal_val; + PrimitiveType primitive_type_val; + }; + TokenState token_state_; struct LineNoCacheTy { const char* last_query; diff --git a/tensorflow/compiler/xla/service/hlo_liveness_analysis_test.cc b/tensorflow/compiler/xla/service/hlo_liveness_analysis_test.cc index e0ae1173c6114f0bc6ef18b2cfff9d54ccfe2faf..436cccb1fb9ecf6f4efad772c700c611b28ce628 100644 --- a/tensorflow/compiler/xla/service/hlo_liveness_analysis_test.cc +++ b/tensorflow/compiler/xla/service/hlo_liveness_analysis_test.cc @@ -403,9 +403,9 @@ TEST_F(HloLivenessAnalysisTest, WhileWithOutfeed) { HloModule OutfeedLoop WhileBody { body_param = (s32[]) parameter(0) - token = token[] after-all() + token0 = token[] after-all() constant.2 = s32[] constant(2) - outfeed_tuple = (s32[]) outfeed(constant.2, token) + outfeed_tuple = (s32[]) outfeed(constant.2, token0) get-tuple-element.1 = s32[] get-tuple-element(body_param), index=0 constant.1 = s32[] constant(1) add = s32[] add(get-tuple-element.1, constant.1) @@ -436,9 +436,9 @@ TEST_F(HloLivenessAnalysisTest, NestedWhileWithOutfeed) { HloModule OutfeedLoop InnerWhileBody { body_param = (s32[]) parameter(0) - token = token[] after-all() + token0 = token[] after-all() constant.2 = s32[] constant(2) - outfeed_tuple = (s32[]) outfeed(constant.2, token) + outfeed_tuple = (s32[]) outfeed(constant.2, token0) get-tuple-element.1 = s32[] get-tuple-element(body_param), index=0 constant.1 = s32[] constant(1) add = s32[] add(get-tuple-element.1, constant.1) diff --git a/tensorflow/compiler/xla/service/hlo_matchers.h b/tensorflow/compiler/xla/service/hlo_matchers.h index 235efb19ce4ed28a5cd9fe5ca52ae5d8e9e5ba3d..67488a6a9a0c9cba7f576f9036c3a0cbe1900fff 100644 --- a/tensorflow/compiler/xla/service/hlo_matchers.h +++ b/tensorflow/compiler/xla/service/hlo_matchers.h @@ -178,7 +178,7 @@ HLO_MATCHER(Constant); HLO_MATCHER(Convert); HLO_MATCHER(Convolution); HLO_MATCHER(Copy); -HLO_MATCHER(CrossReplicaSum); +HLO_MATCHER(AllReduce); HLO_MATCHER(CollectivePermute); HLO_MATCHER(Divide); HLO_MATCHER(Domain); @@ -312,8 +312,8 @@ inline ::testing::Matcher Shape( } inline ::testing::Matcher Shape( absl::string_view shape) { - return ::testing::MakeMatcher(new ::xla::testing::HloShapeMatcher( - ShapeUtil::ParseShapeString(shape).ValueOrDie())); + return ::testing::MakeMatcher( + new ::xla::testing::HloShapeMatcher(ParseShape(shape).ValueOrDie())); } inline ::testing::Matcher ShapeWithLayout( const class Shape& shape) { @@ -323,7 +323,7 @@ inline ::testing::Matcher ShapeWithLayout( inline ::testing::Matcher ShapeWithLayout( absl::string_view shape) { return ::testing::MakeMatcher(new ::xla::testing::HloShapeAndLayoutMatcher( - ShapeUtil::ParseShapeString(shape).ValueOrDie())); + ParseShape(shape).ValueOrDie())); } // Verifies the value of the HloSharing against the provided sharding object. diff --git a/tensorflow/compiler/xla/service/hlo_module.h b/tensorflow/compiler/xla/service/hlo_module.h index 7b9cbf9a53a2201b1312405bbd7ed2b88f65c9be..f1310e4b270898a21dbb4f86123edde4ba8993d0 100644 --- a/tensorflow/compiler/xla/service/hlo_module.h +++ b/tensorflow/compiler/xla/service/hlo_module.h @@ -136,7 +136,9 @@ class HloModule { // information on opcode, shape, operands, and typically a root instruction. // This function returns the same hash value for equivalent HLO modules, // with respect to HloInstruction::Identical() method. - uint64 Hash() const { return entry_computation()->Hash(); } + uint64 Hash() const { + return entry_computation()->root_instruction()->Hash(); + } // Gets the computations in this module. // diff --git a/tensorflow/compiler/xla/service/hlo_module_dce_test.cc b/tensorflow/compiler/xla/service/hlo_module_dce_test.cc index bf66cc6bc37a5e11c9ecfc07a62ba0ea5ca11a03..e535b7d74943943069b4d795cf999a3b1e963360 100644 --- a/tensorflow/compiler/xla/service/hlo_module_dce_test.cc +++ b/tensorflow/compiler/xla/service/hlo_module_dce_test.cc @@ -373,9 +373,9 @@ TEST_F(HloModuleDceTest, WhileWithOutfeed) { HloModule OutfeedLoop WhileBody { body_param = (s32[]) parameter(0) - token = token[] after-all() + token0 = token[] after-all() constant.2 = s32[] constant(2) - outfeed_tuple = (s32[]) outfeed(constant.2, token) + outfeed_tuple = (s32[]) outfeed(constant.2, token0) get-tuple-element.1 = s32[] get-tuple-element(body_param), index=0 constant.1 = s32[] constant(1) add = s32[] add(get-tuple-element.1, constant.1) diff --git a/tensorflow/compiler/xla/service/hlo_opcode.h b/tensorflow/compiler/xla/service/hlo_opcode.h index 127cfd165a5d8229cac3035f56a66f1bcfa734f3..94122ac38ff2a3f7053b19e55f9a400c80ae2134 100644 --- a/tensorflow/compiler/xla/service/hlo_opcode.h +++ b/tensorflow/compiler/xla/service/hlo_opcode.h @@ -49,6 +49,7 @@ namespace xla { V(kAdd, "add") \ V(kAddDependency, "add-dependency") \ V(kAfterAll, "after-all", kHloOpcodeIsVariadic) \ + V(kAllReduce, "all-reduce") \ V(kAllToAll, "all-to-all") \ V(kAtan2, "atan2") \ V(kBatchNormGrad, "batch-norm-grad") \ @@ -70,7 +71,6 @@ namespace xla { V(kConvolution, "convolution") \ V(kCopy, "copy") \ V(kCos, "cosine") \ - V(kCrossReplicaSum, "cross-replica-sum") \ V(kCustomCall, "custom-call") \ V(kDivide, "divide") \ V(kDomain, "domain") \ diff --git a/tensorflow/compiler/xla/service/hlo_parser.cc b/tensorflow/compiler/xla/service/hlo_parser.cc index 9b5bb5d0bd6af104ef62eaa5d3e53cedbe0213d3..44643951c14fb3a210b27064ffac4b99734bca0a 100644 --- a/tensorflow/compiler/xla/service/hlo_parser.cc +++ b/tensorflow/compiler/xla/service/hlo_parser.cc @@ -23,6 +23,7 @@ limitations under the License. #include "absl/strings/str_split.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/primitive_util.h" #include "tensorflow/compiler/xla/service/hlo_domain_metadata.h" #include "tensorflow/compiler/xla/service/hlo_instructions.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" @@ -74,6 +75,7 @@ class HloParser { string GetError() const { return StrJoin(error_, "\n"); } // Stand alone parsing utils for various aggregate data types. + StatusOr ParseShapeOnly(); StatusOr ParseShardingOnly(); StatusOr ParseWindowOnly(); StatusOr ParseConvolutionDimensionNumbersOnly(); @@ -255,7 +257,9 @@ class HloParser { bool ParseName(string* result); bool ParseAttributeName(string* result); bool ParseString(string* result); + bool ParseDimensionSizes(std::vector* dimension_sizes); bool ParseShape(Shape* result); + bool ParseLayout(Layout* layout); bool ParseOpcode(HloOpcode* result); bool ParseFftType(FftType* result); bool ParseFusionKind(HloInstruction::FusionKind* result); @@ -279,9 +283,6 @@ class HloParser { // If the current token is 'kind', eats it (i.e. lexes the next token) and // returns true. bool EatIfPresent(TokKind kind); - // Parses a shape, and returns true if the result is compatible with the given - // shape. - bool EatShapeAndCheckCompatible(const Shape& shape); // Adds the instruction to the pool. Returns false and emits an error if the // instruction already exists. @@ -766,7 +767,7 @@ bool HloParser::ParseInstructionRhs(HloComputation::Builder* builder, HloInstruction::CreateBitcastConvert(shape, operands[0])); break; } - case HloOpcode::kCrossReplicaSum: { + case HloOpcode::kAllReduce: { optional>> tmp_groups; optional to_apply; optional> replica_group_ids; @@ -786,10 +787,9 @@ bool HloParser::ParseInstructionRhs(HloComputation::Builder* builder, if (tmp_groups) { replica_groups = CreateReplicaGroups(*tmp_groups); } - instruction = - builder->AddInstruction(HloInstruction::CreateCrossReplicaSum( - shape, operands, *to_apply, replica_groups, - barrier ? *barrier : "", all_reduce_id)); + instruction = builder->AddInstruction(HloInstruction::CreateAllReduce( + shape, operands, *to_apply, replica_groups, barrier ? *barrier : "", + all_reduce_id)); break; } case HloOpcode::kAllToAll: { @@ -1006,11 +1006,14 @@ bool HloParser::ParseInstructionRhs(HloComputation::Builder* builder, optional window; optional dnums; optional feature_group_count; + optional batch_group_count; attrs["window"] = {/*required=*/false, AttrTy::kWindow, &window}; attrs["dim_labels"] = {/*required=*/true, AttrTy::kConvolutionDimensionNumbers, &dnums}; attrs["feature_group_count"] = {/*required=*/false, AttrTy::kInt64, &feature_group_count}; + attrs["batch_group_count"] = {/*required=*/false, AttrTy::kInt64, + &batch_group_count}; optional> operand_precision; attrs["operand_precision"] = {/*required=*/false, AttrTy::kPrecisionList, &operand_precision}; @@ -1024,6 +1027,9 @@ bool HloParser::ParseInstructionRhs(HloComputation::Builder* builder, if (!feature_group_count) { feature_group_count = 1; } + if (!batch_group_count) { + batch_group_count = 1; + } PrecisionConfig precision_config; if (operand_precision) { *precision_config.mutable_operand_precision() = { @@ -1034,7 +1040,8 @@ bool HloParser::ParseInstructionRhs(HloComputation::Builder* builder, } instruction = builder->AddInstruction(HloInstruction::CreateConvolve( shape, /*lhs=*/operands[0], /*rhs=*/operands[1], - feature_group_count.value(), *window, *dnums, precision_config)); + feature_group_count.value(), batch_group_count.value(), *window, + *dnums, precision_config)); break; } case HloOpcode::kFft: { @@ -1697,11 +1704,6 @@ bool HloParser::ParseSingleSharding(OpSharding* sharding, } break; } - case TokKind::kShape: - // TODO(b/112302613): Left here for backward compatibility to ignore the - // removed tile shape data. - lexer_.Lex(); - break; case TokKind::kRbrace: break; default: @@ -1925,19 +1927,6 @@ bool HloParser::SetValueInLiteralHelper(ParsedElemT value, return true; } -bool HloParser::EatShapeAndCheckCompatible(const Shape& shape) { - Shape new_shape; - if (!ParseShape(&new_shape)) { - return TokenError(StrCat("expects shape ", ShapeUtil::HumanString(shape))); - } - if (!ShapeUtil::Compatible(shape, new_shape)) { - return TokenError(StrCat( - "expects shape ", ShapeUtil::HumanString(shape), - ", but sees a different shape: ", ShapeUtil::HumanString(new_shape))); - } - return true; -} - // literal // ::= tuple // ::= non_tuple @@ -1952,10 +1941,6 @@ bool HloParser::ParseLiteral(Literal* literal, const Shape& shape) { // ::= /*empty*/ // ::= literal (',' literal)* bool HloParser::ParseTupleLiteral(Literal* literal, const Shape& shape) { - if (!EatShapeAndCheckCompatible(shape)) { - return TokenError(StrCat("expects tuple constant in shape ", - ShapeUtil::HumanString(shape))); - } if (!ParseToken(TokKind::kLparen, "expects '(' in front of tuple elements")) { return false; } @@ -1990,16 +1975,12 @@ bool HloParser::ParseNonTupleLiteral(Literal* literal, const Shape& shape) { return ParseSparseLiteral(literal, shape); } - CHECK(LayoutUtil::IsDenseArray(shape)); + CHECK(LayoutUtil::IsDenseArray(shape)) << shape.ToString(true); return ParseDenseLiteral(literal, shape); } bool HloParser::ParseDenseLiteral(Literal* literal, const Shape& shape) { const tensorflow::int64 rank = ShapeUtil::Rank(shape); - if (rank > 1 && !EatShapeAndCheckCompatible(shape)) { - return false; - } - // Create a literal with the given shape in default layout. *literal = LiteralUtil::CreateFromDimensions( shape.element_type(), AsInt64Slice(shape.dimensions())); @@ -2126,10 +2107,6 @@ bool HloParser::ParseDenseLiteral(Literal* literal, const Shape& shape) { } bool HloParser::ParseSparseLiteral(Literal* literal, const Shape& shape) { - if (!EatShapeAndCheckCompatible(shape)) { - return false; - } - switch (shape.element_type()) { case PRED: return ParseSparseLiteralHelper(literal, shape); @@ -2994,6 +2971,39 @@ bool HloParser::ParseParamList() { return ParseToken(TokKind::kRparen, "expects ')' at the end of param list"); } +// dimension_sizes ::= '[' int64_list ']' +bool HloParser::ParseDimensionSizes(std::vector* dimension_sizes) { + auto parse_and_add_item = [&]() { + tensorflow::int64 i; + if (!ParseInt64(&i)) { + return false; + } + dimension_sizes->push_back(i); + return true; + }; + return ParseList(TokKind::kLsquare, TokKind::kRsquare, TokKind::kComma, + parse_and_add_item); +} + +// layout ::= '{' int64_list '}' +bool HloParser::ParseLayout(Layout* layout) { + std::vector minor_to_major; + auto parse_and_add_item = [&]() { + tensorflow::int64 i; + if (!ParseInt64(&i)) { + return false; + } + minor_to_major.push_back(i); + return true; + }; + if (!ParseList(TokKind::kLbrace, TokKind::kRbrace, TokKind::kComma, + parse_and_add_item)) { + return false; + } + *layout = LayoutUtil::MakeLayout(minor_to_major); + return true; +} + // shape ::= shape_val_ // shape ::= '(' tuple_elements ')' // tuple_elements @@ -3017,19 +3027,61 @@ bool HloParser::ParseShape(Shape* result) { return ParseToken(TokKind::kRparen, "expects ')' at the end of tuple."); } - if (lexer_.GetKind() != TokKind::kShape) { - return TokenError(absl::StrCat("expected shape, saw ", + if (lexer_.GetKind() != TokKind::kPrimitiveType) { + return TokenError(absl::StrCat("expected primitive type, saw ", TokKindToString(lexer_.GetKind()))); } - *result = lexer_.GetShapeVal(); + PrimitiveType primitive_type = lexer_.GetPrimitiveTypeVal(); lexer_.Lex(); + + std::vector dimension_sizes; + if (!ParseDimensionSizes(&dimension_sizes)) { + return false; + } + result->set_element_type(primitive_type); + *result->mutable_dimensions() = dimension_sizes; + LayoutUtil::SetToDefaultLayout(result); + + if (lexer_.GetKind() == TokKind::kw_sparse) { + lexer_.Lex(); + const string message = + "expects a brace-bracketed integer for sparse layout"; + tensorflow::int64 max_sparse_elements; + if (!ParseToken(TokKind::kLbrace, message) || + !ParseInt64(&max_sparse_elements) || + !ParseToken(TokKind::kRbrace, message)) { + return false; + } + *result->mutable_layout() = + LayoutUtil::MakeSparseLayout(max_sparse_elements); + return true; + } + + // We need to lookahead to see if a following open brace is the start of a + // layout. The specific problematic case is: + // + // ENTRY %foo (x: f32[42]) -> f32[123] { + // ... + // } + // + // The open brace could either be the start of a computation or the start of a + // layout for the f32[123] shape. We consider it the start of a layout if the + // next token after the open brace is a integer + if (lexer_.GetKind() == TokKind::kLbrace && + lexer_.LookAhead() == TokKind::kInt) { + Layout layout; + if (!ParseLayout(&layout)) { + return false; + } + *result->mutable_layout() = layout; + } return true; } bool HloParser::CanBeShape() { - // A non-tuple shape starts with a kShape token; a tuple shape starts with - // '('. - return lexer_.GetKind() == TokKind::kShape || + // A non-tuple shape starts with a kPrimitiveType token; a tuple shape starts + // with '('. + return lexer_.GetKind() == TokKind::kPrimitiveType || lexer_.GetKind() == TokKind::kLparen; } @@ -3332,6 +3384,18 @@ bool HloParser::AddComputation(const string& name, HloComputation* computation, return true; } +StatusOr HloParser::ParseShapeOnly() { + lexer_.Lex(); + Shape shape; + if (!ParseShape(&shape)) { + return InvalidArgument("Syntax error:\n%s", GetError()); + } + if (lexer_.GetKind() != TokKind::kEof) { + return InvalidArgument("Syntax error:\nExtra content after shape"); + } + return shape; +} + StatusOr HloParser::ParseShardingOnly() { lexer_.Lex(); OpSharding op_sharding; @@ -3475,4 +3539,9 @@ StatusOr ParsePaddingConfig(absl::string_view str) { return parser.ParsePaddingConfigOnly(); } +StatusOr ParseShape(absl::string_view str) { + HloParser parser(str); + return parser.ParseShapeOnly(); +} + } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_parser.h b/tensorflow/compiler/xla/service/hlo_parser.h index d830fa61438239005875f785f85cf2486123ebc9..450a54c54c156c2ae27475d145a8e83dc841b431 100644 --- a/tensorflow/compiler/xla/service/hlo_parser.h +++ b/tensorflow/compiler/xla/service/hlo_parser.h @@ -60,6 +60,9 @@ StatusOr ParseConvolutionDimensionNumbers( // Parses the result of PaddingConfigToString(), e.g. "0_0x1_1". StatusOr ParsePaddingConfig(absl::string_view str); +// Parses and returns a Shape::ToString-format string. +StatusOr ParseShape(absl::string_view str); + } // namespace xla #endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_PARSER_H_ diff --git a/tensorflow/compiler/xla/service/hlo_parser_test.cc b/tensorflow/compiler/xla/service/hlo_parser_test.cc index ab71f011ac9d77d00ddfb41aca7a224d26d416b7..ef31cec32770690505b437d8678c45150766e559 100644 --- a/tensorflow/compiler/xla/service/hlo_parser_test.cc +++ b/tensorflow/compiler/xla/service/hlo_parser_test.cc @@ -82,7 +82,7 @@ ENTRY %constant_pred () -> pred[] { R"(HloModule module ENTRY %constant_pred_array () -> pred[2,3] { - ROOT %constant = pred[2,3]{1,0} constant(pred[2,3] { { 0, 1, 0 }, { 1, 0, 1 } }) + ROOT %constant = pred[2,3]{1,0} constant({ { 0, 1, 0 }, { 1, 0, 1 } }) } )" @@ -128,7 +128,7 @@ ENTRY %ConstantF32Empty.v4 () -> f32[0] { R"(HloModule ConstantF32R4Empty_module ENTRY %ConstantF32R4Empty.v4 () -> f32[2,0,4,3] { - ROOT %constant = f32[2,0,4,3]{3,2,1,0} constant(f32[2,0,4,3] { { /*i0=0*/ }, { /*i0=1*/ } }) + ROOT %constant = f32[2,0,4,3]{3,2,1,0} constant({ { /*i0=0*/ }, { /*i0=1*/ } }) } )" @@ -139,7 +139,7 @@ ENTRY %ConstantF32R4Empty.v4 () -> f32[2,0,4,3] { R"(HloModule Small_3x2x1x1_module ENTRY %Small_3x2x1x1.v1 () -> f32[3,2,1,1] { - ROOT %constant = f32[3,2,1,1]{3,2,1,0} constant(f32[3,2,1,1] { { /*i0=0*/ { /*i1=0*/ {-1} }, { /*i1=1*/ {4.1} } }, { /*i0=1*/ { /*i1=0*/ {2} }, { /*i1=1*/ {4.1} } }, { /*i0=2*/ { /*i1=0*/ {5} }, { /*i1=1*/ {4.4} } } }) + ROOT %constant = f32[3,2,1,1]{3,2,1,0} constant({ { /*i0=0*/ { /*i1=0*/ {-1} }, { /*i1=1*/ {4.1} } }, { /*i0=1*/ { /*i1=0*/ {2} }, { /*i1=1*/ {4.1} } }, { /*i0=2*/ { /*i1=0*/ {5} }, { /*i1=1*/ {4.4} } } }) } )" @@ -196,7 +196,7 @@ ENTRY %add_constants () -> f32[] { R"(HloModule TupleConstant_module ENTRY %TupleConstant.v1 () -> (f32[2,1], f32[2]) { - ROOT %constant = (f32[2,1]{1,0}, f32[2]{0}) constant((f32[2,1], f32[2]) ( f32[2,1] { {1}, {2} }, {2, 42} )) + ROOT %constant = (f32[2,1]{1,0}, f32[2]{0}) constant(( { {1}, {2} }, {2, 42} )) } )" @@ -295,11 +295,11 @@ ENTRY %WhileWithScalarS32Result.v2 () -> s32[] { R"(HloModule TwoSendRecvBothWayRecvFist_module ENTRY %TwoSendRecvBothWayRecvFist.v3 () -> (f32[], token[]) { - %token = token[] after-all() - %recv = (f32[], u32[], token[]) recv(token[] %token), channel_id=15, sharding={maximal device=1} + %token0 = token[] after-all() + %recv = (f32[], u32[], token[]) recv(token[] %token0), channel_id=15, sharding={maximal device=1} ROOT %recv-done = (f32[], token[]) recv-done((f32[], u32[], token[]) %recv), channel_id=15, sharding={maximal device=1} %constant = f32[] constant(2.1), sharding={maximal device=0} - %send = (f32[], u32[], token[]) send(f32[] %constant, token[] %token), channel_id=16, sharding={maximal device=0}, control-predecessors={%recv} + %send = (f32[], u32[], token[]) send(f32[] %constant, token[] %token0), channel_id=16, sharding={maximal device=0}, control-predecessors={%recv} %send-done = token[] send-done((f32[], u32[], token[]) %send), channel_id=16, sharding={maximal device=0} } @@ -310,11 +310,11 @@ ENTRY %TwoSendRecvBothWayRecvFist.v3 () -> (f32[], token[]) { R"(HloModule HostTransferSendRecv_module ENTRY %TwoSendRecvBothWayRecvFist.v3 () -> (f32[], token[]) { - %token = token[] after-all() - %recv = (f32[], u32[], token[]) recv(token[] %token), channel_id=15, is_host_transfer=true + %token0 = token[] after-all() + %recv = (f32[], u32[], token[]) recv(token[] %token0), channel_id=15, is_host_transfer=true ROOT %recv-done = (f32[], token[]) recv-done((f32[], u32[], token[]) %recv), channel_id=15, is_host_transfer=true %constant = f32[] constant(2.1), sharding={maximal device=0} - %send = (f32[], u32[], token[]) send(f32[] %constant, token[] %token), channel_id=16, is_host_transfer=true + %send = (f32[], u32[], token[]) send(f32[] %constant, token[] %token0), channel_id=16, is_host_transfer=true %send-done = token[] send-done((f32[], u32[], token[]) %send), channel_id=16, is_host_transfer=true } @@ -327,7 +327,7 @@ R"(HloModule GetTupleElement_module ENTRY %GetTupleElement.v4 () -> s32[2,3] { %constant = f32[3]{0} constant({1, 2, 3}) - %constant.1 = s32[2,3]{1,0} constant(s32[2,3] { { 1, 2, 3 }, { 4, 5, 6 } }) + %constant.1 = s32[2,3]{1,0} constant({ { 1, 2, 3 }, { 4, 5, 6 } }) %tuple = (f32[3]{0}, s32[2,3]{1,0}) tuple(f32[3]{0} %constant, s32[2,3]{1,0} %constant.1) ROOT %get-tuple-element = s32[2,3]{1,0} get-tuple-element((f32[3]{0}, s32[2,3]{1,0}) %tuple), index=1, sharding={maximal device=0} } @@ -434,7 +434,7 @@ ENTRY %ConvolveBackward (input: f32[128,7,7,512], filter: f32[3,3,512,512]) -> f R"(HloModule Reverse4DFloatArrayOnDim01_module ENTRY %Reverse4DFloatArrayOnDim01.v2 () -> f32[4,3,2,1] { - %constant = f32[4,3,2,1]{0,1,2,3} constant(f32[4,3,2,1] { { /*i0=0*/ { /*i1=0*/ {1}, {2} }, { /*i1=1*/ {3}, {4} }, { /*i1=2*/ {5}, {6} } }, { /*i0=1*/ { /*i1=0*/ {7}, {8} }, { /*i1=1*/ {9}, {10} }, { /*i1=2*/ {11}, {12} } }, { /*i0=2*/ { /*i1=0*/ {13}, {14} }, { /*i1=1*/ {15}, {16} }, { /*i1=2*/ {17}, {18} } }, { /*i0=3*/ { /*i1=0*/ {19}, {20} }, { /*i1=1*/ {21}, {22} }, { /*i1=2*/ {23}, {24} } } }) + %constant = f32[4,3,2,1]{0,1,2,3} constant({ { /*i0=0*/ { /*i1=0*/ {1}, {2} }, { /*i1=1*/ {3}, {4} }, { /*i1=2*/ {5}, {6} } }, { /*i0=1*/ { /*i1=0*/ {7}, {8} }, { /*i1=1*/ {9}, {10} }, { /*i1=2*/ {11}, {12} } }, { /*i0=2*/ { /*i1=0*/ {13}, {14} }, { /*i1=1*/ {15}, {16} }, { /*i1=2*/ {17}, {18} } }, { /*i0=3*/ { /*i1=0*/ {19}, {20} }, { /*i1=1*/ {21}, {22} }, { /*i1=2*/ {23}, {24} } } }) ROOT %reverse = f32[4,3,2,1]{0,1,2,3} reverse(f32[4,3,2,1]{0,1,2,3} %constant), dimensions={0,1} } @@ -446,8 +446,8 @@ ENTRY %Reverse4DFloatArrayOnDim01.v2 () -> f32[4,3,2,1] { R"(HloModule Concat2x3With2x5_module ENTRY %Concat2x3With2x5.v3 () -> f32[2,8] { - %constant = f32[2,3]{1,0} constant(f32[2,3] { { 0, 1, 2 }, { 1000, 1001, 1002 } }) - %constant.1 = f32[2,5]{1,0} constant(f32[2,5] { { 64, 65, 66, 67, 68 }, { 1064, 1065, 1066, 1067, 1068 } }) + %constant = f32[2,3]{1,0} constant({ { 0, 1, 2 }, { 1000, 1001, 1002 } }) + %constant.1 = f32[2,5]{1,0} constant({ { 64, 65, 66, 67, 68 }, { 1064, 1065, 1066, 1067, 1068 } }) ROOT %concatenate = f32[2,8]{1,0} concatenate(f32[2,3]{1,0} %constant, f32[2,5]{1,0} %constant.1), dimensions={1} } @@ -471,8 +471,8 @@ R"(HloModule R4F32OverlapSmall_module } ENTRY %R4F32OverlapSmall.v4 () -> f32[4,5,1,1] { - %constant = f32[4,5,1,1]{3,2,1,0} constant(f32[4,5,1,1] { { /*i0=0*/ { /*i1=0*/ {7} }, { /*i1=1*/ {2} }, { /*i1=2*/ {5} }, { /*i1=3*/ {3} }, { /*i1=4*/ {8} } }, { /*i0=1*/ { /*i1=0*/ {3} }, { /*i1=1*/ {8} }, { /*i1=2*/ {9} }, { /*i1=3*/ {3} }, { /*i1=4*/ {4} } }, { /*i0=2*/ { /*i1=0*/ {1} }, { /*i1=1*/ {5} }, { /*i1=2*/ {7} }, { /*i1=3*/ {5} }, { /*i1=4*/ {6} } }, { /*i0=3*/ { /*i1=0*/ {0} }, { /*i1=1*/ {6} }, { /*i1=2*/ {2} }, { /*i1=3*/ {10} }, { /*i1=4*/ {2} } } }) - %constant.1 = f32[2,2,1,1]{3,2,1,0} constant(f32[2,2,1,1] { { /*i0=0*/ { /*i1=0*/ {2} }, { /*i1=1*/ {6} } }, { /*i0=1*/ { /*i1=0*/ {3} }, { /*i1=1*/ {1} } } }) + %constant = f32[4,5,1,1]{3,2,1,0} constant({ { /*i0=0*/ { /*i1=0*/ {7} }, { /*i1=1*/ {2} }, { /*i1=2*/ {5} }, { /*i1=3*/ {3} }, { /*i1=4*/ {8} } }, { /*i0=1*/ { /*i1=0*/ {3} }, { /*i1=1*/ {8} }, { /*i1=2*/ {9} }, { /*i1=3*/ {3} }, { /*i1=4*/ {4} } }, { /*i0=2*/ { /*i1=0*/ {1} }, { /*i1=1*/ {5} }, { /*i1=2*/ {7} }, { /*i1=3*/ {5} }, { /*i1=4*/ {6} } }, { /*i0=3*/ { /*i1=0*/ {0} }, { /*i1=1*/ {6} }, { /*i1=2*/ {2} }, { /*i1=3*/ {10} }, { /*i1=4*/ {2} } } }) + %constant.1 = f32[2,2,1,1]{3,2,1,0} constant({ { /*i0=0*/ { /*i1=0*/ {2} }, { /*i1=1*/ {6} } }, { /*i0=1*/ { /*i1=0*/ {3} }, { /*i1=1*/ {1} } } }) %constant.2 = f32[] constant(0) ROOT %select-and-scatter = f32[4,5,1,1]{3,2,1,0} select-and-scatter(f32[4,5,1,1]{3,2,1,0} %constant, f32[2,2,1,1]{3,2,1,0} %constant.1, f32[] %constant.2), window={size=2x3x1x1 stride=2x2x1x1}, select=%ge_F32.v3, scatter=%add_F32.v3 } @@ -523,7 +523,7 @@ ENTRY %slice.v2 (p0: f32[3,3,4,4]) -> f32[3,3,2,4] { R"(HloModule Slice3x3x3_To_1x3x3_F32_module ENTRY %Slice3x3x3_To_1x3x3_F32.v2 () -> f32[1,3,3] { - %constant = f32[3,3,3]{2,1,0} constant(f32[3,3,3] { { { 0, 1, 2 }, { 3, 4, 5 }, { 6, 7, 8 } }, { { 9, 10, 11 }, { 12, 13, 14 }, { 15, 16, 17 } }, { { 18, 19, 20 }, { 21, 22, 23 }, { 24, 25, 26 } } }) + %constant = f32[3,3,3]{2,1,0} constant({ { { 0, 1, 2 }, { 3, 4, 5 }, { 6, 7, 8 } }, { { 9, 10, 11 }, { 12, 13, 14 }, { 15, 16, 17 } }, { { 18, 19, 20 }, { 21, 22, 23 }, { 24, 25, 26 } } }) ROOT %slice = f32[1,3,3]{2,1,0} slice(f32[3,3,3]{2,1,0} %constant), slice={[0:1], [0:3], [0:3]} } @@ -547,7 +547,7 @@ ENTRY %SliceR0.v2 () -> s32[] { R"(HloModule Transpose_module ENTRY %Transpose.v2 () -> s32[1,2,3] { - %constant = s32[1,2,3]{2,1,0} constant(s32[1,2,3] { { { 1, 2, 3 }, { 4, 5, 6 } } }) + %constant = s32[1,2,3]{2,1,0} constant({ { { 1, 2, 3 }, { 4, 5, 6 } } }) ROOT %transpose = s32[1,2,3]{2,1,0} transpose(s32[1,2,3]{2,1,0} %constant), dimensions={0,1,2} } @@ -588,7 +588,7 @@ ENTRY %DynamicUpdateSlice.v4 (input: s32[1,1,25,1], update: s32[1,1,2,1], start_ R"(HloModule BasicTraining_module ENTRY %BasicTraining.v4 () -> (f32[2,2,1,2], f32[2], f32[2]) { - %constant = f32[2,2,1,2]{3,2,1,0} constant(f32[2,2,1,2] { { /*i0=0*/ { /*i1=0*/ { 1, 2 } }, { /*i1=1*/ { 3, 4 } } }, { /*i0=1*/ { /*i1=0*/ { 5, 6 } }, { /*i1=1*/ { 7, 8 } } } }) + %constant = f32[2,2,1,2]{3,2,1,0} constant({ { /*i0=0*/ { /*i1=0*/ { 1, 2 } }, { /*i1=1*/ { 3, 4 } } }, { /*i0=1*/ { /*i1=0*/ { 5, 6 } }, { /*i1=1*/ { 7, 8 } } } }) %constant.1 = f32[2]{0} constant({2, 3}) %constant.2 = f32[2]{0} constant({1, 2}) ROOT %batch-norm-training = (f32[2,2,1,2]{3,2,1,0}, f32[2]{0}, f32[2]{0}) batch-norm-training(f32[2,2,1,2]{3,2,1,0} %constant, f32[2]{0} %constant.1, f32[2]{0} %constant.2), epsilon=0.001, feature_index=3 @@ -728,7 +728,7 @@ R"(HloModule fusion_module } ENTRY %fusion.v3 () -> f32[3,2,1,1] { - %constant = f32[3,2,1,1]{3,2,1,0} constant(f32[3,2,1,1] { { /*i0=0*/ { /*i1=0*/ {-1} }, { /*i1=1*/ {4.1} } }, { /*i0=1*/ { /*i1=0*/ {2} }, { /*i1=1*/ {4.1} } }, { /*i0=2*/ { /*i1=0*/ {5} }, { /*i1=1*/ {4.4} } } }) + %constant = f32[3,2,1,1]{3,2,1,0} constant({ { /*i0=0*/ { /*i1=0*/ {-1} }, { /*i1=1*/ {4.1} } }, { /*i0=1*/ { /*i1=0*/ {2} }, { /*i1=1*/ {4.1} } }, { /*i0=2*/ { /*i1=0*/ {5} }, { /*i1=1*/ {4.4} } } }) %constant.1 = f32[2]{0} constant({3.14, 4.25}) ROOT %fusion = f32[3,2,1,1]{3,2,1,0} fusion(f32[3,2,1,1]{3,2,1,0} %constant, f32[2]{0} %constant.1), kind=kLoop, calls=%fused_computation } @@ -740,7 +740,7 @@ ENTRY %fusion.v3 () -> f32[3,2,1,1] { R"(HloModule sparse_f32 ENTRY %sparse () -> f32[2,3,4] { - ROOT %foo = f32[2,3,4]sparse{10} constant(f32[2,3,4]{[0, 1, 2]: 1, [1, 2, 3]: 2, [2, 3, 4]: 3}) + ROOT %foo = f32[2,3,4]sparse{10} constant({[0, 1, 2]: 1, [1, 2, 3]: 2, [2, 3, 4]: 3}) } )" @@ -750,7 +750,7 @@ ENTRY %sparse () -> f32[2,3,4] { R"(HloModule sparse_f32_empty ENTRY %sparse_f32_empty () -> f32[2,3,4] { - ROOT %foo = f32[2,3,4]sparse{10} constant(f32[2,3,4]{}) + ROOT %foo = f32[2,3,4]sparse{10} constant({}) } )" @@ -760,7 +760,7 @@ ENTRY %sparse_f32_empty () -> f32[2,3,4] { R"(HloModule sparse_f32_r1 ENTRY %sparse_f32_r1 () -> f32[9] { - ROOT %foo = f32[9]sparse{10} constant(f32[9]{1: 2, 3: 4, 5: 6}) + ROOT %foo = f32[9]sparse{10} constant({1: 2, 3: 4, 5: 6}) } )" @@ -931,11 +931,11 @@ ENTRY reduce_entry { R"(HloModule outfeed_module ENTRY InfeedToOutfeed { - token = token[] after-all() - infeed = ((u32[3]{0}, pred[]), token[]) infeed(token) + token0 = token[] after-all() + infeed = ((u32[3]{0}, pred[]), token[]) infeed(token0) infeed.data = (u32[3]{0}, pred[]) get-tuple-element(infeed), index=0 - outfeed = token[] outfeed(infeed.data, token) - ROOT infeed.1 = ((u32[3]{0}, pred[]), token[]) infeed(token) + outfeed = token[] outfeed(infeed.data, token0) + ROOT infeed.1 = ((u32[3]{0}, pred[]), token[]) infeed(token0) infeed.1.data = (u32[3]{0}, pred[]) get-tuple-element(infeed.1), index=0 infeed.1.token = token[] get-tuple-element(infeed.1), index=1 outfeed.1 = token[] outfeed(infeed.1.data, infeed.1.token) @@ -1117,9 +1117,9 @@ ENTRY Gather { )" }, -// cross-replica-sum +// all-reduce { -"CrossReplicaSum", +"AllReduce", R"(HloModule CRS add { @@ -1130,14 +1130,14 @@ add { ENTRY CRS { input = f32[8]{0} parameter(0) - ROOT crs = f32[8]{0} cross-replica-sum(input), replica_groups={}, to_apply=add + ROOT crs = f32[8]{0} all-reduce(input), replica_groups={}, to_apply=add } )" }, -// cross-replica-sum with subgroups +// all-reduce with subgroups { -"CrossReplicaSumWithSubgroups", +"AllReduceWithSubgroups", R"(HloModule CRS_Subgroups add { @@ -1146,16 +1146,16 @@ add { ROOT add = f32[] add(lhs, rhs) } -ENTRY CrossReplicaSumWithSubgroups { +ENTRY AllReduceWithSubgroups { input = f32[128,32]{0,1} parameter(0) - ROOT cross-replica-sum = f32[128,32]{0,1} cross-replica-sum(input), replica_groups={{0,1},{2,3}}, barrier="abc", to_apply=add + ROOT all-reduce = f32[128,32]{0,1} all-reduce(input), replica_groups={{0,1},{2,3}}, barrier="abc", to_apply=add } )" }, -// cross-replica-sum with all-reduce-id +// all-reduce with all-reduce-id { -"CrossReplicaSumAllReduce", +"AllReduceAllReduce", R"(HloModule CRS add { @@ -1166,8 +1166,8 @@ add { ENTRY CRS { input = f32[8]{0} parameter(0) - crs.1 = f32[8]{0} cross-replica-sum(input), replica_groups={{0}}, all_reduce_id=1, to_apply=add - ROOT crs.0 = f32[8]{0} cross-replica-sum(input), replica_groups={{0}}, all_reduce_id=1, to_apply=add + crs.1 = f32[8]{0} all-reduce(input), replica_groups={{0}}, all_reduce_id=1, to_apply=add + ROOT crs.0 = f32[8]{0} all-reduce(input), replica_groups={{0}}, all_reduce_id=1, to_apply=add } )" @@ -1266,8 +1266,8 @@ R"(HloModule AddDependency ENTRY AddDependency { p = f32[] parameter(0) neg = f32[] negate(p) - token = token[] after-all(neg) - p_after_token = f32[] add-dependency(p, token) + token0 = token[] after-all(neg) + p_after_token = f32[] add-dependency(p, token0) exp = f32[] exponential(p_after_token) ROOT sum = f32[] add(neg, exp) } @@ -1419,7 +1419,7 @@ TEST_F(HloParserTest, MoreConstants) { ENTRY %SelectScalarS32True.v4 () -> s32[] { %constant.2 = pred[] constant(true) - %constant.1 = s32[] constant(-42), sharding={s32[5,6] devices=[2,2]1,2,3,4} + %constant.1 = s32[] constant(-42), sharding={devices=[2,2]1,2,3,4} %constant = s32[] constant(42) %select = s32[] select(pred[] %constant.2, s32[] %constant.1, s32[] %constant) } @@ -1462,7 +1462,7 @@ TEST_F(HloParserTest, LiteralDimensionsMismatch_2) { const string original = R"(HloModule some_2x3_module ENTRY %some_2x3 () -> f32[2,3] { - ROOT %constant = f32[2,3]{1,0} constant(f32[2,3] {1, 2, 3, 4, 5, 6}) + ROOT %constant = f32[2,3]{1,0} constant({1, 2, 3, 4, 5, 6}) } )"; @@ -1476,7 +1476,7 @@ TEST_F(HloParserTest, LiteralDimensionsMismatch_3) { const string original = R"(HloModule some_2x3x2_module ENTRY %some_2x3x2 () -> f32[2,3,2] { - ROOT %constant = f32[2,3,2]{2,1,0} constant(f32[2,3,2] {{{1, 2}, {3, 4}, {5, 6}, {7, 8}, {9, 10}, {11, 12}}}) + ROOT %constant = f32[2,3,2]{2,1,0} constant({{{1, 2}, {3, 4}, {5, 6}, {7, 8}, {9, 10}, {11, 12}}}) } )"; @@ -1594,11 +1594,11 @@ TEST_F(HloParserTest, UnexpectedAttribute) { const string original = R"(HloModule unexpected_attr_module ENTRY %TwoSendRecvBothWayRecvFist.v3 () -> f32[] { - %token = token[] after-all() - %recv = (f32[], u32[], token[]) recv(token[] %token), channel_id=15 + %token0 = token[] after-all() + %recv = (f32[], u32[], token[]) recv(token[] %token0), channel_id=15 %recv-done = (f32[], token[]) recv-done((f32[], u32[], token[]) %recv), channel_id=15 ROOT %constant = f32[] constant(2.1) - %send = (f32[], u32[], token[]) send(f32[] %constant, token[] %token), channel_id=16, calls=%recv + %send = (f32[], u32[], token[]) send(f32[] %constant, token[] %token0), channel_id=16, calls=%recv %send-done = token[] send-done((f32[], u32[], token[]) %send), channel_id=16 } @@ -1611,11 +1611,11 @@ TEST_F(HloParserTest, MissingAttribute) { const string original = R"(HloModule missing_attr_module ENTRY %TwoSendRecvBothWayRecvFist.v3 () -> f32[] { - %token = token[] after-all() - %recv = (f32[], u32[], token[]) recv(token[] %token), channel_id=15 + %token0 = token[] after-all() + %recv = (f32[], u32[], token[]) recv(token[] %token0), channel_id=15 %recv-done = (f32[], token[]) recv-done((f32[], u32[], token[]) %recv), channel_id=15 ROOT %constant = f32[] constant(-2.1) - %send = (f32[], u32[], token[]) send(f32[] %constant, token[] %token) + %send = (f32[], u32[], token[]) send(f32[] %constant, token[] %token0) %send-done = token[] send-done((f32[], u32[], token[]) %send), channel_id=16 } @@ -1628,11 +1628,11 @@ TEST_F(HloParserTest, PredecessorUndefined) { const string original = R"(HloModule pre_not_found_module ENTRY %TwoSendRecvBothWayRecvFist.v3 () -> f32[] { - %token = token[] after-all() - %recv = (f32[], u32[], token[]) recv(token[] %token), channel_id=15 + %token0 = token[] after-all() + %recv = (f32[], u32[], token[]) recv(token[] %token0), channel_id=15 %recv-done = (f32[], token[]) recv-done((f32[], u32[], token[]) %recv), channel_id=15 ROOT %constant = f32[] constant(2.1) - %send = (f32[], u32[], token[]) send(f32[] %constant, token[] %token), channel_id=16, control-predecessors={%done} + %send = (f32[], u32[], token[]) send(f32[] %constant, token[] %token0), channel_id=16, control-predecessors={%done} %send-done = token[] send-done((f32[], u32[], token[]) %send), channel_id=16 } @@ -1940,8 +1940,8 @@ TEST_F(HloParserTest, ParsePaddingConfigInteriorPaddingImplicitZeroDim) { TEST_F(HloParserTest, NontupleInfeed) { const string original = R"(HloModule nontuple_infeed: ENTRY nontuple_infeed { - token = token[] after-all() - ROOT infeed = pred[] infeed(token) + token0 = token[] after-all() + ROOT infeed = pred[] infeed(token0) })"; ExpectHasSubstr(ParseHloString(original).status().error_message(), "infeed must have a non-empty tuple shape"); @@ -2239,7 +2239,7 @@ HloModule foobar ENTRY %entrycomp (p: f32[2,2]) -> f32[2,2] { %p = f32[2,2] parameter(0) - %constant.1 = f32[2,2] constant(f32[2,2] {{1, 2}, {3, 4}}) + %constant.1 = f32[2,2] constant({{1, 2}, {3, 4}}) ROOT %add.1 = f32[2,2] add(f32[2,2] %p, f32[2,5] %constant.1) } )"; @@ -2249,7 +2249,85 @@ ENTRY %entrycomp (p: f32[2,2]) -> f32[2,2] { " with the shape of the operand instruction f32[2,2]{1,0}."); } -// custom call incompatible shape. +TEST_F(HloParserTest, ParseShapeStringR2F32) { + string shape_string = "f32[123,456]"; + TF_ASSERT_OK_AND_ASSIGN(Shape actual, ParseShape(shape_string)); + Shape expected = ShapeUtil::MakeShape(F32, {123, 456}); + ASSERT_TRUE(ShapeUtil::Equal(expected, actual)) + << "expected: " << ShapeUtil::HumanString(expected) + << "actual: " << ShapeUtil::HumanString(actual); +} + +TEST_F(HloParserTest, ParseShapeStringTupleOfArrays) { + string shape_string = "(f32[1572864],s8[5120,1024])"; + TF_ASSERT_OK_AND_ASSIGN(Shape actual, ParseShape(shape_string)); + Shape expected = + ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(F32, {1572864}), + ShapeUtil::MakeShape(S8, {5120, 1024})}); + ASSERT_TRUE(ShapeUtil::Equal(expected, actual)) + << "expected: " << ShapeUtil::HumanString(expected) + << "actual: " << ShapeUtil::HumanString(actual); +} + +TEST_F(HloParserTest, ParseShapeStringNestedTuple) { + string shape_string = "(f32[1],(f32[2], token[]), opaque[], f32[3])"; + TF_ASSERT_OK_AND_ASSIGN(Shape actual, ParseShape(shape_string)); + Shape expected = ShapeUtil::MakeTupleShape({ + ShapeUtil::MakeShape(F32, {1}), + ShapeUtil::MakeTupleShape( + {ShapeUtil::MakeShape(F32, {2}), ShapeUtil::MakeTokenShape()}), + ShapeUtil::MakeOpaqueShape(), + ShapeUtil::MakeShape(F32, {3}), + }); + ASSERT_TRUE(ShapeUtil::Equal(expected, actual)) + << "expected: " << ShapeUtil::HumanString(expected) + << "actual: " << ShapeUtil::HumanString(actual); +} + +TEST_F(HloParserTest, ParseShapeStringWithLayout) { + string shape_string = "f32[123,456]{0,1}"; + TF_ASSERT_OK_AND_ASSIGN(Shape actual, ParseShape(shape_string)); + Shape expected = ShapeUtil::MakeShapeWithLayout(F32, {123, 456}, {0, 1}); + ASSERT_TRUE(ShapeUtil::Equal(expected, actual)) + << "expected: " << ShapeUtil::HumanString(expected) + << "actual: " << ShapeUtil::HumanString(actual); +} + +TEST_F(HloParserTest, ParseShapeStringWithSparseLayout) { + string shape_string = "f32[123,456]sparse{10}"; + TF_ASSERT_OK_AND_ASSIGN(Shape actual, ParseShape(shape_string)); + Shape expected = ShapeUtil::MakeShapeWithSparseLayout(F32, {123, 456}, 10); + ASSERT_TRUE(ShapeUtil::Equal(expected, actual)) + << "expected: " << ShapeUtil::HumanString(expected) + << "actual: " << ShapeUtil::HumanString(actual); +} + +TEST_F(HloParserTest, ParseOpaqueType) { + TF_ASSERT_OK_AND_ASSIGN(Shape actual, ParseShape("opaque[]")); + Shape expected = ShapeUtil::MakeOpaqueShape(); + ASSERT_TRUE(ShapeUtil::Equal(expected, actual)) + << "expected: " << ShapeUtil::HumanString(expected) + << "actual: " << ShapeUtil::HumanString(actual); +} + +TEST_F(HloParserTest, ParseTokenType) { + TF_ASSERT_OK_AND_ASSIGN(Shape actual, ParseShape("token[]")); + Shape expected = ShapeUtil::MakeTokenShape(); + ASSERT_TRUE(ShapeUtil::Equal(expected, actual)) + << "expected: " << ShapeUtil::HumanString(expected) + << "actual: " << ShapeUtil::HumanString(actual); +} + +TEST_F(HloParserTest, ParseInvalidShapeString) { + string shape_strings[] = { + "f32[123,456]foobar{0,1}", "f32[123,456]sparse{0,1}", "f32[123,456]{foo}", + "f32[123,456]dense{foo}", "f32[123,456]sparse{foo}", + }; + for (const string& shape_string : shape_strings) { + StatusOr result = ParseShape(shape_string); + ASSERT_FALSE(result.ok()) << "shape: " << shape_string; + } +} } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_pass_pipeline.cc b/tensorflow/compiler/xla/service/hlo_pass_pipeline.cc index 312b5d020c398feb7738d14a9cfa0928d5178948..33ce7e23a82d840676bba5f1ca9c0ffc4433465d 100644 --- a/tensorflow/compiler/xla/service/hlo_pass_pipeline.cc +++ b/tensorflow/compiler/xla/service/hlo_pass_pipeline.cc @@ -77,6 +77,11 @@ std::vector HloPassPipeline::GetEnabledPasses( auto repeated_field = debug_options.xla_disable_hlo_passes(); absl::flat_hash_set disabled_pass_names(repeated_field.begin(), repeated_field.end()); + if (debug_options.xla_disable_all_hlo_passes()) { + VLOG(1) << "*All* passes disabled by --xla_disable_all_hlo_passes."; + return {}; + } + if (!disabled_pass_names.empty()) { VLOG(1) << "Passes disabled by --xla_disable_hlo_passes: " << absl::StrJoin(disabled_pass_names, ", "); @@ -113,7 +118,7 @@ void HloPassPipeline::MaybeDumpHlo(const HloModule& module, } const string message = - StrCat("after ", after_pass_name, ", before ", before_pass_name); + absl::StrCat("after ", after_pass_name, ", before ", before_pass_name); hlo_graph_dumper::MaybeDumpHloModule(module, message); VLOG(3) << "HLO " << message << ":"; VLOG(3) << module.entry_computation_layout().ToString(); diff --git a/tensorflow/compiler/xla/service/hlo_proto_util.cc b/tensorflow/compiler/xla/service/hlo_proto_util.cc index 981d06ce101644ecce587c4bd2f7a12c8edf6548..3a9ee57e5551ae5b608f02d9f8bd0428ff16db13 100644 --- a/tensorflow/compiler/xla/service/hlo_proto_util.cc +++ b/tensorflow/compiler/xla/service/hlo_proto_util.cc @@ -39,6 +39,7 @@ HloProto MakeHloProto(const HloModule& module) { StatusOr> CreateModuleFromProto( const HloModuleProto& proto, const HloModuleConfig& module_config) { + VLOG(4) << proto.ShortDebugString(); TF_ASSIGN_OR_RETURN(std::unique_ptr module, HloModule::CreateFromProto(proto, module_config)); TF_RETURN_IF_ERROR( diff --git a/tensorflow/compiler/xla/service/hlo_reachability.cc b/tensorflow/compiler/xla/service/hlo_reachability.cc index 4aa8067752481ffab29e1a573ffa49d4aa046f1f..edaa4c59e2674e5f165c468059747d3dd2d54218 100644 --- a/tensorflow/compiler/xla/service/hlo_reachability.cc +++ b/tensorflow/compiler/xla/service/hlo_reachability.cc @@ -93,7 +93,7 @@ std::unique_ptr HloReachabilityMap::Build( } break; } - case HloOpcode::kCrossReplicaSum: { + case HloOpcode::kAllReduce: { auto all_reduce_id = hlo->all_reduce_id(); if (all_reduce_id) { auto it = channel_dependency_map.find(all_reduce_id.value()); diff --git a/tensorflow/compiler/xla/service/hlo_rematerialization.cc b/tensorflow/compiler/xla/service/hlo_rematerialization.cc index 48add75523f02005c70bc6baf69a6b7d5aa4f7ef..ac74e2432f2176e13eaf7d4a1934a50ee89d1042 100644 --- a/tensorflow/compiler/xla/service/hlo_rematerialization.cc +++ b/tensorflow/compiler/xla/service/hlo_rematerialization.cc @@ -63,7 +63,7 @@ bool IsRematerializable(const HloInstruction* instruction) { case HloOpcode::kCall: case HloOpcode::kConstant: case HloOpcode::kConditional: - case HloOpcode::kCrossReplicaSum: + case HloOpcode::kAllReduce: case HloOpcode::kCustomCall: case HloOpcode::kParameter: case HloOpcode::kWhile: diff --git a/tensorflow/compiler/xla/service/hlo_token.h b/tensorflow/compiler/xla/service/hlo_token.h deleted file mode 100644 index 4458c251dee4af365e39027dd4289925c8890efd..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/xla/service/hlo_token.h +++ /dev/null @@ -1,78 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_TOKEN_H_ -#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_TOKEN_H_ - -#include - -#include "tensorflow/compiler/xla/types.h" -#include "tensorflow/core/platform/types.h" - -namespace xla { - -// Defines different kinds of tokens in a hlo module string. -// -// You shouldn't need to use this directly unless you're using HloLexer -// directly, and you probably don't need to do that. Use hlo_parser instead. -enum class TokKind { - // Markers - kEof, - kError, - - // Tokens with no info. - kEqual, // = - kComma, // , - kColon, // : - kLsquare, - kRsquare, // [ ] - kLbrace, - kRbrace, // { } - kLparen, - kRparen, // ( ) - - kArrow, // -> - - // Keywords - kw_HloModule, - kw_ENTRY, - kw_ROOT, - kw_true, - kw_false, - kw_maximal, - kw_replicated, - kw_nan, - kw_inf, - - kNegInf, // -inf - - // Typed tokens. - kName, // %foo - kAttributeName, // dimensions= - kDimLabels, // [0-9bf]{2,}_[0-9io]{2,}->[0-9bf]{2,} - kDxD, // [0-9]+(x[0-9]+)+ - kPad, // [0-9]+_[0-9]+(_[0-9]+)?(x[0-9]+_[0-9]+(_[0-9]+)?)* - kIdent, // other identifiers - kString, // "abcd\"\n" - kShape, // f32[2,3]{1,0} - kInt, // 42 - kDecimal, // 4.2 -}; - -string TokKindToString(TokKind kind); - -} // namespace xla - -#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_TOKEN_H_ diff --git a/tensorflow/compiler/xla/service/hlo_verifier.cc b/tensorflow/compiler/xla/service/hlo_verifier.cc index 77db7b098a38ff4efdcc7447935fae61561c9ff4..e1c737132f72948e0e46d37dd08ddf8e7b29bfca 100644 --- a/tensorflow/compiler/xla/service/hlo_verifier.cc +++ b/tensorflow/compiler/xla/service/hlo_verifier.cc @@ -44,7 +44,7 @@ bool IsCallerInstruction(HloInstruction* hlo) { case HloOpcode::kCall: case HloOpcode::kConditional: case HloOpcode::kWhile: - case HloOpcode::kCrossReplicaSum: + case HloOpcode::kAllReduce: case HloOpcode::kMap: case HloOpcode::kReduce: case HloOpcode::kReduceWindow: @@ -153,8 +153,8 @@ Status ShapeVerifier::HandleConvolution(HloInstruction* convolution) { const Shape expected, ShapeInference::InferConvolveShape( convolution->operand(0)->shape(), convolution->operand(1)->shape(), - convolution->feature_group_count(), convolution->window(), - convolution->convolution_dimension_numbers())); + convolution->feature_group_count(), convolution->batch_group_count(), + convolution->window(), convolution->convolution_dimension_numbers())); return CheckShape(convolution, expected); } @@ -167,13 +167,12 @@ Status ShapeVerifier::HandleFft(HloInstruction* fft) { return CheckShape(fft, expected); } -Status ShapeVerifier::HandleCrossReplicaSum(HloInstruction* crs) { +Status ShapeVerifier::HandleAllReduce(HloInstruction* crs) { std::vector operand_shapes; for (const HloInstruction* operand : crs->operands()) { operand_shapes.push_back(&operand->shape()); } - return CheckShape(crs, - ShapeInference::InferCrossReplicaSumShape(operand_shapes)); + return CheckShape(crs, ShapeInference::InferAllReduceShape(operand_shapes)); } Status ShapeVerifier::HandleAllToAll(HloInstruction* hlo) { @@ -481,7 +480,9 @@ Status ShapeVerifier::HandleCustomCall(HloInstruction* instruction) { const Shape& operand_shape_with_layout = custom_call->operand_shapes_with_layout()[i]; TF_RET_CHECK(ShapeUtil::Compatible(custom_call->operand(i)->shape(), - operand_shape_with_layout)); + operand_shape_with_layout)) + << custom_call->operand(i)->shape().ToString() << " operand " + << operand_shape_with_layout.ToString(); TF_RET_CHECK(LayoutUtil::HasLayout(operand_shape_with_layout)); } } @@ -683,7 +684,7 @@ Status CheckMixedPrecisionOperands(const HloInstruction* instruction) { case HloOpcode::kCall: case HloOpcode::kConditional: case HloOpcode::kConstant: - case HloOpcode::kCrossReplicaSum: + case HloOpcode::kAllReduce: case HloOpcode::kCustomCall: case HloOpcode::kDomain: case HloOpcode::kFusion: @@ -1344,7 +1345,7 @@ class InstructionVerifier : public DfsHloVisitorWithDefault { return Status::OK(); } - Status HandleCrossReplicaSum(HloInstruction* crs) override { + Status HandleAllReduce(HloInstruction* crs) override { if (crs->all_reduce_id().has_value()) { TF_RET_CHECK(crs->all_reduce_id().value() > 0) << "All reduce id must be greater than 0 for " diff --git a/tensorflow/compiler/xla/service/hlo_verifier.h b/tensorflow/compiler/xla/service/hlo_verifier.h index e4d0c3d6957885f1d719fedb5a900de601e397f8..a1a6aba9728c137d17487b5914f67cb3966fc12b 100644 --- a/tensorflow/compiler/xla/service/hlo_verifier.h +++ b/tensorflow/compiler/xla/service/hlo_verifier.h @@ -52,7 +52,7 @@ class ShapeVerifier : public DfsHloVisitor { Status HandleDot(HloInstruction* dot) override; Status HandleConvolution(HloInstruction* convolution) override; Status HandleFft(HloInstruction* fft) override; - Status HandleCrossReplicaSum(HloInstruction* crs) override; + Status HandleAllReduce(HloInstruction* crs) override; Status HandleAllToAll(HloInstruction* hlo) override; Status HandleCollectivePermute(HloInstruction* hlo) override; Status HandleReducePrecision(HloInstruction* reduce_precision) override; diff --git a/tensorflow/compiler/xla/service/indexed_array_analysis_test.cc b/tensorflow/compiler/xla/service/indexed_array_analysis_test.cc index 98246d5403e4aebc2f4d81e52145706355ddd9a9..295465c8481bcb7d1385192febe0d09614e393b3 100644 --- a/tensorflow/compiler/xla/service/indexed_array_analysis_test.cc +++ b/tensorflow/compiler/xla/service/indexed_array_analysis_test.cc @@ -99,7 +99,7 @@ TEST_F(IndexedArrayAnalysisTest, SimpleOneToOneConstantGather) { HloModule SimpleGather ENTRY main { - operand = s32[3,3] constant(s32[3,3]{{1,2,3},{1,2,3},{1,2,3}}) + operand = s32[3,3] constant({{1,2,3},{1,2,3},{1,2,3}}) indices = s32[5] parameter(0) ROOT gather = s32[5,3] gather(operand, indices), offset_dims={1}, @@ -119,7 +119,7 @@ TEST_F(IndexedArrayAnalysisTest, GatherIsNotScalarIndexed0) { HloModule SimpleGather ENTRY main { - operand = s32[3,3] constant(s32[3,3]{{1,2,3},{1,2,3},{1,2,3}}) + operand = s32[3,3] constant({{1,2,3},{1,2,3},{1,2,3}}) indices = s32[5,2] parameter(0) ROOT gather = s32[5] gather(operand, indices), offset_dims={}, @@ -195,7 +195,7 @@ TEST_F(IndexedArrayAnalysisTest, GatherOfGather_OneToOne) { HloModule SimpleGather ENTRY main { - operand = s32[3,3] constant(s32[3,3]{{1,2,3},{1,2,3},{1,2,3}}) + operand = s32[3,3] constant({{1,2,3},{1,2,3},{1,2,3}}) indices_a = s32[5] parameter(0) indices_b = s32[2] parameter(1) gather_a = s32[5,3] gather(operand, indices_a), @@ -309,7 +309,7 @@ TEST_F(IndexedArrayAnalysisTest, ReshapeOfGather0) { HloModule ReshapeOfGather ENTRY main { - operand = s32[3,4] constant(s32[3,4]{{1,2,3,4},{1,2,3,4},{1,2,3,4}}) + operand = s32[3,4] constant({{1,2,3,4},{1,2,3,4},{1,2,3,4}}) indices = s32[5] parameter(0) gather = s32[5,4] gather(operand, indices), offset_dims={1}, @@ -330,7 +330,7 @@ TEST_F(IndexedArrayAnalysisTest, ReshapeOfGather1) { HloModule ReshapeOfGather ENTRY main { - operand = s32[3,4] constant(s32[3,4]{{1,2,3,4},{1,2,3,4},{1,2,3,4}}) + operand = s32[3,4] constant({{1,2,3,4},{1,2,3,4},{1,2,3,4}}) indices = s32[5,7] parameter(0) gather = s32[5,4,7] gather(operand, indices), offset_dims={1}, @@ -352,7 +352,7 @@ TEST_F(IndexedArrayAnalysisTest, ReshapeOfGather2) { HloModule ReshapeOfGather ENTRY main { - operand = s32[3,2,6] constant(s32[3,2,6]{ + operand = s32[3,2,6] constant({ {{1,2,3,4,5,6},{1,2,3,4,5,6}}, {{1,2,3,4,5,6},{1,2,3,4,5,6}}, {{1,2,3,4,5,6},{1,2,3,4,5,6}}}) @@ -377,7 +377,7 @@ TEST_F(IndexedArrayAnalysisTest, ReshapeOfGather3) { HloModule ReshapeOfGather ENTRY main { - operand = s32[2,6] constant(s32[2,6]{ + operand = s32[2,6] constant({ {1,2,3,4,5,6},{1,2,3,4,5,6}}) indices = s32[1] parameter(0) gather = s32[1,6] gather(operand, indices), @@ -405,7 +405,7 @@ TEST_F(IndexedArrayAnalysisTest, ReshapeOfGather4) { HloModule ReshapeOfGather ENTRY main { - operand = s32[2,3]{1,0} constant(s32[2,3] { { 1, 2, 3 }, { 1, 2, 3 } }) + operand = s32[2,3]{1,0} constant({ { 1, 2, 3 }, { 1, 2, 3 } }) i.0 = s64[1,3]{1,0} parameter(0) g.0 = s32[1,3,3]{2,1,0} gather(operand, i.0), offset_dims={2}, @@ -438,7 +438,7 @@ TEST_F(IndexedArrayAnalysisTest, ReshapeOfGather5) { HloModule ReshapeOfGather ENTRY main { - operand = s32[1,6] constant(s32[1,6]{{1,2,3,4,5,6}}) + operand = s32[1,6] constant({{1,2,3,4,5,6}}) indices = s32[1] parameter(0) gather = s32[1,6] gather(operand, indices), offset_dims={1}, @@ -465,7 +465,7 @@ TEST_F(IndexedArrayAnalysisTest, ReshapeOfGather6) { HloModule ReshapeOfGather ENTRY main { - operand = s32[1,2,6] constant(s32[1,2,6]{{ + operand = s32[1,2,6] constant({{ {1,2,3,4,5,6},{1,2,3,4,5,6}}}) indices = s32[1] parameter(0) gather = s32[1,1,6] gather(operand, indices), @@ -496,7 +496,7 @@ TEST_F(IndexedArrayAnalysisTest, ReshapeOfGather7) { HloModule ReshapeOfGather ENTRY main { - operand = s32[2,6] constant(s32[2,6]{ + operand = s32[2,6] constant({ {1,2,3,4,5,6},{1,2,3,4,5,6}}) indices = s32[1,5] parameter(0) gather = s32[1,5,6] gather(operand, indices), @@ -527,7 +527,7 @@ TEST_F(IndexedArrayAnalysisTest, ReshapeOfGatherNoFold0) { HloModule ReshapeOfGather ENTRY main { - operand = s32[3,4] constant(s32[3,4]{{1,2,3,4},{1,2,3,4},{1,2,3,4}}) + operand = s32[3,4] constant({{1,2,3,4},{1,2,3,4},{1,2,3,4}}) indices = s32[5,6] parameter(0) gather = s32[5,4,6] gather(operand, indices), offset_dims={1}, @@ -556,7 +556,7 @@ TEST_F(IndexedArrayAnalysisTest, ReshapeOfGatherNoFold1) { HloModule ReshapeOfGather ENTRY main { - operand = s32[3,5,2] constant(s32[3,5,2]{ + operand = s32[3,5,2] constant({ {{1,2},{3,4},{5,6},{7,8},{9,10}}, {{1,2},{3,4},{5,6},{7,8},{9,10}}, {{1,2},{3,4},{5,6},{7,8},{9,10}}}) @@ -588,7 +588,7 @@ TEST_F(IndexedArrayAnalysisTest, ReshapeOfGatherNoFold2) { HloModule ReshapeOfGather ENTRY main { - operand = s32[3,4,1] constant(s32[3,4,1]{ + operand = s32[3,4,1] constant({ {{1},{2},{3},{4}}, {{1},{2},{3},{4}}, {{1},{2},{3},{4}}}) @@ -620,7 +620,7 @@ TEST_F(IndexedArrayAnalysisTest, UnaryOpOfGather) { HloModule UnaryOpOfGather ENTRY main { - operand = f32[3,4] constant(f32[3,4]{{1,2,3,4},{1,3,2,4},{4,3,2,1}}) + operand = f32[3,4] constant({{1,2,3,4},{1,3,2,4},{4,3,2,1}}) indices = s32[5] parameter(0) gather = f32[5,4] gather(operand, indices), offset_dims={1}, @@ -645,7 +645,7 @@ TEST_F(IndexedArrayAnalysisTest, AddBroadcastedScalarWithGather) { HloModule AddBroadcastedScalarWithGather ENTRY main { - gather_operand = s32[3,4] constant(s32[3,4]{{1,2,3,4},{1,3,2,4},{4,3,2,1}}) + gather_operand = s32[3,4] constant({{1,2,3,4},{1,3,2,4},{4,3,2,1}}) constant = s32[] constant(5) constant_broadcasted = s32[5,4] broadcast(constant), dimensions={} indices = s32[5] parameter(0) @@ -673,7 +673,7 @@ TEST_F(IndexedArrayAnalysisTest, HloModule SubtractBroadcastedScalarWithGather ENTRY main { - gather_operand = s32[3,4] constant(s32[3,4]{{1,2,3,4},{1,3,2,4},{4,3,2,1}}) + gather_operand = s32[3,4] constant({{1,2,3,4},{1,3,2,4},{4,3,2,1}}) constant = s32[] constant(5) constant_broadcasted = s32[5,4] broadcast(constant), dimensions={} indices = s32[5] parameter(0) @@ -701,7 +701,7 @@ TEST_F(IndexedArrayAnalysisTest, HloModule SubtractBroadcastedScalarWithGather ENTRY main { - gather_operand = s32[3,4] constant(s32[3,4]{{1,2,3,4},{1,3,2,4},{4,3,2,1}}) + gather_operand = s32[3,4] constant({{1,2,3,4},{1,3,2,4},{4,3,2,1}}) constant = s32[] constant(5) constant_broadcasted = s32[5,4] broadcast(constant), dimensions={} indices = s32[5] parameter(0) @@ -728,7 +728,7 @@ TEST_F(IndexedArrayAnalysisTest, AddBroadcastedVectorWithGather) { HloModule AddBroadcastedVectorWithGather ENTRY main { - gather_operand = s32[3,4] constant(s32[3,4]{{1,2,3,4},{1,3,2,4},{4,3,2,1}}) + gather_operand = s32[3,4] constant({{1,2,3,4},{1,3,2,4},{4,3,2,1}}) constant_vect = s32[4] constant({10,11,12,13}) constant_broadcasted = s32[5,4] broadcast(constant_vect), dimensions={1} indices = s32[5] parameter(0) @@ -755,7 +755,7 @@ TEST_F(IndexedArrayAnalysisTest, AddBroadcastedVectorWithGather_Negative) { HloModule AddBroadcastedVectorWithGather ENTRY main { - gather_operand = s32[3,4] constant(s32[3,4]{{1,2,3,4},{1,3,2,4},{4,3,2,1}}) + gather_operand = s32[3,4] constant({{1,2,3,4},{1,3,2,4},{4,3,2,1}}) constant_vect = s32[5] constant({10,11,12,13,14}) constant_broadcasted = s32[5,4] broadcast(constant_vect), dimensions={0} indices = s32[5] parameter(0) @@ -804,8 +804,8 @@ TEST_F(IndexedArrayAnalysisTest, DotOpBasic_0) { HloModule DotOp ENTRY main { - gather_operand = s32[3,4] constant(s32[3,4]{{1,2,3,4},{5,6,7,8},{9,10,11,12}}) - dot_rhs_constant = s32[4,3] constant(s32[4,3]{{1,2,3},{4,5,6},{7,8,9},{10,11,12}}) + gather_operand = s32[3,4] constant({{1,2,3,4},{5,6,7,8},{9,10,11,12}}) + dot_rhs_constant = s32[4,3] constant({{1,2,3},{4,5,6},{7,8,9},{10,11,12}}) indices = s32[5] parameter(0) dot_lhs = s32[5,4] gather(gather_operand, indices), offset_dims={1}, @@ -831,8 +831,8 @@ TEST_F(IndexedArrayAnalysisTest, DotOpBasic_1) { HloModule DotOp ENTRY main { - gather_operand = s32[3,4] constant(s32[3,4]{{1,2,3,4},{5,6,7,8},{9,10,11,12}}) - dot_rhs_constant = s32[3,3] constant(s32[3,3]{{1,2,3},{4,5,6},{7,8,9}}) + gather_operand = s32[3,4] constant({{1,2,3,4},{5,6,7,8},{9,10,11,12}}) + dot_rhs_constant = s32[3,3] constant({{1,2,3},{4,5,6},{7,8,9}}) indices = s32[5] parameter(0) dot_lhs = s32[3,5] gather(gather_operand, indices), offset_dims={0}, @@ -859,8 +859,8 @@ TEST_F(IndexedArrayAnalysisTest, DotOpBasic_2) { HloModule DotOp ENTRY main { - gather_operand = s32[3,4] constant(s32[3,4]{{1,2,3,4},{5,6,7,8},{9,10,11,12}}) - dot_lhs_constant = s32[4,3] constant(s32[4,3]{{1,2,3},{4,5,6},{7,8,9},{10,11,12}}) + gather_operand = s32[3,4] constant({{1,2,3,4},{5,6,7,8},{9,10,11,12}}) + dot_lhs_constant = s32[4,3] constant({{1,2,3},{4,5,6},{7,8,9},{10,11,12}}) indices = s32[5] parameter(0) dot_rhs = s32[3,5] gather(gather_operand, indices), offset_dims={0}, @@ -888,8 +888,8 @@ TEST_F(IndexedArrayAnalysisTest, DotOpBasic_3) { HloModule DotOp ENTRY main { - gather_operand = s32[4,3] constant(s32[4,3]{{1,2,3},{4,5,6},{7,8,9},{10,11,12}}) - dot_lhs_constant = s32[4,3] constant(s32[4,3]{{1,2,3},{4,5,6},{7,8,9},{10,11,12}}) + gather_operand = s32[4,3] constant({{1,2,3},{4,5,6},{7,8,9},{10,11,12}}) + dot_lhs_constant = s32[4,3] constant({{1,2,3},{4,5,6},{7,8,9},{10,11,12}}) indices = s32[5] parameter(0) dot_rhs = s32[5,3] gather(gather_operand, indices), offset_dims={1}, @@ -917,8 +917,8 @@ TEST_F(IndexedArrayAnalysisTest, DotOpWithBatch) { HloModule DotOp ENTRY main { - gather_operand = s32[2,3,2] constant(s32[2,3,2]{{{1,2},{3,4},{5,6}},{{7,8},{9,10},{11,12}}}) - dot_lhs_constant = s32[2,2,3] constant(s32[2,2,3]{{{1,2,3},{4,5,6}},{{7,8,9},{10,11,12}}}) + gather_operand = s32[2,3,2] constant({{{1,2},{3,4},{5,6}},{{7,8},{9,10},{11,12}}}) + dot_lhs_constant = s32[2,2,3] constant({{{1,2,3},{4,5,6}},{{7,8,9},{10,11,12}}}) indices = s32[4] parameter(0) dot_rhs = s32[2,3,4] gather(gather_operand, indices), offset_dims={0,1}, @@ -948,8 +948,8 @@ TEST_F(IndexedArrayAnalysisTest, DotOpNegative) { HloModule DotOp ENTRY main { - gather_operand = s32[3,4] constant(s32[3,4]{{1,2,3,4},{5,6,7,8},{9,10,11,12}}) - dot_rhs_constant = s32[2,3] constant(s32[2,3]{{1,2,3},{4,5,6}}) + gather_operand = s32[3,4] constant({{1,2,3,4},{5,6,7,8},{9,10,11,12}}) + dot_rhs_constant = s32[2,3] constant({{1,2,3},{4,5,6}}) indices = s32[2] parameter(0) dot_lhs = s32[3,2] gather(gather_operand, indices), offset_dims={0}, diff --git a/tensorflow/compiler/xla/service/instruction_fusion.cc b/tensorflow/compiler/xla/service/instruction_fusion.cc index 7559ed1bab84b21a4d51bc38db999900befcfad7..07448715293ca8dde5492a054b84c3408004bdaf 100644 --- a/tensorflow/compiler/xla/service/instruction_fusion.cc +++ b/tensorflow/compiler/xla/service/instruction_fusion.cc @@ -23,6 +23,7 @@ limitations under the License. #include "absl/algorithm/container.h" #include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" #include "absl/memory/memory.h" #include "tensorflow/compiler/xla/map_util.h" #include "tensorflow/compiler/xla/service/fusion_queue.h" @@ -126,7 +127,7 @@ bool IsAlwaysDuplicable(const HloInstruction& instruction) { case HloOpcode::kCall: case HloOpcode::kConditional: case HloOpcode::kConvolution: - case HloOpcode::kCrossReplicaSum: + case HloOpcode::kAllReduce: case HloOpcode::kAllToAll: case HloOpcode::kCollectivePermute: case HloOpcode::kCustomCall: @@ -570,19 +571,42 @@ HloInstruction* InstructionFusion::FuseIntoMultiOutput( bool InstructionFusion::MultiOutputFusionCreatesCycle( HloInstruction* producer, HloInstruction* consumer) { - auto is_reachable = [&](const HloInstruction* a, const HloInstruction* b) { - // A consumer operand may have been multi-output fused into a parallel - // consumer and thus be missing from the original reachability map. - if (!reachability_->IsPresent(a) || !reachability_->IsPresent(b)) { - reachability_ = HloReachabilityMap::Build(consumer->parent()); + absl::flat_hash_set operands; + for (const HloInstruction* operand : consumer->operands()) { + if (operand == producer) { + continue; + } + + // If the reachability map already contains the producer and the operand of + // the consumer, and the producer can reach the operand, then we know for + // sure MultiOutputFusion would create a cycle. If not, we need to do a DFS + // traversal of the computation to verify that this multioutput fusion would + // not create a cycle. + if (reachability_->IsPresent(producer) && + reachability_->IsPresent(operand) && + reachability_->IsReachable(producer, operand)) { + return true; } - return reachability_->IsReachable(a, b); - }; - return absl::c_any_of(consumer->operands(), - [&](const HloInstruction* consumer_operand) { - return consumer_operand != producer && - is_reachable(producer, consumer_operand); - }); + operands.insert(operand->unique_id()); + } + + // Do a DFS on the producer to see if any of the other consumer operands are + // reachable in the current state of the graph. + std::vector worklist = producer->users(); + absl::flat_hash_set visits; + while (!worklist.empty()) { + const HloInstruction* user = worklist.back(); + worklist.pop_back(); + if (operands.count(user->unique_id()) != 0) { + return true; + } + if (visits.count(user->unique_id()) == 0) { + visits.insert(user->unique_id()); + worklist.insert(worklist.end(), user->users().begin(), + user->users().end()); + } + } + return false; } bool InstructionFusion::ShouldFuse(HloInstruction* consumer, diff --git a/tensorflow/compiler/xla/service/instruction_fusion_test.cc b/tensorflow/compiler/xla/service/instruction_fusion_test.cc index 58b7135cea7419f13d60ed510ecf7a88126aee48..611cfd404d7622f561f0acc86fc9b05e16eea22e 100644 --- a/tensorflow/compiler/xla/service/instruction_fusion_test.cc +++ b/tensorflow/compiler/xla/service/instruction_fusion_test.cc @@ -259,8 +259,8 @@ TEST_F(InstructionFusionTest, AvoidDuplicationIfNotAllFusibleRecursively) { add = f32[4,3]{1,0} add(p0, p0) abs1 = f32[4,3]{1,0} abs(add) log = f32[4,3]{1,0} log(abs1) - token = token[] after-all() - send = f32[4,3]{1,0} send(log, token), channel_id=0 + token0 = token[] after-all() + send = f32[4,3]{1,0} send(log, token0), channel_id=0 abs2 = f32[4,3]{1,0} abs(log) ROOT root = f32[4,3]{1,0} subtract(abs2, add) })") @@ -290,8 +290,8 @@ TEST_F(InstructionFusionTest, AvoidDuplicationIfNotAllFusibleRecursively) { p0 = f32[4,3]{1,0} parameter(0) add1 = f32[4,3]{1,0} add(p0, p0) log = f32[4,3]{1,0} log(p0) - token = token[] after-all() - send = f32[4,3]{1,0} send(log, token), channel_id=0 + token0 = token[] after-all() + send = f32[4,3]{1,0} send(log, token0), channel_id=0 add2 = f32[4,3]{1,0} add(log, add1) ROOT root = f32[4,3]{1,0} subtract(add1, add2) })") @@ -324,8 +324,8 @@ TEST_F(InstructionFusionTest, AvoidDuplicationIfNotAllFusibleRecursively) { add1 = f32[4,3]{1,0} add(p0, p0) add2 = f32[4,3]{1,0} add(add1, add1) log = f32[4,3]{1,0} log(add2) - token = token[] after-all() - send = f32[4,3]{1,0} send(log, token), channel_id=0 + token0 = token[] after-all() + send = f32[4,3]{1,0} send(log, token0), channel_id=0 sub1 = f32[4,3]{1,0} subtract(log, add2) sub2 = f32[4,3]{1,0} subtract(add2, add1) ROOT root = (f32[4,3]{1,0}, f32[4,3]{1,0}) tuple(sub1, sub2) diff --git a/tensorflow/compiler/xla/service/interpreter/compiler.cc b/tensorflow/compiler/xla/service/interpreter/compiler.cc index 3a5177c418e3af8253df228a51f2fc0901d10041..d37ae94bf6c4c697bbf30390c02a5099271e00a4 100644 --- a/tensorflow/compiler/xla/service/interpreter/compiler.cc +++ b/tensorflow/compiler/xla/service/interpreter/compiler.cc @@ -76,9 +76,12 @@ StatusOr> InterpreterCompiler::RunBackend( // need to compile anything // Create executable from only the Hlo module. + auto evaluator = absl::make_unique(); + evaluator->set_use_fast_path( + hlo_module->config().debug_options().xla_hlo_evaluator_use_fast_path()); std::unique_ptr executable = - absl::make_unique( - std::move(hlo_module), absl::make_unique()); + absl::make_unique(std::move(hlo_module), + std::move(evaluator)); return std::move(executable); } diff --git a/tensorflow/compiler/xla/service/layout_assignment.cc b/tensorflow/compiler/xla/service/layout_assignment.cc index eddef850cf5250b85b564c1e6c92d1cc8ecd1a43..b9ddd9636fe29e85092ed67fc644a54332b218d3 100644 --- a/tensorflow/compiler/xla/service/layout_assignment.cc +++ b/tensorflow/compiler/xla/service/layout_assignment.cc @@ -2012,7 +2012,7 @@ bool LayoutAssignment::InstructionCanChangeLayout( case HloOpcode::kConditional: case HloOpcode::kConvert: case HloOpcode::kCos: - case HloOpcode::kCrossReplicaSum: + case HloOpcode::kAllReduce: case HloOpcode::kAllToAll: case HloOpcode::kCollectivePermute: case HloOpcode::kDivide: diff --git a/tensorflow/compiler/xla/service/layout_assignment_test.cc b/tensorflow/compiler/xla/service/layout_assignment_test.cc index 5c661bfacb08fe27f3cbdc1fb9db083315166008..31d78752f07c57aef6023fabb8e3a7de20c4278c 100644 --- a/tensorflow/compiler/xla/service/layout_assignment_test.cc +++ b/tensorflow/compiler/xla/service/layout_assignment_test.cc @@ -847,12 +847,12 @@ TEST_F(LayoutAssignmentTest, ChannelLayoutMismatch) { ENTRY entry_computation { param = (f32[2,2]) parameter(0) gte = f32[2,2] get-tuple-element(param), index=0 - token = token[] after-all() - recv = (f32[2,2], u32[], token[]) recv(token), channel_id=1, sharding={maximal device=1} + token0 = token[] after-all() + recv = (f32[2,2], u32[], token[]) recv(token0), channel_id=1, sharding={maximal device=1} recv-done = (f32[2,2], token[]) recv-done(recv), channel_id=1, sharding={maximal device=1} ROOT root = f32[2,2] get-tuple-element(recv-done), index=0 - send = (f32[2,2], u32[], token[]) send(gte, token), channel_id=1, + send = (f32[2,2], u32[], token[]) send(gte, token0), channel_id=1, sharding={maximal device=0} send-done = token[] send-done(send), channel_id=1, sharding={maximal device=0} } @@ -894,11 +894,11 @@ TEST_F(LayoutAssignmentTest, AllReduceLayoutMissmatch) { ENTRY entry_computation { param = (f32[2,2]) parameter(0) gte = f32[2,2] get-tuple-element(param), index=0 - ar.0 = f32[2,2] cross-replica-sum(gte), + ar.0 = f32[2,2] all-reduce(gte), all_reduce_id=1, replica_groups={{0}}, to_apply=add, sharding={maximal device=0} - const = f32[2,2] constant(f32[2,2]{{0,1},{2,3}}) - ROOT ar.1 = f32[2,2] cross-replica-sum(const), + const = f32[2,2] constant({{0,1},{2,3}}) + ROOT ar.1 = f32[2,2] all-reduce(const), all_reduce_id=1, replica_groups={{0}}, to_apply=add, sharding={maximal device=1} })"; diff --git a/tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.cc b/tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.cc index bd0139f85b6a5c5dc23dad962263038451921e65..5eeb29c478a371dae83251771f2dc4844672d3e9 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.cc @@ -18,28 +18,29 @@ limitations under the License. #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" namespace xla { -Status KernelSupportLibrary::For( +Status KernelSupportLibrary::ForWithStatus( absl::string_view name, llvm::Value* start, llvm::Value* end, llvm::Value* step, const std::function& for_body_generator) { - return If(b_->CreateICmpSLT(start, end), [&]() -> Status { + return IfWithStatus(b_->CreateICmpSLT(start, end), [&]() -> Status { TF_RETURN_IF_ERROR(for_body_generator(start, /*is_first_iteration=*/true)); - return For(name, b_->CreateAdd(start, step), end, step, - [&](llvm::Value* iv) { return for_body_generator(iv, false); }); + return ForWithStatus( + name, b_->CreateAdd(start, step), end, step, + [&](llvm::Value* iv) { return for_body_generator(iv, false); }); }); } -Status KernelSupportLibrary::For( +Status KernelSupportLibrary::ForWithStatus( absl::string_view name, llvm::Value* start, llvm::Value* end, llvm::Value* step, bool peel_first_iteration, const std::function& for_body_generator) { if (peel_first_iteration) { - return For(name, start, end, step, true, - [&](llvm::Value* indvar, bool is_first_iteration) -> Status { - return for_body_generator(indvar, - b_->getInt1(is_first_iteration)); - }); + return ForWithStatus( + name, start, end, step, true, + [&](llvm::Value* indvar, bool is_first_iteration) -> Status { + return for_body_generator(indvar, b_->getInt1(is_first_iteration)); + }); } else { std::unique_ptr loop = llvm_ir::ForLoop::EmitForLoop( name, start, end, step, b_, @@ -55,7 +56,7 @@ Status KernelSupportLibrary::For( } } -Status KernelSupportLibrary::If( +Status KernelSupportLibrary::IfWithStatus( absl::string_view name, llvm::Value* condition, const std::function& true_block_generator, const std::function& false_block_generator) { diff --git a/tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.h b/tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.h index 43fec311f150d6054f6ad24f99db332f90ff94a3..612b839cfa15711061e1ae53358a72d5220e1801 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.h +++ b/tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.h @@ -48,41 +48,42 @@ class KernelSupportLibrary { // for (i64 i = `start` + `step`; i s< `end`; i += `step`) // `for_body_generator(/*ind_var=*/,i, /*is_first_iteration=*/false)`; // } - Status For( + Status ForWithStatus( absl::string_view name, llvm::Value* start, llvm::Value* end, llvm::Value* step, const std::function& for_body_generator); - void ForReturnVoid( + void For( absl::string_view name, llvm::Value* start, llvm::Value* end, llvm::Value* step, const std::function& for_body_generator) { CHECK_EQ(Status::OK(), - For(name, start, end, step, + ForWithStatus( + name, start, end, step, [&](llvm::Value* ind_var, bool is_first_iteration) -> Status { for_body_generator(ind_var, is_first_iteration); return Status::OK(); })); } - Status For(absl::string_view name, int64 start, int64 end, int64 step, - const std::function& - for_body_generator) { - return For(name, /*start=*/b_->getInt64(start), - /*end=*/b_->getInt64(end), - /*step=*/b_->getInt64(step), for_body_generator); + Status ForWithStatus( + absl::string_view name, int64 start, int64 end, int64 step, + const std::function& for_body_generator) { + return ForWithStatus(name, /*start=*/b_->getInt64(start), + /*end=*/b_->getInt64(end), + /*step=*/b_->getInt64(step), for_body_generator); } - void ForReturnVoid( + void For( absl::string_view name, int64 start, int64 end, int64 step, const std::function& for_body_generator) { - ForReturnVoid(name, /*start=*/b_->getInt64(start), - /*end=*/b_->getInt64(end), - /*step=*/b_->getInt64(step), for_body_generator); + For(name, /*start=*/b_->getInt64(start), + /*end=*/b_->getInt64(end), + /*step=*/b_->getInt64(step), for_body_generator); } // Generates the following control flow structure if `peel_first_iteration` is @@ -99,19 +100,19 @@ class KernelSupportLibrary { // for (i64 i = `start`; i s< `end`; i += `step`) // `for_body_generator(/*ind_var=*/,i, // /*is_first_iteration=*/,(i != `start`))`; - Status For(absl::string_view name, llvm::Value* start, llvm::Value* end, - llvm::Value* step, bool peel_first_iteration, - const std::function& - for_body_generator); + Status ForWithStatus( + absl::string_view name, llvm::Value* start, llvm::Value* end, + llvm::Value* step, bool peel_first_iteration, + const std::function& + for_body_generator); - void ForReturnVoid(absl::string_view name, llvm::Value* start, - llvm::Value* end, llvm::Value* step, - bool peel_first_iteration, - const std::function& - for_body_generator) { - TF_CHECK_OK(For( + void For(absl::string_view name, llvm::Value* start, llvm::Value* end, + llvm::Value* step, bool peel_first_iteration, + const std::function& + for_body_generator) { + TF_CHECK_OK(ForWithStatus( name, start, end, step, peel_first_iteration, [&](llvm::Value* ind_var, llvm::Value* is_first_iteration) -> Status { for_body_generator(ind_var, is_first_iteration); @@ -119,80 +120,81 @@ class KernelSupportLibrary { })); } - Status For(absl::string_view name, llvm::Value* start, llvm::Value* end, - int64 step, bool peel_first_iteration, - const std::function& - for_body_generator) { - return For(name, /*start=*/start, /*end=*/end, - /*step=*/llvm::ConstantInt::get(start->getType(), step), - peel_first_iteration, for_body_generator); + Status ForWithStatus( + absl::string_view name, llvm::Value* start, llvm::Value* end, int64 step, + bool peel_first_iteration, + const std::function& + for_body_generator) { + return ForWithStatus( + name, /*start=*/start, /*end=*/end, + /*step=*/llvm::ConstantInt::get(start->getType(), step), + peel_first_iteration, for_body_generator); } - void ForReturnVoid(absl::string_view name, llvm::Value* start, - llvm::Value* end, int64 step, bool peel_first_iteration, - const std::function& - for_body_generator) { - ForReturnVoid(name, /*start=*/start, /*end=*/end, - /*step=*/llvm::ConstantInt::get(start->getType(), step), - peel_first_iteration, for_body_generator); + void For(absl::string_view name, llvm::Value* start, llvm::Value* end, + int64 step, bool peel_first_iteration, + const std::function& + for_body_generator) { + For(name, /*start=*/start, /*end=*/end, + /*step=*/llvm::ConstantInt::get(start->getType(), step), + peel_first_iteration, for_body_generator); } - Status For( + Status ForWithStatus( absl::string_view name, llvm::Value* start, llvm::Value* end, llvm::Value* step, const std::function& for_body_generator) { - return For(name, start, end, step, - /*peel_first_iteration=*/false, - [&](llvm::Value* indvar, llvm::Value*) -> Status { - return for_body_generator(indvar); - }); + return ForWithStatus(name, start, end, step, + /*peel_first_iteration=*/false, + [&](llvm::Value* indvar, llvm::Value*) -> Status { + return for_body_generator(indvar); + }); } - void ForReturnVoid( + void For( absl::string_view name, llvm::Value* start, llvm::Value* end, llvm::Value* step, const std::function& for_body_generator) { - ForReturnVoid(name, start, end, step, - /*peel_first_iteration=*/false, - [&](llvm::Value* indvar, llvm::Value*) { - return for_body_generator(indvar); - }); + For(name, start, end, step, + /*peel_first_iteration=*/false, [&](llvm::Value* indvar, llvm::Value*) { + return for_body_generator(indvar); + }); } - Status For( + Status ForWithStatus( absl::string_view name, llvm::Value* start, llvm::Value* end, int64 step, const std::function& for_body_generator) { - return For(name, start, end, llvm::ConstantInt::get(start->getType(), step), - /*peel_first_iteration=*/false, - [&](llvm::Value* indvar, llvm::Value*) -> Status { - return for_body_generator(indvar); - }); + return ForWithStatus(name, start, end, + llvm::ConstantInt::get(start->getType(), step), + /*peel_first_iteration=*/false, + [&](llvm::Value* indvar, llvm::Value*) -> Status { + return for_body_generator(indvar); + }); } - void ForReturnVoid( + void For( absl::string_view name, llvm::Value* start, llvm::Value* end, int64 step, const std::function& for_body_generator) { - ForReturnVoid(name, start, end, - llvm::ConstantInt::get(start->getType(), step), - for_body_generator); + For(name, start, end, llvm::ConstantInt::get(start->getType(), step), + for_body_generator); } - Status For( + Status ForWithStatus( absl::string_view name, int64 start, int64 end, int64 step, const std::function& for_body_generator) { - return For(name, /*start=*/b_->getInt64(start), - /*end=*/b_->getInt64(end), - /*step=*/b_->getInt64(step), for_body_generator); + return ForWithStatus(name, /*start=*/b_->getInt64(start), + /*end=*/b_->getInt64(end), + /*step=*/b_->getInt64(step), for_body_generator); } - void ForReturnVoid( + void For( absl::string_view name, int64 start, int64 end, int64 step, const std::function& for_body_generator) { - ForReturnVoid(name, /*start=*/b_->getInt64(start), - /*end=*/b_->getInt64(end), - /*step=*/b_->getInt64(step), for_body_generator); + For(name, /*start=*/b_->getInt64(start), + /*end=*/b_->getInt64(end), + /*step=*/b_->getInt64(step), for_body_generator); } // Generates the following control flow structure: @@ -201,38 +203,43 @@ class KernelSupportLibrary { // `true_block_generator()`; // else // `false_block_generator()`; - Status If(absl::string_view name, llvm::Value* condition, - const std::function& true_block_generator, - const std::function& false_block_generator = - []() -> Status { return Status::OK(); }); + Status IfWithStatus( + absl::string_view name, llvm::Value* condition, + const std::function& true_block_generator, + const std::function& false_block_generator = []() -> Status { + return Status::OK(); + }); - Status If(llvm::Value* condition, - const std::function& true_block_generator, - const std::function& false_block_generator = - []() -> Status { return Status::OK(); }) { - return If("", condition, true_block_generator, false_block_generator); + Status IfWithStatus( + llvm::Value* condition, + const std::function& true_block_generator, + const std::function& false_block_generator = []() -> Status { + return Status::OK(); + }) { + return IfWithStatus("", condition, true_block_generator, + false_block_generator); } - void IfReturnVoid(llvm::Value* condition, - const std::function& true_block_generator, - const std::function& false_block_generator = []() { - }) { - IfReturnVoid("", condition, true_block_generator, false_block_generator); + void If( + llvm::Value* condition, const std::function& true_block_generator, + const std::function& false_block_generator = []() {}) { + If("", condition, true_block_generator, false_block_generator); } - void IfReturnVoid(absl::string_view name, llvm::Value* condition, - const std::function& true_block_generator, - const std::function& false_block_generator = []() { - }) { - TF_CHECK_OK(If(name, condition, - [&]() { - true_block_generator(); - return Status::OK(); - }, - [&]() { - false_block_generator(); - return Status::OK(); - })); + void If( + absl::string_view name, llvm::Value* condition, + const std::function& true_block_generator, + const std::function& false_block_generator = []() {}) { + TF_CHECK_OK(IfWithStatus( + name, condition, + [&]() { + true_block_generator(); + return Status::OK(); + }, + [&]() { + false_block_generator(); + return Status::OK(); + })); } using ArgumentVector = absl::Span; diff --git a/tensorflow/compiler/xla/service/llvm_ir/kernel_tiling.cc b/tensorflow/compiler/xla/service/llvm_ir/kernel_tiling.cc index c26711e526c9b89cdedcb6aed9f93d41dd25dc83..cebbc4290163d4e98003cd7b5df6ec906509a446 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/kernel_tiling.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/kernel_tiling.cc @@ -120,7 +120,7 @@ KernelMappingScheme::KernelMappingScheme( absl::Span req_block_sizes, int64 num_threads_y, int64 num_threads_x, llvm::IRBuilder<>* b) : b_(b), - dims_in_elems_(dims_in_elems), + dims_in_elems_(dims_in_elems.begin(), dims_in_elems.end()), tile_sizes_{1, tile_size_y, tile_size_x}, num_threads_x_(num_threads_x), num_threads_y_(num_threads_y) { @@ -170,14 +170,16 @@ IrArray::Index KernelMappingScheme::EmitBlockIndex(llvm::Type* index_ty) { IrArray::Index KernelMappingScheme::GetTileIndexForBlockOrigin( const IrArray::Index& block_index) { - IrArray::Index tile_index = block_index; + DCHECK_EQ(block_index.size(), block_sizes_.size()); + std::vector multidim; + multidim.reserve(block_sizes_.size()); for (int i = 0; i < block_sizes_.size(); ++i) { - tile_index[i] = b_->CreateMul( + multidim.push_back(b_->CreateMul( block_index[i], llvm::ConstantInt::get(block_index[i]->getType(), block_sizes_[i]), - "block_origin." + std::to_string(i)); + "block_origin." + std::to_string(i))); } - return tile_index; + return IrArray::Index(multidim, block_index[0]->getType()); } IrArray::Index KernelMappingScheme::GetElementIndexForTileOrigin( @@ -217,14 +219,14 @@ KernelMappingScheme::EmitThreadYXCoordinate(llvm::Type* index_ty) { // defined by (num_thread_y, num_thread_x) from thread_id. llvm::CallInst* thread_id_raw = llvm_ir::EmitCallToIntrinsic( llvm::Intrinsic::nvvm_read_ptx_sreg_tid_x, {}, {}, b_); - llvm_ir::AddRangeMetadata(0, GetThreadsPerTile(), thread_id_raw); + llvm_ir::AddRangeMetadata(0, GetThreadsPerBlock(), thread_id_raw); llvm::Value* thread_id_int = b_->CreateIntCast(thread_id_raw, index_ty, /*isSigned=*/true, "thread.id.x"); llvm::Value* num_thread_x = llvm::ConstantInt::get(index_ty, GetNumberOfThreadsForDimensionX()); - llvm::Value* x = b_->CreateURem(thread_id_int, num_thread_x); - llvm::Value* y = b_->CreateUDiv(thread_id_int, num_thread_x); + llvm::Value* x = b_->CreateURem(thread_id_int, num_thread_x, "thread.x"); + llvm::Value* y = b_->CreateUDiv(thread_id_int, num_thread_x, "thread.y"); return std::make_tuple(y, x); } diff --git a/tensorflow/compiler/xla/service/llvm_ir/kernel_tiling.h b/tensorflow/compiler/xla/service/llvm_ir/kernel_tiling.h index 06002d57b0d7daa07f903feebe67a60a083c0e7c..fb633b12e60d1a9f3103fb2919ad2c3f3f14de20 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/kernel_tiling.h +++ b/tensorflow/compiler/xla/service/llvm_ir/kernel_tiling.h @@ -90,15 +90,16 @@ class KernelMappingScheme { enum { DimZ = 0, DimY, DimX, DimTot }; public: + KernelMappingScheme() {} // dims_in_elems: the normalized tensor dimensions. // req_block_sizes: the requested block size in number of tiles for each // dimension. The actual block size is set to min(req_block_size, // dims_in_number_of_blocks). - explicit KernelMappingScheme(absl::Span dims_in_elems, - int64 tile_size_y, int64 tile_size_x, - absl::Span req_block_sizes, - int64 num_threads_y, int64 num_threads_x, - llvm::IRBuilder<>* b); + KernelMappingScheme(absl::Span dims_in_elems, int64 tile_size_y, + int64 tile_size_x, + absl::Span req_block_sizes, + int64 num_threads_y, int64 num_threads_x, + llvm::IRBuilder<>* b); absl::Span GetDimensionsInElements() const { return dims_in_elems_; @@ -133,11 +134,15 @@ class KernelMappingScheme { } absl::Span GetBlockSizes() const { return block_sizes_; } + int64 GetTileBlockSizeForDimension(int d) const { + DCHECK(d >= DimZ && d <= DimX); + return dims_in_blocks_[d]; + } int64 GetNumberOfThreadsForDimensionX() const { return num_threads_x_; } int64 GetNumberOfThreadsForDimensionY() const { return num_threads_y_; } - int64 GetThreadsPerTile() const { + int64 GetThreadsPerBlock() const { return GetNumberOfThreadsForDimensionX() * GetNumberOfThreadsForDimensionY(); } @@ -163,7 +168,7 @@ class KernelMappingScheme { private: llvm::IRBuilder<>* b_; // The number of elements in each dimension. - absl::Span dims_in_elems_; + std::vector dims_in_elems_; // The number of elements for each dimension of a tile. std::vector tile_sizes_; diff --git a/tensorflow/compiler/xla/service/llvm_ir/sort_util.cc b/tensorflow/compiler/xla/service/llvm_ir/sort_util.cc index e22c2173c271fc9571be1ddb0759d2b31562dc98..6a9406bfebafcc02dc2e144b62284a9e83c3edeb 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/sort_util.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/sort_util.cc @@ -108,7 +108,7 @@ void EmitCompareLoopBody( // if (is_smaller_index && index_is_inbounds) KernelSupportLibrary ksl(b); - ksl.IfReturnVoid("smaller_comparison_index", do_comparison, [&]() { + ksl.If("smaller_comparison_index", do_comparison, [&]() { auto key1 = read_element(0, current_keys_index); auto key2 = read_element(0, compare_keys_index); auto compare_key1 = key1; @@ -155,7 +155,7 @@ void EmitCompareLoopBody( is_smaller_than = b->CreateOr( is_smaller_than, b->CreateAnd(keys_equal, index_is_smaller_than)); } - ksl.IfReturnVoid("is_smaller_than", is_smaller_than, [&]() { + ksl.If("is_smaller_than", is_smaller_than, [&]() { // Swap key1 with key2. write_element(0, current_keys_index, key2); write_element(0, compare_keys_index, key1); @@ -192,7 +192,7 @@ void EmitTiledCompareLoop( b->CreateShl(tiled_keys_index[dimension_to_sort], value_one); // We want to copy two adjacent elements. We first check whether the // first index position is within bounds. - ksl.IfReturnVoid( + ksl.If( "smaller_keys_index", b->CreateICmpSLT(current_keys_index, tiled_keys_index.GetConstantWithIndexType( @@ -203,15 +203,14 @@ void EmitTiledCompareLoop( // Increment to go the next index position. current_keys_index = b->CreateAdd(current_keys_index, value_one); // Here we check whether the next index position is within bounds. - ksl.IfReturnVoid( - "inner_smaller_keys_index", - b->CreateICmpSLT(current_keys_index, - tiled_keys_index.GetConstantWithIndexType( - dimension_to_sort_bound)), - [&]() { - cache_index = b->CreateAdd(cache_index, value_one); - read_or_write(cache_index, current_keys_index); - }); + ksl.If("inner_smaller_keys_index", + b->CreateICmpSLT(current_keys_index, + tiled_keys_index.GetConstantWithIndexType( + dimension_to_sort_bound)), + [&]() { + cache_index = b->CreateAdd(cache_index, value_one); + read_or_write(cache_index, current_keys_index); + }); }); }; @@ -253,7 +252,7 @@ void EmitTiledCompareLoop( if (dimension_to_sort_bound % tile_size) { // Otherwise we need a bounds check for the last tile. The last tile has // size 'dimension_to_sort_bound' % 'tile_size'. - ksl.IfReturnVoid( + ksl.If( "is_last_tile", b->CreateICmpUGE( b->CreateMul(tiled_keys_index[dimension_to_sort], diff --git a/tensorflow/compiler/xla/service/local_service.cc b/tensorflow/compiler/xla/service/local_service.cc index 6c89700983363fec46c41b5430c6eab6b366a1b6..600b069ecdbabf6b05e6abb3a6b8d9b1a4b0ecf4 100644 --- a/tensorflow/compiler/xla/service/local_service.cc +++ b/tensorflow/compiler/xla/service/local_service.cc @@ -52,8 +52,10 @@ namespace xla { } BackendOptions backend_options; - backend_options.set_platform(platform).set_intra_op_parallelism_threads( - options.intra_op_parallelism_threads()); + backend_options.set_platform(platform) + .set_intra_op_parallelism_threads(options.intra_op_parallelism_threads()) + .set_allowed_devices(options.allowed_devices()); + TF_ASSIGN_OR_RETURN(std::unique_ptr backend, Backend::CreateBackend(backend_options)); diff --git a/tensorflow/compiler/xla/service/name_uniquer.cc b/tensorflow/compiler/xla/service/name_uniquer.cc index ac2f79674feceff436c0e9c65338967f498e4473..daa718879ddd45afb02725b557380b2f49fe833e 100644 --- a/tensorflow/compiler/xla/service/name_uniquer.cc +++ b/tensorflow/compiler/xla/service/name_uniquer.cc @@ -17,6 +17,7 @@ limitations under the License. #include "absl/strings/numbers.h" #include "absl/strings/str_cat.h" +#include "tensorflow/compiler/xla/primitive_util.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/types.h" @@ -42,6 +43,7 @@ NameUniquer::NameUniquer(const string& separator) { if (name.empty()) { return ""; } + string result = name; char c = static_cast(result[0]); if (!isalpha(c) && c != '_') { @@ -52,6 +54,13 @@ NameUniquer::NameUniquer(const string& separator) { result[i] = '_'; } } + + // HLO primitive type names (with the exception of 'tuple') are keywords in + // the HLO text representation and cannot be names, so append an underscore if + // the name is a primitive type. + if (primitive_util::IsPrimitiveTypeName(result) && result != "tuple") { + result += "_"; + } return result; } diff --git a/tensorflow/compiler/xla/service/name_uniquer_test.cc b/tensorflow/compiler/xla/service/name_uniquer_test.cc index 3e2592c6ac626143f1421e545a31d9be91e376bc..d0d04147e0c29c66cba447550c0a9c703f35573a 100644 --- a/tensorflow/compiler/xla/service/name_uniquer_test.cc +++ b/tensorflow/compiler/xla/service/name_uniquer_test.cc @@ -104,5 +104,21 @@ TEST_F(NameUniquerTest, KeepNamesInRandomOrder) { EXPECT_EQ("foo.3", uniquer.GetUniqueName("foo.3")); } +TEST_F(NameUniquerTest, AvoidKeywords) { + NameUniquer uniquer("."); + + EXPECT_EQ("f32_", uniquer.GetUniqueName("f32")); + EXPECT_EQ("s64_", uniquer.GetUniqueName("s64")); + EXPECT_EQ("pred_", uniquer.GetUniqueName("pred")); + + // Though a primitive type, "tuple" is not a keyword. + EXPECT_EQ("tuple", uniquer.GetUniqueName("tuple")); + + // Keywords are not capitalized. + EXPECT_EQ("F32", uniquer.GetUniqueName("F32")); + EXPECT_EQ("S32", uniquer.GetUniqueName("S32")); + EXPECT_EQ("Pred", uniquer.GetUniqueName("Pred")); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/pattern_matcher.h b/tensorflow/compiler/xla/service/pattern_matcher.h index c35f72699bfe90f7b8021916c0f81d5e1926ff4c..fdb6a9b01be4b9198e40aa9bf7cdc07ff068a619 100644 --- a/tensorflow/compiler/xla/service/pattern_matcher.h +++ b/tensorflow/compiler/xla/service/pattern_matcher.h @@ -1737,7 +1737,8 @@ class HloConstantScalarImpl { literal_r0_as_val_ty_or.ValueOrDie() == val_literal && literal_r0 == val_as_literal_ty; if (!rv) { - EXPLAIN << "HloInstruction's constant value " << literal_r0.ToString() + EXPLAIN << "HloInstruction's constant value " + << literal_r0.ToStringWithoutShape() << " did not match expected value " << *val_; } return rv; @@ -2035,7 +2036,7 @@ XLA_UNOP_PATTERN(Ceil) XLA_UNOP_PATTERN(Convert) XLA_UNOP_PATTERN(Copy) XLA_UNOP_PATTERN(Cos) -XLA_UNOP_PATTERN(CrossReplicaSum) +XLA_UNOP_PATTERN(AllReduce) XLA_UNOP_PATTERN(Exp) XLA_UNOP_PATTERN(Fft) XLA_UNOP_PATTERN(Floor) diff --git a/tensorflow/compiler/xla/service/pattern_matcher_test.cc b/tensorflow/compiler/xla/service/pattern_matcher_test.cc index 186ef0c7911a2724df810780e018f52586e3e6a8..5c3c009a68bffbda8642fceedfb724879fbf1530 100644 --- a/tensorflow/compiler/xla/service/pattern_matcher_test.cc +++ b/tensorflow/compiler/xla/service/pattern_matcher_test.cc @@ -242,8 +242,8 @@ TEST(PatternMatcherTest, ConstantScalar) { HloModule test_module ENTRY test { a = s32[] constant(1) - b = s32[1,1] constant(s32[1,1]{{2}}) - c = s32[1,2] constant(s32[1,2]{{2,2}}) + b = s32[1,1] constant({{2}}) + c = s32[1,2] constant({{2,2}}) d = f32[] constant(1) e = f32[] constant(1.25) ROOT tuple = (s32[], s32[1,1], s32[1,2], f32[], f32[]) tuple(a,b,c,d,e) diff --git a/tensorflow/compiler/xla/service/platform_util.cc b/tensorflow/compiler/xla/service/platform_util.cc index c227106511c2c17b44569d3b696cd7d764226e81..896b73cda41cb21b539b586aa4701c5bad43f8b9 100644 --- a/tensorflow/compiler/xla/service/platform_util.cc +++ b/tensorflow/compiler/xla/service/platform_util.cc @@ -205,7 +205,9 @@ static bool IsDeviceSupported(se::StreamExecutor* executor) { } /* static */ StatusOr> -PlatformUtil::GetStreamExecutors(se::Platform* platform) { +PlatformUtil::GetStreamExecutors( + se::Platform* platform, + const absl::optional>& allowed_devices) { int device_count = platform->VisibleDeviceCount(); if (device_count <= 0) { return NotFound("no %s devices found", platform->Name()); @@ -226,6 +228,17 @@ PlatformUtil::GetStreamExecutors(se::Platform* platform) { tensorflow::thread::ThreadPool thread_pool( tensorflow::Env::Default(), "device_initialization", device_count); for (int i = 0; i < device_count; ++i) { + // Once a stream executor is instantiated it will cause allocations on + // the device, for example for GPUs cuda context, cudnn handles etc. will + // be constructed. By constructing stream executors only on the + // allowed_devices, we don't make any allocations on other devices. + // This helps in multi-process executions on the same host like horovod or + // shared hosts. + if (allowed_devices && allowed_devices->count(i) == 0) { + VLOG(1) << "Not initializing StreamExecutor for device " << i + << " since it is not in the visible device list"; + continue; + } thread_pool.Schedule([platform, i, &stream_executors]() { VLOG(1) << "Started device init " << i; se::StreamExecutorConfig config; diff --git a/tensorflow/compiler/xla/service/platform_util.h b/tensorflow/compiler/xla/service/platform_util.h index 571451ba43a81d19b70e4954e45d3447f15dcedc..592b20282f334e12e0d7a7f683c9a6ab59d21fea 100644 --- a/tensorflow/compiler/xla/service/platform_util.h +++ b/tensorflow/compiler/xla/service/platform_util.h @@ -16,6 +16,7 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_PLATFORM_UTIL_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_PLATFORM_UTIL_H_ +#include #include #include @@ -60,10 +61,14 @@ class PlatformUtil { // Returns a vector of StreamExecutors for the given platform. The vector is // indexed by device ordinal (device numbering used by StreamExecutor). If an // element is nullptr, then the device is present by not supported by XLA. + // If populated, only the devices in allowed_devices will have + // their StreamExecutors initialized, otherwise all StreamExecutors will be + // initialized and returned. // // If the platform has no visible devices, a not-found error is returned. static StatusOr> GetStreamExecutors( - se::Platform* platform); + se::Platform* platform, + const absl::optional>& allowed_devices = absl::nullopt); private: TF_DISALLOW_COPY_AND_ASSIGN(PlatformUtil); diff --git a/tensorflow/compiler/xla/service/service.cc b/tensorflow/compiler/xla/service/service.cc index 5ec7fe2adedac2fc3d8a7588e853dba90e99006f..a0126f39b3dc4281abedc36a19dd20c3b128e249 100644 --- a/tensorflow/compiler/xla/service/service.cc +++ b/tensorflow/compiler/xla/service/service.cc @@ -113,6 +113,16 @@ int ServiceOptions::intra_op_parallelism_threads() const { return intra_op_parallelism_threads_; } +ServiceOptions& ServiceOptions::set_allowed_devices( + const absl::optional>& allowed_devices) { + allowed_devices_ = allowed_devices; + return *this; +} + +const absl::optional>& ServiceOptions::allowed_devices() const { + return allowed_devices_; +} + /* static */ StatusOr> Service::NewService( se::Platform* platform) { ServiceOptions default_options; @@ -129,6 +139,7 @@ int ServiceOptions::intra_op_parallelism_threads() const { } BackendOptions backend_options; backend_options.set_platform(platform); + backend_options.set_allowed_devices(options.allowed_devices()); TF_ASSIGN_OR_RETURN(execute_backend, Backend::CreateBackend(backend_options)); std::unique_ptr service( @@ -150,17 +161,13 @@ Service::Service(const ServiceOptions& options, LOG(INFO) << StrFormat( "XLA service %p executing computations on platform %s. Devices:", this, execute_backend_->platform()->Name()); + auto stream_executors = execute_backend_->stream_executors(); for (int i = 0; i < execute_backend_->device_count(); ++i) { - if (execute_backend_->device_ordinal_supported(i)) { - se::StreamExecutor* executor = - execute_backend_->stream_executor(i).ValueOrDie(); - const auto& description = executor->GetDeviceDescription(); - LOG(INFO) << StrFormat(" StreamExecutor device (%d): %s, %s", i, - description.name(), - description.platform_version()); - } else { - LOG(INFO) << StrFormat(" StreamExecutor device (%d) not supported", i); - } + se::StreamExecutor* executor = stream_executors.at(i); + const auto& description = executor->GetDeviceDescription(); + LOG(INFO) << StrFormat(" StreamExecutor device (%d): %s, %s", i, + description.name(), + description.platform_version()); } } else { VLOG(1) << "XLA compile-only service constructed"; @@ -1078,9 +1085,11 @@ Status Service::ComputeConstantGraph(const ComputeConstantGraphRequest* arg, ProgramShape program_shape(arg->computation().host_program_shape()); TF_DCHECK_OK(ShapeUtil::ValidateShape(program_shape.result())); + absl::optional output_layout; if (arg->has_output_layout()) { + output_layout = Layout::CreateFromProto(arg->output_layout()); TF_RETURN_IF_ERROR(LayoutUtil::ValidateLayoutForShape( - arg->output_layout(), program_shape.result())); + *output_layout, program_shape.result())); } HloModuleConfig config(program_shape); @@ -1096,8 +1105,8 @@ Status Service::ComputeConstantGraph(const ComputeConstantGraphRequest* arg, // relayout here. // // TODO(b/77824332): Make HloEvaluator take care of the re-layout. - if (arg->has_output_layout()) { - result_literal = result_literal.Relayout(arg->output_layout()); + if (output_layout.has_value()) { + result_literal = result_literal.Relayout(*output_layout); } *result->mutable_literal() = result_literal.ToProto(); diff --git a/tensorflow/compiler/xla/service/service.h b/tensorflow/compiler/xla/service/service.h index 11e1a79552fbd944ab28da129b08cfe676fb08e9..abd3ee5a059ac0910d6acc8076899950498b4c43 100644 --- a/tensorflow/compiler/xla/service/service.h +++ b/tensorflow/compiler/xla/service/service.h @@ -18,6 +18,7 @@ limitations under the License. #include #include +#include #include #include @@ -61,10 +62,17 @@ class ServiceOptions { ServiceOptions& set_intra_op_parallelism_threads(int num_threads); int intra_op_parallelism_threads() const; + // Sets the allowed_devices set for selectively constructing stream executors + // on the platform. + ServiceOptions& set_allowed_devices( + const absl::optional>& allowed_devices); + const absl::optional>& allowed_devices() const; + private: se::Platform* platform_ = nullptr; int number_of_replicas_ = 1; int intra_op_parallelism_threads_ = -1; + absl::optional> allowed_devices_; }; // The XLA service object, which is the same across all platforms. It maintains diff --git a/tensorflow/compiler/xla/service/shape_inference.cc b/tensorflow/compiler/xla/service/shape_inference.cc index 7e7282a737041458aed39b0054f901c23aa87d7a..8e571675c79b08efd454ee5e0fe47bacdcf3dbb7 100644 --- a/tensorflow/compiler/xla/service/shape_inference.cc +++ b/tensorflow/compiler/xla/service/shape_inference.cc @@ -1556,7 +1556,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, /* static */ StatusOr ShapeInference::InferConvolveShape( const Shape& lhs, const Shape& rhs, int64 feature_group_count, - const Window& window, const ConvolutionDimensionNumbers& dnums) { + int64 batch_group_count, const Window& window, + const ConvolutionDimensionNumbers& dnums) { TF_RETURN_IF_ERROR(ExpectArray(lhs, "lhs of convolution")); TF_RETURN_IF_ERROR(ExpectArray(rhs, "rhs of convolution")); @@ -1565,6 +1566,13 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, "feature_group_count must be a positive number, got %d", feature_group_count); } + + if (batch_group_count <= 0) { + return InvalidArgument( + "batch_group_count must be a positive number, got %d", + batch_group_count); + } + if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(lhs, rhs)) { return InvalidArgument( "Convolution with different element types: %s and %s.", @@ -1700,6 +1708,17 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, ShapeUtil::HumanString(lhs), ShapeUtil::HumanString(rhs), dnums.DebugString()); } + + if (input_batch % batch_group_count > 0) { + return InvalidArgument( + "Expected input batch dimension (value %d) to be divisible by " + "batch_group_count (value %d); " + "got (%s, %s)\n" + "Dimension numbers: {%s}.", + input_batch, batch_group_count, ShapeUtil::HumanString(lhs), + ShapeUtil::HumanString(rhs), dnums.DebugString()); + } + std::vector window_dims(num_spatial_dims); for (int i = 0; i < num_spatial_dims; ++i) { window_dims[i] = window.dimensions(i).size(); @@ -1722,7 +1741,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, /*allow_negative_padding=*/true)); std::vector dimensions(num_dims); - dimensions[dnums.output_batch_dimension()] = input_batch; + dimensions[dnums.output_batch_dimension()] = input_batch / batch_group_count; dimensions[dnums.output_feature_dimension()] = kernel_output_features; for (int i = 0; i < num_spatial_dims; ++i) { dimensions[dnums.output_spatial_dimensions(i)] = @@ -1814,7 +1833,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, #undef RET_CHECK_RANK } -/* static */ StatusOr ShapeInference::InferCrossReplicaSumShape( +/* static */ StatusOr ShapeInference::InferAllReduceShape( absl::Span operand_shapes) { for (const Shape* operand_shape : operand_shapes) { TF_RETURN_IF_ERROR( diff --git a/tensorflow/compiler/xla/service/shape_inference.h b/tensorflow/compiler/xla/service/shape_inference.h index d94385a04d50baff8156570a09620fd458547936..1b8fd10d691498087b28ef68517868c5def1da5a 100644 --- a/tensorflow/compiler/xla/service/shape_inference.h +++ b/tensorflow/compiler/xla/service/shape_inference.h @@ -109,7 +109,7 @@ class ShapeInference { // filter (rhs) to lhs in the way specified by the fields on window. static StatusOr InferConvolveShape( const Shape& lhs, const Shape& rhs, int64 feature_group_count, - const Window& window, + int64 batch_group_count, const Window& window, const ConvolutionDimensionNumbers& dimension_numbers); // Infers the shape produced by the given FFT type on the given operand. @@ -118,7 +118,7 @@ class ShapeInference { // Infers the shape produced by a cross replica sum with the given operand // shapes. - static StatusOr InferCrossReplicaSumShape( + static StatusOr InferAllReduceShape( absl::Span operand_shapes); // Infers final shape of an Alltoall operation that is created by the xla diff --git a/tensorflow/compiler/xla/service/shape_inference_test.cc b/tensorflow/compiler/xla/service/shape_inference_test.cc index 4639e32db4d59080a9e85e46983fac61d9e76be9..0a870808d4cd89fa18382522ea5a4bf2355e5ce7 100644 --- a/tensorflow/compiler/xla/service/shape_inference_test.cc +++ b/tensorflow/compiler/xla/service/shape_inference_test.cc @@ -420,7 +420,8 @@ TEST_F(ShapeInferenceTest, Convolve) { dim1->set_window_dilation(1); dim1->set_base_dilation(1); auto inferred_status = ShapeInference::InferConvolveShape( - lhs_shape, rhs_shape, /*feature_group_count=*/1, window, dnums); + lhs_shape, rhs_shape, /*feature_group_count=*/1, /*batch_group_count=*/1, + window, dnums); ASSERT_IS_OK(inferred_status.status()); Shape inferred_shape = inferred_status.ValueOrDie(); ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {10, 12, 2, 3}), @@ -465,7 +466,8 @@ TEST_F(ShapeInferenceTest, ConvolveWithWindowDilation) { dim1->set_window_dilation(2); dim1->set_base_dilation(1); auto inferred_status = ShapeInference::InferConvolveShape( - lhs_shape, rhs_shape, /*feature_group_count=*/1, window, dnums); + lhs_shape, rhs_shape, /*feature_group_count=*/1, /*batch_group_count=*/1, + window, dnums); ASSERT_IS_OK(inferred_status.status()); Shape inferred_shape = inferred_status.ValueOrDie(); ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {10, 12, 31, 5}), @@ -510,7 +512,8 @@ TEST_F(ShapeInferenceTest, ConvolveWithBaseDilation) { dim1->set_window_dilation(1); dim1->set_base_dilation(2); auto inferred_status = ShapeInference::InferConvolveShape( - lhs_shape, rhs_shape, /*feature_group_count=*/1, window, dnums); + lhs_shape, rhs_shape, /*feature_group_count=*/1, /*batch_group_count=*/1, + window, dnums); ASSERT_IS_OK(inferred_status.status()); Shape inferred_shape = inferred_status.ValueOrDie(); ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {10, 12, 4, 9}), @@ -548,7 +551,8 @@ TEST_F(ShapeInferenceTest, ConvolveDimensionNumbersOverlapError) { dim1->set_padding_low(1); dim1->set_padding_high(1); auto inferred_status = ShapeInference::InferConvolveShape( - lhs_shape, rhs_shape, /*feature_group_count=*/1, window, dnums); + lhs_shape, rhs_shape, /*feature_group_count=*/1, /*batch_group_count=*/1, + window, dnums); ASSERT_FALSE(inferred_status.ok()); ASSERT_THAT(inferred_status.status().error_message(), HasSubstr("each dimension exactly once")); diff --git a/tensorflow/compiler/xla/service/transpose_folding.cc b/tensorflow/compiler/xla/service/transpose_folding.cc index 7c1f4b5cc67dd2a84271b4f2b8015fdb2ff6e846..eaf4f28b87ce7706832eebb0bc02d015e64ee89a 100644 --- a/tensorflow/compiler/xla/service/transpose_folding.cc +++ b/tensorflow/compiler/xla/service/transpose_folding.cc @@ -178,7 +178,8 @@ bool FoldTransposeIntoConvolution(InstructionOperandsPair pair) { auto new_conv = HloInstruction::CreateConvolve( convolution.shape(), new_lhs, new_rhs, convolution.feature_group_count(), - convolution.window(), new_dnums, convolution.precision_config()); + convolution.batch_group_count(), convolution.window(), new_dnums, + convolution.precision_config()); TF_CHECK_OK(convolution.parent()->ReplaceWithNewInstruction( &convolution, std::move(new_conv))); diff --git a/tensorflow/compiler/xla/service/transpose_folding_test.cc b/tensorflow/compiler/xla/service/transpose_folding_test.cc index 17cdaa74fc328d156292f5af828d4222a9a01f1f..f8a5fa0215007310d6bec35d20fc643afc824dda 100644 --- a/tensorflow/compiler/xla/service/transpose_folding_test.cc +++ b/tensorflow/compiler/xla/service/transpose_folding_test.cc @@ -139,9 +139,9 @@ TEST_F(TransposeFoldingTest, FoldDotTransposeConstant) { HloModule FoldDotTransposeConstant ENTRY entry_computation { - constant = f32[2,1]{1,0} constant(f32[2,1] { { 1 }, { 2 } }) + constant = f32[2,1]{1,0} constant({ { 1 }, { 2 } }) transpose = f32[1,2]{1,0} transpose(constant), dimensions={1,0} - constant.1 = f32[3,2]{1,0} constant(f32[3,2] { { 1, 2 }, { 3, 4 }, { 5, 6 } }) + constant.1 = f32[3,2]{1,0} constant({ { 1, 2 }, { 3, 4 }, { 5, 6 } }) transpose.1 = f32[2,3]{1,0} transpose(constant.1), dimensions={1,0} ROOT dot = f32[1,3]{1,0} dot(transpose, transpose.1), lhs_contracting_dims={1}, rhs_contracting_dims={0} } @@ -240,12 +240,13 @@ TEST_F(TransposeFoldingTest, FoldConvDimSwapTransposeRhs) { transpose_y->shape().dimensions(dnums.kernel_spatial_dimensions(i))); } StatusOr conv_shape = ShapeInference::InferConvolveShape( - x->shape(), transpose_y->shape(), /*feature_group_count=*/1, window, - dnums); + x->shape(), transpose_y->shape(), /*feature_group_count=*/1, + /*batch_group_count=*/1, window, dnums); EXPECT_IS_OK(conv_shape); HloInstruction* conv = builder.AddInstruction(HloInstruction::CreateConvolve( conv_shape.ValueOrDie(), x, transpose_y, - /*feature_group_count=*/1, window, dnums, DefaultPrecisionConfig(2))); + /*feature_group_count=*/1, /*batch_group_count=*/1, window, dnums, + DefaultPrecisionConfig(2))); auto module = CreateNewVerifiedModule("test_module"); HloComputation* entry_computation = @@ -295,12 +296,13 @@ TEST_F(TransposeFoldingTest, FoldConvComplexTransposeRhs) { transpose_y->shape().dimensions(dnums.kernel_spatial_dimensions(i))); } StatusOr conv_shape = ShapeInference::InferConvolveShape( - x->shape(), transpose_y->shape(), /*feature_group_count=*/1, window, - dnums); + x->shape(), transpose_y->shape(), /*feature_group_count=*/1, + /*batch_group_count=*/1, window, dnums); EXPECT_IS_OK(conv_shape); HloInstruction* conv = builder.AddInstruction(HloInstruction::CreateConvolve( conv_shape.ValueOrDie(), x, transpose_y, - /*feature_group_count=*/1, window, dnums, DefaultPrecisionConfig(2))); + /*feature_group_count=*/1, /*batch_group_count=*/1, window, dnums, + DefaultPrecisionConfig(2))); auto module = CreateNewVerifiedModule("test_module"); HloComputation* entry_computation = @@ -355,12 +357,13 @@ TEST_F(TransposeFoldingTest, FoldConvTransposeLhs) { dim->set_size(y->shape().dimensions(dnums.kernel_spatial_dimensions(i))); } StatusOr conv_shape = ShapeInference::InferConvolveShape( - transpose_x->shape(), y->shape(), /*feature_group_count=*/1, window, - dnums); + transpose_x->shape(), y->shape(), /*feature_group_count=*/1, + /*batch_group_count=*/1, window, dnums); EXPECT_IS_OK(conv_shape); HloInstruction* conv = builder.AddInstruction(HloInstruction::CreateConvolve( conv_shape.ValueOrDie(), transpose_x, y, - /*feature_group_count=*/1, window, dnums, DefaultPrecisionConfig(2))); + /*feature_group_count=*/1, /*batch_group_count=*/1, window, dnums, + DefaultPrecisionConfig(2))); auto module = CreateNewVerifiedModule("test_module"); HloComputation* entry_computation = @@ -421,12 +424,13 @@ TEST_F(TransposeFoldingTest, FoldConvComplexTransposeLhs) { dim->set_size(y->shape().dimensions(dnums.kernel_spatial_dimensions(i))); } StatusOr conv_shape = ShapeInference::InferConvolveShape( - transpose_x->shape(), y->shape(), /*feature_group_count=*/1, window, - dnums); + transpose_x->shape(), y->shape(), /*feature_group_count=*/1, + /*batch_group_count=*/1, window, dnums); EXPECT_IS_OK(conv_shape); HloInstruction* conv = builder.AddInstruction(HloInstruction::CreateConvolve( conv_shape.ValueOrDie(), transpose_x, y, - /*feature_group_count=*/1, window, dnums, DefaultPrecisionConfig(2))); + /*feature_group_count=*/1, /*batch_group_count=*/1, window, dnums, + DefaultPrecisionConfig(2))); auto module = CreateNewVerifiedModule("test_module"); HloComputation* entry_computation = diff --git a/tensorflow/compiler/xla/service/while_loop_constant_sinking_test.cc b/tensorflow/compiler/xla/service/while_loop_constant_sinking_test.cc index 75d406435b6f58faecc86b82c33e9e2dd6bccbea..3bcf5c38309a86e9e3cab3268f3f065005f7a923 100644 --- a/tensorflow/compiler/xla/service/while_loop_constant_sinking_test.cc +++ b/tensorflow/compiler/xla/service/while_loop_constant_sinking_test.cc @@ -129,7 +129,7 @@ condition { ENTRY entry { const_0 = f32[2] constant({1, 2}) - const_1 = (f32[2], f32[2]) constant((f32[2], f32[2]) ({2, 1},{3,1})) + const_1 = (f32[2], f32[2]) constant(({2, 1},{3,1})) while_init = (f32[2],(f32[2],f32[2])) tuple(const_0, const_1) ROOT while = (f32[2],(f32[2],f32[2])) while(while_init), condition=condition, body=body } @@ -206,8 +206,8 @@ body { p_body.0 = f32[2] get-tuple-element((f32[2],f32[2]) p_body), index=0 p_body.1 = f32[2] get-tuple-element((f32[2],f32[2]) p_body), index=1 - token = token[] after-all() - outfeed = token[] outfeed(p_body.0, token) + token0 = token[] after-all() + outfeed = token[] outfeed(p_body.0, token0) ROOT root = (f32[2],f32[2],f32[2]) tuple(p_body.0, p_body.1, p_body.1) } @@ -305,7 +305,7 @@ condition { ENTRY entry { const_0 = f32[] constant(0) - const_1 = (f32[], f32[]) constant((f32[], f32[]) (1, 10)) + const_1 = (f32[], f32[]) constant((1, 10)) while_init = (f32[],(f32[],f32[])) tuple(const_0, const_1) ROOT while = (f32[],(f32[],f32[])) while(while_init), condition=condition, body=body } diff --git a/tensorflow/compiler/xla/service/while_loop_simplifier_test.cc b/tensorflow/compiler/xla/service/while_loop_simplifier_test.cc index 4950e8269e9cf0723d717bd1734518d104c0c9f2..3713989ca2f64ee1d94c9f77255017909d957da2 100644 --- a/tensorflow/compiler/xla/service/while_loop_simplifier_test.cc +++ b/tensorflow/compiler/xla/service/while_loop_simplifier_test.cc @@ -22,6 +22,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_dce.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_matchers.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/service/tuple_simplifier.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" @@ -554,8 +555,7 @@ TEST_F(WhileLoopSimplifierTest, FlattenNestedTuple) { HloInstruction* new_while = FindFirstWhile(m.get()); Shape flat_tuple = - ShapeUtil::ParseShapeString("(s32[1], s32[2], s32[3], s32[4])") - .ValueOrDie(); + ParseShape("(s32[1], s32[2], s32[3], s32[4])").ValueOrDie(); SCOPED_TRACE(m->ToString()); EXPECT_TRUE(ShapeUtil::Equal(new_while->shape(), flat_tuple)); EXPECT_TRUE(ShapeUtil::Equal( @@ -567,8 +567,7 @@ TEST_F(WhileLoopSimplifierTest, FlattenNestedTuple) { flat_tuple)); EXPECT_TRUE(ShapeUtil::Equal( m->entry_computation()->root_instruction()->shape(), - ShapeUtil::ParseShapeString("((s32[1]), (s32[2], s32[3], (s32[4])))") - .ValueOrDie())); + ParseShape("((s32[1]), (s32[2], s32[3], (s32[4])))").ValueOrDie())); } // Edge-case: All elements of the loop carry are constants which can be removed, @@ -641,8 +640,7 @@ TEST_F(WhileLoopSimplifierTest, RemoveConstantFromLoopCarry) { EXPECT_TRUE(TupleSimplifier().Run(m.get()).ok()); HloInstruction* new_while = FindFirstWhile(m.get()); - Shape new_while_shape = - ShapeUtil::ParseShapeString("(s32[1], s32[3])").ValueOrDie(); + Shape new_while_shape = ParseShape("(s32[1], s32[3])").ValueOrDie(); EXPECT_TRUE(ShapeUtil::Equal(new_while->shape(), new_while_shape)); EXPECT_TRUE(ShapeUtil::Equal( new_while->while_body()->root_instruction()->shape(), new_while_shape)); @@ -652,9 +650,9 @@ TEST_F(WhileLoopSimplifierTest, RemoveConstantFromLoopCarry) { EXPECT_TRUE(ShapeUtil::Equal( new_while->while_condition()->parameter_instruction(0)->shape(), new_while_shape)); - EXPECT_TRUE(ShapeUtil::Equal( - m->entry_computation()->root_instruction()->shape(), - ShapeUtil::ParseShapeString("(s32[1], s32[2], s32[3])").ValueOrDie())); + EXPECT_TRUE( + ShapeUtil::Equal(m->entry_computation()->root_instruction()->shape(), + ParseShape("(s32[1], s32[2], s32[3])").ValueOrDie())); EXPECT_THAT(m->entry_computation()->root_instruction(), op::Tuple(_, op::Constant(), _)); } @@ -712,7 +710,7 @@ TEST_F(WhileLoopSimplifierTest, MergeInductionVariables_Simple) { // We should have added a new loop counter for s32[] to the end of the tuple. SCOPED_TRACE(m->ToString()); Shape new_while_shape = - ShapeUtil::ParseShapeString("(s32[], s32[], s32[], s32[])").ValueOrDie(); + ParseShape("(s32[], s32[], s32[], s32[])").ValueOrDie(); EXPECT_TRUE(ShapeUtil::Equal(new_while->shape(), new_while_shape)); EXPECT_TRUE(ShapeUtil::Equal( new_while->while_body()->root_instruction()->shape(), new_while_shape)); diff --git a/tensorflow/compiler/xla/service/while_util_test.cc b/tensorflow/compiler/xla/service/while_util_test.cc index 5e6941933330fde29bc9c779aae4bb3c36914660..d92b9870f373564ae8fd904c8bf9f0d1afbff9c4 100644 --- a/tensorflow/compiler/xla/service/while_util_test.cc +++ b/tensorflow/compiler/xla/service/while_util_test.cc @@ -180,8 +180,8 @@ body { cond { param.c = (s32[], s32[]) parameter(0) - token = token[] after-all() - infeed = (pred[], token[]) infeed(token) + token0 = token[] after-all() + infeed = (pred[], token[]) infeed(token0) ROOT condition = pred[] get-tuple-element(infeed), index=0 } diff --git a/tensorflow/compiler/xla/shape.cc b/tensorflow/compiler/xla/shape.cc index 746ab9e9977b1b10cdb0cb57197027d65bd50f55..b206345db2ac2940b1f139c82fa03a93538b5ccd 100644 --- a/tensorflow/compiler/xla/shape.cc +++ b/tensorflow/compiler/xla/shape.cc @@ -32,7 +32,7 @@ Shape::Shape(const ShapeProto& shape_proto) { *add_tuple_shapes() = Shape(element_shape); } if (shape_proto.has_layout()) { - *mutable_layout() = shape_proto.layout(); + *mutable_layout() = Layout::CreateFromProto(shape_proto.layout()); } } @@ -48,7 +48,7 @@ ShapeProto Shape::ToProto() const { *proto.add_tuple_shapes() = shape.ToProto(); } if (has_layout()) { - *proto.mutable_layout() = layout(); + *proto.mutable_layout() = layout().ToProto(); } return proto; } diff --git a/tensorflow/compiler/xla/shape.h b/tensorflow/compiler/xla/shape.h index 7f6b14ab4286c696dce64d2250a3fe8a57e4865b..7643f64d8a5f0450be1cddad35cf7422afb89048 100644 --- a/tensorflow/compiler/xla/shape.h +++ b/tensorflow/compiler/xla/shape.h @@ -20,6 +20,7 @@ limitations under the License. #include #include "absl/types/optional.h" +#include "tensorflow/compiler/xla/layout.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/platform/types.h" @@ -76,21 +77,10 @@ class Shape { std::vector* mutable_tuple_shapes() { return &tuple_shapes_; } // Methods for accessing the layout field. - bool has_layout() const { return layout_.has_value(); } - const Layout& layout() const { - if (layout_.has_value()) { - return *layout_; - } else { - return Layout::default_instance(); - } - } - Layout* mutable_layout() { - if (!layout_.has_value()) { - layout_ = Layout(); - } - return &layout_.value(); - } - void clear_layout() { layout_.reset(); } + bool has_layout() const { return layout_.format() != INVALID_FORMAT; } + const Layout& layout() const { return layout_; } + Layout* mutable_layout() { return &layout_; } + void clear_layout() { layout_.Clear(); } void Swap(Shape* other) { using std::swap; @@ -101,7 +91,7 @@ class Shape { element_type_ = PRIMITIVE_TYPE_INVALID; dimensions_.clear(); tuple_shapes_.clear(); - layout_.reset(); + clear_layout(); } string SerializeAsString() const { return ToProto().SerializeAsString(); } @@ -118,8 +108,8 @@ class Shape { // The tuple element subshapes. This is nonempty only for tuple shapes. std::vector tuple_shapes_; - // The array layout of the shape. This is present only for array shapes. - absl::optional layout_; + // The layout of the shape. Only relevant for arrays. + Layout layout_; }; // Shape of the parameters and output of an XLA computation. This is analogous diff --git a/tensorflow/compiler/xla/shape_util.cc b/tensorflow/compiler/xla/shape_util.cc index f3cc51ca9158d5c355c656b5450da1a66d96a379..be7d71ada009535a5c08aa3d3d062fa451cfeef3 100644 --- a/tensorflow/compiler/xla/shape_util.cc +++ b/tensorflow/compiler/xla/shape_util.cc @@ -164,9 +164,9 @@ StatusOr MakeShapeWithLayoutInternal( TF_ASSIGN_OR_RETURN(Shape shape, ShapeUtil::MakeValidatedShape(element_type, dimensions)); auto min2maj = shape.mutable_layout()->mutable_minor_to_major(); - min2maj->Clear(); + min2maj->clear(); for (int64 value : minor_to_major) { - min2maj->Add(value); + min2maj->push_back(value); } if (!shape.has_layout()) { return InvalidArgument("Shape has no layout."); @@ -234,7 +234,7 @@ StatusOr MakeShapeWithLayoutInternal( /* static */ StatusOr ShapeUtil::MakeValidatedShape( PrimitiveType element_type, absl::Span dimensions) { - CHECK(IsArrayPrimitiveType(element_type)); + CHECK(IsArrayPrimitiveType(element_type)) << element_type; Shape result; TF_RETURN_IF_ERROR(PopulateShape(element_type, dimensions, &result)); return result; @@ -480,54 +480,6 @@ ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout( return IsScalar(shape) && shape.element_type() == element_type; } -namespace { - -// Class to memoize the computation of -// absl::AsciiStrToLower(PrimitiveType_Name(p)) -// for all PrimitiveType values "p" -class PrimitiveTypeNameGenerator { - public: - PrimitiveTypeNameGenerator() { - for (int i = 0; i < PrimitiveType_ARRAYSIZE; i++) { - if (PrimitiveType_IsValid(i)) { - lowercase_name_[i] = absl::AsciiStrToLower( - PrimitiveType_Name(static_cast(i))); - } - } - } - const string& LowercaseName(PrimitiveType t) { - return lowercase_name_[static_cast(t)]; - } - - private: - string lowercase_name_[PrimitiveType_ARRAYSIZE]; -}; - -const string& LowercasePrimitiveTypeName(PrimitiveType s) { - static PrimitiveTypeNameGenerator* gen = new PrimitiveTypeNameGenerator(); - return gen->LowercaseName(s); -} - -StatusOr StringToPrimitiveType(const string& name) { - static std::unordered_map* name_to_type = [] { - static auto* map = new std::unordered_map; - for (int i = 0; i < PrimitiveType_ARRAYSIZE; i++) { - if (PrimitiveType_IsValid(i)) { - auto value = static_cast(i); - (*map)[LowercasePrimitiveTypeName(value)] = value; - } - } - return map; - }(); - auto found = name_to_type->find(name); - if (found == name_to_type->end()) { - return InvalidArgument("Invalid element type string: \"%s\".", name); - } - return found->second; -} - -} // namespace - /* static */ string ShapeUtil::HumanString(const Shape& shape) { if (IsTuple(shape)) { string text = "("; @@ -539,8 +491,9 @@ StatusOr StringToPrimitiveType(const string& name) { text += ")"; return text; } - return StrCat(LowercasePrimitiveTypeName(shape.element_type()), "[", - absl::StrJoin(shape.dimensions(), ","), "]"); + return StrCat( + primitive_util::LowercasePrimitiveTypeName(shape.element_type()), "[", + absl::StrJoin(shape.dimensions(), ","), "]"); } /* static */ string ShapeUtil::HumanStringWithLayout(const Shape& shape) { @@ -554,7 +507,8 @@ StatusOr StringToPrimitiveType(const string& name) { text += ")"; return text; } - string result = StrCat(LowercasePrimitiveTypeName(shape.element_type()), "["); + string result = StrCat( + primitive_util::LowercasePrimitiveTypeName(shape.element_type()), "["); for (int i = 0; i < shape.dimensions().size(); i++) { StrAppend(&result, (i > 0) ? "," : "", shape.dimensions(i)); } @@ -580,116 +534,6 @@ StatusOr StringToPrimitiveType(const string& name) { HumanString(program_shape.result())); } -namespace { -// Parses shapes with simple recursive descent structure -- consumes from the -// front of s and passes that view recursively as required. -StatusOr ParseShapeStringInternal(absl::string_view* s) { - *s = StripLeadingAsciiWhitespace(*s); - - if (absl::ConsumePrefix(s, "(")) { // Tuple. - std::vector shapes; - bool must_end = false; - while (true) { - if (absl::ConsumePrefix(s, ")")) { - break; - } else if (must_end) { - return InvalidArgument("Expected end of tuple; got: \"%s\"", *s); - } - shapes.emplace_back(); - TF_ASSIGN_OR_RETURN(shapes.back(), ParseShapeStringInternal(s)); - *s = StripLeadingAsciiWhitespace(*s); - must_end = !absl::ConsumePrefix(s, ","); - } - return ShapeUtil::MakeTupleShape(shapes); - } - - string element_type_string; - string dimensions_string; - string format_string; - string layout_string; - // absl::string_view is not compatible with internal RE2 StringPiece, so - // we convert in to the RE2-consumable type and then consume the corresponding - // amount from our string_view type. - static LazyRE2 shape_pattern = { - "^(\\w*\\d*)\\[([\\d,\\s]*)\\](?:\\s*(dense|sparse)?\\s*{([\\d,\\s]+)})" - "?"}; - tensorflow::RegexpStringPiece s_consumable(s->data(), s->size()); - if (RE2::Consume(&s_consumable, *shape_pattern, &element_type_string, - &dimensions_string, &format_string, &layout_string)) { - size_t consumed = s->size() - s_consumable.size(); - s->remove_prefix(consumed); - auto string_to_int64 = [&s](absl::string_view input) -> StatusOr { - int64 element; - if (!absl::SimpleAtoi(input, &element)) { - return InvalidArgument( - "Invalid s64 value in parsed shape string: \"%s\" in \"%s\"", input, - *s); - } - return element; - }; - - auto comma_list_to_int64s = - [string_to_int64](const string& input) -> StatusOr> { - std::vector results; - for (const auto& piece : absl::StrSplit(input, ',', absl::SkipEmpty())) { - TF_ASSIGN_OR_RETURN(int64 element, string_to_int64(piece)); - results.push_back(element); - } - return results; - }; - - // Extract the dimensions. - TF_ASSIGN_OR_RETURN(std::vector dimensions, - comma_list_to_int64s(dimensions_string)); - - // Extract the primitive element type. - TF_ASSIGN_OR_RETURN(const PrimitiveType primitive_type, - StringToPrimitiveType(element_type_string)); - if (primitive_type == PRIMITIVE_TYPE_INVALID || primitive_type == TUPLE) { - return InvalidArgument("Invalid element type string: \"%s\".", - element_type_string); - } - - Shape result; - if (primitive_type == OPAQUE) { - result = ShapeUtil::MakeOpaqueShape(); - } else if (primitive_type == TOKEN) { - result = ShapeUtil::MakeTokenShape(); - } else if (format_string.empty() && layout_string.empty()) { - // Create a shape without a layout set. - TF_ASSIGN_OR_RETURN( - result, ShapeUtil::MakeValidatedShape(primitive_type, dimensions)); - } else if (format_string == "sparse") { - TF_ASSIGN_OR_RETURN(int64 max_elements, string_to_int64(layout_string)); - result = ShapeUtil::MakeShapeWithSparseLayout(primitive_type, dimensions, - max_elements); - } else if (format_string.empty() || format_string == "dense") { - // Extract the layout minor-to-major and set it. - TF_ASSIGN_OR_RETURN(std::vector min2maj, - comma_list_to_int64s(layout_string)); - TF_ASSIGN_OR_RETURN(result, MakeShapeWithLayoutInternal( - primitive_type, dimensions, min2maj)); - } else { - // This should not be reached. - LOG(FATAL) << "Unhandled condition when parsing shape; format: \"" - << format_string << "\", layout: \"" << layout_string << "\""; - } - TF_RETURN_IF_ERROR(ShapeUtil::ValidateShape(result)); - return std::move(result); - } - - return InvalidArgument("Invalid shape string to parse: \"%s\"", *s); -} -} // namespace - -/* static */ StatusOr ShapeUtil::ParseShapeString(absl::string_view s) { - TF_ASSIGN_OR_RETURN(Shape shape, ParseShapeStringInternal(&s)); - if (!s.empty()) { - return InvalidArgument("Invalid shape string to parse: \"%s\"", s); - } - return shape; -} - /* static */ bool ShapeUtil::SameDimensions(const Shape& lhs, const Shape& rhs) { CHECK(ShapeUtil::IsArray(lhs)); @@ -867,13 +711,13 @@ StatusOr ParseShapeStringInternal(absl::string_view* s) { if (shape.dimensions_size() != 0) { return InvalidArgument( "shape has %s element type, but has dimensions field: %s", - LowercasePrimitiveTypeName(shape.element_type()), + primitive_util::LowercasePrimitiveTypeName(shape.element_type()), shape.ShortDebugString()); } if (shape.has_layout()) { return InvalidArgument( "shape has %s element type, but has layout field: %s", - LowercasePrimitiveTypeName(shape.element_type()), + primitive_util::LowercasePrimitiveTypeName(shape.element_type()), shape.ShortDebugString()); } return Status::OK(); @@ -1067,6 +911,11 @@ bool ShapeUtil::IsLeafIndex(const Shape& shape, const ShapeIndex& index) { return absl::c_linear_search(shape.dimensions(), 1); } +/* static */ Shape ShapeUtil::DropDegenerateDimensions(const Shape& shape) { + return FilterDimensions( + [&](int64 dim) -> bool { return shape.dimensions()[dim] != 1; }, shape); +} + namespace { // Helper for ForEachSubshape which visits the subshapes of the given shape in @@ -1618,10 +1467,10 @@ ShapeUtil::DimensionsUnmodifiedByReshape(const Shape& input_shape, if (LayoutUtil::HasLayout(shape)) { Layout* layout = shape.mutable_layout(); layout->set_format(DENSE); - for (size_t i = 0; i < layout->minor_to_major().size();) { + for (int64 i = 0; i < layout->minor_to_major().size();) { if (layout->minor_to_major(i) == dim_to_delete) { layout->mutable_minor_to_major()->erase( - layout->minor_to_major().begin() + i); + layout->mutable_minor_to_major()->begin() + i); continue; } if (layout->minor_to_major(i) > dim_to_delete) { diff --git a/tensorflow/compiler/xla/shape_util.h b/tensorflow/compiler/xla/shape_util.h index 84a27f662a57ba274562e2e9be57b7e971c9b477..8a7d755951e6ec1d0a5416e844e55b6d7e7beb7b 100644 --- a/tensorflow/compiler/xla/shape_util.h +++ b/tensorflow/compiler/xla/shape_util.h @@ -241,10 +241,6 @@ class ShapeUtil { // (param_name: f32[42x12], ...) -> f32[24x42] static string HumanString(const ProgramShape& program_shape); - // Parses a ShapeUtil::HumanString-format shape string back into a shape - // object. - static StatusOr ParseShapeString(absl::string_view s); - // Returns whether the LHS and RHS shapes have the same dimensions; note: does // not check element type. // Precondition: IsArray(lhs) && IsArray(rhs) @@ -266,7 +262,7 @@ class ShapeUtil { } // Returns the higher-precision element type if a and b are both floating - // point types; otherwise, checks that that they have the same element type + // point types; otherwise, checks that they have the same element type // and returns it. static PrimitiveType HigherPrecisionElementType(const Shape& a, const Shape& b) { @@ -551,6 +547,9 @@ class ShapeUtil { // (dimensions with bound 1). static bool HasDegenerateDimensions(const Shape& shape); + // Drops any degenerate dimensions (i.e. dimensions of size 1) + static Shape DropDegenerateDimensions(const Shape& shape); + // Permutes the dimensions by the given permutation, so // return_value.dimensions[permutation[i]] = argument.dimensions[i]. // diff --git a/tensorflow/compiler/xla/shape_util_test.cc b/tensorflow/compiler/xla/shape_util_test.cc index 60bdbe302045e6f3b4bae500c50bc68fb217525d..0a3081f5161f80ac97e864ba08d186df4fbdb55d 100644 --- a/tensorflow/compiler/xla/shape_util_test.cc +++ b/tensorflow/compiler/xla/shape_util_test.cc @@ -82,102 +82,6 @@ TEST(ShapeUtilTest, Rank4DimensionIndexing) { ASSERT_EQ(3, shape.dimensions(0)); } -TEST(ShapeUtilTest, ParseShapeStringR2F32) { - string shape_string = "f32[123,456]"; - TF_ASSERT_OK_AND_ASSIGN(Shape actual, - ShapeUtil::ParseShapeString(shape_string)); - Shape expected = ShapeUtil::MakeShape(F32, {123, 456}); - ASSERT_TRUE(ShapeUtil::Equal(expected, actual)) - << "expected: " << ShapeUtil::HumanString(expected) - << "actual: " << ShapeUtil::HumanString(actual); -} - -TEST(ShapeUtilTest, ParseShapeStringTupleOfArrays) { - string shape_string = "(f32[1572864],s8[5120,1024])"; - TF_ASSERT_OK_AND_ASSIGN(Shape actual, - ShapeUtil::ParseShapeString(shape_string)); - Shape expected = - ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(F32, {1572864}), - ShapeUtil::MakeShape(S8, {5120, 1024})}); - ASSERT_TRUE(ShapeUtil::Equal(expected, actual)) - << "expected: " << ShapeUtil::HumanString(expected) - << "actual: " << ShapeUtil::HumanString(actual); -} - -TEST(ShapeUtilTest, ParseShapeStringNestedTuple) { - string shape_string = "(f32[1],(f32[2], token[]), opaque[], f32[3])"; - TF_ASSERT_OK_AND_ASSIGN(Shape actual, - ShapeUtil::ParseShapeString(shape_string)); - Shape expected = ShapeUtil::MakeTupleShape({ - ShapeUtil::MakeShape(F32, {1}), - ShapeUtil::MakeTupleShape( - {ShapeUtil::MakeShape(F32, {2}), ShapeUtil::MakeTokenShape()}), - ShapeUtil::MakeOpaqueShape(), - ShapeUtil::MakeShape(F32, {3}), - }); - ASSERT_TRUE(ShapeUtil::Equal(expected, actual)) - << "expected: " << ShapeUtil::HumanString(expected) - << "actual: " << ShapeUtil::HumanString(actual); -} - -TEST(ShapeUtilTest, ParseShapeStringWithLayout) { - string shape_string = "f32[123,456]{0,1}"; - TF_ASSERT_OK_AND_ASSIGN(Shape actual, - ShapeUtil::ParseShapeString(shape_string)); - Shape expected = ShapeUtil::MakeShapeWithLayout(F32, {123, 456}, {0, 1}); - ASSERT_TRUE(ShapeUtil::Equal(expected, actual)) - << "expected: " << ShapeUtil::HumanString(expected) - << "actual: " << ShapeUtil::HumanString(actual); -} - -TEST(ShapeUtilTest, ParseShapeStringWithExplicitDenseLayout) { - string shape_string = "f32[123,456]dense{0,1}"; - TF_ASSERT_OK_AND_ASSIGN(Shape actual, - ShapeUtil::ParseShapeString(shape_string)); - Shape expected = ShapeUtil::MakeShapeWithLayout(F32, {123, 456}, {0, 1}); - ASSERT_TRUE(ShapeUtil::Equal(expected, actual)) - << "expected: " << ShapeUtil::HumanString(expected) - << "actual: " << ShapeUtil::HumanString(actual); -} - -TEST(ShapeUtilTest, ParseShapeStringWithSparseLayout) { - string shape_string = "f32[123,456]sparse{10}"; - TF_ASSERT_OK_AND_ASSIGN(Shape actual, - ShapeUtil::ParseShapeString(shape_string)); - Shape expected = ShapeUtil::MakeShapeWithSparseLayout(F32, {123, 456}, 10); - ASSERT_TRUE(ShapeUtil::Equal(expected, actual)) - << "expected: " << ShapeUtil::HumanString(expected) - << "actual: " << ShapeUtil::HumanString(actual); -} - -TEST(ShapeUtilTest, ParseOpaqueType) { - TF_ASSERT_OK_AND_ASSIGN(Shape actual, - ShapeUtil::ParseShapeString("opaque[]")); - Shape expected = ShapeUtil::MakeOpaqueShape(); - ASSERT_TRUE(ShapeUtil::Equal(expected, actual)) - << "expected: " << ShapeUtil::HumanString(expected) - << "actual: " << ShapeUtil::HumanString(actual); -} - -TEST(ShapeUtilTest, ParseTokenType) { - TF_ASSERT_OK_AND_ASSIGN(Shape actual, ShapeUtil::ParseShapeString("token[]")); - Shape expected = ShapeUtil::MakeTokenShape(); - ASSERT_TRUE(ShapeUtil::Equal(expected, actual)) - << "expected: " << ShapeUtil::HumanString(expected) - << "actual: " << ShapeUtil::HumanString(actual); -} - -TEST(ShapeUtilTest, ParseInvalidShapeString) { - string shape_strings[] = { - "f32[123,456]foobar{0,1}", "f32[123,456]sparse{0,1}", "f32[123,456]{foo}", - "f32[123,456]dense{foo}", "f32[123,456]sparse{foo}", - }; - for (const string& shape_string : shape_strings) { - StatusOr result = ShapeUtil::ParseShapeString(shape_string); - ASSERT_FALSE(result.ok()) << "shape: " << shape_string; - } -} - TEST(ShapeUtilTest, CompatibleIdenticalShapes) { Shape shape1 = ShapeUtil::MakeShape(F32, {3, 2}); Shape shape2 = ShapeUtil::MakeShape(F32, {3, 2}); diff --git a/tensorflow/compiler/xla/tests/BUILD b/tensorflow/compiler/xla/tests/BUILD index 5a7a4faa7e89b27fb537f20d94c21cb4a76e000d..ee24d4d99cb1f7ce51a72c6258cbadd6adf12f81 100644 --- a/tensorflow/compiler/xla/tests/BUILD +++ b/tensorflow/compiler/xla/tests/BUILD @@ -1,6 +1,13 @@ # Description: # Base testing infrastructure for XLA. +load("//tensorflow/compiler/xla/tests:build_defs.bzl", "generate_backend_suites", "generate_backend_test_macros", "xla_test", "xla_test_library") +load( + "//tensorflow/core:platform/default/build_config_root.bzl", + "tf_cuda_tests_tags", +) +load("//tensorflow:tensorflow.bzl", "tf_cc_binary", "tf_cc_test") + licenses(["notice"]) # Apache 2.0 package( @@ -23,17 +30,6 @@ filegroup( ]), ) -load("//tensorflow/compiler/xla/tests:build_defs.bzl", "xla_test") -load("//tensorflow/compiler/xla/tests:build_defs.bzl", "xla_test_library") -load("//tensorflow/compiler/xla/tests:build_defs.bzl", "generate_backend_suites") -load("//tensorflow/compiler/xla/tests:build_defs.bzl", "generate_backend_test_macros") -load("//tensorflow:tensorflow.bzl", "tf_cc_binary") -load("//tensorflow:tensorflow.bzl", "tf_cc_test") -load( - "//tensorflow/core:platform/default/build_config_root.bzl", - "tf_cuda_tests_tags", -) - # Generate test_suites for all backends, named "${backend}_tests". generate_backend_suites() @@ -846,6 +842,7 @@ xla_test( "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/tests:client_library_test_base", + "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", @@ -1348,6 +1345,7 @@ xla_test( xla_test( name = "custom_call_test", srcs = ["custom_call_test.cc"], + backends = ["cpu"], deps = [ "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:literal_util", @@ -1637,8 +1635,8 @@ xla_test( ) xla_test( - name = "cross_replica_sum_test", - srcs = ["cross_replica_sum_test.cc"], + name = "all_reduce_test", + srcs = ["all_reduce_test.cc"], deps = [ "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", diff --git a/tensorflow/compiler/xla/tests/cross_replica_sum_test.cc b/tensorflow/compiler/xla/tests/all_reduce_test.cc similarity index 94% rename from tensorflow/compiler/xla/tests/cross_replica_sum_test.cc rename to tensorflow/compiler/xla/tests/all_reduce_test.cc index 410732c07b7b6d3ece33ab11f4778241dc53ca50..7e695f829e39831e2c8558cb07d0689e560bbafa 100644 --- a/tensorflow/compiler/xla/tests/cross_replica_sum_test.cc +++ b/tensorflow/compiler/xla/tests/all_reduce_test.cc @@ -41,7 +41,7 @@ XLA_TEST_F(TrivialCrossReplicaSumTest, OneOperand) { ENTRY test_computation { p = f32[3] parameter(0) - ROOT crs = f32[3] cross-replica-sum(p), to_apply=add + ROOT crs = f32[3] all-reduce(p), to_apply=add })"; auto module = ParseHloString(module_str, GetModuleConfigForTest()).ValueOrDie(); @@ -62,7 +62,7 @@ XLA_TEST_F(TrivialCrossReplicaSumTest, MultipleOperands) { ENTRY test_computation { p0 = f32[3] parameter(0) p1 = f32[2] parameter(1) - ROOT crs = (f32[3], f32[2]) cross-replica-sum(p0, p1), to_apply=add + ROOT crs = (f32[3], f32[2]) all-reduce(p0, p1), to_apply=add })"; auto module = ParseHloString(module_str, GetModuleConfigForTest()).ValueOrDie(); @@ -88,7 +88,7 @@ XLA_TEST_F(TrivialCrossReplicaSumTest, ConstantOperand) { ENTRY test_computation { p0 = f32[3] parameter(0) p1 = f32[2] constant({10, 20}) - ROOT crs = (f32[3], f32[2]) cross-replica-sum(p0, p1), to_apply=add + ROOT crs = (f32[3], f32[2]) all-reduce(p0, p1), to_apply=add })"; auto module = ParseHloString(module_str, GetModuleConfigForTest()).ValueOrDie(); diff --git a/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc b/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc index f6be27bee27f5f28b1474b78ef78a0d2fd99894c..915b456b52215f8d6a9eb6c5b933f3502f1d3d2c 100644 --- a/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc +++ b/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc @@ -329,13 +329,13 @@ TEST_P(ArrayElementwiseOpTestParamCount, AddManyValues) { Literal b_literal = LiteralUtil::CreateR1({b_values}); std::unique_ptr b_data = client_->TransferToServer(b_literal).ConsumeValueOrDie(); - auto b_constant = Parameter(&builder, 1, a_literal.shape(), "b_param"); - auto b_param = ConstantR1(&builder, b_values); + auto b_param = Parameter(&builder, 1, a_literal.shape(), "b_param"); + auto b_constant = ConstantR1(&builder, b_values); - auto sum1 = Add(a_constant, b_constant); - auto sum2 = Add(a_constant, b_param); - auto sum3 = Add(a_param, b_constant); - auto sum4 = Add(a_param, b_param); + auto sum1 = Add(a_constant, b_param); + auto sum2 = Add(a_constant, b_constant); + auto sum3 = Add(a_param, b_param); + auto sum4 = Add(a_param, b_constant); auto sum = Add(sum1, sum2); sum = Add(sum, sum3); diff --git a/tensorflow/compiler/xla/tests/client_library_test_base.cc b/tensorflow/compiler/xla/tests/client_library_test_base.cc index 12c029983336cc9aed0fde4ce6881c9a00a9869e..a350715597044730429ee9fa268ecd6f2bf26b66 100644 --- a/tensorflow/compiler/xla/tests/client_library_test_base.cc +++ b/tensorflow/compiler/xla/tests/client_library_test_base.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/compiler/xla/tests/client_library_test_base.h" +#include #include #include "absl/memory/memory.h" @@ -74,6 +75,9 @@ ClientLibraryTestBase::ClientLibraryTestBase( // default. execution_options_.mutable_debug_options()->add_xla_disable_hlo_passes( "constant_folding"); + + execution_options_.mutable_debug_options() + ->set_xla_hlo_evaluator_use_fast_path(true); } ClientLibraryTestBase::ClientLibraryTestBase(se::Platform* platform) @@ -88,6 +92,9 @@ ClientLibraryTestBase::ClientLibraryTestBase(se::Platform* platform) execution_options_.mutable_debug_options()->add_xla_disable_hlo_passes( "constant_folding"); + + execution_options_.mutable_debug_options() + ->set_xla_hlo_evaluator_use_fast_path(true); } string ClientLibraryTestBase::TestName() const { @@ -273,9 +280,10 @@ StatusOr ClientLibraryTestBase::ComputeAndTransfer( if (!arguments_.empty()) { CHECK(arguments.empty()); for (const auto& argument : arguments_) { - owning_arguments.push_back( - client_->TransferToServer(MaybeConvertLiteralToBfloat16(argument)) - .ValueOrDie()); + TF_ASSIGN_OR_RETURN( + std::unique_ptr owned_argument, + client_->TransferToServer(MaybeConvertLiteralToBfloat16(argument))); + owning_arguments.push_back(std::move(owned_argument)); arguments.push_back(owning_arguments.back().get()); } } @@ -296,9 +304,10 @@ Status ClientLibraryTestBase::ComputeAndCompareLiteralWithStatus( if (!arguments_.empty()) { CHECK(arguments.empty()); for (const auto& argument : arguments_) { - owning_arguments.push_back( - client_->TransferToServer(MaybeConvertLiteralToBfloat16(argument)) - .ValueOrDie()); + TF_ASSIGN_OR_RETURN( + std::unique_ptr owned_argument, + client_->TransferToServer(MaybeConvertLiteralToBfloat16(argument))); + owning_arguments.push_back(std::move(owned_argument)); arguments.push_back(owning_arguments.back().get()); } } @@ -356,9 +365,10 @@ Status ClientLibraryTestBase::ComputeAndCompareLiteralWithStatus( if (!arguments_.empty()) { CHECK(arguments.empty()); for (const auto& argument : arguments_) { - owning_arguments.push_back( - client_->TransferToServer(MaybeConvertLiteralToBfloat16(argument)) - .ValueOrDie()); + TF_ASSIGN_OR_RETURN( + std::unique_ptr owned_argument, + client_->TransferToServer(MaybeConvertLiteralToBfloat16(argument))); + owning_arguments.push_back(std::move(owned_argument)); arguments.push_back(owning_arguments.back().get()); } } diff --git a/tensorflow/compiler/xla/tests/constants_test.cc b/tensorflow/compiler/xla/tests/constants_test.cc index 72ff1e74a47c8584cb5336c86a1c978c4637a902..9174f2651cb90b364f869364fe108cf208c11a84 100644 --- a/tensorflow/compiler/xla/tests/constants_test.cc +++ b/tensorflow/compiler/xla/tests/constants_test.cc @@ -25,7 +25,9 @@ limitations under the License. #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" +#include "tensorflow/compiler/xla/tests/test_macros.h" #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/types.h" @@ -178,5 +180,33 @@ TEST_F(ConstantsTest, Token) { TF_ASSERT_OK(Execute(&builder, {}).status()); } +class ConstantsHloTest : public HloTestBase {}; + +// TODO(b/121147351): Fails on GPU. Not clear if this is expected behavior. +XLA_TEST_F(ConstantsHloTest, DISABLED_ON_GPU(BitcastOfConstant)) { + const char* testcase = R"( + HloModule module, is_scheduled=true + + func { + lhs = s32[] parameter(0) + rhs = s32[] parameter(1) + ROOT mul = s32[] add(lhs, rhs) + } + + ENTRY test { + constant.0 = s32[1]{0} constant({0}) + parameter.0 = s32[] parameter(0) + constant-as-scalar = s32[] bitcast(constant.0) + ROOT result = s32[] call(parameter.0, constant-as-scalar), to_apply=func + } + )"; + auto module = + HloRunner::CreateModuleFromString(testcase, GetDebugOptionsForTest()) + .ValueOrDie(); + auto param = LiteralUtil::CreateR0(1); + auto result = ExecuteNoHloPasses(std::move(module), {¶m}); + EXPECT_TRUE(LiteralTestUtil::Equal(param, result)); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/tests/convolution_test.cc b/tensorflow/compiler/xla/tests/convolution_test.cc index 4a58a1ed66c438d1dd9561f4eb029b38d8c6cbdd..249693891290e14645ee5b4b4d97b2d506a01302 100644 --- a/tensorflow/compiler/xla/tests/convolution_test.cc +++ b/tensorflow/compiler/xla/tests/convolution_test.cc @@ -98,7 +98,7 @@ class ForwardPassConvolution_3x3x256_256_OutputZ_Iota : public ConvolutionTest { precision.add_operand_precision(PrecisionConfig::HIGHEST); precision.add_operand_precision(PrecisionConfig::DEFAULT); Conv(lhs, rhs, {1, 1}, Padding::kValid, /*feature_group_count=*/1, - &precision); + /*batch_group_count=*/1, &precision); ComputeAndCompare(&builder, {}, error_spec_); } diff --git a/tensorflow/compiler/xla/tests/copy_test.cc b/tensorflow/compiler/xla/tests/copy_test.cc index 3622f2c1e84639baed13059b21b20609d1347da6..df005a67097bb8aaf070c57d1c51acd1909fee12 100644 --- a/tensorflow/compiler/xla/tests/copy_test.cc +++ b/tensorflow/compiler/xla/tests/copy_test.cc @@ -133,7 +133,9 @@ XLA_TEST_F(CopyOpTest, CopyConstantR2DifferentLayouts) { // Reverse the minor-to-major order of the literal. Layout* literal_layout = literal.mutable_shape_do_not_use()->mutable_layout(); ASSERT_EQ(2, literal_layout->minor_to_major_size()); - literal_layout->mutable_minor_to_major()->SwapElements(0, 1); + // Swap the first and second elements. + *literal_layout->mutable_minor_to_major() = { + literal_layout->minor_to_major(1), literal_layout->minor_to_major(0)}; HloInstruction* constant = builder.AddInstruction( HloInstruction::CreateConstant(std::move(literal))); diff --git a/tensorflow/compiler/xla/tests/custom_call_test.cc b/tensorflow/compiler/xla/tests/custom_call_test.cc index 738b6442354b01364278e3e3c713aa2cdb5cf47d..cad43d1b5547d74701760fa623e50466fc15c263 100644 --- a/tensorflow/compiler/xla/tests/custom_call_test.cc +++ b/tensorflow/compiler/xla/tests/custom_call_test.cc @@ -54,11 +54,20 @@ void Add1ToValues(float* out, float** in) { out[2] = array[2] + 1; out[3] = array[3] + 1; } + +void F32TupleSwap(float** out, float** in) { + TF_ANNOTATE_MEMORY_IS_INITIALIZED(in[0], sizeof(float)); + TF_ANNOTATE_MEMORY_IS_INITIALIZED(in[1], sizeof(float)); + *out[0] = *in[1]; + *out[1] = *in[0]; +} + } // namespace REGISTER_CUSTOM_CALL_TARGET(R0F32Add2); REGISTER_CUSTOM_CALL_TARGET(R2F32ReduceSum); REGISTER_CUSTOM_CALL_TARGET(Add1ToValues); +REGISTER_CUSTOM_CALL_TARGET(F32TupleSwap); namespace xla { namespace { @@ -69,7 +78,7 @@ class CustomCallTest : public HloTestBase { Shape r2f32_ = ShapeUtil::MakeShape(F32, {2, 2}); }; -XLA_TEST_F(CustomCallTest, DISABLED_ON_GPU(CustomCallR0F32Add2)) { +XLA_TEST_F(CustomCallTest, CustomCallR0F32Add2) { auto module = CreateNewUnverifiedModule(); auto builder = HloComputation::Builder(TestName()); @@ -84,7 +93,7 @@ XLA_TEST_F(CustomCallTest, DISABLED_ON_GPU(CustomCallR0F32Add2)) { LiteralTestUtil::ExpectR0Near(44.0f, result, error_spec_); } -XLA_TEST_F(CustomCallTest, DISABLED_ON_GPU(CustomCallR2F32Reduce)) { +XLA_TEST_F(CustomCallTest, CustomCallR2F32Reduce) { auto module = CreateNewUnverifiedModule(); auto builder = HloComputation::Builder(TestName()); @@ -105,7 +114,7 @@ XLA_TEST_F(CustomCallTest, DISABLED_ON_GPU(CustomCallR2F32Reduce)) { LiteralTestUtil::ExpectR0Near(10.0f, result, error_spec_); } -XLA_TEST_F(CustomCallTest, DISABLED_ON_GPU(UsedInOtherComputations)) { +XLA_TEST_F(CustomCallTest, UsedInOtherComputations) { auto module = CreateNewUnverifiedModule(); auto b = HloComputation::Builder(TestName()); @@ -129,7 +138,7 @@ XLA_TEST_F(CustomCallTest, DISABLED_ON_GPU(UsedInOtherComputations)) { Array3D{{{2, 3}, {4, 5}}, {{3, 4}, {5, 6}}}, result); } -XLA_TEST_F(CustomCallTest, DISABLED_ON_GPU(InputAndOutputLayoutDiffer)) { +XLA_TEST_F(CustomCallTest, InputAndOutputLayoutDiffer) { auto module = CreateNewUnverifiedModule(); auto b = HloComputation::Builder(TestName()); @@ -151,7 +160,7 @@ XLA_TEST_F(CustomCallTest, DISABLED_ON_GPU(InputAndOutputLayoutDiffer)) { LiteralTestUtil::ExpectR2Equal({{2.f, 4.f}, {3.f, 5.f}}, result); } -XLA_TEST_F(CustomCallTest, DISABLED_ON_GPU(LayoutConstrained)) { +XLA_TEST_F(CustomCallTest, LayoutConstrained) { // The argument and result of the computation are set to different layouts, // but the custom call is layout constrained to a fixed operand and result // layout, so the correct result should be produced. @@ -176,6 +185,26 @@ XLA_TEST_F(CustomCallTest, DISABLED_ON_GPU(LayoutConstrained)) { LiteralTestUtil::ExpectR2Equal({{2.f, 3.f}, {4.f, 5.f}}, result); } +XLA_TEST_F(CustomCallTest, TupleOutput) { + const char* kModuleStr = R"( + HloModule m + test { + p0 = f32[] parameter(0) + p1 = f32[] parameter(1) + ROOT %custom-call = (f32[], f32[]) custom-call(f32[] %p0, f32[] %p1), custom_call_target="F32TupleSwap", operand_layout_constraints={f32[], f32[]} + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(kModuleStr)); + + Literal arg0 = LiteralUtil::CreateR0(7.f); + Literal arg1 = LiteralUtil::CreateR0(42.f); + + Literal expected = LiteralUtil::MakeTuple({&arg1, &arg0}); + Literal result = ExecuteAndTransfer(std::move(module), {&arg0, &arg1}); + EXPECT_EQ(result, expected); +} + class CustomCallClientAPITest : public ClientLibraryTestBase {}; // When using the client API, CustomCall targets can't begin with '$' -- these diff --git a/tensorflow/compiler/xla/tests/hlo_test_base.cc b/tensorflow/compiler/xla/tests/hlo_test_base.cc index 989a7c705a8254f99e5cc0e97dfde5942f146964..d57846e19bb80c5b9c87d50596da2915f9aef317 100644 --- a/tensorflow/compiler/xla/tests/hlo_test_base.cc +++ b/tensorflow/compiler/xla/tests/hlo_test_base.cc @@ -181,6 +181,7 @@ DebugOptions HloTestBase::GetDebugOptionsForTest() { // TODO(b/38354253): Change tests to use Parameters instead of Constants. debug_options.add_xla_disable_hlo_passes("constant_folding"); debug_options.set_xla_gpu_max_kernel_unroll_factor(1); + debug_options.set_xla_hlo_evaluator_use_fast_path(true); return debug_options; } diff --git a/tensorflow/compiler/xla/tests/iota_test.cc b/tensorflow/compiler/xla/tests/iota_test.cc index 65205f53ddc582ae477d67705f161fef1e31b857..37b2c635eebe57590e1ba73c62f015ccf399b548 100644 --- a/tensorflow/compiler/xla/tests/iota_test.cc +++ b/tensorflow/compiler/xla/tests/iota_test.cc @@ -80,7 +80,7 @@ TEST_P(IotaR2Test, DoIt) { } INSTANTIATE_TEST_CASE_P(IotaR2TestInstantiation, IotaR2Test, - ::testing::Combine(::testing::Values(F32, S32), + ::testing::Combine(::testing::Values(F32, S32, BF16), ::testing::Range(/*start=*/10, /*end=*/1001, /*step=*/10), diff --git a/tensorflow/compiler/xla/tests/literal_test_util_test.cc b/tensorflow/compiler/xla/tests/literal_test_util_test.cc index b6f9b8156b51144e4f74d285b1e4111d098f13c2..ea9b3037cf482e41238413179888f125822d161c 100644 --- a/tensorflow/compiler/xla/tests/literal_test_util_test.cc +++ b/tensorflow/compiler/xla/tests/literal_test_util_test.cc @@ -89,11 +89,11 @@ TEST(LiteralTestUtilTest, ExpectNearFailurePlacesResultsInTemporaryDirectory) { Literal literal = Literal::CreateFromProto(literal_proto).ConsumeValueOrDie(); if (result.find("expected") != string::npos) { - EXPECT_EQ("2", literal.ToString()); + EXPECT_EQ("f32[] 2", literal.ToString()); } else if (result.find("actual") != string::npos) { - EXPECT_EQ("4", literal.ToString()); + EXPECT_EQ("f32[] 4", literal.ToString()); } else if (result.find("mismatches") != string::npos) { - EXPECT_EQ("true", literal.ToString()); + EXPECT_EQ("pred[] true", literal.ToString()); } else { FAIL() << "unknown file in temporary directory: " << result; } @@ -105,9 +105,9 @@ TEST(LiteralTestUtilTest, NotEqualHasValuesInMessage) { auto actual = LiteralUtil::CreateR1({4, 5, 6}); ::testing::AssertionResult result = LiteralTestUtil::Equal(expected, actual); EXPECT_THAT(result.message(), - ::testing::HasSubstr("Expected literal:\n{1, 2, 3}")); + ::testing::HasSubstr("Expected literal:\ns32[3] {1, 2, 3}")); EXPECT_THAT(result.message(), - ::testing::HasSubstr("Actual literal:\n{4, 5, 6}")); + ::testing::HasSubstr("Actual literal:\ns32[3] {4, 5, 6}")); } TEST(LiteralTestUtilTest, NearComparatorR1) { diff --git a/tensorflow/compiler/xla/tests/test_utils_test.cc b/tensorflow/compiler/xla/tests/test_utils_test.cc index e8f5d7a9a79ebddea3cb989dbe8eab90b630d5e7..448a66cfdd897b17cce1c87c050520a2f2eb0ea2 100644 --- a/tensorflow/compiler/xla/tests/test_utils_test.cc +++ b/tensorflow/compiler/xla/tests/test_utils_test.cc @@ -61,11 +61,11 @@ XLA_TEST_F(TestUtilsTest, Token) { R"(HloModule outfeed_module ENTRY InfeedToOutfeed { - token = token[] parameter(0) - infeed = ((u32[3]{0}, pred[]), token[]) infeed(token) + token0 = token[] parameter(0) + infeed = ((u32[3]{0}, pred[]), token[]) infeed(token0) infeed.data = (u32[3]{0}, pred[]) get-tuple-element(infeed), index=0 - outfeed = token[] outfeed(infeed.data, token) - ROOT infeed.1 = ((u32[3]{0}, pred[]), token[]) infeed(token) + outfeed = token[] outfeed(infeed.data, token0) + ROOT infeed.1 = ((u32[3]{0}, pred[]), token[]) infeed(token0) infeed.1.data = (u32[3]{0}, pred[]) get-tuple-element(infeed.1), index=0 infeed.1.token = token[] get-tuple-element(infeed.1), index=1 outfeed.1 = token[] outfeed(infeed.1.data, infeed.1.token) diff --git a/tensorflow/compiler/xla/tests/token_hlo_test.cc b/tensorflow/compiler/xla/tests/token_hlo_test.cc index 601c6b06938fef1f1ae809b33209ae59b24c70a2..b77cf38ed8e29973985406015c0a3936916ad5e6 100644 --- a/tensorflow/compiler/xla/tests/token_hlo_test.cc +++ b/tensorflow/compiler/xla/tests/token_hlo_test.cc @@ -214,8 +214,8 @@ ENTRY %AddDependency (p0: f32[], p1: f32[]) -> f32[] { %forty_two = f32[] constant(42.0) %add = f32[] add(f32[] %p0, f32[] %forty_two) - %token = token[] after-all(f32[] %add) - %p1_after_token = f32[] add-dependency(f32[] %p1, token[] %token) + %token0 = token[] after-all(f32[] %add) + %p1_after_token = f32[] add-dependency(f32[] %p1, token[] %token0) %neg = f32[] negate(f32[] %p1_after_token) ROOT %product = f32[] multiply(f32[] %add, f32[] %neg) } @@ -236,8 +236,8 @@ HloModule AddDependencyOfConstant, is_scheduled=true ENTRY %AddDependency (p0: f32[]) -> f32[] { %p0 = f32[] parameter(0) %forty_two = f32[] constant(42.0) - %token = token[] after-all(f32[] %p0) - %forty_two_after_token = f32[] add-dependency(f32[] %forty_two, token[] %token) + %token0 = token[] after-all(f32[] %p0) + %forty_two_after_token = f32[] add-dependency(f32[] %forty_two, token[] %token0) ROOT %product = f32[] multiply(f32[] %p0, f32[] %forty_two_after_token) } )"; @@ -255,8 +255,8 @@ HloModule AddDependencyAsRoot, is_scheduled=true ENTRY %AddDependency (p: f32[3]) -> f32[3] { %p = f32[3] parameter(0) %neg = f32[3] negate(f32[3] %p) - %token = token[] after-all() - ROOT %add_dep = f32[3] add-dependency(f32[3] %neg, token[] %token) + %token0 = token[] after-all() + ROOT %add_dep = f32[3] add-dependency(f32[3] %neg, token[] %token0) } )"; TF_ASSERT_OK_AND_ASSIGN( @@ -274,9 +274,9 @@ ENTRY %TupleShapedAddDependency (p0: f32[3], p1: f32[3]) -> f32[3] { %p0 = f32[3] parameter(0) %p1 = f32[3] parameter(1) %forty_two = f32[] constant(42.0) - %token = token[] after-all() - %tuple = (f32[3], token[], f32[3], f32[]) tuple(f32[3] %p0, token[] %token, f32[3] %p1, f32[] %forty_two) - %add_dep = (f32[3], token[], f32[3], f32[]) add-dependency((f32[3], token[], f32[3], f32[]) %tuple, token[] %token) + %token0 = token[] after-all() + %tuple = (f32[3], token[], f32[3], f32[]) tuple(f32[3] %p0, token[] %token0, f32[3] %p1, f32[] %forty_two) + %add_dep = (f32[3], token[], f32[3], f32[]) add-dependency((f32[3], token[], f32[3], f32[]) %tuple, token[] %token0) %elem0 = f32[3] get-tuple-element((f32[3], token[], f32[3], f32[]) %add_dep), index=0 %elem2 = f32[3] get-tuple-element((f32[3], token[], f32[3], f32[]) %add_dep), index=2 ROOT %diff = f32[3] subtract(f32[3] %elem0, f32[3] %elem2) diff --git a/tensorflow/compiler/xla/tests/tuple_test.cc b/tensorflow/compiler/xla/tests/tuple_test.cc index 27ce243e9bd4afbdcc1fdc5b6873d4968086e459..9c586bdeb05afb7378e92caed1f3edc408e051bf 100644 --- a/tensorflow/compiler/xla/tests/tuple_test.cc +++ b/tensorflow/compiler/xla/tests/tuple_test.cc @@ -555,8 +555,8 @@ XLA_TEST_F(TupleHloTest, s = (f32[2],f32[2]) tuple-select(cond, tup0, tup1) gte = f32[2] get-tuple-element(s), index=0 tuple = (f32[2]) tuple(gte) - token = token[] after-all() - ROOT outfeed = token[] outfeed(tuple, token) + token0 = token[] after-all() + ROOT outfeed = token[] outfeed(tuple, token0) } )"; auto module = diff --git a/tensorflow/compiler/xla/text_literal_reader.cc b/tensorflow/compiler/xla/text_literal_reader.cc index cdde88c1359416d423685f330e9cbdf77948034f..c78ec522aa5f13556c6d4602267544694887f250 100644 --- a/tensorflow/compiler/xla/text_literal_reader.cc +++ b/tensorflow/compiler/xla/text_literal_reader.cc @@ -27,6 +27,7 @@ limitations under the License. #include "absl/strings/string_view.h" #include "absl/strings/strip.h" #include "tensorflow/compiler/xla/literal.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/types.h" @@ -66,7 +67,7 @@ StatusOr TextLiteralReader::ReadAllLines() { } absl::StripAsciiWhitespace(&shape_string); - TF_ASSIGN_OR_RETURN(Shape shape, ShapeUtil::ParseShapeString(shape_string)); + TF_ASSIGN_OR_RETURN(Shape shape, ParseShape(shape_string)); if (shape.element_type() != F32) { return Unimplemented( "unsupported element type for text literal reading: %s", diff --git a/tensorflow/compiler/xla/tools/BUILD b/tensorflow/compiler/xla/tools/BUILD index 8926bbed2b54fceaaf0e6e991f0e881d35731ef4..99b32c19a52bf2a1f02047a1ceea626947d994fc 100644 --- a/tensorflow/compiler/xla/tools/BUILD +++ b/tensorflow/compiler/xla/tools/BUILD @@ -14,7 +14,7 @@ filegroup( visibility = ["//tensorflow/compiler/xla:internal"], ) -load("//tensorflow:tensorflow.bzl", "tf_cc_binary") +load("//tensorflow:tensorflow.bzl", "tf_cc_binary", "tf_cc_test") tf_cc_binary( name = "hex_floats_to_packed_literal", @@ -234,3 +234,50 @@ tf_cc_binary( "//tensorflow/core:lib", ], ) + +tf_cc_test( + name = "hlo_extractor_test", + srcs = ["hlo_extractor_test.cc"], + deps = [ + ":hlo_extractor", + "//tensorflow/compiler/xla/service:hlo_matchers", + "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/core:test", + ], +) + +cc_library( + name = "hlo_extractor", + srcs = ["hlo_extractor.cc"], + hdrs = ["hlo_extractor.h"], + deps = [ + "//tensorflow/compiler/xla:status", + "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service:hlo_verifier", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/memory", + ], +) + +tf_cc_binary( + name = "interactive_graphviz", + srcs = ["interactive_graphviz.cc"], + deps = [ + ":hlo_extractor", + "//tensorflow/compiler/xla/client:client_library", + "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/service:compiler", + "//tensorflow/compiler/xla/service:cpu_plugin", + "//tensorflow/compiler/xla/service:gpu_plugin", + "//tensorflow/compiler/xla/service:hlo_graph_dumper", + "//tensorflow/compiler/xla/service:hlo_proto", + "//tensorflow/compiler/xla/service:hlo_runner", + "//tensorflow/compiler/xla/service:local_service", + "//tensorflow/compiler/xla/service:platform_util", + "//tensorflow/core:framework_internal", + "//tensorflow/core:lib", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/strings", + ], +) diff --git a/tensorflow/compiler/xla/tools/hlo_extractor.cc b/tensorflow/compiler/xla/tools/hlo_extractor.cc new file mode 100644 index 0000000000000000000000000000000000000000..f3ce5f99b0c2a8e9ae5446f4bedc34b678c95b96 --- /dev/null +++ b/tensorflow/compiler/xla/tools/hlo_extractor.cc @@ -0,0 +1,159 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/tools/hlo_extractor.h" + +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" +#include "absl/memory/memory.h" +#include "tensorflow/compiler/xla/service/hlo_clone_context.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_verifier.h" +#include "tensorflow/compiler/xla/status.h" + +namespace xla { +namespace { + +// Visitor that build a new HLO module with an entry computation and a root that +// is provided to the visit function. Only HLOs that are reachable from the new +// root instruction are included in the new module. +// +// The constructor allows specifying a set of boundary HLOs to prune the HLO +// graph. HLOs at the boundary are replaced with parameters. Can be nullptr +// which means no boundary, i.e. no HLOs are replaced with parameters. +class ExtractionVisitor : public ConstDfsHloVisitorWithDefault { + public: + explicit ExtractionVisitor( + const HloModule& old_module, + absl::flat_hash_set* boundary) + : old_module_(old_module), + module_(absl::make_unique("extracted", config_)), + clone_context_(module_.get()), + builder_("entry_computation"), + boundary_(boundary) {} + + Status HandleParameter(const HloInstruction* parameter) override { + // Entry parameters need renumbering. + auto new_parameter = HloInstruction::CreateParameter( + parameter_number_++, parameter->shape(), parameter->name()); + clone_context_.MapInstruction(parameter, new_parameter.get()); + builder_.AddInstruction(std::move(new_parameter)); + return Status::OK(); + } + + Status DefaultAction(const HloInstruction* hlo) override { + // Replace instructions at the boundary with parameters, but leave constants + // untouched. + if (boundary_ != nullptr && boundary_->count(hlo) > 0) { + auto new_parameter = HloInstruction::CreateParameter( + parameter_number_, hlo->shape(), hlo->name()); + parameter_number_++; + clone_context_.MapInstruction(hlo, new_parameter.get()); + builder_.AddInstruction(std::move(new_parameter)); + return Status::OK(); + } + std::vector new_operands; + for (auto operand : hlo->operands()) { + new_operands.push_back(clone_context_.GetInstruction(operand)); + } + auto instruction = + hlo->CloneWithNewOperands(hlo->shape(), new_operands, &clone_context_); + builder_.AddInstruction(std::move(instruction)); + return Status::OK(); + } + + Status FinishVisit(const HloInstruction* /*root*/) override { + module_->AddEntryComputation(builder_.Build()); + // Rename HLOs so that their name matches the original. By default, + // HLOs get new unique names when adding a new entry computation to + // a module. + for (auto computation : old_module_.MakeComputationPostOrder()) { + for (auto old_instruction : computation->MakeInstructionPostOrder()) { + if (auto new_instruction = + clone_context_.FindInstruction(old_instruction)) { + new_instruction->SetAndSanitizeName(old_instruction->name()); + } + } + } + return Status::OK(); + } + + HloModule* module() { return module_.get(); } + + std::unique_ptr ConsumeModule() { return std::move(module_); } + + private: + const HloModule& old_module_; + HloModuleConfig config_; + std::unique_ptr module_; + HloCloneContext clone_context_; + HloComputation::Builder builder_; + absl::flat_hash_set* boundary_; + int64 parameter_number_ = 0; +}; + +void ComputeBoundary(const HloInstruction* root, int64 limit, + absl::flat_hash_set* boundary) { + std::deque worklist; + absl::flat_hash_map visited; + worklist.push_back(root); + visited.emplace(root, 0); + while (!worklist.empty()) { + auto hlo = worklist.front(); + worklist.pop_front(); + int64 hops = visited[hlo]; + if (hops > limit) { + boundary->insert(hlo); + continue; + } + for (const HloInstruction* operand : hlo->operands()) { + if (visited.count(operand)) { + continue; + } + worklist.push_back(operand); + visited.emplace(operand, hops + 1); + } + } +} + +} // namespace + +std::unique_ptr ExtractModule(HloInstruction* instruction, + int64 height) { + absl::flat_hash_set boundary; + if (height != -1) { + ComputeBoundary(instruction, height, &boundary); + } + ExtractionVisitor visitor(*instruction->GetModule(), &boundary); + CHECK(instruction->Accept(&visitor).ok()); + + // The first pass may leave unused parameter instructions. Do another + // extraction pass to remove unused parameters. This is done because + // HloComputation does not allow removing parameters after the computation has + // been built. + ExtractionVisitor cleanup_visitor(*visitor.module(), /*boundary=*/nullptr); + TF_CHECK_OK(visitor.module()->entry_computation()->root_instruction()->Accept( + &cleanup_visitor)); + + HloVerifier verifier(/*layout_sensitive=*/false, + /*allow_mixed_precision=*/true); + TF_CHECK_OK(verifier.Run(cleanup_visitor.module()).status()); + return cleanup_visitor.ConsumeModule(); +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/tools/hlo_extractor.h b/tensorflow/compiler/xla/tools/hlo_extractor.h new file mode 100644 index 0000000000000000000000000000000000000000..bc13dc7e438fe0e64312746150af02df805e746a --- /dev/null +++ b/tensorflow/compiler/xla/tools/hlo_extractor.h @@ -0,0 +1,36 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_TOOLS_HLO_EXTRACTOR_H_ +#define TENSORFLOW_COMPILER_XLA_TOOLS_HLO_EXTRACTOR_H_ + +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" + +namespace xla { + +// Creates a new HLO module rooted with an entry computation rooted at the given +// instruction. +// +// By default (height == -1), the new computation includes all transitive +// operands of `root`. If you specify a different height, the new computation +// will include all instructions <= `height` hops away from `root`. +// Instructions at the boundary are replaced by parameters. +std::unique_ptr ExtractModule(HloInstruction* instruction, + int64 height = -1); + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_TOOLS_HLO_EXTRACTOR_H_ diff --git a/tensorflow/compiler/xla/tools/hlo_extractor_test.cc b/tensorflow/compiler/xla/tools/hlo_extractor_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..c187222a11ee721b006194a68620c58749707193 --- /dev/null +++ b/tensorflow/compiler/xla/tools/hlo_extractor_test.cc @@ -0,0 +1,142 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/tools/hlo_extractor.h" + +#include "tensorflow/compiler/xla/service/hlo_matchers.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" + +namespace xla { +namespace { + +namespace op = testing::opcode_matchers; + +using HloExtractorTest = HloTestBase; + +TEST_F(HloExtractorTest, ExtractTopLevel) { + const string& hlo_string = R"( +HloModule test + +ENTRY %entry { + param.0 = f32[4]{0} parameter(0) + negate = f32[4]{0} negate(f32[4]{0} param.0) + ROOT exp = f32[4]{0} exponential(f32[4]{0} negate) +} +)"; + + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr hlo_module, + HloRunner::CreateModuleFromString(hlo_string, GetDebugOptionsForTest())); + + { + auto extracted_module = + ExtractModule(FindInstruction(hlo_module.get(), "exp")); + EXPECT_THAT(extracted_module->entry_computation()->root_instruction(), + op::Exp(op::Negate(op::Parameter(0)))); + } + + { + auto extracted_module = + ExtractModule(FindInstruction(hlo_module.get(), "exp"), /*height=*/0); + EXPECT_THAT(extracted_module->entry_computation()->root_instruction(), + op::Exp(op::Parameter(0))); + } + + { + auto extracted_module = ExtractModule( + FindInstruction(hlo_module.get(), "negate"), /*height=*/0); + EXPECT_THAT(extracted_module->entry_computation()->root_instruction(), + op::Negate(op::Parameter(0))); + } +} + +TEST_F(HloExtractorTest, ExtractDag) { + const string& hlo_string = R"( +HloModule test + +ENTRY %entry { + param.0 = f32[4]{0} parameter(0) + tanh = f32[4]{0} tanh(f32[4]{0} param.0) + negate = f32[4]{0} negate(f32[4]{0} tanh) + exp = f32[4]{0} exponential(f32[4]{0} negate) + ROOT add = f32[4]{0} add(f32[4]{0} negate, f32[4]{0} exp) +} +)"; + + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr hlo_module, + HloRunner::CreateModuleFromString(hlo_string, GetDebugOptionsForTest())); + + { + auto extracted_module = + ExtractModule(FindInstruction(hlo_module.get(), "exp")); + EXPECT_THAT(extracted_module->entry_computation()->root_instruction(), + op::Exp(op::Negate(op::Tanh(op::Parameter(0))))); + } + + { + auto extracted_module = + ExtractModule(FindInstruction(hlo_module.get(), "add"), /*height=*/0); + EXPECT_THAT(extracted_module->entry_computation()->root_instruction(), + op::Add(op::Parameter(0), op::Parameter(1))); + } + { + auto extracted_module = + ExtractModule(FindInstruction(hlo_module.get(), "add"), /*height=*/1); + EXPECT_THAT(extracted_module->entry_computation()->root_instruction(), + op::Add(op::Negate(op::Parameter(0)), + op::Exp(op::Negate(op::Parameter(0))))); + } + { + auto extracted_module = + ExtractModule(FindInstruction(hlo_module.get(), "add"), /*height=*/2); + EXPECT_THAT(extracted_module->entry_computation()->root_instruction(), + op::Add(op::Negate(op::Tanh(op::Parameter(0))), + op::Exp(op::Negate(op::Tanh(op::Parameter(0)))))); + } +} + +TEST_F(HloExtractorTest, ExtractWithConstant) { + const string& hlo_string = R"( +HloModule test + +ENTRY %entry { + p = f32[4]{0} parameter(0) + tanh = f32[4]{0} tanh(p) + c = f32[4]{0} constant({1, 2, 3, 4}) + ROOT add = f32[4]{0} add(tanh, c) +} +)"; + + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr hlo_module, + HloRunner::CreateModuleFromString(hlo_string, GetDebugOptionsForTest())); + + { + auto extracted_module = + ExtractModule(FindInstruction(hlo_module.get(), "add"), /*height=*/0); + EXPECT_THAT(extracted_module->entry_computation()->root_instruction(), + op::Add(op::Parameter(0), op::Parameter(1))); + } + { + auto extracted_module = + ExtractModule(FindInstruction(hlo_module.get(), "add"), /*height=*/1); + EXPECT_THAT(extracted_module->entry_computation()->root_instruction(), + op::Add(op::Tanh(op::Parameter(0)), op::Constant())); + } +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/tools/interactive_graphviz.cc b/tensorflow/compiler/xla/tools/interactive_graphviz.cc new file mode 100644 index 0000000000000000000000000000000000000000..6c90cde5a75a93837ee149fd9b5a60e6413c2ac4 --- /dev/null +++ b/tensorflow/compiler/xla/tools/interactive_graphviz.cc @@ -0,0 +1,652 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// A tool for interactively exploring graphviz dumps of HLO graphs. +// +// Input can be a binary HloSnapshot proto, a binary HloProto proto, or a +// textual HLO string. +// +// Generated visualization is opened in a new default browser window using +// /usr/bin/sensible-browser. + +#include +#include + +#include "absl/algorithm/container.h" +#include "absl/strings/match.h" +#include "absl/strings/numbers.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" +#include "absl/strings/string_view_utils.h" +#include "absl/strings/util.h" +#include "tensorflow/compiler/xla/client/client_library.h" +#include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/service/compiler.h" +#include "tensorflow/compiler/xla/service/hlo.pb.h" +#include "tensorflow/compiler/xla/service/hlo_runner.h" +#include "tensorflow/compiler/xla/service/local_service.h" +#include "tensorflow/compiler/xla/service/platform_util.h" +#include "tensorflow/compiler/xla/tools/hlo_extractor.h" +#include "tensorflow/core/platform/init_main.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/subprocess.h" +#include "tensorflow/core/util/command_line_flags.h" +#if defined(PLATFORM_GOOGLE) +#include "util/readline/readline.h" +#endif + +namespace xla { +namespace tools { +namespace { + +bool ReadLine(const char *prompt, string *line) { +#if defined(PLATFORM_GOOGLE) + return util::ReadLine(prompt, line); +#else + std::cout << prompt; + return std::getline(std::cin, *line); +#endif +} + +// Command-line opts to this tool. See main() for descriptions of these +// fields. +struct Options { + string hlo_snapshot; + string hlo_proto; + string hlo_text; + string platform; + string browser; +}; + +const char* const kUsage = R"( +This tool lets you load an XLA dump and then interactively explore its graphical +representation. + +Most models are too large to visualize in their entirety using graphviz, but +it's still useful to be able to look at the nodes "near" a particular node of +interest. + +If you pass --platform, this tool will compile the HloModule for the given +platform. This means that if you acquired your proto from a binary running at a +particular CL, the HLO graph it ran isn't necessarily the same as the one shown +here, unless this program was built at the same CL (and our compiler is +deterministic :). + +Be patient when starting this program if you give it a large input; it has to +compile the whole thing. + +Usage: + + interactive_graphviz -- \ + --{hlo_snapshot,hlo_proto,hlo_text}=path/to/binary_proto + --platform={CUDA,CPU,...} +)"; + +// Unless an explicit width is specified, we will render a neighborhood of +// kDefaultWidth nodes around the requested instruction. +constexpr int64 kDefaultWidth = 2; + +// When printing all paths between two nodes, we print out only this many nodes +// by default, truncating the graph if there are more nodes than this in the +// all-paths set. +constexpr int64 kDefaultMaxNumNodesInAllPaths = 100; + +using absl::EqualsIgnoreCase; + +// A global control for whether backend configuration display is enabled. +bool show_backend_config = true; + +HloInstruction* FindInstruction(const HloModule& module, string node_name) { + if (absl::StartsWith(node_name, "%")) { + node_name.erase(node_name.begin()); + } + for (const auto& computation : module.computations()) { + auto instrs = computation->instructions(); + auto it = absl::c_find_if(instrs, [&](const HloInstruction* instr) { + // Try with and without "%" at the beginning of the node name. + return EqualsIgnoreCase(instr->name(), node_name) || + EqualsIgnoreCase(instr->name(), absl::StrCat("%", node_name)); + }); + if (it != instrs.end()) { + return *it; + } + } + return nullptr; +} + +HloComputation* FindComputation(const HloModule& module, + const string& comp_name) { + for (auto* computation : module.computations()) { + if (EqualsIgnoreCase(computation->name(), comp_name)) { + return computation; + } + } + return nullptr; +} + +// Print a help message describing the various available commands. +void DoHelpCommand() { + std::cout << R"(Commands: + [] + Renders a neighborhood of nodes around . If + is not provided, the default value is )" + << kDefaultWidth << R"(. + allpaths [] + Renders a subset of all paths from one instruction to the other. Either + order of nodes is accepted. Shows the nodes in the all-paths set on the + shortest paths; default is )" + << kDefaultMaxNumNodesInAllPaths << R"(. + + Renders all nodes in . + backend_config [on|off] + Controls whether backend operation configuration information is printed. + list [name|op_name|op_type] + Lists all instructions whose name, metadata op_name, or metadata op_type + contains as a substring. + list computations + Lists all computations in the module. + info + info + Prints information about or . + extract + Creates a new HLO module with as entry computation root. If + is specified, the new computation contains nodes up to + nodes above the root. + help + Prints this usage information.)" + << std::endl; +} + +// Turn metadata-printing on or off. +void DoBackendConfigCommand(const std::vector& tokens) { + if (tokens.size() == 2 && tokens[1] == "on") { + show_backend_config = true; + } else if (tokens.size() == 2 && tokens[1] == "off") { + show_backend_config = false; + } else if (tokens.size() != 1) { + std::cerr << "(Illegal backend_config value. Use either 'on' or 'off'.)" + << std::endl; + } + std::cout << "Backend configuration display " + << (show_backend_config ? "ON" : "OFF") << std::endl; +} + +// List all computations in the module. +void DoListComputationsCommand(const HloModule& module, + const std::vector& tokens) { + if (tokens.size() > 2) { + std::cout << R"(Illegal syntax; "list computations" takes no arguments.)"; + return; + } + if (module.entry_computation() != nullptr) { + std::cout << "Entry computation:" << std::endl; + std::cout << " " << module.entry_computation()->name() << std::endl + << std::endl; + } + std::cout << "Subcomputations:" << std::endl; + std::vector names; + for (const auto& computation : module.computations()) { + if (computation == module.entry_computation()) { + continue; + } + std::cout << " " << computation->name() << std::endl; + } +} + +// List all instructions matching a pattern. +void DoListCommand(const HloModule& module, const std::vector& tokens) { + string pattern = ""; + string type = "name"; + if (tokens.size() == 2) { + pattern = tokens[1]; + } else if (tokens.size() == 3) { + type = tokens[1]; + pattern = tokens[2]; + } else { + std::cout << "Illegal list query syntax. Use " + << R"("list [name|op_name|op_type] pattern".)" << std::endl; + return; + } + + std::cout << "Query results:" << std::endl; + for (const auto& computation : module.computations()) { + for (const auto& instr : computation->instructions()) { + if ((type == "name" && instr->name().find(pattern) != string::npos) || + (type == "op_name" && + instr->metadata().op_name().find(pattern) != string::npos) || + (type == "op_type" && + instr->metadata().op_type().find(pattern) != string::npos)) { + std::cout << " " << instr->name(); + std::cout << ", op_name '" << instr->metadata().op_name() << "'"; + std::cout << ", op_type '" << instr->metadata().op_type() << "'"; + std::cout << std::endl; + } + } + } +} + +// Print info about an instruction or computation. +void DoInfoCommand(const HloModule& module, const std::vector& tokens) { + if (tokens.size() != 2) { + std::cerr << "Illegal info query syntax. Use " + << R"("info name".)"; + return; + } + string node_name = tokens[1]; + + const HloInstruction* instr = FindInstruction(module, node_name); + const HloComputation* comp = FindComputation(module, node_name); + if (!instr && !comp) { + std::cerr << "Couldn't find HloInstruction or HloComputation named " + << node_name << std::endl; + return; + } + + if (comp != nullptr) { + std::cout << "HloComputation " << comp->name() << std::endl; + if (comp->IsFusionComputation()) { + std::cout << " Fusion instruction: " << comp->FusionInstruction()->name() + << std::endl; + } + std::cout << " Parameters:" << std::endl; + for (const auto& param : comp->parameter_instructions()) { + std::cout << " " << param->name() << " (" + << ShapeUtil::HumanStringWithLayout(param->shape()) << ")" + << std::endl; + } + HloInstruction* root = comp->root_instruction(); + std::cout << " Root instruction: " << root->name() << " (" + << ShapeUtil::HumanStringWithLayout(root->shape()) << ")" + << std::endl; + + auto embedded_computations = comp->MakeEmbeddedComputationsList(); + std::cout << " " << embedded_computations.size() << " embedded computation" + << (embedded_computations.size() != 1 ? "s" : "") + << (!embedded_computations.empty() ? ":" : ".") << std::endl; + for (const HloComputation* c : embedded_computations) { + std::cout << " " << c->name() << std::endl; + } + + // Find which computations reference comp as an embedded computation. + std::vector users; + for (const HloComputation* c : module.computations()) { + if (absl::c_linear_search(c->MakeEmbeddedComputationsList(), comp)) { + users.push_back(c); + } + } + std::cout << " Used by " << users.size() << " computation" + << (users.size() != 1 ? "s" : "") << (!users.empty() ? ":" : "."); + for (const HloComputation* c : users) { + std::cout << " " << c->name() << std::endl; + } + } else { + std::cout << "HloInstruction " << instr->name() << std::endl; + std::cout << " Parent computation: " << instr->parent()->name() + << std::endl; + std::cout << " Opcode: " << HloOpcodeString(instr->opcode()) << std::endl; + std::cout << " Shape: " << ShapeUtil::HumanStringWithLayout(instr->shape()) + << std::endl; + std::cout << " Metadata:" << std::endl; + if (!instr->metadata().op_name().empty()) { + std::cout << " Name: " << instr->metadata().op_name() << std::endl; + } + if (!instr->metadata().op_type().empty()) { + std::cout << " Type: " << instr->metadata().op_type() << std::endl; + } + if (!instr->raw_backend_config_string().empty()) { + std::cout << " Backend configuration: " + << instr->raw_backend_config_string() << std::endl; + } + if (instr->opcode() == HloOpcode::kFusion) { + std::cout << " Fusion kind: " << xla::ToString(instr->fusion_kind()) + << std::endl; + std::cout << " Fusion computation: " + << instr->fused_instructions_computation()->name() << std::endl; + std::cout << " Fused computation root: " + << instr->fused_expression_root()->name() << std::endl; + } + std::cout << " Operands:" << std::endl; + for (HloInstruction* operand : instr->operands()) { + std::cout << " " << operand->name() << " (" + << ShapeUtil::HumanStringWithLayout(operand->shape()) << ")" + << std::endl; + } + std::cout << " Users:" << std::endl; + for (HloInstruction* user : instr->users()) { + std::cout << " " << user->name() << std::endl; + } + if (instr->parent()->root_instruction() == instr) { + std::cout << " Root instruction of " << instr->parent()->name() + << std::endl; + } + } +} + +void DoExtractCommand(const HloModule& module, + absl::Span tokens) { + if (tokens.size() > 3) { + std::cerr << R"(Illegal input. Enter e.g. "extract %fusion.1 2")" + << std::endl; + return; + } + + // Find the node with the given name. + string node_name = tokens[1]; + HloInstruction* instr = FindInstruction(module, node_name); + if (!instr) { + std::cerr << "Couldn't find HloInstruction named " << node_name << "." + << std::endl; + return; + } + + int64 height = -1; + if (tokens.size() == 3) { + if (!absl::SimpleAtoi(tokens[2], &height)) { + std::cerr << "Can't parse '" << tokens[2] << "' as an integer." + << std::endl; + return; + } + } + + auto extracted_module = ExtractModule(instr, height); + std::cout << extracted_module->ToString( + HloPrintOptions::ShortParsable().set_print_backend_config( + show_backend_config)) + << std::endl; +} + +// Checks if there is a use-def path from `from` to `to`. +bool ExistsPathFromTo(const HloInstruction* from, const HloInstruction* to) { + std::unordered_set visited; + std::vector to_visit = {from}; + while (!to_visit.empty()) { + auto* n = to_visit.back(); + if (n == to) { + return true; + } + to_visit.pop_back(); + visited.insert(n); + for (auto* user : n->users()) { + if (!visited.count(user)) { + to_visit.push_back(user); + } + } + } + return false; +} + +void DisplayGraphHandle(const Options &opts, const string& handle) { + std::cout << handle << std::endl; + + // If it is a url, try to open it up in the user's browser too. + if (strings::StartsWithIgnoreCase(handle, "http://") || + strings::StartsWithIgnoreCase(handle, "https://") || + strings::StartsWithIgnoreCase(handle, "file://")) { + const char* browser_bin = opts.browser.empty() ? "/usr/bin/sensible-browser" + : opts.browser.c_str(); + tensorflow::SubProcess p; + p.SetProgram(browser_bin, {browser_bin, handle}); + p.Start(); + } else if (handle.empty()) { + std::cerr << "Unable to render graph, perhaps due to graphviz server " + "timeout. Run with --logtostderr to see." + << std::endl; + } else { + std::cerr << "\nExpected a URL, but got strange graph result (dumped " + "above). If this isn't what you expected, maybe file a bug?" + << std::endl; + } +} + +void DoAllPathsCommand(const Options& opts, const HloModule& module, + const std::vector& tokens) { + if (tokens.size() > 4) { + std::cerr << R"(Illegal input. Enter e.g. "allpaths %add.4 %subtract.2" or +"allpaths add.4 subtract.2 42.)" + << std::endl; + return; + } + + int64 max_nodes = kDefaultMaxNumNodesInAllPaths; + if (tokens.size() == 4 && !absl::SimpleAtoi(tokens[3], &max_nodes)) { + std::cerr << "Can't parse '" << tokens[3] << "' as an integer." + << std::endl; + return; + } + + const HloInstruction* n1 = FindInstruction(module, tokens[1]); + if (!n1) { + std::cerr << "Couldn't find HloInstruction named " << tokens[1]; + return; + } + const HloInstruction* n2 = FindInstruction(module, tokens[2]); + if (!n2) { + std::cerr << "Couldn't find HloInstruction named " << tokens[2]; + return; + } + + // Is there a path from n1 to n2, or vice versa? + const HloInstruction* from; + const HloInstruction* to; + if (ExistsPathFromTo(n1, n2)) { + from = n1; + to = n2; + } else if (ExistsPathFromTo(n2, n1)) { + from = n2; + to = n1; + } else { + std::cerr << "No path from/to " << tokens[1] << " to/from " << tokens[2]; + return; + } + DisplayGraphHandle(opts, hlo_graph_dumper::DumpAllPathsFromTo( + *from, *to, max_nodes, /*show_backend_config=*/show_backend_config)); +} + +// Plot a given instruction neighborhood or computation with graphviz. +void DoPlotCommand(const Options& opts, const HloModule& module, + const std::vector& tokens) { + if (tokens.size() > 2) { + std::cerr << R"(Illegal input. Enter e.g. "%fusion.1 42" or "%fusion.1".)" + << std::endl; + return; + } + + string node_name = tokens[0]; + + // Find the node with the given name. + const HloInstruction* instr = FindInstruction(module, node_name); + const HloComputation* comp = FindComputation(module, node_name); + if (!instr && !comp) { + std::cerr << "Couldn't find HloInstruction or HloComputation named " + << node_name << "." << std::endl; + return; + } + + uint64 graph_width = kDefaultWidth; + if (tokens.size() == 2) { + if (comp) { + std::cerr << "Can only use graph-size parameter with instructions, but " + << node_name << " is a computation." << std::endl; + return; + } + if (!absl::SimpleAtoi(tokens[1], &graph_width)) { + std::cerr << "Can't parse '" << tokens[1] << "' as an integer." + << std::endl; + return; + } + } + + // Generate the graph and print the resulting string, which should be a + // graphviz url. + if (comp) { + DisplayGraphHandle(opts, hlo_graph_dumper::DumpGraph( + *comp, "", comp->parent()->config().debug_options(), nullptr, + /*show_backend_config=*/show_backend_config)); + } else { + DisplayGraphHandle(opts, hlo_graph_dumper::DumpNeighborhoodAround( + *instr, graph_width, /*show_backend_config=*/show_backend_config)); + } +} + +// Run the main event loop, reading user commands and processing them. +void InteractiveDumpGraphs(const Options& opts, const HloModule& module) { + // This is an interactive tool, but some may use `extract` in non-tty + // environment anyway. Give them a clean hlo dump. + if (isatty(fileno(stdin))) { + std::cout << "\n\nLoaded module " << module.name() << "." << std::endl; + DoHelpCommand(); + } + for (string line; ReadLine("\ncommand: ", &line);) { + if (line.empty()) { + std::cout << R"(Enter e.g. "fusion.1 3" or "add.8".)" << std::endl + << R"(Enter "help" for help; ^D, "quit", or "exit" to exit.)" + << std::endl; + continue; + } + std::vector tokens = strings::Split(line, ' '); + if (tokens[0] == "quit" || tokens[0] == "exit") { + break; + } else if (tokens[0] == "help") { + DoHelpCommand(); + } else if (tokens[0] == "backend_config") { + DoBackendConfigCommand(tokens); + } else if (tokens[0] == "list") { + if (tokens.size() > 1 && tokens[1] == "computations") { + DoListComputationsCommand(module, tokens); + } else { + DoListCommand(module, tokens); + } + } else if (tokens[0] == "info") { + DoInfoCommand(module, tokens); + } else if (tokens[0] == "extract") { + DoExtractCommand(module, tokens); + } else if (tokens[0] == "allpaths") { + DoAllPathsCommand(opts, module, tokens); + } else { + DoPlotCommand(opts, module, tokens); + } + } +} + +void CheckFlags(const Options &opts) { + std::vector nonempty_proto_flags; + if (!opts.hlo_proto.empty()) { + nonempty_proto_flags.push_back("--hlo_proto"); + } + if (!opts.hlo_snapshot.empty()) { + nonempty_proto_flags.push_back("--hlo_snapshot"); + } + if (!opts.hlo_text.empty()) { + nonempty_proto_flags.push_back("--hlo_text"); + } + switch (nonempty_proto_flags.size()) { + case 1: + // We're good to go. + break; + case 0: + LOG(FATAL) << "Need one of the following options: " + << absl::StrJoin(nonempty_proto_flags, ", "); + default: + LOG(FATAL) << "Can only specify one of " + << absl::StrJoin(nonempty_proto_flags, ", "); + } +} + +void RealMain(const Options& opts) { + if (!isatty(fileno(stdin))) { + LOG(ERROR) << "\n\n*****************************************\n" + << "This is an interactive tool, but stdin is not a tty.\n" + << "*****************************************\n\n"; + } + + CheckFlags(opts); + + std::unique_ptr module; + if (!opts.hlo_snapshot.empty()) { + HloSnapshot snapshot; + TF_CHECK_OK(tensorflow::ReadBinaryProto(tensorflow::Env::Default(), + opts.hlo_snapshot, &snapshot)) + << "Can't open, read, or parse HloSnapshot proto at " + << opts.hlo_snapshot; + auto config = + HloModule::CreateModuleConfigFromProto(snapshot.hlo().hlo_module(), + xla::GetDebugOptionsFromFlags()) + .ValueOrDie(); + module = HloModule::CreateFromProto(snapshot.hlo().hlo_module(), config) + .ValueOrDie(); + } else if (!opts.hlo_proto.empty()) { + module = HloRunner::ReadModuleFromBinaryProtoFile( + opts.hlo_proto, xla::GetDebugOptionsFromFlags()) + .ValueOrDie(); + } else if (!opts.hlo_text.empty()) { + module = HloRunner::ReadModuleFromHloTextFile( + opts.hlo_text, xla::GetDebugOptionsFromFlags()) + .ValueOrDie(); + } + + // If a platform was specified, compile the module for that platform. + if (!opts.platform.empty()) { + se::Platform* platform = + PlatformUtil::GetPlatform(opts.platform).ValueOrDie(); + LOG(INFO) << "Compiling module for " << platform->Name(); + + se::StreamExecutor* executor = + platform->ExecutorForDevice(/*ordinal=*/0).ValueOrDie(); + auto compiler = Compiler::GetForPlatform(platform).ValueOrDie(); + module = compiler + ->RunHloPasses(std::move(module), executor, + /*device_allocator=*/nullptr) + .ValueOrDie(); + auto executable = compiler + ->RunBackend(std::move(module), executor, + /*device_allocator=*/nullptr) + .ValueOrDie(); + InteractiveDumpGraphs(opts, executable->module()); + } else { + InteractiveDumpGraphs(opts, *module); + } +} + +} // namespace +} // namespace tools +} // namespace xla + +int main(int argc, char** argv) { + xla::tools::Options opts; + opts.browser = "/usr/bin/sensible-browser"; + bool need_help = false; + const std::vector flag_list = { + tensorflow::Flag("hlo_snapshot", &opts.hlo_snapshot, + "HloSnapshot proto to interactively dump to graphviz"), + tensorflow::Flag("hlo_proto", &opts.hlo_proto, + "XLA hlo proto to interactively dump to graphviz"), + tensorflow::Flag("hlo_text", &opts.hlo_text, + "XLA hlo proto to interactively dump to graphviz"), + tensorflow::Flag("platform", &opts.platform, + "Platform to compile for: CPU, CUDA, etc"), + tensorflow::Flag("browser", &opts.browser, + "Path to web browser used to display produced graphs."), + tensorflow::Flag("help", &need_help, + "Prints this help message"), + }; + xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); + bool parse_ok = tensorflow::Flags::Parse(&argc, argv, flag_list); + tensorflow::port::InitMain(argv[0], &argc, &argv); + if (argc != 1 || !parse_ok || need_help) { + LOG(QFATAL) << usage; + } + xla::tools::RealMain(opts); + return 0; +} diff --git a/tensorflow/compiler/xla/tools/replay_computation.cc b/tensorflow/compiler/xla/tools/replay_computation.cc index ff2c3399928c0e6339304323c4f93e212933a340..27a8dd13308b29da9a5013ac9f696613981d68bb 100644 --- a/tensorflow/compiler/xla/tools/replay_computation.cc +++ b/tensorflow/compiler/xla/tools/replay_computation.cc @@ -118,7 +118,12 @@ StatusOr ReplayComputation(const HloSnapshot& module, std::vector> global_data_arguments; std::vector argument_ptrs; if (opts.use_fake_data) { - global_data_arguments = MakeFakeArgumentsOrDie(computation, client); + // Run fake computations with debug options ignoring XLA_FLAGS. Users very + // likely want XLA_FLAGS only to apply to the "real" computation being run, + // not to the fake computations we use for generating arguments. + auto debug_opts = DefaultDebugOptionsIgnoringFlags(); + global_data_arguments = + MakeFakeArgumentsOrDie(computation, client, &debug_opts); for (const auto& data : global_data_arguments) { argument_ptrs.push_back( client->GlobalDataToShapedBuffer(data->handle(), /*device_ordinal=*/0) @@ -140,8 +145,7 @@ StatusOr ReplayComputation(const HloSnapshot& module, bool provide_infeed = false; Shape infeed_shape; if (!opts.fake_infeed_shape.empty()) { - StatusOr shape_status = - ShapeUtil::ParseShapeString(opts.fake_infeed_shape); + StatusOr shape_status = ParseShape(opts.fake_infeed_shape); TF_CHECK_OK(shape_status.status()); infeed_shape = std::move(shape_status).ValueOrDie(); provide_infeed = true; diff --git a/tensorflow/compiler/xla/xla.proto b/tensorflow/compiler/xla/xla.proto index bdeb1728fa2321f25d9db230f2d449a7b4b348ee..0e8fa73f8170addfa5061b33f3d6882a13890bce 100644 --- a/tensorflow/compiler/xla/xla.proto +++ b/tensorflow/compiler/xla/xla.proto @@ -100,6 +100,14 @@ message DebugOptions { // names as specified by the HloPassInterface::name() method. repeated string xla_disable_hlo_passes = 30; + // Disables all HLO passes. Notes that some passes are necessary for + // correctness and the invariants that must be satisfied by "fully optimized" + // HLO are different for different devices and may change over time. The only + // "guarantee", such as it is, is that if you compile XLA and dump the + // optimized HLO for some graph, you should be able to run it again on the + // same device with the same build of XLA. + bool xla_disable_all_hlo_passes = 104; + // Numerical optimization level for the XLA compiler backend; the specific // interpretation of this value is left to the backends. int32 xla_backend_optimization_level = 31; @@ -213,6 +221,17 @@ message DebugOptions { // the host that run models in parallel across multiple devices. int32 xla_force_host_platform_device_count = 102; + // If set to true XLA:GPU invokes `ptxas` with -O0 (default is -O3). + bool xla_gpu_disable_ptxas_optimizations = 103; + + // Dump HLO graphs as an HTML (DOT -> SVG inlined in HTML) + bool xla_hlo_dump_as_html = 105; + + // Enable fast math with eigen in the HLO evaluator. + bool xla_hlo_evaluator_use_fast_path = 106; + + // Next id: 107 + // Extra options to pass to the compilation backend (e.g. LLVM); specific // interpretation of these values is left to the backend. map xla_backend_extra_options = 500; @@ -386,7 +405,7 @@ message WaitForExecutionResponse { message ComputeConstantGraphRequest { HloModuleProto computation = 1; - Layout output_layout = 2; + LayoutProto output_layout = 2; } message ComputeConstantResponse { diff --git a/tensorflow/compiler/xla/xla_data.proto b/tensorflow/compiler/xla/xla_data.proto index 85ec83437a10d973687a7fb84285c2e2541a53c7..e9c86abe5094244988d3465ef7c949509deaec37 100644 --- a/tensorflow/compiler/xla/xla_data.proto +++ b/tensorflow/compiler/xla/xla_data.proto @@ -100,6 +100,8 @@ message PaddingConfig { // A format specifies the method used by a layout to store an array in memory. enum Format { + // TODO(b/120869032): Rename this to FORMAT_NONE or something else which + // better corresponds to its meaning. INVALID_FORMAT = 0; // The default layout, with exactly one storage location per element. DENSE = 1; @@ -109,8 +111,9 @@ enum Format { } // Describes a tile used in tiling-based layout. Refer to -// g3doc/layout_with_tiling.md for details about tiling-based layout. -message Tile { +// g3doc/third_party/tensorflow/compiler/xla/g3doc/layout_with_tiling.md for +// details about tiling-based layout. +message TileProto { // Number of elements in each dimension of the tile. It's ordered from the // most major dimension of the tile to the most minor dimension of the tile. // The dimensions correspond to a suffix of the dimensions of the shape being @@ -128,7 +131,7 @@ message Tile { // See the XLA documentation for more information on shapes and layouts. // // LINT.IfChange -message Layout { +message LayoutProto { // The method used to store the data in memory. The format determines which of // the other fields are used by the layout. Format format = 4; @@ -153,7 +156,7 @@ message Layout { // // TODO(b/119839262): implement tiling in each backend or add Unimplemented // error. - repeated Tile tiles = 6; + repeated TileProto tiles = 6; // Bit size of each element. If the size is bigger than what the element // type requires, the value is stored in the least significant @@ -196,7 +199,7 @@ message ShapeProto { repeated ShapeProto tuple_shapes = 4; // The layout used to back this shape. - Layout layout = 5; + LayoutProto layout = 5; // Important: if any field is added, be sure to modify ShapeUtil::Equal(), // ShapeUtil::Compatible() and ShapeUtil::Hash() appropriately to account for diff --git a/tensorflow/compiler/xrt/kernels/xrt_state_ops.cc b/tensorflow/compiler/xrt/kernels/xrt_state_ops.cc index 3258286c10665225aab917107ffa614459c53f3d..1a5bfac337baf773b84b92af5f88ef7a4c8ba81f 100644 --- a/tensorflow/compiler/xrt/kernels/xrt_state_ops.cc +++ b/tensorflow/compiler/xrt/kernels/xrt_state_ops.cc @@ -120,4 +120,9 @@ REGISTER_KERNEL_BUILDER(Name("XRTReleaseAllocationHandle") .HostMemory("handle"), XRTReleaseAllocationOp); +REGISTER_KERNEL_BUILDER(Name("XRTReleaseAllAllocations").Device(DEVICE_XLA_GPU), + XRTReleaseAllAllocationsOp); +REGISTER_KERNEL_BUILDER(Name("XRTReleaseAllAllocations").Device(DEVICE_XLA_CPU), + XRTReleaseAllAllocationsOp); + } // namespace tensorflow diff --git a/tensorflow/compiler/xrt/kernels/xrt_state_ops.h b/tensorflow/compiler/xrt/kernels/xrt_state_ops.h index 26a58fa42d8b730b365b11d2e5608e9945497763..2e2f3ff116a7b331df8dbd58a9fe40096f524140 100644 --- a/tensorflow/compiler/xrt/kernels/xrt_state_ops.h +++ b/tensorflow/compiler/xrt/kernels/xrt_state_ops.h @@ -183,9 +183,7 @@ class XRTAllocateOp : public OpKernel { // We are guaranteed that the underlying device object won't be deleted out // from under us, while the ScopedRef is live. class DeviceAccessor::ScopedRef device_ref; - OP_REQUIRES_OK(ctx, - DeviceAccessor::InitScopedRef( - ctx, allocation_proto.device_ordinal(), &device_ref)); + OP_REQUIRES_OK(ctx, DeviceAccessor::InitScopedRef(ctx, &device_ref)); XRTTupleAllocation* allocation; OP_REQUIRES_OK(ctx, XRTTupleAllocation::CreateAndTransfer( @@ -469,6 +467,26 @@ class XRTReleaseAllocationOp : public OpKernel { } }; +// Op that discards a handle to device memory. +template +class XRTReleaseAllAllocationsOp : public OpKernel { + public: + explicit XRTReleaseAllAllocationsOp(OpKernelConstruction* ctx) + : OpKernel(ctx) {} + ~XRTReleaseAllAllocationsOp() override = default; + XRTReleaseAllAllocationsOp(const XRTReleaseAllAllocationsOp&) = delete; + XRTReleaseAllAllocationsOp& operator=(const XRTReleaseAllAllocationsOp&) = + delete; + + void Compute(OpKernelContext* ctx) override { + VLOG(1) << "XRTReleaseAllAllocationsOp::Compute"; + + ResourceMgr* rm; + OP_REQUIRES_OK(ctx, DeviceAccessor::GetResourceManager(ctx, &rm)); + OP_REQUIRES_OK(ctx, XRTTupleAllocation::ReleaseAllAllocations(rm)); + } +}; + } // namespace tensorflow #endif // TENSORFLOW_COMPILER_XRT_KERNELS_XRT_STATE_OPS_H_ diff --git a/tensorflow/compiler/xrt/ops/xrt_state_ops.cc b/tensorflow/compiler/xrt/ops/xrt_state_ops.cc index a3d63106fa14674a9f5887ccfd908ce17dbc6384..fe6bee0dacf5dc2050613fc9ad34d3235b5a7b63 100644 --- a/tensorflow/compiler/xrt/ops/xrt_state_ops.cc +++ b/tensorflow/compiler/xrt/ops/xrt_state_ops.cc @@ -133,4 +133,11 @@ used. 'handle' is the id returned from the Op that produced the on-device allocation. )"); +REGISTER_OP("XRTReleaseAllAllocations") + .SetShapeFn(tensorflow::shape_inference::NoOutputs) + .Doc( + R"( +Discards all the XRT allocations. All the client held handles will be invalid. +)"); + } // namespace tensorflow diff --git a/tensorflow/compiler/xrt/tests/raw_api_test.cc b/tensorflow/compiler/xrt/tests/raw_api_test.cc index abaa17e50e3f5e47a45f5a8a45fa2090d3efee39..5f8121703e108f26b048feb7a0412a282f52892c 100644 --- a/tensorflow/compiler/xrt/tests/raw_api_test.cc +++ b/tensorflow/compiler/xrt/tests/raw_api_test.cc @@ -217,7 +217,6 @@ xla::ProgramShape XlaCompiledProgramShape( TEST(RawApiTest, AllocAndRewrite) { xrt::XLAAllocation alloc; - alloc.set_device_ordinal(0); *alloc.mutable_value() = xla::LiteralUtil::CreateR2({{4, 5}, {6, 7}}).ToProto(); @@ -265,9 +264,38 @@ TEST(RawApiTest, AllocAndRewrite) { &outputs)); } +TEST(RawApiTest, AllocAndClearAll) { + xrt::XLAAllocation alloc; + *alloc.mutable_value() = + xla::LiteralUtil::CreateR2({{4, 5}, {6, 7}}).ToProto(); + + Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag()); + auto value = + ops::Const(root.WithDevice("/device:CPU:0"), alloc.SerializeAsString()); + auto handle = ops::XRTAllocate(root, value); + TF_ASSERT_OK(root.status()); + + tensorflow::ClientSession session(root); + std::vector outputs; + TF_EXPECT_OK(session.Run({handle}, &outputs)); + EXPECT_EQ(outputs.size(), 1); + + int64 allocation_handle = outputs[0].scalar()(); + + auto clear_all = ops::XRTReleaseAllAllocations(root); + + outputs.clear(); + TF_EXPECT_OK(session.Run(tensorflow::ClientSession::FeedType(), {}, + {clear_all}, &outputs)); + EXPECT_EQ(outputs.size(), 0); + + auto read_after_clear = ops::XRTReadLiteral(root, Input(allocation_handle)); + EXPECT_EQ(session.Run({read_after_clear}, &outputs).code(), + tensorflow::error::Code::NOT_FOUND); +} + TEST(RawApiTest, ReadAndWriteState) { xrt::XLAAllocation alloc; - alloc.set_device_ordinal(0); *alloc.mutable_value() = TwoElementTuple(); Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag()); @@ -292,7 +320,6 @@ TEST(RawApiTest, ReadAndWriteState) { TEST(RawApiTest, ReadAndWriteStateAutoFree) { xrt::XLAAllocation alloc; - alloc.set_device_ordinal(0); *alloc.mutable_value() = TwoElementTuple(); Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag()); @@ -313,7 +340,6 @@ TEST(RawApiTest, ReadAndWriteStateAutoFree) { TEST(RawApiTest, SubBuffer) { xrt::XLAAllocation alloc; - alloc.set_device_ordinal(0); *alloc.mutable_value() = NestedTuple(); Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag()); @@ -354,10 +380,8 @@ TEST(RawApiTest, SubBuffer) { TEST(RawApiTest, MakeTuple) { xrt::XLAAllocation alloc_0; - alloc_0.set_device_ordinal(0); *alloc_0.mutable_value() = TwoElementTuple(); xrt::XLAAllocation alloc_1; - alloc_1.set_device_ordinal(0); *alloc_1.mutable_value() = ScalarLiteral(); // The trivial tuple that just forwards its input and releases it. @@ -428,10 +452,8 @@ TEST(RawApiTest, MakeTuple) { TEST(RawApiTest, CompileAndExecute) { xrt::XLAAllocation p0; - p0.set_device_ordinal(0); *p0.mutable_value() = FloatVector({1.0f, 2.0f}); xrt::XLAAllocation p1; - p1.set_device_ordinal(0); *p1.mutable_value() = FloatVector({8.0f, 5.0f}); xrt::XLAComputation c; @@ -483,10 +505,8 @@ TEST(RawApiTest, CompileAndExecute) { TEST(RawApiTest, CompileAndExecuteWithArgumentVector) { xrt::XLAAllocation p0; - p0.set_device_ordinal(0); *p0.mutable_value() = FloatVector({1.0f, 2.0f}); xrt::XLAAllocation p1; - p1.set_device_ordinal(0); *p1.mutable_value() = FloatVector({8.0f, 5.0f}); xrt::XLAComputation c; @@ -606,10 +626,8 @@ TEST(RawApiTest, DotGeneralWithLayoutTest) { auto layout = xla::LayoutUtil::MakeLayout({0, 1}); xrt::XLAAllocation p0; - p0.set_device_ordinal(0); *p0.mutable_value() = FloatMatrix({{1.0f, 2.0f}, {3.0f, 4.0f}}, layout); xrt::XLAAllocation p1; - p1.set_device_ordinal(0); *p1.mutable_value() = FloatMatrix({{8.0f}, {5.0f}}, layout); xrt::XLAComputation c; @@ -692,10 +710,8 @@ TEST(RawApiTest, CompileAndExecuteZeroArg) { TEST(RawApiTest, CompileAndExecuteReturnTuple) { xrt::XLAAllocation p0; - p0.set_device_ordinal(0); *p0.mutable_value() = FloatVector({1.0f, 2.0f}); xrt::XLAAllocation p1; - p1.set_device_ordinal(0); *p1.mutable_value() = FloatVector({8.0f, 5.0f}); xrt::XLAComputation c; @@ -745,11 +761,9 @@ TEST(RawApiTest, CompileAndExecuteReturnTuple) { TEST(RawApiTest, CompileAndExecuteReturnExplodedTuple) { xrt::XLAAllocation p0; - p0.set_device_ordinal(0); *p0.mutable_value() = xla::LiteralUtil::CreateR0(12.0f).ToProto(); xrt::XLAAllocation p1; - p1.set_device_ordinal(0); *p1.mutable_value() = xla::LiteralUtil::CreateR0(3.0f).ToProto(); xrt::XLAComputation c; @@ -833,10 +847,8 @@ TEST(RawApiTest, LeakCompilationReference) { TEST(RawApiTest, CompileAndExecuteWithS64Argument) { xrt::XLAAllocation p0; - p0.set_device_ordinal(0); *p0.mutable_value() = xla::LiteralUtil::CreateR0(11031965).ToProto(); xrt::XLAAllocation p1; - p1.set_device_ordinal(0); *p1.mutable_value() = xla::LiteralUtil::CreateR0(4091934).ToProto(); xrt::XLAComputation c; diff --git a/tensorflow/compiler/xrt/xrt.proto b/tensorflow/compiler/xrt/xrt.proto index 378bb9246f27b8106310d565435404d7ac260a87..84adee7392825d408dd88dd74dc0c1bc7b06d7c4 100644 --- a/tensorflow/compiler/xrt/xrt.proto +++ b/tensorflow/compiler/xrt/xrt.proto @@ -59,7 +59,7 @@ message XLAComputation { // Literal to allocate space for, and transfer to, device memory. message XLAAllocation { - int32 device_ordinal = 1; + reserved 1; xla.LiteralProto value = 2; } diff --git a/tensorflow/compiler/xrt/xrt_device.cc b/tensorflow/compiler/xrt/xrt_device.cc index ea40e6c895c4f6af13b74735685f2c342181ada9..34cb64742a20985b29d8e153bbaf5ee184fd385d 100644 --- a/tensorflow/compiler/xrt/xrt_device.cc +++ b/tensorflow/compiler/xrt/xrt_device.cc @@ -43,4 +43,12 @@ namespace tensorflow { return Status::OK(); } +/*static*/ Status XRTGenericDeviceAccessor::InitScopedRef( + OpKernelContext* ctx, ScopedRef* scoped_ref) { + const XlaDevice::Metadata* metadata; + TF_RETURN_IF_ERROR(XlaDevice::GetMetadata(ctx, &metadata)); + scoped_ref->Acquire(metadata->client()); + return Status::OK(); +} + } // namespace tensorflow diff --git a/tensorflow/compiler/xrt/xrt_device.h b/tensorflow/compiler/xrt/xrt_device.h index 1e3fddd2a72a3657d1e115375133c244772ea9f3..fb010651d9bf76c540517b9596e472c881241d8a 100644 --- a/tensorflow/compiler/xrt/xrt_device.h +++ b/tensorflow/compiler/xrt/xrt_device.h @@ -59,6 +59,8 @@ class XRTGenericDeviceAccessor { static Status InitScopedRef(OpKernelContext* ctx, int device_ordinal, ScopedRef* scoped_ref); + + static Status InitScopedRef(OpKernelContext* ctx, ScopedRef* scoped_ref); }; } // namespace tensorflow diff --git a/tensorflow/compiler/xrt/xrt_state.cc b/tensorflow/compiler/xrt/xrt_state.cc index 5c7c537c340e45648e3a95ed49d69474154694af..343460ff107fa81be127950837f786fe4eeadf26 100644 --- a/tensorflow/compiler/xrt/xrt_state.cc +++ b/tensorflow/compiler/xrt/xrt_state.cc @@ -19,6 +19,7 @@ limitations under the License. #include "tensorflow/compiler/xrt/xrt_state.h" #include +#include #include #include #include @@ -34,6 +35,7 @@ limitations under the License. #include "tensorflow/core/framework/resource_mgr.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/random/random.h" +#include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/platform/types.h" #include "tensorflow/stream_executor/stream_executor.h" @@ -41,6 +43,34 @@ namespace tensorflow { namespace { +class BufferAllocStats { + public: + struct Stats { + int64 count = 0; + int64 size = 0; + }; + + Stats ReportAlloc(int64 device, int64 msize) { + mutex_lock lock(lock_); + Stats* device_stats = &stats_[device]; + device_stats->count += 1; + device_stats->size += msize; + return *device_stats; + } + + Stats ReportFree(int64 device, int64 msize) { + mutex_lock lock(lock_); + Stats* device_stats = &stats_[device]; + device_stats->count -= 1; + device_stats->size -= msize; + return *device_stats; + } + + private: + mutable mutex lock_; + std::map stats_; +}; + const char* kTupleContainer = "tuples"; int64 get_uid() { @@ -48,6 +78,11 @@ int64 get_uid() { return static_cast(unsigned_rand); } +BufferAllocStats* GetAllocStats() { + static BufferAllocStats* stats = new BufferAllocStats(); + return stats; +} + Status AllocateScopedShapedBuffer( xla::Backend* backend, int device_ordinal, const xla::Shape& shape, std::unique_ptr* buffer) { @@ -100,9 +135,19 @@ XRTBufferAllocation::XRTBufferAllocation(const se::DeviceMemoryBase& allocation, xla::DeviceMemoryAllocator* allocator) : allocation_(allocation), device_ordinal_(device_ordinal), - allocator_(allocator) {} + allocator_(allocator) { + if (VLOG_IS_ON(2)) { + auto stats = + GetAllocStats()->ReportAlloc(device_ordinal_, allocation_.size()); + LOG(INFO) << "XRT Allocation Stats: device=" << device_ordinal_ + << " count=" << stats.count << " size=" << stats.size; + } +} XRTBufferAllocation::~XRTBufferAllocation() { + if (VLOG_IS_ON(2)) { + GetAllocStats()->ReportFree(device_ordinal_, allocation_.size()); + } // Deallocate explicitly allows allocation_ to be null. Status s = allocator_->Deallocate(device_ordinal_, allocation_); // Nothing to do but check fail here if memory datastructures are corrupted. @@ -227,6 +272,11 @@ const se::DeviceMemoryBase& XRTTupleAllocation::root_allocation() { return rm->Delete(kTupleContainer, key_string); } +/* static */ Status XRTTupleAllocation::ReleaseAllAllocations(ResourceMgr* rm) { + VLOG(1) << "Releasing all XRT held device memory"; + return rm->Cleanup(kTupleContainer); +} + // Helper typedef to make ShapeTree ForEach helper lambda signatures more // readable. They need a type of const T& where in this case T is the // following pointer. diff --git a/tensorflow/compiler/xrt/xrt_state.h b/tensorflow/compiler/xrt/xrt_state.h index 3664c0cd4e6ad26945ae1012208fdb006164a066..3e3d5024124e13b87eed6f79596d50cd64325914 100644 --- a/tensorflow/compiler/xrt/xrt_state.h +++ b/tensorflow/compiler/xrt/xrt_state.h @@ -129,6 +129,10 @@ class XRTTupleAllocation : public ResourceBase { // Deletes the reference in the rm to an allocation interned under key. static Status DeleteFromResourceManager(ResourceMgr* rm, int64 key); + // Releases all the device memory allocated by XRT within the resource + // manager. + static Status ReleaseAllAllocations(ResourceMgr* rm); + // Adds the allocation to a ResourceMgr and returns the key that will be used // to retrieve it. Transfers a reference on *this to rm. Status Intern(ResourceMgr* rm, int64* key); diff --git a/tensorflow/contrib/bigtable/kernels/test_kernels/bigtable_test_client.cc b/tensorflow/contrib/bigtable/kernels/test_kernels/bigtable_test_client.cc index e95dc577184f7e81d942755b41065f52131ce9f6..3fe71a2ea730cc9b60b2e2088a0d80a08b38d1a9 100644 --- a/tensorflow/contrib/bigtable/kernels/test_kernels/bigtable_test_client.cc +++ b/tensorflow/contrib/bigtable/kernels/test_kernels/bigtable_test_client.cc @@ -399,6 +399,17 @@ BigtableTestClient::AsyncMutateRows( return nullptr; } +std::unique_ptr> +BigtableTestClient::AsyncCheckAndMutateRow( + grpc::ClientContext* context, + const google::bigtable::v2::CheckAndMutateRowRequest& request, + grpc::CompletionQueue* cq) { + LOG(WARNING) << "Call to InMemoryDataClient::" << __func__ + << "(); this will likely cause a crash!"; + return nullptr; +} + std::shared_ptr BigtableTestClient::Channel() { LOG(WARNING) << "Call to InMemoryDataClient::Channel(); this will likely " "cause a crash!"; diff --git a/tensorflow/contrib/bigtable/kernels/test_kernels/bigtable_test_client.h b/tensorflow/contrib/bigtable/kernels/test_kernels/bigtable_test_client.h index c4a1f06bc504c3565c7bb09b42e48e7fbddb9cc6..85705904573e9e7710912e3f4ff30dd8fed5bf85 100644 --- a/tensorflow/contrib/bigtable/kernels/test_kernels/bigtable_test_client.h +++ b/tensorflow/contrib/bigtable/kernels/test_kernels/bigtable_test_client.h @@ -80,6 +80,13 @@ class BigtableTestClient : public ::google::cloud::bigtable::DataClient { const ::google::bigtable::v2::MutateRowsRequest& request, ::grpc::CompletionQueue* cq, void* tag) override; + std::unique_ptr> + AsyncCheckAndMutateRow( + grpc::ClientContext* context, + const google::bigtable::v2::CheckAndMutateRowRequest& request, + grpc::CompletionQueue* cq) override; + std::shared_ptr Channel() override; private: diff --git a/tensorflow/contrib/bigtable/python/ops/bigtable_api.py b/tensorflow/contrib/bigtable/python/ops/bigtable_api.py index b6cdc7aab0320fe5f457288ada03a46e18a694cc..fa64055dfd65a134afdf46cebccb7f7d96106502 100644 --- a/tensorflow/contrib/bigtable/python/ops/bigtable_api.py +++ b/tensorflow/contrib/bigtable/python/ops/bigtable_api.py @@ -489,7 +489,7 @@ class BigtableTable(object): "len(dataset.output_types))") return gen_bigtable_ops.dataset_to_bigtable( self._resource, - dataset._as_variant_tensor(), # pylint: disable=protected-access + dataset._variant_tensor, # pylint: disable=protected-access column_families, columns, timestamp) @@ -582,13 +582,14 @@ class _BigtableKeyDataset(dataset_ops.DatasetSource): """_BigtableKeyDataset is an abstract class representing the keys of a table. """ - def __init__(self, table): + def __init__(self, table, variant_tensor): """Constructs a _BigtableKeyDataset. Args: table: a Bigtable class. + variant_tensor: DT_VARIANT representation of the dataset. """ - super(_BigtableKeyDataset, self).__init__() + super(_BigtableKeyDataset, self).__init__(variant_tensor) self._table = table @property @@ -601,13 +602,11 @@ class _BigtablePrefixKeyDataset(_BigtableKeyDataset): """ def __init__(self, table, prefix): - super(_BigtablePrefixKeyDataset, self).__init__(table) self._prefix = prefix - - def _as_variant_tensor(self): - return gen_bigtable_ops.bigtable_prefix_key_dataset( - table=self._table._resource, # pylint: disable=protected-access + variant_tensor = gen_bigtable_ops.bigtable_prefix_key_dataset( + table=table._resource, # pylint: disable=protected-access prefix=self._prefix) + super(_BigtablePrefixKeyDataset, self).__init__(table, variant_tensor) class _BigtableRangeKeyDataset(_BigtableKeyDataset): @@ -615,15 +614,13 @@ class _BigtableRangeKeyDataset(_BigtableKeyDataset): """ def __init__(self, table, start, end): - super(_BigtableRangeKeyDataset, self).__init__(table) self._start = start self._end = end - - def _as_variant_tensor(self): - return gen_bigtable_ops.bigtable_range_key_dataset( - table=self._table._resource, # pylint: disable=protected-access + variant_tensor = gen_bigtable_ops.bigtable_range_key_dataset( + table=table._resource, # pylint: disable=protected-access start_key=self._start, end_key=self._end) + super(_BigtableRangeKeyDataset, self).__init__(table, variant_tensor) class _BigtableSampleKeysDataset(_BigtableKeyDataset): @@ -633,11 +630,9 @@ class _BigtableSampleKeysDataset(_BigtableKeyDataset): # TODO(saeta): Expose the data size offsets into the keys. def __init__(self, table): - super(_BigtableSampleKeysDataset, self).__init__(table) - - def _as_variant_tensor(self): - return gen_bigtable_ops.bigtable_sample_keys_dataset( - table=self._table._resource) # pylint: disable=protected-access + variant_tensor = gen_bigtable_ops.bigtable_sample_keys_dataset( + table=table._resource) # pylint: disable=protected-access + super(_BigtableSampleKeysDataset, self).__init__(table, variant_tensor) class _BigtableLookupDataset(dataset_ops.DatasetSource): @@ -651,20 +646,18 @@ class _BigtableLookupDataset(dataset_ops.DatasetSource): self._normalized = normalized self._column_families = [i[0] for i in normalized] self._columns = [i[1] for i in normalized] + variant_tensor = gen_bigtable_ops.bigtable_lookup_dataset( + keys_dataset=self._dataset._variant_tensor, # pylint: disable=protected-access + table=self._table._resource, # pylint: disable=protected-access + column_families=self._column_families, + columns=self._columns) + super(_BigtableLookupDataset, self).__init__(variant_tensor) @property def _element_structure(self): return structure.NestedStructure(tuple( [structure.TensorStructure(dtypes.string, [])] * self._num_outputs)) - def _as_variant_tensor(self): - # pylint: disable=protected-access - return gen_bigtable_ops.bigtable_lookup_dataset( - keys_dataset=self._dataset._as_variant_tensor(), - table=self._table._resource, - column_families=self._column_families, - columns=self._columns) - class _BigtableScanDataset(dataset_ops.DatasetSource): """_BigtableScanDataset represents a dataset that retrieves keys and values. @@ -679,14 +672,7 @@ class _BigtableScanDataset(dataset_ops.DatasetSource): self._columns = [i[1] for i in normalized] self._probability = probability self._num_outputs = len(normalized) + 1 # 1 for row key - - @property - def _element_structure(self): - return structure.NestedStructure(tuple( - [structure.TensorStructure(dtypes.string, [])] * self._num_outputs)) - - def _as_variant_tensor(self): - return gen_bigtable_ops.bigtable_scan_dataset( + variant_tensor = gen_bigtable_ops.bigtable_scan_dataset( table=self._table._resource, # pylint: disable=protected-access prefix=self._prefix, start_key=self._start, @@ -694,6 +680,13 @@ class _BigtableScanDataset(dataset_ops.DatasetSource): column_families=self._column_families, columns=self._columns, probability=self._probability) + super(_BigtableScanDataset, self).__init__(variant_tensor) + + @property + def _element_structure(self): + return structure.NestedStructure( + tuple( + [structure.TensorStructure(dtypes.string, [])] * self._num_outputs)) class _BigtableSampleKeyPairsDataset(dataset_ops.DatasetSource): @@ -705,17 +698,15 @@ class _BigtableSampleKeyPairsDataset(dataset_ops.DatasetSource): self._prefix = prefix self._start = start self._end = end + variant_tensor = gen_bigtable_ops.bigtable_sample_key_pairs_dataset( + table=self._table._resource, # pylint: disable=protected-access + prefix=self._prefix, + start_key=self._start, + end_key=self._end) + super(_BigtableSampleKeyPairsDataset, self).__init__(variant_tensor) @property def _element_structure(self): return structure.NestedStructure( (structure.TensorStructure(dtypes.string, []), structure.TensorStructure(dtypes.string, []))) - - def _as_variant_tensor(self): - # pylint: disable=protected-access - return gen_bigtable_ops.bigtable_sample_key_pairs_dataset( - table=self._table._resource, - prefix=self._prefix, - start_key=self._start, - end_key=self._end) diff --git a/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py b/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py index 9fdc2fc0c2c7b85502f7a3f9ae7c85cf05d5916c..a5951fb7377d48748f5eb578c034176517df7749 100644 --- a/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py +++ b/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py @@ -614,13 +614,19 @@ class GradientBoostedDecisionTreeModel(object): predictions_dict[NUM_TREES_ATTEMPTED] % self._logits_dimension) return constant_op.constant(-1, dtype=dtypes.int32) - def update_stats(self, loss, predictions_dict): + def update_stats(self, loss, predictions_dict, gradients=None, hessians=None): """Update the accumulators with stats from this batch. Args: loss: A scalar tensor representing average loss of examples. predictions_dict: Dictionary of Rank 2 `Tensor` representing information about predictions per example. + gradients: A tensor with the gradients with the respect to logits from + predictions_dict. If not provided, tensorflow will do + autodifferentiation. + hessians: A tensor with the hessians with the respect to logits from + predictions_dict. If not provided, tensorflow will do + autodifferentiation. Returns: Three values: @@ -642,13 +648,14 @@ class GradientBoostedDecisionTreeModel(object): predictions = predictions_dict[PREDICTIONS] partition_ids = predictions_dict[PARTITION_IDS] ensemble_stamp = predictions_dict[ENSEMBLE_STAMP] - gradients = gradients_impl.gradients( - loss, - predictions, - name="Gradients", - colocate_gradients_with_ops=False, - gate_gradients=0, - aggregation_method=None)[0] + if gradients is None: + gradients = gradients_impl.gradients( + loss, + predictions, + name="Gradients", + colocate_gradients_with_ops=False, + gate_gradients=0, + aggregation_method=None)[0] strategy = self._learner_config.multi_class_strategy class_id = self._get_class_id(predictions_dict) @@ -657,17 +664,20 @@ class GradientBoostedDecisionTreeModel(object): # We build one vs rest trees. if self._logits_dimension == 1: # We have only 1 score, gradients is of shape [batch, 1]. - hessians = gradients_impl.gradients( - gradients, - predictions, - name="Hessian", - colocate_gradients_with_ops=False, - gate_gradients=0, - aggregation_method=None)[0] + if hessians is None: + hessians = gradients_impl.gradients( + gradients, + predictions, + name="Hessian", + colocate_gradients_with_ops=False, + gate_gradients=0, + aggregation_method=None)[0] squeezed_gradients = array_ops.squeeze(gradients, axis=[1]) squeezed_hessians = array_ops.squeeze(hessians, axis=[1]) else: + if hessians is not None: + raise ValueError("Providing hessians is not yet supported here.") hessian_list = self._diagonal_hessian(gradients, predictions) # Assemble hessian list into a tensor. hessians = array_ops.stack(hessian_list, axis=1) @@ -678,6 +688,8 @@ class GradientBoostedDecisionTreeModel(object): squeezed_hessians = array_ops.squeeze( _get_column_by_index(hessians, class_id)) else: + if hessians is not None: + raise ValueError("Providing hessians is not yet supported here.") # Other multiclass strategies. if strategy == learner_pb2.LearnerConfig.FULL_HESSIAN: hessian_list = self._full_hessian(gradients, predictions) @@ -835,9 +847,9 @@ class GradientBoostedDecisionTreeModel(object): stats_update_ops.append( control_flow_ops.cond( continue_centering, - self._make_update_bias_stats_fn( - ensemble_stamp, predictions, gradients, - bias_stats_accumulator), control_flow_ops.no_op)) + self._make_update_bias_stats_fn(ensemble_stamp, predictions, + gradients, bias_stats_accumulator, + hessians), control_flow_ops.no_op)) # Update handler stats. handler_reads = collections.OrderedDict() @@ -1162,7 +1174,8 @@ class GradientBoostedDecisionTreeModel(object): def get_max_tree_depth(self): return self._max_tree_depth - def train(self, loss, predictions_dict, labels): + def train(self, loss, predictions_dict, labels, gradients=None, + hessians=None): """Updates the accumalator stats and grows the ensemble. Args: @@ -1171,6 +1184,12 @@ class GradientBoostedDecisionTreeModel(object): about predictions per example. labels: Rank 2 `Tensor` representing labels per example. Has no effect on the training and is only kept for backward compatibility. + gradients: A tensor with the gradients with the respect to logits from + predictions_dict. If not provided, tensorflow will do + autodifferentiation. + hessians: A tensor with the hessians with the respect to logits from + predictions_dict. If not provided, tensorflow will do + autodifferentiation. Returns: An op that adds a new tree to the ensemble. @@ -1179,7 +1198,8 @@ class GradientBoostedDecisionTreeModel(object): ValueError: if inputs are not valid. """ del labels # unused; kept for backward compatibility. - update_op, _, training_state = self.update_stats(loss, predictions_dict) + update_op, _, training_state = self.update_stats(loss, predictions_dict, + gradients, hessians) with ops.control_dependencies(update_op): return self.increment_step_counter_and_maybe_update_ensemble( predictions_dict, training_state) @@ -1271,21 +1291,28 @@ class GradientBoostedDecisionTreeModel(object): ps_ops=ps_ops, ps_strategy=ps_strategy) - def _make_update_bias_stats_fn(self, ensemble_stamp, predictions, gradients, - bias_stats_accumulator): + def _make_update_bias_stats_fn(self, + ensemble_stamp, + predictions, + gradients, + bias_stats_accumulator, + hessians=None): """A method to create the function which updates the bias stats.""" def _update_bias_stats(): """A method to update the bias stats.""" # Get reduced gradients and hessians. grads_sum = math_ops.reduce_sum(gradients, 0) - hess = gradients_impl.gradients( - grads_sum, - predictions, - name="Hessians", - colocate_gradients_with_ops=False, - gate_gradients=0, - aggregation_method=None)[0] + if hessians is not None: + hess = hessians + else: + hess = gradients_impl.gradients( + grads_sum, + predictions, + name="Hessians", + colocate_gradients_with_ops=False, + gate_gradients=0, + aggregation_method=None)[0] hess_sum = math_ops.reduce_sum(hess, 0) # Accumulate gradients and hessians. diff --git a/tensorflow/contrib/cmake/README.md b/tensorflow/contrib/cmake/README.md index df8b48dfc46124d3b9454d92ffb70dbcf1bc4217..b2badc5785bdb1ea90c7f07e544ea9047146eebd 100644 --- a/tensorflow/contrib/cmake/README.md +++ b/tensorflow/contrib/cmake/README.md @@ -157,7 +157,7 @@ suitable interface for project configuration and dependency setting. press `Finish`. Wait for a moment, the default project dependecy would automatically generate. 6. There are a few options that you can customize your own build. **The setting - here is crucial for a sucessful build, please check all items carefully.** + here is crucial for a successful build, please check all items carefully.** * `tensorflow_BUILD_ALL_KERNELS` should alway be `on` * `tensorflow_BUILD_CC_EXAMPLE` is default to be `on`. This can help you diff --git a/tensorflow/contrib/compiler/BUILD b/tensorflow/contrib/compiler/BUILD index e4566437c60ebb2da039e61c171fbe954a7355c9..e32097ceddfec95b8677fc762d641d09078e5343 100644 --- a/tensorflow/contrib/compiler/BUILD +++ b/tensorflow/contrib/compiler/BUILD @@ -70,22 +70,30 @@ py_library( ], ) -tf_py_test( +cuda_py_test( name = "xla_test", srcs = ["xla_test.py"], additional_deps = [ ":xla", - "@six_archive//:six", + "@absl_py//absl/testing:parameterized", + "//tensorflow/compiler/tests:xla_test", + "//tensorflow/contrib/tpu:tpu_estimator", + "//tensorflow/contrib/tpu:tpu_lib", + "//tensorflow/python:client_testlib", "//tensorflow/python:constant_op", "//tensorflow/python:control_flow_ops", "//tensorflow/python:control_flow_util", "//tensorflow/python:math_ops", "//tensorflow/python:platform", - "//tensorflow/contrib/tpu:tpu_lib", "//tensorflow/python:state_ops", "//tensorflow/python:summary", "//tensorflow/python:training", "//tensorflow/python:variable_scope", + "//tensorflow/python/data/ops:dataset_ops", + ], + tags = [ + "no_mac", + "no_windows", ], - tags = ["no_pip"], + xla_enabled = True, ) diff --git a/tensorflow/contrib/compiler/xla_test.py b/tensorflow/contrib/compiler/xla_test.py index 3b49755afcf0753d31c0ce506dce42709b1ee8bc..c4384dcde75035dc55e67bd503e348fe19b97025 100644 --- a/tensorflow/contrib/compiler/xla_test.py +++ b/tensorflow/contrib/compiler/xla_test.py @@ -18,11 +18,19 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import re +from absl.testing import parameterized + from tensorflow.contrib.compiler import xla +from tensorflow.contrib.tpu.python.tpu import tpu_estimator from tensorflow.contrib.tpu.python.tpu import tpu_feed +from tensorflow.contrib.training.python.training import hparam from tensorflow.python import summary +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.estimator import model_fn as model_fn_lib from tensorflow.python.framework import constant_op from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import control_flow_util from tensorflow.python.ops import logging_ops @@ -30,6 +38,14 @@ from tensorflow.python.ops import math_ops from tensorflow.python.ops import state_ops from tensorflow.python.ops import variable_scope from tensorflow.python.platform import test +from tensorflow.python.training import training + + +_TRAIN = model_fn_lib.ModeKeys.TRAIN +_EVAL = model_fn_lib.ModeKeys.EVAL +_EXPECTED_LOSS = 1 +_EXPECTED_FEATURE = 2 +_EXPECTED_LABEL = 3 class XLACompileContextTest(test.TestCase): @@ -252,5 +268,329 @@ class CheckFunctionArgumentCountTest(test.TestCase): xla.check_function_argument_count(func, 0, queue)) +def _test_train_model_fn(features, labels, mode, params): + """A dummy model_fn for testing purpose.""" + del features, labels, params + loss = constant_op.constant(_EXPECTED_LOSS) + return model_fn_lib.EstimatorSpec( + mode=mode, loss=loss, train_op=array_ops.identity(loss)) + + +@xla.estimator_model_fn +def decorated_model_fn(features, labels, mode, params): + return _test_train_model_fn(features, labels, mode, params) + + +def make_dummy_features_labels(): + # XLA CPU/GPU backend doesn't support guaranteed constant, thus use dataset + # container to work around. + features_dataset = dataset_ops.Dataset.from_tensors( + constant_op.constant(_EXPECTED_FEATURE)).repeat(10) + features_op = features_dataset.make_one_shot_iterator().get_next() + labels_dataset = dataset_ops.Dataset.from_tensors( + constant_op.constant(_EXPECTED_LABEL)).repeat(10) + labels_op = labels_dataset.make_one_shot_iterator().get_next() + return features_op, labels_op + + +class XlaDecoratorTest(test.TestCase, parameterized.TestCase): + + @parameterized.named_parameters( + ('test_use_as_decorator', decorated_model_fn, None), + ('test_use_as_function', xla.estimator_model_fn(_test_train_model_fn), + None), + ('test_use_tpu_false_hparams', decorated_model_fn, + hparam.HParams(use_tpu=False)), + ('test_use_tpu_false_dict_params', decorated_model_fn, { + 'use_tpu': False + }), + ) + def test_compile(self, model_fn, params): + """Calls model_fn and verifies it is compiled.""" + with test.mock.patch.object(xla, 'compile') as mock_xla_compile: + loss = constant_op.constant(_EXPECTED_LOSS) + mock_xla_compile.return_value = [loss] + + features, labels = make_dummy_features_labels() + estimator_spec = model_fn( + features=features, labels=labels, mode=_TRAIN, params=params or {}) + + self.assertEqual(mock_xla_compile.call_count, 1) + self.assertEqual(estimator_spec.mode, _TRAIN) + + with self.test_session() as sess: + self.assertEqual(sess.run(estimator_spec.loss), sess.run(loss)) + self.assertEqual(sess.run(estimator_spec.train_op), sess.run(loss)) + + @parameterized.named_parameters( + ('test_use_tpu_true_hparams', decorated_model_fn, + hparam.HParams(use_tpu=True)), + ('test_use_tpu_true_dict_params', decorated_model_fn, { + 'use_tpu': True + }), + ) + def test_not_compile(self, model_fn, params): + """Calls model_fn and verifies it is NOT compiled.""" + with test.mock.patch.object(xla, 'compile') as mock_xla_compile: + loss = constant_op.constant(_EXPECTED_LOSS) + mock_xla_compile.return_value = [loss] + + features, labels = make_dummy_features_labels() + estimator_spec = model_fn( + features=features, labels=labels, mode=_TRAIN, params=params or {}) + + mock_xla_compile.assert_not_called() + self.assertEqual(estimator_spec.mode, _TRAIN) + + with self.test_session() as sess: + self.assertEqual(sess.run(estimator_spec.loss), sess.run(loss)) + self.assertEqual(sess.run(estimator_spec.train_op), sess.run(loss)) + + def test_model_with_summary(self): + """Tests that summary ops are disabled.""" + + @xla.estimator_model_fn + def model_fn_with_summary(features, labels, mode, params): + del features, labels, params + loss = constant_op.constant(_EXPECTED_LOSS) + summary.scalar('loss_scalar_summary', loss) + summary.histogram('loss_histogram_summary', loss) + summary.image('loss_image_summary', loss) + return model_fn_lib.EstimatorSpec( + mode=mode, loss=loss, train_op=array_ops.identity(loss)) + + features, labels = make_dummy_features_labels() + estimator_spec = model_fn_with_summary( + features=features, labels=labels, mode=_TRAIN, params={}) + + with self.test_session() as sess: + self.assertEqual(sess.run(estimator_spec.loss), _EXPECTED_LOSS) + + +def _test_eval_metric_fn(eval_tensor_1, eval_tensor_2): + return { + 'metric_1': (eval_tensor_1, eval_tensor_1), + 'metric_2': (eval_tensor_2, eval_tensor_2), + } + + +class XlaDecoratorEvaluationTest(test.TestCase): + + def _verify_evaluation_result(self, eval_model_fn): + features, labels = make_dummy_features_labels() + estimator_spec = eval_model_fn( + features=features, labels=labels, mode=_EVAL, params={}) + + with self.test_session() as sess: + self.assertEqual(sess.run(estimator_spec.loss), _EXPECTED_LOSS) + self.assertEqual( + sess.run(estimator_spec.eval_metric_ops['metric_1'][0]), + _EXPECTED_FEATURE + _EXPECTED_LABEL) + self.assertEqual( + sess.run(estimator_spec.eval_metric_ops['metric_1'][1]), + _EXPECTED_FEATURE + _EXPECTED_LABEL) + self.assertEqual( + sess.run(estimator_spec.eval_metric_ops['metric_2'][0]), + _EXPECTED_FEATURE - _EXPECTED_LABEL) + self.assertEqual( + sess.run(estimator_spec.eval_metric_ops['metric_2'][1]), + _EXPECTED_FEATURE - _EXPECTED_LABEL) + + def test_eval_base_estimator_spec_eval_metric_ops_disallowed(self): + + @xla.estimator_model_fn + def eval_model_fn_return_estimator_spec(features, labels, mode, params): + del features, labels, params + loss = constant_op.constant(_EXPECTED_LOSS) + return model_fn_lib.EstimatorSpec( + mode=mode, + loss=loss, + eval_metric_ops={ + 'metric': (array_ops.identity(loss), control_flow_ops.no_op()) + }) + + with self.assertRaisesRegexp( + ValueError, 'EstimatorSpec.eval_metric_ops is not supported with XLA ' + 'compilation. Please use TPUEstimatorSpec.eval_metrics instead.'): + self._verify_evaluation_result(eval_model_fn_return_estimator_spec) + + def test_eval_base_estimator_spec_no_eval_metric_ops(self): + + @xla.estimator_model_fn + def eval_model_fn_no_eval_metric_ops(features, labels, mode, params): + del features, labels, params + return model_fn_lib.EstimatorSpec( + mode=mode, loss=constant_op.constant(_EXPECTED_LOSS)) + + features, labels = make_dummy_features_labels() + estimator_spec = eval_model_fn_no_eval_metric_ops( + features=features, labels=labels, mode=_EVAL, params={}) + with self.test_session() as sess: + self.assertEqual(sess.run(estimator_spec.loss), _EXPECTED_LOSS) + + def test_eval_no_eval_metrics(self): + + @xla.estimator_model_fn + def eval_model_fn_no_eval_metrics(features, labels, mode, params): + del features, labels, params + return tpu_estimator.TPUEstimatorSpec( + mode=mode, loss=constant_op.constant(_EXPECTED_LOSS)) + + features, labels = make_dummy_features_labels() + estimator_spec = eval_model_fn_no_eval_metrics( + features=features, labels=labels, mode=_EVAL, params={}) + + self.assertEqual(estimator_spec.eval_metric_ops, {}) + with self.test_session() as sess: + self.assertEqual(sess.run(estimator_spec.loss), _EXPECTED_LOSS) + + def test_eval_fn_missing_input_tensor(self): + + @xla.estimator_model_fn + def eval_model_fn(features, labels, mode, params): + del params + dummy_eval_metric_fn_tensors_dict = { + 'eval_tensor_1': features + labels, + } + return tpu_estimator.TPUEstimatorSpec( + mode=mode, + loss=constant_op.constant(_EXPECTED_LOSS), + eval_metrics=(_test_eval_metric_fn, + dummy_eval_metric_fn_tensors_dict)) + + with self.assertRaisesRegexp( + ValueError, + re.escape("Arguments ['eval_tensor_2'] are needed by metric_fn (first " + 'element of TPUEstimatorSpec.eval_metrics) but they are not ' + 'provided by evaluation tensors (second element of ' + 'TPUEstimatorSpec.eval_metrics).')): + self._verify_evaluation_result(eval_model_fn) + + def test_eval_fn_extraneous_input_tensor(self): + + @xla.estimator_model_fn + def eval_model_fn(features, labels, mode, params): + del params + dummy_eval_metric_fn_tensors_dict = { + 'eval_tensor_1': features + labels, + 'eval_tensor_2': features - labels, + 'extra_tensor': features * 2 - labels, + } + return tpu_estimator.TPUEstimatorSpec( + mode=mode, + loss=constant_op.constant(_EXPECTED_LOSS), + eval_metrics=(_test_eval_metric_fn, + dummy_eval_metric_fn_tensors_dict)) + + with self.assertRaisesRegexp( + ValueError, + re.escape("Arguments ['extra_tensor'] are provided by evaluation " + 'tensors (second element of TPUEstimatorSpec.eval_metrics) ' + 'but they are not needed by metric_fn (first element of ' + 'TPUEstimatorSpec.eval_metrics).')): + self._verify_evaluation_result(eval_model_fn) + + def test_eval_tensors_as_list(self): + + @xla.estimator_model_fn + def eval_model_fn(features, labels, mode, params): + del params + dummy_eval_metric_fn_tensors = [features + labels, features - labels] + return tpu_estimator.TPUEstimatorSpec( + mode=mode, + loss=constant_op.constant(_EXPECTED_LOSS), + eval_metrics=(_test_eval_metric_fn, dummy_eval_metric_fn_tensors)) + + self._verify_evaluation_result(eval_model_fn) + + def test_eval_tensors_as_dict(self): + + @xla.estimator_model_fn + def eval_model_fn(features, labels, mode, params): + del params + dummy_eval_metric_fn_tensors_dict = { + 'eval_tensor_1': features + labels, + 'eval_tensor_2': features - labels, + } + return tpu_estimator.TPUEstimatorSpec( + mode=mode, + loss=constant_op.constant(_EXPECTED_LOSS), + eval_metrics=(_test_eval_metric_fn, + dummy_eval_metric_fn_tensors_dict)) + + self._verify_evaluation_result(eval_model_fn) + + def test_model_with_summary(self): + """Tests that summary ops are disabled.""" + + @xla.estimator_model_fn + def model_fn_with_summary(features, labels, mode, params): + del features, labels, params + loss = constant_op.constant(_EXPECTED_LOSS) + summary.scalar('loss_scalar_summary', loss) + summary.histogram('loss_histogram_summary', loss) + summary.image('loss_image_summary', loss) + return tpu_estimator.TPUEstimatorSpec(mode=mode, loss=loss) + + features, labels = make_dummy_features_labels() + estimator_spec = model_fn_with_summary( + features=features, labels=labels, mode=_EVAL, params={}) + + with self.test_session() as sess: + self.assertEqual(sess.run(estimator_spec.loss), _EXPECTED_LOSS) + + +class XlaDecoratorScaffoldTest(test.TestCase, parameterized.TestCase): + + def _make_scaffold_fn(self, mode): + + def _scaffold_fn_on_cpu(): + scaffold = training.Scaffold() + self.assertNotIn(mode, self.is_scaffold_fn_called) + self.is_scaffold_fn_called[mode] = True + return scaffold + + return _scaffold_fn_on_cpu + + def test_scaffold_fn_return_none(self): + + @xla.estimator_model_fn + def model_fn(features, labels, mode, params): + del features, labels, params + return tpu_estimator.TPUEstimatorSpec( + mode=mode, + loss=constant_op.constant(_EXPECTED_LOSS), + train_op=control_flow_ops.no_op(), + scaffold_fn=lambda: None) + + features, labels = make_dummy_features_labels() + with self.assertRaisesRegexp( + ValueError, + 'TPUEstimatorSpec.scaffold_fn returns None, which is not allowed'): + model_fn(features=features, labels=labels, mode=_TRAIN, params={}) + + @parameterized.named_parameters( + ('train_mode', _TRAIN), + ('eval_mode', _EVAL), + # TODO(ycao): Add predict_mode test after PREDICT mode is implemented. + ) + def test_scaffold_fn_in_mode(self, mode): + + @xla.estimator_model_fn + def model_fn(features, labels, mode, params): + del features, labels, params + return tpu_estimator.TPUEstimatorSpec( + mode=mode, + loss=constant_op.constant(_EXPECTED_LOSS), + train_op=control_flow_ops.no_op(), + scaffold_fn=self._make_scaffold_fn(mode)) + + features, labels = make_dummy_features_labels() + + self.is_scaffold_fn_called = {} + model_fn(features=features, labels=labels, mode=mode, params={}) + self.assertTrue(self.is_scaffold_fn_called[mode]) + + if __name__ == '__main__': test.main() diff --git a/tensorflow/contrib/constrained_optimization/BUILD b/tensorflow/contrib/constrained_optimization/BUILD index 619153df67c90cea5a5082a411972948bac5fe90..eee4329acbeb38c9f37f79227aeb3acd46dce5e7 100644 --- a/tensorflow/contrib/constrained_optimization/BUILD +++ b/tensorflow/contrib/constrained_optimization/BUILD @@ -42,6 +42,11 @@ py_test( name = "candidates_test", srcs = ["python/candidates_test.py"], srcs_version = "PY2AND3", + tags = [ + # TODO(b/121223093): Re-enable this test after fixing "Distribution + # should match known solution" errors. + "no_mac", + ], deps = [ ":constrained_optimization", "//tensorflow/python:client_testlib", diff --git a/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_ops_test.py b/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_ops_test.py index a268415f0e65206294431a537be18cadbe1a1e84..f5219eb134d07c09b16a544f71d4c18986c19681 100644 --- a/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_ops_test.py +++ b/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_ops_test.py @@ -68,6 +68,7 @@ def RunLSTM(sess, batch_size, time, num_layers=1, + variable_seq_lengths=False, is_training=True, dropout=0., num_dirs=True, @@ -99,6 +100,13 @@ def RunLSTM(sess, num_units).astype(dtype.as_numpy_dtype), dtype=dtype) + if variable_seq_lengths: + lengths_v = np.random.randint(low=1, high=time + 1, size=batch_size) + lengths_v[0] = time # make sure the max sequence has 'time' elems + lengths = ops.convert_to_tensor(lengths_v.astype(np.int32)) + else: + lengths = None + initializer = init_ops.random_uniform_initializer( -0.01, 0.01, dtype=dtype, seed=19980904) @@ -115,6 +123,7 @@ def RunLSTM(sess, outputs_op, state_tuple_op = rnn.dynamic_rnn( cell, inputs, + sequence_length=lengths, initial_state=rnn_cell_impl.LSTMStateTuple( h=initial_h_op, c=initial_c_op), dtype=dtype, @@ -133,6 +142,7 @@ def RunLSTM(sess, cu_initial_h_op, cu_initial_c_op, opaque_params, + sequence_lengths=lengths, dropout=dropout, is_training=is_training, rnn_mode=cudnn_rnn_ops.CUDNN_LSTM) @@ -325,12 +335,19 @@ class CudnnLSTMTest(TensorFlowTestCase, parameterized.TestCase): time, num_layers, dtype, - rtol=2e-6, - atol=2e-6): + variable_seq_lengths, + rtol=3e-6, + atol=3e-6): with self.session(use_gpu=True) as sess: (outputs, cu_outputs, state_tuple, cu_state_tuple, inp_grad, cu_inp_grad, state_grad, cu_state_grad, wgrad, bgrad, cu_wgrad, cu_bgrad) = RunLSTM( - sess, num_units, input_size, batch_size, time, num_layers) + sess, + num_units, + input_size, + batch_size, + time, + num_layers, + variable_seq_lengths=variable_seq_lengths) self.assertAllClose(outputs, cu_outputs, rtol=rtol, atol=atol) for s, cu_s in zip(state_tuple, cu_state_tuple): @@ -341,20 +358,33 @@ class CudnnLSTMTest(TensorFlowTestCase, parameterized.TestCase): self.assertAllClose(bgrad, cu_bgrad, rtol=rtol, atol=atol) self.assertAllClose(wgrad, cu_wgrad, rtol=rtol, atol=atol) - @parameterized.named_parameters(*NAMED_RNN_TESTCASES) + @parameterized.named_parameters( + ExpandNamedTestCases(NAMED_RNN_TESTCASES, **{ + "variable_seq_lengths": [True, False], + })) @unittest.skipUnless(test.is_built_with_cuda(), "Test only applicable when running on GPUs") - def test_training(self, num_units, input_size, batch_size, time, num_layers): + def test_training(self, num_units, input_size, batch_size, time, num_layers, + variable_seq_lengths): if not context.context().num_gpus(): self.skipTest("No GPUs found") - self._test_training_helper(num_units, input_size, batch_size, time, - num_layers, dtypes.float32) + self._test_training_helper( + num_units, + input_size, + batch_size, + time, + num_layers, + dtypes.float32, + variable_seq_lengths=variable_seq_lengths) - @parameterized.named_parameters(*NAMED_RNN_TESTCASES) + @parameterized.named_parameters( + ExpandNamedTestCases(NAMED_RNN_TESTCASES, **{ + "variable_seq_lengths": [True, False], + })) @unittest.skipUnless(test.is_built_with_cuda(), "Test only applicable when running on GPUs") def test_training_fp16(self, num_units, input_size, batch_size, time, - num_layers): + num_layers, variable_seq_lengths): if not context.context().num_gpus(): self.skipTest("No GPUs found") self._test_training_helper( @@ -365,12 +395,17 @@ class CudnnLSTMTest(TensorFlowTestCase, parameterized.TestCase): num_layers, dtypes.float16, rtol=5e-3, - atol=5e-4) + atol=5e-4, + variable_seq_lengths=variable_seq_lengths) - @parameterized.named_parameters(*NAMED_RNN_TESTCASES) + @parameterized.named_parameters( + ExpandNamedTestCases(NAMED_RNN_TESTCASES, **{ + "variable_seq_lengths": [True, False], + })) @unittest.skipUnless(test.is_built_with_cuda(), "Test only applicable when running on GPUs") - def test_inference(self, num_units, input_size, batch_size, time, num_layers): + def test_inference(self, num_units, input_size, batch_size, time, num_layers, + variable_seq_lengths): if not context.context().num_gpus(): self.skipTest("No GPUs found") with self.session(use_gpu=True) as sess: @@ -381,7 +416,8 @@ class CudnnLSTMTest(TensorFlowTestCase, parameterized.TestCase): batch_size, time, num_layers, - is_training=False) + is_training=False, + variable_seq_lengths=variable_seq_lengths) self.assertAllClose(outputs, cu_outputs) # h @@ -389,11 +425,14 @@ class CudnnLSTMTest(TensorFlowTestCase, parameterized.TestCase): # c self.assertAllClose(state_tuple.c, cu_state_tuple.c) - @parameterized.named_parameters(*NAMED_RNN_TESTCASES) + @parameterized.named_parameters( + ExpandNamedTestCases(NAMED_RNN_TESTCASES, **{ + "variable_seq_lengths": [True, False], + })) @unittest.skipUnless(test.is_built_with_cuda(), "Test only applicable when running on GPUs") def test_inference_fp16(self, num_units, input_size, batch_size, time, - num_layers): + num_layers, variable_seq_lengths): if not context.context().num_gpus(): self.skipTest("No GPUs found") with self.session(use_gpu=True) as sess: @@ -405,7 +444,8 @@ class CudnnLSTMTest(TensorFlowTestCase, parameterized.TestCase): time, num_layers, is_training=False, - dtype=dtypes.float16) + dtype=dtypes.float16, + variable_seq_lengths=variable_seq_lengths) rtol, atol = 5e-3, 5e-4 self.assertAllClose(outputs, cu_outputs, rtol=rtol, atol=atol) @@ -416,11 +456,14 @@ class CudnnLSTMTest(TensorFlowTestCase, parameterized.TestCase): self.assertAllClose( state_tuple.c, cu_state_tuple.c, rtol=rtol, atol=atol) - @parameterized.named_parameters(*NAMED_RNN_TESTCASES) + @parameterized.named_parameters( + ExpandNamedTestCases(NAMED_RNN_TESTCASES, **{ + "variable_seq_lengths": [True, False], + })) @unittest.skipUnless(test.is_built_with_cuda(), "Test only applicable when running on GPUs") def test_inference_with_dropout(self, num_units, input_size, batch_size, time, - num_layers): + num_layers, variable_seq_lengths): """Validates that dropout does not affect Cudnn Rnn inference.""" if not context.context().num_gpus(): self.skipTest("No GPUs found") @@ -436,7 +479,8 @@ class CudnnLSTMTest(TensorFlowTestCase, parameterized.TestCase): time, num_layers, is_training=False, - dropout=0.) + dropout=0., + variable_seq_lengths=variable_seq_lengths) with ops.Graph().as_default() as g: with self.session(use_gpu=True, graph=g) as sess: @@ -448,7 +492,8 @@ class CudnnLSTMTest(TensorFlowTestCase, parameterized.TestCase): time, num_layers, is_training=False, - dropout=1.) + dropout=1., + variable_seq_lengths=variable_seq_lengths) self.assertAllClose(cu_outputs, cu_outputs2) # h @@ -464,6 +509,7 @@ def RunGRU(sess, time, num_layers=1, is_training=True, + variable_seq_lengths=False, dropout=0., num_dirs=True, dtype=dtypes.float32): @@ -489,6 +535,13 @@ def RunGRU(sess, num_units).astype(dtype.as_numpy_dtype), dtype=dtype) + if variable_seq_lengths: + lengths_v = np.random.randint(low=1, high=time + 1, size=batch_size) + lengths_v[0] = time # make sure the max sequence has 'time' elems + lengths = ops.convert_to_tensor(lengths_v.astype(np.int32)) + else: + lengths = None + initializer = init_ops.random_uniform_initializer( -0.01, 0.01, dtype=dtype, seed=19980904) with variable_scope.variable_scope("test", initializer=initializer): @@ -521,6 +574,7 @@ def RunGRU(sess, outputs_op, h_op = rnn.dynamic_rnn( cell, inputs, + sequence_length=lengths, initial_state=initial_h_op, dtype=dtype, time_major=True, @@ -533,12 +587,14 @@ def RunGRU(sess, num_layers, num_units, input_size) opaque_params = format_converter.tf_canonical_to_opaque(ws + bs) + cu_initial_h_op = array_ops.expand_dims(initial_h_op, axis=0) cu_outputs_op, cu_h_op, _ = cudnn_rnn_ops._cudnn_rnn( inputs, cu_initial_h_op, array_ops.zeros_like(cu_initial_h_op), # not used opaque_params, + sequence_lengths=lengths, dropout=dropout, is_training=is_training, rnn_mode=cudnn_rnn_ops.CUDNN_GRU) @@ -615,12 +671,19 @@ class CudnnGRUTest(TensorFlowTestCase, parameterized.TestCase): time, num_layers, dtype, - rtol=2e-6, - atol=2e-6): + variable_seq_lengths, + rtol=3e-6, + atol=3e-6): with self.session(use_gpu=True) as sess: - (outputs, cu_outputs, h, cu_h, inp_grad, cu_inp_grad, hgrad, - cu_hgrad, wgrad, bgrad, cu_wgrad, cu_bgrad) = RunGRU( - sess, num_units, input_size, batch_size, time, num_layers) + (outputs, cu_outputs, h, cu_h, inp_grad, cu_inp_grad, hgrad, cu_hgrad, + wgrad, bgrad, cu_wgrad, cu_bgrad) = RunGRU( + sess, + num_units, + input_size, + batch_size, + time, + num_layers, + variable_seq_lengths=variable_seq_lengths) self.assertAllClose(outputs, cu_outputs, rtol=rtol, atol=atol) self.assertAllClose(h, cu_h, rtol=rtol, atol=atol) @@ -631,20 +694,33 @@ class CudnnGRUTest(TensorFlowTestCase, parameterized.TestCase): for wg, cu_wg in zip(wgrad, cu_wgrad): self.assertAllClose(wg, cu_wg, rtol=rtol, atol=atol) - @parameterized.named_parameters(*NAMED_RNN_TESTCASES) + @parameterized.named_parameters( + ExpandNamedTestCases(NAMED_RNN_TESTCASES, **{ + "variable_seq_lengths": [True, False], + })) @unittest.skipUnless(test.is_built_with_cuda(), "Test only applicable when running on GPUs") - def test_training(self, num_units, input_size, batch_size, time, num_layers): + def test_training(self, num_units, input_size, batch_size, time, num_layers, + variable_seq_lengths): if not context.context().num_gpus(): self.skipTest("No GPUs found") - self._test_training_helper(num_units, input_size, batch_size, time, - num_layers, dtypes.float32) + self._test_training_helper( + num_units, + input_size, + batch_size, + time, + num_layers, + dtypes.float32, + variable_seq_lengths=variable_seq_lengths) - @parameterized.named_parameters(*NAMED_RNN_TESTCASES) + @parameterized.named_parameters( + ExpandNamedTestCases(NAMED_RNN_TESTCASES, **{ + "variable_seq_lengths": [True, False], + })) @unittest.skipUnless(test.is_built_with_cuda(), "Test only applicable when running on GPUs") def test_training_fp16(self, num_units, input_size, batch_size, time, - num_layers): + num_layers, variable_seq_lengths): if not context.context().num_gpus(): self.skipTest("No GPUs found") self._test_training_helper( @@ -655,12 +731,17 @@ class CudnnGRUTest(TensorFlowTestCase, parameterized.TestCase): num_layers, dtypes.float16, rtol=5e-3, - atol=5e-4) + atol=5e-4, + variable_seq_lengths=variable_seq_lengths) - @parameterized.named_parameters(*NAMED_RNN_TESTCASES) + @parameterized.named_parameters( + ExpandNamedTestCases(NAMED_RNN_TESTCASES, **{ + "variable_seq_lengths": [True, False], + })) @unittest.skipUnless(test.is_built_with_cuda(), "Test only applicable when running on GPUs") - def test_inference(self, num_units, input_size, batch_size, time, num_layers): + def test_inference(self, num_units, input_size, batch_size, time, num_layers, + variable_seq_lengths): if not context.context().num_gpus(): self.skipTest("No GPUs found") with self.session(use_gpu=True) as sess: @@ -671,15 +752,19 @@ class CudnnGRUTest(TensorFlowTestCase, parameterized.TestCase): batch_size, time, num_layers, - is_training=False) + is_training=False, + variable_seq_lengths=variable_seq_lengths) self.assertAllClose(outputs, cu_outputs) self.assertAllClose(h, cu_h) - @parameterized.named_parameters(*NAMED_RNN_TESTCASES) + @parameterized.named_parameters( + ExpandNamedTestCases(NAMED_RNN_TESTCASES, **{ + "variable_seq_lengths": [True, False], + })) @unittest.skipUnless(test.is_built_with_cuda(), "Test only applicable when running on GPUs") def test_inference_fp16(self, num_units, input_size, batch_size, time, - num_layers): + num_layers, variable_seq_lengths): if not context.context().num_gpus(): self.skipTest("No GPUs found") with self.session(use_gpu=True) as sess: @@ -691,17 +776,21 @@ class CudnnGRUTest(TensorFlowTestCase, parameterized.TestCase): time, num_layers, is_training=False, - dtype=dtypes.float16) + dtype=dtypes.float16, + variable_seq_lengths=variable_seq_lengths) rtol, atol = 5e-3, 5e-4 self.assertAllClose(outputs, cu_outputs, rtol=rtol, atol=atol) self.assertAllClose(h, cu_h, rtol=rtol, atol=atol) - @parameterized.named_parameters(*NAMED_RNN_TESTCASES) + @parameterized.named_parameters( + ExpandNamedTestCases(NAMED_RNN_TESTCASES, **{ + "variable_seq_lengths": [True, False], + })) @unittest.skipUnless(test.is_built_with_cuda(), "Test only applicable when running on GPUs") def test_inference_with_dropout(self, num_units, input_size, batch_size, time, - num_layers): + num_layers, variable_seq_lengths): """Validates that dropout does not affect Cudnn Rnn inference.""" # Hand-picked dropouts are used below (0. and 1.) if not context.context().num_gpus(): @@ -717,7 +806,8 @@ class CudnnGRUTest(TensorFlowTestCase, parameterized.TestCase): time, num_layers, is_training=False, - dropout=0.) + dropout=0., + variable_seq_lengths=variable_seq_lengths) with ops.Graph().as_default() as g: with self.session(use_gpu=True, graph=g) as sess: @@ -729,7 +819,8 @@ class CudnnGRUTest(TensorFlowTestCase, parameterized.TestCase): time, num_layers, is_training=False, - dropout=1.) + dropout=1., + variable_seq_lengths=variable_seq_lengths) self.assertAllClose(cu_outputs, cu_outputs2) self.assertAllClose(cu_h[0], cu_h2[0]) diff --git a/tensorflow/contrib/cudnn_rnn/python/layers/cudnn_rnn.py b/tensorflow/contrib/cudnn_rnn/python/layers/cudnn_rnn.py index 8e25637ed91a1559b321ea96efbfaa2910f67158..86ad8ae8073714657c78badb1e0b4a6d8c8ed5f0 100644 --- a/tensorflow/contrib/cudnn_rnn/python/layers/cudnn_rnn.py +++ b/tensorflow/contrib/cudnn_rnn/python/layers/cudnn_rnn.py @@ -374,7 +374,11 @@ class _CudnnRNN(base_layer.Layer): "This cell does not yet support object-based saving. File a feature " "request if this limitation bothers you.") - def call(self, inputs, initial_state=None, training=True): + def call(self, + inputs, + initial_state=None, + sequence_lengths=None, + training=True): """Runs the forward step for the RNN model. Args: @@ -382,6 +386,9 @@ class _CudnnRNN(base_layer.Layer): initial_state: a tuple of tensor(s) of shape `[num_layers * num_dirs, batch_size, num_units]`. If not provided, use zero initial states. The tuple size is 2 for LSTM and 1 for other RNNs. + sequence_lengths: an int32 array representing the variable sequence + lengths in a batch. The size of the array has to equal the + batch_size. If not provided, the same sequence length will be assumed. training: whether this operation will be used in training or inference. Returns: output: a tensor of shape `[time_len, batch_size, num_dirs * num_units]`. @@ -411,7 +418,7 @@ class _CudnnRNN(base_layer.Layer): # For model that doesn't take input_c, replace with a dummy tensor. c = array_ops.constant([], dtype=dtype) outputs, (output_h, output_c) = self._forward(inputs, h, c, self.kernel, - training) + sequence_lengths, training) if self._rnn_mode == CUDNN_LSTM: return outputs, (output_h, output_c) else: @@ -475,7 +482,7 @@ class _CudnnRNN(base_layer.Layer): dropout=self._dropout, direction=self._direction) - def _forward(self, inputs, h, c, opaque_params, training): + def _forward(self, inputs, h, c, opaque_params, sequence_lengths, training): output, output_h, output_c = cudnn_rnn_ops._cudnn_rnn( # pylint:disable=protected-access inputs, h, @@ -483,6 +490,7 @@ class _CudnnRNN(base_layer.Layer): opaque_params, training, self._rnn_mode, + sequence_lengths=sequence_lengths, input_mode=self._input_mode, direction=self._direction, dropout=self._dropout, diff --git a/tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py b/tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py index 1ce29b42d52ff67477161278ed11016c2e73041d..1facc83972faf229f243af5bc534bcb98aff5440 100644 --- a/tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py +++ b/tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py @@ -955,6 +955,7 @@ def _cudnn_rnn(inputs, params, is_training, rnn_mode, + sequence_lengths=None, input_mode=CUDNN_INPUT_LINEAR_MODE, direction=CUDNN_RNN_UNIDIRECTION, dropout=0., @@ -972,6 +973,10 @@ def _cudnn_rnn(inputs, params: the parameter buffer created for this model. is_training: whether this operation will be used in training or inference rnn_mode: one of ('lstm', 'gru', 'rnn_relu', 'rnn_tanh'). + sequence_lengths: an int32 array representing the variable sequence lengths + in a batch. The size of the array has to equal the batch_size. Default to + None, in which case sequences in the batch are assumed to have the same + length, which is inferred from inputs. input_mode: indicate whether there is a linear projection between the input and the actual computation before the first layer. It could be 'linear_input', 'skip_input' or 'auto_select'. @@ -1010,7 +1015,10 @@ def _cudnn_rnn(inputs, "seed2": seed2, "name": name } - if use_cudnn_v2 != "1": + if sequence_lengths is not None: + args["sequence_lengths"] = sequence_lengths + outputs, output_h, output_c, _, _ = gen_cudnn_rnn_ops.cudnn_rnnv3(**args) + elif use_cudnn_v2 != "1": outputs, output_h, output_c, _ = gen_cudnn_rnn_ops.cudnn_rnn(**args) else: outputs, output_h, output_c, _, _ = gen_cudnn_rnn_ops.cudnn_rnnv2(**args) @@ -1022,6 +1030,7 @@ def cudnn_lstm(inputs, input_c, params, is_training, + sequence_lengths=None, input_mode=CUDNN_INPUT_LINEAR_MODE, direction=CUDNN_RNN_UNIDIRECTION, dropout=0., @@ -1051,12 +1060,17 @@ def cudnn_lstm(inputs, dropout: whether to enable dropout. With it is 0, dropout is disabled. seed: the op seed used for initializing dropout. See `tf.set_random_seed` for behavior. + sequence_lengths: an int32 array representing the variable sequence lengths + in a batch. The size of the array has to equal the batch_size. Default to + None, in which case sequences in the batch are assumed to have the same + length, which is inferred from inputs. name: name of the operation. Returns: outputs, output_h, output_c """ return _cudnn_rnn(inputs, input_h, input_c, params, is_training, CUDNN_LSTM, - input_mode, direction, dropout, seed, name) + sequence_lengths, input_mode, direction, dropout, seed, + name) def _cudnn_rnn_no_input_c(inputs, @@ -1064,6 +1078,7 @@ def _cudnn_rnn_no_input_c(inputs, params, is_training, rnn_mode, + sequence_lengths=None, input_mode=CUDNN_INPUT_LINEAR_MODE, direction=CUDNN_RNN_UNIDIRECTION, dropout=0., @@ -1079,6 +1094,10 @@ def _cudnn_rnn_no_input_c(inputs, params: the parameter buffer created for this model. is_training: whether this operation will be used in training or inference rnn_mode: one of ('lstm', 'gru', 'rnn_relu', 'rnn_tanh'). + sequence_lengths: an int32 array representing the variable sequence lengths + in a batch. The size of the array has to equal the batch_size. Default to + None, in which case sequences in the batch are assumed to have the same + length, which is inferred from inputs. input_mode: indicate whether there is a linear projection between the input and the actual computation before the first layer. It could be 'linear_input', 'skip_input' or 'auto_select'. @@ -1098,8 +1117,8 @@ def _cudnn_rnn_no_input_c(inputs, """ input_c = array_ops.constant([], dtype=input_h.dtype) outputs, output_h, _ = _cudnn_rnn(inputs, input_h, input_c, params, - is_training, rnn_mode, input_mode, - direction, dropout, seed, name) + is_training, rnn_mode, sequence_lengths, + input_mode, direction, dropout, seed, name) return outputs, output_h @@ -1107,6 +1126,7 @@ def cudnn_gru(inputs, input_h, params, is_training, + sequence_lengths=None, input_mode=CUDNN_INPUT_LINEAR_MODE, direction=CUDNN_RNN_UNIDIRECTION, dropout=0., @@ -1129,6 +1149,10 @@ def cudnn_gru(inputs, 'skip_input' is only allowed when input_size == num_units; 'auto_select' implies 'skip_input' when input_size == num_units; otherwise, it implies 'linear_input'. + sequence_lengths: an int32 array representing the variable sequence lengths + in a batch. The size of the array has to equal the batch_size. Default to + None, in which case sequences in the batch are assumed to have the same + length, which is inferred from inputs. direction: the direction model that the model operates. Could be either 'unidirectional' or 'bidirectional' dropout: whether to enable dropout. With it is 0, dropout is disabled. @@ -1139,7 +1163,8 @@ def cudnn_gru(inputs, outputs, output_h """ return _cudnn_rnn_no_input_c(inputs, input_h, params, is_training, CUDNN_GRU, - input_mode, direction, dropout, seed, name) + sequence_lengths, input_mode, direction, dropout, + seed, name) def cudnn_rnn_relu(inputs, @@ -1150,6 +1175,7 @@ def cudnn_rnn_relu(inputs, direction=CUDNN_RNN_UNIDIRECTION, dropout=0., seed=0, + sequence_lengths=None, name=None): """Cudnn RNN Relu. @@ -1162,30 +1188,34 @@ def cudnn_rnn_relu(inputs, is_training: whether this operation will be used in training or inference input_mode: indicate whether there is a linear projection between the input and the actual computation before the first layer. It could be - 'linear_input', 'skip_input' or 'auto_select'. - 'linear_input' (default) always applies a linear projection of input - onto RNN hidden state. (standard RNN behavior). - 'skip_input' is only allowed when input_size == num_units; - 'auto_select' implies 'skip_input' when input_size == num_units; - otherwise, it implies 'linear_input'. + 'linear_input', 'skip_input' or 'auto_select'. 'linear_input' (default) + always applies a linear projection of input onto RNN hidden state. + (standard RNN behavior). 'skip_input' is only allowed when input_size == + num_units; 'auto_select' implies 'skip_input' when input_size == + num_units; otherwise, it implies 'linear_input'. direction: the direction model that the model operates. Could be either - 'unidirectional' or 'bidirectional' + 'unidirectional' or 'bidirectional' dropout: whether to enable dropout. With it is 0, dropout is disabled. seed: the op seed used for initializing dropout. See `tf.set_random_seed` - for behavior. + for behavior. + sequence_lengths: an int32 array representing the variable sequence lengths + in a batch. The size of the array has to equal the batch_size. If not + provided, the same sequence length will be assumed. name: name of the operation. + Returns: outputs, output_h """ return _cudnn_rnn_no_input_c(inputs, input_h, params, is_training, - CUDNN_RNN_RELU, input_mode, direction, dropout, - seed, name) + CUDNN_RNN_RELU, sequence_lengths, input_mode, + direction, dropout, seed, name) def cudnn_rnn_tanh(inputs, input_h, params, is_training, + sequence_lengths=None, input_mode=CUDNN_INPUT_LINEAR_MODE, direction=CUDNN_RNN_UNIDIRECTION, dropout=0., @@ -1208,6 +1238,10 @@ def cudnn_rnn_tanh(inputs, 'skip_input' is only allowed when input_size == num_units; 'auto_select' implies 'skip_input' when input_size == num_units; otherwise, it implies 'linear_input'. + sequence_lengths: an int32 array representing the variable sequence lengths + in a batch. The size of the array has to equal the batch_size. Default to + None, in which case sequences in the batch are assumed to have the same + length, which is inferred from inputs. direction: the direction model that the model operates. Could be either 'unidirectional' or 'bidirectional' dropout: whether to enable dropout. With it is 0, dropout is disabled. @@ -1218,8 +1252,8 @@ def cudnn_rnn_tanh(inputs, outputs, output_h """ return _cudnn_rnn_no_input_c(inputs, input_h, params, is_training, - CUDNN_RNN_TANH, input_mode, direction, dropout, - seed, name) + CUDNN_RNN_TANH, sequence_lengths, input_mode, + direction, dropout, seed, name) def cudnn_rnn_opaque_params_to_canonical(rnn_mode, @@ -1497,7 +1531,13 @@ class _CudnnRNN(object): input_mode=self._input_mode, direction=self._direction) - def __call__(self, input_data, input_h, input_c, params, is_training=True): + def __call__(self, + input_data, + input_h, + input_c, + params, + is_training=True, + sequence_lengths=None): """Runs the forward step for the RNN model. Args: @@ -1509,6 +1549,10 @@ class _CudnnRNN(object): A Tensor of the same shape as input_h. params: the parameter buffer created for this model. is_training: whether this operation will be used in training or inference. + sequence_lengths: an int32 array representing the variable sequence + lengths in a batch. The size of the array has to equal the batch_size. + Default to None, in which case sequences in the batch are assumed to + have the same length, which is inferred from inputs. Returns: output: the output sequence. output_h: the final state for h. @@ -1521,6 +1565,7 @@ class _CudnnRNN(object): params, is_training, self._rnn_mode, + sequence_lengths=sequence_lengths, input_mode=self._input_mode, direction=self._direction, dropout=self._dropout, @@ -1615,7 +1660,13 @@ class CudnnLSTM(_CudnnRNN): dropout=dropout, seed=seed) - def __call__(self, input_data, input_h, input_c, params, is_training=True): + def __call__(self, + input_data, + input_h, + input_c, + params, + sequence_lengths=None, + is_training=True): """Runs the forward step for the Cudnn LSTM model. Args: @@ -1626,6 +1677,10 @@ class CudnnLSTM(_CudnnRNN): input_c: the initial hidden state for c. A Tensor of the same shape as input_h. params: the parameter buffer created for this model. + sequence_lengths: an int32 array representing the variable sequence + lengths in a batch. The size of the array has to equal the batch_size. + Default to None, in which case sequences in the batch are assumed to + have the same length, which is inferred from inputs. is_training: whether this operation will be used in training or inference. Returns: output: the output sequence. @@ -1633,7 +1688,12 @@ class CudnnLSTM(_CudnnRNN): output_c: the final state for c. """ output, output_h, output_c = super(CudnnLSTM, self).__call__( - input_data, input_h, input_c, params, is_training=is_training) + input_data, + input_h, + input_c, + params, + sequence_lengths=sequence_lengths, + is_training=is_training) return (output, output_h, output_c) @@ -1687,7 +1747,12 @@ class _CudnnRNNNoInputC(_CudnnRNN): dropout=dropout, seed=seed) - def __call__(self, input_data, input_h, params, is_training=True): + def __call__(self, + input_data, + input_h, + params, + sequence_lengths=None, + is_training=True): """Runs the forward step for the Cudnn LSTM model. Args: @@ -1696,6 +1761,10 @@ class _CudnnRNNNoInputC(_CudnnRNN): input_h: the initial hidden state for h. A Tensor of shape [num_layers, batch_size, num_units]. params: the parameter buffer created for this model. + sequence_lengths: an int32 array representing the variable sequence + lengths in a batch. The size of the array has to equal the batch_size. + Default to None, in which case sequences in the batch are assumed to + have the same length, which is inferred from inputs. is_training: whether this operation will be used in training or inference. Returns: output: the output sequence. @@ -1707,6 +1776,7 @@ class _CudnnRNNNoInputC(_CudnnRNN): params, is_training, self._rnn_mode, + sequence_lengths=sequence_lengths, input_mode=self._input_mode, direction=self._direction, dropout=self._dropout, diff --git a/tensorflow/contrib/data/python/ops/readers.py b/tensorflow/contrib/data/python/ops/readers.py index c0152156a1ba70297adb7054622b15ca04f859cd..c6bf5215c9406d03d2704e46903b3aa57e7e68d9 100644 --- a/tensorflow/contrib/data/python/ops/readers.py +++ b/tensorflow/contrib/data/python/ops/readers.py @@ -389,13 +389,11 @@ class LMDBDataset(dataset_ops.DatasetSource): Args: filenames: A `tf.string` tensor containing one or more filenames. """ - super(LMDBDataset, self).__init__() self._filenames = ops.convert_to_tensor( filenames, dtype=dtypes.string, name="filenames") - - def _as_variant_tensor(self): - return gen_experimental_dataset_ops.experimental_lmdb_dataset( + variant_tensor = gen_experimental_dataset_ops.experimental_lmdb_dataset( self._filenames, **dataset_ops.flat_structure(self)) + super(LMDBDataset, self).__init__(variant_tensor) @property def _element_structure(self): diff --git a/tensorflow/contrib/data/python/ops/sliding.py b/tensorflow/contrib/data/python/ops/sliding.py index 5c6ee6bfdc7167d14b292f8f763adafca4e3a72c..6708e01d08135a132b797e317cd2a241c3428f40 100644 --- a/tensorflow/contrib/data/python/ops/sliding.py +++ b/tensorflow/contrib/data/python/ops/sliding.py @@ -30,7 +30,6 @@ class _SlideDataset(dataset_ops.UnaryDataset): def __init__(self, input_dataset, window_size, window_shift, window_stride): """See `sliding_window_batch` for details.""" - super(_SlideDataset, self).__init__(input_dataset) self._input_dataset = input_dataset self._window_size = ops.convert_to_tensor( window_size, dtype=dtypes.int64, name="window_stride") @@ -43,14 +42,13 @@ class _SlideDataset(dataset_ops.UnaryDataset): input_dataset.output_types, input_dataset.output_shapes, input_dataset.output_classes) self._structure = input_structure._batch(None) # pylint: disable=protected-access - - def _as_variant_tensor(self): - return ged_ops.experimental_sliding_window_dataset( - self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access + variant_tensor = ged_ops.experimental_sliding_window_dataset( + self._input_dataset._variant_tensor, # pylint: disable=protected-access window_size=self._window_size, window_shift=self._window_shift, window_stride=self._window_stride, **dataset_ops.flat_structure(self)) + super(_SlideDataset, self).__init__(input_dataset, variant_tensor) @property def _element_structure(self): diff --git a/tensorflow/contrib/distribute/python/BUILD b/tensorflow/contrib/distribute/python/BUILD index 4ea1fa050a4527cf4c82b91f66d00b884c9af5d0..d2fb878f96f55200d870447b45f3d0a37c6b0f86 100644 --- a/tensorflow/contrib/distribute/python/BUILD +++ b/tensorflow/contrib/distribute/python/BUILD @@ -1,5 +1,8 @@ # Implementation of a prototype TF distributed computation library. +load("//tensorflow:tensorflow.bzl", "py_test") +load("//tensorflow:tensorflow.bzl", "cuda_py_test") + package( default_visibility = [ "//tensorflow:internal", @@ -10,9 +13,6 @@ licenses(["notice"]) # Apache 2.0 exports_files(["LICENSE"]) -load("//tensorflow:tensorflow.bzl", "py_test") -load("//tensorflow:tensorflow.bzl", "cuda_py_test") - # TODO(priyag): Figure out testonly issues that are preventing us from # including our tests in pip for now. @@ -584,7 +584,10 @@ cuda_py_test( py_library( name = "keras_test_lib", testonly = 1, - srcs = ["keras_test.py"], + srcs = [ + "keras_backward_compat_test.py", + "keras_test.py", + ], deps = [ ":combinations", "//tensorflow/contrib/distribute/python:mirrored_strategy", @@ -611,7 +614,57 @@ cuda_py_test( "no_oss", # TODO(b/117919883): Fix python error. "no_pip", "no_windows_gpu", - "noguitar", # TODO(b/120025010): Re-enable this test on Guitar. + "notsan", + ], +) + +# TODO(b/121200287): Remove this in 2.0 +cuda_py_test( + name = "keras_backward_compat_test", + srcs = ["keras_backward_compat_test.py"], + additional_deps = [ + ":keras_test_lib", + ], + shard_count = 16, + tags = [ + "multi_and_single_gpu", + "no_oss", # TODO(b/117919883): Fix python error. + "no_pip", + "no_windows_gpu", + "notsan", + ], +) + +py_library( + name = "keras_correctness_test_lib", + testonly = 1, + srcs = ["keras_correctness_test.py"], + deps = [ + ":combinations", + "//tensorflow/contrib/distribute/python:mirrored_strategy", + "//tensorflow/contrib/distribute/python:tpu_strategy", + "//tensorflow/python:client_testlib", + "//tensorflow/python:training", + "//tensorflow/python/eager:test", + "//tensorflow/python/estimator:estimator_py", + "//tensorflow/python/keras", + "//third_party/py/numpy", + "@absl_py//absl/testing:parameterized", + ], +) + +cuda_py_test( + name = "keras_correctness_test", + srcs = ["keras_correctness_test.py"], + additional_deps = [ + ":keras_correctness_test_lib", + ], + shard_count = 16, + tags = [ + "multi_and_single_gpu", + "no_oss", # TODO(b/117919883): Fix python error. + "no_pip", + "no_windows_gpu", "notsan", ], ) diff --git a/tensorflow/contrib/distribute/python/collective_all_reduce_strategy.py b/tensorflow/contrib/distribute/python/collective_all_reduce_strategy.py index e988b63a28718e509df0d5ce42423ba4616b0e60..12197c3d0dedee23d12732b8d4398f43bfc61caa 100644 --- a/tensorflow/contrib/distribute/python/collective_all_reduce_strategy.py +++ b/tensorflow/contrib/distribute/python/collective_all_reduce_strategy.py @@ -70,6 +70,8 @@ class CollectiveAllReduceExtended(mirrored_strategy.MirroredExtended): self._cross_device_ops = None self._num_gpus_per_worker = num_gpus_per_worker self._initialize_local_worker(num_gpus_per_worker) + assert isinstance(self._get_cross_device_ops(), + cross_device_ops_lib.CollectiveAllReduce) def _initialize_local_worker(self, num_gpus_per_worker): """Initializes the object for local training.""" @@ -77,16 +79,16 @@ class CollectiveAllReduceExtended(mirrored_strategy.MirroredExtended): self._num_workers = 1 if num_gpus_per_worker: - local_devices = [ + local_devices = tuple( "/device:GPU:%d" % i for i in range(num_gpus_per_worker) - ] + ) else: - local_devices = ["/device:CPU:0"] + local_devices = ("/device:CPU:0",) self._worker_device = device_util.canonicalize("/device:CPU:0") self._collective_keys = cross_device_utils.CollectiveKeys() self._initialize_local(local_devices) - self._cross_tower_ops = cross_device_ops_lib.CollectiveAllReduce( + self._cross_device_ops = cross_device_ops_lib.CollectiveAllReduce( num_workers=self._num_workers, num_gpus_per_worker=num_gpus_per_worker, collective_keys=self._collective_keys) @@ -104,7 +106,7 @@ class CollectiveAllReduceExtended(mirrored_strategy.MirroredExtended): if task_type is None or task_id is None: raise ValueError("When `cluster_spec` is given, you must also specify " "`task_type` and `task_id`") - if task_type not in ["chief", "worker"]: + if task_type not in ("chief", "worker"): raise ValueError( "Unrecognized task_type: %r, valid task types are: \"chief\", " "\"worker\"." % task_type) @@ -119,16 +121,18 @@ class CollectiveAllReduceExtended(mirrored_strategy.MirroredExtended): self._worker_device = "/job:%s/task:%d" % (task_type, task_id) if num_gpus_per_worker: - local_devices = [ + local_devices = tuple( "%s/device:GPU:%d" % (self._worker_device, i) for i in range(num_gpus_per_worker) - ] + ) else: - local_devices = [self._worker_device] + local_devices = (self._worker_device,) self._collective_keys = cross_device_utils.CollectiveKeys() self._initialize_local(local_devices) - self._cross_tower_ops = cross_device_ops_lib.CollectiveAllReduce( + self._input_workers = values.InputWorkers( + self._device_map, [(self._worker_device, self.worker_devices)]) + self._cross_device_ops = cross_device_ops_lib.CollectiveAllReduce( num_workers=self._num_workers, num_gpus_per_worker=num_gpus_per_worker, collective_keys=self._collective_keys) @@ -149,13 +153,18 @@ class CollectiveAllReduceExtended(mirrored_strategy.MirroredExtended): def _create_variable(self, next_creator, *args, **kwargs): colocate_with = kwargs.pop("colocate_with", None) - devices = self._get_devices_from(colocate_with) - group_size = len(devices) * self._num_workers - group_key = self._collective_keys.get_group_key(self._devices) + if colocate_with is None: + device_map = self._device_map + logical_device = 0 # TODO(josh11b): Get logical device from scope here. + else: + device_map = colocate_with.device_map + logical_device = colocate_with.logical_device + group_size = device_map.num_replicas_in_graph * self._num_workers + group_key = self._collective_keys.get_group_key(self.worker_devices) def _real_mirrored_creator(devices, *args, **kwargs): """Creates one MirroredVariable on the current worker.""" - index = {} + value_list = [] unique_var_name = ops.get_default_graph().unique_name( kwargs["name"], mark_as_used=False).rstrip("/") collective_instance_key = self._collective_keys.get_instance_key( @@ -172,7 +181,7 @@ class CollectiveAllReduceExtended(mirrored_strategy.MirroredExtended): with ops.device(d): if i > 0: # Give replicas meaningful distinct names: - var0name = index[devices[0]].name.split(":")[0] + var0name = value_list[0].name.split(":")[0] # We append a / to variable names created on replicas with id > 0 to # ensure that we ignore the name scope and instead use the given # name as the absolute name of the variable. @@ -208,22 +217,24 @@ class CollectiveAllReduceExtended(mirrored_strategy.MirroredExtended): assert unique_var_name == actual_var_name, "%r vs %r" % ( unique_var_name, actual_var_name) assert not isinstance(v, values.DistributedVariable) - index[d] = v - return index + value_list.append(v) + return value_list # pylint: disable=protected-access return mirrored_strategy._create_mirrored_variable( - devices, _real_mirrored_creator, *args, **kwargs) + self._container_strategy(), device_map, logical_device, + _real_mirrored_creator, *args, **kwargs) def _distribute_dataset(self, dataset_fn): """Distributes the dataset to each local GPU.""" # TODO(yuefengz): shard the dataset. + worker_index = 0 return values.PerReplicaDataset( - self._call_dataset_fn(dataset_fn), self._devices, True) + self._call_dataset_fn(dataset_fn), self._input_workers, worker_index, + prefetch_on_device=True) def _make_dataset_iterator(self, dataset): - worker_device_pairs = [(self._worker_device, self._devices)] - return values.DatasetIterator(dataset, worker_device_pairs, + return values.DatasetIterator(dataset, self._input_workers, self._num_replicas_in_sync) def _make_input_fn_iterator( @@ -242,7 +253,7 @@ class CollectiveAllReduceExtended(mirrored_strategy.MirroredExtended): num_replicas_in_sync=self._num_replicas_in_sync) return values.InputFunctionIterator( - input_fn, [(self._worker_device, self._devices)], [input_context]) + input_fn, self._input_workers, [input_context]) def _configure(self, session_config=None, @@ -267,6 +278,8 @@ class CollectiveAllReduceExtended(mirrored_strategy.MirroredExtended): # already been initialized with a `cluster_spec`. self._initialize_multi_worker(self._num_gpus_per_worker, cluster_spec, task_type, task_id) + assert isinstance(self._get_cross_device_ops(), + cross_device_ops_lib.CollectiveAllReduce) if session_config: session_config.CopyFrom(self._update_config_proto(session_config)) @@ -328,7 +341,7 @@ class CollectiveAllReduceExtended(mirrored_strategy.MirroredExtended): @property def _num_replicas_in_sync(self): - return len(self._devices) * self._num_workers + return len(self.worker_devices) * self._num_workers # TODO(priyag): Delete this once all strategies use global batch size. @property diff --git a/tensorflow/contrib/distribute/python/collective_all_reduce_strategy_test.py b/tensorflow/contrib/distribute/python/collective_all_reduce_strategy_test.py index 8a9e583f0afaac37a2057bae9b1ed79de43d68bc..0fb672dded7624e798592d2f5c01945aa830021e 100644 --- a/tensorflow/contrib/distribute/python/collective_all_reduce_strategy_test.py +++ b/tensorflow/contrib/distribute/python/collective_all_reduce_strategy_test.py @@ -82,7 +82,7 @@ class CollectiveAllReduceStrategyTestBase( instance_key_with_id_start=num_gpus * 10000 + CollectiveAllReduceStrategyTestBase.collective_key_base) distribution.extended._collective_keys = collective_keys - distribution.extended._inferred_cross_device_ops._collective_keys = ( + distribution.extended._cross_device_ops._collective_keys = ( collective_keys) if task_type and task_id is not None: return distribution, 'grpc://' + self._cluster_spec[task_type][ @@ -128,7 +128,7 @@ class CollectiveAllReduceStrategyTestBase( before_list = [] after_list = [] for g, v in g_v: - fetched = d.read_var(v) + fetched = d.extended.read_var(v) before_list.append(fetched) with ops.control_dependencies([fetched]): # TODO(yuefengz): support non-Mirrored variable as destinations. @@ -136,7 +136,7 @@ class CollectiveAllReduceStrategyTestBase( reduce_util.ReduceOp.SUM, g, destinations=v) with ops.control_dependencies( d.update(v, update, g, grouped=False)): - after_list.append(d.read_var(v)) + after_list.append(d.extended.read_var(v)) return before_list, after_list before_out, after_out = step() @@ -252,21 +252,22 @@ class CollectiveAllReduceStrategyTestBase( for expected_value in expected_values: next_element = iterator.get_next() - computed_value = sess.run( - [values.select_device(d, next_element) for d in devices]) + computed_value = sess.run([values.select_replica(r, next_element) + for r in range(len(devices))]) self.assertEqual(expected_value, computed_value) with self.assertRaises(errors.OutOfRangeError): next_element = iterator.get_next() - sess.run([values.select_device(d, next_element) for d in devices]) + sess.run([values.select_replica(r, next_element) + for r in range(len(devices))]) # After re-initializing the iterator, should be able to iterate again. sess.run(iterator.initialize()) for expected_value in expected_values: next_element = iterator.get_next() - computed_value = sess.run( - [values.select_device(d, next_element) for d in devices]) + computed_value = sess.run([values.select_replica(r, next_element) + for r in range(len(devices))]) self.assertEqual(expected_value, computed_value) diff --git a/tensorflow/contrib/distribute/python/combinations.py b/tensorflow/contrib/distribute/python/combinations.py index 365ce5cdec79f1914f0c9ccdf59a7dc59e6f819e..4a934953ad2d4c6ecbe2bde2333a49bf8fd72821 100644 --- a/tensorflow/contrib/distribute/python/combinations.py +++ b/tensorflow/contrib/distribute/python/combinations.py @@ -324,7 +324,7 @@ class NamedDistribution(object): # pylint: disable=g-long-lambda default_strategy = NamedDistribution( "Default", - distribution_strategy_context._get_default_distribution_strategy, # pylint: disable=protected-access + distribution_strategy_context._get_default_strategy, # pylint: disable=protected-access required_gpus=None) one_device_strategy = NamedDistribution( "OneDeviceCPU", lambda: one_device_lib.OneDeviceStrategy("/cpu:0"), diff --git a/tensorflow/contrib/distribute/python/cross_device_ops_test.py b/tensorflow/contrib/distribute/python/cross_device_ops_test.py index d6e9521c1c1115ffdbdcf375ad4017bacb962832..54cce2988383fcf5e063726948fbbf62c7094ce5 100644 --- a/tensorflow/contrib/distribute/python/cross_device_ops_test.py +++ b/tensorflow/contrib/distribute/python/cross_device_ops_test.py @@ -40,8 +40,16 @@ from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops +def _get_devices(devices): + if isinstance(devices, (tuple, list)): + return tuple(device_util.resolve(d) for d in devices) + elif isinstance(devices, value_lib.DistributedValues): + return devices.devices + return (device_util.resolve(devices),) + + def _make_per_replica(values, devices, regroup=False): - devices = cross_device_ops_lib.get_devices_from(devices) + devices = _get_devices(devices) assert len(values) == len(devices) # We simulate the result of regroup called on PerReplica which strips the @@ -51,12 +59,12 @@ def _make_per_replica(values, devices, regroup=False): placed_v = array_ops.identity(values[0]) return placed_v - index = {} + index = [] for d, v in zip(devices, values): with ops.device(d): placed_v = array_ops.identity(v) - index[d] = placed_v - return value_lib.PerReplica(index) + index.append(placed_v) + return value_lib.PerReplica(value_lib.ReplicaDeviceMap(devices), index) # pylint: disable=g-doc-args,g-doc-return-or-yield @@ -66,9 +74,9 @@ def _fake_mirrored(value, devices): All components of the returned Mirrored have the same objects, which is not true in reality. """ - devices = cross_device_ops_lib.get_devices_from(devices) - return value_lib.Mirrored( - {d: v for d, v in zip(devices, [value] * len(devices))}) + devices = _get_devices(devices) + return value_lib.Mirrored(value_lib.ReplicaDeviceMap(devices), + [value] * len(devices)) def _make_indexed_slices(values, indices, dense_shape, device): @@ -81,9 +89,9 @@ def _make_indexed_slices(values, indices, dense_shape, device): def _make_mirrored_indexed_slices(devices, values, indices, dense_shape): - return value_lib.Mirrored({ - d: _make_indexed_slices(values, indices, dense_shape, d) for d in devices - }) + values = [_make_indexed_slices(values, indices, dense_shape, d) + for d in devices] + return value_lib.Mirrored(value_lib.ReplicaDeviceMap(devices), values) _cpu_device = "/device:CPU:0" @@ -107,16 +115,16 @@ class CrossDeviceOpsTestBase(test.TestCase, parameterized.TestCase): else: self.assertEqual(type(left), type(right)) self.assertEqual(set(left.devices), set(right.devices)) - if isinstance(list(left._index.values())[0], ops.IndexedSlices): - for (d, v) in left._index.items(): - self._assert_indexed_slices_equal(v, right._index[d]) + if isinstance(left.values[0], ops.IndexedSlices): + for d in left.devices: + self._assert_indexed_slices_equal(left.get(d), right.get(d)) elif context.executing_eagerly(): - self.assertEqual([v.numpy() for v in left._index.values()], - list(right._index.values())) + self.assertEqual([v.numpy() for v in left.values], + list(right.values)) else: with self.cached_session() as sess: self.assertEqual( - sess.run(list(left._index.values())), list(right._index.values())) + sess.run(list(left.values)), list(right.values)) def _testReductionAndBroadcast(self, cross_device_ops, distribution): devices = distribution.extended.worker_devices @@ -280,7 +288,8 @@ class SingleWorkerCrossDeviceOpsTest(CrossDeviceOpsTestBase): devices = ["/cpu:0", "/gpu:0"] t0 = _make_indexed_slices([[1., 2.]], [1], [5, 2], devices[0]) t1 = _make_indexed_slices([[3., 4.], [5., 6.]], [1, 3], [5, 2], devices[1]) - per_replica = value_lib.PerReplica({devices[0]: t0, devices[1]: t1}) + per_replica = value_lib.PerReplica( + value_lib.ReplicaDeviceMap(devices), (t0, t1)) result = cross_device_ops_lib._simple_reduce( per_replica, devices[0], math_ops.add_n, reduce_util.ReduceOp.SUM) @@ -314,7 +323,8 @@ class SingleWorkerCrossDeviceOpsTest(CrossDeviceOpsTestBase): t0 = _make_indexed_slices([[1., 2.]], [1], dense_shape, devices[0]) t1 = _make_indexed_slices( [[3., 4.], [5., 6.]], [1, 3], dense_shape, devices[1]) - per_replica = value_lib.PerReplica({devices[0]: t0, devices[1]: t1}) + per_replica = value_lib.PerReplica( + value_lib.ReplicaDeviceMap(devices), (t0, t1)) if batch_reduce: result = cross_device_ops_instance.batch_reduce( @@ -474,8 +484,8 @@ class MultiWorkerCollectiveAllReduceTest( run_options.experimental.collective_graph_key = 6 left_values = np.array( - sess.run(list(left._index.values()), options=run_options)).flatten() - right_values = np.array(list(right._index.values())).flatten() + sess.run(list(left.values), options=run_options)).flatten() + right_values = np.array(list(right.values)).flatten() self.assertEqual(len(left_values), len(right_values)) for l, r in zip(left_values, right_values): self.assertEqual(l, r) @@ -496,7 +506,7 @@ class MultiWorkerCollectiveAllReduceTest( # Collective ops doesn't support scalar tensors, so we have to construct # 1-d tensors. values = [constant_op.constant([float(d)]) for d in range(len(devices))] - per_replica = _make_per_replica(values, devices, regroup=True) + per_replica = _make_per_replica(values, devices) mean = np.array([(len(devices) - 1.) / 2.]) values_2 = [constant_op.constant([d + 1.0]) for d in range(len(devices))] diff --git a/tensorflow/contrib/distribute/python/cross_device_utils_test.py b/tensorflow/contrib/distribute/python/cross_device_utils_test.py index 2303a31677afbd12a0b8e7eea3ecf7c7736c46ad..275aac2eeca575e927878d1ece63ce37ed38e8a0 100644 --- a/tensorflow/contrib/distribute/python/cross_device_utils_test.py +++ b/tensorflow/contrib/distribute/python/cross_device_utils_test.py @@ -103,7 +103,8 @@ class IndexedSlicesUtilsTest(test.TestCase, parameterized.TestCase): constant_op.constant([[1., 2.], [0, 0], [3., 4.]])) t1 = math_ops._as_indexed_slices( constant_op.constant([[0., 0.], [5, 6], [7., 8.]])) - per_replica = value_lib.PerReplica({"/gpu:0": t0, "/cpu:0": t1}) + device_map = value_lib.ReplicaDeviceMap(("/gpu:0", "/cpu:0")) + per_replica = value_lib.PerReplica(device_map, (t0, t1)) self.assertTrue(cross_device_utils.contains_indexed_slices(per_replica)) @combinations.generate(combinations.combine( diff --git a/tensorflow/contrib/distribute/python/estimator_training_test.py b/tensorflow/contrib/distribute/python/estimator_training_test.py index b369a7fefe6f35cf5a9b64451419cf4f72a99471..3f55a8a1c8b88d1b8e4031547fa3fbe519983630 100644 --- a/tensorflow/contrib/distribute/python/estimator_training_test.py +++ b/tensorflow/contrib/distribute/python/estimator_training_test.py @@ -375,11 +375,13 @@ class DistributeCoordinatorIntegrationTest( threads = self.run_multiple_tasks_in_threads(self._independent_worker_fn, cluster_spec, train_distribute, eval_distribute) + threads_to_join = [] for task_type, ts in threads.items(): if task_type == PS: continue for t in ts: - t.join() + threads_to_join.append(t) + self.join_independent_workers(threads_to_join) estimator = self._get_estimator(train_distribute, eval_distribute) self._inspect_train_and_eval_events(estimator) @@ -413,8 +415,7 @@ class DistributeCoordinatorIntegrationTest( threads = self.run_multiple_tasks_in_threads(self._independent_worker_fn, cluster_spec, train_distribute, eval_distribute) - threads[WORKER][0].join() - threads[EVALUATOR][0].join() + self.join_independent_workers([threads[WORKER][0], threads[EVALUATOR][0]]) estimator = self._get_estimator(train_distribute, eval_distribute) self._inspect_train_and_eval_events(estimator) diff --git a/tensorflow/contrib/distribute/python/keras_backward_compat_test.py b/tensorflow/contrib/distribute/python/keras_backward_compat_test.py new file mode 100644 index 0000000000000000000000000000000000000000..93c0280c8215712071457cafb9c6040f7d97fa60 --- /dev/null +++ b/tensorflow/contrib/distribute/python/keras_backward_compat_test.py @@ -0,0 +1,1417 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for tf.keras models using DistributionStrategy.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +from absl.testing import parameterized +import numpy as np + +from tensorflow.contrib.distribute.python import combinations +from tensorflow.contrib.distribute.python import mirrored_strategy +from tensorflow.contrib.distribute.python import tpu_strategy +from tensorflow.python import keras +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.distribute import distribute_lib +from tensorflow.python.distribute import values +from tensorflow.python.eager import test +from tensorflow.python.estimator import keras as keras_lib +from tensorflow.python.estimator import run_config as run_config_lib +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import random_seed +from tensorflow.python.framework import test_util +from tensorflow.python.keras import testing_utils +from tensorflow.python.keras.engine import distributed_training_utils +from tensorflow.python.keras.optimizer_v2 import gradient_descent as gradient_descent_keras +from tensorflow.python.ops.parsing_ops import gen_parsing_ops +from tensorflow.python.platform import gfile +from tensorflow.python.summary.writer import writer_cache +from tensorflow.python.training import gradient_descent +from tensorflow.python.training import rmsprop + +_RANDOM_SEED = 1337 +_TRAIN_SIZE = 200 +_INPUT_SIZE = (10,) +_NUM_CLASS = 2 + + +# TODO(anjalisridhar): Add a decorator that will allow us to run these tests as +# part of the tf.keras unit tests suite. +def simple_sequential_model(): + model = keras.models.Sequential() + model.add(keras.layers.Dense(16, activation='relu', input_shape=_INPUT_SIZE)) + model.add(keras.layers.Dropout(0.1)) + model.add(keras.layers.Dense(_NUM_CLASS, activation='softmax')) + return model + + +def simple_functional_model(): + a = keras.layers.Input(shape=_INPUT_SIZE) + b = keras.layers.Dense(16, activation='relu')(a) + b = keras.layers.Dropout(0.1)(b) + b = keras.layers.Dense(_NUM_CLASS, activation='softmax')(b) + model = keras.models.Model(inputs=[a], outputs=[b]) + return model + + +def multi_inputs_multi_outputs_model(): + input_a = keras.layers.Input(shape=(16,), name='input_a') + input_b = keras.layers.Input(shape=(16,), name='input_b') + input_m = keras.layers.Input(shape=(8,), dtype='string', name='input_m') + dense = keras.layers.Dense(8, name='dense_1') + + interm_a = dense(input_a) + # Read m + interm_m = keras.layers.Lambda(gen_parsing_ops.string_to_number)(input_m) + interm_s = keras.layers.Lambda(lambda k: k[0] * k[1])([interm_m, interm_a]) + interm_b = dense(input_b) + merged = keras.layers.concatenate([interm_s, interm_b], name='merge') + output_c = keras.layers.Dense(3, activation='softmax', name='dense_2')(merged) + output_d = keras.layers.Dense(2, activation='softmax', name='dense_3')(merged) + model = keras.models.Model( + inputs=[input_a, input_b, input_m], outputs=[output_c, output_d]) + model.compile( + loss='categorical_crossentropy', + optimizer=gradient_descent.GradientDescentOptimizer(0.001), + metrics={ + 'dense_2': 'categorical_accuracy', + 'dense_3': 'categorical_accuracy' + }) + return model + + +def get_ds_train_input_fn(): + np.random.seed(_RANDOM_SEED) + (x_train, y_train), _ = testing_utils.get_test_data( + train_samples=_TRAIN_SIZE, + test_samples=50, + input_shape=_INPUT_SIZE, + num_classes=_NUM_CLASS) + y_train = keras.utils.to_categorical(y_train) + + dataset = dataset_ops.Dataset.from_tensor_slices((x_train, y_train)) + dataset = dataset.batch(32) + return dataset + + +def get_ds_test_input_fn(): + np.random.seed(_RANDOM_SEED) + _, (x_test, y_test) = testing_utils.get_test_data( + train_samples=_TRAIN_SIZE, + test_samples=50, + input_shape=_INPUT_SIZE, + num_classes=_NUM_CLASS) + y_test = keras.utils.to_categorical(y_test) + + dataset = dataset_ops.Dataset.from_tensor_slices((x_test, y_test)) + dataset = dataset.batch(32) + return dataset + + +def get_multi_inputs_multi_outputs_data(): + (a_train, c_train), (a_test, c_test) = testing_utils.get_test_data( + train_samples=_TRAIN_SIZE, + test_samples=50, + input_shape=(16,), + num_classes=3, + random_seed=_RANDOM_SEED) + (b_train, d_train), (b_test, d_test) = testing_utils.get_test_data( + train_samples=_TRAIN_SIZE, + test_samples=50, + input_shape=(16,), + num_classes=2, + random_seed=_RANDOM_SEED) + (m_train, _), (m_test, _) = testing_utils.get_test_data( + train_samples=_TRAIN_SIZE, + test_samples=50, + input_shape=(8,), + num_classes=2, + random_seed=_RANDOM_SEED) + + c_train = keras.utils.to_categorical(c_train) + c_test = keras.utils.to_categorical(c_test) + d_train = keras.utils.to_categorical(d_train) + d_test = keras.utils.to_categorical(d_test) + + train_data = { + 'input_a': a_train, + 'input_b': b_train, + 'input_m': m_train, + 'output_c': c_train, + 'output_d': d_train + } + test_data = { + 'input_a': a_test, + 'input_b': b_test, + 'input_m': m_test, + 'output_c': c_test, + 'output_d': d_test + } + + return (train_data, test_data) + + +def batch_wrapper(dataset, batch_size, distribution, repeat=None): + if repeat: + dataset = dataset.repeat(repeat) + # TPUs currently require fully defined input shapes, drop_remainder ensures + # the input will have fully defined shapes. + if isinstance(distribution, tpu_strategy.TPUStrategy): + return dataset.batch(batch_size, drop_remainder=True) + else: + return dataset.batch(batch_size) + + +def get_model(): + x = keras.layers.Input(shape=(3,), name='input') + y = keras.layers.Dense(4, name='dense')(x) + model = keras.Model(x, y) + return model + + +def get_dataset(distribution): + inputs = np.zeros((10, 3), dtype=np.float32) + targets = np.zeros((10, 4), dtype=np.float32) + dataset = dataset_ops.Dataset.from_tensor_slices((inputs, targets)) + dataset = dataset.repeat(100) + dataset = batch_wrapper(dataset, 10, distribution) + return dataset + + +def get_predict_dataset(distribution): + inputs = np.zeros((10, 3), dtype=np.float32) + dataset = dataset_ops.Dataset.from_tensor_slices(inputs) + dataset = dataset.repeat(100) + dataset = batch_wrapper(dataset, 10, distribution) + return dataset + + +def multi_input_output_model(): + a = keras.layers.Input(shape=(3,), name='input_a') + b = keras.layers.Input(shape=(5,), name='input_b') + # TODO(anjalisridhar): Change the output dimension of the second Dense layer + # once the iterator output validation issue has been fixed. + dense_1 = keras.layers.Dense(7, name='dense_1') + dense_2 = keras.layers.Dense(7, name='dense_2') + c = dense_1(a) + d = dense_2(b) + e = keras.layers.Dropout(0.5, name='dropout')(c) + model = keras.models.Model([a, b], [d, e]) + return model + + +def get_correctness_test_inputs(use_numpy, use_validation_data, + with_distribution, + x_train, y_train, x_predict): + """Generates the inputs for correctness check when enable Keras with DS.""" + training_epochs = 2 + global_batch_size = 64 + batch_size = global_batch_size + # TODO(b/118776054): Use global batch size for Keras/DS support. + use_per_core_batch_size = ( + with_distribution and + not distributed_training_utils.global_batch_size_supported( + with_distribution)) + if use_per_core_batch_size: + batch_size //= with_distribution.num_replicas_in_sync + + if use_numpy: + training_inputs = { + 'batch_size': batch_size, + 'x': x_train, + 'y': y_train, + 'epochs': training_epochs, + 'shuffle': False, + } + + if use_validation_data: + eval_inputs = None + training_inputs['validation_data'] = (x_train, y_train) + else: + eval_inputs = { + 'batch_size': batch_size, + 'x': x_train, + 'y': y_train, + } + predict_inputs = { + 'x': np.array(x_predict, dtype=np.float32), + } + else: + # For dataset inputs, we do not pass batch_size to + # keras.fit/evaluate/predict. The batch size is part of the dataset. + train_dataset = dataset_ops.Dataset.from_tensor_slices( + (x_train, y_train)) + x = batch_wrapper( + train_dataset, batch_size, with_distribution, repeat=training_epochs) + + training_inputs = { + 'batch_size': None, + 'x': x, + 'y': None, + 'epochs': training_epochs, + 'shuffle': False, + 'steps_per_epoch': len(x_train) // global_batch_size, + } + if use_validation_data: + eval_inputs = None # Remove the eval_inputs + eval_dataset = dataset_ops.Dataset.from_tensor_slices( + (x_train, y_train)) + x = batch_wrapper(eval_dataset, batch_size, with_distribution) + training_inputs['validation_data'] = x + training_inputs['validation_steps'] = 5 + else: + eval_inputs = { + 'batch_size': None, + 'x': x, + 'y': None, + 'steps': 20, + } + + predict_batch_size = len(x_predict) + if use_per_core_batch_size: + predict_batch_size //= with_distribution.num_replicas_in_sync + predict_dataset = dataset_ops.Dataset.from_tensor_slices(x_predict) + predict_dataset = batch_wrapper(predict_dataset, + predict_batch_size, with_distribution) + predict_inputs = { + 'steps': 1, + 'x': predict_dataset, + } + + return training_inputs, eval_inputs, predict_inputs + + +strategies_minus_tpu = [ + combinations.default_strategy, + combinations.one_device_strategy, + combinations.mirrored_strategy_with_gpu_and_cpu, + combinations.mirrored_strategy_with_two_gpus, + combinations.core_mirrored_strategy_with_gpu_and_cpu, + combinations.core_mirrored_strategy_with_two_gpus] + +tpu_strategies = [ + combinations.tpu_strategy, # steps_per_run=2 + combinations.tpu_strategy_one_step] + + +def strategy_minus_tpu_combinations(): + return combinations.combine( + distribution=strategies_minus_tpu, + mode=['graph', 'eager']) + + +def tpu_strategy_combinations(): + return combinations.combine( + distribution=tpu_strategies, + mode=['graph']) + + +def all_strategy_combinations(): + return strategy_minus_tpu_combinations() + tpu_strategy_combinations() + + +# TODO(priyag): Add v2 optimizers here. +def strategy_and_optimizer_combinations(): + return combinations.times( + all_strategy_combinations(), + combinations.combine( + optimizer=[combinations.adagrad_optimizer_v1_fn, + combinations.adam_optimizer_v1_fn, + combinations.gradient_descent_optimizer_v1_fn, + combinations.rmsprop_optimizer_v1_fn])) + + +def strategy_and_input_combinations(): + return ( + combinations.times( + combinations.combine(distribution=strategies_minus_tpu), + combinations.combine(mode=['graph'], + use_numpy=[True, False], + use_validation_data=[True, False]) + + combinations.combine(mode=['eager'], + use_numpy=[False], + use_validation_data=[False])) + + combinations.times( + combinations.combine(distribution=tpu_strategies), + combinations.combine(mode=['graph'], + use_numpy=[True, False], + use_validation_data=[True, False]))) + + +def strategy_for_numpy_input_combinations(): + return combinations.combine( + distribution=strategies_minus_tpu + tpu_strategies, + mode=['graph']) + + +class TestEstimatorDistributionStrategy(test_util.TensorFlowTestCase, + parameterized.TestCase): + + def setUp(self): + self._base_dir = os.path.join(self.get_temp_dir(), + 'keras_mirrored_strategy_test') + gfile.MakeDirs(self._base_dir) + self._config = run_config_lib.RunConfig( + tf_random_seed=_RANDOM_SEED, model_dir=self._base_dir) + + def tearDown(self): + writer_cache.FileWriterCache.clear() + if os.path.isdir(self._base_dir): + gfile.DeleteRecursively(self._base_dir) + + @combinations.generate(combinations.combine( + distribution=[ + combinations.mirrored_strategy_with_gpu_and_cpu, + combinations.mirrored_strategy_with_two_gpus, + combinations.core_mirrored_strategy_with_gpu_and_cpu, + combinations.core_mirrored_strategy_with_two_gpus], + mode=['graph'])) + def test_train_functional_with_distribution_strategy(self, distribution): + keras_model = simple_functional_model() + keras_model.compile( + loss='categorical_crossentropy', + metrics=[keras.metrics.CategoricalAccuracy()], + optimizer=rmsprop.RMSPropOptimizer(learning_rate=0.01)) + config = run_config_lib.RunConfig(tf_random_seed=_RANDOM_SEED, + model_dir=self._base_dir, + train_distribute=distribution, + eval_distribute=distribution) + with self.cached_session(): + est_keras = keras_lib.model_to_estimator( + keras_model=keras_model, config=config) + before_eval_results = est_keras.evaluate( + input_fn=get_ds_test_input_fn, steps=1) + est_keras.train(input_fn=get_ds_train_input_fn, steps=_TRAIN_SIZE / 16) + after_eval_results = est_keras.evaluate(input_fn=get_ds_test_input_fn, + steps=1) + self.assertLess(after_eval_results['loss'], before_eval_results['loss']) + + writer_cache.FileWriterCache.clear() + gfile.DeleteRecursively(self._config.model_dir) + + @combinations.generate(combinations.combine( + distribution=[ + combinations.mirrored_strategy_with_gpu_and_cpu, + combinations.mirrored_strategy_with_two_gpus, + combinations.core_mirrored_strategy_with_gpu_and_cpu, + combinations.core_mirrored_strategy_with_two_gpus], + mode=['graph'])) + def test_train_sequential_with_distribution_strategy(self, distribution): + keras_model = simple_sequential_model() + keras_model.compile( + loss='categorical_crossentropy', + metrics=[keras.metrics.CategoricalAccuracy()], + optimizer=rmsprop.RMSPropOptimizer(learning_rate=0.01)) + config = run_config_lib.RunConfig(tf_random_seed=_RANDOM_SEED, + model_dir=self._base_dir, + train_distribute=distribution) + with self.cached_session(): + est_keras = keras_lib.model_to_estimator( + keras_model=keras_model, config=config) + before_eval_results = est_keras.evaluate( + input_fn=get_ds_test_input_fn, steps=1) + est_keras.train(input_fn=get_ds_train_input_fn, steps=_TRAIN_SIZE / 16) + after_eval_results = est_keras.evaluate(input_fn=get_ds_test_input_fn, + steps=1) + self.assertLess(after_eval_results['loss'], before_eval_results['loss']) + + writer_cache.FileWriterCache.clear() + gfile.DeleteRecursively(self._config.model_dir) + + @combinations.generate(combinations.combine( + distribution=[ + combinations.mirrored_strategy_with_gpu_and_cpu, + combinations.core_mirrored_strategy_with_gpu_and_cpu], + mode=['graph'])) + def test_multi_inputs_multi_outputs_with_input_fn_as_dict(self, distribution): + train_data, test_data = get_multi_inputs_multi_outputs_data() + + def train_input_fn(): + input_dict = { + 'input_a': train_data['input_a'], + 'input_b': train_data['input_b'], + 'input_m': train_data['input_m'].astype(np.str) + } + output_dict = { + 'dense_2': train_data['output_c'], + 'dense_3': train_data['output_d'] + } + return dataset_ops.Dataset.from_tensor_slices((input_dict, + output_dict)).batch(16) + + def eval_input_fn(): + input_dict = { + 'input_a': test_data['input_a'], + 'input_b': test_data['input_b'], + 'input_m': test_data['input_m'].astype(np.str) + } + output_dict = { + 'dense_2': test_data['output_c'], + 'dense_3': test_data['output_d'] + } + return dataset_ops.Dataset.from_tensor_slices((input_dict, + output_dict)).batch(16) + + self.do_test_multi_inputs_multi_outputs_with_input_fn( + distribution, train_input_fn, eval_input_fn) + + def do_test_multi_inputs_multi_outputs_with_input_fn( + self, distribution, train_input_fn, eval_input_fn): + config = run_config_lib.RunConfig( + tf_random_seed=_RANDOM_SEED, + model_dir=self._base_dir, + train_distribute=distribution) + with self.cached_session(): + model = multi_inputs_multi_outputs_model() + est_keras = keras_lib.model_to_estimator(keras_model=model, config=config) + baseline_eval_results = est_keras.evaluate( + input_fn=eval_input_fn, steps=1) + est_keras.train(input_fn=train_input_fn, steps=_TRAIN_SIZE / 16) + eval_results = est_keras.evaluate(input_fn=eval_input_fn, steps=1) + self.assertLess(eval_results['loss'], baseline_eval_results['loss']) + + @combinations.generate(combinations.combine( + distribution=[ + combinations.mirrored_strategy_with_gpu_and_cpu, + combinations.core_mirrored_strategy_with_gpu_and_cpu], + mode=['graph'])) + def test_keras_optimizer_with_distribution_strategy(self, distribution): + keras_model = simple_sequential_model() + keras_model.compile( + loss='categorical_crossentropy', + optimizer=keras.optimizers.rmsprop(lr=0.01)) + + config = run_config_lib.RunConfig(tf_random_seed=_RANDOM_SEED, + model_dir=self._base_dir, + train_distribute=distribution) + with self.cached_session(): + est_keras = keras_lib.model_to_estimator(keras_model=keras_model, + config=config) + with self.assertRaisesRegexp(ValueError, + 'Only TensorFlow native optimizers are ' + 'supported with DistributionStrategy.'): + est_keras.train(input_fn=get_ds_train_input_fn, steps=_TRAIN_SIZE / 16) + + writer_cache.FileWriterCache.clear() + gfile.DeleteRecursively(self._config.model_dir) + + +class TestDistributionStrategyWithNumpyArrays(test.TestCase, + parameterized.TestCase): + + @combinations.generate(strategy_for_numpy_input_combinations()) + def test_creating_var_with_numpy_arrays(self, distribution): + with self.cached_session(): + x = np.asarray(np.random.random((64, 3)), dtype=np.float32) + var_x = distributed_training_utils.get_var_for_numpy(distribution, x) + val = self.evaluate(var_x.value()) + # Verify that the numpy value is copied to the variable. + self.assertAllEqual(x, val) + + @combinations.generate(strategy_for_numpy_input_combinations()) + def test_calculating_input_params_no_steps_no_batch_size(self, distribution): + # Calculate the per_replica_batch_size scaling factor for strategies + # that use per_core_batch_size + replica_scale_factor = 1.0 + if not distributed_training_utils.global_batch_size_supported(distribution): + replica_scale_factor = distribution.num_replicas_in_sync + + with self.cached_session(): + # Input samples of different sizes + input_20_samples = np.zeros((20, 3), dtype=np.float32) + input_63_samples = np.zeros((63, 3), dtype=np.float32) + input_64_samples = np.zeros((64, 3), dtype=np.float32) + + # Default global batch size 32 for input with 64 samples run in 2 steps + steps, batch_size = distributed_training_utils.get_input_params( + distribution, input_64_samples, steps=None, batch_size=None) + self.assertEqual(batch_size, 32 // replica_scale_factor) + self.assertEqual(steps, 2) + + # Computed global batch size 20 is lower than 32 if we pass less samples. + steps, batch_size = distributed_training_utils.get_input_params( + distribution, input_20_samples, steps=None, batch_size=None) + self.assertEqual(batch_size, 20 // replica_scale_factor) + self.assertEqual(steps, 1) + + # Default global batch size 32 cannot be used with 63 samples. + with self.assertRaisesRegexp(ValueError, 'not divisible by batch size'): + distributed_training_utils.get_input_params( + distribution, input_63_samples, steps=None, batch_size=None) + + @combinations.generate(strategy_for_numpy_input_combinations()) + def test_calculating_input_params_with_steps_no_batch_size(self, + distribution): + # Calculate the per_replica_batch_size scaling factor for strategies + # that use per_core_batch_size + replica_scale_factor = 1.0 + if not distributed_training_utils.global_batch_size_supported(distribution): + replica_scale_factor = distribution.num_replicas_in_sync + + with self.cached_session(): + # Input samples of different sizes + input_63_samples = np.zeros((63, 3), dtype=np.float32) + input_64_samples = np.zeros((64, 3), dtype=np.float32) + + # Computed global batch size is correct for number of specified 1 step + steps, batch_size = distributed_training_utils.get_input_params( + distribution, input_64_samples, steps=1, batch_size=None) + self.assertEqual(batch_size, 64 // replica_scale_factor) + self.assertEqual(steps, 1) + + # Computed global batch size is correct for number of specified 2 steps + steps, batch_size = distributed_training_utils.get_input_params( + distribution, input_64_samples, steps=2, batch_size=None) + self.assertEqual(batch_size, 32 // replica_scale_factor) + self.assertEqual(steps, 2) + + # All samples can not be consumed in specified number of steps + with self.assertRaisesRegexp(ValueError, 'not divisible by steps'): + distributed_training_utils.get_input_params( + distribution, input_63_samples, steps=2, batch_size=None) + + # This cases is different for different strategies due to the + # difference in supported batch size being global or per-replica. + if replica_scale_factor == 1: + # Computed global batch size is correct even if not sharadable + steps, batch_size = distributed_training_utils.get_input_params( + distribution, input_63_samples, steps=3, batch_size=None) + self.assertEqual(batch_size, 21) + self.assertEqual(steps, 3) + else: + # Computed global batch size can not be sharded across replicas + with self.assertRaisesRegexp(ValueError, 'could not be sharded evenly ' + 'across the sync replicas'): + distributed_training_utils.get_input_params( + distribution, input_63_samples, steps=1, batch_size=None) + + @combinations.generate(strategy_for_numpy_input_combinations()) + def test_calculating_input_params_no_steps_with_batch_size(self, + distribution): + # Calculate the per_replica_batch_size scaling factor for strategies + # that use per_core_batch_size + replica_scale_factor = 1.0 + if not distributed_training_utils.global_batch_size_supported(distribution): + replica_scale_factor = distribution.num_replicas_in_sync + + with self.cached_session(): + input_64_samples = np.zeros((64, 3), dtype=np.float32) + + # Computed steps is correct for specified batch size + steps, batch_size = distributed_training_utils.get_input_params( + distribution, input_64_samples, steps=None, batch_size=16) + self.assertEqual(batch_size, 16) + self.assertEqual(steps, 4 // replica_scale_factor) + + # Computed steps is correct for specified batch size + steps, batch_size = distributed_training_utils.get_input_params( + distribution, input_64_samples, steps=None, batch_size=32) + self.assertEqual(batch_size, 32) + self.assertEqual(steps, 2 // replica_scale_factor) + + # Number of samples is not divisible by the global batch size + with self.assertRaisesRegexp(ValueError, 'not divisible by batch size'): + distributed_training_utils.get_input_params( + distribution, input_64_samples, steps=None, batch_size=20) + + # Number of samples is not divisible by the global batch size + with self.assertRaisesRegexp(ValueError, 'not divisible by batch size'): + distributed_training_utils.get_input_params( + distribution, input_64_samples, steps=None, batch_size=3) + + @combinations.generate(strategy_for_numpy_input_combinations()) + def test_calculating_input_params_with_steps_with_batch_size(self, + distribution): + with self.cached_session(): + input_64_samples = np.zeros((64, 3), dtype=np.float32) + + # No change to steps and batch size if both specified and feasible + steps, batch_size = distributed_training_utils.get_input_params( + distribution, input_64_samples, steps=5, batch_size=3) + self.assertEqual(batch_size, 3) + self.assertEqual(steps, 5) + + # Number of samples is less than global batch size * steps + with self.assertRaisesRegexp(ValueError, 'less than samples required'): + distributed_training_utils.get_input_params( + distribution, input_64_samples, steps=10, batch_size=13) + + @combinations.generate(strategy_for_numpy_input_combinations()) + def test_calling_model_with_numpy_arrays(self, distribution): + with self.cached_session(): + model = get_model() + + optimizer = gradient_descent.GradientDescentOptimizer(0.001) + loss = 'mse' + metrics = ['mae'] + model.compile(optimizer, loss, metrics=metrics, distribute=distribution) + + inputs = np.zeros((64, 3), dtype=np.float32) + targets = np.zeros((64, 4), dtype=np.float32) + + # Call fit with validation data + model.fit(inputs, targets, epochs=1, batch_size=2, verbose=0, + validation_data=(inputs, targets)) + + # TODO(anjalisridhar): We need tests for when the batch size and steps are + # smaller and results in a 0 batch_size and steps value. + model.evaluate(inputs, targets) + # with steps + model.evaluate(inputs, targets, steps=2) + # with batch_size + model.evaluate(inputs, targets, batch_size=8) + + model.predict(inputs) + # with steps + model.predict(inputs, steps=2) + # with batch_size + model.predict(inputs, batch_size=8) + + @combinations.generate(strategy_for_numpy_input_combinations()) + def test_calling_model_with_nested_numpy_arrays(self, distribution): + with self.cached_session(): + model = multi_input_output_model() + + optimizer = gradient_descent.GradientDescentOptimizer(learning_rate=0.001) + loss = 'mse' + model.compile(optimizer, loss, distribute=distribution) + + input_a_np = np.asarray(np.random.random((64, 3)), dtype=np.float32) + input_b_np = np.asarray(np.random.random((64, 5)), dtype=np.float32) + inputs = [input_a_np, input_b_np] + + output_d_np = np.asarray(np.random.random((64, 7)), dtype=np.float32) + output_e_np = np.asarray(np.random.random((64, 7)), dtype=np.float32) + targets = [output_d_np, output_e_np] + + # Call fit with validation data + model.fit(inputs, targets, epochs=1, batch_size=8, verbose=0) + + # TODO(anjalisridhar): We need tests for when the batch size and steps are + # smaller and results in a 0 batch_size and steps value. + model.evaluate(inputs, targets) + # with steps + model.evaluate(inputs, targets, steps=2) + # with batch_size + model.evaluate(inputs, targets, batch_size=8) + + model.predict(inputs) + # with steps + model.predict(inputs, steps=2) + # with batch_size + model.predict(inputs, batch_size=8) + + @combinations.generate(combinations.combine( + distribution=strategies_minus_tpu, mode=['graph'])) + def test_numpy_with_sample_weights(self, distribution): + model = get_model() + optimizer = rmsprop.RMSPropOptimizer(learning_rate=0.001) + loss = 'mse' + model.compile(optimizer, loss, distribute=distribution) + + inputs = np.zeros((20, 3), np.float32) + targets = np.zeros((20, 4), np.float32) + sample_weights = np.ones((20), np.float32) + + model.fit(inputs, targets, sample_weight=sample_weights, epochs=1, + steps_per_epoch=2, verbose=1) + + @combinations.generate(strategy_for_numpy_input_combinations()) + def test_flatten_predict_outputs(self, distribution): + with self.cached_session(): + model = multi_input_output_model() + + optimizer = gradient_descent.GradientDescentOptimizer(learning_rate=0.001) + loss = 'mse' + model.compile(optimizer, loss, distribute=distribution) + + # We take 6 input samples with each input having a dimension of 3 or 5. + input_a_np = np.asarray(np.random.random((6, 3)), dtype=np.float32) + input_b_np = np.asarray(np.random.random((6, 5)), dtype=np.float32) + inputs = [input_a_np, input_b_np] + + outs = model.predict(inputs, steps=1) + # `predict` a list that is equal in length to the number of model outputs. + # In this test our model has two outputs and each element of `outs` + # corresponds to all the samples of one of the model outputs. + self.assertLen(outs, 2) + # Each of the output samples have a dimension of 7. We should process all + # the available input samples(6). + self.assertAllEqual([6, 7], outs[0].shape) + self.assertAllEqual([6, 7], outs[1].shape) + + +class TestDistributionStrategyWithDatasets(test.TestCase, + parameterized.TestCase): + + @combinations.generate(all_strategy_combinations()) + def test_calling_model_on_same_dataset(self, distribution): + with self.cached_session(): + model = get_model() + + optimizer = gradient_descent.GradientDescentOptimizer(0.001) + loss = 'mse' + metrics = ['mae', keras.metrics.CategoricalAccuracy()] + model.compile(optimizer, loss, metrics=metrics, distribute=distribution) + + dataset = get_dataset(distribution) + + # Call fit with validation data + model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=0, + validation_data=dataset, validation_steps=2) + model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=0, + validation_data=dataset, validation_steps=2) + model.predict(get_predict_dataset(distribution), steps=2) + + @combinations.generate(all_strategy_combinations()) + def test_model_interleaved_eval_same_as_direct_eval(self, distribution): + with self.cached_session(): + user_controlled_model = get_model() + user_controlled_model.compile( + gradient_descent.GradientDescentOptimizer(0.001), + loss='mse', + metrics=['mae', keras.metrics.CategoricalAccuracy()], + distribute=distribution) + + interleaved_model = get_model() + interleaved_model.set_weights(user_controlled_model.get_weights()) + interleaved_model.compile( + gradient_descent.GradientDescentOptimizer(0.001), + loss='mse', + metrics=['mae', keras.metrics.CategoricalAccuracy()], + distribute=distribution) + + dataset = get_dataset(distribution) + + # Call fit with validation interleaved + interleaved_output = interleaved_model.fit( + dataset, epochs=2, steps_per_epoch=2, verbose=1, + validation_data=dataset, validation_steps=2, shuffle=False) + + # Manually control the validation running after each epoch. + user_controlled_output = [] + for _ in range(2): + user_controlled_model.fit( + dataset, epochs=1, steps_per_epoch=2, verbose=1, shuffle=False) + user_controlled_output.append( + user_controlled_model.evaluate(dataset, steps=2)) + + self.assertEqual(interleaved_output.history['val_loss'], + [x[0] for x in user_controlled_output]) + self.assertEqual(interleaved_output.history['val_mean_absolute_error'], + [x[1] for x in user_controlled_output]) + self.assertEqual(interleaved_output.history['val_categorical_accuracy'], + [x[2] for x in user_controlled_output]) + + # TODO(priyag): Enable this test for TPU. Currently tuples/dict don't work + # as clone_model's input_tensors argument only seems to accept list and not + # tuples or dict. + + @combinations.generate(combinations.combine( + distribution=[ + combinations.mirrored_strategy_with_gpu_and_cpu, + combinations.core_mirrored_strategy_with_gpu_and_cpu], + mode=['graph', 'eager'])) + def test_fit_with_tuple_and_dict_dataset_inputs(self, distribution): + with self.cached_session(): + model = multi_input_output_model() + + optimizer = gradient_descent.GradientDescentOptimizer(learning_rate=0.001) + loss = 'mse' + metrics = ['mae', keras.metrics.CategoricalAccuracy()] + model.compile(optimizer, loss, metrics=metrics, distribute=distribution) + + input_a_np = np.random.random((10, 3)) + input_b_np = np.random.random((10, 5)) + output_d_np = np.random.random((10, 7)) + output_e_np = np.random.random((10, 7)) + + # Test with tuples + dataset_tuple = dataset_ops.Dataset.from_tensor_slices(( + (input_a_np, input_b_np), (output_d_np, output_e_np))) + dataset_tuple = dataset_tuple.repeat(100) + dataset_tuple = dataset_tuple.batch(10) + + model.fit(dataset_tuple, epochs=1, steps_per_epoch=2, verbose=1) + + # Test with dict + dataset_dict = dataset_ops.Dataset.from_tensor_slices(( + {'input_a': input_a_np, 'input_b': input_b_np}, + (output_d_np, output_e_np))) + dataset_dict = dataset_dict.repeat(100) + dataset_dict = dataset_dict.batch(10) + + model.fit(dataset_dict, epochs=1, steps_per_epoch=2, verbose=1) + + @combinations.generate(all_strategy_combinations()) + def test_fit_eval_and_predict_methods_on_dataset(self, distribution): + with self.cached_session(): + model = get_model() + + optimizer = gradient_descent.GradientDescentOptimizer(0.001) + loss = 'mse' + metrics = ['mae', keras.metrics.CategoricalAccuracy()] + model.compile(optimizer, loss, metrics=metrics, distribute=distribution) + + dataset = get_dataset(distribution) + + model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=1) + model.evaluate(dataset, steps=2, verbose=1) + model.predict(get_predict_dataset(distribution), steps=2) + + @combinations.generate(strategy_and_optimizer_combinations()) + def test_fit_eval_and_predict_with_optimizer(self, distribution, optimizer): + with self.cached_session(): + model = get_model() + + loss = 'mse' + model.compile(optimizer(), loss, distribute=distribution) + + dataset = get_dataset(distribution) + + model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=1) + model.evaluate(dataset, steps=2, verbose=1) + model.predict(get_predict_dataset(distribution), steps=2) + + @combinations.generate(strategy_minus_tpu_combinations()) + def test_dataset_with_sample_weights(self, distribution): + model = get_model() + optimizer = rmsprop.RMSPropOptimizer(learning_rate=0.001) + loss = 'mse' + model.compile(optimizer, loss, distribute=distribution) + + inputs = np.zeros((10, 3), np.float32) + targets = np.zeros((10, 4), np.float32) + sample_weights = np.ones((10), np.float32) + dataset = dataset_ops.Dataset.from_tensor_slices((inputs, targets, + sample_weights)) + dataset = dataset.repeat() + dataset = dataset.batch(10) + + model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=1) + model.evaluate(dataset, steps=2, verbose=1) + model.predict(dataset, steps=2) + + @combinations.generate(combinations.combine( + distribution=[ + combinations.mirrored_strategy_with_gpu_and_cpu, + combinations.core_mirrored_strategy_with_gpu_and_cpu], + mode=['graph', 'eager'])) + # TODO(b/120943676, b/120957836): Re-enable once the validation code is + # restored. + def DISABLED_test_dataset_wrong_input_shape(self, distribution): + with self.cached_session(): + model = get_model() + + optimizer = rmsprop.RMSPropOptimizer(learning_rate=0.001) + loss = 'mse' + model.compile(optimizer, loss, distribute=distribution) + + # Wrong input shape + inputs = np.zeros((10, 5), dtype=np.float32) + targets = np.zeros((10, 4), dtype=np.float32) + dataset = dataset_ops.Dataset.from_tensor_slices((inputs, targets)) + dataset = dataset.repeat(100) + dataset = dataset.batch(10) + + with self.assertRaisesRegexp(ValueError, + 'expected input to have shape'): + model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=0) + + @combinations.generate(combinations.combine( + distribution=[combinations.mirrored_strategy_with_gpu_and_cpu], + mode=['graph', 'eager'])) + # TODO(b/120943676, b/120957836): Re-enable once the validation code is + # restored. + def DISABLED_test_dataset_no_batch_input_validation(self, distribution): + with self.cached_session(): + model = get_model() + + optimizer = rmsprop.RMSPropOptimizer(learning_rate=0.001) + loss = 'mse' + model.compile(optimizer, loss, distribute=distribution) + + # User forgets to batch the dataset + inputs = np.zeros((10, 3), dtype=np.float32) + targets = np.zeros((10, 4), dtype=np.float32) + dataset = dataset_ops.Dataset.from_tensor_slices((inputs, targets)) + dataset = dataset.repeat(100) + + with self.assertRaisesRegexp(ValueError, 'expected input to have shape'): + model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=0) + + @combinations.generate(combinations.combine( + distribution=[combinations.tpu_strategy_one_step], + mode=['graph'])) + def test_dataset_input_shape_fully_defined(self, distribution): + with self.cached_session(): + model = get_model() + + optimizer = rmsprop.RMSPropOptimizer(learning_rate=0.001) + loss = 'mse' + model.compile(optimizer, loss, distribute=distribution) + + dataset = get_dataset(distribution) + # Input shapes are not fully known. Batch dimension is unknown as we are + # not using the drop_remainder argument. + dataset = dataset.repeat(100).batch(10) + + with self.assertRaisesRegexp(ValueError, 'requires fully defined shapes'): + model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=0) + + @combinations.generate(combinations.combine( + distribution=[ + combinations.mirrored_strategy_with_gpu_and_cpu, + combinations.mirrored_strategy_with_two_gpus, + combinations.core_mirrored_strategy_with_gpu_and_cpu, + combinations.core_mirrored_strategy_with_two_gpus], + mode=['graph', 'eager'])) + def test_learning_phase_value(self, distribution): + # TODO(anjalisridhar): Modify this test to use Lambdas since we can compare + # meaningful values. Currently we don't pass the learning phase if the + # Lambda layer uses the learning phase. + with self.cached_session(): + x = keras.layers.Input(shape=(1,), name='input') + y = keras.layers.Dense(1, kernel_initializer='ones')(x) + z = keras.layers.Dropout(0.9999)(y) + model = keras.Model(x, z) + initial_weights = model.get_weights() + + optimizer = gradient_descent.GradientDescentOptimizer(0.005) + loss = 'mse' + metrics = ['acc'] + model.compile(optimizer, loss, metrics=metrics, distribute=distribution) + + batch_size = 8 + if isinstance(distribution, mirrored_strategy.CoreMirroredStrategy): + # CoreMirroredStrategy uses global batch size. + batch_size = 8 * distribution.num_replicas_in_sync + + inputs = np.ones((10, 1), dtype=np.float32) + targets = np.ones((10, 1), dtype=np.float32) + dataset = dataset_ops.Dataset.from_tensor_slices((inputs, targets)) + dataset = dataset.repeat().batch(batch_size) + hist = model.fit(dataset, epochs=1, steps_per_epoch=20, verbose=1) + self.assertAlmostEqual(hist.history['acc'][0], 0, 0) + + model.set_weights(initial_weights) + # TODO(psv/anjalisridhar): Enable these lines after we fix b/117431185. + # evaluate_output = model.evaluate(dataset, steps=20) + # self.assertAlmostEqual(evaluate_output[1], 1, 0) + + inputs = np.ones((10, 1), dtype=np.float32) + predict_dataset = dataset_ops.Dataset.from_tensor_slices(inputs) + + predict_dataset = predict_dataset.repeat().batch(batch_size) + output = model.predict(predict_dataset, steps=10) + # `predict` runs for 10 steps + ref_output = np.ones((160, 1), dtype=np.float32) + self.assertArrayNear(output, ref_output, 1e-1) + + @combinations.generate(strategy_minus_tpu_combinations()) + def testOptimizerWithCallbacks(self, distribution): + with self.cached_session(): + model = get_model() + + optimizer = gradient_descent_keras.SGD(0.01) + loss = 'mse' + model.compile(optimizer, loss, distribute=distribution) + + dataset = get_dataset(distribution) + + def schedule(_): + return 0.001 + + model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=0, + callbacks=[keras.callbacks.LearningRateScheduler(schedule)]) + grouped_models = distribution.unwrap(model._distributed_model) + with distribution.scope(): + for m in grouped_models: + self.assertAllClose(0.001, keras.backend.get_value( + m.optimizer.lr), atol=1e-05, rtol=1e-05) + + +class TestDistributionStrategyErrorCases(test.TestCase, parameterized.TestCase): + + @combinations.generate(combinations.combine( + distribution=[ + combinations.mirrored_strategy_with_gpu_and_cpu, + combinations.core_mirrored_strategy_with_gpu_and_cpu], + mode=['graph', 'eager'])) + def test_validating_dataset_input_tensors_with_shape_mismatch(self, + distribution): + with self.cached_session(): + a = constant_op.constant([1, 2], shape=(1, 2)) + b = constant_op.constant([[1, 2], [1, 2]], shape=(2, 2)) + device_map = values.ReplicaDeviceMap(('/device:CPU:0', '/device:GPU:0')) + x = values.DistributedValues(device_map, (a, b)) + y = values.DistributedValues(device_map, (a, a)) + with distribution.scope(): + # Removed device and input tensor shape details from the error message + # since the order of the device and the corresponding input tensor shape + # is not deterministic over different runs. + with self.assertRaisesRegexp(ValueError, + 'Input tensor shapes do not match for ' + 'distributed tensor inputs ' + 'DistributedValues:.+'): + distributed_training_utils.validate_distributed_dataset_inputs( + distribution, x, y) + + @combinations.generate(combinations.combine( + distribution=[ + combinations.mirrored_strategy_with_gpu_and_cpu, + combinations.core_mirrored_strategy_with_gpu_and_cpu], + mode=['graph', 'eager'])) + def test_validating_dataset_input_tensors_with_dtype_mismatch(self, + distribution): + with self.cached_session(): + a = constant_op.constant([1, 2], shape=(1, 2), dtype=dtypes.int32) + b = constant_op.constant([1, 2], shape=(1, 2), dtype=dtypes.float64) + device_map = values.ReplicaDeviceMap(('/device:CPU:0', '/device:GPU:0')) + x = values.DistributedValues(device_map, (a, b)) + y = values.DistributedValues(device_map, (a, a)) + with distribution.scope(): + # Removed device and input tensor dtype details from the error message + # since the order of the device and the corresponding input tensor dtype + # is not deterministic over different runs. + with self.assertRaisesRegexp(ValueError, + 'Input tensor dtypes do not match for ' + 'distributed tensor inputs ' + 'DistributedValues:.+'): + distributed_training_utils.validate_distributed_dataset_inputs( + distribution, x, y) + + @combinations.generate(combinations.combine( + distribution=[ + combinations.mirrored_strategy_with_gpu_and_cpu, + combinations.core_mirrored_strategy_with_gpu_and_cpu], + mode=['graph', 'eager'])) + def test_unsupported_features(self, distribution): + with self.cached_session(): + model = get_model() + + optimizer = gradient_descent.GradientDescentOptimizer(0.001) + loss = 'mse' + metrics = ['mae'] + model.compile(optimizer, loss, metrics=metrics, distribute=distribution) + + dataset = get_dataset(distribution) + + # Test with validation split + with self.assertRaisesRegexp( + ValueError, '`validation_split` argument is not ' + 'supported when input `x` is a dataset or a ' + 'dataset iterator.+'): + model.fit(dataset, + epochs=1, steps_per_epoch=2, verbose=0, + validation_split=0.5, validation_steps=2) + + # Test with sample weight. + sample_weight = np.random.random((10,)) + with self.assertRaisesRegexp( + ValueError, '`sample_weight` argument is not supported when input ' + '`x` is a dataset or a dataset iterator.'): + model.fit( + dataset, + epochs=1, + steps_per_epoch=2, + verbose=0, + sample_weight=sample_weight) + + # Test with not specifying the `steps` argument. + with self.assertRaisesRegexp( + ValueError, 'you should specify the `steps_per_epoch` argument'): + model.fit(dataset, epochs=1, verbose=0) + with self.assertRaisesRegexp(ValueError, + 'you should specify the `steps` argument'): + model.evaluate(dataset, verbose=0) + + with self.assertRaisesRegexp(ValueError, + 'you should specify the `steps` argument'): + model.predict(dataset, verbose=0) + + @combinations.generate(combinations.combine( + distribution=[ + combinations.mirrored_strategy_with_gpu_and_cpu, + combinations.core_mirrored_strategy_with_gpu_and_cpu], + mode=['graph', 'eager'])) + def test_calling_with_unsupported_predefined_callbacks(self, distribution): + with self.cached_session(): + model = get_model() + + optimizer = gradient_descent.GradientDescentOptimizer(0.001) + loss = 'mse' + metrics = ['mae'] + model.compile(optimizer, loss, metrics=metrics, distribute=distribution) + + dataset = get_dataset(distribution) + + def schedule(_): + return 0.001 + with self.assertRaisesRegexp(ValueError, + 'You must specify a Keras Optimizer V2 when ' + 'using'): + model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=0, + callbacks=[keras.callbacks.LearningRateScheduler(schedule)]) + + with self.assertRaisesRegexp(ValueError, + 'You must specify a Keras Optimizer V2 when ' + 'using'): + model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=0, + callbacks=[keras.callbacks.ReduceLROnPlateau()]) + + +class TestDistributionStrategyWithLossMasking(test.TestCase, + parameterized.TestCase): + + # TODO(priyag): Enable all strategies for this test. Currently it does not + # work for TPU due to some invalid datatype. + @combinations.generate(combinations.combine( + distribution=[ + combinations.mirrored_strategy_with_gpu_and_cpu, + combinations.core_mirrored_strategy_with_gpu_and_cpu], + mode=['graph', 'eager'])) + def test_masking(self, distribution): + with self.cached_session(): + np.random.seed(1337) + x = np.array([[[1], [1]], [[0], [0]]]) + model = keras.models.Sequential() + model.add(keras.layers.Masking(mask_value=0, input_shape=(2, 1))) + model.add( + keras.layers.TimeDistributed( + keras.layers.Dense(1, kernel_initializer='one'))) + model.compile(loss='mse', + optimizer=gradient_descent.GradientDescentOptimizer(0.01), + distribute=distribution) + y = np.array([[[1], [1]], [[1], [1]]]) + dataset = dataset_ops.Dataset.from_tensor_slices((x, y)) + dataset = dataset.repeat(100) + dataset = dataset.batch(10) + hist = model.fit(x=dataset, epochs=1, steps_per_epoch=2) + self.assertEqual(hist.history['loss'][0], 0) + + +class TestDistributionStrategyWithNormalizationLayer( + test.TestCase, parameterized.TestCase): + + @combinations.generate(all_strategy_combinations()) + def test_batchnorm_correctness(self, distribution): + with self.cached_session(): + model = keras.models.Sequential() + norm = keras.layers.BatchNormalization(input_shape=(10,), momentum=0.8) + model.add(norm) + model.compile(loss='mse', + optimizer=gradient_descent.GradientDescentOptimizer(0.01), + distribute=distribution) + + # centered on 5.0, variance 10.0 + x = np.random.normal(loc=5.0, scale=10.0, size=(1000, 10)) + x = x.astype('float32') + dataset = dataset_ops.Dataset.from_tensor_slices((x, x)) + dataset = dataset.repeat(100) + dataset = batch_wrapper(dataset, 32, distribution) + + predict_dataset = dataset_ops.Dataset.from_tensor_slices(x) + predict_dataset = predict_dataset.repeat(100) + predict_dataset = batch_wrapper(predict_dataset, 32, distribution) + + model.fit(dataset, epochs=4, verbose=0, steps_per_epoch=10) + out = model.predict(predict_dataset, steps=2) + out -= keras.backend.eval(norm.beta) + out /= keras.backend.eval(norm.gamma) + np.testing.assert_allclose(out.mean(), 0.0, atol=1e-1) + np.testing.assert_allclose(out.std(), 1.0, atol=1e-1) + + +class TestDistributionStrategyCorrectness(test.TestCase, + parameterized.TestCase): + + @combinations.generate(all_strategy_combinations()) + def test_metric_correctness(self, distribution): + with self.cached_session(): + keras.backend.set_image_data_format('channels_last') + num_samples = 10000 + + x_train = np.random.randint(0, 2, num_samples) + x_train = np.reshape(x_train, (num_samples, 1)) + y_train = x_train + x_train = x_train.astype('float32') + y_train = y_train.astype('float32') + + # Create identity model. + model = keras.Sequential() + model.add( + keras.layers.Dense(1, input_shape=(1,), kernel_initializer='ones')) + model.compile( + loss=keras.losses.mean_squared_error, + optimizer=gradient_descent.GradientDescentOptimizer(0.5), + metrics=[keras.metrics.BinaryAccuracy()], + distribute=distribution) + + batch_size = 64 + if not distributed_training_utils.global_batch_size_supported( + distribution): + batch_size //= distribution.num_replicas_in_sync + train_dataset = dataset_ops.Dataset.from_tensor_slices((x_train, y_train)) + train_dataset = batch_wrapper(train_dataset, batch_size, distribution) + + history = model.fit(x=train_dataset, epochs=2, steps_per_epoch=10) + self.assertEqual(history.history['binary_accuracy'], [1.0, 1.0]) + + @combinations.generate(all_strategy_combinations()) + def test_eval_metrics_correctness(self, distribution): + with self.cached_session(): + model = keras.Sequential() + model.add( + keras.layers.Dense( + 3, activation='relu', input_dim=4, kernel_initializer='ones')) + model.add( + keras.layers.Dense( + 1, activation='sigmoid', kernel_initializer='ones')) + model.compile( + loss='mae', + metrics=['accuracy', keras.metrics.BinaryAccuracy()], + optimizer=gradient_descent.GradientDescentOptimizer(0.001), + distribute=distribution) + + # verify correctness of stateful and stateless metrics. + x = np.ones((100, 4)).astype('float32') + y = np.ones((100, 1)).astype('float32') + dataset = dataset_ops.Dataset.from_tensor_slices((x, y)).repeat() + dataset = batch_wrapper(dataset, 4, distribution) + outs = model.evaluate(dataset, steps=10) + self.assertEqual(outs[1], 1.) + self.assertEqual(outs[2], 1.) + + y = np.zeros((100, 1)).astype('float32') + dataset = dataset_ops.Dataset.from_tensor_slices((x, y)).repeat() + dataset = batch_wrapper(dataset, 4, distribution) + outs = model.evaluate(dataset, steps=10) + self.assertEqual(outs[1], 0.) + self.assertEqual(outs[2], 0.) + + @combinations.generate(strategy_and_input_combinations()) + def test_correctness(self, distribution, use_numpy, use_validation_data): + with self.cached_session(): + default_tolerance = 1e-5 + tol_table = {} + + if isinstance(distribution, ( + mirrored_strategy.MirroredStrategy, + mirrored_strategy.CoreMirroredStrategy, + distribute_lib._DefaultDistributionStrategy)): # pylint: disable=protected-access + # TODO(b/119257215): Weights are not exactly the same, so use larger + # tolerance for now. Predict should be related to weights. + tol_table = { + 'weights_1': 1e-4, + 'weights_2': 1e-4, + 'predict_result_1': 1e-4, + } + + keras.backend.set_image_data_format('channels_last') + np.random.seed(_RANDOM_SEED) + random_seed.set_random_seed(_RANDOM_SEED) + + # Train, eval, and predict datasets are created with the same input numpy + # arrays. + # TODO(xiejw): Change this back to 10000, once we support final partial + # batch. + num_samples = 9984 + x_train = np.random.rand(num_samples, 1) + y_train = 3 * x_train + x_train = x_train.astype('float32') + y_train = y_train.astype('float32') + x_predict = [[1.], [2.], [3.], [4.]] + + # The model is built once and the initial weights are saved. + # This is used to initialize the model for both the distribution and + # non-distribution run. In addition, we add few non-linear layers to make + # it non-trivial. + def _create_model(): + model = keras.Sequential() + model.add(keras.layers.Dense(10, activation='relu', input_shape=(1,))) + model.add(keras.layers.Dense(10, activation='relu')) + model.add(keras.layers.Dense(10, activation='relu')) + model.add(keras.layers.Dense(1)) + return model + + model = _create_model() + initial_weights = model.get_weights() + del model # avoid accident usage. + + def fit_eval_and_predict(with_distribution=None): + model = _create_model() + # We have initialized the model to the same weight for the distribution + # and non-distribution run. + model.set_weights(initial_weights) + model.compile( + loss=keras.losses.mean_squared_error, + optimizer=gradient_descent_keras.SGD(0.5), + metrics=['mse'], + distribute=with_distribution) + + training_inputs, eval_inputs, predict_inputs = ( + get_correctness_test_inputs(use_numpy, use_validation_data, + with_distribution, + x_train, y_train, x_predict)) + + result = {} + result['training_history_1'] = model.fit(**training_inputs).history + + if eval_inputs is not None: + result['eval_result_1'] = model.evaluate(**eval_inputs) + + result['weights_1'] = model.get_weights() + result['predict_result_1'] = model.predict(**predict_inputs) + + # Train and eval again to mimic user's flow. + + result['training_history_2'] = model.fit(**training_inputs).history + + if eval_inputs is not None: + result['eval_result_2'] = model.evaluate(**eval_inputs) + + result['weights_2'] = model.get_weights() + + return result + + results_with_ds = fit_eval_and_predict(with_distribution=distribution) + results_without_ds = fit_eval_and_predict(with_distribution=None) + + # Verify that the weights, training history, eval results, predict outputs + # are the same within some limits of tolerance. + for key in results_with_ds: + if (key.startswith('training_history') and + isinstance(distribution, tpu_strategy.TPUStrategy) and + distribution.extended.steps_per_run > 1): + # TODO(b/119894254): Enable this test for all cases once the + # underlying bug is fixed. + continue + + tolerance = tol_table.get(key, default_tolerance) + + self.assertAllClose( + results_with_ds[key], + results_without_ds[key], + atol=tolerance, + rtol=tolerance, + msg='Fail to assert {}.'.format(key)) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/distribute/python/keras_correctness_test.py b/tensorflow/contrib/distribute/python/keras_correctness_test.py new file mode 100644 index 0000000000000000000000000000000000000000..e078731610882bfe6d5a97b1636d9a4a1325b047 --- /dev/null +++ b/tensorflow/contrib/distribute/python/keras_correctness_test.py @@ -0,0 +1,362 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Correctness tests for tf.keras using DistributionStrategy.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from absl.testing import parameterized +import numpy as np + +from tensorflow.contrib.distribute.python import combinations +from tensorflow.contrib.distribute.python import mirrored_strategy +from tensorflow.contrib.distribute.python import tpu_strategy +from tensorflow.python import keras +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.distribute import distribute_lib +from tensorflow.python.eager import test +from tensorflow.python.framework import random_seed +from tensorflow.python.keras.engine import distributed_training_utils +from tensorflow.python.keras.optimizer_v2 import gradient_descent as gradient_descent_keras +from tensorflow.python.training import gradient_descent + +_RANDOM_SEED = 1337 + +# Note: Please make sure the tests in this file are also covered in +# keras_backward_compat_test for features that are supported with both APIs. + + +def batch_wrapper(dataset, batch_size, distribution, repeat=None): + if repeat: + dataset = dataset.repeat(repeat) + # TPUs currently require fully defined input shapes, drop_remainder ensures + # the input will have fully defined shapes. + if isinstance(distribution, tpu_strategy.TPUStrategy): + return dataset.batch(batch_size, drop_remainder=True) + else: + return dataset.batch(batch_size) + + +def get_correctness_test_inputs(use_numpy, use_validation_data, + with_distribution, + x_train, y_train, x_predict): + """Generates the inputs for correctness check when enable Keras with DS.""" + training_epochs = 2 + global_batch_size = 64 + batch_size = global_batch_size + # TODO(b/118776054): Use global batch size for Keras/DS support. + use_per_core_batch_size = ( + with_distribution and + not distributed_training_utils.global_batch_size_supported( + with_distribution)) + if use_per_core_batch_size: + batch_size //= with_distribution.num_replicas_in_sync + + if use_numpy: + training_inputs = { + 'batch_size': batch_size, + 'x': x_train, + 'y': y_train, + 'epochs': training_epochs, + 'shuffle': False, + } + + if use_validation_data: + eval_inputs = None + training_inputs['validation_data'] = (x_train, y_train) + else: + eval_inputs = { + 'batch_size': batch_size, + 'x': x_train, + 'y': y_train, + } + predict_inputs = { + 'x': np.array(x_predict, dtype=np.float32), + } + else: + # For dataset inputs, we do not pass batch_size to + # keras.fit/evaluate/predict. The batch size is part of the dataset. + train_dataset = dataset_ops.Dataset.from_tensor_slices( + (x_train, y_train)) + x = batch_wrapper( + train_dataset, batch_size, with_distribution, repeat=training_epochs) + + training_inputs = { + 'batch_size': None, + 'x': x, + 'y': None, + 'epochs': training_epochs, + 'shuffle': False, + 'steps_per_epoch': len(x_train) // global_batch_size, + } + if use_validation_data: + eval_inputs = None # Remove the eval_inputs + eval_dataset = dataset_ops.Dataset.from_tensor_slices( + (x_train, y_train)) + x = batch_wrapper(eval_dataset, batch_size, with_distribution) + training_inputs['validation_data'] = x + training_inputs['validation_steps'] = 5 + else: + eval_inputs = { + 'batch_size': None, + 'x': x, + 'y': None, + 'steps': 20, + } + + predict_batch_size = len(x_predict) + if use_per_core_batch_size: + predict_batch_size //= with_distribution.num_replicas_in_sync + predict_dataset = dataset_ops.Dataset.from_tensor_slices(x_predict) + predict_dataset = batch_wrapper(predict_dataset, + predict_batch_size, with_distribution) + predict_inputs = { + 'steps': 1, + 'x': predict_dataset, + } + + return training_inputs, eval_inputs, predict_inputs + + +strategies_minus_tpu = [ + combinations.default_strategy, + combinations.one_device_strategy, + combinations.mirrored_strategy_with_gpu_and_cpu, + combinations.mirrored_strategy_with_two_gpus, + combinations.core_mirrored_strategy_with_gpu_and_cpu, + combinations.core_mirrored_strategy_with_two_gpus] + +tpu_strategies = [ + combinations.tpu_strategy, # steps_per_run=2 + combinations.tpu_strategy_one_step] + + +def strategy_minus_tpu_combinations(): + return combinations.combine( + distribution=strategies_minus_tpu, + mode=['graph', 'eager']) + + +def tpu_strategy_combinations(): + return combinations.combine( + distribution=tpu_strategies, + mode=['graph']) + + +def all_strategy_combinations(): + return strategy_minus_tpu_combinations() + tpu_strategy_combinations() + + +def strategy_and_input_combinations(): + return ( + combinations.times( + combinations.combine(distribution=strategies_minus_tpu), + combinations.combine(mode=['graph'], + use_numpy=[True, False], + use_validation_data=[True, False]) + + combinations.combine(mode=['eager'], + use_numpy=[False], + use_validation_data=[False])) + + combinations.times( + combinations.combine(distribution=tpu_strategies), + combinations.combine(mode=['graph'], + use_numpy=[True, False], + use_validation_data=[True, False]))) + + +class TestDistributionStrategyCorrectness(test.TestCase, + parameterized.TestCase): + + @combinations.generate(all_strategy_combinations()) + def test_metric_correctness(self, distribution): + with self.cached_session(): + keras.backend.set_image_data_format('channels_last') + num_samples = 10000 + + x_train = np.random.randint(0, 2, num_samples) + x_train = np.reshape(x_train, (num_samples, 1)) + y_train = x_train + x_train = x_train.astype('float32') + y_train = y_train.astype('float32') + + # Create identity model. + with distribution.scope(): + model = keras.Sequential() + model.add( + keras.layers.Dense(1, input_shape=(1,), kernel_initializer='ones')) + model.compile( + loss=keras.losses.mean_squared_error, + optimizer=gradient_descent.GradientDescentOptimizer(0.5), + metrics=[keras.metrics.BinaryAccuracy()]) + + batch_size = 64 + if not distributed_training_utils.global_batch_size_supported( + distribution): + batch_size //= distribution.num_replicas_in_sync + train_dataset = dataset_ops.Dataset.from_tensor_slices((x_train, y_train)) + train_dataset = batch_wrapper(train_dataset, batch_size, distribution) + + history = model.fit(x=train_dataset, epochs=2, steps_per_epoch=10) + self.assertEqual(history.history['binary_accuracy'], [1.0, 1.0]) + + @combinations.generate(all_strategy_combinations()) + def test_eval_metrics_correctness(self, distribution): + with self.cached_session(): + with distribution.scope(): + model = keras.Sequential() + model.add( + keras.layers.Dense( + 3, activation='relu', input_dim=4, kernel_initializer='ones')) + model.add( + keras.layers.Dense( + 1, activation='sigmoid', kernel_initializer='ones')) + model.compile( + loss='mae', + metrics=['accuracy', keras.metrics.BinaryAccuracy()], + optimizer=gradient_descent.GradientDescentOptimizer(0.001)) + + # verify correctness of stateful and stateless metrics. + x = np.ones((100, 4)).astype('float32') + y = np.ones((100, 1)).astype('float32') + dataset = dataset_ops.Dataset.from_tensor_slices((x, y)).repeat() + dataset = batch_wrapper(dataset, 4, distribution) + outs = model.evaluate(dataset, steps=10) + self.assertEqual(outs[1], 1.) + self.assertEqual(outs[2], 1.) + + y = np.zeros((100, 1)).astype('float32') + dataset = dataset_ops.Dataset.from_tensor_slices((x, y)).repeat() + dataset = batch_wrapper(dataset, 4, distribution) + outs = model.evaluate(dataset, steps=10) + self.assertEqual(outs[1], 0.) + self.assertEqual(outs[2], 0.) + + @combinations.generate(strategy_and_input_combinations()) + def test_correctness(self, distribution, use_numpy, use_validation_data): + with self.cached_session(): + default_tolerance = 1e-5 + tol_table = {} + + if isinstance(distribution, ( + mirrored_strategy.MirroredStrategy, + mirrored_strategy.CoreMirroredStrategy, + distribute_lib._DefaultDistributionStrategy)): # pylint: disable=protected-access + # TODO(b/119257215): Weights are not exactly the same, so use larger + # tolerance for now. Predict should be related to weights. + tol_table = { + 'weights_1': 1e-4, + 'weights_2': 1e-4, + 'predict_result_1': 1e-4, + } + + keras.backend.set_image_data_format('channels_last') + np.random.seed(_RANDOM_SEED) + random_seed.set_random_seed(_RANDOM_SEED) + + # Train, eval, and predict datasets are created with the same input numpy + # arrays. + # TODO(xiejw): Change this back to 10000, once we support final partial + # batch. + num_samples = 9984 + x_train = np.random.rand(num_samples, 1) + y_train = 3 * x_train + x_train = x_train.astype('float32') + y_train = y_train.astype('float32') + x_predict = [[1.], [2.], [3.], [4.]] + + # The model is built once and the initial weights are saved. + # This is used to initialize the model for both the distribution and + # non-distribution run. In addition, we add few non-linear layers to make + # it non-trivial. + def _create_model(): + model = keras.Sequential() + model.add(keras.layers.Dense(10, activation='relu', input_shape=(1,))) + model.add(keras.layers.Dense(10, activation='relu')) + model.add(keras.layers.Dense(10, activation='relu')) + model.add(keras.layers.Dense(1)) + return model + + model = _create_model() + initial_weights = model.get_weights() + del model # avoid accident usage. + + def _build_and_compile_model(): + # We have initialized the model to the same weight for the distribution + # and non-distribution run. + model = _create_model() + model.set_weights(initial_weights) + model.compile( + loss=keras.losses.mean_squared_error, + optimizer=gradient_descent_keras.SGD(0.5), + metrics=['mse']) + return model + + def fit_eval_and_predict(with_distribution=None): + if with_distribution: + with with_distribution.scope(): + model = _build_and_compile_model() + else: + model = _build_and_compile_model() + + training_inputs, eval_inputs, predict_inputs = ( + get_correctness_test_inputs(use_numpy, use_validation_data, + with_distribution, + x_train, y_train, x_predict)) + + result = {} + result['training_history_1'] = model.fit(**training_inputs).history + + if eval_inputs is not None: + result['eval_result_1'] = model.evaluate(**eval_inputs) + + result['weights_1'] = model.get_weights() + result['predict_result_1'] = model.predict(**predict_inputs) + + # Train and eval again to mimic user's flow. + + result['training_history_2'] = model.fit(**training_inputs).history + + if eval_inputs is not None: + result['eval_result_2'] = model.evaluate(**eval_inputs) + + result['weights_2'] = model.get_weights() + + return result + + results_with_ds = fit_eval_and_predict(with_distribution=distribution) + results_without_ds = fit_eval_and_predict(with_distribution=None) + + # Verify that the weights, training history, eval results, predict outputs + # are the same within some limits of tolerance. + for key in results_with_ds: + if (key.startswith('training_history') and + isinstance(distribution, tpu_strategy.TPUStrategy) and + distribution.extended.steps_per_run > 1): + # TODO(b/119894254): Enable this test for all cases once the + # underlying bug is fixed. + continue + + tolerance = tol_table.get(key, default_tolerance) + + self.assertAllClose( + results_with_ds[key], + results_without_ds[key], + atol=tolerance, + rtol=tolerance, + msg='Fail to assert {}.'.format(key)) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/distribute/python/keras_optimizer_v2_test.py b/tensorflow/contrib/distribute/python/keras_optimizer_v2_test.py index 6dfd85bcc4f3784e2744fd876a7190cc9581d96a..cce93b3c10a2ac7bd1c594a5027b9d51629bb915 100644 --- a/tensorflow/contrib/distribute/python/keras_optimizer_v2_test.py +++ b/tensorflow/contrib/distribute/python/keras_optimizer_v2_test.py @@ -18,24 +18,12 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import shutil -import tempfile from absl.testing import parameterized import numpy as np -import six from tensorflow.contrib.distribute.python import combinations -from tensorflow.core.protobuf import config_pb2 from tensorflow.python import keras -from tensorflow.python.data.ops import dataset_ops from tensorflow.python.distribute import distribution_strategy_context as ds_context -from tensorflow.python.estimator import run_config -from tensorflow.python.estimator import training -from tensorflow.python.estimator.canned import dnn_linear_combined -from tensorflow.python.estimator.canned import prediction_keys -from tensorflow.python.estimator.export import export -from tensorflow.python.estimator.inputs import numpy_io -from tensorflow.python.feature_column import feature_column_lib as feature_column from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops @@ -44,103 +32,7 @@ from tensorflow.python.keras.optimizer_v2 import gradient_descent from tensorflow.python.ops import math_ops from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variables -from tensorflow.python.platform import gfile from tensorflow.python.platform import test -from tensorflow.python.summary.writer import writer_cache - - -class KerasOptimizerV2IntegrationTest(test.TestCase, parameterized.TestCase): - - def setUp(self): - self._model_dir = tempfile.mkdtemp() - - def dataset_input_fn(self, x, y, batch_size): - - def input_fn(): - dataset = dataset_ops.Dataset.from_tensor_slices((x, y)) - dataset = dataset.repeat(1).batch(batch_size) - return dataset - - return input_fn - - @combinations.generate( - combinations.combine( - mode=['graph'], - distribution=[ - combinations.one_device_strategy, - combinations.mirrored_strategy_with_gpu_and_cpu, - combinations.mirrored_strategy_with_two_gpus, - combinations.core_mirrored_strategy_with_gpu_and_cpu, - combinations.core_mirrored_strategy_with_two_gpus - ], - use_train_and_evaluate=[True, False])) - def test_complete_flow_with_mode(self, distribution, use_train_and_evaluate): - label_dimension = 2 - input_dimension = label_dimension - batch_size = 10 - data = np.linspace(0., 2., batch_size * label_dimension, dtype=np.float32) - data = data.reshape(batch_size, label_dimension) - train_input_fn = self.dataset_input_fn( - x={'x': data}, - y=data, - batch_size=batch_size // distribution.num_replicas_in_sync) - eval_input_fn = self.dataset_input_fn( - x={'x': data}, - y=data, - batch_size=batch_size // distribution.num_replicas_in_sync) - predict_input_fn = numpy_io.numpy_input_fn( - x={'x': data}, batch_size=batch_size, shuffle=False) - - linear_feature_columns = [ - feature_column.numeric_column('x', shape=(input_dimension,)) - ] - dnn_feature_columns = [ - feature_column.numeric_column('x', shape=(input_dimension,)) - ] - feature_columns = linear_feature_columns + dnn_feature_columns - session_config = config_pb2.ConfigProto( - log_device_placement=True, allow_soft_placement=True) - estimator = dnn_linear_combined.DNNLinearCombinedRegressor( - linear_feature_columns=linear_feature_columns, - dnn_hidden_units=(2, 2), - dnn_feature_columns=dnn_feature_columns, - label_dimension=label_dimension, - model_dir=self._model_dir, - dnn_optimizer=adam.Adam(0.001), - linear_optimizer=adam.Adam(0.001), - config=run_config.RunConfig( - train_distribute=distribution, - eval_distribute=distribution, - session_config=session_config)) - - num_steps = 2 - if use_train_and_evaluate: - scores, _ = training.train_and_evaluate( - estimator, training.TrainSpec(train_input_fn, max_steps=num_steps), - training.EvalSpec(eval_input_fn)) - else: - estimator.train(train_input_fn, steps=num_steps) - scores = estimator.evaluate(eval_input_fn) - - self.assertIn('loss', six.iterkeys(scores)) - - predictions = np.array([ - x[prediction_keys.PredictionKeys.PREDICTIONS] - for x in estimator.predict(predict_input_fn) - ]) - self.assertAllEqual((batch_size, label_dimension), predictions.shape) - - feature_spec = feature_column.make_parse_example_spec(feature_columns) - serving_input_receiver_fn = export.build_parsing_serving_input_receiver_fn( - feature_spec) - export_dir = estimator.export_savedmodel(tempfile.mkdtemp(), - serving_input_receiver_fn) - self.assertTrue(gfile.Exists(export_dir)) - - def tearDown(self): - if self._model_dir: - writer_cache.FileWriterCache.clear() - shutil.rmtree(self._model_dir) def get_model(): @@ -162,7 +54,9 @@ class MirroredStrategyOptimizerV2Test(test.TestCase, parameterized.TestCase): var = variables.Variable( 2.0, name='var', aggregation=variable_scope.VariableAggregation.SUM) # grad for cpu is 1, grad for gpu is 2, avg grad is 1.5. - loss = math_ops.cast(_replica_id() + 1, dtype=dtypes.float32) * var + def loss(): + return math_ops.cast(_replica_id() + 1, dtype=dtypes.float32) * var + optimizer = adam.Adam(learning_rate=0.01, beta_1=0.2, beta_2=0.2) train_op = optimizer.minimize(loss, var_list=[var]) m = optimizer.get_slot(var, 'm') @@ -177,12 +71,12 @@ class MirroredStrategyOptimizerV2Test(test.TestCase, parameterized.TestCase): self.assertAllClose( var_val, self.evaluate( - [distribution.read_var(var), + [distribution.extended.read_var(var), var.get(devices[0]), var.get(devices[1])])) self.assertAllClose([0, 0, 0], self.evaluate([ - distribution.read_var(counter), + distribution.extended.read_var(counter), counter.get(devices[0]), counter.get(devices[1]) ])) @@ -195,7 +89,7 @@ class MirroredStrategyOptimizerV2Test(test.TestCase, parameterized.TestCase): self.assertAllClose( m_val, self.evaluate( - [distribution.read_var(m), + [distribution.extended.read_var(m), m.get(devices[0]), m.get(devices[1])])) # v(1) = beta2 * v(0) + (1-beta2) * grad^2 = 0.2 * 0 + 0.8 * 2.25 @@ -203,7 +97,7 @@ class MirroredStrategyOptimizerV2Test(test.TestCase, parameterized.TestCase): self.assertAllClose( v_val, self.evaluate( - [distribution.read_var(v), + [distribution.extended.read_var(v), v.get(devices[0]), v.get(devices[1])])) # var(1) = var(0) - lr * m(1) * sqrt(1 - beta2) / sqrt(v(1)) / (1 - beta1) @@ -212,12 +106,12 @@ class MirroredStrategyOptimizerV2Test(test.TestCase, parameterized.TestCase): self.assertAllClose( var_val, self.evaluate( - [distribution.read_var(var), + [distribution.extended.read_var(var), var.get(devices[0]), var.get(devices[1])])) self.assertAllClose([1, 1, 1], self.evaluate([ - distribution.read_var(counter), + distribution.extended.read_var(counter), counter.get(devices[0]), counter.get(devices[1]) ])) @@ -228,7 +122,7 @@ class MirroredStrategyOptimizerV2Test(test.TestCase, parameterized.TestCase): self.assertAllClose( m_val, self.evaluate( - [distribution.read_var(m), + [distribution.extended.read_var(m), m.get(devices[0]), m.get(devices[1])])) # v(2) = beta2 * v(1) + (1-beta2) * grad^2 = 0.2 * 1.8 + 0.8 * 2.25 @@ -236,12 +130,12 @@ class MirroredStrategyOptimizerV2Test(test.TestCase, parameterized.TestCase): self.assertAllClose( v_val, self.evaluate( - [distribution.read_var(v), + [distribution.extended.read_var(v), v.get(devices[0]), v.get(devices[1])])) self.assertAllClose([2, 2, 2], self.evaluate([ - distribution.read_var(counter), + distribution.extended.read_var(counter), counter.get(devices[0]), counter.get(devices[1]) ])) @@ -254,11 +148,12 @@ class MirroredStrategyOptimizerV2Test(test.TestCase, parameterized.TestCase): def testOptimizerWithKerasModelAndNumpyArrays(self, distribution): with self.cached_session(): - model = get_model() - optimizer = gradient_descent.SGD(0.001) - loss = 'mse' - metrics = ['mae'] - model.compile(optimizer, loss, metrics=metrics, distribute=distribution) + with distribution.scope(): + model = get_model() + optimizer = gradient_descent.SGD(0.001) + loss = 'mse' + metrics = ['mae'] + model.compile(optimizer, loss, metrics=metrics) inputs = np.zeros((64, 3), dtype=np.float32) targets = np.zeros((64, 4), dtype=np.float32) diff --git a/tensorflow/contrib/distribute/python/keras_test.py b/tensorflow/contrib/distribute/python/keras_test.py index e530ab6f173d568e554168b30aea01d9129dcf9b..84e9aea228352e0a6010fe95529407818d020b5f 100644 --- a/tensorflow/contrib/distribute/python/keras_test.py +++ b/tensorflow/contrib/distribute/python/keras_test.py @@ -32,7 +32,6 @@ from tensorflow.python.estimator import keras as keras_lib from tensorflow.python.estimator import run_config as run_config_lib from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes -from tensorflow.python.framework import random_seed from tensorflow.python.framework import test_util from tensorflow.python.keras import testing_utils from tensorflow.python.keras.engine import distributed_training_utils @@ -48,6 +47,9 @@ _TRAIN_SIZE = 200 _INPUT_SIZE = (10,) _NUM_CLASS = 2 +# Note: Please make sure the tests in this file are also covered in +# keras_backward_compat_test for features that are supported with both APIs. + # TODO(anjalisridhar): Add a decorator that will allow us to run these tests as # part of the tf.keras unit tests suite. @@ -165,7 +167,9 @@ def get_multi_inputs_multi_outputs_data(): return (train_data, test_data) -def batch_wrapper(dataset, batch_size, distribution): +def batch_wrapper(dataset, batch_size, distribution, repeat=None): + if repeat: + dataset = dataset.repeat(repeat) # TPUs currently require fully defined input shapes, drop_remainder ensures # the input will have fully defined shapes. if isinstance(distribution, tpu_strategy.TPUStrategy): @@ -212,85 +216,6 @@ def multi_input_output_model(): return model -def get_correctness_test_inputs(use_numpy, use_validation_data, - with_distribution, - x_train, y_train, x_predict): - """Generates the inputs for correctness check when enable Keras with DS.""" - global_batch_size = 64 - batch_size = global_batch_size - # TODO(b/118776054): Use global batch size for Keras/DS support. - use_per_core_batch_size = ( - with_distribution and - not distributed_training_utils.global_batch_size_supported( - with_distribution)) - if use_per_core_batch_size: - batch_size //= with_distribution.num_replicas_in_sync - - if use_numpy: - training_inputs = { - 'batch_size': batch_size, - 'x': x_train, - 'y': y_train, - 'epochs': 1, - 'shuffle': False, - } - - if use_validation_data: - eval_inputs = None - training_inputs['validation_data'] = (x_train, y_train) - else: - eval_inputs = { - 'batch_size': batch_size, - 'x': x_train, - 'y': y_train, - } - predict_inputs = { - 'x': np.array(x_predict, dtype=np.float32), - } - else: - # For dataset inputs, we do not pass batch_size to - # keras.fit/evaluate/predict. The batch size is part of the dataset. - train_dataset = dataset_ops.Dataset.from_tensor_slices( - (x_train, y_train)) - x = batch_wrapper(train_dataset, batch_size, with_distribution) - - training_inputs = { - 'batch_size': None, - 'x': x, - 'y': None, - 'epochs': 1, - 'shuffle': False, - 'steps_per_epoch': len(x_train) // global_batch_size, - } - if use_validation_data: - eval_inputs = None # Remove the eval_inputs - eval_dataset = dataset_ops.Dataset.from_tensor_slices( - (x_train, y_train)) - x = batch_wrapper(eval_dataset, batch_size, with_distribution) - training_inputs['validation_data'] = x - training_inputs['validation_steps'] = 5 - else: - eval_inputs = { - 'batch_size': None, - 'x': x, - 'y': None, - 'steps': 20, - } - - predict_batch_size = len(x_predict) - if use_per_core_batch_size: - predict_batch_size //= with_distribution.num_replicas_in_sync - predict_dataset = dataset_ops.Dataset.from_tensor_slices(x_predict) - predict_dataset = batch_wrapper(predict_dataset, - predict_batch_size, with_distribution) - predict_inputs = { - 'steps': 1, - 'x': predict_dataset, - } - - return training_inputs, eval_inputs, predict_inputs - - strategies_minus_tpu = [ combinations.default_strategy, combinations.one_device_strategy, @@ -331,23 +256,6 @@ def strategy_and_optimizer_combinations(): combinations.rmsprop_optimizer_v1_fn])) -def strategy_and_input_combinations(): - return ( - combinations.times( - combinations.combine(distribution=strategies_minus_tpu), - combinations.combine(mode=['graph'], - use_numpy=[True, False], - use_validation_data=[True, False]) - + combinations.combine(mode=['eager'], - use_numpy=[False], - use_validation_data=[False])) + - combinations.times( - combinations.combine(distribution=tpu_strategies), - combinations.combine(mode=['graph'], - use_numpy=[True, False], - use_validation_data=[True, False]))) - - def strategy_for_numpy_input_combinations(): return combinations.combine( distribution=strategies_minus_tpu + tpu_strategies, @@ -371,7 +279,9 @@ class TestEstimatorDistributionStrategy(test_util.TensorFlowTestCase, @combinations.generate(combinations.combine( distribution=[ + combinations.mirrored_strategy_with_gpu_and_cpu, combinations.mirrored_strategy_with_two_gpus, + combinations.core_mirrored_strategy_with_gpu_and_cpu, combinations.core_mirrored_strategy_with_two_gpus], mode=['graph'])) def test_train_functional_with_distribution_strategy(self, distribution): @@ -399,7 +309,9 @@ class TestEstimatorDistributionStrategy(test_util.TensorFlowTestCase, @combinations.generate(combinations.combine( distribution=[ + combinations.mirrored_strategy_with_gpu_and_cpu, combinations.mirrored_strategy_with_two_gpus, + combinations.core_mirrored_strategy_with_gpu_and_cpu, combinations.core_mirrored_strategy_with_two_gpus], mode=['graph'])) def test_train_sequential_with_distribution_strategy(self, distribution): @@ -426,8 +338,8 @@ class TestEstimatorDistributionStrategy(test_util.TensorFlowTestCase, @combinations.generate(combinations.combine( distribution=[ - combinations.mirrored_strategy_with_two_gpus, - combinations.core_mirrored_strategy_with_two_gpus], + combinations.mirrored_strategy_with_gpu_and_cpu, + combinations.core_mirrored_strategy_with_gpu_and_cpu], mode=['graph'])) def test_multi_inputs_multi_outputs_with_input_fn_as_dict(self, distribution): train_data, test_data = get_multi_inputs_multi_outputs_data() @@ -478,8 +390,8 @@ class TestEstimatorDistributionStrategy(test_util.TensorFlowTestCase, @combinations.generate(combinations.combine( distribution=[ - combinations.mirrored_strategy_with_two_gpus, - combinations.core_mirrored_strategy_with_two_gpus], + combinations.mirrored_strategy_with_gpu_and_cpu, + combinations.core_mirrored_strategy_with_gpu_and_cpu], mode=['graph'])) def test_keras_optimizer_with_distribution_strategy(self, distribution): keras_model = simple_sequential_model() @@ -645,12 +557,12 @@ class TestDistributionStrategyWithNumpyArrays(test.TestCase, @combinations.generate(strategy_for_numpy_input_combinations()) def test_calling_model_with_numpy_arrays(self, distribution): with self.cached_session(): - model = get_model() - - optimizer = gradient_descent.GradientDescentOptimizer(0.001) - loss = 'mse' - metrics = ['mae'] - model.compile(optimizer, loss, metrics=metrics, distribute=distribution) + with distribution.scope(): + model = get_model() + optimizer = gradient_descent.GradientDescentOptimizer(0.001) + loss = 'mse' + metrics = ['mae'] + model.compile(optimizer, loss, metrics=metrics) inputs = np.zeros((64, 3), dtype=np.float32) targets = np.zeros((64, 4), dtype=np.float32) @@ -676,11 +588,12 @@ class TestDistributionStrategyWithNumpyArrays(test.TestCase, @combinations.generate(strategy_for_numpy_input_combinations()) def test_calling_model_with_nested_numpy_arrays(self, distribution): with self.cached_session(): - model = multi_input_output_model() - - optimizer = gradient_descent.GradientDescentOptimizer(learning_rate=0.001) - loss = 'mse' - model.compile(optimizer, loss, distribute=distribution) + with distribution.scope(): + model = multi_input_output_model() + optimizer = gradient_descent.GradientDescentOptimizer( + learning_rate=0.001) + loss = 'mse' + model.compile(optimizer, loss) input_a_np = np.asarray(np.random.random((64, 3)), dtype=np.float32) input_b_np = np.asarray(np.random.random((64, 5)), dtype=np.float32) @@ -710,26 +623,29 @@ class TestDistributionStrategyWithNumpyArrays(test.TestCase, @combinations.generate(combinations.combine( distribution=strategies_minus_tpu, mode=['graph'])) def test_numpy_with_sample_weights(self, distribution): - model = get_model() - optimizer = rmsprop.RMSPropOptimizer(learning_rate=0.001) - loss = 'mse' - model.compile(optimizer, loss, distribute=distribution) + with self.cached_session(): + with distribution.scope(): + model = get_model() + optimizer = rmsprop.RMSPropOptimizer(learning_rate=0.001) + loss = 'mse' + model.compile(optimizer, loss) - inputs = np.zeros((20, 3), np.float32) - targets = np.zeros((20, 4), np.float32) - sample_weights = np.ones((20), np.float32) + inputs = np.zeros((20, 3), np.float32) + targets = np.zeros((20, 4), np.float32) + sample_weights = np.ones((20), np.float32) - model.fit(inputs, targets, sample_weight=sample_weights, epochs=1, - steps_per_epoch=2, verbose=1) + model.fit(inputs, targets, sample_weight=sample_weights, epochs=1, + steps_per_epoch=2, verbose=1) @combinations.generate(strategy_for_numpy_input_combinations()) def test_flatten_predict_outputs(self, distribution): with self.cached_session(): - model = multi_input_output_model() - - optimizer = gradient_descent.GradientDescentOptimizer(learning_rate=0.001) - loss = 'mse' - model.compile(optimizer, loss, distribute=distribution) + with distribution.scope(): + model = multi_input_output_model() + optimizer = gradient_descent.GradientDescentOptimizer( + learning_rate=0.001) + loss = 'mse' + model.compile(optimizer, loss) # We take 6 input samples with each input having a dimension of 3 or 5. input_a_np = np.asarray(np.random.random((6, 3)), dtype=np.float32) @@ -753,12 +669,12 @@ class TestDistributionStrategyWithDatasets(test.TestCase, @combinations.generate(all_strategy_combinations()) def test_calling_model_on_same_dataset(self, distribution): with self.cached_session(): - model = get_model() - - optimizer = gradient_descent.GradientDescentOptimizer(0.001) - loss = 'mse' - metrics = ['mae', keras.metrics.CategoricalAccuracy()] - model.compile(optimizer, loss, metrics=metrics, distribute=distribution) + with distribution.scope(): + model = get_model() + optimizer = gradient_descent.GradientDescentOptimizer(0.001) + loss = 'mse' + metrics = ['mae', keras.metrics.CategoricalAccuracy()] + model.compile(optimizer, loss, metrics=metrics) dataset = get_dataset(distribution) @@ -772,20 +688,19 @@ class TestDistributionStrategyWithDatasets(test.TestCase, @combinations.generate(all_strategy_combinations()) def test_model_interleaved_eval_same_as_direct_eval(self, distribution): with self.cached_session(): - user_controlled_model = get_model() - user_controlled_model.compile( - gradient_descent.GradientDescentOptimizer(0.001), - loss='mse', - metrics=['mae', keras.metrics.CategoricalAccuracy()], - distribute=distribution) - - interleaved_model = get_model() - interleaved_model.set_weights(user_controlled_model.get_weights()) - interleaved_model.compile( - gradient_descent.GradientDescentOptimizer(0.001), - loss='mse', - metrics=['mae', keras.metrics.CategoricalAccuracy()], - distribute=distribution) + with distribution.scope(): + user_controlled_model = get_model() + user_controlled_model.compile( + gradient_descent.GradientDescentOptimizer(0.001), + loss='mse', + metrics=['mae', keras.metrics.CategoricalAccuracy()]) + + interleaved_model = get_model() + interleaved_model.set_weights(user_controlled_model.get_weights()) + interleaved_model.compile( + gradient_descent.GradientDescentOptimizer(0.001), + loss='mse', + metrics=['mae', keras.metrics.CategoricalAccuracy()]) dataset = get_dataset(distribution) @@ -820,12 +735,13 @@ class TestDistributionStrategyWithDatasets(test.TestCase, mode=['graph', 'eager'])) def test_fit_with_tuple_and_dict_dataset_inputs(self, distribution): with self.cached_session(): - model = multi_input_output_model() - - optimizer = gradient_descent.GradientDescentOptimizer(learning_rate=0.001) - loss = 'mse' - metrics = ['mae', keras.metrics.CategoricalAccuracy()] - model.compile(optimizer, loss, metrics=metrics, distribute=distribution) + with distribution.scope(): + model = multi_input_output_model() + optimizer = gradient_descent.GradientDescentOptimizer( + learning_rate=0.001) + loss = 'mse' + metrics = ['mae', keras.metrics.CategoricalAccuracy()] + model.compile(optimizer, loss, metrics=metrics) input_a_np = np.random.random((10, 3)) input_b_np = np.random.random((10, 5)) @@ -852,12 +768,12 @@ class TestDistributionStrategyWithDatasets(test.TestCase, @combinations.generate(all_strategy_combinations()) def test_fit_eval_and_predict_methods_on_dataset(self, distribution): with self.cached_session(): - model = get_model() - - optimizer = gradient_descent.GradientDescentOptimizer(0.001) - loss = 'mse' - metrics = ['mae', keras.metrics.CategoricalAccuracy()] - model.compile(optimizer, loss, metrics=metrics, distribute=distribution) + with distribution.scope(): + model = get_model() + optimizer = gradient_descent.GradientDescentOptimizer(0.001) + loss = 'mse' + metrics = ['mae', keras.metrics.CategoricalAccuracy()] + model.compile(optimizer, loss, metrics=metrics) dataset = get_dataset(distribution) @@ -868,10 +784,10 @@ class TestDistributionStrategyWithDatasets(test.TestCase, @combinations.generate(strategy_and_optimizer_combinations()) def test_fit_eval_and_predict_with_optimizer(self, distribution, optimizer): with self.cached_session(): - model = get_model() - - loss = 'mse' - model.compile(optimizer(), loss, distribute=distribution) + with distribution.scope(): + model = get_model() + loss = 'mse' + model.compile(optimizer(), loss) dataset = get_dataset(distribution) @@ -881,35 +797,39 @@ class TestDistributionStrategyWithDatasets(test.TestCase, @combinations.generate(strategy_minus_tpu_combinations()) def test_dataset_with_sample_weights(self, distribution): - model = get_model() - optimizer = rmsprop.RMSPropOptimizer(learning_rate=0.001) - loss = 'mse' - model.compile(optimizer, loss, distribute=distribution) - - inputs = np.zeros((10, 3), np.float32) - targets = np.zeros((10, 4), np.float32) - sample_weights = np.ones((10), np.float32) - dataset = dataset_ops.Dataset.from_tensor_slices((inputs, targets, - sample_weights)) - dataset = dataset.repeat() - dataset = dataset.batch(10) - - model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=1) - model.evaluate(dataset, steps=2, verbose=1) - model.predict(dataset, steps=2) + with self.cached_session(): + with distribution.scope(): + model = get_model() + optimizer = rmsprop.RMSPropOptimizer(learning_rate=0.001) + loss = 'mse' + model.compile(optimizer, loss) + + inputs = np.zeros((10, 3), np.float32) + targets = np.zeros((10, 4), np.float32) + sample_weights = np.ones((10), np.float32) + dataset = dataset_ops.Dataset.from_tensor_slices((inputs, targets, + sample_weights)) + dataset = dataset.repeat() + dataset = dataset.batch(10) + + model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=1) + model.evaluate(dataset, steps=2, verbose=1) + model.predict(dataset, steps=2) @combinations.generate(combinations.combine( distribution=[ - combinations.mirrored_strategy_with_two_gpus, - combinations.core_mirrored_strategy_with_two_gpus], + combinations.mirrored_strategy_with_gpu_and_cpu, + combinations.core_mirrored_strategy_with_gpu_and_cpu], mode=['graph', 'eager'])) - def test_dataset_wrong_input_shape(self, distribution): + # TODO(b/120943676, b/120957836): Re-enable once the validation code is + # restored. + def DISABLED_test_dataset_wrong_input_shape(self, distribution): with self.cached_session(): - model = get_model() - - optimizer = rmsprop.RMSPropOptimizer(learning_rate=0.001) - loss = 'mse' - model.compile(optimizer, loss, distribute=distribution) + with distribution.scope(): + model = get_model() + optimizer = rmsprop.RMSPropOptimizer(learning_rate=0.001) + loss = 'mse' + model.compile(optimizer, loss) # Wrong input shape inputs = np.zeros((10, 5), dtype=np.float32) @@ -923,15 +843,17 @@ class TestDistributionStrategyWithDatasets(test.TestCase, model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=0) @combinations.generate(combinations.combine( - distribution=[combinations.mirrored_strategy_with_two_gpus], + distribution=[combinations.mirrored_strategy_with_gpu_and_cpu], mode=['graph', 'eager'])) - def test_dataset_no_batch_input_validation(self, distribution): + # TODO(b/120943676, b/120957836): Re-enable once the validation code is + # restored. + def DISABLED_test_dataset_no_batch_input_validation(self, distribution): with self.cached_session(): - model = get_model() - - optimizer = rmsprop.RMSPropOptimizer(learning_rate=0.001) - loss = 'mse' - model.compile(optimizer, loss, distribute=distribution) + with distribution.scope(): + model = get_model() + optimizer = rmsprop.RMSPropOptimizer(learning_rate=0.001) + loss = 'mse' + model.compile(optimizer, loss) # User forgets to batch the dataset inputs = np.zeros((10, 3), dtype=np.float32) @@ -947,11 +869,11 @@ class TestDistributionStrategyWithDatasets(test.TestCase, mode=['graph'])) def test_dataset_input_shape_fully_defined(self, distribution): with self.cached_session(): - model = get_model() - - optimizer = rmsprop.RMSPropOptimizer(learning_rate=0.001) - loss = 'mse' - model.compile(optimizer, loss, distribute=distribution) + with distribution.scope(): + model = get_model() + optimizer = rmsprop.RMSPropOptimizer(learning_rate=0.001) + loss = 'mse' + model.compile(optimizer, loss) dataset = get_dataset(distribution) # Input shapes are not fully known. Batch dimension is unknown as we are @@ -963,7 +885,9 @@ class TestDistributionStrategyWithDatasets(test.TestCase, @combinations.generate(combinations.combine( distribution=[ + combinations.mirrored_strategy_with_gpu_and_cpu, combinations.mirrored_strategy_with_two_gpus, + combinations.core_mirrored_strategy_with_gpu_and_cpu, combinations.core_mirrored_strategy_with_two_gpus], mode=['graph', 'eager'])) def test_learning_phase_value(self, distribution): @@ -971,16 +895,17 @@ class TestDistributionStrategyWithDatasets(test.TestCase, # meaningful values. Currently we don't pass the learning phase if the # Lambda layer uses the learning phase. with self.cached_session(): - x = keras.layers.Input(shape=(1,), name='input') - y = keras.layers.Dense(1, kernel_initializer='ones')(x) - z = keras.layers.Dropout(0.9999)(y) - model = keras.Model(x, z) - initial_weights = model.get_weights() + with distribution.scope(): + x = keras.layers.Input(shape=(1,), name='input') + y = keras.layers.Dense(1, kernel_initializer='ones')(x) + z = keras.layers.Dropout(0.9999)(y) + model = keras.Model(x, z) + initial_weights = model.get_weights() - optimizer = gradient_descent.GradientDescentOptimizer(0.005) - loss = 'mse' - metrics = ['acc'] - model.compile(optimizer, loss, metrics=metrics, distribute=distribution) + optimizer = gradient_descent.GradientDescentOptimizer(0.005) + loss = 'mse' + metrics = ['acc'] + model.compile(optimizer, loss, metrics=metrics) batch_size = 8 if isinstance(distribution, mirrored_strategy.CoreMirroredStrategy): @@ -994,7 +919,8 @@ class TestDistributionStrategyWithDatasets(test.TestCase, hist = model.fit(dataset, epochs=1, steps_per_epoch=20, verbose=1) self.assertAlmostEqual(hist.history['acc'][0], 0, 0) - model.set_weights(initial_weights) + with distribution.scope(): + model.set_weights(initial_weights) # TODO(psv/anjalisridhar): Enable these lines after we fix b/117431185. # evaluate_output = model.evaluate(dataset, steps=20) # self.assertAlmostEqual(evaluate_output[1], 1, 0) @@ -1008,14 +934,17 @@ class TestDistributionStrategyWithDatasets(test.TestCase, ref_output = np.ones((160, 1), dtype=np.float32) self.assertArrayNear(output, ref_output, 1e-1) - @combinations.generate(strategy_minus_tpu_combinations()) + @combinations.generate(all_strategy_combinations()) def testOptimizerWithCallbacks(self, distribution): with self.cached_session(): - model = get_model() - - optimizer = gradient_descent_keras.SGD(0.01) - loss = 'mse' - model.compile(optimizer, loss, distribute=distribution) + # TODO(b/120946189): Investigate why default strategy + eager fails. + if '_Default' in distribution.__class__.__name__: + self.skipTest('Disable the test for default strategy.') + with distribution.scope(): + model = get_model() + optimizer = gradient_descent_keras.SGD(0.01) + loss = 'mse' + model.compile(optimizer, loss) dataset = get_dataset(distribution) @@ -1024,11 +953,7 @@ class TestDistributionStrategyWithDatasets(test.TestCase, model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=0, callbacks=[keras.callbacks.LearningRateScheduler(schedule)]) - grouped_models = distribution.unwrap(model._grouped_model) - with distribution.scope(): - for m in grouped_models: - self.assertAllClose(0.001, keras.backend.get_value( - m.optimizer.lr), atol=1e-05, rtol=1e-05) + self.assertAllClose(0.001, keras.backend.get_value(model.optimizer.lr)) class TestDistributionStrategyErrorCases(test.TestCase, parameterized.TestCase): @@ -1043,16 +968,17 @@ class TestDistributionStrategyErrorCases(test.TestCase, parameterized.TestCase): with self.cached_session(): a = constant_op.constant([1, 2], shape=(1, 2)) b = constant_op.constant([[1, 2], [1, 2]], shape=(2, 2)) - x = values.DistributedValues({'/device:CPU:0': a, '/device:GPU:0': b}) - y = values.DistributedValues({'/device:CPU:0': a, '/device:GPU:0': a}) - with distribution.scope(): - # Removed device and input tensor shape details from the error message - # since the order of the device and the corresponding input tensor shape - # is not deterministic over different runs. - with self.assertRaisesRegexp(ValueError, - 'Input tensor shapes do not match for ' - 'distributed tensor inputs ' - 'DistributedValues:.+'): + device_map = values.ReplicaDeviceMap(('/device:CPU:0', '/device:GPU:0')) + x = values.DistributedValues(device_map, (a, b)) + y = values.DistributedValues(device_map, (a, a)) + # Removed device and input tensor shape details from the error message + # since the order of the device and the corresponding input tensor shape + # is not deterministic over different runs. + with self.assertRaisesRegexp(ValueError, + 'Input tensor shapes do not match for ' + 'distributed tensor inputs ' + 'DistributedValues:.+'): + with distribution.scope(): distributed_training_utils.validate_distributed_dataset_inputs( distribution, x, y) @@ -1066,32 +992,33 @@ class TestDistributionStrategyErrorCases(test.TestCase, parameterized.TestCase): with self.cached_session(): a = constant_op.constant([1, 2], shape=(1, 2), dtype=dtypes.int32) b = constant_op.constant([1, 2], shape=(1, 2), dtype=dtypes.float64) - x = values.DistributedValues({'/device:CPU:0': a, '/device:GPU:0': b}) - y = values.DistributedValues({'/device:CPU:0': a, '/device:GPU:0': a}) - with distribution.scope(): - # Removed device and input tensor dtype details from the error message - # since the order of the device and the corresponding input tensor dtype - # is not deterministic over different runs. - with self.assertRaisesRegexp(ValueError, - 'Input tensor dtypes do not match for ' - 'distributed tensor inputs ' - 'DistributedValues:.+'): + device_map = values.ReplicaDeviceMap(('/device:CPU:0', '/device:GPU:0')) + x = values.DistributedValues(device_map, (a, b)) + y = values.DistributedValues(device_map, (a, a)) + # Removed device and input tensor dtype details from the error message + # since the order of the device and the corresponding input tensor dtype + # is not deterministic over different runs. + with self.assertRaisesRegexp(ValueError, + 'Input tensor dtypes do not match for ' + 'distributed tensor inputs ' + 'DistributedValues:.+'): + with distribution.scope(): distributed_training_utils.validate_distributed_dataset_inputs( distribution, x, y) @combinations.generate(combinations.combine( distribution=[ - combinations.mirrored_strategy_with_two_gpus, - combinations.core_mirrored_strategy_with_two_gpus], + combinations.mirrored_strategy_with_gpu_and_cpu, + combinations.core_mirrored_strategy_with_gpu_and_cpu], mode=['graph', 'eager'])) def test_unsupported_features(self, distribution): with self.cached_session(): - model = get_model() - - optimizer = gradient_descent.GradientDescentOptimizer(0.001) - loss = 'mse' - metrics = ['mae'] - model.compile(optimizer, loss, metrics=metrics, distribute=distribution) + with distribution.scope(): + model = get_model() + optimizer = gradient_descent.GradientDescentOptimizer(0.001) + loss = 'mse' + metrics = ['mae'] + model.compile(optimizer, loss, metrics=metrics) dataset = get_dataset(distribution) @@ -1130,17 +1057,17 @@ class TestDistributionStrategyErrorCases(test.TestCase, parameterized.TestCase): @combinations.generate(combinations.combine( distribution=[ - combinations.mirrored_strategy_with_two_gpus, - combinations.core_mirrored_strategy_with_two_gpus], + combinations.mirrored_strategy_with_gpu_and_cpu, + combinations.core_mirrored_strategy_with_gpu_and_cpu], mode=['graph', 'eager'])) def test_calling_with_unsupported_predefined_callbacks(self, distribution): with self.cached_session(): - model = get_model() - - optimizer = gradient_descent.GradientDescentOptimizer(0.001) - loss = 'mse' - metrics = ['mae'] - model.compile(optimizer, loss, metrics=metrics, distribute=distribution) + with distribution.scope(): + model = get_model() + optimizer = gradient_descent.GradientDescentOptimizer(0.001) + loss = 'mse' + metrics = ['mae'] + model.compile(optimizer, loss, metrics=metrics) dataset = get_dataset(distribution) @@ -1157,12 +1084,6 @@ class TestDistributionStrategyErrorCases(test.TestCase, parameterized.TestCase): 'using'): model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=0, callbacks=[keras.callbacks.ReduceLROnPlateau()]) - with self.assertRaisesRegexp(ValueError, - 'histogram_freq in the TensorBoard callback ' - 'is not supported when using ' - 'DistributionStrategy.'): - model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=0, - callbacks=[keras.callbacks.TensorBoard(histogram_freq=10)]) class TestDistributionStrategyWithLossMasking(test.TestCase, @@ -1172,21 +1093,21 @@ class TestDistributionStrategyWithLossMasking(test.TestCase, # work for TPU due to some invalid datatype. @combinations.generate(combinations.combine( distribution=[ - combinations.mirrored_strategy_with_two_gpus, - combinations.core_mirrored_strategy_with_two_gpus], + combinations.mirrored_strategy_with_gpu_and_cpu, + combinations.core_mirrored_strategy_with_gpu_and_cpu], mode=['graph', 'eager'])) def test_masking(self, distribution): with self.cached_session(): np.random.seed(1337) x = np.array([[[1], [1]], [[0], [0]]]) - model = keras.models.Sequential() - model.add(keras.layers.Masking(mask_value=0, input_shape=(2, 1))) - model.add( - keras.layers.TimeDistributed( - keras.layers.Dense(1, kernel_initializer='one'))) - model.compile(loss='mse', - optimizer=gradient_descent.GradientDescentOptimizer(0.01), - distribute=distribution) + with distribution.scope(): + model = keras.models.Sequential() + model.add(keras.layers.Masking(mask_value=0, input_shape=(2, 1))) + model.add( + keras.layers.TimeDistributed( + keras.layers.Dense(1, kernel_initializer='one'))) + model.compile(loss='mse', + optimizer=gradient_descent.GradientDescentOptimizer(0.01)) y = np.array([[[1], [1]], [[1], [1]]]) dataset = dataset_ops.Dataset.from_tensor_slices((x, y)) dataset = dataset.repeat(100) @@ -1201,12 +1122,12 @@ class TestDistributionStrategyWithNormalizationLayer( @combinations.generate(all_strategy_combinations()) def test_batchnorm_correctness(self, distribution): with self.cached_session(): - model = keras.models.Sequential() - norm = keras.layers.BatchNormalization(input_shape=(10,), momentum=0.8) - model.add(norm) - model.compile(loss='mse', - optimizer=gradient_descent.GradientDescentOptimizer(0.01), - distribute=distribution) + with distribution.scope(): + model = keras.models.Sequential() + norm = keras.layers.BatchNormalization(input_shape=(10,), momentum=0.8) + model.add(norm) + model.compile(loss='mse', + optimizer=gradient_descent.GradientDescentOptimizer(0.01)) # centered on 5.0, variance 10.0 x = np.random.normal(loc=5.0, scale=10.0, size=(1000, 10)) @@ -1227,144 +1148,5 @@ class TestDistributionStrategyWithNormalizationLayer( np.testing.assert_allclose(out.std(), 1.0, atol=1e-1) -class TestDistributionStrategyCorrectness(test.TestCase, - parameterized.TestCase): - - @combinations.generate(all_strategy_combinations()) - def test_metric_correctness(self, distribution): - with self.cached_session(): - keras.backend.set_image_data_format('channels_last') - num_samples = 10000 - - x_train = np.random.randint(0, 2, num_samples) - x_train = np.reshape(x_train, (num_samples, 1)) - y_train = x_train - x_train = x_train.astype('float32') - y_train = y_train.astype('float32') - - # Create identity model. - model = keras.Sequential() - model.add( - keras.layers.Dense(1, input_shape=(1,), kernel_initializer='ones')) - model.compile( - loss=keras.losses.mean_squared_error, - optimizer=gradient_descent.GradientDescentOptimizer(0.5), - metrics=[keras.metrics.BinaryAccuracy()], - distribute=distribution) - - batch_size = 64 - if not distributed_training_utils.global_batch_size_supported( - distribution): - batch_size //= distribution.num_replicas_in_sync - train_dataset = dataset_ops.Dataset.from_tensor_slices((x_train, y_train)) - train_dataset = batch_wrapper(train_dataset, batch_size, distribution) - - history = model.fit(x=train_dataset, epochs=1, steps_per_epoch=10) - self.assertEqual(history.history['binary_accuracy'], [1.0]) - - @combinations.generate(strategy_and_input_combinations()) - def test_correctness(self, distribution, use_numpy, use_validation_data): - - with self.cached_session(): - tolerance = 1e-5 - - if isinstance(distribution, (mirrored_strategy.MirroredStrategy, - mirrored_strategy.CoreMirroredStrategy)): - # TODO(b/119257215): use the default one once the flakyness is fixed. - tolerance = 1e-4 - - if (use_validation_data and - not isinstance(distribution, tpu_strategy.TPUStrategy)): - # TODO(b/120435565): Enable tests with use_validation_data once the - # the underlying bug is fixed. - return - - keras.backend.set_image_data_format('channels_last') - np.random.seed(_RANDOM_SEED) - random_seed.set_random_seed(_RANDOM_SEED) - - # Train, eval, and predict datasets are created with the same input numpy - # arrays. - # TODO(xiejw): Change this back to 10000, once we support final partial - # batch. - num_samples = 9984 - x_train = np.random.rand(num_samples, 1) - y_train = 3 * x_train - x_train = x_train.astype('float32') - y_train = y_train.astype('float32') - x_predict = [[1.], [2.], [3.], [4.]] - - # The model is built once and the initial weights are saved. - # This is used to initialize the model for both the distribution and - # non-distribution run. In addition, we add few non-linear layers to make - # it non-trivial. - def _create_model(): - model = keras.Sequential() - model.add(keras.layers.Dense(10, activation='relu', input_shape=(1,))) - model.add(keras.layers.Dense(10, activation='relu')) - model.add(keras.layers.Dense(10, activation='relu')) - model.add(keras.layers.Dense(1)) - return model - - model = _create_model() - initial_weights = model.get_weights() - del model # avoid accident usage. - - def fit_eval_and_predict(with_distribution=None): - model = _create_model() - # We have initialized the model to the same weight for the distribution - # and non-distribution run. - model.set_weights(initial_weights) - model.compile( - loss=keras.losses.mean_squared_error, - optimizer=gradient_descent_keras.SGD(0.5), - distribute=with_distribution) - - training_inputs, eval_inputs, predict_inputs = ( - get_correctness_test_inputs(use_numpy, use_validation_data, - with_distribution, - x_train, y_train, x_predict)) - - traning_history = model.fit(**training_inputs).history - - if eval_inputs is not None: - eval_result = model.evaluate(**eval_inputs) - else: - # Creates a dummy identical eval_result to be compared later. - eval_result = 1.0 - - weights = model.get_weights() - predict_result = model.predict(**predict_inputs) - - return weights, traning_history, eval_result, predict_result - - wts_with_ds, history_with_ds, eval_with_ds, predict_with_ds = ( - fit_eval_and_predict(with_distribution=distribution)) - - (wts_without_ds, history_without_ds, eval_without_ds, - predict_without_ds) = fit_eval_and_predict(with_distribution=None) - - # Verify that the weights, training history, eval results, predict outputs - # are the same within some limits of tolerance. - self.assertAllClose( - wts_with_ds, wts_without_ds, atol=tolerance, rtol=tolerance, - msg='Fail to assert weights after training.') - - self.assertAllClose( - eval_with_ds, eval_without_ds, atol=tolerance, rtol=tolerance, - msg='Fail to assert eval results.') - self.assertAllClose( - predict_with_ds, predict_without_ds, atol=tolerance, rtol=tolerance, - msg='Fail to assert predict results.') - - if not (isinstance(distribution, tpu_strategy.TPUStrategy) - and distribution.extended.steps_per_run > 1): - # TODO(b/119894254): Enable this test for all cases once the underlying - # bug is fixed. - self.assertAllClose( - history_with_ds, history_without_ds, atol=tolerance, rtol=tolerance, - msg='Fail to assert training history.') - - if __name__ == '__main__': test.main() diff --git a/tensorflow/contrib/distribute/python/metrics_v1_test.py b/tensorflow/contrib/distribute/python/metrics_v1_test.py index 8ac659abe96370b751ed1556cc699fe20788a0fd..32a0d199434e0627122fd4e47cf8894079ef3a1e 100644 --- a/tensorflow/contrib/distribute/python/metrics_v1_test.py +++ b/tensorflow/contrib/distribute/python/metrics_v1_test.py @@ -100,7 +100,7 @@ class MetricsV1Test(test.TestCase, parameterized.TestCase): if isinstance(distribution, tpu_strategy.TPUStrategy): def step_fn(ctx, inputs): value, update = distribution.call_for_each_replica( - metric_fn, args=inputs) + metric_fn, args=(inputs,)) ctx.set_non_tensor_output(name="value", output=value) return distribution.group(update) @@ -115,7 +115,7 @@ class MetricsV1Test(test.TestCase, parameterized.TestCase): distribution.extended.steps_per_run) else: value, update = distribution.call_for_each_replica( - metric_fn, iterator.get_next()) + metric_fn, args=(iterator.get_next(),)) update = distribution.group(update) # TODO(josh11b): Once we switch to using a global batch size for input, # replace "distribution.num_replicas_in_sync" with "1". diff --git a/tensorflow/contrib/distribute/python/minimize_loss_test.py b/tensorflow/contrib/distribute/python/minimize_loss_test.py index dcc9df4cda51b87e95fb166a726170a8817715fc..824c4b09371fcc8d590f2d2b2be8f39b4a585b27 100644 --- a/tensorflow/contrib/distribute/python/minimize_loss_test.py +++ b/tensorflow/contrib/distribute/python/minimize_loss_test.py @@ -67,7 +67,7 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase): def step_fn(ctx, inputs): del ctx # Unused return distribution.group( - distribution.call_for_each_replica(model_fn, args=inputs)) + distribution.call_for_each_replica(model_fn, args=(inputs,))) iterator = self._get_iterator(distribution.distribute_dataset(dataset_fn)) @@ -161,7 +161,7 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase): def step_fn(ctx, inputs): del ctx # Unused return distribution.group( - distribution.call_for_each_replica(model_fn, args=inputs)) + distribution.call_for_each_replica(model_fn, args=(inputs,))) iterator = self._get_iterator(distribution.distribute_dataset(dataset_fn)) @@ -230,9 +230,9 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase): def step_fn(ctx, inputs): del ctx # Unused fetches = distribution.unwrap( - distribution.call_for_each_replica(model_fn, args=inputs)) + distribution.call_for_each_replica(model_fn, args=(inputs,))) if update_ops_in_cross_replica_mode: - fetches += ops.get_collection(ops.GraphKeys.UPDATE_OPS) + fetches += tuple(ops.get_collection(ops.GraphKeys.UPDATE_OPS)) return control_flow_ops.group(fetches) iterator = self._get_iterator(distribution.distribute_dataset(dataset_fn)) @@ -302,8 +302,8 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase): with distribution.scope(): all_vars = [] - def model_fn(x, y): - + def model_fn(inputs): + x, y = inputs def loss_fn(): # Use fixed initialization to make the steps deterministic. w = variable_scope.get_variable("w", initializer=[[2.]]) @@ -327,7 +327,7 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase): def step_fn(ctx, inputs): del ctx # Unused return distribution.group( - distribution.call_for_each_replica(model_fn, args=inputs)) + distribution.call_for_each_replica(model_fn, args=(inputs,))) iterator = self._get_iterator(distribution.distribute_dataset(dataset_fn)) @@ -413,7 +413,7 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase): def step_fn(output_context, inputs): (train_op, loss) = distribution.call_for_each_replica( - model_fn, args=(output_context,) + inputs) + model_fn, args=(output_context, inputs)) output_context.set_last_step_output( name="cross_replica_loss_reduced", output=loss, @@ -443,7 +443,7 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase): step_fn, iterator, iterations=2, initial_loop_values=initial_loop_values) - self.assertEqual({key1: [value1]}, ctx.non_tensor_outputs) + self.assertEqual({key1: (value1,)}, ctx.non_tensor_outputs) self._verify_loss_output( initial_loss(), loss_output=ctx.last_step_outputs["replica_loss_reduced"], diff --git a/tensorflow/contrib/distribute/python/mirrored_strategy.py b/tensorflow/contrib/distribute/python/mirrored_strategy.py index 20f1a08d4261b931a9353738147fba7d7dff9225..71e50b83b079bc73a7b178356f0f26adbd98638f 100644 --- a/tensorflow/contrib/distribute/python/mirrored_strategy.py +++ b/tensorflow/contrib/distribute/python/mirrored_strategy.py @@ -20,7 +20,6 @@ from __future__ import print_function import functools -from tensorflow.python.distribute import device_util from tensorflow.python.distribute import distribute_lib from tensorflow.python.distribute import mirrored_strategy from tensorflow.python.distribute import values @@ -28,7 +27,6 @@ from tensorflow.python.distribute import values # pylint: disable=protected-access,invalid-name _call_for_each_replica = mirrored_strategy._call_for_each_replica -_reduce_non_distributed_value = mirrored_strategy._reduce_non_distributed_value _create_mirrored_variable = mirrored_strategy._create_mirrored_variable all_local_devices = mirrored_strategy.all_local_devices CoreMirroredStrategy = mirrored_strategy.MirroredStrategy @@ -137,21 +135,16 @@ class MirroredExtended(CoreMirroredExtended): Returns: An `InputIterator` which returns inputs for each step of the computation. """ - if self._local_mode: - worker = device_util.canonicalize("/device:CPU:0") - worker_device_pairs = [(worker, self._devices)] - else: - worker_device_pairs = self._worker_devices - return values.DatasetIterator(dataset, worker_device_pairs) + return values.DatasetIterator(dataset, self._input_workers) def _distribute_dataset(self, dataset_fn): if self._local_mode: return values.PerReplicaDataset( - self._call_dataset_fn(dataset_fn), self._devices) + self._call_dataset_fn(dataset_fn), self._input_workers, 0) else: return values.MultiWorkerDataset( functools.partial(self._call_dataset_fn, dataset_fn), - self._worker_devices, + self._input_workers, auto_shard=self._auto_shard_dataset) # TODO(priyag): Delete this once all strategies use global batch size. diff --git a/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py b/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py index 66512f983e1c80e0c7937d104cd4f73bfd934eb8..a6348d2457a008f79ba4e4b580122bfc5d562c62 100644 --- a/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py +++ b/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py @@ -180,9 +180,37 @@ class MirroredStrategyVariableCreatorStackTest( variable_scope.variable_creator_scope(main_thread_creator): result = distribution.extended.call_for_each_replica(model_fn) result = distribution.unwrap(result) - expected = ["main_thread:thread_0", "main_thread:thread_1"] + expected = ("main_thread:thread_0", "main_thread:thread_1") self.assertEqual(expected, result) +@combinations.generate(combinations.combine( + distribution=[ + combinations.mirrored_strategy_with_gpu_and_cpu, + combinations.core_mirrored_strategy_with_gpu_and_cpu], + mode=["graph", "eager"])) +class MirroredStrategyCallForEachReplicaTest(test.TestCase): + + def testExecutingEagerlyOutsideFunction(self, distribution): + """Verify we preserve the value of executing_eagerly_outside_functions().""" + def model_fn(): + return ops.executing_eagerly_outside_functions() + + originally = ops.executing_eagerly_outside_functions() + with distribution.scope(): + in_scope = ops.executing_eagerly_outside_functions() + in_model_fn = distribution.extended.call_for_each_replica(model_fn) + unwrapped = distribution.unwrap(in_model_fn) + self.assertEqual(in_scope, unwrapped[0]) + self.assertEqual(in_scope, originally) + + # Verify this all again, but this time in a FuncGraph. + with func_graph.FuncGraph("fg").as_default(), distribution.scope(): + in_scope = ops.executing_eagerly_outside_functions() + in_model_fn = distribution.extended.call_for_each_replica(model_fn) + unwrapped = distribution.unwrap(in_model_fn) + self.assertEqual(in_scope, unwrapped[0]) + self.assertEqual(in_scope, originally) + @combinations.generate(combinations.combine( distribution=[ @@ -530,10 +558,8 @@ class MirroredStrategyVariableCreationTest(test.TestCase): return v with distribution.scope(): - names = values.DistributedValues({ - "/device:CPU:0": "foo", - "/device:GPU:0": "bar" - }) + device_map = values.ReplicaDeviceMap(("/device:CPU:0", "/device:GPU:0")) + names = values.DistributedValues(device_map, ("foo", "bar")) with self.assertRaises(RuntimeError): _ = distribution.extended.call_for_each_replica(model_fn, args=(names,)) @@ -667,6 +693,15 @@ class MirroredStrategyVariableCreationTest(test.TestCase): distribution.extended.worker_devices[0]).read_value())) self.assertEqual(10.0, self.evaluate(ret_v_sum)) + def testVarDistributeStrategy(self, distribution): + with distribution.scope(): + mirrored = variable_scope.variable(1.0) + replica_local = variable_scope.variable( + 1.0, + synchronization=variable_scope.VariableSynchronization.ON_READ) + self.assertIs(distribution, mirrored.distribute_strategy) + self.assertIs(distribution, replica_local.distribute_strategy) + @combinations.generate(combinations.combine( distribution=[ @@ -1095,7 +1130,7 @@ class ReplicaLocalVariableAssignTest(test.TestCase): # When we read the value using `read_var` we should see the SUM of each of # values on each of the replicas. self.assertEqual(2.0, self.evaluate( - distribution.read_var(replica_local_var))) + distribution.extended.read_var(replica_local_var))) # Assigning 6.0 in cross replica context will assign a value of # 6.0/num_replicas to each replica. tlv_ops = replica_local_var.assign(6.0) @@ -1104,7 +1139,7 @@ class ReplicaLocalVariableAssignTest(test.TestCase): # The value on all the replicas are added before being returned by # `read_var`. self.assertEqual(6.0, self.evaluate( - distribution.read_var(replica_local_var))) + distribution.extended.read_var(replica_local_var))) def testAssignReplicaLocalVarMeanAggregation(self, distribution): def model_fn(): @@ -1123,13 +1158,13 @@ class ReplicaLocalVariableAssignTest(test.TestCase): # When we read the value using `read_var` we should see the MEAN of values # on all replicas which is the value assigned in replica context. self.assertEqual(1.0, self.evaluate( - distribution.read_var(replica_local_var))) + distribution.extended.read_var(replica_local_var))) tlv_ops = replica_local_var.assign(6.0) self.evaluate(tlv_ops) # On reading the replica local var we should get the MEAN of all values # which is equal to the value assigned. self.assertEqual(6.0, self.evaluate( - distribution.read_var(replica_local_var))) + distribution.extended.read_var(replica_local_var))) class MockModel(object): @@ -1182,9 +1217,9 @@ class MirroredStrategyDefunTest(test.TestCase): result = distribution.extended.call_for_each_replica( model_fn, args=[mock_model] + inputs) - for device in devices: - device_result = values.select_device(device, result) - device_expected_result = values.select_device(device, expected_result) + for r in range(len(devices)): + device_result = values.select_replica(r, result) + device_expected_result = values.select_replica(r, expected_result) self.assertAllClose(device_expected_result, self.evaluate(device_result)) @@ -1265,9 +1300,9 @@ class MirroredStrategyDefunTest(test.TestCase): def fn1(mock_model, factor): return mock_model(factor) - factors = values.PerReplica({"CPU:0": 5.0, "GPU:0": 3.0}) - expected_result = values.PerReplica({"CPU:0": 5.0 * 1.25, - "GPU:0": 3.0 * 1.25}) + device_map = values.ReplicaDeviceMap(("/device:CPU:0", "/device:GPU:0")) + factors = values.PerReplica(device_map, (5.0, 3.0)) + expected_result = values.PerReplica(device_map, (5.0 * 1.25, 3.0 * 1.25)) self._call_and_check(distribution, fn1, [factors], expected_result, [fn1]) def testTrain(self, distribution): diff --git a/tensorflow/contrib/distribute/python/moving_averages_test.py b/tensorflow/contrib/distribute/python/moving_averages_test.py index c492d8bafc9024ed059f05b92e5466f3702726b9..8f13e9153ea7a951dd722c4549882c97e79b57fe 100644 --- a/tensorflow/contrib/distribute/python/moving_averages_test.py +++ b/tensorflow/contrib/distribute/python/moving_averages_test.py @@ -139,6 +139,27 @@ class AssignMovingAveragesTest(test.TestCase, parameterized.TestCase): (2.0 * 0.25 + 0.0) / (1.0 * 0.25 + 1.0)], var.eval()) + @combinations.generate(all_combinations) + def testAssignVariable(self, distribution): + + def replica_fn(): + var = variables.Variable([10.0, 11.0]) + # Here we expect to check the case when input value are variable. + val = variables.Variable([1., 2.]) + decay = 0.25 + assign = moving_averages.assign_moving_average( + var, val, decay, zero_debias=False) + return var, assign + + with distribution.scope(), self.cached_session() as sess: + var, assign = distribution.call_for_each_replica(replica_fn) + variables.global_variables_initializer().run() + self.assertAllClose([10.0, 11.0], var.eval()) + sess.run(distribution.unwrap(assign)) + self.assertAllClose( + [10 * 0.25 + 1. * (1 - 0.25), 11 * 0.25 + 2. * (1 - 0.25)], + var.eval()) + if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/distribute/python/multi_worker_test_base.py b/tensorflow/contrib/distribute/python/multi_worker_test_base.py index 147c9b83f866fd364ea23cf7988692a7b5f61b9c..b05aac431f65b4281d9ed9c2fa95c210d55f4008 100644 --- a/tensorflow/contrib/distribute/python/multi_worker_test_base.py +++ b/tensorflow/contrib/distribute/python/multi_worker_test_base.py @@ -40,6 +40,7 @@ from tensorflow.python.client import session from tensorflow.python.estimator import run_config from tensorflow.python.platform import test from tensorflow.python.platform import tf_logging as logging +from tensorflow.python.training import coordinator from tensorflow.python.training import server_lib ASSIGNED_PORTS = set() @@ -360,6 +361,7 @@ class IndependentWorkerTestBase(test.TestCase): self._mock_os_env = MockOsEnv() self._mock_context = test.mock.patch.object(os, 'environ', self._mock_os_env) + self._coord = coordinator.Coordinator() super(IndependentWorkerTestBase, self).setUp() self._mock_context.__enter__() @@ -368,8 +370,9 @@ class IndependentWorkerTestBase(test.TestCase): super(IndependentWorkerTestBase, self).tearDown() def _task_thread(self, task_fn, tf_config, *args, **kwargs): - os.environ['TF_CONFIG'] = json.dumps(tf_config) - task_fn(*args, **kwargs) + with self._coord.stop_on_exception(): + os.environ['TF_CONFIG'] = json.dumps(tf_config) + task_fn(*args, **kwargs) def _run_task_in_thread(self, task_fn, cluster_spec, task_type, task_id, *args, **kwargs): @@ -403,3 +406,6 @@ class IndependentWorkerTestBase(test.TestCase): *args, **kwargs) threads[task_type].append(t) return threads + + def join_independent_workers(self, worker_threads): + self._coord.join(worker_threads) diff --git a/tensorflow/contrib/distribute/python/one_device_strategy.py b/tensorflow/contrib/distribute/python/one_device_strategy.py index e322b6acb84c166a885c9aaa3002f331903a5063..5986bc4661f2615a16fcd8d5bf503f1f0dd3d504 100644 --- a/tensorflow/contrib/distribute/python/one_device_strategy.py +++ b/tensorflow/contrib/distribute/python/one_device_strategy.py @@ -51,6 +51,10 @@ class OneDeviceExtended(distribute_lib.DistributionStrategyExtended): super(OneDeviceExtended, self).__init__(container_strategy) self._device = device self._default_device = device + worker = device_util.canonicalize("/device:CPU:0") + worker_device_pairs = [(worker, [self._device])] + device_map = values.SingleDeviceMap(device) + self._input_workers = values.InputWorkers(device_map, worker_device_pairs) def _create_variable(self, next_creator, *args, **kwargs): colocate_with = kwargs.pop("colocate_with", None) @@ -60,7 +64,7 @@ class OneDeviceExtended(distribute_lib.DistributionStrategyExtended): if isinstance(colocate_with, six.string_types): with ops.device(colocate_with): return next_creator(*args, **kwargs) - if (isinstance(colocate_with, list) and len(colocate_with) == 1 and + if (isinstance(colocate_with, (list, tuple)) and len(colocate_with) == 1 and isinstance(colocate_with[0], six.string_types)): with ops.device(colocate_with[0]): return next_creator(*args, **kwargs) @@ -69,23 +73,18 @@ class OneDeviceExtended(distribute_lib.DistributionStrategyExtended): def _make_dataset_iterator(self, dataset): """Make iterator from dataset without splitting the batch.""" - worker = device_util.canonicalize("/device:CPU:0") - worker_device_pairs = [(worker, [self._device])] - return values.DatasetIterator(dataset, worker_device_pairs) + return values.DatasetIterator(dataset, self._input_workers) def _distribute_dataset(self, dataset_fn): return values.PerReplicaDataset( - self._call_dataset_fn(dataset_fn), [self._device]) + self._call_dataset_fn(dataset_fn), self._input_workers, 0) def _make_input_fn_iterator( self, input_fn, replication_mode=distribute_lib.InputReplicationMode.PER_WORKER): - worker = device_util.canonicalize("/device:CPU:0") - worker_device_pairs = [(worker, [self._device])] return values.InputFunctionIterator( - input_fn, worker_device_pairs, - [distribute_lib.InputContext()]) + input_fn, self._input_workers, [distribute_lib.InputContext()]) def _broadcast_to(self, tensor, destinations): del destinations @@ -102,10 +101,7 @@ class OneDeviceExtended(distribute_lib.DistributionStrategyExtended): def body(i, *args): """A wrapper around `fn` to create the while loop body.""" del args - fn_inputs = iterator.get_next() - if not isinstance(fn_inputs, tuple): - fn_inputs = (fn_inputs,) - fn_result = fn(ctx, fn_inputs) + fn_result = fn(ctx, iterator.get_next()) flat_last_step_outputs = nest.flatten(ctx.last_step_outputs) with ops.control_dependencies([fn_result]): return [i + 1] + flat_last_step_outputs @@ -166,7 +162,7 @@ class OneDeviceExtended(distribute_lib.DistributionStrategyExtended): return array_ops.identity(replica_local_var) def _unwrap(self, value): - return [value] + return (value,) def value_container(self, value): return value @@ -177,15 +173,15 @@ class OneDeviceExtended(distribute_lib.DistributionStrategyExtended): @property def worker_devices(self): - return [self._device] + return (self._device,) @property def parameter_devices(self): - return [self._device] + return (self._device,) def non_slot_devices(self, var_list): del var_list - return [self._device] + return (self._device,) @property def experimental_should_init(self): @@ -208,12 +204,11 @@ class OneDeviceExtended(distribute_lib.DistributionStrategyExtended): class _OneDeviceReplicaContext(distribute_lib.ReplicaContext): """ReplicaContext for OneDeviceStrategy.""" - def __init__(self, distribution_strategy): + def __init__(self, strategy): + zero = constant_op.constant(0, dtypes.int32) distribute_lib.ReplicaContext.__init__( - self, - distribution_strategy, - replica_id_in_sync_group=constant_op.constant(0, dtypes.int32)) + self, strategy, replica_id_in_sync_group=zero) @property def devices(self): - return [self._distribution_strategy.extended.worker_devices[0]] + return self._strategy.extended.worker_devices diff --git a/tensorflow/contrib/distribute/python/parameter_server_strategy.py b/tensorflow/contrib/distribute/python/parameter_server_strategy.py index eaeb4d703015fc0762359b24dc23888c01e69111..2fd0c4d6ea6f9b92c2fd0569485972c1066af9a1 100644 --- a/tensorflow/contrib/distribute/python/parameter_server_strategy.py +++ b/tensorflow/contrib/distribute/python/parameter_server_strategy.py @@ -139,22 +139,22 @@ class ParameterServerExtended(distribute_lib.DistributionStrategyExtended): "`task_type` and `task_id`") cluster_spec = multi_worker_util.normalize_cluster_spec(cluster_spec) - self._worker_device = "/job:%s/task:%d" % (self._task_type, self._task_id) + worker_device = "/job:%s/task:%d" % (self._task_type, self._task_id) # Define compute devices which is a list of device strings and one for each # replica. When there are GPUs, replicate operations on these GPUs. # Otherwise, place operations on CPU. if num_gpus_per_worker > 0: - self._compute_devices = [ - "%s/device:GPU:%d" % (self._worker_device, i) + compute_devices = tuple( + "%s/device:GPU:%d" % (worker_device, i) for i in range(num_gpus_per_worker) - ] + ) else: - self._compute_devices = [self._worker_device] + compute_devices = (worker_device,) - self._compute_devices = list( - map(device_util.resolve, self._compute_devices)) - self._canonical_compute_device_set = set(self._compute_devices) + self._device_map = values.ReplicaDeviceMap(compute_devices) + self._input_workers = values.InputWorkers( + self._device_map, [(worker_device, compute_devices)]) # In distributed mode, place variables on ps jobs in a round-robin fashion. # Note that devices returned from `replica_device_setter` are not @@ -169,19 +169,19 @@ class ParameterServerExtended(distribute_lib.DistributionStrategyExtended): raise ValueError("The cluster spec needs to have `ps` jobs.") self._variable_device = device_setter.replica_device_setter( ps_tasks=num_ps_replicas, - worker_device=self._worker_device, + worker_device=worker_device, merge_devices=True, cluster=cluster_spec) # The `_parameter_devices` is needed for the `parameter_devices` property # and is a list of all variable devices. Here parameter devices are all # tasks of the "ps" job. - self._parameter_devices = map("/job:ps/task:{}".format, - range(num_ps_replicas)) + self._parameter_devices = tuple(map("/job:ps/task:{}".format, + range(num_ps_replicas))) # Add a default device so that ops without specified devices will not end up # on other workers. - self._default_device = self._worker_device + self._default_device = worker_device self._is_chief = multi_worker_util.is_chief(cluster_spec, task_type, task_id) @@ -192,36 +192,36 @@ class ParameterServerExtended(distribute_lib.DistributionStrategyExtended): logging.info( "Multi-worker ParameterServerStrategy with " "cluster_spec = %r, task_type = %r, task_id = %r, " - "num_ps_replicas = %r, is_chief = %r, compute_devices = %r, " + "num_ps_replicas = %r, is_chief = %r, device_map = %r, " "variable_device = %r", cluster_spec.as_dict(), task_type, task_id, - num_ps_replicas, self._is_chief, self._compute_devices, + num_ps_replicas, self._is_chief, self._device_map, self._variable_device) def _initialize_local(self, num_gpus_per_worker): """Initialize internal devices for local training.""" - self._worker_device = device_util.canonicalize("/device:CPU:0") + worker_device = device_util.canonicalize("/device:CPU:0") # Define compute devices which is a list of device strings and one for each # replica. When there are GPUs, replicate operations on these GPUs. # Otherwise, place operations on CPU. if num_gpus_per_worker > 0: - self._compute_devices = list( + compute_devices = tuple( map("/device:GPU:{}".format, range(num_gpus_per_worker))) else: - self._compute_devices = [_LOCAL_CPU] + compute_devices = (_LOCAL_CPU,) - self._compute_devices = list( - map(device_util.resolve, self._compute_devices)) - self._canonical_compute_device_set = set(self._compute_devices) + self._device_map = values.ReplicaDeviceMap(compute_devices) + self._input_workers = values.InputWorkers( + self._device_map, [(worker_device, compute_devices)]) # If there is only one GPU, put everything on that GPU. Otherwise, place # variables on CPU. if num_gpus_per_worker == 1: - assert len(list(self._compute_devices)) == 1 + assert len(compute_devices) == 1 self._variable_device = _LOCAL_GPU_0 - self._parameter_devices = [_LOCAL_GPU_0] + self._parameter_devices = (_LOCAL_GPU_0,) else: self._variable_device = _LOCAL_CPU - self._parameter_devices = [_LOCAL_CPU] + self._parameter_devices = (_LOCAL_CPU,) self._is_chief = True self._cluster_spec = None @@ -230,16 +230,16 @@ class ParameterServerExtended(distribute_lib.DistributionStrategyExtended): logging.info( "ParameterServerStrategy with compute_devices = %r, " - "variable_device = %r", self._compute_devices, self._variable_device) + "variable_device = %r", compute_devices, self._variable_device) def _distribute_dataset(self, dataset_fn): """Distributes the dataset to each local GPU.""" return values.PerReplicaDataset( - self._call_dataset_fn(dataset_fn), self._compute_devices, True) + self._call_dataset_fn(dataset_fn), self._input_workers, 0, + prefetch_on_device=True) def _make_dataset_iterator(self, dataset): - worker_device_pairs = [(self._worker_device, self._compute_devices)] - return values.DatasetIterator(dataset, worker_device_pairs, + return values.DatasetIterator(dataset, self._input_workers, self._num_replicas_in_sync) def _make_input_fn_iterator( @@ -259,9 +259,8 @@ class ParameterServerExtended(distribute_lib.DistributionStrategyExtended): num_input_pipelines=num_input_pipelines, input_pipeline_id=input_pipeline_id, num_replicas_in_sync=self._num_replicas_in_sync) - worker_device_pairs = [(self._worker_device, self._compute_devices)] return values.InputFunctionIterator( - input_fn, worker_device_pairs, [input_context]) + input_fn, self._input_workers, [input_context]) def _broadcast_to(self, tensor, destinations): # This is both a fast path for Python constants, and a way to delay @@ -272,7 +271,9 @@ class ParameterServerExtended(distribute_lib.DistributionStrategyExtended): if isinstance(tensor, (float, int)): return tensor if not cross_device_ops_lib.check_destinations(destinations): - destinations = self._compute_devices + # TODO(josh11b): Use current logical device instead of 0 here. + destinations = values.LogicalDeviceSpec( + device_map=self._device_map, logical_device=0) return self._cross_device_ops.broadcast(tensor, destinations) def _allow_variable_partition(self): @@ -302,7 +303,8 @@ class ParameterServerExtended(distribute_lib.DistributionStrategyExtended): # Create and wrap the variable. v = next_creator(*args, **kwargs) - wrapped = values.AggregatingVariable(v, aggregation) + wrapped = values.AggregatingVariable( + self._container_strategy(), v, aggregation) # Add the wrapped variable to the requested collections. # The handling of eager mode and the global step matches @@ -338,7 +340,7 @@ class ParameterServerExtended(distribute_lib.DistributionStrategyExtended): def _call_for_each_replica(self, fn, args, kwargs): # pylint: disable=protected-access return mirrored_strategy._call_for_each_replica( - self._container_strategy(), fn, args, kwargs) + self._container_strategy(), self._device_map, fn, args, kwargs) def _verify_destinations_not_different_worker(self, destinations): if not self._cluster_spec: @@ -350,14 +352,14 @@ class ParameterServerExtended(distribute_lib.DistributionStrategyExtended): if d_spec.job == self._task_type and d_spec.task != self._task_id: raise ValueError( "Cannot reduce to another worker: %r, current worker is %r" % - (d, self._worker_device)) + (d, self._input_workers.worker_devices[0])) def _reduce_to(self, reduce_op, value, destinations): self._verify_destinations_not_different_worker(destinations) if not isinstance(value, values.DistributedValues): # pylint: disable=protected-access - return mirrored_strategy._reduce_non_distributed_value( - self, reduce_op, value, destinations) + return cross_device_ops_lib.reduce_non_distributed_value( + reduce_op, self._device_map, value, destinations) return self._cross_device_ops.reduce( reduce_op, value, destinations=destinations) @@ -373,7 +375,7 @@ class ParameterServerExtended(distribute_lib.DistributionStrategyExtended): def _select_fn(x): # pylint: disable=g-missing-docstring if isinstance(x, values.Mirrored): if len(x.devices) == 1: - return list(x._index.values())[0] # pylint: disable=protected-access + return x.primary else: raise ValueError( "You cannot update variable with a Mirrored object with multiple " @@ -415,11 +417,8 @@ class ParameterServerExtended(distribute_lib.DistributionStrategyExtended): def _unwrap(self, val): if isinstance(val, values.DistributedValues): - # Return in a deterministic order. - if set(val.devices) == self._canonical_compute_device_set: - return [val.get(device=d) for d in self._compute_devices] - return [val.get(device=d) for d in sorted(val.devices)] - return [val] + return val.values + return (val,) def value_container(self, val): if (hasattr(val, "_aggregating_container") and @@ -493,16 +492,19 @@ class ParameterServerExtended(distribute_lib.DistributionStrategyExtended): @property def _num_replicas_in_sync(self): - return len(self._compute_devices) + return self._device_map.num_replicas_in_graph @property def worker_devices(self): - # Make a copy to prevent users from accidentally mutating our copy. - return list(self._compute_devices) + return self._device_map.all_devices + + @property + def worker_devices_by_replica(self): + return self._device_map.devices_by_replica @property def parameter_devices(self): - return list(self._parameter_devices) + return self._parameter_devices def non_slot_devices(self, var_list): return min(var_list, key=lambda x: x.name) diff --git a/tensorflow/contrib/distribute/python/parameter_server_strategy_test.py b/tensorflow/contrib/distribute/python/parameter_server_strategy_test.py index 83d7473666a65e438a1c0119d2a12bf54e53c8fc..e6ae16d8565f9d0225e2fd1b2ffbf5e86d0ef33e 100644 --- a/tensorflow/contrib/distribute/python/parameter_server_strategy_test.py +++ b/tensorflow/contrib/distribute/python/parameter_server_strategy_test.py @@ -477,7 +477,7 @@ class ParameterServerStrategyTestBase( before_list = [] after_list = [] for g, v in g_v: - fetched = d.read_var(v) + fetched = d.extended.read_var(v) before_list.append(fetched) with ops.control_dependencies([fetched]): # TODO(yuefengz): support non-Mirrored variable as destinations. @@ -485,7 +485,7 @@ class ParameterServerStrategyTestBase( reduce_util.ReduceOp.SUM, g, destinations=v) with ops.control_dependencies( d.update(v, update, g, grouped=False)): - after_list.append(d.read_var(v)) + after_list.append(d.extended.read_var(v)) return before_list, after_list before_out, after_out = step() @@ -532,21 +532,22 @@ class ParameterServerStrategyTestBase( for expected_value in expected_values: next_element = iterator.get_next() - computed_value = sess.run( - [values.select_device(d, next_element) for d in devices]) + computed_value = sess.run([values.select_replica(r, next_element) + for r in range(len(devices))]) self.assertEqual(expected_value, computed_value) with self.assertRaises(errors.OutOfRangeError): next_element = iterator.get_next() - sess.run([values.select_device(d, next_element) for d in devices]) + sess.run([values.select_replica(r, next_element) + for r in range(len(devices))]) # After re-initializing the iterator, should be able to iterate again. sess.run(iterator.initialize()) for expected_value in expected_values: next_element = iterator.get_next() - computed_value = sess.run( - [values.select_device(d, next_element) for d in devices]) + computed_value = sess.run([values.select_replica(r, next_element) + for r in range(len(devices))]) self.assertEqual(expected_value, computed_value) @@ -715,6 +716,7 @@ class ParameterServerStrategyWithChiefTest(ParameterServerStrategyTestBase, id(get_step), get_step.__class__.__name__))) self.assertIs(values.AggregatingVariable, type(created_step)) self.assertIs(values.AggregatingVariable, type(get_step)) + self.assertIs(distribution, created_step.distribute_strategy) def testValueContainer(self): distribution = parameter_server_strategy.ParameterServerStrategy( diff --git a/tensorflow/contrib/distribute/python/step_fn.py b/tensorflow/contrib/distribute/python/step_fn.py index c928b6d9f1f21508edd753f94c38ab2723cc0a9f..faeb96bcb7c516b1e494661ef2cbe8dad476ab55 100644 --- a/tensorflow/contrib/distribute/python/step_fn.py +++ b/tensorflow/contrib/distribute/python/step_fn.py @@ -100,7 +100,7 @@ class StandardSingleLossStep(StandardInputStep): gradients_fn = optimizer_lib.get_filtered_grad_fn(gradients_fn) grads_and_vars = self.distribution.call_for_each_replica( - gradients_fn, args=(ctx,) + inputs) + gradients_fn, args=(ctx, inputs)) # If threads use layers, then we need to run the first step # sequentially, so that layers.build() is not executed in parallel. # Otherwise, multiple sets of mirrored variables are going to be diff --git a/tensorflow/contrib/distribute/python/strategy_test_lib.py b/tensorflow/contrib/distribute/python/strategy_test_lib.py index d50b142c5e9ad36522b11a77219140a7b40d9bf6..6e5280e35632d3f3cb6a4fe172a15fb7f508354c 100644 --- a/tensorflow/contrib/distribute/python/strategy_test_lib.py +++ b/tensorflow/contrib/distribute/python/strategy_test_lib.py @@ -112,7 +112,7 @@ class DistributionTestBase(test.TestCase): before_list = [] after_list = [] for g, v in g_v: - fetched = d.read_var(v) + fetched = d.extended.read_var(v) before_list.append(fetched) # control_dependencies irrelevant but harmless in eager execution with ops.control_dependencies([fetched]): @@ -120,7 +120,7 @@ class DistributionTestBase(test.TestCase): reduce_util.ReduceOp.SUM, g, destinations=v) with ops.control_dependencies(d.update( v, update, g, grouped=False)): - after_list.append(d.read_var(v)) + after_list.append(d.extended.read_var(v)) return before_list, after_list for i in range(10): @@ -168,14 +168,14 @@ class DistributionTestBase(test.TestCase): before_list = [] after_list = [] for g, v in g_v: - fetched = d.read_var(v) + fetched = d.extended.read_var(v) before_list.append(fetched) with ops.control_dependencies([fetched]): g = d.extended.reduce_to( reduce_util.ReduceOp.SUM, g, destinations=v) with ops.control_dependencies(d.update( v, update, g, grouped=False)): - after_list.append(d.read_var(v)) + after_list.append(d.extended.read_var(v)) return before_list, after_list before_out, after_out = step() @@ -254,12 +254,13 @@ class DistributionTestBase(test.TestCase): for expected_value in expected_values: next_element = iterator.get_next() computed_value = evaluate( - [values.select_device(d, next_element) for d in devices]) + [values.select_replica(r, next_element) for r in range(len(devices))]) self.assertEqual(expected_value, computed_value) with self.assertRaises(errors.OutOfRangeError): next_element = iterator.get_next() - evaluate([values.select_device(d, next_element) for d in devices]) + evaluate( + [values.select_replica(r, next_element) for r in range(len(devices))]) # After re-initializing the iterator, should be able to iterate again. evaluate(iterator.initialize()) @@ -267,7 +268,7 @@ class DistributionTestBase(test.TestCase): for expected_value in expected_values: next_element = iterator.get_next() computed_value = evaluate( - [values.select_device(d, next_element) for d in devices]) + [values.select_replica(r, next_element) for r in range(len(devices))]) self.assertEqual(expected_value, computed_value) def _test_global_step_update(self, strategy): @@ -290,4 +291,4 @@ class DistributionTestBase(test.TestCase): self.evaluate(strategy.group(train_ops)) global_step_tensors = strategy.unwrap(value) global_step_values = self.evaluate(global_step_tensors) - self.assertEqual([1] * len(global_step_tensors), global_step_values) + self.assertEqual((1,) * len(global_step_tensors), global_step_values) diff --git a/tensorflow/contrib/distribute/python/tpu_strategy.py b/tensorflow/contrib/distribute/python/tpu_strategy.py index 806ff0ac61529626e3a29b77a082e045cd479ed8..e081a735e2dcc2f84ead67d8a1e84507e46c23af 100644 --- a/tensorflow/contrib/distribute/python/tpu_strategy.py +++ b/tensorflow/contrib/distribute/python/tpu_strategy.py @@ -35,6 +35,7 @@ from tensorflow.python.distribute import device_util from tensorflow.python.distribute import distribute_lib from tensorflow.python.distribute import reduce_util from tensorflow.python.distribute import values +from tensorflow.python.distribute.cluster_resolver import tpu_cluster_resolver as resolver_lib from tensorflow.python.eager import context from tensorflow.python.eager import tape from tensorflow.python.framework import constant_op @@ -66,8 +67,9 @@ def get_tpu_system_metadata(tpu_cluster_resolver): # TODO(jhseu): Deduplicate with MirroredStrategy? -def _create_tpu_mirrored_variable(devices, real_mirrored_creator, *args, - **kwargs): # pylint: disable=g-missing-docstring +def _create_tpu_mirrored_variable( # pylint: disable=missing-docstring + strategy, device_map, logical_device, real_mirrored_creator, + *args, **kwargs): # Figure out what collections this variable should be added to. # We'll add the TPUMirroredVariable to those collections instead. collections = kwargs.pop("collections", None) @@ -97,8 +99,11 @@ def _create_tpu_mirrored_variable(devices, real_mirrored_creator, *args, # was never recorded on the tape instead of having to do this manually # here. with tape.stop_recording(): - index = real_mirrored_creator(devices, *args, **kwargs) - result = values.TPUMirroredVariable(index, index[devices[0]], aggregation) + devices = device_map.logical_to_actual_devices(logical_device) + value_list = real_mirrored_creator(devices, *args, **kwargs) + result = values.TPUMirroredVariable( + strategy, device_map, value_list, aggregation, + logical_device=logical_device) if not context.executing_eagerly(): g = ops.get_default_graph() @@ -110,7 +115,7 @@ def _create_tpu_mirrored_variable(devices, real_mirrored_creator, *args, if kwargs.get("trainable", True): collections.append(ops.GraphKeys.TRAINABLE_VARIABLES) l = g.get_collection_ref(ops.GraphKeys.TRAINABLE_VARIABLES) - for v in index.values(): + for v in value_list: l.remove(v) g.add_to_collections(collections, result) return result @@ -119,7 +124,10 @@ def _create_tpu_mirrored_variable(devices, real_mirrored_creator, *args, class TPUStrategy(distribute_lib.DistributionStrategy): """TPU distribution strategy implementation.""" - def __init__(self, tpu_cluster_resolver, steps_per_run, num_cores=None): + def __init__(self, + tpu_cluster_resolver=None, + steps_per_run=None, + num_cores=None): """Initializes the TPUStrategy object. Args: @@ -145,12 +153,26 @@ class TPUStrategy(distribute_lib.DistributionStrategy): class TPUExtended(distribute_lib.DistributionStrategyExtended): """Implementation of TPUStrategy.""" - # Track what TPU devices have been initialized. + # Track what TPU devices have been initialized. This is *intentionally* + # shared across all instances of TPUExtended as we want to keep track of which + # devices are initialized globally. _initialized_devices = [] - def __init__(self, container_strategy, tpu_cluster_resolver, steps_per_run, + def __init__(self, + container_strategy, + tpu_cluster_resolver=None, + steps_per_run=None, num_cores=None): super(TPUExtended, self).__init__(container_strategy) + + if tpu_cluster_resolver is None: + tpu_cluster_resolver = resolver_lib.TPUClusterResolver("") + + if steps_per_run is None: + # TODO(frankchn): Warn when we are being used by DS/Keras and this is + # not specified. + steps_per_run = 1 + self._tpu_cluster_resolver = tpu_cluster_resolver self._tpu_metadata = get_tpu_system_metadata(self._tpu_cluster_resolver) # TODO(sourabhbajaj): Change this from num_cores to metadata_override @@ -158,13 +180,24 @@ class TPUExtended(distribute_lib.DistributionStrategyExtended): # TODO(jhseu): Switch to DeviceAssignment to support pods and model # parallelism. - device_map = {d.name: i for i, d in enumerate(self._tpu_metadata.devices) - if "device:TPU:" in d.name} - self._device_index = values.PerReplica(device_map) + self._device_index = { + d.name: i for i, d in enumerate(self._tpu_metadata.devices) + if "device:TPU:" in d.name + } self._host_device = self.get_host_cpu_device(0) - self._tpu_devices = sorted(device_map.keys()) + self._tpu_devices = tuple(sorted(self._device_index.keys())) # Only create variables for the number of replicas we're running. self._tpu_devices = self._tpu_devices[:self._num_replicas_in_sync] + self._device_map = values.ReplicaDeviceMap(self._tpu_devices) + + # For input: + input_device_map = values.ReplicaDeviceMap(tuple( + self.get_host_cpu_device(hid) for hid in range(self.num_hosts))) + worker_devices = [ + (self.get_host(hid), [self.get_host_cpu_device(hid)]) + for hid in range(self.num_hosts) + ] + self._input_workers = values.InputWorkers(input_device_map, worker_devices) # TODO(sourabhbajaj): Remove this once performance of running one step # at a time is comparable to multiple steps. @@ -261,20 +294,13 @@ class TPUExtended(distribute_lib.DistributionStrategyExtended): def _make_dataset_iterator(self, dataset): """Make iterators for each of the TPU hosts.""" - worker_devices = [ - (self.get_host(hid), [self.get_host_cpu_device(hid)]) - for hid in range(self.num_hosts) - ] - return values.DatasetIterator(dataset, worker_devices, + return values.DatasetIterator(dataset, self._input_workers, self._num_replicas_in_sync) def _distribute_dataset(self, dataset_fn): - worker_devices = [ - (self.get_host(hid), [self.get_host_cpu_device(hid)]) - for hid in range(self.num_hosts) - ] return values.MultiWorkerDataset( - functools.partial(self._call_dataset_fn, dataset_fn), worker_devices) + functools.partial(self._call_dataset_fn, dataset_fn), + self._input_workers) # TODO(priyag): Deal with OutOfRange errors once b/111349762 is fixed. # TODO(sourabhbajaj): Remove the initial_loop_values parameter when we have @@ -307,10 +333,7 @@ class TPUExtended(distribute_lib.DistributionStrategyExtended): def run_fn(): """Single step on the TPU device.""" - fn_inputs = dequeue_fn() - if not isinstance(fn_inputs, tuple): - fn_inputs = (fn_inputs,) - fn_result = fn(ctx, fn_inputs) + fn_result = fn(ctx, dequeue_fn()) flat_last_step_outputs = nest.flatten(ctx.last_step_outputs) if flat_last_step_outputs: with ops.control_dependencies([fn_result]): @@ -417,22 +440,23 @@ class TPUExtended(distribute_lib.DistributionStrategyExtended): else: return [] - def _get_devices_from(self, colocate_with=None): - # TODO(jhseu): Change this when we support model parallelism. - return self._tpu_devices - def _create_variable(self, next_creator, *args, **kwargs): """Create a TPUMirroredVariable. See `DistributionStrategy.scope`.""" colocate_with = kwargs.pop("colocate_with", None) - devices = self._get_devices_from(colocate_with) + if colocate_with is None: + device_map = self._device_map + logical_device = 0 # TODO(josh11b): Get logical device from scope here. + else: + device_map = colocate_with.device_map + logical_device = colocate_with.logical_device def _real_mirrored_creator(devices, *args, **kwargs): # pylint: disable=g-missing-docstring - index = {} + value_list = [] for i, d in enumerate(devices): with ops.device(d): if i > 0: # Give replicas meaningful distinct names: - var0name = index[devices[0]].name.split(":")[0] + var0name = value_list[0].name.split(":")[0] # We append a / to variable names created on replicas with id > 0 to # ensure that we ignore the name scope and instead use the given # name as the absolute name of the variable. @@ -440,20 +464,21 @@ class TPUExtended(distribute_lib.DistributionStrategyExtended): # Initialize replicas with the same value: if context.executing_eagerly(): kwargs["initial_value"] = array_ops.identity( - index[devices[0]].value()) + value_list[0].value()) else: def initial_value_fn(device=d): with ops.device(device): - return array_ops.identity(index[devices[0]].initial_value) + return array_ops.identity(value_list[0].initial_value) kwargs["initial_value"] = initial_value_fn with context.context().device_policy(context.DEVICE_PLACEMENT_SILENT): v = next_creator(*args, **kwargs) assert not isinstance(v, values.TPUMirroredVariable) - index[d] = v - return index + value_list.append(v) + return value_list - return _create_tpu_mirrored_variable(devices, _real_mirrored_creator, *args, - **kwargs) + return _create_tpu_mirrored_variable( + self._container_strategy(), device_map, logical_device, + _real_mirrored_creator, *args, **kwargs) def _reduce_to(self, reduce_op, value, destinations): if values._enclosing_tpu_context() is not None: # pylint: disable=protected-access @@ -465,6 +490,14 @@ class TPUExtended(distribute_lib.DistributionStrategyExtended): "Currently only support sum & mean in TPUStrategy.") return tpu_ops.cross_replica_sum(value) + if not isinstance(value, values.DistributedValues): + # This function handles reducing values that are not PerReplica or + # Mirrored values. For example, the same value could be present on all + # replicas in which case `value` would be a single value or value could + # be 0. + return cross_device_ops_lib.reduce_non_distributed_value( + reduce_op, self._device_map, value, destinations) + # Validate that the destination is same as the host device # Note we don't do this when in replicate context as the reduction is # performed on the TPU device itself. @@ -486,19 +519,19 @@ class TPUExtended(distribute_lib.DistributionStrategyExtended): if group: return fn(var, *args, **kwargs) else: - return [fn(var, *args, **kwargs)] + return (fn(var, *args, **kwargs),) # Otherwise, we revert to MirroredStrategy behavior and update each variable # directly. - updates = {} - for d, v in var._index.items(): # pylint: disable=protected-access - name = "update_%d" % self._device_index.get(d) + updates = [] + for i, (d, v) in enumerate(zip(var.devices, var.values)): + name = "update_%d" % i with ops.device(d), distribute_lib.UpdateContext(d), ops.name_scope(name): # If args and kwargs are not mirrored, the value is returned as is. - updates[d] = fn(v, - *values.select_device_mirrored(d, args), - **values.select_device_mirrored(d, kwargs)) - return values.update_regroup(self, updates, group) + updates.append(fn(v, + *values.select_device_mirrored(d, args), + **values.select_device_mirrored(d, kwargs))) + return values.update_regroup(self, self._device_map, updates, group) def read_var(self, var): assert isinstance(var, values.TPUMirroredVariable) @@ -507,13 +540,18 @@ class TPUExtended(distribute_lib.DistributionStrategyExtended): def _unwrap(self, val): if isinstance(val, values.DistributedValues): # Return in a deterministic order. - return [val.get(device=d) for d in sorted(val.devices)] + return tuple(val.get(device=d) for d in sorted(val.devices)) elif isinstance(val, list): # TODO(josh11b): We need to remove this case; per device values should # be represented using a PerReplica wrapper instead of a list with # one entry per device. - return val - return [val] + return tuple(val) + elif isinstance(val, values.TPUMirroredVariable): + # pylint: disable=protected-access + if values._enclosing_tpu_context() is not None: + return (val,) + return val.values + return (val,) def value_container(self, value): return value @@ -606,17 +644,16 @@ class TPUExtended(distribute_lib.DistributionStrategyExtended): class _TPUReplicaContext(distribute_lib.ReplicaContext): """Replication Context class for TPU Strategy.""" - # TODO(sourabhbajaj): Call for each tower should be updating this. - def __init__(self, distribution_strategy): + # TODO(sourabhbajaj): Call for each replica should be updating this. + def __init__(self, strategy): + # TODO(b/118385803): properly initialize replica_id, instead of always 0 + replica_id = constant_op.constant(0, dtypes.int32) distribute_lib.ReplicaContext.__init__( - self, - distribution_strategy, - # TODO(b/118385803): properly initialize replica_id, instead of always 0 - replica_id_in_sync_group=constant_op.constant(0, dtypes.int32)) + self, strategy, replica_id_in_sync_group=replica_id) @property def devices(self): distribute_lib.require_replica_context(self) - ds = self._distribution_strategy + ds = self._strategy replica_id = tensor_util.constant_value(self._replica_id_in_sync_group) - return [ds.extended.worker_devices[replica_id]] + return (ds.extended.worker_devices[replica_id],) diff --git a/tensorflow/contrib/distribute/python/values_test.py b/tensorflow/contrib/distribute/python/values_test.py index 538b859f3d1ece55b460f6dbf8f01540a6013381..73efb524b93a367d98395d4e83ac4bf136318a27 100644 --- a/tensorflow/contrib/distribute/python/values_test.py +++ b/tensorflow/contrib/distribute/python/values_test.py @@ -24,6 +24,7 @@ from absl.testing import parameterized from tensorflow.contrib.distribute.python import combinations from tensorflow.contrib.distribute.python import multi_worker_test_base from tensorflow.core.protobuf import config_pb2 +from tensorflow.python.data.experimental.ops import batching from tensorflow.python.data.ops import dataset_ops from tensorflow.python.distribute import device_util from tensorflow.python.distribute import distribute_lib @@ -51,7 +52,8 @@ class DistributedValuesTest(test.TestCase): with ops.device("/device:CPU:0"): one = constant_op.constant(1) two = constant_op.constant(2) - v = values.DistributedValues({"/device:CPU:0": one, "/device:GPU:0": two}) + device_map = values.ReplicaDeviceMap(("/device:CPU:0", "/device:GPU:0")) + v = values.DistributedValues(device_map, (one, two)) self.assertEqual(two, v.get("/device:GPU:0")) self.assertEqual(one, v.get()) with self.assertRaises(ValueError): @@ -63,24 +65,26 @@ class DistributedValuesTest(test.TestCase): ops.device("/device:CPU:0"): one = constant_op.constant(1) two = constant_op.constant(2) - v = values.DistributedValues({"/device:CPU:0": one, "/device:GPU:0": two}) + device_map = values.ReplicaDeviceMap(("/device:CPU:0", "/device:GPU:0")) + v = values.DistributedValues(device_map, (one, two)) self.assertEqual(two, v.get("/device:GPU:0")) self.assertEqual(one, v.get()) with self.assertRaises(ValueError): self.assertIsNone(v.get("/device:GPU:2")) def testCanonicalization(self): - canonical_cpu = ["/job:localhost/replica:0/task:0/device:CPU:0"] - v = values.DistributedValues({"": 42}) - self.assertEqual(canonical_cpu, list(v._index.keys())) - v = values.DistributedValues({"/device:CPU:0": 42}) - self.assertEqual(canonical_cpu, list(v._index.keys())) - v = values.DistributedValues({"/cpu:0": 42}) - self.assertEqual(canonical_cpu, list(v._index.keys())) - v = values.DistributedValues({"/CPU:0": 42}) - self.assertEqual(canonical_cpu, list(v._index.keys())) + canonical_cpu = ("/job:localhost/replica:0/task:0/device:CPU:0",) + v = values.DistributedValues(values.SingleDeviceMap(""), (42,)) + self.assertEqual(canonical_cpu, v.devices) + v = values.DistributedValues(values.SingleDeviceMap("/device:CPU:0"), (42,)) + self.assertEqual(canonical_cpu, v.devices) + v = values.DistributedValues(values.SingleDeviceMap("/cpu:0"), (42,)) + self.assertEqual(canonical_cpu, v.devices) + v = values.DistributedValues(values.SingleDeviceMap("/CPU:0"), (42,)) + self.assertEqual(canonical_cpu, v.devices) with self.assertRaises(AssertionError): - v = values.DistributedValues({"/device:cpu:0": 42}) + v = values.DistributedValues( + values.SingleDeviceMap("/device:cpu:0"), (42,)) def testIsTensorLike(self): with context.graph_mode(), \ @@ -88,7 +92,8 @@ class DistributedValuesTest(test.TestCase): ops.device("/device:CPU:0"): one = constant_op.constant(1) two = constant_op.constant(2) - v = values.DistributedValues({"/device:CPU:0": one, "/device:GPU:0": two}) + device_map = values.ReplicaDeviceMap(("/device:CPU:0", "/device:GPU:0")) + v = values.DistributedValues(device_map, (one, two)) self.assertEqual(two, v.get("/device:GPU:0")) self.assertEqual(one, v.get()) self.assertTrue(v.is_tensor_like) @@ -100,7 +105,8 @@ class DistributedValuesTest(test.TestCase): ops.device("/device:CPU:0"): one = constant_op.constant(1) two = 2.0 - v = values.DistributedValues({"/device:CPU:0": one, "/device:GPU:0": two}) + device_map = values.ReplicaDeviceMap(("/device:CPU:0", "/device:GPU:0")) + v = values.DistributedValues(device_map, (one, two)) self.assertEqual(two, v.get("/device:GPU:0")) self.assertEqual(one, v.get()) self.assertFalse(v.is_tensor_like) @@ -118,8 +124,8 @@ class DistributedDelegateTest(test.TestCase): def __init__(self, x): self.x = x - v = values.DistributedDelegate( - {"/device:CPU:0": Foo(7), "/device:GPU:0": Foo(8)}) + device_map = values.ReplicaDeviceMap(("/device:CPU:0", "/device:GPU:0")) + v = values.DistributedDelegate(device_map, (Foo(7), Foo(8))) self.assertEqual(7, v.x) with self.assertRaises(AttributeError): _ = v.y @@ -127,7 +133,8 @@ class DistributedDelegateTest(test.TestCase): @test_util.run_in_graph_and_eager_modes def testOperatorOverride(self): with ops.device("/device:CPU:0"): - v = values.DistributedDelegate({"/device:CPU:0": 7, "/device:GPU:0": 8}) + device_map = values.ReplicaDeviceMap(("/device:CPU:0", "/device:GPU:0")) + v = values.DistributedDelegate(device_map, (7, 8)) # v should act like int(7). self.assertEqual(8, v + 1) self.assertEqual(10, 3 + v) @@ -178,16 +185,15 @@ def _nested_value(d): def _make_mirrored(): v = [] - index = {} devices = ["/device:GPU:0", "/device:CPU:0"] for d, n, init in zip(devices, ["v", "v/replica"], [1., 2.]): with ops.device(d): v.append(variable_scope.get_variable( name=n, initializer=init, use_resource=True)) - index[d] = v[-1] - mirrored = values.MirroredVariable(index, v[0], + device_map = values.ReplicaDeviceMap(devices) + mirrored = values.MirroredVariable(None, device_map, v, variable_scope.VariableAggregation.SUM) - return v, devices, mirrored + return v, device_map, mirrored class RegroupAndSelectDeviceTest(test.TestCase): @@ -204,8 +210,9 @@ class RegroupAndSelectDeviceTest(test.TestCase): self.assertEqual(expected[i], result.get(_device_str(i))) def testNested(self): - result = values.regroup({_device_str(0): _nested_value("1"), - _device_str(1): _nested_value("2")}) + device_map = values.ReplicaDeviceMap((_device_str(0), _device_str(1))) + result = values.regroup(device_map, + (_nested_value("1"), _nested_value("2"))) self.assertIsInstance(result, tuple) self.assertEqual(3, len(result)) self._is_per_replica(result[0], ["a1", "a2"]) @@ -221,11 +228,11 @@ class RegroupAndSelectDeviceTest(test.TestCase): self._is_per_replica(result[1][1]["c"], ["d1", "d2"]) self._is_per_replica(result[1][1]["e"], ["f1", "f2"]) - # Also test that we can undo the merge using select_device() + # Also test that we can undo the merge using select_replica() self.assertEqual(_nested_value("1"), - values.select_device(_device_str(0), result)) + values.select_replica(0, result)) self.assertEqual(_nested_value("2"), - values.select_device(_device_str(1), result)) + values.select_replica(1, result)) # select_device_mirrored() should fail due to non-mirrored values with self.assertRaises(TypeError): values.select_device_mirrored(_device_str(0), result) @@ -235,8 +242,9 @@ class RegroupAndSelectDeviceTest(test.TestCase): def testWrapClass(self): # Normally a mirrored value would be the same across devices, but # for a test it is convenient to be able to tell the values apart. - result = values.regroup({_device_str(0): _nested_value("1"), - _device_str(1): _nested_value("2")}, + device_map = values.ReplicaDeviceMap((_device_str(0), _device_str(1))) + result = values.regroup(device_map, + (_nested_value("1"), _nested_value("2")), values.Mirrored) self.assertIsInstance(result, tuple) self.assertEqual(3, len(result)) @@ -253,11 +261,11 @@ class RegroupAndSelectDeviceTest(test.TestCase): self._is_per_replica(result[1][1]["c"], ["d1", "d2"], values.Mirrored) self._is_per_replica(result[1][1]["e"], ["f1", "f2"], values.Mirrored) - # Also test that we can undo the merge using select_device() + # Also test that we can undo the merge using select_replica() self.assertEqual(_nested_value("1"), - values.select_device(_device_str(0), result)) + values.select_replica(0, result)) self.assertEqual(_nested_value("2"), - values.select_device(_device_str(1), result)) + values.select_replica(1, result)) # Values are marked as mirrored, so select_device_mirrored() is allowed. self.assertEqual(_nested_value("1"), values.select_device_mirrored(_device_str(0), result)) @@ -267,63 +275,66 @@ class RegroupAndSelectDeviceTest(test.TestCase): def testMirroredContainer(self): if context.num_gpus() < 1 and context.executing_eagerly(): self.skipTest("A GPU is not available for this test in eager mode.") - v, devices, mirrored = _make_mirrored() - result = values.regroup(dict(zip(devices, v))) + v, device_map, mirrored = _make_mirrored() + result = values.regroup(device_map, v) self.assertIs(mirrored, result) def testSameId(self): foo = object() - result = values.regroup({_device_str(0): ("a", foo), - _device_str(1): ("b", foo)}) + device_map = values.ReplicaDeviceMap((_device_str(0), _device_str(1))) + result = values.regroup(device_map, (("a", foo), ("b", foo))) self.assertIsInstance(result, tuple) self.assertEqual(2, len(result)) self._is_per_replica(result[0], ["a", "b"]) self.assertIs(foo, result[1]) - # Test select_device(), should undo the merge done by regroup(). - result_0 = values.select_device(_device_str(0), result) + # Test select_replica(), should undo the merge done by regroup(). + result_0 = values.select_replica(0, result) self.assertIsInstance(result_0, tuple) self.assertEqual(2, len(result_0)) self.assertEqual("a", result_0[0]) self.assertIs(foo, result_0[1]) - result_1 = values.select_device(_device_str(1), result) + result_1 = values.select_replica(1, result) self.assertIsInstance(result_1, tuple) self.assertEqual(2, len(result_1)) self.assertEqual("b", result_1[0]) self.assertIs(foo, result_1[1]) def testOneDevice(self): - result = values.regroup({_device_str(0): _nested_value("1")}) - # On one device regroup() and select_device() are basically identity. + device_map = values.ReplicaDeviceMap((_device_str(0),)) + result = values.regroup(device_map, (_nested_value("1"),)) + # On one device regroup() and select_replica() are basically identity. self.assertEqual(_nested_value("1"), result) self.assertEqual(_nested_value("1"), - values.select_device(_device_str(0), result)) + values.select_replica(0, result)) # The one exception has to do with MirroredVariables. d = "/device:CPU:0" with ops.device(d): v = variable_scope.get_variable( name="v", initializer=1., use_resource=True) - index = {d: v} - mirrored = values.MirroredVariable(index, v, + device_map = values.ReplicaDeviceMap((d,)) + mirrored = values.MirroredVariable(None, device_map, (v,), variable_scope.VariableAggregation.SUM) - result = values.regroup(index) + result = values.regroup(device_map, (v,)) self.assertIs(mirrored, result) def testNamedTupleEstimatorSpec(self): with context.graph_mode(), ops.Graph().as_default(): - created_estimator_specs = {} - to_regroup = {} + devices = [] + created_estimator_specs = [] for device_id in range(3): spec = model_fn_lib.EstimatorSpec( mode=model_fn_lib.ModeKeys.TRAIN, loss=constant_op.constant(device_id / 2), train_op=array_ops.identity(constant_op.constant(device_id))) - created_estimator_specs[device_id] = spec - to_regroup[_device_str(device_id)] = spec + devices.append(_device_str(device_id)) + created_estimator_specs.append(spec) - merged_estimator_spec = values.regroup(to_regroup) + device_map = values.ReplicaDeviceMap(devices) + merged_estimator_spec = values.regroup( + device_map, created_estimator_specs) self.assertTrue( isinstance(merged_estimator_spec, model_fn_lib.EstimatorSpec)) @@ -337,10 +348,10 @@ class RegroupAndSelectDeviceTest(test.TestCase): # Scaffold is populated by `EstimatorSpec.__new__`. self.assertEqual(created_estimator_specs[device_id].scaffold, merged_estimator_spec.scaffold.get(d)) - # Also test that we can undo the merge using select_device() + # Also test that we can undo the merge using select_replica() self.assertEqual(created_estimator_specs[device_id], - values.select_device(_device_str(device_id), - merged_estimator_spec)) + values.select_replica(device_id, + merged_estimator_spec)) class PerReplicaDatasetTest(test.TestCase): @@ -349,7 +360,9 @@ class PerReplicaDatasetTest(test.TestCase): config.allow_soft_placement = True def _test_iterator(self, devices, dataset, expected_values): - per_replica_dataset = values.PerReplicaDataset(dataset, devices) + device_map = values.ReplicaDeviceMap(devices) + input_workers = values.InputWorkers(device_map) + per_replica_dataset = values.PerReplicaDataset(dataset, input_workers, 0) if context.executing_eagerly(): iterator = per_replica_dataset.make_one_shot_iterator() else: @@ -357,15 +370,13 @@ class PerReplicaDatasetTest(test.TestCase): self.evaluate([iterator.initializer]) for expected_value in expected_values: - next_element = iterator.get_next() - computed_value = self.evaluate( - [values.select_device(d, next_element) for d in devices]) + next_element = iterator.get_next_as_list() + computed_value = self.evaluate(next_element) self.assertEqual(expected_value, computed_value) with self.assertRaises(errors.OutOfRangeError): - next_element = iterator.get_next() - self.evaluate([ - values.select_device(d, next_element) for d in devices]) + next_element = iterator.get_next_as_list() + self.evaluate(next_element) @test_util.run_in_graph_and_eager_modes def testOneDevice(self): @@ -421,11 +432,13 @@ class PerReplicaDatasetTest(test.TestCase): dataset = dataset_ops.Dataset.from_tensor_slices( random_ops.random_uniform((10,))) - per_replica_dataset = values.PerReplicaDataset(dataset, devices) + device_map = values.ReplicaDeviceMap(devices) + input_workers = values.InputWorkers(device_map) + per_replica_dataset = values.PerReplicaDataset(dataset, input_workers, 0) iterator = per_replica_dataset.make_initializable_iterator() self.evaluate(iterator.initializer) - next_element = iterator.get_next() + next_element = iterator.get_next_as_list() for _ in range(10): self.evaluate(next_element) @@ -443,35 +456,39 @@ class MultiWorkerDatasetTest(multi_worker_test_base.MultiWorkerTestBase): def _test_iterator(self, sess, iterator, devices, expected_values): next_element = iterator.get_next() - for device in devices: - v = values.select_device(device, next_element) + for r, device in enumerate(devices): + v = values.select_replica(r, next_element) # The `v` here can be a tuple. for element in nest.flatten(v): self.assertTrue(element.device in device) for expected_value in expected_values: - actual = sess.run( - [values.select_device(d, next_element) for d in devices]) + t = [values.select_replica(r, next_element) for r in range(len(devices))] + actual = sess.run(t) self.assertEqual(expected_value, actual) with self.assertRaises(errors.OutOfRangeError): - sess.run([values.select_device(d, next_element) for d in devices]) + sess.run([values.select_replica(r, next_element) + for r in range(len(devices))]) def _test_dataset(self, dataset_fn, worker_devices, devices, - expected_values, auto_shard=True): + expected_values): + device_map = values.ReplicaDeviceMap(devices) + input_workers = values.InputWorkers(device_map, worker_devices) multi_worker_dataset = values.MultiWorkerDataset( - dataset_fn, worker_devices, auto_shard=auto_shard) + dataset_fn, input_workers) multi_worker_iterator = multi_worker_dataset.make_initializable_iterator() with self.cached_session() as sess: sess.run(multi_worker_iterator.initializer) self._test_iterator(sess, multi_worker_iterator, devices, expected_values) def _cpu_devices(self): - worker_devices = [ + worker_devices = ( ("/job:worker/replica:0/task:0", ["/job:worker/replica:0/task:0/device:CPU:0"]), ("/job:worker/replica:0/task:1", - ["/job:worker/replica:0/task:1/device:CPU:0"])] + ["/job:worker/replica:0/task:1/device:CPU:0"]) + ) devices = [ "/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:1/device:CPU:0" @@ -479,16 +496,16 @@ class MultiWorkerDatasetTest(multi_worker_test_base.MultiWorkerTestBase): return worker_devices, devices def _cpu_and_one_gpu_devices(self): - worker_devices = [ - ("/job:worker/replica:0/task:0", [ + worker_devices = ( + ("/job:worker/replica:0/task:0", ( "/job:worker/replica:0/task:0/device:GPU:0", "/job:worker/replica:0/task:0/device:CPU:0" - ]), - ("/job:worker/replica:0/task:1", [ + )), + ("/job:worker/replica:0/task:1", ( "/job:worker/replica:0/task:1/device:GPU:0", "/job:worker/replica:0/task:1/device:CPU:0" - ]) - ] + )) + ) devices = [ "/job:worker/replica:0/task:0/device:GPU:0", "/job:worker/replica:0/task:0/device:CPU:0", @@ -501,16 +518,9 @@ class MultiWorkerDatasetTest(multi_worker_test_base.MultiWorkerTestBase): worker_devices, devices = self._cpu_devices() with context.graph_mode(): dataset_fn = lambda: dataset_ops.Dataset.range(8) - self._test_dataset(dataset_fn, worker_devices, devices, - [[0, 1], [2, 3], [4, 5], [6, 7]]) - - def testDataDistributionNoAutoShard(self): - worker_devices, devices = self._cpu_devices() - with context.graph_mode(): - dataset_fn = lambda: dataset_ops.Dataset.range(4) - self._test_dataset(dataset_fn, worker_devices, devices, - [[0, 0], [1, 1], [2, 2], [3, 3]], - auto_shard=False) + self._test_dataset( + dataset_fn, worker_devices, devices, + [[0, 0], [1, 1], [2, 2], [3, 3], [4, 4], [5, 5], [6, 6], [7, 7]]) def testDataDistributionTwoDevicePerWorker(self): if context.num_gpus() < 1: @@ -518,8 +528,9 @@ class MultiWorkerDatasetTest(multi_worker_test_base.MultiWorkerTestBase): worker_devices, devices = self._cpu_and_one_gpu_devices() with context.graph_mode(): dataset_fn = lambda: dataset_ops.Dataset.range(8) - self._test_dataset(dataset_fn, worker_devices, devices, - [[0, 2, 1, 3], [4, 6, 5, 7]]) + self._test_dataset( + dataset_fn, worker_devices, devices, + [[0, 1, 0, 1], [2, 3, 2, 3], [4, 5, 4, 5], [6, 7, 6, 7]]) def testTupleDataset(self): worker_devices, devices = self._cpu_devices() @@ -531,9 +542,7 @@ class MultiWorkerDatasetTest(multi_worker_test_base.MultiWorkerTestBase): dataset2 = dataset_ops.Dataset.range(8).map(lambda x: x**2) return dataset_ops.Dataset.zip((dataset1, dataset2)) - expected_values = [ - [(i, i**2), (i + 1, (i + 1)**2)] for i in range(0, 8, 2) - ] + expected_values = [[(i, i**2), (i, i**2)] for i in range(8)] self._test_dataset(dataset_fn, worker_devices, devices, expected_values) @@ -541,34 +550,38 @@ class MultiWorkerDatasetTest(multi_worker_test_base.MultiWorkerTestBase): worker_devices, devices = self._cpu_devices() with context.graph_mode(), self.cached_session() as sess: dataset_fn = lambda: dataset_ops.Dataset.range(8) + device_map = values.ReplicaDeviceMap(devices) + input_workers = values.InputWorkers(device_map, worker_devices) multi_worker_dataset = values.MultiWorkerDataset( - dataset_fn, worker_devices, auto_shard=True) + dataset_fn, input_workers) multi_worker_iterator = multi_worker_dataset.make_initializable_iterator() sess.run(multi_worker_iterator.initializer) - self._test_iterator(sess, multi_worker_iterator, devices, - [[0, 1], [2, 3], [4, 5], [6, 7]]) + self._test_iterator( + sess, multi_worker_iterator, devices, + [[0, 0], [1, 1], [2, 2], [3, 3], [4, 4], [5, 5], [6, 6], [7, 7]]) # After re-initializing the iterator, should be able to iterate again. sess.run(multi_worker_iterator.initializer) - self._test_iterator(sess, multi_worker_iterator, devices, - [[0, 1], [2, 3], [4, 5], [6, 7]]) + self._test_iterator( + sess, multi_worker_iterator, devices, + [[0, 0], [1, 1], [2, 2], [3, 3], [4, 4], [5, 5], [6, 6], [7, 7]]) def testValueErrorForIterator(self): # Incompatiable arguments. + d1 = "/device:GPU:0" + d2 = "/device:GPU:1" + device_map = values.ReplicaDeviceMap([d1, d2]) + input_workers = values.InputWorkers( + device_map, (("w1", (d1,)), ("w2", (d2,)))) with self.assertRaises(ValueError): - values.MultiWorkerDataIterator({"w1": None}, {"w1": "d1", "w2": "d2"}) + values.MultiWorkerDataIterator([("w1", None)], input_workers) - # Test duplicated devices under same worker. - worker_devices, _ = self._cpu_devices() - worker_devices[0][1].append("/job:worker/replica:0/task:0/device:CPU:0") - with context.graph_mode(): - dataset_fn = lambda: dataset_ops.Dataset.range(8) - multi_worker_dataset = values.MultiWorkerDataset( - dataset_fn, worker_devices, auto_shard=True) - multi_worker_iterator = multi_worker_dataset.make_initializable_iterator() - with self.assertRaises(ValueError): - multi_worker_iterator.get_next() + def testDuplicateDevices(self): + _, devices = self._cpu_devices() + devices.append("/job:worker/replica:0/task:0/device:CPU:0") + with self.assertRaises(ValueError): + _ = values.ReplicaDeviceMap(devices) class InputIteratorTestBase(test.TestCase): @@ -576,16 +589,18 @@ class InputIteratorTestBase(test.TestCase): def _test_iterator(self, input_type, dataset_fn, worker_device_pairs, expected_values, sess=None, split_batch_by=None): devices = nest.flatten([ds for _, ds in worker_device_pairs]) + device_map = values.ReplicaDeviceMap(devices) + input_workers = values.InputWorkers(device_map, worker_device_pairs) if input_type == "input_fn": input_contexts = [ distribute_lib.InputContext() for _ in worker_device_pairs] input_fn = lambda _: dataset_fn() - iterator = values.InputFunctionIterator(input_fn, worker_device_pairs, - input_contexts) + iterator = values.InputFunctionIterator( + input_fn, input_workers, input_contexts) else: - iterator = values.DatasetIterator(dataset_fn(), worker_device_pairs, - split_batch_by) + iterator = values.DatasetIterator( + dataset_fn(), input_workers, split_batch_by) evaluate = lambda x: sess.run(x) if sess else self.evaluate(x) @@ -594,12 +609,13 @@ class InputIteratorTestBase(test.TestCase): for expected_value in expected_values: next_element = iterator.get_next() computed_value = evaluate( - [values.select_device(d, next_element) for d in devices]) + [values.select_replica(r, next_element) for r in range(len(devices))]) self.assertAllEqual(expected_value, computed_value) with self.assertRaises(errors.OutOfRangeError): next_element = iterator.get_next() - evaluate([values.select_device(d, next_element) for d in devices]) + evaluate([values.select_replica(r, next_element) + for r in range(len(devices))]) # After re-initializing the iterator, should be able to iterate again. evaluate(control_flow_ops.group(iterator.initialize())) @@ -607,7 +623,7 @@ class InputIteratorTestBase(test.TestCase): for expected_value in expected_values: next_element = iterator.get_next() computed_value = evaluate( - [values.select_device(d, next_element) for d in devices]) + [values.select_replica(r, next_element) for r in range(len(devices))]) self.assertAllEqual(expected_value, computed_value) @@ -748,6 +764,34 @@ class InputIteratorMultiWorkerTest( expected_values, sess) +class SplitDatasetBatchTest(test.TestCase): + + def testBatchDataset(self): + dataset = dataset_ops.Dataset.range(100).batch(20) + split_batch_by = 2 + result_dataset = values._split_dataset_batch(dataset, split_batch_by) + expected_values = [range(i, i+10) for i in range(0, 100, 10)] + result = [self.evaluate(el) for el in result_dataset] + self.assertAllEqual(expected_values, result) + + def testMapAndBatchDataset(self): + dataset = dataset_ops.Dataset.range(100) + dataset = dataset.apply(batching.map_and_batch(lambda x: x, 20)) + split_batch_by = 2 + result_dataset = values._split_dataset_batch(dataset, split_batch_by) + expected_values = [range(i, i+10) for i in range(0, 100, 10)] + result = [self.evaluate(el) for el in result_dataset] + self.assertAllEqual(expected_values, result) + + def testPrefetchDataset(self): + dataset = dataset_ops.Dataset.range(100).batch(20).prefetch(1) + split_batch_by = 2 + result_dataset = values._split_dataset_batch(dataset, split_batch_by) + expected_values = [range(i, i+10) for i in range(0, 100, 10)] + result = [self.evaluate(el) for el in result_dataset] + self.assertAllEqual(expected_values, result) + + class MirroredVariableTest(test.TestCase, parameterized.TestCase): config = config_pb2.ConfigProto() @@ -768,8 +812,8 @@ class MirroredVariableTest(test.TestCase, parameterized.TestCase): def testVariableOnAnotherDevice(self): v = variable_scope.get_variable( name="v", initializer=[1.], use_resource=True) - index = {"/job:foo/device:CPU:0": v} - mirrored = values.MirroredVariable(index, v, + device_map = values.ReplicaDeviceMap(("/job:foo/device:CPU:0",)) + mirrored = values.MirroredVariable(None, device_map, (v,), variable_scope.VariableAggregation.MEAN) self.assertEqual(v.name, mirrored.name) @@ -797,7 +841,8 @@ class MirroredVariableTest(test.TestCase, parameterized.TestCase): self.skipTest("A GPU is not available for this test in eager mode.") with self.cached_session(config=self.config) as sess: - v, devices, mirrored = _make_mirrored() + v, device_map, mirrored = _make_mirrored() + devices = device_map.all_devices # Overwrite the initial values. self._assign_mirrored(devices, v, [3., 4.]) @@ -815,7 +860,8 @@ class MirroredVariableTest(test.TestCase, parameterized.TestCase): def _save_mirrored(self): """Save variables with mirroring, returns save_path.""" with self.session(graph=ops.Graph()) as sess: - v, devices, mirrored = _make_mirrored() + v, device_map, mirrored = _make_mirrored() + devices = device_map.all_devices # Overwrite the initial values. self._assign_mirrored(devices, v, [3., 4.]) @@ -860,7 +906,8 @@ class MirroredVariableTest(test.TestCase, parameterized.TestCase): def _restore_mirrored(self, save_path): """Restore to variables with mirroring in a fresh graph.""" with self.session(graph=ops.Graph()) as sess: - v, devices, mirrored = _make_mirrored() + v, device_map, mirrored = _make_mirrored() + devices = device_map.all_devices # Overwrite the initial values. self._assign_mirrored(devices, v, [7., 8.]) @@ -904,25 +951,24 @@ class MirroredVariableTest(test.TestCase, parameterized.TestCase): with ops.device("/device:GPU:0"): v = variable_scope.get_variable( name="v", initializer=1., use_resource=True) - mirrored = values.MirroredVariable({ - "/device:GPU:0": v - }, v, variable_scope.VariableAggregation.MEAN) + mirrored = values.MirroredVariable( + distribution, values.ReplicaDeviceMap(("/device:GPU:0",)), (v,), + variable_scope.VariableAggregation.MEAN) sess.run(variables_lib.global_variables_initializer()) sess.run({"complicated": mirrored}) -_devices = ["/device:GPU:0", "/device:CPU:0"] +_devices = ("/device:GPU:0", "/device:CPU:0") -def _make_replica_local(method): +def _make_replica_local(method, strategy=None): + device_map = values.ReplicaDeviceMap(_devices) v = [] - index = {} for d, n, init in zip(_devices, ["v", "v/replica"], [1., 2.]): with ops.device(d): v.append(variable_scope.get_variable( name=n, initializer=init, use_resource=True)) - index[d] = v[-1] - replica_local = values.ReplicaLocalVariable(index, v[0], method) + replica_local = values.ReplicaLocalVariable(strategy, device_map, v, method) return v, replica_local @@ -948,9 +994,9 @@ class ReplicaLocalVariablePropertiesTest(test.TestCase): def testVariableOnAnotherDevice(self): v = variable_scope.get_variable( name="v", initializer=[1.], use_resource=True) - index = {"/job:foo/device:CPU:0": v} + device_map = values.ReplicaDeviceMap(("/job:foo/device:CPU:0",)) replica_local = values.ReplicaLocalVariable( - index, v, variable_scope.VariableAggregation.MEAN) + None, device_map, (v,), variable_scope.VariableAggregation.MEAN) self.assertEqual(v.name, replica_local.name) self.assertEqual(v.dtype, replica_local.dtype) @@ -997,7 +1043,7 @@ class ReplicaLocalVariableTest(test.TestCase, parameterized.TestCase): def testSaveAndRestoreReplicaLocalSumOneGraph(self, distribution): with self.cached_session() as sess: v, replica_local = _make_replica_local( - variable_scope.VariableAggregation.SUM) + variable_scope.VariableAggregation.SUM, distribution) # Overwrite the initial values. self._assign_replica_local(_devices, v, [3., 4.]) @@ -1020,7 +1066,7 @@ class ReplicaLocalVariableTest(test.TestCase, parameterized.TestCase): with self.cached_session() as sess: v, replica_local = _make_replica_local( - variable_scope.VariableAggregation.MEAN) + variable_scope.VariableAggregation.MEAN, distribution) # Overwrite the initial values. self._assign_replica_local(_devices, v, [3., 4.]) @@ -1040,7 +1086,7 @@ class ReplicaLocalVariableTest(test.TestCase, parameterized.TestCase): """Save variables with mirroring, returns save_path.""" with self.session(graph=ops.Graph()) as sess: v, replica_local = _make_replica_local( - variable_scope.VariableAggregation.MEAN) + variable_scope.VariableAggregation.MEAN, distribution) # Overwrite the initial values. self._assign_replica_local(_devices, v, [3., 4.]) @@ -1056,7 +1102,7 @@ class ReplicaLocalVariableTest(test.TestCase, parameterized.TestCase): def _save_replica_local_sum(self, distribution): """Save variables with mirroring, returns save_path.""" with self.session(graph=ops.Graph()) as sess: - v, replica_local = _make_replica_local("sum") + v, replica_local = _make_replica_local("sum", distribution) # Overwrite the initial values. self._assign_replica_local(_devices, v, [1.5, 2.]) @@ -1103,7 +1149,7 @@ class ReplicaLocalVariableTest(test.TestCase, parameterized.TestCase): """Restore to variables with mirroring in a fresh graph.""" with self.session(graph=ops.Graph()) as sess: v, replica_local = _make_replica_local( - variable_scope.VariableAggregation.MEAN) + variable_scope.VariableAggregation.MEAN, distribution) # Overwrite the initial values. self._assign_replica_local(_devices, v, [7., 8.]) @@ -1118,7 +1164,7 @@ class ReplicaLocalVariableTest(test.TestCase, parameterized.TestCase): """Restore to variables with mirroring in a fresh graph.""" with self.session(graph=ops.Graph()) as sess: v, replica_local = _make_replica_local( - variable_scope.VariableAggregation.SUM) + variable_scope.VariableAggregation.SUM, distribution) # Overwrite the initial values. self._assign_replica_local(_devices, v, [7., 8.]) diff --git a/tensorflow/contrib/distributions/python/ops/sample_stats.py b/tensorflow/contrib/distributions/python/ops/sample_stats.py index 978e627d6638ddeea9df288d389354f0ac53d115..19e99e03803e7f4cdfdb023feb04daaba68eceed 100644 --- a/tensorflow/contrib/distributions/python/ops/sample_stats.py +++ b/tensorflow/contrib/distributions/python/ops/sample_stats.py @@ -300,7 +300,7 @@ def percentile(x, raise ValueError("Argument 'interpolation' must be in %s. Found %s" % (allowed_interpolations, interpolation)) - with ops.name_scope(name, [x, q]): + with ops.name_scope(name, values=[x, q]): x = ops.convert_to_tensor(x, name="x") # Double is needed here and below, else we get the wrong index if the array # is huge along axis. diff --git a/tensorflow/contrib/eager/python/BUILD b/tensorflow/contrib/eager/python/BUILD index 77052a75a70bec1162feb2b126d247924b3a2e36..8966a9befcd3db4a3f397b319e80f37f84ad236b 100644 --- a/tensorflow/contrib/eager/python/BUILD +++ b/tensorflow/contrib/eager/python/BUILD @@ -15,7 +15,6 @@ py_library( ":metrics", ":network", ":parameter_server", - ":remote", ":saver", "//tensorflow/python:framework_ops", "//tensorflow/python:framework_test_lib", @@ -31,6 +30,7 @@ py_library( "//tensorflow/python/eager:def_function", "//tensorflow/python/eager:execution_callbacks", "//tensorflow/python/eager:function", + "//tensorflow/python/eager:remote", ], ) @@ -238,24 +238,12 @@ py_test( ], ) -py_library( - name = "remote", - srcs = ["remote.py"], - srcs_version = "PY2AND3", - visibility = ["//tensorflow:internal"], - deps = [ - "//tensorflow/core:protos_all_py", - "//tensorflow/python:platform", - "//tensorflow/python/eager:context", - ], -) - cuda_py_test( name = "remote_test", srcs = ["remote_test.py"], additional_deps = [ ":parameter_server", - ":remote", + "//tensorflow/python/eager:remote", "//tensorflow/contrib/eager/python:tfe", "//tensorflow/python:array_ops", "//tensorflow/python:client", diff --git a/tensorflow/contrib/eager/python/remote_test.py b/tensorflow/contrib/eager/python/remote_test.py index 3926de15e71c9917f88fc3f58740b8c75354ab26..f540d9b37b69c7be3b0662b07bd6e9cb8220fadc 100644 --- a/tensorflow/contrib/eager/python/remote_test.py +++ b/tensorflow/contrib/eager/python/remote_test.py @@ -24,12 +24,12 @@ import os import numpy as np from tensorflow.contrib.eager.python import parameter_server -from tensorflow.contrib.eager.python import remote from tensorflow.core.protobuf import cluster_pb2 from tensorflow.core.protobuf import tensorflow_server_pb2 from tensorflow.python.eager import backprop from tensorflow.python.eager import context from tensorflow.python.eager import function +from tensorflow.python.eager import remote from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops diff --git a/tensorflow/contrib/eager/python/tfe.py b/tensorflow/contrib/eager/python/tfe.py index 33c988fd9065e7fbe7b9aeb85cad82eb3c119f76..31481d7685c79b76c40b1f8041441a0e71d3b00e 100644 --- a/tensorflow/contrib/eager/python/tfe.py +++ b/tensorflow/contrib/eager/python/tfe.py @@ -41,6 +41,8 @@ To use, at program startup, call `tf.enable_eager_execution()`. @@add_execution_callback @@clear_execution_callbacks +@@errstate +@@ExecutionCallback @@inf_callback @@inf_nan_callback @@nan_callback @@ -97,7 +99,6 @@ from tensorflow.contrib.eager.python.network import Network from tensorflow.contrib.eager.python.network import Sequential from tensorflow.contrib.eager.python.network import save_network_checkpoint from tensorflow.contrib.eager.python.network import restore_network_checkpoint -from tensorflow.contrib.eager.python.remote import connect_to_remote_host from tensorflow.contrib.eager.python.saver import get_optimizer_variables from tensorflow.contrib.eager.python.saver import restore_variables_on_create from tensorflow.contrib.eager.python.saver import Saver @@ -119,10 +120,13 @@ from tensorflow.python.eager.context import set_server_def from tensorflow.python.eager.def_function import function from tensorflow.python.eager.execution_callbacks import add_execution_callback from tensorflow.python.eager.execution_callbacks import clear_execution_callbacks +from tensorflow.python.eager.execution_callbacks import errstate +from tensorflow.python.eager.execution_callbacks import ExecutionCallback from tensorflow.python.eager.execution_callbacks import inf_callback from tensorflow.python.eager.execution_callbacks import inf_nan_callback from tensorflow.python.eager.execution_callbacks import nan_callback from tensorflow.python.eager.execution_callbacks import seterr +from tensorflow.python.eager.remote import connect_to_remote_host from tensorflow.python.framework.tensor_spec import TensorSpec from tensorflow.python.framework.ops import enable_eager_execution from tensorflow.python.framework.ops import enable_eager_execution_internal as enable_remote_eager_execution diff --git a/tensorflow/contrib/feature_column/BUILD b/tensorflow/contrib/feature_column/BUILD index 1cd83bdb5de7c2f6dc91c980750b49aca1a7790b..4c1d1a29f20b5574b63cf87ecf62db95f92902cd 100644 --- a/tensorflow/contrib/feature_column/BUILD +++ b/tensorflow/contrib/feature_column/BUILD @@ -110,8 +110,8 @@ py_test( "//tensorflow/python:parsing_ops", "//tensorflow/python:sparse_tensor", "//tensorflow/python:training", - "//tensorflow/python/feature_column", "//tensorflow/python/feature_column:feature_column_py", + "//tensorflow/python/feature_column:feature_column_v2_test", "//third_party/py/numpy", "@absl_py//absl/testing:parameterized", ], diff --git a/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column_v2.py b/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column_v2.py index 0d34ad161855476b6a4cd9a258521dbe122b4140..83b93ec332044f754f9dcde8d7c5c19b26e53a4a 100644 --- a/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column_v2.py +++ b/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column_v2.py @@ -203,7 +203,8 @@ def sequence_categorical_column_with_identity( columns = [watches_embedding] features = tf.parse_example(..., features=make_parse_example_spec(columns)) - input_layer, sequence_length = sequence_input_layer(features, columns) + sequence_feature_layer = SequenceFeatureLayer(columns) + input_layer, sequence_length = sequence_feature_layer(features) rnn_cell = tf.nn.rnn_cell.BasicRNNCell(hidden_size) outputs, state = tf.nn.dynamic_rnn( @@ -219,15 +220,17 @@ def sequence_categorical_column_with_identity( `[0, num_buckets)`, and will replace out-of-range inputs. Returns: - A `_SequenceCategoricalColumn`. + A `SequenceCategoricalColumn`. Raises: ValueError: if `num_buckets` is less than one. ValueError: if `default_value` is not in range `[0, num_buckets)`. """ - return fc_old._SequenceCategoricalColumn( - fc_old._categorical_column_with_identity( - key=key, num_buckets=num_buckets, default_value=default_value)) + return fc.SequenceCategoricalColumn( + fc.categorical_column_with_identity( + key=key, + num_buckets=num_buckets, + default_value=default_value)) def sequence_categorical_column_with_hash_bucket( @@ -247,7 +250,8 @@ def sequence_categorical_column_with_hash_bucket( columns = [tokens_embedding] features = tf.parse_example(..., features=make_parse_example_spec(columns)) - input_layer, sequence_length = sequence_input_layer(features, columns) + sequence_feature_layer = SequenceFeatureLayer(columns) + input_layer, sequence_length = sequence_feature_layer(features) rnn_cell = tf.nn.rnn_cell.BasicRNNCell(hidden_size) outputs, state = tf.nn.dynamic_rnn( @@ -260,15 +264,17 @@ def sequence_categorical_column_with_hash_bucket( dtype: The type of features. Only string and integer types are supported. Returns: - A `_SequenceCategoricalColumn`. + A `SequenceCategoricalColumn`. Raises: ValueError: `hash_bucket_size` is not greater than 1. ValueError: `dtype` is neither string nor integer. """ - return fc_old._SequenceCategoricalColumn( - fc_old._categorical_column_with_hash_bucket( - key=key, hash_bucket_size=hash_bucket_size, dtype=dtype)) + return fc.SequenceCategoricalColumn( + fc.categorical_column_with_hash_bucket( + key=key, + hash_bucket_size=hash_bucket_size, + dtype=dtype)) def sequence_categorical_column_with_vocabulary_file( @@ -290,7 +296,8 @@ def sequence_categorical_column_with_vocabulary_file( columns = [states_embedding] features = tf.parse_example(..., features=make_parse_example_spec(columns)) - input_layer, sequence_length = sequence_input_layer(features, columns) + sequence_feature_layer = SequenceFeatureLayer(columns) + input_layer, sequence_length = sequence_feature_layer(features) rnn_cell = tf.nn.rnn_cell.BasicRNNCell(hidden_size) outputs, state = tf.nn.dynamic_rnn( @@ -314,7 +321,7 @@ def sequence_categorical_column_with_vocabulary_file( dtype: The type of features. Only string and integer types are supported. Returns: - A `_SequenceCategoricalColumn`. + A `SequenceCategoricalColumn`. Raises: ValueError: `vocabulary_file` is missing or cannot be opened. @@ -323,8 +330,8 @@ def sequence_categorical_column_with_vocabulary_file( ValueError: `num_oov_buckets` and `default_value` are both specified. ValueError: `dtype` is neither string nor integer. """ - return fc_old._SequenceCategoricalColumn( - fc_old._categorical_column_with_vocabulary_file( + return fc.SequenceCategoricalColumn( + fc.categorical_column_with_vocabulary_file( key=key, vocabulary_file=vocabulary_file, vocabulary_size=vocabulary_size, @@ -351,7 +358,8 @@ def sequence_categorical_column_with_vocabulary_list( columns = [colors_embedding] features = tf.parse_example(..., features=make_parse_example_spec(columns)) - input_layer, sequence_length = sequence_input_layer(features, columns) + sequence_feature_layer = SequenceFeatureLayer(columns) + input_layer, sequence_length = sequence_feature_layer(features) rnn_cell = tf.nn.rnn_cell.BasicRNNCell(hidden_size) outputs, state = tf.nn.dynamic_rnn( @@ -375,7 +383,7 @@ def sequence_categorical_column_with_vocabulary_list( with `default_value`. Returns: - A `_SequenceCategoricalColumn`. + A `SequenceCategoricalColumn`. Raises: ValueError: if `vocabulary_list` is empty, or contains duplicate keys. @@ -383,8 +391,8 @@ def sequence_categorical_column_with_vocabulary_list( ValueError: `num_oov_buckets` and `default_value` are both specified. ValueError: if `dtype` is not integer or string. """ - return fc_old._SequenceCategoricalColumn( - fc_old._categorical_column_with_vocabulary_list( + return fc.SequenceCategoricalColumn( + fc.categorical_column_with_vocabulary_list( key=key, vocabulary_list=vocabulary_list, dtype=dtype, diff --git a/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column_v2_test.py b/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column_v2_test.py index ca4398a142065de0be7bee57cd7e54670bbae12e..be012a87690c24c6d9b7808790393e1aa6d01211 100644 --- a/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column_v2_test.py +++ b/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column_v2_test.py @@ -26,7 +26,7 @@ from tensorflow.contrib.feature_column.python.feature_column import sequence_fea from tensorflow.contrib.feature_column.python.feature_column import sequence_feature_column_v2 as sfc from tensorflow.python.feature_column import feature_column as fc_old from tensorflow.python.feature_column import feature_column_lib as fc -from tensorflow.python.feature_column.feature_column import _LazyBuilder +from tensorflow.python.feature_column.feature_column_v2_test import _TestStateManager from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors from tensorflow.python.framework import ops @@ -131,7 +131,7 @@ class SequenceInputLayerTest(test.TestCase, parameterized.TestCase): feature_columns=[embedding_column_b, embedding_column_a]) global_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES) - self.assertItemsEqual( + self.assertCountEqual( ('sequence_input_layer/aaa_embedding/embedding_weights:0', 'sequence_input_layer/bbb_embedding/embedding_weights:0'), tuple([v.name for v in global_vars])) @@ -223,7 +223,7 @@ class SequenceInputLayerTest(test.TestCase, parameterized.TestCase): feature_columns=shared_embedding_columns) global_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES) - self.assertItemsEqual( + self.assertCountEqual( ('sequence_input_layer/aaa_bbb_shared_embedding/embedding_weights:0',), tuple([v.name for v in global_vars])) with monitored_session.MonitoredSession() as sess: @@ -670,6 +670,23 @@ def _assert_sparse_tensor_indices_shape(test_case, expected, actual): test_case.assertAllEqual(expected.dense_shape, actual.dense_shape) +def _get_sequence_dense_tensor(column, features): + return column.get_sequence_dense_tensor( + fc.FeatureTransformationCache(features), None) + + +def _get_sequence_dense_tensor_state(column, features): + state_manager = _TestStateManager() + column.create_state(state_manager) + return column.get_sequence_dense_tensor( + fc.FeatureTransformationCache(features), state_manager) + + +def _get_sparse_tensors(column, features): + return column.get_sparse_tensors( + fc.FeatureTransformationCache(features), None) + + class SequenceCategoricalColumnWithIdentityTest( test.TestCase, parameterized.TestCase): @@ -698,7 +715,7 @@ class SequenceCategoricalColumnWithIdentityTest( expected = sparse_tensor.SparseTensorValue(**expected_args) column = sfc.sequence_categorical_column_with_identity('aaa', num_buckets=9) - id_weight_pair = column._get_sparse_tensors(_LazyBuilder({'aaa': inputs})) + id_weight_pair = _get_sparse_tensors(column, {'aaa': inputs}) self.assertIsNone(id_weight_pair.weight_tensor) with monitored_session.MonitoredSession() as sess: @@ -737,7 +754,7 @@ class SequenceCategoricalColumnWithHashBucketTest( column = sfc.sequence_categorical_column_with_hash_bucket( 'aaa', hash_bucket_size=10) - id_weight_pair = column._get_sparse_tensors(_LazyBuilder({'aaa': inputs})) + id_weight_pair = _get_sparse_tensors(column, {'aaa': inputs}) self.assertIsNone(id_weight_pair.weight_tensor) with monitored_session.MonitoredSession() as sess: @@ -790,7 +807,7 @@ class SequenceCategoricalColumnWithVocabularyFileTest( vocabulary_file=self._wire_vocabulary_file_name, vocabulary_size=self._wire_vocabulary_size) - id_weight_pair = column._get_sparse_tensors(_LazyBuilder({'aaa': inputs})) + id_weight_pair = _get_sparse_tensors(column, {'aaa': inputs}) self.assertIsNone(id_weight_pair.weight_tensor) with monitored_session.MonitoredSession() as sess: @@ -814,8 +831,7 @@ class SequenceCategoricalColumnWithVocabularyFileTest( input_placeholder_shape[1] = None input_placeholder = array_ops.sparse_placeholder( dtypes.string, shape=input_placeholder_shape) - id_weight_pair = column._get_sparse_tensors( - _LazyBuilder({'aaa': input_placeholder})) + id_weight_pair = _get_sparse_tensors(column, {'aaa': input_placeholder}) self.assertIsNone(id_weight_pair.weight_tensor) with monitored_session.MonitoredSession() as sess: @@ -855,7 +871,7 @@ class SequenceCategoricalColumnWithVocabularyListTest( key='aaa', vocabulary_list=('omar', 'stringer', 'marlo')) - id_weight_pair = column._get_sparse_tensors(_LazyBuilder({'aaa': inputs})) + id_weight_pair = _get_sparse_tensors(column, {'aaa': inputs}) self.assertIsNone(id_weight_pair.weight_tensor) with monitored_session.MonitoredSession() as sess: @@ -922,13 +938,12 @@ class SequenceEmbeddingColumnTest( categorical_column = sfc.sequence_categorical_column_with_identity( key='aaa', num_buckets=vocabulary_size) - embedding_column = fc_old._embedding_column( - categorical_column, - dimension=embedding_dimension, + embedding_column = fc.embedding_column( + categorical_column, dimension=embedding_dimension, initializer=_initializer) - embedding_lookup, _ = embedding_column._get_sequence_dense_tensor( - _LazyBuilder({'aaa': inputs})) + embedding_lookup, _ = _get_sequence_dense_tensor_state( + embedding_column, {'aaa': inputs}) global_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES) self.assertItemsEqual( @@ -961,10 +976,11 @@ class SequenceEmbeddingColumnTest( categorical_column = sfc.sequence_categorical_column_with_identity( key='aaa', num_buckets=vocabulary_size) - embedding_column = fc_old._embedding_column(categorical_column, dimension=2) + embedding_column = fc.embedding_column( + categorical_column, dimension=2) - _, sequence_length = embedding_column._get_sequence_dense_tensor( - _LazyBuilder({'aaa': inputs})) + _, sequence_length = _get_sequence_dense_tensor_state( + embedding_column, {'aaa': inputs}) with monitored_session.MonitoredSession() as sess: sequence_length = sess.run(sequence_length) @@ -988,10 +1004,11 @@ class SequenceEmbeddingColumnTest( categorical_column = sfc.sequence_categorical_column_with_identity( key='aaa', num_buckets=vocabulary_size) - embedding_column = fc_old._embedding_column(categorical_column, dimension=2) + embedding_column = fc.embedding_column( + categorical_column, dimension=2) - _, sequence_length = embedding_column._get_sequence_dense_tensor( - _LazyBuilder({'aaa': sparse_input})) + _, sequence_length = _get_sequence_dense_tensor_state( + embedding_column, {'aaa': sparse_input}) with monitored_session.MonitoredSession() as sess: self.assertAllEqual( @@ -1058,22 +1075,18 @@ class SequenceSharedEmbeddingColumnTest(test.TestCase): key='aaa', num_buckets=vocabulary_size) categorical_column_b = sfc.sequence_categorical_column_with_identity( key='bbb', num_buckets=vocabulary_size) - shared_embedding_columns = fc.shared_embedding_columns( + shared_embedding_columns = fc.shared_embedding_columns_v2( [categorical_column_a, categorical_column_b], dimension=embedding_dimension, initializer=_initializer) - embedding_lookup_a = shared_embedding_columns[0]._get_sequence_dense_tensor( - _LazyBuilder({ - 'aaa': sparse_input_a - }))[0] - embedding_lookup_b = shared_embedding_columns[1]._get_sequence_dense_tensor( - _LazyBuilder({ - 'bbb': sparse_input_b - }))[0] + embedding_lookup_a = _get_sequence_dense_tensor( + shared_embedding_columns[0], {'aaa': sparse_input_a})[0] + embedding_lookup_b = _get_sequence_dense_tensor( + shared_embedding_columns[1], {'bbb': sparse_input_b})[0] global_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES) - self.assertItemsEqual(('embedding_weights:0',), + self.assertItemsEqual(('aaa_bbb_shared_embedding:0',), tuple([v.name for v in global_vars])) with monitored_session.MonitoredSession() as sess: self.assertAllEqual(embedding_values, global_vars[0].eval(session=sess)) @@ -1104,17 +1117,13 @@ class SequenceSharedEmbeddingColumnTest(test.TestCase): expected_sequence_length_b = [2, 1] categorical_column_b = sfc.sequence_categorical_column_with_identity( key='bbb', num_buckets=vocabulary_size) - shared_embedding_columns = fc.shared_embedding_columns( + shared_embedding_columns = fc.shared_embedding_columns_v2( [categorical_column_a, categorical_column_b], dimension=2) - sequence_length_a = shared_embedding_columns[0]._get_sequence_dense_tensor( - _LazyBuilder({ - 'aaa': sparse_input_a - }))[1] - sequence_length_b = shared_embedding_columns[1]._get_sequence_dense_tensor( - _LazyBuilder({ - 'bbb': sparse_input_b - }))[1] + sequence_length_a = _get_sequence_dense_tensor( + shared_embedding_columns[0], {'aaa': sparse_input_a})[1] + sequence_length_b = _get_sequence_dense_tensor( + shared_embedding_columns[1], {'bbb': sparse_input_b})[1] with monitored_session.MonitoredSession() as sess: sequence_length_a = sess.run(sequence_length_a) @@ -1155,17 +1164,13 @@ class SequenceSharedEmbeddingColumnTest(test.TestCase): categorical_column_b = sfc.sequence_categorical_column_with_identity( key='bbb', num_buckets=vocabulary_size) - shared_embedding_columns = fc.shared_embedding_columns( + shared_embedding_columns = fc.shared_embedding_columns_v2( [categorical_column_a, categorical_column_b], dimension=2) - sequence_length_a = shared_embedding_columns[0]._get_sequence_dense_tensor( - _LazyBuilder({ - 'aaa': sparse_input_a - }))[1] - sequence_length_b = shared_embedding_columns[1]._get_sequence_dense_tensor( - _LazyBuilder({ - 'bbb': sparse_input_b - }))[1] + sequence_length_a = _get_sequence_dense_tensor( + shared_embedding_columns[0], {'aaa': sparse_input_a})[1] + sequence_length_b = _get_sequence_dense_tensor( + shared_embedding_columns[1], {'bbb': sparse_input_b})[1] with monitored_session.MonitoredSession() as sess: self.assertAllEqual( @@ -1221,10 +1226,10 @@ class SequenceIndicatorColumnTest(test.TestCase, parameterized.TestCase): categorical_column = sfc.sequence_categorical_column_with_identity( key='aaa', num_buckets=vocabulary_size) - indicator_column = fc_old._indicator_column(categorical_column) + indicator_column = fc.indicator_column(categorical_column) - indicator_tensor, _ = indicator_column._get_sequence_dense_tensor( - _LazyBuilder({'aaa': inputs})) + indicator_tensor, _ = _get_sequence_dense_tensor( + indicator_column, {'aaa': inputs}) with monitored_session.MonitoredSession() as sess: self.assertAllEqual(expected, indicator_tensor.eval(session=sess)) @@ -1253,10 +1258,10 @@ class SequenceIndicatorColumnTest(test.TestCase, parameterized.TestCase): categorical_column = sfc.sequence_categorical_column_with_identity( key='aaa', num_buckets=vocabulary_size) - indicator_column = fc_old._indicator_column(categorical_column) + indicator_column = fc.indicator_column(categorical_column) - _, sequence_length = indicator_column._get_sequence_dense_tensor( - _LazyBuilder({'aaa': inputs})) + _, sequence_length = _get_sequence_dense_tensor( + indicator_column, {'aaa': inputs}) with monitored_session.MonitoredSession() as sess: sequence_length = sess.run(sequence_length) @@ -1282,19 +1287,14 @@ class SequenceIndicatorColumnTest(test.TestCase, parameterized.TestCase): key='aaa', num_buckets=vocabulary_size) indicator_column = fc.indicator_column(categorical_column) - _, sequence_length = indicator_column._get_sequence_dense_tensor( - _LazyBuilder({'aaa': sparse_input})) + _, sequence_length = _get_sequence_dense_tensor( + indicator_column, {'aaa': sparse_input}) with monitored_session.MonitoredSession() as sess: self.assertAllEqual( expected_sequence_length, sequence_length.eval(session=sess)) -def _get_sequence_dense_tensor(column, features): - return column.get_sequence_dense_tensor( - fc.FeatureTransformationCache(features), None) - - class SequenceNumericColumnTest(test.TestCase, parameterized.TestCase): def test_defaults(self): diff --git a/tensorflow/contrib/fused_conv/kernels/fused_conv2d_bias_activation_op.cc b/tensorflow/contrib/fused_conv/kernels/fused_conv2d_bias_activation_op.cc index 93b1aaa85e88e00c1b12a388321a4d6fb10f1611..c541c71f996c7a1b36cf28ae9a1783f8dca0a72c 100644 --- a/tensorflow/contrib/fused_conv/kernels/fused_conv2d_bias_activation_op.cc +++ b/tensorflow/contrib/fused_conv/kernels/fused_conv2d_bias_activation_op.cc @@ -522,7 +522,7 @@ void LaunchFusedConv2DBiasActivationOp:: auto bias_ptr = AsDeviceMemory(bias.template flat().data(), bias.template flat().size()); - static int64 ConvolveScratchSize = GetCudnnWorkspaceLimit( + static int64 ConvolveScratchSize = GetDnnWorkspaceLimit( // default value is in bytes despite the name of the environment variable "TF_CUDNN_WORKSPACE_LIMIT_IN_MB", 1LL << 32 // 4GB ); @@ -570,7 +570,7 @@ void LaunchFusedConv2DBiasActivationOp:: for (auto profile_algorithm : algorithms) { // TODO(zhengxq): profile each algorithm multiple times to better // accuracy. - CudnnScratchAllocator scratch_allocator(ConvolveScratchSize, ctx); + DnnScratchAllocator scratch_allocator(ConvolveScratchSize, ctx); dnn::ProfileResult profile_result; bool cudnn_launch_status = stream @@ -609,7 +609,7 @@ void LaunchFusedConv2DBiasActivationOp:: algorithm_config); } - CudnnScratchAllocator scratch_allocator(ConvolveScratchSize, ctx); + DnnScratchAllocator scratch_allocator(ConvolveScratchSize, ctx); bool cudnn_launch_status = stream ->ThenFusedConvolveWithAlgorithm( diff --git a/tensorflow/contrib/gan/python/estimator/python/gan_estimator_impl.py b/tensorflow/contrib/gan/python/estimator/python/gan_estimator_impl.py index 3593b501bb738b8f58dce4e40cffbdf410f136b3..adb72228217892fffc10b0e2630edcd9d3e38a02 100644 --- a/tensorflow/contrib/gan/python/estimator/python/gan_estimator_impl.py +++ b/tensorflow/contrib/gan/python/estimator/python/gan_estimator_impl.py @@ -233,13 +233,14 @@ def _get_estimator_spec( estimator_spec = _get_eval_estimator_spec( gan_model, gan_loss, get_eval_metric_ops_fn) else: # model_fn_lib.ModeKeys.TRAIN: - gopt = (generator_optimizer() if callable(generator_optimizer) else - generator_optimizer) - dopt = (discriminator_optimizer() if callable(discriminator_optimizer) - else discriminator_optimizer) + if callable(generator_optimizer): + generator_optimizer = generator_optimizer() + if callable(discriminator_optimizer): + discriminator_optimizer = discriminator_optimizer() get_hooks_fn = get_hooks_fn or tfgan_train.get_sequential_train_hooks() estimator_spec = _get_train_estimator_spec( - gan_model, gan_loss, gopt, dopt, get_hooks_fn, is_chief=is_chief) + gan_model, gan_loss, generator_optimizer, discriminator_optimizer, + get_hooks_fn, is_chief=is_chief) return estimator_spec diff --git a/tensorflow/contrib/gan/python/estimator/python/gan_estimator_test.py b/tensorflow/contrib/gan/python/estimator/python/gan_estimator_test.py index bc9021050bc010ce75c3091fef868549686c0e90..5a3d29cf0b3cb1bbe03cb5ba4f327caf46432b76 100644 --- a/tensorflow/contrib/gan/python/estimator/python/gan_estimator_test.py +++ b/tensorflow/contrib/gan/python/estimator/python/gan_estimator_test.py @@ -75,8 +75,8 @@ class GetGANModelTest(test.TestCase, parameterized.TestCase): def test_get_gan_model(self, mode): with ops.Graph().as_default(): generator_inputs = {'x': array_ops.ones([3, 4])} - real_data = (array_ops.zeros([3, 4]) if - mode != model_fn_lib.ModeKeys.PREDICT else None) + is_predict = mode == model_fn_lib.ModeKeys.PREDICT + real_data = array_ops.zeros([3, 4]) if not is_predict else None gan_model = estimator._get_gan_model( mode, generator_fn, discriminator_fn, real_data, generator_inputs, add_summaries=False) @@ -139,6 +139,7 @@ class GetEstimatorSpecTest(test.TestCase, parameterized.TestCase): @classmethod def setUpClass(cls): + super(GetEstimatorSpecTest, cls).setUpClass() cls._generator_optimizer = training.GradientDescentOptimizer(1.0) cls._discriminator_optimizer = training.GradientDescentOptimizer(1.0) @@ -200,7 +201,6 @@ class GetEstimatorSpecTest(test.TestCase, parameterized.TestCase): self.assertSetEqual(frozenset(sync_opts), frozenset((g_opt, d_opt))) -# TODO(joelshor): Add pandas test. class GANEstimatorIntegrationTest(test.TestCase): def setUp(self): @@ -231,11 +231,11 @@ class GANEstimatorIntegrationTest(test.TestCase): get_eval_metric_ops_fn=get_metrics, model_dir=self._model_dir) - # TRAIN + # Train. num_steps = 10 est.train(train_input_fn, steps=num_steps) - # EVALUTE + # Evaluate. scores = est.evaluate(eval_input_fn) self.assertEqual(num_steps, scores[ops.GraphKeys.GLOBAL_STEP]) self.assertIn('loss', six.iterkeys(scores)) @@ -243,7 +243,7 @@ class GANEstimatorIntegrationTest(test.TestCase): scores['loss']) self.assertIn('mse_custom_metric', six.iterkeys(scores)) - # PREDICT + # Predict. predictions = np.array([x for x in est.predict(predict_input_fn)]) self.assertAllEqual(prediction_size, predictions.shape) diff --git a/tensorflow/contrib/gan/python/features/python/conditioning_utils_impl.py b/tensorflow/contrib/gan/python/features/python/conditioning_utils_impl.py index e2594faf85bcf91cbe09f266e4d4211d20bdee17..364fa4eb461c62784803f0c309e3b7c5855df199 100644 --- a/tensorflow/contrib/gan/python/features/python/conditioning_utils_impl.py +++ b/tensorflow/contrib/gan/python/features/python/conditioning_utils_impl.py @@ -64,6 +64,9 @@ def condition_tensor(tensor, conditioning): """ tensor.shape[1:].assert_is_fully_defined() num_features = tensor.shape[1:].num_elements() + if conditioning.shape.ndims < 2: + raise ValueError('conditioning must be at least 2D, but saw shape: %s' + % conditioning.shape) mapped_conditioning = layers.linear( layers.flatten(conditioning), num_features) diff --git a/tensorflow/contrib/gan/python/features/python/conditioning_utils_test.py b/tensorflow/contrib/gan/python/features/python/conditioning_utils_test.py index 0aad769793761be69ee9d1e3416e44c7b3d8cea0..f5c7d53cf2c9aa08ba0074950983ef3ecd90168b 100644 --- a/tensorflow/contrib/gan/python/features/python/conditioning_utils_test.py +++ b/tensorflow/contrib/gan/python/features/python/conditioning_utils_test.py @@ -45,7 +45,7 @@ class ConditioningUtilsTest(test.TestCase): array_ops.placeholder(dtypes.float32, (5, None)), array_ops.placeholder(dtypes.float32, (5, 1))) - with self.assertRaisesRegexp(ValueError, 'expected min_ndim=2'): + with self.assertRaisesRegexp(ValueError, 'at least 2D'): conditioning_utils.condition_tensor( array_ops.placeholder(dtypes.float32, (5, 2)), array_ops.placeholder(dtypes.float32, (5))) diff --git a/tensorflow/contrib/gdr/BUILD b/tensorflow/contrib/gdr/BUILD index e534fdc17749974ebe713c2730682bea6d7a85e4..704be917b3680a1b5712f4f1dc5059b354db8610 100644 --- a/tensorflow/contrib/gdr/BUILD +++ b/tensorflow/contrib/gdr/BUILD @@ -37,7 +37,7 @@ tf_proto_library_cc( ], ) -tf_cuda_library( +cc_library( name = "gdr_memory_manager", srcs = ["gdr_memory_manager.cc"], hdrs = ["gdr_memory_manager.h"], @@ -58,7 +58,7 @@ tf_cuda_library( ], ) -tf_cuda_library( +cc_library( name = "gdr_worker", srcs = ["gdr_worker.cc"], hdrs = ["gdr_worker.h"], diff --git a/tensorflow/contrib/gdr/gdr.proto b/tensorflow/contrib/gdr/gdr.proto index c0b89245b150bfa49cb527d25b6e1f324f353b25..bd438787c3374be6ead4f6233101fd1f548643ea 100644 --- a/tensorflow/contrib/gdr/gdr.proto +++ b/tensorflow/contrib/gdr/gdr.proto @@ -9,5 +9,4 @@ message RemoteMemoryRegion { uint64 addr = 3; uint32 rkey = 4; uint32 tensor_key = 5; - uint64 checksum = 6; } diff --git a/tensorflow/contrib/gdr/gdr_memory_manager.cc b/tensorflow/contrib/gdr/gdr_memory_manager.cc index 53587fcf3050f313c85485f77ce411cba7faccff..ce1875151597f926aeb6392e7fc8307312da123f 100644 --- a/tensorflow/contrib/gdr/gdr_memory_manager.cc +++ b/tensorflow/contrib/gdr/gdr_memory_manager.cc @@ -26,17 +26,14 @@ limitations under the License. #include #include #include -#include #include "tensorflow/contrib/gdr/gdr.pb.h" #include "tensorflow/core/common_runtime/device.h" #include "tensorflow/core/common_runtime/dma_helper.h" -#include "tensorflow/core/common_runtime/process_state.h" -#if GOOGLE_CUDA #include "tensorflow/core/common_runtime/gpu/gpu_process_state.h" -#include "tensorflow/core/common_runtime/gpu/gpu_util.h" -#endif // GOOGLE_CUDA +#include "tensorflow/core/common_runtime/process_state.h" #include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/random/random.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/platform/numa.h" @@ -81,10 +78,6 @@ int TryToReadNumaNode(ibv_device* device) { int32 value; if (strings::safe_strto32(content, &value)) { if (value < 0) { - LOG(INFO) << "Successful NUMA node read from SysFS had negative value (" - << value - << "), but there must be at least one NUMA node" - ", so returning NUMA node zero"; return port::kNUMANoAffinity; } LOG(INFO) << "NUMA node for device: " << device->name << " is " << value; @@ -114,7 +107,7 @@ class GdrMemoryManager : public RemoteMemoryManager { public: GdrMemoryManager(const string& host, const string& port); - virtual ~GdrMemoryManager(); + virtual ~GdrMemoryManager() {} virtual Status Init() override; @@ -140,7 +133,7 @@ class GdrMemoryManager : public RemoteMemoryManager { return ptr < reinterpret_cast(other->addr) + other->length; } - ibv_mr* FindMemoryRegion(void* addr, size_t length); + ibv_mr* FindMemoryRegion(const Tensor* tensor); void InsertMemoryRegion(void* addr, size_t length, const std::string& allocator_name); @@ -152,7 +145,6 @@ class GdrMemoryManager : public RemoteMemoryManager { const string port_; RdmaEndpointPtr listening_; std::atomic stopped_; - int epfd_; int numa_node_; // Server side endpoints @@ -163,15 +155,19 @@ class GdrMemoryManager : public RemoteMemoryManager { std::atomic next_key_; // Server side on-the-fly tensor buffers - mutex server_mu_; - std::map tensor_buffers_ - GUARDED_BY(server_mu_); + mutex buf_mu_; + std::map tensor_buffers_ GUARDED_BY(buf_mu_); // Client side endpoints mutex client_mu_; std::map, RdmaEndpointPtr> clients_ GUARDED_BY(client_mu_); + // Client side callbacks + mutex callback_mu_; + std::map tensor_callbacks_ + GUARDED_BY(callback_mu_); + // Managed memory regions mutex alloc_mu_; std::vector mrs_ GUARDED_BY(alloc_mu_); @@ -184,16 +180,9 @@ GdrMemoryManager::GdrMemoryManager(const string& host, const string& port) port_(port), listening_(nullptr, EndpointDeleter), stopped_(true), - next_key_(0) {} - -GdrMemoryManager::~GdrMemoryManager() { close(epfd_); } + next_key_(static_cast(random::New64())) {} Status GdrMemoryManager::Init() { - epfd_ = epoll_create1(0); - if (epfd_ == -1) { - return errors::Unavailable(strerror(errno), ": ", "epoll_create"); - } - rdma_addrinfo* addrinfo; rdma_addrinfo hints = {}; hints.ai_port_space = RDMA_PS_TCP; @@ -206,7 +195,7 @@ Status GdrMemoryManager::Init() { ibv_qp_init_attr init_attr = {}; init_attr.qp_type = IBV_QPT_RC; - init_attr.cap.max_recv_wr = 32; + init_attr.cap.max_recv_wr = 1024; init_attr.cap.max_send_wr = 1; init_attr.cap.max_recv_sge = 1; init_attr.cap.max_send_sge = 1; @@ -239,14 +228,6 @@ Status GdrMemoryManager::Init() { "cannot set server to non-blocking mode"); } - epoll_event event = {}; - event.events = EPOLLIN | EPOLLPRI; - event.data.ptr = listening_.get(); - if (epoll_ctl(epfd_, EPOLL_CTL_ADD, listening_->channel->fd, &event)) { - return errors::Unavailable(strerror(errno), ": ", - "cannot add server to epoll"); - } - numa_node_ = TryToReadNumaNode(listening_->verbs->device); SubAllocator::Visitor alloc_visitor = [this](void* ptr, int numa_node, @@ -265,121 +246,114 @@ Status GdrMemoryManager::Init() { ProcessState::singleton()->AddCPUFreeVisitor(free_visitor); LOG(INFO) << "Instrumenting CPU allocator(s)"; -#if GOOGLE_CUDA for (int numa_idx = 0; numa_idx < port::NUMANumNodes(); ++numa_idx) { GPUProcessState::singleton()->AddCUDAHostAllocVisitor(numa_idx, alloc_visitor); GPUProcessState::singleton()->AddCUDAHostFreeVisitor(numa_idx, free_visitor); } + if (IsGDRAvailable()) { SubAllocator::Visitor cuda_alloc_visitor = [this](void* ptr, int gpu_id, size_t num_bytes) { VLOG(2) << "Registering RDMA capable memory region on GPU " << gpu_id; InsertMemoryRegion(ptr, num_bytes, strings::StrCat("GPU:", gpu_id)); }; - for (int numa_idx = 0; numa_idx < port::NUMANumNodes(); ++numa_idx) { - GPUProcessState::singleton()->AddGPUAllocVisitor(numa_idx, - cuda_alloc_visitor); - } - VLOG(1) << "Instrumenting GPU allocator(s) for all Numas"; + GPUProcessState::singleton()->AddGPUAllocVisitor(numa_node_, + cuda_alloc_visitor); + LOG(INFO) << "Instrumenting GPU allocator for NUMA " << numa_node_; } -#endif // GOOGLE_CUDA + return Status::OK(); } void GdrMemoryManager::Run() { stopped_ = false; while (!stopped_) { - epoll_event events[32]; - int ret = epoll_wait(epfd_, events, 32, 1); - if (ret == -1) { - LOG(ERROR) << "epoll_wait: " << strerror(errno); - return; - } - for (int i = 0; i < ret; i++) { - rdma_cm_id* id = static_cast(events[i].data.ptr); - if (id == listening_.get()) { - // Accept incoming connections - if (!rdma_get_request(listening_.get(), &id)) { - if (!rdma_accept(id, nullptr)) { - LOG(INFO) << "Accepted new RDMA connection"; - if (ibv_req_notify_cq(id->recv_cq, 0)) { - LOG(ERROR) << strerror(errno) << ": ibv_req_notify_cq failed"; - EndpointDeleter(id); - continue; - } - for (int i = 0; i < 32; i++) { - if (rdma_post_recvv(id, nullptr, nullptr, 0)) { - LOG(ERROR) << strerror(errno) << ": rdma_post_recvv failed"; - EndpointDeleter(id); - continue; - } - } - int flags = fcntl(id->recv_cq_channel->fd, F_GETFL, 0); - if (fcntl(id->recv_cq_channel->fd, F_SETFL, flags | O_NONBLOCK)) { - LOG(ERROR) << strerror(errno) - << ": cannot set server_client to non-blocking mode"; - EndpointDeleter(id); - continue; - } - epoll_event event = {}; - event.events = EPOLLIN | EPOLLPRI; - event.data.ptr = id; - if (epoll_ctl(epfd_, EPOLL_CTL_ADD, id->recv_cq_channel->fd, - &event)) { - LOG(ERROR) << strerror(errno) - << ": cannot add server client to epoll"; - EndpointDeleter(id); - continue; - } - server_clients_.push_back({id, EndpointDeleter}); + rdma_cm_id* id = nullptr; + // Accept incoming connections + if (!rdma_get_request(listening_.get(), &id)) { + if (!rdma_accept(id, nullptr)) { + LOG(INFO) << "Accepted new RDMA connection"; + for (int i = 0; i < 1024; i++) { + if (rdma_post_recvv(id, nullptr, nullptr, 0)) { + LOG(ERROR) << strerror(errno) << ": rdma_post_recvv failed"; + EndpointDeleter(id); + continue; } } - } else { - // Polling work completions - ibv_cq* cq; - void* context; - if (!ibv_get_cq_event(id->recv_cq_channel, &cq, &context)) { - ibv_ack_cq_events(id->recv_cq, 1); - if (ibv_req_notify_cq(id->recv_cq, 0)) { - LOG(ERROR) << strerror(errno) << ": ibv_req_notify_cq failed"; - continue; + server_clients_.push_back({id, EndpointDeleter}); + } + } + // Polling server side work completions + for (const auto& client : server_clients_) { + ibv_wc wc[32]; + int ret = ibv_poll_cq(client->recv_cq, 32, wc); + if (ret < 0) { + LOG(ERROR) << "ibv_poll_cq failed"; + continue; + } + for (int i = 0; i < ret; i++) { + if (wc[i].opcode != IBV_WC_RECV_RDMA_WITH_IMM) { + LOG(ERROR) << "Received unknown operation " << wc[i].opcode; + } + if (wc[i].status != 0) { + LOG(ERROR) << ibv_wc_status_str(wc[i].status); + } + TensorKey tensor_key = ntohl(wc[i].imm_data); + + if (rdma_post_recvv(client.get(), nullptr, nullptr, 0)) { + perror("rdma_post_recvv"); + LOG(ERROR) << "rdma_post_recvv failed"; + } + + mutex_lock l(buf_mu_); + auto iter = tensor_buffers_.find(tensor_key); + if (iter == std::end(tensor_buffers_)) { + LOG(ERROR) << "Cannot find tensor buffer for tensor key " + << tensor_key; + } else { + const TensorBuffer* buffer = iter->second; + buffer->Unref(); + tensor_buffers_.erase(iter); + } + } + } + // Polling client side work completions + if (client_mu_.try_lock()) { + for (const auto& client : clients_) { + ibv_wc wc[32]; + int ret = ibv_poll_cq(client.second->send_cq, 32, wc); + for (int i = 0; i < ret; i++) { + Status s; + if (wc[i].status) { + s = errors::Unavailable(ibv_wc_status_str(wc[i].status)); + } else { + s = Status::OK(); } - ibv_wc wc[32]; - int ret = ibv_poll_cq(id->recv_cq, 32, wc); - if (ret < 0) { - LOG(ERROR) << "ibv_poll_cq failed"; - continue; + TensorKey key = wc[i].wr_id; + + ibv_send_wr wr = {}; + wr.opcode = IBV_WR_RDMA_WRITE_WITH_IMM; + wr.imm_data = htonl(key); + ibv_send_wr* bad_wr; + if (ibv_post_send(client.second->qp, &wr, &bad_wr)) { + LOG(ERROR) << strerror(errno) + << ": ibv_post_send failed for tensor_key " << key; } - for (int i = 0; i < ret; i++) { - if (wc[i].opcode != IBV_WC_RECV_RDMA_WITH_IMM) { - LOG(ERROR) << "Received unknown operation " << wc[i].opcode; - } - if (wc[i].status != 0) { - LOG(ERROR) << ibv_wc_status_str(wc[i].status); - } - TensorKey tensor_key = ntohl(wc[i].imm_data); - { - mutex_lock l(server_mu_); - auto iter = tensor_buffers_.find(tensor_key); - if (iter == std::end(tensor_buffers_)) { - LOG(ERROR) << "Cannot find tensor buffer for tensor key " - << tensor_key; - } else { - const TensorBuffer* buffer = iter->second; - buffer->Unref(); - tensor_buffers_.erase(iter); - } - } - if (rdma_post_recvv(id, nullptr, nullptr, 0)) { - perror("rdma_post_recvv"); - LOG(ERROR) << "rdma_post_recvv failed"; - continue; - } + + mutex_lock l(callback_mu_); + auto iter = tensor_callbacks_.find(key); + if (iter != std::end(tensor_callbacks_)) { + iter->second(s); + tensor_callbacks_.erase(iter); + } else { + LOG(WARNING) << "Cannot find client callback with tensor key " + << key; } } } + client_mu_.unlock(); } } } @@ -390,116 +364,58 @@ void GdrMemoryManager::TransportOptionsFromTensor( ::google::protobuf::Any* mutable_transport_options, const Tensor& tensor, Device* device, DeviceContext* device_context, bool on_host, StatusCallback done) { - auto buffer = DMAHelper::buffer(&tensor); - void* addr = buffer->data(); - size_t length = buffer->size(); - if (length == 0) { - done(errors::Unavailable("Cannot register tensor buffer of size 0")); - return; - } + ibv_mr* mr = FindMemoryRegion(&tensor); + const TensorBuffer* buffer = DMAHelper::buffer(&tensor); - ibv_mr* mr = FindMemoryRegion(addr, length); - -#if GOOGLE_CUDA - if (device->tensorflow_gpu_device_info() && !on_host) { - Allocator* alloc = GPUProcessState::singleton()->GetCUDAHostAllocator(0); - Tensor* host_copy = new Tensor(alloc, tensor.dtype(), tensor.shape()); - GPUUtil::CopyGPUTensorToCPU( - device, device_context, &tensor, host_copy, - [done, host_copy, mutable_transport_options, this](const Status& s) { - if (!s.ok()) { - done(s); - delete host_copy; - return; - } - auto buffer = DMAHelper::buffer(host_copy); - void* addr = buffer->data(); - size_t length = buffer->size(); - ibv_mr* mr = FindMemoryRegion(addr, length); - - if (mr == nullptr) { - done(errors::Unavailable("Cannot find pinned memory region")); - delete host_copy; - return; - } - - buffer->Ref(); - TensorKey tensor_key = next_key_++; - { - mutex_lock l(server_mu_); - tensor_buffers_.insert(std::make_pair(tensor_key, buffer)); - } - - uint64_t checksum = 0; - if (VLOG_IS_ON(2)) { - checksum = GPUUtil::Checksum(*host_copy); - } - - RemoteMemoryRegion remote_mr; - remote_mr.set_host(host_); - remote_mr.set_port(port_); - remote_mr.set_addr(reinterpret_cast(addr)); - remote_mr.set_rkey(mr->rkey); - remote_mr.set_tensor_key(tensor_key); - remote_mr.set_checksum(checksum); - mutable_transport_options->PackFrom(remote_mr); - - done(Status::OK()); - delete host_copy; - }); - return; - } -#endif + Tensor* copy = nullptr; if (mr == nullptr) { - Allocator* alloc = ProcessState::singleton()->GetCPUAllocator(numa_node_); - Tensor host_copy(alloc, tensor.dtype(), tensor.shape()); - - std::memcpy(DMAHelper::buffer(&host_copy)->data(), buffer->data(), length); - VLOG(2) << "Copying " << length << " bytes unpinned tensor buffer"; - - buffer = DMAHelper::buffer(&host_copy); - addr = buffer->data(); - length = buffer->size(); - - mr = FindMemoryRegion(addr, length); + AllocatorAttributes alloc_attrs; + alloc_attrs.set_gpu_compatible(true); + alloc_attrs.set_nic_compatible(true); + alloc_attrs.set_on_host(true); + Allocator* alloc = device->GetAllocator(alloc_attrs); + copy = new Tensor(alloc, tensor.dtype(), tensor.shape()); + + mr = FindMemoryRegion(copy); + buffer = DMAHelper::buffer(copy); if (mr == nullptr) { done(errors::Unavailable("Cannot find pinned memory region")); + delete copy; return; } - - buffer->Ref(); - } else { - buffer->Ref(); } TensorKey tensor_key = next_key_++; + buffer->Ref(); { - mutex_lock l(server_mu_); + mutex_lock l(buf_mu_); tensor_buffers_.insert(std::make_pair(tensor_key, buffer)); } - uint64_t checksum = 0; - if (VLOG_IS_ON(2)) { -#ifdef GOOGLE_CUDA - if (device->tensorflow_gpu_device_info() && !on_host) { - checksum = GPUUtil::Checksum(device, device_context, tensor); - } else { - checksum = GPUUtil::Checksum(tensor); - } -#endif - } - RemoteMemoryRegion remote_mr; remote_mr.set_host(host_); remote_mr.set_port(port_); - remote_mr.set_addr(reinterpret_cast(addr)); + remote_mr.set_addr(reinterpret_cast(buffer->data())); remote_mr.set_rkey(mr->rkey); remote_mr.set_tensor_key(tensor_key); - remote_mr.set_checksum(checksum); mutable_transport_options->PackFrom(remote_mr); - done(Status::OK()); + if (copy && device->tensorflow_gpu_device_info() && !on_host) { + device_context->CopyDeviceTensorToCPU(&tensor, "" /* tensor_name */, device, + copy, [done, copy](const Status& s) { + done(s); + delete copy; + }); + return; + } else if (copy) { + std::memcpy(buffer->data(), DMAHelper::buffer(&tensor)->data(), + buffer->size()); + done(Status::OK()); + delete copy; // OK to delete; we have reffed the underlying TensorBuffer + } else { + done(Status::OK()); + } } void GdrMemoryManager::TensorFromTransportOptions( @@ -512,42 +428,10 @@ void GdrMemoryManager::TensorFromTransportOptions( return; } - auto buffer = DMAHelper::buffer(tensor); - void* addr = buffer->data(); - size_t length = buffer->size(); - ibv_mr* mr = FindMemoryRegion(addr, length); - - Tensor host_copy; -#if GOOGLE_CUDA - if (mr == nullptr && !on_host) { - Allocator* alloc = - GPUProcessState::singleton()->GetCUDAHostAllocator(numa_node_); - host_copy = Tensor(alloc, tensor->dtype(), tensor->shape()); - buffer = DMAHelper::buffer(&host_copy); - addr = buffer->data(); - length = buffer->size(); - mr = FindMemoryRegion(addr, length); - } -#endif // GOOGLE_CUDA - - if (mr == nullptr) { - Allocator* alloc = ProcessState::singleton()->GetCPUAllocator(numa_node_); - host_copy = Tensor(alloc, tensor->dtype(), tensor->shape()); - - buffer = DMAHelper::buffer(&host_copy); - addr = buffer->data(); - length = buffer->size(); - - mr = FindMemoryRegion(addr, length); - if (mr == nullptr) { - done(errors::Unavailable("Cannot find pinned memory region")); - return; - } - } - - decltype(clients_)::iterator iter; - bool success; + rdma_cm_id* id = nullptr; { + decltype(clients_)::iterator iter; + bool success; mutex_lock l(client_mu_); std::tie(iter, success) = clients_.insert( std::make_pair(std::make_pair(remote_mr.host(), remote_mr.port()), @@ -560,93 +444,94 @@ void GdrMemoryManager::TensorFromTransportOptions( return; } } - } - rdma_cm_id* id = iter->second.get(); - - uint64_t start = Env::Default()->NowMicros(); - - if (rdma_post_read(id, nullptr, buffer->data(), buffer->size(), mr, 0, - remote_mr.addr(), remote_mr.rkey())) { - done(errors::Unavailable(strerror(errno), ": ", "rdma_post_read failed")); - return; + id = iter->second.get(); } - ibv_send_wr wr = {}; - wr.opcode = IBV_WR_RDMA_WRITE_WITH_IMM; - wr.imm_data = htonl(remote_mr.tensor_key()); - wr.send_flags = IBV_SEND_SIGNALED; - ibv_send_wr* bad_wr; - if (ibv_post_send(id->qp, &wr, &bad_wr)) { - done(errors::Unavailable(strerror(errno), ": ", "ibv_post_send failed")); - return; - } + ibv_mr* mr = FindMemoryRegion(tensor); + const TensorBuffer* buffer = DMAHelper::buffer(tensor); - ibv_wc wc = {}; - int ret; - while ((ret = ibv_poll_cq(id->send_cq, 1, &wc)) == 0) - ; - if (ret < 0 || wc.status) { - done(errors::Unavailable(ibv_wc_status_str(wc.status))); - return; - } + const Tensor* copy = nullptr; -#if GOOGLE_CUDA - if (device->tensorflow_gpu_device_info() && !on_host && - host_copy.NumElements() > 0) { - uint64_t checksum = 0; - if (VLOG_IS_ON(2)) { - checksum = GPUUtil::Checksum(host_copy); - CHECK(checksum == remote_mr.checksum()) - << "Checksum mismatch: " << checksum << "!=" << remote_mr.checksum(); + if (mr == nullptr) { + AllocatorAttributes alloc_attrs; + alloc_attrs.set_gpu_compatible(true); + alloc_attrs.set_nic_compatible(true); + alloc_attrs.set_on_host(true); + Allocator* alloc = device->GetAllocator(alloc_attrs); + copy = new Tensor(alloc, tensor->dtype(), tensor->shape()); + + mr = FindMemoryRegion(copy); + buffer = DMAHelper::buffer(copy); + if (mr == nullptr) { + done(errors::Unavailable("Cannot find pinned memory region")); + delete copy; + return; } - Tensor* ref = new Tensor; - std::swap(host_copy, *ref); - GPUUtil::CopyCPUTensorToGPU( - ref, device_context, device, tensor, - [ref, done, buffer, remote_mr, start](const Status& s) { - if (!s.ok()) { - done(s); - delete ref; - return; - } - uint64_t end = Env::Default()->NowMicros(); - - VLOG(2) << "RDMA from remote memory region " << remote_mr.rkey() - << " of size " << buffer->size() << " with tensor key " - << remote_mr.tensor_key() << " took " << (end - start) - << " micros"; - done(Status::OK()); - delete ref; - }); - return; } -#endif // GOOGLE_CUDA - if ((on_host || !device->tensorflow_gpu_device_info()) && - host_copy.NumElements() > 0) { - std::memcpy(DMAHelper::buffer(tensor)->data(), addr, length); - VLOG(2) << "Copying " << length << " bytes unpinned tensor buffer"; - } + uint64_t start = Env::Default()->NowMicros(); - uint64_t end = Env::Default()->NowMicros(); + TensorKey tensor_key = remote_mr.tensor_key(); - VLOG(2) << "RDMA from remote memory region " << remote_mr.rkey() - << " of size " << buffer->size() << " with tensor key " - << remote_mr.tensor_key() << " took " << (end - start) << " micros"; + StatusCallback callback = [done, copy, device, device_context, on_host, + tensor, start, tensor_key](const Status& s) { + if (!s.ok()) { + done(s); + if (copy) { + delete copy; + } + return; + } - uint64_t checksum = 0; - if (VLOG_IS_ON(2)) { -#ifdef GOOGLE_CUDA - if (device->tensorflow_gpu_device_info() && !on_host) { - checksum = GPUUtil::Checksum(device, device_context, *tensor); + VLOG(2) << "RDMA of tensor " << tensor_key << " of size " + << DMAHelper::buffer(tensor)->size() << " took " + << (Env::Default()->NowMicros() - start) << " micros"; + + if (copy && device->tensorflow_gpu_device_info() && !on_host) { + device_context->CopyCPUTensorToDevice(copy, device, tensor, + [done, copy](const Status& s) { + done(s); + delete copy; + }); + } else if (copy) { + std::memcpy(DMAHelper::buffer(tensor)->data(), + DMAHelper::buffer(copy)->data(), + DMAHelper::buffer(copy)->size()); + done(s); + delete copy; } else { - checksum = GPUUtil::Checksum(*tensor); + done(s); + } + }; + + { + mutex_lock l(callback_mu_); + if (tensor_callbacks_.find(tensor_key) == std::end(tensor_callbacks_)) { + tensor_callbacks_.insert(std::make_pair(tensor_key, std::move(callback))); + } else { + done(errors::Unavailable("Received duplicated tensor key")); + if (copy) { + delete copy; + } + return; + } + } + + if (rdma_post_read(id, reinterpret_cast(tensor_key), buffer->data(), + buffer->size(), mr, IBV_SEND_SIGNALED, remote_mr.addr(), + remote_mr.rkey())) { + done(errors::Unavailable(strerror(errno), ": ", "rdma_post_read failed")); + { + mutex_lock l(callback_mu_); + auto iter = tensor_callbacks_.find(tensor_key); + if (iter != std::end(tensor_callbacks_)) { + tensor_callbacks_.erase(iter); + } + } + if (copy) { + delete copy; } - CHECK(checksum == remote_mr.checksum()) - << "Checksum mismatch: " << checksum << "!=" << remote_mr.checksum(); -#endif } - done(Status::OK()); } Status GdrMemoryManager::CreateEndpoint(const string& host, const string& port, @@ -663,7 +548,7 @@ Status GdrMemoryManager::CreateEndpoint(const string& host, const string& port, ibv_qp_init_attr init_attr = {}; init_attr.qp_type = IBV_QPT_RC; init_attr.cap.max_recv_wr = 1; - init_attr.cap.max_send_wr = 32; + init_attr.cap.max_send_wr = 1024; init_attr.cap.max_recv_sge = 1; init_attr.cap.max_send_sge = 1; @@ -687,8 +572,8 @@ Status GdrMemoryManager::CreateEndpoint(const string& host, const string& port, return Status::OK(); } -ibv_mr* GdrMemoryManager::FindMemoryRegion(void* addr, size_t length) { - if (length == 0) return nullptr; +ibv_mr* GdrMemoryManager::FindMemoryRegion(const Tensor* tensor) { + const void* addr = DMAHelper::buffer(tensor)->data(); mutex_lock l(alloc_mu_); auto iter = std::upper_bound(mrs_.begin(), mrs_.end(), addr, &Comparator); if (iter == std::end(mrs_) || iter->get()->addr > addr) { diff --git a/tensorflow/contrib/gdr/gdr_rendezvous_mgr.cc b/tensorflow/contrib/gdr/gdr_rendezvous_mgr.cc index fbccbead03fc0d641db40ede661bf3677d44c45d..5f8c300155770ed03ad12a9fa5ac74456edaf024 100644 --- a/tensorflow/contrib/gdr/gdr_rendezvous_mgr.cc +++ b/tensorflow/contrib/gdr/gdr_rendezvous_mgr.cc @@ -58,11 +58,9 @@ class GdrRecvTensorCall : public BaseRecvTensorCall { resp_.InitAlloc(dst_device_, recv_args_.alloc_attrs); StatusCallback cb = [this, recv_done](const Status& s) { bool dma_ok = resp_.metadata().has_transport_options(); - if (s.ok() && tensor().TotalBytes() > 0 && (!is_dead()) && dma_ok) { + if (s.ok() && tensor().TotalBytes() > 1024 && (!is_dead()) && dma_ok) { auto transport_options = resp_.metadata().transport_options(); - const bool on_host = - (dst_device_->tensorflow_gpu_device_info() == nullptr) || - recv_args_.alloc_attrs.on_host(); + const bool on_host = recv_args_.alloc_attrs.on_host(); remote_memory_manager_->TensorFromTransportOptions( const_cast(&tensor()), transport_options, dst_device_, recv_args_.device_context, on_host, @@ -70,9 +68,6 @@ class GdrRecvTensorCall : public BaseRecvTensorCall { if (!s.ok()) { mutex_lock l(mu_); status_.Update(s); - LOG(ERROR) << "Cannot find pinned memory region from allocator " - << dst_device_->GetAllocator(recv_args_.alloc_attrs) - ->Name(); } recv_done(); }); diff --git a/tensorflow/contrib/gdr/gdr_server_lib.cc b/tensorflow/contrib/gdr/gdr_server_lib.cc index b3f48ec1dd9c75055f4e1ea76eb203b6ccf94718..dc0d5d548b80d36409778ef34e63171441f10142 100644 --- a/tensorflow/contrib/gdr/gdr_server_lib.cc +++ b/tensorflow/contrib/gdr/gdr_server_lib.cc @@ -74,9 +74,8 @@ Status GdrServer::Start() { } Status GdrServer::Stop() { - TF_RETURN_IF_ERROR(GrpcServer::Stop()); remote_memory_manager_->Stop(); - return Status::OK(); + return GrpcServer::Stop(); } Status GdrServer::Join() { diff --git a/tensorflow/contrib/gdr/gdr_worker.cc b/tensorflow/contrib/gdr/gdr_worker.cc index 867cb83f42034c8e9061e333ea671457745f92c3..016e5ea27b397830c69b6e1761b5994ebcfa9c3d 100644 --- a/tensorflow/contrib/gdr/gdr_worker.cc +++ b/tensorflow/contrib/gdr/gdr_worker.cc @@ -18,9 +18,6 @@ limitations under the License. #include "tensorflow/core/common_runtime/device.h" #include "tensorflow/core/common_runtime/device_mgr.h" #include "tensorflow/core/common_runtime/dma_helper.h" -#if GOOGLE_CUDA -#include "tensorflow/core/common_runtime/gpu/gpu_util.h" -#endif // GOOGLE_CUDA #include "tensorflow/core/common_runtime/process_util.h" #include "tensorflow/core/common_runtime/step_stats_collector.h" #include "tensorflow/core/distributed_runtime/graph_mgr.h" @@ -78,7 +75,7 @@ void GdrWorker::GrpcRecvTensorAsync(CallOptions* opts, const bool dma_ok = request->dma_ok(); env_->rendezvous_mgr->RecvLocalAsync( step_id, parsed, - [this, opts, response, done, src_dev, dma_ok]( + [this, opts, response, done, src_dev, request, dma_ok]( const Status& status, const Rendezvous::Args& send_args, const Rendezvous::Args&, const Tensor& val, const bool is_dead) { opts->ClearCancelCallback(); @@ -89,10 +86,8 @@ void GdrWorker::GrpcRecvTensorAsync(CallOptions* opts, // 3) the tensor has the on_host allocation attribute, // i.e. it's in CPU RAM *independent of its assigned // device type*. - const bool on_host = - (src_dev->tensorflow_gpu_device_info() == nullptr) || - send_args.alloc_attrs.on_host(); - if (val.TotalBytes() > 0 && (!is_dead) && + const bool on_host = send_args.alloc_attrs.on_host(); + if (val.TotalBytes() > 1024 && (!is_dead) && DMAHelper::CanUseDMA(&val) && dma_ok) { // DMA cases. RecvTensorResponse* proto = new RecvTensorResponse; @@ -117,8 +112,7 @@ void GdrWorker::GrpcRecvTensorAsync(CallOptions* opts, } else { // Non-DMA cases. if (src_dev->tensorflow_gpu_device_info() && (!on_host)) { -#if GOOGLE_CUDA - const DeviceContext* send_dev_context = send_args.device_context; + DeviceContext* send_dev_context = send_args.device_context; AllocatorAttributes alloc_attrs; alloc_attrs.set_gpu_compatible(true); alloc_attrs.set_on_host(true); @@ -127,7 +121,8 @@ void GdrWorker::GrpcRecvTensorAsync(CallOptions* opts, CHECK(send_dev_context) << "send dev name: " << src_dev->name() << " gpu_info: " << src_dev->tensorflow_gpu_device_info(); - // "val" is on a GPU. Uses GPUUtil to fill the response proto. + // "val" is on an accelerator device. Uses the device_context to + // fill the copy on host. StatusCallback copy_ready = [response, done, copy, is_dead](const Status& s) { // The value is now ready to be returned on the wire. @@ -136,11 +131,8 @@ void GdrWorker::GrpcRecvTensorAsync(CallOptions* opts, delete copy; }; - GPUUtil::CopyGPUTensorToCPU(src_dev, send_dev_context, &val, copy, - copy_ready); -#else - done(errors::Internal("No GPU device in process")); -#endif // GOOGLE_CUDA + send_dev_context->CopyDeviceTensorToCPU( + &val, request->rendezvous_key(), src_dev, copy, copy_ready); } else { grpc::EncodeTensorToByteBuffer(is_dead, val, response); done(Status::OK()); diff --git a/tensorflow/contrib/hadoop/python/ops/hadoop_dataset_ops.py b/tensorflow/contrib/hadoop/python/ops/hadoop_dataset_ops.py index 5c5599858ee6879a5703d65658bf4bbd881c7e72..71eac729a8a81c2f59f9ed5d7f42fb7b1c3e1b5c 100644 --- a/tensorflow/contrib/hadoop/python/ops/hadoop_dataset_ops.py +++ b/tensorflow/contrib/hadoop/python/ops/hadoop_dataset_ops.py @@ -23,11 +23,16 @@ from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.util import structure from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops +from tensorflow.python.util import deprecation class SequenceFileDataset(dataset_ops.DatasetSource): """A Sequence File Dataset that reads the sequence file.""" + @deprecation.deprecated( + None, + "tf.contrib.hadoop will be removed in 2.0, the support for Apache Hadoop " + "will continue to be provided through the tensorflow/io GitHub project.") def __init__(self, filenames): """Create a `SequenceFileDataset`. @@ -50,13 +55,11 @@ class SequenceFileDataset(dataset_ops.DatasetSource): Args: filenames: A `tf.string` tensor containing one or more filenames. """ - super(SequenceFileDataset, self).__init__() self._filenames = ops.convert_to_tensor( filenames, dtype=dtypes.string, name="filenames") - - def _as_variant_tensor(self): - return gen_dataset_ops.sequence_file_dataset( + variant_tensor = gen_dataset_ops.sequence_file_dataset( self._filenames, self._element_structure._flat_types) # pylint: disable=protected-access + super(SequenceFileDataset, self).__init__(variant_tensor) @property def _element_structure(self): diff --git a/tensorflow/contrib/ignite/python/ops/ignite_dataset_ops.py b/tensorflow/contrib/ignite/python/ops/ignite_dataset_ops.py index e4762c91b193f9c5e32fa2642e702e61e8e5e57f..66e654ca636a5a051c6f9cd35bf9001dfbcbf7f4 100644 --- a/tensorflow/contrib/ignite/python/ops/ignite_dataset_ops.py +++ b/tensorflow/contrib/ignite/python/ops/ignite_dataset_ops.py @@ -31,6 +31,7 @@ from tensorflow.python.data.util import structure from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape +from tensorflow.python.util import deprecation @six.add_metaclass(abc.ABCMeta) @@ -699,6 +700,10 @@ class IgniteDataset(dataset_ops.DatasetSource): Ignite Binary Client Protocol. """ + @deprecation.deprecated( + None, + "tf.contrib.ignite will be removed in 2.0, the support for Apache Ignite " + "will continue to be provided through the tensorflow/io GitHub project.") def __init__(self, cache_name, host="localhost", diff --git a/tensorflow/contrib/kafka/python/ops/kafka_dataset_ops.py b/tensorflow/contrib/kafka/python/ops/kafka_dataset_ops.py index 2b86331099ccae03664462987ee0c141d766c10f..b399e1b6c2ac47db205b5d8bbc81875ef5c08a31 100644 --- a/tensorflow/contrib/kafka/python/ops/kafka_dataset_ops.py +++ b/tensorflow/contrib/kafka/python/ops/kafka_dataset_ops.py @@ -23,12 +23,17 @@ from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.util import structure from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops +from tensorflow.python.util import deprecation class KafkaDataset(dataset_ops.DatasetSource): """A Kafka Dataset that consumes the message. """ + @deprecation.deprecated( + None, + "tf.contrib.kafka will be removed in 2.0, the support for Apache Kafka " + "will continue to be provided through the tensorflow/io GitHub project.") def __init__(self, topics, servers="localhost", diff --git a/tensorflow/contrib/kinesis/python/ops/kinesis_dataset_ops.py b/tensorflow/contrib/kinesis/python/ops/kinesis_dataset_ops.py index 20395395281768ac429984a1e3552cfd187527a2..2b1d478a9b0fd12ca25c72da6872acccfd7285fc 100644 --- a/tensorflow/contrib/kinesis/python/ops/kinesis_dataset_ops.py +++ b/tensorflow/contrib/kinesis/python/ops/kinesis_dataset_ops.py @@ -23,6 +23,7 @@ from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.util import structure from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops +from tensorflow.python.util import deprecation class KinesisDataset(dataset_ops.DatasetSource): @@ -50,6 +51,10 @@ class KinesisDataset(dataset_ops.DatasetSource): is returned immediately instead. """ + @deprecation.deprecated( + None, + "tf.contrib.kinesis will be removed in 2.0, the support for Kinesis " + "will continue to be provided through the tensorflow/io GitHub project.") def __init__(self, stream, shard="", diff --git a/tensorflow/contrib/layers/python/layers/layers_test.py b/tensorflow/contrib/layers/python/layers/layers_test.py index 0a4d2c6d4cb5cad7da93cea89478bc0fca2ac4d6..d791418c9d0f887058ceb535092fa8122da1aa75 100644 --- a/tensorflow/contrib/layers/python/layers/layers_test.py +++ b/tensorflow/contrib/layers/python/layers/layers_test.py @@ -1459,13 +1459,6 @@ class DropoutTest(test.TestCase): class FlattenTest(test.TestCase): - def testInvalidRank(self): - with ops.Graph().as_default() as g, self.session(g): - inputs = array_ops.placeholder(dtype=dtypes.float32) - inputs.set_shape(tensor_shape.TensorShape((5,))) - with self.assertRaisesRegexp(ValueError, 'incompatible with the layer'): - _layers.flatten(inputs) - def testUnknownLastDim(self): with ops.Graph().as_default() as g, self.session(g): inputs = array_ops.placeholder(dtype=dtypes.float32) @@ -1502,6 +1495,12 @@ class FlattenTest(test.TestCase): images.get_shape().num_elements()) self.assertEqual(output.get_shape()[0], images.get_shape()[0]) + def testFlatten0D(self): + with self.cached_session(): + scalars = random_ops.random_uniform((5,), seed=1, name='scalars') + output = _layers.flatten(scalars) + self.assertEqual(output.shape, (5, 1)) + def testFlattenBatchSize(self): height, width = 3, 3 with self.cached_session() as sess: diff --git a/tensorflow/contrib/linear_optimizer/python/kernel_tests/sdca_ops_test.py b/tensorflow/contrib/linear_optimizer/python/kernel_tests/sdca_ops_test.py index 8466dc36d13e223aed4f1dfe8e39a6f91c99fa55..d49834dc860a8b4341ddd3720fde52281f7474f7 100644 --- a/tensorflow/contrib/linear_optimizer/python/kernel_tests/sdca_ops_test.py +++ b/tensorflow/contrib/linear_optimizer/python/kernel_tests/sdca_ops_test.py @@ -12,7 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Tests for SdcaModel.""" +"""Tests for SdcaModel (deprecated). + +This module and all its submodules are deprecated. To UPDATE or USE linear +optimizers, please check its latest version in core: +tensorflow_estimator/python/estimator/canned/linear_optimizer/. +""" from __future__ import absolute_import from __future__ import division diff --git a/tensorflow/contrib/linear_optimizer/python/ops/sdca_ops.py b/tensorflow/contrib/linear_optimizer/python/ops/sdca_ops.py index f3f1dcd98db5ae24af154d1f0851a0688d2bc611..c056a12fa5307a7e9ac4cf30e1386ddfd5cd7d75 100644 --- a/tensorflow/contrib/linear_optimizer/python/ops/sdca_ops.py +++ b/tensorflow/contrib/linear_optimizer/python/ops/sdca_ops.py @@ -12,7 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Proximal stochastic dual coordinate ascent optimizer for linear models.""" +# pylint: disable=line-too-long +"""Proximal stochastic dual coordinate ascent optimizer for linear models (deprecated). + +This module and all its submodules are deprecated. To UPDATE or USE linear +optimizers, please check its latest version in core: +tensorflow_estimator/python/estimator/canned/linear_optimizer/. +""" +# pylint: enable=line-too-long from __future__ import absolute_import from __future__ import division from __future__ import print_function @@ -40,6 +47,7 @@ from tensorflow.python.ops import variables as var_ops from tensorflow.python.ops.nn import log_poisson_loss from tensorflow.python.ops.nn import sigmoid_cross_entropy_with_logits from tensorflow.python.summary import summary +from tensorflow.python.util import deprecation __all__ = ['SdcaModel'] @@ -48,7 +56,7 @@ __all__ = ['SdcaModel'] class SdcaModel(object): """Stochastic dual coordinate ascent solver for linear models. - Loss functions supported: + Loss functions supported: * Binary logistic loss * Squared loss @@ -109,6 +117,10 @@ class SdcaModel(object): ``` """ + @deprecation.deprecated( + None, 'This class is deprecated. To UPDATE or USE linear optimizers, ' + 'please check its latest version in core: ' + 'tensorflow_estimator/python/estimator/canned/linear_optimizer/.') def __init__(self, examples, variables, options): """Create a new sdca optimizer.""" diff --git a/tensorflow/contrib/linear_optimizer/python/ops/sharded_mutable_dense_hashtable.py b/tensorflow/contrib/linear_optimizer/python/ops/sharded_mutable_dense_hashtable.py index a001555e8f257c88a52fdb40d4181f5cd9c92e84..a28394964a12013c43d85701b5a0ab5c559afd62 100644 --- a/tensorflow/contrib/linear_optimizer/python/ops/sharded_mutable_dense_hashtable.py +++ b/tensorflow/contrib/linear_optimizer/python/ops/sharded_mutable_dense_hashtable.py @@ -12,7 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Sharded mutable dense hash table.""" +"""Sharded mutable dense hash table (deprecated). + +This module and all its submodules are deprecated. To UPDATE or USE linear +optimizers, please check its latest version in core: +tensorflow_estimator/python/estimator/canned/linear_optimizer/. +""" from __future__ import absolute_import from __future__ import division @@ -28,6 +33,7 @@ from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import data_flow_ops from tensorflow.python.ops import math_ops +from tensorflow.python.util import deprecation # TODO(rohanj): This should subclass Checkpointable and implement @@ -45,6 +51,10 @@ class ShardedMutableDenseHashTable(object): # TODO(andreasst): consider moving this to lookup module + @deprecation.deprecated( + None, 'This class is deprecated. To UPDATE or USE linear optimizers, ' + 'please check its latest version in core: ' + 'tensorflow_estimator/python/estimator/canned/linear_optimizer/.') def __init__(self, key_dtype, value_dtype, diff --git a/tensorflow/contrib/linear_optimizer/python/ops/sharded_mutable_dense_hashtable_test.py b/tensorflow/contrib/linear_optimizer/python/ops/sharded_mutable_dense_hashtable_test.py index 2b56d0fa3a8b8564b7c73a62bd99cc900d6f5c54..2d1457f9e4cc576da696be191e718814dd9ff4e5 100644 --- a/tensorflow/contrib/linear_optimizer/python/ops/sharded_mutable_dense_hashtable_test.py +++ b/tensorflow/contrib/linear_optimizer/python/ops/sharded_mutable_dense_hashtable_test.py @@ -12,7 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Tests for sharded_mutable_dense_hashtable.py.""" +"""Tests for sharded_mutable_dense_hashtable.py (deprecated). + +This module and all its submodules are deprecated. To UPDATE or USE linear +optimizers, please check its latest version in core: +tensorflow_estimator/python/estimator/canned/linear_optimizer/. +""" from __future__ import absolute_import from __future__ import division diff --git a/tensorflow/contrib/linear_optimizer/python/ops/sparse_feature_column.py b/tensorflow/contrib/linear_optimizer/python/ops/sparse_feature_column.py index 003795233ff2b28e33fc10388ef25efb63c43bb0..64730f8eed1ff9bfcd4a980dceb28abb98e39f73 100644 --- a/tensorflow/contrib/linear_optimizer/python/ops/sparse_feature_column.py +++ b/tensorflow/contrib/linear_optimizer/python/ops/sparse_feature_column.py @@ -12,7 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Sparse feature column.""" +"""Sparse feature column (deprecated). + +This module and all its submodules are deprecated. To UPDATE or USE linear +optimizers, please check its latest version in core: +tensorflow_estimator/python/estimator/canned/linear_optimizer/. +""" from __future__ import absolute_import from __future__ import division @@ -21,6 +26,7 @@ from __future__ import print_function from tensorflow.python.framework import dtypes from tensorflow.python.framework.ops import internal_convert_to_tensor from tensorflow.python.framework.ops import name_scope +from tensorflow.python.util import deprecation class SparseFeatureColumn(object): @@ -68,6 +74,10 @@ class SparseFeatureColumn(object): @@feature_values """ + @deprecation.deprecated( + None, 'This class is deprecated. To UPDATE or USE linear optimizers, ' + 'please check its latest version in core: ' + 'tensorflow_estimator/python/estimator/canned/linear_optimizer/.') def __init__(self, example_indices, feature_indices, feature_values): """Creates a `SparseFeatureColumn` representation. diff --git a/tensorflow/contrib/linear_optimizer/python/ops/sparse_feature_column_test.py b/tensorflow/contrib/linear_optimizer/python/ops/sparse_feature_column_test.py index 51c4f68543da2f563481cc2d35b556796616cf9d..0ae780e1a100c7dadde7196803f2ae0d4bcb2334 100644 --- a/tensorflow/contrib/linear_optimizer/python/ops/sparse_feature_column_test.py +++ b/tensorflow/contrib/linear_optimizer/python/ops/sparse_feature_column_test.py @@ -12,7 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Tests for sparse_feature_column.py.""" +"""Tests for sparse_feature_column.py (deprecated). + +This module and all its submodules are deprecated. To UPDATE or USE linear +optimizers, please check its latest version in core: +tensorflow_estimator/python/estimator/canned/linear_optimizer/. +""" from __future__ import absolute_import from __future__ import division diff --git a/tensorflow/contrib/lookup/lookup_ops.py b/tensorflow/contrib/lookup/lookup_ops.py index e52fb5ab1431e086f99b4033a6216636a83bad79..229a72a780d5ccce8263444ffeae7700f6ac8613 100644 --- a/tensorflow/contrib/lookup/lookup_ops.py +++ b/tensorflow/contrib/lookup/lookup_ops.py @@ -91,7 +91,7 @@ def index_table_from_tensor(mapping, The bucket ID range is `[mapping size, mapping size + num_oov_buckets - 1]`. The underlying table must be initialized by calling - `tf.tables_initializer.run()` or `table.initializer.run()` once. + `session.run(tf.tables_initializer)` or `session.run(table.init)` once. Elements in `mapping` cannot have duplicates, otherwise when executing the table initializer op, it will throw a `FailedPreconditionError`. @@ -158,7 +158,7 @@ def string_to_index(tensor, mapping, default_value=-1, name=None): will throw a FailedPreconditionError. The underlying table must be initialized by calling - `tf.tables_initializer.run()` once. + `session.run(tf.tables_initializer)` once. For example: @@ -202,7 +202,7 @@ def index_to_string_table_from_tensor(mapping, default_value="UNK", name=None): (an out-of-vocabulary entry) is assigned the `default_value` The underlying table must be initialized by calling - `tf.tables_initializer.run()` or `table.initializer.run()` once. + `session.run(tf.tables_initializer)` or `session.run(table.init)` once. Elements in `mapping` cannot have duplicates, otherwise when executing the table initializer op, it will throw a `FailedPreconditionError`. @@ -257,7 +257,7 @@ def index_to_string(tensor, mapping, default_value="UNK", name=None): (an out-of-vocabulary entry) is assigned the `default_value` The underlying table must be initialized by calling - `tf.tables_initializer.run()` once. + `session.run(tf.tables_initializer)` once. For example: diff --git a/tensorflow/contrib/makefile/download_dependencies.sh b/tensorflow/contrib/makefile/download_dependencies.sh index b396c527673902d61072dc9cf7d2766476be8369..2a5232b476712a96f84be0f4725beb78bc138297 100755 --- a/tensorflow/contrib/makefile/download_dependencies.sh +++ b/tensorflow/contrib/makefile/download_dependencies.sh @@ -30,11 +30,13 @@ EIGEN_URL="$(grep -o 'http.*bitbucket.org/eigen/eigen/get/.*tar\.gz' "${BZL_FILE GEMMLOWP_URL="$(grep -o 'https://mirror.bazel.build/github.com/google/gemmlowp/.*zip' "${BZL_FILE_PATH}" | head -n1)" GOOGLETEST_URL="https://github.com/google/googletest/archive/release-1.8.0.tar.gz" NSYNC_URL="$(grep -o 'https://mirror.bazel.build/github.com/google/nsync/.*tar\.gz' "${BZL_FILE_PATH}" | head -n1)" -# Note: The Protobuf source in `tensorflow/workspace.bzl` in TensorFlow -# 1.10 branch does not work. `make distclean` fails and blocks the build -# process. For now we're hardcoding to the version which is used by -# TensorFlow 1.9. -PROTOBUF_URL="https://mirror.bazel.build/github.com/google/protobuf/archive/v3.6.0.tar.gz" + +# Note: The protobuf repo needs to be cloned due to its submodules. +# These variables contain the GitHub repo and the sha, from `tensorflow/workspace.bzl`, +# from which to clone it from and checkout to. +readonly PROTOBUF_REPO="https://github.com/protocolbuffers/protobuf.git" +readonly PROTOBUF_TAG="$(grep -o 'https://github.com/protocolbuffers/protobuf/archive/.*tar\.gz' "${BZL_FILE_PATH}" | head -n1 | awk '{print substr($0, index($0, "archive") + 8, index($0, "tar") - index($0, "archive") - 9) }')" + # TODO (yongtang): Replace the following with 'https://mirror.bazel.build/github.com/google/re2/.*tar\.gz' once # the archive has been propagated in mirror.bazel.build. RE2_URL="$(grep -o 'https://github.com/google/re2/.*tar\.gz' "${BZL_FILE_PATH}" | head -n1)" @@ -91,11 +93,34 @@ download_and_extract() { find "${dir}" -type f -name '*BUILD' -delete } +function clone_repository() { + local repo_url="${1}" + local destination_directory="${2}" + local commit_sha="${3}" + + if [[ -d "${destination_directory}" ]]; then + rm -rf "${destination_directory}" + fi + + git clone "${repo_url}" "${destination_directory}" + + pushd "$(pwd)" 1>/dev/null + + cd "${destination_directory}" + + if [[ -n "${commit_sha}" ]]; then + git checkout "${PROTOBUF_TAG}" + fi + + git submodule update --init + + popd 1>/dev/null +} + download_and_extract "${EIGEN_URL}" "${DOWNLOADS_DIR}/eigen" download_and_extract "${GEMMLOWP_URL}" "${DOWNLOADS_DIR}/gemmlowp" download_and_extract "${GOOGLETEST_URL}" "${DOWNLOADS_DIR}/googletest" download_and_extract "${NSYNC_URL}" "${DOWNLOADS_DIR}/nsync" -download_and_extract "${PROTOBUF_URL}" "${DOWNLOADS_DIR}/protobuf" download_and_extract "${RE2_URL}" "${DOWNLOADS_DIR}/re2" download_and_extract "${FFT2D_URL}" "${DOWNLOADS_DIR}/fft2d" download_and_extract "${DOUBLE_CONVERSION_URL}" "${DOWNLOADS_DIR}/double_conversion" @@ -106,6 +131,8 @@ download_and_extract "${CUB_URL}" "${DOWNLOADS_DIR}/cub/external/cub_archive" download_and_extract "${FARMHASH_URL}" "${DOWNLOADS_DIR}/farmhash" download_and_extract "${FLATBUFFERS_URL}" "${DOWNLOADS_DIR}/flatbuffers" +clone_repository "${PROTOBUF_REPO}" "${DOWNLOADS_DIR}/protobuf" "${PROTOBUF_TAG}" + replace_by_sed 's#static uint32x4_t p4ui_CONJ_XOR = vld1q_u32( conj_XOR_DATA );#static uint32x4_t p4ui_CONJ_XOR; // = vld1q_u32( conj_XOR_DATA ); - Removed by script#' \ "${DOWNLOADS_DIR}/eigen/Eigen/src/Core/arch/NEON/Complex.h" replace_by_sed 's#static uint32x2_t p2ui_CONJ_XOR = vld1_u32( conj_XOR_DATA );#static uint32x2_t p2ui_CONJ_XOR;// = vld1_u32( conj_XOR_DATA ); - Removed by scripts#' \ diff --git a/tensorflow/contrib/makefile/tf_op_files.txt b/tensorflow/contrib/makefile/tf_op_files.txt index 655c7eefcb978d40c8bc16a23685e03ed71bfb63..2cd7d6d519a55423a96526b541845392d9ec6bc2 100644 --- a/tensorflow/contrib/makefile/tf_op_files.txt +++ b/tensorflow/contrib/makefile/tf_op_files.txt @@ -119,6 +119,7 @@ tensorflow/core/kernels/fake_quant_ops.cc tensorflow/core/kernels/fifo_queue.cc tensorflow/core/kernels/fifo_queue_op.cc tensorflow/core/kernels/fill_functor.cc +tensorflow/core/kernels/fft_ops.cc tensorflow/core/kernels/function_ops.cc tensorflow/core/kernels/fused_batch_norm_op.cc tensorflow/core/kernels/gather_functor.cc diff --git a/tensorflow/contrib/metrics/python/metrics/classification.py b/tensorflow/contrib/metrics/python/metrics/classification.py index 062deb74b165329d8e72efa73b9d81f4174f8831..9aabc4bec3053871e3ff6cd3a88fd76d293f48cc 100644 --- a/tensorflow/contrib/metrics/python/metrics/classification.py +++ b/tensorflow/contrib/metrics/python/metrics/classification.py @@ -18,13 +18,13 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from tensorflow.python.distribute import distribution_strategy_context from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import metrics_impl from tensorflow.python.ops import variable_scope -from tensorflow.python.training import distribution_strategy_context # TODO(nsilberman): move into metrics/python/ops/ diff --git a/tensorflow/contrib/model_pruning/README.md b/tensorflow/contrib/model_pruning/README.md index 45a60d79482787df4564ae3360f8252af93c7a26..710a262f33872ada8d090d796f80dc06c2a27f84 100644 --- a/tensorflow/contrib/model_pruning/README.md +++ b/tensorflow/contrib/model_pruning/README.md @@ -53,7 +53,6 @@ The pruning library allows for specification of the following hyper parameters: | weight_sparsity_map | list of strings | [""] | list of weight variable name (or layer name):target sparsity pairs. Eg. [conv1:0.9,conv2/kernel:0.8]. For layers/weights not in this list, sparsity as specified by the target_sparsity hyperparameter is used. | | threshold_decay | float | 0.0 | The decay factor to use for exponential decay of the thresholds | | pruning_frequency | integer | 10 | How often should the masks be updated? (in # of global_steps) | -| nbins | integer | 256 | Number of bins to use for histogram computation. Note: When running on TPUs, a large (>1024) value for `nbins` may adversely affect the training time. | | block_height|integer | 1 | Number of rows in a block for block sparse matrices| | block_width |integer | 1 | Number of cols in a block for block sparse matrices| | block_pooling_function| string | AVG | The function to use to pool weight values in a block: average (AVG) or max (MAX)| diff --git a/tensorflow/contrib/model_pruning/python/pruning.py b/tensorflow/contrib/model_pruning/python/pruning.py index f6b4373edd0544555dd16a373802d2feb5d674b1..9966f7cf798d206fffbaeb4d16b6500a90d113e4 100644 --- a/tensorflow/contrib/model_pruning/python/pruning.py +++ b/tensorflow/contrib/model_pruning/python/pruning.py @@ -214,7 +214,7 @@ def get_pruning_hparams(): target_sparsity=0.5, sparsity_function_begin_step=0, sparsity_function_end_step=100, - sparsity_function_exponent=3, + sparsity_function_exponent=3.0, use_tpu=False) @@ -397,28 +397,26 @@ class Pruning(object): raise ValueError('Sparsity variable undefined') sparsity = self._get_sparsity(weights.op.name) - with ops.name_scope(weights.op.name + '_pruning_ops'): abs_weights = math_ops.abs(weights) - max_value = math_ops.reduce_max(abs_weights) - cdf_fn = pruning_utils.compute_cdf_from_histogram - if self._spec.use_tpu: - cdf_fn = pruning_utils.compute_cdf - - norm_cdf = cdf_fn(abs_weights, [0.0, max_value], nbins=self._spec.nbins) - current_threshold = math_ops.multiply( - math_ops.div( - math_ops.reduce_sum( - math_ops.cast( - math_ops.less(norm_cdf, sparsity), dtypes.float32)), - float(self._spec.nbins)), max_value) - + k = math_ops.cast( + math_ops.round( + math_ops.cast(array_ops.size(abs_weights), dtypes.float32) * + (1 - sparsity)), dtypes.int32) + # Sort the entire array + values, _ = nn_ops.top_k( + array_ops.reshape(abs_weights, [-1]), k=array_ops.size(abs_weights)) + # Grab the (k-1) th value + current_threshold = array_ops.gather(values, k - 1) smoothed_threshold = math_ops.add_n([ math_ops.multiply(current_threshold, 1 - self._spec.threshold_decay), math_ops.multiply(threshold, self._spec.threshold_decay) ]) + new_mask = math_ops.cast( - math_ops.greater(abs_weights, smoothed_threshold), dtypes.float32) + math_ops.greater_equal(abs_weights, smoothed_threshold), + dtypes.float32) + return smoothed_threshold, new_mask def _maybe_update_block_mask(self, weights, threshold): diff --git a/tensorflow/contrib/model_pruning/python/pruning_test.py b/tensorflow/contrib/model_pruning/python/pruning_test.py index 1b6da5ce2b4ebb3ea3b204c4ed12bed8db951447..835614d8822147dadb029107ae0e917cc955eef0 100644 --- a/tensorflow/contrib/model_pruning/python/pruning_test.py +++ b/tensorflow/contrib/model_pruning/python/pruning_test.py @@ -102,7 +102,7 @@ class PruningTest(test.TestCase): weights = variables.VariableV1( math_ops.linspace(1.0, 100.0, 100), name="weights") masked_weights = pruning.apply_mask(weights) - sparsity = variables.VariableV1(0.5, name="sparsity") + sparsity = variables.VariableV1(0.95, name="sparsity") p = pruning.Pruning(sparsity=sparsity) p._spec.threshold_decay = 0.0 mask_update_op = p.mask_update_op() @@ -111,7 +111,7 @@ class PruningTest(test.TestCase): self.assertAllEqual(np.count_nonzero(masked_weights_val), 100) session.run(mask_update_op) masked_weights_val = masked_weights.eval() - self.assertAllEqual(np.count_nonzero(masked_weights_val), 50) + self.assertAllEqual(np.count_nonzero(masked_weights_val), 5) def _blockMasking(self, hparams, weights, expected_mask): diff --git a/tensorflow/contrib/model_pruning/python/pruning_utils.py b/tensorflow/contrib/model_pruning/python/pruning_utils.py index 14fc51229ab53a77e8089040e8a8576babd0fafd..8f2ba036469bd02328a831a3d1de2ffbd10f5004 100644 --- a/tensorflow/contrib/model_pruning/python/pruning_utils.py +++ b/tensorflow/contrib/model_pruning/python/pruning_utils.py @@ -25,16 +25,12 @@ from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops -from tensorflow.python.ops import clip_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import init_ops -from tensorflow.python.ops import math_ops from tensorflow.python.ops import nn_ops from tensorflow.python.ops import state_ops from tensorflow.python.ops import variable_scope -_NBINS = 256 - def weight_mask_variable(var, scope): """Create a mask for the weights. @@ -165,128 +161,6 @@ def expand_tensor(tensor, block_dims): return expanded_tensor -def _histogram(values, value_range, nbins=100, dtype=dtypes.int32, name=None): - """Return histogram of values. - - Given the tensor `values`, this operation returns a rank 1 histogram counting - the number of entries in `values` that fell into every bin. The bins are - equal width and determined by the arguments `value_range` and `nbins`. - - Args: - values: Numeric `Tensor`. - value_range: Shape [2] `Tensor` of same `dtype` as `values`. - values <= value_range[0] will be mapped to hist[0], - values >= value_range[1] will be mapped to hist[-1]. - nbins: Scalar `int32 Tensor`. Number of histogram bins. - dtype: dtype for returned histogram. - name: A name for this operation (defaults to 'histogram'). - - Returns: - A 1-D `Tensor` holding histogram of values. - - """ - with ops.name_scope(name, 'histogram', [values, value_range, nbins]) as scope: - values = ops.convert_to_tensor(values, name='values') - values = array_ops.reshape(values, [-1]) - nbins_float = np.float32(nbins) - - # Map tensor values that fall within value_range to [0, 1]. - scaled_values = math_ops.truediv( - values - value_range[0], - value_range[1] - value_range[0], - name='scaled_values') - - # map tensor values within the open interval value_range to {0,.., nbins-1}, - # values outside the open interval will be zero or less, or nbins or more. - indices = math_ops.floor(nbins_float * scaled_values, name='indices') - - # Clip edge cases (e.g. value = value_range[1]) or "outliers." - indices = math_ops.cast( - clip_ops.clip_by_value(indices, 0, nbins_float - 1), dtypes.int32) - - return math_ops.unsorted_segment_sum( - array_ops.ones_like(indices, dtype=dtype), indices, nbins, name=scope) - - -def compute_cdf_from_histogram(values, value_range, **kwargs): - """Returns the normalized cumulative distribution of the given values tensor. - - Computes the histogram and uses tf.cumsum to arrive at cdf - - Args: - values: Numeric `Tensor`. - value_range: Shape [2] `Tensor` of same `dtype` as `values`. - **kwargs: keyword arguments: nbins, name - - Returns: - A 1-D `Tensor` holding normalized cdf of values. - - """ - nbins = kwargs.get('nbins', _NBINS) - name = kwargs.get('name', None) - with ops.name_scope(name, 'cdf', [values, value_range, nbins]): - histogram = _histogram( - values, value_range, dtype=dtypes.float32, nbins=nbins) - cdf = math_ops.cumsum(histogram) - return math_ops.div(cdf, math_ops.reduce_max(cdf)) - - -def compute_cdf(values, value_range, **kwargs): - """Returns the normalized cumulative distribution of the given values tensor. - - Uses tf.while_loop to directly compute the cdf of the values. - - Args: - values: Numeric `Tensor`. - value_range: Shape [2] `Tensor` of same `dtype` as `values` - **kwargs: keyword arguments: nbins, name - - Returns: - A 1-D `Tensor` holding normalized cdf of values. - - """ - nbins = kwargs.get('nbins', _NBINS) - name = kwargs.get('name', None) - with ops.name_scope(name, 'cdf', [values, value_range, nbins]): - values = ops.convert_to_tensor(values, name='values') - nbins_float = np.float32(nbins) - - # Map tensor values that fall within value_range to [0, 1]. - scaled_values = math_ops.truediv( - values - value_range[0], - value_range[1] - value_range[0], - name='scaled_values') - - # map tensor values within the open interval value_range to {0,.., nbins-1}, - # values outside the open interval will be zero or less, or nbins or more. - indices = math_ops.floor(nbins_float * scaled_values, name='indices') - - # Clip edge cases (e.g. value = value_range[1]) or "outliers." - indices = math_ops.cast( - clip_ops.clip_by_value(indices, 0, nbins_float - 1), dtypes.int32) - - cdf = array_ops.zeros(nbins) - i = constant_op.constant(0) - - def loop_cond(loop_count, _): - return math_ops.less(loop_count, nbins) - - def loop_body(loop_count, cdf): - temp = math_ops.reduce_sum( - math_ops.cast( - math_ops.less_equal(indices, loop_count), dtypes.float32)) - cdf = math_ops.add( - cdf, - array_ops.one_hot( - loop_count, depth=nbins, on_value=temp, off_value=0.0)) - return [loop_count + 1, cdf] - - _, cdf = control_flow_ops.while_loop( - loop_cond, loop_body, [i, cdf], maximum_iterations=nbins) - - return math_ops.div(cdf, math_ops.reduce_max(cdf)) - - def factorized_pool(input_tensor, window_shape, pooling_type, diff --git a/tensorflow/contrib/model_pruning/python/pruning_utils_test.py b/tensorflow/contrib/model_pruning/python/pruning_utils_test.py index d6f2bfcb6c2e2beda912eb538d8a4a0a17b486b3..b85bc413155d53cd6d53e98dae0ad626531f61eb 100644 --- a/tensorflow/contrib/model_pruning/python/pruning_utils_test.py +++ b/tensorflow/contrib/model_pruning/python/pruning_utils_test.py @@ -19,13 +19,9 @@ from __future__ import division from __future__ import print_function from absl.testing import parameterized -import numpy as np from tensorflow.contrib.model_pruning.python import pruning_utils -from tensorflow.python.framework import constant_op from tensorflow.python.ops import array_ops -from tensorflow.python.ops import init_ops -from tensorflow.python.ops import math_ops from tensorflow.python.ops import nn_ops from tensorflow.python.ops import random_ops from tensorflow.python.ops import variable_scope @@ -33,57 +29,6 @@ from tensorflow.python.ops import variables from tensorflow.python.platform import test -class PruningUtilsTest(test.TestCase): - - def _compare_cdf(self, values): - abs_values = math_ops.abs(values) - max_value = math_ops.reduce_max(abs_values) - with self.cached_session(): - variables.global_variables_initializer().run() - cdf_from_histogram = pruning_utils.compute_cdf_from_histogram( - abs_values, [0.0, max_value], nbins=pruning_utils._NBINS) - cdf = pruning_utils.compute_cdf(abs_values, [0.0, max_value]) - self.assertAllEqual(cdf.eval(), cdf_from_histogram.eval()) - - def testHistogram(self): - width = 10 - height = 10 - nbins = 100 - expected_histogram = np.full(nbins, 1.0) - init = init_ops.constant_initializer(np.linspace(0.0, 1.0, width * height)) - weights = variable_scope.get_variable( - "weights", [width, height], initializer=init) - histogram = pruning_utils._histogram( - weights, [0, 1.0], nbins, dtype=np.float32) - with self.cached_session(): - variables.global_variables_initializer().run() - computed_histogram = histogram.eval() - self.assertAllEqual(expected_histogram, computed_histogram) - - def testCDF(self): - nbins = 5 - weights = constant_op.constant([-1, 0, 1, 1.5, 2, 3, 4, 5, 10, 100]) - abs_weights = math_ops.abs(weights) - norm_cdf = pruning_utils.compute_cdf_from_histogram( - abs_weights, [0.0, 5.0], nbins=nbins) - expected_cdf = np.array([0.1, 0.4, 0.5, 0.6, 1.0], dtype=np.float32) - with self.cached_session() as sess: - variables.global_variables_initializer().run() - norm_cdf_val = sess.run(norm_cdf) - self.assertAllEqual(len(norm_cdf_val), nbins) - self.assertAllEqual(expected_cdf, norm_cdf_val) - - def testCDFEquivalence2D(self): - width = 100 - height = 100 - weights = variable_scope.get_variable("weights", shape=[width, height]) - self._compare_cdf(weights) - - def testCDFEquivalence4D(self): - weights = variable_scope.get_variable("weights", shape=[5, 5, 128, 128]) - self._compare_cdf(weights) - - @parameterized.named_parameters( ("Input_32x32_block_1x1", [32, 32], [1, 1]), # block size 6x6 diff --git a/tensorflow/contrib/opt/BUILD b/tensorflow/contrib/opt/BUILD index f4ac70eb1a720c2acc3ef942f269228156749cba..0446e823d95f8ecbed6a0c34a83ade009e68448b 100644 --- a/tensorflow/contrib/opt/BUILD +++ b/tensorflow/contrib/opt/BUILD @@ -14,6 +14,7 @@ py_library( name = "opt_py", srcs = [ "__init__.py", + "python/training/adam_gs_optimizer.py", "python/training/adamax.py", "python/training/addsign.py", "python/training/agn_optimizer.py", @@ -22,6 +23,7 @@ py_library( "python/training/external_optimizer.py", "python/training/ggt.py", "python/training/lars_optimizer.py", + "python/training/lazy_adam_gs_optimizer.py", "python/training/lazy_adam_optimizer.py", "python/training/matrix_functions.py", "python/training/model_average_optimizer.py", @@ -60,6 +62,21 @@ py_library( ], ) +py_test( + name = "adam_gs_optimizer_test", + srcs = ["python/training/adam_gs_optimizer_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":opt_py", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:math_ops", + "//tensorflow/python:training", + "//third_party/py/numpy", + ], +) + py_test( name = "adamax_test", srcs = ["python/training/adamax_test.py"], @@ -148,6 +165,25 @@ py_test( ], ) +py_test( + name = "lazy_adam_gs_optimizer_test", + srcs = ["python/training/lazy_adam_gs_optimizer_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":opt_py", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:constant_op", + "//tensorflow/python:dtypes", + "//tensorflow/python:framework_ops", + "//tensorflow/python:math_ops", + "//tensorflow/python:resource_variable_ops", + "//tensorflow/python:variables", + "//third_party/py/numpy", + "@absl_py//absl/testing:parameterized", + ], +) + py_test( name = "lazy_adam_optimizer_test", srcs = ["python/training/lazy_adam_optimizer_test.py"], diff --git a/tensorflow/contrib/opt/__init__.py b/tensorflow/contrib/opt/__init__.py index c7ea68efa9a13a471bba3f41d0600855793b20a2..e8fc52342ceabb47da97ca0f3c8a01e419a221a1 100644 --- a/tensorflow/contrib/opt/__init__.py +++ b/tensorflow/contrib/opt/__init__.py @@ -19,6 +19,7 @@ from __future__ import division from __future__ import print_function # pylint: disable=wildcard-import +from tensorflow.contrib.opt.python.training.adam_gs_optimizer import * from tensorflow.contrib.opt.python.training.adamax import * from tensorflow.contrib.opt.python.training.addsign import * from tensorflow.contrib.opt.python.training.agn_optimizer import * @@ -28,6 +29,7 @@ from tensorflow.contrib.opt.python.training.external_optimizer import * from tensorflow.contrib.opt.python.training.lars_optimizer import * from tensorflow.contrib.opt.python.training.ggt import * from tensorflow.contrib.opt.python.training.lazy_adam_optimizer import * +from tensorflow.contrib.opt.python.training.lazy_adam_gs_optimizer import * from tensorflow.contrib.opt.python.training.model_average_optimizer import * from tensorflow.contrib.opt.python.training.moving_average_optimizer import * from tensorflow.contrib.opt.python.training.multitask_optimizer_wrapper import * @@ -44,12 +46,14 @@ from tensorflow.python.util.all_util import remove_undocumented _allowed_symbols = [ 'AdaMaxOptimizer', + 'AdamGSOptimizer', 'PowerSignOptimizer', 'AddSignOptimizer', 'DelayCompensatedGradientDescentOptimizer', 'DropStaleGradientOptimizer', 'ExternalOptimizerInterface', 'LARSOptimizer', + 'LazyAdamGSOptimizer', 'LazyAdamOptimizer', 'NadamOptimizer', 'MovingAverageOptimizer', diff --git a/tensorflow/contrib/opt/python/training/adam_gs_optimizer.py b/tensorflow/contrib/opt/python/training/adam_gs_optimizer.py new file mode 100644 index 0000000000000000000000000000000000000000..3fb649ea82e79b3bc78a2da6d5c3e9a071adec6d --- /dev/null +++ b/tensorflow/contrib/opt/python/training/adam_gs_optimizer.py @@ -0,0 +1,217 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Adam rewrite to use global step for computing beta1 & beta2 accumulation.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import resource_variable_ops +from tensorflow.python.ops import state_ops +from tensorflow.python.training import optimizer +from tensorflow.python.training import training_ops +from tensorflow.python.util.tf_export import tf_export + + +@tf_export("train.AdamOptimizer") +class AdamGSOptimizer(optimizer.Optimizer): + """Optimizer that implements the Adam algorithm. + + See [Kingma et al., 2014](http://arxiv.org/abs/1412.6980) + ([pdf](http://arxiv.org/pdf/1412.6980.pdf)). + """ + + def __init__(self, global_step=0, learning_rate=0.001, + beta1=0.9, beta2=0.999, epsilon=1e-8, + use_locking=False, name="Adam"): + """Construct a new Adam optimizer. + + Branched from tf.train.AdamOptimizer. The only difference is to pass + global step for computing beta1 and beta2 accumulators, instead of having + optimizer keep its own independent beta1 and beta2 accumulators as non-slot + variables. + + Initialization: + + $$m_0 := 0 \text{(Initialize initial 1st moment vector)}$$ + $$v_0 := 0 \text{(Initialize initial 2nd moment vector)}$$ + $$t := 0 \text{(Initialize timestep)}$$ + + The update rule for `variable` with gradient `g` uses an optimization + described at the end of section2 of the paper: + + $$t := t + 1$$ + $$lr_t := \text{learning\_rate} * \sqrt{1 - beta_2^t} / (1 - beta_1^t)$$ + + $$m_t := beta_1 * m_{t-1} + (1 - beta_1) * g$$ + $$v_t := beta_2 * v_{t-1} + (1 - beta_2) * g * g$$ + $$variable := variable - lr_t * m_t / (\sqrt{v_t} + \epsilon)$$ + + The default value of 1e-8 for epsilon might not be a good default in + general. For example, when training an Inception network on ImageNet a + current good choice is 1.0 or 0.1. Note that since AdamOptimizer uses the + formulation just before Section 2.1 of the Kingma and Ba paper rather than + the formulation in Algorithm 1, the "epsilon" referred to here is "epsilon + hat" in the paper. + + The sparse implementation of this algorithm (used when the gradient is an + IndexedSlices object, typically because of `tf.gather` or an embedding + lookup in the forward pass) does apply momentum to variable slices even if + they were not used in the forward pass (meaning they have a gradient equal + to zero). Momentum decay (beta1) is also applied to the entire momentum + accumulator. This means that the sparse behavior is equivalent to the dense + behavior (in contrast to some momentum implementations which ignore momentum + unless a variable slice was actually used). + + Args: + global_step: tensorflow variable indicating the step. + learning_rate: A Tensor or a floating point value. The learning rate. + beta1: A float value or a constant float tensor. + The exponential decay rate for the 1st moment estimates. + beta2: A float value or a constant float tensor. + The exponential decay rate for the 2nd moment estimates. + epsilon: A small constant for numerical stability. This epsilon is + "epsilon hat" in the Kingma and Ba paper (in the formula just before + Section 2.1), not the epsilon in Algorithm 1 of the paper. + use_locking: If True use locks for update operations. + name: Optional name for the operations created when applying gradients. + Defaults to "Adam". + + @compatibility(eager) + When eager execution is enabled, `learning_rate`, `beta1`, `beta2`, and + `epsilon` can each be a callable that takes no arguments and returns the + actual value to use. This can be useful for changing these values across + different invocations of optimizer functions. + @end_compatibility + """ + super(AdamGSOptimizer, self).__init__(use_locking, name) + self._lr = learning_rate + self._beta1 = beta1 + self._beta2 = beta2 + self._epsilon = epsilon + self._global_step = global_step + self._global_step_on_worker = None + + # Tensor versions of the constructor arguments, created in _prepare(). + self._lr_t = None + self._beta1_t = None + self._beta2_t = None + self._epsilon_t = None + + # Created in SparseApply if needed. + self._updated_lr = None + + def _get_beta_accumulators(self): + return (math_ops.pow(self._beta1_t, self._global_step_on_worker), + math_ops.pow(self._beta2_t, self._global_step_on_worker)) + + def _create_slots(self, var_list): + # Create slots for the first and second moments. + for v in var_list: + self._zeros_slot(v, "m", self._name) + self._zeros_slot(v, "v", self._name) + + def _prepare(self): + lr = self._call_if_callable(self._lr) + beta1 = self._call_if_callable(self._beta1) + beta2 = self._call_if_callable(self._beta2) + epsilon = self._call_if_callable(self._epsilon) + + self._lr_t = ops.convert_to_tensor(lr, name="learning_rate") + self._beta1_t = ops.convert_to_tensor(beta1, name="beta1") + self._beta2_t = ops.convert_to_tensor(beta2, name="beta2") + self._epsilon_t = ops.convert_to_tensor(epsilon, name="epsilon") + + # Performance optimization so that worker creates a copy of the global step + # to avoid overloading the parameter server holding the global step. + self._global_step_on_worker = math_ops.cast( + array_ops.identity(self._global_step) + 1, dtypes.float32) + + def _apply_dense(self, grad, var): + m = self.get_slot(var, "m") + v = self.get_slot(var, "v") + beta1_power, beta2_power = self._get_beta_accumulators() + return training_ops.apply_adam( + var, m, v, + math_ops.cast(beta1_power, var.dtype.base_dtype), + math_ops.cast(beta2_power, var.dtype.base_dtype), + math_ops.cast(self._lr_t, var.dtype.base_dtype), + math_ops.cast(self._beta1_t, var.dtype.base_dtype), + math_ops.cast(self._beta2_t, var.dtype.base_dtype), + math_ops.cast(self._epsilon_t, var.dtype.base_dtype), + grad, use_locking=self._use_locking).op + + def _resource_apply_dense(self, grad, var): + m = self.get_slot(var, "m") + v = self.get_slot(var, "v") + beta1_power, beta2_power = self._get_beta_accumulators() + return training_ops.resource_apply_adam( + var.handle, m.handle, v.handle, + math_ops.cast(beta1_power, grad.dtype.base_dtype), + math_ops.cast(beta2_power, grad.dtype.base_dtype), + math_ops.cast(self._lr_t, grad.dtype.base_dtype), + math_ops.cast(self._beta1_t, grad.dtype.base_dtype), + math_ops.cast(self._beta2_t, grad.dtype.base_dtype), + math_ops.cast(self._epsilon_t, grad.dtype.base_dtype), + grad, use_locking=self._use_locking) + + def _apply_sparse_shared(self, grad, var, indices, scatter_add): + beta1_power, beta2_power = self._get_beta_accumulators() + beta1_power = math_ops.cast(beta1_power, var.dtype.base_dtype) + beta2_power = math_ops.cast(beta2_power, var.dtype.base_dtype) + lr_t = math_ops.cast(self._lr_t, var.dtype.base_dtype) + beta1_t = math_ops.cast(self._beta1_t, var.dtype.base_dtype) + beta2_t = math_ops.cast(self._beta2_t, var.dtype.base_dtype) + epsilon_t = math_ops.cast(self._epsilon_t, var.dtype.base_dtype) + lr = (lr_t * math_ops.sqrt(1 - beta2_power) / (1 - beta1_power)) + # m_t = beta1 * m + (1 - beta1) * g_t + m = self.get_slot(var, "m") + m_scaled_g_values = grad * (1 - beta1_t) + m_t = state_ops.assign(m, m * beta1_t, + use_locking=self._use_locking) + with ops.control_dependencies([m_t]): + m_t = scatter_add(m, indices, m_scaled_g_values) + # v_t = beta2 * v + (1 - beta2) * (g_t * g_t) + v = self.get_slot(var, "v") + v_scaled_g_values = (grad * grad) * (1 - beta2_t) + v_t = state_ops.assign(v, v * beta2_t, use_locking=self._use_locking) + with ops.control_dependencies([v_t]): + v_t = scatter_add(v, indices, v_scaled_g_values) + v_sqrt = math_ops.sqrt(v_t) + var_update = state_ops.assign_sub(var, + lr * m_t / (v_sqrt + epsilon_t), + use_locking=self._use_locking) + return control_flow_ops.group(*[var_update, m_t, v_t]) + + def _apply_sparse(self, grad, var): + return self._apply_sparse_shared( + grad.values, var, grad.indices, + lambda x, i, v: state_ops.scatter_add( # pylint: disable=g-long-lambda + x, i, v, use_locking=self._use_locking)) + + def _resource_scatter_add(self, x, i, v): + with ops.control_dependencies( + [resource_variable_ops.resource_scatter_add( + x.handle, i, v)]): + return x.value() + + def _resource_apply_sparse(self, grad, var, indices): + return self._apply_sparse_shared( + grad, var, indices, self._resource_scatter_add) diff --git a/tensorflow/contrib/opt/python/training/adam_gs_optimizer_test.py b/tensorflow/contrib/opt/python/training/adam_gs_optimizer_test.py new file mode 100644 index 0000000000000000000000000000000000000000..c68c965aef3729bebe7d0e0dd707c344321d9e3f --- /dev/null +++ b/tensorflow/contrib/opt/python/training/adam_gs_optimizer_test.py @@ -0,0 +1,382 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Tests for AdamGS.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.contrib.opt.python.training import adam_gs_optimizer +from tensorflow.python.client import session +from tensorflow.python.eager import context +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.framework import test_util +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import resource_variable_ops +from tensorflow.python.ops import variables +from tensorflow.python.platform import test + + +def adam_update_numpy(param, + g_t, + t, + m, + v, + alpha=0.001, + beta1=0.9, + beta2=0.999, + epsilon=1e-8): + alpha_t = alpha * np.sqrt(1 - beta2**t) / (1 - beta1**t) + + m_t = beta1 * m + (1 - beta1) * g_t + v_t = beta2 * v + (1 - beta2) * g_t * g_t + + param_t = param - alpha_t * m_t / (np.sqrt(v_t) + epsilon) + return param_t, m_t, v_t + + +class AdamGSOptimizerTest(test.TestCase): + + def doTestSparse(self, use_resource=False): + for dtype in [dtypes.half, dtypes.float32, dtypes.float64]: + with self.cached_session(): + # Initialize variables for numpy implementation. + m0, v0, m1, v1 = 0.0, 0.0, 0.0, 0.0 + var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype) + grads0_np = np.array([0.1, 0.1], dtype=dtype.as_numpy_dtype) + var1_np = np.array([3.0, 4.0], dtype=dtype.as_numpy_dtype) + grads1_np = np.array([0.01, 0.01], dtype=dtype.as_numpy_dtype) + + if use_resource: + global_step = resource_variable_ops.ResourceVariable( + array_ops.zeros([], dtypes.int64)) + var0 = resource_variable_ops.ResourceVariable(var0_np) + var1 = resource_variable_ops.ResourceVariable(var1_np) + else: + global_step = variables.Variable(array_ops.zeros([], dtypes.int64)) + var0 = variables.Variable(var0_np) + var1 = variables.Variable(var1_np) + grads0_np_indices = np.array([0, 1], dtype=np.int32) + grads0 = ops.IndexedSlices( + constant_op.constant(grads0_np), + constant_op.constant(grads0_np_indices), constant_op.constant([2])) + grads1_np_indices = np.array([0, 1], dtype=np.int32) + grads1 = ops.IndexedSlices( + constant_op.constant(grads1_np), + constant_op.constant(grads1_np_indices), constant_op.constant([2])) + opt = adam_gs_optimizer.AdamGSOptimizer(global_step=global_step) + update = opt.apply_gradients(zip([grads0, grads1], [var0, var1]), + global_step=global_step) + variables.global_variables_initializer().run() + + # Fetch params to validate initial values + self.assertAllClose([1.0, 2.0], self.evaluate(var0)) + self.assertAllClose([3.0, 4.0], self.evaluate(var1)) + + beta1_power, beta2_power = opt._get_beta_accumulators() + + # Run 3 steps of Adam + for t in range(1, 4): + self.assertAllCloseAccordingToType(0.9**t, self.evaluate(beta1_power)) + self.assertAllCloseAccordingToType(0.999**t, + self.evaluate(beta2_power)) + update.run() + + var0_np, m0, v0 = adam_update_numpy(var0_np, grads0_np, t, m0, v0) + var1_np, m1, v1 = adam_update_numpy(var1_np, grads1_np, t, m1, v1) + + # Validate updated params + self.assertAllCloseAccordingToType(var0_np, self.evaluate(var0)) + self.assertAllCloseAccordingToType(var1_np, self.evaluate(var1)) + + def testSparse(self): + self.doTestSparse(use_resource=False) + + def testResourceSparse(self): + self.doTestSparse(use_resource=True) + + def testSparseDevicePlacement(self): + for index_dtype in [dtypes.int32, dtypes.int64]: + with self.cached_session(force_gpu=test.is_gpu_available()): + # If a GPU is available, tests that all optimizer ops can be placed on + # it (i.e. they have GPU kernels). + var = variables.Variable([[1.0], [2.0]]) + indices = constant_op.constant([0, 1], dtype=index_dtype) + gathered_sum = math_ops.reduce_sum(array_ops.gather(var, indices)) + optimizer = adam_gs_optimizer.AdamGSOptimizer(3.0) + minimize_op = optimizer.minimize(gathered_sum) + variables.global_variables_initializer().run() + minimize_op.run() + + def testSparseRepeatedIndices(self): + for dtype in [dtypes.half, dtypes.float32, dtypes.float64]: + with self.cached_session(): + repeated_index_global_step = variables.Variable( + array_ops.zeros([], dtypes.int64)) + aggregated_global_step = variables.Variable( + array_ops.zeros([], dtypes.int64)) + repeated_index_update_var = variables.Variable( + [[1.0], [2.0]], dtype=dtype) + aggregated_update_var = variables.Variable( + [[1.0], [2.0]], dtype=dtype) + grad_repeated_index = ops.IndexedSlices( + constant_op.constant( + [0.1, 0.1], shape=[2, 1], dtype=dtype), + constant_op.constant([1, 1]), + constant_op.constant([2, 1])) + grad_aggregated = ops.IndexedSlices( + constant_op.constant( + [0.2], shape=[1, 1], dtype=dtype), + constant_op.constant([1]), + constant_op.constant([2, 1])) + repeated_update = adam_gs_optimizer.AdamGSOptimizer( + global_step=repeated_index_global_step).apply_gradients( + [(grad_repeated_index, repeated_index_update_var)], + global_step=repeated_index_global_step) + aggregated_update = adam_gs_optimizer.AdamGSOptimizer( + global_step=aggregated_global_step).apply_gradients( + [(grad_aggregated, aggregated_update_var)], + global_step=aggregated_global_step) + variables.global_variables_initializer().run() + self.assertAllClose(aggregated_update_var.eval(), + self.evaluate(repeated_index_update_var)) + for _ in range(3): + repeated_update.run() + aggregated_update.run() + self.assertAllClose(aggregated_update_var.eval(), + self.evaluate(repeated_index_update_var)) + + def doTestBasic(self, use_resource=False, use_callable_params=False): + for i, dtype in enumerate([dtypes.half, dtypes.float32, dtypes.float64]): + with self.session(graph=ops.Graph()): + # Initialize variables for numpy implementation. + m0, v0, m1, v1 = 0.0, 0.0, 0.0, 0.0 + var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype) + grads0_np = np.array([0.1, 0.1], dtype=dtype.as_numpy_dtype) + var1_np = np.array([3.0, 4.0], dtype=dtype.as_numpy_dtype) + grads1_np = np.array([0.01, 0.01], dtype=dtype.as_numpy_dtype) + + if use_resource: + global_step = resource_variable_ops.ResourceVariable( + array_ops.zeros([], dtypes.int64), name="global_step_%d" % i) + var0 = resource_variable_ops.ResourceVariable( + var0_np, name="var0_%d" % i) + var1 = resource_variable_ops.ResourceVariable( + var1_np, name="var1_%d" % i) + else: + global_step = variables.Variable(array_ops.zeros([], dtypes.int64)) + var0 = variables.Variable(var0_np) + var1 = variables.Variable(var1_np) + grads0 = constant_op.constant(grads0_np) + grads1 = constant_op.constant(grads1_np) + + learning_rate = lambda: 0.001 + beta1 = lambda: 0.9 + beta2 = lambda: 0.999 + epsilon = lambda: 1e-8 + if not use_callable_params: + learning_rate = learning_rate() + beta1 = beta1() + beta2 = beta2() + epsilon = epsilon() + + opt = adam_gs_optimizer.AdamGSOptimizer(global_step=global_step, + learning_rate=learning_rate) + update = opt.apply_gradients(zip([grads0, grads1], [var0, var1]), + global_step=global_step) + opt_variables = opt.variables() + beta1_power, beta2_power = opt._get_beta_accumulators() + self.assertTrue(beta1_power is not None) + self.assertTrue(beta2_power is not None) + self.assertNotIn(beta1_power, opt_variables) + self.assertNotIn(beta2_power, opt_variables) + + if not context.executing_eagerly(): + with ops.Graph().as_default(): + # Shouldn't return non-slot variables from other graphs. + self.assertEqual(0, len(opt.variables())) + self.evaluate(variables.global_variables_initializer()) + # Fetch params to validate initial values + self.assertAllClose([1.0, 2.0], self.evaluate(var0)) + self.assertAllClose([3.0, 4.0], self.evaluate(var1)) + + # Run 3 steps of Adam + for t in range(1, 4): + if not context.executing_eagerly(): + self.evaluate(update) + self.assertAllCloseAccordingToType( + 0.9**(t + 1), self.evaluate(beta1_power)) + self.assertAllCloseAccordingToType( + 0.999**(t + 1), self.evaluate(beta2_power)) + else: + if t > 1: + opt.apply_gradients(zip([grads0, grads1], [var0, var1]), + global_step=global_step) + beta1_power, beta2_power = opt._get_beta_accumulators() + self.assertAllCloseAccordingToType( + 0.9**t, self.evaluate(beta1_power)) + self.assertAllCloseAccordingToType( + 0.999**t, self.evaluate(beta2_power)) + + var0_np, m0, v0 = adam_update_numpy(var0_np, grads0_np, t, m0, v0) + var1_np, m1, v1 = adam_update_numpy(var1_np, grads1_np, t, m1, v1) + + # Validate updated params + self.assertAllCloseAccordingToType(var0_np, self.evaluate(var0)) + self.assertAllCloseAccordingToType(var1_np, self.evaluate(var1)) + if use_resource: + self.assertEqual("var0_%d/Adam:0" % (i,), + opt.get_slot(var=var0, name="m").name) + + def testBasic(self): + with self.cached_session(): + self.doTestBasic(use_resource=False) + + @test_util.run_in_graph_and_eager_modes(reset_test=True) + def testResourceBasic(self): + self.doTestBasic(use_resource=True) + + def testBasicCallableParams(self): + with context.eager_mode(): + self.doTestBasic(use_resource=True, use_callable_params=True) + + def testTensorLearningRate(self): + for dtype in [dtypes.half, dtypes.float32, dtypes.float64]: + with self.cached_session(): + global_step = variables.Variable(array_ops.zeros([], dtypes.int64)) + # Initialize variables for numpy implementation. + m0, v0, m1, v1 = 0.0, 0.0, 0.0, 0.0 + var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype) + grads0_np = np.array([0.1, 0.1], dtype=dtype.as_numpy_dtype) + var1_np = np.array([3.0, 4.0], dtype=dtype.as_numpy_dtype) + grads1_np = np.array([0.01, 0.01], dtype=dtype.as_numpy_dtype) + + var0 = variables.Variable(var0_np) + var1 = variables.Variable(var1_np) + grads0 = constant_op.constant(grads0_np) + grads1 = constant_op.constant(grads1_np) + opt = adam_gs_optimizer.AdamGSOptimizer( + global_step=global_step, learning_rate=constant_op.constant(0.001)) + update = opt.apply_gradients(zip([grads0, grads1], [var0, var1]), + global_step=global_step) + variables.global_variables_initializer().run() + + # Fetch params to validate initial values + self.assertAllClose([1.0, 2.0], self.evaluate(var0)) + self.assertAllClose([3.0, 4.0], self.evaluate(var1)) + + beta1_power, beta2_power = opt._get_beta_accumulators() + + # Run 3 steps of Adam + for t in range(1, 4): + self.assertAllCloseAccordingToType(0.9**t, self.evaluate(beta1_power)) + self.assertAllCloseAccordingToType(0.999**t, + self.evaluate(beta2_power)) + update.run() + + var0_np, m0, v0 = adam_update_numpy(var0_np, grads0_np, t, m0, v0) + var1_np, m1, v1 = adam_update_numpy(var1_np, grads1_np, t, m1, v1) + + # Validate updated params + self.assertAllCloseAccordingToType(var0_np, self.evaluate(var0)) + self.assertAllCloseAccordingToType(var1_np, self.evaluate(var1)) + + def testSharing(self): + for dtype in [dtypes.half, dtypes.float32, dtypes.float64]: + with self.cached_session(): + global_step = variables.Variable(array_ops.zeros([], dtypes.int64)) + # Initialize variables for numpy implementation. + m0, v0, m1, v1 = 0.0, 0.0, 0.0, 0.0 + var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype) + grads0_np = np.array([0.1, 0.1], dtype=dtype.as_numpy_dtype) + var1_np = np.array([3.0, 4.0], dtype=dtype.as_numpy_dtype) + grads1_np = np.array([0.01, 0.01], dtype=dtype.as_numpy_dtype) + + var0 = variables.Variable(var0_np) + var1 = variables.Variable(var1_np) + grads0 = constant_op.constant(grads0_np) + grads1 = constant_op.constant(grads1_np) + opt = adam_gs_optimizer.AdamGSOptimizer(global_step=global_step) + update1 = opt.apply_gradients(zip([grads0, grads1], [var0, var1]), + global_step=global_step) + update2 = opt.apply_gradients(zip([grads0, grads1], [var0, var1]), + global_step=global_step) + variables.global_variables_initializer().run() + + beta1_power, beta2_power = opt._get_beta_accumulators() + + # Fetch params to validate initial values + self.assertAllClose([1.0, 2.0], self.evaluate(var0)) + self.assertAllClose([3.0, 4.0], self.evaluate(var1)) + + # Run 3 steps of intertwined Adam1 and Adam2. + for t in range(1, 4): + self.assertAllCloseAccordingToType(0.9**t, self.evaluate(beta1_power)) + self.assertAllCloseAccordingToType(0.999**t, + self.evaluate(beta2_power)) + if t % 2 == 0: + update1.run() + else: + update2.run() + + var0_np, m0, v0 = adam_update_numpy(var0_np, grads0_np, t, m0, v0) + var1_np, m1, v1 = adam_update_numpy(var1_np, grads1_np, t, m1, v1) + + # Validate updated params + self.assertAllCloseAccordingToType(var0_np, self.evaluate(var0)) + self.assertAllCloseAccordingToType(var1_np, self.evaluate(var1)) + + def testTwoSessions(self): + optimizer = adam_gs_optimizer.AdamGSOptimizer() + + with context.eager_mode(): + var0 = variables.Variable(np.array([1.0, 2.0]), name="v0") + grads0 = constant_op.constant(np.array([0.1, 0.1])) + optimizer.apply_gradients([(grads0, var0)]) + + g = ops.Graph() + with g.as_default(): + with session.Session(): + var0 = variables.Variable(np.array([1.0, 2.0]), name="v0") + grads0 = constant_op.constant(np.array([0.1, 0.1])) + optimizer.apply_gradients([(grads0, var0)]) + + gg = ops.Graph() + with gg.as_default(): + with session.Session(): + var0 = variables.Variable(np.array([1.0, 2.0]), name="v0") + grads0 = constant_op.constant(np.array([0.1, 0.1])) + + # If the optimizer saves any state not keyed by graph the following line + # fails. + optimizer.apply_gradients([(grads0, var0)]) + + def testSlotsUniqueEager(self): + with context.eager_mode(): + v1 = resource_variable_ops.ResourceVariable(1.) + v2 = resource_variable_ops.ResourceVariable(1.) + opt = adam_gs_optimizer.AdamGSOptimizer(1.) + opt.minimize(lambda: v1 + v2) + # There should be two unique slot variables for v1 and v2 respectively. + self.assertEqual(4, len(set(opt.variables()))) + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/opt/python/training/elastic_average_optimizer.py b/tensorflow/contrib/opt/python/training/elastic_average_optimizer.py index 6c203e5519e6a66d20e2509eca3c74eb66bf32c7..fa1a7aaff0aa59a6a64b1f0bf836a273926d785d 100644 --- a/tensorflow/contrib/opt/python/training/elastic_average_optimizer.py +++ b/tensorflow/contrib/opt/python/training/elastic_average_optimizer.py @@ -30,6 +30,7 @@ from tensorflow.python.ops import variables from tensorflow.python.training import optimizer from tensorflow.python.training import saver from tensorflow.python.training import session_run_hook +from tensorflow.python.training.saving import saveable_object_util LOCAL_VARIABLE_NAME = 'local_center_variable' GLOBAL_VARIABLE_NAME = 'global_center_variable' @@ -424,7 +425,7 @@ class ElasticAverageOptimizer(optimizer.Optimizer): if var_list is None: var_list = variables.trainable_variables() if not isinstance(var_list, dict): - var_list = saver.BaseSaverBuilder.OpListToDict(var_list) + var_list = saveable_object_util.op_list_to_dict(var_list) swapped_var_list = {} for key, var in var_list.items(): @@ -464,4 +465,4 @@ class _ElasticAverageOptimizerHook(session_run_hook.SessionRunHook): def after_create_session(self, session, coord): """Run initialization ops""" - session.run(self._variable_init_op) \ No newline at end of file + session.run(self._variable_init_op) diff --git a/tensorflow/contrib/opt/python/training/lazy_adam_gs_optimizer.py b/tensorflow/contrib/opt/python/training/lazy_adam_gs_optimizer.py new file mode 100644 index 0000000000000000000000000000000000000000..8827007e4d7f6722398a8e36bd626377842d92ef --- /dev/null +++ b/tensorflow/contrib/opt/python/training/lazy_adam_gs_optimizer.py @@ -0,0 +1,114 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""LazyAdam rewrite to use global step for computing beta1 & beta2 accumulation. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.opt.python.training import adam_gs_optimizer +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import resource_variable_ops +from tensorflow.python.ops import state_ops + + +class LazyAdamGSOptimizer(adam_gs_optimizer.AdamGSOptimizer): + """Variant of the Adam optimizer that handles sparse updates more efficiently. + + Branched from tf.contrib.opt.LazyAdamGSOptimizer. The only difference is to + pass global step for computing beta1 and beta2 accumulators, instead of having + optimizer keep its own independent beta1 and beta2 accumulators as non-slot + variables. + + The original Adam algorithm maintains two moving-average accumulators for + each trainable variable; the accumulators are updated at every step. + This class provides lazier handling of gradient updates for sparse variables. + It only updates moving-average accumulators for sparse variable indices that + appear in the current batch, rather than updating the accumulators for all + indices. Compared with the original Adam optimizer, it can provide large + improvements in model training throughput for some applications. However, it + provides slightly different semantics than the original Adam algorithm, and + may lead to different empirical results. + """ + + def _apply_sparse(self, grad, var): + beta1_power, beta2_power = self._get_beta_accumulators() + beta1_power = math_ops.cast(beta1_power, var.dtype.base_dtype) + beta2_power = math_ops.cast(beta2_power, var.dtype.base_dtype) + lr_t = math_ops.cast(self._lr_t, var.dtype.base_dtype) + beta1_t = math_ops.cast(self._beta1_t, var.dtype.base_dtype) + beta2_t = math_ops.cast(self._beta2_t, var.dtype.base_dtype) + epsilon_t = math_ops.cast(self._epsilon_t, var.dtype.base_dtype) + lr = (lr_t * math_ops.sqrt(1 - beta2_power) / (1 - beta1_power)) + + # \\(m := beta1 * m + (1 - beta1) * g_t\\) + m = self.get_slot(var, "m") + m_t = state_ops.scatter_update(m, grad.indices, + beta1_t * array_ops.gather(m, grad.indices) + + (1 - beta1_t) * grad.values, + use_locking=self._use_locking) + + # \\(v := beta2 * v + (1 - beta2) * (g_t * g_t)\\) + v = self.get_slot(var, "v") + v_t = state_ops.scatter_update(v, grad.indices, + beta2_t * array_ops.gather(v, grad.indices) + + (1 - beta2_t) * math_ops.square(grad.values), + use_locking=self._use_locking) + + # \\(variable -= learning_rate * m_t / (epsilon_t + sqrt(v_t))\\) + m_t_slice = array_ops.gather(m_t, grad.indices) + v_t_slice = array_ops.gather(v_t, grad.indices) + denominator_slice = math_ops.sqrt(v_t_slice) + epsilon_t + var_update = state_ops.scatter_sub(var, grad.indices, + lr * m_t_slice / denominator_slice, + use_locking=self._use_locking) + return control_flow_ops.group(var_update, m_t, v_t) + + def _resource_apply_sparse(self, grad, var, indices): + beta1_power, beta2_power = self._get_beta_accumulators() + beta1_power = math_ops.cast(beta1_power, var.dtype.base_dtype) + beta2_power = math_ops.cast(beta2_power, var.dtype.base_dtype) + lr_t = math_ops.cast(self._lr_t, var.dtype.base_dtype) + beta1_t = math_ops.cast(self._beta1_t, var.dtype.base_dtype) + beta2_t = math_ops.cast(self._beta2_t, var.dtype.base_dtype) + epsilon_t = math_ops.cast(self._epsilon_t, var.dtype.base_dtype) + lr = (lr_t * math_ops.sqrt(1 - beta2_power) / (1 - beta1_power)) + + # \\(m := beta1 * m + (1 - beta1) * g_t\\) + m = self.get_slot(var, "m") + m_t_slice = beta1_t * array_ops.gather(m, indices) + (1 - beta1_t) * grad + m_update_op = resource_variable_ops.resource_scatter_update(m.handle, + indices, + m_t_slice) + + # \\(v := beta2 * v + (1 - beta2) * (g_t * g_t)\\) + v = self.get_slot(var, "v") + v_t_slice = (beta2_t * array_ops.gather(v, indices) + + (1 - beta2_t) * math_ops.square(grad)) + v_update_op = resource_variable_ops.resource_scatter_update(v.handle, + indices, + v_t_slice) + + # \\(variable -= learning_rate * m_t / (epsilon_t + sqrt(v_t))\\) + var_slice = lr * m_t_slice / (math_ops.sqrt(v_t_slice) + epsilon_t) + var_update_op = resource_variable_ops.resource_scatter_sub(var.handle, + indices, + var_slice) + + return control_flow_ops.group(var_update_op, m_update_op, v_update_op) diff --git a/tensorflow/contrib/opt/python/training/lazy_adam_gs_optimizer_test.py b/tensorflow/contrib/opt/python/training/lazy_adam_gs_optimizer_test.py new file mode 100644 index 0000000000000000000000000000000000000000..bdc9a02a546c8399172d0c5b58941b4d80179955 --- /dev/null +++ b/tensorflow/contrib/opt/python/training/lazy_adam_gs_optimizer_test.py @@ -0,0 +1,402 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Tests for LazyAdamGSOptimizer.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from absl.testing import parameterized +import numpy as np + +from tensorflow.contrib.opt.python.training import lazy_adam_gs_optimizer +from tensorflow.python.eager import context +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.framework import test_util +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import resource_variable_ops +from tensorflow.python.ops import variables +from tensorflow.python.platform import test + + +def adam_update_numpy(param, + g_t, + t, + m, + v, + alpha=0.001, + beta1=0.9, + beta2=0.999, + epsilon=1e-8): + alpha_t = alpha * np.sqrt(1 - beta2**t) / (1 - beta1**t) + + m_t = beta1 * m + (1 - beta1) * g_t + v_t = beta2 * v + (1 - beta2) * g_t * g_t + + param_t = param - alpha_t * m_t / (np.sqrt(v_t) + epsilon) + return param_t, m_t, v_t + + +class LazyAdamGSOptimizerTest(test.TestCase, parameterized.TestCase): + + @parameterized.parameters([False, True]) + def testSparse(self, use_resource): + for dtype in [dtypes.half, dtypes.float32, dtypes.float64]: + with self.cached_session(): + # Initialize variables for numpy implementation. + m0, v0, m1, v1 = 0.0, 0.0, 0.0, 0.0 + var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype) + grads0_np = np.array([0.1, 0.1], dtype=dtype.as_numpy_dtype) + var1_np = np.array([3.0, 4.0], dtype=dtype.as_numpy_dtype) + grads1_np = np.array([0.01, 0.01], dtype=dtype.as_numpy_dtype) + + if use_resource: + global_step = resource_variable_ops.ResourceVariable( + array_ops.zeros([], dtypes.int64)) + var0 = resource_variable_ops.ResourceVariable(var0_np) + var1 = resource_variable_ops.ResourceVariable(var1_np) + else: + global_step = variables.Variable(array_ops.zeros([], dtypes.int64)) + var0 = variables.Variable(var0_np) + var1 = variables.Variable(var1_np) + + grads0_np_indices = np.array([0, 1], dtype=np.int32) + grads0 = ops.IndexedSlices( + constant_op.constant(grads0_np), + constant_op.constant(grads0_np_indices), constant_op.constant([2])) + grads1_np_indices = np.array([0, 1], dtype=np.int32) + grads1 = ops.IndexedSlices( + constant_op.constant(grads1_np), + constant_op.constant(grads1_np_indices), constant_op.constant([2])) + opt = lazy_adam_gs_optimizer.LazyAdamGSOptimizer( + global_step=global_step) + update = opt.apply_gradients(zip([grads0, grads1], [var0, var1]), + global_step=global_step) + variables.global_variables_initializer().run() + + # Fetch params to validate initial values + self.assertAllClose([1.0, 2.0], var0.eval()) + self.assertAllClose([3.0, 4.0], var1.eval()) + + beta1_power, beta2_power = opt._get_beta_accumulators() + + # Run 3 steps of Adam + for t in range(1, 4): + self.assertAllCloseAccordingToType(0.9**t, beta1_power.eval()) + self.assertAllCloseAccordingToType(0.999**t, beta2_power.eval()) + update.run() + + var0_np, m0, v0 = adam_update_numpy(var0_np, grads0_np, t, m0, v0) + var1_np, m1, v1 = adam_update_numpy(var1_np, grads1_np, t, m1, v1) + + # Validate updated params + self.assertAllCloseAccordingToType(var0_np, var0.eval()) + self.assertAllCloseAccordingToType(var1_np, var1.eval()) + + @parameterized.parameters([False, True]) + def testSparseDevicePlacement(self, use_resource): + for index_dtype in [dtypes.int32, dtypes.int64]: + with self.cached_session(force_gpu=test.is_gpu_available()): + # If a GPU is available, tests that all optimizer ops can be placed on + # it (i.e. they have GPU kernels). + if use_resource: + global_step = resource_variable_ops.ResourceVariable( + array_ops.zeros([], dtypes.int64)) + var = resource_variable_ops.ResourceVariable([[1.0], [2.0]]) + else: + global_step = variables.Variable(array_ops.zeros([], dtypes.int64)) + var = variables.Variable([[1.0], [2.0]]) + + indices = constant_op.constant([0, 1], dtype=index_dtype) + gathered_sum = math_ops.reduce_sum(array_ops.gather(var, indices)) + optimizer = lazy_adam_gs_optimizer.LazyAdamGSOptimizer( + global_step=global_step, learning_rate=3.0) + minimize_op = optimizer.minimize(gathered_sum, global_step=global_step) + variables.global_variables_initializer().run() + minimize_op.run() + + @parameterized.parameters([False, True]) + def testSparseRepeatedIndices(self, use_resource): + for dtype in [dtypes.half, dtypes.float32, dtypes.float64]: + with self.cached_session(): + if use_resource: + repeated_index_global_step = resource_variable_ops.ResourceVariable( + array_ops.zeros([], dtypes.int64)) + aggregated_global_step = resource_variable_ops.ResourceVariable( + array_ops.zeros([], dtypes.int64)) + repeated_index_update_var = resource_variable_ops.ResourceVariable( + [[1.0], [2.0]], dtype=dtype) + aggregated_update_var = resource_variable_ops.ResourceVariable( + [[1.0], [2.0]], dtype=dtype) + else: + repeated_index_global_step = variables.Variable( + array_ops.zeros([], dtypes.int64)) + aggregated_global_step = variables.Variable( + array_ops.zeros([], dtypes.int64)) + repeated_index_update_var = variables.Variable( + [[1.0], [2.0]], dtype=dtype) + aggregated_update_var = variables.Variable( + [[1.0], [2.0]], dtype=dtype) + + grad_repeated_index = ops.IndexedSlices( + constant_op.constant( + [0.1, 0.1], shape=[2, 1], dtype=dtype), + constant_op.constant([1, 1]), + constant_op.constant([2, 1])) + grad_aggregated = ops.IndexedSlices( + constant_op.constant( + [0.2], shape=[1, 1], dtype=dtype), + constant_op.constant([1]), + constant_op.constant([2, 1])) + repeated_update_opt = lazy_adam_gs_optimizer.LazyAdamGSOptimizer( + global_step=repeated_index_global_step) + repeated_update = repeated_update_opt.apply_gradients( + [(grad_repeated_index, repeated_index_update_var)], + global_step=repeated_index_global_step) + aggregated_update_opt = lazy_adam_gs_optimizer.LazyAdamGSOptimizer( + global_step=aggregated_global_step) + aggregated_update = aggregated_update_opt.apply_gradients( + [(grad_aggregated, aggregated_update_var)], + global_step=aggregated_global_step) + variables.global_variables_initializer().run() + self.assertAllClose(aggregated_update_var.eval(), + repeated_index_update_var.eval()) + for _ in range(3): + repeated_update.run() + aggregated_update.run() + self.assertAllClose(aggregated_update_var.eval(), + repeated_index_update_var.eval()) + + def doTestBasic(self, use_resource=False, use_callable_params=False): + for i, dtype in enumerate([dtypes.half, dtypes.float32, dtypes.float64]): + with self.session(graph=ops.Graph()): + # Initialize variables for numpy implementation. + m0, v0, m1, v1 = 0.0, 0.0, 0.0, 0.0 + var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype) + grads0_np = np.array([0.1, 0.1], dtype=dtype.as_numpy_dtype) + var1_np = np.array([3.0, 4.0], dtype=dtype.as_numpy_dtype) + grads1_np = np.array([0.01, 0.01], dtype=dtype.as_numpy_dtype) + + if use_resource: + global_step = resource_variable_ops.ResourceVariable( + array_ops.zeros([], dtypes.int64), name="global_step_%d" % i) + var0 = resource_variable_ops.ResourceVariable( + var0_np, name="var0_%d" % i) + var1 = resource_variable_ops.ResourceVariable( + var1_np, name="var1_%d" % i) + else: + global_step = variables.Variable(array_ops.zeros([], dtypes.int64)) + var0 = variables.Variable(var0_np) + var1 = variables.Variable(var1_np) + grads0 = constant_op.constant(grads0_np) + grads1 = constant_op.constant(grads1_np) + + learning_rate = lambda: 0.001 + beta1 = lambda: 0.9 + beta2 = lambda: 0.999 + epsilon = lambda: 1e-8 + if not use_callable_params: + learning_rate = learning_rate() + beta1 = beta1() + beta2 = beta2() + epsilon = epsilon() + + opt = lazy_adam_gs_optimizer.LazyAdamGSOptimizer( + global_step=global_step, learning_rate=learning_rate) + update = opt.apply_gradients(zip([grads0, grads1], [var0, var1]), + global_step=global_step) + opt_variables = opt.variables() + beta1_power, beta2_power = opt._get_beta_accumulators() + self.assertIsNotNone(beta1_power) + self.assertIsNotNone(beta2_power is not None) + self.assertNotIn(beta1_power, opt_variables) + self.assertNotIn(beta2_power, opt_variables) + + if not context.executing_eagerly(): + with ops.Graph().as_default(): + # Shouldn't return non-slot variables from other graphs. + self.assertEqual(0, len(opt.variables())) + self.evaluate(variables.global_variables_initializer()) + # Fetch params to validate initial values + self.assertAllClose([1.0, 2.0], self.evaluate(var0)) + self.assertAllClose([3.0, 4.0], self.evaluate(var1)) + + # Run 3 steps of Adam + for t in range(1, 4): + if not context.executing_eagerly(): + self.evaluate(update) + self.assertAllCloseAccordingToType( + 0.9**(t + 1), self.evaluate(beta1_power)) + self.assertAllCloseAccordingToType( + 0.999**(t + 1), self.evaluate(beta2_power)) + else: + if t > 1: + opt.apply_gradients(zip([grads0, grads1], [var0, var1]), + global_step=global_step) + beta1_power, beta2_power = opt._get_beta_accumulators() + self.assertAllCloseAccordingToType( + 0.9**t, self.evaluate(beta1_power)) + self.assertAllCloseAccordingToType( + 0.999**t, self.evaluate(beta2_power)) + + var0_np, m0, v0 = adam_update_numpy(var0_np, grads0_np, t, m0, v0) + var1_np, m1, v1 = adam_update_numpy(var1_np, grads1_np, t, m1, v1) + + # Validate updated params + self.assertAllCloseAccordingToType(var0_np, self.evaluate(var0)) + self.assertAllCloseAccordingToType(var1_np, self.evaluate(var1)) + if use_resource: + self.assertEqual("var0_%d/Adam:0" % (i,), + opt.get_slot(var=var0, name="m").name) + + def testBasic(self): + with self.cached_session(): + self.doTestBasic(use_resource=False) + + @test_util.run_in_graph_and_eager_modes(reset_test=True) + def testResourceBasic(self): + self.doTestBasic(use_resource=True) + + def testBasicCallableParams(self): + with context.eager_mode(): + self.doTestBasic(use_resource=True, use_callable_params=True) + + def testTensorLearningRate(self): + for dtype in [dtypes.half, dtypes.float32, dtypes.float64]: + with self.cached_session(): + global_step = variables.Variable(array_ops.zeros([], dtypes.int64)) + # Initialize variables for numpy implementation. + m0, v0, m1, v1 = 0.0, 0.0, 0.0, 0.0 + var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype) + grads0_np = np.array([0.1, 0.1], dtype=dtype.as_numpy_dtype) + var1_np = np.array([3.0, 4.0], dtype=dtype.as_numpy_dtype) + grads1_np = np.array([0.01, 0.01], dtype=dtype.as_numpy_dtype) + + var0 = variables.Variable(var0_np) + var1 = variables.Variable(var1_np) + grads0 = constant_op.constant(grads0_np) + grads1 = constant_op.constant(grads1_np) + opt = lazy_adam_gs_optimizer.LazyAdamGSOptimizer( + global_step=global_step, learning_rate=constant_op.constant(0.001)) + update = opt.apply_gradients(zip([grads0, grads1], [var0, var1]), + global_step=global_step) + variables.global_variables_initializer().run() + + # Fetch params to validate initial values + self.assertAllClose([1.0, 2.0], var0.eval()) + self.assertAllClose([3.0, 4.0], var1.eval()) + + beta1_power, beta2_power = opt._get_beta_accumulators() + + # Run 3 steps of Adam + for t in range(1, 4): + self.assertAllCloseAccordingToType(0.9**t, beta1_power.eval()) + self.assertAllCloseAccordingToType(0.999**t, beta2_power.eval()) + update.run() + + var0_np, m0, v0 = adam_update_numpy(var0_np, grads0_np, t, m0, v0) + var1_np, m1, v1 = adam_update_numpy(var1_np, grads1_np, t, m1, v1) + + # Validate updated params + self.assertAllCloseAccordingToType(var0_np, var0.eval()) + self.assertAllCloseAccordingToType(var1_np, var1.eval()) + + def testSharing(self): + for dtype in [dtypes.half, dtypes.float32, dtypes.float64]: + with self.cached_session(): + global_step = variables.Variable(array_ops.zeros([], dtypes.int64)) + # Initialize variables for numpy implementation. + m0, v0, m1, v1 = 0.0, 0.0, 0.0, 0.0 + var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype) + grads0_np = np.array([0.1, 0.1], dtype=dtype.as_numpy_dtype) + var1_np = np.array([3.0, 4.0], dtype=dtype.as_numpy_dtype) + grads1_np = np.array([0.01, 0.01], dtype=dtype.as_numpy_dtype) + + var0 = variables.Variable(var0_np) + var1 = variables.Variable(var1_np) + grads0 = constant_op.constant(grads0_np) + grads1 = constant_op.constant(grads1_np) + opt = lazy_adam_gs_optimizer.LazyAdamGSOptimizer( + global_step=global_step) + update1 = opt.apply_gradients(zip([grads0, grads1], [var0, var1]), + global_step=global_step) + update2 = opt.apply_gradients(zip([grads0, grads1], [var0, var1]), + global_step=global_step) + variables.global_variables_initializer().run() + + beta1_power, beta2_power = opt._get_beta_accumulators() + + # Fetch params to validate initial values + self.assertAllClose([1.0, 2.0], var0.eval()) + self.assertAllClose([3.0, 4.0], var1.eval()) + + # Run 3 steps of intertwined Adam1 and Adam2. + for t in range(1, 4): + self.assertAllCloseAccordingToType(0.9**t, beta1_power.eval()) + self.assertAllCloseAccordingToType(0.999**t, beta2_power.eval()) + if t % 2 == 0: + update1.run() + else: + update2.run() + + var0_np, m0, v0 = adam_update_numpy(var0_np, grads0_np, t, m0, v0) + var1_np, m1, v1 = adam_update_numpy(var1_np, grads1_np, t, m1, v1) + + # Validate updated params + self.assertAllCloseAccordingToType(var0_np, var0.eval()) + self.assertAllCloseAccordingToType(var1_np, var1.eval()) + + def testTwoSessions(self): + optimizer = lazy_adam_gs_optimizer.LazyAdamGSOptimizer() + + with context.eager_mode(): + var0 = variables.Variable(np.array([1.0, 2.0]), name="v0") + grads0 = constant_op.constant(np.array([0.1, 0.1])) + optimizer.apply_gradients([(grads0, var0)]) + + g = ops.Graph() + with g.as_default(): + with self.session(graph=g): + var0 = variables.Variable(np.array([1.0, 2.0]), name="v0") + grads0 = constant_op.constant(np.array([0.1, 0.1])) + optimizer.apply_gradients([(grads0, var0)]) + + gg = ops.Graph() + with gg.as_default(): + with self.session(graph=gg): + var0 = variables.Variable(np.array([1.0, 2.0]), name="v0") + grads0 = constant_op.constant(np.array([0.1, 0.1])) + + # If the optimizer saves any state not keyed by graph the following line + # fails. + optimizer.apply_gradients([(grads0, var0)]) + + def testSlotsUniqueEager(self): + with context.eager_mode(): + v1 = resource_variable_ops.ResourceVariable(1.) + v2 = resource_variable_ops.ResourceVariable(1.) + opt = lazy_adam_gs_optimizer.LazyAdamGSOptimizer(1.) + opt.minimize(lambda: v1 + v2) + # There should be two non-slot variables, and two unique slot variables + # for v1 and v2 respectively. + self.assertLen(set(opt.variables()), 4) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/opt/python/training/moving_average_optimizer.py b/tensorflow/contrib/opt/python/training/moving_average_optimizer.py index b7fd2d2fb9db3eed15eb1cc2934199939790b1c0..bf3e5c51f78cc3ca3c7c77009c9cf428c4988953 100644 --- a/tensorflow/contrib/opt/python/training/moving_average_optimizer.py +++ b/tensorflow/contrib/opt/python/training/moving_average_optimizer.py @@ -26,6 +26,7 @@ from tensorflow.python.ops import variables from tensorflow.python.training import moving_averages from tensorflow.python.training import optimizer from tensorflow.python.training import saver +from tensorflow.python.training.saving import saveable_object_util class MovingAverageOptimizer(optimizer.Optimizer): @@ -165,7 +166,7 @@ class MovingAverageOptimizer(optimizer.Optimizer): if var_list is None: var_list = variables.global_variables() if not isinstance(var_list, dict): - var_list = saver.BaseSaverBuilder.OpListToDict(var_list) + var_list = saveable_object_util.op_list_to_dict(var_list) v_name_to_tensor = {} for k, tensor_or_list in six.iteritems(var_list): diff --git a/tensorflow/contrib/optimizer_v2/optimizer_v2.py b/tensorflow/contrib/optimizer_v2/optimizer_v2.py index 73a556f0b299614b098ceef0fb9d32f148227b03..7fb23abc38d9dc101204ed83808aebe5a8ef1e78 100644 --- a/tensorflow/contrib/optimizer_v2/optimizer_v2.py +++ b/tensorflow/contrib/optimizer_v2/optimizer_v2.py @@ -25,6 +25,7 @@ import abc import six from tensorflow.python.distribute import distribute_lib +from tensorflow.python.distribute import distribution_strategy_context as distribute_ctx from tensorflow.python.distribute import reduce_util as ds_reduce_util from tensorflow.python.eager import backprop from tensorflow.python.eager import context @@ -36,7 +37,6 @@ from tensorflow.python.ops import math_ops from tensorflow.python.ops import resource_variable_ops from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variables -from tensorflow.python.training import distribution_strategy_context as distribute_ctx from tensorflow.python.training import optimizer as optimizer_v1 from tensorflow.python.training import slot_creator from tensorflow.python.training.checkpointable import base as checkpointable @@ -997,10 +997,10 @@ class OptimizerV2(optimizer_v1.Optimizer): with ops.control_dependencies([update_ops]): finish_updates = distribution.extended.update_non_slot( non_slot_devices, finish, group=False) - # We said grouped=False, which means finish_updates is always a list. - # It will be [None] when finish() returns None. - if finish_updates == [None]: - finish_updates = [update_ops] + # We said group=False, which means finish_updates is always a tuple. + # It will be (None,) when finish() returns None. + if finish_updates == (None,): + finish_updates = (update_ops,) # Update `global_step` (if any). if global_step is None: diff --git a/tensorflow/contrib/quantize/python/quantize.py b/tensorflow/contrib/quantize/python/quantize.py index 21d1b1213090273b5abd8e012f8711db98c94347..7c973fe597181b822e617db1f85a08f1b678e26f 100644 --- a/tensorflow/contrib/quantize/python/quantize.py +++ b/tensorflow/contrib/quantize/python/quantize.py @@ -685,7 +685,7 @@ def _InsertQuantOp(context, [1; 2^bits - 1] or wide range [0; 2^bits - 1]. producer_scope: The restriction of producer scope. If not None, the new op will be inserted only when the producer is in this scope. - consumer_scope: The restriction of producer scope. If not None, the new op + consumer_scope: The restriction of consumer scope. If not None, the new op will be inserted only when all the consumers are in this scope. Raises: ValueError: When producer operation is not directly connected to the diff --git a/tensorflow/contrib/receptive_field/README.md b/tensorflow/contrib/receptive_field/README.md index 79b015a9163f5727caa40b54579c71e57621c92f..d1c41e4c0a11028765c9fc0dc345cb29453baa31 100644 --- a/tensorflow/contrib/receptive_field/README.md +++ b/tensorflow/contrib/receptive_field/README.md @@ -185,5 +185,4 @@ Effective padding (vertical) = 1482 ## Authors -André Araujo (github id: andrefaraujo) and Mark Sandler (github id: -marksandler) +André Araujo (@andrefaraujo) and Mark Sandler (@marksandler) diff --git a/tensorflow/contrib/receptive_field/python/util/examples/compute_rf.py b/tensorflow/contrib/receptive_field/python/util/examples/compute_rf.py index d6fdd12bbe37fb0e0cb12f1d0adc3fce29b19e8a..72f98ccc32e945b48b5f1b570bcca323a5b5f48a 100644 --- a/tensorflow/contrib/receptive_field/python/util/examples/compute_rf.py +++ b/tensorflow/contrib/receptive_field/python/util/examples/compute_rf.py @@ -12,10 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Computes Receptive Field (RF) information given a graph protobuf. - -For an example of usage, see accompanying file compute_rf.sh -""" +"""Computes Receptive Field (RF) information given a graph protobuf.""" from __future__ import absolute_import from __future__ import division diff --git a/tensorflow/contrib/receptive_field/python/util/examples/rf_benchmark.py b/tensorflow/contrib/receptive_field/python/util/examples/rf_benchmark.py index a298b4d49038468299b58140758c69675368e855..325929a5937ac60a6134fae064e7633a4c57473d 100644 --- a/tensorflow/contrib/receptive_field/python/util/examples/rf_benchmark.py +++ b/tensorflow/contrib/receptive_field/python/util/examples/rf_benchmark.py @@ -16,8 +16,6 @@ The receptive field (and related parameters) for the different models are printed to stdout, and may also optionally be written to a CSV file. - -For an example of usage, see rf_benchmark.sh """ from __future__ import absolute_import @@ -262,11 +260,11 @@ def _model_rf(graphdef, information will be computed. model_type: Type of model to be used, used only for printing purposes. csv_writer: A CSV writer for RF parameters, which is used if it is not None. - input_resolution: Input resolution to use when computing RF - parameters. This is important for the case where padding can only be - defined if the input resolution is known, which may happen if using SAME - padding. This is assumed the resolution for both height and width. If - None, we consider the resolution is unknown. + input_resolution: Input resolution to use when computing RF parameters. This + is important for the case where padding can only be defined if the input + resolution is known, which may happen if using SAME padding. This is + assumed the resolution for both height and width. If None, we consider the + resolution is unknown. """ for desired_end_point_key in desired_end_point_keys: print('- %s:' % desired_end_point_key) @@ -283,10 +281,10 @@ def _model_rf(graphdef, if (receptive_field_x == receptive_field_y) and ( effective_stride_x == effective_stride_y) and ( effective_padding_x == effective_padding_y): - print('Receptive field size = %5s, effective stride = %5s, effective ' - 'padding = %5s' % (str(receptive_field_x), - str(effective_stride_x), - str(effective_padding_x))) + print( + 'Receptive field size = %5s, effective stride = %5s, effective ' + 'padding = %5s' % (str(receptive_field_x), str(effective_stride_x), + str(effective_padding_x))) else: print('Receptive field size: horizontal = %5s, vertical = %5s. ' 'Effective stride: horizontal = %5s, vertical = %5s. Effective ' @@ -362,9 +360,8 @@ def _process_model_rf(model_type='resnet_v1_50', defined if the input resolution is known, which may happen if using SAME padding. The entries in the list are assumed the resolution for both height and width. If one of the elements in the list is None, we consider - it to mean that the resolution is unknown. If the list itself is None, - we use the default list [None, 224, 321]. - + it to mean that the resolution is unknown. If the list itself is None, we + use the default list [None, 224, 321]. """ # Process default value for this list. if input_resolutions is None: @@ -477,8 +474,8 @@ def _mobilenet_v1_rf(csv_writer=None): csv_writer: A CSV writer for RF parameters, which is used if it is not None. """ for model_type in _SUPPORTED_MOBILENETV1_VARIANTS: - with slim.arg_scope( - [slim.batch_norm, slim.dropout], is_training=False) as arg_sc: + with slim.arg_scope([slim.batch_norm, slim.dropout], + is_training=False) as arg_sc: _process_model_rf(model_type, csv_writer, arg_sc) diff --git a/tensorflow/contrib/receptive_field/python/util/receptive_field.py b/tensorflow/contrib/receptive_field/python/util/receptive_field.py index b9bd2f09761ab10a62d37e8e2580b93b9b8a4453..9127c772c75279d9c8eacc5a17680beba9247d01 100644 --- a/tensorflow/contrib/receptive_field/python/util/receptive_field.py +++ b/tensorflow/contrib/receptive_field/python/util/receptive_field.py @@ -12,12 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Functions to compute receptive field of a fully-convolutional network. - -Please refer to the following g3doc for detailed explanation on how this -computation is performed, and why it is important: -g3doc/photos/vision/features/delf/g3doc/rf_computation.md -""" +"""Functions to compute receptive field of a fully-convolutional network.""" from __future__ import absolute_import from __future__ import division @@ -96,8 +91,8 @@ class ReceptiveField(object): Args: y: An array of feature coordinates with shape `(..., d)`, where `d` is the number of dimensions of the coordinates. - axis: The dimensions for which to compute the input center coordinates. - If `None` (the default), compute the input center coordinates for all + axis: The dimensions for which to compute the input center coordinates. If + `None` (the default), compute the input center coordinates for all dimensions. Returns: @@ -127,8 +122,8 @@ class ReceptiveField(object): Args: x: An array of input center coordinates with shape `(..., d)`, where `d` is the number of dimensions of the coordinates. - axis: The dimensions for which to compute the feature coordinates. - If `None` (the default), compute the feature coordinates for all + axis: The dimensions for which to compute the feature coordinates. If + `None` (the default), compute the feature coordinates for all dimensions. Returns: @@ -274,14 +269,15 @@ def compute_receptive_field_from_graph_def(graph_def, continue # Get params for this layer. - (kernel_size_x, kernel_size_y, stride_x, stride_y, padding_x, - padding_y, _, _) = parse_layer_parameters.get_layer_params( + (kernel_size_x, kernel_size_y, stride_x, stride_y, padding_x, padding_y, + _, _) = parse_layer_parameters.get_layer_params( node, name_to_node, node_info[node.name].input_size) - logging.vlog(3, "kernel_size_x = %s, kernel_size_y = %s, " - "stride_x = %s, stride_y = %s, " - "padding_x = %s, padding_y = %s, input size = %s" % - (kernel_size_x, kernel_size_y, stride_x, stride_y, padding_x, - padding_y, node_info[node.name].input_size)) + logging.vlog( + 3, "kernel_size_x = %s, kernel_size_y = %s, " + "stride_x = %s, stride_y = %s, " + "padding_x = %s, padding_y = %s, input size = %s" % + (kernel_size_x, kernel_size_y, stride_x, stride_y, padding_x, + padding_y, node_info[node.name].input_size)) if padding_x is None or padding_y is None: undefined_padding = True @@ -352,15 +348,15 @@ def compute_receptive_field_from_graph_def(graph_def, raise ValueError( "Graph is not aligned since effective stride from different " "paths is different in vertical direction") - if (rf_sizes_x[inp_name] - 1 - ) / 2 - effective_paddings_x[inp_name] != ( - rf_size_input_x - 1) / 2 - effective_padding_input_x: + if (rf_sizes_x[inp_name] - + 1) / 2 - effective_paddings_x[inp_name] != ( + rf_size_input_x - 1) / 2 - effective_padding_input_x: raise ValueError( "Graph is not aligned since center shift from different " "paths is different in horizontal direction") - if (rf_sizes_y[inp_name] - 1 - ) / 2 - effective_paddings_y[inp_name] != ( - rf_size_input_y - 1) / 2 - effective_padding_input_y: + if (rf_sizes_y[inp_name] - + 1) / 2 - effective_paddings_y[inp_name] != ( + rf_size_input_y - 1) / 2 - effective_padding_input_y: raise ValueError( "Graph is not aligned since center shift from different " "paths is different in vertical direction") diff --git a/tensorflow/contrib/resampler/xla/resampler_ops_xla_test.py b/tensorflow/contrib/resampler/xla/resampler_ops_xla_test.py index d8ca0eab276b39f025d018edebb78eed7a8433bb..cec4c3c23305034d167a248a637425507750064e 100644 --- a/tensorflow/contrib/resampler/xla/resampler_ops_xla_test.py +++ b/tensorflow/contrib/resampler/xla/resampler_ops_xla_test.py @@ -164,6 +164,15 @@ class ResamplerOpsTest(xla_test.XLATestCase): expected = [[[0.0], [27.62]]] self._assertForwardOpMatchesExpected(input_np, warp_np, expected) + expected_grad_data = [[[[0.12], [0.27999997]], [[0.18000001], + [0.42000002]]]] + expected_grad_warp = [[[0., 0.], [22.60000038, 35.20000076]]] + + grad_output = np.ones([1, 2, 1], dtype=dtype) + self._assertBackwardOpMatchesExpected(input_np, warp_np, grad_output, + expected_grad_data, + expected_grad_warp) + # One of (x, y) is less than 0. for dtype in self.float_types: input_shape = [1, 2, 2, 1] @@ -171,11 +180,21 @@ class ResamplerOpsTest(xla_test.XLATestCase): input_np = np.array(input_data, dtype=dtype).reshape(input_shape) warp_shape = [1, 2, 2] + # -1 is out of bound for grad_warp. warp_data = [-1, 0.1, 0.7, 0.6] warp_np = np.array(warp_data, dtype=dtype).reshape(warp_shape) expected = [[[0.0], [27.62]]] self._assertForwardOpMatchesExpected(input_np, warp_np, expected) + expected_grad_data = [[[[0.12], [0.27999997]], [[0.18000001], + [0.42000002]]]] + expected_grad_warp = [[[0., 0.], [22.60000038, 35.20000076]]] + + grad_output = np.ones([1, 2, 1], dtype=dtype) + self._assertBackwardOpMatchesExpected(input_np, warp_np, grad_output, + expected_grad_data, + expected_grad_warp) + # Both of (x, y) are greater than image size. for dtype in self.float_types: input_shape = [1, 2, 2, 1] @@ -183,11 +202,20 @@ class ResamplerOpsTest(xla_test.XLATestCase): input_np = np.array(input_data, dtype=dtype).reshape(input_shape) warp_shape = [1, 2, 2] + # -0.1 is *inbound* for grad_warp and grad_data, 2.1 is out of bound. warp_data = [-0.1, 0.1, 1.2, 2.1] warp_np = np.array(warp_data, dtype=dtype).reshape(warp_shape) expected = [[[0.0], [0.0]]] self._assertForwardOpMatchesExpected(input_np, warp_np, expected) + expected_grad_data = [[[[0.81], [0.0]], [[0.09], [0.0]]]] + expected_grad_warp = [[[10.30, 2.7], [0.0, 0.0]]] + + grad_output = np.ones([1, 2, 1], dtype=dtype) + self._assertBackwardOpMatchesExpected(input_np, warp_np, grad_output, + expected_grad_data, + expected_grad_warp) + # One of (x, y) is greater than image size. for dtype in self.float_types: input_shape = [1, 2, 2, 1] @@ -200,6 +228,14 @@ class ResamplerOpsTest(xla_test.XLATestCase): expected = [[[0.0], [0.0]]] self._assertForwardOpMatchesExpected(input_np, warp_np, expected) + expected_grad_data = [[[[0.81], [0.81]], [[0.0], [0.08]]]] + expected_grad_warp = [[[-4.5, 9.5], [-9.9, 39.20]]] + + grad_output = np.ones([1, 2, 1], dtype=dtype) + self._assertBackwardOpMatchesExpected(input_np, warp_np, grad_output, + expected_grad_data, + expected_grad_warp) + if __name__ == '__main__': test.main() diff --git a/tensorflow/contrib/rnn/BUILD b/tensorflow/contrib/rnn/BUILD index e124867415f94fb5052f34f50363ea718d71053b..44b232e0f2b26f16f0300e11cf2764e1157a0050 100644 --- a/tensorflow/contrib/rnn/BUILD +++ b/tensorflow/contrib/rnn/BUILD @@ -118,6 +118,7 @@ cuda_py_tests( "//tensorflow/python:rnn_cell", "//tensorflow/python:variable_scope", "//tensorflow/python:variables", + "@absl_py//absl/testing:parameterized", ], ) diff --git a/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py b/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py index 7d57b0413a3bb51c35e670ce3fdb2cc818f44a58..a0d013c618ea56077098b15b7eed5f9110239516 100644 --- a/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py +++ b/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py @@ -20,6 +20,7 @@ from __future__ import print_function import os +from absl.testing import parameterized import numpy as np from tensorflow.contrib import rnn as contrib_rnn @@ -31,6 +32,8 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import random_seed from tensorflow.python.framework import test_util +from tensorflow.python.keras import layers as keras_layers +from tensorflow.python.layers import base as base_layer from tensorflow.python.ops import array_ops from tensorflow.python.ops import init_ops from tensorflow.python.ops import math_ops @@ -805,12 +808,13 @@ class RNNCellTest(test.TestCase): self.assertAllClose(res[1], [[0.13248, 0.13248]]) -class DropoutWrapperTest(test.TestCase): +class DropoutWrapperTest(test.TestCase, parameterized.TestCase): def _testDropoutWrapper(self, batch_size=None, time_steps=None, parallel_iterations=None, + wrapper_type=None, **kwargs): with self.cached_session() as sess: with variable_scope.variable_scope( @@ -832,7 +836,7 @@ class DropoutWrapperTest(test.TestCase): constant([[0.1, 0.1, 0.1]] * batch_size, dtype=dtypes.float32) ] * 2) outputs, final_state = rnn.dynamic_rnn( - cell=rnn_cell_impl.DropoutWrapper( + cell=wrapper_type( rnn_cell_impl.LSTMCell(3), dtype=x.dtype, **kwargs), time_major=True, parallel_iterations=parallel_iterations, @@ -845,16 +849,34 @@ class DropoutWrapperTest(test.TestCase): self.assertEqual(res[1].h.shape, (batch_size, 3)) return res - def testWrappedCellProperty(self): + @parameterized.parameters( + [rnn_cell_impl.DropoutWrapper, rnn_cell_impl.DropoutWrapperV2]) + def testDropoutWrapperProperties(self, wrapper_type): cell = rnn_cell_impl.BasicRNNCell(10) - wrapper = rnn_cell_impl.DropoutWrapper(cell) + wrapper = wrapper_type(cell) # Github issue 15810 self.assertEqual(wrapper.wrapped_cell, cell) - - def testDropoutWrapperKeepAllConstantInput(self): + self.assertEqual(wrapper.state_size, 10) + self.assertEqual(wrapper.output_size, 10) + + @parameterized.parameters( + [rnn_cell_impl.DropoutWrapper, rnn_cell_impl.DropoutWrapperV2]) + def testDropoutWrapperZeroState(self, wrapper_type): + class _Cell(rnn_cell_impl.BasicRNNCell): + + def zero_state(self, batch_size=None, dtype=None): + return "wrapped_cell_zero_state" + wrapper = wrapper_type(_Cell(10)) + self.assertEqual(wrapper.zero_state(10, dtypes.float32), + "wrapped_cell_zero_state") + + @parameterized.parameters( + [rnn_cell_impl.DropoutWrapper, rnn_cell_impl.DropoutWrapperV2]) + def testDropoutWrapperKeepAllConstantInput(self, wrapper_type): keep = array_ops.ones([]) res = self._testDropoutWrapper( - input_keep_prob=keep, output_keep_prob=keep, state_keep_prob=keep) + input_keep_prob=keep, output_keep_prob=keep, state_keep_prob=keep, + wrapper_type=wrapper_type) true_full_output = np.array( [[[0.751109, 0.751109, 0.751109]], [[0.895509, 0.895509, 0.895509]]], dtype=np.float32) @@ -864,10 +886,13 @@ class DropoutWrapperTest(test.TestCase): self.assertAllClose(true_full_output[1], res[1].h) self.assertAllClose(true_full_final_c, res[1].c) - def testDropoutWrapperKeepAll(self): + @parameterized.parameters( + [rnn_cell_impl.DropoutWrapper, rnn_cell_impl.DropoutWrapperV2]) + def testDropoutWrapperKeepAll(self, wrapper_type): keep = variable_scope.get_variable("all", initializer=1.0) res = self._testDropoutWrapper( - input_keep_prob=keep, output_keep_prob=keep, state_keep_prob=keep) + input_keep_prob=keep, output_keep_prob=keep, state_keep_prob=keep, + wrapper_type=wrapper_type) true_full_output = np.array( [[[0.751109, 0.751109, 0.751109]], [[0.895509, 0.895509, 0.895509]]], dtype=np.float32) @@ -877,7 +902,9 @@ class DropoutWrapperTest(test.TestCase): self.assertAllClose(true_full_output[1], res[1].h) self.assertAllClose(true_full_final_c, res[1].c) - def testDropoutWrapperWithSeed(self): + @parameterized.parameters( + [rnn_cell_impl.DropoutWrapper, rnn_cell_impl.DropoutWrapperV2]) + def testDropoutWrapperWithSeed(self, wrapper_type): keep_some = 0.5 random_seed.set_random_seed(2) ## Use parallel_iterations = 1 in both calls to @@ -889,7 +916,8 @@ class DropoutWrapperTest(test.TestCase): output_keep_prob=keep_some, state_keep_prob=keep_some, seed=10, - parallel_iterations=1) + parallel_iterations=1, + wrapper_type=wrapper_type) # Clear away the graph and the test session (which keeps variables around) ops.reset_default_graph() self._ClearCachedSession() @@ -899,18 +927,22 @@ class DropoutWrapperTest(test.TestCase): output_keep_prob=keep_some, state_keep_prob=keep_some, seed=10, - parallel_iterations=1) + parallel_iterations=1, + wrapper_type=wrapper_type) self.assertAllClose(res_standard_1[0], res_standard_2[0]) self.assertAllClose(res_standard_1[1].c, res_standard_2[1].c) self.assertAllClose(res_standard_1[1].h, res_standard_2[1].h) - def testDropoutWrapperKeepNoOutput(self): + @parameterized.parameters( + [rnn_cell_impl.DropoutWrapper, rnn_cell_impl.DropoutWrapperV2]) + def testDropoutWrapperKeepNoOutput(self, wrapper_type): keep_all = variable_scope.get_variable("all", initializer=1.0) keep_none = variable_scope.get_variable("none", initializer=1e-6) res = self._testDropoutWrapper( input_keep_prob=keep_all, output_keep_prob=keep_none, - state_keep_prob=keep_all) + state_keep_prob=keep_all, + wrapper_type=wrapper_type) true_full_output = np.array( [[[0.751109, 0.751109, 0.751109]], [[0.895509, 0.895509, 0.895509]]], dtype=np.float32) @@ -920,7 +952,9 @@ class DropoutWrapperTest(test.TestCase): self.assertAllClose(true_full_output[1], res[1].h) self.assertAllClose(true_full_final_c, res[1].c) - def testDropoutWrapperKeepNoStateExceptLSTMCellMemory(self): + @parameterized.parameters( + [rnn_cell_impl.DropoutWrapper, rnn_cell_impl.DropoutWrapperV2]) + def testDropoutWrapperKeepNoStateExceptLSTMCellMemory(self, wrapper_type): keep_all = variable_scope.get_variable("all", initializer=1.0) keep_none = variable_scope.get_variable("none", initializer=1e-6) # Even though we dropout state, by default DropoutWrapper never @@ -928,7 +962,8 @@ class DropoutWrapperTest(test.TestCase): res = self._testDropoutWrapper( input_keep_prob=keep_all, output_keep_prob=keep_all, - state_keep_prob=keep_none) + state_keep_prob=keep_none, + wrapper_type=wrapper_type) true_c_state = np.array([[1.713925, 1.713925, 1.713925]], dtype=np.float32) true_full_output = np.array( [[[0.751109, 0.751109, 0.751109]], [[0.895509, 0.895509, 0.895509]]], @@ -941,7 +976,9 @@ class DropoutWrapperTest(test.TestCase): # c state of an LSTMStateTuple is NEVER modified. self.assertAllClose(true_c_state, res[1].c) - def testDropoutWrapperKeepNoInput(self): + @parameterized.parameters( + [rnn_cell_impl.DropoutWrapper, rnn_cell_impl.DropoutWrapperV2]) + def testDropoutWrapperKeepNoInput(self, wrapper_type): keep_all = variable_scope.get_variable("all", initializer=1.0) keep_none = variable_scope.get_variable("none", initializer=1e-6) true_full_output = np.array( @@ -953,12 +990,15 @@ class DropoutWrapperTest(test.TestCase): res = self._testDropoutWrapper( input_keep_prob=keep_none, output_keep_prob=keep_all, - state_keep_prob=keep_all) + state_keep_prob=keep_all, + wrapper_type=wrapper_type) self.assertGreater(np.linalg.norm(res[0] - true_full_output), 1e-4) self.assertGreater(np.linalg.norm(res[1].h - true_full_output[1]), 1e-4) self.assertGreater(np.linalg.norm(res[1].c - true_full_final_c), 1e-4) - def testDropoutWrapperRecurrentOutput(self): + @parameterized.parameters( + [rnn_cell_impl.DropoutWrapper, rnn_cell_impl.DropoutWrapperV2]) + def testDropoutWrapperRecurrentOutput(self, wrapper_type): keep_some = 0.8 keep_all = variable_scope.get_variable("all", initializer=1.0) res = self._testDropoutWrapper( @@ -966,6 +1006,7 @@ class DropoutWrapperTest(test.TestCase): output_keep_prob=keep_some, state_keep_prob=keep_all, variational_recurrent=True, + wrapper_type=wrapper_type, input_size=3, batch_size=5, time_steps=7) @@ -974,13 +1015,16 @@ class DropoutWrapperTest(test.TestCase): for m in output_mask[1:]: self.assertAllClose(output_mask[0], m) - def testDropoutWrapperRecurrentStateInputAndOutput(self): + @parameterized.parameters( + [rnn_cell_impl.DropoutWrapper, rnn_cell_impl.DropoutWrapperV2]) + def testDropoutWrapperRecurrentStateInputAndOutput(self, wrapper_type): keep_some = 0.9 res = self._testDropoutWrapper( input_keep_prob=keep_some, output_keep_prob=keep_some, state_keep_prob=keep_some, variational_recurrent=True, + wrapper_type=wrapper_type, input_size=3, batch_size=5, time_steps=7) @@ -1002,7 +1046,10 @@ class DropoutWrapperTest(test.TestCase): for batch_entry in state_h_mask: self.assertAllClose(batch_entry, state_h_mask[0]) - def testDropoutWrapperRecurrentStateInputAndOutputWithSeed(self): + @parameterized.parameters( + [rnn_cell_impl.DropoutWrapper, rnn_cell_impl.DropoutWrapperV2]) + def testDropoutWrapperRecurrentStateInputAndOutputWithSeed( + self, wrapper_type): keep_some = 0.9 random_seed.set_random_seed(2347) np.random.seed(23487) @@ -1011,6 +1058,7 @@ class DropoutWrapperTest(test.TestCase): output_keep_prob=keep_some, state_keep_prob=keep_some, variational_recurrent=True, + wrapper_type=wrapper_type, input_size=3, batch_size=5, time_steps=7, @@ -1024,6 +1072,7 @@ class DropoutWrapperTest(test.TestCase): output_keep_prob=keep_some, state_keep_prob=keep_some, variational_recurrent=True, + wrapper_type=wrapper_type, input_size=3, batch_size=5, time_steps=7, @@ -1050,6 +1099,60 @@ class DropoutWrapperTest(test.TestCase): self.assertAllClose(res0[1].c, res1[1].c) self.assertAllClose(res0[1].h, res1[1].h) + def testDropoutWrapperKerasStyle(self): + """Tests if DropoutWrapperV2 cell is instantiated in keras style scope.""" + wrapped_cell_v2 = rnn_cell_impl.DropoutWrapperV2( + rnn_cell_impl.BasicRNNCell(1)) + self.assertTrue(wrapped_cell_v2._keras_style) + + wrapped_cell = rnn_cell_impl.DropoutWrapper(rnn_cell_impl.BasicRNNCell(1)) + self.assertFalse(wrapped_cell._keras_style) + + def testDropoutWrapperV2VariableNames(self): + """Tests that variables names do not depend on wrapper in RNN layer.""" + + def _rnn_input(apply_wrapper): + """Creates a RNN layer with/without wrapper and returns built rnn cell.""" + with base_layer.keras_style_scope(): + base_cell = rnn_cell_impl.MultiRNNCell( + [rnn_cell_impl.BasicRNNCell(1) for _ in range(2)]) + if apply_wrapper: + rnn_cell = rnn_cell_impl.DropoutWrapperV2(base_cell) + else: + rnn_cell = base_cell + rnn_layer = keras_layers.RNN(rnn_cell) + inputs = ops.convert_to_tensor([[[1]]], dtype=dtypes.float32) + _ = rnn_layer(inputs) + return base_cell._cells[0] + + rnn_1 = _rnn_input(True) + ops.reset_default_graph() + rnn_2 = _rnn_input(False) + + self.assertLen(rnn_1.weights, expected_len=2) + self.assertCountEqual([v.name for v in rnn_1.weights], + [v.name for v in rnn_2.weights]) + + def testDropoutWrapperV2Caller(self): + """Tests that DropoutWrapperV2 is using the LayerRNNCell's caller.""" + + with base_layer.keras_style_scope(): + base_cell = rnn_cell_impl.MultiRNNCell( + [rnn_cell_impl.BasicRNNCell(1) for _ in range(2)]) + rnn_cell = rnn_cell_impl.DropoutWrapperV2(base_cell) + inputs = ops.convert_to_tensor([[1]], dtype=dtypes.float32) + state = ops.convert_to_tensor([[1]], dtype=dtypes.float32) + _ = rnn_cell(inputs, [state, state]) + weights = base_cell._cells[0].weights + self.assertLen(weights, expected_len=2) + self.assertTrue(all(["dropout_wrapper" in v.name for v in weights])) + + def testDropoutWrapperV2Build(self): + cell = rnn_cell_impl.LSTMCell(10) + wrapper = rnn_cell_impl.DropoutWrapperV2(cell) + wrapper.build((1,)) + self.assertTrue(cell.built) + def basic_rnn_cell(inputs, state, num_units, scope=None): if state is None: diff --git a/tensorflow/contrib/saved_model/python/saved_model/keras_saved_model.py b/tensorflow/contrib/saved_model/python/saved_model/keras_saved_model.py index ffba514bb96f5ce8d963cb0a0482738eafe88355..2a4b6eae367fe617e9a19d80f16eb3fda9ade1c0 100644 --- a/tensorflow/contrib/saved_model/python/saved_model/keras_saved_model.py +++ b/tensorflow/contrib/saved_model/python/saved_model/keras_saved_model.py @@ -22,53 +22,57 @@ import os import six from tensorflow.python.client import session -from tensorflow.python.estimator import keras as estimator_keras_util -from tensorflow.python.estimator import model_fn as model_fn_lib -from tensorflow.python.estimator.export import export as export_helpers from tensorflow.python.framework import ops from tensorflow.python.keras import backend as K from tensorflow.python.keras import models as models_lib from tensorflow.python.keras import optimizers from tensorflow.python.keras.engine import sequential +from tensorflow.python.keras.engine import training_utils from tensorflow.python.keras.metrics import Metric from tensorflow.python.keras.models import model_from_json from tensorflow.python.lib.io import file_io from tensorflow.python.ops import variables -from tensorflow.python.platform import gfile from tensorflow.python.platform import tf_logging as logging from tensorflow.python.saved_model import builder as saved_model_builder from tensorflow.python.saved_model import constants +from tensorflow.python.saved_model import save as save_lib from tensorflow.python.saved_model import utils_impl as saved_model_utils from tensorflow.python.training import saver as saver_lib from tensorflow.python.training.checkpointable import util as checkpointable_utils from tensorflow.python.util import compat +from tensorflow.python.util import nest +from tensorflow_estimator.python.estimator import keras as estimator_keras_util +from tensorflow_estimator.python.estimator import model_fn as model_fn_lib +from tensorflow_estimator.python.estimator.export import export as export_helpers def save_keras_model( - model, saved_model_path, custom_objects=None, as_text=None): - """Save a `tf.keras.Model` into Tensorflow SavedModel format. + model, saved_model_path, custom_objects=None, as_text=None, + input_signature=None, serving_only=False): + """Saves a `tf.keras.Model` into Tensorflow SavedModel format. `save_model` generates new files/folders under the `saved_model_path` folder: - 1) an asset folder containing the json string of the model's - configuration (topology). - 2) a checkpoint containing the model weights. - 3) a saved_model.pb file containing the model's MetaGraphs. The prediction + 1) a checkpoint containing the model weights. + 2) a saved_model.pb file containing the model's MetaGraphs. The prediction graph is always exported. The evaluaton and training graphs are exported if the following conditions are met: - Evaluation: model loss is defined. - Training: model is compiled with an optimizer defined under `tf.train`. This is because `tf.keras.optimizers.Optimizer` instances cannot be saved to checkpoints. - - Model Requirements: - - Model must be a sequential model or functional model. Subclassed models can - not be saved via this function, unless you provide an implementation for - get_config() and from_config(). - - All variables must be saveable by the model. In general, this condition is - met through the use of layers defined in the keras library. However, - there is currently a bug with variables created in Lambda layer functions - not being saved correctly (see - https://github.com/keras-team/keras/issues/9740). + 3) Model's json configuration, if model.get_config() has been implemented. + This file can be used to reload the model using + tf.keras.models.model_from_json(). Note that if any custom objects were + used, they should be passed to the `custom_object` argument when loading + the model. + + Model limitations: + - Sequential and functional models can always be saved. + - Subclassed models can only be saved when `serving_only=True`. This is due to + the current implementation copying the model in order to export the training + and evaluation graphs. Because the topology of subclassed models cannot be + determined, the subclassed models cannot be cloned. Subclassed models will + be entirely exportable in the future. Note that each mode is exported in separate graphs, so different modes do not share variables. To use the train graph with evaluation or prediction graphs, @@ -94,38 +98,88 @@ def save_keras_model( ``` Args: - model: A `tf.keras.Model` to be saved. + model: A `tf.keras.Model` to be saved. If the model is subclassed, the flag + `serving_only` must be set to True. saved_model_path: a string specifying the path to the SavedModel directory. The SavedModel will be saved to a timestamped folder created within this directory. custom_objects: Optional dictionary mapping string names to custom classes or functions (e.g. custom loss functions). - as_text: whether to write the `SavedModel` proto in text format. + as_text: whether to write the `SavedModel` proto in text format. Currently + unavailable in serving-only mode. + input_signature: A possibly nested sequence of `tf.TensorSpec` objects, used + to specify the expected model inputs. `input_signature`'s nested structure + should match the expected nested structure of the inputs to the model. If + this is not set, this function will attempt to infer the input shapes and + dtypes from the model. Note that if the model is subclassed, the tensor + inputs to the call function should be nested in the first argument (this + is a general requirement for using subclassed models with Keras functions + .fit(), .predict(), etc.). + serving_only: Export only the outputs produced from calling the model in + predict mode. The losses, optimizer, and other training configurations are + not saved. If the SavedModel will only be used for serving (rather than + retraining), or if the model is subclassed, this can be set to True. Returns: String path to the SavedModel folder, a subdirectory of `saved_model_path`. Raises: - NotImplementedError: If the model is a subclassed model. - ValueError: If a Sequential model does not have input shapes defined by the - user, and is not built. + NotImplementedError: If the model is a subclassed model, and serving_only is + False. + ValueError: If the input signature cannot be inferred from the model. """ + export_dir = export_helpers.get_timestamped_export_dir(saved_model_path) + + if serving_only: + save_lib.save( + model, export_dir, + signatures=training_utils.trace_model_call(model, input_signature)) + else: + _save_v1_format(model, export_dir, custom_objects, as_text, input_signature) + + try: + _export_model_json(model, export_dir) + except NotImplementedError: + logging.warning('Skipped saving model JSON, subclassed model does not have ' + 'get_config() defined.') + + return export_dir + + +def _export_model_json(model, saved_model_path): + """Saves model configuration as a json string under assets folder.""" + model_json = model.to_json() + model_json_filepath = os.path.join( + saved_model_utils.get_or_create_assets_dir(saved_model_path), + compat.as_text(constants.SAVED_MODEL_FILENAME_JSON)) + file_io.write_string_to_file(model_json_filepath, model_json) + + +def _export_model_variables(model, saved_model_path): + """Saves model weights in checkpoint format under variables folder.""" + saved_model_utils.get_or_create_variables_dir(saved_model_path) + checkpoint_prefix = saved_model_utils.get_variables_path(saved_model_path) + model.save_weights(checkpoint_prefix, save_format='tf', overwrite=True) + return checkpoint_prefix + + +def _save_v1_format(model, path, custom_objects, as_text, input_signature): + """Exports model to v1 SavedModel format.""" if not model._is_graph_network: if isinstance(model, sequential.Sequential): # If input shape is not directly set in the model, the exported model - # will assume that the inputs have the same shape as the shape the model - # was built model with. - if not model.built: + # will infer the expected shapes of the input from the model. + if not model.built and input_signature is None: raise ValueError( - 'Sequential model must be built before it can be exported.') + 'Sequential model\'s input shape is unknown. Please build the ' + 'model, or use the input_signature argument to specify the ' + 'model inputs.') else: raise NotImplementedError( - 'Exporting subclassed models is not yet supported.') + 'Subclassed models can only be exported for serving. Please set ' + 'argument serving_only=True.') - export_dir = export_helpers.get_timestamped_export_dir(saved_model_path) - temp_export_dir = export_helpers.get_temp_export_dir(export_dir) - - builder = saved_model_builder._SavedModelBuilder(temp_export_dir) + builder = saved_model_builder._SavedModelBuilder(path) # Manually save variables to export them in an object-based checkpoint. This # skips the `builder.add_meta_graph_and_variables()` step, which saves a @@ -133,7 +187,7 @@ def save_keras_model( # TODO(b/113134168): Add fn to Builder to save with object-based saver. # TODO(b/113178242): This should only export the model json structure. Only # one save is needed once the weights can be copied from the model to clone. - checkpoint_path = _export_model_json_and_variables(model, temp_export_dir) + checkpoint_path = _export_model_variables(model, path) # Export each mode. Use ModeKeys enums defined for `Estimator` to ensure that # Keras models and `Estimator`s are exported with the same format. @@ -143,10 +197,12 @@ def save_keras_model( export_args = {'builder': builder, 'model': model, 'custom_objects': custom_objects, - 'checkpoint_path': checkpoint_path} + 'checkpoint_path': checkpoint_path, + 'input_signature': input_signature} has_saved_vars = False if model.optimizer: + # TODO(kathywu): Verify this works with v2 optimizer. if isinstance(model.optimizer, optimizers.TFOptimizer): _export_mode(model_fn_lib.ModeKeys.TRAIN, has_saved_vars, **export_args) has_saved_vars = True @@ -161,34 +217,20 @@ def save_keras_model( builder.save(as_text) - gfile.Rename(temp_export_dir, export_dir) - return export_dir - - -def _export_model_json_and_variables(model, saved_model_path): - """Save model variables and json structure into SavedModel subdirectories.""" - # Save model configuration as a json string under assets folder. - model_json = model.to_json() - model_json_filepath = os.path.join( - saved_model_utils.get_or_create_assets_dir(saved_model_path), - compat.as_text(constants.SAVED_MODEL_FILENAME_JSON)) - file_io.write_string_to_file(model_json_filepath, model_json) - - # Save model weights in checkpoint format under variables folder. - saved_model_utils.get_or_create_variables_dir(saved_model_path) - checkpoint_prefix = saved_model_utils.get_variables_path(saved_model_path) - model.save_weights(checkpoint_prefix, save_format='tf', overwrite=True) - return checkpoint_prefix - def _get_var_list(model): - """Return list of all checkpointed saveable objects in the model.""" + """Returns list of all checkpointed saveable objects in the model.""" return checkpointable_utils.named_saveables(model) +def create_placeholder(spec): + return K.placeholder(shape=spec.shape, dtype=spec.dtype, name=spec.name) + + def _export_mode( - mode, has_saved_vars, builder, model, custom_objects, checkpoint_path): - """Export a model, and optionally save new vars from the clone model. + mode, has_saved_vars, builder, model, custom_objects, checkpoint_path, + input_signature): + """Exports a model, and optionally saves new vars from the clone model. Args: mode: A `tf.estimator.ModeKeys` string. @@ -199,6 +241,8 @@ def _export_mode( custom_objects: A dictionary mapping string names to custom classes or functions. checkpoint_path: String path to checkpoint. + input_signature: Nested TensorSpec containing the expected inputs. Can be + `None`, in which case the signature will be inferred from the model. Raises: ValueError: If the train/eval mode is being exported, but the model does @@ -214,10 +258,16 @@ def _export_mode( K.set_learning_phase(mode == model_fn_lib.ModeKeys.TRAIN) + if input_signature is None: + input_tensors = None + else: + input_tensors = nest.map_structure(create_placeholder, input_signature) + # Clone the model into blank graph. This will create placeholders for inputs # and targets. clone = models_lib.clone_and_build_model( - model, custom_objects=custom_objects, compile_clone=compile_clone) + model, input_tensors=input_tensors, custom_objects=custom_objects, + compile_clone=compile_clone) # Make sure that iterations variable is added to the global step collection, # to ensure that, when the SavedModel graph is loaded, the iterations @@ -271,7 +321,7 @@ def _export_mode( def _create_signature_def_map(model, mode): - """Create a SignatureDef map from a Keras model.""" + """Creates a SignatureDef map from a Keras model.""" inputs_dict = {name: x for name, x in zip(model.input_names, model.inputs)} if model.optimizer: targets_dict = {x.name.split(':')[0]: x @@ -309,14 +359,14 @@ def _create_signature_def_map(model, mode): def _assert_same_non_optimizer_objects(model, model_graph, clone, clone_graph): # pylint: disable=unused-argument - """Assert model and clone contain the same checkpointable objects.""" + """Asserts model and clone contain the same checkpointable objects.""" # TODO(fchollet, kathywu): make sure this works in eager mode. return True def load_keras_model(saved_model_path): - """Load a keras.Model from SavedModel. + """Loads a keras.Model from SavedModel. load_model reinstantiates model state by: 1) loading model topology from json (this will eventually come diff --git a/tensorflow/contrib/saved_model/python/saved_model/keras_saved_model_test.py b/tensorflow/contrib/saved_model/python/saved_model/keras_saved_model_test.py index 93d73e1b484ed810fb347b13e95022dfca3584c2..fbf8138493362d4a3c8a75e1ee1bb2fbe8096499 100644 --- a/tensorflow/contrib/saved_model/python/saved_model/keras_saved_model_test.py +++ b/tensorflow/contrib/saved_model/python/saved_model/keras_saved_model_test.py @@ -29,7 +29,9 @@ from tensorflow.python import keras from tensorflow.python.client import session from tensorflow.python.eager import context from tensorflow.python.estimator import model_fn as model_fn_lib +from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_spec from tensorflow.python.framework import test_util from tensorflow.python.keras.engine import training from tensorflow.python.keras.utils import tf_utils @@ -215,7 +217,7 @@ class LayerWithLearningPhase(keras.engine.base_layer.Layer): return input_shape -def functional_model(uses_learning_phase): +def functional_model(uses_learning_phase=True): inputs = keras.layers.Input(shape=(3,)) x = keras.layers.Dense(2)(inputs) x = keras.layers.Dense(3)(x) @@ -224,7 +226,7 @@ def functional_model(uses_learning_phase): return keras.models.Model(inputs, x) -def sequential_model(uses_learning_phase): +def sequential_model(uses_learning_phase=True): model = keras.models.Sequential() model.add(keras.layers.Dense(2, input_shape=(3,))) model.add(keras.layers.Dense(3)) @@ -233,7 +235,7 @@ def sequential_model(uses_learning_phase): return model -def sequential_model_without_input_shape(uses_learning_phase): +def sequential_model_without_input_shape(uses_learning_phase=True): model = keras.models.Sequential() model.add(keras.layers.Dense(2)) model.add(keras.layers.Dense(3)) @@ -242,10 +244,30 @@ def sequential_model_without_input_shape(uses_learning_phase): return model +class Subclassed(keras.models.Model): + + def __init__(self): + super(Subclassed, self).__init__() + self.dense1 = keras.layers.Dense(2) + self.dense2 = keras.layers.Dense(3) + + def call(self, inputs): + x = self.dense1(inputs) + x = self.dense2(x) + return x + + +def subclassed_model(): + return Subclassed() + + def load_model(sess, path, mode): tags = model_fn_lib.EXPORT_TAG_MAP[mode] - sig_def_key = (signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY - if mode == model_fn_lib.ModeKeys.PREDICT else mode) + if mode == model_fn_lib.ModeKeys.PREDICT: + sig_def_key = signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY + else: + sig_def_key = mode + meta_graph_def = loader_impl.load(sess, tags, path) inputs = { k: sess.graph.get_tensor_by_name(v.name) @@ -463,13 +485,54 @@ class TestModelSavedModelExport(test.TestCase, parameterized.TestCase): clone.compile(loss='mse', optimizer=keras.optimizers.RMSprop(lr=0.0001)) clone.train_on_batch(input_arr, target_arr) - def testSaveSeqModelWithoutInputShapesRaisesError(self): - """A Sequential model that hasn't been built should raise an error.""" + def testSaveSequentialModelWithoutInputShapes(self): model = sequential_model_without_input_shape(True) - with self.assertRaisesRegexp( - ValueError, 'must be built'): + # A Sequential model that hasn't been built should raise an error. + with self.assertRaisesRegexp(ValueError, 'Please build the model'): keras_saved_model.save_keras_model(model, '') + saved_model_path = self._save_model_dir() + output_path = keras_saved_model.save_keras_model( + model, saved_model_path, + input_signature=tensor_spec.TensorSpec(shape=(10, 11, 12, 13, 14), + dtype=dtypes.float32, + name='spec_input')) + + with session.Session(graph=ops.Graph()) as sess: + inputs, outputs, _ = load_model(sess, output_path, + model_fn_lib.ModeKeys.PREDICT) + self.assertEqual(5, inputs[next(iter(inputs.keys()))].shape.ndims) + self.assertEqual(5, outputs[next(iter(outputs.keys()))].shape.ndims) + self.assertEqual(3, outputs[next(iter(outputs.keys()))].shape[-1]) + + @test_util.run_v2_only + @parameterized.parameters( + { + 'model_builder': sequential_model_without_input_shape, + 'input_signature': [tensor_spec.TensorSpec(shape=[None, 3], + dtype=dtypes.float32)]}, + { + 'model_builder': subclassed_model, + 'input_signature': [tensor_spec.TensorSpec(shape=[None, 3], + dtype=dtypes.float32)]}) + def testServingOnly(self, model_builder, input_signature): + saved_model_path = self._save_model_dir() + input_arr = np.random.random((5, 3)).astype(np.float32) + model = model_builder() + ref_predict = model.predict(input_arr) + + output_path = keras_saved_model.save_keras_model( + model, saved_model_path, serving_only=True, + input_signature=input_signature) + + # Load predict graph, and test predictions + with session.Session(graph=ops.Graph()) as sess: + inputs, outputs, _ = load_model(sess, output_path, + model_fn_lib.ModeKeys.PREDICT) + predictions = sess.run(outputs[next(iter(outputs.keys()))], + {inputs[next(iter(inputs.keys()))]: input_arr}) + self.assertAllClose(ref_predict, predictions, atol=1e-05) + if __name__ == '__main__': test.main() diff --git a/tensorflow/contrib/seq2seq/python/kernel_tests/attention_wrapper_test.py b/tensorflow/contrib/seq2seq/python/kernel_tests/attention_wrapper_test.py index 922f21b98b35dfff19c8c605a25e89c5d2da8d98..d815f81f847ad79ddcc6c6ecf5c050598e185d8d 100644 --- a/tensorflow/contrib/seq2seq/python/kernel_tests/attention_wrapper_test.py +++ b/tensorflow/contrib/seq2seq/python/kernel_tests/attention_wrapper_test.py @@ -35,6 +35,7 @@ from tensorflow.python.ops import array_ops from tensorflow.python.ops import init_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import random_ops +from tensorflow.python.ops import rnn from tensorflow.python.ops import rnn_cell from tensorflow.python.ops import variables from tensorflow.python.ops import variable_scope as vs @@ -992,5 +993,67 @@ class AttentionWrapperTest(test.TestCase): expected_final_alignment_history=expected_final_alignment_history, name='testMultiAttention') + def testCustomizedAttention(self): + batch_size = 2 + max_time = 3 + num_units = 2 + memory = constant_op.constant([[[1., 1.], [2., 2.], [3., 3.]], + [[4., 4.], [5., 5.], [6., 6.]]]) + memory_sequence_length = constant_op.constant([3, 2]) + attention_mechanism = wrapper.BahdanauAttention(num_units, memory, + memory_sequence_length) + + # Sets all returned values to be all ones. + def _customized_attention(unused_attention_mechanism, unused_cell_output, + unused_attention_state, unused_attention_layer): + """Customized attention. + + Returns: + attention: `Tensor` of shape [batch_size, num_units], attention output. + alignments: `Tensor` of shape [batch_size, max_time], sigma value for + each input memory (prob. function of input keys). + next_attention_state: A `Tensor` representing the next state for the + attention. + """ + attention = array_ops.ones([batch_size, num_units]) + alignments = array_ops.ones([batch_size, max_time]) + next_attention_state = alignments + return attention, alignments, next_attention_state + + attention_cell = wrapper.AttentionWrapper( + rnn_cell.LSTMCell(2), + attention_mechanism, + attention_layer_size=None, # don't use attention layer. + output_attention=False, + alignment_history=(), + attention_fn=_customized_attention, + name='attention') + self.assertEqual(num_units, attention_cell.output_size) + + initial_state = attention_cell.zero_state( + batch_size=2, dtype=dtypes.float32) + source_input_emb = array_ops.ones([2, 3, 2]) + source_input_length = constant_op.constant([3, 2]) + + # 'state' is a tuple of + # (cell_state, h, attention, alignments, alignment_history, attention_state) + output, state = rnn.dynamic_rnn( + attention_cell, + inputs=source_input_emb, + sequence_length=source_input_length, + initial_state=initial_state, + dtype=dtypes.float32) + + with self.session() as sess: + sess.run(variables.global_variables_initializer()) + output_value, state_value = sess.run([output, state], feed_dict={}) + self.assertAllEqual(np.array([2, 3, 2]), output_value.shape) + self.assertAllClose(np.array([[1., 1.], [1., 1.]]), state_value.attention) + self.assertAllClose( + np.array([[1., 1., 1.], [1., 1., 1.]]), state_value.alignments) + self.assertAllClose( + np.array([[1., 1., 1.], [1., 1., 1.]]), state_value.attention_state) + + if __name__ == '__main__': test.main() diff --git a/tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py b/tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py index 77e9f848b137911b53e1b4df5dd740fe38af55bb..60ec3efffe771a3a6d6f36ed4b51a34ef9509612 100644 --- a/tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py +++ b/tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py @@ -1088,7 +1088,8 @@ class AttentionWrapper(rnn_cell_impl.RNNCell): output_attention=True, initial_cell_state=None, name=None, - attention_layer=None): + attention_layer=None, + attention_fn=None): """Construct the `AttentionWrapper`. **NOTE** If you are using the `BeamSearchDecoder` with a cell wrapped in @@ -1132,7 +1133,9 @@ class AttentionWrapper(rnn_cell_impl.RNNCell): feed the context and cell output into the attention layer to generate attention at each time step. If attention_mechanism is a list, attention_layer_size must be a list of the same length. If - attention_layer is set, this must be None. + attention_layer is set, this must be None. If attention_fn is set, + it must guaranteed that the outputs of attention_fn also meet the + above requirements. alignment_history: Python boolean, whether to store alignment history from all time steps in the final output state (currently stored as a time major `TensorArray` on which you must call `stack()`). @@ -1158,6 +1161,12 @@ class AttentionWrapper(rnn_cell_impl.RNNCell): the context as attention at each time step. If attention_mechanism is a list, attention_layer must be a list of the same length. If attention_layers_size is set, this must be None. + attention_fn: An optional callable function that allows users to provide + their own customized attention function, which takes input + (attention_mechanism, cell_output, attention_state, attention_layer) and + outputs (attention, alignments, next_attention_state). If provided, + the attention_layer_size should be the size of the outputs of + attention_fn. Raises: TypeError: `attention_layer_size` is not None and (`attention_mechanism` @@ -1240,6 +1249,10 @@ class AttentionWrapper(rnn_cell_impl.RNNCell): tensor_shape.dimension_value(attention_mechanism.values.shape[-1]) for attention_mechanism in attention_mechanisms) + if attention_fn is None: + attention_fn = _compute_attention + self._attention_fn = attention_fn + self._cell = cell self._attention_mechanisms = attention_mechanisms self._cell_input_fn = cell_input_fn @@ -1443,7 +1456,7 @@ class AttentionWrapper(rnn_cell_impl.RNNCell): all_attention_states = [] maybe_all_histories = [] for i, attention_mechanism in enumerate(self._attention_mechanisms): - attention, alignments, next_attention_state = _compute_attention( + attention, alignments, next_attention_state = self._attention_fn( attention_mechanism, cell_output, previous_attention_state[i], self._attention_layers[i] if self._attention_layers else None) alignment_history = previous_alignment_history[i].write( diff --git a/tensorflow/contrib/seq2seq/python/ops/beam_search_decoder.py b/tensorflow/contrib/seq2seq/python/ops/beam_search_decoder.py index ab36848f13ab3078cd232c18f140188e12db703b..8f8f057702951094758b277ce060955f3dc6e99d 100644 --- a/tensorflow/contrib/seq2seq/python/ops/beam_search_decoder.py +++ b/tensorflow/contrib/seq2seq/python/ops/beam_search_decoder.py @@ -921,6 +921,7 @@ def _get_scores(log_probs, sequence_lengths, length_penalty_weight, """ length_penalty_ = _length_penalty( sequence_lengths=sequence_lengths, penalty_factor=length_penalty_weight) + length_penalty_ = math_ops.cast(length_penalty_, dtype=log_probs.dtype) scores = log_probs / length_penalty_ coverage_penalty_weight = ops.convert_to_tensor( diff --git a/tensorflow/contrib/solvers/python/kernel_tests/lanczos_test.py b/tensorflow/contrib/solvers/python/kernel_tests/lanczos_test.py index 8fcd7aeef6a6964902666a4f3c17e05b0c7b52ee..f31bdbd399c9de4f2f5d557b75b1ece6d64a765e 100644 --- a/tensorflow/contrib/solvers/python/kernel_tests/lanczos_test.py +++ b/tensorflow/contrib/solvers/python/kernel_tests/lanczos_test.py @@ -19,6 +19,7 @@ from __future__ import print_function import numpy as np +from tensorflow.python import tf2 from tensorflow.contrib.solvers.python.ops import lanczos from tensorflow.contrib.solvers.python.ops import util from tensorflow.python.framework import constant_op @@ -80,7 +81,8 @@ if __name__ == "__main__": for shape in [[4, 4], [7, 4], [5, 8]]: for orthogonalize in True, False: for steps in range(1, min(shape) + 1): - for use_static_shape in True, False: + # TF2 does not support placeholders so we skip it + for use_static_shape in set([True, tf2.enabled()]): arg_string = "%s_%s_%s_%s_staticshape_%s" % ( dtype.__name__, "_".join(map(str, shape)), orthogonalize, steps, use_static_shape) diff --git a/tensorflow/contrib/solvers/python/kernel_tests/least_squares_test.py b/tensorflow/contrib/solvers/python/kernel_tests/least_squares_test.py index 2a9100903aae5689919a6b25fcb18ff192f250b3..841a41a2339824ab8ca15f4bdd74be697cd6fe9f 100644 --- a/tensorflow/contrib/solvers/python/kernel_tests/least_squares_test.py +++ b/tensorflow/contrib/solvers/python/kernel_tests/least_squares_test.py @@ -19,6 +19,7 @@ from __future__ import print_function import numpy as np +from tensorflow.python import tf2 from tensorflow.contrib.solvers.python.ops import least_squares from tensorflow.contrib.solvers.python.ops import util from tensorflow.python.framework import constant_op @@ -76,7 +77,8 @@ def _get_least_squares_tests(dtype_, use_static_shape_, shape_): if __name__ == "__main__": for dtype in np.float32, np.float64: for shape in [[4, 4], [8, 5], [3, 7]]: - for use_static_shape in True, False: + # TF2 does not support placeholders under eager so we skip it + for use_static_shape in set([True, tf2.enabled()]): arg_string = "%s_%s_staticshape_%s" % (dtype.__name__, "_".join(map(str, shape)), use_static_shape) diff --git a/tensorflow/contrib/solvers/python/kernel_tests/linear_equations_test.py b/tensorflow/contrib/solvers/python/kernel_tests/linear_equations_test.py index a0e6eb87bc06fb1303a7eb86fa6760458f20a9b9..10807f7a80617e56abeb6d13ce419a49a2269aac 100644 --- a/tensorflow/contrib/solvers/python/kernel_tests/linear_equations_test.py +++ b/tensorflow/contrib/solvers/python/kernel_tests/linear_equations_test.py @@ -19,6 +19,7 @@ from __future__ import print_function import numpy as np +from tensorflow.python import tf2 from tensorflow.contrib.solvers.python.ops import linear_equations from tensorflow.contrib.solvers.python.ops import util from tensorflow.python.framework import constant_op @@ -113,7 +114,8 @@ def _get_linear_equations_tests(dtype_, use_static_shape_, shape_): if __name__ == "__main__": for dtype in np.float32, np.float64: for size in 1, 4, 10: - for use_static_shape in True, False: + # TF2 does not support placeholders under eager so we skip it + for use_static_shape in set([True, tf2.enabled()]): shape = [size, size] arg_string = "%s_%s_staticshape_%s" % (dtype.__name__, size, use_static_shape) diff --git a/tensorflow/contrib/tensorrt/README.md b/tensorflow/contrib/tensorrt/README.md index caf8b6db0dc0a220d593f9c0afc9464ca51a1e05..a9c2ad78a3db409e6e8669c48c4df37c8db19c4b 100644 --- a/tensorflow/contrib/tensorrt/README.md +++ b/tensorflow/contrib/tensorrt/README.md @@ -1,8 +1,46 @@ -# Using TensorRT in TensorFlow +# Using TensorRT in TensorFlow (TF-TRT) -This module provides necessary bindings and introduces TRT_engine_op operator -that wraps a subgraph in TensorRT. This is still a work in progress but should -be useable with most common graphs. +This module provides necessary bindings and introduces `TRTEngineOp` operator +that wraps a subgraph in TensorRT. This module is under active development. + +## Installing TF-TRT + +Currently TensorFlow nightly builds include TF-TRT by default, which means you +don't need to install TF-TRT separately. You can pull the latest TF containers +from docker hub or install the latest TF pip package to get access to the latest +TF-TRT. + +If you want to use TF-TRT on NVIDIA Jetson platform, you can find the download +links for the relevant TensorFlow pip packages here: +https://docs.nvidia.com/deeplearning/dgx/index.html#installing-frameworks-for-jetson + +## Installing TensorRT + +In order to make use of TF-TRT, you will need a local installation of TensorRT. +Installation instructions for compatibility with TensorFlow are provided on the +[TensorFlow GPU support](https://www.tensorflow.org/install/gpu) guide. + +## Examples + +You can find example scripts for running inference on deep learning models in +this repository: https://github.com/tensorflow/tensorrt + +We have used these examples to verify the accuracy and performance of TF-TRT. +For more information see +[Verified Models](https://docs.nvidia.com/deeplearning/dgx/integrate-tf-trt/index.html#verified-models). + +## Documentation + +[TF-TRT documentation](https://docs.nvidia.com/deeplearning/dgx/integrate-tf-trt/index.html) +gives an overview of the supported functionalities, provides tutorials and +verified models, explains best practices with troubleshooting guides. + +## Tests + +TF-TRT includes both Python tests and C++ unit tests. Most of Python tests are +located in the test directory and they can be executed using `bazel test` or +directly with the Python command. Most of the C++ unit tests are used to test +the conversion functions that convert each TF op to a number of TensorRT layers. ## Compilation @@ -18,12 +56,3 @@ bazel build --config=cuda --config=opt //tensorflow/tools/pip_package:build_pip_ bazel-bin/tensorflow/tools/pip_package/build_pip_package /tmp/ ``` -After the installation of tensorflow package, TensorRT transformation will be -available. An example use can be found in test/test_tftrt.py script - -## Installing TensorRT 3.0.4 - -In order to make use of TensorRT integration, you will need a local installation -of TensorRT 3.0.4 from the [NVIDIA Developer website](https://developer.nvidia.com/tensorrt). -Installation instructions for compatibility with TensorFlow are provided on the -[TensorFlow GPU support](https://www.tensorflow.org/install/gpu) guide. diff --git a/tensorflow/contrib/tensorrt/convert/convert_graph.cc b/tensorflow/contrib/tensorrt/convert/convert_graph.cc index 3b32f72bc1f220fd6730c71e3d2b3b6b806b748e..bf2de94e04ae3f6817f7a679ce9fd88e750827dd 100644 --- a/tensorflow/contrib/tensorrt/convert/convert_graph.cc +++ b/tensorflow/contrib/tensorrt/convert/convert_graph.cc @@ -89,49 +89,52 @@ Status TrtCandidateSelector::IsTensorRTCandidate(const tensorflow::Node* node) { // TODO(laigd): move this set to TrtNodeValidator where it should belong. // LINT.IfChange static const std::set candidate_ops = { - "Identity", - "Snapshot", - "Const", - "Conv2D", - "MaxPool", - "BiasAdd", - "Relu", - "Sigmoid", - "Tanh", + "Abs", "Add", - "Mul", - "Sub", - "Rsqrt", - "Pad", - "Mean", "AvgPool", + "BatchMatMul", + "BiasAdd", "ConcatV2", + "Const", + "Conv2D", "DepthwiseConv2dNative", - "FusedBatchNorm", - "FusedBatchNormV2", "Div", - "RealDiv", - "Rsqrt", - "Reciprocal", "Exp", + "ExpandDims", + "FusedBatchNorm", + "FusedBatchNormV2", + "Identity", "Log", - "Sqrt", - "Abs", - "Neg", - "Transpose", - "Reshape", "MatMul", - "BatchMatMul", - "Softmax", - "Minimum", - "Maximum", - "TopKV2", - "Sum", - "Prod", "Max", + "MaxPool", + "Maximum", + "Mean", "Min", + "Minimum", + "Mul", + "Neg", + "Pad", + "Prod", + "RealDiv", + "Reciprocal", + "Relu", "Relu6", + "Reshape", + "Rsqrt", + "Rsqrt", + "Sigmoid", + "Snapshot", + "Softmax", + "Sqrt", "Square", + "Squeeze", + "StridedSlice", + "Sub", + "Sum", + "Tanh", + "TopKV2", + "Transpose", }; bool is_supported_op_type = (candidate_ops.count(node->type_string()) || @@ -320,6 +323,13 @@ tensorflow::Status ConvertGraphDefToTensorRT( return Status::OK(); } +struct EdgePtrCompare { + bool operator()(const tensorflow::Edge* lhs, + const tensorflow::Edge* rhs) const { + return lhs->id() < rhs->id(); + } +}; + // Function to get subsegment information structure. tensorflow::Status GetEngineInfo( const tensorflow::Graph* g, @@ -358,8 +368,12 @@ tensorflow::Status GetEngineInfo( } const int node_id = node->id(); subgraph_node_ids.push_back(node_id); - // Create input connections. - for (const auto edge : node->in_edges()) { + // Create input connections. Sort edges first to make determnistic since + // in_edges is a set of pointers. + std::vector in_edges(node->in_edges().begin(), + node->in_edges().end()); + std::sort(in_edges.begin(), in_edges.end(), EdgePtrCompare()); + for (const auto edge : in_edges) { auto input_node = edge->src(); if (input_node->IsSource() || segment_nodes.count(input_node->name())) { continue; @@ -407,8 +421,12 @@ tensorflow::Status GetEngineInfo( node_id, edge->dst_input(), /*input_edge=*/true, port); } } - // Create output connections. - for (const auto edge : node->out_edges()) { + // Create output connections. Sort edges first to make determnistic since + // out_edges is a set of pointers. + std::vector out_edges(node->out_edges().begin(), + node->out_edges().end()); + std::sort(out_edges.begin(), out_edges.end(), EdgePtrCompare()); + for (const auto edge : out_edges) { auto output_node = edge->dst(); if (output_node->IsSink() || segment_nodes.count(output_node->name())) { continue; @@ -564,6 +582,18 @@ tensorflow::Status CreateTRTNode(const std::vector& infos, int pos, } input_shape_protos.at(conn.port_number) = in_shape; input_shapes.at(conn.port_number) = conn.outside_shape; + // Shape must be fully defined (excluding batch dimension) for static + // mode. + if (info.engine_type == EngineInfo::EngineType::TRTStatic) { + for (int i = 1; i < conn.outside_shape.dims(); i++) { + if (conn.outside_shape.dim_size(i) <= 0) { + return tensorflow::errors::Internal( + "Input shapes must be fully defined when in static mode. " + "Please try is_dynamic_op=True (shape was ", + conn.outside_shape.DebugString(), ")"); + } + } + } // Rewrire data input if it's not found in original graph. tensorflow::Node* input_node = graph->FindNodeId(conn.outside_id); @@ -585,6 +615,14 @@ tensorflow::Status CreateTRTNode(const std::vector& infos, int pos, } } } + // We don't support segments with no inputs. Fall back to native TF here to + // avoid crash later. Constant folding should've folded the ops that make up + // these segments. + if (inputs.empty()) { + return tensorflow::errors::Internal( + "Segment has no inputs (possible " + "constfold failure)"); + } const bool calibrate_int8 = (info.precision_mode == INT8MODE && info.use_calibration); diff --git a/tensorflow/contrib/tensorrt/convert/convert_nodes.cc b/tensorflow/contrib/tensorrt/convert/convert_nodes.cc index fee095668e5aef44316ff15c1d8572b2ecd960df..adf8831b960172fc29b5d631e5b0533318d4764d 100644 --- a/tensorflow/contrib/tensorrt/convert/convert_nodes.cc +++ b/tensorflow/contrib/tensorrt/convert/convert_nodes.cc @@ -120,6 +120,15 @@ inline nvinfer1::Dims TensorShapeToTrtDims(const TensorShapeType& shape, return trt_dims; } +Status TensorShapeArrayToTrtDims(const std::vector& shape, + nvinfer1::Dims* out, + bool ignore_first_dim = false) { + PartialTensorShape tensor_shape; + TF_RETURN_IF_ERROR(TensorShapeUtils::MakeShape(shape, &tensor_shape)); + *out = TensorShapeToTrtDims(tensor_shape, ignore_first_dim); + return tensorflow::Status::OK(); +} + void GetOutputProperties(const grappler::GraphProperties& graph_properties, const Node* node, const int out_port, PartialTensorShape* shape, @@ -623,6 +632,11 @@ bool TFAttrs::get(const string& key) const { return this->at(key)->b(); } +template <> +int TFAttrs::get(const string& key) const { + return this->at(key)->i(); +} + // TODO(jie): reorder4 & reorder2 should be merged? // TODO(aaroey): fix the order of parameters. template @@ -1524,6 +1538,24 @@ enum class ConvolutionType { DEFAULT, DEPTHWISE_CONV }; tensorflow::Status ConvertConv2DHelper(OpConverterParams* params, int group) { const auto& inputs = params->inputs; const auto& node_def = params->node_def; + if (inputs.at(0).is_weights()) { + return tensorflow::errors::Unimplemented( + node_def.op(), " is only implemented for tensors, not weights, at ", + node_def.name()); + } + if (inputs.at(1).is_tensor()) { + return tensorflow::errors::Unimplemented("Kernel for ", node_def.op(), + " must be constant weights, at ", + node_def.name()); + } + TRT_ShapedWeights weights_rsck = inputs.at(1).weights(); + VLOG(2) << "weight shape: " << weights_rsck.DebugString(); + if (weights_rsck.shape_.nbDims != 4) { + return tensorflow::errors::Internal( + "Conv2D expects kernel of dimension 4, at: " + node_def.name()); + } + if (params->validation_only) return tensorflow::Status::OK(); + const nvinfer1::ITensor* tensor = inputs.at(0).tensor(); TFAttrs attrs(node_def); @@ -1545,12 +1577,6 @@ tensorflow::Status ConvertConv2DHelper(OpConverterParams* params, int group) { if (num_groups == 0) num_groups = tensor_dim.d[0]; // depthwise convolution VLOG(2) << "groups count: " << num_groups; - TRT_ShapedWeights weights_rsck = inputs.at(1).weights(); - VLOG(2) << "weight shape: " << weights_rsck.DebugString(); - if (weights_rsck.shape_.nbDims != 4) { - return tensorflow::errors::Internal( - "Conv2D expects kernel of dimension 4, at: " + node_def.name()); - } if (params->converter->precision_mode() == FP16MODE) { weights_rsck = ConvertFP32ToFP16(params->weight_store, inputs.at(1).weights()); @@ -1637,7 +1663,7 @@ tensorflow::Status ConvertConv2DHelper(OpConverterParams* params, case ConvolutionType::DEPTHWISE_CONV: return ConvertConv2DHelper(params, 0); } - return tensorflow::errors::Unimplemented("unsupported convolution type at, " + + return tensorflow::errors::Unimplemented("Unsupported convolution type, at ", params->node_def.name()); } @@ -1880,6 +1906,372 @@ tensorflow::Status ConvertReshape(OpConverterParams* params) { return tensorflow::Status::OK(); } +tensorflow::Status ConvertExpandDims(OpConverterParams* params) { + const auto& inputs = params->inputs; + const auto& node_def = params->node_def; + if (inputs.size() != 2) { + return tensorflow::errors::InvalidArgument( + "Two inputs expected for ExpandDims, at ", node_def.name()); + } + if (inputs.at(0).is_weights()) { + return tensorflow::errors::Unimplemented( + "ExpandDims expects tensor for input, at ", node_def.name()); + } + if (!inputs.at(1).is_weights()) { + return tensorflow::errors::InvalidArgument( + "ExpandDims expects weights for axis, at ", node_def.name()); + } + // Get input shape as vector. + TRT_TensorOrWeights input_tensor = inputs.at(0); + const nvinfer1::Dims dims = input_tensor.GetTrtDims(); + std::vector input_dims(dims.d, dims.d + dims.nbDims); + // Add batch dim back. + input_dims.insert(input_dims.begin(), -1); + const int input_rank = input_dims.size(); + // Get axis to expand on. + TRT_ShapedWeights weights = inputs.at(1).weights(); + if (weights.count() != 1) { + return tensorflow::errors::InvalidArgument( + "ExpandDims axis must be a scalar, at ", node_def.name()); + } + const int* weights_ptr = + static_cast(const_cast(weights.GetValues())); + int axis = weights_ptr[0]; + // Make sure axis is valid. + if ((axis < (-input_rank - 1)) || (axis > input_rank)) { + return tensorflow::errors::InvalidArgument( + "Axis for ExpandDims is invalid, must be in the range " + "[-rank(input) - 1, rank(input)], at ", + node_def.name()); + } + // Convert negative axis to corresponding positive axis. + if (axis < 0) axis += input_rank + 1; + if (axis == 0) { + return tensorflow::errors::Unimplemented( + "Modifying batch dimension is not supported for ExpandDims, at ", + node_def.name()); + } + if (params->validation_only) return Status::OK(); + + // ExpandDims: Insert new dim of size 1. + input_dims.insert(input_dims.begin() + axis, 1); + // Reshape tensor. + nvinfer1::Dims new_dims; + TF_RETURN_IF_ERROR(TensorShapeArrayToTrtDims(input_dims, &new_dims, + /*ignore_first_dim=*/true)); + const nvinfer1::ITensor* output_tensor = nullptr; + TF_RETURN_IF_ERROR(params->converter->PrepareTensorForShape( + input_tensor, new_dims, &output_tensor)); + params->outputs->push_back( + TRT_TensorOrWeights(const_cast(output_tensor))); + return tensorflow::Status::OK(); +} + +tensorflow::Status ConvertSqueeze(OpConverterParams* params) { + const auto& inputs = params->inputs; + const auto& node_def = params->node_def; + if (inputs.size() != 1) { + return tensorflow::errors::InvalidArgument( + "One input expected for Squeeze, at ", node_def.name()); + } + if (inputs.at(0).is_weights()) { + return tensorflow::errors::Unimplemented( + "Squeeze expects tensor for input, at ", node_def.name()); + } + // Get input shape. + TRT_TensorOrWeights input_tensor = inputs.at(0); + const nvinfer1::Dims dims = input_tensor.GetTrtDims(); + std::vector input_dims(dims.d, dims.d + dims.nbDims); + // Add batch dim back. + input_dims.insert(input_dims.begin(), -1); + const int input_rank = input_dims.size(); + // Mark axes to remove by setting them to 0. + TFAttrs attrs(node_def); + auto squeeze_dims = attrs.get>("squeeze_dims"); + if (squeeze_dims.size() == 0) { + return tensorflow::errors::Unimplemented( + "Squeeze is only implemented for explicit dims, at ", node_def.name()); + } + for (int axis : squeeze_dims) { + // Make sure axis is valid. + if ((axis < -input_rank) || (axis >= input_rank)) { + return tensorflow::errors::InvalidArgument( + "Axis for Squeeze is invalid, must be in the range " + "[-rank(input), rank(input)), at ", + node_def.name()); + } + // Convert negative axis to corresponding positive axis. + if (axis < 0) axis += input_rank; + // Don't squeeze batch dim. + if (axis == 0) { + return tensorflow::errors::Unimplemented( + "Cannot squeeze batch dimension, at ", node_def.name()); + } + // Make sure target dimension is size 1. + if (input_dims[axis] != 1) { + return tensorflow::errors::InvalidArgument( + "Cannot squeeze a dimension which isn't size 1, at ", + node_def.name()); + } + // Mark dim for removal by setting to 0. + input_dims[axis] = 0; + } + if (params->validation_only) return Status::OK(); + + // Remove all dims which are equal to 0. + input_dims.erase(std::remove(input_dims.begin(), input_dims.end(), 0), + input_dims.end()); + // Reshape tensor. + nvinfer1::Dims new_dims; + TF_RETURN_IF_ERROR(TensorShapeArrayToTrtDims(input_dims, &new_dims, + /*ignore_first_dim=*/true)); + const nvinfer1::ITensor* output_tensor = nullptr; + TF_RETURN_IF_ERROR(params->converter->PrepareTensorForShape( + input_tensor, new_dims, &output_tensor)); + params->outputs->push_back( + TRT_TensorOrWeights(const_cast(output_tensor))); + return tensorflow::Status::OK(); +} + +// Gets the bounds (start or end) from the weights of a StridedSlice op. +tensorflow::Status GetStridedSliceBound(const std::vector& input_dims, + const TRT_ShapedWeights& bound_weights, + int mask, bool begin, string node_name, + std::vector* output_bound) { + const string bound_name = (begin) ? "begin" : "end"; + const int* weights_ptr = static_cast(bound_weights.GetValues()); + *output_bound = + std::vector(weights_ptr, weights_ptr + bound_weights.count()); + if (output_bound->size() != input_dims.size()) { + return tensorflow::errors::InvalidArgument( + "StridedSlice \"", bound_name, "\" specified ", + std::to_string(output_bound->size()), " dimensions, but input rank is ", + std::to_string(input_dims.size()), ", at ", node_name); + } + for (int i = 0; i < output_bound->size(); i++) { + if ((1 << i) & mask) { + // Apply mask. + (*output_bound)[i] = (begin) ? 0 : input_dims[i]; + // Masked bound will always result in a valid, non-negative bound, so we + // don't need the following checks. For the common case of using masks on + // a undefined batch dim (-1), we specifically don't want to do the + // following checks because they will erroneously detect an out of range + // bound or try to correct the negative value. + continue; + } + // Make sure bound is valid. + if (((*output_bound)[i] < -input_dims[i]) || + ((*output_bound)[i] > input_dims[i])) { + return tensorflow::errors::InvalidArgument( + bound_name, " value of ", std::to_string((*output_bound)[i]), + " for StridedSlice is invalid, must be in the range " + "[-dim_size(i), dim_size(i)], at ", + node_name); + } + // Convert negative values to their positive equivalent. + if ((*output_bound)[i] < 0) { + (*output_bound)[i] += input_dims[i]; + } + } + return tensorflow::Status::OK(); +} + +tensorflow::Status ConvertStridedSlice(OpConverterParams* params) { + const auto& inputs = params->inputs; + const auto& node_def = params->node_def; + if (inputs.size() != 4) { + return tensorflow::errors::InvalidArgument( + "StridedSlice expects 4 inputs, at ", node_def.name()); + } + if (!inputs.at(1).is_weights() || !inputs.at(2).is_weights() || + !inputs.at(3).is_weights()) { + return tensorflow::errors::InvalidArgument( + "StridedSlice expects weights for begin, end, and strides, at ", + node_def.name()); + } + if (!inputs.at(0).is_tensor()) { + return tensorflow::errors::Unimplemented( + "StridedSlice is only implemented for tensors, at ", node_def.name()); + } + // Get input dims. + nvinfer1::Dims dims = inputs.at(0).GetTrtDims(); + std::vector input_dims(dims.d, dims.d + dims.nbDims); + if (inputs.at(0).is_tensor()) { + // Temporarily add batch dimension so that indexes line up properly. + input_dims.insert(input_dims.begin(), inputs.at(0).batch_size()); + } + if (input_dims.size() > 4) { + return tensorflow::errors::Unimplemented( + "StridedSlice is not implemented for tensors with rank > 4, at ", + node_def.name()); + } + TFAttrs attrs(node_def); + // Get begin and end bounds per axis. + std::vector begin, end; + TF_RETURN_IF_ERROR(GetStridedSliceBound(input_dims, inputs.at(1).weights(), + attrs.get("begin_mask"), true, + node_def.name(), &begin)); + TF_RETURN_IF_ERROR(GetStridedSliceBound(input_dims, inputs.at(2).weights(), + attrs.get("end_mask"), false, + node_def.name(), &end)); + // Get strides per axis (must all be 1). + TRT_ShapedWeights stride_weights = inputs.at(3).weights(); + const int* stride_weights_ptr = static_cast(stride_weights.GetValues()); + std::vector strides(stride_weights_ptr, + stride_weights_ptr + stride_weights.count()); + for (int x : strides) { + if (x != 1) { + return tensorflow::errors::Unimplemented( + "StridedSlice is only implemented for stride of 1, at ", + node_def.name()); + } + } + // Unsupported mask options. + for (const string& attr : + {"ellipsis_mask", "new_axis_mask", "shrink_axis_mask"}) { + int attr_val = attrs.get(attr); + if (attr_val != 0) { + return tensorflow::errors::Unimplemented( + attr, " is not supported for StridedSlice, at ", node_def.name()); + } + } + + nvinfer1::ITensor* tensor = + const_cast(inputs.at(0).tensor()); + // Reshape if necessary to 4-D, since IPaddingLayer requires a 4-D input. + const bool need_reshape = (input_dims.size() != 4); + int reshape_dims_added = 0; + nvinfer1::Dims reshape_dims; + if (need_reshape) { + // Add new dims after batch dim until tensor is 4D. + while (input_dims.size() < 4) { + input_dims.insert(input_dims.begin() + 1, 1); + begin.insert(begin.begin() + 1, 0); + end.insert(end.begin() + 1, 1); + reshape_dims_added++; + } + TF_RETURN_IF_ERROR(TensorShapeArrayToTrtDims(input_dims, &reshape_dims, + /*ignore_first_dim=*/true)); + } + // Find dimensions which need to be sliced. + std::vector pad_dims; + for (int i = 0; i < input_dims.size(); i++) { + if ((begin[i] != 0) || (end[i] != input_dims[i])) { + if (i == 0) { + return tensorflow::errors::Unimplemented( + "StridedSlice can't modify batch dim, at ", node_def.name()); + } else if ((end[i] - begin[i]) < 0) { + return tensorflow::errors::InvalidArgument( + "New size of sliced dimension is negative, at ", node_def.name()); + } + pad_dims.push_back(i); + } + } + if (pad_dims.size() == 0) { + // No dimensions are changed. We could create a padding layer anyway with + // values of 0. + if (params->validation_only) return Status::OK(); + params->outputs->push_back(inputs.at(0)); + return tensorflow::Status::OK(); + } else if (pad_dims.size() == 1) { + // Only one dim is modified but we have to have 2, mark a second dim which + // will have padding of 0. The dim we add is chosen to avoid an unecessary + // transpose. + if (pad_dims[0] != 2) { + pad_dims.push_back(2); + } else { + pad_dims.push_back(3); + } + } else if (pad_dims.size() > 2) { + return tensorflow::errors::Unimplemented( + "StridedSlice can only modify 2 dimensions, at ", node_def.name()); + } + std::sort(pad_dims.begin(), pad_dims.end()); + // Convert to pre/post padding values. Since TRT does not have a StridedSlice + // or Slice layer, we instead create an IPaddingLayer with negative padding. + nvinfer1::DimsHW pre_padding, post_padding; + for (int i = 0; i < pad_dims.size(); i++) { + const int axis = pad_dims[i]; + pre_padding.d[i] = -begin[axis]; + post_padding.d[i] = end[axis] - input_dims[axis]; + } + + // IPaddingLayer will always apply the padding to dims 2,3 (input format is + // NCHW). + const bool need_transpose = !(pad_dims[0] == 2 && pad_dims[1] == 3); + std::vector transpose_order(input_dims.size()); + std::vector inv_transpose_order(input_dims.size()); + if (need_transpose) { + if (pad_dims[0] == 1 && pad_dims[1] == 3) { + transpose_order = {0, 2, 1, 3}; + inv_transpose_order = {0, 2, 1, 3}; + } else if (pad_dims[0] == 1 && pad_dims[1] == 2) { + transpose_order = {0, 3, 1, 2}; + inv_transpose_order = {0, 2, 3, 1}; + } + } + if (params->validation_only) return Status::OK(); + + // Start conversion. + if (need_reshape) { + const nvinfer1::ITensor* output_tensor = nullptr; + TF_RETURN_IF_ERROR(params->converter->PrepareTensorForShape( + inputs.at(0), reshape_dims, &output_tensor)); + tensor = const_cast(output_tensor); + } + if (need_transpose) { + const nvinfer1::ITensor* output_tensor = nullptr; + TF_RETURN_IF_ERROR(params->converter->TransposeTensor( + tensor, transpose_order, &output_tensor)); + tensor = const_cast(output_tensor); + } + + // Add padding layer + nvinfer1::IPaddingLayer* layer = params->converter->network()->addPadding( + *const_cast(tensor), pre_padding, post_padding); + TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name()); + params->converter->MarkQuantizationRangesAsInferrable(tensor, + layer->getOutput(0)); + tensor = layer->getOutput(0); + + // Restore transpose + if (need_transpose) { + const nvinfer1::ITensor* output_tensor = nullptr; + TF_RETURN_IF_ERROR(params->converter->TransposeTensor( + tensor, inv_transpose_order, &output_tensor)); + tensor = const_cast(output_tensor); + } + // Restore reshape + if (need_reshape) { + // Calculate output dimensions + for (int i = 0; i < pad_dims.size(); i++) { + const int axis = pad_dims[i]; + input_dims[axis] = end[axis] - begin[axis]; + } + // Remove added 1 dimensions + for (int i = 0; i < reshape_dims_added; i++) { + int value = input_dims[1]; + if (value != 1) { + return tensorflow::errors::Internal( + "StridedSlice error when reshaping, at ", node_def.name()); + } + input_dims.erase(input_dims.begin() + 1); + } + + nvinfer1::Dims new_dims; + TF_RETURN_IF_ERROR(TensorShapeArrayToTrtDims(input_dims, &new_dims, + /*ignore_first_dim=*/true)); + const nvinfer1::ITensor* output_tensor = nullptr; + TF_RETURN_IF_ERROR(params->converter->PrepareTensorForShape( + TRT_TensorOrWeights(tensor), new_dims, &output_tensor)); + tensor = const_cast(output_tensor); + } + + params->outputs->push_back( + TRT_TensorOrWeights(const_cast(tensor))); + return tensorflow::Status::OK(); +} + tensorflow::Status ConvertConv2D(OpConverterParams* params) { return ConvertConv2DHelper(params, ConvolutionType::DEFAULT); } @@ -1891,9 +2283,29 @@ tensorflow::Status ConvertConv2DDepthwise(OpConverterParams* params) { tensorflow::Status ConvertPool(OpConverterParams* params) { const auto& inputs = params->inputs; const auto& node_def = params->node_def; - const nvinfer1::ITensor* tensor = inputs.at(0).tensor(); + if (inputs.at(0).is_weights()) { + return tensorflow::errors::Unimplemented( + node_def.op(), " is only implemented for tensors, not weights, at ", + node_def.name()); + } + nvinfer1::PoolingType type; + if (node_def.op() == "MaxPool") { + type = nvinfer1::PoolingType::kMAX; + } else if (node_def.op() == "AvgPool") { + type = nvinfer1::PoolingType::kAVERAGE; + } else { + return tensorflow::errors::Unimplemented( + "Unsupported pooling type: ", node_def.op(), ", at ", node_def.name()); + } TFAttrs attrs(node_def); + const string padding_type = attrs.get("padding"); + if ((padding_type != "SAME") && (padding_type != "VALID")) { + return tensorflow::errors::Unimplemented( + "Unsupported padding type: ", padding_type, ", at ", node_def.name()); + } + if (params->validation_only) return Status::OK(); + const nvinfer1::ITensor* tensor = inputs.at(0).tensor(); int h_index = 2; int w_index = 3; const auto data_format = attrs.get("data_format"); @@ -1904,16 +2316,6 @@ tensorflow::Status ConvertPool(OpConverterParams* params) { const_cast(tensor), {0, 3, 1, 2}, &tensor)); } - nvinfer1::PoolingType type; - if (node_def.op() == "MaxPool") { - type = nvinfer1::PoolingType::kMAX; - } else if (node_def.op() == "AvgPool") { - type = nvinfer1::PoolingType::kAVERAGE; - } else { - return tensorflow::errors::Unimplemented("Unsupported pool type: ", - node_def.op()); - } - const auto tf_stride = attrs.get>("strides"); const nvinfer1::DimsHW stride(tf_stride[h_index], tf_stride[w_index]); @@ -1922,7 +2324,6 @@ tensorflow::Status ConvertPool(OpConverterParams* params) { auto tensor_dim = tensor->getDimensions(); std::vector> padding; - const string padding_type = attrs.get("padding"); if (padding_type == "SAME") { // This is NCHW tensor with no batch dimension. // 1 -> h @@ -1932,9 +2333,6 @@ tensorflow::Status ConvertPool(OpConverterParams* params) { {static_cast(tensor_dim.d[1]), static_cast(tensor_dim.d[2])}); } else if (padding_type == "VALID") { padding = {{0, 0}, {0, 0}}; - } else { - return tensorflow::errors::Unimplemented("Unsupported padding type: ", - padding_type); } if (padding[0].first != padding[0].second || @@ -2701,6 +3099,7 @@ tensorflow::Status ConvertPad(OpConverterParams* params) { return tensorflow::errors::Unimplemented( "Padding layer does not support padding on dimension 1 and 3 yet"); } + if (params->validation_only) return Status::OK(); bool legit_pad = true; nvinfer1::DimsHW pre_padding(0, 0); @@ -2804,6 +3203,7 @@ tensorflow::Status ConvertConcat(OpConverterParams* params) { inputs_vec.push_back(tensor_i); } + if (params->validation_only) return tensorflow::Status::OK(); // nvinfer1::ITensor const* tensor = inputs.at(0).tensor(); nvinfer1::IConcatenationLayer* layer = @@ -2825,12 +3225,35 @@ tensorflow::Status ConvertFusedBatchNorm(OpConverterParams* params) { auto data_format = attrs.get("data_format"); if (data_format != "NCHW") { return tensorflow::errors::Unimplemented( - "only data_format=NCHW is supported, at " + node_def.name()); + node_def.op(), " only supports data_format=NCHW, at ", node_def.name()); } bool is_training = attrs.get("is_training"); if (is_training) { + // Trying to use batchnorm in training mode is a very common problem. + // Because the error message will only be printed in VLOG(1) by the + // segmenter, we issue a special warning so that users will actually see it. + LOG(WARNING) << node_def.op() << " only supports is_training=false. If you " + << "are using Keras, please call " + << "keras.backend.set_learning_phase(0) before constructing " + << "your model. At " << node_def.name(); return tensorflow::errors::Unimplemented( - "only is_training=false is supported, at " + node_def.name()); + node_def.op(), " only supports is_training=false, at ", + node_def.name()); + } + if (inputs.at(0).is_weights()) { + return tensorflow::errors::Unimplemented( + node_def.op(), + " is only implemented for tensor inputs, not weights, at ", + node_def.name()); + } + for (int i = 1; i < 5; i++) { + if (inputs.at(i).is_tensor()) { + return tensorflow::errors::Unimplemented( + node_def.op(), + " must have constant inputs for scale, offset, mean and variance, " + "at ", + node_def.name()); + } } nvinfer1::ITensor const* tensor = inputs.at(0).tensor(); @@ -2845,7 +3268,7 @@ tensorflow::Status ConvertFusedBatchNorm(OpConverterParams* params) { for (int i = 1; i < 5; i++) { if (inputs.at(i).weights().type_ != parameter_type) { return tensorflow::errors::Unimplemented( - "Inconsistent parameter type for batchnormis not supported, at: " + + "Inconsistent parameter type for batchnorm is not supported, at: " + node_def.name()); } } @@ -2865,6 +3288,8 @@ tensorflow::Status ConvertFusedBatchNorm(OpConverterParams* params) { "Inconsistent batchnorm parameter count, at: " + node_def.name()); } } + if (params->validation_only) return Status::OK(); + // We could technically have two weights with different shape. // that requires two addScale op, arguably less performant TRT_ShapedWeights combined_scale_weights = @@ -3150,12 +3575,19 @@ static void RegisterValidatableOpConverters( std::unordered_map* registration) { // TODO(laigd): support all op types. (*registration)["BiasAdd"] = ConvertBiasAdd; + (*registration)["ConcatV2"] = ConvertConcat; (*registration)["Const"] = ConvertConst; - (*registration)["Transpose"] = ConvertTranspose; - (*registration)["Reshape"] = ConvertReshape; + (*registration)["Conv2D"] = ConvertConv2D; + (*registration)["DepthwiseConv2dNative"] = ConvertConv2DDepthwise; + (*registration)["ExpandDims"] = ConvertExpandDims; (*registration)["MatMul"] = ConvertMatMul; + (*registration)["Pad"] = ConvertPad; (*registration)["Relu6"] = ConvertRelu6; + (*registration)["Reshape"] = ConvertReshape; (*registration)["Square"] = ConvertSquare; + (*registration)["Squeeze"] = ConvertSqueeze; + (*registration)["StridedSlice"] = ConvertStridedSlice; + (*registration)["Transpose"] = ConvertTranspose; for (auto quantization_op_type : {"QuantizeAndDequantizeV2", "QuantizeAndDequantizeV3", @@ -3169,6 +3601,12 @@ static void RegisterValidatableOpConverters( for (auto activation_op_type : {"Relu", "Sigmoid", "Tanh"}) { (*registration)[activation_op_type] = ConvertActivation; } + for (auto pool_op_type : {"AvgPool", "MaxPool"}) { + (*registration)[pool_op_type] = ConvertPool; + } + for (auto normalization_op_type : {"FusedBatchNorm", "FusedBatchNormV2"}) { + (*registration)[normalization_op_type] = ConvertFusedBatchNorm; + } } void TrtNodeValidator::RegisterOpValidators() { @@ -3177,21 +3615,10 @@ void TrtNodeValidator::RegisterOpValidators() { void Converter::RegisterOpConverters() { RegisterValidatableOpConverters(&op_registry_); - - op_registry_["Conv2D"] = ConvertConv2D; - op_registry_["DepthwiseConv2dNative"] = ConvertConv2DDepthwise; - op_registry_["MaxPool"] = ConvertPool; - op_registry_["AvgPool"] = ConvertPool; // TODO(ben,jie): this is a temp hack. op_registry_["Identity"] = ConvertIdentity; // Identity should be removed op_registry_["Snapshot"] = ConvertIdentity; // Snapshot should be removed - op_registry_["Pad"] = ConvertPad; - - op_registry_["ConcatV2"] = ConvertConcat; - op_registry_["FusedBatchNorm"] = ConvertFusedBatchNorm; - op_registry_["FusedBatchNormV2"] = ConvertFusedBatchNorm; - op_registry_["Rsqrt"] = ConvertUnary; op_registry_["Reciprocal"] = ConvertUnary; op_registry_["Exp"] = ConvertUnary; diff --git a/tensorflow/contrib/tensorrt/convert/convert_nodes_test.cc b/tensorflow/contrib/tensorrt/convert/convert_nodes_test.cc index 443033379f0d6554784d44412a02aa8cb035ab08..a2ddfbffa5b0d8c421bcfe054097a9e42b79fe8f 100644 --- a/tensorflow/contrib/tensorrt/convert/convert_nodes_test.cc +++ b/tensorflow/contrib/tensorrt/convert/convert_nodes_test.cc @@ -2113,6 +2113,512 @@ TEST_F(OpConverterTest, ConvertActivation) { } } +TEST_F(OpConverterTest, ConvertExpandDims) { + { + // Input list is empty, should fail. + NodeDef node_def = MakeNodeDef("my_expanddims", "ExpandDims", {}); + RunValidationAndConversion( + node_def, error::INVALID_ARGUMENT, + "Two inputs expected for ExpandDims, at my_expanddims"); + } + + // Get the NodeDef for ExpandDims. + Scope s = Scope::NewRootScope(); + auto input = ops::Placeholder(s.WithOpName("input"), DT_FLOAT); + auto weights = ops::Placeholder(s.WithOpName("weights"), DT_INT32); + auto expanddims = + ops::ExpandDims(s.WithOpName("my_expanddims"), input, weights); + const NodeDef& node_def = expanddims.operation.node()->def(); + { + // Input is weights, should fail. + Reset(); + AddTestWeights("input", {1, 2, 3}, {1, 2, 3, 4, 5, 6}); + AddTestWeights("weights", {1}, {1}); + RunValidationAndConversion( + node_def, error::UNIMPLEMENTED, + "ExpandDims expects tensor for input, at my_expanddims"); + } + { + // Axis is a tensor, should fail. + Reset(); + AddTestTensor("input", {1, 2, 3}); + AddTestTensor("weights", {3}); + RunValidationAndConversion( + node_def, error::INVALID_ARGUMENT, + "ExpandDims expects weights for axis, at my_expanddims"); + } + { + // Add dim at batch dimension, should fail. + Reset(); + AddTestTensor("input", {1, 2, 3}); + AddTestWeights("weights", {1}, {0}); + RunValidationAndConversion( + node_def, error::UNIMPLEMENTED, + "Modifying batch dimension is not supported for ExpandDims, at " + "my_expanddims"); + } + { + // Add dim at batch dimension via negative axis, should fail. + Reset(); + AddTestTensor("input", {1, 2, 3}); + // Input is rank 4 (batch dim included) + AddTestWeights("weights", {1}, {-5}); + RunValidationAndConversion( + node_def, error::UNIMPLEMENTED, + "Modifying batch dimension is not supported for ExpandDims, at " + "my_expanddims"); + } + { + // Axis > rank(input), should fail. + Reset(); + AddTestTensor("input", {1, 2, 3}); + // Input is rank 4 (batch dim included) + AddTestWeights("weights", {1}, {5}); + RunValidationAndConversion( + node_def, error::INVALID_ARGUMENT, + "Axis for ExpandDims is invalid, must be in the range " + "[-rank(input) - 1, rank(input)], at my_expanddims"); + } + { + // Axis < -rank(input)-1, should fail. + Reset(); + AddTestTensor("input", {1, 2, 3}); + // Input is rank 4 (batch dim included) + AddTestWeights("weights", {1}, {-6}); + RunValidationAndConversion( + node_def, error::INVALID_ARGUMENT, + "Axis for ExpandDims is invalid, must be in the range " + "[-rank(input) - 1, rank(input)], at my_expanddims"); + } + + struct TestParams { + TestParams(const std::vector& input_dims, int axis, + const std::vector& expected_output_dims) + : input_dims(input_dims), + axis(axis), + expected_output_dims(expected_output_dims) {} + std::vector input_dims; + int axis; + std::vector expected_output_dims; + }; + + // Ok. + const int kExpandDimsOKCases = 8; + TestParams ok_params[kExpandDimsOKCases] = { + TestParams{{2, 3}, 1, {1, 2, 3}}, TestParams{{2, 3}, -3, {1, 2, 3}}, + TestParams{{2, 3}, 3, {2, 3, 1}}, TestParams{{2, 3}, -1, {2, 3, 1}}, + TestParams{{2, 3}, 2, {2, 1, 3}}, TestParams{{2, 3}, -2, {2, 1, 3}}, + TestParams{{6}, 1, {1, 6}}, TestParams{{6}, -1, {6, 1}}, + }; + for (int i = 0; i < kExpandDimsOKCases; ++i) { + Reset(); + AddTestTensor("input", ok_params[i].input_dims); + AddTestWeights("weights", {1}, {ok_params[i].axis}); + RunValidationAndConversion(node_def); + TRT_TensorOrWeights output; + TF_EXPECT_OK(GetTensorOrWeights("my_expanddims", &output)); + EXPECT_TRUE(output.is_tensor()); + ExpectTrtDimsEqualsArray(ok_params[i].expected_output_dims, + output.tensor()->getDimensions()); + + std::vector output_data(6); + BuildAndRun({{"input", {1, 2, 3, 4, 5, 6}}}, "my_expanddims", + &output_data); + EXPECT_THAT(output_data, ElementsAre(1, 2, 3, 4, 5, 6)); + } +} + +TEST_F(OpConverterTest, ConvertSqueeze) { + { + // Input list is empty, should fail. + NodeDef node_def = MakeNodeDef("my_squeeze", "Squeeze", {}); + RunValidationAndConversion(node_def, error::INVALID_ARGUMENT, + "One input expected for Squeeze, at my_squeeze"); + } + { + // No attrs, should fail. + Reset(); + Scope s = Scope::NewRootScope(); + auto input = ops::Placeholder(s.WithOpName("input"), DT_FLOAT); + auto squeeze = ops::Squeeze(s.WithOpName("my_squeeze"), input); + const NodeDef& node_def = squeeze.operation.node()->def(); + AddTestTensor("input", {1, 2, 3}); + RunValidationAndConversion( + node_def, error::UNIMPLEMENTED, + "Squeeze is only implemented for explicit dims, at my_squeeze"); + } + + // Get the NodeDef for Squeeze. + auto get_squeeze_nodedef = [](std::vector axis) -> NodeDef { + Scope s = Scope::NewRootScope(); + auto input = ops::Placeholder(s.WithOpName("input"), DT_FLOAT); + ops::Squeeze::Attrs squeeze_attrs; + squeeze_attrs.axis_ = gtl::ArraySlice(axis); + auto squeeze = + ops::Squeeze(s.WithOpName("my_squeeze"), input, squeeze_attrs); + return squeeze.operation.node()->def(); + }; + + { + // Input is weights, should fail. + Reset(); + NodeDef node_def = get_squeeze_nodedef({0}); + AddTestWeights("input", {1, 2, 3}, {1, 2, 3, 4, 5, 6}); + RunValidationAndConversion( + node_def, error::UNIMPLEMENTED, + "Squeeze expects tensor for input, at my_squeeze"); + } + { + // Squeeze batch dim, should fail. + Reset(); + NodeDef node_def = get_squeeze_nodedef({0}); + AddTestTensor("input", {1, 2, 3}); + RunValidationAndConversion(node_def, error::UNIMPLEMENTED, + "Cannot squeeze batch dimension, at my_squeeze"); + } + { + // Squeeze batch dim via negative axis, should fail. + Reset(); + NodeDef node_def = get_squeeze_nodedef({-4}); + AddTestTensor("input", {1, 2, 3}); + RunValidationAndConversion(node_def, error::UNIMPLEMENTED, + "Cannot squeeze batch dimension, at my_squeeze"); + } + { + // Squeeze >= rank(input), should fail. + Reset(); + NodeDef node_def = get_squeeze_nodedef({4}); + AddTestTensor("input", {1, 2, 3}); + RunValidationAndConversion( + node_def, error::INVALID_ARGUMENT, + "Axis for Squeeze is invalid, must be in the range " + "[-rank(input), rank(input)), at my_squeeze"); + } + { + // Squeeze < -rank(input), should fail. + Reset(); + NodeDef node_def = get_squeeze_nodedef({-5}); + AddTestTensor("input", {1, 2, 3}); + RunValidationAndConversion( + node_def, error::INVALID_ARGUMENT, + "Axis for Squeeze is invalid, must be in the range " + "[-rank(input), rank(input)), at my_squeeze"); + } + + struct TestParams { + TestParams(const std::vector& input_dims, const std::vector& axis, + const std::vector& expected_output_dims) + : input_dims(input_dims), + axis(axis), + expected_output_dims(expected_output_dims) {} + std::vector input_dims; + std::vector axis; + std::vector expected_output_dims; + }; + + // Ok. + const int kSqueezeOKCases = 10; + TestParams ok_params[kSqueezeOKCases] = { + TestParams{{1, 2, 3}, {1}, {2, 3}}, + TestParams{{1, 2, 3}, {-3}, {2, 3}}, + TestParams{{2, 3, 1}, {3}, {2, 3}}, + TestParams{{2, 3, 1}, {-1}, {2, 3}}, + TestParams{{1, 2, 1, 3, 1}, {1, 3, 5}, {2, 3}}, + TestParams{{1, 2, 1, 3, 1}, {3, 1, 5}, {2, 3}}, + TestParams{{1, 2, 1, 3, 1}, {-1, -3, -5}, {2, 3}}, + TestParams{{1, 2, 1, 3, 1}, {1, -3, 5}, {2, 3}}, + TestParams{{1, 6}, {1}, {6}}, + TestParams{{6, 1}, {2}, {6}}, + }; + for (int i = 0; i < kSqueezeOKCases; ++i) { + Reset(); + NodeDef node_def = get_squeeze_nodedef(ok_params[i].axis); + AddTestTensor("input", ok_params[i].input_dims); + RunValidationAndConversion(node_def); + TRT_TensorOrWeights output; + TF_EXPECT_OK(GetTensorOrWeights("my_squeeze", &output)); + EXPECT_TRUE(output.is_tensor()); + ExpectTrtDimsEqualsArray(ok_params[i].expected_output_dims, + output.tensor()->getDimensions()); + + std::vector output_data(6); + BuildAndRun({{"input", {1, 2, 3, 4, 5, 6}}}, "my_squeeze", + &output_data); + EXPECT_THAT(output_data, ElementsAre(1, 2, 3, 4, 5, 6)); + } +} + +TEST_F(OpConverterTest, ConvertStridedSlice) { + { + // Input list is empty, should fail. + NodeDef node_def = MakeNodeDef("my_strided_slice", "StridedSlice", {}); + RunValidationAndConversion( + node_def, error::INVALID_ARGUMENT, + "StridedSlice expects 4 inputs, at my_strided_slice"); + } + + // Get nodedef for StridedSlice layer. + auto get_strided_slice_nodedef = + [](int begin_mask = 0, int end_mask = 0, int ellipsis_mask = 0, + int new_axis_mask = 0, int shrink_axis_mask = 0) -> NodeDef { + Scope s = Scope::NewRootScope(); + auto input = ops::Placeholder(s.WithOpName("input"), DT_FLOAT); + auto begin = ops::Placeholder(s.WithOpName("begin"), DT_INT32); + auto end = ops::Placeholder(s.WithOpName("end"), DT_INT32); + auto strides = ops::Placeholder(s.WithOpName("strides"), DT_INT32); + ops::StridedSlice::Attrs attrs = ops::StridedSlice::Attrs() + .BeginMask(begin_mask) + .EndMask(end_mask) + .EllipsisMask(ellipsis_mask) + .NewAxisMask(new_axis_mask) + .ShrinkAxisMask(shrink_axis_mask); + auto strided_slice = ops::StridedSlice(s.WithOpName("my_strided_slice"), + input, begin, end, strides, attrs); + return strided_slice.operation.node()->def(); + }; + + { + NodeDef node_def = get_strided_slice_nodedef(); + AddTestWeights("input", {1, 2, 3}, {1, 2, 3, 4, 5, 6}); + AddTestWeights("begin", {4}, {0, 0, 0, 0}); + AddTestWeights("end", {4}, {1, 1, 2, 3}); + AddTestWeights("strides", {4}, {1, 1, 1, 1}); + RunValidationAndConversion( + node_def, error::UNIMPLEMENTED, + "StridedSlice is only implemented for tensors, at my_strided_slice"); + } + { + // Begin, end, strides are tensors, should fail. + Reset(); + NodeDef node_def = get_strided_slice_nodedef(); + AddTestTensor("input", {1, 2, 3}); + AddTestTensor("begin", {4}); + AddTestTensor("end", {4}); + AddTestTensor("strides", {4}); + RunValidationAndConversion( + node_def, error::INVALID_ARGUMENT, + "StridedSlice expects weights for begin, end, and strides, at " + "my_strided_slice"); + } + { + // Non-zero ellipsis_mask, should fail. + Reset(); + NodeDef node_def = get_strided_slice_nodedef( + /*begin_mask=*/0, /*end_mask=*/0, /*ellipsis_mask=*/2, + /*new_axis_mask=*/0, /*shrink_axis_mask=*/0); + AddTestTensor("input", {1, 2, 3}); + AddTestWeights("begin", {4}, {0, 0, 0, 0}); + AddTestWeights("end", {4}, {1, 1, 2, 3}); + AddTestWeights("strides", {4}, {1, 1, 1, 1}); + RunValidationAndConversion( + node_def, error::UNIMPLEMENTED, + "ellipsis_mask is not supported for StridedSlice, at " + "my_strided_slice"); + } + { + // Modify batch dim, should fail. + Reset(); + NodeDef node_def = get_strided_slice_nodedef(); + AddTestTensor("input", {1, 2, 3}); + AddTestWeights("begin", {4}, {0, 0, 0, 0}); + AddTestWeights("end", {4}, {0, 1, 2, 3}); + AddTestWeights("strides", {4}, {1, 1, 1, 1}); + RunValidationAndConversion( + node_def, error::UNIMPLEMENTED, + "StridedSlice can't modify batch dim, at my_strided_slice"); + } + { + // Stride is not 1, should fail. + Reset(); + NodeDef node_def = get_strided_slice_nodedef(); + AddTestTensor("input", {1, 2, 3}); + AddTestWeights("begin", {4}, {0, 0, 0, 0}); + AddTestWeights("end", {4}, {1, 1, 2, 3}); + AddTestWeights("strides", {4}, {1, 2, -1, 3}); + RunValidationAndConversion(node_def, error::UNIMPLEMENTED, + "StridedSlice is only implemented for stride of " + "1, at my_strided_slice"); + } + { + // Begin out of bounds, should fail. + Reset(); + NodeDef node_def = get_strided_slice_nodedef(); + AddTestTensor("input", {1, 2, 3}); + AddTestWeights("begin", {4}, {1, 2, 3, 4}); + AddTestWeights("end", {4}, {0, 1, 2, 3}); + AddTestWeights("strides", {4}, {1, 1, 1, 1}); + RunValidationAndConversion( + node_def, error::INVALID_ARGUMENT, + "begin value of 2 for StridedSlice is invalid, must be in the range " + "[-dim_size(i), dim_size(i)], at my_strided_slice"); + } + { + // End out of bounds, should fail. + Reset(); + NodeDef node_def = get_strided_slice_nodedef(); + AddTestTensor("input", {1, 2, 3}); + AddTestWeights("begin", {4}, {0, 0, 0, 0}); + AddTestWeights("end", {4}, {1, 2, 3, 4}); + AddTestWeights("strides", {4}, {1, 1, 1, 1}); + RunValidationAndConversion( + node_def, error::INVALID_ARGUMENT, + "end value of 2 for StridedSlice is invalid, must be in the range " + "[-dim_size(i), dim_size(i)], at my_strided_slice"); + } + { + // Size of sliced dim is negative, should fail. + Reset(); + NodeDef node_def = get_strided_slice_nodedef(); + AddTestTensor("input", {1, 2, 3}); + AddTestWeights("begin", {4}, {0, 0, 2, 0}); + AddTestWeights("end", {4}, {1, 1, 0, 3}); + AddTestWeights("strides", {4}, {1, 1, 1, 1}); + RunValidationAndConversion( + node_def, error::INVALID_ARGUMENT, + "New size of sliced dimension is negative, at my_strided_slice"); + } + + struct TestParams { + TestParams(const std::vector& input_dims, + const std::vector& expected_output_dims, + const std::vector& begin, const std::vector& end, + const std::vector& begin_mask, + const std::vector& end_mask, + const std::vector& expected_output) + : input_dims(input_dims), + expected_output_dims(expected_output_dims), + begin(begin), + end(end), + expected_output(expected_output) { + // Masks are provided in terms of vectors for readability. Convert them to + // binary here. + this->begin_mask = 0; + for (int i = 0; i < begin_mask.size(); i++) { + if (begin_mask[i]) this->begin_mask |= (1 << i); + } + this->end_mask = 0; + for (int i = 0; i < end_mask.size(); i++) { + if (end_mask[i]) this->end_mask |= (1 << i); + } + } + + std::vector input_dims; + std::vector expected_output_dims; + std::vector begin; + std::vector end; + int begin_mask; + int end_mask; + std::vector expected_output; + }; + + // Ok. + const int kStridedSliceOKCases = 18; + TestParams ok_params[kStridedSliceOKCases] = { + // 2D Crop. + TestParams{/*input_dims=*/{1, 2, 3}, /*expected_output_dims=*/{1, 1, 2}, + /*begin=*/{0, 0, 0, 0}, /*end=*/{0, 0, 1, 2}, + /*begin_mask=*/{0, 0, 0, 0}, /*end_mask=*/{1, 1, 0, 0}, + /*expected_output=*/{1, 2}}, + TestParams{/*input_dims=*/{1, 2, 3}, /*expected_output_dims=*/{1, 1, 2}, + /*begin=*/{0, 0, 1, 1}, /*end=*/{0, 0, 0, 0}, + /*begin_mask=*/{0, 0, 0, 0}, /*end_mask=*/{1, 1, 1, 1}, + /*expected_output=*/{5, 6}}, + TestParams{/*input_dims=*/{1, 2, 3}, /*expected_output_dims=*/{1, 1, 2}, + /*begin=*/{0, 0, 1, 1}, /*end=*/{0, 1, 2, 3}, + /*begin_mask=*/{0, 0, 0, 0}, /*end_mask=*/{1, 1, 0, 0}, + /*expected_output=*/{5, 6}}, + // 2D Crop, with transpose. + TestParams{/*input_dims=*/{2, 3, 1}, /*expected_output_dims=*/{1, 2, 1}, + /*begin=*/{0, 0, 0, 0}, /*end=*/{0, 1, 2, 1}, + /*begin_mask=*/{0, 0, 0, 0}, /*end_mask=*/{1, 0, 0, 0}, + /*expected_output=*/{1, 2}}, + TestParams{/*input_dims=*/{2, 3, 1}, /*expected_output_dims=*/{1, 2, 1}, + /*begin=*/{0, 1, 1, 0}, /*end=*/{0, 2, 3, 1}, + /*begin_mask=*/{0, 0, 0, 0}, /*end_mask=*/{1, 0, 0, 0}, + /*expected_output=*/{5, 6}}, + TestParams{/*input_dims=*/{2, 1, 3}, /*expected_output_dims=*/{1, 1, 2}, + /*begin=*/{0, 0, 0, 0}, /*end=*/{0, 1, 1, 2}, + /*begin_mask=*/{0, 0, 0, 0}, /*end_mask=*/{1, 0, 0, 0}, + /*expected_output=*/{1, 2}}, + TestParams{/*input_dims=*/{2, 1, 3}, /*expected_output_dims=*/{1, 1, 2}, + /*begin=*/{0, 1, 0, 1}, /*end=*/{0, 2, 1, 3}, + /*begin_mask=*/{0, 0, 0, 0}, /*end_mask=*/{1, 0, 0, 0}, + /*expected_output=*/{5, 6}}, + // 2D Crop, with reshape. + TestParams{/*input_dims=*/{2, 3}, /*expected_output_dims=*/{1, 2}, + /*begin=*/{0, 0, 0}, /*end=*/{0, 1, 2}, + /*begin_mask=*/{0, 0, 0}, /*end_mask=*/{1, 0, 0}, + /*expected_output=*/{1, 2}}, + TestParams{/*input_dims=*/{2, 3}, /*expected_output_dims=*/{1, 2}, + /*begin=*/{0, 1, 1}, /*end=*/{0, 0, 0}, + /*begin_mask=*/{0, 0, 0}, /*end_mask=*/{1, 1, 1}, + /*expected_output=*/{5, 6}}, + // 1D Crop. + TestParams{/*input_dims=*/{1, 2, 3}, /*expected_output_dims=*/{1, 2, 2}, + /*begin=*/{0, 0, 0, 0}, /*end=*/{0, 0, 0, 2}, + /*begin_mask=*/{0, 0, 0, 0}, /*end_mask=*/{1, 1, 1, 0}, + /*expected_output=*/{1, 2, 4, 5}}, + TestParams{/*input_dims=*/{1, 2, 3}, /*expected_output_dims=*/{1, 1, 3}, + /*begin=*/{0, 0, 1, 0}, /*end=*/{0, 0, 0, 0}, + /*begin_mask=*/{0, 0, 0, 0}, /*end_mask=*/{1, 1, 1, 1}, + /*expected_output=*/{4, 5, 6}}, + // 1D Crop, with transpose. + TestParams{/*input_dims=*/{2, 3, 1}, /*expected_output_dims=*/{1, 3, 1}, + /*begin=*/{0, 0, 0, 0}, /*end=*/{0, 1, 0, 0}, + /*begin_mask=*/{0, 0, 0, 0}, /*end_mask=*/{1, 0, 1, 1}, + /*expected_output=*/{1, 2, 3}}, + TestParams{/*input_dims=*/{2, 3, 1}, /*expected_output_dims=*/{1, 3, 1}, + /*begin=*/{0, 1, 0, 0}, /*end=*/{0, 0, 0, 0}, + /*begin_mask=*/{0, 0, 0, 0}, /*end_mask=*/{1, 1, 1, 1}, + /*expected_output=*/{4, 5, 6}}, + // 1D Crop, with reshape. + TestParams{/*input_dims=*/{6}, /*expected_output_dims=*/{3}, + /*begin=*/{0, 0}, /*end=*/{0, 3}, + /*begin_mask=*/{0, 0}, /*end_mask=*/{1, 0}, + /*expected_output=*/{1, 2, 3}}, + TestParams{/*input_dims=*/{1, 6}, /*expected_output_dims=*/{1, 3}, + /*begin=*/{0, 0, 2}, /*end=*/{0, 0, 5}, + /*begin_mask=*/{0, 0, 0}, /*end_mask=*/{1, 1, 0}, + /*expected_output=*/{3, 4, 5}}, + TestParams{/*input_dims=*/{6, 1}, /*expected_output_dims=*/{3, 1}, + /*begin=*/{0, 2, 0}, /*end=*/{0, 5, 0}, + /*begin_mask=*/{0, 0, 0}, /*end_mask=*/{1, 0, 1}, + /*expected_output=*/{3, 4, 5}}, + // Negative axis. + TestParams{/*input_dims=*/{6, 1}, /*expected_output_dims=*/{3, 1}, + /*begin=*/{0, -6, 0}, /*end=*/{0, -3, 0}, + /*begin_mask=*/{0, 0, 0}, /*end_mask=*/{1, 0, 1}, + /*expected_output=*/{1, 2, 3}}, + TestParams{/*input_dims=*/{6, 1}, /*expected_output_dims=*/{5, 1}, + /*begin=*/{0, 0, 0}, /*end=*/{0, -1, 0}, + /*begin_mask=*/{0, 0, 0}, /*end_mask=*/{1, 0, 1}, + /*expected_output=*/{1, 2, 3, 4, 5}}, + }; + + for (int i = 0; i < kStridedSliceOKCases; i++) { + Reset(); + NodeDef node_def = get_strided_slice_nodedef(ok_params[i].begin_mask, + ok_params[i].end_mask); + AddTestTensor("input", ok_params[i].input_dims); + AddTestWeights("begin", + {static_cast(ok_params[i].begin.size())}, + ok_params[i].begin); + AddTestWeights("end", {static_cast(ok_params[i].end.size())}, + ok_params[i].end); + std::vector strides(ok_params[i].input_dims.size(), 1); + AddTestWeights("strides", {static_cast(strides.size())}, + strides); + RunValidationAndConversion(node_def); + + TRT_TensorOrWeights output; + TF_EXPECT_OK(GetTensorOrWeights("my_strided_slice", &output)); + std::vector output_data(ok_params[i].expected_output.size()); + BuildAndRun({{"input", {1, 2, 3, 4, 5, 6}}}, "my_strided_slice", + &output_data); + EXPECT_THAT(output_data, ElementsAreArray(ok_params[i].expected_output)); + } +} + } // namespace convert } // namespace tensorrt } // namespace tensorflow diff --git a/tensorflow/contrib/tensorrt/convert/trt_optimization_pass.cc b/tensorflow/contrib/tensorrt/convert/trt_optimization_pass.cc index c1688d4db88a270dcd202989f89a677ed10576d9..d57f2300f8e6e6ce79c538133da6bc5cf5ead2f5 100644 --- a/tensorflow/contrib/tensorrt/convert/trt_optimization_pass.cc +++ b/tensorflow/contrib/tensorrt/convert/trt_optimization_pass.cc @@ -226,8 +226,9 @@ tensorflow::Status TRTOptimizationPass::Optimize( tensorflow::tensorrt::convert::ConversionParams cp; if (use_calibration_ && precision_mode_ != INT8MODE) { - LOG(ERROR) << "Calibration with FP32 or FP16 is not implemented. " - << "Falling back to use_calibration = False."; + VLOG(1) << "Calibration with FP32 or FP16 is not implemented. " + << "Falling back to use_calibration = False." + << "Note that the default value of use_calibration is True."; use_calibration_ = false; } diff --git a/tensorflow/contrib/tensorrt/resources/trt_allocator_test.cc b/tensorflow/contrib/tensorrt/resources/trt_allocator_test.cc index ad6b1d7d4c57d696d3dee3b479733e152e669211..beb1284208e4c10ffe1d36ef411cf08f11dbcb78 100644 --- a/tensorflow/contrib/tensorrt/resources/trt_allocator_test.cc +++ b/tensorflow/contrib/tensorrt/resources/trt_allocator_test.cc @@ -48,11 +48,14 @@ TEST(TRTAllocatorTest, Align) { 513ul, 700ul, 12345ul, 1ul << 32}) { for (uint64_t alignment = 1; alignment <= space * 4; alignment *= 2) { for (const uintptr_t ptr_val : - {1ul, alignment == 1 ? 1ul : alignment - 1, alignment, alignment + 1, - alignment + (alignment / 2)}) { + {static_cast(1), + alignment == 1 ? static_cast(1) : alignment - 1, + alignment, alignment + 1, alignment + (alignment / 2)}) { if (ptr_val % alignment == 0) { for (const uint64_t size : - {1ul, space == 1 ? 1ul : space - 1, space, space + 1}) { + {static_cast(1), + space == 1 ? static_cast(1) : space - 1, space, + space + 1}) { EXPECT_EQ(space >= size, RunTest(alignment, size, ptr_val, space)); } } else { @@ -62,8 +65,10 @@ TEST(TRTAllocatorTest, Align) { EXPECT_TRUE( RunTest(alignment, space - diff, ptr_val + diff, space - diff)); for (const uint64_t size : - {1ul, space - diff > 1 ? space - diff - 1 : 1ul, space - diff, - space - diff + 1, space - 1}) { + {static_cast(1), + space - diff > 1 ? space - diff - 1 + : static_cast(1), + space - diff, space - diff + 1, space - 1}) { EXPECT_EQ(space - diff >= size, RunTest(alignment, size, ptr_val, space)); } diff --git a/tensorflow/contrib/tensorrt/segment/segment.cc b/tensorflow/contrib/tensorrt/segment/segment.cc index 6abc5226ccf96e472df77269bee6186726e5768d..084a96e0fa5c97edc58adf2590ed94e5ef0e4d85 100644 --- a/tensorflow/contrib/tensorrt/segment/segment.cc +++ b/tensorflow/contrib/tensorrt/segment/segment.cc @@ -225,6 +225,24 @@ SimpleGraph::~SimpleGraph() { for (auto x : edges_) delete x; } +// Define comparison functions for std::set with pointer keys so that behavior +// is deterministic. When using std::set with pointer key types, the items are +// sorted by pointer address which is non-deterministic. This can cause issues +// for INT8 mode because the graph is converted twice and non-determinism may +// cause a mismatch between the calibration tables of the conversions. +struct SimpleEdgePtrCompare { + bool operator()(const SimpleEdge* lhs, const SimpleEdge* rhs) const { + return lhs->id() < rhs->id(); + } +}; + +struct NodePtrCompare { + bool operator()(const tensorflow::Node* lhs, + const tensorflow::Node* rhs) const { + return lhs->name() < rhs->name(); + } +}; + namespace { // Copied from TF ReverseDFS, which only works for tensorflow::Graph. @@ -476,7 +494,7 @@ tensorflow::Status SegmentGraph( // nodes. Iterate since combining two nodes may unblock other // combining. while (true) { - std::set contract_edges; + std::set contract_edges; for (const SimpleEdge* out_edge : node->out_edges()) { VLOG(3) << "... out node " << out_edge->dst()->name() << " ( " << out_edge->dst()->id() << " <- " << node->id() << " )"; @@ -530,7 +548,7 @@ tensorflow::Status SegmentGraph( // A map from the segment identifier (currently the name of the root node of // the segment tree) to the segment nodes set. - std::map> sg_map; + std::map> sg_map; // A map from the segment identifier (currently the name of the root node of // the segment tree) to the device names that the nodes in the segment are @@ -566,7 +584,8 @@ tensorflow::Status SegmentGraph( // --------------------------------- Step 2 --------------------------------- // Remove ineligible input/output nodes. for (auto& itr : sg_map) { - std::set& segment_nodes = itr.second; + std::set& segment_nodes = + itr.second; VLOG(1) << "Segment original size: " << segment_nodes.size(); while (true) { std::deque in_nodes_que, out_nodes_que; @@ -618,8 +637,9 @@ tensorflow::Status SegmentGraph( bool is_input_nodes, std::deque* que) { // Run a BFS on the queue to find all the input/output nodes. - std::set visited; - std::set logged(que->begin(), que->end()); + std::set visited; + std::set logged(que->begin(), + que->end()); while (!que->empty()) { auto node = que->front(); que->pop_front(); @@ -653,7 +673,8 @@ tensorflow::Status SegmentGraph( // --------------------------------- Step 3 --------------------------------- // Convert the segments into the expected return format for (const auto& itr : sg_map) { - const std::set& segment_nodes = itr.second; + const std::set& segment_nodes = + itr.second; if (VLOG_IS_ON(1)) { string s = "parent=" + itr.first + ":"; for (auto node : segment_nodes) s += " " + node->name(); diff --git a/tensorflow/contrib/tensorrt/test/quantization_mnist_test.py b/tensorflow/contrib/tensorrt/test/quantization_mnist_test.py index 31cbef89e23949ba5ceaab34e0f683fd906bf0ce..e7d6ec4ad395d38a06f97020f2f363009f2286c7 100644 --- a/tensorflow/contrib/tensorrt/test/quantization_mnist_test.py +++ b/tensorflow/contrib/tensorrt/test/quantization_mnist_test.py @@ -191,7 +191,7 @@ class QuantizationAwareTrainingMNISTTest(test_util.TensorFlowTestCase): batch_size=batch_size, num_parallel_calls=8)) dataset = dataset.repeat(count=1) - iterator = data.make_one_shot_iterator(dataset) + iterator = dataset.make_one_shot_iterator() features, labels = iterator.get_next() return features, labels @@ -205,7 +205,7 @@ class QuantizationAwareTrainingMNISTTest(test_util.TensorFlowTestCase): batch_size=batch_size, num_parallel_calls=8)) dataset = dataset.repeat(count=num_epochs) - iterator = data.make_one_shot_iterator(dataset) + iterator = dataset.make_one_shot_iterator() features, labels = iterator.get_next() return features, labels diff --git a/tensorflow/contrib/tensorrt/test/rank_two_test.py b/tensorflow/contrib/tensorrt/test/rank_two_test.py index 0cd733dca13462ac8f4478544005ae4000f711f1..563232fc12675d9e1b32b7ab461591af57beadb9 100644 --- a/tensorflow/contrib/tensorrt/test/rank_two_test.py +++ b/tensorflow/contrib/tensorrt/test/rank_two_test.py @@ -51,8 +51,10 @@ class RankTwoTest(trt_test.TfTrtIntegrationTestBase): c = constant_op.constant(3.0, name="c%d_3" % i) q = math_ops.add(q, c, name="add%d_3" % i) if i == 0: + axis = constant_op.constant(-1, dtype=dtypes.int32, name="axis") for j in range(2): - q = array_ops.expand_dims(q, -1, name="expand%d_%d" % (i, j)) + q = array_ops.expand_dims(q, axis, name="expand%d_%d" % (i, j)) + q = self.trt_incompatible_op(q) q = gen_math_ops.reciprocal(q, name="reciprocal%d" % i) outputs.append(q) # Combine both paths @@ -70,7 +72,7 @@ class RankTwoTest(trt_test.TfTrtIntegrationTestBase): return { "TRTEngineOp_0": [ "add0_1", "add0_2", "add0_3", "c0_1", "c0_2", "c0_3", "abs0_1", - "abs0_2" + "abs0_2", "expand0_0", "expand0_1", "axis" ], "TRTEngineOp_1": [ "add", "add1_1", "add1_2", "add1_3", "c1_1", "c1_2", "c1_3", diff --git a/tensorflow/contrib/tensorrt/test/unary_test.py b/tensorflow/contrib/tensorrt/test/unary_test.py index 9fc50e05952abd335e196dce8fc8a81056d7007d..b6e5e32db1236684a06c2d44298b9a3d39667152 100644 --- a/tensorflow/contrib/tensorrt/test/unary_test.py +++ b/tensorflow/contrib/tensorrt/test/unary_test.py @@ -106,10 +106,7 @@ class UnaryTest(trt_test.TfTrtIntegrationTestBase): def ExpectedEnginesToBuild(self, run_params): """Return the expected engines to build.""" - return [ - "TRTEngineOp_0", "TRTEngineOp_1", "TRTEngineOp_2", "TRTEngineOp_3", - "TRTEngineOp_4" - ] + return ["TRTEngineOp_0"] if __name__ == "__main__": diff --git a/tensorflow/contrib/tfprof/README.md b/tensorflow/contrib/tfprof/README.md index b29d1acacf17b57549558be45c853566817c1729..f40e76f554e8815aac96344d8cb0b911bafdd712 100644 --- a/tensorflow/contrib/tfprof/README.md +++ b/tensorflow/contrib/tfprof/README.md @@ -1,7 +1,5 @@ # tfprof: TensorFlow Profiler and Beyond -

Please use `tf.profiler.xxx` instead of `tf.contrib.tfprof.xxx`

-

Full Document in tensorflow/core/profiler/README.md

diff --git a/tensorflow/contrib/tpu/BUILD b/tensorflow/contrib/tpu/BUILD index 4bf3a0463d9046eea2f60e9154fca1357e728215..76641318134eac90dadc9b98c51f5bb2207c88d3 100644 --- a/tensorflow/contrib/tpu/BUILD +++ b/tensorflow/contrib/tpu/BUILD @@ -1,15 +1,15 @@ # Description: Operations defined for Cloud TPUs -licenses(["notice"]) # Apache 2.0 - load( "//tensorflow:tensorflow.bzl", "tf_custom_op_library", "tf_gen_op_libs", "tf_gen_op_wrapper_py", + "tf_py_test", ) load("//tensorflow:tensorflow.bzl", "tf_custom_op_py_library") -load("//tensorflow:tensorflow.bzl", "tf_py_test") + +licenses(["notice"]) # Apache 2.0 package( default_visibility = [ @@ -102,6 +102,8 @@ tf_gen_op_libs( "replication_ops", "tpu_configuration_ops", "tpu_embedding_ops", + "tpu_ordinal_selector_op", + "functional_ops", ], deps = [ "//tensorflow/contrib/tpu/proto:tpu_embedding_configuration_proto_cc", @@ -153,6 +155,33 @@ tf_gen_op_wrapper_py( ], ) +tf_gen_op_wrapper_py( + name = "tpu_ordinal_selector_op", + deps = [ + ":tpu_ordinal_selector_op_op_lib", + ], +) + +tf_gen_op_wrapper_py( + name = "gen_functional_ops", + out = "python/tpu/gen_functional_ops.py", + hidden = [ + "TPUPartitionedCall", + ], + deps = [":functional_ops_op_lib"], +) + +py_library( + name = "functional", + srcs = ["python/tpu/functional.py"], + visibility = [ + "//visibility:public", + ], + deps = [ + ":gen_functional_ops", + ], +) + py_library( name = "profiler", srcs = ["python/profiler/__init__.py"], @@ -193,6 +222,7 @@ py_library( ], srcs_version = "PY2AND3", deps = [ + ":feature_column", ":keras_support", # split out to avoid cycle with tpu_strategy ":tpu_embedding", ":tpu_estimator", @@ -307,6 +337,7 @@ py_library( tf_py_test( name = "datasets_test", + size = "medium", srcs = ["python/tpu/datasets_test.py"], additional_deps = [ "//tensorflow/python:client_testlib", @@ -314,6 +345,7 @@ tf_py_test( ], flaky = 1, # TODO(b/117363808): fails 1/1000 OSS runs grpc_enabled = True, + shard_count = 4, ) tf_py_test( @@ -412,3 +444,37 @@ py_library( "@six_archive//:six", ], ) + +py_library( + name = "feature_column", + srcs = ["python/tpu/feature_column.py"], + deps = [ + ":tpu_lib", + "//tensorflow/python:framework_ops", + "//tensorflow/python:init_ops", + "//tensorflow/python:variable_scope", + "//tensorflow/python/feature_column", + "//tensorflow/python/feature_column:feature_column_py", + ], +) + +tf_py_test( + name = "feature_column_test", + srcs = [ + "python/tpu/feature_column_test.py", + ], + additional_deps = [ + ":feature_column", + "//third_party/py/numpy", + "//tensorflow/python:dtypes", + "//tensorflow/python:framework_ops", + "//tensorflow/python:lookup_ops", + "//tensorflow/python:parsing_ops", + "//tensorflow/python:session", + "//tensorflow/python:sparse_tensor", + "//tensorflow/python:variables", + "//tensorflow/python/feature_column", + "//tensorflow/python/feature_column:feature_column_py", + ], + main = "python/tpu/feature_column_test.py", +) diff --git a/tensorflow/core/platform/cuda_libdevice_path.cc b/tensorflow/contrib/tpu/ops/functional_ops.cc similarity index 58% rename from tensorflow/core/platform/cuda_libdevice_path.cc rename to tensorflow/contrib/tpu/ops/functional_ops.cc index 4d6532b983d52e7882ab540da31fb0b57183eb6f..aa81e8b24b5e303f5de5d2938b9474fc6b7af6c9 100644 --- a/tensorflow/core/platform/cuda_libdevice_path.cc +++ b/tensorflow/contrib/tpu/ops/functional_ops.cc @@ -1,4 +1,4 @@ -/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -13,14 +13,19 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/core/platform/cuda_libdevice_path.h" - -#include "tensorflow/core/lib/io/path.h" +#include "tensorflow/core/framework/common_shape_fns.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/shape_inference.h" namespace tensorflow { -string LibdeviceRoot() { - return tensorflow::io::JoinPath(tensorflow::CudaRoot(), "nvvm/libdevice"); -} +REGISTER_OP("TPUPartitionedCall") + .Input("args: Tin") + .Input("device_ordinal: int32") + .Output("output: Tout") + .Attr("Tin: list(type) >= 0") + .Attr("Tout: list(type) >= 0") + .Attr("f: func") + .SetShapeFn(shape_inference::UnknownShape); } // namespace tensorflow diff --git a/tensorflow/contrib/tpu/ops/infeed_ops.cc b/tensorflow/contrib/tpu/ops/infeed_ops.cc index efc546f9a6077de9cac5a5acefa3fc7206547fc6..2ed16c2a2270a5399059d7e07f5903e11098bbf9 100644 --- a/tensorflow/contrib/tpu/ops/infeed_ops.cc +++ b/tensorflow/contrib/tpu/ops/infeed_ops.cc @@ -40,6 +40,7 @@ REGISTER_OP("InfeedEnqueue") .Input("input: dtype") .Attr("dtype: type") .Attr("shape: shape = {}") + .Attr("layout: list(int) = []") .Attr("device_ordinal: int = -1") .SetShapeFn(shape_inference::NoOutputs) .SetIsStateful() @@ -49,6 +50,9 @@ An op which feeds a single Tensor value into the computation. input: A tensor that will be provided using the infeed mechanism. dtype: The type of elements in the tensor. shape: The shape of the tensor. +layout: A vector holding the requested layout in minor-to-major sequence. +If a layout attribute is passed, but its values are all -1, the layout will +be computed by the infeed operation. device_ordinal: The TPU device to use. This should be -1 when the Op is running on a TPU device, and >= 0 when the Op is running on the CPU device. @@ -58,6 +62,7 @@ REGISTER_OP("InfeedEnqueueTuple") .Input("inputs: dtypes") .Attr("dtypes: list(type)") .Attr("shapes: list(shape)") + .Attr("layouts: list(int) = []") .Attr("device_ordinal: int = -1") .SetShapeFn(shape_inference::NoOutputs) .SetIsStateful() @@ -67,6 +72,10 @@ An op which feeds multiple Tensor values into the computation as an XLA tuple. inputs: A list of tensors that will be provided using the infeed mechanism. dtypes: The element types of each element in `inputs`. shapes: The shapes of each tensor in `inputs`. +layouts: A vector holding the requested layout in minor-to-major sequence for +all the tuple shapes, in the order the shapes appear in the "shapes" input. +The layout elements for a sub-shape can be set to -1, in which case the +corresponding layout will be computed by the infeed operation. device_ordinal: The TPU device to use. This should be -1 when the Op is running on a TPU device, and >= 0 when the Op is running on the CPU device. diff --git a/tensorflow/contrib/tpu/ops/tpu_ordinal_selector_op.cc b/tensorflow/contrib/tpu/ops/tpu_ordinal_selector_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..54e6b20f7f388b67a96ac8acfe814a4202b56a18 --- /dev/null +++ b/tensorflow/contrib/tpu/ops/tpu_ordinal_selector_op.cc @@ -0,0 +1,39 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/shape_inference.h" + +namespace tensorflow { + +REGISTER_OP("TPUOrdinalSelector") + .Output("device_ordinals: int32") + .SetIsStateful() + .SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) { + c->set_output(0, + c->Vector(shape_inference::InferenceContext::kUnknownDim)); + return Status::OK(); + }) + .Doc(R"doc( +A TPU core selector Op. + +This Op produces a set of TPU cores (for warm-up) or a single TPU core +(for regular inference) to execute the TPU program on. The output is +consumed by TPUPartitionedCall. + +device_ordinals: A vector 1 or more TPU cores. +)doc"); + +} // namespace tensorflow diff --git a/tensorflow/contrib/tpu/python/tpu/datasets_test.py b/tensorflow/contrib/tpu/python/tpu/datasets_test.py index 52d87b800401c3e584da9843916cfc7a767c082a..8a94f527bb6dffa48e71e6500ae5e9e9589fbf5c 100644 --- a/tensorflow/contrib/tpu/python/tpu/datasets_test.py +++ b/tensorflow/contrib/tpu/python/tpu/datasets_test.py @@ -27,6 +27,7 @@ from tensorflow.python.client import session from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.ops import readers from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape from tensorflow.python.lib.io import python_io from tensorflow.python.platform import test @@ -55,6 +56,7 @@ class DatasetsTest(test.TestCase): session_config = config_pb2.ConfigProto(cluster_def=self._cluster_def) self._sess = session.Session(self._worker.target, config=session_config) + self._worker_device = '/job:' + worker_job.name def testTextLineDataset(self): all_contents = [] @@ -70,7 +72,8 @@ class DatasetsTest(test.TestCase): dataset = datasets.StreamingFilesDataset( os.path.join(self.get_temp_dir(), 'text_line.*.txt'), filetype='text') - iterator = dataset_ops.make_initializable_iterator(dataset) + with ops.device(self._worker_device): + iterator = dataset_ops.make_initializable_iterator(dataset) self._sess.run(iterator.initializer) get_next = iterator.get_next() @@ -94,7 +97,8 @@ class DatasetsTest(test.TestCase): dataset = datasets.StreamingFilesDataset( os.path.join(self.get_temp_dir(), 'tf_record*'), filetype='tfrecord') - iterator = dataset_ops.make_initializable_iterator(dataset) + with ops.device(self._worker_device): + iterator = dataset_ops.make_initializable_iterator(dataset) self._sess.run(iterator.initializer) get_next = iterator.get_next() @@ -121,7 +125,8 @@ class DatasetsTest(test.TestCase): dataset = datasets.StreamingFilesDataset(filenames, filetype='tfrecord') - iterator = dataset_ops.make_initializable_iterator(dataset) + with ops.device(self._worker_device): + iterator = dataset_ops.make_initializable_iterator(dataset) self._sess.run(iterator.initializer) get_next = iterator.get_next() @@ -154,7 +159,8 @@ class DatasetsTest(test.TestCase): os.path.join(self.get_temp_dir(), 'fixed_length*'), filetype=FixedLengthFile) - iterator = dataset_ops.make_initializable_iterator(dataset) + with ops.device(self._worker_device): + iterator = dataset_ops.make_initializable_iterator(dataset) self._sess.run(iterator.initializer) get_next = iterator.get_next() @@ -177,7 +183,8 @@ class DatasetsTest(test.TestCase): dataset = datasets.StreamingFilesDataset( dataset_ops.Dataset.range(10), filetype=gen_dataset) - iterator = dataset_ops.make_initializable_iterator(dataset) + with ops.device(self._worker_device): + iterator = dataset_ops.make_initializable_iterator(dataset) self._sess.run(iterator.initializer) get_next = iterator.get_next() diff --git a/tensorflow/contrib/tpu/python/tpu/feature_column.py b/tensorflow/contrib/tpu/python/tpu/feature_column.py new file mode 100644 index 0000000000000000000000000000000000000000..d5d00d628d407bf3bb5312bd54f6ccd13dc37db4 --- /dev/null +++ b/tensorflow/contrib/tpu/python/tpu/feature_column.py @@ -0,0 +1,439 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# =================================================================== +"""TPU Feature Column Library.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import contextlib +import math + +from tensorflow.contrib.tpu.python.tpu import tpu +from tensorflow.contrib.tpu.python.tpu import tpu_function +from tensorflow.python.feature_column import feature_column as fc +from tensorflow.python.feature_column import feature_column_lib as fc_lib +from tensorflow.python.framework import ops +from tensorflow.python.ops import init_ops +from tensorflow.python.ops import variable_scope +# pylint: disable=protected-access + + +_TPU_FC_TO_SCOPE = '_tpu_feature_column_scope' +_SUPPORTED_CATEGORICAL_COLUMNS = (fc._IdentityCategoricalColumn, + fc._VocabularyFileCategoricalColumn, + fc._VocabularyListCategoricalColumn, + fc._WeightedCategoricalColumn, + fc_lib.IdentityCategoricalColumn, + fc_lib.VocabularyFileCategoricalColumn, + fc_lib.VocabularyListCategoricalColumn, + fc_lib.WeightedCategoricalColumn) + + +def embedding_column(categorical_column, + dimension, + combiner='mean', + initializer=None): + """TPU embedding_column for `tf.feature_column.embedding_column`. + + Note that the interface for TPU embedding_column is different from the non-TPU + version. The following args available for the non-TPU version are NOT + supported: ckpt_to_load_from, tensor_name_in_ckp, max_norm and trainable. + + Args: + categorical_column: A categorical_column returned from + categorical_column_with_identity, weighted_categorical_column, + categorical_column_with_vocabulary_list or + categorical_column_with_vocabulary_file. + dimension: An integer specifying dimension of the embedding, must be > 0. + combiner: A string specifying how to reduce if there are multiple entries + in a single row. For more information, see + `tf.feature_column.embedding_column`. + initializer: A variable initializer function to be used in embedding + variable initialization. If not specified, defaults to + `tf.truncated_normal_initializer` with mean `0.0` and standard deviation + `1/sqrt(dimension)`. + + Returns: + A _TPUEmbeddingColumn. + + Raises: + ValueError: if `dimension` not > 0. + ValueError: if `initializer` is specified but not callable. + """ + if not isinstance(categorical_column, _SUPPORTED_CATEGORICAL_COLUMNS): + raise TypeError( + 'categorical_column for tpu ' + ' embedding_column must be type %s, got %s.' % (' or '.join([ + cc.__name__ for cc in _SUPPORTED_CATEGORICAL_COLUMNS + ]), type(categorical_column))) + if (dimension is None) or (dimension < 1): + raise ValueError('Invalid dimension {}.'.format(dimension)) + + if (initializer is not None) and (not callable(initializer)): + raise ValueError('initializer must be callable if specified. ' + 'Embedding of column_name: {}'.format( + categorical_column.name)) + if initializer is None: + initializer = init_ops.truncated_normal_initializer( + mean=0.0, stddev=1 / math.sqrt(dimension)) + + embedding_shape = categorical_column._num_buckets, dimension # pylint: disable=protected-access + + def _creator(weight_collections, scope): + embedding_column_layer = fc._EmbeddingColumnLayer( + embedding_shape=embedding_shape, + initializer=initializer, + weight_collections=weight_collections, + trainable=True, + name='embedding_column_layer') + return embedding_column_layer(None, scope=scope) # pylint: disable=not-callable + + column = _TPUEmbeddingColumn( + categorical_column=categorical_column, + dimension=dimension, + combiner=combiner, + layer_creator=_creator, + ckpt_to_load_from=None, + tensor_name_in_ckpt=None, + max_norm=None, + trainable=True) + # For Embedding column, the initializer is hidden inside the creator Fn, which + # is not accessiable later. So, we attach it to a speicial field. Also note + # that non-TPU Embedding column and non-TPU shared Embedding column handle the + # initializer differently. See shared_embedding_columns for details. + column._tpu_initializer = initializer + return column + + +def shared_embedding_columns(categorical_columns, + dimension, + combiner='mean', + initializer=None, + shared_embedding_collection_name=None): + """List of dense columns that convert from sparse, categorical input.""" + for categorical_column in categorical_columns: + if not isinstance(categorical_column, _SUPPORTED_CATEGORICAL_COLUMNS): + raise TypeError( + 'categorical_column for tpu ' + ' shared_embedding_columns must be type %s, got %s.' % (' or '.join([ + cc.__name__ for cc in _SUPPORTED_CATEGORICAL_COLUMNS + ]), type(categorical_column))) + columns = fc_lib.shared_embedding_columns( + categorical_columns, + dimension, + combiner=combiner, + initializer=initializer, + shared_embedding_collection_name=shared_embedding_collection_name, + ckpt_to_load_from=None, + tensor_name_in_ckpt=None, + max_norm=None, + trainable=True) + + # Use the initializer and shared_embedding_collection_name to create TPU + # version + initializer = columns[0].initializer + shared_embedding_collection_name = columns[0].shared_embedding_collection_name + tpu_columns = [] + + # Create the state (_SharedEmbeddingColumnLayer) here. + for categorical_column in categorical_columns: + column = _TPUSharedEmbeddingColumn( + categorical_column=categorical_column, + dimension=dimension, + combiner=combiner, + initializer=initializer, + shared_embedding_collection_name=shared_embedding_collection_name, + ckpt_to_load_from=None, + tensor_name_in_ckpt=None, + max_norm=None, + trainable=True) + tpu_columns.append(column) + + return tpu_columns + + +class _TPUBaseEmbeddingColumn(object): + """Base class for TPU Embedding Column.""" + + def __init__(self, categorical_column): + self._tpu_categorical_column = categorical_column + + def get_combiner(self): + """Returns the embedding combiner.""" + raise NotImplementedError('not implemented') + + def get_embedding_table_size(self): + """Returns the embedding table size, tuple of vocab size and dimension.""" + raise NotImplementedError('not implemented') + + def get_feature_key_name(self): + """Returns the feature key name in the features dict.""" + raise NotImplementedError('not impl') + + def get_weight_key_name(self): + """Return the key name for weights.""" + raise NotImplementedError('not impl') + + def get_embedding_var_name(self): + """Returns the embedding variable name. + + Feature key name and embedding variable name are usually one-to-one mapping. + But for shared embedding columns, it is many-to-one mapping. + """ + raise NotImplementedError('not impl') + + def get_initializer(self): + """Returns the initializer.""" + raise NotImplementedError('not impl') + + def is_categorical_column_weighted(self): + """Check if the categorical column of the embedding column is weighted.""" + raise NotImplementedError('not impl') + + +class _TPUEmbeddingColumn(_TPUBaseEmbeddingColumn, fc._EmbeddingColumn): + """Core Embedding Column.""" + + def __new__(cls, + categorical_column, + dimension, + combiner='mean', + layer_creator=None, + ckpt_to_load_from=None, + tensor_name_in_ckpt=None, + max_norm=None, + trainable=True): + # Note, args ckpt_to_load_from, tensor_name_in_ckpt, max_norm and trainable + # are not supported on TPU. They are solely for matching the signature of + # __new__ of parent class fc._EmbeddingColumn. + return fc._EmbeddingColumn.__new__( + cls, + categorical_column, + dimension, + combiner=combiner, + layer_creator=layer_creator, + ckpt_to_load_from=ckpt_to_load_from, + tensor_name_in_ckpt=tensor_name_in_ckpt, + max_norm=max_norm, + trainable=trainable) + + def __init__(self, + categorical_column, + dimension, + combiner='mean', + layer_creator=None, + ckpt_to_load_from=None, + tensor_name_in_ckpt=None, + max_norm=None, + trainable=True): + _TPUBaseEmbeddingColumn.__init__(self, categorical_column) + self._key = None + + def get_combiner(self): + return self.combiner + + def get_embedding_table_size(self): + """Returns num_ids and width.""" + return (self.categorical_column._num_buckets, self.dimension) + + def get_feature_key_name(self): + """get_feature_key_name.""" + if self.is_categorical_column_weighted(): + return self.categorical_column.categorical_column.name + return self.categorical_column.name + + def get_weight_key_name(self): + """get_weight_key_name.""" + if self.is_categorical_column_weighted(): + return self.categorical_column.weight_feature_key + return None + + def get_embedding_var_name(self): + """get_embedding_var_name.""" + return self.categorical_column.name + + def get_initializer(self): + return self._tpu_initializer + + def is_categorical_column_weighted(self): + """Check if the categorical column of the embedding column is weighted.""" + if isinstance( + self.categorical_column, + ( + fc._WeightedCategoricalColumn, # pylint: disable=protected-access + fc_lib.WeightedCategoricalColumn)): + return True + return False + + def _get_dense_tensor(self, inputs, weight_collections=None, trainable=None): + if tpu.under_tpu_inference_context(): + # TODO(shizhiw, b/112012627, b/112336539): Replace _outside_all_rewrites() + # with outside compilation. + with _outside_all_rewrites(): + return fc._EmbeddingColumn._get_dense_tensor( + self, inputs, weight_collections, trainable) + + if _is_running_on_cpu(): + return fc._EmbeddingColumn._get_dense_tensor( + self, inputs, weight_collections, trainable) + + # TPU mode + # Get the embeddings from the LazyBuilder. + tensor = inputs.get(self.get_feature_key_name()) + + # Add to collection for _create_tpu_embedding_variables_and_ops + _record_variable_scope_and_name(self.get_embedding_var_name(), + 'embedding_weights') + + return tensor + + +@contextlib.contextmanager +def _outside_all_rewrites(): + """'Break out' of a tpu.rewrite() (or shard(), etc.).""" + with ops.control_dependencies(None): + yield + + +class _TPUSharedEmbeddingColumn(_TPUBaseEmbeddingColumn, + fc._SharedEmbeddingColumn): + """Core Shared Embedding Column.""" + + def __new__(cls, + categorical_column, + dimension, + combiner='mean', + initializer=None, + shared_embedding_collection_name=None, + ckpt_to_load_from=None, + tensor_name_in_ckpt=None, + max_norm=None, + trainable=True): + return fc._SharedEmbeddingColumn.__new__( + cls, + categorical_column, + dimension, + combiner=combiner, + initializer=initializer, + shared_embedding_collection_name=shared_embedding_collection_name, + ckpt_to_load_from=ckpt_to_load_from, + tensor_name_in_ckpt=tensor_name_in_ckpt, + max_norm=max_norm, + trainable=trainable) + + def __init__(self, + categorical_column, + dimension, + combiner='mean', + initializer=None, + shared_embedding_collection_name=None, + ckpt_to_load_from=None, + tensor_name_in_ckpt=None, + max_norm=None, + trainable=True): + + _TPUBaseEmbeddingColumn.__init__(self, categorical_column) + self._key = None + + def get_combiner(self): + return self.combiner + + def get_embedding_table_size(self): + """Returns num_ids and width.""" + return (self.categorical_column._num_buckets, self.dimension) + + def get_feature_key_name(self): + """get_feature_key_name.""" + if self.is_categorical_column_weighted(): + return self.categorical_column.categorical_column.name + return self.categorical_column.name + + def get_weight_key_name(self): + """get_weight_key_name.""" + if self.is_categorical_column_weighted(): + return self.categorical_column.weight_feature_key + return None + + def get_embedding_var_name(self): + """get_embedding_var_name.""" + return self.shared_embedding_collection_name + + def get_initializer(self): + return self.initializer + + def is_categorical_column_weighted(self): + """Check if the categorical column of the embedding column is weighted.""" + if isinstance( + self.categorical_column, + ( + fc._WeightedCategoricalColumn, # pylint: disable=protected-access + fc_lib.WeightedCategoricalColumn)): + return True + return False + + def _get_dense_tensor(self, inputs, weight_collections=None, trainable=None): + if tpu.under_tpu_inference_context(): + # TODO(shizhiw, b/112012627, b/112336539): Replace _outside_all_rewrites() + # with outside compilation. + with _outside_all_rewrites(): + return fc._SharedEmbeddingColumn._get_dense_tensor( + self, inputs, weight_collections, trainable) + + if _is_running_on_cpu(): + return fc._SharedEmbeddingColumn._get_dense_tensor( + self, inputs, weight_collections, trainable) + + # TPU mode + # Get the embeddings from the LazyBuilder. + tensor = inputs.get(self.get_feature_key_name()) + + # Add to collection for _create_tpu_embedding_variables_and_ops + _record_variable_scope_and_name( + self.get_embedding_var_name(), + 'embedding_weights', + is_shared_embedding=True) + return tensor + + +def _record_variable_scope_and_name(embedding_var_name, + embedding_var_name_in_fc, + is_shared_embedding=False): + """Add embedding variable name and scope to collection.""" + g = ops.get_default_graph() + collection = g.get_collection_ref(_TPU_FC_TO_SCOPE) + if not collection: + collection.append({}) + + var_def_dict = collection[0] + + captured_scope = None + + if is_shared_embedding and (embedding_var_name in var_def_dict): + if var_def_dict[embedding_var_name][1] != embedding_var_name_in_fc: + raise ValueError( + 'For embedding var name {}, the shared embedding name is different, ' + 'got {}; expected {}'.format(embedding_var_name, + embedding_var_name_in_fc, + var_def_dict[embedding_var_name][1])) + else: + # scope contains var_scope_name. + captured_scope = variable_scope.get_variable_scope() + var_def_dict[embedding_var_name] = (captured_scope, + embedding_var_name_in_fc) + + +def _is_running_on_cpu(): + """Returns True if the current context is CPU model.""" + return tpu_function.get_tpu_context().number_of_shards is None diff --git a/tensorflow/contrib/tpu/python/tpu/feature_column_test.py b/tensorflow/contrib/tpu/python/tpu/feature_column_test.py new file mode 100644 index 0000000000000000000000000000000000000000..75164cce4c261cc541dd6b01ee22699d286d9621 --- /dev/null +++ b/tensorflow/contrib/tpu/python/tpu/feature_column_test.py @@ -0,0 +1,286 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# =================================================================== +"""Tests for contrib.tpu.python.tpu.feature_column.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.contrib.tpu.python.tpu import feature_column as tpu_fc +from tensorflow.python.client import session +from tensorflow.python.feature_column import feature_column as fc +from tensorflow.python.feature_column import feature_column_lib as fc_lib +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.framework import sparse_tensor +from tensorflow.python.ops import lookup_ops +from tensorflow.python.ops import parsing_ops +from tensorflow.python.ops import variables as variables_lib +from tensorflow.python.platform import test + + +def _initialized_session(): + sess = session.Session() + sess.run(variables_lib.global_variables_initializer()) + sess.run(lookup_ops.tables_initializer()) + return sess + + +class EmbeddingColumnTest(test.TestCase): + + def test_defaults(self): + categorical_column = fc_lib.categorical_column_with_identity( + key='aaa', num_buckets=3) + embedding_dimension = 2 + embedding_column = tpu_fc.embedding_column( + categorical_column, dimension=embedding_dimension) + self.assertIs(categorical_column, embedding_column.categorical_column) + self.assertEqual(embedding_dimension, embedding_column.dimension) + self.assertEqual('mean', embedding_column.combiner) + self.assertEqual('aaa_embedding', embedding_column.name) + self.assertEqual('aaa_embedding', embedding_column._var_scope_name) + self.assertEqual((embedding_dimension,), embedding_column._variable_shape) + self.assertEqual({ + 'aaa': parsing_ops.VarLenFeature(dtypes.int64) + }, embedding_column._parse_example_spec) + + def test_all_constructor_args(self): + categorical_column = fc_lib.categorical_column_with_identity( + key='aaa', num_buckets=3) + embedding_dimension = 2 + embedding_column = tpu_fc.embedding_column( + categorical_column, + dimension=embedding_dimension, + combiner='my_combiner', + initializer=lambda: 'my_initializer') + self.assertIs(categorical_column, embedding_column.categorical_column) + self.assertEqual(embedding_dimension, embedding_column.dimension) + self.assertEqual('my_combiner', embedding_column.combiner) + self.assertEqual('aaa_embedding', embedding_column.name) + self.assertEqual('aaa_embedding', embedding_column._var_scope_name) + self.assertEqual((embedding_dimension,), embedding_column._variable_shape) + self.assertEqual({ + 'aaa': parsing_ops.VarLenFeature(dtypes.int64) + }, embedding_column._parse_example_spec) + + def test_get_dense_tensor(self): + # Inputs. + vocabulary_size = 3 + sparse_input = sparse_tensor.SparseTensorValue( + # example 0, ids [2] + # example 1, ids [0, 1] + # example 2, ids [] + # example 3, ids [1] + indices=((0, 0), (1, 0), (1, 4), (3, 0)), + values=(2, 0, 1, 1), + dense_shape=(4, 5)) + + # Embedding variable. + embedding_dimension = 2 + embedding_values = ( + (1., 2.), # id 0 + (3., 5.), # id 1 + (7., 11.) # id 2 + ) + + def _initializer(shape, dtype, partition_info): + self.assertAllEqual((vocabulary_size, embedding_dimension), shape) + self.assertEqual(dtypes.float32, dtype) + self.assertIsNone(partition_info) + return embedding_values + + # Expected lookup result, using combiner='mean'. + expected_lookups = ( + # example 0, ids [2], embedding = [7, 11] + (7., 11.), + # example 1, ids [0, 1], embedding = mean([1, 2] + [3, 5]) = [2, 3.5] + (2., 3.5), + # example 2, ids [], embedding = [0, 0] + (0., 0.), + # example 3, ids [1], embedding = [3, 5] + (3., 5.), + ) + + # Build columns. + categorical_column = fc_lib.categorical_column_with_identity( + key='aaa', num_buckets=vocabulary_size) + embedding_column = tpu_fc.embedding_column( + categorical_column, + dimension=embedding_dimension, + initializer=_initializer) + + # Provide sparse input and get dense result. + embedding_lookup = embedding_column._get_dense_tensor( + fc._LazyBuilder({ + 'aaa': sparse_input + })) + + # Assert expected embedding variable and lookups. + global_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES) + self.assertItemsEqual(('embedding_weights:0',), + tuple([v.name for v in global_vars])) + with _initialized_session(): + self.assertAllEqual(embedding_values, global_vars[0].eval()) + self.assertAllEqual(expected_lookups, embedding_lookup.eval()) + + +class SharedEmbeddingColumnTest(test.TestCase): + + def test_defaults(self): + categorical_column_a = fc_lib.categorical_column_with_identity( + key='aaa', num_buckets=3) + categorical_column_b = fc_lib.categorical_column_with_identity( + key='bbb', num_buckets=3) + embedding_dimension = 2 + embedding_column_b, embedding_column_a = tpu_fc.shared_embedding_columns( + [categorical_column_b, categorical_column_a], + dimension=embedding_dimension) + self.assertIs(categorical_column_a, embedding_column_a.categorical_column) + self.assertIs(categorical_column_b, embedding_column_b.categorical_column) + self.assertEqual(embedding_dimension, embedding_column_a.dimension) + self.assertEqual(embedding_dimension, embedding_column_b.dimension) + self.assertEqual('mean', embedding_column_a.combiner) + self.assertEqual('mean', embedding_column_b.combiner) + self.assertIsNotNone(embedding_column_a.initializer) + self.assertIsNotNone(embedding_column_b.initializer) + self.assertEqual('aaa_bbb_shared_embedding', + embedding_column_a.shared_embedding_collection_name) + self.assertEqual('aaa_bbb_shared_embedding', + embedding_column_b.shared_embedding_collection_name) + self.assertEqual('aaa_shared_embedding', embedding_column_a.name) + self.assertEqual('bbb_shared_embedding', embedding_column_b.name) + self.assertEqual('aaa_bbb_shared_embedding', + embedding_column_a._var_scope_name) + self.assertEqual('aaa_bbb_shared_embedding', + embedding_column_b._var_scope_name) + self.assertEqual((embedding_dimension,), embedding_column_a._variable_shape) + self.assertEqual((embedding_dimension,), embedding_column_b._variable_shape) + self.assertEqual({ + 'aaa': parsing_ops.VarLenFeature(dtypes.int64) + }, embedding_column_a._parse_example_spec) + self.assertEqual({ + 'bbb': parsing_ops.VarLenFeature(dtypes.int64) + }, embedding_column_b._parse_example_spec) + + def test_all_constructor_args(self): + categorical_column_a = fc_lib.categorical_column_with_identity( + key='aaa', num_buckets=3) + categorical_column_b = fc_lib.categorical_column_with_identity( + key='bbb', num_buckets=3) + embedding_dimension = 2 + embedding_column_a, embedding_column_b = tpu_fc.shared_embedding_columns( + [categorical_column_a, categorical_column_b], + dimension=embedding_dimension, + combiner='my_combiner', + initializer=lambda: 'my_initializer', + shared_embedding_collection_name='var_scope_name') + self.assertIs(categorical_column_a, embedding_column_a.categorical_column) + self.assertIs(categorical_column_b, embedding_column_b.categorical_column) + self.assertEqual(embedding_dimension, embedding_column_a.dimension) + self.assertEqual(embedding_dimension, embedding_column_b.dimension) + self.assertEqual('my_combiner', embedding_column_a.combiner) + self.assertEqual('my_combiner', embedding_column_b.combiner) + self.assertEqual('my_initializer', embedding_column_a.initializer()) + self.assertEqual('my_initializer', embedding_column_b.initializer()) + self.assertEqual('var_scope_name', + embedding_column_a.shared_embedding_collection_name) + self.assertEqual('var_scope_name', + embedding_column_b.shared_embedding_collection_name) + self.assertEqual('aaa_shared_embedding', embedding_column_a.name) + self.assertEqual('bbb_shared_embedding', embedding_column_b.name) + self.assertEqual('var_scope_name', embedding_column_a._var_scope_name) + self.assertEqual('var_scope_name', embedding_column_b._var_scope_name) + self.assertEqual((embedding_dimension,), embedding_column_a._variable_shape) + self.assertEqual((embedding_dimension,), embedding_column_b._variable_shape) + self.assertEqual({ + 'aaa': parsing_ops.VarLenFeature(dtypes.int64) + }, embedding_column_a._parse_example_spec) + self.assertEqual({ + 'bbb': parsing_ops.VarLenFeature(dtypes.int64) + }, embedding_column_b._parse_example_spec) + + def test_get_dense_tensor(self): + # Inputs. + vocabulary_size = 3 + # -1 values are ignored. + input_a = np.array([ + [2, -1, -1], # example 0, ids [2] + [0, 1, -1] + ]) # example 1, ids [0, 1] + input_b = np.array([ + [0, -1, -1], # example 0, ids [0] + [-1, -1, -1] + ]) # example 1, ids [] + input_features = {'aaa': input_a, 'bbb': input_b} + + # Embedding variable. + embedding_dimension = 2 + embedding_values = ( + (1., 2.), # id 0 + (3., 5.), # id 1 + (7., 11.) # id 2 + ) + + def _initializer(shape, dtype, partition_info): + self.assertAllEqual((vocabulary_size, embedding_dimension), shape) + self.assertEqual(dtypes.float32, dtype) + self.assertIsNone(partition_info) + return embedding_values + + # Expected lookup result, using combiner='mean'. + expected_lookups_a = ( + # example 0: + (7., 11.), # ids [2], embedding = [7, 11] + # example 1: + (2., 3.5), # ids [0, 1], embedding = mean([1, 2] + [3, 5]) = [2, 3.5] + ) + expected_lookups_b = ( + # example 0: + (1., 2.), # ids [0], embedding = [1, 2] + # example 1: + (0., 0.), # ids [], embedding = [0, 0] + ) + + # Build columns. + categorical_column_a = fc_lib.categorical_column_with_identity( + key='aaa', num_buckets=vocabulary_size) + categorical_column_b = fc_lib.categorical_column_with_identity( + key='bbb', num_buckets=vocabulary_size) + embedding_column_a, embedding_column_b = tpu_fc.shared_embedding_columns( + [categorical_column_a, categorical_column_b], + dimension=embedding_dimension, + initializer=_initializer) + + # Provide sparse input and get dense result. + embedding_lookup_a = embedding_column_a._get_dense_tensor( + fc._LazyBuilder(input_features)) + embedding_lookup_b = embedding_column_b._get_dense_tensor( + fc._LazyBuilder(input_features)) + + # Assert expected embedding variable and lookups. + global_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES) + self.assertItemsEqual(('embedding_weights:0',), + tuple([v.name for v in global_vars])) + embedding_var = global_vars[0] + with _initialized_session(): + self.assertAllEqual(embedding_values, embedding_var.eval()) + self.assertAllEqual(expected_lookups_a, embedding_lookup_a.eval()) + self.assertAllEqual(expected_lookups_b, embedding_lookup_b.eval()) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/tpu/python/tpu/functional.py b/tensorflow/contrib/tpu/python/tpu/functional.py new file mode 100644 index 0000000000000000000000000000000000000000..1ec9b5b33d007eb2eaa557438f32ea69053261c6 --- /dev/null +++ b/tensorflow/contrib/tpu/python/tpu/functional.py @@ -0,0 +1,25 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================= +"""Functional operations.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.tpu.python.tpu import gen_functional_ops + + +TPUPartitionedCall = gen_functional_ops._tpu_partitioned_call # pylint: disable=invalid-name,protected-access + diff --git a/tensorflow/contrib/tpu/python/tpu/keras_support.py b/tensorflow/contrib/tpu/python/tpu/keras_support.py index 4ce194590342555a7c4e9e119bf51e516a37a715..37fe9af8c4b154a2e20a957f6ca5d97df3d413be 100644 --- a/tensorflow/contrib/tpu/python/tpu/keras_support.py +++ b/tensorflow/contrib/tpu/python/tpu/keras_support.py @@ -1373,6 +1373,10 @@ class KerasTPUModel(models.Model): # not hashable. self._numpy_to_infeed_manager_list = [] + # Add distribution specific arguments since we don't call the Model init. + self._distribution_strategy = None + self._compile_distribution = None + self.predict_function = None self.test_function = None self.train_function = None @@ -2069,6 +2073,8 @@ class KerasTPUModel(models.Model): # tpu_model may not be compiled, e.g., loading weights and then predict. return for k, v in six.iteritems(cpu_optimizer_config): + if k == 'name': + continue opt_var = getattr(self._tpu_model.optimizer, k) if isinstance(opt_var, variables.Variable): logging.info('CPU -> TPU %s: %s {%s}', k, v, K.get_value(opt_var)) @@ -2097,6 +2103,8 @@ class KerasTPUModel(models.Model): self._cpu_model.set_weights(tpu_weights) for k, v in six.iteritems(tpu_optimizer_config): logging.info('TPU -> CPU %s: %s', k, v) + if k == 'name': + continue opt_var = getattr(self.cpu_optimizer, k) if isinstance(opt_var, variables.Variable): K.get_session().run(opt_var.assign(v)) diff --git a/tensorflow/contrib/tpu/python/tpu/keras_tpu_variables.py b/tensorflow/contrib/tpu/python/tpu/keras_tpu_variables.py index 8b0b240dc7302c203a22349d583323327fc4480b..de425626c813784ef657d17eac0c7bb77599a155 100644 --- a/tensorflow/contrib/tpu/python/tpu/keras_tpu_variables.py +++ b/tensorflow/contrib/tpu/python/tpu/keras_tpu_variables.py @@ -69,6 +69,7 @@ class ReplicatedVariable(object): def __init__(self, name, variables): self._name = name self._primary_var = variables[0] + self._common_name = self._primary_var.name.split(":")[0] self._vars = variables self._cached_value = None self._dtype = variables[0].dtype diff --git a/tensorflow/contrib/tpu/python/tpu/session_support.py b/tensorflow/contrib/tpu/python/tpu/session_support.py index a95275487899c4770ef99b620a7671eec2bb81eb..3e463823c820a3ef8628324f77e1a9caf8d385d5 100644 --- a/tensorflow/contrib/tpu/python/tpu/session_support.py +++ b/tensorflow/contrib/tpu/python/tpu/session_support.py @@ -43,12 +43,19 @@ class CoordinatorShutdownException(Exception): pass +def _clone_session(session, graph=None): + return session_lib.Session( + target=session.sess_str, + config=session._config, # pylint: disable=protected-access + graph=graph if graph else session.graph) + + def _make_heartbeat_op(session, device, request_ph): """Return a heartbeat op or None if heartbeats are not supported by device.""" try: # Test if we can connect in a isolated graph + session with ops.Graph().as_default(): - with session_lib.Session(target=session.sess_str) as temp_session: + with _clone_session(session) as temp_session: with ops.device(device): heartbeat_op = tpu_ops.worker_heartbeat('') options = config_pb2.RunOptions(timeout_in_ms=5000) @@ -220,6 +227,7 @@ class WatchdogManager(threading.Thread): self.ping_interval = ping_interval self.shutdown_timeout = shutdown_timeout self.daemon = True + self._config = session._config # pylint: disable=protected-access self._target = session.sess_str self._running = False self._devices = devices @@ -234,6 +242,7 @@ class WatchdogManager(threading.Thread): self._session = session_lib.Session( target=self._target, graph=self._graph, + config=self._config, ) if self._devices is None: @@ -334,8 +343,7 @@ class GracefulShutdownHook(session_run_hook.SessionRunHook): with self._graph.as_default(): logging.info('Installing graceful shutdown hook.') - self._session = session_lib.Session( - target=training_session.sess_str, graph=self._graph) + self._session = _clone_session(training_session, self._graph) self._workers = WorkerHeartbeatManager.from_devices( self._session, all_worker_devices(self._session)) self._heartbeat_supported = self._workers.num_workers() > 0 diff --git a/tensorflow/contrib/tpu/python/tpu/tensor_tracer.py b/tensorflow/contrib/tpu/python/tpu/tensor_tracer.py index 70baea203cc6174bebc7d90646045efae5f2391d..a1494e3660bc09e3af45e81097151a35990810fb 100644 --- a/tensorflow/contrib/tpu/python/tpu/tensor_tracer.py +++ b/tensorflow/contrib/tpu/python/tpu/tensor_tracer.py @@ -21,44 +21,56 @@ from __future__ import print_function import os import os.path import re +import sys from tensorflow.contrib.tpu.python.ops import tpu_ops from tensorflow.contrib.tpu.python.tpu import tpu from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_util from tensorflow.python.ops import gen_math_ops +from tensorflow.python.ops import linalg_ops from tensorflow.python.ops import logging_ops from tensorflow.python.ops import math_ops +from tensorflow.python.platform import gfile from tensorflow.python.platform import tf_logging as logging _TRACER_LOG_PREFIX = ' [>>>TT>>>]' _DEVICE_TYPE_TPU = 'tpu' _DEVICE_TYPE_CPU = 'cpu' -_GLOBAL_STEP_OP_NAME = 'GLOBAL-STEP' _TRACE_MODE_NAN_INF = 'nan-inf' _TRACE_MODE_PART_TENSOR = 'part-tensor' _TRACE_MODE_PART_TENSOR_SIZE = 3 _TRACE_MODE_FULL_TENSOR = 'full-tensor' -_RECORD_OUTSIDE_OP_RANGE = 'not-traced-outside-op-range' -_RECORD_SHOULD_NOT_TRACE = 'not-traced-should-not-trace' -_RECORD_FILTERED_OUT = 'not-traced-filtered-out' -_RECORD_SCALAR = 'not-traced-scalar' -_RECORD_DYNAMIC_SHAPE = 'not-traced-dynamic-shape' -_RECORD_GET_TRACED = 'get-traced' +_TRACE_MODE_NORM = 'norm' +_TRACE_MODE_MAX_ABS = 'max-abs' +_REASON_OUTSIDE_OP_RANGE = 'not-traced-outside-op-range' +_REASON_UNSAFE_OP = 'not-traced-unsafe-op' +_REASON_UNSAFE_SCALAR = 'not-traced-unsafe-scalar' +_REASON_LESS_INTERESTING_OP = 'not-traced-less-interesting-op' +_REASON_DEVICE_MISMATCH = 'not-traced-device-mismatch' +_REASON_DYNAMIC_SHAPE = 'not-traced-dynamic-shape' +_REASON_SCALAR_GET_TRACED = 'traced-scalar' +_REASON_TENSOR_GET_TRACED = 'traced-tensor' +_REASON_USER_INCLUDED = 'traced-user-included' +_REASON_USER_EXCLUDED = 'not-traced-user-excluded' +_REASON_NON_NUMERIC_TENSOR = 'not-traced-non-numeric-tensor' _MARKER_SECTION_BEGIN = '!!!!!!! section-begin:' _MARKER_SECTION_END = '!!!!!!! section-end:' _SECTION_NAME_CONFIG = 'configuration' _SECTION_NAME_REASON = 'reason' _SECTION_NAME_OP_LIST = 'op-list' +_SECTION_NAME_TENSOR_LIST = 'tensor-list' _SECTION_NAME_GRAPH = 'graph' _FIELD_NAME_VERSION = 'version:' _FIELD_NAME_DEVICE = 'device:' _FIELD_NAME_TRACE_MODE = 'trace-mode:' _FIELD_NAME_NUM_REPLICAS = 'num-replicas:' _FIELD_NAME_NUM_OPS = 'number-of-ops:' +_FIELD_NAME_NUM_TENSORS = 'number-of-tensors:' _FIELD_NAME_TOPOLOGICAL_SORT_SUCCEED = 'topological-sort-succeed:' _FLAGS_ENV_VAR = 'TENSOR_TRACER_FLAGS' _FLAG_SINGLE_QUOTE_PAT = re.compile(r"\s*--([^=]+)='([^']*)'") @@ -66,13 +78,72 @@ _FLAG_DOUBLE_QUOTE_PAT = re.compile(r'\s*--([^=]+)="([^"]*)"') _FLAG_NO_QUOTE_PAT = re.compile(r'\s*--([^=]+)=(\S*)') _FLAG_NAME_ENABLE = 'enable' _FLAG_NAME_TRACE_MODE = 'trace_mode' -_FLAG_NAME_INTERESTING_OPS = 'interesting_ops' +_FLAG_NAME_INCLUDE_LESS_INTERESTING_OPS = 'include_less_interesting_ops' +_FLAG_NAME_EXCLUDED_OPNAMES = 'excluded_opnames' +_FLAG_NAME_EXCLUDED_OPTYPES = 'excluded_optypes' +_FLAG_NAME_INCLUDED_OPNAMES = 'included_opnames' +_FLAG_NAME_INCLUDED_OPTYPES = 'included_optypes' _FLAG_NAME_TRACE_FILE = 'trace_file_path' +_FLAG_NAME_REPORT_FILE = 'report_file_path' _FLAG_NAME_USE_TEST_UNDECLARED_OUTPUTS_DIR = 'use_test_undeclared_outputs_dir' _FLAG_NAME_OP_RANGE = 'op_range' _OP_RANGE_PAT = re.compile(r'(\d+):(\d+)') _OUTPUT_STREAM_ESCAPE = 'file://' _TEST_UNDECLARED_OUTPUTS_DIR_ENV_VAR = 'TEST_UNDECLARED_OUTPUTS_DIR' +_TENSOR_TRACER_COLLECTION = 'tensor_tracer_variables' +_TENSOR_TRACER_CHECKPOINT = 'tensor_tracer_checkpoint' + + +def tensor_checkpoint(tensor, checkpoint_name): + """Adds a checkpoint with the given checkpoint name for the given tensor. + + The tensor will be added to the list of tensors that will be traced by the + tensor tracer. + + Args: + tensor: the tensor object for which the tracing is requested. + checkpoint_name: a string name for the checkpoint. This name has to be a + unique name if used within model comparison. The tensors that have the same + checkpoint identifier is compared in model comparison. + Returns: + The provided tensor. + """ + + tensor.graph.get_collection(_TENSOR_TRACER_COLLECTION) + tensor.graph.add_to_collection(_TENSOR_TRACER_COLLECTION, + (tensor, checkpoint_name)) + return tensor + + +def keras_layer_checkpoint(layer, checkpoint_name): + """An interface for adding the tensor outputs of a keras layer. + + Encapsulates tensor_checkpoint. + + Args: + layer: A keras layer. + checkpoint_name: a string name for the checkpoint. This name has to be a + unique name if used within model comparison. The tensors that have the same + checkpoint identifier is compared in model comparison. + + Returns: + The provided layer. + """ + try: + outputs = layer.output + if tensor_util.is_tensor(outputs): + tensor_checkpoint(outputs, '%s' % (checkpoint_name)) + else: + idx = 0 + for output_tensor in outputs: + if tensor_util.is_tensor(outputs): + tensor_checkpoint(output_tensor, '%s_%d' % (checkpoint_name, idx)) + idx += 1 + except AttributeError: + pass + except RuntimeError: + pass + return layer class TensorTracer(object): @@ -105,6 +176,34 @@ class TensorTracer(object): match = _FLAG_NO_QUOTE_PAT.match(flags, pos) return match + @staticmethod + def validate_flag_names(): + """Validates if the TensorTrace flags passed are valid.""" + valid_flag_names = [_FLAG_NAME_ENABLE, _FLAG_NAME_TRACE_MODE, + _FLAG_NAME_EXCLUDED_OPNAMES, + _FLAG_NAME_EXCLUDED_OPTYPES, + _FLAG_NAME_INCLUDED_OPNAMES, + _FLAG_NAME_INCLUDED_OPTYPES, + _FLAG_NAME_TRACE_FILE, _FLAG_NAME_REPORT_FILE, + _FLAG_NAME_USE_TEST_UNDECLARED_OUTPUTS_DIR, + _FLAG_NAME_INCLUDE_LESS_INTERESTING_OPS, + _FLAG_NAME_OP_RANGE] + tensor_tracer_flags = os.environ.get(_FLAGS_ENV_VAR) + if not tensor_tracer_flags: + return + pos = 0 + while True: + match = TensorTracer._match_next_flag(tensor_tracer_flags, pos) + if not match: + break + flag_name = match.group(1) + if flag_name not in valid_flag_names: + raise ValueError( + 'The flag name "%s" passed via the environment variable "%s" ' + 'is invalid. Valid flag names are:' + '\n%s'%(flag_name, _FLAGS_ENV_VAR, valid_flag_names)) + pos = match.end() + @staticmethod def print_flag_values(): """Prints all TensorTracer flags passed via environment variables.""" @@ -146,6 +245,20 @@ class TensorTracer(object): pos = match.end() return '' + @staticmethod + def flag_value_to_re_list(flag_name): + """Converts list of strings to compiled RE.""" + + re_list = [] + flag_value = TensorTracer.get_flag_value(flag_name) + if not flag_value: + return re_list + list_of_values = flag_value.split() + for v in list_of_values: + r = re.compile(v) + re_list.append(r) + return re_list + @staticmethod def is_enabled(): """Returns True if TensorTracer is enabled.""" @@ -186,29 +299,67 @@ class TensorTracer(object): """Checks if the given trace mode is valid.""" valid_trace_modes = [_TRACE_MODE_NAN_INF, _TRACE_MODE_PART_TENSOR, - _TRACE_MODE_FULL_TENSOR] + _TRACE_MODE_FULL_TENSOR, _TRACE_MODE_NORM, + _TRACE_MODE_MAX_ABS] if trace_mode not in valid_trace_modes: raise ValueError('Invalid trace mode "%s" given to the Tensor_Tracer.' 'Valid trace modes are: %s'%(trace_mode, valid_trace_modes)) @staticmethod - def should_trace(device_type, op): - """Returns True if the given Op should be traced.""" + def unsafe_op(op): + """Returns True if this op is not safe to be traced.""" - if device_type != _DEVICE_TYPE_TPU: - raise ValueError('Non TPU device type is not supported') if control_flow_util.IsInCond(op): + return True + # Reasons for not including following op types: + # Assign: cause incorrect result with CPU tracing. + # others: compilation problems. + if op.type in ['Assign', 'Pack', 'Shape', 'Reshape', 'ArgMin', 'ArgMax']: + return True + return False + + @staticmethod + def device_mismatch(device_type, op): + if device_type == _DEVICE_TYPE_TPU: + # pylint: disable=protected-access + return tpu._TPU_REPLICATE_ATTR not in op.node_def.attr + # pylint: enable=protected-access + return False + + @staticmethod + def unsafe_scalar_trace(op): + """Return true if scalar output tensor from Op is not safe to be traced.""" + + # Tracing the following causes cycle in the graph on TPU. + if op.type in ['LoopCond', 'Enter', 'Merge', 'Const', + 'Switch', 'Less', 'ReadVariableOp']: + return True + # Tracing the following will cause casting-issue + # with the norm tracing mode or other compilation issues on CPU. + if op.type in ['VarHandleOp', 'IteratorToStringHandle', + 'IteratorGetNext', 'OneShotIterator', + 'IteratorV2', 'MakeIterator', + 'BatchDatasetV2', 'MapDataset', + 'FixedLengthRecordDataset', 'TakeDataset', 'ZipDataset', + 'Placeholder', 'PlaceholderWithDefault', 'StridedSlice']: + return True + return False + + @staticmethod + def less_interesting_op(op): + """Returns True if the given Op is not an interesting one to be traced.""" + + include_less_interesting = TensorTracer.get_flag_value( + _FLAG_NAME_INCLUDE_LESS_INTERESTING_OPS) + if include_less_interesting: return False - if op.type in ['Reshape', 'ArgMin', 'ArgMax']: - return False - # pylint: disable=protected-access - return tpu._TPU_REPLICATE_ATTR in op.node_def.attr - # pylint: enable=protected-access + return op.type in ['Const', 'Identity', 'Cast', 'Shape'] @staticmethod def reason(op_idx, details): - """Returns why the Op at op_idx is traced or not.""" + """Returns reason why the Op at op_idx is traced or not.""" + return '%d %s'%(op_idx, details) @staticmethod @@ -274,6 +425,33 @@ class TensorTracer(object): assert len(unsorted_ops) == len(sorted_ops) return (True, sorted_ops) + @staticmethod + def _make_op_and_tensor_maps(op_list): + """Creates various maps and lists from op_list. + + Args: + op_list: a list of Ops + + Returns: + opname_idx_map: a map from Op's name to its index in op_list. + tensor_list: a list of output tensors of the Ops in op_list. + tensorname_idx_map: a map from output tensor name to its index + in tensor_list. + """ + + opname_idx_map = {} + tensor_list = [] + tensorname_idx_map = {} + for op_id, op in enumerate(op_list): + if op.name in opname_idx_map: + raise ValueError('Duplicated Op name: %s'%op.name) + opname_idx_map[op.name] = op_id + for output_tensor in op.outputs: + if output_tensor.name not in tensorname_idx_map: + tensor_list.append(output_tensor) + tensorname_idx_map[output_tensor.name] = len(tensor_list)-1 + return (opname_idx_map, tensor_list, tensorname_idx_map) + def __init__(self): """Initializes a TensorTracer. @@ -281,16 +459,20 @@ class TensorTracer(object): """ self._version = 'use-outside-compilation' self._device_type = None + TensorTracer.validate_flag_names() self._trace_mode = TensorTracer.get_flag_value(_FLAG_NAME_TRACE_MODE) if not self._trace_mode: self._trace_mode = _TRACE_MODE_NAN_INF TensorTracer.check_trace_mode(self._trace_mode) self._part_tensor_size = _TRACE_MODE_PART_TENSOR_SIZE self._instrument_records = {} - interesting_ops = TensorTracer.get_flag_value(_FLAG_NAME_INTERESTING_OPS) - self._selected_ops = interesting_ops.split() self._set_trace_file_path() + self._set_report_file() self._set_op_range() + self._set_excluded_opnames() + self._set_excluded_optypes() + self._set_included_opnames() + self._set_included_optypes() self._num_replicas = None self._replica_id = None @@ -318,10 +500,7 @@ class TensorTracer(object): """Sets the path of the output trace file.""" self._trace_file_path = TensorTracer.get_flag_value(_FLAG_NAME_TRACE_FILE) - if not self._trace_file_path: - raise ValueError('--%s is not set in the environment variable %s' - %(_FLAG_NAME_TRACE_FILE, _FLAGS_ENV_VAR)) - elif TensorTracer.use_test_undeclared_outputs_dir(): + if self._trace_file_path and TensorTracer.use_test_undeclared_outputs_dir(): if os.path.isabs(self._trace_file_path): raise ValueError('If use_test_undeclared_outputs_dir is set,' 'trace_file_path cannot be an absolute path (%s)' @@ -330,6 +509,22 @@ class TensorTracer(object): self._trace_file_path = os.path.join(outputs_dir, self._trace_file_path) + def _set_report_file(self): + """Sets the path of the output report file.""" + + self._report_file_path = TensorTracer.get_flag_value(_FLAG_NAME_REPORT_FILE) + if not self._report_file_path: + self._report_file = None + return + try: + self._report_file = gfile.Open(self._report_file_path, 'w') + except IOError as e: + raise e + + def _close_report_file(self): + if self._report_file: + self._report_file.close() + def _set_op_range(self): """Sets the index range of the Ops that we will consider tracing.""" @@ -350,19 +545,48 @@ class TensorTracer(object): return False return self._op_range[1] < 0 or idx <= self._op_range[1] - def _write_report(self, content): - """Writes the given content to the report.""" + def _set_excluded_opnames(self): + self._excluded_opname_re_list = TensorTracer.flag_value_to_re_list( + _FLAG_NAME_EXCLUDED_OPNAMES) + + def _set_excluded_optypes(self): + self._excluded_optype_re_list = TensorTracer.flag_value_to_re_list( + _FLAG_NAME_EXCLUDED_OPTYPES) + + def _set_included_opnames(self): + self._included_opname_re_list = TensorTracer.flag_value_to_re_list( + _FLAG_NAME_INCLUDED_OPNAMES) + + def _set_included_optypes(self): + self._included_optype_re_list = TensorTracer.flag_value_to_re_list( + _FLAG_NAME_INCLUDED_OPTYPES) + + def _is_user_included_op(self, op): + for opname_re in self._included_opname_re_list: + if opname_re.match(op.name): + return True + for optype_re in self._included_optype_re_list: + if optype_re.match(op.type): + return True + return False - logging.info('%s %s'%(_TRACER_LOG_PREFIX, content)) + def _is_user_excluded_op(self, op): + for opname_re in self._excluded_opname_re_list: + if opname_re.match(op.name): + return True + for optype_re in self._excluded_optype_re_list: + if optype_re.match(op.type): + return True + return False - def _is_selected_op(self, op_name): - """Returns True if the Op with op_name is selected to be traced.""" + def _write_report(self, content): + """Writes the given content to the report.""" - if not self._selected_ops: - return True - if op_name in self._selected_ops: - return True - return False + line = '%s %s'%(_TRACER_LOG_PREFIX, content) + if self._report_file: + self._report_file.write(line) + else: + logging.info(line) def _write_config_section(self): """Writes the config section of the report.""" @@ -382,15 +606,42 @@ class TensorTracer(object): self._write_report('"%s" %s\n'%(key, self._instrument_records[key])) self._write_report('%s %s\n'%(_MARKER_SECTION_END, _SECTION_NAME_REASON)) - def _write_op_list_section(self, op_list): + def _write_op_list_section(self, op_list, tensorname_idx_map): """Writes the Op-list section of the report.""" self._write_report('%s %s\n'%(_MARKER_SECTION_BEGIN, _SECTION_NAME_OP_LIST)) self._write_report('%s %d\n'%(_FIELD_NAME_NUM_OPS, len(op_list))) for i in range(0, len(op_list)): - self._write_report('%d "%s" %s\n'%(i, op_list[i].name, op_list[i].type)) + op = op_list[i] + line = '%d "%s" %s'%(i, op.name, op.type) + for out_tensor in op.outputs: + if out_tensor.name not in tensorname_idx_map: + raise ValueError( + 'out_tensor %s is not in tensorname_idx_map'%out_tensor.name) + line += ' %d'%tensorname_idx_map[out_tensor.name] + line += '\n' + self._write_report(line) self._write_report('%s %s\n'%(_MARKER_SECTION_END, _SECTION_NAME_OP_LIST)) + def _write_tensor_list_section(self, tensor_list, opname_idx_map): + """Writes the tensor-list section of the report.""" + + self._write_report('%s %s\n'%(_MARKER_SECTION_BEGIN, + _SECTION_NAME_TENSOR_LIST)) + self._write_report('%s %d\n'%(_FIELD_NAME_NUM_TENSORS, len(tensor_list))) + for i in range(0, len(tensor_list)): + tensor = tensor_list[i] + line = '%d "%s"'%(i, tensor.name) + for consumer_op in tensor.consumers(): + if consumer_op.name not in opname_idx_map: + raise ValueError( + 'consumer_op %s is not in opname_idx_map'%consumer_op.name) + line += ' %d'%opname_idx_map[consumer_op.name] + line += '\n' + self._write_report(line) + self._write_report('%s %s\n'%(_MARKER_SECTION_END, + _SECTION_NAME_TENSOR_LIST)) + def _write_graph_section(self, succeed, sorted_or_cycle): """Writes the graph section of the report.""" @@ -422,7 +673,7 @@ class TensorTracer(object): Args: op_name: the name of the Op that outputs the tensor to be printed. output_idx: which output of the Op it is (0 means the first output). - num_elements: number of elements to print. + num_elements: number of elements to print (-1 means print all). tensor: the tensor needs to be returned. output_tensor: the tensor needs to be printed. @@ -430,10 +681,13 @@ class TensorTracer(object): The same tensor passed via the "tensor" argument. """ msg = '"%s:%d" '%(op_name, output_idx) - output_stream = _OUTPUT_STREAM_ESCAPE + self._trace_file_path + if self._trace_file_path: + output_stream = _OUTPUT_STREAM_ESCAPE + self._trace_file_path + else: + output_stream = sys.stderr print_op = logging_ops.print_v2(msg, array_ops.shape(output_tensor), ' @', self._replica_id, - '\n', output_tensor, + '\n', output_tensor, '\n', summarize=num_elements, output_stream=output_stream) with ops.control_dependencies([print_op]): @@ -442,7 +696,8 @@ class TensorTracer(object): def _detect_nan_inf(tensor): """Trace function for detecting any NaN/Inf in the tensor.""" - if tensor.dtype.is_floating: + if tensor.dtype.__eq__(dtypes.bfloat16) or tensor.dtype.__eq__( + dtypes.float16): # Since host can't handle bf16, always convert tensor to f32. tensor = math_ops.cast(tensor, dtypes.float32) output_tensor = math_ops.reduce_any( @@ -450,12 +705,19 @@ class TensorTracer(object): gen_math_ops.is_inf(tensor))) else: output_tensor = constant_op.constant(0) - return _print_tensor(op_name, output_idx, 1, tensor, output_tensor) + return _print_tensor(op_name, output_idx, -1, tensor, output_tensor) - def _show_global_step(tensor): - """Trace function for printing the global step count.""" + def _show_norm(tensor): + tensor = math_ops.cast(tensor, dtypes.float64) + output_tensor = linalg_ops.norm(tensor) + return _print_tensor(op_name, output_idx, -1, tensor, output_tensor) - return _print_tensor(op_name, output_idx, 1, tensor, tensor) + def _show_max_abs(tensor): + output_tensor = math_ops.cast(math_ops.reduce_max(math_ops.abs(tensor)), + dtypes.float64) + zero = constant_op.constant(0, dtypes.float64) + output_tensor = gen_math_ops.maximum(zero, output_tensor) + return _print_tensor(op_name, output_idx, -1, tensor, output_tensor) def _show_part_tensor(tensor): """Trace function for printing part of the tensor.""" @@ -468,23 +730,139 @@ class TensorTracer(object): return _print_tensor(op_name, output_idx, -1, tensor, tensor) - if op_name == _GLOBAL_STEP_OP_NAME: - return _show_global_step if self._trace_mode == _TRACE_MODE_NAN_INF: return _detect_nan_inf if self._trace_mode == _TRACE_MODE_PART_TENSOR: return _show_part_tensor if self._trace_mode == _TRACE_MODE_FULL_TENSOR: return _show_full_tensor + if self._trace_mode == _TRACE_MODE_NORM: + return _show_norm + if self._trace_mode == _TRACE_MODE_MAX_ABS: + return _show_max_abs raise RuntimeError('Tensor trace fun for %s is not yet implemented' %self._trace_mode) + def _skip_op(self, op_id, op, user_included, user_excluded): + """Returns True if we should not trace Op.""" + + if user_included: + self._instrument_records[op.name] = TensorTracer.reason( + op_id, _REASON_USER_INCLUDED) + return False + if user_excluded: + self._instrument_records[op.name] = TensorTracer.reason( + op_id, _REASON_USER_EXCLUDED) + return True + if not self._inside_op_range(op_id): + self._instrument_records[op.name] = TensorTracer.reason( + op_id, _REASON_OUTSIDE_OP_RANGE) + return True + if TensorTracer.unsafe_op(op): + self._instrument_records[op.name] = TensorTracer.reason( + op_id, _REASON_UNSAFE_OP) + return True + if TensorTracer.device_mismatch(self._device_type, op): + self._instrument_records[op.name] = TensorTracer.reason( + op_id, _REASON_DEVICE_MISMATCH) + return True + if TensorTracer.less_interesting_op(op): + self._instrument_records[op.name] = TensorTracer.reason( + op_id, _REASON_LESS_INTERESTING_OP) + return True + return False + + def _skip_tensor(self, op_id, out_tensor, user_included, + user_excluded): + """Returns True if we should not trace out_tensor.""" + + # Skips a tensor if the tensor has a non-numeric type. + # Note: we cannot use check_ops.is_numeric_tensor(out_tensor) + # because it also excludes tensors with dtypes, bool, and + # float32_ref, which we actually want to trace. + non_numeric_tensor_types = set([dtypes.variant, dtypes.resource, + dtypes.string]) + if out_tensor.dtype in non_numeric_tensor_types: + self._instrument_records[out_tensor.name] = TensorTracer.reason( + op_id, _REASON_NON_NUMERIC_TENSOR) + return True + + if user_included: + self._instrument_records[out_tensor.name] = TensorTracer.reason( + op_id, _REASON_USER_INCLUDED) + return False + if user_excluded: + self._instrument_records[out_tensor.name] = TensorTracer.reason( + op_id, _REASON_USER_EXCLUDED) + return True + if not out_tensor.get_shape().is_fully_defined(): + self._instrument_records[out_tensor.name] = TensorTracer.reason( + op_id, _REASON_DYNAMIC_SHAPE) + return True + rank = len(out_tensor.shape) + if rank < 1: + # scalar + if TensorTracer.unsafe_scalar_trace(out_tensor.op): + self._instrument_records[out_tensor.name] = TensorTracer.reason( + op_id, _REASON_UNSAFE_SCALAR) + return True + else: + self._instrument_records[out_tensor.name] = TensorTracer.reason( + op_id, _REASON_SCALAR_GET_TRACED) + return False + else: + # tensor + self._instrument_records[out_tensor.name] = TensorTracer.reason( + op_id, _REASON_TENSOR_GET_TRACED) + return False + + def _pre_tracing(self, graph): + """Work needs to be done prior to TPU or CPU tracing.""" + + operations = graph.get_operations() + (opname_idx_map, tensor_list, tensorname_idx_map) = ( + TensorTracer._make_op_and_tensor_maps(operations)) + self._write_config_section() + self._write_op_list_section(operations, tensorname_idx_map) + self._write_tensor_list_section(tensor_list, opname_idx_map) + # Does the topological sort before adding any nodes to the graph. + (succeed, sorted_or_cycle) = TensorTracer.topological_sort(graph) + return (operations, succeed, sorted_or_cycle) + + def _post_tracing(self, succeed, sorted_or_cycle): + """Work needs to be done after TPU or CPU tracing.""" + + self._write_reason_section() + self._write_graph_section(succeed, sorted_or_cycle) + self._close_report_file() + + def _get_checkpoints(self, graph): + """Returns the list of Ops that produce the tensors traced with API. + + Args: + graph: the graph of Ops. + + Returns: + A set of operation names which should be traced. + """ + + self._write_report('%s %s\n'%(_MARKER_SECTION_BEGIN, + _TENSOR_TRACER_CHECKPOINT)) + checkpoint_operations = set() + tensor_tracer_variables = graph.get_collection(_TENSOR_TRACER_COLLECTION) + for (tensor, checkpoint_name) in tensor_tracer_variables: + self._write_report('%s %s\n'%(tensor.name, checkpoint_name)) + checkpoint_operations.add(tensor.op.name) + self._write_report('%s %s\n'%(_MARKER_SECTION_END, + _TENSOR_TRACER_CHECKPOINT)) + return checkpoint_operations + def trace_tpu(self, graph, result_tensor, num_replicas=None): """Traces the tensors generated by TPU Ops in a TF graph. Args: - graph: the graph of Ops. + graph: the graph of Ops executed on the TPU. result_tensor: a result tensor of evaluating the graph. num_replicas: number of replicas used on the TPU. @@ -502,38 +880,22 @@ class TensorTracer(object): TensorTracer.check_device_type(self._device_type) result_tensor_copy = self._add_replica_id_to_graph(num_replicas, result_tensor) - self._write_config_section() + (operations, succeed, sorted_or_cycle) = self._pre_tracing(graph) tracing_ops = [] - operations = graph.get_operations() - self._write_op_list_section(operations) - # Does the topological sort before adding any nodes to the graph. - (succeed, sorted_or_cycle) = TensorTracer.topological_sort(graph) + checkpoint_operations = self._get_checkpoints(graph) + for op_id, op in enumerate(operations): - if not self._inside_op_range(op_id): - self._instrument_records[op.name] = TensorTracer.reason( - op_id, _RECORD_OUTSIDE_OP_RANGE) + if checkpoint_operations and op.name not in checkpoint_operations: continue - if not TensorTracer.should_trace(self._device_type, op): - self._instrument_records[op.name] = TensorTracer.reason( - op_id, _RECORD_SHOULD_NOT_TRACE) - continue - if not self._is_selected_op(op.name): - self._instrument_records[op.name] = TensorTracer.reason( - op_id, _RECORD_FILTERED_OUT) + user_included = self._is_user_included_op(op) + user_excluded = self._is_user_excluded_op(op) + if self._skip_op(op_id, op, user_included, user_excluded): continue for i in range(len(op.outputs)): out_tensor = op.outputs[i] - if not out_tensor.get_shape().is_fully_defined(): - self._instrument_records[out_tensor.name] = TensorTracer.reason( - op_id, _RECORD_DYNAMIC_SHAPE) - continue # cannot trace tensors with dynamic shape. - rank = len(out_tensor.shape) - if rank < 1: - self._instrument_records[out_tensor.name] = TensorTracer.reason( - op_id, _RECORD_SCALAR) - continue # cannot trace scalar. - self._instrument_records[out_tensor.name] = TensorTracer.reason( - op_id, _RECORD_GET_TRACED) + if self._skip_tensor(op_id, out_tensor, user_included, + user_excluded): + continue consumers = out_tensor.consumers() trace_op = tpu.outside_compilation( self._make_tensor_trace_fun(op.name, i), out_tensor) @@ -546,8 +908,45 @@ class TensorTracer(object): # if there is no consumer, we will add the control dependence later # when we add the control dependency to the output operations. tracing_ops.append(trace_op) + self._post_tracing(succeed, sorted_or_cycle) + return (result_tensor_copy, tracing_ops) - self._write_reason_section() - self._write_graph_section(succeed, sorted_or_cycle) + def trace_cpu(self, graph): + """Traces the tensors generated by CPU Ops in a TF graph. - return (result_tensor_copy, tracing_ops) + Args: + graph: the graph of Ops executed on the CPU. + + Returns: + tracing_calls: a map from keys to trace calls. + A key is constructed from an Op's name. + A trace call consists of a function and a tensor ( + the function will be invoked with the tensor). + """ + + self._device_type = _DEVICE_TYPE_CPU + TensorTracer.check_device_type(self._device_type) + self._num_replicas = 1 + self._replica_id = 0 + (operations, succeed, sorted_or_cycle) = self._pre_tracing(graph) + tracing_calls = {} + checkpoint_operations = self._get_checkpoints(graph) + + for op_id, op in enumerate(operations): + if checkpoint_operations and op.name not in checkpoint_operations: + continue + user_included = self._is_user_included_op(op) + user_excluded = self._is_user_excluded_op(op) + if self._skip_op(op_id, op, user_included, user_excluded): + continue + for i in range(len(op.outputs)): + out_tensor = op.outputs[i] + if self._skip_tensor(op_id, out_tensor, user_included, + user_excluded): + continue + trace_fun = self._make_tensor_trace_fun(op.name, i) + trace_call = (trace_fun, [out_tensor]) + trace_call_key = 'tensor_tracing_cpu-%s:%d'%(op.name, i) + tracing_calls[trace_call_key] = trace_call + self._post_tracing(succeed, sorted_or_cycle) + return tracing_calls diff --git a/tensorflow/contrib/tpu/python/tpu/tpu.py b/tensorflow/contrib/tpu/python/tpu/tpu.py index def57da20d6018dcf27ccb7a9d04592f38ce2f7c..9266d81cf5fc035790062f0e307a5da0b01a9fc1 100644 --- a/tensorflow/contrib/tpu/python/tpu/tpu.py +++ b/tensorflow/contrib/tpu/python/tpu/tpu.py @@ -646,6 +646,10 @@ def split_compile_and_replicate(computation, array_ops.identity(x, name="replicated_input_{}".format(i)) for i, x in enumerate(computation_inputs) ] + for i in computation_inputs: + # pylint: disable=protected-access + i.op._set_attr("_tpu_input_identity", attr_value_pb2.AttrValue(b=True)) + # pylint: enable=protected-access # If there is an infeed queue, adds the dequeued values to the # computation's inputs. @@ -726,7 +730,11 @@ def split_compile_and_replicate(computation, new_output_tensors = [] for t in output_tensors: with ops.device(t.device if t.device else core(0)): - new_output_tensors.append(array_ops.identity(t)) + o = array_ops.identity(t) + # pylint: disable=protected-access + o.op._set_attr("_tpu_output_identity", attr_value_pb2.AttrValue(b=True)) + # pylint: enable=protected-access + new_output_tensors.append(o) output_tensors = new_output_tensors context.ExitResult(output_tensors) finally: @@ -777,15 +785,15 @@ def split_compile_and_replicate(computation, ] -def shard(computation, - inputs=None, - num_shards=1, - input_shard_axes=None, - outputs_from_all_shards=True, - output_shard_axes=None, - infeed_queue=None, - device_assignment=None, - name=None): +def split_compile_and_shard(computation, + inputs=None, + num_shards=1, + input_shard_axes=None, + outputs_from_all_shards=True, + output_shard_axes=None, + infeed_queue=None, + device_assignment=None, + name=None): """Shards `computation` for parallel execution. `inputs` must be a list of Tensors or None (equivalent to an empty list), each @@ -839,7 +847,7 @@ def shard(computation, is equal to the number of cores in the TPU system. name: (Deprecated) Does nothing. Returns: - A list of output tensors. + A tuple of (compile op, [output tensors]). Raises: ValueError: If num_shards <= 0 ValueError: If len(input_shard_axes) != len(inputs) @@ -874,7 +882,7 @@ def shard(computation, else: transposed_inputs = [[]] * num_shards - outputs = replicate( + compile_op, outputs = split_compile_and_replicate( computation, transposed_inputs, infeed_queue=infeed_queue, @@ -891,7 +899,7 @@ def shard(computation, # one so it can be used as a control dependency or fetch node. # TODO(b/36647078) remove disable when pylint bug is fixed. # pylint: disable=indexing-exception - return [outputs[0]] + return compile_op, [outputs[0]] # pylint: enable=indexing-exception # TODO(b/36647078) remove disable when pylint bug is fixed. @@ -925,7 +933,87 @@ def shard(computation, # TODO(phawkins): use a smarter policy, e.g., round-robin across shards. results.append(x[0]) - return results + return compile_op, results + + +def shard(computation, + inputs=None, + num_shards=1, + input_shard_axes=None, + outputs_from_all_shards=True, + output_shard_axes=None, + infeed_queue=None, + device_assignment=None, + name=None): + """Shards `computation` for parallel execution. + + `inputs` must be a list of Tensors or None (equivalent to an empty list), each + of which has a corresponding split axis (from `input_shard_axes`). Each input + is split into `num_shards` pieces along the corresponding axis, and + computation is applied to each shard in parallel. + + Tensors are broadcast to all shards if they are lexically captured by + `computation`. e.g., + + x = tf.constant(7) + def computation(): + return x + 3 + ... = shard(computation, ...) + + TODO(phawkins): consider adding support for broadcasting Tensors passed + as inputs. + + If `outputs_from_all_shards` is true, the outputs from all shards of + `computation` are concatenated back together along their `output_shards_axes`. + Otherwise, each output is taken from an arbitrary shard. + + Inputs and outputs of the computation must be at least rank-1 Tensors. + + Args: + computation: A Python function that builds a computation to apply to each + shard of the input. + inputs: A list of input tensors or None (equivalent to an empty list). Each + input tensor has a corresponding shard axes, given by `input_shard_axes`, + which must have size divisible by `num_shards`. + num_shards: The number of shards. + input_shard_axes: A list of dimensions along which to shard `inputs`, or + `None`. `None` means "shard all inputs along dimension 0". If not `None`, + there must be one dimension per input. + outputs_from_all_shards: Boolean or list of boolean. For each output, if + `True`, outputs from all shards are concatenated along the corresponding + `output_shard_axes` entry. Otherwise, each output is taken + from an arbitrary shard. If the argument is a boolean, the argument's + value is used for each output. + output_shard_axes: A list of dimensions along which to concatenate the + outputs of `computation`, or `None`. `None` means "concatenate all outputs + along dimension 0". If not `None`, there must be one dimension per output. + Ignored if `outputs_from_all_shards` is False. + infeed_queue: If not `None`, the `InfeedQueue` to use to augment the inputs + of `computation`. + device_assignment: If not `None`, a `DeviceAssignment` describing the + mapping between logical cores in the computation with physical cores in + the TPU topology. Uses a default device assignment if `None`. The + `DeviceAssignment` may be omitted if each shard of the computation uses + only one core, and there is either only one shard, or the number of shards + is equal to the number of cores in the TPU system. + name: (Deprecated) Does nothing. + Returns: + A list of output tensors. + Raises: + ValueError: If num_shards <= 0 + ValueError: If len(input_shard_axes) != len(inputs) + ValueError: If len(output_shard_axes) != len(outputs from `computation`) + """ + return split_compile_and_shard( + computation, + inputs=inputs, + num_shards=num_shards, + input_shard_axes=input_shard_axes, + outputs_from_all_shards=outputs_from_all_shards, + output_shard_axes=output_shard_axes, + infeed_queue=infeed_queue, + device_assignment=device_assignment, + name=name)[1] def batch_parallel(computation, diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py b/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py index fb1316cf33dee86aba9e6f1ae15cb54298c25d7c..87a970f0523363426b0da5b12838b797d7f8bebb 100644 --- a/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py +++ b/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py @@ -31,6 +31,7 @@ import six from six.moves import queue as Queue # pylint: disable=redefined-builtin from six.moves import xrange # pylint: disable=redefined-builtin +from tensorflow.contrib.tpu.proto import compilation_result_pb2 as tpu_compilation_result from tensorflow.contrib.tpu.python.tpu import tensor_tracer from tensorflow.contrib.tpu.python.ops import tpu_ops from tensorflow.contrib.tpu.python.tpu import error_handling @@ -336,6 +337,16 @@ class TPUEstimatorSpec(model_fn_lib._TPUEstimatorSpec): # pylint: disable=prote hooks = None if self.host_call is not None: hooks = [_OutfeedHostCallHook(host_call_ret['host_call'])] + if tensor_tracer.TensorTracer.is_enabled(): + tt = tensor_tracer.TensorTracer() + tracing_calls = tt.trace_cpu(ops.get_default_graph()) + tracing_call_ret = _OutfeedHostCall.create_cpu_hostcall(tracing_calls) + tracing_functions = tracing_call_ret.values() + if tracing_functions: + if hooks: + hooks.extend([_OutfeedHostCallHook(tracing_functions)]) + else: + hooks = [_OutfeedHostCallHook(tracing_functions)] hooks = tuple(hooks or []) scaffold = self.scaffold_fn() if self.scaffold_fn else None return model_fn_lib.EstimatorSpec( @@ -412,6 +423,7 @@ class TPUInfeedOutfeedSessionHook(session_run_hook.SessionRunHook): ctx, enqueue_ops, dequeue_ops, + tpu_compile_op, run_infeed_loop_on_coordinator=True, rendezvous=None, master=None, @@ -429,6 +441,7 @@ class TPUInfeedOutfeedSessionHook(session_run_hook.SessionRunHook): self._feed_error = None self._finished = False self._should_initialize_tpu = True + self._tpu_compile_op = tpu_compile_op def begin(self): logging.info('TPU job name %s', self._master_job) @@ -477,6 +490,15 @@ class TPUInfeedOutfeedSessionHook(session_run_hook.SessionRunHook): def _create_infeed_controller(self, name, target, args): return _OpQueueContext(name=name, target=target, args=args) + def _assertCompilationSucceeded(self, result, coord): + proto = tpu_compilation_result.CompilationResultProto() + proto.ParseFromString(result) + if proto.status_error_message: + logging.error('Compilation failed: {}'.format(proto.status_error_message)) + coord.request_stop() + else: + logging.info('Compilation succeeded') + def after_create_session(self, session, coord): if self._should_initialize_tpu: logging.info('Init TPU system') @@ -490,6 +512,10 @@ class TPUInfeedOutfeedSessionHook(session_run_hook.SessionRunHook): session.run(self._init_ops, options=config_pb2.RunOptions(timeout_in_ms=5 * 60 * 1000)) + if os.environ.get('TPU_SPLIT_COMPILE_AND_EXECUTE', '') == '1': + logging.info('Compiling user program: this may take a while...') + self._assertCompilationSucceeded(session.run(self._tpu_compile_op), coord) + self._infeed_controller = self._create_infeed_controller( name='InfeedController', target=self._run_infeed, args=(session,)) @@ -530,13 +556,17 @@ class TPUInfeedOutfeedSessionHook(session_run_hook.SessionRunHook): class TPUInfeedOutfeedSessionHookForPrediction(TPUInfeedOutfeedSessionHook): - def __init__(self, ctx, enqueue_ops, dequeue_ops, rendezvous=None): + def __init__(self, ctx, enqueue_ops, dequeue_ops, tpu_compile_op, + rendezvous=None, master=None, session_config=None): super(TPUInfeedOutfeedSessionHookForPrediction, self).__init__( ctx, enqueue_ops, dequeue_ops, + tpu_compile_op=tpu_compile_op, run_infeed_loop_on_coordinator=False, - rendezvous=rendezvous) + rendezvous=rendezvous, + master=master, + session_config=session_config) def _create_infeed_controller(self, name, target, args): return _OpSignalOnceQueueContext(name=name, target=target, args=args) @@ -1642,7 +1672,7 @@ class _OutfeedHostCall(object): 'Exception while calling %s: %s. It is likely the tensors ' '(%s[1]) do not match the ' 'function\'s arguments', name, e, name) - raise e + raise return ret def record(self, host_calls): @@ -1748,9 +1778,22 @@ class _OutfeedHostCall(object): raise RuntimeError( 'All tensors outfed from TPU should preserve batch size ' 'dimension, but got scalar {}'.format(dequeue_ops[i][0])) - # TODO(xiejw): Allow users to specify the axis for batch size - # dimension. - dequeue_ops[i] = array_ops.concat(dequeue_ops[i], axis=0) + # TODO(xiejw): Make the specification of the outfeed combinaton + # function more explicit and well-documented. We may want to give the + # user the option of concatenating along any axis. + if (self._ctx.config.tpu_config.per_host_input_for_training is + tpu_config.InputPipelineConfig.BROADCAST): + # If the infeed is in BROADCAST mode (each core recieving the same + # input), then we assume that the cores also produce identical + # copies of the same output, and we simply take the output from + # the first core. This mode is used by Mesh-TensorFlow. + with ops.control_dependencies(dequeue_ops[i]): + dequeue_ops[i] = array_ops.identity(dequeue_ops[i][0]) + else: + # Assume that the input has been batch-split and that axis 0 of the + # output tensors represents the batch size. Concatenate along + # the axis 0 to re-combine the batch. + dequeue_ops[i] = array_ops.concat(dequeue_ops[i], axis=0) if self._tensor_keys[name] is not None: # The user-provided eval_metrics[1] is a dict. @@ -1762,7 +1805,7 @@ class _OutfeedHostCall(object): 'Exception while calling %s: %s. It is likely the tensors ' '(%s[1]) do not match the ' 'function\'s arguments', name, e, name) - raise e + raise else: ret[name] = self._host_fns[name](*dequeue_ops) @@ -2250,7 +2293,7 @@ class TPUEstimator(estimator_lib.Estimator): (k, _export_output_to_tensors(v)) for k, v in six.iteritems(estimator_spec.export_outputs)) tensors = nest.flatten(tensors_dict) - tpu_tensors = [t for t in tensors if _is_tpu_tensor(t)] + tpu_tensors = [t for t in tensors if t is not None] # We cannot return anything other than `tpu_tensors` here so we capture # the rest for later use. @@ -2264,18 +2307,10 @@ class TPUEstimator(estimator_lib.Estimator): # `tpu_tensors_on_cpu`. new_tensors = [] for t in tensors: - if _is_tpu_tensor(t): - new_tensors.append(tpu_tensors_on_cpu.pop(0)) - elif t is None: + if t is None: new_tensors.append(None) else: - # Only fetching `tpu_tensors_on_cpu` does not trigger - # TPU computation and blocks, so we add the control dependency here. - control_inputs = ( - tpu_tensors_on_cpu if _is_iterable(tpu_tensors_on_cpu) else - (tpu_tensors_on_cpu,)) - with ops.control_dependencies(control_inputs): - new_tensors.append(array_ops.identity(t)) + new_tensors.append(tpu_tensors_on_cpu.pop(0)) # Reconstruct `tensors_dict`. new_tensors_dict = nest.pack_sequence_as(tensors_dict, new_tensors) @@ -2532,7 +2567,7 @@ class TPUEstimator(estimator_lib.Estimator): graph.add_to_collection(_TPU_ENQUEUE_OPS, enqueue_op) if mode == model_fn_lib.ModeKeys.TRAIN: - loss, host_call, scaffold, training_hooks = ( + compile_op, loss, host_call, scaffold, training_hooks = ( _train_on_tpu_system(ctx, model_fn_wrapper, dequeue_fn)) host_ops = host_call.create_tpu_hostcall() if host_ops is None: @@ -2567,6 +2602,7 @@ class TPUEstimator(estimator_lib.Estimator): ctx, enqueue_ops, host_ops, + tpu_compile_op=compile_op, run_infeed_loop_on_coordinator=( run_infeed_loop_on_coordinator), rendezvous=self._rendezvous[mode], @@ -2624,8 +2660,8 @@ class TPUEstimator(estimator_lib.Estimator): scaffold=scaffold) if mode == model_fn_lib.ModeKeys.EVAL: - total_loss, host_calls, scaffold, eval_hooks = _eval_on_tpu_system( - ctx, model_fn_wrapper, dequeue_fn) + compile_op, total_loss, host_calls, scaffold, eval_hooks = ( + _eval_on_tpu_system(ctx, model_fn_wrapper, dequeue_fn)) iterations_per_loop_var = _create_or_get_iterations_per_loop() mean_loss = math_ops.div( total_loss, @@ -2672,6 +2708,7 @@ class TPUEstimator(estimator_lib.Estimator): ctx, enqueue_ops, eval_update_ops + host_ops, + tpu_compile_op=compile_op, run_infeed_loop_on_coordinator=( run_infeed_loop_on_coordinator), rendezvous=self._rendezvous[mode], @@ -2692,7 +2729,7 @@ class TPUEstimator(estimator_lib.Estimator): # Predict assert mode == model_fn_lib.ModeKeys.PREDICT - (dummy_predict_op, host_calls, + (compile_op, dummy_predict_op, host_calls, scaffold, prediction_hooks) = _predict_on_tpu_system( ctx, model_fn_wrapper, dequeue_fn) with ops.control_dependencies([dummy_predict_op]): @@ -2748,7 +2785,10 @@ class TPUEstimator(estimator_lib.Estimator): hooks = [ _StoppingPredictHook(scalar_stopping_signal), TPUInfeedOutfeedSessionHookForPrediction( - ctx, enqueue_ops, host_ops, rendezvous=self._rendezvous[mode]), + ctx, enqueue_ops, host_ops, rendezvous=self._rendezvous[mode], + tpu_compile_op=compile_op, + master=self._config.master, + session_config=self._session_config), ] + input_hooks if prediction_hooks: @@ -2763,17 +2803,6 @@ class TPUEstimator(estimator_lib.Estimator): return _model_fn -def _is_tpu_tensor(tensor): - if not isinstance(tensor, ops.Tensor): - return False - try: - tensor.op.get_attr(tpu._OUTSIDE_COMPILATION_ATTR) # pylint: disable=protected-access - except ValueError: - return True - else: - return False - - def _export_output_to_tensors(export_output): """Get a list of `Tensors` used in `export_output`. @@ -2845,15 +2874,16 @@ def _eval_on_tpu_system(ctx, model_fn_wrapper, dequeue_fn): return training_loop.repeat(iterations_per_loop_var, single_tpu_eval_step, [_ZERO_LOSS]) - (loss,) = tpu.shard( + (compile_op, loss,) = tpu.split_compile_and_shard( multi_tpu_eval_steps_on_single_shard, inputs=[], num_shards=ctx.num_replicas, outputs_from_all_shards=False, device_assignment=ctx.device_assignment) + loss = loss[0] scaffold = _get_scaffold(captured_scaffold_fn) - return loss, host_calls, scaffold, captured_eval_hooks.get() + return compile_op, loss, host_calls, scaffold, captured_eval_hooks.get() def _train_on_tpu_system(ctx, model_fn_wrapper, dequeue_fn): @@ -2868,15 +2898,16 @@ def _train_on_tpu_system(ctx, model_fn_wrapper, dequeue_fn): return training_loop.repeat(iterations_per_loop_var, single_tpu_train_step, [_INITIAL_LOSS]) - (loss,) = tpu.shard( + (compile_op, loss,) = tpu.split_compile_and_shard( multi_tpu_train_steps_on_single_shard, inputs=[], num_shards=ctx.num_replicas, outputs_from_all_shards=False, device_assignment=ctx.device_assignment) + loss = loss[0] scaffold = _get_scaffold(captured_scaffold_fn) - return loss, host_call, scaffold, captured_training_hooks.get() + return compile_op, loss, host_call, scaffold, captured_training_hooks.get() def _predict_on_tpu_system(ctx, model_fn_wrapper, dequeue_fn): @@ -2896,15 +2927,17 @@ def _predict_on_tpu_system(ctx, model_fn_wrapper, dequeue_fn): cond, single_tpu_predict_step, inputs=inputs, name=b'loop') return outputs - (dummy_predict_op,) = tpu.shard( + (compile_op, dummy_predict_op,) = tpu.split_compile_and_shard( multi_tpu_predict_steps_on_single_shard, inputs=[], num_shards=ctx.num_replicas, outputs_from_all_shards=False, device_assignment=ctx.device_assignment) + dummy_predict_op = dummy_predict_op[0] scaffold = _get_scaffold(captured_scaffold_fn) - return dummy_predict_op, host_calls, scaffold, captured_predict_hooks.get() + return (compile_op, dummy_predict_op, host_calls, scaffold, + captured_predict_hooks.get()) def _wrap_computation_in_while_loop(device, op_fn): diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_estimator_signals_test.py b/tensorflow/contrib/tpu/python/tpu/tpu_estimator_signals_test.py index 55235556de0214a8e04fb85469cd1d8e4656fb56..e3ea983abfd24d03c964fbc647b56262e15e0a96 100644 --- a/tensorflow/contrib/tpu/python/tpu/tpu_estimator_signals_test.py +++ b/tensorflow/contrib/tpu/python/tpu/tpu_estimator_signals_test.py @@ -21,8 +21,8 @@ from __future__ import print_function import numpy as np from tensorflow.contrib.tpu.python.tpu import tpu_estimator -from tensorflow.python import data as dataset_lib from tensorflow.python.client import session +from tensorflow.python.data.ops import dataset_ops from tensorflow.python.framework import errors from tensorflow.python.framework import ops from tensorflow.python.platform import test @@ -34,10 +34,10 @@ def make_input_fn(num_samples): def input_fn(params): batch_size = params['batch_size'] - da1 = dataset_lib.Dataset.from_tensor_slices(a) - da2 = dataset_lib.Dataset.from_tensor_slices(b) + da1 = dataset_ops.Dataset.from_tensor_slices(a) + da2 = dataset_ops.Dataset.from_tensor_slices(b) - dataset = dataset_lib.Dataset.zip((da1, da2)) + dataset = dataset_ops.Dataset.zip((da1, da2)) dataset = dataset.map(lambda fa, fb: {'a': fa, 'b': fb}) dataset = dataset.batch(batch_size) return dataset @@ -50,10 +50,10 @@ def make_input_fn_with_labels(num_samples): def input_fn(params): batch_size = params['batch_size'] - da1 = dataset_lib.Dataset.from_tensor_slices(a) - da2 = dataset_lib.Dataset.from_tensor_slices(b) + da1 = dataset_ops.Dataset.from_tensor_slices(a) + da2 = dataset_ops.Dataset.from_tensor_slices(b) - dataset = dataset_lib.Dataset.zip((da1, da2)) + dataset = dataset_ops.Dataset.zip((da1, da2)) dataset = dataset.map(lambda fa, fb: ({'a': fa}, fb)) dataset = dataset.batch(batch_size) return dataset @@ -71,7 +71,7 @@ class TPUEstimatorStoppingSignalsTest(test.TestCase): with ops.Graph().as_default(): dataset = input_fn(params) - features = dataset_lib.make_one_shot_iterator(dataset).get_next() + features = dataset_ops.make_one_shot_iterator(dataset).get_next() # With tf.data.Dataset.batch, the batch is None, i.e., dynamic shape. self.assertIsNone(features['a'].shape.as_list()[0]) diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_system_metadata.py b/tensorflow/contrib/tpu/python/tpu/tpu_system_metadata.py index ec682e5829c4df536a043334b74200f0b6259df3..d66ecfcf4a56b8da1c2d2f518bebe4baa76b315e 100644 --- a/tensorflow/contrib/tpu/python/tpu/tpu_system_metadata.py +++ b/tensorflow/contrib/tpu/python/tpu/tpu_system_metadata.py @@ -52,6 +52,7 @@ def _query_tpu_system_metadata(master_address, cluster_def=None, devices = [] device_dict = collections.defaultdict(list) + # TODO(b/120564445): Replace with standard library for retries. retry_count = 1 while True: logging.info('Querying Tensorflow master (%s) for TPU system metadata.', diff --git a/tensorflow/contrib/training/python/training/hparam.py b/tensorflow/contrib/training/python/training/hparam.py index 3beb7bfe3048a8f0294f7e9149b5a07b5fcc7d17..bcc177601b95172b05d327247bd370c2f8b65d59 100644 --- a/tensorflow/contrib/training/python/training/hparam.py +++ b/tensorflow/contrib/training/python/training/hparam.py @@ -187,7 +187,7 @@ def _cast_to_type_if_compatible(name, param_type, value): return param_type(value) -def parse_values(values, type_map): +def parse_values(values, type_map, ignore_unknown=False): """Parses hyperparameter values from a string into a python map. `values` is a string containing comma-separated `name=value` pairs. @@ -233,6 +233,9 @@ def parse_values(values, type_map): type T if either V has type T, or V is a list of elements of type T. Hence, for a multidimensional parameter 'x' taking float values, 'x=[0.1,0.2]' will parse successfully if type_map['x'] = float. + ignore_unknown: Bool. Whether values that are missing a type in type_map + should be ignored. If set to True, a ValueError will not be raised for + unknown hyperparameter type. Returns: A python map mapping each name to either: @@ -260,6 +263,8 @@ def parse_values(values, type_map): m_dict = m.groupdict() name = m_dict['name'] if name not in type_map: + if ignore_unknown: + continue raise ValueError('Unknown hyperparameter type for %s' % name) type_ = type_map[name] diff --git a/tensorflow/contrib/training/python/training/hparam_test.py b/tensorflow/contrib/training/python/training/hparam_test.py index 660c97f25e8458c345c8914bcaf98f37d047e50e..a990e04711ce68bd928a508484f0d6f657dd2f8c 100644 --- a/tensorflow/contrib/training/python/training/hparam_test.py +++ b/tensorflow/contrib/training/python/training/hparam_test.py @@ -216,6 +216,14 @@ class HParamsTest(test.TestCase): self.assertTrue(isinstance(parse_dict['arr'], dict)) self.assertDictEqual(parse_dict['arr'], {1: 10}) + def testParseValuesWithIndexAssigment1_IgnoreUnknown(self): + """Assignment to an index position.""" + parse_dict = hparam.parse_values( + 'arr[1]=10,b=5', {'arr': int}, ignore_unknown=True) + self.assertEqual(len(parse_dict), 1) + self.assertTrue(isinstance(parse_dict['arr'], dict)) + self.assertDictEqual(parse_dict['arr'], {1: 10}) + def testParseValuesWithIndexAssigment2(self): """Assignment to multiple index positions.""" parse_dict = hparam.parse_values('arr[0]=10,arr[5]=20', {'arr': int}) @@ -223,6 +231,14 @@ class HParamsTest(test.TestCase): self.assertTrue(isinstance(parse_dict['arr'], dict)) self.assertDictEqual(parse_dict['arr'], {0: 10, 5: 20}) + def testParseValuesWithIndexAssigment2_IgnoreUnknown(self): + """Assignment to multiple index positions.""" + parse_dict = hparam.parse_values( + 'arr[0]=10,arr[5]=20,foo=bar', {'arr': int}, ignore_unknown=True) + self.assertEqual(len(parse_dict), 1) + self.assertTrue(isinstance(parse_dict['arr'], dict)) + self.assertDictEqual(parse_dict['arr'], {0: 10, 5: 20}) + def testParseValuesWithIndexAssigment3(self): """Assignment to index positions in multiple names.""" parse_dict = hparam.parse_values('arr[0]=10,arr[1]=20,L[5]=100,L[10]=200', @@ -234,6 +250,17 @@ class HParamsTest(test.TestCase): self.assertTrue(isinstance(parse_dict['L'], dict)) self.assertDictEqual(parse_dict['L'], {5: 100, 10: 200}) + def testParseValuesWithIndexAssigment3_IgnoreUnknown(self): + """Assignment to index positions in multiple names.""" + parse_dict = hparam.parse_values( + 'arr[0]=10,C=5,arr[1]=20,B[0]=kkk,L[5]=100,L[10]=200', + {'arr': int, 'L': int}, ignore_unknown=True) + self.assertEqual(len(parse_dict), 2) + self.assertTrue(isinstance(parse_dict['arr'], dict)) + self.assertDictEqual(parse_dict['arr'], {0: 10, 1: 20}) + self.assertTrue(isinstance(parse_dict['L'], dict)) + self.assertDictEqual(parse_dict['L'], {5: 100, 10: 200}) + def testParseValuesWithIndexAssigment4(self): """Assignment of index positions and scalars.""" parse_dict = hparam.parse_values('x=10,arr[1]=20,y=30', @@ -246,6 +273,17 @@ class HParamsTest(test.TestCase): self.assertEqual(parse_dict['x'], 10) self.assertEqual(parse_dict['y'], 30) + def testParseValuesWithIndexAssigment4_IgnoreUnknown(self): + """Assignment of index positions and scalars.""" + parse_dict = hparam.parse_values( + 'x=10,foo[0]=bar,arr[1]=20,zzz=78,y=30', + {'x': int, 'y': int, 'arr': int}, ignore_unknown=True) + self.assertEqual(len(parse_dict), 3) + self.assertTrue(isinstance(parse_dict['arr'], dict)) + self.assertDictEqual(parse_dict['arr'], {1: 20}) + self.assertEqual(parse_dict['x'], 10) + self.assertEqual(parse_dict['y'], 30) + def testParseValuesWithIndexAssigment5(self): """Different variable types.""" parse_dict = hparam.parse_values('a[0]=5,b[1]=true,c[2]=abc,d[3]=3.14', { @@ -264,24 +302,55 @@ class HParamsTest(test.TestCase): self.assertTrue(isinstance(parse_dict['d'], dict)) self.assertDictEqual(parse_dict['d'], {3: 3.14}) + def testParseValuesWithIndexAssigment5_IgnoreUnknown(self): + """Different variable types.""" + parse_dict = hparam.parse_values( + 'a[0]=5,cc=4,b[1]=true,c[2]=abc,mm=2,d[3]=3.14', + {'a': int, 'b': bool, 'c': str, 'd': float}, + ignore_unknown=True) + self.assertEqual(set(parse_dict.keys()), {'a', 'b', 'c', 'd'}) + self.assertTrue(isinstance(parse_dict['a'], dict)) + self.assertDictEqual(parse_dict['a'], {0: 5}) + self.assertTrue(isinstance(parse_dict['b'], dict)) + self.assertDictEqual(parse_dict['b'], {1: True}) + self.assertTrue(isinstance(parse_dict['c'], dict)) + self.assertDictEqual(parse_dict['c'], {2: 'abc'}) + self.assertTrue(isinstance(parse_dict['d'], dict)) + self.assertDictEqual(parse_dict['d'], {3: 3.14}) + def testParseValuesWithBadIndexAssigment1(self): """Reject assignment of list to variable type.""" with self.assertRaisesRegexp(ValueError, r'Assignment of a list to a list index.'): hparam.parse_values('arr[1]=[1,2,3]', {'arr': int}) + def testParseValuesWithBadIndexAssigment1_IgnoreUnknown(self): + """Reject assignment of list to variable type.""" + with self.assertRaisesRegexp(ValueError, + r'Assignment of a list to a list index.'): + hparam.parse_values( + 'arr[1]=[1,2,3],c=8', {'arr': int}, ignore_unknown=True) + def testParseValuesWithBadIndexAssigment2(self): """Reject if type missing.""" with self.assertRaisesRegexp(ValueError, r'Unknown hyperparameter type for arr'): hparam.parse_values('arr[1]=5', {}) + def testParseValuesWithBadIndexAssigment2_IgnoreUnknown(self): + """Ignore missing type.""" + hparam.parse_values('arr[1]=5', {}, ignore_unknown=True) + def testParseValuesWithBadIndexAssigment3(self): """Reject type of the form name[index].""" with self.assertRaisesRegexp(ValueError, 'Unknown hyperparameter type for arr'): hparam.parse_values('arr[1]=1', {'arr[1]': int}) + def testParseValuesWithBadIndexAssigment3_IgnoreUnknown(self): + """Ignore type of the form name[index].""" + hparam.parse_values('arr[1]=1', {'arr[1]': int}, ignore_unknown=True) + def testWithReusedVariables(self): with self.assertRaisesRegexp(ValueError, 'Multiple assignments to variable \'x\''): diff --git a/tensorflow/contrib/verbs/rdma_mgr.cc b/tensorflow/contrib/verbs/rdma_mgr.cc index 2784bf124ceaacd8e01f0653287fa7f006d0d608..2f2375427862ad1e99a0e6bfc506382d200e9b1d 100644 --- a/tensorflow/contrib/verbs/rdma_mgr.cc +++ b/tensorflow/contrib/verbs/rdma_mgr.cc @@ -277,9 +277,18 @@ void RdmaMgr::InitAllocators() { ProcessState::singleton()->AddCPUFreeVisitor(free_visitor); #if GOOGLE_CUDA + GPUProcessState::singleton()->AddCUDAHostAllocVisitor(0, alloc_visitor); + GPUProcessState::singleton()->AddCUDAHostFreeVisitor(0, free_visitor); + if (IsGDRAvailable()) { // Note we don't free allocated GPU memory so there is no free visitor - int32_t bus_id = TryToReadNumaNode(rdma_adapter_->context_->device) + 1; + + // TODO: This is to fix the 'invalid use of member in static member function + // bug'. + // Waiting for better implementation. + // int32_t bus_id = TryToReadNumaNode(rdma_adapter_->context_->device) + // + 1; + int32_t bus_id = 0; SubAllocator::Visitor cuda_alloc_visitor = [](void* ptr, int gpu_id, size_t num_bytes) { @@ -288,9 +297,6 @@ void RdmaMgr::InitAllocators() { }; GPUProcessState::singleton()->AddGPUAllocVisitor(bus_id, cuda_alloc_visitor); - GPUProcessState::singleton()->AddCUDAHostAllocVisitor(bus_id, - alloc_visitor); - GPUProcessState::singleton()->AddCUDAHostFreeVisitor(bus_id, free_visitor); LOG(INFO) << "Instrumenting GPU allocator with bus_id " << bus_id; } #endif // GOOGLE_CUDA diff --git a/tensorflow/contrib/verbs/verbs_server_lib.cc b/tensorflow/contrib/verbs/verbs_server_lib.cc index 5b72b1604aca2e0c593978c6104322372788eb3c..19ef109f671ee57ce2aceb55110c50aa44352223 100644 --- a/tensorflow/contrib/verbs/verbs_server_lib.cc +++ b/tensorflow/contrib/verbs/verbs_server_lib.cc @@ -33,6 +33,8 @@ RendezvousMgrInterface* NewRdmaRendezvousMgr(const WorkerEnv* env) { return new RdmaRendezvousMgr(env); } +std::once_flag reg_mem_visitors_call; + } // namespace VerbsServer::VerbsServer(const ServerDef& server_def, Env* env) @@ -76,10 +78,6 @@ Status VerbsServer::ChannelCacheFactory(const ServerDef& server_def, return Status::OK(); } -namespace { -std::once_flag reg_mem_visitors_call; -} // namespace - Status VerbsServer::Init(ServiceInitFunction service_func, RendezvousMgrCreationFunction rendezvous_mgr_func) { std::call_once(reg_mem_visitors_call, []() { RdmaMgr::RegMemVisitors(); }); diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD index 575edfe7a936df2a76fd43f76b47b7ac8da3c2e7..8bf1480d33b2d2117fb5c7ddf046262cfeb8a8ab 100644 --- a/tensorflow/core/BUILD +++ b/tensorflow/core/BUILD @@ -49,7 +49,7 @@ # filegroup ":android_proto_srcs" - Protos # filegroup ":android_srcs" - Core sources # cc_library ":android_tensorflow_lib" - Native library -# cc_library ":android_tensorflow_lib_selective_registration" - Native library +# cc_library ":android_tensorflow_lib_lite" - Native library, without ops, # supporting SELECTIVE_REGISTRATION feature. # portable_proto_library ":android_proto_lib" (Google-internal) # @@ -113,7 +113,6 @@ load( "tf_additional_device_tracer_test_flags", "tf_additional_gdr_lib_defines", "tf_additional_human_readable_json_deps", - "tf_additional_logger_deps", "tf_additional_lib_defines", "tf_additional_lib_deps", "tf_additional_lib_hdrs", @@ -446,15 +445,31 @@ cc_library( ) cc_library( - name = "logger", - srcs = tf_platform_srcs(["logger.cc"]), - hdrs = ["platform/logger.h"] + tf_platform_hdrs(["logger.h"]), + name = "logger_interface", + hdrs = ["platform/logger.h"], copts = tf_copts(), visibility = ["//visibility:public"], deps = [ - ":lib", - ":lib_internal", - ] + tf_additional_logger_deps(), + ":lib_proto_parsing", + "@protobuf_archive//:protobuf", + ], +) + +cc_library( + name = "default_logger", + srcs = ["platform/default/logger.cc"], + hdrs = ["platform/logger.h"], + deps = [ + "//tensorflow/core:lib_proto_parsing", + "//tensorflow/core:logger_interface", + ], +) + +cc_library( + name = "logger", + hdrs = ["platform/logger.h"], + visibility = ["//visibility:public"], + deps = ["//tensorflow/core/platform/default/build_config:logger"], ) filegroup( @@ -492,7 +507,10 @@ cc_library( ":platform_env_internal_hdrs", ], copts = tf_copts(), - visibility = ["//tensorflow/core:__subpackages__"], + visibility = [ + "//tensorflow/c:__subpackages__", + "//tensorflow/core:__subpackages__", + ], deps = [ ":error_codes_proto_cc", ":lib", @@ -1608,6 +1626,9 @@ filegroup( "**/*main.cc", "debug/**/*", "framework/op_gen_*", + "framework/node_def_util.*", + "framework/op_kernel.*", + "framework/dataset.*", "lib/jpeg/**/*", "lib/png/**/*", "lib/gif/**/*", @@ -1616,7 +1637,6 @@ filegroup( "util/reporter.*", "platform/**/cuda_libdevice_path.*", "platform/**/logger.cc", - "platform/**/logger.h", "platform/default/test_benchmark.*", "platform/cuda.h", "platform/google/**/*", @@ -1651,6 +1671,9 @@ filegroup( "common_runtime/**/*.cc", "graph/**/*.h", "graph/**/*.cc", + "framework/node_def_util.*", + "framework/op_kernel.*", + "framework/dataset.*", ], exclude = [ "**/*test.*", @@ -1679,6 +1702,9 @@ filegroup( # operators, use :android_tensorflow_lib if you want full operator # support. # +# If you just need TensorFlow types, e.g. Tensors, use +# :android_tensorflow_lib_lite_no_runtime. +# # Compiles to a trivial library on non-Android to prevent irrelevant # build errors. If not building this as part of an android_binary, # a command such as the following must be used: @@ -1689,7 +1715,33 @@ filegroup( cc_library( name = "android_tensorflow_lib_lite", srcs = if_android(["//tensorflow/core:android_srcs"]), - copts = tf_copts(android_optimization_level_override = None), + copts = tf_copts(android_optimization_level_override = None) + [ + "-DSUPPORT_SELECTIVE_REGISTRATION", + ], + linkopts = ["-lz"], + tags = [ + "manual", + "notap", + ], + visibility = ["//visibility:public"], + deps = [ + ":mobile_additional_lib_deps", + ":protos_all_cc_impl", + ":stats_calculator_portable", + "//third_party/eigen3", + "@double_conversion//:double-conversion", + "@nsync//:nsync_cpp", + "@protobuf_archive//:protobuf", + ], + alwayslink = 1, +) + +cc_library( + name = "android_tensorflow_lib_lite_nortti", + srcs = if_android(["//tensorflow/core:android_srcs"]), + copts = tf_copts(android_optimization_level_override = None) + [ + "-DSUPPORT_SELECTIVE_REGISTRATION", + ] + tf_opts_nortti_if_android(), linkopts = ["-lz"], tags = [ "manual", @@ -1797,52 +1849,6 @@ cc_library( alwayslink = 1, ) -# Android library for use with the SELECTIVE_REGISTRATION feature. -# Does not contain operators. In contrast to android_tensorflow_lib_lite, -# this links in framework support for all types, relying on selective -# registration of ops to prune code size. -cc_library( - name = "android_tensorflow_lib_selective_registration", - srcs = if_android(["//tensorflow/core:android_srcs_only_runtime"]), - copts = tf_copts(android_optimization_level_override = None) + [ - "-DSUPPORT_SELECTIVE_REGISTRATION", - ], - linkopts = if_android(["-lz"]), - tags = [ - "manual", - "notap", - ], - visibility = ["//visibility:public"], - deps = [ - ":protos_all_cc_impl", - "@com_google_absl//absl/container:flat_hash_set", - "@protobuf_archive//:protobuf", - ], - alwayslink = 1, -) - -# Android library for use with the SELECTIVE_REGISTRATION feature with -# no proto_rtti. -cc_library( - name = "android_tensorflow_lib_selective_registration_nortti", - srcs = if_android(["//tensorflow/core:android_srcs_only_runtime"]), - copts = tf_copts(android_optimization_level_override = None) + tf_opts_nortti_if_android() + [ - "-DSUPPORT_SELECTIVE_REGISTRATION", - ], - linkopts = if_android(["-lz"]), - tags = [ - "manual", - "notap", - ], - visibility = ["//visibility:public"], - deps = [ - ":protos_all_cc_impl", - "@com_google_absl//absl/container:flat_hash_set", - "@protobuf_archive//:protobuf", - ], - alwayslink = 1, -) - filegroup( name = "android_op_registrations_and_gradients", srcs = glob( @@ -4052,20 +4058,6 @@ tf_cuda_cc_test( ], ) -tf_cc_test_gpu( - name = "cuda_libdevice_path_test", - size = "small", - srcs = ["platform/cuda_libdevice_path_test.cc"], - linkstatic = tf_kernel_tests_linkstatic(), - tags = tf_cuda_tests_tags(), - deps = [ - ":cuda_libdevice_path", - ":lib", - ":test", - ":test_main", - ], -) - tf_cuda_only_cc_test( name = "util_cuda_kernel_helper_test", srcs = [ @@ -4921,7 +4913,7 @@ filegroup( cc_library( name = "cuda_libdevice_path", - srcs = ["platform/cuda_libdevice_path.cc"] + tf_additional_libdevice_srcs(), + srcs = tf_additional_libdevice_srcs(), hdrs = ["platform/cuda_libdevice_path.h"], copts = tf_copts(), data = tf_additional_libdevice_data(), diff --git a/tensorflow/core/api_def/base_api/api_def_CudnnRNNBackpropV3.pbtxt b/tensorflow/core/api_def/base_api/api_def_CudnnRNNBackpropV3.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..7967ca7c5d17abd6451f0cd05c8154c3eaf4766b --- /dev/null +++ b/tensorflow/core/api_def/base_api/api_def_CudnnRNNBackpropV3.pbtxt @@ -0,0 +1,49 @@ +op { + graph_op_name: "CudnnRNNBackpropV3" + visibility: HIDDEN + summary: "Backprop step of CudnnRNNV3." + description: <