diff --git a/tools/bazel.rc b/.bazelrc similarity index 95% rename from tools/bazel.rc rename to .bazelrc index 1fdf51f53e29c7111cf89c016400b710051cf9c6..cd7e13ddfc146208f79be900917b05b694869d72 100644 --- a/tools/bazel.rc +++ b/.bazelrc @@ -76,7 +76,6 @@ build:nonccl --define=no_nccl_support=true build --define=use_fast_cpp_protos=true build --define=allow_oversize_protos=true -build --define=grpc_no_ares=true build --spawn_strategy=standalone build --genrule_strategy=standalone @@ -93,3 +92,11 @@ build:dynamic_kernels --copt=-DAUTOLOAD_DYNAMIC_KERNELS build --define=PREFIX=/usr build --define=LIBDIR=$(PREFIX)/lib build --define=INCLUDEDIR=$(PREFIX)/include + +# Default options should come above this line + +# Options from ./configure +try-import %workspace%/.tf_configure.bazelrc + +# Put user-specific options in .bazelrc.user +try-import %workspace%/.bazelrc.user diff --git a/.gitignore b/.gitignore index 90324058600bee46af56e49028977971848a80de..e1d352c238a1b2d4febe0f5d4a30cfa0c942f7e7 100644 --- a/.gitignore +++ b/.gitignore @@ -1,7 +1,7 @@ .DS_Store .ipynb_checkpoints node_modules -/.bazelrc +/.bazelrc.user /.tf_configure.bazelrc /bazel-* /bazel_pip diff --git a/README.md b/README.md index 044174947a094d43a51f7140dd40ec0f17801d40..519815d006cc33be10132909baf414a4bd843435 100644 --- a/README.md +++ b/README.md @@ -113,11 +113,12 @@ The TensorFlow project strives to abide by generally accepted best practices in Build Type | Status | Artifacts ----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | --------- **IBM s390x** | [![Build Status](http://ibmz-ci.osuosl.org/job/TensorFlow_IBMZ_CI/badge/icon)](http://ibmz-ci.osuosl.org/job/TensorFlow_IBMZ_CI/) | TBA -**IBM ppc64le CPU** | [![Build Status](http://powerci.osuosl.org/job/TensorFlow_PPC64LE_CPU_Build/badge/icon)](http://powerci.osuosl.org/job/TensorFlow_PPC64LE_CPU_Build/) | TBA -**IBM ppc64le GPU** Nightly | [![Build Status](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Nightly_Artifact/badge/icon)](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Nightly_Artifact/) | [Nightly](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Nightly_Artifact/) -**IBM ppc64le GPU** Stable Release | [![Build Status](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Release_Build/badge/icon)](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Release_Build/) | [Release](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Release_Build/) +**Linux ppc64le CPU** Nightly | [![Build Status](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_CPU_Build/badge/icon)](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_CPU_Build/) | [Nightly](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_CPU_Nightly_Artifact/) +**Linux ppc64le CPU** Stable Release | [![Build Status](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_CPU_Release_Build/badge/icon)](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_CPU_Release_Build/) | [Release](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_CPU_Release_Build/) +**Linux ppc64le GPU** Nightly | [![Build Status](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Build/badge/icon)](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Build/) | [Nightly](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Nightly_Artifact/) +**Linux ppc64le GPU** Stable Release | [![Build Status](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Release_Build/badge/icon)](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Release_Build/) | [Release](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Release_Build/) **Linux CPU with Intel® MKL-DNN** Nightly | [![Build Status](https://tensorflow-ci.intel.com/job/tensorflow-mkl-linux-cpu/badge/icon)](https://tensorflow-ci.intel.com/job/tensorflow-mkl-linux-cpu/) | [Nightly](https://tensorflow-ci.intel.com/job/tensorflow-mkl-build-whl-nightly/) -**Linux CPU with Intel® MKL-DNN** Python 2.7
**Linux CPU with Intel® MKL-DNN** Python 3.4
**Linux CPU with Intel® MKL-DNN** Python 3.5
**Linux CPU with Intel® MKL-DNN** Python 3.6 | [![Build Status](https://tensorflow-ci.intel.com/job/tensorflow-mkl-build-release-whl/badge/icon)](https://tensorflow-ci.intel.com/job/tensorflow-mkl-build-release-whl/lastStableBuild) | [1.11.0 py2.7](https://storage.googleapis.com/intel-optimized-tensorflow/tensorflow-1.11.0-cp27-cp27mu-linux_x86_64.whl)
[1.11.0 py3.4](https://storage.googleapis.com/intel-optimized-tensorflow/tensorflow-1.11.0-cp34-cp34m-linux_x86_64.whl)
[1.11.0 py3.5](https://storage.googleapis.com/intel-optimized-tensorflow/tensorflow-1.11.0-cp35-cp35m-linux_x86_64.whl)
[1.11.0 py3.6](https://storage.googleapis.com/intel-optimized-tensorflow/tensorflow-1.11.0-cp36-cp36m-linux_x86_64.whl) +**Linux CPU with Intel® MKL-DNN** Python 2.7
**Linux CPU with Intel® MKL-DNN** Python 3.4
**Linux CPU with Intel® MKL-DNN** Python 3.5
**Linux CPU with Intel® MKL-DNN** Python 3.6 | [![Build Status](https://tensorflow-ci.intel.com/job/tensorflow-mkl-build-release-whl/badge/icon)](https://tensorflow-ci.intel.com/job/tensorflow-mkl-build-release-whl/lastStableBuild) | [1.12.0 py2.7](https://storage.googleapis.com/intel-optimized-tensorflow/tensorflow-1.12.0-cp27-cp27mu-linux_x86_64.whl)
[1.12.0 py3.4](https://storage.googleapis.com/intel-optimized-tensorflow/tensorflow-1.12.0-cp34-cp34m-linux_x86_64.whl)
[1.12.0 py3.5](https://storage.googleapis.com/intel-optimized-tensorflow/tensorflow-1.12.0-cp35-cp35m-linux_x86_64.whl)
[1.12.0 py3.6](https://storage.googleapis.com/intel-optimized-tensorflow/tensorflow-1.12.0-cp36-cp36m-linux_x86_64.whl) ## For more information diff --git a/RELEASE.md b/RELEASE.md index b13b071bd6cf4d3a260c8e248a67d23e1a688498..32abdcea497618918964174a661a6ba872598f65 100644 --- a/RELEASE.md +++ b/RELEASE.md @@ -7,6 +7,8 @@ Serving. * Keras models now support evaluating with a `tf.data.Dataset`. * TensorFlow binaries are built with XLA support linked in by default. +* Ignite Dataset added to contrib/ignite that allows to work with Apache + Ignite. ## Bug Fixes and Other Changes diff --git a/WORKSPACE b/WORKSPACE index 7cc08e0164a202581ad7ebbe107a9e19410e70e4..7057d3f149e766cd2983ecc89509f84c37075602 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -16,30 +16,27 @@ load("@io_bazel_rules_closure//closure:defs.bzl", "closure_repositories") closure_repositories() -http_archive( - name = "base_images_docker", - sha256 = "e2b1b7254270bb7605e814a9dbf6d1e4ae04a11136ff1714fbfdabe3f87f7cf9", - strip_prefix = "base-images-docker-12801524f867e657fbb5d1a74f31618aff181ac6", - urls = ["https://github.com/GoogleCloudPlatform/base-images-docker/archive/12801524f867e657fbb5d1a74f31618aff181ac6.tar.gz"], -) +load("//third_party/toolchains/preconfig/generate:archives.bzl", + "bazel_toolchains_archive") -http_archive( - name = "bazel_toolchains", - sha256 = "15b5858b1b5541ec44df31b94c3b8672815b31d71215a98398761ea9f4c4eedb", - strip_prefix = "bazel-toolchains-6200b238c9c2d137c0d9a7262c80cc71d98e692b", - urls = [ - "https://github.com/bazelbuild/bazel-toolchains/archive/6200b238c9c2d137c0d9a7262c80cc71d98e692b.tar.gz", - ], +bazel_toolchains_archive() + +load( + "@bazel_toolchains//repositories:repositories.bzl", + bazel_toolchains_repositories = "repositories", ) -http_archive( - name = "io_bazel_rules_docker", - sha256 = "29d109605e0d6f9c892584f07275b8c9260803bf0c6fcb7de2623b2bedc910bd", - strip_prefix = "rules_docker-0.5.1", - urls = ["https://github.com/bazelbuild/rules_docker/archive/v0.5.1.tar.gz"], +bazel_toolchains_repositories() + +load( + "@io_bazel_rules_docker//container:container.bzl", + container_repositories = "repositories", ) -load("//third_party/toolchains/preconfig/generate:workspace.bzl", "remote_config_workspace") +container_repositories() + +load("//third_party/toolchains/preconfig/generate:workspace.bzl", + "remote_config_workspace") remote_config_workspace() @@ -47,7 +44,7 @@ remote_config_workspace() # files, in case the parsing of those build files depends on the bazel # version we require here. load("//tensorflow:version_check.bzl", "check_bazel_version_at_least") -check_bazel_version_at_least("0.15.0") +check_bazel_version_at_least("0.18.0") load("//tensorflow:workspace.bzl", "tf_workspace") diff --git a/tensorflow/opensource_only/arm_compiler.BUILD b/arm_compiler.BUILD similarity index 100% rename from tensorflow/opensource_only/arm_compiler.BUILD rename to arm_compiler.BUILD diff --git a/configure.py b/configure.py index 6c905a0be3d685b5921dfbc5bddfbe6471a82625..1e732db26404906901a9eeab97a5e75137ee8388 100644 --- a/configure.py +++ b/configure.py @@ -255,18 +255,6 @@ def setup_python(environ_cp): def reset_tf_configure_bazelrc(): """Reset file that contains customized config settings.""" open(_TF_BAZELRC, 'w').close() - bazelrc_path = os.path.join(_TF_WORKSPACE_ROOT, '.bazelrc') - - data = [] - if os.path.exists(bazelrc_path): - with open(bazelrc_path, 'r') as f: - data = f.read().splitlines() - with open(bazelrc_path, 'w') as f: - for l in data: - if _TF_BAZELRC_FILENAME in l: - continue - f.write('%s\n' % l) - f.write('import %%workspace%%/%s\n' % _TF_BAZELRC_FILENAME) def cleanup_makefile(): """Delete any leftover BUILD files from the Makefile build. @@ -488,11 +476,12 @@ def check_bazel_version(min_version, max_version): if curr_version_int < min_version_int: print('Please upgrade your bazel installation to version %s or higher to ' 'build TensorFlow!' % min_version) - sys.exit(0) - if curr_version_int > max_version_int: + sys.exit(1) + if (curr_version_int > max_version_int and + 'TF_IGNORE_MAX_BAZEL_VERSION' not in os.environ): print('Please downgrade your bazel installation to version %s or lower to ' 'build TensorFlow!' % max_version) - sys.exit(0) + sys.exit(1) return curr_version @@ -1565,11 +1554,9 @@ def main(): # environment variables. environ_cp = dict(os.environ) - check_bazel_version('0.15.0', '0.20.0') + check_bazel_version('0.19.0', '0.20.0') reset_tf_configure_bazelrc() - # Explicitly import tools/bazel.rc, this is needed for Bazel 0.19.0 or later - write_to_bazelrc('import %workspace%/tools/bazel.rc') cleanup_makefile() setup_python(environ_cp) diff --git a/tensorflow/BUILD b/tensorflow/BUILD index fd4b94202aad24a82abef8abd16431f61a8326f0..449a1372edb031c68786d8672e2a1499c2b3d047 100644 --- a/tensorflow/BUILD +++ b/tensorflow/BUILD @@ -267,6 +267,15 @@ config_setting( visibility = ["//visibility:public"], ) +# By default, XLA GPU is compiled into tensorflow when building with +# --config=cuda even when `with_xla_support` is false. The config setting +# here allows us to override the behavior if needed. +config_setting( + name = "no_xla_deps_in_cuda", + define_values = {"no_xla_deps_in_cuda": "true"}, + visibility = ["//visibility:public"], +) + config_setting( name = "with_gdr_support", define_values = {"with_gdr_support": "true"}, @@ -606,9 +615,11 @@ py_library( name = "tensorflow_py", srcs_version = "PY2AND3", visibility = ["//visibility:public"], - deps = [ + deps = select({ + "api_version_2": [], + "//conditions:default": ["//tensorflow/contrib:contrib_py"], + }) + [ ":tensorflow_py_no_contrib", - "//tensorflow/contrib:contrib_py", "//tensorflow/python/estimator:estimator_py", ], ) diff --git a/tensorflow/api_template.__init__.py b/tensorflow/api_template.__init__.py index d81cf067eb07e88e2b8a86cf5643674235eb3f3b..4eba763129a6aef40e3c130d56bf8ab19638b7ca 100644 --- a/tensorflow/api_template.__init__.py +++ b/tensorflow/api_template.__init__.py @@ -20,14 +20,14 @@ from __future__ import print_function as _print_function import os as _os +# API IMPORTS PLACEHOLDER + # pylint: disable=g-bad-import-order from tensorflow.python.tools import component_api_helper as _component_api_helper _component_api_helper.package_hook( parent_package_str=__name__, child_package_str=('tensorflow_estimator.python.estimator.api.estimator')) -# API IMPORTS PLACEHOLDER - # Make sure directory containing top level submodules is in # the __path__ so that "from tensorflow.foo import bar" works. # We're using bitwise, but there's nothing special about that. diff --git a/tensorflow/api_template_v1.__init__.py b/tensorflow/api_template_v1.__init__.py index 65bdb6cb1b5e6fb0656a12b932d767aeacfccd29..21b5277614667bdbd7271ac3e57f5b69d5a19264 100644 --- a/tensorflow/api_template_v1.__init__.py +++ b/tensorflow/api_template_v1.__init__.py @@ -23,13 +23,13 @@ import os as _os # pylint: disable=g-bad-import-order from tensorflow.python import pywrap_tensorflow # pylint: disable=unused-import +# API IMPORTS PLACEHOLDER + from tensorflow.python.tools import component_api_helper as _component_api_helper _component_api_helper.package_hook( parent_package_str=__name__, child_package_str=('tensorflow_estimator.python.estimator.api.estimator')) -# API IMPORTS PLACEHOLDER - from tensorflow.python.util.lazy_loader import LazyLoader # pylint: disable=g-import-not-at-top contrib = LazyLoader('contrib', globals(), 'tensorflow.contrib') del LazyLoader diff --git a/tensorflow/c/c_api.cc b/tensorflow/c/c_api.cc index 94d18eb8b04e3534be547aca5cfbb32da40ffbf6..9580215a317b1a6b1cdacbd430a1764af61be990 100644 --- a/tensorflow/c/c_api.cc +++ b/tensorflow/c/c_api.cc @@ -488,6 +488,7 @@ static TF_Tensor* EmptyTensor(TF_DataType dtype, const TensorShape& shape) { // Non-static for testing. TF_Tensor* TF_TensorFromTensor(const tensorflow::Tensor& src, TF_Status* status) { + TF_SetStatus(status, TF_OK, ""); if (!src.IsInitialized()) { status->status = FailedPrecondition( "attempt to use a tensor with an uninitialized value"); diff --git a/tensorflow/c/eager/c_api.h b/tensorflow/c/eager/c_api.h index f80ae5a6d02d4d613c95cf8486e0fc0aeed3affc..120748ab763a3358b6e38e64bb3b6fd2ea32f7c3 100755 --- a/tensorflow/c/eager/c_api.h +++ b/tensorflow/c/eager/c_api.h @@ -170,23 +170,11 @@ TF_CAPI_EXPORT extern int64_t TFE_TensorHandleDim(TFE_TensorHandle* h, int dim_index, TF_Status* status); -// Returns the device of the operation that produced `h`. -// If `h` was produced by a copy, returns the destination device of -// the copy. Note that returned device name is not always the device -// holding the tensor handle's memory. If you want the latter, use -// TFE_TensorHandleBackingDeviceName. -// This function will block till the operation that produces `h` has completed. -// -// Device on which the kernel of the operation that produced `h` ran. -// -// If `h` was produced by a copy, returns the destination device of -// the copy. -// -// Note that returned device name is not always the device that owns the memory -// that backs the tensor handle. For the latter see -// TFE_TensorHandleBackingDeviceName. -// -// This function will block till the operation that produces `h` has completed. +// Returns the device of the operation that produced `h`. If `h` was produced by +// a copy, returns the destination device of the copy. Note that the returned +// device name is not always the device holding the tensor handle's memory. If +// you want the latter, use TFE_TensorHandleBackingDeviceName. This function +// will block till the operation that produces `h` has completed. TF_CAPI_EXPORT extern const char* TFE_TensorHandleDeviceName( TFE_TensorHandle* h, TF_Status* status); diff --git a/tensorflow/c/env.cc b/tensorflow/c/env.cc index 07b9e8b940c55caf62ae0b81b884bf313d335459..1c35ff9001d0ee1ab0fbae9e1bcc07116fab1065 100644 --- a/tensorflow/c/env.cc +++ b/tensorflow/c/env.cc @@ -159,3 +159,25 @@ TF_CAPI_EXPORT extern uint64_t TF_NowMicros(void) { TF_CAPI_EXPORT extern uint64_t TF_NowSeconds(void) { return ::tensorflow::Env::Default()->NowSeconds(); } + +void TF_DefaultThreadOptions(TF_ThreadOptions* options) { + options->stack_size = 0; + options->guard_size = 0; + options->numa_node = -1; +} + +TF_Thread* TF_StartThread(const TF_ThreadOptions* options, + const char* thread_name, void (*work_func)(void*), + void* param) { + ::tensorflow::ThreadOptions cc_options; + cc_options.stack_size = options->stack_size; + cc_options.guard_size = options->guard_size; + cc_options.numa_node = options->numa_node; + return reinterpret_cast(::tensorflow::Env::Default()->StartThread( + cc_options, thread_name, [=]() { (*work_func)(param); })); +} + +void TF_JoinThread(TF_Thread* thread) { + // ::tensorflow::Thread joins on destruction + delete reinterpret_cast<::tensorflow::Thread*>(thread); +} diff --git a/tensorflow/c/env.h b/tensorflow/c/env.h index 9d27c5da37735042c7476b591e57486dbde33152..15652353cd7e1f1e7d7a4c665703c0166682d790 100644 --- a/tensorflow/c/env.h +++ b/tensorflow/c/env.h @@ -13,6 +13,9 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include +#include + #ifndef TENSORFLOW_C_ENV_H_ #define TENSORFLOW_C_ENV_H_ @@ -23,6 +26,7 @@ limitations under the License. struct TF_WritableFileHandle; struct TF_StringStream; +struct TF_Thread; #ifdef __cplusplus extern "C" { @@ -37,6 +41,20 @@ typedef struct TF_FileStatistics { bool is_directory; } TF_FileStatistics; +typedef struct TF_ThreadOptions { + // Thread stack size to use (in bytes), zero implies that the system default + // will be used. + size_t stack_size; + + // Guard area size to use near thread stacks to use (in bytes), zero implies + // that the system default will be used. + size_t guard_size; + + // The NUMA node to use, -1 implies that there should be no NUMA affinity for + // this thread. + int numa_node; +} TF_ThreadOptions; + // Creates the specified directory. Typical status code are: // * TF_OK - successfully created the directory // * TF_ALREADY_EXISTS - directory already exists @@ -150,6 +168,25 @@ TF_CAPI_EXPORT extern uint64_t TF_NowMicros(void); // Returns the number of seconds since the Unix epoch. TF_CAPI_EXPORT extern uint64_t TF_NowSeconds(void); +// Populates a TF_ThreadOptions struct with system-default values. +TF_CAPI_EXPORT extern void TF_DefaultThreadOptions(TF_ThreadOptions* options); + +// Returns a new thread that is running work_func and is identified +// (for debugging/performance-analysis) by thread_name. +// +// The given param (which may be null) is passed to work_func when the thread +// starts. In this way, data may be passed from the thread back to the caller. +// +// Caller takes ownership of the result and must call TF_JoinThread on it +// eventually. +TF_CAPI_EXPORT extern TF_Thread* TF_StartThread(const TF_ThreadOptions* options, + const char* thread_name, + void (*work_func)(void*), + void* param); + +// Waits for the given thread to finish execution, then deletes it. +TF_CAPI_EXPORT extern void TF_JoinThread(TF_Thread* thread); + #ifdef __cplusplus } #endif diff --git a/tensorflow/c/env_test.cc b/tensorflow/c/env_test.cc index e2206c6befd2167346c64032940d6e8c631e4a3e..687ad024137352662759ec1f43df87e89faca353 100644 --- a/tensorflow/c/env_test.cc +++ b/tensorflow/c/env_test.cc @@ -17,6 +17,7 @@ limitations under the License. #include "tensorflow/c/c_api.h" #include "tensorflow/core/lib/io/path.h" +#include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/types.h" @@ -98,3 +99,29 @@ TEST(TestEnv, TestTimeFunctions) { ASSERT_GE(TF_NowMicros(), 946684800 * 1e6); ASSERT_GE(TF_NowNanos(), 946684800 * 1e9); } + +namespace { + +struct SomeThreadData { + ::tensorflow::mutex mu; + bool did_work = false; +}; + +void SomeThreadFunc(void* data) { + auto* real_data = static_cast(data); + ::tensorflow::mutex_lock l(real_data->mu); + real_data->did_work = true; +} + +} // namespace + +TEST(TestEnv, TestThreads) { + TF_ThreadOptions options; + TF_DefaultThreadOptions(&options); + SomeThreadData data; + TF_Thread* thread = + TF_StartThread(&options, "SomeThreadName", &SomeThreadFunc, &data); + TF_JoinThread(thread); + ::tensorflow::mutex_lock l(data.mu); + ASSERT_TRUE(data.did_work); +} diff --git a/tensorflow/compat_template_v1.__init__.py b/tensorflow/compat_template_v1.__init__.py index 7df80ec01245a7fe820c79d5879458c4cd0a93cb..d58acde09f007bc9df40b08b0ef79c6031ca7941 100644 --- a/tensorflow/compat_template_v1.__init__.py +++ b/tensorflow/compat_template_v1.__init__.py @@ -23,12 +23,12 @@ import os as _os # pylint: disable=g-bad-import-order from tensorflow.python import pywrap_tensorflow # pylint: disable=unused-import +# API IMPORTS PLACEHOLDER + from tensorflow.python.tools import component_api_helper as _component_api_helper _component_api_helper.package_hook( parent_package_str=__name__, child_package_str=('tensorflow_estimator.python.estimator.api.estimator')) -# API IMPORTS PLACEHOLDER - from tensorflow.python.platform import flags # pylint: disable=g-import-not-at-top app.flags = flags # pylint: disable=undefined-variable diff --git a/tensorflow/compiler/aot/tfcompile.bzl b/tensorflow/compiler/aot/tfcompile.bzl index 2dc3e8c9113b37bf9d575ad66783f4ab49478af4..4051664c24cacad4a2d151ad3ac9009015900609 100644 --- a/tensorflow/compiler/aot/tfcompile.bzl +++ b/tensorflow/compiler/aot/tfcompile.bzl @@ -283,7 +283,7 @@ def tf_library( ) # Variables used for gen_test and gen_benchmark. - cpp_class_split = cpp_class.rsplit("::", maxsplit = 2) + cpp_class_split = cpp_class.rsplit("::", 2) if len(cpp_class_split) == 1: no_ns_name = cpp_class_split[0] else: diff --git a/tensorflow/compiler/jit/BUILD b/tensorflow/compiler/jit/BUILD index 15dcbb2641eca031e82db9aa58dee6a14ab0a2cc..d8c88a9fca2db74265b4962e07a66ab214b1d994 100644 --- a/tensorflow/compiler/jit/BUILD +++ b/tensorflow/compiler/jit/BUILD @@ -515,6 +515,7 @@ cc_library( "//tensorflow/compiler/jit/ops:xla_ops", "//tensorflow/compiler/tf2xla:dump_graph", "//tensorflow/compiler/tf2xla:resource_operation_table", + "//tensorflow/compiler/tf2xla:side_effect_util", "//tensorflow/compiler/tf2xla:tf2xla_util", "//tensorflow/compiler/tf2xla:xla_compiler", "//tensorflow/compiler/tf2xla/cc:xla_jit_ops", @@ -613,6 +614,7 @@ tf_cc_test( "//tensorflow/cc:cc_ops", "//tensorflow/cc:cc_ops_internal", "//tensorflow/cc:function_ops", + "//tensorflow/cc:functional_ops", "//tensorflow/cc:ops", "//tensorflow/cc:resource_variable_ops", "//tensorflow/cc:scope", @@ -625,6 +627,7 @@ tf_cc_test( "//tensorflow/compiler/tf2xla/cc:xla_ops", "//tensorflow/compiler/tf2xla/kernels:xla_dummy_ops", "//tensorflow/compiler/tf2xla/kernels:xla_ops", + "//tensorflow/compiler/xla:test", "//tensorflow/core:core_cpu", "//tensorflow/core:framework", "//tensorflow/core:framework_internal", diff --git a/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc b/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc index f478832781cb1dc045d9163d4a6f5e5f64a8a705..03aba97bbe81a11f6366d118ee5bc573d0c6b31b 100644 --- a/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc +++ b/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc @@ -779,7 +779,8 @@ Status Encapsulator::Subgraph::RecordArg( if (inserted) { NodeDef arg_def; NodeDefBuilder builder( - absl::StrCat(src_node->name(), "_", src_slot, "_arg"), kArgOp); + absl::StrCat(src_node->name(), "_", src_slot, "_arg"), kArgOp, + NodeDebugInfo(src_node->def())); DataType dtype = edge->dst()->input_type(edge->dst_input()); builder.Attr("T", dtype); builder.Attr("index", arg_index); @@ -814,7 +815,8 @@ Status Encapsulator::Subgraph::RecordResult( if (inserted) { NodeDef ret_def; NodeDefBuilder builder( - absl::StrCat(src_node->name(), "_", src_slot, "_retval"), kRetValOp); + absl::StrCat(src_node->name(), "_", src_slot, "_retval"), kRetValOp, + NodeDebugInfo(src_node->def())); DataType dtype = src_node->output_type(src_slot); builder.Attr("T", dtype); builder.Attr("index", ret_index); @@ -974,6 +976,7 @@ Status Encapsulator::Subgraph::AddHostComputes( } NodeDef host_compute_def; + // TODO(shikharagarwal): What source node should we use for errors? NodeDefBuilder builder(absl::StrCat("outside_compilation_", oc_subgraph_name, "_host_compute"), kHostComputeOp); @@ -1040,6 +1043,7 @@ Status Encapsulator::Subgraph::MakeSequencingNode(const string& subgraph_name, Graph* graph_out) { if (sequencer_ == nullptr) { NodeDef seq_def; + // TODO(shikharagarwal): What source node should we use for errors? NodeDefBuilder builder(absl::StrCat(subgraph_name, "_sequencer"), "NoOp"); builder.Attr(kXlaHostTransferSequencerAttr, subgraph_name); builder.Device(device_); @@ -1214,7 +1218,8 @@ Status Encapsulator::Subgraph::AddHostComputeKeyPlaceholder( GraphDefBuilder::Options options(graph_out, /*status=*/nullptr); NodeDef key_def; NodeDefBuilder builder( - absl::StrCat(call_node_def_.name(), "_key_placeholder"), "Placeholder"); + absl::StrCat(call_node_def_.name(), "_key_placeholder"), "Placeholder", + NodeDebugInfo(call_node_def_)); builder.Attr("dtype", DT_STRING); builder.Attr("shape", shape_proto); builder.Attr("_host_compute_call_node", call_node_def_.name()); @@ -1248,6 +1253,7 @@ Status Encapsulator::Subgraph::AddRecvAtHostNode( } NodeDef recv_def; + // TODO(shikharagarwal): What source node should we use for errors? NodeDefBuilder builder(absl::StrCat("outside_compilation_", subgraph_name, "_", oc_subgraph_name, "_recv"), kRecvAtHostOp); @@ -1303,6 +1309,7 @@ Status Encapsulator::Subgraph::AddSendFromHostNode( } NodeDef send_def; + // TODO(shikharagarwal): What source node should we use for errors? NodeDefBuilder builder(absl::StrCat("outside_compilation_", subgraph_name, "_", oc_subgraph_name, "_send"), kSendFromHostOp); @@ -1833,8 +1840,9 @@ Node* AddDummyShapedNode(const Node* src_node, int src_port, // Add any Enter nodes required to bring the constant to the correct control // flow frame. while (!control_flow_info[src_node->id()].frame_name.empty()) { + NodeDebugInfo debug_info(*src_node); NodeBuilder enter_builder(options.GetNameForOp("Enter"), "Enter", - options.op_registry()); + options.op_registry(), &debug_info); enter_builder.Attr("frame_name", control_flow_info[src_node->id()].frame_name); enter_builder.Attr("is_constant", true); @@ -2018,7 +2026,8 @@ Status Encapsulator::DoStaticShapeInferenceForOutsideCompilationSend( return errors::InvalidArgument( "Shape inference is not possible for outside_compilation " "SendFromHost node ", - send_node->name(), " because shape of node ", n->name(), + send_node->name(), " because shape of node ", + FormatNodeForError(*n), " will not be known at compilation time."); } } @@ -2047,8 +2056,7 @@ Status Encapsulator::DoStaticShapeInferenceForOutsideCompilationSend( return errors::Internal( "Internal assumption failed while rewriting an outside_compilation " "cluster that contains a while loop. Logic assumes back-edge is to " - "port 1 of a 2-input " - "Merge node."); + "port 1 of a 2-input Merge node."); } // Connect the existing edge to both inputs of the Merge node so that the // graph will be well-formed. diff --git a/tensorflow/compiler/jit/encapsulate_subgraphs_pass_test.cc b/tensorflow/compiler/jit/encapsulate_subgraphs_pass_test.cc index de89be9a3555960dabe7bacd17226c15ae888ae6..8617beec004d0fe912155f054442c5b6249bb6b5 100644 --- a/tensorflow/compiler/jit/encapsulate_subgraphs_pass_test.cc +++ b/tensorflow/compiler/jit/encapsulate_subgraphs_pass_test.cc @@ -299,7 +299,7 @@ REGISTER_OP("XlaHostCompute") .Attr("Toutputs: list(type) >= 0") .Attr("ancestors: list(string) >= 0") .Attr("key: string") - .Attr("shape_inference_graph: string = ''") + .Attr("shape_inference_graph: func") .Attr("shapes: list(shape) >= 0") .SetShapeFn(::tensorflow::shape_inference::UnknownShape); @@ -510,11 +510,7 @@ Status Encapsulate(GraphDef* graphdef, FunctionDefLibrary* library, s = ConvertGraphDefToGraph(options, *graphdef, graph.get()); if (!s.ok()) return s; - s = PerformStaticShapeInferenceBeforeEncapsulation( - graph.get(), "_encapsulate", "_outside"); - if (!s.ok()) return s; - - s = PreprocessForEncapsulation(graph.get(), "_encapsulate", "_outside"); + s = PerformStaticShapeInferenceBeforeEncapsulation(graph.get()); if (!s.ok()) return s; std::unique_ptr graph_out; @@ -550,6 +546,14 @@ Status Encapsulate(GraphDef* graphdef, FunctionDefLibrary* library, graphdef->Swap(&graphdef_out); *library = lib_def->ToProto(); + // Remove "_xla_inferred_shapes" attr. They are added by + // `PerformStaticShapeInferenceBeforeEncapsulation`. + for (FunctionDef& fdef : *library->mutable_function()) { + for (NodeDef& node_def : *fdef.mutable_node_def()) { + node_def.mutable_attr()->erase("_xla_inferred_shapes"); + } + } + return s; } @@ -901,18 +905,22 @@ TEST(EncapsulateSubgraphsTest, OneFunctionOneOutside) { { GraphDefBuilder shape(GraphDefBuilder::kFailImmediately); Node* key_constant = KeyPlaceholder("F1", shape.opts()); - Node* recv = RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O1", - {DT_FLOAT, DT_FLOAT}, shape.opts()); + Node* recv = RecvAtHost( + ops::NodeOut(key_constant, 0), "F1", "O1", {DT_FLOAT, DT_FLOAT}, + shape.opts().WithAttr(kXlaHasHostTransferAttrName, true)); Node* e = Binary(ops::NodeOut(recv, 0), ops::NodeOut(recv, 1), shape.opts() .WithName("E") .WithAttr("_encapsulate", "F1") .WithAttr("_outside", "O1")); - SendFromHost(ops::NodeOut(key_constant, 0), "F1", "O1", {e}, shape.opts()); + SendFromHost(ops::NodeOut(key_constant, 0), "F1", "O1", {e}, + shape.opts().WithAttr(kXlaHasHostTransferAttrName, true)); TF_EXPECT_OK( AddGraphDefToFunctionLibrary(shape, "F1_O1", &library_expected)); } + NameAttrList shape_inference_graph; + shape_inference_graph.set_name("_outside_compilation_shape_inference_F1_O1"); *library_expected.add_function() = test::function::XTimesTwo(); *library_expected.add_function() = FunctionDefHelper::Create( "F1", {"a_0_arg:float", "b_0_arg:float"}, {"f_0_retval_retval:float"}, {}, @@ -931,8 +939,7 @@ TEST(EncapsulateSubgraphsTest, OneFunctionOneOutside) { {"Toutputs", absl::Span({DT_FLOAT})}, {"ancestors", absl::Span({})}, {"key", "host_compute_channel_F1_O1"}, - {"shape_inference_graph", - "_outside_compilation_shape_inference_F1_O1"}, + {"shape_inference_graph", shape_inference_graph}, {"shapes", absl::Span({})}, {"_outside_compilation_subgraph", "O1"}}, {"c"}}, @@ -948,16 +955,18 @@ TEST(EncapsulateSubgraphsTest, OneFunctionOneOutside) { Node* key_constant = KeyPlaceholder("F1", b2.opts().WithName("F1_key_placeholder")); - Node* recv = RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O1", - {DT_FLOAT, DT_FLOAT}, b2.opts()); + Node* recv = RecvAtHost( + ops::NodeOut(key_constant, 0), "F1", "O1", {DT_FLOAT, DT_FLOAT}, + b2.opts().WithAttr(kXlaHasHostTransferAttrName, true)); Node* e = Binary(ops::NodeOut(recv, 0), ops::NodeOut(recv, 1), b2.opts() .WithName("E") - .WithControlInputs({recv, b}) + .WithControlInputs({recv}) .WithAttr("_encapsulate", "F1") .WithAttr("_outside", "O1")); Node* send = SendFromHost(ops::NodeOut(key_constant, 0), "F1", "O1", {e}, - b2.opts().WithControlInput(e)); + b2.opts().WithControlInput(e).WithAttr( + kXlaHasHostTransferAttrName, true)); Node* s = Sequencer( b2.opts().WithName("F1_sequencer").WithControlInputs({recv, send}), @@ -966,9 +975,9 @@ TEST(EncapsulateSubgraphsTest, OneFunctionOneOutside) { NodeBuilder node_builder("F1", "F1", lib_def.get()); node_builder.Input(a).Input(b); Node* call = - b2.opts().WithControlInputs({s}).FinalizeBuilder(&node_builder); + b2.opts().WithControlInputs({s, b}).FinalizeBuilder(&node_builder); - Binary(a, call, b2.opts().WithName("G").WithControlInputs({e})); + Binary(a, call, b2.opts().WithName("G").WithControlInputs({call})); TF_EXPECT_OK(b2.ToGraphDef(&graphdef_expected)); } @@ -1022,14 +1031,16 @@ TEST(EncapsulateSubgraphsTest, OneFunctionTwoOutside) { { GraphDefBuilder shape1(GraphDefBuilder::kFailImmediately); Node* key_constant = KeyPlaceholder("F1", shape1.opts()); - Node* recv = RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O1", - {DT_FLOAT, DT_FLOAT}, shape1.opts()); + Node* recv = RecvAtHost( + ops::NodeOut(key_constant, 0), "F1", "O1", {DT_FLOAT, DT_FLOAT}, + shape1.opts().WithAttr(kXlaHasHostTransferAttrName, true)); Node* e = Binary(ops::NodeOut(recv, 0), ops::NodeOut(recv, 1), shape1.opts() .WithName("E") .WithAttr("_encapsulate", "F1") .WithAttr("_outside", "O1")); - SendFromHost(ops::NodeOut(key_constant, 0), "F1", "O1", {e}, shape1.opts()); + SendFromHost(ops::NodeOut(key_constant, 0), "F1", "O1", {e}, + shape1.opts().WithAttr(kXlaHasHostTransferAttrName, true)); TF_EXPECT_OK( AddGraphDefToFunctionLibrary(shape1, "F1_O1", &library_expected)); } @@ -1037,33 +1048,45 @@ TEST(EncapsulateSubgraphsTest, OneFunctionTwoOutside) { { GraphDefBuilder shape2(GraphDefBuilder::kFailImmediately); Node* key_constant = KeyPlaceholder("F1", shape2.opts()); - Node* recv1 = RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O1", - {DT_FLOAT, DT_FLOAT}, shape2.opts()); + Node* recv1 = RecvAtHost( + ops::NodeOut(key_constant, 0), "F1", "O1", {DT_FLOAT, DT_FLOAT}, + shape2.opts().WithAttr(kXlaHasHostTransferAttrName, true)); Node* e = Binary(ops::NodeOut(recv1, 0), ops::NodeOut(recv1, 1), shape2.opts() .WithName("E") .WithAttr("_encapsulate", "F1") .WithAttr("_outside", "O1")); - Node* recv2 = RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O2", - {DT_FLOAT, DT_FLOAT}, shape2.opts()); + Node* recv2 = RecvAtHost( + ops::NodeOut(key_constant, 0), "F1", "O2", {DT_FLOAT, DT_FLOAT}, + shape2.opts().WithAttr(kXlaHasHostTransferAttrName, true)); + Node* g = Binary(e, ops::NodeOut(recv2, 0), + shape2.opts() + .WithName("G") + .WithAttr("_encapsulate", "F1") + .WithAttr("_outside", "O2")); Node* h = Binary(ops::NodeOut(recv2, 1), e, shape2.opts() .WithName("H") .WithAttr("_encapsulate", "F1") .WithAttr("_outside", "O2")); - SendFromHost(ops::NodeOut(key_constant, 0), "F1", "O2", {h}, shape2.opts()); + SendFromHost(ops::NodeOut(key_constant, 0), "F1", "O2", {g, h}, + shape2.opts().WithAttr(kXlaHasHostTransferAttrName, true)); TF_EXPECT_OK( AddGraphDefToFunctionLibrary(shape2, "F1_O2", &library_expected)); } + NameAttrList shape_inference_graph1, shape_inference_graph2; + shape_inference_graph1.set_name("_outside_compilation_shape_inference_F1_O1"); + shape_inference_graph2.set_name("_outside_compilation_shape_inference_F1_O2"); *library_expected.add_function() = FunctionDefHelper::Create( - "F1", {"a_0_arg:float", "b_0_arg:float"}, {"i_0_retval_retval:float"}, {}, + "F1", {"a_0_arg:float", "b_0_arg:float"}, + {"g_0_retval_retval:float", "i_0_retval_retval:float"}, {}, { {{"C"}, "UnaryTest", {"a_0_arg"}}, {{"D"}, "BinaryTest", {"b_0_arg", "C:o:0"}, {}}, {{"I"}, "UnaryTest", - {"outside_compilation_O2_host_compute:outputs:0"}}, + {"outside_compilation_O2_host_compute:outputs:1"}}, {{"F"}, "BinaryTest", {"C:o:0", "outside_compilation_O1_host_compute:outputs:0"}, @@ -1073,11 +1096,10 @@ TEST(EncapsulateSubgraphsTest, OneFunctionTwoOutside) { "XlaHostCompute", {"F:o:0", "D:o:0"}, {{"Tinputs", absl::Span({DT_FLOAT, DT_FLOAT})}, - {"Toutputs", absl::Span({DT_FLOAT})}, + {"Toutputs", absl::Span({DT_FLOAT, DT_FLOAT})}, {"ancestors", absl::Span({})}, {"key", "host_compute_channel_F1_O2"}, - {"shape_inference_graph", - "_outside_compilation_shape_inference_F1_O2"}, + {"shape_inference_graph", shape_inference_graph2}, {"shapes", absl::Span({})}, {"_outside_compilation_subgraph", "O2"}}, {"F"}}, @@ -1088,13 +1110,13 @@ TEST(EncapsulateSubgraphsTest, OneFunctionTwoOutside) { {"Toutputs", absl::Span({DT_FLOAT})}, {"ancestors", absl::Span({})}, {"key", "host_compute_channel_F1_O1"}, - {"shape_inference_graph", - "_outside_compilation_shape_inference_F1_O1"}, + {"shape_inference_graph", shape_inference_graph1}, {"shapes", absl::Span({})}, {"_outside_compilation_subgraph", "O1"}}, {"D"}}, }, - {{"i_0_retval_retval", "I:o:0"}}); + {{"g_0_retval_retval", "outside_compilation_O2_host_compute:outputs:0"}, + {"i_0_retval_retval", "I:o:0"}}); { std::unique_ptr lib_def( @@ -1105,19 +1127,22 @@ TEST(EncapsulateSubgraphsTest, OneFunctionTwoOutside) { Node* key_constant = KeyPlaceholder("F1", b2.opts().WithName("F1_key_placeholder")); - Node* recv1 = RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O1", - {DT_FLOAT, DT_FLOAT}, b2.opts()); + Node* recv1 = RecvAtHost( + ops::NodeOut(key_constant, 0), "F1", "O1", {DT_FLOAT, DT_FLOAT}, + b2.opts().WithAttr(kXlaHasHostTransferAttrName, true)); Node* e = Binary(ops::NodeOut(recv1, 0), ops::NodeOut(recv1, 1), b2.opts() .WithName("E") - .WithControlInputs({recv1, b}) + .WithControlInputs({recv1}) .WithAttr("_encapsulate", "F1") .WithAttr("_outside", "O1")); Node* send1 = SendFromHost(ops::NodeOut(key_constant, 0), "F1", "O1", {e}, - b2.opts().WithControlInput(e)); + b2.opts().WithControlInput(e).WithAttr( + kXlaHasHostTransferAttrName, true)); - Node* recv2 = RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O2", - {DT_FLOAT, DT_FLOAT}, b2.opts()); + Node* recv2 = RecvAtHost( + ops::NodeOut(key_constant, 0), "F1", "O2", {DT_FLOAT, DT_FLOAT}, + b2.opts().WithAttr(kXlaHasHostTransferAttrName, true)); Node* g = Binary(e, ops::NodeOut(recv2, 0), b2.opts() .WithName("G") @@ -1130,7 +1155,8 @@ TEST(EncapsulateSubgraphsTest, OneFunctionTwoOutside) { .WithAttr("_encapsulate", "F1") .WithAttr("_outside", "O2")); Node* send2 = - SendFromHost(ops::NodeOut(key_constant, 0), "F1", "O2", {h}, b2.opts()); + SendFromHost(ops::NodeOut(key_constant, 0), "F1", "O2", {g, h}, + b2.opts().WithAttr(kXlaHasHostTransferAttrName, true)); Node* s = Sequencer(b2.opts() .WithName("F1_sequencer") @@ -1139,12 +1165,13 @@ TEST(EncapsulateSubgraphsTest, OneFunctionTwoOutside) { NodeBuilder node_builder("F1", "F1", lib_def.get()); node_builder.Input(a).Input(b); - Node* call = b2.opts().WithControlInput(s).FinalizeBuilder(&node_builder); + Node* call = + b2.opts().WithControlInputs({s, b}).FinalizeBuilder(&node_builder); - Binary(g, call, b2.opts().WithName("J")); + Binary(ops::NodeOut(call, 0), ops::NodeOut(call, 1), + b2.opts().WithName("J")); TF_EXPECT_OK(b2.ToGraphDef(&graphdef_expected)); } - TF_EXPECT_GRAPH_EQ(graphdef_expected, graphdef); TF_EXPECT_FUNCTIONDEFLIBRARY_EQ(library_expected, library); } @@ -1196,7 +1223,9 @@ TEST(EncapsulateSubgraphsTest, TwoFunctionsTwoOutside) { *library_expected.add_function() = FunctionDefHelper::Create( "F1", {"a_0_arg:float", "b_0_arg:float"}, - {"f_0_retval_retval:float", "d_0_retval_retval:float"}, {}, + {"e_0_retval_retval:float", "f_0_retval_retval:float", + "d_0_retval_retval:float"}, + {}, { {{"C"}, "UnaryTest", {"a_0_arg"}}, {{"D"}, "BinaryTest", {"b_0_arg", "C:o:0"}}, @@ -1212,35 +1241,37 @@ TEST(EncapsulateSubgraphsTest, TwoFunctionsTwoOutside) { {"Toutputs", absl::Span({DT_FLOAT})}, {"ancestors", absl::Span({})}, {"key", "host_compute_channel_F1_O1"}, - {"shape_inference_graph", ""}, + {"shape_inference_graph", NameAttrList()}, {"shapes", absl::Span({shape_proto_expected})}, {"_outside_compilation_subgraph", "O1"}}, {"D"}}, }, - {{"d_0_retval_retval", "D:o:0"}, {"f_0_retval_retval", "F:o:0"}}); + {{"e_0_retval_retval", "outside_compilation_O1_host_compute:outputs:0"}, + {"d_0_retval_retval", "D:o:0"}, + {"f_0_retval_retval", "F:o:0"}}); *library_expected.add_function() = FunctionDefHelper::Create( - "F2", {"f_0_arg:float", "bridge_e_g_0_arg:float"}, - {"i_0_retval_retval:float", "g_0_retval_retval:float"}, {}, + "F2", {"e_0_arg:float", "f_0_arg:float", "d_0_arg:float"}, + {"g_0_retval_retval:float", "i_0_retval_retval:float"}, {}, { - {{"G"}, "BinaryTest", {"bridge_e_g_0_arg", "f_0_arg"}}, + {{"G"}, "BinaryTest", {"e_0_arg", "f_0_arg"}}, {{"I"}, "BinaryTest", {"f_0_arg", "outside_compilation_O1_host_compute:outputs:0"}}, {{"outside_compilation_O1_host_compute"}, "XlaHostCompute", - {"G:o:0"}, - {{"Tinputs", absl::Span({DT_FLOAT})}, + {"d_0_arg", "G:o:0"}, + {{"Tinputs", absl::Span({DT_FLOAT, DT_FLOAT})}, {"Toutputs", absl::Span({DT_FLOAT})}, {"ancestors", absl::Span({})}, {"key", "host_compute_channel_F2_O1"}, - {"shape_inference_graph", ""}, + {"shape_inference_graph", NameAttrList()}, {"shapes", absl::Span({shape_proto_expected})}, {"_outside_compilation_subgraph", "O1"}}}, }, - {{"i_0_retval_retval", "I:o:0"}, {"g_0_retval_retval", "G:o:0"}}); + {{"g_0_retval_retval", "G:o:0"}, {"i_0_retval_retval", "I:o:0"}}); { std::unique_ptr lib_def( @@ -1251,16 +1282,18 @@ TEST(EncapsulateSubgraphsTest, TwoFunctionsTwoOutside) { Node* key_constant1 = KeyPlaceholder("F1", b2.opts().WithName("F1_key_placeholder")); - Node* recv1 = RecvAtHost(ops::NodeOut(key_constant1, 0), "F1", "O1", - {DT_FLOAT, DT_FLOAT}, b2.opts()); + Node* recv1 = RecvAtHost( + ops::NodeOut(key_constant1, 0), "F1", "O1", {DT_FLOAT, DT_FLOAT}, + b2.opts().WithAttr(kXlaHasHostTransferAttrName, true)); Node* e = Binary(ops::NodeOut(recv1, 0), ops::NodeOut(recv1, 1), b2.opts() .WithName("E") - .WithControlInputs({recv1, b}) + .WithControlInputs({recv1}) .WithAttr("_encapsulate", "F1") .WithAttr("_outside", "O1")); Node* send1 = SendFromHost(ops::NodeOut(key_constant1, 0), "F1", "O1", {e}, - b2.opts().WithControlInput(e)); + b2.opts().WithControlInput(e).WithAttr( + kXlaHasHostTransferAttrName, true)); Node* s1 = Sequencer( b2.opts().WithName("F1_sequencer").WithControlInputs({recv1, send1}), "F1"); @@ -1268,29 +1301,33 @@ TEST(EncapsulateSubgraphsTest, TwoFunctionsTwoOutside) { NodeBuilder node_builder1("F1", "F1", lib_def.get()); node_builder1.Input(a).Input(b); Node* call1 = - b2.opts().WithControlInput(s1).FinalizeBuilder(&node_builder1); + b2.opts().WithControlInputs({s1, b}).FinalizeBuilder(&node_builder1); Node* key_constant2 = KeyPlaceholder("F2", b2.opts().WithName("F2_key_placeholder")); - Node* recv2 = RecvAtHost(ops::NodeOut(key_constant2, 0), "F2", "O1", - {DT_FLOAT}, b2.opts()); - Node* h = Binary(ops::NodeOut(call1, 1), recv2, + Node* recv2 = RecvAtHost( + ops::NodeOut(key_constant2, 0), "F2", "O1", {DT_FLOAT, DT_FLOAT}, + b2.opts().WithAttr(kXlaHasHostTransferAttrName, true)); + Node* h = Binary(recv2, ops::NodeOut(recv2, 1), b2.opts() .WithName("H") .WithAttr("_encapsulate", "F2") .WithAttr("_outside", "O1")); - Node* send2 = SendFromHost(ops::NodeOut(key_constant2, 0), "F2", "O1", {h}, - b2.opts()); + Node* send2 = + SendFromHost(ops::NodeOut(key_constant2, 0), "F2", "O1", {h}, + b2.opts().WithAttr(kXlaHasHostTransferAttrName, true)); Node* s2 = Sequencer( b2.opts().WithName("F2_sequencer").WithControlInputs({recv2, send2}), "F2"); NodeBuilder node_builder2("F2", "F2", lib_def.get()); - node_builder2.Input(call1).Input(e); + node_builder2.Input(call1) + .Input(ops::NodeOut(call1, 1)) + .Input(ops::NodeOut(call1, 2)); Node* call2 = b2.opts() - .WithControlInputs({s2, e, call1}) + .WithControlInputs({s2, call1}) .FinalizeBuilder(&node_builder2); - Binary(ops::NodeOut(call2, 1), call2, b2.opts().WithName("J")); + Binary(call2, ops::NodeOut(call2, 1), b2.opts().WithName("J")); TF_EXPECT_OK(b2.ToGraphDef(&graphdef_expected)); } @@ -1326,8 +1363,7 @@ TEST(EncapsulateSubgraphsTest, TwoFunctionsTwoOutsideDependencyFromOutside) { Node* h = Unary(g, b1.opts() .WithName("H") .WithAttr("_encapsulate", "F2") - .WithAttr("_outside", "O1") - .WithControlInput(e)); + .WithAttr("_outside", "O1")); Node* i = Unary(h, b1.opts().WithName("I").WithAttr("_encapsulate", "F2")); Binary(f, i, b1.opts().WithName("J")); TF_EXPECT_OK(b1.ToGraphDef(&graphdef)); @@ -1358,7 +1394,7 @@ TEST(EncapsulateSubgraphsTest, TwoFunctionsTwoOutsideDependencyFromOutside) { {"Toutputs", absl::Span({DT_FLOAT})}, {"ancestors", absl::Span({})}, {"key", "host_compute_channel_F1_O1"}, - {"shape_inference_graph", ""}, + {"shape_inference_graph", NameAttrList()}, {"shapes", absl::Span({shape_proto_expected})}, {"_outside_compilation_subgraph", "O1"}}, @@ -1380,7 +1416,7 @@ TEST(EncapsulateSubgraphsTest, TwoFunctionsTwoOutsideDependencyFromOutside) { {"Toutputs", absl::Span({DT_FLOAT})}, {"ancestors", absl::Span({})}, {"key", "host_compute_channel_F2_O1"}, - {"shape_inference_graph", ""}, + {"shape_inference_graph", NameAttrList()}, {"shapes", absl::Span({shape_proto_expected})}, {"_outside_compilation_subgraph", "O1"}}}, @@ -1401,7 +1437,7 @@ TEST(EncapsulateSubgraphsTest, TwoFunctionsTwoOutsideDependencyFromOutside) { Node* e = Binary(ops::NodeOut(recv1, 0), ops::NodeOut(recv1, 1), b2.opts() .WithName("E") - .WithControlInputs({recv1, b}) + .WithControlInputs({recv1}) .WithAttr("_encapsulate", "F1") .WithAttr("_outside", "O1")); Node* send1 = SendFromHost(ops::NodeOut(key_constant1, 0), "F1", "O1", {e}, @@ -1413,7 +1449,7 @@ TEST(EncapsulateSubgraphsTest, TwoFunctionsTwoOutsideDependencyFromOutside) { NodeBuilder node_builder1("F1", "F1", lib_def.get()); node_builder1.Input(a).Input(b); Node* call1 = - b2.opts().WithControlInput(s1).FinalizeBuilder(&node_builder1); + b2.opts().WithControlInputs({s1, b}).FinalizeBuilder(&node_builder1); Node* key_constant2 = KeyPlaceholder("F2", b2.opts().WithName("F2_key_placeholder")); @@ -1422,8 +1458,7 @@ TEST(EncapsulateSubgraphsTest, TwoFunctionsTwoOutsideDependencyFromOutside) { Node* h = Unary(recv2, b2.opts() .WithName("H") .WithAttr("_encapsulate", "F2") - .WithAttr("_outside", "O1") - .WithControlInput(e)); + .WithAttr("_outside", "O1")); Node* send2 = SendFromHost(ops::NodeOut(key_constant2, 0), "F2", "O1", {h}, b2.opts()); @@ -1484,12 +1519,12 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationNoInputs) { {"D:o:0", "outside_compilation_O1_host_compute:outputs:0"}}, {{"outside_compilation_O1_host_compute"}, "XlaHostCompute", - {}, - {{"Tinputs", absl::Span({})}, + {"a_0_arg"}, + {{"Tinputs", absl::Span({DT_FLOAT})}, {"Toutputs", absl::Span({DT_FLOAT})}, {"ancestors", absl::Span({})}, {"key", "host_compute_channel_F1_O1"}, - {"shape_inference_graph", ""}, + {"shape_inference_graph", NameAttrList()}, {"shapes", absl::Span({shape_proto_expected})}, {"_outside_compilation_subgraph", "O1"}}}, @@ -1503,16 +1538,19 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationNoInputs) { Node* a = InputShaped(b2.opts().WithName("A")); Node* b = Input(b2.opts().WithName("B")); - Node* e = Unary(a, b2.opts() - .WithName("E") - .WithAttr("_encapsulate", "F1") - .WithAttr("_outside", "O1")); Node* key_constant = KeyPlaceholder("F1", b2.opts().WithName("F1_key_placeholder")); + Node* recv1 = RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O1", + {DT_FLOAT}, b2.opts()); + Node* e = Unary(recv1, b2.opts() + .WithName("E") + .WithAttr("_encapsulate", "F1") + .WithAttr("_outside", "O1")); Node* send1 = SendFromHost(ops::NodeOut(key_constant, 0), "F1", "O1", {e}, b2.opts()); Node* s1 = Sequencer( - b2.opts().WithName("F1_sequencer").WithControlInput(send1), "F1"); + b2.opts().WithName("F1_sequencer").WithControlInputs({send1, recv1}), + "F1"); NodeBuilder node_builder1("F1", "F1", lib_def.get()); node_builder1.Input(a).Input(b); Node* call1 = @@ -1569,12 +1607,12 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationControlInput) { {"D:o:0", "outside_compilation_O1_host_compute:outputs:0"}}, {{"outside_compilation_O1_host_compute"}, "XlaHostCompute", - {}, - {{"Tinputs", absl::Span({})}, + {"a_0_arg"}, + {{"Tinputs", absl::Span({DT_FLOAT})}, {"Toutputs", absl::Span({DT_FLOAT})}, {"ancestors", absl::Span({})}, {"key", "host_compute_channel_F1_O1"}, - {"shape_inference_graph", ""}, + {"shape_inference_graph", NameAttrList()}, {"shapes", absl::Span({shape_proto_expected})}, {"_outside_compilation_subgraph", "O1"}}, @@ -1591,13 +1629,13 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationControlInput) { Node* key_constant = KeyPlaceholder("F1", b2.opts().WithName("F1_key_placeholder")); - Node* recv1 = - RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O1", {}, b2.opts()); - Node* e = Unary(a, b2.opts() - .WithName("E") - .WithControlInput(recv1) - .WithAttr("_encapsulate", "F1") - .WithAttr("_outside", "O1")); + Node* recv1 = RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O1", + {DT_FLOAT}, b2.opts()); + Node* e = Unary(recv1, b2.opts() + .WithName("E") + .WithControlInput(recv1) + .WithAttr("_encapsulate", "F1") + .WithAttr("_outside", "O1")); Node* send1 = SendFromHost(ops::NodeOut(key_constant, 0), "F1", "O1", {e}, b2.opts()); Node* s1 = Sequencer( @@ -1644,8 +1682,27 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationNoOutputs) { FunctionDefLibrary library_expected; GraphDef graphdef_expected; + { + GraphDefBuilder shape1(GraphDefBuilder::kFailImmediately); + Node* key_constant = KeyPlaceholder("F1", shape1.opts()); + Node* recv1 = + RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O1", {DT_FLOAT}, + shape1.opts().WithAttr(kXlaHasHostTransferAttrName, true)); + Node* e = Unary(ops::NodeOut(recv1, 0), shape1.opts() + .WithName("E") + .WithAttr("_encapsulate", "F1") + .WithAttr("_outside", "O1")); + SendFromHost(ops::NodeOut(key_constant, 0), "F1", "O1", {e}, + shape1.opts().WithAttr(kXlaHasHostTransferAttrName, true)); + TF_EXPECT_OK( + AddGraphDefToFunctionLibrary(shape1, "F1_O1", &library_expected)); + } + + NameAttrList shape_inference_graph; + shape_inference_graph.set_name("_outside_compilation_shape_inference_F1_O1"); *library_expected.add_function() = FunctionDefHelper::Create( - "F1", {"a_0_arg:float", "b_0_arg:float"}, {"f_0_retval_retval:float"}, {}, + "F1", {"a_0_arg:float", "b_0_arg:float"}, + {"e_0_retval_retval:float", "f_0_retval_retval:float"}, {}, { {{"C"}, "UnaryTest", {"a_0_arg"}}, {{"D"}, "BinaryTest", {"b_0_arg", "C:o:0"}}, @@ -1654,14 +1711,15 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationNoOutputs) { "XlaHostCompute", {"D:o:0"}, {{"Tinputs", absl::Span({DT_FLOAT})}, - {"Toutputs", absl::Span({})}, + {"Toutputs", absl::Span({DT_FLOAT})}, {"ancestors", absl::Span({})}, {"key", "host_compute_channel_F1_O1"}, - {"shape_inference_graph", ""}, + {"shape_inference_graph", shape_inference_graph}, {"shapes", absl::Span({})}, {"_outside_compilation_subgraph", "O1"}}}, }, - {{"f_0_retval_retval", "F:o:0"}}); + {{"e_0_retval_retval", "outside_compilation_O1_host_compute:outputs:0"}, + {"f_0_retval_retval", "F:o:0"}}); { std::unique_ptr lib_def( @@ -1678,14 +1736,17 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationNoOutputs) { .WithName("E") .WithAttr("_encapsulate", "F1") .WithAttr("_outside", "O1")); + Node* send1 = + SendFromHost(ops::NodeOut(key_constant, 0), "F1", "O1", {e}, b2.opts()); Node* s1 = Sequencer( - b2.opts().WithName("F1_sequencer").WithControlInput(recv1), "F1"); + b2.opts().WithName("F1_sequencer").WithControlInputs({recv1, send1}), + "F1"); NodeBuilder node_builder1("F1", "F1", lib_def.get()); node_builder1.Input(a).Input(b); Node* call1 = b2.opts().WithControlInput(s1).FinalizeBuilder(&node_builder1); - Binary(e, call1, b2.opts().WithName("G")); + Binary(call1, ops::NodeOut(call1, 1), b2.opts().WithName("G")); TF_EXPECT_OK(b2.ToGraphDef(&graphdef_expected)); } @@ -1722,8 +1783,27 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationControlOutput) { FunctionDefLibrary library_expected; GraphDef graphdef_expected; + { + GraphDefBuilder shape1(GraphDefBuilder::kFailImmediately); + Node* key_constant = KeyPlaceholder("F1", shape1.opts()); + Node* recv1 = + RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O1", {DT_FLOAT}, + shape1.opts().WithAttr(kXlaHasHostTransferAttrName, true)); + Node* e = Unary(ops::NodeOut(recv1, 0), shape1.opts() + .WithName("E") + .WithAttr("_encapsulate", "F1") + .WithAttr("_outside", "O1")); + SendFromHost(ops::NodeOut(key_constant, 0), "F1", "O1", {e}, + shape1.opts().WithAttr(kXlaHasHostTransferAttrName, true)); + TF_EXPECT_OK( + AddGraphDefToFunctionLibrary(shape1, "F1_O1", &library_expected)); + } + + NameAttrList shape_inference_graph; + shape_inference_graph.set_name("_outside_compilation_shape_inference_F1_O1"); *library_expected.add_function() = FunctionDefHelper::Create( - "F1", {"a_0_arg:float", "b_0_arg:float"}, {"f_0_retval_retval:float"}, {}, + "F1", {"a_0_arg:float", "b_0_arg:float"}, + {"e_0_retval_retval:float", "f_0_retval_retval:float"}, {}, { {{"C"}, "UnaryTest", {"a_0_arg"}}, {{"D"}, "BinaryTest", {"b_0_arg", "C:o:0"}}, @@ -1736,14 +1816,15 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationControlOutput) { "XlaHostCompute", {"D:o:0"}, {{"Tinputs", absl::Span({DT_FLOAT})}, - {"Toutputs", absl::Span({})}, + {"Toutputs", absl::Span({DT_FLOAT})}, {"ancestors", absl::Span({})}, {"key", "host_compute_channel_F1_O1"}, - {"shape_inference_graph", ""}, + {"shape_inference_graph", shape_inference_graph}, {"shapes", absl::Span({})}, {"_outside_compilation_subgraph", "O1"}}}, }, - {{"f_0_retval_retval", "F:o:0"}}); + {{"e_0_retval_retval", "outside_compilation_O1_host_compute:outputs:0"}, + {"f_0_retval_retval", "F:o:0"}}); { std::unique_ptr lib_def( @@ -1760,7 +1841,7 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationControlOutput) { .WithName("E") .WithAttr("_encapsulate", "F1") .WithAttr("_outside", "O1")); - Node* send1 = SendFromHost(ops::NodeOut(key_constant, 0), "F1", "O1", {}, + Node* send1 = SendFromHost(ops::NodeOut(key_constant, 0), "F1", "O1", {e}, b2.opts().WithControlInput(e)); Node* s1 = Sequencer( b2.opts().WithName("F1_sequencer").WithControlInputs({recv1, send1}), @@ -1770,7 +1851,7 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationControlOutput) { Node* call1 = b2.opts().WithControlInput(s1).FinalizeBuilder(&node_builder1); - Binary(e, call1, b2.opts().WithName("G")); + Binary(call1, ops::NodeOut(call1, 1), b2.opts().WithName("G")); TF_EXPECT_OK(b2.ToGraphDef(&graphdef_expected)); } @@ -1813,22 +1894,45 @@ TEST(EncapsulateSubgraphsTest, FunctionDefLibrary library_expected; GraphDef graphdef_expected; + { + GraphDefBuilder shape1(GraphDefBuilder::kFailImmediately); + Node* key_constant = KeyPlaceholder("F1", shape1.opts()); + Node* recv1 = + RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O1", {DT_FLOAT}, + shape1.opts().WithAttr(kXlaHasHostTransferAttrName, true)); + Node* e = Unary(ops::NodeOut(recv1, 0), shape1.opts() + .WithName("E") + .WithAttr("_encapsulate", "F1") + .WithAttr("_outside", "O1")); + SendFromHost(ops::NodeOut(key_constant, 0), "F1", "O1", {e}, + shape1.opts().WithAttr(kXlaHasHostTransferAttrName, true)); + TF_EXPECT_OK( + AddGraphDefToFunctionLibrary(shape1, "F1_O1", &library_expected)); + } + { GraphDefBuilder shape2(GraphDefBuilder::kFailImmediately); Node* key_constant = KeyPlaceholder("F1", shape2.opts()); - Node* recv2 = RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O2", - {DT_FLOAT}, shape2.opts()); + Node* recv2 = + RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O2", {DT_FLOAT}, + shape2.opts().WithAttr(kXlaHasHostTransferAttrName, true)); Node* g = Unary(ops::NodeOut(recv2, 0), shape2.opts() .WithName("G") .WithAttr("_encapsulate", "F1") .WithAttr("_outside", "O2")); - SendFromHost(ops::NodeOut(key_constant, 0), "F1", "O2", {g}, shape2.opts()); + SendFromHost(ops::NodeOut(key_constant, 0), "F1", "O2", {g}, + shape2.opts().WithAttr(kXlaHasHostTransferAttrName, true)); TF_EXPECT_OK( AddGraphDefToFunctionLibrary(shape2, "F1_O2", &library_expected)); } + NameAttrList shape_inference_graph1; + shape_inference_graph1.set_name("_outside_compilation_shape_inference_F1_O1"); + NameAttrList shape_inference_graph2; + shape_inference_graph2.set_name("_outside_compilation_shape_inference_F1_O2"); *library_expected.add_function() = FunctionDefHelper::Create( - "F1", {"a_0_arg:float", "b_0_arg:float"}, {"h_0_retval_retval:float"}, {}, + "F1", {"a_0_arg:float", "b_0_arg:float"}, + {"e_0_retval_retval:float", "h_0_retval_retval:float"}, {}, { {{"C"}, "UnaryTest", {"a_0_arg"}}, {{"D"}, "BinaryTest", {"b_0_arg", "C:o:0"}}, @@ -1836,6 +1940,16 @@ TEST(EncapsulateSubgraphsTest, {{"H"}, "UnaryTest", {"outside_compilation_O2_host_compute:outputs:0"}}, + {{"outside_compilation_O1_host_compute"}, + "XlaHostCompute", + {"a_0_arg"}, + {{"Tinputs", absl::Span({DT_FLOAT})}, + {"Toutputs", absl::Span({DT_FLOAT})}, + {"ancestors", absl::Span({})}, + {"key", "host_compute_channel_F1_O1"}, + {"shape_inference_graph", shape_inference_graph1}, + {"shapes", absl::Span({})}, + {"_outside_compilation_subgraph", "O1"}}}, {{"outside_compilation_O2_host_compute"}, "XlaHostCompute", {"F:o:0"}, @@ -1843,12 +1957,12 @@ TEST(EncapsulateSubgraphsTest, {"Toutputs", absl::Span({DT_FLOAT})}, {"ancestors", absl::Span({})}, {"key", "host_compute_channel_F1_O2"}, - {"shape_inference_graph", - "_outside_compilation_shape_inference_F1_O2"}, + {"shape_inference_graph", shape_inference_graph2}, {"shapes", absl::Span({})}, {"_outside_compilation_subgraph", "O2"}}}, }, - {{"h_0_retval_retval", "H:o:0"}}); + {{"e_0_retval_retval", "outside_compilation_O1_host_compute:outputs:0"}, + {"h_0_retval_retval", "H:o:0"}}); { std::unique_ptr lib_def( @@ -1856,30 +1970,39 @@ TEST(EncapsulateSubgraphsTest, GraphDefBuilder b2(GraphDefBuilder::kFailImmediately, lib_def.get()); Node* a = Input(b2.opts().WithName("A")); Node* b = Input(b2.opts().WithName("B")); - - Node* e = Unary(a, b2.opts() - .WithName("E") - .WithAttr("_encapsulate", "F1") - .WithAttr("_outside", "O1")); Node* key_constant = KeyPlaceholder("F1", b2.opts().WithName("F1_key_placeholder")); - Node* recv = RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O2", - {DT_FLOAT}, b2.opts()); - Node* g = Unary(recv, b2.opts() - .WithName("G") - .WithAttr("_encapsulate", "F1") - .WithAttr("_outside", "O2") - .WithControlInput(e)); - Node* send = - SendFromHost(ops::NodeOut(key_constant, 0), "F1", "O2", {g}, b2.opts()); - Node* s1 = Sequencer( - b2.opts().WithName("F1_sequencer").WithControlInputs({recv, send}), - "F1"); + Node* recv1 = + RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O1", {DT_FLOAT}, + b2.opts().WithAttr(kXlaHasHostTransferAttrName, true)); + + Node* e = Unary(recv1, b2.opts() + .WithName("E") + .WithAttr("_encapsulate", "F1") + .WithAttr("_outside", "O1")); + Node* send1 = + SendFromHost(ops::NodeOut(key_constant, 0), "F1", "O1", {e}, + b2.opts().WithAttr(kXlaHasHostTransferAttrName, true)); + Node* recv2 = + RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O2", {DT_FLOAT}, + b2.opts().WithAttr(kXlaHasHostTransferAttrName, true)); + Node* g = Unary(recv2, b2.opts() + .WithName("G") + .WithAttr("_encapsulate", "F1") + .WithAttr("_outside", "O2") + .WithControlInput(e)); + Node* send2 = + SendFromHost(ops::NodeOut(key_constant, 0), "F1", "O2", {g}, + b2.opts().WithAttr(kXlaHasHostTransferAttrName, true)); + Node* s1 = Sequencer(b2.opts() + .WithName("F1_sequencer") + .WithControlInputs({recv1, send1, recv2, send2}), + "F1"); NodeBuilder node_builder1("F1", "F1", lib_def.get()); node_builder1.Input(a).Input(b).ControlInput(s1); Node* call1 = b2.opts().FinalizeBuilder(&node_builder1); - Binary(e, call1, b2.opts().WithName("I")); + Binary(call1, ops::NodeOut(call1, 1), b2.opts().WithName("I")); TF_EXPECT_OK(b2.ToGraphDef(&graphdef_expected)); } @@ -1925,19 +2048,24 @@ TEST(EncapsulateSubgraphsTest, { GraphDefBuilder shape1(GraphDefBuilder::kFailImmediately); Node* key_constant = KeyPlaceholder("F1", shape1.opts()); - Node* recv2 = RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O1", - {DT_FLOAT}, shape1.opts()); + Node* recv2 = + RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O1", {DT_FLOAT}, + shape1.opts().WithAttr(kXlaHasHostTransferAttrName, true)); Node* e = Unary(ops::NodeOut(recv2, 0), shape1.opts() .WithName("E") .WithAttr("_encapsulate", "F1") .WithAttr("_outside", "O1")); - SendFromHost(ops::NodeOut(key_constant, 0), "F1", "O1", {e}, shape1.opts()); + SendFromHost(ops::NodeOut(key_constant, 0), "F1", "O1", {e}, + shape1.opts().WithAttr(kXlaHasHostTransferAttrName, true)); TF_EXPECT_OK( AddGraphDefToFunctionLibrary(shape1, "F1_O1", &library_expected)); } + NameAttrList shape_inference_graph; + shape_inference_graph.set_name("_outside_compilation_shape_inference_F1_O1"); *library_expected.add_function() = FunctionDefHelper::Create( - "F1", {"a_0_arg:float", "b_0_arg:float"}, {"h_0_retval_retval:float"}, {}, + "F1", {"a_0_arg:float", "b_0_arg:float"}, + {"e_0_retval_retval:float", "h_0_retval_retval:float"}, {}, { {{"C"}, "UnaryTest", {"a_0_arg"}}, {{"D"}, "BinaryTest", {"b_0_arg", "C:o:0"}}, @@ -1945,6 +2073,16 @@ TEST(EncapsulateSubgraphsTest, "UnaryTest", {"outside_compilation_O1_host_compute:outputs:0"}}, {{"H"}, "UnaryTest", {"F:o:0"}}, + {{"outside_compilation_O2_host_compute"}, + "XlaHostCompute", + {"a_0_arg"}, + {{"Tinputs", absl::Span({DT_FLOAT})}, + {"Toutputs", absl::Span({})}, + {"ancestors", absl::Span({})}, + {"key", "host_compute_channel_F1_O2"}, + {"shape_inference_graph", NameAttrList()}, + {"shapes", absl::Span({})}, + {"_outside_compilation_subgraph", "O2"}}}, {{"outside_compilation_O1_host_compute"}, "XlaHostCompute", {"D:o:0"}, @@ -1952,12 +2090,12 @@ TEST(EncapsulateSubgraphsTest, {"Toutputs", absl::Span({DT_FLOAT})}, {"ancestors", absl::Span({})}, {"key", "host_compute_channel_F1_O1"}, - {"shape_inference_graph", - "_outside_compilation_shape_inference_F1_O1"}, + {"shape_inference_graph", shape_inference_graph}, {"shapes", absl::Span({})}, {"_outside_compilation_subgraph", "O1"}}}, }, - {{"h_0_retval_retval", "H:o:0"}}); + {{"e_0_retval_retval", "outside_compilation_O1_host_compute:outputs:0"}, + {"h_0_retval_retval", "H:o:0"}}); { std::unique_ptr lib_def( @@ -1968,27 +2106,33 @@ TEST(EncapsulateSubgraphsTest, Node* key_constant = KeyPlaceholder("F1", b2.opts().WithName("F1_key_placeholder")); - Node* recv = RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O1", - {DT_FLOAT}, b2.opts()); - Node* e = Unary(recv, b2.opts() - .WithName("E") - .WithAttr("_encapsulate", "F1") - .WithAttr("_outside", "O1")); + Node* recv1 = + RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O1", {DT_FLOAT}, + b2.opts().WithAttr(kXlaHasHostTransferAttrName, true)); + Node* e = Unary(recv1, b2.opts() + .WithName("E") + .WithAttr("_encapsulate", "F1") + .WithAttr("_outside", "O1")); Node* send = - SendFromHost(ops::NodeOut(key_constant, 0), "F1", "O1", {e}, b2.opts()); - /*Node* g =*/Unary(a, b2.opts() - .WithName("G") - .WithAttr("_encapsulate", "F1") - .WithAttr("_outside", "O2") - .WithControlInput(e)); - Node* s1 = Sequencer( - b2.opts().WithName("F1_sequencer").WithControlInputs({recv, send}), - "F1"); + SendFromHost(ops::NodeOut(key_constant, 0), "F1", "O1", {e}, + b2.opts().WithAttr(kXlaHasHostTransferAttrName, true)); + Node* recv2 = + RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O2", {DT_FLOAT}, + b2.opts().WithAttr(kXlaHasHostTransferAttrName, true)); + /*Node* g =*/Unary(recv2, b2.opts() + .WithName("G") + .WithAttr("_encapsulate", "F1") + .WithAttr("_outside", "O2") + .WithControlInput(e)); + Node* s1 = Sequencer(b2.opts() + .WithName("F1_sequencer") + .WithControlInputs({recv1, recv2, send}), + "F1"); NodeBuilder node_builder1("F1", "F1", lib_def.get()); node_builder1.Input(a).Input(b).ControlInput(s1); Node* call1 = b2.opts().FinalizeBuilder(&node_builder1); - Binary(e, call1, b2.opts().WithName("I")); + Binary(call1, ops::NodeOut(call1, 1), b2.opts().WithName("I")); TF_EXPECT_OK(b2.ToGraphDef(&graphdef_expected)); } @@ -2039,19 +2183,24 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationClusterDependency) { { GraphDefBuilder shape1(GraphDefBuilder::kFailImmediately); Node* key_constant = KeyPlaceholder("F1", shape1.opts()); - Node* recv2 = RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O1", - {DT_FLOAT}, shape1.opts()); + Node* recv2 = + RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O1", {DT_FLOAT}, + shape1.opts().WithAttr(kXlaHasHostTransferAttrName, true)); Node* e = Unary(ops::NodeOut(recv2, 0), shape1.opts() .WithName("E") .WithAttr("_encapsulate", "F1") .WithAttr("_outside", "O1")); - SendFromHost(ops::NodeOut(key_constant, 0), "F1", "O1", {e}, shape1.opts()); + SendFromHost(ops::NodeOut(key_constant, 0), "F1", "O1", {e}, + shape1.opts().WithAttr(kXlaHasHostTransferAttrName, true)); TF_EXPECT_OK( AddGraphDefToFunctionLibrary(shape1, "F1_O1", &library_expected)); } + NameAttrList shape_inference_graph; + shape_inference_graph.set_name("_outside_compilation_shape_inference_F1_O1"); *library_expected.add_function() = FunctionDefHelper::Create( - "F1", {"a_0_arg:float", "b_0_arg:float"}, {"h_0_retval_retval:float"}, {}, + "F1", {"a_0_arg:float", "b_0_arg:float"}, + {"e_0_retval_retval:float", "h_0_retval_retval:float"}, {}, {{{"C"}, "UnaryTest", {"a_0_arg"}}, {{"D"}, "BinaryTest", {"b_0_arg", "C:o:0"}}, {{"F"}, "UnaryTest", {"outside_compilation_O1_host_compute:outputs:0"}}, @@ -2063,8 +2212,7 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationClusterDependency) { {"Toutputs", absl::Span({DT_FLOAT})}, {"ancestors", absl::Span({})}, {"key", "host_compute_channel_F1_O1"}, - {"shape_inference_graph", - "_outside_compilation_shape_inference_F1_O1"}, + {"shape_inference_graph", shape_inference_graph}, {"shapes", absl::Span({})}, {"_outside_compilation_subgraph", "O1"}}}, {{"outside_compilation_O2_host_compute"}, @@ -2074,7 +2222,7 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationClusterDependency) { {"Toutputs", absl::Span({})}, {"ancestors", absl::Span({})}, {"key", "host_compute_channel_F1_O2"}, - {"shape_inference_graph", ""}, + {"shape_inference_graph", NameAttrList()}, {"shapes", absl::Span({})}, {"_outside_compilation_subgraph", "O2"}}, {}}, @@ -2085,11 +2233,12 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationClusterDependency) { {"Toutputs", absl::Span({})}, {"ancestors", absl::Span({})}, {"key", "host_compute_channel_F1_O3"}, - {"shape_inference_graph", ""}, + {"shape_inference_graph", NameAttrList()}, {"shapes", absl::Span({})}, {"_outside_compilation_subgraph", "O3"}}, {}}}, - {{"h_0_retval_retval", "H:o:0"}}); + {{"e_0_retval_retval", "outside_compilation_O1_host_compute:outputs:0"}, + {"h_0_retval_retval", "H:o:0"}}); { std::unique_ptr lib_def( @@ -2100,23 +2249,27 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationClusterDependency) { Node* key_constant = KeyPlaceholder("F1", b2.opts().WithName("F1_key_placeholder")); - Node* recv1 = RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O1", - {DT_FLOAT}, b2.opts()); + Node* recv1 = + RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O1", {DT_FLOAT}, + b2.opts().WithAttr(kXlaHasHostTransferAttrName, true)); Node* e = Unary(recv1, b2.opts() .WithName("E") .WithAttr("_encapsulate", "F1") .WithAttr("_outside", "O1")); Node* send = - SendFromHost(ops::NodeOut(key_constant, 0), "F1", "O1", {e}, b2.opts()); - Node* recv2 = RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O2", - {DT_FLOAT}, b2.opts()); + SendFromHost(ops::NodeOut(key_constant, 0), "F1", "O1", {e}, + b2.opts().WithAttr(kXlaHasHostTransferAttrName, true)); + Node* recv2 = + RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O2", {DT_FLOAT}, + b2.opts().WithAttr(kXlaHasHostTransferAttrName, true)); Node* g = Unary(recv2, b2.opts() .WithName("G") .WithAttr("_encapsulate", "F1") .WithAttr("_outside", "O2") .WithControlInput(e)); - Node* recv3 = RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O3", - {DT_FLOAT}, b2.opts()); + Node* recv3 = + RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O3", {DT_FLOAT}, + b2.opts().WithAttr(kXlaHasHostTransferAttrName, true)); /*Node* i =*/Binary(recv3, e, b2.opts() .WithName("I") @@ -2131,7 +2284,7 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationClusterDependency) { node_builder1.Input(a).Input(b).ControlInput(s1); Node* call1 = b2.opts().FinalizeBuilder(&node_builder1); - Binary(e, call1, b2.opts().WithName("J")); + Binary(call1, ops::NodeOut(call1, 1), b2.opts().WithName("J")); TF_EXPECT_OK(b2.ToGraphDef(&graphdef_expected)); } @@ -2167,14 +2320,44 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationNoInputsOrOutputs) { FunctionDefLibrary library_expected; GraphDef graphdef_expected; + { + GraphDefBuilder shape1(GraphDefBuilder::kFailImmediately); + Node* key_constant = KeyPlaceholder("F1", shape1.opts()); + Node* recv2 = + RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O1", {DT_FLOAT}, + shape1.opts().WithAttr(kXlaHasHostTransferAttrName, true)); + Node* e = Unary(ops::NodeOut(recv2, 0), shape1.opts() + .WithName("E") + .WithAttr("_encapsulate", "F1") + .WithAttr("_outside", "O1")); + SendFromHost(ops::NodeOut(key_constant, 0), "F1", "O1", {e}, + shape1.opts().WithAttr(kXlaHasHostTransferAttrName, true)); + TF_EXPECT_OK( + AddGraphDefToFunctionLibrary(shape1, "F1_O1", &library_expected)); + } + + NameAttrList shape_inference_graph; + shape_inference_graph.set_name("_outside_compilation_shape_inference_F1_O1"); *library_expected.add_function() = FunctionDefHelper::Create( - "F1", {"a_0_arg:float", "b_0_arg:float"}, {"f_0_retval_retval:float"}, {}, + "F1", {"a_0_arg:float", "b_0_arg:float"}, + {"e_0_retval_retval:float", "f_0_retval_retval:float"}, {}, { {{"C"}, "UnaryTest", {"a_0_arg"}}, {{"D"}, "BinaryTest", {"b_0_arg", "C:o:0"}}, {{"F"}, "UnaryTest", {"D:o:0"}}, + {{"outside_compilation_O1_host_compute"}, + "XlaHostCompute", + {"a_0_arg"}, + {{"Tinputs", absl::Span({DT_FLOAT})}, + {"Toutputs", absl::Span({DT_FLOAT})}, + {"ancestors", absl::Span({})}, + {"key", "host_compute_channel_F1_O1"}, + {"shape_inference_graph", shape_inference_graph}, + {"shapes", absl::Span({})}, + {"_outside_compilation_subgraph", "O1"}}}, }, - {{"f_0_retval_retval", "F:o:0"}}); + {{"e_0_retval_retval", "outside_compilation_O1_host_compute:outputs:0"}, + {"f_0_retval_retval", "F:o:0"}}); { std::unique_ptr lib_def( @@ -2183,15 +2366,26 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationNoInputsOrOutputs) { Node* a = Input(b2.opts().WithName("A")); Node* b = Input(b2.opts().WithName("B")); - Node* e = Unary(a, b2.opts() - .WithName("E") - .WithAttr("_encapsulate", "F1") - .WithAttr("_outside", "O1")); + Node* key_constant = + KeyPlaceholder("F1", b2.opts().WithName("F1_key_placeholder")); + Node* recv = + RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O1", {DT_FLOAT}, + b2.opts().WithAttr(kXlaHasHostTransferAttrName, true)); + Node* e = Unary(recv, b2.opts() + .WithName("E") + .WithAttr("_encapsulate", "F1") + .WithAttr("_outside", "O1")); + Node* send = + SendFromHost(ops::NodeOut(key_constant, 0), "F1", "O1", {e}, + b2.opts().WithAttr(kXlaHasHostTransferAttrName, true)); + Node* s = Sequencer( + b2.opts().WithName("F1_sequencer").WithControlInputs({recv, send}), + "F1"); NodeBuilder node_builder1("F1", "F1", lib_def.get()); - node_builder1.Input(a).Input(b); + node_builder1.Input(a).Input(b).ControlInput(s); Node* call1 = b2.opts().FinalizeBuilder(&node_builder1); - Binary(e, call1, b2.opts().WithName("G")); + Binary(call1, ops::NodeOut(call1, 1), b2.opts().WithName("G")); TF_EXPECT_OK(b2.ToGraphDef(&graphdef_expected)); } @@ -2236,20 +2430,22 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationShapeInference) { { GraphDefBuilder shape(GraphDefBuilder::kFailImmediately); Node* key_constant = KeyPlaceholder("F1", shape.opts()); - Node* recv = RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O1", - {DT_FLOAT}, shape.opts()); - Node* a = InputShaped(shape.opts().WithName("A")); - Node* c = Unary(a, shape.opts().WithName("C")); - Node* e = BinaryUnknownShape(c, recv, + Node* recv = RecvAtHost( + ops::NodeOut(key_constant, 0), "F1", "O1", {DT_FLOAT, DT_FLOAT}, + shape.opts().WithAttr(kXlaHasHostTransferAttrName, true)); + Node* e = BinaryUnknownShape(recv, ops::NodeOut(recv, 1), shape.opts() .WithName("E") .WithAttr("_encapsulate", "F1") .WithAttr("_outside", "O1")); - SendFromHost(ops::NodeOut(key_constant, 0), "F1", "O1", {e}, shape.opts()); + SendFromHost(ops::NodeOut(key_constant, 0), "F1", "O1", {e}, + shape.opts().WithAttr(kXlaHasHostTransferAttrName, true)); TF_EXPECT_OK( AddGraphDefToFunctionLibrary(shape, "F1_O1", &library_expected)); } + NameAttrList shape_inference_graph; + shape_inference_graph.set_name("_outside_compilation_shape_inference_F1_O1"); *library_expected.add_function() = test::function::XTimesTwo(); *library_expected.add_function() = FunctionDefHelper::Create( "F1", {"b_0_arg:float", "c_0_arg:float"}, {"f_0_retval_retval:float"}, {}, @@ -2262,13 +2458,12 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationShapeInference) { {"outside_compilation_O1_host_compute"}}, {{"outside_compilation_O1_host_compute"}, "XlaHostCompute", - {"c:o:0"}, - {{"Tinputs", absl::Span({DT_FLOAT})}, + {"c_0_arg", "c:o:0"}, + {{"Tinputs", absl::Span({DT_FLOAT, DT_FLOAT})}, {"Toutputs", absl::Span({DT_FLOAT})}, {"ancestors", absl::Span({})}, {"key", "host_compute_channel_F1_O1"}, - {"shape_inference_graph", - "_outside_compilation_shape_inference_F1_O1"}, + {"shape_inference_graph", shape_inference_graph}, {"shapes", absl::Span({})}, {"_outside_compilation_subgraph", "O1"}}, {"c"}}, @@ -2285,16 +2480,18 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationShapeInference) { Node* key_constant = KeyPlaceholder("F1", b2.opts().WithName("F1_key_placeholder")); - Node* recv = RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O1", - {DT_FLOAT}, b2.opts()); - Node* e = BinaryUnknownShape(c, ops::NodeOut(recv, 0), + Node* recv = RecvAtHost( + ops::NodeOut(key_constant, 0), "F1", "O1", {DT_FLOAT, DT_FLOAT}, + b2.opts().WithAttr(kXlaHasHostTransferAttrName, true)); + Node* e = BinaryUnknownShape(recv, ops::NodeOut(recv, 1), b2.opts() .WithName("E") - .WithControlInputs({recv, b}) + .WithControlInputs({recv}) .WithAttr("_encapsulate", "F1") .WithAttr("_outside", "O1")); Node* send = SendFromHost(ops::NodeOut(key_constant, 0), "F1", "O1", {e}, - b2.opts().WithControlInput(e)); + b2.opts().WithControlInput(e).WithAttr( + kXlaHasHostTransferAttrName, true)); Node* s = Sequencer( b2.opts().WithName("F1_sequencer").WithControlInputs({recv, send}), @@ -2303,9 +2500,9 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationShapeInference) { NodeBuilder node_builder("F1", "F1", lib_def.get()); node_builder.Input(b).Input(c); Node* call = - b2.opts().WithControlInputs({s, c}).FinalizeBuilder(&node_builder); + b2.opts().WithControlInputs({s, b, c}).FinalizeBuilder(&node_builder); - Binary(a, call, b2.opts().WithName("G").WithControlInputs({e})); + Binary(a, call, b2.opts().WithName("G").WithControlInputs({call})); TF_EXPECT_OK(b2.ToGraphDef(&graphdef_expected)); } diff --git a/tensorflow/compiler/jit/encapsulate_util.cc b/tensorflow/compiler/jit/encapsulate_util.cc index 1f4b9c90a4ff0b1166cdb7b5942771b350740ef3..2264806d6bdabd9f26d9f83b681524399f996317 100644 --- a/tensorflow/compiler/jit/encapsulate_util.cc +++ b/tensorflow/compiler/jit/encapsulate_util.cc @@ -62,517 +62,6 @@ void ReplaceAttr(Node* n, const string& attr_name, const T& value) { n->AddAttr(attr_name, value); } -// Step 1a ~ 1d for PreprocessForEncapsulation(). See comments of -// PreprocessForEncapsulation() for details. -Status ProcessControlEdges(Graph* g, const string& xla_computation_attr_name, - const string& outside_compilation_attr_name) { - // Gather edges to remove. We should not remove the edge while iterating. - std::vector edges_to_remove; - for (const Edge* e : g->edges()) { - if (!e->IsControlEdge()) { - continue; - } - - auto src_xla_computation = - GetStringAttr(*e->src(), xla_computation_attr_name); - auto dst_xla_computation = - GetStringAttr(*e->dst(), xla_computation_attr_name); - auto src_outside_compilation = - GetStringAttr(*e->src(), outside_compilation_attr_name); - auto dst_outside_compilation = - GetStringAttr(*e->dst(), outside_compilation_attr_name); - - if (!src_xla_computation && !dst_xla_computation) { - continue; - } else if (src_xla_computation && !dst_xla_computation) { - if (src_outside_compilation) { - // Case 1c: outside compilation to host computation control edge. - edges_to_remove.push_back(e); - - TF_RETURN_IF_ERROR(AppendToListAttr( - e->dst(), kXlaControlDependenciesAttrName, e->src()->name())); - } - } else if (!src_xla_computation && dst_xla_computation) { - if (dst_outside_compilation) { - // Case 1c: host computation control to outside compilation edge. - edges_to_remove.push_back(e); - - TF_RETURN_IF_ERROR(AppendToListAttr( - e->dst(), kXlaControlDependenciesAttrName, e->src()->name())); - } - } else { // src_xla_computation && dst_xla_computation - if (*src_xla_computation != *dst_xla_computation) { - if (src_outside_compilation && dst_outside_compilation) { - // Case 1b: outside compilation to outside compilation control edge. - edges_to_remove.push_back(e); - - TF_RETURN_IF_ERROR(AppendToListAttr( - e->dst(), kXlaControlDependenciesAttrName, e->src()->name())); - } else if (src_outside_compilation && !dst_outside_compilation) { - // Case 1a: outside compilation to another XLA computaition control - // edge. - TF_RETURN_IF_ERROR(AppendToListAttr( - e->src(), kXlaConnectedToOtherXlaComputationAttrName, - *dst_xla_computation)); - } else if (!src_outside_compilation && dst_outside_compilation) { - // Case 1a: another XLA computaition to outside compilation control - // edge. - TF_RETURN_IF_ERROR(AppendToListAttr( - e->dst(), kXlaConnectedFromOtherXlaComputationAttrName, - *src_xla_computation)); - } - } - } - } - - for (auto e : edges_to_remove) { - g->RemoveEdge(e); - } - return Status::OK(); -} - -// Step 2 for PreprocessForEncapsulation(). See comments of -// PreprocessForEncapsulation() for details. -Status ProcessXlaToXlaDataEdges(Graph* g, - const string& xla_computation_attr_name, - const string& outside_compilation_attr_name) { - // Gather edges between XLA computations. Notice that we do not store `Edge*` - // directly because we remove some nodes while adding Identity nodes, and - // those Edge pointers might be invalidated. - struct EdgeInfo { - int dst_input, dst_node_id; - }; - std::vector edges; - for (const Edge* e : g->edges()) { - if (e->IsControlEdge()) { - continue; - } - - auto src_xla_computation = - GetStringAttr(*e->src(), xla_computation_attr_name); - auto dst_xla_computation = - GetStringAttr(*e->dst(), xla_computation_attr_name); - auto src_outside_compilation = - GetStringAttr(*e->src(), outside_compilation_attr_name); - auto dst_outside_compilation = - GetStringAttr(*e->dst(), outside_compilation_attr_name); - if (!src_xla_computation || !dst_xla_computation) { - continue; - } - - if (*src_xla_computation != *dst_xla_computation) { - if (src_outside_compilation || dst_outside_compilation) { - edges.push_back(EdgeInfo{e->dst_input(), e->dst()->id()}); - VLOG(4) << "XLA -> XLA edge: " << e->DebugString(); - } - } - } - - // For each XLA -> XLA edge, add an Identity node between src and dst. - for (int i = 0; i < edges.size(); i++) { - Node* dst = g->FindNodeId(edges[i].dst_node_id); - const Edge* e; - TF_RETURN_IF_ERROR(dst->input_edge(edges[i].dst_input, &e)); - Node* src = e->src(); - int src_output = e->src_output(), dst_input = e->dst_input(); - g->RemoveEdge(e); - - // Create Identity node, and connect it between `src` and `dst`. - string identity_node_name = - absl::StrCat("bridge_", src->name(), "_", dst->name()); - DataType dtype = src->output_type(src_output); - TF_ASSIGN_OR_RETURN(Node * identity_node, - BuildIdentityNode(g, identity_node_name, dtype, src, - /*requested_device=*/absl::nullopt)); - identity_node->AddAttr(kBridgeSourceNodeAttrName, src->name()); - g->AddEdge(src, src_output, identity_node, 0); - g->AddEdge(identity_node, 0, dst, dst_input); - - // Replace `e->dst()` because its input node changed. - NodeDef new_def = dst->def(); - *new_def.mutable_input(dst_input) = identity_node->name(); - TF_ASSIGN_OR_RETURN(Node * dst_replace_node, ReplaceNode(g, dst, new_def)); - - // Other edge in `edges` might have `e->dst()` as src or dst - // node. Before removing `e->dst()`, replace those edges with corresponding - // edges for `dst_replace_node`. - for (int j = i + 1; j < edges.size(); j++) { - if (edges[j].dst_node_id == edges[i].dst_node_id) { - edges[j].dst_node_id = dst_replace_node->id(); - } - } - } - return Status::OK(); -} - -// Step 3 for PreprocessForEncapsulation(). See comments of -// PreprocessForEncapsulation() for details. -Status ProcessDataEdgeBetweenOutsideCompilationAndHostComputation( - Graph* g, const string& xla_computation_attr_name, - const string& outside_compilation_attr_name) { - // Gather edges between outside compilation and host computation. Notice that - // we do not store `Edge*` directly because we remove some nodes while adding - // Identity nodes, and those Edge pointers might be invalidated. - struct EdgeInfo { - int dst_input, dst_node_id; - bool is_host_to_outside_compilation; - }; - std::vector edges; - for (const Edge* e : g->edges()) { - if (e->IsControlEdge()) { - continue; - } - - if (e->src()->attrs().Find(xla_computation_attr_name) == nullptr && - e->dst()->attrs().Find(xla_computation_attr_name) != nullptr && - e->dst()->attrs().Find(outside_compilation_attr_name) != nullptr) { - edges.push_back(EdgeInfo{e->dst_input(), e->dst()->id(), - /*is_host_to_outside_compilation=*/true}); - VLOG(4) << "Host -> oc edge: " << e->DebugString(); - } else if (e->dst()->attrs().Find(xla_computation_attr_name) == nullptr && - e->src()->attrs().Find(xla_computation_attr_name) != nullptr && - e->src()->attrs().Find(outside_compilation_attr_name) != - nullptr) { - edges.push_back(EdgeInfo{e->dst_input(), e->dst()->id(), - /*is_host_to_outside_compilation=*/false}); - VLOG(4) << "Oc -> host edge: " << e->DebugString(); - } - } - - // Remove the edge from host to outside compilation. Add a placeholder as - // outside compilation node input. - std::map, Node*> placeholders; - for (int i = 0; i < edges.size(); i++) { - Node* dst = g->FindNodeId(edges[i].dst_node_id); - const Edge* e; - TF_RETURN_IF_ERROR(dst->input_edge(edges[i].dst_input, &e)); - Node* src = e->src(); - int src_output = e->src_output(), dst_input = e->dst_input(); - g->RemoveEdge(e); - - // Find or create placeholder node. - string new_name = - edges[i].is_host_to_outside_compilation - ? absl::StrCat(src->name(), "_host_to_oc_placeholder_", src_output) - : absl::StrCat(src->name(), "_oc_to_host_placeholder_", src_output); - auto placeholder_index = std::make_pair(src->name(), src_output); - auto iter = placeholders.find(placeholder_index); - Node* placeholder_node; - if (iter == placeholders.end()) { - NodeDefBuilder placeholder_builder(new_name, "Placeholder"); - placeholder_builder.Attr("dtype", src->output_type(src_output)); - if (edges[i].is_host_to_outside_compilation) { - placeholder_builder.Attr(kHostToOutsideCompilationOriginalNodeAttrName, - src->name()); - placeholder_builder.Attr(kHostToOutsideCompilationSrcOutputAttrName, - src_output); - // If this placeholder node is in outside compilation, we need to set - // `xla_computation_attr_name` and `outside_compilation_attr_name`. - string xla_computation_attr, outside_compilation_attr; - TF_RETURN_IF_ERROR(GetNodeAttr(dst->attrs(), xla_computation_attr_name, - &xla_computation_attr)); - TF_RETURN_IF_ERROR(GetNodeAttr(dst->attrs(), - outside_compilation_attr_name, - &outside_compilation_attr)); - placeholder_builder.Attr(xla_computation_attr_name, - xla_computation_attr); - placeholder_builder.Attr(outside_compilation_attr_name, - outside_compilation_attr); - } else { - placeholder_builder.Attr(kOutsideCompilationToHostOriginalNodeAttrName, - src->name()); - placeholder_builder.Attr(kOutsideCompilationToHostSrcOutputAttrName, - src_output); - } - NodeDef placeholder_def; - TF_RETURN_IF_ERROR(placeholder_builder.Finalize(&placeholder_def)); - Status s; - placeholder_node = g->AddNode(placeholder_def, &s); - TF_RETURN_IF_ERROR(s); - placeholders[placeholder_index] = placeholder_node; - } else { - placeholder_node = iter->second; - } - g->AddEdge(placeholder_node, 0, dst, dst_input); - - // Replace `e->dst()` because its input node changed. - NodeDef new_def = dst->def(); - *new_def.mutable_input(dst_input) = placeholder_node->name(); - TF_ASSIGN_OR_RETURN(Node * dst_replace_node, ReplaceNode(g, dst, new_def)); - - // Other edge in `edges` might have `e->dst()` as src or dst - // node. Before removing `e->dst()`, replace those edges with corresponding - // edges for `dst_replace_node`. - for (int j = i + 1; j < edges.size(); j++) { - if (edges[j].dst_node_id == edges[i].dst_node_id) { - edges[j].dst_node_id = dst_replace_node->id(); - } - } - } - return Status::OK(); -} - -// Step 1 for `PostprocessForEncapsulation`. See comments of -// `PostprocessForEncapsulation` for details. -Status RemovePlaceholderBetweenOutsideCompilationAndHostComputation(Graph* g) { - // Gather all outside compilation to host computation nodes. - struct PlaceHolderNodeInfo { - Node* n; - bool is_host_to_oc; - }; - std::vector placeholder_nodes; - for (Node* n : g->nodes()) { - if (n->type_string() == "Placeholder") { - if (HasNodeAttr(n->def(), - kOutsideCompilationToHostOriginalNodeAttrName)) { - placeholder_nodes.push_back({n, false}); - } else if (HasNodeAttr(n->def(), - kHostToOutsideCompilationOriginalNodeAttrName)) { - placeholder_nodes.push_back({n, true}); - } - } - } - - // Remove the placeholder nodes, and reconnect original edge. - auto node_name_index = g->BuildNodeNameIndex(); - for (auto placeholder_iter : placeholder_nodes) { - Node* n = placeholder_iter.n; - - string node_name; - int node_src_output; - if (placeholder_iter.is_host_to_oc) { - TF_RETURN_IF_ERROR( - GetNodeAttr(n->attrs(), kHostToOutsideCompilationOriginalNodeAttrName, - &node_name)); - TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), - kHostToOutsideCompilationSrcOutputAttrName, - &node_src_output)); - } else { - TF_RETURN_IF_ERROR( - GetNodeAttr(n->attrs(), kOutsideCompilationToHostOriginalNodeAttrName, - &node_name)); - TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), - kOutsideCompilationToHostSrcOutputAttrName, - &node_src_output)); - } - auto iter = node_name_index.find(node_name); - if (iter == node_name_index.end()) { - return errors::Internal( - "Cannot find original node for oc -> host placeholder node ", - node_name); - } - - // Change all usage node to use the original node instead. - Node* original_node = iter->second; - std::vector control_edges; - std::vector data_edges; - for (auto e : n->out_edges()) { - if (e->IsControlEdge()) { - control_edges.push_back(e); - } else { - data_edges.push_back({e->dst(), e->src_output(), e->dst_input()}); - } - } - for (const Edge* e : control_edges) { - g->AddControlEdge(original_node, e->dst()); - g->RemoveEdge(e); - } - for (int i = 0; i < data_edges.size(); i++) { - Node* dst = data_edges[i].dst; - NodeDef new_def = dst->def(); - int dst_input = data_edges[i].dst_input; - *new_def.mutable_input(dst_input) = - absl::StrCat(original_node->name(), ":", node_src_output); - TF_ASSIGN_OR_RETURN(Node * replace_node, ReplaceNode(g, dst, new_def)); - - const Edge* edge_to_replace = nullptr; - TF_RETURN_IF_ERROR(replace_node->input_edge(dst_input, &edge_to_replace)); - g->RemoveEdge(edge_to_replace); - g->AddEdge(original_node, node_src_output, replace_node, dst_input); - - // Other edges might have `dst` as dst node. Update those edges with - // `replace_node`. - for (int j = i + 1; j < data_edges.size(); j++) { - if (data_edges[j].dst == dst) { - data_edges[j].dst = replace_node; - } - } - - // Other placeholder node might have `dst` as original node. Update - // `node_name_index` with `replace_node`. - node_name_index[replace_node->name()] = replace_node; - } - - // Remove placeholder node. - g->RemoveNode(n); - } - return Status::OK(); -} - -// Step 2 for `PostprocessForEncapsulation`. See comments of -// `PostprocessForEncapsulation` for details. -Status RemoveIdentityBetweenDifferentXlaComputation(Graph* g) { - // Gather Identity nodes to remove. - std::vector bridge_nodes; - for (Node* n : g->nodes()) { - if (n->type_string() == "Identity" && - HasNodeAttr(n->def(), kBridgeSourceNodeAttrName)) { - bridge_nodes.push_back(n); - } - } - - // Remove the identity nodes, and reconnect the original edge. - for (int i = 0; i < bridge_nodes.size(); i++) { - Node* n = bridge_nodes[i]; - const Edge* src_edge = nullptr; - TF_RETURN_IF_ERROR(n->input_edge(0, &src_edge)); - - // Change all usage node to use the original node instead. - std::vector control_edges; - std::vector data_edges; - for (auto e : n->out_edges()) { - if (e->IsControlEdge()) { - control_edges.push_back(e); - } else { - data_edges.push_back({e->dst(), e->src_output(), e->dst_input()}); - } - } - for (const Edge* e : control_edges) { - g->AddControlEdge(src_edge->src(), e->dst()); - g->RemoveEdge(e); - } - for (int j = 0; j < data_edges.size(); j++) { - Node* dst = data_edges[j].dst; - NodeDef new_def = dst->def(); - int dst_input = data_edges[j].dst_input; - *new_def.mutable_input(dst_input) = - absl::StrCat(src_edge->src()->name(), ":", src_edge->src_output()); - TF_ASSIGN_OR_RETURN(Node * replace_node, ReplaceNode(g, dst, new_def)); - - const Edge* edge_to_replace = nullptr; - TF_RETURN_IF_ERROR(replace_node->input_edge(dst_input, &edge_to_replace)); - g->RemoveEdge(edge_to_replace); - g->AddEdge(src_edge->src(), src_edge->src_output(), replace_node, - dst_input); - - // Other edges might have `dst` as dst node. Update those edges with - // `replace_node`. - for (int k = j + 1; k < data_edges.size(); k++) { - if (data_edges[k].dst == dst) { - data_edges[k].dst = replace_node; - } - } - - // The node we replaced might be in `bridge_nodes`. If so, update - // `bridge_nodes` to use the replaced node. - for (int k = i + 1; k < bridge_nodes.size(); k++) { - if (bridge_nodes[k] == dst) { - bridge_nodes[k] = replace_node; - } - } - } - - // Remove Identity node. - g->RemoveNode(n); - } - return Status::OK(); -} - -// Step 3 for `PostprocessForEncapsulation`. See comments of -// `PostprocessForEncapsulation` for details. -// We do not need to worry about removed nodes in step 1 and 2; -// `PreprocessForEncapsulation` will not record control dependencies for those -// remvoed nodes in the first place. -Status AddControlDependencies( - Graph* g, const std::unordered_map& cluster_node_names) { - auto node_name_index = g->BuildNodeNameIndex(); - - // Reconnect outside compilation to outside compilation control edge. - for (Node* n : g->nodes()) { - std::vector control_deps; - Status s = - GetNodeAttr(n->attrs(), kXlaControlDependenciesAttrName, &control_deps); - if (!s.ok()) { - if (s.code() != error::NOT_FOUND) { - return s; - } else { - continue; - } - } else { - n->ClearAttr(kXlaControlDependenciesAttrName); - for (const string& control_input : control_deps) { - auto iter = node_name_index.find(control_input); - if (iter == node_name_index.end()) { - return errors::Internal("Cannot find original node for ", - control_input); - } - g->AddControlEdge(iter->second, n); - } - } - } - - // Reconnect outside compilation to XLA computation control edge. - for (Node* n : g->nodes()) { - std::vector control_deps; - Status s = GetNodeAttr( - n->attrs(), kXlaConnectedToOtherXlaComputationAttrName, &control_deps); - if (!s.ok()) { - if (s.code() != error::NOT_FOUND) { - return s; - } else { - continue; - } - } else { - n->ClearAttr(kXlaConnectedToOtherXlaComputationAttrName); - for (const string& control_input : control_deps) { - auto iter = cluster_node_names.find(control_input); - if (iter == cluster_node_names.end()) { - return errors::Internal("Cannot find cluster node for ", - control_input); - } - auto iter2 = node_name_index.find(iter->second); - if (iter2 == node_name_index.end()) { - return errors::Internal("Cannot find cluster node for ", - iter->second); - } - g->AddControlEdge(n, iter2->second); - } - } - } - - // Reconnect XLA computation to outside compilation control edge. - for (Node* n : g->nodes()) { - std::vector control_deps; - Status s = - GetNodeAttr(n->attrs(), kXlaConnectedFromOtherXlaComputationAttrName, - &control_deps); - if (!s.ok()) { - if (s.code() != error::NOT_FOUND) { - return s; - } else { - continue; - } - } else { - n->ClearAttr(kXlaConnectedFromOtherXlaComputationAttrName); - for (const string& control_input : control_deps) { - auto iter = cluster_node_names.find(control_input); - if (iter == cluster_node_names.end()) { - return errors::Internal("Cannot find cluster node for ", - control_input); - } - auto iter2 = node_name_index.find(iter->second); - if (iter2 == node_name_index.end()) { - return errors::Internal("Cannot find cluster node for ", - iter->second); - } - g->AddControlEdge(iter2->second, n); - } - } - } - - return Status::OK(); -} - // Step 1 for `PreprocessEdgesBetweenOutsideCompilations`. See comments of // `PreprocessEdgesBetweenOutsideCompilations` for details. Status PreprocessControlEdgesBetweenOutsideCompilations( @@ -811,20 +300,6 @@ Status PostprocessControlEdgesBetweenOutsideCompilations( const char kXlaInferredShapesAttrName[] = "_xla_inferred_shapes"; -const char kXlaConnectedToOtherXlaComputationAttrName[] = - "_xla_connected_to_other_xla_computation"; -const char kXlaConnectedFromOtherXlaComputationAttrName[] = - "_xla_connected_from_other_xla_computation"; -const char kXlaControlDependenciesAttrName[] = "_xla_control_dependencies"; -const char kBridgeSourceNodeAttrName[] = "_xla_bridge_src"; -const char kOutsideCompilationToHostOriginalNodeAttrName[] = - "_xla_oc_to_host_node_name"; -const char kOutsideCompilationToHostSrcOutputAttrName[] = - "_xla_oc_to_host_src_output"; -const char kHostToOutsideCompilationOriginalNodeAttrName[] = - "_xla_host_to_oc_node_name"; -const char kHostToOutsideCompilationSrcOutputAttrName[] = - "_xla_host_to_oc_src_output"; const char kXlaConnectedToXlaComputationAttrName[] = "_xla_connected_to_xla_computation"; const char kXlaConnectedFromXlaComputationAttrName[] = @@ -835,32 +310,7 @@ const char kOutsideCompilationSrcOutputAttrName[] = "_xla_oc_to_oc_src_output"; const char kXlaControlDependenciesWithinXlaClusterAttrName[] = "_xla_control_dependencies_within_xla_cluster"; -Status PerformStaticShapeInferenceBeforeEncapsulation( - Graph* g, const string& xla_computation_attr_name, - const string& outside_compilation_attr_name) { - // Find all outside compilation to XLA computation data edges. - std::unordered_set outside_compilation_send_nodes; - for (auto e : g->edges()) { - if (e->IsControlEdge()) { - continue; - } - - auto src_computation = GetStringAttr(*e->src(), xla_computation_attr_name); - auto dst_computation = GetStringAttr(*e->dst(), xla_computation_attr_name); - if (!src_computation || !dst_computation || - *src_computation != *dst_computation) { - continue; - } - - auto src_outside_compilation = - GetStringAttr(*e->src(), outside_compilation_attr_name); - auto dst_outside_compilation = - GetStringAttr(*e->dst(), outside_compilation_attr_name); - if (src_outside_compilation && !dst_outside_compilation) { - outside_compilation_send_nodes.insert(e->src()); - } - } - +Status PerformStaticShapeInferenceBeforeEncapsulation(Graph* g) { // Perform shape inference. std::map arg_shapes; GraphShapeInfo shape_info; @@ -868,55 +318,21 @@ Status PerformStaticShapeInferenceBeforeEncapsulation( InferShapes(g, arg_shapes, /*fnlib_def=*/nullptr, &shape_info)); // Add attribute for output shapes. - for (Node* n : outside_compilation_send_nodes) { - auto iter = shape_info.find(n->name()); - if (iter == shape_info.end()) { - continue; - } - + auto node_name_index = g->BuildNodeNameIndex(); + for (auto iter : shape_info) { std::vector output_shapes; - std::transform(iter->second.begin(), iter->second.end(), + std::transform(iter.second.begin(), iter.second.end(), std::back_inserter(output_shapes), [](const InferredShape& inferred_shape) { return inferred_shape.shape; }); + Node* n = node_name_index[iter.first]; n->AddAttr(kXlaInferredShapesAttrName, output_shapes); } return Status::OK(); } -Status PreprocessForEncapsulation(Graph* g, - const string& xla_computation_attr_name, - const string& outside_compilation_attr_name) { - TF_RETURN_IF_ERROR(ProcessControlEdges(g, xla_computation_attr_name, - outside_compilation_attr_name)); - TF_RETURN_IF_ERROR(ProcessXlaToXlaDataEdges(g, xla_computation_attr_name, - outside_compilation_attr_name)); - TF_RETURN_IF_ERROR(ProcessDataEdgeBetweenOutsideCompilationAndHostComputation( - g, xla_computation_attr_name, outside_compilation_attr_name)); - return Status::OK(); -} - -Status PostprocessForEncapsulation( - Graph* g, const string& xla_computation_attr_name, - const string& outside_compilation_attr_name, - const std::unordered_map& clusters) { - // The `node` pointer in `XlaClusterInfo` might be invalidated in step 1/2, - // but the node name won't change. Record cluster node name for - // `AddControlDependencies`. - std::unordered_map cluster_node_names; - for (const auto& iter : clusters) { - cluster_node_names[iter.first] = iter.second.node->name(); - } - - TF_RETURN_IF_ERROR( - RemovePlaceholderBetweenOutsideCompilationAndHostComputation(g)); - TF_RETURN_IF_ERROR(RemoveIdentityBetweenDifferentXlaComputation(g)); - TF_RETURN_IF_ERROR(AddControlDependencies(g, cluster_node_names)); - return Status::OK(); -} - Status PreprocessEdgesBetweenOutsideCompilations( Graph* g, const string& outside_compilation_attr_name) { // Remove edges from source node to outside compilation nodes, and edges diff --git a/tensorflow/compiler/jit/encapsulate_util.h b/tensorflow/compiler/jit/encapsulate_util.h index e363bc5754ac395bae262dc67a780a0173efaf5e..c9f16d14168163e11bb19092f566f1de8724aca3 100644 --- a/tensorflow/compiler/jit/encapsulate_util.h +++ b/tensorflow/compiler/jit/encapsulate_util.h @@ -27,51 +27,13 @@ namespace tensorflow { // a list of PartialTensorShape objects. extern const char kXlaInferredShapesAttrName[]; -// Infer output shapes for outside compilation nodes which have output data -// edges to XLA computation nodes. These shapes will be used later by XLA -// compiler as output shapes of the outside compilation's XlaHostCompute op. -// XLA computation nodes will be mark by attr `xla_computation_attr_name`; -// outside compilation nodes will be marked by both attr -// `xla_computation_attr_name` and `outside_compilation_attr_name`. -// -// Those outside compilation nodes will be marked with attribute -// `kXlaInferredShapesAttrName`. +// Infers output shapes for all nodes in graph `g`. The output shapes will be +// stored in node attribute `kXlaInferredShapesAttrName`. // // We have to perform shape inference before encapsulation because after // encapsulation, some nodes will be encapsulated into function call, and shape // inference does not handle function call at the moment. -Status PerformStaticShapeInferenceBeforeEncapsulation( - Graph* g, const string& xla_computation_attr_name, - const string& outside_compilation_attr_name); - -// Attribute indicating that some ops in other XLA computation has control -// dependency on this node. Attribute value will be a list of string (XLA -// computation names). -extern const char kXlaConnectedToOtherXlaComputationAttrName[]; - -// Attribute indicating that this node has control dependency on some ops in -// other XLA computation. Attribute value will be a list of string (XLA -// computation names). -extern const char kXlaConnectedFromOtherXlaComputationAttrName[]; - -// Attribute indicating that this node has control dependencies on some other -// nodes. Attribute value will be a list of string (node names). -extern const char kXlaControlDependenciesAttrName[]; - -// Attribute indicating that this is an Identity node added to act as a bridge -// between different XLA computations. Attribute value will be string (source -// node name). -extern const char kBridgeSourceNodeAttrName[]; - -// Attribute indicating that this is an Placeholder node added to act as a -// temporary input node for an outside compilation node. Attribute value will be -// string (original input node name). -extern const char kOutsideCompilationToHostOriginalNodeAttrName[]; - -// Attribute indicating that this is an Placeholder node added to act as a -// temporary input node for an outside compilation node. Attribute value will be -// int (src_output for original edge). -extern const char kOutsideCompilationToHostSrcOutputAttrName[]; +Status PerformStaticShapeInferenceBeforeEncapsulation(Graph* g); // Attribute indicating that some ops in this node's XLA computation has control // dependency on this node. Attribute value will always be "true". @@ -81,16 +43,6 @@ extern const char kXlaConnectedToXlaComputationAttrName[]; // this node's XLA computation. Attribute value will always be "true". extern const char kXlaConnectedFromXlaComputationAttrName[]; -// Attribute indicating that this is an Placeholder node added to act as a -// temporary input node for an host node. Attribute value will be string -// (original input node name). -extern const char kHostToOutsideCompilationOriginalNodeAttrName[]; - -// Attribute indicating that this is an Placeholder node added to act as a -// temporary input node for a host node. Attribute value will be int (src_output -// for original edge). -extern const char kHostToOutsideCompilationSrcOutputAttrName[]; - // Attribute indicating that this is an Placeholder node added to act as a // temporary input node for an outside compilation node. Attribute value will be // string (original input node name). @@ -106,27 +58,6 @@ extern const char kOutsideCompilationSrcOutputAttrName[]; // (node names). extern const char kXlaControlDependenciesWithinXlaClusterAttrName[]; -// Preprocesses edges between different XLA clusters for encapsulation. It will -// perform the following operations in order: -// -// 1a. For control edges between outside compilation and another XLA -// computation, add attr "kXlaConnected{From, To}OtherXlaComputationAttrName -// = XLA computation node name" to the outside compilation node. -// 1b. For control edges between different outside compilations (in different -// XLA computations), remove the edge and add attr -// "kXlaControlDependenciesAttrName = src node name" to dst node. -// 1c. For control edges between outside compilation and host computation, -// remove the edge and add attr "kXlaControlDependenciesAttrName = src node -// name" to dst node. -// 2. For data edges between different XLA computations, if either src or dst -// is outside compilation, add an Identity node in between the edge. The -// identity node will have attr kBridgeSourceNodeAttrName. -// 3. For data edges between outside compilation and host computation, remove -// the edge and create a Placeholder node as dst node's input. -Status PreprocessForEncapsulation(Graph* g, - const string& xla_computation_attr_name, - const string& outside_compilation_attr_name); - // Information for XLA computation. struct XlaClusterInfo { // Add an explicitly-defined default constructor for this class. @@ -158,24 +89,6 @@ struct XlaClusterInfo { const std::map host_compute_core; }; -// Postprocesses edges between different XLA clusters for encapsulation. This -// function reverts what `PreprocessForEncapsulation` did. It will perform the -// following operations in order: -// -// 1. Remove Placeholder nodes between outside compilation and host computation -// (created in `PreprocessForEncapsulation` step 3). -// 2. Remove Identity nodes created in `PreprocessForEncapsulation` step 2. -// 3a. Reconnect control edges between outside compilation and another XLA -// computation (marked by `PreprocessForEncapsulation` step 1a). -// 3b. Reconnect control edges between different outside compilations (marked by -// `PreprocessForEncapsulation` step 1b). -// 3c. Reconnect control edges between outside compilation and host computation -// (marked by `PreprocessForEncapsulation` step 1c). -Status PostprocessForEncapsulation( - Graph* g, const string& xla_computation_attr_name, - const string& outside_compilation_attr_name, - const std::unordered_map& clusters); - // Preprocesses edges within the same XLA cluster. It will perform the following // operations in order: // diff --git a/tensorflow/compiler/jit/encapsulate_util_test.cc b/tensorflow/compiler/jit/encapsulate_util_test.cc index 3b8b49cb92f3e453883a8e64e12ce3748a5173f6..3bb979e0698d2d6be42ed5bae66c25267928192c 100644 --- a/tensorflow/compiler/jit/encapsulate_util_test.cc +++ b/tensorflow/compiler/jit/encapsulate_util_test.cc @@ -38,24 +38,11 @@ TEST(PerformStaticShapeInferenceBeforeEncapsulationTest, Basic) { Graph g(OpRegistry::Global()); TF_CHECK_OK(s.ToGraph(&g)); - // "add" node is outside compilation node, "identity" node is XLA node. - auto node_index = g.BuildNodeNameIndex(); - Node *add_node = node_index["add"], *identity_node = node_index["identity"]; - add_node->AddAttr("_xla", "cluster"); - add_node->AddAttr("_oc", "cluster"); - identity_node->AddAttr("_xla", "cluster"); - TF_CHECK_OK( - PerformStaticShapeInferenceBeforeEncapsulation(&g, "_xla", "_oc")); + TF_CHECK_OK(PerformStaticShapeInferenceBeforeEncapsulation(&g)); - // Check that only "add" node now has _xla_inferred_shapes attr. - std::vector nodes_with_inferred_shape; - for (Node *n : g.nodes()) { - if (HasNodeAttr(n->def(), kXlaInferredShapesAttrName)) { - nodes_with_inferred_shape.push_back(n); - } - } - EXPECT_EQ(nodes_with_inferred_shape.size(), 1); - EXPECT_EQ(nodes_with_inferred_shape[0], add_node); + // Check that "add" node now has _xla_inferred_shapes attr. + auto node_index = g.BuildNodeNameIndex(); + Node *add_node = node_index["add"]; std::vector output_shapes; TF_CHECK_OK(GetNodeAttr(add_node->attrs(), kXlaInferredShapesAttrName, &output_shapes)); @@ -66,329 +53,4 @@ TEST(PerformStaticShapeInferenceBeforeEncapsulationTest, Basic) { EXPECT_EQ(shape_proto.dim(0).size(), 2); } -TEST(PreprocessForEncapsulationTest, ControlEdges) { - // Build the graph: - // "const_0" and "const_1" in host computation - // "add" = "const_0" + "const_1" in XLA computation 0 - // "identity0" = "add" in XLA computation 0 & outside compilation 0 - // "identity1" = "identity0" in XLA computation 0 - // "identity2" = "identity1" in host computation - // "identity3" = "identity2" in XLA computation 1 - // "identity4" = "identity3" in XLA computation 1 & outside compilation 1 - // "identity5" = "identity4" in XLA computation 1 - // "identity6" = "identity5" in host computation - tensorflow::Scope s = tensorflow::Scope::NewRootScope(); - Output const_0 = ops::Const(s.WithOpName("const_0"), 1, {}); - Output const_1 = ops::Const(s.WithOpName("const_1"), 2, {}); - Output add = ops::Add(s.WithOpName("add"), const_0, const_1); - Output identity0 = ops::Identity(s.WithOpName("identity0"), add); - Output identity1 = ops::Identity(s.WithOpName("identity1"), identity0); - Output identity2 = ops::Identity(s.WithOpName("identity2"), identity1); - Output identity3 = ops::Identity(s.WithOpName("identity3"), identity2); - Output identity4 = ops::Identity(s.WithOpName("identity4"), identity3); - Output identity5 = ops::Identity(s.WithOpName("identity5"), identity4); - Graph g(OpRegistry::Global()); - TF_CHECK_OK(s.ToGraph(&g)); - auto node_index = g.BuildNodeNameIndex(); - - // Set XLA computation/outside compilation attr, and add control edges. - Node *const0_node = node_index["const_0"], *add_node = node_index["add"], - *identity0_node = node_index["identity0"], - *identity1_node = node_index["identity1"], - *identity2_node = node_index["identity2"], - *identity3_node = node_index["identity3"], - *identity4_node = node_index["identity4"], - *identity5_node = node_index["identity5"]; - add_node->AddAttr("_xla", "0"); - identity0_node->AddAttr("_xla", "0"); - identity0_node->AddAttr("_oc", "0"); - identity1_node->AddAttr("_xla", "0"); - identity3_node->AddAttr("_xla", "1"); - identity4_node->AddAttr("_xla", "1"); - identity4_node->AddAttr("_oc", "0"); - identity5_node->AddAttr("_xla", "1"); - // Case 1a: control edges between outside compilation and another XLA - // computation. - g.AddControlEdge(identity0_node, identity3_node); - g.AddControlEdge(identity1_node, identity4_node); - // Case 1b: control edges between different outside compilations. - g.AddControlEdge(identity0_node, identity4_node); - // Case 1c: control edges between outside compilation and host computation. - g.AddControlEdge(const0_node, identity0_node); - g.AddControlEdge(identity0_node, identity2_node); - - TF_CHECK_OK(PreprocessForEncapsulation(&g, "_xla", "_oc")); - - // Case 1a: add attr "_xla_control_deps_{from/to} = XLA computation node name" - // to the outside compilation node. - std::vector attr; - TF_CHECK_OK(GetNodeAttr(identity0_node->def(), - kXlaConnectedToOtherXlaComputationAttrName, &attr)); - EXPECT_EQ(attr.size(), 1); - EXPECT_EQ(attr[0], "1"); - attr.clear(); - TF_CHECK_OK(GetNodeAttr(identity4_node->def(), - kXlaConnectedFromOtherXlaComputationAttrName, &attr)); - EXPECT_EQ(attr.size(), 1); - EXPECT_EQ(attr[0], "0"); - // Case 1b: add attr "_xla_control_deps = src node name" to dst node. - attr.clear(); - TF_CHECK_OK(GetNodeAttr(identity4_node->def(), - kXlaControlDependenciesAttrName, &attr)); - EXPECT_EQ(attr.size(), 1); - EXPECT_EQ(attr[0], "identity0"); - // Case 1c: add attr "_xla_control_deps = src node name" to dst node. - attr.clear(); - TF_CHECK_OK(GetNodeAttr(identity0_node->def(), - kXlaControlDependenciesAttrName, &attr)); - EXPECT_EQ(attr.size(), 1); - EXPECT_EQ(attr[0], "const_0"); - attr.clear(); - TF_CHECK_OK(GetNodeAttr(identity2_node->def(), - kXlaControlDependenciesAttrName, &attr)); - EXPECT_EQ(attr.size(), 1); - EXPECT_EQ(attr[0], "identity0"); -} - -TEST(PreprocessForEncapsulationTest, DataEdges) { - // Build the graph: - // "const_0" and "const_1" in host computation - // "identityn0" = ("const_0", "const_1") in host computation 0 - // "add0" = "const_0" + "const_1" in XLA computation 0 - // "add1" = "add0" + "const_0" in XLA computation 0 & outside compilation 0 - // "identity0" = "add1" in XLA computation 0 - // "add2" = "add1" + "identity0" in host computation - // "add3" = "add1" + "add2" in XLA computation 1 - // "add4" = "identity0" + "add2" in XLA computation 1 & outside compilation 0 - // "add5" = "identityn0"[0] + "identityn0"[1] in XLA computation 1 & - // outside compilation 0 - // "identityn1" = ("identityn0"[0], "identityn0"[1]) in XLA computation 1 & - // outside compilation 0 - // "identity1" = "add4" in XLA computation 1 - // "identity2" = "identity1" in host computation - tensorflow::Scope s = tensorflow::Scope::NewRootScope(); - Output const_0 = ops::Const(s.WithOpName("const_0"), 1, {}); - Output const_1 = ops::Const(s.WithOpName("const_1"), 2, {}); - auto identityn0 = - ops::IdentityN(s.WithOpName("identityn_0"), {const_0, const_1}); - Output add0 = ops::Add(s.WithOpName("add0"), const_0, const_1); - Output add1 = ops::Add(s.WithOpName("add1"), add0, const_0); - Output identity0 = ops::Identity(s.WithOpName("identity0"), add1); - Output add2 = ops::Add(s.WithOpName("add2"), add1, identity0); - Output add3 = ops::Add(s.WithOpName("add3"), add1, add2); - Output add4 = ops::Add(s.WithOpName("add4"), identity0, add2); - Output add5 = ops::Add(s.WithOpName("add5"), identityn0[0], identityn0[1]); - auto identityn1 = ops::IdentityN(s.WithOpName("identityn_1"), - {identityn0[0], identityn0[1]}); - Output identity1 = ops::Identity(s.WithOpName("identity1"), add4); - Output identity2 = ops::Identity(s.WithOpName("identity2"), add4); - Graph g(OpRegistry::Global()); - TF_CHECK_OK(s.ToGraph(&g)); - auto node_index = g.BuildNodeNameIndex(); - - // Set XLA computation/outside compilation attr. - Node *add0_node = node_index["add0"], *add1_node = node_index["add1"], - *identity0_node = node_index["identity0"], - *add3_node = node_index["add3"], *add4_node = node_index["add4"], - *add5_node = node_index["add5"], - *identityn1_node = node_index["identityn_1"], - *identity1_node = node_index["identity1"]; - add0_node->AddAttr("_xla", "0"); - add1_node->AddAttr("_xla", "0"); - add1_node->AddAttr("_oc", "0"); - identity0_node->AddAttr("_xla", "0"); - add3_node->AddAttr("_xla", "1"); - add4_node->AddAttr("_xla", "1"); - add4_node->AddAttr("_oc", "0"); - add5_node->AddAttr("_xla", "1"); - add5_node->AddAttr("_oc", "0"); - identityn1_node->AddAttr("_xla", "1"); - identityn1_node->AddAttr("_oc", "0"); - identity1_node->AddAttr("_xla", "1"); - - TF_CHECK_OK(PreprocessForEncapsulation(&g, "_xla", "_oc")); - - // Check input nodes for related data edges. - node_index = g.BuildNodeNameIndex(); - // Step 2: add an Identity node between different XLA computations. - Node *bridge_add1_add3 = node_index["bridge_add1_add3"]; - EXPECT_NE(bridge_add1_add3, nullptr); - string str; - TF_CHECK_OK( - GetNodeAttr(bridge_add1_add3->attrs(), kBridgeSourceNodeAttrName, &str)); - EXPECT_EQ(str, "add1"); - Node *bridge_identity0_add4 = node_index["bridge_identity0_add4"]; - EXPECT_NE(bridge_identity0_add4, nullptr); - // Step 3: add placeholder for edges between host computation and outside - // compilation. - EXPECT_EQ(bridge_add1_add3->def().input(0), "add1_oc_to_host_placeholder_0"); - Node *add1_oc_to_host_placeholder = - node_index["add1_oc_to_host_placeholder_0"]; - TF_CHECK_OK(GetNodeAttr(add1_oc_to_host_placeholder->attrs(), - kOutsideCompilationToHostOriginalNodeAttrName, &str)); - EXPECT_EQ(str, "add1"); - int i; - TF_CHECK_OK(GetNodeAttr(add1_oc_to_host_placeholder->attrs(), - kOutsideCompilationToHostSrcOutputAttrName, &i)); - EXPECT_EQ(i, 0); - add4_node = node_index["add4"]; - ASSERT_NE(add4_node, nullptr); - EXPECT_EQ(add4_node->def().input(0), - "bridge_identity0_add4_host_to_oc_placeholder_0"); - Node *identity0_host_to_oc_placeholder = - node_index["bridge_identity0_add4_host_to_oc_placeholder_0"]; - TF_CHECK_OK(GetNodeAttr(identity0_host_to_oc_placeholder->attrs(), - kHostToOutsideCompilationOriginalNodeAttrName, &str)); - EXPECT_EQ(str, "bridge_identity0_add4"); - TF_CHECK_OK(GetNodeAttr(identity0_host_to_oc_placeholder->attrs(), - kHostToOutsideCompilationSrcOutputAttrName, &i)); - EXPECT_EQ(i, 0); - - // Check different placeholder nodes are created for different src_output. - Node *placeholder0 = node_index["identityn_0_host_to_oc_placeholder_0"], - *placeholder1 = node_index["identityn_0_host_to_oc_placeholder_1"]; - EXPECT_NE(placeholder0, nullptr); - EXPECT_NE(placeholder1, nullptr); - // Check we only have 2 placeholder nodes created for "identityn_0". - int placeholder_count = 0; - for (Node *n : g.nodes()) { - if (HasNodeAttr(n->def(), kHostToOutsideCompilationOriginalNodeAttrName)) { - string attr; - TF_CHECK_OK(GetNodeAttr( - n->attrs(), kHostToOutsideCompilationOriginalNodeAttrName, &attr)); - if (attr == "identityn_0") { - ++placeholder_count; - } - } - } - EXPECT_EQ(placeholder_count, 2); -} - -TEST(PostprocessForEncapsulationTest, ControlEdges) { - // Build the graph: - // "const0" - // "identity0" = "const0" (XLA computation 0) - // "identity1" = "identity0" - // "identity2" = "identity1" (XLA computation 1) - // "identity3" = "identity2" - tensorflow::Scope s = tensorflow::Scope::NewRootScope(); - Output const0 = ops::Const(s.WithOpName("const0"), 1, {}); - Output identity0 = ops::Identity(s.WithOpName("identity0"), const0); - Output identity1 = ops::Identity(s.WithOpName("identity1"), identity0); - Output identity2 = ops::Identity(s.WithOpName("identity2"), identity1); - Output identity3 = ops::Identity(s.WithOpName("identity3"), identity2); - Graph g(OpRegistry::Global()); - TF_CHECK_OK(s.ToGraph(&g)); - auto node_index = g.BuildNodeNameIndex(); - - // Set XLA computation/outside compilation attr, and add control edges. - Node *const0_node = node_index["const0"], - *identity0_node = node_index["identity0"], - *identity1_node = node_index["identity1"], - *identity2_node = node_index["identity2"], - *identity3_node = node_index["identity3"]; - identity1_node->AddAttr(kXlaConnectedFromOtherXlaComputationAttrName, - std::vector{"0"}); - identity1_node->AddAttr(kXlaConnectedToOtherXlaComputationAttrName, - std::vector{"1"}); - identity3_node->AddAttr(kXlaControlDependenciesAttrName, - std::vector{"const0", "identity1"}); - - std::unordered_map clusters; - clusters["0"].node = identity0_node; - clusters["1"].node = identity2_node; - TF_CHECK_OK(PostprocessForEncapsulation(&g, "_xla", "_oc", clusters)); - - // Case 3a: we have control edge identity0 -> identity1, and identity1 -> - // identity2. - bool edge_identity0_identity1 = false, edge_identity1_identity2 = false; - for (const Edge *e : g.edges()) { - if (!e->IsControlEdge()) { - continue; - } - if (e->src() == identity0_node && e->dst() == identity1_node) { - edge_identity0_identity1 = true; - } else if (e->src() == identity1_node && e->dst() == identity2_node) { - edge_identity1_identity2 = true; - } - } - EXPECT_TRUE(edge_identity0_identity1); - EXPECT_TRUE(edge_identity1_identity2); - // Case 3b: we have control edge const0 -> identity3, and identity1 -> - // identity3. - bool edge_const0_identity3 = false, edge_identity1_identity3 = false; - for (const Edge *e : g.edges()) { - if (!e->IsControlEdge()) { - continue; - } - if (e->src() == const0_node && e->dst() == identity3_node) { - edge_const0_identity3 = true; - } else if (e->src() == identity1_node && e->dst() == identity3_node) { - edge_identity1_identity3 = true; - } - } - EXPECT_TRUE(edge_const0_identity3); - EXPECT_TRUE(edge_identity1_identity3); -} - -TEST(PostprocessForEncapsulationTest, DataEdges) { - // Build the graph: - // "const0" in outside compilation "0" - // "placeholder0" (for "const0") in host computation - // "add0" = "placeholder0" + "placeholder0" in host computation - // "placeholder1" (for "add0") in outside compilation 1 - // "add1" = "placeholder1" + "placeholder1" in outside compilation 1 - // - // "bridge" = "placeholder0" in host computation - // "placeholder2" (for "bridge") in outside compilation 1 - // "add2" = "placeholder2" + "placeholder2" in outside compilation 1 - tensorflow::Scope s = tensorflow::Scope::NewRootScope(); - Output const0 = ops::Const(s.WithOpName("const0"), 1, {}); - Output placeholder0 = - ops::Placeholder(s.WithOpName("placeholder0"), DT_INT32); - Output add0 = ops::Add(s.WithOpName("add0"), placeholder0, placeholder0); - Output placeholder1 = - ops::Placeholder(s.WithOpName("placeholder1"), DT_INT32); - Output add1 = ops::Add(s.WithOpName("add1"), placeholder1, placeholder1); - Output bridge = ops::Identity(s.WithOpName("bridge"), placeholder0); - Output placeholder2 = - ops::Placeholder(s.WithOpName("placeholder2"), DT_INT32); - Output add2 = ops::Add(s.WithOpName("add2"), placeholder2, placeholder2); - Graph g(OpRegistry::Global()); - TF_CHECK_OK(s.ToGraph(&g)); - auto node_index = g.BuildNodeNameIndex(); - - // Set related attributes. - Node *placeholder0_node = node_index["placeholder0"]; - placeholder0_node->AddAttr(kOutsideCompilationToHostOriginalNodeAttrName, - "const0"); - placeholder0_node->AddAttr(kOutsideCompilationToHostSrcOutputAttrName, 0); - Node *placeholder1_node = node_index["placeholder1"]; - placeholder1_node->AddAttr(kHostToOutsideCompilationOriginalNodeAttrName, - "add0"); - placeholder1_node->AddAttr(kHostToOutsideCompilationSrcOutputAttrName, 0); - Node *bridge_node = node_index["bridge"]; - bridge_node->AddAttr(kBridgeSourceNodeAttrName, "const0"); - Node *placeholder2_node = node_index["placeholder2"]; - placeholder2_node->AddAttr(kHostToOutsideCompilationOriginalNodeAttrName, - "bridge"); - placeholder2_node->AddAttr(kHostToOutsideCompilationSrcOutputAttrName, 0); - - std::unordered_map clusters; - TF_CHECK_OK(PostprocessForEncapsulation(&g, "_xla", "_oc", clusters)); - - // Result graph should be: - // "add0" = "const0" + "const0" - // "add1" = "add0" + "add0" - // "add2" = "const0" + "const0" - node_index = g.BuildNodeNameIndex(); - EXPECT_EQ(node_index.size(), 6); - EXPECT_EQ(node_index["add0"]->def().input(0), "const0:0"); - EXPECT_EQ(node_index["add0"]->def().input(1), "const0:0"); - EXPECT_EQ(node_index["add1"]->def().input(0), "add0:0"); - EXPECT_EQ(node_index["add1"]->def().input(1), "add0:0"); - EXPECT_EQ(node_index["add2"]->def().input(0), "const0:0"); - EXPECT_EQ(node_index["add2"]->def().input(1), "const0:0"); -} - } // namespace tensorflow diff --git a/tensorflow/compiler/jit/encapsulate_xla_computations_pass.cc b/tensorflow/compiler/jit/encapsulate_xla_computations_pass.cc index d334100aa4a915a87fb05d371e0e3379a7ee05f2..ec745cdbb7e237f8b4935dd41e9791fc75f5355d 100644 --- a/tensorflow/compiler/jit/encapsulate_xla_computations_pass.cc +++ b/tensorflow/compiler/jit/encapsulate_xla_computations_pass.cc @@ -297,6 +297,7 @@ Status RewriteSubgraph(const std::vector& arg_source_tensors, NodeDef def; def.set_name(launch->name()); + MergeDebugInfo(NodeDebugInfo(launch->def()), &def); // Target the XLA CPU/GPU backends. VLOG(2) << "Replacing with XlaLaunch"; diff --git a/tensorflow/compiler/jit/extract_outside_compilation_pass.cc b/tensorflow/compiler/jit/extract_outside_compilation_pass.cc index e3c7e2f89be9b37b51a633dabb099969c181013f..1906f1ac850095a27add10d6b22d3bbb0f811ce9 100644 --- a/tensorflow/compiler/jit/extract_outside_compilation_pass.cc +++ b/tensorflow/compiler/jit/extract_outside_compilation_pass.cc @@ -20,8 +20,10 @@ limitations under the License. #include "tensorflow/compiler/jit/encapsulate_subgraphs_pass.h" #include "tensorflow/compiler/jit/encapsulate_util.h" #include "tensorflow/compiler/tf2xla/dump_graph.h" +#include "tensorflow/compiler/tf2xla/side_effect_util.h" #include "tensorflow/compiler/tf2xla/tf2xla_util.h" #include "tensorflow/core/common_runtime/function.h" +#include "tensorflow/core/framework/function.h" #include "tensorflow/core/framework/graph_to_functiondef.h" #include "tensorflow/core/framework/node_def_builder.h" #include "tensorflow/core/framework/node_def_util.h" @@ -98,9 +100,12 @@ xla::StatusOr BuildRecvAtHostNode( recv_at_host_builder.Attr("Toutputs", recv_at_host_dtypes); // The correct device_ordinal will be inserted during replication in a // subsequent rewrite. - recv_at_host_builder.Attr("device_ordinal", 0); + AttrValue device_ordinal_value; + device_ordinal_value.set_placeholder("device_ordinal"); + recv_at_host_builder.Attr("device_ordinal", device_ordinal_value); recv_at_host_builder.Attr( "key", absl::StrCat("host_compute_channel_", oc_cluster_name)); + recv_at_host_builder.Attr(kXlaHasHostTransferAttrName, true); recv_at_host_builder.Input(key_placeholder->name(), 0, DT_STRING); TF_RETURN_IF_ERROR(recv_at_host_builder.Finalize(&recv_at_host_def)); Status s; @@ -197,9 +202,12 @@ xla::StatusOr BuildSendFromHostNode( send_from_host_builder.Attr("Tinputs", send_from_host_dtypes); // The correct device_ordinal will be inserted during replication in a // subsequent rewrite. - send_from_host_builder.Attr("device_ordinal", 0); + AttrValue device_ordinal_value; + device_ordinal_value.set_placeholder("device_ordinal"); + send_from_host_builder.Attr("device_ordinal", device_ordinal_value); send_from_host_builder.Attr( "key", absl::StrCat("host_compute_channel_", oc_cluster_name)); + send_from_host_builder.Attr(kXlaHasHostTransferAttrName, true); std::vector inputs(send_from_host_dtypes.size()); for (auto* n : ret_nodes) { int index; @@ -357,6 +365,47 @@ Status ReplaceOrRemoveOutsideCompilationCallNode( return Status::OK(); } +// Resets "device_ordinal" attr to placeholder value for related nodes +// (XlaRecvAtHost nodes; XlaSendFromHost nodes; If nodes containing +// XlaRecvAtHost/XlaSendFromHost). +Status ResetDeviceOrdinalToPlaceholderValue(Graph* g) { + AttrValue device_ordinal_value; + device_ordinal_value.set_placeholder("device_ordinal"); + for (Node* n : g->nodes()) { + if (!HasNodeAttr(n->def(), kXlaHasHostTransferAttrName)) { + continue; + } + + if (n->type_string() == "_XlaRecvAtHost" || + n->type_string() == "_XlaSendFromHost") { + n->ClearAttr("device_ordinal"); + n->AddAttr("device_ordinal", device_ordinal_value); + } else if (n->type_string() == "If") { + for (const string& attr_name : + std::vector{"then_branch", "else_branch"}) { + NameAttrList branch_func; + TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), attr_name, &branch_func)); + (*branch_func.mutable_attr())["device_ordinal"] = device_ordinal_value; + n->ClearAttr(attr_name); + n->AddAttr(attr_name, branch_func); + } + } else if (n->type_string() == "While") { + for (const string& attr_name : std::vector{"cond", "body"}) { + NameAttrList branch_func; + TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), attr_name, &branch_func)); + (*branch_func.mutable_attr())["device_ordinal"] = device_ordinal_value; + n->ClearAttr(attr_name); + n->AddAttr(attr_name, branch_func); + } + } else { + return errors::Internal("Unknown node marked with ", + kXlaHasHostTransferAttrName, ": ", + n->DebugString()); + } + } + return Status::OK(); +} + // For an XLA computation, builds host side graph given all outside compilation // graphs inside it. The host side graph contains: // 1) a "sequencer" node (we will add control edge between XlaRecvAtHost and @@ -368,8 +417,8 @@ Status ReplaceOrRemoveOutsideCompilationCallNode( Status ConstructHostGraph( const string& xla_cluster_name, const string& outside_compilation_attr_name, const std::vector& outside_compilation_host_graphs, - FunctionLibraryDefinition* fld, std::unique_ptr* host_graph) { - host_graph->reset(new Graph(fld)); + FunctionLibraryDefinition* fld, const string& host_graph_func_name) { + Graph host_graph(fld); // Create sequencer node in host graph. NodeDefBuilder sequencer_builder(absl::StrCat(xla_cluster_name, "_sequencer"), @@ -378,24 +427,34 @@ Status ConstructHostGraph( NodeDef sequencer_def; TF_RETURN_IF_ERROR(sequencer_builder.Finalize(&sequencer_def)); Status s; - Node* sequencer = (*host_graph)->AddNode(sequencer_def, &s); + Node* sequencer = host_graph.AddNode(sequencer_def, &s); TF_RETURN_IF_ERROR(s); // Create key placeholder in host graph. TF_ASSIGN_OR_RETURN( Node * key_placeholder, - AddHostComputeKeyPlaceholder(xla_cluster_name, host_graph->get())); + AddHostComputeKeyPlaceholder(xla_cluster_name, &host_graph)); // For each outside compilation graph, copy them to host graph with the // following changes: // a) Use key_placeholder in host graph instead of its own. - // b) Add control edge from RecvAtHost/SendFromHost to sequencer. + // b) Add control edge from host transfer nodes (XlaRecvAtHost, + // XlaSendFromHost, If/While nodes containing + // XlaRecvAtHost/XlaSendFromHost) to sequencer node. // c) Clear node_def.device(), so device placer won't get confused. for (const string& host_func : outside_compilation_host_graphs) { VLOG(4) << "Expanding host graph " << host_func; + // Temporarily use "0" as "device_ordinal". It will be reset to placeholder + // value after we expanded all host graphs. We cannot just use placeholder + // value here because FunctionDef instantiation does not allow placeholder + // value for attributes. + AttrValue device_ordinal_attr; + device_ordinal_attr.set_i(0); + protobuf::Map attrs; + attrs["device_ordinal"] = device_ordinal_attr; FunctionBody* host_fbody = nullptr; TF_RETURN_IF_ERROR(FunctionDefToBodyHelper( - *fld->Find(host_func), AttrSlice(), fld, + *fld->Find(host_func), AttrSlice(&attrs), fld, [&](const string& op, const OpDef** sig) { return fld->LookUpOpDef(op, sig); }, @@ -408,8 +467,8 @@ Status ConstructHostGraph( FixupSourceAndSinkEdges(host_fbody->graph); std::map node_map; - node_map[host_fbody->graph->source_node()] = (*host_graph)->source_node(); - node_map[host_fbody->graph->sink_node()] = (*host_graph)->sink_node(); + node_map[host_fbody->graph->source_node()] = host_graph.source_node(); + node_map[host_fbody->graph->sink_node()] = host_graph.sink_node(); Status s; ReverseDFS( *host_fbody->graph, /*enter=*/nullptr, @@ -431,7 +490,7 @@ Status ConstructHostGraph( NodeDef copy_def = n->def(); // Change c). copy_def.clear_device(); - copy = (*host_graph)->AddNode(copy_def, &s); + copy = host_graph.AddNode(copy_def, &s); if (!s.ok()) { return; } @@ -446,22 +505,23 @@ Status ConstructHostGraph( e->src()->DebugString()); return; } - (*host_graph) - ->AddEdge(node_map[e->src()], e->src_output(), copy, - e->dst_input()); + host_graph.AddEdge(node_map[e->src()], e->src_output(), copy, + e->dst_input()); } // Change b). - if (copy->type_string() == "_XlaRecvAtHost" || - copy->type_string() == "_XlaSendFromHost") { - (*host_graph)->AddControlEdge(copy, sequencer); + if (HasNodeAttr(copy->def(), kXlaHasHostTransferAttrName)) { + host_graph.AddControlEdge(copy, sequencer); } }, NodeComparatorID()); + if (!s.ok()) { return s; } } + // Reset "device_ordinal" to placeholder value. + TF_RETURN_IF_ERROR(ResetDeviceOrdinalToPlaceholderValue(&host_graph)); // sequencer and key_placeholder might be dead nodes. Prune them if necessary. // - sequencer should be pruned iff it has no input control edges from @@ -470,21 +530,30 @@ Status ConstructHostGraph( // - key_placeholder should be pruned iff there's no RecvAtHost/SendFromHost. // We don't need to do anything special. if (!sequencer->in_edges().empty()) { - (*host_graph)->AddControlEdge(sequencer, (*host_graph)->sink_node()); + host_graph.AddControlEdge(sequencer, host_graph.sink_node()); } PruneForReverseReachability( - host_graph->get(), - std::unordered_set{(*host_graph)->sink_node()}); + &host_graph, std::unordered_set{host_graph.sink_node()}); // Postprocess edges between different outside compilations. TF_RETURN_IF_ERROR(PostprocessEdgesBetweenOutsideCompilations( - host_graph->get(), outside_compilation_attr_name)); + &host_graph, outside_compilation_attr_name)); if (VLOG_IS_ON(4)) { dump_graph::DumpGraphToFile( absl::StrCat("extract_outside_compilation_host_graph_for_", xla_cluster_name), - **host_graph, fld); + host_graph, fld); + } + + FunctionDef host_graph_fdef; + TF_RETURN_IF_ERROR( + GraphToFunctionDef(host_graph, host_graph_func_name, &host_graph_fdef)); + if (fld->Find(host_graph_func_name)) { + TF_RETURN_IF_ERROR( + fld->ReplaceFunction(host_graph_func_name, host_graph_fdef)); + } else { + TF_RETURN_IF_ERROR(fld->AddFunctionDef(host_graph_fdef)); } return Status::OK(); @@ -492,8 +561,28 @@ Status ConstructHostGraph( // Expand XLA computation's outside compilation host side graph into main graph. // Add a control edge between sequencer node and the XLA computation node. -Status ExpandHostGraphIntoMainGraph(Graph* main_graph, Graph* host_graph, +Status ExpandHostGraphIntoMainGraph(Graph* main_graph, + FunctionLibraryDefinition* fld, + const string& host_graph_func_name, Node* xla_computation_node) { + // Temporarily use "0" as "device_ordinal". It will be rewritten with the + // correct value in a later pass. We cannot just use placeholder value here + // because FunctionDef instantiation does not allow placeholder value for + // attributes. + AttrValue device_ordinal_attr; + device_ordinal_attr.set_i(0); + protobuf::Map attrs; + attrs["device_ordinal"] = device_ordinal_attr; + FunctionBody* fbody = nullptr; + TF_RETURN_IF_ERROR(FunctionDefToBodyHelper( + *fld->Find(host_graph_func_name), AttrSlice(&attrs), fld, + [&](const string& op, const OpDef** sig) { + return fld->LookUpOpDef(op, sig); + }, + &fbody)); + std::unique_ptr fbody_deleter(fbody); + Graph* host_graph = fbody->graph; + // We use ReverseDFS() to copy nodes. Make sure all nodes are reverse // reachable from sink node so all nodes will be copied. // TODO(b/77601805): consolidate copy graph functions. @@ -545,23 +634,25 @@ Status ExpandHostGraphIntoMainGraph(Graph* main_graph, Graph* host_graph, return s; } -// Rewrites shape inference graph for outside compilation. -// 1. If the outside compilation is a "top-level" one (not in a function of any -// If/While/etc.), this shape inference graph might have host computation to -// outside compilation placeholder nodes, which will cause shape inference to -// fail. However, those nodes are not in `host_graph` any more (because we -// have executed `PostprocessForEncapsultion`). In this case, we clear the -// graph, and copy SendFromHost with all its predecessors from `host_graph`. -// This case is detected by whether the SendFromHost node exists in -// `host_graph` as well. -// 2. Remove control edges, and prune nodes that are not useful for shape -// inference. +// Rewrites shape inference graph for outside compilation: +// 1) If XlaSendFromHost also exists in `host_graph`, copy nodes from +// `host_graph`. Because we might still have outside compilation to outside +// compilation placeholder nodes in shape inference graph, which will prevent +// us from inferring XlaSendFromHost shape. But in `host_graph`, we already +// removed those placeholder nodes. +// 2) Remove control edges. +// 3) Prune nodes that are not useful for shape inference. Status RewriteShapeInferenceGraph(const string& shape_inference_graph_name, Graph* host_graph, FunctionLibraryDefinition* fld) { + // Use "0" as "device_ordinal". It does not matter for shape inference. + AttrValue device_ordinal_attr; + device_ordinal_attr.set_i(0); + protobuf::Map attrs; + attrs["device_ordinal"] = device_ordinal_attr; FunctionBody* fbody = nullptr; TF_RETURN_IF_ERROR(FunctionDefToBodyHelper( - *fld->Find(shape_inference_graph_name), AttrSlice(), fld, + *fld->Find(shape_inference_graph_name), AttrSlice(&attrs), fld, [&](const string& op, const OpDef** sig) { return fld->LookUpOpDef(op, sig); }, @@ -650,6 +741,7 @@ Status RewriteShapeInferenceGraph(const string& shape_inference_graph_name, g->RemoveEdge(e); } } + // Nodes that are not reverse reachable from SendFromHost are not useful for // shape inference. Prune them. PruneForReverseReachability(g, @@ -669,6 +761,572 @@ Status RewriteShapeInferenceGraph(const string& shape_inference_graph_name, return Status::OK(); } +// Builds XlaSendToHost node which sends cond predicate to host. +xla::StatusOr BuildSendIfPredNode(const string& name, + const string& host_transfer_key, + Node* pred_node, Graph* g) { + NodeDefBuilder send_pred_builder(name, "XlaSendToHost"); + send_pred_builder.Attr("Tinput", DT_BOOL); + send_pred_builder.Attr("key", absl::StrCat(host_transfer_key, "_dtoh_0")); + send_pred_builder.Attr(kXlaTokenInputNodesAttrName, + std::vector{kXlaTokenArgNodeName}); + send_pred_builder.Input(pred_node->name(), 0, DT_BOOL); + NodeDef send_pred_def; + TF_RETURN_IF_ERROR(send_pred_builder.Finalize(&send_pred_def)); + Status s; + Node* send_pred_node = g->AddNode(send_pred_def, &s); + TF_RETURN_IF_ERROR(s); + g->AddEdge(pred_node, 0, send_pred_node, 0); + return send_pred_node; +} + +// Replaces key placeholder node with an _Arg node. +Status ReplaceKeyPlaceholderWithArgNode(const string& xla_cluster_name, + const string& func_name, + FunctionLibraryDefinition* fld) { + // Temporarily use "0" as "device_ordinal". It will be reset to placeholder + // value after rewriting. + AttrValue device_ordinal_attr; + device_ordinal_attr.set_i(0); + protobuf::Map attrs; + attrs["device_ordinal"] = device_ordinal_attr; + FunctionBody* fbody = nullptr; + TF_RETURN_IF_ERROR(FunctionDefToBodyHelper( + *fld->Find(func_name), AttrSlice(&attrs), fld, + [&](const string& op, const OpDef** sig) { + return fld->LookUpOpDef(op, sig); + }, + &fbody)); + std::unique_ptr fbody_deleter(fbody); + Graph* g = fbody->graph; + + // Find or create the key placeholder node. + Node* key_placeholder = nullptr; + for (Node* n : g->nodes()) { + if (IsKeyPlaceholderNode(*n)) { + key_placeholder = n; + break; + } + } + if (!key_placeholder) { + TF_ASSIGN_OR_RETURN(key_placeholder, + AddHostComputeKeyPlaceholder(xla_cluster_name, g)); + } + + // Build the _Arg node, and replace key placeholder node with it. + NodeDefBuilder arg_builder("key_arg", FunctionLibraryDefinition::kArgOp); + arg_builder.Attr("T", DT_STRING); + arg_builder.Attr("index", 0); + NodeDef arg_def; + TF_RETURN_IF_ERROR(arg_builder.Finalize(&arg_def)); + TF_RETURN_IF_ERROR(ReplaceNode(g, key_placeholder, arg_def).status()); + + // Reset "device_ordinal" to placeholder value. + TF_RETURN_IF_ERROR(ResetDeviceOrdinalToPlaceholderValue(g)); + + FunctionDef replace_fdef; + TF_RETURN_IF_ERROR(GraphToFunctionDef(*g, func_name, &replace_fdef)); + TF_RETURN_IF_ERROR(fld->ReplaceFunction(func_name, replace_fdef)); + return Status::OK(); +} + +// Builds host side graph for If node. +Status BuildHostGraphForIfNode(const string& xla_cluster_attr_name, + const string& outside_compilation_attr_name, + const string& xla_cluster_name, + const string& if_node_name, + const string& host_transfer_key, + const string& host_graph_func_name, + FunctionLibraryDefinition* fld, + const string& then_branch_host_func_name, + const string& else_branch_host_func_name) { + Graph host_graph(fld); + string outside_compilation_name = absl::StrCat("oc_if_", if_node_name); + AttrValue device_ordinal_value; + device_ordinal_value.set_placeholder("device_ordinal"); + + // Step 1: add key placeholder node. + TF_ASSIGN_OR_RETURN( + Node * key_placeholder, + AddHostComputeKeyPlaceholder(xla_cluster_name, &host_graph)); + + // Step 2: build XlaRecvAtHost node to recv predicate. + NodeDefBuilder recv_pred_builder( + absl::StrCat("recv_oc_if_pred_", if_node_name), "_XlaRecvAtHost"); + recv_pred_builder.Attr("Toutputs", std::vector{DT_BOOL}); + recv_pred_builder.Attr("key", host_transfer_key); + recv_pred_builder.Attr("device_ordinal", device_ordinal_value); + recv_pred_builder.Attr(xla_cluster_attr_name, xla_cluster_name); + recv_pred_builder.Attr(outside_compilation_attr_name, + outside_compilation_name); + recv_pred_builder.Attr(kXlaHasHostTransferAttrName, true); + recv_pred_builder.Input(key_placeholder->name(), 0, DT_STRING); + NodeDef recv_pred_def; + TF_RETURN_IF_ERROR(recv_pred_builder.Finalize(&recv_pred_def)); + Status s; + Node* recv_pred_node = host_graph.AddNode(recv_pred_def, &s); + TF_RETURN_IF_ERROR(s); + host_graph.AddEdge(key_placeholder, 0, recv_pred_node, 0); + + // Step 3: rewrite `{then, else}_branch_host_func_name`, replace key + // placeholder with an _Arg node. + TF_RETURN_IF_ERROR(ReplaceKeyPlaceholderWithArgNode( + xla_cluster_name, then_branch_host_func_name, fld)); + TF_RETURN_IF_ERROR(ReplaceKeyPlaceholderWithArgNode( + xla_cluster_name, else_branch_host_func_name, fld)); + + // Step 4: build If node to choose between `{then, else}_branch_host_graph`. + NodeDefBuilder if_builder(absl::StrCat("oc_if_", if_node_name), "If"); + if_builder.Attr("Tcond", DT_BOOL); + if_builder.Attr("Tin", std::vector{DT_STRING}); + if_builder.Attr("Tout", std::vector{}); + NameAttrList host_then_branch, host_else_branch; + host_then_branch.set_name(then_branch_host_func_name); + (*host_then_branch.mutable_attr())["device_ordinal"] = device_ordinal_value; + host_else_branch.set_name(else_branch_host_func_name); + (*host_else_branch.mutable_attr())["device_ordinal"] = device_ordinal_value; + if_builder.Attr("then_branch", host_then_branch); + if_builder.Attr("else_branch", host_else_branch); + if_builder.Attr(kXlaHasHostTransferAttrName, true); + if_builder.Attr(xla_cluster_attr_name, xla_cluster_name); + if_builder.Attr(outside_compilation_attr_name, outside_compilation_name); + if_builder.Input(recv_pred_node->name(), 0, DT_BOOL); + std::vector if_inputs{ + {key_placeholder->name(), 0, DT_STRING}}; + if_builder.Input(if_inputs); + NodeDef if_def; + TF_RETURN_IF_ERROR(if_builder.Finalize(&if_def)); + Node* if_node = host_graph.AddNode(if_def, &s); + TF_RETURN_IF_ERROR(s); + host_graph.AddEdge(recv_pred_node, 0, if_node, 0); + host_graph.AddEdge(key_placeholder, 0, if_node, 1); + + // Convert `host_graph` to function, and add a "device_ordinal" attr. + FunctionDef oc_host_graph_fdef; + TF_RETURN_IF_ERROR(GraphToFunctionDef(host_graph, host_graph_func_name, + &oc_host_graph_fdef)); + if (fld->Find(host_graph_func_name)) { + TF_RETURN_IF_ERROR( + fld->ReplaceFunction(host_graph_func_name, oc_host_graph_fdef)); + } else { + TF_RETURN_IF_ERROR(fld->AddFunctionDef(oc_host_graph_fdef)); + } + + return Status::OK(); +} + +// Rewrites loop cond to add a node which sends loop cond to host. +Status AddSendLoopPredToLoopCond(FunctionLibraryDefinition* fld, + const NameAttrList& loop_cond_func, + const string& while_node_name, + const string& host_transfer_key) { + // Instantiate the loop cond function. + FunctionBody* fbody = nullptr; + TF_RETURN_IF_ERROR(FunctionDefToBodyHelper( + *fld->Find(loop_cond_func.name()), AttrSlice(&loop_cond_func.attr()), fld, + [&](const string& op, const OpDef** sig) { + return fld->LookUpOpDef(op, sig); + }, + &fbody)); + std::unique_ptr fbody_deleter(fbody); + Graph* g = fbody->graph; + + // Find the _Retval node and the loop cond node. + Node* ret_node = nullptr; + for (Node* n : g->nodes()) { + if (n->type_string() == "_Retval") { + if (ret_node) { + return errors::Internal("Multiple return node for loop cond function ", + loop_cond_func.name(), ": ", + ret_node->DebugString(), " and ", + n->DebugString()); + } else { + ret_node = n; + } + } + } + if (!ret_node) { + return errors::Internal("No _Retval node for loop cond function ", + loop_cond_func.name()); + } + Node* loop_cond; + TF_RETURN_IF_ERROR(ret_node->input_node(0, &loop_cond)); + + // Build the XlaSendToHost node. + NodeDefBuilder send_loop_cond_builder( + absl::StrCat("send_oc_while_cond_", while_node_name), "XlaSendToHost"); + send_loop_cond_builder.Attr("Tinput", DT_BOOL); + send_loop_cond_builder.Attr("key", + absl::StrCat(host_transfer_key, "_dtoh_0")); + send_loop_cond_builder.Attr(kXlaTokenInputNodesAttrName, + std::vector{kXlaTokenArgNodeName}); + send_loop_cond_builder.Input(loop_cond->name(), 0, DT_BOOL); + NodeDef send_loop_cond_def; + TF_RETURN_IF_ERROR(send_loop_cond_builder.Finalize(&send_loop_cond_def)); + Status s; + Node* send_loop_cond_node = g->AddNode(send_loop_cond_def, &s); + TF_RETURN_IF_ERROR(s); + g->AddEdge(loop_cond, 0, send_loop_cond_node, 0); + + // Replace original function. + FunctionDef replace_fdef; + TF_RETURN_IF_ERROR( + GraphToFunctionDef(*g, loop_cond_func.name(), &replace_fdef)); + TF_RETURN_IF_ERROR(fld->ReplaceFunction(loop_cond_func.name(), replace_fdef)); + + return Status::OK(); +} + +// Rewrites while loop cond function for host. +Status RewriteHostWhileLoopCond( + const string& cond_host_func_name, const string& while_node_name, + const string& host_transfer_key, const string& xla_cluster_attr_name, + const string& xla_cluster_name, const string& outside_compilation_attr_name, + const string& outside_compilation_name, FunctionLibraryDefinition* fld) { + // Replace key placeholder node with _Arg node. + TF_RETURN_IF_ERROR(ReplaceKeyPlaceholderWithArgNode( + xla_cluster_name, cond_host_func_name, fld)); + + // Instantiate cond function. + AttrValue device_ordinal_temp_value; + device_ordinal_temp_value.set_i(0); + protobuf::Map attrs; + attrs["device_ordinal"] = device_ordinal_temp_value; + FunctionBody* cond_fbody = nullptr; + TF_RETURN_IF_ERROR(FunctionDefToBodyHelper( + *fld->Find(cond_host_func_name), AttrSlice(&attrs), fld, + [&](const string& op, const OpDef** sig) { + return fld->LookUpOpDef(op, sig); + }, + &cond_fbody)); + std::unique_ptr cond_fbody_deleter(cond_fbody); + Graph* cond_graph = cond_fbody->graph; + Node* key_arg = nullptr; + for (Node* n : cond_graph->nodes()) { + if (n->type_string() == "_Arg") { + key_arg = n; + } + } + if (!key_arg) { + return errors::Internal( + "No _Arg node found for host compute key in function ", + cond_host_func_name); + } + + // Add an XlaRecvAtHost node to use as cond function return value. + // We don't need to set kXlaHasHostTransferAttrName for this node, because + // it's already added for the "While" node on the host. + NodeDefBuilder recv_pred_builder( + absl::StrCat("recv_oc_while_cond_", while_node_name), "_XlaRecvAtHost"); + recv_pred_builder.Attr("Toutputs", std::vector{DT_BOOL}); + recv_pred_builder.Attr("key", host_transfer_key); + AttrValue device_ordinal_value; + device_ordinal_value.set_placeholder("device_ordinal"); + recv_pred_builder.Attr("device_ordinal", device_ordinal_value); + recv_pred_builder.Attr(xla_cluster_attr_name, xla_cluster_name); + recv_pred_builder.Attr(outside_compilation_attr_name, + outside_compilation_name); + recv_pred_builder.Input(key_arg->name(), 0, DT_STRING); + NodeDef recv_pred_def; + TF_RETURN_IF_ERROR(recv_pred_builder.Finalize(&recv_pred_def)); + Status s; + Node* recv_pred_node = cond_graph->AddNode(recv_pred_def, &s); + TF_RETURN_IF_ERROR(s); + cond_graph->AddEdge(key_arg, 0, recv_pred_node, 0); + NodeDefBuilder ret_builder( + absl::StrCat("recv_oc_while_cond_ret_", while_node_name), "_Retval"); + ret_builder.Attr("T", DT_BOOL); + ret_builder.Attr("index", 0); + ret_builder.Input(recv_pred_node->name(), 0, DT_BOOL); + NodeDef ret_def; + TF_RETURN_IF_ERROR(ret_builder.Finalize(&ret_def)); + Node* ret_node = cond_graph->AddNode(ret_def, &s); + TF_RETURN_IF_ERROR(s); + cond_graph->AddEdge(recv_pred_node, 0, ret_node, 0); + + // Reset device_ordinal to placeholder value. + TF_RETURN_IF_ERROR(ResetDeviceOrdinalToPlaceholderValue(cond_graph)); + + // Replace original function. + FunctionDef cond_replace_fdef; + TF_RETURN_IF_ERROR( + GraphToFunctionDef(*cond_graph, cond_host_func_name, &cond_replace_fdef)); + TF_RETURN_IF_ERROR( + fld->ReplaceFunction(cond_host_func_name, cond_replace_fdef)); + + return Status::OK(); +} + +// Rewrites while loop body function for host. +Status RewriteHostWhileLoopBody( + const string& body_host_func_name, const string& while_node_name, + const string& host_transfer_key, const string& xla_cluster_attr_name, + const string& xla_cluster_name, const string& outside_compilation_attr_name, + const string& outside_compilation_name, FunctionLibraryDefinition* fld) { + // Replace key placeholder node with _Arg node. + TF_RETURN_IF_ERROR(ReplaceKeyPlaceholderWithArgNode( + xla_cluster_name, body_host_func_name, fld)); + + // Instantiate body function. + AttrValue device_ordinal_temp_value; + device_ordinal_temp_value.set_i(0); + protobuf::Map attrs; + attrs["device_ordinal"] = device_ordinal_temp_value; + FunctionBody* body_fbody = nullptr; + TF_RETURN_IF_ERROR(FunctionDefToBodyHelper( + *fld->Find(body_host_func_name), AttrSlice(&attrs), fld, + [&](const string& op, const OpDef** sig) { + return fld->LookUpOpDef(op, sig); + }, + &body_fbody)); + std::unique_ptr body_fbody_deleter(body_fbody); + Graph* body_graph = body_fbody->graph; + Node* key_arg = nullptr; + for (Node* n : body_graph->nodes()) { + if (n->type_string() == "_Arg") { + key_arg = n; + } + } + if (!key_arg) { + return errors::Internal( + "No _Arg node found for host compute key in function ", + body_host_func_name); + } + + // Add a _Retval node to loop body. + NodeDefBuilder ret_builder( + absl::StrCat("recv_oc_while_body_ret_", while_node_name), "_Retval"); + ret_builder.Attr("T", DT_STRING); + ret_builder.Attr("index", 0); + ret_builder.Input(key_arg->name(), 0, DT_STRING); + NodeDef ret_def; + TF_RETURN_IF_ERROR(ret_builder.Finalize(&ret_def)); + Status s; + Node* ret_node = body_graph->AddNode(ret_def, &s); + TF_RETURN_IF_ERROR(s); + body_graph->AddEdge(key_arg, 0, ret_node, 0); + + // Reset device_ordinal to placeholder value. + TF_RETURN_IF_ERROR(ResetDeviceOrdinalToPlaceholderValue(body_graph)); + + // Replace original function. + FunctionDef body_replace_fdef; + TF_RETURN_IF_ERROR( + GraphToFunctionDef(*body_graph, body_host_func_name, &body_replace_fdef)); + TF_RETURN_IF_ERROR( + fld->ReplaceFunction(body_host_func_name, body_replace_fdef)); + + return Status::OK(); +} + +// Builds host side graph for while node. +Status BuildHostGraphForWhileNode( + const string& xla_cluster_attr_name, + const string& outside_compilation_attr_name, const string& xla_cluster_name, + const string& while_node_name, const string& host_transfer_key, + const string& host_graph_func_name, FunctionLibraryDefinition* fld, + const string& cond_host_func_name, const string& body_host_func_name) { + Graph host_graph(fld); + string outside_compilation_name = absl::StrCat("oc_while_", while_node_name); + + // Step 1: add key placeholder node. + TF_ASSIGN_OR_RETURN( + Node * key_placeholder, + AddHostComputeKeyPlaceholder(xla_cluster_name, &host_graph)); + + // Step 2: rewrite cond function. + TF_RETURN_IF_ERROR(RewriteHostWhileLoopCond( + cond_host_func_name, while_node_name, host_transfer_key, + xla_cluster_attr_name, xla_cluster_name, outside_compilation_attr_name, + outside_compilation_name, fld)); + + // Step 3: rewrite body function. + TF_RETURN_IF_ERROR(RewriteHostWhileLoopBody( + body_host_func_name, while_node_name, host_transfer_key, + xla_cluster_attr_name, xla_cluster_name, outside_compilation_attr_name, + outside_compilation_name, fld)); + + // Step 4: build While node. + NodeDefBuilder while_builder(absl::StrCat("oc_while_", while_node_name), + "While"); + while_builder.Attr("T", std::vector{DT_STRING}); + NameAttrList func; + AttrValue device_ordinal_value; + device_ordinal_value.set_placeholder("device_ordinal"); + (*func.mutable_attr())["device_ordinal"] = device_ordinal_value; + func.set_name(cond_host_func_name); + while_builder.Attr("cond", func); + func.set_name(body_host_func_name); + while_builder.Attr("body", func); + while_builder.Attr(kXlaHasHostTransferAttrName, true); + while_builder.Attr(xla_cluster_attr_name, xla_cluster_name); + while_builder.Attr(outside_compilation_attr_name, outside_compilation_name); + std::vector while_inputs{ + {key_placeholder->name(), 0, DT_STRING}}; + while_builder.Input(while_inputs); + NodeDef while_def; + TF_RETURN_IF_ERROR(while_builder.Finalize(&while_def)); + Status s; + Node* while_node = host_graph.AddNode(while_def, &s); + TF_RETURN_IF_ERROR(s); + host_graph.AddEdge(key_placeholder, 0, while_node, 0); + + // Convert `host_graph` to function. + FunctionDef oc_host_graph_fdef; + TF_RETURN_IF_ERROR(GraphToFunctionDef(host_graph, host_graph_func_name, + &oc_host_graph_fdef)); + if (fld->Find(host_graph_func_name)) { + TF_RETURN_IF_ERROR( + fld->ReplaceFunction(host_graph_func_name, oc_host_graph_fdef)); + } else { + TF_RETURN_IF_ERROR(fld->AddFunctionDef(oc_host_graph_fdef)); + } + + return Status::OK(); +} + +Status ExtractOutsideCompilationForNodesWithAssociatedFunctions( + Graph* g, const string& xla_cluster_attr_name, + const string& outside_compilation_attr_name, const string& xla_cluster_name, + const std::map& host_compute_core, + FunctionLibraryDefinition* fld, std::vector* host_graphs, + std::vector* shape_inference_graphs, + bool* has_outside_compilation) { + std::vector if_nodes, while_nodes; + for (Node* n : g->nodes()) { + if (n->type_string() == "If") { + if_nodes.push_back(n); + } else if (n->type_string() == "While") { + while_nodes.push_back(n); + } + } + + for (Node* n : if_nodes) { + // Instantiate "then_branch" and "else_branch". + NameAttrList then_branch, else_branch; + TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "then_branch", &then_branch)); + TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "else_branch", &else_branch)); + + // Extract outside compilation for then_branch and else_branch. + bool then_branch_has_outside_compilation = false; + bool else_branch_has_outside_compilation = false; + string then_branch_host_func_name = + absl::StrCat("oc_then_branch_host_if_", n->name()), + else_branch_host_func_name = + absl::StrCat("oc_else_branch_host_if_", n->name()); + string then_branch_xla_func_name = absl::StrCat(then_branch.name(), "_oc"), + else_branch_xla_func_name = absl::StrCat(else_branch.name(), "_oc"); + TF_RETURN_IF_ERROR(ExtractOutsideCompilationForFunction( + xla_cluster_attr_name, outside_compilation_attr_name, xla_cluster_name, + then_branch, then_branch_xla_func_name, then_branch_host_func_name, + host_compute_core, fld, shape_inference_graphs, + &then_branch_has_outside_compilation)); + TF_RETURN_IF_ERROR(ExtractOutsideCompilationForFunction( + xla_cluster_attr_name, outside_compilation_attr_name, xla_cluster_name, + else_branch, else_branch_xla_func_name, else_branch_host_func_name, + host_compute_core, fld, shape_inference_graphs, + &else_branch_has_outside_compilation)); + + // If then/else branch do not have outside compilation, nothing to do. + if (!then_branch_has_outside_compilation && + !else_branch_has_outside_compilation) { + continue; + } + + *has_outside_compilation = true; + + // Change If node to call the new functions. + then_branch.set_name(then_branch_xla_func_name); + n->ClearAttr("then_branch"); + n->AddAttr("then_branch", then_branch); + else_branch.set_name(else_branch_xla_func_name); + n->ClearAttr("else_branch"); + n->AddAttr("else_branch", else_branch); + + string host_transfer_key = absl::StrCat("oc_if_pred_", n->name()); + + // XLA computation: add a SendToHost node to send cond predicate. + Node* pred_node; + TF_RETURN_IF_ERROR(n->input_node(0, &pred_node)); + TF_ASSIGN_OR_RETURN( + Node * send_pred_node, + BuildSendIfPredNode(absl::StrCat("send_oc_if_pred_", n->name()), + host_transfer_key, pred_node, g)); + n->AddAttr(kXlaTokenInputNodesAttrName, + std::vector{send_pred_node->name()}); + + // Add a control edge from `send_pred_node` to If node, so XlaCompiler will + // visit If node after `send_pred_node`, thus the token output for + // `send_pred_node` has been generated. + g->AddControlEdge(send_pred_node, n); + + // Build host side graph for the "If" node. + string oc_host_graph_name = absl::StrCat("oc_if_host_graph_", n->name()); + TF_RETURN_IF_ERROR(BuildHostGraphForIfNode( + xla_cluster_attr_name, outside_compilation_attr_name, xla_cluster_name, + n->name(), host_transfer_key, oc_host_graph_name, fld, + then_branch_host_func_name, else_branch_host_func_name)); + host_graphs->push_back(oc_host_graph_name); + } + + for (Node* n : while_nodes) { + // Instantiate "cond" and "body". + NameAttrList cond, body; + TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "cond", &cond)); + TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "body", &body)); + + // Extract outside compilation for cond and body. + bool cond_has_outside_compilation = false; + bool body_has_outside_compilation = false; + string cond_host_func_name = absl::StrCat("oc_cond_host_while_", n->name()), + body_host_func_name = absl::StrCat("oc_body_host_while_", n->name()); + string cond_xla_func_name = absl::StrCat(cond.name(), "_oc"), + body_xla_func_name = absl::StrCat(body.name(), "_oc"); + TF_RETURN_IF_ERROR(ExtractOutsideCompilationForFunction( + xla_cluster_attr_name, outside_compilation_attr_name, xla_cluster_name, + cond, cond_xla_func_name, cond_host_func_name, host_compute_core, fld, + shape_inference_graphs, &cond_has_outside_compilation)); + TF_RETURN_IF_ERROR(ExtractOutsideCompilationForFunction( + xla_cluster_attr_name, outside_compilation_attr_name, xla_cluster_name, + body, body_xla_func_name, body_host_func_name, host_compute_core, fld, + shape_inference_graphs, &body_has_outside_compilation)); + + // If cond/body do not have outside compilation, nothing to do. + if (!cond_has_outside_compilation && !body_has_outside_compilation) { + continue; + } + + *has_outside_compilation = true; + + // Change While node to call the new functions. + cond.set_name(cond_xla_func_name); + n->ClearAttr("cond"); + n->AddAttr("cond", cond); + body.set_name(body_xla_func_name); + n->ClearAttr("body"); + n->AddAttr("body", body); + + string host_transfer_key = absl::StrCat("oc_while_pred_", n->name()); + + // XLA computation: rewrite cond function to add a SendToHost node to send + // loop predicate. + TF_RETURN_IF_ERROR( + AddSendLoopPredToLoopCond(fld, cond, n->name(), host_transfer_key)); + n->AddAttr(kXlaTokenInputNodesAttrName, + std::vector{kXlaTokenArgNodeName}); + + // Build host side graph for the "While" node. + string oc_host_graph_name = absl::StrCat("oc_while_host_graph_", n->name()); + TF_RETURN_IF_ERROR(BuildHostGraphForWhileNode( + xla_cluster_attr_name, outside_compilation_attr_name, xla_cluster_name, + n->name(), host_transfer_key, oc_host_graph_name, fld, + cond_host_func_name, body_host_func_name)); + host_graphs->push_back(oc_host_graph_name); + } + + return Status::OK(); +} + } // namespace Status RewriteOutsideCompilationSubgraphFn::operator()( @@ -755,12 +1413,15 @@ Status RewriteOutsideCompilationSubgraphFn::operator()( // it with HostCompute node later. AddNodeAttr("_outside_compilation_subgraph", old_name, node_def); if (shapes) { - AddNodeAttr("shape_inference_graph", "", node_def); + NameAttrList shape_inference_graph; + AddNodeAttr("shape_inference_graph", shape_inference_graph, node_def); AddNodeAttr("shapes", *shapes, node_def); } else { string shape_inference_func_name = absl::StrCat("_outside_compilation_shape_inference_", new_name); - AddNodeAttr("shape_inference_graph", shape_inference_func_name, node_def); + NameAttrList shape_inference_graph; + shape_inference_graph.set_name(shape_inference_func_name); + AddNodeAttr("shape_inference_graph", shape_inference_graph, node_def); AddNodeAttr("shapes", std::vector{}, node_def); } AddNodeAttr("ancestors", std::vector{}, node_def); @@ -775,11 +1436,10 @@ Status ExtractOutsideCompilationForFunction( const string& xla_cluster_attr_name, const string& outside_compilation_attr_name, const string& xla_cluster_name, const NameAttrList& func_name_attrs, const string& new_func_name, + const string& host_graph_func_name, const std::map& host_compute_core, - FunctionLibraryDefinition* fld, std::unique_ptr* host_graph, - std::vector* shape_inference_graphs, + FunctionLibraryDefinition* fld, std::vector* shape_inference_graphs, bool* has_outside_compilation) { - // Early return if function does not have any outside compilation nodes. const string& func_name = func_name_attrs.name(); const FunctionDef* fdef = fld->Find(func_name); if (!fdef) { @@ -792,9 +1452,8 @@ Status ExtractOutsideCompilationForFunction( break; } } - if (!has_outside_compilation) { - return Status::OK(); - } + // We cannot early return here, because we might have outside compilation in + // If/While function body. // Convert the function to graph. FunctionBody* fbody = nullptr; @@ -835,11 +1494,11 @@ Status ExtractOutsideCompilationForFunction( // If we could not infer shapes for XlaSendFromHost inputs statically, we // will set the "shape_inference_graph" attribute. In that case, copy // outside compilation subgraph as shape inference graph in `fld`. - string shape_inference_graph; + NameAttrList shape_inference_graph; TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "shape_inference_graph", &shape_inference_graph)); - if (!shape_inference_graph.empty()) { - shape_inference_graphs->push_back(shape_inference_graph); + if (!shape_inference_graph.name().empty()) { + shape_inference_graphs->push_back(shape_inference_graph.name()); const FunctionDef* xla_fdef = fld->Find(n->name()); if (!xla_fdef) { @@ -847,9 +1506,9 @@ Status ExtractOutsideCompilationForFunction( } FunctionDef shape_inference_fdef = *xla_fdef; shape_inference_fdef.mutable_signature()->set_name( - shape_inference_graph); - if (fld->Find(shape_inference_graph)) { - TF_RETURN_IF_ERROR(fld->ReplaceFunction(shape_inference_graph, + shape_inference_graph.name()); + if (fld->Find(shape_inference_graph.name())) { + TF_RETURN_IF_ERROR(fld->ReplaceFunction(shape_inference_graph.name(), shape_inference_fdef)); } else { TF_RETURN_IF_ERROR(fld->AddFunctionDef(shape_inference_fdef)); @@ -867,12 +1526,17 @@ Status ExtractOutsideCompilationForFunction( *graph_out, fld); } + // Handle nodes with associated functions. + TF_RETURN_IF_ERROR(ExtractOutsideCompilationForNodesWithAssociatedFunctions( + graph_out.get(), xla_cluster_attr_name, outside_compilation_attr_name, + xla_cluster_name, host_compute_core, fld, + &outside_compilation_host_graphs, shape_inference_graphs, + has_outside_compilation)); + // Construct host graph. - if (!outside_compilation_host_graphs.empty()) { - TF_RETURN_IF_ERROR( - ConstructHostGraph(xla_cluster_name, outside_compilation_attr_name, - outside_compilation_host_graphs, fld, host_graph)); - } + TF_RETURN_IF_ERROR(ConstructHostGraph( + xla_cluster_name, outside_compilation_attr_name, + outside_compilation_host_graphs, fld, host_graph_func_name)); // Remove the outside compilation graphs from function library. for (const string& func : outside_compilation_host_graphs) { @@ -909,24 +1573,17 @@ Status ExtractOutsideCompilation( auto const& host_compute_core = iter.second.host_compute_core; bool has_outside_compilation; - std::unique_ptr host_graph; + string host_graph_func_name = absl::StrCat("oc_host_graph_", n->name()); TF_RETURN_IF_ERROR(ExtractOutsideCompilationForFunction( xla_cluster_attr_name, outside_compilation_attr_name, xla_cluster_name, - func_name_attrs, func_name_attrs.name(), host_compute_core, fld, - &host_graph, &shape_inference_graphs, &has_outside_compilation)); - if (host_graph) { - TF_RETURN_IF_ERROR(ExpandHostGraphIntoMainGraph(g, host_graph.get(), n)); - } - } - - if (VLOG_IS_ON(4)) { - dump_graph::DumpGraphToFile("extract_outside_compilation_expanded", *g, - fld); + func_name_attrs, func_name_attrs.name(), host_graph_func_name, + host_compute_core, fld, &shape_inference_graphs, + &has_outside_compilation)); + TF_RETURN_IF_ERROR( + ExpandHostGraphIntoMainGraph(g, fld, host_graph_func_name, n)); + TF_RETURN_IF_ERROR(fld->RemoveFunction(host_graph_func_name)); } - TF_RETURN_IF_ERROR(PostprocessForEncapsulation( - g, xla_cluster_attr_name, outside_compilation_attr_name, clusters)); - for (auto shape_inference_graph_name : shape_inference_graphs) { TF_RETURN_IF_ERROR( RewriteShapeInferenceGraph(shape_inference_graph_name, g, fld)); diff --git a/tensorflow/compiler/jit/extract_outside_compilation_pass.h b/tensorflow/compiler/jit/extract_outside_compilation_pass.h index 2a4f07cca213d999202024294f5d8f94527059c3..e07e7c5dd0cd42ddd4d643d8b36583c82056bbb2 100644 --- a/tensorflow/compiler/jit/extract_outside_compilation_pass.h +++ b/tensorflow/compiler/jit/extract_outside_compilation_pass.h @@ -88,9 +88,10 @@ Status ExtractOutsideCompilationForFunction( const string& xla_cluster_attr_name, const string& outside_compilation_attr_name, const string& xla_cluster_name, const NameAttrList& func_name_attrs, const string& new_func_name, + const string& host_graph_func_name, const std::map& host_compute_core, - FunctionLibraryDefinition* fld, std::unique_ptr* host_graph, - std::vector* shape_inference_graphs, bool* has_outside_compilation); + FunctionLibraryDefinition* fld, std::vector* shape_inference_graphs, + bool* has_outside_compilation); // Rewrites XLA computation in `clusters` to replace outside compilation nodes // with XlaHostCompute, and moves those outside compilations into `g`. If shapes diff --git a/tensorflow/compiler/jit/extract_outside_compilation_pass_test.cc b/tensorflow/compiler/jit/extract_outside_compilation_pass_test.cc index bff956100da661b679b4557fce53671e6cef88c5..e9a89e34e0c7b04b4be34e367b2d0bf627c0061a 100644 --- a/tensorflow/compiler/jit/extract_outside_compilation_pass_test.cc +++ b/tensorflow/compiler/jit/extract_outside_compilation_pass_test.cc @@ -19,8 +19,10 @@ limitations under the License. #include "tensorflow/cc/framework/scope.h" #include "tensorflow/cc/ops/array_ops.h" #include "tensorflow/cc/ops/function_ops.h" +#include "tensorflow/cc/ops/functional_ops.h" #include "tensorflow/cc/ops/standard_ops.h" #include "tensorflow/compiler/jit/encapsulate_util.h" +#include "tensorflow/compiler/xla/test.h" #include "tensorflow/core/common_runtime/function.h" #include "tensorflow/core/framework/common_shape_fns.h" #include "tensorflow/core/framework/function.h" @@ -109,10 +111,10 @@ TEST(RewriteOutsideCompilationSubgraphFnTest, Basic) { } EXPECT_TRUE(has_control_edge_to_send_from_host); // Verify step 7: necessary attrs added to call_node_def. - string shape_inference_graph; + NameAttrList shape_inference_graph; TF_CHECK_OK(GetNodeAttr(AttrSlice(&call_node_def.attr()), "shape_inference_graph", &shape_inference_graph)); - EXPECT_EQ(shape_inference_graph, + EXPECT_EQ(shape_inference_graph.name(), "_outside_compilation_shape_inference_cluster_0"); } @@ -249,27 +251,26 @@ TEST(ExtractOutsideCompilationForFunctionTest, Basic) { protobuf::Map attrs; std::map host_compute_core = {{"0", 1}, {"1", 0}}; - std::unique_ptr host_graph; std::vector shape_inference_graphs; bool has_outside_compilation; NameAttrList name_attrs; name_attrs.set_name("cluster"); *name_attrs.mutable_attr() = attrs; TF_CHECK_OK(ExtractOutsideCompilationForFunction( - "_xla", "_oc", "cluster", name_attrs, "cluster_rewritten", - host_compute_core, &fld, &host_graph, &shape_inference_graphs, + "_xla", "_oc", "cluster", name_attrs, "cluster_rewritten", "host_graph", + host_compute_core, &fld, &shape_inference_graphs, &has_outside_compilation)); // Get rewritten XLA computation function. - FunctionBody *fbody = nullptr; - TF_CHECK_OK(FunctionDefToBodyHelper(*fld.Find("cluster_rewritten"), - AttrSlice(), &fld, - [&](const string &op, const OpDef **sig) { - return fld.LookUpOpDef(op, sig); - }, - &fbody)); - std::unique_ptr fbody_deleter(fbody); - auto node_name_index = fbody->graph->BuildNodeNameIndex(); + FunctionBody *xla_fbody = nullptr; + TF_CHECK_OK(FunctionDefToBodyHelper( + *fld.Find("cluster_rewritten"), AttrSlice(), &fld, + [&](const string &op, const OpDef **sig) { + return fld.LookUpOpDef(op, sig); + }, + &xla_fbody)); + std::unique_ptr xla_fbody_deleter(xla_fbody); + auto node_name_index = xla_fbody->graph->BuildNodeNameIndex(); // Check XlaHostCompute nodes. Node *host_compute_0 = node_name_index["outside_compilation_0_host_compute"]; @@ -292,18 +293,31 @@ TEST(ExtractOutsideCompilationForFunctionTest, Basic) { EXPECT_EQ(shapes[0].dim_size(), 1); // Check XlaHostCompute nodes' "shape_inference_graph" attr. Both should have // empty values. - string shape_inference_graph; + NameAttrList shape_inference_graph; TF_CHECK_OK(GetNodeAttr(host_compute_0->attrs(), "shape_inference_graph", &shape_inference_graph)); - EXPECT_EQ(shape_inference_graph, ""); + EXPECT_EQ(shape_inference_graph.name(), ""); TF_CHECK_OK(GetNodeAttr(host_compute_1->attrs(), "shape_inference_graph", &shape_inference_graph)); - EXPECT_EQ(shape_inference_graph, ""); + EXPECT_EQ(shape_inference_graph.name(), ""); // Check `shape_inference_graphs`. EXPECT_EQ(shape_inference_graphs.size(), 0); - // Check `host_graph`: verify we have key placeholder and sequencer. + // Check host graph: verify we have key placeholder and sequencer. + FunctionBody *host_fbody = nullptr; + AttrValue device_ordinal_temp_value; + device_ordinal_temp_value.set_i(0); + protobuf::Map host_func_attrs; + host_func_attrs["device_ordinal"] = device_ordinal_temp_value; + TF_CHECK_OK(FunctionDefToBodyHelper( + *fld.Find("host_graph"), AttrSlice(&host_func_attrs), &fld, + [&](const string &op, const OpDef **sig) { + return fld.LookUpOpDef(op, sig); + }, + &host_fbody)); + std::unique_ptr host_fbody_deleter(host_fbody); + Graph *host_graph = host_fbody->graph; Node *key_placeholder = nullptr, *sequencer = nullptr; for (Node *n : host_graph->nodes()) { if (n->type_string() == "Placeholder" && @@ -365,25 +379,37 @@ TEST(ExtractOutsideCompilationForFunctionTest, NoHostGraph) { protobuf::Map attrs; std::map host_compute_core = {{"0", 1}, {"1", 0}}; - std::unique_ptr host_graph; std::vector shape_inference_graphs; bool has_outside_compilation; NameAttrList name_attrs; name_attrs.set_name("cluster"); *name_attrs.mutable_attr() = attrs; TF_CHECK_OK(ExtractOutsideCompilationForFunction( - "_xla", "_oc", "cluster", name_attrs, "cluster_rewritten", - host_compute_core, &fld, &host_graph, &shape_inference_graphs, + "_xla", "_oc", "cluster", name_attrs, "cluster_rewritten", "host_graph", + host_compute_core, &fld, &shape_inference_graphs, &has_outside_compilation)); - // Check `host_graph` is empty. - EXPECT_FALSE(host_graph); + // Check host graph is empty. + FunctionBody *host_fbody = nullptr; + AttrValue device_ordinal_temp_value; + device_ordinal_temp_value.set_i(0); + protobuf::Map host_func_attrs; + host_func_attrs["device_ordinal"] = device_ordinal_temp_value; + TF_CHECK_OK(FunctionDefToBodyHelper( + *fld.Find("host_graph"), AttrSlice(&host_func_attrs), &fld, + [&](const string &op, const OpDef **sig) { + return fld.LookUpOpDef(op, sig); + }, + &host_fbody)); + std::unique_ptr host_fbody_deleter(host_fbody); + Graph *host_graph = host_fbody->graph; + EXPECT_EQ(host_graph->num_nodes(), 2); } TEST(ExtractOutsideCompilationForFunctionTest, XlaHostComputeRemoved) { // Build the XLA computation func. // "const0" - // "const1" (outside compilation clsuter "0") + // "const1" (outside compilation cluster "0") FunctionDefLibrary fdl; { tensorflow::Scope s = tensorflow::Scope::NewRootScope(); @@ -401,31 +427,43 @@ TEST(ExtractOutsideCompilationForFunctionTest, XlaHostComputeRemoved) { protobuf::Map attrs; std::map host_compute_core = {{"0", 1}, {"1", 0}}; - std::unique_ptr host_graph; std::vector shape_inference_graphs; bool has_outside_compilation; NameAttrList name_attrs; name_attrs.set_name("cluster"); *name_attrs.mutable_attr() = attrs; TF_CHECK_OK(ExtractOutsideCompilationForFunction( - "_xla", "_oc", "cluster", name_attrs, "cluster_rewritten", - host_compute_core, &fld, &host_graph, &shape_inference_graphs, + "_xla", "_oc", "cluster", name_attrs, "cluster_rewritten", "host_graph", + host_compute_core, &fld, &shape_inference_graphs, &has_outside_compilation)); // Check rewritten XLA graph: verify that we have no XlaHostCompute. - FunctionBody *fbody = nullptr; - TF_CHECK_OK(FunctionDefToBodyHelper(*fld.Find("cluster_rewritten"), - AttrSlice(), &fld, - [&](const string &op, const OpDef **sig) { - return fld.LookUpOpDef(op, sig); - }, - &fbody)); - std::unique_ptr fbody_deleter(fbody); - for (Node *n : fbody->graph->nodes()) { + FunctionBody *xla_fbody = nullptr; + TF_CHECK_OK(FunctionDefToBodyHelper( + *fld.Find("cluster_rewritten"), AttrSlice(), &fld, + [&](const string &op, const OpDef **sig) { + return fld.LookUpOpDef(op, sig); + }, + &xla_fbody)); + std::unique_ptr xla_fbody_deleter(xla_fbody); + for (Node *n : xla_fbody->graph->nodes()) { EXPECT_NE(n->type_string(), "XlaHostCompute"); } - // Check `host_graph`: verify we have no placeholder, but we have "const1". + // Check host graph: verify we have no placeholder, but we have "const1". + FunctionBody *host_fbody = nullptr; + AttrValue device_ordinal_temp_value; + device_ordinal_temp_value.set_i(0); + protobuf::Map host_func_attrs; + host_func_attrs["device_ordinal"] = device_ordinal_temp_value; + TF_CHECK_OK(FunctionDefToBodyHelper( + *fld.Find("host_graph"), AttrSlice(&host_func_attrs), &fld, + [&](const string &op, const OpDef **sig) { + return fld.LookUpOpDef(op, sig); + }, + &host_fbody)); + std::unique_ptr host_fbody_deleter(host_fbody); + Graph *host_graph = host_fbody->graph; int num_key_placeholders = 0; for (Node *n : host_graph->nodes()) { if (n->type_string() == "Placeholder" && @@ -438,4 +476,310 @@ TEST(ExtractOutsideCompilationForFunctionTest, XlaHostComputeRemoved) { EXPECT_NE(node_name_index.find("const1"), node_name_index.end()); } +REGISTER_OP("XlaSendToHost") + .Input("input: Tinput") + .Attr("Tinput: type") + .Attr("key: string") + .SetIsStateful(); + +REGISTER_OP("XlaRecvFromHost") + .Output("output: Toutput") + .Attr("Toutput: type") + .Attr("shape: shape") + .Attr("key: string") + .SetIsStateful(); + +TEST(ExtractOutsideCompilationForFunctionTest, OutsideCompilationInIf) { + // Build the XLA computation func. + // "const0" (bool) + // "const1" (int32) + // "if0" (pred = "const0", input = "const1", then_branch = "true_fn", + // else_branch = "false_fn") + FunctionDefLibrary fdl; + { + tensorflow::Scope s = tensorflow::Scope::NewRootScope(); + Output arg = ops::_Arg(s.WithOpName("arg"), DT_INT32, 0); + Output identity = ops::Identity(s.WithOpName("identity_true_fn"), arg); + ops::_Retval retval(s.WithOpName("retval"), identity, 0); + std::unique_ptr g(new Graph(OpRegistry::Global())); + TF_CHECK_OK(s.ToGraph(g.get())); + auto node_name_image = g->BuildNodeNameIndex(); + node_name_image["identity_true_fn"]->AddAttr("_oc", "0"); + PartialTensorShape shape({2}); + node_name_image["identity_true_fn"]->AddAttr( + kXlaInferredShapesAttrName, std::vector{shape}); + + FunctionDef *true_fn_fdef = fdl.add_function(); + TF_CHECK_OK(GraphToFunctionDef(*g, "true_fn", true_fn_fdef)); + } + { + tensorflow::Scope s = tensorflow::Scope::NewRootScope(); + Output arg = ops::_Arg(s.WithOpName("arg"), DT_INT32, 0); + Output identity = ops::Identity(s.WithOpName("identity_false_fn"), arg); + ops::_Retval retval(s.WithOpName("retval"), identity, 0); + std::unique_ptr g(new Graph(OpRegistry::Global())); + TF_CHECK_OK(s.ToGraph(g.get())); + auto node_name_image = g->BuildNodeNameIndex(); + node_name_image["identity_false_fn"]->AddAttr("_oc", "0"); + PartialTensorShape shape({2}); + node_name_image["identity_false_fn"]->AddAttr( + kXlaInferredShapesAttrName, std::vector{shape}); + + FunctionDef *false_fn_fdef = fdl.add_function(); + TF_CHECK_OK(GraphToFunctionDef(*g, "false_fn", false_fn_fdef)); + } + { + tensorflow::Scope s = tensorflow::Scope::NewRootScope(); + Output cond = ops::Const(s.WithOpName("const0"), true, {2}); + Output input = ops::Const(s.WithOpName("const1"), 1, {2}); + NameAttrList true_fn; + true_fn.set_name("true_fn"); + NameAttrList false_fn; + false_fn.set_name("false_fn"); + auto if_op = ops::If(s.WithOpName("if"), cond, + std::initializer_list{cond, input}, {DT_INT32}, + true_fn, false_fn); + ops::_Retval retval(s.WithOpName("retval"), if_op.output[0], 0); + std::unique_ptr g(new Graph(OpRegistry::Global())); + TF_CHECK_OK(s.ToGraph(g.get())); + + FunctionDef *xla_fdef = fdl.add_function(); + TF_CHECK_OK(GraphToFunctionDef(*g, "cluster", xla_fdef)); + } + FunctionLibraryDefinition fld(OpRegistry::Global(), fdl); + + protobuf::Map attrs; + std::map host_compute_core; + std::vector shape_inference_graphs; + bool has_outside_compilation; + NameAttrList name_attrs; + name_attrs.set_name("cluster"); + *name_attrs.mutable_attr() = attrs; + TF_CHECK_OK(ExtractOutsideCompilationForFunction( + "_xla", "_oc", "cluster", name_attrs, "cluster_rewritten", "host_graph", + host_compute_core, &fld, &shape_inference_graphs, + &has_outside_compilation)); + + // Check host graph. + { + FunctionBody *host_fbody = nullptr; + AttrValue device_ordinal_temp_value; + device_ordinal_temp_value.set_i(0); + protobuf::Map host_func_attrs; + host_func_attrs["device_ordinal"] = device_ordinal_temp_value; + TF_CHECK_OK(FunctionDefToBodyHelper( + *fld.Find("host_graph"), AttrSlice(&host_func_attrs), &fld, + [&](const string &op, const OpDef **sig) { + return fld.LookUpOpDef(op, sig); + }, + &host_fbody)); + std::unique_ptr host_fbody_deleter(host_fbody); + Graph *host_graph = host_fbody->graph; + auto node_name_index = host_graph->BuildNodeNameIndex(); + + // Verify we have XlaRecvAtHost to receive "If" predicate. + Node *recv_if_pred_node = node_name_index["recv_oc_if_pred_if"]; + EXPECT_NE(recv_if_pred_node, nullptr); + + // Verify we have an "If" to choose outside compilation between then_branch + // and else_branch, and it has `recv_if_pred_node` as cond input. + Node *if_oc_node = node_name_index["oc_if_if"]; + EXPECT_NE(if_oc_node, nullptr); + Node *if_oc_node_cond_input; + TF_CHECK_OK(if_oc_node->input_node(0, &if_oc_node_cond_input)); + EXPECT_EQ(if_oc_node_cond_input, recv_if_pred_node); + + // Check that then_branch outside compilation has node "identity_true_fn". + const FunctionDef *true_def = fld.Find("oc_then_branch_host_if_if"); + EXPECT_NE(true_def, nullptr); + bool has_identity_true_fn_node = false; + for (const auto &node_def : true_def->node_def()) { + if (node_def.name() == "identity_true_fn") { + has_identity_true_fn_node = true; + break; + } + } + EXPECT_TRUE(has_identity_true_fn_node); + + // Check that else_branch outside compilation has node "identity_false_fn". + const FunctionDef *false_def = fld.Find("oc_else_branch_host_if_if"); + EXPECT_NE(false_def, nullptr); + bool has_identity_false_fn_node = false; + for (const auto &node_def : false_def->node_def()) { + if (node_def.name() == "identity_false_fn") { + has_identity_false_fn_node = true; + break; + } + } + EXPECT_TRUE(has_identity_false_fn_node); + } + + // Check XLA graph. + { + FunctionBody *xla_fbody = nullptr; + TF_CHECK_OK(FunctionDefToBodyHelper( + *fld.Find("cluster_rewritten"), AttrSlice(), &fld, + [&](const string &op, const OpDef **sig) { + return fld.LookUpOpDef(op, sig); + }, + &xla_fbody)); + std::unique_ptr xla_fbody_deleter(xla_fbody); + Graph *xla_graph = xla_fbody->graph; + auto node_name_index = xla_graph->BuildNodeNameIndex(); + + // Check that we have XlaSendToHost to send cond predicate to host, and + // there is a control edge to If node. + Node *send_if_pred_node = node_name_index["send_oc_if_pred_if"]; + EXPECT_NE(send_if_pred_node, nullptr); + bool has_control_edge_to_if = false; + for (const Edge *e : send_if_pred_node->out_edges()) { + if (e->IsControlEdge() && e->dst()->name() == "if") { + has_control_edge_to_if = true; + break; + } + } + EXPECT_TRUE(has_control_edge_to_if); + + // Check that the "If" node now has `send_if_pred_node` as attribute + // _xla_token_input_nodes. + Node *if_node = node_name_index["if"]; + EXPECT_NE(if_node, nullptr); + std::vector token_inputs; + TF_CHECK_OK( + GetNodeAttr(if_node->def(), "_xla_token_input_nodes", &token_inputs)); + EXPECT_THAT(token_inputs, ::testing::ElementsAre("send_oc_if_pred_if")); + } +} + +TEST(ExtractOutsideCompilationForFunctionTest, OutsideCompilationInWhile) { + // Build the XLA computation func. + // "const0" (bool) + // "while0" (input = "const0", cond = "cond_fn", body = "body_fn") + FunctionDefLibrary fdl; + { + tensorflow::Scope s = tensorflow::Scope::NewRootScope(); + Output arg = ops::_Arg(s.WithOpName("arg"), DT_BOOL, 0); + Output identity = ops::Identity(s.WithOpName("identity_cond_fn"), arg); + ops::_Retval retval(s.WithOpName("retval"), identity, 0); + std::unique_ptr g(new Graph(OpRegistry::Global())); + TF_CHECK_OK(s.ToGraph(g.get())); + auto node_name_image = g->BuildNodeNameIndex(); + node_name_image["identity_cond_fn"]->AddAttr("_oc", "0"); + PartialTensorShape shape({2}); + node_name_image["identity_cond_fn"]->AddAttr( + kXlaInferredShapesAttrName, std::vector{shape}); + + FunctionDef *cond_fn_fdef = fdl.add_function(); + TF_CHECK_OK(GraphToFunctionDef(*g, "cond_fn", cond_fn_fdef)); + } + { + tensorflow::Scope s = tensorflow::Scope::NewRootScope(); + Output arg = ops::_Arg(s.WithOpName("arg"), DT_BOOL, 0); + Output identity = ops::Identity(s.WithOpName("identity_body_fn"), arg); + ops::_Retval retval(s.WithOpName("retval"), identity, 0); + std::unique_ptr g(new Graph(OpRegistry::Global())); + TF_CHECK_OK(s.ToGraph(g.get())); + auto node_name_image = g->BuildNodeNameIndex(); + node_name_image["identity_body_fn"]->AddAttr("_oc", "0"); + PartialTensorShape shape({2}); + node_name_image["identity_body_fn"]->AddAttr( + kXlaInferredShapesAttrName, std::vector{shape}); + + FunctionDef *body_fn_fdef = fdl.add_function(); + TF_CHECK_OK(GraphToFunctionDef(*g, "body_fn", body_fn_fdef)); + } + { + tensorflow::Scope s = tensorflow::Scope::NewRootScope(); + Output input = ops::Const(s.WithOpName("const0"), true, {2}); + NameAttrList cond_fn; + cond_fn.set_name("cond_fn"); + NameAttrList body_fn; + body_fn.set_name("body_fn"); + auto while_op = + ops::While(s.WithOpName("while"), std::initializer_list{input}, + cond_fn, body_fn); + ops::_Retval retval(s.WithOpName("retval"), while_op.output[0], 0); + std::unique_ptr g(new Graph(OpRegistry::Global())); + TF_CHECK_OK(s.ToGraph(g.get())); + + FunctionDef *xla_fdef = fdl.add_function(); + TF_CHECK_OK(GraphToFunctionDef(*g, "cluster", xla_fdef)); + } + FunctionLibraryDefinition fld(OpRegistry::Global(), fdl); + + protobuf::Map attrs; + std::map host_compute_core; + std::vector shape_inference_graphs; + bool has_outside_compilation; + NameAttrList name_attrs; + name_attrs.set_name("cluster"); + *name_attrs.mutable_attr() = attrs; + TF_CHECK_OK(ExtractOutsideCompilationForFunction( + "_xla", "_oc", "cluster", name_attrs, "cluster_rewritten", "host_graph", + host_compute_core, &fld, &shape_inference_graphs, + &has_outside_compilation)); + + // Check host graph. + { + FunctionBody *host_fbody = nullptr; + AttrValue device_ordinal_temp_value; + device_ordinal_temp_value.set_i(0); + protobuf::Map host_func_attrs; + host_func_attrs["device_ordinal"] = device_ordinal_temp_value; + TF_CHECK_OK(FunctionDefToBodyHelper( + *fld.Find("host_graph"), AttrSlice(&host_func_attrs), &fld, + [&](const string &op, const OpDef **sig) { + return fld.LookUpOpDef(op, sig); + }, + &host_fbody)); + std::unique_ptr host_fbody_deleter(host_fbody); + Graph *host_graph = host_fbody->graph; + auto node_name_index = host_graph->BuildNodeNameIndex(); + + // Verify we have an "While" to execute outside compilation. + Node *while_oc_node = node_name_index["oc_while_while"]; + EXPECT_NE(while_oc_node, nullptr); + + // Check that cond outside compilation has node "identity_cond_fn". + const FunctionDef *cond_def = fld.Find("oc_cond_host_while_while"); + EXPECT_NE(cond_def, nullptr); + bool has_identity_cond_fn_node = false; + for (const auto &node_def : cond_def->node_def()) { + if (node_def.name() == "identity_cond_fn") { + has_identity_cond_fn_node = true; + break; + } + } + EXPECT_TRUE(has_identity_cond_fn_node); + + // Check that body outside compilation has node "identity_body_fn". + const FunctionDef *body_def = fld.Find("oc_body_host_while_while"); + EXPECT_NE(body_def, nullptr); + bool has_identity_body_fn_node = false; + for (const auto &node_def : body_def->node_def()) { + if (node_def.name() == "identity_body_fn") { + has_identity_body_fn_node = true; + break; + } + } + EXPECT_TRUE(has_identity_body_fn_node); + } + + // Check XLA graph. + { + // Verify that rewritten cond fn has XlaSendToHost to send loop predicate to + // host. + const FunctionDef *cond_def = fld.Find("cond_fn_oc"); + EXPECT_NE(cond_def, nullptr); + bool has_send_oc_while_cond_node = false; + for (const auto &node_def : cond_def->node_def()) { + if (node_def.name() == "send_oc_while_cond_while") { + has_send_oc_while_cond_node = true; + break; + } + } + EXPECT_TRUE(has_send_oc_while_cond_node); + } +} + } // namespace tensorflow diff --git a/tensorflow/compiler/jit/partially_decluster_pass.cc b/tensorflow/compiler/jit/partially_decluster_pass.cc index 42ea3926e16ae791dbe1bede3b8742383db7667c..e1fd2aaee2822daeffb415d053c9c4f56002a856 100644 --- a/tensorflow/compiler/jit/partially_decluster_pass.cc +++ b/tensorflow/compiler/jit/partially_decluster_pass.cc @@ -120,6 +120,7 @@ Status PartiallyDeclusterNode(Graph* graph, Node* n) { NodeDef ndef = n->def(); ndef.set_name(absl::StrCat(n->name(), "/declustered")); + MergeDebugInfo(NodeDebugInfo(n->def()), &ndef); RemoveFromXlaCluster(&ndef); Status s; Node* cloned_node = graph->AddNode(ndef, &s); diff --git a/tensorflow/compiler/jit/shape_inference.cc b/tensorflow/compiler/jit/shape_inference.cc index 80c691fe490c1092315708a2da754d367d585300..a27e0d9f2a6ecddfdbdb29be673084d77a178d8a 100644 --- a/tensorflow/compiler/jit/shape_inference.cc +++ b/tensorflow/compiler/jit/shape_inference.cc @@ -53,7 +53,15 @@ Status PropagateShapes(const Graph& graph, // shapes, even if no shape function is registered for a node. Status status = shape_refiner->AddNode(n); if (!status.ok()) { - VLOG(1) << "Shape inference failed for node: " << status; + VLOG(1) << "Shape inference failed for node " << n->name() << ": " + << status; + } else { + shape_inference::InferenceContext* context = shape_refiner->GetContext(n); + for (int i = 0; i < n->num_outputs(); i++) { + shape_inference::ShapeHandle handle = context->output(i); + VLOG(4) << "Output " << i << " for node " << n->name() << ": " + << context->DebugString(handle); + } } if (n->type_string() == "_Arg") { diff --git a/tensorflow/compiler/jit/xla_device_context.cc b/tensorflow/compiler/jit/xla_device_context.cc index 6e6532731e64bd42ee56aa719748988f321e0f17..1f3afe8822d441a5ce37617fe18d7767e9bc72e4 100644 --- a/tensorflow/compiler/jit/xla_device_context.cc +++ b/tensorflow/compiler/jit/xla_device_context.cc @@ -79,6 +79,13 @@ XlaDeviceContext::XlaDeviceContext( } } +void XlaDeviceContext::CopyTensorInSameDevice(const Tensor* input_tensor, + Device* device, + Tensor* output_tensor, + StatusCallback done) const { + done(errors::Unimplemented("XLA->XLA same-device copies not implemented.")); +} + void XlaDeviceContext::CopyCPUTensorToDevice(const Tensor* cpu_tensor, Device* device, Tensor* device_tensor, diff --git a/tensorflow/compiler/jit/xla_device_context.h b/tensorflow/compiler/jit/xla_device_context.h index 1e18df197a2dd65590c5181b4dae4481dca36641..e45db989fac720df6c3458c93a6b8dbb0919f930 100644 --- a/tensorflow/compiler/jit/xla_device_context.h +++ b/tensorflow/compiler/jit/xla_device_context.h @@ -62,6 +62,9 @@ class XlaDeviceContext : public DeviceContext { void CopyDeviceTensorToCPU(const Tensor* device_tensor, absl::string_view tensor_name, Device* device, Tensor* cpu_tensor, StatusCallback done) override; + void CopyTensorInSameDevice(const Tensor* input_tensor, Device* device, + Tensor* output_tensor, + StatusCallback done) const override; xla::LocalClient* client() const { return client_; } se::Stream* stream() const { return stream_.get(); } diff --git a/tensorflow/compiler/tests/depthwise_conv_op_test.py b/tensorflow/compiler/tests/depthwise_conv_op_test.py index 174bfa9efbcd7dcb4f895237eb01c17bc4a3a6b4..90146e6b27ca31304a2549ec247412341efe390c 100644 --- a/tensorflow/compiler/tests/depthwise_conv_op_test.py +++ b/tensorflow/compiler/tests/depthwise_conv_op_test.py @@ -350,8 +350,13 @@ class DepthwiseConv2DTest(xla_test.XLATestCase): self._CompareBackpropInput(input_size, filter_size, output_size, stride, padding) - def _CompareBackpropFilter(self, input_sizes, filter_sizes, output_sizes, - stride, padding): + def _CompareBackpropFilter(self, + input_sizes, + filter_sizes, + output_sizes, + stride, + padding, + data_format="NHWC"): x0 = np.random.rand(*input_sizes).astype(np.float32) x2 = np.random.rand(*output_sizes).astype(np.float32) @@ -360,13 +365,30 @@ class DepthwiseConv2DTest(xla_test.XLATestCase): t0 = array_ops.placeholder(np.float32, shape=input_sizes) t1 = constant_op.constant(filter_sizes, shape=[len(filter_sizes)]) t2 = array_ops.placeholder(np.float32, shape=output_sizes) + native_t0 = t0 + native_t2 = t2 + strides = [1, stride, stride, 1] + if use_xla: + if data_format == "NCHW": + # Transpose from NWHC input to NCHW + # Ex. [4, 5, 5, 48] to [4, 48, 5, 5] + native_t0 = array_ops.transpose(t0, [0, 3, 1, 2]) + native_t2 = array_ops.transpose(t2, [0, 3, 1, 2]) + strides = [1, 1, stride, stride] with self.test_scope(): backprop = nn_ops.depthwise_conv2d_native_backprop_filter( - t0, t1, t2, strides=[1, stride, stride, 1], padding=padding) + native_t0, + t1, + native_t2, + strides=strides, + padding=padding, + data_format=data_format) else: + # For CPU, the format NCHW is not supported. Therefore we always use + # NHWC here. backprop = nn_ops.depthwise_conv2d_native_backprop_filter( - t0, t1, t2, strides=[1, stride, stride, 1], padding=padding) + native_t0, t1, native_t2, strides=strides, padding=padding) ret = backprop.eval({t0: x0, t2: x2}) self.assertShapeEqual(ret, backprop) return ret @@ -379,11 +401,24 @@ class DepthwiseConv2DTest(xla_test.XLATestCase): for index, (input_size, filter_size, output_size, stride, padding) in enumerate(ConfigsToTest()): print("Testing DepthwiseConv2DFilterGradCompare,", index, "th config:", - input_size, "*", filter_size, "stride:", stride, "padding:", - padding) + input_size, "*", filter_size, "producing output", output_size, + "stride:", stride, "padding:", padding) self._CompareBackpropFilter(input_size, filter_size, output_size, stride, padding) + def testDepthwiseConv2DFilterGradFormatNCHWCompare(self): + for index, (input_size, filter_size, output_size, stride, + padding) in enumerate(ConfigsToTest()): + print("Testing DepthwiseConv2DFilterGradFormatNCHWCompare,", index, + "th config:", input_size, "*", filter_size, "producing output", + output_size, "stride:", stride, "padding:", padding) + self._CompareBackpropFilter( + input_size, + filter_size, + output_size, + stride, + padding, + data_format="NCHW") if __name__ == "__main__": test.main() diff --git a/tensorflow/compiler/tf2xla/functionalize_cond.cc b/tensorflow/compiler/tf2xla/functionalize_cond.cc index c693e42d26712d55852f45c806215fc1f1b9a030..7ae96e1d484900e28e8c23c3bb2232401144ad82 100644 --- a/tensorflow/compiler/tf2xla/functionalize_cond.cc +++ b/tensorflow/compiler/tf2xla/functionalize_cond.cc @@ -640,7 +640,8 @@ Status Conditional::ExtractBodies(Graph* graph) { Status Conditional::BuildIfNode(Graph* graph, FunctionLibraryDefinition* library) { VLOG(2) << "Build cond function for " << name(); - NodeDefBuilder builder(name(), "If", library); + NodeDebugInfo debug_info((*merges_.begin())->def()); + NodeDefBuilder builder(name(), "If", library, &debug_info); const string branch_name[] = {"else_branch", "then_branch"}; for (auto branch : {BranchType::kElseBranch, BranchType::kThenBranch}) { int branch_index = static_cast(branch); diff --git a/tensorflow/compiler/tf2xla/kernels/BUILD b/tensorflow/compiler/tf2xla/kernels/BUILD index 8bc329229648c5aced8d06c99b170803bb3a90f8..a18a4e92d62787051f6ab92e72ee8bf0d1060dca 100644 --- a/tensorflow/compiler/tf2xla/kernels/BUILD +++ b/tensorflow/compiler/tf2xla/kernels/BUILD @@ -1,16 +1,11 @@ +load("//tensorflow:tensorflow.bzl", "tf_copts", "tf_kernel_library") + licenses(["notice"]) # Apache 2.0 package( default_visibility = ["//tensorflow/compiler/tf2xla:internal"], ) -load("//tensorflow:tensorflow.bzl", "tf_copts") -load("//tensorflow:tensorflow.bzl", "tf_kernel_library") -load( - "//third_party/mkl:build_defs.bzl", - "if_mkl", -) - tf_kernel_library( name = "xla_ops", srcs = [ @@ -122,12 +117,9 @@ tf_kernel_library( "//tensorflow/compiler/tf2xla:common", "//tensorflow/compiler/tf2xla:xla_compiler", "//tensorflow/compiler/tf2xla/lib:broadcast", - "//tensorflow/compiler/tf2xla/lib:cholesky", - "//tensorflow/compiler/tf2xla/lib:qr", "//tensorflow/compiler/tf2xla/lib:random", "//tensorflow/compiler/tf2xla/lib:scatter", "//tensorflow/compiler/tf2xla/lib:util", - "//tensorflow/compiler/tf2xla/lib:while_loop", "//tensorflow/compiler/tf2xla/ops:xla_ops", "//tensorflow/compiler/xla:array4d", "//tensorflow/compiler/xla:literal", @@ -140,11 +132,14 @@ tf_kernel_library( "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/client/lib:arithmetic", + "//tensorflow/compiler/xla/client/lib:cholesky", "//tensorflow/compiler/xla/client/lib:constants", + "//tensorflow/compiler/xla/client/lib:loops", "//tensorflow/compiler/xla/client/lib:math", "//tensorflow/compiler/xla/client/lib:matrix", "//tensorflow/compiler/xla/client/lib:pooling", "//tensorflow/compiler/xla/client/lib:prng", + "//tensorflow/compiler/xla/client/lib:qr", "//tensorflow/compiler/xla/client/lib:sorting", "//tensorflow/compiler/xla/client/lib:triangular_solve", "//tensorflow/core:framework", diff --git a/tensorflow/compiler/tf2xla/kernels/cholesky_op.cc b/tensorflow/compiler/tf2xla/kernels/cholesky_op.cc index 9fcbc86adc0967cbb7fb73da8bdabc58b60953da..0ed3044efa5b1060d2b0ad2d5563b0e02ebf66ec 100644 --- a/tensorflow/compiler/tf2xla/kernels/cholesky_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/cholesky_op.cc @@ -13,9 +13,9 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/tf2xla/lib/cholesky.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/lib/cholesky.h" namespace tensorflow { namespace { @@ -24,7 +24,7 @@ class CholeskyOp : public XlaOpKernel { public: explicit CholeskyOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} void Compile(XlaOpKernelContext* ctx) override { - ctx->SetOutput(0, Cholesky(ctx->Input(0))); + ctx->SetOutput(0, xla::Cholesky(ctx->Input(0))); } }; diff --git a/tensorflow/compiler/tf2xla/kernels/conv_op_helpers.cc b/tensorflow/compiler/tf2xla/kernels/conv_op_helpers.cc index 641fefafb357f6ad10483c454600f3dadd4f8cb7..4124b258c7788e3850f07cbf4d53930784c635fd 100644 --- a/tensorflow/compiler/tf2xla/kernels/conv_op_helpers.cc +++ b/tensorflow/compiler/tf2xla/kernels/conv_op_helpers.cc @@ -392,23 +392,31 @@ xla::StatusOr MakeXlaBackpropFilterConvOp( builder->GetShape(activations)); TF_ASSIGN_OR_RETURN(xla::Shape out_backprop_shape, builder->GetShape(gradients)); + xla::XlaOp filter_backprop; + + xla::Shape input_shape = activations_shape; + xla::Shape output_shape = out_backprop_shape; + + TensorShape input_tensor_shape, filter_tensor_shape, output_tensor_shape; + TF_RETURN_IF_ERROR(XLAShapeToTensorShape(filter_shape, &filter_tensor_shape)); + TF_RETURN_IF_ERROR(XLAShapeToTensorShape(input_shape, &input_tensor_shape)); + TF_RETURN_IF_ERROR(XLAShapeToTensorShape(output_shape, &output_tensor_shape)); + const xla::Shape expanded_filter_shape = attrs.depthwise ? ExpandedFilterShapeForDepthwiseConvolution(filter_shape) : filter_shape; - // Reuse dimension computation logic from conv_grad_ops.cc. ConvBackpropDimensions dims; - TF_RETURN_IF_ERROR(ConvBackpropComputeDimensionsV2XlaShapes( - type_string, attrs.num_spatial_dims, activations_shape, - expanded_filter_shape, out_backprop_shape, attrs.dilations, attrs.strides, - attrs.padding, attrs.data_format, &dims)); - // The filter gradients are computed by a convolution of the input // activations and the output gradients, with some appropriate padding. // See the comment at the top of conv_grad_ops.h for details. - xla::ConvolutionDimensionNumbers dnums; + TF_RETURN_IF_ERROR(ConvBackpropComputeDimensionsV2XlaShapes( + type_string, attrs.num_spatial_dims, activations_shape, + expanded_filter_shape, out_backprop_shape, attrs.dilations, attrs.strides, + attrs.padding, attrs.data_format, &dims)); + // The activations (inputs) form the LHS of the convolution. // Activations have shape: [batch, in_rows, in_cols, ..., in_depth] // For the gradient computation, we flip the roles of the batch and @@ -420,29 +428,99 @@ xla::StatusOr MakeXlaBackpropFilterConvOp( int n_dim = GetTensorBatchDimIndex(num_dims, attrs.data_format); int c_dim = GetTensorFeatureDimIndex(num_dims, attrs.data_format); - // Swap n_dim and c_dim in the activations. - dnums.set_input_batch_dimension(c_dim); - dnums.set_input_feature_dimension(n_dim); + int64 total_spatial_size = 1; + for (int i = 0; i < attrs.num_spatial_dims; ++i) { + total_spatial_size *= dims.input_size(i); + } - // The gradients become the RHS of the convolution. - // The gradients have shape [batch, out_rows, out_cols, ..., out_depth] - // where the batch becomes the input feature for the convolution. - dnums.set_kernel_input_feature_dimension(n_dim); - dnums.set_kernel_output_feature_dimension(c_dim); + // We use this approach only for depthwise convolutions where feature counts + // are large but space dimensions are small. The conversion logic below + // assumes that the data format is NHWC, so we also check that here. + bool should_perform_depthwise_conv = + attrs.data_format == FORMAT_NHWC && + (total_spatial_size < dims.in_depth) && + filter_tensor_shape.dim_size(num_dims - 1) == 1 && attrs.depthwise; + + int64 num_spatial_dims = + attrs.num_spatial_dims + (should_perform_depthwise_conv ? 1 : 0); + + std::vector> padding(num_spatial_dims); + std::vector rhs_dilation(num_spatial_dims); + std::vector window_strides(num_spatial_dims); + std::vector ones(num_spatial_dims, 1); + + if (should_perform_depthwise_conv) { + // This approach is similar to handling of grouped convolutions in + // the convolution_feature_group_converter.cc. Please refer to it for + // details. + + // Add spatial dimension to the activation, and reshape. + std::vector activations_reshape_sizes, gradients_reshape_sizes; + + activations_reshape_sizes.push_back(dims.batch_size); + gradients_reshape_sizes.push_back(dims.batch_size); + for (int i = 0; i < attrs.num_spatial_dims; i++) { + activations_reshape_sizes.push_back(dims.input_size(i)); + gradients_reshape_sizes.push_back(dims.output_size(i)); + } + activations_reshape_sizes.push_back(dims.in_depth); + activations_reshape_sizes.push_back(1); + gradients_reshape_sizes.push_back(dims.out_depth); + gradients_reshape_sizes.push_back(1); + + activations = xla::Reshape(activations, activations_reshape_sizes); + gradients = xla::Reshape(gradients, gradients_reshape_sizes); + + int64 new_spatial_dim = activations_reshape_sizes.size() - 1; + + // Set the newly added dimension to be the batch. + dnums.set_input_batch_dimension(new_spatial_dim); + dnums.set_input_feature_dimension(c_dim); + + // The gradients become the RHS of the convolution. + // The gradients have shape [batch, out_rows, out_cols, ..., out_depth, 1] + // where the batch becomes a spatial dimension, and 1 becomes + // the input feature for the convolution. + dnums.set_kernel_input_feature_dimension(new_spatial_dim); + dnums.set_kernel_output_feature_dimension(c_dim); + + // Treat original batch dimension as a spatial dimension. + dnums.add_input_spatial_dimensions(n_dim); + dnums.add_kernel_spatial_dimensions(n_dim); + } else { + // The activations (inputs) form the LHS of the convolution. + // Activations have shape: [batch, in_rows, in_cols, ..., in_depth] + // For the gradient computation, we flip the roles of the batch and + // feature dimensions. + // Each spatial entry has size in_depth * batch + + // Swap n_dim and c_dim in the activations. + dnums.set_input_batch_dimension(c_dim); + dnums.set_input_feature_dimension(n_dim); + + // The gradients become the RHS of the convolution. + // The gradients have shape [batch, out_rows, out_cols, ..., out_depth] + // where the batch becomes the input feature for the convolution. + dnums.set_kernel_input_feature_dimension(n_dim); + dnums.set_kernel_output_feature_dimension(c_dim); + } - std::vector> padding(attrs.num_spatial_dims); - std::vector rhs_dilation(attrs.num_spatial_dims); - std::vector window_strides(attrs.num_spatial_dims); - std::vector ones(attrs.num_spatial_dims, 1); + dnums.set_output_batch_dimension(num_spatial_dims); + dnums.set_output_feature_dimension(num_spatial_dims + 1); // Tensorflow filter shape is [ H, W, ..., inC, outC ]. - for (int i = 0; i < attrs.num_spatial_dims; ++i) { + for (int i = 0; i < num_spatial_dims; ++i) { dnums.add_output_spatial_dimensions(i); } - dnums.set_output_batch_dimension(attrs.num_spatial_dims); - dnums.set_output_feature_dimension(attrs.num_spatial_dims + 1); - for (int i = 0; i < attrs.num_spatial_dims; ++i) { + if (should_perform_depthwise_conv) { + // Set the right parameters for the newly created spatial dimension. + padding[0] = {0, 0}; + rhs_dilation[0] = 1; + window_strides[0] = 1; + } + + for (int64 i = 0; i < attrs.num_spatial_dims; ++i) { int64 dim = GetTensorSpatialDimIndex(num_dims, attrs.data_format, i); dnums.add_input_spatial_dimensions(dim); dnums.add_kernel_spatial_dimensions(dim); @@ -483,9 +561,10 @@ xla::StatusOr MakeXlaBackpropFilterConvOp( const int64 pad_before = attrs.padding == Padding::SAME ? std::max(pad_total / 2, 0) : 0; - padding[i] = {pad_before, pad_total - pad_before}; - rhs_dilation[i] = dims.spatial_dims[i].stride; - window_strides[i] = attrs.dilations[dim]; + int64 dim_being_operated = should_perform_depthwise_conv ? i + 1 : i; + padding[dim_being_operated] = {pad_before, pad_total - pad_before}; + rhs_dilation[dim_being_operated] = dims.spatial_dims[i].stride; + window_strides[dim_being_operated] = attrs.dilations[dim]; } // Besides padding the input, we will also expand output_rows to @@ -496,13 +575,19 @@ xla::StatusOr MakeXlaBackpropFilterConvOp( // // This is done by specifying the window dilation factors in the // convolution HLO below. - auto filter_backprop = - xla::ConvGeneralDilated(activations, gradients, window_strides, padding, - /*lhs_dilation=*/ones, rhs_dilation, dnums); - - if (attrs.depthwise) { - filter_backprop = ContractFilterForDepthwiseBackprop( - filter_shape, filter_backprop, activations.builder()); + filter_backprop = xla::ConvGeneralDilated( + activations, gradients, window_strides, padding, + /*lhs_dilation=*/ones, rhs_dilation, dnums, + /*feature_group_count=*/ + should_perform_depthwise_conv ? dims.in_depth : 1); + + if (should_perform_depthwise_conv) { + filter_backprop = xla::Reshape(filter_backprop, filter_shape.dimensions()); + } else { + if (attrs.depthwise) { + filter_backprop = ContractFilterForDepthwiseBackprop( + filter_shape, filter_backprop, activations.builder()); + } } return filter_backprop; diff --git a/tensorflow/compiler/tf2xla/kernels/gather_op.cc b/tensorflow/compiler/tf2xla/kernels/gather_op.cc index 20b0de193dc060197f3062d3be0b8d45f7dcb9b1..41c31d0ed58fe9bc9bbde0bd58993c975f04fd60 100644 --- a/tensorflow/compiler/tf2xla/kernels/gather_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/gather_op.cc @@ -14,7 +14,6 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/tf2xla/kernels/gather_op_helpers.h" -#include "tensorflow/compiler/tf2xla/lib/while_loop.h" #include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/tf2xla/type_util.h" #include "tensorflow/compiler/tf2xla/xla_context.h" diff --git a/tensorflow/compiler/tf2xla/kernels/if_op.cc b/tensorflow/compiler/tf2xla/kernels/if_op.cc index b5e083912555c865b5eadc7697075c9ca4451ca9..4f0f0fd9aefecc3d31f8bd9c8ca40ebb0860c82d 100644 --- a/tensorflow/compiler/tf2xla/kernels/if_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/if_op.cc @@ -56,6 +56,7 @@ void XlaIfOp::Compile(XlaOpKernelContext* ctx) { VLOG(1) << "Building If: " << input_types_.size() << " inputs"; std::vector arguments(input_types_.size()); + int num_resource_args = 0; for (int i = 0; i < input_types_.size(); ++i) { XlaCompiler::Argument& arg = arguments[i]; DataType type = ctx->input_type(i + 1); @@ -81,6 +82,8 @@ void XlaIfOp::Compile(XlaOpKernelContext* ctx) { << " type: " << DataTypeString(arg.type) << " shape: " << arg.shape.DebugString() << " initialized: " << arg.initialized; + + num_resource_args++; } else { arg.kind = XlaCompiler::Argument::kParameter; arg.type = input_types_[i]; @@ -236,9 +239,13 @@ void XlaIfOp::Compile(XlaOpKernelContext* ctx) { ctx->SetOutput(i, output_handle); } if (has_token_input_output_) { - // Set token output for this "if" op. + // Set token output for this "If" op. Token output is the last output of + // XLA computation, which comes after all "normal" TF outputs and resource + // updates. For "If" node, num of resource updates equals to number of + // resource args because we set `return_updated_values_for_all_resources` + // to true in XlaCompiler option. xla::XlaOp token_output = - xla::GetTupleElement(outputs, output_types_.size()); + xla::GetTupleElement(outputs, output_types_.size() + num_resource_args); auto shape_or = b->GetShape(token_output); OP_REQUIRES_OK(ctx, shape_or.status()); OP_REQUIRES(ctx, xla::ShapeUtil::IsToken(shape_or.ValueOrDie()), diff --git a/tensorflow/compiler/tf2xla/kernels/image_ops.cc b/tensorflow/compiler/tf2xla/kernels/image_ops.cc index e9bb0a77e99d144863b027bd214081316d61c314..96ddd42e2ae04d454e4fb85628d139e17a543d2e 100644 --- a/tensorflow/compiler/tf2xla/kernels/image_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/image_ops.cc @@ -15,12 +15,12 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/kernels/gather_op_helpers.h" #include "tensorflow/compiler/tf2xla/lib/util.h" -#include "tensorflow/compiler/tf2xla/lib/while_loop.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/lib/arithmetic.h" #include "tensorflow/compiler/xla/client/lib/constants.h" +#include "tensorflow/compiler/xla/client/lib/loops.h" #include "tensorflow/compiler/xla/client/lib/sorting.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/shape_util.h" @@ -505,9 +505,9 @@ class NonMaxSuppressionOp : public XlaOpKernel { init_values.push_back(included_iou); auto suppress_loop_result = - XlaWhileLoop(WhileCondFn(num_boxes, output_size), - SuppressBodyFn(num_boxes), init_values, "suppress_loop", - builder) + xla::WhileLoopHelper(WhileCondFn(num_boxes, output_size), + SuppressBodyFn(num_boxes), init_values, + "suppress_loop", builder) .ValueOrDie(); xla::XlaOp included_score = diff --git a/tensorflow/compiler/tf2xla/kernels/qr_op.cc b/tensorflow/compiler/tf2xla/kernels/qr_op.cc index 7ea0afc1f53cbe4cfcc3f6121a4ecd55864c1b52..66ec40a946b8a063d84acd33daf81f52ea2c35ed 100644 --- a/tensorflow/compiler/tf2xla/kernels/qr_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/qr_op.cc @@ -13,9 +13,9 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/tf2xla/lib/qr.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/lib/qr.h" namespace tensorflow { namespace { @@ -26,7 +26,7 @@ class QROp : public XlaOpKernel { OP_REQUIRES_OK(ctx, ctx->GetAttr("full_matrices", &full_matrices_)); } void Compile(XlaOpKernelContext* ctx) override { - auto result = QRDecomposition(ctx->Input(0), full_matrices_); + auto result = xla::QRDecomposition(ctx->Input(0), full_matrices_); if (!result.ok()) { ctx->SetStatus(result.status()); return; diff --git a/tensorflow/compiler/tf2xla/kernels/random_ops.cc b/tensorflow/compiler/tf2xla/kernels/random_ops.cc index 8822e29f7e77b1cbc6fa6ca61d0062d9b1b0c36e..2d92056e4f522f6206e7d632f0fa1e8b793fd6e3 100644 --- a/tensorflow/compiler/tf2xla/kernels/random_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/random_ops.cc @@ -20,12 +20,12 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/kernels/gather_op_helpers.h" #include "tensorflow/compiler/tf2xla/lib/random.h" #include "tensorflow/compiler/tf2xla/lib/util.h" -#include "tensorflow/compiler/tf2xla/lib/while_loop.h" #include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/lib/arithmetic.h" +#include "tensorflow/compiler/xla/client/lib/loops.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/tensor.h" @@ -175,8 +175,8 @@ class RandomShuffleOp : public XlaOpKernel { }; // for i in range(n): auto swap_loop_result = - XlaForEachIndex(n, xla::S32, swap_body_fn, {swaps, indices}, - "indices_swap_loop", builder) + xla::ForEachIndex(n, xla::S32, swap_body_fn, {swaps, indices}, + "indices_swap_loop", builder) .ValueOrDie(); auto swapped_indices = swap_loop_result[1]; diff --git a/tensorflow/compiler/tf2xla/kernels/resampler_ops.cc b/tensorflow/compiler/tf2xla/kernels/resampler_ops.cc index 54d34a38abc4948a1a08197d72e3e7f763649093..f9985d526033ca675c701a508a3d1576e46bc5f7 100644 --- a/tensorflow/compiler/tf2xla/kernels/resampler_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/resampler_ops.cc @@ -125,7 +125,7 @@ XlaOp ConcatenateIota(xla::XlaBuilder* b, XlaOp indices, dimensions.back() = 1; auto batch_indices = - xla::Iota(b, xla::ShapeUtil::MakeShape(xla::U32, dimensions), + xla::Iota(b, xla::ShapeUtil::MakeShape(xla::S32, dimensions), /*iota_dimension=*/0); return xla::ConcatInDim(b, {batch_indices, indices}, dimensions.size() - 1); @@ -189,11 +189,53 @@ XlaOp ScatterToGradData(XlaOpKernelContext* ctx, XlaOp grad_data, XlaOp indices, scatter_dim_numbers); } +// Bounds samples to 0 if the warp image indices are out of the (-1, image_size) +// bound. +// The resulting dimension is given by 'result_dims'. +XlaOp BoundSamples(XlaOpKernelContext* ctx, XlaOp warp, + xla::PrimitiveType warp_type, TensorShape warp_shape, + std::vector result_dims, + std::vector broadcasted_dims, int64 last_warp_dim, + xla::Shape data_shape, XlaOp sample) { + auto is_gt_minus_one = + xla::Gt(warp, + xla::ConvertElementType( + xla::ConstantR1(ctx->builder(), {-1, -1}), warp_type), + /*broadcast_dimensions=*/{warp_shape.dims() - 1}); + auto is_lt_image_size = xla::Lt( + warp, + xla::ConvertElementType( + xla::ConstantR1( + ctx->builder(), + {/*width=*/static_cast(data_shape.dimensions(2)), + /*height=*/static_cast(data_shape.dimensions(1))}), + warp_type), + /*broadcast_dimensions=*/{warp_shape.dims() - 1}); + + auto is_in_bound_padded_x_y = xla::And(is_gt_minus_one, is_lt_image_size); + // Reduce along last dimension. The resulting dimension is: + // [batch, dim_0, ...dim_n]. + auto is_in_bound = xla::Reduce( + is_in_bound_padded_x_y, xla::ConstantR0(ctx->builder(), true), + xla::CreateScalarAndComputation(xla::PrimitiveType::PRED, ctx->builder()), + {last_warp_dim}); + + // Broadcast 'is_in_bound' to the same dimension as 'result_dims'. + auto broadcasted_is_in_bound = + xla::BroadcastInDim(is_in_bound, result_dims, broadcasted_dims); + + // Set out of bound samples to zero. + auto zeros = + xla::Broadcast(xla::Zero(ctx->builder(), warp_type), result_dims); + return xla::Select(broadcasted_is_in_bound, sample, zeros); +} + // Build computation the backprop into input 'data'. // Where input: // grad_output is of dimension [batch, dim_0, ...dim_n, channel] // ratio is of dimension [batch, dim_0, ...dim_n, 2] // gather_indices is of dimension [batch, dim_0, ...dim_n, 3] +// data_shape is of dimension [batch, x(width), y(height), channel] // // Output: // scatter-add to each 2x2 grad_data neighbor: @@ -201,10 +243,12 @@ XlaOp ScatterToGradData(XlaOpKernelContext* ctx, XlaOp grad_data, XlaOp indices, // grad_data[cx, fy, chan] += output_grad * (1 - dx) * dy // grad_data[fx, cy, chan] += output_grad * dx * (1 - dy) // grad_data[cx, cy, chan] += output_grad * (1 - dx) * (1 - dy) -// where (dx, dy) is (1 - ratio). +// where (dx, dy) is (1 - ratio). If (dx, dy) is out of bound, then the their +// contribution is 0 to 'grad_data'. XlaOp CalculateGradData(XlaOpKernelContext* ctx, XlaOp grad_output, XlaOp ratio, - XlaOp gather_indices, xla::PrimitiveType warp_type, - TensorShape warp_shape, int64 data_channels, + XlaOp gather_indices, XlaOp warp, + xla::PrimitiveType warp_type, TensorShape warp_shape, + int64 last_warp_dim, int64 data_channels, xla::Shape data_shape) { // Weights tensor has dimension [batch, dim_0, ... dim_n, 4]. auto weights = BilinearWeights(ctx, ratio, warp_shape, warp_type); @@ -229,6 +273,18 @@ XlaOp CalculateGradData(XlaOpKernelContext* ctx, XlaOp grad_output, XlaOp ratio, std::iota(reshaped_weights_indices.begin(), reshaped_weights_indices.end(), 0); + // Set out of bound weights to 0. + // The dimension of the reshaped_weight: [batch, dim_0, ...dim_n, 2, 2]. + std::vector reshaped_result_dims(warp_dims.begin(), + warp_dims.end() - 1); + reshaped_result_dims.push_back(2); + reshaped_result_dims.push_back(2); + std::vector broadcasted_dims(warp_dims.size() - 1); + std::iota(broadcasted_dims.begin(), broadcasted_dims.end(), 0); + reshaped_weights = BoundSamples(ctx, warp, warp_type, warp_shape, + reshaped_result_dims, broadcasted_dims, + last_warp_dim, data_shape, reshaped_weights); + // The dimension is [batch, dim_0, ..., dim_n, 2, 2, data_channel]. auto broadcast_reshaped_weights = xla::BroadcastInDim( reshaped_weights, weights_with_channels_dims, reshaped_weights_indices); @@ -245,18 +301,41 @@ XlaOp CalculateGradData(XlaOpKernelContext* ctx, XlaOp grad_output, XlaOp ratio, auto grad_data = xla::ConstantLiteral( ctx->builder(), xla::Literal::CreateFromShape(data_shape)); - return ScatterToGradData(ctx, grad_data, gather_indices, - grad_output_multiply_weights, warp_shape.dims(), - warp_type); + // Pad grad data then slice it back. + // + // After left and right column 0-padding, the new dimension of padded data + // will be [batch, x+2, y+2, channel]. + auto padded_grad_data = + xla::Pad(grad_data, xla::Zero(ctx->builder(), warp_type), + xla::MakeEdgePaddingConfig({{0, 0}, {1, 1}, {1, 1}, {0, 0}})); + + auto shifting_value = xla::ConstantR1( + ctx->builder(), {/*batch=*/0, /*x(width)=*/1, /*y(height)=*/1}); + auto shifted_gather_indices = + xla::Add(gather_indices, shifting_value, {last_warp_dim}); + + auto updated_grad_data = ScatterToGradData( + ctx, padded_grad_data, shifted_gather_indices, + grad_output_multiply_weights, warp_shape.dims(), warp_type); + + const int64 batch_size = data_shape.dimensions(0); + const int64 width = data_shape.dimensions(1); + const int64 height = data_shape.dimensions(2); + // Slice out the result accounting for the padding. + return xla::Slice( + updated_grad_data, /*start_indices=*/{0, 1, 1, 0}, + /*limit_indices=*/{batch_size, width + 1, height + 1, data_channels}, + /*strides=*/{1, 1, 1, 1}); } // Build computation for the backprop into input 'warp'. // Where input: -// warp is of dimension [batch, dim_0, ...dim_n, 2] -// grad_output is of dimension [batch, dim_0, ...dim_n, channel] -// ratio is of dimension [batch, dim_0, ...dim_n, 2] -// gather_indices is of dimension [batch, dim_0, ...dim_n, 3] -// data is of dimension [batch, x, y, channel] +// warp is of dimension [batch, dim_0, ...dim_n, 2] +// grad_output is of dimension [batch, dim_0, ...dim_n, channel] +// ratio is of dimension [batch, dim_0, ...dim_n, 2] +// gather_indices is of dimension [batch, dim_0, ...dim_n, 3] where the last +// dimension of size 3 is for {batch, x(width), y(height)}. +// data is of dimension [batch, x, y, channel] // // Output (simplified by ignoring the batch dimensions): // Since the forward path has: @@ -275,12 +354,12 @@ XlaOp CalculateGradData(XlaOpKernelContext* ctx, XlaOp grad_output, XlaOp ratio, // grad_warp_x = py * (img_cxcy - img_fxcy) + (1-py) * (img_cxfy-img_fxfy) // grad_warp_y = px * (img_cxcy - img_cxfy) + (1-px) * (img_fxcy-img_fxfy) // -// where (px, py) is warp, (fx, fy) is the left top corner and (cx, cy) is the +// where (px, py) is warp, (fx, fy) is the top left corner and (cx, cy) is the // bottom right corner in a 2x2 neighborhood. XlaOp CalculateGradWarp(XlaOpKernelContext* ctx, XlaOp grad_output, XlaOp ratio, XlaOp gather_indices, XlaOp data, TensorShape warp_shape, int64 data_channels, - xla::PrimitiveType data_type) { + xla::PrimitiveType data_type, xla::Shape data_shape) { auto warp_dims = warp_shape.dim_sizes(); std::vector warp_dims_without_last_dims(warp_dims.begin(), warp_dims.end() - 1); @@ -289,12 +368,30 @@ XlaOp CalculateGradWarp(XlaOpKernelContext* ctx, XlaOp grad_output, XlaOp ratio, std::vector neighbor_broadcast_dims = warp_dims_without_last_dims; neighbor_broadcast_dims.push_back(4); - // The dimension is [batch, dim_0, ... dim_n, 4, data_channels] - auto neighbors_data = Gather2by2Neighbors( - ctx->builder(), data, gather_indices, data_channels, warp_shape.dims()); + // With dimension [batch, dim_0, ...dim_n, 4] + auto neighbor_broadcast_shape = + xla::ShapeUtil::MakeShape(data_type, neighbor_broadcast_dims); const int64 last_warp_dim = warp_shape.dims() - 1; + // Pad data with 0, before gathering such that 0 will be returned for samples + // in the range of (-1, 0) or (image_dimension-1, image_dimension). + // After left and right column 0-padding, the new dimension of padded data + // will be [batch, x+2, y+2, channel]. + auto padded_data = + xla::Pad(data, xla::Zero(ctx->builder(), data_type), + xla::MakeEdgePaddingConfig({{0, 0}, {1, 1}, {1, 1}, {0, 0}})); + + auto shifting_value = xla::ConstantR1( + ctx->builder(), {/*batch=*/0, /*x(width)=*/1, /*y(height)=*/1}); + auto shifted_gather_indices = + xla::Add(gather_indices, shifting_value, {last_warp_dim}); + + // The dimension is [batch, dim_0, ... dim_n, 4, data_channels] + auto neighbors_data = + Gather2by2Neighbors(ctx->builder(), padded_data, shifted_gather_indices, + data_channels, warp_shape.dims()); + // Since we will be creating the dot product of: // lhs: [batch, dim_0, ...dim_n, 4] // and @@ -417,7 +514,7 @@ class ResamplerOp : public XlaOpKernel { // Find the coordinates of the top left corner for the 2x2 region to be // sampled from. The dimensions are [batch, dim_0, ... dim_n, 2] where the // last dimension of size 2 in turn is [x, y]. - XlaOp top_left = xla::ConvertElementType(warp, xla::U32); + XlaOp top_left = xla::ConvertElementType(warp, xla::S32); auto gather_indices = ConcatenateIota(ctx->builder(), top_left, warp_shape); @@ -526,7 +623,8 @@ class ResamplerGradOp : public XlaOpKernel { size, "]")); } // Last dimension of warp shape must be of size 2. - OP_REQUIRES(ctx, warp_shape.dim_size(warp_shape.dims() - 1) == 2, + const int64 last_warp_dim = warp_shape.dims() - 1; + OP_REQUIRES(ctx, warp_shape.dim_size(last_warp_dim) == 2, errors::InvalidArgument( "the last dimension of warp must be exactly size 2.")); xla::PrimitiveType warp_type = ctx->input_xla_type(1); @@ -549,24 +647,32 @@ class ResamplerGradOp : public XlaOpKernel { // Find the top left corner coordinate for the region to be sampled from. // The dimensions are [batch, dim_0, ... dim_n, 2] where the last dimension // of size 2 in turn is [x, y]. - XlaOp top_left = xla::ConvertElementType(warp, xla::U32); + XlaOp top_left = xla::ConvertElementType(xla::Floor(warp), xla::S32); - // Dimensions are [batch, dim_0, ... dim_n, 2] + // Dimensions are [batch, dim_0, ... dim_n, 2]. XlaOp ratio = warp - xla::ConvertElementType(top_left, warp_type); // Indices for gathering neighboring pixels. auto gather_indices = ConcatenateIota(ctx->builder(), top_left, warp_shape); - auto grad_data = - CalculateGradData(ctx, grad_output, ratio, gather_indices, warp_type, - warp_shape, data_channels, data_shape); + auto grad_data = CalculateGradData( + ctx, grad_output, ratio, gather_indices, warp, warp_type, warp_shape, + last_warp_dim, data_channels, data_shape); auto grad_warp = CalculateGradWarp(ctx, grad_output, ratio, gather_indices, data, - warp_shape, data_channels, data_type); + warp_shape, data_channels, data_type, data_shape); + auto warp_dims = warp_shape.dim_sizes(); + std::vector result_dims(warp_dims.begin(), warp_dims.end() - 1); + result_dims.push_back(2); + std::vector broadcasted_dims(warp_dims.size() - 1); + std::iota(broadcasted_dims.begin(), broadcasted_dims.end(), 0); + auto grad_warp_bounded = + BoundSamples(ctx, warp, warp_type, warp_shape, result_dims, + broadcasted_dims, last_warp_dim, data_shape, grad_warp); ctx->SetOutput(0, grad_data); - ctx->SetOutput(1, grad_warp); + ctx->SetOutput(1, grad_warp_bounded); } }; diff --git a/tensorflow/compiler/tf2xla/kernels/training_ops.cc b/tensorflow/compiler/tf2xla/kernels/training_ops.cc index 960c1462ceb8c00a2d6c96564f6c985fd1caef0f..26d4214099d1d07c1b2e275d783654d9cd948e28 100644 --- a/tensorflow/compiler/tf2xla/kernels/training_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/training_ops.cc @@ -172,6 +172,65 @@ class ResourceApplyMomentum : public XlaOpKernel { REGISTER_XLA_OP(Name("ResourceApplyMomentum").TypeConstraint("T", kFloatTypes), ResourceApplyMomentum); +class ResourceApplyKerasMomentum : public XlaOpKernel { + public: + explicit ResourceApplyKerasMomentum(OpKernelConstruction* ctx) + : XlaOpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("use_nesterov", &use_nesterov_)); + } + + void Compile(XlaOpKernelContext* ctx) override { + DataType type = ctx->input_type(2); + + TensorShape var_shape, accum_shape; + xla::XlaOp var, accum; + OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, type, &var_shape, &var)); + OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(1, type, &accum_shape, &accum)); + + OP_REQUIRES(ctx, var_shape.IsSameSize(accum_shape), + errors::InvalidArgument( + "var and accum do not have the same shape", + var_shape.DebugString(), " ", accum_shape.DebugString())); + + TensorShape lr_shape = ctx->InputShape(2); + OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(lr_shape), + errors::InvalidArgument("lr is not a scalar: ", + lr_shape.DebugString())); + + TensorShape grad_shape = ctx->InputShape(3); + OP_REQUIRES(ctx, var_shape.IsSameSize(grad_shape), + errors::InvalidArgument( + "var and grad do not have the same shape", + var_shape.DebugString(), " ", grad_shape.DebugString())); + + TensorShape momentum_shape = ctx->InputShape(4); + OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(momentum_shape), + errors::InvalidArgument("momentum is not a scalar: ", + momentum_shape.DebugString())); + + xla::XlaOp lr = ctx->Input(2); + xla::XlaOp grad = ctx->Input(3); + xla::XlaOp momentum = ctx->Input(4); + + accum = accum * momentum - grad * lr; + if (use_nesterov_) { + // See https://github.com/tensorflow/tensorflow/pull/2798 for an + // explanation of the reparameterization used here. + var = var + accum * momentum - grad * lr; + } else { + var = var + accum; + } + OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, type, var)); + OP_REQUIRES_OK(ctx, ctx->AssignVariable(1, type, accum)); + } + + private: + bool use_nesterov_; +}; +REGISTER_XLA_OP( + Name("ResourceApplyKerasMomentum").TypeConstraint("T", kFloatTypes), + ResourceApplyKerasMomentum); + class ResourceApplyAdagrad : public XlaOpKernel { public: explicit ResourceApplyAdagrad(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} diff --git a/tensorflow/compiler/tf2xla/kernels/while_op.cc b/tensorflow/compiler/tf2xla/kernels/while_op.cc index ce007fc04a818869686b9936a1607cee42665e87..89b577bfc05b4665d492f4ea5cf6f869af2fa9a9 100644 --- a/tensorflow/compiler/tf2xla/kernels/while_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/while_op.cc @@ -41,8 +41,7 @@ Status MakeXlaCompilerArgumentsFromInputs( *has_uninitialized_vars = false; *has_tensor_arrays = false; for (int i = 0; i < ctx->num_inputs(); ++i) { - VLOG(2) << " Input " << i - << " type: " << DataTypeString(ctx->input_type(i)) + VLOG(2) << " Input " << i << " type: " << DataTypeString(ctx->input_type(i)) << " shape: " << ctx->InputShape(i).DebugString(); XlaCompiler::Argument& arg = (*args)[i]; DataType type = ctx->input_type(i); @@ -233,13 +232,22 @@ void XlaWhileOp::Compile(XlaOpKernelContext* ctx) { xla::ShapeUtil::HumanString(body_input_shape), " vs. ", xla::ShapeUtil::HumanString(body.xla_output_shape))); - xla::Shape expected_cond_output_shape = xla::ShapeUtil::MakeTupleShape( - {xla::ShapeUtil::MakeShape(xla::PRED, {})}); + xla::Shape expected_cond_output_shape_without_side_effect = + xla::ShapeUtil::MakeTupleShape( + {xla::ShapeUtil::MakeShape(xla::PRED, {})}); + xla::Shape expected_cond_output_shape_with_side_effect = + xla::ShapeUtil::MakeTupleShape({xla::ShapeUtil::MakeShape(xla::PRED, {}), + xla::ShapeUtil::MakeTokenShape()}); OP_REQUIRES(ctx, - xla::ShapeUtil::Compatible(cond.xla_output_shape, - expected_cond_output_shape), + xla::ShapeUtil::Compatible( + cond.xla_output_shape, + expected_cond_output_shape_without_side_effect) || + xla::ShapeUtil::Compatible( + cond.xla_output_shape, + expected_cond_output_shape_with_side_effect), errors::InvalidArgument( - "Output shape of loop condition should be (pred[]), got: ", + "Output shape of loop condition should be (pred[]) or " + "(pred[], token[]), got: ", xla::ShapeUtil::HumanString(cond.xla_output_shape))); int num_inputs = body.input_mapping.size(); diff --git a/tensorflow/compiler/tf2xla/lib/BUILD b/tensorflow/compiler/tf2xla/lib/BUILD index 3e7a761120317ff85947559b7b2e52be9232afb7..3d7b0bc959f9dbf3c1b9749379e2ea0d285b302b 100644 --- a/tensorflow/compiler/tf2xla/lib/BUILD +++ b/tensorflow/compiler/tf2xla/lib/BUILD @@ -15,8 +15,6 @@ filegroup( ]), ) -load("//tensorflow/compiler/xla/tests:build_defs.bzl", "xla_test") - cc_library( name = "broadcast", srcs = ["broadcast.cc"], @@ -33,27 +31,6 @@ cc_library( ], ) -cc_library( - name = "cholesky", - srcs = ["cholesky.cc"], - hdrs = ["cholesky.h"], - deps = [ - ":util", - ":while_loop", - "//tensorflow/compiler/xla:literal", - "//tensorflow/compiler/xla:shape_util", - "//tensorflow/compiler/xla:status_macros", - "//tensorflow/compiler/xla:statusor", - "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/client:xla_builder", - "//tensorflow/compiler/xla/client/lib:constants", - "//tensorflow/compiler/xla/client/lib:matrix", - "//tensorflow/compiler/xla/client/lib:slicing", - "//tensorflow/compiler/xla/client/lib:triangular_solve", - "//tensorflow/core:lib", - ], -) - cc_library( name = "random", srcs = ["random.cc"], @@ -69,35 +46,12 @@ cc_library( ], ) -cc_library( - name = "qr", - srcs = ["qr.cc"], - hdrs = ["qr.h"], - deps = [ - ":util", - ":while_loop", - "//tensorflow/compiler/xla:literal_util", - "//tensorflow/compiler/xla:shape_util", - "//tensorflow/compiler/xla:status_macros", - "//tensorflow/compiler/xla:statusor", - "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/client:xla_builder", - "//tensorflow/compiler/xla/client/lib:arithmetic", - "//tensorflow/compiler/xla/client/lib:constants", - "//tensorflow/compiler/xla/client/lib:math", - "//tensorflow/compiler/xla/client/lib:matrix", - "//tensorflow/compiler/xla/client/lib:slicing", - "//tensorflow/core:lib", - ], -) - cc_library( name = "scatter", srcs = ["scatter.cc"], hdrs = ["scatter.h"], deps = [ ":util", - ":while_loop", "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", @@ -128,19 +82,3 @@ cc_library( "@com_google_absl//absl/types:span", ], ) - -cc_library( - name = "while_loop", - srcs = ["while_loop.cc"], - hdrs = ["while_loop.h"], - deps = [ - ":util", - "//tensorflow/compiler/xla:shape_util", - "//tensorflow/compiler/xla:status_macros", - "//tensorflow/compiler/xla:statusor", - "//tensorflow/compiler/xla/client:xla_builder", - "//tensorflow/compiler/xla/client:xla_computation", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/types:span", - ], -) diff --git a/tensorflow/compiler/tf2xla/lib/scatter.cc b/tensorflow/compiler/tf2xla/lib/scatter.cc index 2b1c2ced925d9fee7392986015a6e716a94d356f..688056791f9750e6b22df4b2cd4643de0b780651 100644 --- a/tensorflow/compiler/tf2xla/lib/scatter.cc +++ b/tensorflow/compiler/tf2xla/lib/scatter.cc @@ -20,7 +20,6 @@ limitations under the License. #include "absl/types/span.h" #include "tensorflow/compiler/tf2xla/lib/util.h" -#include "tensorflow/compiler/tf2xla/lib/while_loop.h" #include "tensorflow/compiler/xla/client/lib/arithmetic.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/literal.h" diff --git a/tensorflow/compiler/tf2xla/resource_operation_table.cc b/tensorflow/compiler/tf2xla/resource_operation_table.cc index 72b240996fb4d9dcb5f5dfd919da618cbae08c16..ff9f1b9ccba2c4f3307890d5aac4ddb6cfaafcd9 100644 --- a/tensorflow/compiler/tf2xla/resource_operation_table.cc +++ b/tensorflow/compiler/tf2xla/resource_operation_table.cc @@ -65,6 +65,7 @@ CreateResourceOpInfoMap() { add("ResourceApplyFtrlV2" , kReadWrite, kVariable); add("ResourceApplyGradientDescent" , kReadWrite, kVariable); add("ResourceApplyMomentum" , kReadWrite, kVariable); + add("ResourceApplyKerasMomentum" , kReadWrite, kVariable); add("ResourceApplyPowerSign" , kReadWrite, kVariable); add("ResourceApplyProximalAdagrad" , kReadWrite, kVariable); add("ResourceApplyProximalGradientDescent" , kReadWrite, kVariable); diff --git a/tensorflow/compiler/tf2xla/side_effect_util.cc b/tensorflow/compiler/tf2xla/side_effect_util.cc index b233e6b2c28e1968bb74901fc684e808ae45ab60..b62f8e9115229ac35c657d374c68336f1168ff77 100644 --- a/tensorflow/compiler/tf2xla/side_effect_util.cc +++ b/tensorflow/compiler/tf2xla/side_effect_util.cc @@ -24,6 +24,8 @@ const char kXlaTokenInputNodesAttrName[] = "_xla_token_input_nodes"; const char kXlaTokenArgNodeName[] = "_xla_token_arg_node"; +const char kXlaHasHostTransferAttrName[] = "_xla_has_host_transfer"; + std::set CalculateTokenInputsForOutputToken(const Graph& g) { std::set results; Node* first_side_effecting_node_on_path = nullptr; diff --git a/tensorflow/compiler/tf2xla/side_effect_util.h b/tensorflow/compiler/tf2xla/side_effect_util.h index f22ddb2f58e1fa5c10ca0fdb956d9136942388b7..7081b362c36c4785164b29003a5f89cd73bcf3af 100644 --- a/tensorflow/compiler/tf2xla/side_effect_util.h +++ b/tensorflow/compiler/tf2xla/side_effect_util.h @@ -35,6 +35,9 @@ extern const char kXlaTokenInputNodesAttrName[]; // node has side-effect dependency on current graph's token input. extern const char kXlaTokenArgNodeName[]; +// This node have XlaRecvAtHost/XlaSendFromHost in its associated functions. +extern const char kXlaHasHostTransferAttrName[]; + // Calculates side-effect dependencies for the graph's token output. // Returns a set of node names representing these dependencies. std::set CalculateTokenInputsForOutputToken(const Graph& g); diff --git a/tensorflow/compiler/tf2xla/tf2xla_test.cc b/tensorflow/compiler/tf2xla/tf2xla_test.cc index ab26d939ccba75ce58609ffd71c7ccadbe90cfa8..24afe595b18b823818bd8fe65bc599af8bce040a 100644 --- a/tensorflow/compiler/tf2xla/tf2xla_test.cc +++ b/tensorflow/compiler/tf2xla/tf2xla_test.cc @@ -91,7 +91,7 @@ TEST(ConvertGraphDefToXla, Sum) { client->ExecuteAndTransfer(computation, {x_global.get(), y_global.get()}); TF_EXPECT_OK(result_or.status()); xla::Literal result = std::move(result_or.ValueOrDie()); - EXPECT_EQ("(s32[]) (\n42\n)", result.ToString()); + EXPECT_EQ("(\ns32[] 42\n)", result.ToString()); config.mutable_feed(0)->mutable_id()->set_output_index( 123); /* invalid output_index */ diff --git a/tensorflow/compiler/tf2xla/tf2xla_util.cc b/tensorflow/compiler/tf2xla/tf2xla_util.cc index cc81772e8c5da710bc733f7e4f5fe820b2c2d110..18d87727c500619bf386be7d8c7085724f44aba3 100644 --- a/tensorflow/compiler/tf2xla/tf2xla_util.cc +++ b/tensorflow/compiler/tf2xla/tf2xla_util.cc @@ -364,6 +364,7 @@ Status AddPlaceholdersForFeeds( GraphDef gd; *gd.mutable_versions() = graph_def->versions(); *gd.add_node() = *existing; + MergeDebugInfo(NodeDebugInfo(*existing), gd.mutable_node(0)); TF_RETURN_IF_ERROR( AddDefaultAttrsToGraphDef(&gd, *op_registry, 0 /*node_offset*/)); @@ -390,6 +391,7 @@ Status AddPlaceholdersForFeeds( // in this code. for (auto it = placeholder_info.begin(); it != placeholder_info.end(); ++it) { const PlaceholderInfo& info = it->second; + // TODO(shikharagarwal): Add original node information. NodeDef* d = graph_def->add_node(); d->set_name(info.placeholder_name); d->set_op("PlaceholderV2"); @@ -557,6 +559,12 @@ bool HasAssociatedFunction(const NodeDef& node_def, return true; } + if (node_def.op() == "XlaHostCompute") { + // XlaHostCompute has "shape_inference_graph" func attr, but that's not + // related to graph execution. + return false; + } + for (const auto& iter : node_def.attr()) { if (iter.second.has_func()) { return true; @@ -578,6 +586,9 @@ std::vector GetAssociatedFunctions( // This is a SymbolicGradient op. AttrValueMap attrs(node.attrs().begin(), node.attrs().end()); results.emplace_back(AssociatedFunctionInfo::SymbolicGradient(op, attrs)); + } else if (node.type_string() == "XlaHostCompute") { + // XlaHostCompute has "shape_inference_graph" func attr, but that's not + // related to graph execution. } else { // Collect all function attrs for the node. for (auto& iter : node.attrs()) { @@ -599,7 +610,9 @@ Status RewriteAssociatedFunction( switch (associated_function.type()) { case AssociatedFunctionInfo::kFunctionCallNode: { // Change this node to call the new function. - NodeDefBuilder builder(node->name(), rewritten_function_name, fld); + NodeDebugInfo debug_info(*node); + NodeDefBuilder builder(node->name(), rewritten_function_name, fld, + &debug_info); for (auto attr : node->attrs()) { builder.Attr(attr.first, attr.second); } diff --git a/tensorflow/compiler/xla/BUILD b/tensorflow/compiler/xla/BUILD index 4360e0857964b0ac63fc887e269b04a4b00d854a..722d1376687efa1c04158e3fd9ce539aac9d0122 100644 --- a/tensorflow/compiler/xla/BUILD +++ b/tensorflow/compiler/xla/BUILD @@ -109,7 +109,7 @@ cc_library( name = "status_macros", srcs = ["status_macros.cc"], hdrs = ["status_macros.h"], - visibility = [":friends"], + visibility = ["//visibility:public"], deps = [ ":statusor", ":types", @@ -224,6 +224,7 @@ cc_library( name = "shape_util", srcs = [ "index_util.cc", + "layout.cc", "layout_util.cc", "primitive_util.cc", "shape.cc", @@ -231,6 +232,7 @@ cc_library( ], hdrs = [ "index_util.h", + "layout.h", "layout_util.h", "primitive_util.h", "shape.h", @@ -290,6 +292,22 @@ tf_cc_test( ], ) +tf_cc_test( + name = "primitive_util_test", + srcs = ["primitive_util_test.cc"], + deps = [ + ":shape_util", + ":status_macros", + ":test", + ":test_helpers", + ":types", + ":util", + ":xla_data_proto", + "//tensorflow/core:lib", + "//tensorflow/core:test_main", + ], +) + tf_cc_test( name = "layout_util_test", srcs = ["layout_util_test.cc"], @@ -301,6 +319,22 @@ tf_cc_test( ], ) +tf_cc_test( + name = "layout_test", + srcs = ["layout_test.cc"], + deps = [ + ":shape_util", + ":status_macros", + ":test", + ":test_helpers", + ":types", + ":util", + ":xla_data_proto", + "//tensorflow/core:test_main", + "@com_google_absl//absl/strings", + ], +) + tf_cc_test( name = "index_util_test", srcs = ["index_util_test.cc"], @@ -575,6 +609,7 @@ cc_library( ":types", ":util", ":xla_data_proto", + "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", "@com_google_absl//absl/memory", @@ -705,7 +740,6 @@ cc_library( "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service:hlo_evaluator", "//tensorflow/compiler/xla/service:shape_inference", - "//tensorflow/compiler/xla/service/cpu:runtime_single_threaded_matmul", "//tensorflow/core:lib", "@com_google_absl//absl/memory", "@com_google_absl//absl/types:span", diff --git a/tensorflow/compiler/xla/client/BUILD b/tensorflow/compiler/xla/client/BUILD index fe99564d3c671cd7890e1fa26fcd2e3384972983..e61d9d2520366f3f21a18b6c62ba924fba23308a 100644 --- a/tensorflow/compiler/xla/client/BUILD +++ b/tensorflow/compiler/xla/client/BUILD @@ -3,7 +3,7 @@ licenses(["notice"]) # Apache 2.0 -package(default_visibility = [":friends"]) +package(default_visibility = ["//visibility:public"]) package_group( name = "friends", diff --git a/tensorflow/compiler/xla/client/client.cc b/tensorflow/compiler/xla/client/client.cc index 74b76f929949d3300a5d0ff45d5fa4cd9f162642..43127cae1e5d81521003a28288e27d291e33c9b9 100644 --- a/tensorflow/compiler/xla/client/client.cc +++ b/tensorflow/compiler/xla/client/client.cc @@ -186,7 +186,7 @@ StatusOr Client::ComputeConstant(const XlaComputation& computation, ComputeConstantGraphRequest request; *request.mutable_computation() = computation.proto(); if (output_layout != nullptr) { - *request.mutable_output_layout() = *output_layout; + *request.mutable_output_layout() = output_layout->ToProto(); } ComputeConstantResponse response; diff --git a/tensorflow/compiler/xla/client/lib/BUILD b/tensorflow/compiler/xla/client/lib/BUILD index 41db8de29ff0085a30847ff41db4ffbfc774e2a1..970f00759f630f30f1c1321231fd9e0199026142 100644 --- a/tensorflow/compiler/xla/client/lib/BUILD +++ b/tensorflow/compiler/xla/client/lib/BUILD @@ -1,5 +1,7 @@ # Common computation builders for XLA. +load("//tensorflow/compiler/xla/tests:build_defs.bzl", "generate_backend_suites", "xla_test") + licenses(["notice"]) # Apache 2.0 package(default_visibility = ["//tensorflow/compiler/xla/client:friends"]) @@ -13,9 +15,6 @@ filegroup( ]), ) -load("//tensorflow/compiler/xla/tests:build_defs.bzl", "xla_test") -load("//tensorflow/compiler/xla/tests:build_defs.bzl", "generate_backend_suites") - # Generate test_suites for all backends, named "${backend}_tests". generate_backend_suites() @@ -35,6 +34,48 @@ cc_library( ], ) +cc_library( + name = "cholesky", + srcs = ["cholesky.cc"], + hdrs = ["cholesky.h"], + deps = [ + ":math", + "//tensorflow/compiler/xla:literal", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/client/lib:constants", + "//tensorflow/compiler/xla/client/lib:loops", + "//tensorflow/compiler/xla/client/lib:matrix", + "//tensorflow/compiler/xla/client/lib:slicing", + "//tensorflow/compiler/xla/client/lib:triangular_solve", + "//tensorflow/core:lib", + ], +) + +xla_test( + name = "cholesky_test", + srcs = ["cholesky_test.cc"], + tags = ["optonly"], + deps = [ + ":arithmetic", + ":cholesky", + ":matrix", + "//tensorflow/compiler/xla:array2d", + "//tensorflow/compiler/xla:literal", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/tests:client_library_test_base", + "//tensorflow/compiler/xla/tests:literal_test_util", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/core:test", + ], +) + cc_library( name = "constants", srcs = ["constants.cc"], @@ -75,6 +116,22 @@ cc_library( ], ) +cc_library( + name = "loops", + srcs = ["loops.cc"], + hdrs = ["loops.h"], + deps = [ + ":constants", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/client:xla_computation", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", + ], +) + cc_library( name = "math", srcs = ["math.cc"], @@ -177,6 +234,48 @@ cc_library( ], ) +cc_library( + name = "qr", + srcs = ["qr.cc"], + hdrs = ["qr.h"], + deps = [ + ":arithmetic", + ":constants", + ":loops", + ":math", + ":matrix", + ":slicing", + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/core:lib", + ], +) + +xla_test( + name = "qr_test", + srcs = ["qr_test.cc"], + tags = ["optonly"], + deps = [ + ":matrix", + ":qr", + "//tensorflow/compiler/xla:array2d", + "//tensorflow/compiler/xla:array3d", + "//tensorflow/compiler/xla:literal", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/tests:client_library_test_base", + "//tensorflow/compiler/xla/tests:literal_test_util", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/core:test", + ], +) + cc_library( name = "slicing", srcs = ["slicing.cc"], @@ -237,6 +336,34 @@ xla_test( ], ) +cc_library( + name = "quantize", + hdrs = ["quantize.h"], + deps = [ + ":constants", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/core:lib", + ], +) + +xla_test( + name = "quantize_test", + srcs = ["quantize_test.cc"], + tags = ["enable_for_xla_interpreter"], + deps = [ + ":quantize", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/tests:client_library_test_base", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + ], +) + cc_library( name = "testing", srcs = ["testing.cc"], diff --git a/tensorflow/compiler/tf2xla/lib/cholesky.cc b/tensorflow/compiler/xla/client/lib/cholesky.cc similarity index 68% rename from tensorflow/compiler/tf2xla/lib/cholesky.cc rename to tensorflow/compiler/xla/client/lib/cholesky.cc index 550ab5b05693b79e60e49577309328ac6846d3f9..fd98049968491d80b9717a2de1f34997bd9d18c1 100644 --- a/tensorflow/compiler/tf2xla/lib/cholesky.cc +++ b/tensorflow/compiler/xla/client/lib/cholesky.cc @@ -13,14 +13,14 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/tf2xla/lib/cholesky.h" +#include "tensorflow/compiler/xla/client/lib/cholesky.h" #include #include -#include "tensorflow/compiler/tf2xla/lib/util.h" -#include "tensorflow/compiler/tf2xla/lib/while_loop.h" #include "tensorflow/compiler/xla/client/lib/constants.h" +#include "tensorflow/compiler/xla/client/lib/loops.h" +#include "tensorflow/compiler/xla/client/lib/math.h" #include "tensorflow/compiler/xla/client/lib/matrix.h" #include "tensorflow/compiler/xla/client/lib/slicing.h" #include "tensorflow/compiler/xla/client/lib/triangular_solve.h" @@ -31,7 +31,7 @@ limitations under the License. #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/core/lib/core/errors.h" -namespace tensorflow { +namespace xla { namespace { @@ -50,26 +50,25 @@ namespace { // l[..., j+1:, j] = (a[..., j+1:, j] - np.dot(l[..., j+1:, :j], row_t)) / // l[..., j, j] // return l -xla::XlaOp CholeskyUnblocked(xla::XlaOp a, - xla::PrecisionConfig::Precision precision) { - xla::XlaBuilder* builder = a.builder(); - return builder->ReportErrorOrReturn([&]() -> xla::StatusOr { - TF_ASSIGN_OR_RETURN(xla::Shape a_shape, builder->GetShape(a)); - const int n_dims = xla::ShapeUtil::Rank(a_shape); - const int64 n = xla::ShapeUtil::GetDimension(a_shape, -1); - auto major_dims = xla::AsInt64Slice(a_shape.dimensions()) +XlaOp CholeskyUnblocked(XlaOp a, PrecisionConfig::Precision precision) { + XlaBuilder* builder = a.builder(); + return builder->ReportErrorOrReturn([&]() -> StatusOr { + TF_ASSIGN_OR_RETURN(Shape a_shape, builder->GetShape(a)); + const int n_dims = ShapeUtil::Rank(a_shape); + const int64 n = ShapeUtil::GetDimension(a_shape, -1); + auto major_dims = AsInt64Slice(a_shape.dimensions()) .subspan( /*pos=*/0, /*len=*/n_dims - 2); - xla::XlaOp l = xla::ZerosLike(a); + XlaOp l = ZerosLike(a); // Construct the for loop body to iterate over rows. - auto body_fn = [&](xla::XlaOp i, absl::Span loop_vars, - xla::XlaBuilder* body_builder) - -> xla::StatusOr> { - xla::Shape col_shape; - xla::Shape row_shape; + auto body_fn = + [&](XlaOp i, absl::Span loop_vars, + XlaBuilder* body_builder) -> StatusOr> { + Shape col_shape; + Shape row_shape; for (int64 d : major_dims) { row_shape.add_dimensions(d); col_shape.add_dimensions(d); @@ -77,43 +76,40 @@ xla::XlaOp CholeskyUnblocked(xla::XlaOp a, row_shape.add_dimensions(1); row_shape.add_dimensions(n); row_shape.set_element_type(a_shape.element_type()); - auto mask_zeros_row = xla::Zeros(body_builder, row_shape); + auto mask_zeros_row = Zeros(body_builder, row_shape); col_shape.add_dimensions(n); col_shape.add_dimensions(1); col_shape.set_element_type(a_shape.element_type()); - auto mask_zeros_col = xla::Zeros(body_builder, col_shape); + auto mask_zeros_col = Zeros(body_builder, col_shape); std::vector mask_vector(n); std::iota(mask_vector.begin(), mask_vector.end(), 0); - auto mask_range = xla::ConstantR1(body_builder, mask_vector); + auto mask_range = ConstantR1(body_builder, mask_vector); auto mask_range_row = - xla::Broadcast(xla::Reshape(mask_range, {0}, {1, n}), major_dims); + Broadcast(Reshape(mask_range, {0}, {1, n}), major_dims); auto mask_range_col = - xla::Broadcast(xla::Reshape(mask_range, {0}, {n, 1}), major_dims); + Broadcast(Reshape(mask_range, {0}, {n, 1}), major_dims); auto body_a = loop_vars[0]; auto body_l = loop_vars[1]; // row = l[..., i, :i] // select the whole i-th row, then mask out all columns past i-1 - auto zero = xla::ConstantR0(body_builder, 0); + auto zero = ConstantR0(body_builder, 0); auto l_i = DynamicSliceInMinorDims(body_l, {i, zero}, {1, n}); - auto row = xla::Select(xla::Ge(mask_range_row, i), mask_zeros_row, l_i); + auto row = Select(Ge(mask_range_row, i), mask_zeros_row, l_i); // a[..., i, i] auto a_ii = DynamicSliceInMinorDims(body_a, {i, i}, {1, 1}); // np.dot(row, np.swapaxes(row, -1, -2)) auto diag_dot = BatchDot(row, TransposeInMinorDims(row), precision); // l[..., i, i] = np.sqrt(a[..., i, i] - np.dot(row, // np.swapaxes(row, -1, -2))) - auto l_ii = - xla::Pow(a_ii - diag_dot, - FloatLiteral(body_builder, a_shape.element_type(), 0.5)); + auto l_ii = Sqrt(a_ii - diag_dot); // a[..., i+1:, i] // select the whole i-th column, then mask out all rows above i+1 auto a_0i = DynamicSliceInMinorDims(body_a, {i}, {1}); - auto a_ip1i = - xla::Select(xla::Le(mask_range_col, i), mask_zeros_col, a_0i); + auto a_ip1i = Select(Le(mask_range_col, i), mask_zeros_col, a_0i); // l[..., i+1:, i] = (a[..., i+1:, i] - np.dot(l[..., i+1:, :i], r.T)) / // l[..., i, i] @@ -122,8 +118,7 @@ xla::XlaOp CholeskyUnblocked(xla::XlaOp a, // r.T) auto dot = BatchDot(body_l, TransposeInMinorDims(row), precision); // np.dot(l[..., i+1:, :i], r.T) - auto dot_ip1 = - xla::Select(xla::Le(mask_range_col, i), mask_zeros_col, dot); + auto dot_ip1 = Select(Le(mask_range_col, i), mask_zeros_col, dot); body_l = DynamicUpdateSliceInMinorDims(body_l, (a_ip1i - dot_ip1) / l_ii, {i}); @@ -131,12 +126,12 @@ xla::XlaOp CholeskyUnblocked(xla::XlaOp a, // column assign will wrap around and overwrite the diagonal assign. body_l = DynamicUpdateSliceInMinorDims(body_l, l_ii, {i, i}); - return std::vector{body_a, body_l}; + return std::vector{body_a, body_l}; }; TF_ASSIGN_OR_RETURN( auto cholesky_while, - XlaForEachIndex(n, xla::S32, body_fn, {a, l}, "unblocked", builder)); + ForEachIndex(n, S32, body_fn, {a, l}, "unblocked", builder)); return cholesky_while[1]; }); @@ -144,34 +139,35 @@ xla::XlaOp CholeskyUnblocked(xla::XlaOp a, } // namespace -xla::XlaOp Cholesky(xla::XlaOp a, int64 block_size, - xla::PrecisionConfig::Precision precision) { - xla::XlaBuilder* builder = a.builder(); - return builder->ReportErrorOrReturn([&]() -> xla::StatusOr { - TF_ASSIGN_OR_RETURN(xla::Shape a_shape, builder->GetShape(a)); - const int ndims = xla::ShapeUtil::Rank(a_shape); +XlaOp Cholesky(XlaOp a, int64 block_size, + PrecisionConfig::Precision precision) { + XlaBuilder* builder = a.builder(); + return builder->ReportErrorOrReturn([&]() -> StatusOr { + TF_ASSIGN_OR_RETURN(Shape a_shape, builder->GetShape(a)); + const int ndims = ShapeUtil::Rank(a_shape); if (ndims < 2) { - return errors::InvalidArgument( - "Arguments to Cholesky must have rank >= 2: ", ndims); + return InvalidArgument( + "Argument to Cholesky must have rank >= 2; shape was %s", + a_shape.ToString()); } - const int64 n = xla::ShapeUtil::GetDimension(a_shape, -1); - if (n != xla::ShapeUtil::GetDimension(a_shape, -2)) { - return errors::InvalidArgument( - "Arguments to Cholesky must be square matrices: ", - xla::ShapeUtil::HumanString(a_shape)); + const int64 n = ShapeUtil::GetDimension(a_shape, -1); + if (n != ShapeUtil::GetDimension(a_shape, -2)) { + return InvalidArgument( + "Argument to Cholesky must be batched square matrices; got shape %s", + ShapeUtil::HumanString(a_shape)); } if (block_size < 1) { - return errors::InvalidArgument( - "block_size argument to Cholesky must be >= 1; got ", block_size); + return InvalidArgument( + "block_size argument to Cholesky must be >= 1; got %d", block_size); } // Blocked left-looking Cholesky factorization. // Algorithm 1 from // Haidar, Azzam, et al. "High-performance Cholesky factorization for // GPU-only execution." Proceedings of General Purpose GPUs. ACM, 2017. - xla::XlaOp l = xla::ZerosLike(a); + XlaOp l = ZerosLike(a); for (int64 i = 0; i < n; i += block_size) { int64 k = std::min(block_size, n - i); if (i > 0) { @@ -207,4 +203,4 @@ xla::XlaOp Cholesky(xla::XlaOp a, int64 block_size, }); } -} // namespace tensorflow +} // namespace xla diff --git a/tensorflow/compiler/tf2xla/lib/cholesky.h b/tensorflow/compiler/xla/client/lib/cholesky.h similarity index 87% rename from tensorflow/compiler/tf2xla/lib/cholesky.h rename to tensorflow/compiler/xla/client/lib/cholesky.h index 9a561c34b92ee45059f2a05336e682838f8e36e2..0bae26837c0f14dd0cfab82cf426becc787ec11c 100644 --- a/tensorflow/compiler/tf2xla/lib/cholesky.h +++ b/tensorflow/compiler/xla/client/lib/cholesky.h @@ -13,13 +13,13 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_COMPILER_TF2XLA_LIB_CHOLESKY_H_ -#define TENSORFLOW_COMPILER_TF2XLA_LIB_CHOLESKY_H_ +#ifndef TENSORFLOW_COMPILER_XLA_CLIENT_LIB_CHOLESKY_H_ +#define TENSORFLOW_COMPILER_XLA_CLIENT_LIB_CHOLESKY_H_ #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -namespace tensorflow { +namespace xla { // Computes the Cholesky decompositions of a batch of symmetric positive // definite matrices. @@ -34,6 +34,6 @@ xla::XlaOp Cholesky( xla::XlaOp a, int64 block_size = 256, xla::PrecisionConfig::Precision precision = xla::PrecisionConfig::HIGHEST); -} // namespace tensorflow +} // namespace xla -#endif // TENSORFLOW_COMPILER_TF2XLA_LIB_CHOLESKY_H_ +#endif // TENSORFLOW_COMPILER_XLA_CLIENT_LIB_CHOLESKY_H_ diff --git a/tensorflow/compiler/xla/client/lib/cholesky_test.cc b/tensorflow/compiler/xla/client/lib/cholesky_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..ba9580a3d32225625acc1447344b7d2c16c5d8a5 --- /dev/null +++ b/tensorflow/compiler/xla/client/lib/cholesky_test.cc @@ -0,0 +1,166 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/client/lib/cholesky.h" + +#include +#include +#include + +#include "tensorflow/compiler/xla/array2d.h" +#include "tensorflow/compiler/xla/client/lib/arithmetic.h" +#include "tensorflow/compiler/xla/client/lib/matrix.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/literal.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/tests/client_library_test_base.h" +#include "tensorflow/compiler/xla/tests/literal_test_util.h" +#include "tensorflow/compiler/xla/tests/test_macros.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/core/lib/core/status_test_util.h" + +namespace { + +using xla::int64; + +using CholeskyTest = xla::ClientLibraryTestBase; + +XLA_TEST_F(CholeskyTest, Simple) { + xla::XlaBuilder builder(TestName()); + + xla::Array2D a_vals({ + {4, 6, 8, 10}, + {6, 45, 54, 63}, + {8, 54, 146, 166}, + {10, 63, 166, 310}, + }); + + xla::XlaOp a; + auto a_data = CreateR2Parameter(a_vals, 0, "a", &builder, &a); + xla::Cholesky(a, /*block_size=*/2); + + xla::Array2D expected({ + {2, 0, 0, 0}, + {3, 6, 0, 0}, + {4, 7, 9, 0}, + {5, 8, 10, 11}, + }); + + ComputeAndCompareR2(&builder, expected, {a_data.get()}, + xla::ErrorSpec(1e-4, 1e-4)); +} + +XLA_TEST_F(CholeskyTest, Simple2) { + xla::XlaBuilder builder(TestName()); + + xla::Array2D a_vals({ + {16, 24, 8, 12}, + {24, 61, 82, 48}, + {8, 82, 456, 106}, + {12, 48, 106, 62}, + }); + + xla::XlaOp a; + auto a_data = CreateR2Parameter(a_vals, 0, "a", &builder, &a); + xla::Cholesky(a); + + xla::Array2D expected( + {{4, 0, 0, 0}, {6, 5, 0, 0}, {2, 14, 16, 0}, {3, 6, 1, 4}}); + + ComputeAndCompareR2(&builder, expected, {a_data.get()}, + xla::ErrorSpec(1e-4, 1e-4)); +} + +XLA_TEST_F(CholeskyTest, SimpleBatched) { + xla::XlaBuilder builder(TestName()); + + xla::Array3D a_vals({ + { + {4, 6, 8, 10}, + {6, 45, 54, 63}, + {8, 54, 146, 166}, + {10, 63, 166, 310}, + }, + { + {16, 24, 8, 12}, + {24, 61, 82, 48}, + {8, 82, 456, 106}, + {12, 48, 106, 62}, + }, + }); + + xla::XlaOp a; + auto a_data = CreateR3Parameter(a_vals, 0, "a", &builder, &a); + xla::Cholesky(a); + + xla::Array3D expected({ + { + {2, 0, 0, 0}, + {3, 6, 0, 0}, + {4, 7, 9, 0}, + {5, 8, 10, 11}, + }, + {{4, 0, 0, 0}, {6, 5, 0, 0}, {2, 14, 16, 0}, {3, 6, 1, 4}}, + }); + + ComputeAndCompareR3(&builder, expected, {a_data.get()}, + xla::ErrorSpec(1e-4, 1e-4)); +} + +using CholeskyTestCase = std::tuple; + +class RandomCholeskyTest + : public xla::ClientLibraryTestBase, + public ::testing::WithParamInterface {}; + +XLA_TEST_P(RandomCholeskyTest, Random) { + xla::XlaBuilder builder(TestName()); + + auto test_params = GetParam(); + std::vector dimensions = {std::get<0>(test_params), + std::get<1>(test_params), + std::get<1>(test_params)}; + xla::Shape shape = xla::ShapeUtil::MakeShape(xla::F32, dimensions); + TF_ASSERT_OK_AND_ASSIGN( + auto literal, + xla::LiteralUtil::CreateRandomLiteral(shape, 0.0, 1.0)); + + auto input = xla::Parameter(&builder, 0, shape, "input"); + // Form a random positive definite matrix. + auto matrix = xla::BatchDot(input, TransposeInMinorDims(input), + xla::PrecisionConfig::HIGHEST); + + auto cholesky = xla::Cholesky(matrix, /*block_size=*/4); + + // Verify that ||matrix - cholesky * cholesky_t||_2 ~= 0 + auto verification = xla::BatchDot(cholesky, TransposeInMinorDims(cholesky), + xla::PrecisionConfig::HIGHEST); + auto delta = matrix - verification; + xla::Reduce(delta * delta, xla::ConstantR0(&builder, 0.0), + CreateScalarAddComputation(xla::F32, &builder), {0, 1, 2}); + + TF_ASSERT_OK_AND_ASSIGN(auto input_data, client_->TransferToServer(literal)); + ComputeAndCompareR0(&builder, 0.0, {input_data.get()}, + xla::ErrorSpec(1e-4, 1e-4)); +} + +INSTANTIATE_TEST_CASE_P(RandomCholeskyTestInstance, RandomCholeskyTest, + ::testing::Values(CholeskyTestCase{1, 1}, + CholeskyTestCase{1, 2}, + CholeskyTestCase{10, 5}, + CholeskyTestCase{2, 20})); + +} // namespace diff --git a/tensorflow/compiler/tf2xla/lib/while_loop.cc b/tensorflow/compiler/xla/client/lib/loops.cc similarity index 50% rename from tensorflow/compiler/tf2xla/lib/while_loop.cc rename to tensorflow/compiler/xla/client/lib/loops.cc index 594ab1dfd0700f47501712183f6efe62d17e15e7..721f987628a8ac7da3f3f872939c3f0457d6bbe2 100644 --- a/tensorflow/compiler/tf2xla/lib/while_loop.cc +++ b/tensorflow/compiler/xla/client/lib/loops.cc @@ -13,44 +13,43 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/tf2xla/lib/while_loop.h" -#include "tensorflow/compiler/tf2xla/lib/util.h" +#include "tensorflow/compiler/xla/client/lib/loops.h" + +#include "tensorflow/compiler/xla/client/lib/constants.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" -namespace tensorflow { +namespace xla { -xla::StatusOr> XlaWhileLoop( - const LoopConditionFunction& condition_function, - const LoopBodyFunction& body_function, - absl::Span initial_values, absl::string_view name, - xla::XlaBuilder* builder) { +StatusOr> WhileLoopHelper( + const WhileLoopHelperConditionFunction& condition_function, + const WhileLoopHelperBodyFunction& body_function, + absl::Span initial_values, absl::string_view name, + XlaBuilder* builder) { int arity = initial_values.size(); - std::vector var_shapes; + std::vector var_shapes; var_shapes.reserve(arity); - for (const xla::XlaOp& input : initial_values) { + for (const XlaOp& input : initial_values) { TF_ASSIGN_OR_RETURN(auto shape, builder->GetShape(input)); var_shapes.push_back(std::move(shape)); } - xla::Shape tuple_shape = xla::ShapeUtil::MakeTupleShape(var_shapes); + Shape tuple_shape = ShapeUtil::MakeTupleShape(var_shapes); // Unpacks a tuple into its component parts. - auto unpack_tuple = [](xla::XlaOp tuple, int arity, - xla::XlaBuilder* builder) { - std::vector elements(arity); + auto unpack_tuple = [](XlaOp tuple, int arity, XlaBuilder* builder) { + std::vector elements(arity); for (int i = 0; i < arity; ++i) { - elements[i] = xla::GetTupleElement(tuple, i); + elements[i] = GetTupleElement(tuple, i); } return elements; }; // Build the condition. - std::unique_ptr cond_builder = + std::unique_ptr cond_builder = builder->CreateSubBuilder(absl::StrCat(name, "_condition")); { - auto parameter = - xla::Parameter(cond_builder.get(), 0, tuple_shape, "parameter"); + auto parameter = Parameter(cond_builder.get(), 0, tuple_shape, "parameter"); TF_RETURN_IF_ERROR( condition_function(unpack_tuple(parameter, arity, cond_builder.get()), @@ -60,11 +59,10 @@ xla::StatusOr> XlaWhileLoop( TF_ASSIGN_OR_RETURN(auto cond, cond_builder->Build()); // Build the body. - std::unique_ptr body_builder = + std::unique_ptr body_builder = builder->CreateSubBuilder(absl::StrCat(name, "_body")); { - auto parameter = - xla::Parameter(body_builder.get(), 0, tuple_shape, "parameter"); + auto parameter = Parameter(body_builder.get(), 0, tuple_shape, "parameter"); TF_ASSIGN_OR_RETURN( auto result, @@ -72,56 +70,54 @@ xla::StatusOr> XlaWhileLoop( body_builder.get())); TF_RET_CHECK(result.size() == initial_values.size()); - xla::Tuple(body_builder.get(), result); + Tuple(body_builder.get(), result); } TF_ASSIGN_OR_RETURN(auto body, body_builder->Build()); - auto outputs = xla::While(cond, body, xla::Tuple(builder, initial_values)); + auto outputs = While(cond, body, Tuple(builder, initial_values)); return unpack_tuple(outputs, arity, builder); } -xla::StatusOr> XlaForEachIndex( - int64 num_iterations, xla::PrimitiveType num_iterations_type, +StatusOr> ForEachIndex( + int64 num_iterations, PrimitiveType num_iterations_type, const ForEachIndexBodyFunction& body_function, - absl::Span initial_values, absl::string_view name, - xla::XlaBuilder* builder) { - auto while_cond_fn = - [&](absl::Span values, - xla::XlaBuilder* cond_builder) -> xla::StatusOr { - return xla::Lt(values[0], IntegerLiteral(cond_builder, num_iterations_type, - num_iterations)); + absl::Span initial_values, absl::string_view name, + XlaBuilder* builder) { + auto while_cond_fn = [&](absl::Span values, + XlaBuilder* cond_builder) -> StatusOr { + return Lt(values[0], ConstantR0WithType(cond_builder, num_iterations_type, + num_iterations)); }; - auto while_body_fn = [&](absl::Span values, - xla::XlaBuilder* body_builder) - -> xla::StatusOr> { - xla::XlaOp iteration = values[0]; + auto while_body_fn = + [&](absl::Span values, + XlaBuilder* body_builder) -> StatusOr> { + XlaOp iteration = values[0]; - std::vector updated_values; + std::vector updated_values; updated_values.reserve(values.size()); - updated_values.push_back(xla::Add( + updated_values.push_back(Add( iteration, - xla::ConstantLiteral(body_builder, - xla::LiteralUtil::One(num_iterations_type)))); + ConstantLiteral(body_builder, LiteralUtil::One(num_iterations_type)))); values.remove_prefix(1); - TF_ASSIGN_OR_RETURN(std::vector body_outputs, + TF_ASSIGN_OR_RETURN(std::vector body_outputs, body_function(iteration, values, body_builder)); updated_values.insert(updated_values.end(), body_outputs.begin(), body_outputs.end()); return updated_values; }; - std::vector values; + std::vector values; values.reserve(initial_values.size() + 1); - values.push_back(xla::ConstantLiteral( - builder, xla::LiteralUtil::Zero(num_iterations_type))); + values.push_back( + ConstantLiteral(builder, LiteralUtil::Zero(num_iterations_type))); values.insert(values.end(), initial_values.begin(), initial_values.end()); - TF_ASSIGN_OR_RETURN(values, XlaWhileLoop(while_cond_fn, while_body_fn, values, - name, builder)); + TF_ASSIGN_OR_RETURN(values, WhileLoopHelper(while_cond_fn, while_body_fn, + values, name, builder)); values.erase(values.begin(), values.begin() + 1); return values; } -} // namespace tensorflow +} // namespace xla diff --git a/tensorflow/compiler/tf2xla/lib/while_loop.h b/tensorflow/compiler/xla/client/lib/loops.h similarity index 62% rename from tensorflow/compiler/tf2xla/lib/while_loop.h rename to tensorflow/compiler/xla/client/lib/loops.h index f2134bb4495a12b8342961d96f70e7737f816c7d..e11de59493e9c1de51fbdb6c45dab6d82b85a62a 100644 --- a/tensorflow/compiler/tf2xla/lib/while_loop.h +++ b/tensorflow/compiler/xla/client/lib/loops.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_COMPILER_TF2XLA_LIB_WHILE_LOOP_H_ -#define TENSORFLOW_COMPILER_TF2XLA_LIB_WHILE_LOOP_H_ +#ifndef TENSORFLOW_COMPILER_XLA_CLIENT_LIB_LOOPS_H_ +#define TENSORFLOW_COMPILER_XLA_CLIENT_LIB_LOOPS_H_ #include #include @@ -25,19 +25,18 @@ limitations under the License. #include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/statusor.h" -namespace tensorflow { +namespace xla { // Function that builds a loop condition. Takes as input a sequence of input // values, and returns a boolean value representing if the condition succeeds. -typedef std::function(absl::Span, - xla::XlaBuilder*)> - LoopConditionFunction; +typedef std::function(absl::Span, XlaBuilder*)> + WhileLoopHelperConditionFunction; // Function that builds a loop body. Takes as input a sequence of input values // and returns a sequence of output values. -typedef std::function>( - absl::Span, xla::XlaBuilder*)> - LoopBodyFunction; +typedef std::function>(absl::Span, + XlaBuilder*)> + WhileLoopHelperBodyFunction; // Helper function for building an XLA while loop, where the values carried by // the loop are a tuple of values, e.g., (a, b, c): @@ -47,27 +46,27 @@ typedef std::function>( // init: (a, b, c) // ) // 'name' is a descriptive name for the loop. -xla::StatusOr> XlaWhileLoop( - const LoopConditionFunction& condition_function, - const LoopBodyFunction& body_function, - absl::Span initial_values, absl::string_view name, - xla::XlaBuilder* builder); +StatusOr> WhileLoopHelper( + const WhileLoopHelperConditionFunction& condition_function, + const WhileLoopHelperBodyFunction& body_function, + absl::Span initial_values, absl::string_view name, + XlaBuilder* builder); // Builds an XLA loop that repeats a computation `num_iterations` times. // // The body function (ForEachIndexBodyFunction) takes as input a pair of // (current iteration number, loop-carried values), and returns an updated // vector of the loop-carried values. -typedef std::function>( - xla::XlaOp, absl::Span, xla::XlaBuilder*)> +typedef std::function>( + XlaOp, absl::Span, XlaBuilder*)> ForEachIndexBodyFunction; -xla::StatusOr> XlaForEachIndex( - int64 num_iterations, xla::PrimitiveType num_iterations_type, +StatusOr> ForEachIndex( + int64 num_iterations, PrimitiveType num_iterations_type, const ForEachIndexBodyFunction& body_function, - absl::Span initial_values, absl::string_view name, - xla::XlaBuilder* builder); + absl::Span initial_values, absl::string_view name, + XlaBuilder* builder); -} // namespace tensorflow +} // namespace xla -#endif // TENSORFLOW_COMPILER_TF2XLA_LIB_WHILE_LOOP_H_ +#endif // TENSORFLOW_COMPILER_XLA_CLIENT_LIB_LOOPS_H_ diff --git a/tensorflow/compiler/tf2xla/lib/qr.cc b/tensorflow/compiler/xla/client/lib/qr.cc similarity index 62% rename from tensorflow/compiler/tf2xla/lib/qr.cc rename to tensorflow/compiler/xla/client/lib/qr.cc index d6007748609fdd161cb89692a167eb7ed12fe00c..72ca653173b78d9338f632c41779f2a30db1e978 100644 --- a/tensorflow/compiler/tf2xla/lib/qr.cc +++ b/tensorflow/compiler/xla/client/lib/qr.cc @@ -13,15 +13,14 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/tf2xla/lib/qr.h" +#include "tensorflow/compiler/xla/client/lib/qr.h" #include #include -#include "tensorflow/compiler/tf2xla/lib/util.h" -#include "tensorflow/compiler/tf2xla/lib/while_loop.h" #include "tensorflow/compiler/xla/client/lib/arithmetic.h" #include "tensorflow/compiler/xla/client/lib/constants.h" +#include "tensorflow/compiler/xla/client/lib/loops.h" #include "tensorflow/compiler/xla/client/lib/math.h" #include "tensorflow/compiler/xla/client/lib/matrix.h" #include "tensorflow/compiler/xla/client/lib/slicing.h" @@ -32,10 +31,18 @@ limitations under the License. #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/core/lib/core/errors.h" -namespace tensorflow { +namespace xla { namespace { +std::vector ConcatVectors(absl::Span xs, + absl::Span ys) { + std::vector output(xs.size() + ys.size()); + std::copy(xs.begin(), xs.end(), output.begin()); + std::copy(ys.begin(), ys.end(), output.begin() + xs.size()); + return output; +} + // Computes a Householder reflection of the form: // H = I - tau v v.T. // such that @@ -65,52 +72,47 @@ namespace { // return (v, tau, beta) // TODO(phawkins): LAPACK's xLARFG implementation has code for handling // overflows in the norm/beta calculations. Perhaps do the same here. -xla::Status House(xla::XlaOp x, xla::XlaOp k, - absl::Span batch_dims, const int64 m, - xla::XlaOp* v, xla::XlaOp* tau, xla::XlaOp* beta) { - xla::XlaBuilder* const builder = x.builder(); - TF_ASSIGN_OR_RETURN(xla::Shape x_shape, builder->GetShape(x)); - const xla::PrimitiveType type = x_shape.element_type(); +Status House(XlaOp x, XlaOp k, absl::Span batch_dims, + const int64 m, XlaOp* v, XlaOp* tau, XlaOp* beta) { + XlaBuilder* const builder = x.builder(); + TF_ASSIGN_OR_RETURN(Shape x_shape, builder->GetShape(x)); + const PrimitiveType type = x_shape.element_type(); std::vector batch_dim_ids(batch_dims.size()); std::iota(batch_dim_ids.begin(), batch_dim_ids.end(), 0); const int64 minor_dim = batch_dims.size(); - xla::XlaOp zero = xla::ScalarLike(x, 0.0); - xla::XlaOp one = xla::ScalarLike(x, 1.0); + XlaOp zero = ScalarLike(x, 0.0); + XlaOp one = ScalarLike(x, 1.0); // alpha = x[k] - xla::XlaOp alpha = - xla::Reshape(DynamicSliceInMinorDims(x, {k}, {1}), batch_dims); + XlaOp alpha = Reshape(DynamicSliceInMinorDims(x, {k}, {1}), batch_dims); // Compute x[k+1:] (padded with zeros in elements 0..k) - xla::XlaOp iota = xla::Iota(builder, xla::S32, m); - xla::XlaOp x_after_k = - xla::Mul(x, xla::ConvertElementType(xla::Gt(iota, k), type), - /*broadcast_dimensions=*/{minor_dim}); + XlaOp iota = Iota(builder, S32, m); + XlaOp x_after_k = Mul(x, ConvertElementType(Gt(iota, k), type), + /*broadcast_dimensions=*/{minor_dim}); // sigma = np.dot(x[k+1:], x[k+1:]) - auto sigma = - xla::Reduce(x_after_k * x_after_k, zero, - xla::CreateScalarAddComputation(type, builder), {minor_dim}); + auto sigma = Reduce(x_after_k * x_after_k, zero, + CreateScalarAddComputation(type, builder), {minor_dim}); // mu = np.sqrt(x[k]*x[k] + sigma) - auto mu = xla::Sqrt(xla::Square(alpha) + sigma); + auto mu = Sqrt(Square(alpha) + sigma); - auto sigma_is_zero = xla::Eq(sigma, zero); + auto sigma_is_zero = Eq(sigma, zero); - *beta = xla::Select(sigma_is_zero, alpha, -xla::Sign(alpha) * mu); - *tau = xla::Select(sigma_is_zero, xla::Broadcast(zero, batch_dims), - (*beta - alpha) / *beta); - auto divisor = xla::Select(sigma_is_zero, xla::Broadcast(one, batch_dims), - alpha - *beta); + *beta = Select(sigma_is_zero, alpha, -Sign(alpha) * mu); + *tau = Select(sigma_is_zero, Broadcast(zero, batch_dims), + (*beta - alpha) / *beta); + auto divisor = + Select(sigma_is_zero, Broadcast(one, batch_dims), alpha - *beta); - auto e_k = xla::Broadcast(xla::ConvertElementType(xla::Eq(iota, k), type), - std::vector(batch_dims.size(), 1)); + auto e_k = Broadcast(ConvertElementType(Eq(iota, k), type), + std::vector(batch_dims.size(), 1)); // Form v as [0, 0, ..., 1] ++ x[k+1:] / divisor // If sigma is zero, x[k+1:] is zero, so use any non-zero divisor. - *v = e_k + - xla::Div(x_after_k, divisor, /*broadcast_dimensions=*/batch_dim_ids); + *v = e_k + Div(x_after_k, divisor, /*broadcast_dimensions=*/batch_dim_ids); return Status::OK(); } @@ -143,90 +145,86 @@ xla::Status House(xla::XlaOp x, xla::XlaOp k, // return (q, vs, taus) struct QRBlockResult { // The factored R value - xla::XlaOp r; + XlaOp r; // Representation of the Householder matrices I - beta v v.T - xla::XlaOp taus; // Shape: [..., n] - xla::XlaOp vs; // Shape: [..., m, n] + XlaOp taus; // Shape: [..., n] + XlaOp vs; // Shape: [..., m, n] }; -xla::StatusOr QRBlock( - xla::XlaOp a, xla::PrecisionConfig::Precision precision) { - xla::XlaBuilder* builder = a.builder(); - TF_ASSIGN_OR_RETURN(xla::Shape a_shape, builder->GetShape(a)); - const int num_dims = xla::ShapeUtil::Rank(a_shape); +StatusOr QRBlock(XlaOp a, PrecisionConfig::Precision precision) { + XlaBuilder* builder = a.builder(); + TF_ASSIGN_OR_RETURN(Shape a_shape, builder->GetShape(a)); + const int num_dims = ShapeUtil::Rank(a_shape); if (num_dims < 2) { - return errors::InvalidArgument("Arguments to QR must have rank >= 2: ", - num_dims); + return InvalidArgument("Argument to QR must have rank >= 2; got shape %s", + a_shape.ToString()); } - xla::PrimitiveType type = a_shape.element_type(); + PrimitiveType type = a_shape.element_type(); - const int64 m = xla::ShapeUtil::GetDimension(a_shape, -2); - const int64 n = xla::ShapeUtil::GetDimension(a_shape, -1); + const int64 m = ShapeUtil::GetDimension(a_shape, -2); + const int64 n = ShapeUtil::GetDimension(a_shape, -1); const int64 num_batch_dims = num_dims - 2; std::vector batch_dims(num_batch_dims); for (int i = 0; i < num_batch_dims; ++i) { - batch_dims[i] = xla::ShapeUtil::GetDimension(a_shape, i); + batch_dims[i] = ShapeUtil::GetDimension(a_shape, i); } std::vector batch_dim_indices(num_batch_dims); std::iota(batch_dim_indices.begin(), batch_dim_indices.end(), 0); - auto qr_body_fn = - [&](xla::XlaOp j, absl::Span values, - xla::XlaBuilder* builder) -> xla::StatusOr> { + auto qr_body_fn = [&](XlaOp j, absl::Span values, + XlaBuilder* builder) -> StatusOr> { auto a = values[0]; auto vs = values[1]; auto taus = values[2]; // v, beta = house(a[:, j], j) auto x = DynamicSliceInMinorDims(a, {j}, {1}); - xla::XlaOp v, tau, beta; - TF_RETURN_IF_ERROR(House(xla::Collapse(x, {num_dims - 2, num_dims - 1}), j, + XlaOp v, tau, beta; + TF_RETURN_IF_ERROR(House(Collapse(x, {num_dims - 2, num_dims - 1}), j, batch_dims, m, &v, &tau, &beta)); std::vector shape = batch_dims; shape.push_back(1); shape.push_back(m); - auto v_broadcast = xla::Reshape(v, shape); + auto v_broadcast = Reshape(v, shape); // a[:, :] -= tau * np.dot(v[:, np.newaxis], // np.dot(v[np.newaxis, :], a[:, :])) auto vva = BatchDot(v_broadcast, a, precision); vva = BatchDot(TransposeInMinorDims(v_broadcast), vva, precision); - a = a - xla::Mul(tau, vva, - /*broadcast_dimensions=*/batch_dim_indices); + a = a - Mul(tau, vva, + /*broadcast_dimensions=*/batch_dim_indices); // It is more precise to populate column 'k' explicitly, rather than // computing it implicitly by applying the Householder transformation. // a[k,k] = beta // a[k+1:,k] = np.zeros([m-k-1], dtype=a.dtype) - auto iota = xla::Reshape(xla::Iota(a.builder(), xla::S32, m), {m, 1}); - auto predecessor_mask = xla::ConvertElementType(xla::Lt(iota, j), type); - auto mask = xla::Broadcast(xla::ConvertElementType(xla::Eq(iota, j), type), - std::vector(batch_dims.size(), 1)); - auto new_x = - xla::Mul(x, predecessor_mask, - /*broadcast_dimensions=*/{num_dims - 2, num_dims - 1}) + - xla::Mul(beta, mask, /*broadcast_dimensions=*/batch_dim_indices); + auto iota = Reshape(Iota(a.builder(), S32, m), {m, 1}); + auto predecessor_mask = ConvertElementType(Lt(iota, j), type); + auto mask = Broadcast(ConvertElementType(Eq(iota, j), type), + std::vector(batch_dims.size(), 1)); + auto new_x = Mul(x, predecessor_mask, + /*broadcast_dimensions=*/{num_dims - 2, num_dims - 1}) + + Mul(beta, mask, /*broadcast_dimensions=*/batch_dim_indices); a = DynamicUpdateSliceInMinorDims(a, new_x, {j}); // vs[:, j] = v vs = DynamicUpdateSliceInMinorDims( - vs, xla::Reshape(v, ConcatVectors(batch_dims, {m, 1})), {j}); + vs, Reshape(v, ConcatVectors(batch_dims, {m, 1})), {j}); // taus[j] = tau taus = DynamicUpdateSliceInMinorDims( - taus, xla::Reshape(tau, ConcatVectors(batch_dims, {1})), {j}); - return std::vector{a, vs, taus}; + taus, Reshape(tau, ConcatVectors(batch_dims, {1})), {j}); + return std::vector{a, vs, taus}; }; - auto vs = xla::Zeros(builder, xla::ShapeUtil::MakeShape( - type, ConcatVectors(batch_dims, {m, n}))); - auto taus = xla::Zeros( - builder, xla::ShapeUtil::MakeShape(type, ConcatVectors(batch_dims, {n}))); + auto vs = Zeros( + builder, ShapeUtil::MakeShape(type, ConcatVectors(batch_dims, {m, n}))); + auto taus = Zeros(builder, + ShapeUtil::MakeShape(type, ConcatVectors(batch_dims, {n}))); - TF_ASSIGN_OR_RETURN(auto values, - XlaForEachIndex(std::min(m, n), xla::S32, qr_body_fn, - {a, vs, taus}, "qr", builder)); + TF_ASSIGN_OR_RETURN(auto values, ForEachIndex(std::min(m, n), S32, qr_body_fn, + {a, vs, taus}, "qr", builder)); QRBlockResult result; result.r = values[0]; @@ -250,24 +248,23 @@ xla::StatusOr QRBlock( // return W // There is no need to return Y since at termination of the loop it is equal to // vs. -xla::StatusOr ComputeWYRepresentation( - xla::PrimitiveType type, absl::Span batch_dims, xla::XlaOp vs, - xla::XlaOp taus, int64 m, int64 n, - xla::PrecisionConfig::Precision precision) { +StatusOr ComputeWYRepresentation(PrimitiveType type, + absl::Span batch_dims, + XlaOp vs, XlaOp taus, int64 m, int64 n, + PrecisionConfig::Precision precision) { std::vector batch_dim_indices(batch_dims.size()); std::iota(batch_dim_indices.begin(), batch_dim_indices.end(), 0); int64 n_index = batch_dims.size() + 1; - auto body_fn = - [&](xla::XlaOp j, absl::Span values, - xla::XlaBuilder* builder) -> xla::StatusOr> { + auto body_fn = [&](XlaOp j, absl::Span values, + XlaBuilder* builder) -> StatusOr> { auto w = values[0]; auto y = values[1]; const auto vs = values[2]; const auto taus = values[3]; // Want j values in range [1, ... n). - j = j + xla::ConstantR0(builder, 1); + j = j + ConstantR0(builder, 1); // vs has shape [..., m, 1] auto v = DynamicSliceInMinorDims(vs, {j}, {1}); // beta has shape [..., 1] @@ -278,31 +275,31 @@ xla::StatusOr ComputeWYRepresentation( // wyv has shape [..., m, 1] auto wyv = BatchDot(w, yv, precision); - auto z = xla::Mul( + auto z = Mul( -beta, v + wyv, /*broadcast_dimensions=*/ConcatVectors(batch_dim_indices, {n_index})); w = DynamicUpdateSliceInMinorDims(w, z, {j}); y = DynamicUpdateSliceInMinorDims(y, v, {j}); - return std::vector{w, y, vs, taus}; + return std::vector{w, y, vs, taus}; }; - xla::XlaBuilder* builder = vs.builder(); - auto w = xla::Zeros(builder, xla::ShapeUtil::MakeShape( - type, ConcatVectors(batch_dims, {m, n}))); + XlaBuilder* builder = vs.builder(); + auto w = Zeros(builder, + ShapeUtil::MakeShape(type, ConcatVectors(batch_dims, {m, n}))); auto y = w; auto v = SliceInMinorDims(vs, {0}, {1}); auto beta = SliceInMinorDims(taus, {0}, {1}); y = UpdateSliceInMinorDims(y, v, {0}); - auto bv = xla::Mul( - -beta, v, - /*broadcast_dimensions=*/ConcatVectors(batch_dim_indices, {n_index})); + auto bv = + Mul(-beta, v, + /*broadcast_dimensions=*/ConcatVectors(batch_dim_indices, {n_index})); w = UpdateSliceInMinorDims(w, bv, {0}); TF_ASSIGN_OR_RETURN( - auto values, XlaForEachIndex(n - 1, xla::S32, body_fn, {w, y, vs, taus}, - "wy", builder)); + auto values, + ForEachIndex(n - 1, S32, body_fn, {w, y, vs, taus}, "wy", builder)); return values[0]; } @@ -323,34 +320,34 @@ xla::StatusOr ComputeWYRepresentation( // return (q, a) // TODO(phawkins): consider using UT transformations (in the form I - V U V') // rather than WY transformations. -xla::StatusOr QRDecomposition( - xla::XlaOp a, bool full_matrices, int64 block_size, - xla::PrecisionConfig::Precision precision) { - xla::XlaBuilder* builder = a.builder(); - TF_ASSIGN_OR_RETURN(xla::Shape a_shape, builder->GetShape(a)); - const int num_dims = xla::ShapeUtil::Rank(a_shape); +StatusOr QRDecomposition( + XlaOp a, bool full_matrices, int64 block_size, + PrecisionConfig::Precision precision) { + XlaBuilder* builder = a.builder(); + TF_ASSIGN_OR_RETURN(Shape a_shape, builder->GetShape(a)); + const int num_dims = ShapeUtil::Rank(a_shape); if (num_dims < 2) { - return errors::InvalidArgument("Arguments to QR must have rank >= 2: ", - num_dims); + return InvalidArgument("Arguments to QR must have rank >= 2: got shape %s", + a_shape.ToString()); } - xla::PrimitiveType type = a_shape.element_type(); + PrimitiveType type = a_shape.element_type(); - const int64 m = xla::ShapeUtil::GetDimension(a_shape, -2); - const int64 n = xla::ShapeUtil::GetDimension(a_shape, -1); + const int64 m = ShapeUtil::GetDimension(a_shape, -2); + const int64 n = ShapeUtil::GetDimension(a_shape, -1); const int64 p = std::min(m, n); if (block_size < 1) { - return errors::InvalidArgument( - "block_size argument to QR must be >= 1; got ", block_size); + return InvalidArgument("block_size argument to QR must be >= 1; got %d", + block_size); } const int64 num_batch_dims = num_dims - 2; std::vector batch_dims(num_batch_dims); for (int i = 0; i < num_batch_dims; ++i) { - batch_dims[i] = xla::ShapeUtil::GetDimension(a_shape, i); + batch_dims[i] = ShapeUtil::GetDimension(a_shape, i); } - auto q = xla::Broadcast(xla::IdentityMatrix(builder, type, m, m), batch_dims); + auto q = Broadcast(IdentityMatrix(builder, type, m, m), batch_dims); for (int64 i = 0; i < p; i += block_size) { int64 k = std::min(block_size, p - i); @@ -393,4 +390,4 @@ xla::StatusOr QRDecomposition( return result; } -} // namespace tensorflow +} // namespace xla diff --git a/tensorflow/compiler/tf2xla/lib/qr.h b/tensorflow/compiler/xla/client/lib/qr.h similarity index 74% rename from tensorflow/compiler/tf2xla/lib/qr.h rename to tensorflow/compiler/xla/client/lib/qr.h index 24b537ac8b63b93e734c3d0e335ea455f7d51a54..827c8eeca05ef09a0d77363eb3c40961b95813d8 100644 --- a/tensorflow/compiler/tf2xla/lib/qr.h +++ b/tensorflow/compiler/xla/client/lib/qr.h @@ -13,13 +13,13 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_COMPILER_TF2XLA_LIB_QR_H_ -#define TENSORFLOW_COMPILER_TF2XLA_LIB_QR_H_ +#ifndef TENSORFLOW_COMPILER_XLA_CLIENT_LIB_QR_H_ +#define TENSORFLOW_COMPILER_XLA_CLIENT_LIB_QR_H_ #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -namespace tensorflow { +namespace xla { // Computes the QR decompositions of a batch of matrices. That is, // given a (batched) matrix a, computes an orthonormal matrix Q and an @@ -29,14 +29,14 @@ namespace tensorflow { // the block size to use. // TODO(phawkins): handle the complex case. struct QRDecompositionResult { - xla::XlaOp q; - xla::XlaOp r; + XlaOp q; + XlaOp r; }; -xla::StatusOr QRDecomposition( - xla::XlaOp a, bool full_matrices, int64 block_size = 128, - xla::PrecisionConfig::Precision precision = xla::PrecisionConfig::HIGHEST); +StatusOr QRDecomposition( + XlaOp a, bool full_matrices, int64 block_size = 128, + PrecisionConfig::Precision precision = PrecisionConfig::HIGHEST); -} // namespace tensorflow +} // namespace xla -#endif // TENSORFLOW_COMPILER_TF2XLA_LIB_QR_H_ +#endif // TENSORFLOW_COMPILER_XLA_CLIENT_LIB_QR_H_ diff --git a/tensorflow/compiler/xla/client/lib/qr_test.cc b/tensorflow/compiler/xla/client/lib/qr_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..b27d364b62444d6d5fb1278b6e6461affc15b2e6 --- /dev/null +++ b/tensorflow/compiler/xla/client/lib/qr_test.cc @@ -0,0 +1,93 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/client/lib/qr.h" + +#include "tensorflow/compiler/xla/array2d.h" +#include "tensorflow/compiler/xla/array3d.h" +#include "tensorflow/compiler/xla/client/lib/matrix.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/literal.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/tests/client_library_test_base.h" +#include "tensorflow/compiler/xla/tests/literal_test_util.h" +#include "tensorflow/compiler/xla/tests/test_macros.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/core/status_test_util.h" + +namespace { + +using QrTest = xla::ClientLibraryTestBase; + +XLA_TEST_F(QrTest, Simple) { + xla::XlaBuilder builder(TestName()); + + xla::Array2D a_vals({ + {4, 6, 8, 10}, + {6, 45, 54, 63}, + {8, 54, 146, 166}, + {10, 63, 166, 310}, + }); + + xla::XlaOp a; + auto a_data = CreateR2Parameter(a_vals, 0, "a", &builder, &a); + TF_ASSERT_OK_AND_ASSIGN( + auto result, + xla::QRDecomposition(a, /*full_matrices=*/true, /*block_size=*/2)); + + // Verifies that the decomposition composes back to the original matrix. + // + // This isn't a terribly demanding test, (e.g., we should verify that Q is + // orthonormal and R is upper-triangular) but it's awkward to write such tests + // without more linear algebra libraries. It's easier to test the numerics + // from Python, anyway, where we have access to numpy and scipy. + xla::BatchDot(result.q, result.r, xla::PrecisionConfig::HIGHEST); + + ComputeAndCompareR2(&builder, a_vals, {a_data.get()}, + xla::ErrorSpec(1e-4, 1e-4)); +} + +XLA_TEST_F(QrTest, SimpleBatched) { + xla::XlaBuilder builder(TestName()); + + xla::Array3D a_vals({ + { + {4, 6, 8, 10}, + {6, 45, 54, 63}, + {8, 54, 146, 166}, + {10, 63, 166, 310}, + }, + { + {16, 24, 8, 12}, + {24, 61, 82, 48}, + {8, 82, 456, 106}, + {12, 48, 106, 62}, + }, + }); + + xla::XlaOp a; + auto a_data = CreateR3Parameter(a_vals, 0, "a", &builder, &a); + TF_ASSERT_OK_AND_ASSIGN( + auto result, + xla::QRDecomposition(a, /*full_matrices=*/true, /*block_size=*/2)); + + xla::BatchDot(result.q, result.r, xla::PrecisionConfig::HIGHEST); + + ComputeAndCompareR3(&builder, a_vals, {a_data.get()}, + xla::ErrorSpec(1e-4, 1e-4)); +} + +} // namespace diff --git a/tensorflow/compiler/xla/client/lib/quantize.h b/tensorflow/compiler/xla/client/lib/quantize.h new file mode 100644 index 0000000000000000000000000000000000000000..26dbbd5b00bd1a29f4047c9a4294fcac7340cf6c --- /dev/null +++ b/tensorflow/compiler/xla/client/lib/quantize.h @@ -0,0 +1,186 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_CLIENT_LIB_QUANTIZE_H_ +#define TENSORFLOW_COMPILER_XLA_CLIENT_LIB_QUANTIZE_H_ + +#include +#include +#include + +#include "tensorflow/compiler/xla/client/lib/constants.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/bfloat16/bfloat16.h" + +namespace xla { + +constexpr int64 kBitsOfByte = 8; + +// Represents the range used for quantization +struct QuantizedRange { + QuantizedRange() = default; + QuantizedRange(float min_in, float max_in) : min(min_in), max(max_in) {} + + bool operator==(const QuantizedRange& rhs) const { + return this->min == rhs.min && this->max == rhs.max; + } + + bool operator!=(const QuantizedRange& rhs) const { return !(*this == rhs); } + + tensorflow::bfloat16 min = tensorflow::bfloat16(0.0f); + tensorflow::bfloat16 max = tensorflow::bfloat16(0.0f); +}; + +template +inline std::vector PackToUint32(absl::Span input) { + const int64 kElementsPerPack = sizeof(uint32) / sizeof(T); + const int64 input_size = input.size(); + const int64 output_size = CeilOfRatio(input_size, kElementsPerPack); + + std::vector output_vec; + constexpr int64 kShiftBits = sizeof(T) / sizeof(uint8) * kBitsOfByte; + + for (int64 i = 0; i < output_size; i++) { + uint32 result = 0; + for (int64 p = 0; p < kElementsPerPack; p++) { + int64 index = i * kElementsPerPack + p; + if (index < input_size) { + int64 total_shift_bits = kShiftBits * (kElementsPerPack - p - 1); + result |= (input[index] << total_shift_bits); + } + } + output_vec.push_back(result); + } + + return output_vec; +} + +// Dequantize the quantized input of packed uint32 to bfloat16. +// Only uint8 or uint16 is supported for the original unpacked input. +// Returns a tensor of shape [d0,..., dn * unpack_size] if +// input shape is [d0, ..., dn], where unpack_size = sizeof(unit32) / sizeof(T). +// If transpose_output is true, will return a tensor of shape +// [dn * unpack_size, dn-1, ..., d1, d0]. transpose_output is faster when +// input's rank higher than 1. The input needs to be transposed to use +// transpose_output feature. +template +inline XlaOp Dequantize(XlaOp input, const QuantizedRange& range, + absl::string_view mode_string = "MIN_COMBINED", + bool transpose_output = false) { + XlaBuilder* const builder = input.builder(); + return builder->ReportErrorOrReturn([&]() -> StatusOr { + float half_range = + !std::is_signed::value + ? 0.0f + : (static_cast(std::numeric_limits::max()) - + std::numeric_limits::min() + 1) / + 2.0f; + const int64 unpack_size = sizeof(uint32) / sizeof(T); + TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(input)); + + auto element_type = shape.element_type(); + if (element_type != U32) { + return InvalidArgument( + "Only U32 is supported for input type of xla::Dequantize Op."); + } + + // Broadcast the input to [unpack_size, d0, ..., dn] if input size is + // [d0, ..., dn]. + auto broadcast_input = Broadcast(input, {unpack_size}); + + XlaOp iota_r1 = Iota(builder, U32, unpack_size); + // Highest significant bytes needs to shift more bytes than lower + // significant bytes. + XlaOp shift_bytes = + xla::ConstantR0(builder, unpack_size - 1) - iota_r1; + + const int bytes_of_type = sizeof(T) / sizeof(uint8); + std::vector shift_vec(unpack_size, kBitsOfByte * bytes_of_type); + XlaOp shift_bits = + shift_bytes * xla::ConstantR1(builder, shift_vec); + + // Make bit_mask for different data type T. + uint32 bit_mask = 0x00000000; + for (int i = 0; i < bytes_of_type; i++) { + bit_mask <<= kBitsOfByte; + bit_mask |= 0x000000ff; + } + + std::vector shift_transpose_dimensions(shape.dimensions_size()); + std::iota(shift_transpose_dimensions.begin(), + shift_transpose_dimensions.end(), 0); + shift_transpose_dimensions.insert(shift_transpose_dimensions.begin(), 1, + shape.dimensions_size()); + + // Shift the input by sizeof(T) bytes and apply bit_mask to unpack. + XlaOp shifted_input = ShiftRightLogical( + broadcast_input, Transpose(Broadcast(shift_bits, shape.dimensions()), + shift_transpose_dimensions)); + XlaOp unpack_input = + And(shifted_input, xla::ConstantR0(builder, bit_mask)); + + XlaOp result; + + if (mode_string == "MIN_COMBINED") { + const tensorflow::bfloat16 scale_factor = + (range.max - range.min) / + (static_cast(std::numeric_limits::max() - + std::numeric_limits::min())); + // result = bfloat16(input + half_range) * scale_factor + range.min + XlaOp unpack_input_bf16 = ConvertElementType(unpack_input, BF16); + XlaOp half_range_bf16 = xla::ConstantR0( + builder, static_cast(half_range)); + XlaOp sum = unpack_input_bf16 + half_range_bf16; + + result = + sum * xla::ConstantR0(builder, scale_factor) + + xla::ConstantR0(builder, range.min); + } else { + // TODO(wangtao): support other modes. + return InvalidArgument( + "Only MIN_COMBINED mode is supported in xla::Dequantize Op."); + } + + std::vector transpose_dimensions(shape.dimensions_size()); + std::iota(transpose_dimensions.begin(), transpose_dimensions.end(), 1); + std::reverse(transpose_dimensions.begin(), transpose_dimensions.end()); + transpose_dimensions.insert(transpose_dimensions.begin() + 1, 1, 0); + + // Transpose the result to be [dn, unpack_size, dn-1, ..., d1, d0]. + XlaOp transposed_result = Transpose(result, transpose_dimensions); + + // Reshape to be [dn * unpack_size, dn-1, ..., d1, d0]. + XlaOp reshaped_result = Collapse(transposed_result, {0, 1}); + + // Return the transpose result if transpose_output is true. + if (transpose_output) { + return reshaped_result; + } + + // Transpose the result to be [d0, d1, ..., dn-1, dn * unpack_size]. + std::vector result_dimensions(shape.dimensions_size()); + std::iota(result_dimensions.begin(), result_dimensions.end(), 0); + std::reverse(result_dimensions.begin(), result_dimensions.end()); + + return Transpose(reshaped_result, result_dimensions); + }); +} + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_CLIENT_LIB_QUANTIZE_H_ diff --git a/tensorflow/compiler/xla/client/lib/quantize_test.cc b/tensorflow/compiler/xla/client/lib/quantize_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..be3603d9e11670913c21a834d2216a999306d582 --- /dev/null +++ b/tensorflow/compiler/xla/client/lib/quantize_test.cc @@ -0,0 +1,337 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/client/lib/quantize.h" + +#include + +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/tests/client_library_test_base.h" +#include "tensorflow/compiler/xla/tests/test_macros.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/util.h" + +namespace xla { +namespace { + +using bfloat16 = tensorflow::bfloat16; + +template +std::vector GenerateInput() { + std::vector input; + + for (int64 i = std::numeric_limits::min(); + i < std::numeric_limits::max(); ++i) { + input.push_back(static_cast(i)); + } + + return input; +} + +template +Array2D GenerateLargeSizeInput(int num_columns, int num_rows) { + Array2D input(num_columns, num_rows); + + input.FillRandom(6, 128); + + return input; +} + +template +Array2D PackLargeInput(Array2D &input) { + const int64 size_per_pack = sizeof(uint32) / sizeof(NativeT); + int64 width = input.width(); + + int64 padded_output_width = CeilOfRatio(width, size_per_pack); + + Array2D pack_input(input.height(), padded_output_width); + + for (int h = 0; h < input.height(); h++) { + std::vector input_row; + for (int w = 0; w < width; w++) { + input_row.push_back(input({h, w})); + } + + auto pack_input_vec = PackToUint32(input_row); + + for (int w = 0; w < padded_output_width; w++) { + pack_input(h, w) = pack_input_vec[w]; + } + } + + return pack_input; +} + +template +Array2D GenerateLargeSizeMinCombinedOutput( + Array2D &input, const QuantizedRange &range, + bool transpose_output = false) { + const int64 size_per_pack = sizeof(uint32) / sizeof(NativeT); + int64 width = input.width(); + + int64 padded_output_width = CeilOfRatio(width, size_per_pack) * size_per_pack; + + int64 output_height; + int64 output_width; + + if (transpose_output) { + output_height = padded_output_width; + output_width = input.height(); + } else { + output_height = input.height(); + output_width = padded_output_width; + } + + Array2D output(output_height, output_width, bfloat16(0.0)); + + float half_range = + !std::is_signed::value + ? 0.0f + : (static_cast(std::numeric_limits::max() - + std::numeric_limits::min() + 1)) / + 2.0f; + const bfloat16 scale_factor = + (range.max - range.min) / + (static_cast(std::numeric_limits::max() - + std::numeric_limits::min())); + + for (int h = 0; h < input.height(); h++) { + std::vector input_row; + for (int w = 0; w < width; w++) { + bfloat16 result = + static_cast(input(h, w) + half_range) * scale_factor + + range.min; + if (transpose_output) { + output(w, h) = result; + } else { + output(h, w) = result; + } + } + } + + return output; +} + +template +std::vector GenerateMinCombinedOutput(const QuantizedRange &range) { + float half_range = + !std::is_signed::value + ? 0.0f + : (static_cast(std::numeric_limits::max() - + std::numeric_limits::min() + 1)) / + 2.0f; + const bfloat16 scale_factor = + (range.max - range.min) / + (static_cast(std::numeric_limits::max() - + std::numeric_limits::min())); + std::vector output; + for (int64 i = std::numeric_limits::min(); + i < std::numeric_limits::max(); ++i) { + bfloat16 result = + static_cast(i + half_range) * scale_factor + range.min; + output.push_back(result); + } + + const int64 pack_size = sizeof(uint32) / sizeof(NativeT); + const int64 output_size = output.size(); + + int64 num_tailing_zeros = + CeilOfRatio(output_size, pack_size) * pack_size - output_size; + + output.insert(output.end(), num_tailing_zeros, bfloat16(0.0)); + return output; +} + +// TODO(wangtao): add a test to make sure this op is the inverse of the existing +// TF quantize op defined in: third_party/tensorflow/core/kernels/quantize_op.cc + +using DequantizeTest = ClientLibraryTestBase; + +TEST(PackTest, PackUint8ToUint32) { + std::vector input = {0xAB, 0x0B, 0x00, 0xF0, 0x01}; + auto output = PackToUint32(input); + EXPECT_THAT(output, ::testing::ElementsAre(0xAB0B00F0, 0x01000000)); +} + +TEST(PackTest, PackInt8ToUint32) { + std::vector input = {static_cast(0x81), 0x0B, 0x00, 0x20, + 0x01}; + auto output = PackToUint32(input); + EXPECT_THAT(output, ::testing::ElementsAre(0x810B0020, 0x01000000)); +} + +TEST(PackTest, PackUint8ToUint32PerfectSize) { + std::vector input = {3, 2, 1, 0}; + auto output = PackToUint32(input); + EXPECT_THAT(output, ::testing::ElementsAre(0x03020100)); +} + +XLA_TEST_F(DequantizeTest, MinCombinedUint16R1) { + XlaBuilder builder(TestName()); + auto input = GenerateInput(); + auto x = ConstantR1(&builder, PackToUint32(input)); + QuantizedRange range(0, 255.0f); + xla::Dequantize(x, range, "MIN_COMBINED"); + auto expected = GenerateMinCombinedOutput(range); + ComputeAndCompareR1(&builder, expected, {}); +} + +XLA_TEST_F(DequantizeTest, MinCombinedUint8R1) { + XlaBuilder builder(TestName()); + auto input = GenerateInput(); + auto x = ConstantR1(&builder, PackToUint32(input)); + QuantizedRange range(0, 127.0f); + xla::Dequantize(x, range, "MIN_COMBINED"); + auto expected = GenerateMinCombinedOutput(range); + ComputeAndCompareR1(&builder, expected, {}); +} + +XLA_TEST_F(DequantizeTest, MinCombinedUint8R2) { + XlaBuilder builder(TestName()); + std::vector> input = { + {0, 1, 2, 3}, + {4, 5, 6, 7}, + {8, 9, 10, 11}, + {12, 13, 16, 15}, + }; + auto x = ConstantR2(&builder, {{PackToUint32(input[0])[0]}, + {PackToUint32(input[1])[0]}, + {PackToUint32(input[2])[0]}, + {PackToUint32(input[3])[0]}}); + QuantizedRange range(0, 255.0f); + xla::Dequantize(x, range, "MIN_COMBINED"); + const Array2D expected = { + {bfloat16(0.0), bfloat16(1.0), bfloat16(2.0), bfloat16(3.0)}, + {bfloat16(4.0), bfloat16(5.0), bfloat16(6.0), bfloat16(7.0)}, + {bfloat16(8.0), bfloat16(9.0), bfloat16(10.0), bfloat16(11.0)}, + {bfloat16(12.0), bfloat16(13.0), bfloat16(16.0), bfloat16(15.0)}, + }; + ComputeAndCompareR2(&builder, expected, {}); +} + +XLA_TEST_F(DequantizeTest, MinCombinedUint8R2TransposeOutput) { + XlaBuilder builder(TestName()); + std::vector> input = { + {0, 1, 2, 3}, + {4, 5, 6, 7}, + {8, 9, 10, 11}, + {12, 13, 16, 15}, + }; + auto x = ConstantR2(&builder, {{PackToUint32(input[0])[0]}, + {PackToUint32(input[1])[0]}, + {PackToUint32(input[2])[0]}, + {PackToUint32(input[3])[0]}}); + QuantizedRange range(0, 255.0f); + xla::Dequantize(x, range, "MIN_COMBINED", /*transpose_output=*/true); + const Array2D expected = { + {bfloat16(0.0), bfloat16(4.0), bfloat16(8.0), bfloat16(12.0)}, + {bfloat16(1.0), bfloat16(5.0), bfloat16(9.0), bfloat16(13.0)}, + {bfloat16(2.0), bfloat16(6.0), bfloat16(10.0), bfloat16(16.0)}, + {bfloat16(3.0), bfloat16(7.0), bfloat16(11.0), bfloat16(15.0)}, + }; + ComputeAndCompareR2(&builder, expected, {}); +} + +XLA_TEST_F(DequantizeTest, MinCombinedUint8R2TailingZero) { + XlaBuilder builder(TestName()); + std::vector> input = { + {0, 1, 2, 3, 16}, + {4, 5, 6, 7, 17}, + {8, 9, 10, 11, 18}, + {12, 13, 16, 15, 19}, + }; + auto x = ConstantR2( + &builder, + {{PackToUint32(input[0])[0], PackToUint32(input[0])[1]}, + {PackToUint32(input[1])[0], PackToUint32(input[1])[1]}, + {PackToUint32(input[2])[0], PackToUint32(input[2])[1]}, + {PackToUint32(input[3])[0], PackToUint32(input[3])[1]}}); + QuantizedRange range(0, 255.0f); + xla::Dequantize(x, range, "MIN_COMBINED"); + + const Array2D expected = { + {bfloat16(0.0), bfloat16(1.0), bfloat16(2.0), bfloat16(3.0), + bfloat16(16.0), bfloat16(0.0), bfloat16(0.0), bfloat16(0.0)}, + {bfloat16(4.0), bfloat16(5.0), bfloat16(6.0), bfloat16(7.0), + bfloat16(17.0), bfloat16(0.0), bfloat16(0.0), bfloat16(0.0)}, + {bfloat16(8.0), bfloat16(9.0), bfloat16(10.0), bfloat16(11.0), + bfloat16(18.0), bfloat16(0.0), bfloat16(0.0), bfloat16(0.0)}, + {bfloat16(12.0), bfloat16(13.0), bfloat16(16.0), bfloat16(15.0), + bfloat16(19.0), bfloat16(0.0), bfloat16(0.0), bfloat16(0.0)}, + }; + ComputeAndCompareR2(&builder, expected, {}); +} + +XLA_TEST_F(DequantizeTest, MinCombinedUint8R2TailingZeroTransposeOutput) { + XlaBuilder builder(TestName()); + std::vector> input = { + {0, 1, 2, 3, 16}, + {4, 5, 6, 7, 17}, + {8, 9, 10, 11, 18}, + {12, 13, 16, 15, 19}, + }; + auto x = ConstantR2( + &builder, + {{PackToUint32(input[0])[0], PackToUint32(input[0])[1]}, + {PackToUint32(input[1])[0], PackToUint32(input[1])[1]}, + {PackToUint32(input[2])[0], PackToUint32(input[2])[1]}, + {PackToUint32(input[3])[0], PackToUint32(input[3])[1]}}); + QuantizedRange range(0, 255.0f); + xla::Dequantize(x, range, "MIN_COMBINED", /*transpose_output=*/true); + + const Array2D expected = { + {bfloat16(0.0), bfloat16(4.0), bfloat16(8.0), bfloat16(12.0)}, + {bfloat16(1.0), bfloat16(5.0), bfloat16(9.0), bfloat16(13.0)}, + {bfloat16(2.0), bfloat16(6.0), bfloat16(10.0), bfloat16(16.0)}, + {bfloat16(3.0), bfloat16(7.0), bfloat16(11.0), bfloat16(15.0)}, + {bfloat16(16.0), bfloat16(17.0), bfloat16(18.0), bfloat16(19.0)}, + {bfloat16(0.0), bfloat16(0.0), bfloat16(0.0), bfloat16(0.0)}, + {bfloat16(0.0), bfloat16(0.0), bfloat16(0.0), bfloat16(0.0)}, + {bfloat16(0.0), bfloat16(0.0), bfloat16(0.0), bfloat16(0.0)}, + }; + ComputeAndCompareR2(&builder, expected, {}); +} + +XLA_TEST_F(DequantizeTest, MinCombinedUint8LargeSizeTest) { + XlaBuilder builder(TestName()); + Array2D input = GenerateLargeSizeInput(500, 3547); + Array2D input_packed = PackLargeInput(input); + + auto x = ConstantR2FromArray2D(&builder, input_packed); + QuantizedRange range(0, 255.0f); + xla::Dequantize(x, range, "MIN_COMBINED"); + + const Array2D expected = + GenerateLargeSizeMinCombinedOutput(input, range); + ComputeAndCompareR2(&builder, expected, {}); +} + +XLA_TEST_F(DequantizeTest, MinCombinedUint8LargeSizeTestTransposeOutput) { + XlaBuilder builder(TestName()); + Array2D input = GenerateLargeSizeInput(500, 3547); + Array2D input_packed = PackLargeInput(input); + + auto x = ConstantR2FromArray2D(&builder, input_packed); + QuantizedRange range(0, 255.0f); + xla::Dequantize(x, range, "MIN_COMBINED", /*transpose_output=*/true); + + const Array2D expected = GenerateLargeSizeMinCombinedOutput( + input, range, /*transpose_output=*/true); + ComputeAndCompareR2(&builder, expected, {}); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/client/lib/testing.cc b/tensorflow/compiler/xla/client/lib/testing.cc index a95bbf2c8c860914877d3195b97342097dafc725..5db9d10dff4c50d71cde934b3f3c345bee571f29 100644 --- a/tensorflow/compiler/xla/client/lib/testing.cc +++ b/tensorflow/compiler/xla/client/lib/testing.cc @@ -59,22 +59,25 @@ XlaOp BuildFakeDataOpOnDevice(const Shape& shape, XlaBuilder* builder) { return Tuple(builder, parts); } -std::unique_ptr MakeFakeDataViaDeviceOrDie(const Shape& shape, - Client* client) { +std::unique_ptr MakeFakeDataViaDeviceOrDie( + const Shape& shape, Client* client, DebugOptions* debug_opts) { XlaBuilder b(absl::StrCat("make_fake_", ShapeUtil::HumanString(shape))); BuildFakeDataOpOnDevice(shape, &b); XlaComputation computation = b.Build().ConsumeValueOrDie(); auto execution_options = CreateDefaultExecutionOptions(); *execution_options.mutable_shape_with_output_layout() = shape.ToProto(); + if (debug_opts) { + *execution_options.mutable_debug_options() = *debug_opts; + } return client->Execute(computation, /*arguments=*/{}, &execution_options) .ConsumeValueOrDie(); } } // namespace -std::unique_ptr MakeFakeDataOrDie(const Shape& shape, - Client* client) { +std::unique_ptr MakeFakeDataOrDie( + const Shape& shape, Client* client, DebugOptions* debug_opts /*=nullptr*/) { if (DataSizeOfShape(shape) < (1LL << 20)) { StatusOr literal_status = MakeFakeLiteral(shape); if (!literal_status.ok()) { @@ -82,24 +85,25 @@ std::unique_ptr MakeFakeDataOrDie(const Shape& shape, // an on-device computation. CHECK_EQ(literal_status.status().code(), tensorflow::error::UNIMPLEMENTED); - return MakeFakeDataViaDeviceOrDie(shape, client); + return MakeFakeDataViaDeviceOrDie(shape, client, debug_opts); } return client->TransferToServer(literal_status.ValueOrDie()).ValueOrDie(); } // If the data is large, generate it on-device. - return MakeFakeDataViaDeviceOrDie(shape, client); + return MakeFakeDataViaDeviceOrDie(shape, client, debug_opts); } std::vector> MakeFakeArgumentsOrDie( - const XlaComputation& computation, Client* client) { + const XlaComputation& computation, Client* client, + DebugOptions* debug_opts /*=nullptr*/) { CHECK(computation.proto().has_host_program_shape()) << "Computation should have progran shape."; auto program_shape = computation.proto().host_program_shape(); std::vector> results; for (const ShapeProto& shape : program_shape.parameters()) { - results.push_back(MakeFakeDataOrDie(Shape(shape), client)); + results.push_back(MakeFakeDataOrDie(Shape(shape), client, debug_opts)); } return results; } diff --git a/tensorflow/compiler/xla/client/lib/testing.h b/tensorflow/compiler/xla/client/lib/testing.h index 03695ce2a339735e3e49522f4fe1bbf2d83a3834..428fa3e93d1b46983aae60176e7c2242d2552fdb 100644 --- a/tensorflow/compiler/xla/client/lib/testing.h +++ b/tensorflow/compiler/xla/client/lib/testing.h @@ -29,14 +29,19 @@ namespace xla { // Generates fake data of the given shape on the device or dies. The fake data // is created by performing a computation on the device rather than transferring // data from the host to the device. -std::unique_ptr MakeFakeDataOrDie(const Shape& shape, - Client* client); +// +// The optional DebugOptions are used when generating fake data on the device. +std::unique_ptr MakeFakeDataOrDie( + const Shape& shape, Client* client, DebugOptions* debug_opts = nullptr); // Returns vector of GlobalData handles of fake data (created using // MakeFakeDataOrDie) that are correctly shaped arguments for the given // xla computation. +// +// The optional DebugOptions are used when generating fake data on the device. std::vector> MakeFakeArgumentsOrDie( - const XlaComputation& computation, Client* client); + const XlaComputation& computation, Client* client, + DebugOptions* debug_opts = nullptr); } // namespace xla diff --git a/tensorflow/compiler/xla/client/lib/triangular_solve.cc b/tensorflow/compiler/xla/client/lib/triangular_solve.cc index c5a1d34cc66e6f8c1a832f8a8437163b846a5431..ac58090dfe33a8ae350019771e0b970d6f26e476 100644 --- a/tensorflow/compiler/xla/client/lib/triangular_solve.cc +++ b/tensorflow/compiler/xla/client/lib/triangular_solve.cc @@ -393,6 +393,12 @@ XlaOp TriangularSolve(XlaOp a, XlaOp b, bool left_side, bool lower, block_size); } + if (ShapeUtil::IsZeroElementArray(b_shape)) { + // The output has the same shape as 'b', and since the output has zero + // elements, any such array will do. + return b; + } + // We find the diagonal blocks of the coefficient matrix auto diag_blocks = DiagonalBlocks(a, block_size); diff --git a/tensorflow/compiler/xla/client/lib/triangular_solve_test.cc b/tensorflow/compiler/xla/client/lib/triangular_solve_test.cc index f6a70d64a788d95a456774ccbbcf67f2e5cac98b..d0188e8ea06d0edacdba330f46647af201747abf 100644 --- a/tensorflow/compiler/xla/client/lib/triangular_solve_test.cc +++ b/tensorflow/compiler/xla/client/lib/triangular_solve_test.cc @@ -33,56 +33,68 @@ limitations under the License. namespace xla { namespace { -using TriangularSolveTest = xla::ClientLibraryTestBase; -using TriangularSolveLeftLookingTest = xla::ClientLibraryTestBase; -using complex64 = xla::complex64; +using TriangularSolveTest = ClientLibraryTestBase; +using TriangularSolveLeftLookingTest = ClientLibraryTestBase; -xla::Array2D AValsLower() { +Array2D AValsLower() { return {{2, 0, 0, 0}, {3, 6, 0, 0}, {4, 7, 9, 0}, {5, 8, 10, 11}}; } -xla::Array2D AValsUpper() { +Array2D AValsUpper() { return {{2, 3, 4, 5}, {0, 6, 7, 8}, {0, 0, 9, 10}, {0, 0, 0, 11}}; } -xla::Array2D BValsRight() { +Array2D BValsRight() { return {{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}}; } -xla::Array2D BValsLeft() { +Array2D BValsLeft() { return {{1, 2, 3}, {4, 5, 6}, {7, 8, 9}, {10, 11, 12}}; } -xla::Array2D AValsLowerComplex() { +Array2D AValsLowerComplex() { return {{2, 0, 0, 0}, {complex64(3, 1), 6, 0, 0}, {4, complex64(7, 2), 9, 0}, {5, 8, complex64(10, 3), 11}}; } -xla::Array2D AValsUpperComplex() { +Array2D AValsUpperComplex() { return {{2, 3, complex64(4, 3), 5}, {0, 6, complex64(7, 2), 8}, {0, 0, complex64(9, 1), 10}, {0, 0, 0, 11}}; } -xla::Array2D BValsRightComplex() { +Array2D BValsRightComplex() { return {{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}}; } -xla::Array2D BValsLeftComplex() { +Array2D BValsLeftComplex() { return {{1, 2, 3}, {4, 5, 6}, {7, 8, 9}, {10, 11, 12}}; } -xla::Array2D AValsFull() { - return {{2, 0, 1, 2}, {3, 6, 0, 1}, {4, 7, 9, 0}, {5, 8, 10, 11}}; +XLA_TEST_F(TriangularSolveTest, EmptyArrays) { + XlaBuilder builder(TestName()); + + XlaOp a, b; + auto a_data = + CreateR2Parameter(Array2D(0, 0), 0, "a", &builder, &a); + auto b_data = + CreateR2Parameter(Array2D(0, 10), 1, "b", &builder, &b); + TriangularSolve(a, b, + /*left_side=*/true, /*lower=*/true, + /*transpose_a=*/true, /*conjugate_a=*/false, + /*block_size=*/2); + + ComputeAndCompareR2(&builder, Array2D(0, 10), + {a_data.get(), b_data.get()}); } XLA_TEST_F(TriangularSolveTest, SimpleRightLowerTranspose) { - xla::XlaBuilder builder(TestName()); + XlaBuilder builder(TestName()); - xla::XlaOp a, b; + XlaOp a, b; auto a_data = CreateR2Parameter(AValsLower(), 0, "a", &builder, &a); auto b_data = CreateR2Parameter(BValsRight(), 1, "b", &builder, &b); TriangularSolve(a, b, @@ -90,20 +102,20 @@ XLA_TEST_F(TriangularSolveTest, SimpleRightLowerTranspose) { /*transpose_a=*/true, /*conjugate_a=*/false, /*block_size=*/2); - xla::Array2D expected({ + Array2D expected({ {0.5, 0.08333334, 0.04629629, 0.03367003}, {2.5, -0.25, -0.1388889, -0.1010101}, {4.5, -0.58333331, -0.32407406, -0.23569024}, }); ComputeAndCompareR2(&builder, expected, {a_data.get(), b_data.get()}, - xla::ErrorSpec(1e-2, 1e-2)); + ErrorSpec(1e-2, 1e-2)); } XLA_TEST_F(TriangularSolveTest, SimpleRightLowerNotranspose) { - xla::XlaBuilder builder(TestName()); + XlaBuilder builder(TestName()); - xla::XlaOp a, b; + XlaOp a, b; auto a_data = CreateR2Parameter(AValsLower(), 0, "a", &builder, &a); auto b_data = CreateR2Parameter(BValsRight(), 1, "b", &builder, &b); TriangularSolve(a, b, @@ -111,20 +123,20 @@ XLA_TEST_F(TriangularSolveTest, SimpleRightLowerNotranspose) { /*transpose_a=*/false, /*conjugate_a=*/false, /*block_size=*/2); - xla::Array2D expected({ + Array2D expected({ {-0.16414141, -0.06902357, -0.07070707, 0.36363636}, {0.64393939, 0.06565657, -0.03030303, 0.72727273}, {1.4520202, 0.2003367, 0.01010101, 1.09090909}, }); ComputeAndCompareR2(&builder, expected, {a_data.get(), b_data.get()}, - xla::ErrorSpec(1e-2, 1e-2)); + ErrorSpec(1e-2, 1e-2)); } XLA_TEST_F(TriangularSolveTest, SimpleRightUpperTranspose) { - xla::XlaBuilder builder(TestName()); + XlaBuilder builder(TestName()); - xla::XlaOp a, b; + XlaOp a, b; auto a_data = CreateR2Parameter(AValsUpper(), 0, "a", &builder, &a); auto b_data = CreateR2Parameter(BValsRight(), 1, "b", &builder, &b); TriangularSolve(a, b, @@ -132,20 +144,20 @@ XLA_TEST_F(TriangularSolveTest, SimpleRightUpperTranspose) { /*transpose_a=*/true, /*conjugate_a=*/false, /*block_size=*/2); - xla::Array2D expected({ + Array2D expected({ {-0.16414141, -0.06902357, -0.07070707, 0.36363636}, {0.64393939, 0.06565657, -0.03030303, 0.72727273}, {1.4520202, 0.2003367, 0.01010101, 1.09090909}, }); ComputeAndCompareR2(&builder, expected, {a_data.get(), b_data.get()}, - xla::ErrorSpec(1e-2, 1e-2)); + ErrorSpec(1e-2, 1e-2)); } XLA_TEST_F(TriangularSolveTest, SimpleRightUpperNotranspose) { - xla::XlaBuilder builder(TestName()); + XlaBuilder builder(TestName()); - xla::XlaOp a, b; + XlaOp a, b; auto a_data = CreateR2Parameter(AValsUpper(), 0, "a", &builder, &a); auto b_data = CreateR2Parameter(BValsRight(), 1, "b", &builder, &b); TriangularSolve(a, b, @@ -153,20 +165,20 @@ XLA_TEST_F(TriangularSolveTest, SimpleRightUpperNotranspose) { /*transpose_a=*/false, /*conjugate_a=*/false, /*block_size=*/2); - xla::Array2D expected({ + Array2D expected({ {0.5, 0.08333334, 0.04629629, 0.03367003}, {2.5, -0.25, -0.1388889, -0.1010101}, {4.5, -0.58333331, -0.32407406, -0.23569024}, }); ComputeAndCompareR2(&builder, expected, {a_data.get(), b_data.get()}, - xla::ErrorSpec(1e-2, 1e-2)); + ErrorSpec(1e-2, 1e-2)); } XLA_TEST_F(TriangularSolveTest, SimpleLeftLowerTranspose) { - xla::XlaBuilder builder(TestName()); + XlaBuilder builder(TestName()); - xla::XlaOp a, b; + XlaOp a, b; auto a_data = CreateR2Parameter(AValsLower(), 0, "a", &builder, &a); auto b_data = CreateR2Parameter(BValsLeft(), 1, "b", &builder, &b); TriangularSolve(a, b, @@ -174,7 +186,7 @@ XLA_TEST_F(TriangularSolveTest, SimpleLeftLowerTranspose) { /*transpose_a=*/true, /*conjugate_a=*/false, /*block_size=*/2); - xla::Array2D expected({ + Array2D expected({ {-0.89646465, -0.69444444, -0.49242424}, {-0.27441077, -0.24074074, -0.20707071}, {-0.23232323, -0.22222222, -0.21212121}, @@ -182,13 +194,13 @@ XLA_TEST_F(TriangularSolveTest, SimpleLeftLowerTranspose) { }); ComputeAndCompareR2(&builder, expected, {a_data.get(), b_data.get()}, - xla::ErrorSpec(1e-2, 1e-2)); + ErrorSpec(1e-2, 1e-2)); } XLA_TEST_F(TriangularSolveTest, SimpleLeftLowerNotranspose) { - xla::XlaBuilder builder(TestName()); + XlaBuilder builder(TestName()); - xla::XlaOp a, b; + XlaOp a, b; auto a_data = CreateR2Parameter(AValsLower(), 0, "a", &builder, &a); auto b_data = CreateR2Parameter(BValsLeft(), 1, "b", &builder, &b); TriangularSolve(a, b, @@ -196,7 +208,7 @@ XLA_TEST_F(TriangularSolveTest, SimpleLeftLowerNotranspose) { /*transpose_a=*/false, /*conjugate_a=*/false, /*block_size=*/2); - xla::Array2D expected({ + Array2D expected({ {0.5, 1.0, 1.5}, {0.41666667, 0.33333333, 0.25}, {0.23148148, 0.18518519, 0.13888889}, @@ -204,13 +216,13 @@ XLA_TEST_F(TriangularSolveTest, SimpleLeftLowerNotranspose) { }); ComputeAndCompareR2(&builder, expected, {a_data.get(), b_data.get()}, - xla::ErrorSpec(1e-2, 1e-2)); + ErrorSpec(1e-2, 1e-2)); } XLA_TEST_F(TriangularSolveTest, SimpleLeftLowerNotransposeIrregularblock) { - xla::XlaBuilder builder(TestName()); + XlaBuilder builder(TestName()); - xla::XlaOp a, b; + XlaOp a, b; auto a_data = CreateR2Parameter(AValsLower(), 0, "a", &builder, &a); auto b_data = CreateR2Parameter(BValsLeft(), 1, "b", &builder, &b); TriangularSolve(a, b, @@ -218,7 +230,7 @@ XLA_TEST_F(TriangularSolveTest, SimpleLeftLowerNotransposeIrregularblock) { /*transpose_a=*/false, /*conjugate_a=*/false, /*block_size=*/3); - xla::Array2D expected({ + Array2D expected({ {0.5, 1.0, 1.5}, {0.41666667, 0.33333333, 0.25}, {0.23148148, 0.18518519, 0.13888889}, @@ -226,13 +238,13 @@ XLA_TEST_F(TriangularSolveTest, SimpleLeftLowerNotransposeIrregularblock) { }); ComputeAndCompareR2(&builder, expected, {a_data.get(), b_data.get()}, - xla::ErrorSpec(1e-2, 1e-2)); + ErrorSpec(1e-2, 1e-2)); } XLA_TEST_F(TriangularSolveTest, SimpleLeftUpperTranspose) { - xla::XlaBuilder builder(TestName()); + XlaBuilder builder(TestName()); - xla::XlaOp a, b; + XlaOp a, b; auto a_data = CreateR2Parameter(AValsUpper(), 0, "a", &builder, &a); auto b_data = CreateR2Parameter(BValsLeft(), 1, "b", &builder, &b); TriangularSolve(a, b, @@ -240,7 +252,7 @@ XLA_TEST_F(TriangularSolveTest, SimpleLeftUpperTranspose) { /*transpose_a=*/true, /*conjugate_a=*/false, /*block_size=*/2); - xla::Array2D expected({ + Array2D expected({ {0.5, 1.0, 1.5}, {0.41666667, 0.33333333, 0.25}, {0.23148148, 0.18518519, 0.13888889}, @@ -248,13 +260,13 @@ XLA_TEST_F(TriangularSolveTest, SimpleLeftUpperTranspose) { }); ComputeAndCompareR2(&builder, expected, {a_data.get(), b_data.get()}, - xla::ErrorSpec(1e-2, 1e-2)); + ErrorSpec(1e-2, 1e-2)); } XLA_TEST_F(TriangularSolveTest, SimpleLeftUpperNotranspose) { - xla::XlaBuilder builder(TestName()); + XlaBuilder builder(TestName()); - xla::XlaOp a, b; + XlaOp a, b; auto a_data = CreateR2Parameter(AValsUpper(), 0, "a", &builder, &a); auto b_data = CreateR2Parameter(BValsLeft(), 1, "b", &builder, &b); TriangularSolve(a, b, @@ -262,7 +274,7 @@ XLA_TEST_F(TriangularSolveTest, SimpleLeftUpperNotranspose) { /*transpose_a=*/false, /*conjugate_a=*/false, /*block_size=*/2); - xla::Array2D expected({ + Array2D expected({ {-0.89646465, -0.69444444, -0.49242424}, {-0.27441077, -0.24074074, -0.20707071}, {-0.23232323, -0.22222222, -0.21212121}, @@ -270,13 +282,13 @@ XLA_TEST_F(TriangularSolveTest, SimpleLeftUpperNotranspose) { }); ComputeAndCompareR2(&builder, expected, {a_data.get(), b_data.get()}, - xla::ErrorSpec(1e-2, 1e-2)); + ErrorSpec(1e-2, 1e-2)); } XLA_TEST_F(TriangularSolveTest, SimpleRightLowerTransposeConjugate) { - xla::XlaBuilder builder(TestName()); + XlaBuilder builder(TestName()); - xla::XlaOp a, b; + XlaOp a, b; auto a_data = CreateR2Parameter(AValsLowerComplex(), 0, "a", &builder, &a); auto b_data = @@ -286,7 +298,7 @@ XLA_TEST_F(TriangularSolveTest, SimpleRightLowerTransposeConjugate) { /*transpose_a=*/true, /*conjugate_a=*/true, /*block_size=*/2); - xla::Array2D expected({ + Array2D expected({ {0.5, complex64(0.08333333, 0.08333333), complex64(0.02777778, -0.0462963), complex64(0.06313131, -0.01094276)}, {2.5, complex64(-0.25, 0.41666667), complex64(-0.23148148, -0.37962963), @@ -295,15 +307,14 @@ XLA_TEST_F(TriangularSolveTest, SimpleRightLowerTransposeConjugate) { complex64(0.11026936, -0.03114478)}, }); - ComputeAndCompareR2(&builder, expected, - {a_data.get(), b_data.get()}, - xla::ErrorSpec(1e-2, 1e-2)); + ComputeAndCompareR2( + &builder, expected, {a_data.get(), b_data.get()}, ErrorSpec(1e-2, 1e-2)); } XLA_TEST_F(TriangularSolveTest, SimpleLeftUpperTransposeNoconjugate) { - xla::XlaBuilder builder(TestName()); + XlaBuilder builder(TestName()); - xla::XlaOp a, b; + XlaOp a, b; auto a_data = CreateR2Parameter(AValsUpperComplex(), 0, "a", &builder, &a); auto b_data = @@ -313,7 +324,7 @@ XLA_TEST_F(TriangularSolveTest, SimpleLeftUpperTransposeNoconjugate) { /*transpose_a=*/true, /*conjugate_a=*/false, /*block_size=*/2); - xla::Array2D expected({ + Array2D expected({ {0.5, 1., 1.5}, {0.41666667, 0.33333333, 0.25}, {complex64(0.20020325, -2.81504065e-01), @@ -324,9 +335,8 @@ XLA_TEST_F(TriangularSolveTest, SimpleLeftUpperTransposeNoconjugate) { complex64(0.15798226, 5.12749446e-01)}, }); - ComputeAndCompareR2(&builder, expected, - {a_data.get(), b_data.get()}, - xla::ErrorSpec(1e-2, 1e-2)); + ComputeAndCompareR2( + &builder, expected, {a_data.get(), b_data.get()}, ErrorSpec(1e-2, 1e-2)); } } // namespace diff --git a/tensorflow/compiler/xla/debug_options_flags.cc b/tensorflow/compiler/xla/debug_options_flags.cc index 20609cad58d920c0c272899c41efeb99d23cd490..a9a91648ac377987e7f226116e11c9c697ace103 100644 --- a/tensorflow/compiler/xla/debug_options_flags.cc +++ b/tensorflow/compiler/xla/debug_options_flags.cc @@ -22,49 +22,49 @@ limitations under the License. #include "tensorflow/compiler/xla/parse_flags_from_env.h" namespace xla { -namespace { -DebugOptions* flag_values; -std::vector* flag_objects; -std::once_flag flags_init; - -void SetDebugOptionsDefaults(DebugOptions* flags) { - flags->set_xla_llvm_enable_alias_scope_metadata(true); - flags->set_xla_llvm_enable_noalias_metadata(true); - flags->set_xla_llvm_enable_invariant_load_metadata(true); - flags->set_xla_llvm_disable_expensive_passes(false); - flags->set_xla_backend_optimization_level(3); - flags->set_xla_cpu_multi_thread_eigen(true); - flags->set_xla_gpu_cuda_data_dir("./cuda_sdk_lib"); - flags->set_xla_eliminate_hlo_implicit_broadcast(true); +DebugOptions DefaultDebugOptionsIgnoringFlags() { + DebugOptions opts; + opts.set_xla_llvm_enable_alias_scope_metadata(true); + opts.set_xla_llvm_enable_noalias_metadata(true); + opts.set_xla_llvm_enable_invariant_load_metadata(true); + opts.set_xla_llvm_disable_expensive_passes(false); + opts.set_xla_backend_optimization_level(3); + opts.set_xla_cpu_multi_thread_eigen(true); + opts.set_xla_gpu_cuda_data_dir("./cuda_sdk_lib"); + opts.set_xla_eliminate_hlo_implicit_broadcast(true); + opts.set_xla_hlo_dump_as_html(false); #ifdef INTEL_MKL - flags->set_xla_cpu_use_mkl_dnn(true); + opts.set_xla_cpu_use_mkl_dnn(true); #endif // INTEL_MKL - flags->set_xla_gpu_max_kernel_unroll_factor(4); + opts.set_xla_gpu_max_kernel_unroll_factor(4); // Set cudnn batchnorm off by default; it does not provide a performance win // on average. - flags->set_xla_gpu_use_cudnn_batchnorm(false); + opts.set_xla_gpu_use_cudnn_batchnorm(false); // Run all GPU work on one stream by default. Using multiple streams // increases memory usage and we lack strong motivating benchmarks for tuning // the heuristics needed to decide when to run on multiple streams. See // b/77879207. - flags->set_xla_gpu_disable_multi_streaming(true); + opts.set_xla_gpu_disable_multi_streaming(true); // TODO(jlebar): Disable fastmath once doing so is not a performance // regression. - flags->set_xla_cpu_enable_fast_math(true); - flags->set_xla_gpu_enable_fast_min_max(true); + opts.set_xla_cpu_enable_fast_math(true); + opts.set_xla_gpu_enable_fast_min_max(true); - flags->set_xla_force_host_platform_device_count(1); + opts.set_xla_force_host_platform_device_count(1); + return opts; } +static DebugOptions* flag_values; +static std::vector* flag_objects; +static std::once_flag flags_init; + // Allocates flag_values and flag_objects; this function must not be called more // than once - its call done via call_once. -void AllocateFlags() { - flag_values = new DebugOptions; - - SetDebugOptionsDefaults(flag_values); +static void AllocateFlags() { + flag_values = new DebugOptions(DefaultDebugOptionsIgnoringFlags()); // Returns a lambda that calls "member_setter" on "flag_values" with the // argument passed in to the lambda. @@ -133,6 +133,11 @@ void AllocateFlags() { bool_setter_for(&DebugOptions::set_xla_hlo_dump_as_graphdef), flag_values->xla_hlo_dump_as_graphdef(), "Dump HLO graphs as TensorFlow GraphDefs."), + tensorflow::Flag("xla_hlo_dump_as_html", + bool_setter_for(&DebugOptions::set_xla_hlo_dump_as_html), + flag_values->xla_hlo_dump_as_html(), + "Dump HLO graphs as an HTML (DOT rendered into SVG " + "inlined in HTML)."), tensorflow::Flag( "xla_hlo_graph_sharding_color", bool_setter_for(&DebugOptions::set_xla_hlo_graph_sharding_color), @@ -202,6 +207,16 @@ void AllocateFlags() { "Comma-separated list of hlo passes to be disabled. These names " "must exactly match the passes' names; no whitespace around " "commas."), + tensorflow::Flag( + "xla_disable_all_hlo_passes", + bool_setter_for(&DebugOptions::set_xla_disable_all_hlo_passes), false, + "Disables all HLO passes. Notes that some passes are necessary for " + "correctness and the invariants that must be satisfied by 'fully " + "optimized' HLO are different for different devices and may change " + "over time. The only 'guarantee', such as it is, is that if you " + "compile XLA and dump the optimized HLO for some graph, you should " + "be able to run it again on the same device with the same build of " + "XLA."), tensorflow::Flag( "xla_embed_ir_in_executable", bool_setter_for(&DebugOptions::set_xla_embed_ir_in_executable), @@ -344,8 +359,6 @@ void AllocateFlags() { ParseFlagsFromEnvAndDieIfUnknown("XLA_FLAGS", *flag_objects); } -} // namespace - void AppendDebugOptionsFlags(std::vector* flag_list) { std::call_once(flags_init, &AllocateFlags); flag_list->insert(flag_list->end(), flag_objects->begin(), diff --git a/tensorflow/compiler/xla/debug_options_flags.h b/tensorflow/compiler/xla/debug_options_flags.h index 60e59abc2a2e0f1cce3de1afc928f9fe36f75b33..dbf86a40f052af09c61da0e1abb3116ef5214357 100644 --- a/tensorflow/compiler/xla/debug_options_flags.h +++ b/tensorflow/compiler/xla/debug_options_flags.h @@ -29,7 +29,10 @@ void AppendDebugOptionsFlags(std::vector* flag_list); // Fetches a DebugOptions proto message from flags provided to the program. // Flags must be registered with the flags parser using AppendDebugOptionsFlags // first. -xla::DebugOptions GetDebugOptionsFromFlags(); +DebugOptions GetDebugOptionsFromFlags(); + +// Gets a DebugOptions proto that reflects the defaults as if no flags were set. +DebugOptions DefaultDebugOptionsIgnoringFlags(); } // namespace xla diff --git a/tensorflow/compiler/xla/g3doc/operation_semantics.md b/tensorflow/compiler/xla/g3doc/operation_semantics.md index d888b1f23f36f33ef94ef0e22374e0c796e47a89..002ebc31b992826b4dfc53f31a9e3625cde3c5d0 100644 --- a/tensorflow/compiler/xla/g3doc/operation_semantics.md +++ b/tensorflow/compiler/xla/g3doc/operation_semantics.md @@ -38,25 +38,25 @@ Alltoall is a collective operation that sends data from all cores to all cores. It has two phases: 1. the scatter phase. On each core, the operand is split into `split_count` - number of blocks along the `split_dimensions`, and the blocks are scattered - to all cores, e.g., the ith block is send to the ith core. +number of blocks along the `split_dimensions`, and the blocks are scattered +to all cores, e.g., the ith block is send to the ith core. 2. the gather phase. Each core concatenates the received blocks along the - `concat_dimension`. +`concat_dimension`. The participating cores can be configured by: - `replica_groups`: each ReplicaGroup contains a list of replica id. If empty, - all replicas belong to one group in the order of 0 - (n-1). Alltoall will be - applied within subgroups in the specified order. For example, replica - groups = {{1,2,3},{4,5,0}} means, an Alltoall will be applied within replica - 1, 2, 3, and in the gather phase, the received blocks will be concatenated - in the order of 1, 2, 3; another Alltoall will be applied within replica 4, - 5, 0, and the concatenation order is 4, 5, 0. +all replicas belong to one group in the order of 0 - (n-1). Alltoall will be +applied within subgroups in the specified order. For example, replica +groups = {{1,2,3},{4,5,0}} means, an Alltoall will be applied within replica +1, 2, 3, and in the gather phase, the received blocks will be concatenated +in the order of 1, 2, 3; another Alltoall will be applied within replica 4, +5, 0, and the concatenation order is 4, 5, 0. Prerequisites: - The dimension size of the operand on the split_dimension is divisible by - split_count. +split_count. - The operand's shape is not tuple. `AllToAll(operand, split_dimension, concat_dimension, split_count, @@ -93,7 +93,7 @@ AllToAll(x, /*split_dimension=*/1, /*concat_dimension=*/0, /*split_count=*/4); ```
- +
In this example, there are 4 cores participating the Alltoall. On each core, the @@ -387,34 +387,34 @@ For example, let v be an array of 24 elements: ``` let v = f32[4x2x3] {{{10, 11, 12}, {15, 16, 17}}, - {{20, 21, 22}, {25, 26, 27}}, - {{30, 31, 32}, {35, 36, 37}}, - {{40, 41, 42}, {45, 46, 47}}}; +{{20, 21, 22}, {25, 26, 27}}, +{{30, 31, 32}, {35, 36, 37}}, +{{40, 41, 42}, {45, 46, 47}}}; // Collapse to a single dimension, leaving one dimension. let v012 = Collapse(v, {0,1,2}); then v012 == f32[24] {10, 11, 12, 15, 16, 17, - 20, 21, 22, 25, 26, 27, - 30, 31, 32, 35, 36, 37, - 40, 41, 42, 45, 46, 47}; +20, 21, 22, 25, 26, 27, +30, 31, 32, 35, 36, 37, +40, 41, 42, 45, 46, 47}; // Collapse the two lower dimensions, leaving two dimensions. let v01 = Collapse(v, {0,1}); then v01 == f32[4x6] {{10, 11, 12, 15, 16, 17}, - {20, 21, 22, 25, 26, 27}, - {30, 31, 32, 35, 36, 37}, - {40, 41, 42, 45, 46, 47}}; +{20, 21, 22, 25, 26, 27}, +{30, 31, 32, 35, 36, 37}, +{40, 41, 42, 45, 46, 47}}; // Collapse the two higher dimensions, leaving two dimensions. let v12 = Collapse(v, {1,2}); then v12 == f32[8x3] {{10, 11, 12}, - {15, 16, 17}, - {20, 21, 22}, - {25, 26, 27}, - {30, 31, 32}, - {35, 36, 37}, - {40, 41, 42}, - {45, 46, 47}}; +{15, 16, 17}, +{20, 21, 22}, +{25, 26, 27}, +{30, 31, 32}, +{35, 36, 37}, +{40, 41, 42}, +{45, 46, 47}}; ``` @@ -441,9 +441,9 @@ replicas. Note that there are the following restrictions on the `source_target_pair`: - Any two pairs should not have the same target replica id, and they should - not have the same source replica id. +not have the same source replica id. - If a replica id is not a target in any pair, then the output on that replica - is a tensor consists of 0(s) with the same shape as the input. +is a tensor consists of 0(s) with the same shape as the input. ## Concatenate @@ -480,25 +480,25 @@ Concat({{2, 3}, {4, 5}, {6, 7}}, 0) ``` let a = { - {1, 2}, - {3, 4}, - {5, 6}, +{1, 2}, +{3, 4}, +{5, 6}, }; let b = { - {7, 8}, +{7, 8}, }; Concat({a, b}, 0) >>> { - {1, 2}, - {3, 4}, - {5, 6}, - {7, 8}, +{1, 2}, +{3, 4}, +{5, 6}, +{7, 8}, } ``` Diagram:
- +
## Conditional @@ -566,20 +566,20 @@ the rhs is also an input. In a neural network, these are the input activations. The n+2 dimensions are, in this order: * `batch`: Each coordinate in this dimension represents an independent input - for which convolution is carried out. +for which convolution is carried out. * `z/depth/features`: Each (y,x) position in the base area has a vector - associated to it, which goes into this dimension. +associated to it, which goes into this dimension. * `spatial_dims`: Describes the `n` spatial dimensions that define the base - area that the window moves across. +area that the window moves across. The `rhs` argument is a rank n+2 array describing the convolutional filter/kernel/window. The dimensions are, in this order: * `output-z`: The `z` dimension of the output. * `input-z`: The size of this dimension times `feature_group_count` should - equal the size of the `z` dimension in lhs. +equal the size of the `z` dimension in lhs. * `spatial_dims`: Describes the `n` spatial dimensions that define the n-d - window that moves across the base area. +window that moves across the base area. The `window_strides` argument specifies the stride of the convolutional window in the spatial dimensions. For example, if the stride in the first spatial @@ -633,7 +633,7 @@ The output shape has these dimensions, in this order: * `batch`: Same size as `batch` on the input (`lhs`). * `z`: Same size as `output-z` on the kernel (`rhs`). * `spatial_dims`: One value for each valid placement of the convolutional - window. +window. The valid placements of the convolutional window are determined by the strides and the size of the base area after padding. @@ -658,15 +658,15 @@ Here is pseudo-code for a 2d convolution with padding and striding: ``` for (b, oz, oy, ox) { // output coordinates - value = 0; - for (iz, ky, kx) { // kernel coordinates and input z - iy = oy*stride_y + ky - pad_low_y; - ix = ox*stride_x + kx - pad_low_x; - if ((iy, ix) inside the base area considered without padding) { - value += input(b, iz, iy, ix) * kernel(oz, iz, ky, kx); - } - } - output(b, oz, oy, ox) = value; +value = 0; +for (iz, ky, kx) { // kernel coordinates and input z +iy = oy*stride_y + ky - pad_low_y; +ix = ox*stride_x + kx - pad_low_x; +if ((iy, ix) inside the base area considered without padding) { +value += input(b, iz, iy, ix) * kernel(oz, iz, ky, kx); +} +} +output(b, oz, oy, ox) = value; } ``` @@ -777,19 +777,19 @@ Here is an example of an implementation of `myfunc`: ``` extern "C" void myfunc(void* out, void** in) { - float (&x)[2] = *static_cast(in[0]); - float (&y)[2][3] = *static_cast(in[1]); - EXPECT_EQ(1, x[0]); - EXPECT_EQ(2, x[1]); - EXPECT_EQ(10, y[0][0]); - EXPECT_EQ(20, y[0][1]); - EXPECT_EQ(30, y[0][2]); - EXPECT_EQ(40, y[1][0]); - EXPECT_EQ(50, y[1][1]); - EXPECT_EQ(60, y[1][2]); - float (&z)[3][3] = *static_cast(out); - z[0][0] = x[1] + y[1][0]; - // ... +float (&x)[2] = *static_cast(in[0]); +float (&y)[2][3] = *static_cast(in[1]); +EXPECT_EQ(1, x[0]); +EXPECT_EQ(2, x[1]); +EXPECT_EQ(10, y[0][0]); +EXPECT_EQ(20, y[0][1]); +EXPECT_EQ(30, y[0][2]); +EXPECT_EQ(40, y[1][0]); +EXPECT_EQ(50, y[1][1]); +EXPECT_EQ(60, y[1][2]); +float (&z)[3][3] = *static_cast(out); +z[0][0] = x[1] + y[1][0]; +// ... } ``` @@ -864,17 +864,17 @@ Example with contracting dimension numbers: ``` lhs = { {1.0, 2.0, 3.0}, - {4.0, 5.0, 6.0} } +{4.0, 5.0, 6.0} } rhs = { {1.0, 1.0, 1.0}, - {2.0, 2.0, 2.0} } +{2.0, 2.0, 2.0} } DotDimensionNumbers dnums; dnums.add_lhs_contracting_dimensions(1); dnums.add_rhs_contracting_dimensions(1); DotGeneral(lhs, rhs, dnums) -> { {6.0, 12.0}, - {15.0, 30.0} } +{15.0, 30.0} } ``` Associated batch dimension numbers from the 'lhs' and 'rhs' must have the same @@ -886,14 +886,14 @@ Example with batch dimension numbers (batch size 2, 2x2 matrices): ``` lhs = { { {1.0, 2.0}, - {3.0, 4.0} }, - { {5.0, 6.0}, - {7.0, 8.0} } } +{3.0, 4.0} }, +{ {5.0, 6.0}, +{7.0, 8.0} } } rhs = { { {1.0, 0.0}, - {0.0, 1.0} }, - { {1.0, 0.0}, - {0.0, 1.0} } } +{0.0, 1.0} }, +{ {1.0, 0.0}, +{0.0, 1.0} } } DotDimensionNumbers dnums; dnums.add_lhs_contracting_dimensions(2); @@ -902,9 +902,9 @@ dnums.add_lhs_batch_dimensions(0); dnums.add_rhs_batch_dimensions(0); DotGeneral(lhs, rhs, dnums) -> { { {1.0, 2.0}, - {3.0, 4.0} }, - { {5.0, 6.0}, - {7.0, 8.0} } } +{3.0, 4.0} }, +{ {5.0, 6.0}, +{7.0, 8.0} } } ``` | Input | Output | Semantics | @@ -963,22 +963,22 @@ let a = {0.0, 1.0, 2.0, 3.0, 4.0} let s = {2} DynamicSlice(a, s, {2}) produces: - {2.0, 3.0} +{2.0, 3.0} ``` 2-dimensional example: ``` let b = - { {0.0, 1.0, 2.0}, - {3.0, 4.0, 5.0}, - {6.0, 7.0, 8.0}, - {9.0, 10.0, 11.0} } +{ {0.0, 1.0, 2.0}, +{3.0, 4.0, 5.0}, +{6.0, 7.0, 8.0}, +{9.0, 10.0, 11.0} } let s = {2, 1} DynamicSlice(b, s, {2, 2}) produces: - { { 7.0, 8.0}, - {10.0, 11.0} } +{ { 7.0, 8.0}, +{10.0, 11.0} } ``` ## DynamicUpdateSlice @@ -1027,29 +1027,29 @@ let u = {5.0, 6.0} let s = {2} DynamicUpdateSlice(a, u, s) produces: - {0.0, 1.0, 5.0, 6.0, 4.0} +{0.0, 1.0, 5.0, 6.0, 4.0} ``` 2-dimensional example: ``` let b = - { {0.0, 1.0, 2.0}, - {3.0, 4.0, 5.0}, - {6.0, 7.0, 8.0}, - {9.0, 10.0, 11.0} } +{ {0.0, 1.0, 2.0}, +{3.0, 4.0, 5.0}, +{6.0, 7.0, 8.0}, +{9.0, 10.0, 11.0} } let u = - { {12.0, 13.0}, - {14.0, 15.0}, - {16.0, 17.0} } +{ {12.0, 13.0}, +{14.0, 15.0}, +{16.0, 17.0} } let s = {1, 1} DynamicUpdateSlice(b, u, s) produces: - { {0.0, 1.0, 2.0}, - {3.0, 12.0, 13.0}, - {6.0, 14.0, 15.0}, - {9.0, 16.0, 17.0} } +{ {0.0, 1.0, 2.0}, +{3.0, 12.0, 13.0}, +{6.0, 14.0, 15.0}, +{9.0, 16.0, 17.0} } ``` ## Element-wise binary arithmetic operations @@ -1235,42 +1235,42 @@ shape of `start_indices` to be `[6,7,1]`). The bounds for the output array along dimension `i` is computed as follows: - 1. If `i` is present in `batch_dims` (i.e. is equal to `batch_dims[k]` for - some `k`) then we pick the corresponding dimension bounds out of - `start_indices.shape`, skipping `index_vector_dim` (i.e. pick - `start_indices.shape.dims`[`k`] if `k` < `index_vector_dim` and - `start_indices.shape.dims`[`k`+`1`] otherwise). +1. If `i` is present in `batch_dims` (i.e. is equal to `batch_dims[k]` for +some `k`) then we pick the corresponding dimension bounds out of +`start_indices.shape`, skipping `index_vector_dim` (i.e. pick +`start_indices.shape.dims`[`k`] if `k` < `index_vector_dim` and +`start_indices.shape.dims`[`k`+`1`] otherwise). - 2. If `i` is present in `offset_dims` (i.e. equal to `offset_dims`[`k`] for - some `k`) then we pick the corresponding bound out of `slice_sizes` after - accounting for `collapsed_slice_dims` (i.e. we pick - `adjusted_slice_sizes`[`k`] where `adjusted_slice_sizes` is `slice_sizes` - with the bounds at indices `collapsed_slice_dims` removed). +2. If `i` is present in `offset_dims` (i.e. equal to `offset_dims`[`k`] for +some `k`) then we pick the corresponding bound out of `slice_sizes` after +accounting for `collapsed_slice_dims` (i.e. we pick +`adjusted_slice_sizes`[`k`] where `adjusted_slice_sizes` is `slice_sizes` +with the bounds at indices `collapsed_slice_dims` removed). Formally, the operand index `In` corresponding to an output index `Out` is computed as follows: - 1. Let `G` = { `Out`[`k`] for `k` in `batch_dims` }. Use `G` to slice out - vector `S` such that `S`[`i`] = `start_indices`[Combine(`G`, `i`)] where - Combine(A, b) inserts b at position `index_vector_dim` into A. Note that - this is well defined even if `G` is empty -- if `G` is empty then `S` = - `start_indices`. - - 2. Create a starting index, `S``in`, into `operand` using `S` by - scattering `S` using `start_index_map`. More precisely: - 1. `S``in`[`start_index_map`[`k`]] = `S`[`k`] if `k` < - `start_index_map.size`. - 2. `S``in`[`_`] = `0` otherwise. - - 3. Create an index `O``in` into `operand` by scattering the indices - at the offset dimensions in `Out` according to the `collapsed_slice_dims` - set. More precisely: - 1. `O``in`[`expand_offset_dims`(`k`)] = - `Out`[`offset_dims`[`k`]] if `k` < `offset_dims.size` - (`expand_offset_dims` is defined below). - 2. `O``in`[`_`] = `0` otherwise. - 4. `In` is `O``in` + `S``in` where + is element-wise - addition. +1. Let `G` = { `Out`[`k`] for `k` in `batch_dims` }. Use `G` to slice out +vector `S` such that `S`[`i`] = `start_indices`[Combine(`G`, `i`)] where +Combine(A, b) inserts b at position `index_vector_dim` into A. Note that +this is well defined even if `G` is empty -- if `G` is empty then `S` = +`start_indices`. + +2. Create a starting index, `S``in`, into `operand` using `S` by +scattering `S` using `start_index_map`. More precisely: +1. `S``in`[`start_index_map`[`k`]] = `S`[`k`] if `k` < +`start_index_map.size`. +2. `S``in`[`_`] = `0` otherwise. + +3. Create an index `O``in` into `operand` by scattering the indices +at the offset dimensions in `Out` according to the `collapsed_slice_dims` +set. More precisely: +1. `O``in`[`expand_offset_dims`(`k`)] = +`Out`[`offset_dims`[`k`]] if `k` < `offset_dims.size` +(`expand_offset_dims` is defined below). +2. `O``in`[`_`] = `0` otherwise. +4. `In` is `O``in` + `S``in` where + is element-wise +addition. `expand_offset_dims` is the monotonic function with domain [`0`, `offset.size`) and range [`0`, `operand.rank`) \ `collapsed_slice_dims`. So if, e.g., @@ -1282,21 +1282,21 @@ and range [`0`, `operand.rank`) \ `collapsed_slice_dims`. So if, e.g., Informally, every index `Out` in the output array corresponds to an element `E` in the operand array, computed as follows: - - We use the batch dimensions in `Out` to look up a starting index from - `start_indices`. +- We use the batch dimensions in `Out` to look up a starting index from +`start_indices`. - - We use `start_index_map` to map the starting index (which may have size less - than operand.rank) to a "full" starting index into operand. +- We use `start_index_map` to map the starting index (which may have size less +than operand.rank) to a "full" starting index into operand. - - We dynamic-slice out a slice with size `slice_sizes` using the full starting - index. +- We dynamic-slice out a slice with size `slice_sizes` using the full starting +index. - - We reshape the slice by collapsing the `collapsed_slice_dims` dimensions. - Since all collapsed slice dimensions have to have bound 1 this reshape is - always legal. +- We reshape the slice by collapsing the `collapsed_slice_dims` dimensions. +Since all collapsed slice dimensions have to have bound 1 this reshape is +always legal. - - We use the offset dimensions in `Out` to index into this slice to get the - input element, `E`, corresponding to output index `Out`. +- We use the offset dimensions in `Out` to index into this slice to get the +input element, `E`, corresponding to output index `Out`. `index_vector_dim` is set to `start_indices.rank` - `1` in all of the examples that follow. More interesting values for `index_vector_dim` does not @@ -1315,7 +1315,7 @@ the output shape, and maps it to an element in the input array in the following way:
- +
We first select an (`X`,`Y`) vector from the gather indices array using `G`. @@ -1334,7 +1334,7 @@ version of the example above using a "gather indices" array of shape `[4,5,2]` would translate indices like this:
- +
Again, this acts as a batch dynamic slice `G``0` and @@ -1343,27 +1343,27 @@ Again, this acts as a batch dynamic slice `G``0` and The gather operation in XLA generalizes the informal semantics outlined above in the following ways: - 1. We can configure which dimensions in the output shape are the offset - dimensions (dimensions containing `O``0`, `O``1` in - the last example). The output batch dimensions (dimensions containing - `G``0`, `G``1` in the last example) are defined to be - the output dimensions that are not offset dimensions. +1. We can configure which dimensions in the output shape are the offset +dimensions (dimensions containing `O``0`, `O``1` in +the last example). The output batch dimensions (dimensions containing +`G``0`, `G``1` in the last example) are defined to be +the output dimensions that are not offset dimensions. - 2. The number of output offset dimensions explicitly present in the output - shape may be smaller than the input rank. These "missing" dimensions, which - are listed explicitly as `collapsed_slice_dims`, must have a slice size of - `1`. Since they have a slice size of `1` the only valid index for them is - `0` and eliding them does not introduce ambiguity. +2. The number of output offset dimensions explicitly present in the output +shape may be smaller than the input rank. These "missing" dimensions, which +are listed explicitly as `collapsed_slice_dims`, must have a slice size of +`1`. Since they have a slice size of `1` the only valid index for them is +`0` and eliding them does not introduce ambiguity. - 3. The slice extracted from the "Gather Indices" array ((`X`, `Y`) in the last - example) may have fewer elements than the input array rank, and an explicit - mapping dictates how the index should be expanded to have the same rank as - the input. +3. The slice extracted from the "Gather Indices" array ((`X`, `Y`) in the last +example) may have fewer elements than the input array rank, and an explicit +mapping dictates how the index should be expanded to have the same rank as +the input. As a final example, we use (2) and (3) to implement `tf.gather_nd`:
- +
`G``0` and `G``1` are used to slice out a starting index @@ -1442,11 +1442,11 @@ dependency between the while loops. ``` result1 = while (condition, init = init_value) { - Infeed(shape) +Infeed(shape) } result2 = while (condition, init = result1) { - Infeed(shape) +Infeed(shape) } ``` @@ -1464,7 +1464,9 @@ Infeed of the device. Builds a constant literal on device rather than a potentially large host transfer. Creates a rank 1 array of values starting at zero and incrementing by -one. +one. For floating-point types, the produced array is equivalent to +`ConvertElementType(Iota(...))` where the `Iota` is of integral type and the +conversion is to the floating-point type. Arguments | Type | Semantics ---------------- | --------------- | ------------------------------------ diff --git a/tensorflow/compiler/xla/layout.cc b/tensorflow/compiler/xla/layout.cc new file mode 100644 index 0000000000000000000000000000000000000000..e3b5fcd5274881cec31ecf906e3461685f82a1f4 --- /dev/null +++ b/tensorflow/compiler/xla/layout.cc @@ -0,0 +1,96 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/layout.h" + +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" +#include "tensorflow/compiler/xla/layout_util.h" + +namespace xla { + +TileProto Tile::ToProto() const { + TileProto tile_proto; + for (int64 i : dimensions()) { + tile_proto.add_dimensions(i); + } + return tile_proto; +} + +string Tile::ToString() const { + return absl::StrCat("(", absl::StrJoin(dimensions(), ","), ")"); +} + +/* static */ Layout Layout::CreateFromProto(const LayoutProto& proto) { + Layout layout; + layout.set_format(proto.format()); + layout.minor_to_major_.reserve(proto.minor_to_major_size()); + for (const int64 dimension : proto.minor_to_major()) { + layout.add_minor_to_major(dimension); + } + layout.set_max_sparse_elements(proto.max_sparse_elements()); + for (const TileProto& tile_proto : proto.tiles()) { + *layout.add_tiles() = Tile::CreateFromProto(tile_proto); + } + layout.set_element_size_in_bits(proto.element_size_in_bits()); + return layout; +} + +LayoutProto Layout::ToProto() const { + LayoutProto proto; + proto.set_format(format_); + proto.mutable_minor_to_major()->Reserve(minor_to_major_size()); + for (const int64 dimension : minor_to_major()) { + proto.add_minor_to_major(dimension); + } + proto.set_max_sparse_elements(max_sparse_elements_); + for (const Tile& tile : tiles()) { + *proto.add_tiles() = tile.ToProto(); + } + proto.set_element_size_in_bits(element_size_in_bits()); + return proto; +} + +string Layout::ToString() const { + // TODO(b/119839262): Emit tiles in string. + if (format() == SPARSE) { + return absl::StrCat("sparse{", max_sparse_elements(), "}"); + } else if (format() == DENSE) { + return absl::StrCat("{", absl::StrJoin(minor_to_major(), ","), "}"); + } else { + CHECK_EQ(format(), INVALID_FORMAT); + return "invalid{}"; + } +} + +bool Layout::operator==(const Layout& other) const { + return (other.format() == format() && + other.minor_to_major() == minor_to_major() && + other.element_size_in_bits() == element_size_in_bits() && + other.max_sparse_elements() == max_sparse_elements() && + other.tiles() == tiles()); +} + +std::ostream& operator<<(std::ostream& out, const Tile& tile) { + out << tile.ToString(); + return out; +} + +std::ostream& operator<<(std::ostream& out, const Layout& layout) { + out << layout.ToString(); + return out; +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/layout.h b/tensorflow/compiler/xla/layout.h new file mode 100644 index 0000000000000000000000000000000000000000..313368c39e4c976fc481941eb17325101f2ba69a --- /dev/null +++ b/tensorflow/compiler/xla/layout.h @@ -0,0 +1,187 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_LAYOUT_H_ +#define TENSORFLOW_COMPILER_XLA_LAYOUT_H_ + +#include + +#include "absl/types/span.h" + +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/platform/types.h" + +namespace xla { + +// Describes a tile used in tiling-based layout. Refer to +// g3doc/third_party/tensorflow/compiler/xla/g3doc/layout_with_tiling.md for +// details. +class Tile { + public: + Tile() = default; + explicit Tile(absl::Span dimensions) + : dimensions_(dimensions.begin(), dimensions.end()) {} + + // De/Serialize a Tile to and from a TileProto. + static Tile CreateFromProto(const TileProto& tile_proto) { + return Tile(AsInt64Slice(tile_proto.dimensions())); + } + TileProto ToProto() const; + + bool operator==(const Tile& other) const { + return dimensions() == other.dimensions(); + } + bool operator!=(const Tile& other) const { return !(*this == other); } + + string ToString() const; + + // Returns the bound of the tile in the given dimension index. + int64 dimension(int i) const { return dimensions_.at(i); } + + // Returns the dimensions of the tile. + const std::vector& dimensions() const { return dimensions_; } + + private: + // The bounds of the tile. + std::vector dimensions_; +}; + +class Layout { + public: + Layout() = default; + + // Constructs a dense layout with the given minor-to-major order. + explicit Layout(absl::Span minor_to_major) + : format_(DENSE), + minor_to_major_(minor_to_major.begin(), minor_to_major.end()) {} + + // Constructs a dense tiled layout with the given minor-to-major order and + // tiles. + Layout(absl::Span minor_to_major, absl::Span tiles) + : format_(DENSE), + minor_to_major_(minor_to_major.begin(), minor_to_major.end()), + tiles_(tiles.begin(), tiles.end()) {} + + // Construct a shape from a LayoutProto. + static Layout CreateFromProto(const LayoutProto& proto); + + // Returns a LayoutProto representation of the Layout. + LayoutProto ToProto() const; + + // Returns a human-readable string that represents this layout. + string ToString() const; + + bool operator==(const Layout& other) const; + bool operator!=(const Layout& other) const { return !(*this == other); } + + // The following methods mirror the protobuf generated code interface for the + // message LayoutProto. This enabled easy migration of this data structure + // from a proto to a proper C++ class. + // + // TODO(b/29771030): Replace or augment these methods with a more ergonomic + // interface. + + // Methods for accessing the format. + Format format() const { return format_; } + Layout& set_format(Format value) { + format_ = value; + return *this; + } + + // Methods for accessing the minor-to-major array. + int minor_to_major_size() const { return minor_to_major_.size(); } + int64 minor_to_major(int index) const { return minor_to_major_.at(index); } + Layout& set_minor_to_major(int index, int64 value) { + minor_to_major_.at(index) = value; + return *this; + } + Layout& add_minor_to_major(int64 value) { + minor_to_major_.push_back(value); + return *this; + } + Layout& clear_minor_to_major() { + minor_to_major_.clear(); + return *this; + } + const std::vector& minor_to_major() const { return minor_to_major_; } + std::vector* mutable_minor_to_major() { return &minor_to_major_; } + + // Methods for accessing the tile field. + int tiles_size() const { return tiles_.size(); } + const Tile& tiles(int index) const { return tiles_.at(index); } + Tile* mutable_tiles(int index) { return &tiles_.at(index); } + Tile* add_tiles() { + tiles_.push_back(Tile()); + return &tiles_.back(); + } + Layout& clear_tiles() { + tiles_.clear(); + return *this; + } + const std::vector& tiles() const { return tiles_; } + std::vector* mutable_tiles() { return &tiles_; } + + // Methods for accessing the int64 fields. + int64 max_sparse_elements() const { return max_sparse_elements_; } + Layout& set_max_sparse_elements(int64 value) { + max_sparse_elements_ = value; + return *this; + } + int64 element_size_in_bits() const { return element_size_in_bits_; } + Layout& set_element_size_in_bits(int64 value) { + element_size_in_bits_ = value; + return *this; + } + + void Swap(Layout* other) { + using std::swap; + swap(*this, *other); + } + + void Clear() { + format_ = INVALID_FORMAT; + minor_to_major_.clear(); + max_sparse_elements_ = 0; + element_size_in_bits_ = 0; + } + + public: + // The format of this layout. + Format format_ = INVALID_FORMAT; + + // Sequence of dimension numbers, from minor (fastest varying index) to major + // (slowest varying index). + std::vector minor_to_major_; + + // The maximum number of elements that can be stored for SPARSE formats. This + // can be used to determine the maximum size in bytes of arrays stored in + // memory. This field must be zero unless the format is SPARSE. + int64 max_sparse_elements_ = 0; + + // The number of bits used to store an individual array element. + int64 element_size_in_bits_ = 0; + + // The tiles used in tiling-based layout. + std::vector tiles_; +}; + +std::ostream& operator<<(std::ostream& out, const Tile& Tile); +std::ostream& operator<<(std::ostream& out, const Layout& layout); + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_LAYOUT_H_ diff --git a/tensorflow/compiler/xla/layout_test.cc b/tensorflow/compiler/xla/layout_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..fb6abd3f6523b978e72b21ec082ae06973e86243 --- /dev/null +++ b/tensorflow/compiler/xla/layout_test.cc @@ -0,0 +1,104 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/layout.h" + +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" +#include "tensorflow/compiler/xla/layout_util.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/test_helpers.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" + +namespace xla { +namespace { + +class LayoutTest : public ::testing::Test {}; + +TEST_F(LayoutTest, ToString) { + EXPECT_EQ(Layout().ToString(), "invalid{}"); + EXPECT_EQ(Layout({4, 5, 6}).ToString(), "{4,5,6}"); + EXPECT_EQ(Layout().set_format(SPARSE).set_max_sparse_elements(123).ToString(), + "sparse{123}"); + EXPECT_EQ(Layout({4, 5, 6}).ToString(), "{4,5,6}"); + EXPECT_EQ(Layout({3, 2, 1, 0}, {Tile({42, 123}), Tile({4, 5})}).ToString(), + "{3,2,1,0}"); + EXPECT_EQ( + Layout({1, 0}, {Tile({2, 55})}).set_element_size_in_bits(42).ToString(), + "{1,0}"); +} + +TEST_F(LayoutTest, StreamOut) { + { + std::ostringstream oss; + oss << Tile({7, 8}); + EXPECT_EQ(oss.str(), "(7,8)"); + } + + { + std::ostringstream oss; + oss << Layout({0, 1, 2}); + EXPECT_EQ(oss.str(), "{0,1,2}"); + } +} + +TEST_F(LayoutTest, SparseLayoutMaxElements) { + EXPECT_EQ(LayoutUtil::MaxSparseElements(LayoutUtil::MakeSparseLayout(101)), + 101); +} + +TEST_F(LayoutTest, Equality) { + EXPECT_EQ(Layout(), Layout()); + const std::vector empty_dims; + EXPECT_EQ(Layout(empty_dims), Layout(empty_dims)); + EXPECT_NE(Layout(), Layout(empty_dims)); + EXPECT_EQ(Layout({0, 1, 2, 3}), Layout({0, 1, 2, 3})); + EXPECT_NE(Layout({0, 1, 2, 3}), Layout({0, 1, 2})); + EXPECT_EQ(Layout({0, 1, 2}, {Tile({42, 44})}), + Layout({0, 1, 2}, {Tile({42, 44})})); + EXPECT_NE(Layout({0, 1, 2}, {Tile({42, 44})}), + Layout({0, 1, 2}, {Tile({42, 45})})); + EXPECT_NE(Layout({0, 1, 2}, {Tile({42, 44})}), Layout({0, 1, 2, 3})); + EXPECT_EQ(Layout({0, 1, 2}).set_element_size_in_bits(33), + Layout({0, 1, 2}).set_element_size_in_bits(33)); + EXPECT_NE(Layout({0, 1, 2}).set_element_size_in_bits(33), + Layout({0, 1, 2}).set_element_size_in_bits(7)); + EXPECT_EQ(Layout().set_format(SPARSE), Layout().set_format(SPARSE)); + EXPECT_EQ(Layout().set_format(SPARSE).set_max_sparse_elements(42), + Layout().set_format(SPARSE).set_max_sparse_elements(42)); + EXPECT_NE(Layout().set_format(SPARSE).set_max_sparse_elements(42), + Layout().set_format(SPARSE).set_max_sparse_elements(24)); +} + +TEST_F(LayoutTest, LayoutToFromProto) { + // Round-trips a Layout through proto de/serialization. + auto expect_unchanged = [](const Layout& layout) { + EXPECT_EQ(layout, Layout::CreateFromProto(layout.ToProto())); + }; + + expect_unchanged(Layout()); + expect_unchanged(Layout({1, 3, 2, 0})); + expect_unchanged(Layout().set_format(SPARSE)); + expect_unchanged(Layout().set_format(SPARSE).set_max_sparse_elements(123)); + expect_unchanged(Layout({0, 1}).set_element_size_in_bits(42)); + expect_unchanged(Layout({3, 2, 1, 0}, {Tile({42, 123}), Tile({4, 5})})); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/layout_util.cc b/tensorflow/compiler/xla/layout_util.cc index dbb81381acde645f08639737b6e7b6f6ad971f9b..ddccd8c798df5b926d2e5aea8975cb6cb6640824 100644 --- a/tensorflow/compiler/xla/layout_util.cc +++ b/tensorflow/compiler/xla/layout_util.cc @@ -41,15 +41,13 @@ namespace { // Internal helper for GetDefaultLayoutForShape and SetToDefaultLayout. Sets // minor_to_major to the value that represents the default layout. -void SetDefaultLayoutToContainer( - tensorflow::protobuf::RepeatedField* - minor_to_major) { +void SetDefaultLayoutToContainer(std::vector* minor_to_major) { // The default XLA layout is major-to-minor (dim 0 is major). // For more information on XLA layouts, see: // https://www.tensorflow.org/performance/xla/shapes const int64 size = minor_to_major->size(); for (int64 i = 0; i < size; ++i) { - minor_to_major->Set(i, size - 1 - i); + (*minor_to_major)[i] = size - 1 - i; } } @@ -94,9 +92,8 @@ namespace { Layout CreateDefaultLayoutForRank(int64 rank) { Layout layout; layout.set_format(DENSE); - tensorflow::protobuf::RepeatedField* - minor_to_major = layout.mutable_minor_to_major(); - minor_to_major->Resize(rank, 0); + std::vector* minor_to_major = layout.mutable_minor_to_major(); + minor_to_major->resize(rank, 0); SetDefaultLayoutToContainer(minor_to_major); return layout; } @@ -139,9 +136,8 @@ Layout CreateDefaultLayoutForRank(int64 rank) { shape->clear_layout(); } else if (ShapeUtil::IsArray(*shape)) { shape->mutable_layout()->set_format(DENSE); - tensorflow::protobuf::RepeatedField* - minor_to_major = shape->mutable_layout()->mutable_minor_to_major(); - minor_to_major->Resize(shape->dimensions_size(), 0); + auto* minor_to_major = shape->mutable_layout()->mutable_minor_to_major(); + minor_to_major->resize(shape->dimensions_size(), 0); SetDefaultLayoutToContainer(minor_to_major); } else { // Opaque, token types etc. have no layout. @@ -210,9 +206,8 @@ Layout CreateDefaultLayoutForRank(int64 rank) { } if (layout.format() == INVALID_FORMAT || !Format_IsValid(layout.format())) { - return InvalidArgument( - "Layout has an invalid format (%d) in layout {%s}, shape {%s}", - layout.format(), layout.ShortDebugString(), shape.ShortDebugString()); + return InvalidArgument("Layout has an invalid format (%d)", + layout.format()); } if (layout.format() == DENSE) { @@ -316,7 +311,7 @@ Layout CreateDefaultLayoutForRank(int64 rank) { } /* static */ bool LayoutUtil::Equal(const Layout& lhs, const Layout& rhs) { - return protobuf_util::ProtobufEquals(lhs, rhs); + return lhs == rhs; } /* static */ absl::Span LayoutUtil::MinorToMajor( @@ -358,11 +353,7 @@ Layout CreateDefaultLayoutForRank(int64 rank) { } /* static */ string LayoutUtil::HumanString(const Layout& layout) { - if (IsSparse(layout)) { - return absl::StrCat("sparse{", layout.max_sparse_elements(), "}"); - } - CHECK(IsDense(layout)); - return absl::StrCat("{", absl::StrJoin(layout.minor_to_major(), ","), "}"); + return layout.ToString(); } namespace { @@ -444,11 +435,6 @@ Status LayoutUtil::CopyLayoutBetweenShapes(const Shape& src, Shape* dst) { return true; } -std::ostream& operator<<(std::ostream& out, const Layout& layout) { - out << LayoutUtil::HumanString(layout); - return out; -} - /*static*/ size_t LayoutUtil::Hash(const Layout& layout) { using tensorflow::hash; using tensorflow::Hash64Combine; diff --git a/tensorflow/compiler/xla/layout_util.h b/tensorflow/compiler/xla/layout_util.h index 6c298e57252449ce3f1f9055436e918f2d9f17f1..609dba67bcdbcb11be0906b7d87a52a17ba0dfbd 100644 --- a/tensorflow/compiler/xla/layout_util.h +++ b/tensorflow/compiler/xla/layout_util.h @@ -21,6 +21,7 @@ limitations under the License. #include #include "absl/types/span.h" +#include "tensorflow/compiler/xla/layout.h" #include "tensorflow/compiler/xla/shape.h" #include "tensorflow/compiler/xla/status.h" #include "tensorflow/compiler/xla/types.h" @@ -195,8 +196,6 @@ class LayoutUtil { TF_DISALLOW_COPY_AND_ASSIGN(LayoutUtil); }; -std::ostream& operator<<(std::ostream& out, const Layout& layout); - } // namespace xla #endif // TENSORFLOW_COMPILER_XLA_LAYOUT_UTIL_H_ diff --git a/tensorflow/compiler/xla/layout_util_test.cc b/tensorflow/compiler/xla/layout_util_test.cc index 12ce2d2d7c6fa8c590035f9ff2af50001ccf80d8..4cc94c270cd64eb19761cc1044861c7d185b7888 100644 --- a/tensorflow/compiler/xla/layout_util_test.cc +++ b/tensorflow/compiler/xla/layout_util_test.cc @@ -317,17 +317,6 @@ TEST_F(LayoutUtilTest, DefaultLayoutGettersMajorToMinor) { ShapeUtil::MakeShape(F32, {10, 20, 30, 15, 25})))); } -TEST_F(LayoutUtilTest, SparseLayoutMaxElements) { - EXPECT_EQ(LayoutUtil::MaxSparseElements(LayoutUtil::MakeSparseLayout(101)), - 101); -} - -TEST_F(LayoutUtilTest, StreamOut) { - std::ostringstream oss; - oss << LayoutUtil::MakeLayout({0, 1, 2}); - EXPECT_EQ(oss.str(), "{0,1,2}"); -} - TEST_F(LayoutUtilTest, ValidateLayout_ValidArrayLayout) { Shape shape = ShapeUtil::MakeShapeWithLayout(F32, {2, 3}, {0, 1}); auto status = diff --git a/tensorflow/compiler/xla/literal.cc b/tensorflow/compiler/xla/literal.cc index 8f480c1f1079b4e1a5be53958ebdf6e004ad9ebe..277c98721e59ac12965392500fdfdc3d91e59a8b 100644 --- a/tensorflow/compiler/xla/literal.cc +++ b/tensorflow/compiler/xla/literal.cc @@ -1028,20 +1028,21 @@ string ShapeToString(bool print_layout, const Shape& shape) { } void ToStringHelper(const LiteralBase& literal, const ShapeIndex& shape_index, - bool print_layout, std::vector* pieces); + bool print_shape, bool print_layout, + std::vector* pieces); void TupleToStringHelper(const LiteralBase& literal, - const ShapeIndex& shape_index, bool print_layout, - std::vector* pieces) { + const ShapeIndex& shape_index, bool print_shape, + bool print_layout, std::vector* pieces) { const Shape& subshape = ShapeUtil::GetSubshape(literal.shape(), shape_index); - pieces->push_back(ShapeToString(print_layout, subshape)); - pieces->push_back(" (\n"); + pieces->push_back("(\n"); std::vector tuple_pieces; for (int i = 0; i < ShapeUtil::TupleElementCount(subshape); ++i) { ShapeIndex element_index = shape_index; element_index.push_back(i); std::vector element_pieces; - ToStringHelper(literal, element_index, print_layout, &element_pieces); + ToStringHelper(literal, element_index, print_shape, print_layout, + &element_pieces); tuple_pieces.push_back(absl::StrJoin(element_pieces, "")); } pieces->push_back(absl::StrJoin(tuple_pieces, ",\n")); @@ -1049,9 +1050,11 @@ void TupleToStringHelper(const LiteralBase& literal, } void SparseArrayToStringHelper(const LiteralBase& literal, - const Shape& subshape, bool print_layout, - std::vector* pieces) { - pieces->push_back(ShapeToString(print_layout, subshape)); + const Shape& subshape, bool print_shape, + bool print_layout, std::vector* pieces) { + if (print_shape) { + pieces->push_back(ShapeToString(print_layout, subshape)); + } pieces->push_back("{"); int64 rank = ShapeUtil::Rank(subshape); int64 num_elements = literal.sparse_element_count(); @@ -1073,8 +1076,8 @@ void SparseArrayToStringHelper(const LiteralBase& literal, } void DenseArrayToStringHelper(const LiteralBase& literal, - const ShapeIndex& shape_index, bool print_layout, - std::vector* pieces) { + const ShapeIndex& shape_index, bool print_shape, + bool print_layout, std::vector* pieces) { const Shape& subshape = ShapeUtil::GetSubshape(literal.shape(), shape_index); int64 rank = ShapeUtil::Rank(subshape); @@ -1135,7 +1138,7 @@ void DenseArrayToStringHelper(const LiteralBase& literal, } }; - if (rank > 1) { + if (print_shape) { pieces->push_back(ShapeToString(print_layout, subshape)); pieces->push_back(" "); } @@ -1146,19 +1149,23 @@ void DenseArrayToStringHelper(const LiteralBase& literal, } void ToStringHelper(const LiteralBase& literal, const ShapeIndex& shape_index, - bool print_layout, std::vector* pieces) { + bool print_shape, bool print_layout, + std::vector* pieces) { const Shape& subshape = ShapeUtil::GetSubshape(literal.shape(), shape_index); CHECK(LayoutUtil::HasLayout(literal.shape())); CHECK(LayoutUtil::HasLayout(subshape)); if (ShapeUtil::IsTuple(subshape)) { - TupleToStringHelper(literal, shape_index, print_layout, pieces); + TupleToStringHelper(literal, shape_index, print_shape, print_layout, + pieces); } else if (ShapeUtil::IsToken(subshape)) { pieces->push_back("token"); } else if (LayoutUtil::IsSparseArray(subshape)) { - SparseArrayToStringHelper(literal, subshape, print_layout, pieces); + SparseArrayToStringHelper(literal, subshape, print_shape, print_layout, + pieces); } else { CHECK(LayoutUtil::IsDenseArray(subshape)); - DenseArrayToStringHelper(literal, shape_index, print_layout, pieces); + DenseArrayToStringHelper(literal, shape_index, print_shape, print_layout, + pieces); } } @@ -1169,10 +1176,27 @@ int64 LiteralBase::sparse_element_count() const { return sparse_indices()->index_count(); } -string LiteralBase::ToString(bool print_layout) const { +string LiteralBase::ToString() const { + std::vector pieces; + CHECK(LayoutUtil::HasLayout(this->shape())); + ToStringHelper(*this, {}, /*print_shape=*/true, + /*print_layout=*/false, &pieces); + return absl::StrJoin(pieces, ""); +} + +string LiteralBase::ToStringWithoutShape() const { + std::vector pieces; + CHECK(LayoutUtil::HasLayout(this->shape())); + ToStringHelper(*this, {}, /*print_shape=*/false, + /*print_layout=*/false, &pieces); + return absl::StrJoin(pieces, ""); +} + +string LiteralBase::ToStringWithLayout() const { std::vector pieces; CHECK(LayoutUtil::HasLayout(this->shape())); - ToStringHelper(*this, {}, print_layout, &pieces); + ToStringHelper(*this, {}, /*print_shape=*/true, + /*print_layout=*/true, &pieces); return absl::StrJoin(pieces, ""); } diff --git a/tensorflow/compiler/xla/literal.h b/tensorflow/compiler/xla/literal.h index fa9a71af4ceb998a7a289443cbef70eb52cb1a11..67e908e7ec4d4346f4e26a99a42aac26928ec0c2 100644 --- a/tensorflow/compiler/xla/literal.h +++ b/tensorflow/compiler/xla/literal.h @@ -92,9 +92,20 @@ class LiteralBase { // array. string GetR1U8AsString() const; - // Returns a string representation of the literal value. - // Warning: this function can take minutes for multi-million element Literals. - string ToString(bool print_layout = false) const; + // Returns a string representation of the literal value. The Shape of the + // literal is a prefix of the literal value in the string. + + // Warning: this function can take minutes for multi-million + // element Literals. + string ToString() const; + + // Returns a string representation of the literal value which does *not* + // include the shape string. + string ToStringWithoutShape() const; + + // Returns a string representation of the literal value which includes the + // shape string with its layout.does *not* include the shape string. + string ToStringWithLayout() const; // Gets an element in the literal at the given index. The multi_index is // CHECKed against the dimension sizes. diff --git a/tensorflow/compiler/xla/literal_comparison.cc b/tensorflow/compiler/xla/literal_comparison.cc index b044f0ad73f13a0599e77f1f43888bc974e31f73..1ac9a48e805daa86f0dc65b54626195c89241020 100644 --- a/tensorflow/compiler/xla/literal_comparison.cc +++ b/tensorflow/compiler/xla/literal_comparison.cc @@ -46,68 +46,102 @@ uint16 GetRawValue(Eigen::half val) { return val.x; } // between the left-hand-side and right-hand-side, by bit-casting to UnsignedT // -- on miscompare, a nice error message is given in the AssertionFailure. template -Status CompareFloatsBitwiseEqual(FloatT lhs, FloatT rhs, - absl::Span multi_index) { +bool CompareFloatsBitwiseEqual(FloatT lhs, FloatT rhs, + absl::Span multi_index) { + auto ulhs = absl::bit_cast(GetRawValue(lhs)); + auto urhs = absl::bit_cast(GetRawValue(rhs)); + return ulhs == urhs; +} + +// Templated comparator that specializes for float equality comparison with the +// bitwise helper above (this is the un-specialized fallback, to just use the +// default gunit implementation). +template +bool CompareEqual(NativeT lhs, NativeT rhs, + absl::Span multi_index) { + return lhs == rhs; +} + +// Specializations for floating types that do bitwise comparisons when equality +// comparison is requested. +template <> +bool CompareEqual(bfloat16 lhs, bfloat16 rhs, + absl::Span multi_index) { + return CompareFloatsBitwiseEqual(lhs, rhs, multi_index); +} +template <> +bool CompareEqual(Eigen::half lhs, Eigen::half rhs, + absl::Span multi_index) { + return CompareFloatsBitwiseEqual(lhs, rhs, multi_index); +} +template <> +bool CompareEqual(float lhs, float rhs, + absl::Span multi_index) { + return CompareFloatsBitwiseEqual(lhs, rhs, multi_index); +} +template <> +bool CompareEqual(double lhs, double rhs, + absl::Span multi_index) { + return CompareFloatsBitwiseEqual(lhs, rhs, multi_index); +} +template <> +bool CompareEqual(complex64 lhs, complex64 rhs, + absl::Span multi_index) { + return CompareEqual(lhs.real(), rhs.real(), multi_index) && + CompareEqual(lhs.imag(), rhs.imag(), multi_index); +} + +template +Status MakeBitwiseErrorStatus(NativeT lhs, NativeT rhs, + absl::Span multi_index) { auto ulhs = absl::bit_cast(GetRawValue(lhs)); auto urhs = absl::bit_cast(GetRawValue(rhs)); auto lhs_double = static_cast(lhs); auto rhs_double = static_cast(rhs); - if (ulhs != urhs) { return InvalidArgument( "floating values are not bitwise-equal; and equality testing " "was requested: %s=%g=%a vs %s=%g=%a at array index %s", StrCat(absl::Hex(ulhs)), lhs_double, lhs_double, StrCat(absl::Hex(urhs)), rhs_double, rhs_double, LiteralUtil::MultiIndexAsString(multi_index)); - } - return Status::OK(); } -// Templated comparator that specializes for float equality comparison with the -// bitwise helper above (this is the un-specialized fallback, to just use the -// default gunit implementation). template -Status CompareEqual(NativeT lhs, NativeT rhs, - absl::Span multi_index) { - if (lhs == rhs) { - return Status::OK(); - } +Status MakeErrorStatus(NativeT lhs, NativeT rhs, + absl::Span multi_index) { return InvalidArgument( "first mismatch at array index %s:\n expected value: %s\n actual " "value: %s", LiteralUtil::MultiIndexAsString(multi_index), StrCat(lhs), StrCat(rhs)); } -// Specializations for floating types that do bitwise comparisons when equality -// comparison is requested. template <> -Status CompareEqual(bfloat16 lhs, bfloat16 rhs, - absl::Span multi_index) { - return CompareFloatsBitwiseEqual(lhs, rhs, multi_index); +Status MakeErrorStatus(bfloat16 lhs, bfloat16 rhs, + absl::Span multi_index) { + return MakeBitwiseErrorStatus(lhs, rhs, multi_index); } template <> -Status CompareEqual(Eigen::half lhs, Eigen::half rhs, - absl::Span multi_index) { - return CompareFloatsBitwiseEqual(lhs, rhs, multi_index); +Status MakeErrorStatus(Eigen::half lhs, Eigen::half rhs, + absl::Span multi_index) { + return MakeBitwiseErrorStatus(lhs, rhs, multi_index); } template <> -Status CompareEqual(float lhs, float rhs, - absl::Span multi_index) { - return CompareFloatsBitwiseEqual(lhs, rhs, multi_index); +Status MakeErrorStatus(float lhs, float rhs, + absl::Span multi_index) { + return MakeBitwiseErrorStatus(lhs, rhs, multi_index); } template <> -Status CompareEqual(double lhs, double rhs, - absl::Span multi_index) { - return CompareFloatsBitwiseEqual(lhs, rhs, multi_index); +Status MakeErrorStatus(double lhs, double rhs, + absl::Span multi_index) { + return MakeBitwiseErrorStatus(lhs, rhs, multi_index); } template <> -Status CompareEqual(complex64 lhs, complex64 rhs, - absl::Span multi_index) { - auto res = CompareEqual(lhs.real(), rhs.real(), multi_index); - if (!res.ok()) { - return res; +Status MakeErrorStatus(complex64 lhs, complex64 rhs, + absl::Span multi_index) { + if (!CompareEqual(lhs.real(), rhs.real(), multi_index)) { + return MakeErrorStatus(lhs.real(), rhs.real(), multi_index); } - return CompareEqual(lhs.imag(), rhs.imag(), multi_index); + return MakeErrorStatus(lhs.imag(), rhs.imag(), multi_index); } // A recursive function which iterates through every index of expected and @@ -119,7 +153,11 @@ Status Equal(LiteralSlice expected, LiteralSlice actual, if (dimension == expected.shape().dimensions_size()) { NativeT expected_value = expected.Get(multi_index); NativeT actual_value = actual.Get(multi_index); - return CompareEqual(expected_value, actual_value, multi_index); + bool result = + CompareEqual(expected_value, actual_value, multi_index); + return result ? Status::OK() + : MakeErrorStatus(expected_value, actual_value, + multi_index); } Status result; @@ -330,7 +368,7 @@ class NearComparator { NanMismatch(expected, actual, error_.relaxed_nans); float abs_error; float rel_error; - if (CompareEqual(expected, actual, {linear_index}).ok()) { + if (CompareEqual(expected, actual, {linear_index})) { abs_error = 0; rel_error = 0; } else if (is_nan_mismatch) { @@ -344,7 +382,7 @@ class NearComparator { } else if (IsInf(expected) || IsInf(actual)) { // If either the expected or actual value is infinity but not both, // then both absolute and relative error are regarded as inifity. - CHECK(!CompareEqual(expected, actual, {linear_index}).ok()); + CHECK(!CompareEqual(expected, actual, {linear_index})); abs_error = std::numeric_limits::infinity(); rel_error = std::numeric_limits::infinity(); } else { diff --git a/tensorflow/compiler/xla/literal_test.cc b/tensorflow/compiler/xla/literal_test.cc index 49363ad802ddb9520f89b53257216bc7ddaf8ff5..d8c7141cacb8f60cb4ce56d07ac5827a8dbf9b20 100644 --- a/tensorflow/compiler/xla/literal_test.cc +++ b/tensorflow/compiler/xla/literal_test.cc @@ -98,42 +98,42 @@ class LiteralUtilTest : public ::testing::Test { TEST_F(LiteralUtilTest, LiteralScalarToString) { auto true_lit = LiteralUtil::CreateR0(true); - EXPECT_EQ("true", true_lit.ToString()); + EXPECT_EQ("pred[] true", true_lit.ToString()); auto false_lit = LiteralUtil::CreateR0(false); - EXPECT_EQ("false", false_lit.ToString()); + EXPECT_EQ("pred[] false", false_lit.ToString()); auto u32_lit = LiteralUtil::CreateR0(42); - EXPECT_EQ("42", u32_lit.ToString()); + EXPECT_EQ("u32[] 42", u32_lit.ToString()); auto s32_lit = LiteralUtil::CreateR0(-999); - EXPECT_EQ("-999", s32_lit.ToString()); + EXPECT_EQ("s32[] -999", s32_lit.ToString()); auto f32_lit = LiteralUtil::CreateR0(3.14f); - EXPECT_EQ("3.14", f32_lit.ToString()); + EXPECT_EQ("f32[] 3.14", f32_lit.ToString()); auto f16_lit = LiteralUtil::CreateR0(static_cast(0.5f)); - EXPECT_EQ("0.5", f16_lit.ToString()); + EXPECT_EQ("f16[] 0.5", f16_lit.ToString()); auto c64_lit = LiteralUtil::CreateR0({3.14f, 2.78f}); - EXPECT_EQ("(3.14, 2.78)", c64_lit.ToString()); + EXPECT_EQ("c64[] (3.14, 2.78)", c64_lit.ToString()); auto bf16_lit = LiteralUtil::CreateR0(static_cast(0.5f)); - EXPECT_EQ("0.5", bf16_lit.ToString()); + EXPECT_EQ("bf16[] 0.5", bf16_lit.ToString()); // 3.14 will be rounded to 3.14062 in bfloat16 format. auto bf16_lit_truncated = LiteralUtil::CreateR0(static_cast(3.14f)); - ASSERT_EQ("3.14062", bf16_lit_truncated.ToString()); + ASSERT_EQ("bf16[] 3.14062", bf16_lit_truncated.ToString()); auto bf16_lit_truncated2 = LiteralUtil::CreateR0(static_cast(9.001f)); - EXPECT_EQ("9", bf16_lit_truncated2.ToString()); + EXPECT_EQ("bf16[] 9", bf16_lit_truncated2.ToString()); } TEST_F(LiteralUtilTest, LiteralVectorToString) { auto pred_vec = LiteralUtil::CreateR1({true, false, true}); - EXPECT_EQ("{1, 0, 1}", pred_vec.ToString()); + EXPECT_EQ("pred[3] {1, 0, 1}", pred_vec.ToString()); } TEST_F(LiteralUtilTest, R2ToString) { @@ -210,8 +210,8 @@ TEST_F(LiteralUtilTest, TupleToString) { auto scalar = LiteralUtil::CreateR0(1.0); auto matrix = LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}}); auto tuple = LiteralUtil::MakeTuple({&scalar, &matrix}); - const string expected = R"((f32[], f32[2,2]) ( -1, + const string expected = R"(( +f32[] 1, f32[2,2] { { 1, 2 }, { 3, 4 } @@ -1890,7 +1890,7 @@ TEST_F(LiteralUtilTest, SortSparseElements) { literal.AppendSparseElement({3, 4, 5}, 3.0); literal.AppendSparseElement({1, 2, 3}, 1.0); literal.SortSparseElements(); - EXPECT_EQ(literal.ToString(false), + EXPECT_EQ(literal.ToString(), "f32[10,10,10]{[1, 2, 3]: 1, [2, 3, 4]: 2, [3, 4, 5]: 3}"); } diff --git a/tensorflow/compiler/xla/packed_literal_reader.cc b/tensorflow/compiler/xla/packed_literal_reader.cc index 0f86f9f35e105713aa3072a9ebf572d33d35d66d..339660cf44fd64fc5859e72255d63762fcf20efe 100644 --- a/tensorflow/compiler/xla/packed_literal_reader.cc +++ b/tensorflow/compiler/xla/packed_literal_reader.cc @@ -42,8 +42,7 @@ PackedLiteralReader::~PackedLiteralReader() { delete file_; } StatusOr PackedLiteralReader::Read(const Shape& shape, const Layout* layout) { VLOG(3) << "reading shape from file: " << ShapeUtil::HumanString(shape) - << " layout: " - << (layout == nullptr ? "" : layout->ShortDebugString()); + << " layout: " << (layout == nullptr ? "" : layout->ToString()); Shape literal_shape = shape; if (layout != nullptr) { TF_RETURN_IF_ERROR( diff --git a/tensorflow/compiler/xla/primitive_util.cc b/tensorflow/compiler/xla/primitive_util.cc index b16147e3be71771269d8b7a18528bef3a8c72d99..00ad01fc407017624a9183d69e61cb0d382e3f11 100644 --- a/tensorflow/compiler/xla/primitive_util.cc +++ b/tensorflow/compiler/xla/primitive_util.cc @@ -15,6 +15,9 @@ limitations under the License. #include "tensorflow/compiler/xla/primitive_util.h" +#include "absl/strings/ascii.h" +#include "absl/strings/numbers.h" +#include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/platform/logging.h" namespace xla { @@ -90,5 +93,65 @@ bool IsArrayType(PrimitiveType primitive_type) { primitive_type != OPAQUE && primitive_type != TOKEN; } +// Class to memoize the computation of +// absl::AsciiStrToLower(PrimitiveType_Name(p)) +// for all PrimitiveType values "p" +class PrimitiveTypeNameGenerator { + public: + PrimitiveTypeNameGenerator() { + for (int i = 0; i < PrimitiveType_ARRAYSIZE; i++) { + if (PrimitiveType_IsValid(i)) { + lowercase_name_[i] = absl::AsciiStrToLower( + PrimitiveType_Name(static_cast(i))); + } + } + } + const string& LowercaseName(PrimitiveType t) { + return lowercase_name_[static_cast(t)]; + } + + private: + string lowercase_name_[PrimitiveType_ARRAYSIZE]; +}; + +const string& LowercasePrimitiveTypeName(PrimitiveType s) { + static auto* gen = new PrimitiveTypeNameGenerator(); + return gen->LowercaseName(s); +} + +namespace { + +// Returns a map from lower-case primitive type name to primitive type. +const std::unordered_map& GetPrimitiveTypeStringMap() { + static std::unordered_map* name_to_type = [] { + static auto* map = new std::unordered_map; + for (int i = 0; i < PrimitiveType_ARRAYSIZE; i++) { + if (PrimitiveType_IsValid(i) && i != PRIMITIVE_TYPE_INVALID) { + auto value = static_cast(i); + (*map)[LowercasePrimitiveTypeName(value)] = value; + } + } + return map; + }(); + return *name_to_type; +} + +} // namespace + +StatusOr StringToPrimitiveType(absl::string_view name) { + const auto& map = GetPrimitiveTypeStringMap(); + auto found = map.find(string(name)); + if (found == map.end()) { + return InvalidArgument("Invalid element type string: \"%s\".", name); + } + return found->second; +} + +bool IsPrimitiveTypeName(absl::string_view name) { + const auto& map = GetPrimitiveTypeStringMap(); + auto found = map.find(string(name)); + return found != map.end(); +} + } // namespace primitive_util } // namespace xla diff --git a/tensorflow/compiler/xla/primitive_util.h b/tensorflow/compiler/xla/primitive_util.h index 889e9a1ceca675689406d255d348c82c398563aa..70603b6fed1be50c427799e6dce7b8bf9631a6f4 100644 --- a/tensorflow/compiler/xla/primitive_util.h +++ b/tensorflow/compiler/xla/primitive_util.h @@ -20,6 +20,9 @@ limitations under the License. #include +#include "absl/strings/string_view.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" @@ -221,6 +224,17 @@ template <> struct PrimitiveTypeToNative { using type = complex64; }; + +// Returns the lower-case name of the given primitive type. +const string& LowercasePrimitiveTypeName(PrimitiveType s); + +// Returns the PrimitiveType matching the given name. The given name is expected +// to be lower-case. +StatusOr StringToPrimitiveType(absl::string_view name); + +// Returns true if the given name is a primitive type string (lower-case). +bool IsPrimitiveTypeName(absl::string_view name); + } // namespace primitive_util } // namespace xla diff --git a/tensorflow/compiler/xla/primitive_util_test.cc b/tensorflow/compiler/xla/primitive_util_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..1f765d6da9ef65849fe8ede56ced7597d623cb59 --- /dev/null +++ b/tensorflow/compiler/xla/primitive_util_test.cc @@ -0,0 +1,46 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/primitive_util.h" + +#include +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/test_helpers.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" + +namespace xla { +namespace { + +TEST(PrimitiveUtilTest, StringToPrimitiveType) { + auto expect_ok_and_equal = [](const string& str, PrimitiveType expected) { + TF_ASSERT_OK_AND_ASSIGN(PrimitiveType actual, + primitive_util::StringToPrimitiveType(str)); + EXPECT_EQ(expected, actual); + }; + expect_ok_and_equal("f32", F32); + expect_ok_and_equal("tuple", TUPLE); + expect_ok_and_equal("pred", PRED); + expect_ok_and_equal("s32", S32); + + EXPECT_IS_NOT_OK(primitive_util::StringToPrimitiveType("F32").status()); + EXPECT_IS_NOT_OK(primitive_util::StringToPrimitiveType("Pred").status()); + EXPECT_IS_NOT_OK(primitive_util::StringToPrimitiveType("preD").status()); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/python/BUILD b/tensorflow/compiler/xla/python/BUILD index 63ac1c6649210cbae9e238a74e0a45fb8ee4da63..4a57b1051e081a706267df66e239dc9d330c57ba 100644 --- a/tensorflow/compiler/xla/python/BUILD +++ b/tensorflow/compiler/xla/python/BUILD @@ -66,7 +66,10 @@ cc_library( "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/client:xla_computation", + "//tensorflow/compiler/xla/client/lib:cholesky", "//tensorflow/compiler/xla/client/lib:math", + "//tensorflow/compiler/xla/client/lib:qr", + "//tensorflow/compiler/xla/client/lib:triangular_solve", "//tensorflow/compiler/xla/service:platform_util", "//tensorflow/compiler/xla/service:shaped_buffer", "//tensorflow/compiler/xrt:xrt_proto", diff --git a/tensorflow/compiler/xla/python/local_computation_builder.cc b/tensorflow/compiler/xla/python/local_computation_builder.cc index 6e2ee866321a070d55a7221c7c68024ceaa93448..5d191f5a18ebad8213c29fcc08f317db9626e4ed 100644 --- a/tensorflow/compiler/xla/python/local_computation_builder.cc +++ b/tensorflow/compiler/xla/python/local_computation_builder.cc @@ -24,7 +24,10 @@ limitations under the License. #include "tensorflow/cc/framework/ops.h" #include "tensorflow/cc/framework/scope.h" #include "tensorflow/cc/ops/standard_ops.h" +#include "tensorflow/compiler/xla/client/lib/cholesky.h" #include "tensorflow/compiler/xla/client/lib/math.h" +#include "tensorflow/compiler/xla/client/lib/qr.h" +#include "tensorflow/compiler/xla/client/lib/triangular_solve.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/executable_run_options.h" @@ -644,6 +647,15 @@ LocalOp LocalComputationBuilder::ConstantLiteral(const Literal& literal) { return xla::ConstantLiteral(&builder_, literal); } +LocalOp LocalComputationBuilder::Iota(PrimitiveType element_type, int64 size) { + return xla::Iota(&builder_, element_type, size); +} + +LocalOp LocalComputationBuilder::BroadcastedIota(const Shape& shape, + int64 dimension) { + return xla::Iota(&builder_, shape, dimension); +} + LocalOp LocalComputationBuilder::Broadcast( const LocalOp& operand, absl::Span broadcast_sizes) { return xla::Broadcast(operand.op(), broadcast_sizes); @@ -780,6 +792,21 @@ LocalOp LocalComputationBuilder::Call(const LocalComputation& local_computation, return xla::Call(&builder_, local_computation.computation(), xla_ops); } +LocalOp LocalComputationBuilder::CustomCall( + const string& call_target_name, absl::Span operands, + const Shape& shape_with_layout, + const std::vector& operand_shapes_with_layout, + const string& opaque) { + std::vector xla_ops; + xla_ops.reserve(operands.size()); + for (const auto& op : operands) { + xla_ops.push_back(op.op()); + } + return xla::CustomCallWithLayout(&builder_, call_target_name, xla_ops, + shape_with_layout, + operand_shapes_with_layout, opaque); +} + LocalOp LocalComputationBuilder::Transpose( const LocalOp& operand, absl::Span permutation) { return xla::Transpose(operand.op(), permutation); @@ -865,6 +892,27 @@ LocalOp LocalComputationBuilder::SortKeyVal(const LocalOp& keys, return xla::Sort(keys.op(), {values.op()}, dimension); } +LocalOp LocalComputationBuilder::Cholesky(const LocalOp& a) { + return xla::Cholesky(a.op()); +} + +LocalOp LocalComputationBuilder::QR(const LocalOp& a, bool full_matrices) { + XlaBuilder* builder = a.op().builder(); + return builder->ReportErrorOrReturn([&]() -> StatusOr { + TF_ASSIGN_OR_RETURN(auto qr, xla::QRDecomposition(a.op(), full_matrices)); + return xla::Tuple(builder, {qr.q, qr.r}); + }); +} + +LocalOp LocalComputationBuilder::TriangularSolve(const LocalOp& a, + const LocalOp& b, + bool left_side, bool lower, + bool transpose_a, + bool conjugate_a) { + return xla::TriangularSolve(a.op(), b.op(), left_side, lower, transpose_a, + conjugate_a); +} + StatusOr LocalComputationBuilder::BuildConstantSubGraph( const LocalOp& operand) { TF_ASSIGN_OR_RETURN(XlaComputation computation, diff --git a/tensorflow/compiler/xla/python/local_computation_builder.h b/tensorflow/compiler/xla/python/local_computation_builder.h index 149e44570df5c6a3df88bbe2ffa779be47842d82..c6e58ac971d93662c41fc7a6001f94fb26d2eff5 100644 --- a/tensorflow/compiler/xla/python/local_computation_builder.h +++ b/tensorflow/compiler/xla/python/local_computation_builder.h @@ -286,6 +286,10 @@ class LocalComputationBuilder { LocalOp ConstantLiteral(const Literal& literal); + LocalOp Iota(PrimitiveType element_type, int64 size); + + LocalOp BroadcastedIota(const Shape& shape, int64 dimension); + LocalOp Broadcast(const LocalOp& operand, absl::Span broadcast_sizes); @@ -352,6 +356,12 @@ class LocalComputationBuilder { LocalOp Call(const LocalComputation& local_computation, absl::Span operands); + LocalOp CustomCall(const string& call_target_name, + absl::Span operands, + const Shape& shape_with_layout, + const std::vector& operand_shapes_with_layout, + const string& opaque); + LocalOp Transpose(const LocalOp& operand, absl::Span permutation); @@ -394,6 +404,13 @@ class LocalComputationBuilder { LocalOp SortKeyVal(const LocalOp& keys, const LocalOp& values, int64 dimension); + LocalOp QR(const LocalOp& a, bool full_matrices); + + LocalOp Cholesky(const LocalOp& a); + + LocalOp TriangularSolve(const LocalOp& a, const LocalOp& b, bool left_side, + bool lower, bool transpose_a, bool conjugate_a); + StatusOr BuildConstantSubGraph(const LocalOp& operand); #define _FORWARD(method_name, return_sig, args_sig) \ diff --git a/tensorflow/compiler/xla/python/local_computation_builder.i b/tensorflow/compiler/xla/python/local_computation_builder.i index d23d693c1e5bde43b52959e4397aa311268411bb..11fb00e616ad410fd1e5b0225ca3cd5362fef59b 100644 --- a/tensorflow/compiler/xla/python/local_computation_builder.i +++ b/tensorflow/compiler/xla/python/local_computation_builder.i @@ -1051,6 +1051,8 @@ tensorflow::ImportNumpy(); %unignore xla::swig::LocalComputationBuilder::Outfeed; %unignore xla::swig::LocalComputationBuilder::ConstantLiteral; %unignore xla::swig::LocalComputationBuilder::ConstantR0; +%unignore xla::swig::LocalComputationBuilder::Iota; +%unignore xla::swig::LocalComputationBuilder::BroadcastedIota; %unignore xla::swig::LocalComputationBuilder::Broadcast; %unignore xla::swig::LocalComputationBuilder::BroadcastInDim; %unignore xla::swig::LocalComputationBuilder::Pad; @@ -1144,6 +1146,10 @@ tensorflow::ImportNumpy(); %unignore xla::swig::LocalComputationBuilder::Imag; %unignore xla::swig::LocalComputationBuilder::Conj; %unignore xla::swig::LocalComputationBuilder::Complex; +%unignore xla::swig::LocalComputationBuilder::Cholesky; +%unignore xla::swig::LocalComputationBuilder::QR; +%unignore xla::swig::LocalComputationBuilder::TriangularSolve; +%unignore xla::swig::LocalComputationBuilder::CustomCall; %unignore xla::swig::DeleteLocalComputation; %unignore xla::swig::DestructureLocalShapedBufferTuple; %unignore xla::swig::DestructureXrtAllocationTuple; diff --git a/tensorflow/compiler/xla/python/xla_client.py b/tensorflow/compiler/xla/python/xla_client.py index c91a2aaf56dfe2127168628c78e0c4b868a28055..4166fa0327eba5edd0dee030e283c86ade627040 100644 --- a/tensorflow/compiler/xla/python/xla_client.py +++ b/tensorflow/compiler/xla/python/xla_client.py @@ -831,6 +831,33 @@ class ComputationBuilder(object): return self.ParameterWithShape( Shape.from_pyval(value), name=name, parameter_num=parameter_num) + def Iota(self, dtype, size): + """Enqueues an iota constant onto the computation. + + Args: + dtype: expected numpy dtype of the output. + size: integer, the number of elements in the array. + + Returns: + A LocalOp representing the added iota constant. + """ + element_type = DTYPE_TO_XLA_ELEMENT_TYPE[str(np.dtype(dtype))] + return self._client.Iota(element_type, size) + + def BroadcastedIota(self, dtype, shape, dimension): + """Enqueues a broadcasted iota constant onto the computation. + + Args: + dtype: expected numpy dtype of the output. + shape: tuple of integers, the expected output shape (dimensions). + dimension: positive integer, dimension along which to increment values. + + Returns: + A LocalOp representing the added broadcasted iota constant. + """ + xla_shape = Shape.array_shape(dtype, shape) + return self._client.BroadcastedIota(xla_shape, dimension) + def Broadcast(self, operand, sizes): """Enqueues a broadcast operation onto the computation. @@ -1102,6 +1129,31 @@ class ComputationBuilder(object): """ return self._client.Call(computation_to_apply.computation, operands) + def CustomCall(self, + call_target_name, + operands, + shape_with_layout, + operand_shapes_with_layout, + opaque=None): + """Enqueues a custom call operation onto the computation. + + Args: + call_target_name: the name of the function to call. + operands: an iterable of LocalOp. The number and types of operands must + match the arity of `operand_shapes_with_layout`. + shape_with_layout: the shape of the operator's output, with layout. + operand_shapes_with_layout: the shapes of `operands`, including the + expected layouts. + opaque: an opaque string passed to the backend. + + Returns: + A LocalOp representing the added custom call op. + """ + opaque = opaque or '' + return self._client.CustomCall(call_target_name, operands, + shape_with_layout, + operand_shapes_with_layout, opaque) + def Map(self, operands, computation_to_apply, dimensions): """Enqueues a map operation onto the computation. @@ -1411,6 +1463,20 @@ class ComputationBuilder(object): """Enqueues a key-value sort operation onto the computation.""" return self._client.SortKeyVal(keys, values, dimension) + def Cholesky(self, a): + """Enqueues a Cholesky decomposition onto the computation.""" + return self._client.Cholesky(a) + + def QR(self, a, full_matrices=True): + """Enqueues a QR decomposition onto the computation.""" + return self._client.QR(a, full_matrices) + + def TriangularSolve(self, a, b, left_side=False, lower=False, + transpose_a=False, conjugate_a=False): + """Enqueues a triangular-solve operation onto the computation.""" + return self._client.TriangularSolve( + a, b, left_side, lower, transpose_a, conjugate_a) + def _forward_methods_to_local_builder(): """Forward remaining ComputationBuilder methods to the C API. diff --git a/tensorflow/compiler/xla/python/xla_client_test.py b/tensorflow/compiler/xla/python/xla_client_test.py index 21b5c93b615ec429a5da0b4ffe89e8f75f59ef1b..95c6dc8c4570564e361c27fd2bca5c90eebb4661 100644 --- a/tensorflow/compiler/xla/python/xla_client_test.py +++ b/tensorflow/compiler/xla/python/xla_client_test.py @@ -18,6 +18,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import functools import itertools import threading @@ -51,9 +52,11 @@ class LocalComputationTest(unittest.TestCase): def _ExecuteAndCompareExact(self, c, arguments=(), expected=None): self._ExecuteAndAssertWith(np.testing.assert_equal, c, arguments, expected) - def _ExecuteAndCompareClose(self, c, arguments=(), expected=None): - self._ExecuteAndAssertWith(np.testing.assert_allclose, c, arguments, - expected) + def _ExecuteAndCompareClose(self, c, arguments=(), expected=None, rtol=1e-7, + atol=0): + self._ExecuteAndAssertWith( + functools.partial(np.testing.assert_allclose, rtol=rtol, atol=atol), + c, arguments, expected) def NumpyArrayF32(*args, **kwargs): @@ -143,6 +146,17 @@ class ComputationsWithConstantsTest(LocalComputationTest): c.Pow(c.Constant(NumpyArrayF64([1.5, 2.5, 3.0])), c.ConstantF64Scalar(2.)) self._ExecuteAndCompareClose(c, expected=[2.25, 6.25, 9.]) + def testIota(self): + c = self._NewComputation() + c.Iota(np.float32, 10) + self._ExecuteAndCompareExact(c, expected=np.arange(10, dtype=np.float32)) + + def testBroadcastedIota(self): + c = self._NewComputation() + c.BroadcastedIota(np.int64, (2, 3), 1) + expected = np.array([[0, 1, 2], [0, 1, 2]], dtype=np.int64) + self._ExecuteAndCompareExact(c, expected=expected) + def testBooleanAnd(self): c = self._NewComputation() c.And( @@ -1057,6 +1071,38 @@ class SingleOpTest(LocalComputationTest): self.assertTrue(np.all(lo <= result)) self.assertTrue(np.all(result < hi)) + def testCholesky(self): + l = np.array([[4, 0, 0, 0], [6, 5, 0, 0], [2, 14, 16, 0], [3, 6, 1, 4]], + dtype=np.float32) + c = self._NewComputation() + c.Cholesky(c.Constant(np.dot(l, l.T))) + self._ExecuteAndCompareClose(c, expected=l, rtol=1e-4) + + def testQR(self): + a = np.array( + [[4, 6, 8, 10], [6, 45, 54, 63], [8, 54, 146, 166], [10, 63, 166, 310]], + dtype=np.float32) + c = self._NewComputation() + c.QR(c.Constant(a), full_matrices=True) + q, r = self._Execute(c, ()) + np.testing.assert_allclose(np.dot(q, r), a, rtol=1e-4) + + def testTriangularSolve(self): + a_vals = np.array( + [[2, 0, 0, 0], [3, 6, 0, 0], [4, 7, 9, 0], [5, 8, 10, 11]], + dtype=np.float32) + b_vals = np.array([[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]], + dtype=np.float32) + + c = self._NewComputation() + c.TriangularSolve(c.Constant(a_vals), c.Constant(b_vals), left_side=False, + lower=True, transpose_a=True) + self._ExecuteAndCompareClose(c, expected=np.array([ + [0.5, 0.08333334, 0.04629629, 0.03367003], + [2.5, -0.25, -0.1388889, -0.1010101], + [4.5, -0.58333331, -0.32407406, -0.23569024], + ], dtype=np.float32), rtol=1e-4) + def testIsConstant(self): c = self._NewComputation() a = c.ConstantS32Scalar(3) diff --git a/tensorflow/compiler/xla/python_api/xla_shape.py b/tensorflow/compiler/xla/python_api/xla_shape.py index 95b2bf300ec67e9f034f77450416544cb088ae55..bdcd4abd6cc708795416b15412f37dde10d7fe97 100644 --- a/tensorflow/compiler/xla/python_api/xla_shape.py +++ b/tensorflow/compiler/xla/python_api/xla_shape.py @@ -20,6 +20,8 @@ from __future__ import print_function import numpy as _np # Avoids becoming a part of public Tensorflow API. +from six.moves import xrange + from tensorflow.compiler.xla import xla_data_pb2 from tensorflow.compiler.xla.python_api import types diff --git a/tensorflow/compiler/xla/reference_util.cc b/tensorflow/compiler/xla/reference_util.cc index ceb5e74db7c3b9305e9d77068df9ae0a3690af8a..a27e2005dae3a44f4e49032e70f62d633f64779a 100644 --- a/tensorflow/compiler/xla/reference_util.cc +++ b/tensorflow/compiler/xla/reference_util.cc @@ -21,7 +21,6 @@ limitations under the License. #include "absl/memory/memory.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/literal_util.h" -#include "tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul.h" #include "tensorflow/compiler/xla/service/hlo_evaluator.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/shape_inference.h" @@ -32,48 +31,19 @@ limitations under the License. namespace xla { -namespace { - -template -std::unique_ptr> MatmulArray2DImpl( - const Array2D& lhs, const Array2D& rhs, - const std::function& impl_fn) { - CHECK_EQ(lhs.width(), rhs.height()); - int m = lhs.height(); - int n = rhs.width(); - int k = lhs.width(); - auto result = absl::make_unique>(m, n); - // Because Eigen is a header-oriented library, make sure that the Eigen code - // is the same as the code used by the CPU backend (otherwise the linker will - // randomly pick *some* definition). - impl_fn( - /*run_options_ptr=*/nullptr, result->data(), rhs.data(), lhs.data(), n, m, - k, - /*transpose_lhs=*/0, - /*transpose_rhs=*/0); - return result; -} - -} // namespace - /* static */ std::unique_ptr> ReferenceUtil::MatmulArray2D( const Array2D& lhs, const Array2D& rhs) { - return MatmulArray2DImpl( - lhs, rhs, __xla_cpu_runtime_EigenSingleThreadedMatMulF16); + return HloEvaluator::MatmulArray2D(lhs, rhs); } /* static */ std::unique_ptr> ReferenceUtil::MatmulArray2D( const Array2D& lhs, const Array2D& rhs) { - return MatmulArray2DImpl( - lhs, rhs, __xla_cpu_runtime_EigenSingleThreadedMatMulF32); + return HloEvaluator::MatmulArray2D(lhs, rhs); } /* static */ std::unique_ptr> ReferenceUtil::MatmulArray2D( const Array2D& lhs, const Array2D& rhs) { - return MatmulArray2DImpl( - lhs, rhs, __xla_cpu_runtime_EigenSingleThreadedMatMulF64); + return HloEvaluator::MatmulArray2D(lhs, rhs); } /* static */ std::unique_ptr> ReferenceUtil::Array2DF32ToF64( diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD index 4c21ae2a427477caa86fb4130616c38eb3bcf006..f20121e4908053044b0d8eeaea3e90637f822f51 100644 --- a/tensorflow/compiler/xla/service/BUILD +++ b/tensorflow/compiler/xla/service/BUILD @@ -241,6 +241,7 @@ cc_library( ":hlo_casting_utils", ":hlo_query", ":shape_inference", + "//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_util", @@ -249,6 +250,7 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:window_util", "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/service/cpu:runtime_single_threaded_matmul", "//tensorflow/core:lib", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/base", @@ -1012,6 +1014,7 @@ cc_library( srcs = ["name_uniquer.cc"], hdrs = ["name_uniquer.h"], deps = [ + "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:types", "//tensorflow/core:lib", "@com_google_absl//absl/container:flat_hash_map", @@ -1576,6 +1579,7 @@ cc_library( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:lib", "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:optional", @@ -1782,6 +1786,7 @@ tf_cc_test( ":hlo_cse", ":hlo_dce", ":hlo_matchers", + ":hlo_parser", ":hlo_pass", ":hlo_pass_pipeline", ":tuple_simplifier", @@ -3163,6 +3168,7 @@ cc_library( name = "hlo_graph_dumper", srcs = [ "hlo_graph_dumper.cc", + "hlo_graph_html_renderer.cc", ], hdrs = ["hlo_graph_dumper.h"], deps = [ @@ -3624,7 +3630,6 @@ cc_library( srcs = ["hlo_lexer.cc"], hdrs = [ "hlo_lexer.h", - "hlo_token.h", ], deps = [ "//tensorflow/compiler/xla:shape_util", diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.cc b/tensorflow/compiler/xla/service/algebraic_simplifier.cc index 985c5af1c4d89425dd6693585e42e22510fe21f8..1287dcf546d9fe575dd440d48323ed8efbf1de9d 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier.cc @@ -17,6 +17,7 @@ limitations under the License. #include #include +#include #include #include #include @@ -25,6 +26,7 @@ limitations under the License. #include #include "absl/algorithm/container.h" +#include "absl/container/inlined_vector.h" #include "absl/memory/memory.h" #include "absl/strings/str_cat.h" #include "absl/types/optional.h" @@ -239,6 +241,13 @@ class AlgebraicSimplifierVisitor : public DfsHloVisitorWithDefault { // more fusion than leaving the nodes as Dot operations. StatusOr HandleDotStrengthReduction(HloInstruction* dot); + // Removes dimension dim from hlo. + HloInstruction* StripDim(HloInstruction* hlo, int64 dim) { + CHECK_EQ(hlo->shape().dimensions(dim), 1); + return computation_->AddInstruction(HloInstruction::CreateReshape( + ShapeUtil::DeleteDimension(dim, hlo->shape()), hlo)); + } + // Reshapes an instruction to rank 1 if it is not already rank 1. HloInstruction* Flatten(HloInstruction* hlo) { if (ShapeUtil::Rank(hlo->shape()) == 1) { @@ -908,21 +917,51 @@ StatusOr AlgebraicSimplifierVisitor::HandleDotStrengthReduction( HloInstruction* dot) { HloInstruction *lhs, *rhs; CHECK(Match(dot, m::Dot(m::Op(&lhs), m::Op(&rhs)))); - int64 lhs_collapsing_dim = - dot->dot_dimension_numbers().lhs_contracting_dimensions(0); + + const auto kept_dim = [](int64 rank, int64 contracting_dimension, + absl::Span batch_dimensions) -> int64 { + for (int64 i = 0; i < rank; ++i) { + if (i != contracting_dimension && + !absl::c_linear_search(batch_dimensions, i)) { + return i; + } + } + return -1; + }; + + const int64 dot_rank = ShapeUtil::Rank(dot->shape()); + const int64 rhs_rank = ShapeUtil::Rank(rhs->shape()); + const int64 lhs_rank = ShapeUtil::Rank(lhs->shape()); + const auto& dnums = dot->dot_dimension_numbers(); + if (dnums.rhs_contracting_dimensions_size() > 1) { + return false; + } + if (dot_rank > 2 && (lhs_rank != rhs_rank || lhs_rank != dot_rank)) { + return false; + } + int64 lhs_collapsing_dim = dnums.lhs_contracting_dimensions(0); + int64 lhs_kept_dim = kept_dim(lhs_rank, lhs_collapsing_dim, + AsInt64Slice(dnums.lhs_batch_dimensions())); + // If there is no non-contracting dimension in rank 2, do not strength reduce. + if (lhs_kept_dim == -1 && lhs_rank > 1) { + return false; + } if (lhs->IsRank2Transpose()) { lhs = lhs->mutable_operand(0); - lhs_collapsing_dim = 1 - lhs_collapsing_dim; + std::swap(lhs_collapsing_dim, lhs_kept_dim); } - const int64 lhs_kept_dim = 1 - lhs_collapsing_dim; - int64 rhs_collapsing_dim = - dot->dot_dimension_numbers().rhs_contracting_dimensions(0); + int64 rhs_collapsing_dim = dnums.rhs_contracting_dimensions(0); + int64 rhs_kept_dim = kept_dim(rhs_rank, rhs_collapsing_dim, + AsInt64Slice(dnums.rhs_batch_dimensions())); + // If there is no non-contracting dimension in rank 2, do not strength reduce. + if (rhs_kept_dim == -1 && rhs_rank > 1) { + return false; + } if (rhs->IsRank2Transpose()) { rhs = rhs->mutable_operand(0); - rhs_collapsing_dim = 1 - rhs_collapsing_dim; + std::swap(rhs_collapsing_dim, rhs_kept_dim); } - const int64 rhs_kept_dim = 1 - rhs_collapsing_dim; auto as_type = [&](HloInstruction* hlo, const PrimitiveType element_type) { if (hlo->shape().element_type() == element_type) { @@ -945,10 +984,15 @@ StatusOr AlgebraicSimplifierVisitor::HandleDotStrengthReduction( return AddReduce(as_type(hlo, F32), dim); }; + auto broadcast = [&](HloInstruction* hlo, const Shape& shape, + absl::Span dims) { + return computation_->AddInstruction( + HloInstruction::CreateBroadcast(shape, hlo, dims)); + }; + auto broadcast_to_dim = [&](HloInstruction* hlo, const Shape& shape, int64 dim) { - return computation_->AddInstruction( - HloInstruction::CreateBroadcast(shape, hlo, {dim})); + return broadcast(hlo, shape, {dim}); }; auto multiply = [&](HloInstruction* local_lhs, HloInstruction* local_rhs) { @@ -959,11 +1003,9 @@ StatusOr AlgebraicSimplifierVisitor::HandleDotStrengthReduction( // Strength reduce dot(a[K] , b[K]) = // reshape(result.shape, // reduce_sum(multiply(a, b), {0})) - if (ShapeUtil::Rank(rhs->shape()) == 1 && - ShapeUtil::Rank(lhs->shape()) == 1) { - TF_RETURN_IF_ERROR( - ReplaceInstruction(dot, reshape_if_necessary(add_reduce_in_f32( - multiply(Flatten(lhs), Flatten(rhs)), 0)))); + if (rhs_rank == 1 && lhs_rank == 1) { + TF_RETURN_IF_ERROR(ReplaceInstruction( + dot, reshape_if_necessary(add_reduce_in_f32(multiply(lhs, rhs), 0)))); return true; } @@ -977,8 +1019,7 @@ StatusOr AlgebraicSimplifierVisitor::HandleDotStrengthReduction( // Simplify outer product into multiply with implicit broadcasting. // // A dot(a[M, 1], b[1, N]) = multiply(a [M,1], b [1, N]) - if (ShapeUtil::Rank(rhs->shape()) == 2 && - rhs->shape().dimensions(rhs_collapsing_dim) == 1) { + if (rhs_rank == 2 && rhs->shape().dimensions(rhs_collapsing_dim) == 1) { TF_RETURN_IF_ERROR(ReplaceInstruction( dot, multiply(broadcast_to_dim(Flatten(lhs), dot->shape(), 0), broadcast_to_dim(Flatten(rhs), dot->shape(), 1)))); @@ -992,9 +1033,8 @@ StatusOr AlgebraicSimplifierVisitor::HandleDotStrengthReduction( // {0}) // ) // ) - if (ShapeUtil::Rank(lhs->shape()) == 1 || - (ShapeUtil::Rank(lhs->shape()) == 2 && - lhs->shape().dimensions(lhs_kept_dim) == 1)) { + if (lhs_rank == 1 || + (lhs_rank == 2 && lhs->shape().dimensions(lhs_kept_dim) == 1)) { if (ShapeUtil::Rank(rhs->shape()) == 1) { TF_RETURN_IF_ERROR( ReplaceInstruction(dot, reshape_if_necessary(add_reduce_in_f32( @@ -1014,9 +1054,8 @@ StatusOr AlgebraicSimplifierVisitor::HandleDotStrengthReduction( // reshape(result.shape, // reduce_sum(multiply(a, broadcast(reshape([K],b), {1})), {0}) // ) - if (ShapeUtil::Rank(rhs->shape()) == 1 || - (ShapeUtil::Rank(rhs->shape()) == 2 && - rhs->shape().dimensions(rhs_kept_dim) == 1)) { + if (rhs_rank == 1 || + (rhs_rank == 2 && rhs->shape().dimensions(rhs_kept_dim) == 1)) { TF_RETURN_IF_ERROR(ReplaceInstruction( dot, reshape_if_necessary(add_reduce_in_f32( multiply(lhs, broadcast_to_dim(Flatten(rhs), lhs->shape(), @@ -1024,6 +1063,97 @@ StatusOr AlgebraicSimplifierVisitor::HandleDotStrengthReduction( lhs_collapsing_dim)))); return true; } + + // Only consider kDot with batch dimension. + if (dot_rank <= 2) { + return false; + } + + CHECK_EQ(rhs_rank, lhs_rank); + CHECK_EQ(dot_rank, lhs_rank); + // If there is more than one non-contracting dimension or the batch dimensions + // are not equal, bail out since transposes may be required to do a strength + // reduction. + if (dnums.rhs_batch_dimensions_size() + 2 != dot_rank || + !absl::c_equal(dnums.lhs_batch_dimensions(), + dnums.rhs_batch_dimensions())) { + return false; + } + + auto broadcast_dims = [](int64 rank, int64 non_broadcast_dim) { + absl::InlinedVector dims; + for (int64 i = 0; i < rank; ++i) { + if (i != non_broadcast_dim) { + dims.push_back(i); + } + } + return dims; + }; + + // If the contracting dimension is 1, remove the degnerate dimnesions from the + // lhs and rhs, broadcast each to the result shape and multiply. + if (lhs->shape().dimensions(lhs_collapsing_dim) == 1 && + (rhs_kept_dim == rhs_rank - 1 || + (rhs_collapsing_dim == rhs_rank - 1 && rhs_kept_dim == rhs_rank - 2))) { + CHECK_EQ(rhs->shape().dimensions(rhs_collapsing_dim), 1); + const int64 lhs_kept_dim_in_output = + lhs_kept_dim > lhs_collapsing_dim ? (lhs_kept_dim - 1) : lhs_kept_dim; + absl::InlinedVector lhs_broadcast_dims; + for (const int64 dim : dnums.lhs_batch_dimensions()) { + lhs_broadcast_dims.push_back(dim > lhs_collapsing_dim ? (dim - 1) : dim); + } + absl::InlinedVector rhs_broadcast_dims = lhs_broadcast_dims; + lhs_broadcast_dims.push_back(lhs_kept_dim_in_output); + absl::c_sort(lhs_broadcast_dims); + rhs_broadcast_dims.push_back(dot_rank - 1); + absl::c_sort(rhs_broadcast_dims); + TF_RETURN_IF_ERROR(ReplaceInstruction( + dot, reshape_if_necessary( + multiply(broadcast(StripDim(lhs, lhs_collapsing_dim), + dot->shape(), lhs_broadcast_dims), + broadcast(StripDim(rhs, rhs_collapsing_dim), + dot->shape(), rhs_broadcast_dims))))); + return true; + } + + // If the lhs and rhs non-contracting dimensions are both one, strip each one, + // multiply and then reduce the collapsing dimension + if (lhs->shape().dimensions(lhs_kept_dim) == 1 && + rhs->shape().dimensions(rhs_kept_dim) == 1 && + lhs_kept_dim == rhs_kept_dim) { + auto new_lhs = StripDim(lhs, lhs_kept_dim); + auto new_rhs = StripDim(rhs, rhs_kept_dim); + const int64 reduce_dim = rhs_kept_dim < rhs_collapsing_dim + ? (rhs_collapsing_dim - 1) + : rhs_collapsing_dim; + TF_RETURN_IF_ERROR( + ReplaceInstruction(dot, reshape_if_necessary(add_reduce_in_f32( + multiply(new_lhs, new_rhs), reduce_dim)))); + return true; + } + + // If the lhs non-contracting dimensions is one, strip the one, brodcast to + // the rhs shape, multiply and then reduce the collapsing dimension + if (lhs->shape().dimensions(lhs_kept_dim) == 1) { + auto new_lhs = broadcast(StripDim(lhs, lhs_kept_dim), rhs->shape(), + broadcast_dims(rhs_rank, rhs_kept_dim)); + TF_RETURN_IF_ERROR(ReplaceInstruction( + dot, reshape_if_necessary(add_reduce_in_f32(multiply(new_lhs, rhs), + rhs_collapsing_dim)))); + return true; + } + + // If the rhs non-contracting dimensions is one, strip the one, brodcast to + // the lhs shape, multiply and then reduce the collapsing dimension + if (rhs->shape().dimensions(rhs_kept_dim) == 1) { + auto new_rhs = broadcast(StripDim(rhs, rhs_kept_dim), lhs->shape(), + broadcast_dims(lhs_rank, lhs_kept_dim)); + TF_RETURN_IF_ERROR(ReplaceInstruction( + dot, reshape_if_necessary(add_reduce_in_f32(multiply(lhs, new_rhs), + lhs_collapsing_dim)))); + return true; + } + return false; } @@ -1302,25 +1432,31 @@ Status AlgebraicSimplifierVisitor::HandleDot(HloInstruction* dot) { HloInstruction *lhs, *rhs; CHECK(Match(dot, m::Dot(m::Op(&lhs), m::Op(&rhs)))); - // Only optimize F32 or BF16 dot operations where the dot, rhs and lhs are - // rank 2 or below. - if ((dot->shape().element_type() != F32 && - dot->shape().element_type() != BF16) || - ShapeUtil::Rank(lhs->shape()) > 2 || ShapeUtil::Rank(rhs->shape()) > 2 || - ShapeUtil::Rank(dot->shape()) > 2) { - return Status::OK(); - } - // Replace a zero element dot with a broadcast of the constant 0. if (ShapeUtil::IsZeroElementArray(dot->shape()) || ShapeUtil::IsZeroElementArray(lhs->shape()) || ShapeUtil::IsZeroElementArray(rhs->shape())) { - auto zero = computation_->AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.0f))); + auto zero = computation_->AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::Zero(dot->shape().element_type()))); return ReplaceWithNewInstruction( dot, HloInstruction::CreateBroadcast(dot->shape(), zero, {})); } + // Only optimize F32 or BF16 dot operations where the dot, rhs and lhs are + // rank 2 or below. + if (dot->shape().element_type() != F32 && + dot->shape().element_type() != BF16) { + return Status::OK(); + } + if (ShapeUtil::Rank(lhs->shape()) > 2 || ShapeUtil::Rank(rhs->shape()) > 2 || + ShapeUtil::Rank(dot->shape()) > 2) { + if (options_.enable_dot_strength_reduction() && + !options_.is_layout_sensitive()) { + TF_RETURN_IF_ERROR(HandleDotStrengthReduction(dot).status()); + } + return Status::OK(); + } + TF_ASSIGN_OR_RETURN(HloInstruction * dot_of_concat_optimized, OptimizeDotOfConcat(dot)); if (dot_of_concat_optimized) { @@ -2026,6 +2162,7 @@ Status AlgebraicSimplifierVisitor::HandleReshape(HloInstruction* reshape) { reshape, HloInstruction::CreateReshape(reshape->shape(), operand->mutable_operand(0))); } + if (operand->opcode() == HloOpcode::kRng && operand->user_count() == 1) { *operand->mutable_shape() = reshape->shape(); return ReplaceInstruction(reshape, operand); @@ -2748,6 +2885,22 @@ Status AlgebraicSimplifierVisitor::HandleSort(HloInstruction* sort) { return Status::OK(); } +namespace { +bool OnlyPermutesMoreThanOneDegenerateDim(const Shape& shape, + absl::Span perm) { + std::vector new_permutation; + int64 degenerate_count = 0; + for (int64 i = 0; i < perm.size(); ++i) { + if (shape.dimensions(i) != 1) { + new_permutation.push_back(perm[i]); + } else { + ++degenerate_count; + } + } + return degenerate_count > 1 && absl::c_is_sorted(new_permutation); +} +} // namespace + Status AlgebraicSimplifierVisitor::HandleTranspose(HloInstruction* transpose) { auto operand = transpose->mutable_operand(0); if (std::is_sorted(transpose->dimensions().begin(), @@ -2764,6 +2917,15 @@ Status AlgebraicSimplifierVisitor::HandleTranspose(HloInstruction* transpose) { transpose->dimensions()))); } + // Replace transpose with a reshape if more than one degenerate method is + // permuted. + if (OnlyPermutesMoreThanOneDegenerateDim(transpose->shape(), + transpose->dimensions())) { + return ReplaceWithNewInstruction( + transpose, HloInstruction::CreateReshape( + transpose->shape(), transpose->mutable_operand(0))); + } + if (operand->opcode() == HloOpcode::kRng && operand->user_count() == 1) { *operand->mutable_shape() = transpose->shape(); return ReplaceInstruction(transpose, operand); diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc index 14ce519b6a0fd221070006d336d23bddeb6cd621..cfb4c48277605a6f90ef51debac1c3bc26bed070 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc @@ -2047,6 +2047,27 @@ TEST_F(AlgebraicSimplifierTest, TransposesMerged) { computation->root_instruction()->dimensions()); } +TEST_F(AlgebraicSimplifierTest, TransposeIsReshape) { + const char* hlo_string = R"( + HloModule module + + ENTRY test { + param = f32[10] parameter(0) + reshaped = f32[1,1,10] reshape(f32[10] param) + transposed = f32[10,1,1] transpose(f32[1,1,10] reshaped), dimensions={2,1,0} + ROOT reshaped_again = f32[10] reshape(f32[10,1,1] transposed) + } + )"; + TF_ASSERT_OK_AND_ASSIGN( + auto module, + HloRunner::CreateModuleFromString(hlo_string, GetDebugOptionsForTest())); + + HloPassFix simplifier(default_options_); + EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + auto root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, GmockMatch(m::Parameter())); +} + // Test merging reshape and broadcast. TEST_F(AlgebraicSimplifierTest, ReshapeAndBroadcastMerged) { auto m = CreateNewVerifiedModule(); @@ -4083,6 +4104,57 @@ INSTANTIATE_TEST_CASE_P( PadReduceWindowEffectiveBroadcastTest, ::testing::ValuesIn(PadReduceWindowEffectiveBroadcastCases())); +class BatchDotStrengthReductionTest + : public AlgebraicSimplifierTest, + public ::testing::WithParamInterface< + ::testing::tuple> {}; +TEST_P(BatchDotStrengthReductionTest, BatchDotStrengthReduction) { + auto module = CreateNewVerifiedModule(); + int m, k, n; + PrimitiveType element_type; + std::tie(m, k, n, element_type) = GetParam(); + + Shape dot_shape = ShapeUtil::MakeShape(element_type, {1, 3, 5, m, n}); + Shape lhs_shape = ShapeUtil::MakeShape(element_type, {1, 3, 5, m, k}); + Shape rhs_shape = ShapeUtil::MakeShape(element_type, {1, 3, 5, k, n}); + HloComputation::Builder builder(TestName()); + + auto lhs = builder.AddInstruction( + HloInstruction::CreateParameter(0, lhs_shape, "lhs")); + auto rhs = builder.AddInstruction( + HloInstruction::CreateParameter(1, rhs_shape, "rhs")); + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_batch_dimensions(0); + dot_dnums.add_lhs_batch_dimensions(1); + dot_dnums.add_lhs_batch_dimensions(2); + dot_dnums.add_rhs_batch_dimensions(0); + dot_dnums.add_rhs_batch_dimensions(1); + dot_dnums.add_rhs_batch_dimensions(2); + dot_dnums.add_lhs_contracting_dimensions(4); + dot_dnums.add_rhs_contracting_dimensions(3); + builder.AddInstruction(HloInstruction::CreateDot( + dot_shape, lhs, rhs, dot_dnums, DefaultPrecisionConfig(2))); + auto computation = module->AddEntryComputation(builder.Build()); + AlgebraicSimplifier simplifier(default_options_); + TF_ASSERT_OK_AND_ASSIGN(bool changed, simplifier.Run(module.get())); + const bool dot_should_be_transformed = m == 1 || k == 1 || n == 1; + const bool computation_should_be_modified = dot_should_be_transformed; + EXPECT_EQ(changed, computation_should_be_modified); + bool has_no_dot = true; + for (const auto& hlo : computation->instructions()) { + if (hlo->opcode() == HloOpcode::kDot) { + has_no_dot = false; + break; + } + } + EXPECT_EQ(has_no_dot, dot_should_be_transformed); +} + +INSTANTIATE_TEST_CASE_P( + BatchDotStrengthReductionTestInstantiation, BatchDotStrengthReductionTest, + ::testing::Combine(::testing::Values(1, 2), ::testing::Values(1, 2), + ::testing::Values(1, 2), ::testing::Values(F32, BF16))); + class DotStrengthReductionTest : public AlgebraicSimplifierTest, public ::testing::WithParamInterface< diff --git a/tensorflow/compiler/xla/service/ar_crs_combiner.cc b/tensorflow/compiler/xla/service/ar_crs_combiner.cc index 362bc44a1cf377b51c5519c6ab5e0d9628e80e58..47d2c7e35705698d49950c2fa042af1c6327d521 100644 --- a/tensorflow/compiler/xla/service/ar_crs_combiner.cc +++ b/tensorflow/compiler/xla/service/ar_crs_combiner.cc @@ -36,24 +36,40 @@ namespace { namespace m = match; -// If the argument instruction is a CRS in the sequence -// AR -> Convert -> Add -> CRS -// then return the AR in the sequence. -// TODO(b/117554291): Rewrite this to recognize more general patterns, -// not just the specific one of AR -> Add -> Convert -> CRS. -absl::optional MatchesArCrsPattern( - HloInstruction* instruction) { - HloInstruction *ar, *convert, *add, *crs; - if (Match(instruction, - m::CrossReplicaSum( - &crs, m::Add(&add, m::Op(), - m::Convert(&convert, - m::CrossReplicaSum(&ar, m::Op()))))) && - ar->users().size() == 1 && ar->shape().element_type() == BF16 && - convert->shape().element_type() == F32 && !crs->all_reduce_id()) { - return ar; +// Returns true iff the argument instruction is an AllReduce, followed by a +// certain sequence of instructions and then a CRS. It must be possible to move +// the AR past each instruction in the sequence. +bool MatchesArCrsPattern(HloInstruction* instruction) { + auto can_ar_move_past_instruction = [](HloInstruction* instruction) -> bool { + if (instruction->user_count() != 1) { + return false; + } + auto opcode = instruction->opcode(); + return opcode == HloOpcode::kBitcast || opcode == HloOpcode::kTranspose || + opcode == HloOpcode::kReshape || opcode == HloOpcode::kConvert || + opcode == HloOpcode::kAdd || opcode == HloOpcode::kSubtract || + opcode == HloOpcode::kMultiply; + }; + + auto computation_is_addition = [](HloComputation* c) { + return c->instruction_count() == 3 && + Match(c->root_instruction(), m::Add(m::Parameter(), m::Parameter())); + }; + + if (!instruction->IsCrossModuleAllReduce() || + !computation_is_addition(instruction->called_computations()[0]) || + instruction->user_count() != 1) { + return false; } - return absl::optional(); + auto next = instruction->users()[0]; + while (!next->IsCrossReplicaAllReduce()) { + if (can_ar_move_past_instruction(next)) { + next = next->users()[0]; + } else { + return false; + } + } + return computation_is_addition(next->called_computations()[0]); } } // namespace @@ -195,9 +211,8 @@ bool ArCrsCombiner::InstructionsComputeSameValue( void ArCrsCombiner::GroupAllReducesById(HloModule* module) { for (HloComputation* computation : module->MakeNonfusionComputations()) { for (HloInstruction* instruction : computation->instructions()) { - auto ar = MatchesArCrsPattern(instruction); - if (ar) { - all_reduce_map_[*((*ar)->all_reduce_id())].push_back(*ar); + if (MatchesArCrsPattern(instruction)) { + all_reduce_map_[*(instruction->all_reduce_id())].push_back(instruction); } } } @@ -205,21 +220,23 @@ void ArCrsCombiner::GroupAllReducesById(HloModule* module) { void ArCrsCombiner::KeepProvablyEqualInstructionGroups() { for (auto it : all_reduce_map_) { + auto all_reduce_id = it.first; auto instruction_vec = it.second; CHECK_EQ(instruction_vec.size(), num_spatial_partitions_); - auto instr_0 = instruction_vec[0]; - auto add_0 = instr_0->users()[0]->users()[0]; - CHECK_EQ(HloOpcode::kAdd, add_0->opcode()); - for (int i = 1; i < instruction_vec.size(); ++i) { auto instr_i = instruction_vec[i]; - auto add_i = instr_i->users()[0]->users()[0]; - CHECK_EQ(HloOpcode::kAdd, add_i->opcode()); + auto next_0 = instr_0->users()[0]; + auto next_i = instr_i->users()[0]; absl::flat_hash_map visited_pairs; - if (!InstructionsComputeSameValue(add_0, add_i, &visited_pairs)) { - all_reduce_map_.erase(it.first); - } + do { + if (!InstructionsComputeSameValue(next_0, next_i, &visited_pairs)) { + all_reduce_map_.erase(all_reduce_id); + break; + } + next_0 = next_0->users()[0]; + next_i = next_i->users()[0]; + } while (!next_0->IsCrossReplicaAllReduce()); } } } @@ -228,47 +245,51 @@ StatusOr ArCrsCombiner::RewriteGraph() { if (all_reduce_map_.empty()) { return false; } - - auto computation_is_addition = [](HloComputation* c) { - return c->instruction_count() == 3 && - Match(c->root_instruction(), m::Add(m::Parameter(), m::Parameter())); - }; - for (auto it : all_reduce_map_) { auto instruction_vec = it.second; for (auto all_reduce : instruction_vec) { auto parent_computation = all_reduce->parent(); - auto convert = all_reduce->users()[0]; - auto add = convert->users()[0]; - auto crs = add->users()[0]; - - if (!computation_is_addition(all_reduce->called_computations()[0]) || - !computation_is_addition(crs->called_computations()[0])) { - continue; + auto all_reduce_id = all_reduce->all_reduce_id(); + auto prev = all_reduce->mutable_operand(0); + auto next = all_reduce->users()[0]; + TF_CHECK_OK(all_reduce->ReplaceUseWith(next, prev)); + TF_CHECK_OK(parent_computation->RemoveInstruction(all_reduce)); + while (!next->IsCrossReplicaAllReduce()) { + switch (next->opcode()) { + case HloOpcode::kBitcast: + case HloOpcode::kTranspose: + case HloOpcode::kReshape: + case HloOpcode::kConvert: + case HloOpcode::kMultiply: + break; + case HloOpcode::kAdd: + case HloOpcode::kSubtract: { + auto other_operand = (next->operands()[0] == prev) + ? next->operands()[1] + : next->operands()[0]; + // To move the AR past the addition/subtraction, we need to divide + // other_operand by the number of spatial partitions. + auto shape = other_operand->shape(); + Literal lit(shape); + lit.PopulateWithValue(num_spatial_partitions_); + auto divisor = parent_computation->AddInstruction( + HloInstruction::CreateConstant(lit.Clone())); + auto division = + parent_computation->AddInstruction(HloInstruction::CreateBinary( + shape, HloOpcode::kDivide, other_operand, divisor)); + TF_CHECK_OK(other_operand->ReplaceUseWith(next, division)); + break; + } + default: + LOG(FATAL) << "Unexpected instruction: " << next->ToShortString(); + } + prev = next; + next = next->users()[0]; } - HloInstruction* other_summand = (add->operands()[0] == convert) - ? add->operands()[1] - : add->operands()[0]; - // To move the AR past the addition, we need to divide other_summand by - // the number of spatial partitions. - CHECK_EQ(all_reduce->user_count(), 1); - TF_CHECK_OK( - all_reduce->ReplaceAllUsesWith(all_reduce->mutable_operand(0))); - auto shape = other_summand->shape(); - Literal lit(shape); - lit.PopulateWithValue(num_spatial_partitions_); - auto divisor = parent_computation->AddInstruction( - HloInstruction::CreateConstant(lit.Clone())); - auto division = - parent_computation->AddInstruction(HloInstruction::CreateBinary( - shape, HloOpcode::kDivide, other_summand, divisor)); - TF_CHECK_OK(other_summand->ReplaceUseWith(add, division)); // The AllReduce and the CRS are combined to an all-core AllReduce. - crs->set_all_reduce_id(all_reduce->all_reduce_id()); - TF_CHECK_OK(parent_computation->RemoveInstruction(all_reduce)); + next->set_all_reduce_id(all_reduce_id); } } - return true; } diff --git a/tensorflow/compiler/xla/service/ar_crs_combiner.h b/tensorflow/compiler/xla/service/ar_crs_combiner.h index f6a7ef76ec3b76972d1b2c7fb548cecfb9423160..6be7e1002dc6822bf0b563721f00896da171c0a9 100644 --- a/tensorflow/compiler/xla/service/ar_crs_combiner.h +++ b/tensorflow/compiler/xla/service/ar_crs_combiner.h @@ -25,9 +25,12 @@ limitations under the License. namespace xla { -// Combine an AllReduce and a CrossReplicaSum when they are close to each other -// in the graph, to use an efficient CrossReplicaSum implementation that -// fully utilizes the interconnect bandwidth. +// When the HLO graph contains an AllReduce, followed by some simple linear +// operations, followed by a CrossReplicaSum, we can combine the AR and the CRS, +// to use an efficient CrossReplicaSum implementation that fully utilizes the +// interconnect bandwidth. +// Such sequences appear in spatially partitioned models. +// This pass must run right after spatial partitioning. class ArCrsCombiner : public HloModulePass { public: ArCrsCombiner(int num_spatial_partitions) diff --git a/tensorflow/compiler/xla/service/ar_crs_combiner_test.cc b/tensorflow/compiler/xla/service/ar_crs_combiner_test.cc index 10171835d83c75fef091a34b8fe102d263211307..8a4fd0ee1b25ec82f5dadfc8446af185914d4033 100644 --- a/tensorflow/compiler/xla/service/ar_crs_combiner_test.cc +++ b/tensorflow/compiler/xla/service/ar_crs_combiner_test.cc @@ -32,8 +32,8 @@ HloModule foobar ENTRY %entrycomp (p: f32[2,2]) -> (f32[2,2], f32[2,2]) { %p = f32[2,2] parameter(0) - %constant.f32.1 = f32[2,2] constant(f32[2,2] {{1, 2}, {3, 4}}) - %constant.f32.2 = f32[2,2] constant(f32[2,2] {{1, 2}, {3, 4}}) + %constant.f32.1 = f32[2,2] constant({{1, 2}, {3, 4}}) + %constant.f32.2 = f32[2,2] constant({{1, 2}, {3, 4}}) ROOT %tuple = (f32[2,2], f32[2,2]) tuple(%constant.f32.1, %constant.f32.2) } )"; @@ -91,7 +91,7 @@ HloModule foobar ENTRY %entrycomp (p: f32[2,2]) -> ((f32[2,2]), (f32[2,2], f32[2,2])) { %p = f32[2,2] parameter(0) - %constant.f32 = f32[2,2] constant(f32[2,2] {{1, 2}, {3, 4}}) + %constant.f32 = f32[2,2] constant({{1, 2}, {3, 4}}) %tuple1 = (f32[2,2]) tuple(%constant.f32) %tuple2 = (f32[2,2], f32[2,2]) tuple(%constant.f32, %constant.f32) ROOT %tuple = ((f32[2,2]), (f32[2,2], f32[2,2])) tuple(%tuple1, %tuple2) @@ -152,7 +152,7 @@ HloModule foobar ENTRY %entrycomp (p: f32[2,2]) -> (f32[2,2], f32[2,2]) { %p = f32[2,2] parameter(0) - %constant.f32 = f32[2,2] constant(f32[2,2] {{1, 2}, {3, 4}}) + %constant.f32 = f32[2,2] constant({{1, 2}, {3, 4}}) %tuple.1 = (f32[2,2], f32[2,2]) tuple(%constant.f32, %constant.f32) %get-tuple-element.1 = f32[2,2] get-tuple-element(%tuple.1), index=0 %get-tuple-element.2 = f32[2,2] get-tuple-element(%tuple.1), index=0 @@ -174,7 +174,7 @@ HloModule foobar ENTRY %entrycomp (p: f32[2,2]) -> (f32[2,2], f32[2,2]) { %p = f32[2,2] parameter(0) - %constant.f32 = f32[2,2] constant(f32[2,2] {{1, 2}, {3, 4}}) + %constant.f32 = f32[2,2] constant({{1, 2}, {3, 4}}) %tuple.1 = (f32[2,2], f32[2,2]) tuple(%constant.f32, %constant.f32) %get-tuple-element.1 = f32[2,2] get-tuple-element(%tuple.1), index=0 %get-tuple-element.2 = f32[2,2] get-tuple-element(%tuple.1), index=1 @@ -196,8 +196,8 @@ HloModule foobar ENTRY %entrycomp (p: f32[2,2]) -> (f32[2,2], f32[2,2]) { %p = f32[2,2] parameter(0) - %constant.f32.1 = f32[2,2] constant(f32[2,2] {{1, 2}, {3, 4}}) - %constant.f32.2 = f32[2,2] constant(f32[2,2] {{2, 3}, {4, 5}}) + %constant.f32.1 = f32[2,2] constant({{1, 2}, {3, 4}}) + %constant.f32.2 = f32[2,2] constant({{2, 3}, {4, 5}}) %tuple.1 = (f32[2,2], f32[2,2]) tuple(%constant.f32.1, %constant.f32.2) %get-tuple-element.1 = f32[2,2] get-tuple-element(%tuple.1), index=0 %get-tuple-element.2 = f32[2,2] get-tuple-element(%tuple.1), index=1 @@ -226,7 +226,7 @@ HloModule foobar %body (x: (f32[2,2], f32[2,2])) -> (f32[2,2], f32[2,2]) { %x = (f32[2,2], f32[2,2]) parameter(0) - %constant.f32 = f32[2,2] constant(f32[2,2] {{1, 2}, {3, 4}}) + %constant.f32 = f32[2,2] constant({{1, 2}, {3, 4}}) %get-tuple-element.1 = f32[2,2] get-tuple-element(%x), index=0 %get-tuple-element.2 = f32[2,2] get-tuple-element(%x), index=1 %add.1 = f32[2,2] add(%get-tuple-element.1, %constant.f32) @@ -235,7 +235,7 @@ HloModule foobar } ENTRY %WhileLoop () -> (f32[2,2], f32[2,2]) { - %constant.f32 = f32[2,2] constant(f32[2,2] {{3, 4}, {5, 6}}) + %constant.f32 = f32[2,2] constant({{3, 4}, {5, 6}}) %init.tuple = (f32[2,2], f32[2,2]) tuple(%constant.f32, %constant.f32) ROOT %while = (f32[2,2], f32[2,2]) while(%init.tuple), condition=%condition, body=%body } @@ -263,7 +263,7 @@ HloModule foobar %body (x: (f32[2,2], f32[2,2])) -> (f32[2,2], f32[2,2]) { %x = (f32[2,2], f32[2,2]) parameter(0) - %constant.f32 = f32[2,2] constant(f32[2,2] {{1, 2}, {3, 4}}) + %constant.f32 = f32[2,2] constant({{1, 2}, {3, 4}}) %get-tuple-element.1 = f32[2,2] get-tuple-element(%x), index=0 %get-tuple-element.2 = f32[2,2] get-tuple-element(%x), index=1 %add.1 = f32[2,2] add(%get-tuple-element.1, %constant.f32) @@ -272,8 +272,8 @@ HloModule foobar } ENTRY %WhileLoop () -> (f32[2,2], f32[2,2]) { - %constant.f32.1 = f32[2,2] constant(f32[2,2] {{3, 4}, {5, 6}}) - %constant.f32.2 = f32[2,2] constant(f32[2,2] {{3, 4}, {7, 8}}) + %constant.f32.1 = f32[2,2] constant({{3, 4}, {5, 6}}) + %constant.f32.2 = f32[2,2] constant({{3, 4}, {7, 8}}) %init.tuple = (f32[2,2], f32[2,2]) tuple(%constant.f32.1, %constant.f32.2) ROOT %while = (f32[2,2], f32[2,2]) while(%init.tuple), condition=%condition, body=%body } @@ -301,8 +301,8 @@ HloModule foobar %body (x: (f32[2,2], f32[2,2])) -> (f32[2,2], f32[2,2]) { %x = (f32[2,2], f32[2,2]) parameter(0) - %constant.f32.1 = f32[2,2] constant(f32[2,2] {{1, 2}, {3, 4}}) - %constant.f32.2 = f32[2,2] constant(f32[2,2] {{3, 4}, {1, 2}}) + %constant.f32.1 = f32[2,2] constant({{1, 2}, {3, 4}}) + %constant.f32.2 = f32[2,2] constant({{3, 4}, {1, 2}}) %get-tuple-element.1 = f32[2,2] get-tuple-element(%x), index=0 %get-tuple-element.2 = f32[2,2] get-tuple-element(%x), index=1 %add.1 = f32[2,2] add(%get-tuple-element.1, %constant.f32.1) @@ -311,7 +311,7 @@ HloModule foobar } ENTRY %WhileLoop () -> (f32[2,2], f32[2,2]) { - %constant.f32 = f32[2,2] constant(f32[2,2] {{3, 4}, {5, 6}}) + %constant.f32 = f32[2,2] constant({{3, 4}, {5, 6}}) %init.tuple = (f32[2,2], f32[2,2]) tuple(%constant.f32, %constant.f32) ROOT %while = (f32[2,2], f32[2,2]) while(%init.tuple), condition=%condition, body=%body } @@ -326,11 +326,27 @@ ENTRY %WhileLoop () -> (f32[2,2], f32[2,2]) { EXPECT_FALSE(ArCrsCombiner::TestInstructionsComputeSameValue(i1, i2)); } -TEST_F(ArCrsCombinerTest, RewritePatternArConvertAddCrs) { +void CompareReplicaGroups(const std::vector& groups_before, + const std::vector& groups_after) { + ASSERT_EQ(groups_before.size(), groups_after.size()); + for (int i = 0; i < groups_before.size(); ++i) { + // Somewhat verbose way to compare the replica_ids, because EqualsProto + // is not available in the open-source build. + auto group_before = groups_before[i]; + std::vector ids_before(group_before.replica_ids().begin(), + group_before.replica_ids().end()); + auto group_after = groups_after[i]; + std::vector ids_after(group_after.replica_ids().begin(), + group_after.replica_ids().end()); + EXPECT_EQ(ids_before, ids_after); + } +} + +TEST_F(ArCrsCombinerTest, RewriteArConvertCrs) { const char* module_str = R"( HloModule foobar -%binary_add (a: bf16[], b: bf16[]) -> bf16[] { +%sum.bf16 (a: bf16[], b: bf16[]) -> bf16[] { %a = bf16[] parameter(0) %b = bf16[] parameter(1) ROOT %add = bf16[] add(%a, %b) @@ -342,48 +358,257 @@ HloModule foobar ROOT %add = f32[] add(%x, %y) } -ENTRY %entrycomp (p: f32[2,2]) -> (f32[2,2], f32[2,2]) { - %p = f32[2,2] parameter(0) - %constant.bf16 = bf16[2,2] constant(bf16[2,2] {{1, 2}, {3, 4}}) - %constant.f32 = f32[2,2] constant(f32[2,2] {{1, 2}, {3, 4}}) +ENTRY %entrycomp (p: bf16[]) -> (f32[], f32[]) { + %p = bf16[] parameter(0) + + %cross-replica-sum.ar.1 = bf16[] + cross-replica-sum(%p), + replica_groups={{0},{1}}, + all_reduce_id=1, + to_apply=%sum.bf16, + sharding={maximal device=0} + %convert.1 = f32[] + convert(%cross-replica-sum.ar.1), + sharding={maximal device=0} + %cross-replica-sum.1 = f32[] + cross-replica-sum(%convert.1), + replica_groups={{0,1}}, + to_apply=%sum.f32, + sharding={maximal device=0} + + %cross-replica-sum.ar.2 = bf16[] + cross-replica-sum(%p), + replica_groups={{0},{1}}, + all_reduce_id=1, + to_apply=%sum.bf16, + sharding={maximal device=1} + %convert.2 = f32[] + convert(%cross-replica-sum.ar.2), + sharding={maximal device=1} + %cross-replica-sum.2 = f32[] + cross-replica-sum(%convert.2), + replica_groups={{0,1}}, + to_apply=%sum.f32, + sharding={maximal device=1} + + ROOT %tuple = (f32[], f32[]) + tuple(%cross-replica-sum.1, %cross-replica-sum.2), + sharding={{maximal device=0}, {maximal device=1}} +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(module_str)); + auto crs_before = + module->entry_computation()->root_instruction()->operands()[0]; + auto replica_groups_before = crs_before->replica_groups(); + ArCrsCombiner combiner(2); + auto changed = combiner.Run(module.get()).ValueOrDie(); + EXPECT_TRUE(changed); + EXPECT_THAT(module->entry_computation()->root_instruction(), + op::Tuple(op::CrossReplicaSum(op::Convert(op::Parameter())), + op::CrossReplicaSum(op::Convert(op::Parameter())))); + auto crs_after = + module->entry_computation()->root_instruction()->operands()[0]; + auto replica_groups_after = crs_after->replica_groups(); + CompareReplicaGroups(replica_groups_before, replica_groups_after); +} + +TEST_F(ArCrsCombinerTest, RewriteArBitcastCrs) { + const char* module_str = R"( +HloModule foobar + +%sum.1 (a: f32[2,1], b: f32[2,1]) -> f32[2,1] { + %a = f32[2,1] parameter(0) + %b = f32[2,1] parameter(1) + ROOT %add = f32[2,1] add(%a, %b) +} + +%sum.2 (x: f32[2], y: f32[2]) -> f32[2] { + %x = f32[2] parameter(0) + %y = f32[2] parameter(1) + ROOT %add = f32[2] add(%x, %y) +} + +ENTRY %entrycomp (p: f32[2,1]) -> (f32[2], f32[2]) { + %p = f32[2,1] parameter(0) + + %cross-replica-sum.ar.1 = f32[2,1] + cross-replica-sum(%p), + replica_groups={{0},{1}}, + all_reduce_id=1, + to_apply=%sum.1, + sharding={maximal device=0} + %bitcast.1 = f32[2]{0} bitcast(f32[2,1]{1,0} %cross-replica-sum.ar.1) + %cross-replica-sum.1 = f32[2] + cross-replica-sum(%bitcast.1), + replica_groups={{0,1}}, + to_apply=%sum.2, + sharding={maximal device=0} + + %cross-replica-sum.ar.2 = f32[2,1] + cross-replica-sum(%p), + replica_groups={{0},{1}}, + all_reduce_id=1, + to_apply=%sum.1, + sharding={maximal device=1} + %bitcast.2 = f32[2]{0} bitcast(f32[2,1]{1,0} %cross-replica-sum.ar.2) + %cross-replica-sum.2 = f32[2] + cross-replica-sum(%bitcast.2), + replica_groups={{0,1}}, + to_apply=%sum.2, + sharding={maximal device=1} + + ROOT %tuple = (f32[], f32[]) + tuple(%cross-replica-sum.1, %cross-replica-sum.2), + sharding={{maximal device=0}, {maximal device=1}} +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(module_str)); + auto crs_before = + module->entry_computation()->root_instruction()->operands()[0]; + auto replica_groups_before = crs_before->replica_groups(); + ArCrsCombiner combiner(2); + auto changed = combiner.Run(module.get()).ValueOrDie(); + EXPECT_TRUE(changed); + EXPECT_THAT(module->entry_computation()->root_instruction(), + op::Tuple(op::CrossReplicaSum(op::Bitcast(op::Parameter())), + op::CrossReplicaSum(op::Bitcast(op::Parameter())))); + auto crs_after = + module->entry_computation()->root_instruction()->operands()[0]; + auto replica_groups_after = crs_after->replica_groups(); + CompareReplicaGroups(replica_groups_before, replica_groups_after); +} + +TEST_F(ArCrsCombinerTest, RewriteArMultiplyCrs) { + const char* module_str = R"( +HloModule foobar + +%sum.f32 (x: f32[], y: f32[]) -> f32[] { + %x = f32[] parameter(0) + %y = f32[] parameter(1) + ROOT %add = f32[] add(%x, %y) +} + +ENTRY %entrycomp (p: f32[]) -> (f32[], f32[]) { + %p = f32[] parameter(0) + %constant.f32 = f32[] constant(123) + + %cross-replica-sum.ar.1 = f32[] + cross-replica-sum(%p), + replica_groups={{0},{1}}, + all_reduce_id=1, + to_apply=%sum.f32, + sharding={maximal device=0} + %multiply.1 = f32[] + multiply(%cross-replica-sum.ar.1, %constant.f32), + sharding={maximal device=0} + %cross-replica-sum.1 = f32[] + cross-replica-sum(%multiply.1), + replica_groups={{0,1}}, + to_apply=%sum.f32, + sharding={maximal device=0} + + %cross-replica-sum.ar.2 = f32[] + cross-replica-sum(%p), + replica_groups={{0},{1}}, + all_reduce_id=1, + to_apply=%sum.f32, + sharding={maximal device=1} + %multiply.2 = f32[] + multiply(%cross-replica-sum.ar.2, %constant.f32), + sharding={maximal device=1} + %cross-replica-sum.2 = f32[] + cross-replica-sum(%multiply.2), + replica_groups={{0,1}}, + to_apply=%sum.f32, + sharding={maximal device=1} + + ROOT %tuple = (f32[], f32[]) + tuple(%cross-replica-sum.1, %cross-replica-sum.2), + sharding={{maximal device=0}, {maximal device=1}} +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(module_str)); + auto crs_before = + module->entry_computation()->root_instruction()->operands()[0]; + auto replica_groups_before = crs_before->replica_groups(); + ArCrsCombiner combiner(2); + auto changed = combiner.Run(module.get()).ValueOrDie(); + EXPECT_TRUE(changed); + EXPECT_THAT( + module->entry_computation()->root_instruction(), + op::Tuple( + op::CrossReplicaSum(op::Multiply(op::Parameter(), op::Constant())), + op::CrossReplicaSum(op::Multiply(op::Parameter(), op::Constant())))); + auto crs_after = + module->entry_computation()->root_instruction()->operands()[0]; + auto replica_groups_after = crs_after->replica_groups(); + CompareReplicaGroups(replica_groups_before, replica_groups_after); +} + +TEST_F(ArCrsCombinerTest, RewriteArConvertAddCrs) { + const char* module_str = R"( +HloModule foobar + +%sum.bf16 (a: bf16[], b: bf16[]) -> bf16[] { + %a = bf16[] parameter(0) + %b = bf16[] parameter(1) + ROOT %add = bf16[] add(%a, %b) +} + +%sum.f32 (x: f32[], y: f32[]) -> f32[] { + %x = f32[] parameter(0) + %y = f32[] parameter(1) + ROOT %add = f32[] add(%x, %y) +} + +ENTRY %entrycomp (p: f32[]) -> (f32[], f32[]) { + %p = f32[] parameter(0) + %constant.bf16 = bf16[] constant(1) + %constant.f32 = f32[] constant(2) - %cross-replica-sum.ar.1 = bf16[2,2] + %cross-replica-sum.ar.1 = bf16[] cross-replica-sum(%constant.bf16), replica_groups={{0},{1}}, all_reduce_id=1, - to_apply=%binary_add, + to_apply=%sum.bf16, sharding={maximal device=0} - %convert.1 = f32[2,2] + %convert.1 = f32[] convert(%cross-replica-sum.ar.1), sharding={maximal device=0} - %add.1 = f32[2,2] + %add.1 = f32[] add(%constant.f32, %convert.1), sharding={maximal device=0} - %cross-replica-sum.1 = f32[2,2] + %cross-replica-sum.1 = f32[] cross-replica-sum(%add.1), replica_groups={{0,1}}, to_apply=%sum.f32, sharding={maximal device=0} - %cross-replica-sum.ar.2 = bf16[2,2] + %cross-replica-sum.ar.2 = bf16[] cross-replica-sum(%constant.bf16), replica_groups={{0},{1}}, all_reduce_id=1, - to_apply=%binary_add, + to_apply=%sum.bf16, sharding={maximal device=1} - %convert.2 = f32[2,2] + %convert.2 = f32[] convert(%cross-replica-sum.ar.2), sharding={maximal device=1} - %add.2 = f32[2,2] + %add.2 = f32[] add(%constant.f32, %convert.2), sharding={maximal device=1} - %cross-replica-sum.2 = f32[2,2] + %cross-replica-sum.2 = f32[] cross-replica-sum(%add.2), replica_groups={{0,1}}, to_apply=%sum.f32, sharding={maximal device=1} - ROOT %tuple = (f32[2,2], f32[2,2]) + ROOT %tuple = (f32[], f32[]) tuple(%cross-replica-sum.1, %cross-replica-sum.2), sharding={{maximal device=0}, {maximal device=1}} } @@ -407,25 +632,14 @@ ENTRY %entrycomp (p: f32[2,2]) -> (f32[2,2], f32[2,2]) { auto crs_after = module->entry_computation()->root_instruction()->operands()[0]; auto replica_groups_after = crs_after->replica_groups(); - ASSERT_EQ(replica_groups_before.size(), replica_groups_after.size()); - for (int i = 0; i < replica_groups_before.size(); ++i) { - // Somewhat verbose way to compare the replica_ids, because EqualsProto - // is not available in the open-source build. - auto group_before = replica_groups_before[i]; - std::vector ids_before(group_before.replica_ids().begin(), - group_before.replica_ids().end()); - auto group_after = replica_groups_after[i]; - std::vector ids_after(group_after.replica_ids().begin(), - group_after.replica_ids().end()); - EXPECT_EQ(ids_before, ids_after); - } + CompareReplicaGroups(replica_groups_before, replica_groups_after); } TEST_F(ArCrsCombinerTest, OtherSummandNotTheSameDontRewrite) { const char* module_str = R"( HloModule foobar -%binary_add (a: bf16[], b: bf16[]) -> bf16[] { +%sum.bf16 (a: bf16[], b: bf16[]) -> bf16[] { %a = bf16[] parameter(0) %b = bf16[] parameter(1) ROOT %add = bf16[] add(%a, %b) @@ -437,49 +651,49 @@ HloModule foobar ROOT %add = f32[] add(%x, %y) } -ENTRY %entrycomp (p: f32[2,2]) -> (f32[2,2], f32[2,2]) { - %p = f32[2,2] parameter(0) - %constant.bf16 = bf16[2,2] constant(bf16[2,2] {{1, 2}, {3, 4}}) - %constant.f32.1 = f32[2,2] constant(f32[2,2] {{1, 2}, {3, 4}}) - %constant.f32.2 = f32[2,2] constant(f32[2,2] {{3, 4}, {5, 6}}) +ENTRY %entrycomp (p: f32[]) -> (f32[], f32[]) { + %p = f32[] parameter(0) + %constant.bf16 = bf16[] constant(1) + %constant.f32.1 = f32[] constant(2) + %constant.f32.2 = f32[] constant(3) - %cross-replica-sum.ar.1 = bf16[2,2] + %cross-replica-sum.ar.1 = bf16[] cross-replica-sum(%constant.bf16), replica_groups={{0},{1}}, all_reduce_id=1, - to_apply=%binary_add, + to_apply=%sum.bf16, sharding={maximal device=0} - %convert.1 = f32[2,2] + %convert.1 = f32[] convert(%cross-replica-sum.ar.1), sharding={maximal device=0} - %add.1 = f32[2,2] + %add.1 = f32[] add(%constant.f32.1, %convert.1), sharding={maximal device=0} - %cross-replica-sum.1 = f32[2,2] + %cross-replica-sum.1 = f32[] cross-replica-sum(%add.1), replica_groups={{0,1}}, to_apply=%sum.f32, sharding={maximal device=0} - %cross-replica-sum.ar.2 = bf16[2,2] + %cross-replica-sum.ar.2 = bf16[] cross-replica-sum(%constant.bf16), replica_groups={{0},{1}}, all_reduce_id=1, - to_apply=%binary_add, + to_apply=%sum.bf16, sharding={maximal device=1} - %convert.2 = f32[2,2] + %convert.2 = f32[] convert(%cross-replica-sum.ar.2), sharding={maximal device=1} - %add.2 = f32[2,2] + %add.2 = f32[] add(%constant.f32.2, %convert.2), sharding={maximal device=1} - %cross-replica-sum.2 = f32[2,2] + %cross-replica-sum.2 = f32[] cross-replica-sum(%add.2), replica_groups={{0,1}}, to_apply=%sum.f32, sharding={maximal device=1} - ROOT %tuple = (f32[2,2], f32[2,2]) + ROOT %tuple = (f32[], f32[]) tuple(%cross-replica-sum.1, %cross-replica-sum.2), sharding={{maximal device=0}, {maximal device=1}} } diff --git a/tensorflow/compiler/xla/service/cpu/BUILD b/tensorflow/compiler/xla/service/cpu/BUILD index ce4c2a9cc69240b9565b35a3f2504d7fc9373917..4173af5179ba096523db973ca7e0466faefda38a 100644 --- a/tensorflow/compiler/xla/service/cpu/BUILD +++ b/tensorflow/compiler/xla/service/cpu/BUILD @@ -572,6 +572,7 @@ cc_library( ":runtime_matvec", "//tensorflow/compiler/xla:executable_run_options", "//tensorflow/core:framework_lite", + "//tensorflow/core/kernels:eigen_contraction_kernel", "//third_party/eigen3", ], ) diff --git a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc index 6374822c81bf42fd12829f57cf93c19457128219..f3dfa4d64264808e0d5c9f86693bb844b2011964 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc @@ -635,18 +635,17 @@ StatusOr> CpuCompiler::RunBackend( .EmitComputation( embedded_computation, embedded_computation->name(), /*is_top_level_computation=*/false, - &schedule.sequence(embedded_computation).instructions()) + schedule.sequence(embedded_computation).instructions()) .status()); } string function_name_prefix = entry_computation->name().empty() ? "__compute" : entry_computation->name(); - TF_ASSIGN_OR_RETURN( - llvm::Function * entry_function, - ir_emitter.EmitComputation( - entry_computation, function_name_prefix, - /*is_top_level_computation=*/true, - &schedule.sequence(entry_computation).instructions())); + TF_ASSIGN_OR_RETURN(llvm::Function * entry_function, + ir_emitter.EmitComputation( + entry_computation, function_name_prefix, + /*is_top_level_computation=*/true, + schedule.sequence(entry_computation).instructions())); string function_name = [&]() { llvm::SmallVector function_name_vector; @@ -835,7 +834,7 @@ CpuCompiler::CompileAheadOfTime(std::unique_ptr module_group, .EmitComputation( embedded_computation, embedded_computation->name(), /*is_top_level_computation=*/false, - &schedule.sequence(embedded_computation).instructions()) + schedule.sequence(embedded_computation).instructions()) .status()); } const string& entry_point_name = options.entry_point_name(); @@ -843,7 +842,7 @@ CpuCompiler::CompileAheadOfTime(std::unique_ptr module_group, ir_emitter.EmitComputation( computation, entry_point_name, /*is_top_level_computation=*/true, - &schedule.sequence(computation).instructions())); + schedule.sequence(computation).instructions())); CHECK(entry_function->getName() == llvm_ir::AsStringRef(entry_point_name)); diff --git a/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc b/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc index 97f9b85a606e140fd7f3b1e3ecfb0dd5ba289f03..a33035ad1081d7d73ceed6ce3a208af5910d2d2c 100644 --- a/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc +++ b/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc @@ -323,11 +323,11 @@ void ColumnMajorMatrixVectorProductEmitter::Emit() { int64 column_remainder = k() % tile_cols(); int64 column_limit = k() - column_remainder; - ksl_.ForReturnVoid("dot.outer.tiled", - /*start=*/0, /*end=*/column_limit, /*step=*/tile_cols(), - [&](llvm::Value* column, bool is_first_column) { - EmitOuterLoopBody(column, tile_cols(), is_first_column); - }); + ksl_.For("dot.outer.tiled", + /*start=*/0, /*end=*/column_limit, /*step=*/tile_cols(), + [&](llvm::Value* column, bool is_first_column) { + EmitOuterLoopBody(column, tile_cols(), is_first_column); + }); if (column_remainder != 0) { EmitOuterLoopBody(b_->getInt64(column_limit), column_remainder, @@ -340,7 +340,7 @@ void ColumnMajorMatrixVectorProductEmitter::EmitInnerLoopTiled( int64 columns, bool is_first_column) { int64 row_limit = m() - (m() % tile_rows()); - ksl_.ForReturnVoid( + ksl_.For( "dot.inner.tiled", /*start=*/0, /*end=*/row_limit, /*step=*/tile_rows(), [&](llvm::Value* row) { std::vector lhs_tile = @@ -372,7 +372,7 @@ void ColumnMajorMatrixVectorProductEmitter::EmitInnerLoopEpilogue( // // initialized. // } - ksl_.ForReturnVoid( + ksl_.For( "dot.inner.epilg.outer", /*start=*/current_tile_col, /*end=*/b_->CreateAdd(columns_llvm, current_tile_col), /*step=*/1, /*peel_first_iteration=*/false, @@ -381,14 +381,14 @@ void ColumnMajorMatrixVectorProductEmitter::EmitInnerLoopEpilogue( llvm::Value* total_offset = b_->CreateMul(col, b_->getInt64(m())); llvm::Value* lhs_base_pointer = vsl_.ComputeOffsetPointer(lhs_, total_offset); - ksl_.ForReturnVoid( + ksl_.For( "dot.inner.epilg.inner", /*start=*/row_start, /*end=*/m(), /*step=*/1, [&](llvm::Value* scalar_row) { llvm::Value* product = vsl_.Mul( vsl_.LoadScalar(lhs_base_pointer, scalar_row), rhs_element); llvm::Value* setting_result_first_time = b_->CreateAnd( is_first_scalar_col, b_->getInt1(is_first_tiled_column)); - ksl_.IfReturnVoid( + ksl_.If( setting_result_first_time, /*true_block_generator=*/ [&]() { @@ -568,10 +568,9 @@ void RowMajorMatrixVectorProductEmitter::Emit() { int64 row_remainder = m() % tile_rows(); int64 row_limit = m() - row_remainder; - ksl_.ForReturnVoid( - "dot.outer.tiled", - /*start=*/0, /*end=*/row_limit, /*step=*/tile_rows(), - [&](llvm::Value* row) { EmitOuterLoopBody(row, tile_rows()); }); + ksl_.For("dot.outer.tiled", + /*start=*/0, /*end=*/row_limit, /*step=*/tile_rows(), + [&](llvm::Value* row) { EmitOuterLoopBody(row, tile_rows()); }); if (row_remainder != 0) { EmitOuterLoopBody(b_->getInt64(row_limit), row_remainder); @@ -583,17 +582,17 @@ void RowMajorMatrixVectorProductEmitter::EmitInnerLoopTiled( std::vector* vector_accumulators) { int64 column_limit = k() - (k() % tile_cols()); - ksl_.ForReturnVoid("dot.inner.tiled", /*start=*/0, /*end=*/column_limit, - /*step=*/tile_cols(), [&](llvm::Value* col) { - std::vector lhs_tile = - lhs_memory_tile->LoadTile(/*minor_dim_offset=*/col); - llvm::Value* rhs_value = vsl_.LoadVector(rhs_, col); - for (int i = 0; i < rows; i++) { - llvm::Value* old_sum = (*vector_accumulators)[i].Get(); - (*vector_accumulators)[i].Set(vsl_.Add( - old_sum, vsl_.Mul(rhs_value, lhs_tile[i]))); - } - }); + ksl_.For("dot.inner.tiled", /*start=*/0, /*end=*/column_limit, + /*step=*/tile_cols(), [&](llvm::Value* col) { + std::vector lhs_tile = + lhs_memory_tile->LoadTile(/*minor_dim_offset=*/col); + llvm::Value* rhs_value = vsl_.LoadVector(rhs_, col); + for (int i = 0; i < rows; i++) { + llvm::Value* old_sum = (*vector_accumulators)[i].Get(); + (*vector_accumulators)[i].Set( + vsl_.Add(old_sum, vsl_.Mul(rhs_value, lhs_tile[i]))); + } + }); } void RowMajorMatrixVectorProductEmitter::EmitInnerLoopEpilogue( @@ -609,7 +608,7 @@ void RowMajorMatrixVectorProductEmitter::EmitInnerLoopEpilogue( b_->CreateAdd(b_->getInt64(r), current_tile_row), b_->getInt64(k())); llvm::Value* lhs_base_pointer = vsl_.ComputeOffsetPointer(lhs_, total_offset); - ksl_.ForReturnVoid( + ksl_.For( "dot.inner.epilg.inner", /*start=*/column_start, /*end=*/k(), /*step=*/1, [&](llvm::Value* scalar_col) { llvm::Value* product = @@ -813,7 +812,7 @@ void TiledSmallGemmEmitter::HandleResiduesOnN() { if (n_start != dims().n()) { VectorSupportLibrary vsl(scalar_type(), 1, b_, "gemm"); - ksl_.ForReturnVoid("epi.n", n_start, dims().n(), 1, [&](llvm::Value* n_i) { + ksl_.For("epi.n", n_start, dims().n(), 1, [&](llvm::Value* n_i) { llvm::Value* n_i_next = b_->CreateAdd(n_i, b_->getInt64(1)); HandleResiduesOnK(&vsl, n_i, n_i_next); }); @@ -924,7 +923,7 @@ void TiledSmallGemmEmitter::EmitTiledGemm( VectorSupportLibrary* vsl, int64 tile_size_k, llvm::Value* k_start, llvm::Value* k_end, llvm::Value* n_start, llvm::Value* n_end, int64 tile_size_m, llvm::Value* m_start, llvm::Value* m_end) { - ksl_.ForReturnVoid( + ksl_.For( "dot.m", m_start, m_end, tile_size_m, [&](llvm::Value* m_i) { MemoryTile result_memory_tile( vsl, b_, /*matrix=*/result_, @@ -935,11 +934,11 @@ void TiledSmallGemmEmitter::EmitTiledGemm( /*matrix_size_along_minor_dim=*/dims().k(), /*major_dim_offset=*/m_i, /*tile_size_along_major_dim=*/tile_size_m); - ksl_.ForReturnVoid( + ksl_.For( "dot.n", n_start, n_end, vsl->vector_size(), [&](llvm::Value* n_i) { TileVariable result_tile_var(vsl, result_memory_tile.LoadTile(n_i)); - ksl_.ForReturnVoid( + ksl_.For( "dot.k", k_start, k_end, tile_size_k, [&](llvm::Value* k_i) { MemoryTile rhs_memory_tile(vsl, b_, rhs_, dims().n(), k_i, tile_size_k); diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc index 4032c2da2f33ee61da8771ae6225a14172cbe6e8..62a4e8d3507a4e678e80c1abea680c030d048de5 100644 --- a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc +++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc @@ -111,10 +111,9 @@ IrEmitter::IrEmitter( StatusOr IrEmitter::EmitComputation( HloComputation* computation, const string& function_name_prefix, bool is_top_level_computation, - const std::vector* instruction_order) { + absl::Span instruction_order) { string function_name = name_uniquer_.GetUniqueName(function_name_prefix); - VLOG(2) << "Emitting IR for CPU function [" << function_name_prefix - << "]; ordered? " << (instruction_order != nullptr); + VLOG(2) << "Emitting IR for CPU function [" << function_name_prefix << "]"; is_top_level_computation_ = is_top_level_computation; num_dynamic_loop_bounds_ = 0; if (!computation->root_instruction()->outer_dimension_partitions().empty()) { @@ -141,11 +140,7 @@ StatusOr IrEmitter::EmitComputation( bool use_rdtscp = arch_type_ == llvm::Triple::ArchType::x86 || arch_type_ == llvm::Triple::ArchType::x86_64; profiling_state_ = ProfilingState(use_rdtscp); - if (instruction_order == nullptr) { - TF_RETURN_IF_ERROR(computation->Accept(this)); - } else { - TF_RETURN_IF_ERROR(computation->AcceptOrdered(this, *instruction_order)); - } + TF_RETURN_IF_ERROR(computation->AcceptOrdered(this, instruction_order)); llvm::Function* ir_function = compute_function_->function(); InsertOrDie(&emitted_functions_, computation, ir_function); // Delete 'compute_function', finalizing 'ir_function' and restoring caller @@ -2271,6 +2266,22 @@ Status IrEmitter::HandleCustomCall(HloInstruction* custom_call) { /*isVarArg=*/false))); TF_RETURN_IF_ERROR(EmitTargetAddressForOp(custom_call)); + // Write the tuple table if the output is a tuple. + if (ShapeUtil::IsTuple(custom_call->shape())) { + std::vector base_ptrs; + for (int i = 0; i < ShapeUtil::TupleElementCount(custom_call->shape()); + ++i) { + const Shape& elem_shape = + ShapeUtil::GetTupleElementShape(custom_call->shape(), i); + TF_RET_CHECK(!ShapeUtil::IsTuple(elem_shape)) + << "Nested tuples not implemented"; + TF_ASSIGN_OR_RETURN(const BufferAllocation::Slice slice, + assignment_.GetUniqueSlice(custom_call, {i})); + llvm::Value* addr = EmitBufferPointer(slice, elem_shape); + base_ptrs.push_back(addr); + } + llvm_ir::EmitTuple(GetIrArrayFor(custom_call), base_ptrs, &b_, module_); + } auto* output_address_arg = PointerCast(GetEmittedValueFor(custom_call), i8_ptr_type); diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.h b/tensorflow/compiler/xla/service/cpu/ir_emitter.h index 559a8162a2d53f28ea6817653503c216af90a610..1db75cc8becea80f121289a843d4eb16ee9a8c8a 100644 --- a/tensorflow/compiler/xla/service/cpu/ir_emitter.h +++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.h @@ -101,7 +101,7 @@ class IrEmitter : public DfsHloVisitorWithDefault, StatusOr EmitComputation( HloComputation* computation, const string& function_name_prefix, bool is_top_level_computation, - const std::vector* instruction_order); + absl::Span instruction_order); llvm::IRBuilder<>* b() { return &b_; } diff --git a/tensorflow/compiler/xla/service/cpu/parallel_task_assignment_test.cc b/tensorflow/compiler/xla/service/cpu/parallel_task_assignment_test.cc index f0b65046c14ccec5336abf7c4d05d1d755f783bd..35ae62b42dfa768c6abd0508097d6b235b2ebf54 100644 --- a/tensorflow/compiler/xla/service/cpu/parallel_task_assignment_test.cc +++ b/tensorflow/compiler/xla/service/cpu/parallel_task_assignment_test.cc @@ -112,10 +112,10 @@ TEST_F(ParallelTaskAssignmentTest, InfeedOutfeedOperationNotParallelized) { const string hlo_string = R"( HloModule TestTaskParallel_infeed_outfeed ENTRY InfeedOutfeed { - token = token[] after-all() - infeed0 = (u32[12345678,2]{1,0}, token[]) infeed(token) + token0 = token[] after-all() + infeed0 = (u32[12345678,2]{1,0}, token[]) infeed(token0) infeed0.data = u32[12345678,2]{1,0} get-tuple-element((u32[12345678,2]{1,0}, token[]) infeed0), index=0 - ROOT outfeed0 = token[] outfeed(infeed0.data, token) + ROOT outfeed0 = token[] outfeed(infeed0.data, token0) } )"; diff --git a/tensorflow/compiler/xla/service/cpu/runtime_matmul.cc b/tensorflow/compiler/xla/service/cpu/runtime_matmul.cc index a71a85913cfef271bc2a226cb0cf2dd4204499a4..56f018abdd496e804dc4dea5420d400175491db3 100644 --- a/tensorflow/compiler/xla/service/cpu/runtime_matmul.cc +++ b/tensorflow/compiler/xla/service/cpu/runtime_matmul.cc @@ -23,6 +23,10 @@ limitations under the License. #include "tensorflow/core/platform/dynamic_annotations.h" #include "tensorflow/core/platform/types.h" +#if defined(TENSORFLOW_USE_CUSTOM_CONTRACTION_KERNEL) +#include "tensorflow/core/kernels/eigen_contraction_kernel.h" +#endif + using tensorflow::int32; using tensorflow::int64; diff --git a/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc b/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc index efccadedf27181a4cddf4f1dc3610f7c6db1d821..296f39a4853f2d3f7030209a921001e92c39d609 100644 --- a/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc +++ b/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc @@ -139,7 +139,7 @@ llvm::JITSymbol SimpleOrcJIT::ResolveRuntimeSymbol(const std::string& name) { } if (func_addr == nullptr) { - VLOG(2) << "Unable to resolve runtime symbol: " << name; + LOG(ERROR) << "Unable to resolve runtime symbol: " << name; return nullptr; } llvm::JITEvaluatedSymbol symbol_info(reinterpret_cast(func_addr), @@ -296,6 +296,9 @@ bool RegisterKnownJITSymbols() { REGISTER_LIBM_SYMBOL(sin, double (*)(double)); #ifdef __APPLE__ REGISTER_LIBM_SYMBOL(__sincos, void (*)(double, double*, double*)); + registry->Register("__sincosf_stret", + reinterpret_cast(__sincosf_stret)); + registry->Register("__sincos_stret", reinterpret_cast(__sincos_stret)); #else REGISTER_LIBM_SYMBOL(sincos, void (*)(double, double*, double*)); #endif @@ -311,6 +314,13 @@ bool RegisterKnownJITSymbols() { registry->Register("memcpy", reinterpret_cast(memcpy)); registry->Register("memmove", reinterpret_cast(memmove)); registry->Register("memset", reinterpret_cast(memset)); + +#ifdef __APPLE__ + registry->Register("__bzero", reinterpret_cast(bzero)); + registry->Register("memset_pattern16", + reinterpret_cast(memset_pattern16)); +#endif + return true; } diff --git a/tensorflow/compiler/xla/service/cpu/tests/cpu_literal_caching_test.cc b/tensorflow/compiler/xla/service/cpu/tests/cpu_literal_caching_test.cc index fa0e09ff6b5694c0e97963b83c6e541b858a1376..0584c0484f810a03ccccd522163f54535440ef8b 100644 --- a/tensorflow/compiler/xla/service/cpu/tests/cpu_literal_caching_test.cc +++ b/tensorflow/compiler/xla/service/cpu/tests/cpu_literal_caching_test.cc @@ -31,29 +31,27 @@ HloModule RepeatedConstants while_body { arg_body = f32[2,3,2] parameter(0) ROOT const = f32[2,3,2] constant( - f32[2,3,2] {{{1, 2}, {1001, 1002}, {2001, 2002}}, {{2, 1}, {2001, 3002}, {2001, 2002}}}) } while_cond { arg_cond = f32[2,3,2] parameter(0) - token = token[] after-all() - infeed = (pred[], token[]) infeed(token) + token0 = token[] after-all() + infeed = (pred[], token[]) infeed(token0) ROOT unknown = pred[] get-tuple-element((pred[], token[]) infeed), index=0 } ENTRY main { param = f32[2,3,2] parameter(0) const_a = f32[2,3,2] constant( - f32[2,3,2] {{{1, 2}, {1001, 1002}, {2001, 2002}}, {{2, 1}, {2001, 3002}, {2001, 2002}}}) const_b = f32[2,3,2] while(f32[2,3,2] const_a), condition=while_cond, body=while_body - token = token[] after-all() - out0 = token[] outfeed(f32[2,3,2] const_a, token[] token) - ROOT out1 = token[] outfeed(f32[2,3,2] const_b, token[] token) + token0 = token[] after-all() + out0 = token[] outfeed(f32[2,3,2] const_a, token[] token0) + ROOT out1 = token[] outfeed(f32[2,3,2] const_b, token[] token0) } )"; @@ -82,24 +80,24 @@ HloModule RepeatedConstants while_body { arg_body = (f32[2,1]{1,0}, f32[1]{0}) parameter(0) - ROOT const = (f32[2,1]{1,0}, f32[1]{0}) constant((f32[2,1], f32[1]) ( f32[2,1] { { 1 }, { 2 } }, {2} )) + ROOT const = (f32[2,1]{1,0}, f32[1]{0}) constant(({ { 1 }, { 2 } }, {2} )) } while_cond { arg_cond = (f32[2,1]{1,0}, f32[1]{0}) parameter(0) - token = token[] after-all() - infeed = (pred[], token[]) infeed(token) + token0 = token[] after-all() + infeed = (pred[], token[]) infeed(token0) ROOT unknown = pred[] get-tuple-element((pred[], token[]) infeed), index=0 } ENTRY main { param = f32[2,3,2] parameter(0) - const_a = (f32[2,1]{1,0}, f32[1]{0}) constant((f32[2,1], f32[1]) ( f32[2,1] { { 1 }, { 2 } }, {2} )) + const_a = (f32[2,1]{1,0}, f32[1]{0}) constant(( { { 1 }, { 2 } }, {2} )) const_b = (f32[2,1]{1,0}, f32[1]{0}) while((f32[2,1]{1,0}, f32[1]{0}) const_a), condition=while_cond, body=while_body - token = token[] after-all() - out0 = () outfeed((f32[2,1]{1,0}, f32[1]{0}) const_a, token[] token) - ROOT out1 = () outfeed((f32[2,1]{1,0}, f32[1]{0}) const_b, token[] token) + token0 = token[] after-all() + out0 = () outfeed((f32[2,1]{1,0}, f32[1]{0}) const_a, token[] token0) + ROOT out1 = () outfeed((f32[2,1]{1,0}, f32[1]{0}) const_b, token[] token0) } )"; diff --git a/tensorflow/compiler/xla/service/cpu/tests/cpu_outfeed_test.cc b/tensorflow/compiler/xla/service/cpu/tests/cpu_outfeed_test.cc index e2c7af541eede5265f274c72f55305549f059839..aab7f0b393881642437f1891256bd138823a3b87 100644 --- a/tensorflow/compiler/xla/service/cpu/tests/cpu_outfeed_test.cc +++ b/tensorflow/compiler/xla/service/cpu/tests/cpu_outfeed_test.cc @@ -28,12 +28,11 @@ HloModule Outfeed ENTRY main { const_a = f32[2,3,2] constant( - f32[2,3,2] {{{1, 2}, {1001, 1002}, {2001, 2002}}, {{2, 1}, {2001, 3002}, {2001, 2002}}}) - token = token[] after-all() - outfeed = token[] outfeed(f32[2,3,2] const_a, token) + token0 = token[] after-all() + outfeed = token[] outfeed(f32[2,3,2] const_a, token0) ROOT root = () tuple() } )"; diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_conv_rewriter_test.cc b/tensorflow/compiler/xla/service/gpu/cudnn_conv_rewriter_test.cc index 443883a89f66a747def1049bc5afb53fec3c2409..73af18f87aeeedaefac4fc37fb7b6f78f506bb4f 100644 --- a/tensorflow/compiler/xla/service/gpu/cudnn_conv_rewriter_test.cc +++ b/tensorflow/compiler/xla/service/gpu/cudnn_conv_rewriter_test.cc @@ -599,7 +599,7 @@ TEST_F(CudnnConvRewriterTest, BackwardInputConvolveConstantFilter) { Array4D constant_arr(4, 4, 2, 2); constant_arr.FillIota(0); string constant_str = - LiteralUtil::CreateR4FromArray4D(constant_arr).ToString(); + LiteralUtil::CreateR4FromArray4D(constant_arr).ToStringWithoutShape(); const string module_str = absl::StrFormat(R"( HloModule test diff --git a/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment_test.cc b/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment_test.cc index 2ffc8bfb49b205dced0d540ba72426e72d95e596..29756d27260b0f41b2dd4b649ea9b1610ff90268 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment_test.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment_test.cc @@ -369,7 +369,7 @@ TEST_F(LayoutAssignmentTest, SortLayout) { const char* hlo_text = R"( HloModule SortLayout ENTRY sort { - keys = f32[3,2]{0,1} constant(f32[3,2]{0,1}{{0,1},{0,1},{0,1}}) + keys = f32[3,2]{0,1} constant({{0,1},{0,1},{0,1}}) values = f32[2,3]{1,0} parameter(0) transpose = f32[3,2]{1,0} transpose(values), dimensions={1,0} ROOT sort = (f32[3,2]{1,0}, f32[3,2]{1,0}) sort(keys, transpose), diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc index fb040aff30d48bf5817946ce53d37bc6685941e4..87d16c0afcc3c115f652558b5d8c24606ff56733 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc @@ -22,7 +22,6 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h" #include "absl/algorithm/container.h" -#include "absl/container/inlined_vector.h" #include "absl/memory/memory.h" #include "absl/strings/str_cat.h" #include "absl/types/optional.h" @@ -548,91 +547,7 @@ Status IrEmitterUnnested::HandleFusion(HloInstruction* fusion) { // TODO(b/112040122): Support variadic reduce. return Unimplemented("Variadic reduce is not supported on GPU"); } - VLOG(3) << "Emitting fused reduction to vector: " << fusion->ToString(); - std::vector> thunks; - absl::Span output_instructions = - root->opcode() == HloOpcode::kTuple - ? root->operands() - : absl::Span(&root, 1); - - // For multi-output fusion emit an initializer for each tuple element. - // Otherwise it's sufficient to just initialize the single output. - HloInstruction* first_reduce = nullptr; - for (int i = 0, e = output_instructions.size(); i != e; ++i) { - if (output_instructions[i]->opcode() == HloOpcode::kReduce) { - TF_ASSIGN_OR_RETURN( - std::unique_ptr initializer_thunk, - BuildInitializerThunk(fusion, output_instructions[i] == root - ? ShapeIndex() - : ShapeIndex({i}))); - thunks.push_back(std::move(initializer_thunk)); - first_reduce = - first_reduce == nullptr ? output_instructions[i] : first_reduce; - } - } - CHECK(first_reduce != nullptr); - std::unique_ptr kernel_thunk = - BuildKernelThunk(fusion, /*implements_whole_instruction=*/false); - GpuElementalIrEmitter elemental_emitter( - hlo_module_config_, ir_emitter_context_->llvm_module(), &b_, - GetNestedComputer()); - FusedIrEmitter fused_emitter(GetGeneratorForOperandIrArrays(fusion), - &elemental_emitter); - TF_RETURN_IF_ERROR(root->Accept(&fused_emitter)); - - // For multi-output fusion CHECK the constraints and feed all the - // reduces into a single loop code generator. Single-output reduce - // fusion is a special case of that. - InlinedVector input_gens; - InlinedVector init_value_gens; - std::vector> - extra_output_gens; - InlinedVector reducers; - InlinedVector reduce_output_shapes; - for (int i = 0, e = output_instructions.size(); i != e; ++i) { - const HloInstruction* inst = output_instructions[i]; - ShapeIndex output_shape_index; - if (root->opcode() == HloOpcode::kTuple) { - output_shape_index = {i}; - } - if (inst->opcode() == HloOpcode::kReduce) { - CHECK(IsReductionToVector(*inst)) - << "Only reductions to vector are supported"; - // Shapes, layouts and dimensions must be the same for all reduces - // inside of this fusion. - CHECK(ShapeUtil::Equal(first_reduce->shape(), inst->shape())); - CHECK(ShapeUtil::Equal(first_reduce->operand(0)->shape(), - inst->operand(0)->shape())); - CHECK(ShapeUtil::Equal(first_reduce->operand(1)->shape(), - inst->operand(1)->shape())); - CHECK(first_reduce->dimensions() == inst->dimensions()); - input_gens.push_back(fused_emitter.GetGenerator(inst->operand(0))); - init_value_gens.push_back( - fused_emitter.GetGenerator(inst->operand(1))); - reducers.push_back(inst->to_apply()); - reduce_output_shapes.push_back(std::move(output_shape_index)); - } else { - // For extra outputs we can relax shape equality to allow different - // types (with the same number of elements). Layouts still have to - // match. - CHECK(ShapeUtil::CompatibleIgnoringElementType( - first_reduce->operand(0)->shape(), inst->shape())); - CHECK(LayoutUtil::Equal(first_reduce->operand(0)->shape().layout(), - inst->shape().layout())); - extra_output_gens.emplace_back(fused_emitter.GetGenerator(inst), - std::move(output_shape_index)); - } - } - const Shape& input_shape = first_reduce->operand(0)->shape(); - TF_CHECK_OK(EmitReductionToVector( - kernel_thunk.get(), first_reduce, input_shape, input_gens, - init_value_gens, first_reduce->dimensions(), reducers, - reduce_output_shapes, extra_output_gens)); - thunks.push_back(std::move(kernel_thunk)); - std::unique_ptr sequential_thunk = - absl::make_unique(std::move(thunks), fusion); - AddThunkToThunkSequence(std::move(sequential_thunk)); - return Status::OK(); + return EmitReductionToVector(fusion); } default: LOG(FATAL) << "Bad opcode for input fusion: " @@ -702,13 +617,12 @@ Status IrEmitterUnnested::HandleCopy(HloInstruction* copy) { } Status IrEmitterUnnested::EmitExtraOutputsForReduce( - const HloInstruction* reduce, const IrArray::Index& index, + const HloInstruction* unnested_hlo, const IrArray::Index& index, absl::Span> extra_output_gens) { for (int i = 0; i != extra_output_gens.size(); ++i) { - const HloInstruction* output = reduce->parent()->FusionInstruction(); llvm::Value* extra_output_address = - GetIrArray(*output, *output, extra_output_gens[i].second) + GetIrArray(*unnested_hlo, *unnested_hlo, extra_output_gens[i].second) .EmitArrayElementAddress(index, &b_, "extra_output_element_address"); TF_ASSIGN_OR_RETURN(llvm::Value* const extra_output_ir_value, @@ -718,984 +632,13 @@ Status IrEmitterUnnested::EmitExtraOutputsForReduce( return Status::OK(); } -Status IrEmitterUnnested::EmitReductionToScalar( - KernelThunk* kernel_thunk, HloInstruction* reduce, const Shape& input_shape, - absl::Span input_gens, - absl::Span init_value_gens, - absl::Span reducers, - absl::Span reduce_output_shapes, - absl::Span> - extra_output_gens) { - // Number of elements processed by a single thread. - constexpr int64 kTileSize = 16; - int64 num_elems = ShapeUtil::ElementsIn(input_shape); - - // Round up the number of tiles to a multiple of the warp size. This is - // necessary for correctness. We launch one thread per tile, and if the - // number of threads isn't a multiple of the number of the warp size, our - // shuffles will read from inactive threads, producing undefined values. - int64 num_tiles = - RoundUpToNearest(CeilOfRatio(num_elems, kTileSize), kWarpSize); - - Shape tiled_input_shape = ShapeUtil::MakeShapeWithLayout( - reduce->shape().element_type(), {num_tiles}, {0}); - LaunchDimensions launch_dimensions = CalculateLaunchDimensions( - tiled_input_shape, ir_emitter_context_->device_description()); - - llvm::Type* index_ty = - GetIndexTypeForKernel(reduce, launch_dimensions.launch_bound(), &b_); - - auto index_typed_constant = [&](uint64 c) -> llvm::Constant* { - return llvm::ConstantInt::get(index_ty, c); - }; - - // Check whether every thread will process a full tile's worth of elements - // without reading outside the bounds of the input. If this is true, we can - // skip some bounds checks in the final algorithm. - bool all_threads_in_bounds = num_tiles * kTileSize == num_elems; - - // __global__ void full_reduce_kernel() { - // x_in_tiles = threadIdx.x + blockIdx.x * blockDim.x; - // x = x_in_tiles * kTileSize; - // - // partial_result = init_value; - // if (all_threads_in_bounds || x + kTileSize <= num_elems) { - // for (i = 0; i < kTileSize; ++i) { - // partial_result = Reducer(partial_result, input[x + i]); - // } - // } else { - // for (i = 0; i < kTileSize; ++i) { - // if (x + i < num_elems) { - // partial_result = Reducer(partial_result, input[x + i]); - // } - // } - // } - // for (i = warpSize / 2; i > 0; i /= 2) { - // partial_result = Reducer(partial_result, - // __shfl_down(partial_result, i)); - // } - // if (lane_id == 0) { - // AtomicReducer(&output[y], partial_result); - // } - // } - // - // // Choose num_blocks and threads_per_block such that: - // // - // // num_blocks * threads_per_block = - // // RoundUpToNextMultipleOf(Ceil(num_elems / kTileSize), warpSize), - // // - // // and threads_per_block is a multiple of warpSize. - // reduce_kernel // - auto loop_body_emitter = [=](const IrArray::Index& tile_index) -> Status { - const int num_reduces = reducers.size(); - llvm::Type* element_ir_type = - llvm_ir::PrimitiveTypeToIrType(input_shape.element_type(), module_); - std::vector partial_reduction_result_addresses; - for (int i = 0; i != num_reduces; ++i) { - llvm::Value* partial_reduction_result_address = - Alloca(element_ir_type, /*ArraySize=*/nullptr, - "partial_reduction_result." + llvm::Twine(i)); - TF_ASSIGN_OR_RETURN(llvm::Value* const init_ir_value, - init_value_gens[i](IrArray::Index(index_ty))); - Store(init_ir_value, partial_reduction_result_address); - partial_reduction_result_addresses.push_back( - partial_reduction_result_address); - } - - llvm::Value* x_in_tiles = tile_index[0]; - x_in_tiles = ZExtOrTrunc(x_in_tiles, index_ty); - - // Emit an inner for-loop that reduces the elements in the tile. - auto emit_tile_element_loop = [=](bool tile_in_bounds) -> Status { - std::unique_ptr tile_element_loop = - llvm_ir::ForLoop::EmitForLoop( - "element_id_in_tile", index_typed_constant(0), - index_typed_constant(kTileSize), index_typed_constant(1), &b_); - - // Emit the body of the partial reduction loop. - llvm_ir::SetToFirstInsertPoint(tile_element_loop->GetBodyBasicBlock(), - &b_); - llvm::Value* x = - NSWAdd(NSWMul(x_in_tiles, index_typed_constant(kTileSize)), - tile_element_loop->GetIndVarValue()); - // Unless we know the tile is entirely in bounds, we have to emit a - // x-in-bounds check before reading from the input. - if (!tile_in_bounds) { - llvm_ir::LlvmIfData if_data = llvm_ir::EmitIfThenElse( - ICmpULT(x, index_typed_constant(num_elems)), "x_in_bounds", &b_); - - // Emit code that reads the input element and accumulates it to - // the partial reduction result. - llvm_ir::SetToFirstInsertPoint(if_data.true_block, &b_); - } - - IrArray::Index input_index( - /*linear=*/x, input_shape, &b_); - llvm::Value* input_address = Alloca(element_ir_type); - for (int i = 0; i != num_reduces; ++i) { - TF_ASSIGN_OR_RETURN(llvm::Value* const input_ir_value, - input_gens[i](input_index)); - Store(input_ir_value, input_address); - TF_RETURN_IF_ERROR(EmitCallToNestedComputation( - *reducers[i], - {partial_reduction_result_addresses[i], input_address}, - partial_reduction_result_addresses[i])); - } - return EmitExtraOutputsForReduce(reduce, input_index, extra_output_gens); - }; - - // x_end = kTileSize + x_in_tiles * kTileSize, i.e., the location that's - // immediately beyond the tile. - llvm::Value* x_end = - NSWAdd(index_typed_constant(kTileSize), - NSWMul(x_in_tiles, index_typed_constant(kTileSize))); - // The tile is entirely in bound if all_threads_in_bounds or - // x_end <= num_elems. - llvm::Value* tile_in_bounds = - Or(ICmpULE(x_end, index_typed_constant(num_elems)), - b_.getInt1(all_threads_in_bounds)); - llvm_ir::LlvmIfData if_tile_in_bounds_data = - llvm_ir::EmitIfThenElse(tile_in_bounds, "tile_in_bounds", &b_); - llvm_ir::SetToFirstInsertPoint(if_tile_in_bounds_data.true_block, &b_); - TF_RETURN_IF_ERROR(emit_tile_element_loop(/*tile_in_bounds=*/true)); - llvm_ir::SetToFirstInsertPoint(if_tile_in_bounds_data.false_block, &b_); - TF_RETURN_IF_ERROR(emit_tile_element_loop(/*tile_in_bounds=*/false)); - - // After the if-then-else statement on tile_in_bounds, emit calls to - // shfl_down that accumulate the partial reduction results of all threads - // from the warp. - llvm_ir::SetToFirstInsertPoint(if_tile_in_bounds_data.after_block, &b_); - int bit_width = llvm_ir::GetSizeInBits(element_ir_type); - // bitcast cannot be applied to aggregate types (even packed ones), so we - // instead bitcast addresses of load/store to intN* of the same bit-width. - llvm::Type* shuffle_ir_type = element_ir_type->isStructTy() - ? b_.getIntNTy(bit_width) - : element_ir_type; - for (int shuffle_distance = kWarpSize / 2; shuffle_distance >= 1; - shuffle_distance /= 2) { - llvm::Value* result_from_other_lane = - Alloca(element_ir_type, nullptr, "result_from_other_lane"); - for (int i = 0; i != num_reduces; ++i) { - llvm::Value* partial_reduction_result = - Load(BitCast(partial_reduction_result_addresses[i], - shuffle_ir_type->getPointerTo()), - "partial_reduction_result"); - CHECK_EQ(launch_dimensions.threads_per_block() % kWarpSize, 0) - << "Requires block size a multiple of the warp size, otherwise we " - "will read undefined elements."; - Store(EmitFullWarpShuffleDown(partial_reduction_result, - b_.getInt32(shuffle_distance), &b_), - BitCast(result_from_other_lane, shuffle_ir_type->getPointerTo())); - TF_RETURN_IF_ERROR(EmitCallToNestedComputation( - *reducers[i], - {partial_reduction_result_addresses[i], result_from_other_lane}, - partial_reduction_result_addresses[i])); - } - } - - const HloInstruction* output = - reduce->IsFused() ? reduce->parent()->FusionInstruction() : reduce; - - // Emit an atomic operation that accumulates the partial reduction result of - // lane 0 (which holds the partially accumulated result for its warp) to the - // output element. - llvm::Value* lane_id = - URem(x_in_tiles, index_typed_constant(kWarpSize), "lane_id"); - llvm_ir::LlvmIfData if_lane_id_is_zero_data = llvm_ir::EmitIfThenElse( - ICmpEQ(lane_id, index_typed_constant(0)), "lane_id_is_zero", &b_); - llvm_ir::SetToFirstInsertPoint(if_lane_id_is_zero_data.true_block, &b_); - - for (int i = 0; i != num_reduces; ++i) { - llvm::Value* output_address = - GetIrArray(*output, *output, reduce_output_shapes[i]) - .EmitArrayElementAddress( - IrArray::Index( - /*linear=*/b_.getInt64(0), - ShapeUtil::GetSubshape(output->shape(), - reduce_output_shapes[i]), - &b_), - &b_, "output_element_address"); - TF_RETURN_IF_ERROR(EmitAtomicOperationForNestedComputation( - *reducers[i], output_address, partial_reduction_result_addresses[i])); - } - return Status::OK(); - }; - - // Emit a parallel loop that iterates through all input tiles, one per thread. - UpdateLaunchDimensions(launch_dimensions, kernel_thunk, - ir_emitter_context_->llvm_module()); - return ParallelLoopEmitter(loop_body_emitter, tiled_input_shape, - launch_dimensions, &b_) - .EmitLoop(IrName(reduce), index_ty); -} - -Status IrEmitterUnnested::EmitColumnReduction( - KernelThunk* kernel_thunk, int64 height, int64 width, - HloInstruction* reduce, const Shape& input_shape, - absl::Span input_gens, - absl::Span init_value_gens, - absl::Span reducers, - absl::Span reduce_output_shapes, - absl::Span> - extra_output_gens) { - // Divide the input matrix into tiles of size KxL. For example, when the - // input matrix is 4x4, K=2, and L=1 the tiled matrix looks like - // - // 0123 - // 0123 - // 4567 - // 4567 // Numbers indicate tile IDs. - // - // Each tile is first partially reduced to a scalar by a thread, and then the - // scalar is accumulated to the output vector using atomic operations. - // - // We choose 128 as the tile size based on empirical evidence. It's big enough - // to reduce the amount of atomic adds in the end, maximizing the memory - // bandwidth. A tile width of 2 allows for high memory bandwidth utilization - // on 16b input data. - constexpr int64 kTileHeight = 128; - constexpr int64 kTileWidth = 2; - - // If the height is not a multiple of kTileHeight, we pad the bottom of the - // input matrix. - const int64 height_in_tiles = CeilOfRatio(height, kTileHeight); - // If width is not a multiple of kTileWidth the rightmost thread will process - // fewer input elements. - const int64 width_in_tiles = CeilOfRatio(width, kTileWidth); - Shape tiled_input_shape = - ShapeUtil::MakeShapeWithLayout(reduce->shape().element_type(), - {height_in_tiles, width_in_tiles}, {1, 0}); - LaunchDimensions launch_dimensions = CalculateLaunchDimensions( - tiled_input_shape, ir_emitter_context_->device_description()); - - // TODO(b/110211620): Convert to use i32 index_type when it is possible. - llvm::Type* index_ty = b_.getInt64Ty(); - - auto index_typed_constant = [&](uint64 c) -> llvm::Constant* { - return llvm::ConstantInt::get(index_ty, c); - }; - - // for (linear_index = threadIdx.x + blockIdx.x * blockDim.x; - // linear_index < height_in_tiles * width_in_tiles; - // linear_index += blockDim.x * gridDim.x) { - // y_in_tiles = linear_index / width_in_tiles; - // x_in_tiles = linear_index % width_in_tiles; - // - // partial_results[kTileWidth] = init_values; - // tile_in_y_bounds = height % kTileHeight == 0 || - // y_in_tiles * kTileHeight + kTileHeight <= height; - // tile_in_x_bounds = width % kTileWidth == 0 || - // x_in_tiles * kTileWidth + kTileWidth <= width; - // // The implementation handles y and x bound checks separately. - // if (tile_in_y_bounds && tile_in_x_bounds) { - // for (y_offset : range(kTileHeight)) { - // y = y_in_tiles * kTileHeight + y_offset; - // for (x_offset : range(kTileWidth)) { - // x = x_in_tiles * kTileWidth + x_offset; - // partial_result = Reducer(partial_result[x_offset], input[y][x]); - // } - // } - // } else { - // for (y_offset : range(kTileHeight)) { - // y = y_in_tiles * kTileHeight + y_offset; - // for (y_offset : range(kTileHeight)) { - // x = x_in_tiles * kTileWidth + x_offset; - // if (y < height && x < width) { - // partial_result = Reducer(partial_result, input[y][x]); - // } - // } - // } - // } - // for (x_offset : range(kTileWidth)) { - // AtomicReducer(&output[x + x_offset], partial_result[x_offset]); - // } - // } - auto loop_body_emitter = [=](const IrArray::Index& tile_index) -> Status { - const int num_reduces = reducers.size(); - // Emit the loop body that reduces one tile. - llvm::Type* element_ir_type = - llvm_ir::PrimitiveTypeToIrType(input_shape.element_type(), module_); - std::vector partial_reduction_result_addresses; - for (int i = 0; i != num_reduces; ++i) { - for (int x_offset = 0; x_offset < kTileWidth; ++x_offset) { - llvm::Value* partial_reduction_result_address = - Alloca(element_ir_type, /*ArraySize=*/nullptr, - "partial_reduction_result." + - llvm::Twine(i * kTileWidth + x_offset)); - TF_ASSIGN_OR_RETURN(llvm::Value* const init_ir_value, - init_value_gens[i](IrArray::Index(index_ty))); - Store(init_ir_value, partial_reduction_result_address); - partial_reduction_result_addresses.push_back( - partial_reduction_result_address); - } - } - - // Emit an inner for-loop that partially reduces the elements in the given - // tile. - llvm::Value* y_in_tiles = tile_index[0]; - llvm::Value* x_in_tiles = tile_index[1]; - - y_in_tiles = ZExtOrTrunc(y_in_tiles, index_ty); - x_in_tiles = ZExtOrTrunc(x_in_tiles, index_ty); - - auto emit_tile_element_loop = [=](bool tile_in_y_bounds, - bool tile_in_x_bounds) -> Status { - std::unique_ptr tile_element_loop = - llvm_ir::ForLoop::EmitForLoop( - "element_id_in_tile", index_typed_constant(0), - index_typed_constant(kTileHeight), index_typed_constant(1), &b_); - - // Emit the body of the partial reduction loop. - llvm_ir::SetToFirstInsertPoint(tile_element_loop->GetBodyBasicBlock(), - &b_); - llvm::Value* y = - NSWAdd(NSWMul(y_in_tiles, index_typed_constant(kTileHeight)), - tile_element_loop->GetIndVarValue()); - - // Unless we know that y is in bounds, we have to emit a check before - // reading from the input. - if (!tile_in_y_bounds) { - llvm_ir::LlvmIfData if_data = llvm_ir::EmitIfThenElse( - ICmpULT(y, index_typed_constant(height)), "y_in_bounds", &b_); - - // Emit code that reads the input element and accumulates it to - // the partial reduction result. - llvm_ir::SetToFirstInsertPoint(if_data.true_block, &b_); - } - for (int x_offset = 0; x_offset < kTileWidth; ++x_offset) { - llvm::Value* x = - NSWAdd(NSWMul(x_in_tiles, index_typed_constant(kTileWidth)), - index_typed_constant(x_offset)); - // Unless we know that x is in bounds, we have to emit a check before - // reading from the input. - if (!tile_in_x_bounds) { - llvm_ir::LlvmIfData if_data = llvm_ir::EmitIfThenElse( - ICmpULT(x, index_typed_constant(width)), "x_in_bounds", &b_); - llvm_ir::SetToFirstInsertPoint(if_data.true_block, &b_); - } - llvm::Value* input_address = Alloca(element_ir_type); - // {y,x} is an index to input_matrix_shape [height,width]. We need to - // convert that to an index to input_shape (the shape of the operand of - // "reduce"). This conversion is composed of a transposition from - // input_shape to normalized_input_shape and a reshape from - // normalized_input_shape to input_matrix_shape. - const Shape normalized_input_shape = - ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout( - input_shape); - auto input_shape_min2maj = LayoutUtil::MinorToMajor(input_shape); - const std::vector transpose_dimension_mapping( - input_shape_min2maj.rbegin(), input_shape_min2maj.rend()); - - const Shape input_matrix_shape = - ShapeUtil::MakeShapeWithDescendingLayout(input_shape.element_type(), - {height, width}); - const IrArray::Index input_matrix_index({y, x}, input_matrix_shape, - &b_); - const IrArray::Index input_index = - input_matrix_index - .SourceIndexOfReshape(input_matrix_shape, - normalized_input_shape, &b_) - .SourceIndexOfTranspose(normalized_input_shape, input_shape, - transpose_dimension_mapping, &b_); - for (int i = 0; i != num_reduces; ++i) { - TF_ASSIGN_OR_RETURN(llvm::Value* const input_ir_value, - input_gens[i](input_index)); - Store(input_ir_value, input_address); - TF_RETURN_IF_ERROR(EmitCallToNestedComputation( - *reducers[i], - {partial_reduction_result_addresses[i * kTileWidth + x_offset], - input_address}, - partial_reduction_result_addresses[i * kTileWidth + x_offset])); - TF_RETURN_IF_ERROR(EmitExtraOutputsForReduce(reduce, input_index, - extra_output_gens)); - } - } - return Status::OK(); - }; - - // y_end = kTileHeight + y_in_tiles * kTileHeight, i.e., the y location - // that's immediately beyond the tile. - llvm::Value* y_end = - NSWAdd(index_typed_constant(kTileHeight), - NSWMul(y_in_tiles, index_typed_constant(kTileHeight))); - // x_end = kTileWidth + x_in_tiles * kTileWidth, i.e., the x location - // that's immediately beyond the tile. - llvm::Value* x_end = - NSWAdd(index_typed_constant(kTileWidth), - NSWMul(x_in_tiles, index_typed_constant(kTileWidth))); - llvm::Value* tile_in_y_bounds = - Or(ICmpULE(y_end, index_typed_constant(height)), - b_.getInt1(height % kTileHeight == 0)); - llvm::Value* tile_in_x_bounds = - Or(ICmpULE(x_end, index_typed_constant(width)), - b_.getInt1(width % kTileWidth == 0)); - // The tile is in y bounds if "height" is a multiple of kTileHeight or - // y_end <= height. - llvm_ir::LlvmIfData if_tile_in_y_bounds_data = - llvm_ir::EmitIfThenElse(tile_in_y_bounds, "tile_in_y_bounds", &b_); - llvm_ir::SetToFirstInsertPoint(if_tile_in_y_bounds_data.true_block, &b_); - // The tile is in x bounds if "width" is a multiple of kTileWidth or - // x_end <= width. - llvm_ir::LlvmIfData if_tile_in_x_bounds_data = - llvm_ir::EmitIfThenElse(tile_in_x_bounds, "tile_in_x_bounds", &b_); - llvm_ir::SetToFirstInsertPoint(if_tile_in_x_bounds_data.true_block, &b_); - TF_RETURN_IF_ERROR(emit_tile_element_loop(/*tile_in_y_bounds=*/true, - /*tile_in_x_bounds=*/true)); - llvm_ir::SetToFirstInsertPoint(if_tile_in_x_bounds_data.false_block, &b_); - TF_RETURN_IF_ERROR(emit_tile_element_loop(/*tile_in_y_bounds=*/true, - /*tile_in_x_bounds=*/false)); - llvm_ir::SetToFirstInsertPoint(if_tile_in_y_bounds_data.false_block, &b_); - if_tile_in_x_bounds_data = - llvm_ir::EmitIfThenElse(tile_in_x_bounds, "tile_in_x_bounds", &b_); - llvm_ir::SetToFirstInsertPoint(if_tile_in_x_bounds_data.true_block, &b_); - TF_RETURN_IF_ERROR(emit_tile_element_loop(/*tile_in_y_bounds=*/false, - /*tile_in_x_bounds=*/true)); - llvm_ir::SetToFirstInsertPoint(if_tile_in_x_bounds_data.false_block, &b_); - TF_RETURN_IF_ERROR(emit_tile_element_loop(/*tile_in_y_bounds=*/false, - /*tile_in_x_bounds=*/false)); - - // After the nested if-then-else statement on tile_in_y_bounds and - // tile_in_x_bounds, emit atomic operations to accumulate the partial - // reduction result to the output element. - llvm_ir::SetToFirstInsertPoint(if_tile_in_y_bounds_data.after_block, &b_); - const HloInstruction* output = - reduce->IsFused() ? reduce->parent()->FusionInstruction() : reduce; - for (int i = 0; i != num_reduces; ++i) { - for (int x_offset = 0; x_offset < kTileWidth; ++x_offset) { - llvm::Value* x = - NSWAdd(NSWMul(x_in_tiles, index_typed_constant(kTileWidth)), - index_typed_constant(x_offset)); - llvm::Value* output_address = - GetIrArray(*output, *output, reduce_output_shapes[i]) - .EmitArrayElementAddress( - IrArray::Index( - x, - ShapeUtil::GetSubshape(output->shape(), - reduce_output_shapes[i]), - &b_), - &b_, "output_element_address"); - TF_RETURN_IF_ERROR(EmitAtomicOperationForNestedComputation( - *reducers[i], output_address, - partial_reduction_result_addresses[i * kTileWidth + x_offset])); - } - } - return Status::OK(); - }; - - // Emit a parallel loop that iterate through all input tiles. - UpdateLaunchDimensions(launch_dimensions, kernel_thunk, - ir_emitter_context_->llvm_module()); - return ParallelLoopEmitter(loop_body_emitter, tiled_input_shape, - launch_dimensions, &b_) - .EmitLoop(IrName(reduce), index_ty); -} - -static std::pair ComputeKernelMappingSchemeForReduction( - int64 depth, int64 width, int64 kWarpSize) { - constexpr int64 kTargetNumElementsPerThread = 64; - int64 x_tile_size = kTargetNumElementsPerThread; - int64 z_tile_size = 1; - - // Only tile along the x dimension with tile size kTargetNumElementsPerThread - // if doing so doesn't require a slow version of loop with bound check on each - // dimension. A more sophisticated heuristics is to enable tile along the - // x dimension with tile size kTargetNumElementsPerThread when either width is - // a factor of (kWarpSize * kTargetNumElementsPerThread) or width is big - // enough so that only a small fraction of the threads execute the slow - // version of loop with bound check. - if (width % (kWarpSize * kTargetNumElementsPerThread) != 0) { - x_tile_size = 8; - z_tile_size = 8; - while (depth % z_tile_size != 0) { - z_tile_size -= 1; - } - } - - return std::pair(x_tile_size, z_tile_size); -} - -Status IrEmitterUnnested::EmitRowReduction( - KernelThunk* kernel_thunk, int64 depth, int64 height, int64 width, - HloInstruction* reduce, const Shape& input_shape, - absl::Span input_gens, - absl::Span init_value_gens, - absl::Span reducers, - absl::Span reduce_output_shapes, - absl::Span> - extra_output_gens) { - // A naive algorithm is: - // 1. Divide the x dimension of the input tensor into tiles of size 1x1xX. - // 2. Partially reduces each tile to a scalar using one thread. - // 3. Accumulates that scalar to the output vector using atomic operations. - // - // for (linear_index = threadIdx.x + blockIdx.x * blockDim.x; - // linear_index < depth * height * width_in_tiles; - // linear_index += blockDim.x * gridDim.x) { - // int x_in_tiles = linear_index % width_in_tiles; - // int y = linear_index / width_in_tiles % height; - // int z = linear_index / (height * width_in_tiles); - // float partial_result = 0; - // for (element_id_in_tile : range(x_tile_size)) { - // int x = x_in_tiles * x_tile_size + element_id_in_tile; - // if (x < width) - // partial_result = reducer(partial_result, input[z][y][x]); - // } - // AtomicReducer(&output[y], partial_result); - // } - // - // Four optimizations are performed. - // - // 1. To coalesce global memory accesses, dilate the tile with a factor of 32 - // (i.e. the warp size). For example, suppose the width is 8x32=256. Instead - // of making each tile consecutive, we let make tile 0 column - // [0,32,64,...,224], tile 1 column [1,33,65,...,225], and so on. This ensures - // that threads in a warp access consecutive memory in one iteration (i.e. - // coalesced). In the above example, the warp that contains thread 0-31 - // accesses column 0-31 in the first iteration, and 32-63 in the second - // iteration, and so on. - // - // 2. Partially accumulate partial reduced results computed by threads in the - // same warp using shfl_down. Using shfl_down is faster than directly using - // atomic operations because shfl_down transfers the data between threads - // using shared memory and threads in the same warp run in lock step (thus no - // extra synchronization needed). See - // https://devblogs.nvidia.com/parallelforall/faster-parallel-reductions-kepler/ - // for details. The downside is, to produce correct results when using - // shfl_down, we need to guarantee threads in the same warp work on input - // elements with the same y, so the number of tiles in each row must be a - // multiple of 32. - // - // 3. Specialize the case that the entire tile is in bounds. When that is - // true, we don't need to emit "if(x 0; shuffle_distance /= 2) - // partial_result = Reducer( - // partial_result, - // __shfl_down_sync(CUDA_WARP_ALL, partial_result, shuffle_distance)); - // if (lane_id == 0) - // AtomicReducer(&output[y], partial_result); - // } - // - - int64 x_tile_size; - int64 z_tile_size; - std::tie(x_tile_size, z_tile_size) = - ComputeKernelMappingSchemeForReduction(depth, width, kWarpSize); - - // Round the width in tiles up to the nearest multiple of kWarpSize, so that - // the use of shfl_down is valid. - const int64 width_in_tiles = - RoundUpToNearest(CeilOfRatio(width, x_tile_size), kWarpSize); - Shape tiled_input_shape = ShapeUtil::MakeShapeWithLayout( - reduce->shape().element_type(), - {depth / z_tile_size, height, width_in_tiles}, {2, 1, 0}); - LaunchDimensions launch_dimensions = CalculateLaunchDimensions( - tiled_input_shape, ir_emitter_context_->device_description()); - llvm::Type* index_ty = - GetIndexTypeForKernel(reduce, launch_dimensions.launch_bound(), &b_); - - auto index_typed_constant = [&](uint64 c) -> llvm::Constant* { - return llvm::ConstantInt::get(index_ty, c); - }; - - auto loop_body_emitter = [=](const IrArray::Index& tile_index) { - const int num_reduces = reducers.size(); - llvm::Type* element_ir_type = llvm_ir::PrimitiveTypeToIrType( - input_shape.element_type(), ir_emitter_context_->llvm_module()); - std::vector partial_reduction_result_addresses; - for (int i = 0; i != num_reduces; ++i) { - llvm::Value* partial_reduction_result_address = - Alloca(element_ir_type, /*ArraySize=*/nullptr, - "partial_reduction_result." + llvm::Twine(i)); - TF_ASSIGN_OR_RETURN(llvm::Value* const init_ir_value, - init_value_gens[i](IrArray::Index(index_ty))); - Store(init_ir_value, partial_reduction_result_address); - partial_reduction_result_addresses.push_back( - partial_reduction_result_address); - } - - llvm::Value* z_tile = tile_index[0]; - llvm::Value* y = tile_index[1]; - llvm::Value* x_tile = tile_index[2]; - - x_tile = ZExtOrTrunc(x_tile, index_ty); - - llvm::Value* warp_id = - UDiv(x_tile, index_typed_constant(kWarpSize), "warp_id"); - llvm::Value* lane_id = - URem(x_tile, index_typed_constant(kWarpSize), "lane_id"); - - // The x-location of the last element in this z-x-tile. - // last_x = lane_id + warpSize * (x_tile_size - 1 + warp_id * x_tile_size); - llvm::Value* last_x = NSWAdd( - lane_id, - NSWMul(index_typed_constant(kWarpSize), - NSWAdd(index_typed_constant(x_tile_size - 1), - NSWMul(warp_id, index_typed_constant(x_tile_size))))); - - KernelSupportLibrary ksl( - &b_, - /*unroll_mode=*/xla::llvm_ir::UnrollMode::kFullyUnroll, - /*prevent_vectorization=*/false); - - // Emit a for-loop that partially reduces the elements in the given - // z-x-tile. - auto emit_z_x_tile_element_loop = [&](bool x_tile_in_bounds, - int64 x_tile_loop_bound) -> Status { - auto emit_z_tile_element_loop = [&](llvm::Value* z_indvar) -> Status { - llvm::Value* z = - NSWAdd(z_indvar, NSWMul(index_typed_constant(z_tile_size), z_tile)); - TF_RETURN_IF_ERROR(ksl.For( - "x_tile", - /*start=*/index_typed_constant(0), - /*end=*/index_typed_constant(x_tile_loop_bound), - /*step=*/1, [&](llvm::Value* x_indvar) -> Status { - // x = lane_id + - // warpSize * (element_id_in_x_tile + warp_id * x_tile_size); - llvm::Value* x = NSWAdd( - lane_id, - NSWMul(index_typed_constant(kWarpSize), - NSWAdd(x_indvar, - NSWMul(warp_id, llvm::ConstantInt::get( - index_ty, x_tile_size))))); - - // Unless we know the x-tile is entirely in bounds, we have to - // emit a x-in-bounds check before reading from the input. - if (!x_tile_in_bounds) { - llvm_ir::LlvmIfData if_x_in_bounds_data = - llvm_ir::EmitIfThenElse( - ICmpULT(x, index_typed_constant(width)), "x_in_bounds", - &b_); - // Points b_ to the then-block. - llvm_ir::SetToFirstInsertPoint(if_x_in_bounds_data.true_block, - &b_); - } - - // Emit code that reads the input element and accumulates it - // to the partial reduction result. - llvm::Value* input_address = Alloca(element_ir_type); - { - // {z,y,x} is an index to input_3d_tensor_shape - // [depth,height,width]. We need to convert that to an index - // to input_shape (the shape of the operand of "reduce"). - // This conversion is composed of a transposition from - // input_shape to normalized_input_shape and a reshape from - // normalized_input_shape to input_3d_tensor_shape. - const Shape normalized_input_shape = ShapeUtil:: - MakeShapeWithDescendingLayoutAndSamePhysicalLayout( - input_shape); - auto input_shape_min2maj = - LayoutUtil::MinorToMajor(input_shape); - const std::vector transpose_dimension_mapping( - input_shape_min2maj.rbegin(), input_shape_min2maj.rend()); - const Shape input_3d_tensor_shape = - ShapeUtil::MakeShapeWithDescendingLayout( - input_shape.element_type(), {depth, height, width}); - const IrArray::Index input_3d_tensor_index( - {z, y, x}, input_3d_tensor_shape, &b_); - const IrArray::Index input_index = - input_3d_tensor_index - .SourceIndexOfReshape(input_3d_tensor_shape, - normalized_input_shape, &b_) - .SourceIndexOfTranspose( - normalized_input_shape, input_shape, - transpose_dimension_mapping, &b_); - - for (int i = 0; i != num_reduces; ++i) { - TF_ASSIGN_OR_RETURN(llvm::Value* const input_ir_value, - input_gens[i](input_index)); - Store(input_ir_value, input_address); - TF_RETURN_IF_ERROR(EmitCallToNestedComputation( - *reducers[i], - {partial_reduction_result_addresses[i], input_address}, - partial_reduction_result_addresses[i])); - } - return EmitExtraOutputsForReduce(reduce, input_index, - extra_output_gens); - } - })); - return Status::OK(); - }; - - return ksl.For("z_tile", - /*start=*/index_typed_constant(0), - /*end=*/index_typed_constant(z_tile_size), - /*step=*/1, emit_z_tile_element_loop); - }; - - llvm::Value* tile_in_bounds = - Or(b_.getInt1(width % (x_tile_size * kWarpSize) == 0), - ICmpULT(last_x, index_typed_constant(width))); - - TF_RETURN_IF_ERROR( - ksl.If(tile_in_bounds, - /*true_block_generator=*/ - [&]() -> Status { - return emit_z_x_tile_element_loop(/*x_tile_in_bounds=*/true, - x_tile_size); - }, - /*false_block_generator=*/ - [&]() -> Status { - return emit_z_x_tile_element_loop( - /*x_tile_in_bounds=*/false, - CeilOfRatio(width % (x_tile_size * kWarpSize), kWarpSize)); - })); - - // After accumulating the elements of the z_x_tile, emit calls to - // shfl_down that accumulate the partial reduction results of all - // threads in a warp. - int bit_width = llvm_ir::GetSizeInBits(element_ir_type); - // bitcast cannot be applied to aggregate types (even packed ones), so we - // instead bitcast addresses of load/store to intN* of the same bit-width. - llvm::Type* shuffle_ir_type = element_ir_type->isStructTy() - ? b_.getIntNTy(bit_width) - : element_ir_type; - for (int shuffle_distance = 16; shuffle_distance >= 1; - shuffle_distance /= 2) { - llvm::Value* result_from_other_lane = - Alloca(element_ir_type, nullptr, "result_from_other_lane"); - for (int i = 0; i != num_reduces; ++i) { - llvm::Value* partial_reduction_result = - Load(BitCast(partial_reduction_result_addresses[i], - shuffle_ir_type->getPointerTo()), - "partial_reduction_result"); - CHECK_EQ(launch_dimensions.threads_per_block() % kWarpSize, 0) - << "Requires block size a multiple of the warp size, otherwise we " - "will read undefined elements."; - Store(EmitFullWarpShuffleDown(partial_reduction_result, - b_.getInt32(shuffle_distance), &b_), - BitCast(result_from_other_lane, shuffle_ir_type->getPointerTo())); - TF_RETURN_IF_ERROR(EmitCallToNestedComputation( - *reducers[i], - {partial_reduction_result_addresses[i], result_from_other_lane}, - partial_reduction_result_addresses[i])); - } - } - - const HloInstruction* output = - reduce->IsFused() ? reduce->parent()->FusionInstruction() : reduce; - - // Emit an atomic operation that accumulates the partial reduction result of - // lane 0 (which holds the partially accumulated result for its warp) to the - // output element. - llvm_ir::LlvmIfData if_lane_id_is_zero_data = llvm_ir::EmitIfThenElse( - ICmpEQ(lane_id, index_typed_constant(0)), "lane_id_is_zero", &b_); - llvm_ir::SetToFirstInsertPoint(if_lane_id_is_zero_data.true_block, &b_); - for (int i = 0; i != num_reduces; ++i) { - llvm::Value* output_address = - GetIrArray(*output, *output, reduce_output_shapes[i]) - .EmitArrayElementAddress( - IrArray::Index(y, - ShapeUtil::GetSubshape( - output->shape(), reduce_output_shapes[i]), - &b_), - &b_, "output_element_address"); - // We don't need to emit atomic operations if there is only one tile of - // results. 'depth' is the z dimension, 'width' is the x dimension. - if (z_tile_size >= depth && x_tile_size >= width) { - TF_RETURN_IF_ERROR(EmitCallToNestedComputation( - *reducers[i], - {output_address, partial_reduction_result_addresses[i]}, - output_address)); - } else { - TF_RETURN_IF_ERROR(EmitAtomicOperationForNestedComputation( - *reducers[i], output_address, - partial_reduction_result_addresses[i])); - } - } - return Status::OK(); - }; - - // Emit a parallel loop that iterates through every input tiles. - UpdateLaunchDimensions(launch_dimensions, kernel_thunk, - ir_emitter_context_->llvm_module()); - return ParallelLoopEmitter(loop_body_emitter, tiled_input_shape, - launch_dimensions, &b_) - .EmitLoop(IrName(reduce), index_ty); -} - -// Figures out whether `reduce` is a row or column reduction, and which -// dimensions to reduce, and calls either `EmitRowReduction` or -// `EmitColumnReduction` as appropriate. -// Prerequisite: all the dimensions to keep are contiguous in the input layout -// and, if `reduce` is fused, the fused subgraph is pure -// elementwise. -Status IrEmitterUnnested::EmitReductionToVector( - KernelThunk* kernel_thunk, HloInstruction* reduce, const Shape& input_shape, - absl::Span input_gens, - absl::Span init_value_gens, - absl::Span dimensions_to_reduce, - absl::Span reducers, - absl::Span reduce_output_shapes, - absl::Span> - extra_output_gens) { - // This emission requires "reduce" to have an input layout. It is either set - // by LayoutAssignment (for a top-level kReduce) or by InstructionFusion (for - // a fused kReduce). - CHECK(input_shape.has_layout()) << "LayoutAssignment or InstructionFusion " - "doesn't set the input layout of " - << reduce->ToString(); - - // Specialize multi-dimensional-array-to-vector reduction. - std::vector input_dims_to_keep; - for (int64 input_dim = 0; input_dim < ShapeUtil::Rank(input_shape); - ++input_dim) { - if (std::find(dimensions_to_reduce.begin(), dimensions_to_reduce.end(), - input_dim) == dimensions_to_reduce.end()) { - input_dims_to_keep.push_back(input_dim); - } - } - - // Sort the dimensions to keep from minor to major, to facilitate checking - // whether another dimension is major or minor of them. - std::sort(input_dims_to_keep.begin(), input_dims_to_keep.end(), - [&input_shape](int64 dim_a, int64 dim_b) { - return PositionInContainer(LayoutUtil::MinorToMajor(input_shape), - dim_a) < - PositionInContainer(LayoutUtil::MinorToMajor(input_shape), - dim_b); - }); - // Now, if output rank is at least 1, `input_dims_to_keep.front()` is - // minormost and `input_dims_to_keep.back()` is majormost. - - // If the dimensions to keep are minormost, emit a column reduction. As all - // the dimensions to keep are contiguous, by prerequisite of - // `EmitReductionToVector`, we only need to check whether the minormost - // dimension of the input is to keep. - if (ShapeUtil::IsEffectiveScalar(reduce->shape())) { - return EmitReductionToScalar(kernel_thunk, reduce, input_shape, input_gens, - init_value_gens, reducers, - reduce_output_shapes, extra_output_gens); - } else if (input_dims_to_keep.front() == - LayoutUtil::Minor(input_shape.layout(), 0)) { - // Column reduction. Treat the result of "input" as a matrix whose width - // is the most minor dimension and height the product of other dimensions, - // and treat "reduce" as a column reduction of the input matrix. - const int64 width = ShapeUtil::ElementsIn(reduce->shape()); - // "width" can be zero, so don't do - // height = ShapeUtil::ElementsIn(input_shape) / width; - int64 height = 1; - for (int64 input_dim = 0; input_dim < ShapeUtil::Rank(input_shape); - ++input_dim) { - if (!std::count(input_dims_to_keep.begin(), input_dims_to_keep.end(), - input_dim)) { - height *= input_shape.dimensions(input_dim); - } - } - return EmitColumnReduction(kernel_thunk, height, width, reduce, input_shape, - input_gens, init_value_gens, reducers, - reduce_output_shapes, extra_output_gens); - } else { - // Reduce the row dimension of a matrix or reduce dimension 0 and 2 in a - // 3D tensor. The size of dimension 1 (the height) is the size of the - // dimension to keep, the size of dimension 0 (the depth) is the product - // of dimensions that are more major than the dimension to keep, and the - // size of dimension 2 (the width) is the product of more minor - // dimensions. - int64 depth = 1; - int64 width = 1; - for (int64 input_dim = 0; input_dim < ShapeUtil::Rank(input_shape); - ++input_dim) { - if (PositionInContainer(LayoutUtil::MinorToMajor(input_shape), - input_dim) > - PositionInContainer(LayoutUtil::MinorToMajor(input_shape), - input_dims_to_keep.back())) { - depth *= input_shape.dimensions(input_dim); - } else if (PositionInContainer(LayoutUtil::MinorToMajor(input_shape), - input_dim) < - PositionInContainer(LayoutUtil::MinorToMajor(input_shape), - input_dims_to_keep.front())) { - width *= input_shape.dimensions(input_dim); - } - } - const int64 height = ShapeUtil::ElementsIn(reduce->shape()); - return EmitRowReduction(kernel_thunk, depth, height, width, reduce, - input_shape, input_gens, init_value_gens, reducers, - reduce_output_shapes, extra_output_gens); - } -} - Status IrEmitterUnnested::HandleReduce(HloInstruction* reduce) { // TODO(b/112040122): Support multi-output reduce. if (!ShapeUtil::IsArray(reduce->shape())) { return Unimplemented("Multi-output reduce is not supported on GPU"); } - auto input = reduce->operand(0); - auto init_value = reduce->operand(1); - absl::Span dimensions_to_reduce(reduce->dimensions()); - HloComputation* reducer = reduce->to_apply(); - // HandleReduce specializes reduction from a multi-dimensional array to a 1D - // array. The specialized version requires an initializer thunk that - // initializes the output array to the initial value of the reduce. if (IsReductionToVector(*reduce)) { - TF_ASSIGN_OR_RETURN(std::unique_ptr initializer_thunk, - BuildInitializerThunk(reduce)); - std::vector> thunks; - thunks.push_back(std::move(initializer_thunk)); - std::unique_ptr kernel_thunk = - BuildKernelThunk(reduce, /*implements_whole_instruction=*/false); - - TF_CHECK_OK(EmitReductionToVector( - kernel_thunk.get(), reduce, input->shape(), - {[&](const IrArray::Index& index) { - return GetIrArray(*input, *reduce).EmitReadArrayElement(index, &b_); - }}, - {[&](const IrArray::Index& index) { - return GetIrArray(*init_value, *reduce) - .EmitReadArrayElement(index, &b_); - }}, - dimensions_to_reduce, {reducer}, {{}}, {})); - - thunks.push_back(std::move(kernel_thunk)); - - std::unique_ptr sequential_thunk = - absl::make_unique(std::move(thunks), reduce); - AddThunkToThunkSequence(std::move(sequential_thunk)); - return Status::OK(); + return EmitReductionToVector(reduce); } return IrEmitter::HandleReduce(reduce); @@ -1820,7 +763,7 @@ Status IrEmitterUnnested::HandleSelectAndScatter( // Create the inner loop to iterate over the window. llvm_ir::ForLoopNest window_loops(IrName(select_and_scatter, "inner"), &b_, index_type); - std::vector window_size; + DimensionVector window_size; for (const auto& dim : window.dimensions()) { window_size.push_back(dim.size()); CHECK_GT(dim.size(), 0); @@ -3121,11 +2064,9 @@ Status IrEmitterUnnested::EmitTargetElementLoopInThunk( // pressure, since we touch threadIdx.x and blockIdx.x at the beginning of the // kernel *anyway*. std::vector output_arrays = ConstructIrArrayForOutputs(hlo); - TF_RETURN_IF_ERROR( - KernelSupportLibrary(&b_).If("emit_mof_tuple", IsBlock0Thread0(&b_), [&] { - llvm_ir::EmitTuple(GetIrArray(hlo, hlo), output_arrays, &b_, module_); - return Status::OK(); - })); + KernelSupportLibrary{&b_}.If("emit_mof_tuple", IsBlock0Thread0(&b_), [&] { + llvm_ir::EmitTuple(GetIrArray(hlo, hlo), output_arrays, &b_, module_); + }); // For multioutput fusion, we need to emit each operand and the root. TF_RETURN_IF_ERROR( @@ -3195,34 +2136,36 @@ int IrEmitterUnnested::ConstructInputReducedShapeAndCastInputIrArrayToShape( namespace { -void EmitFullTile(const KernelMappingScheme* mapping_scheme, - const IrArray::Index& tile_origin_index, - llvm::IRBuilder<>* builder, llvm::Value* y, llvm::Value* x, - llvm::Type* index_ty, - const std::function& emit_elem_function) { +void EmitFullElementalTile( + const KernelMappingScheme* mapping_scheme, + const IrArray::Index& tile_origin_index, const string& loop_name, + KernelSupportLibrary* ksl, llvm::IRBuilder<>* builder, llvm::Value* y, + llvm::Value* x, llvm::Type* index_ty, + const std::function& emit_elem_function) { int64 num_threads_x = mapping_scheme->GetNumberOfThreadsForDimensionX(); int64 num_threads_y = mapping_scheme->GetNumberOfThreadsForDimensionY(); int64 tile_size_x = mapping_scheme->GetTileSizeForDimensionX(); int64 tile_size_y = mapping_scheme->GetTileSizeForDimensionY(); - for (int64 i = 0; i < tile_size_y; i += num_threads_y) { - IrArray::Index source_idx_y = - tile_origin_index.AddOffsetToDim(llvm::ConstantInt::get(index_ty, i), - KernelMappingScheme::DimY, builder); - llvm::Value* y_loc = - builder->CreateAdd(llvm::ConstantInt::get(index_ty, i), y); - for (int64 j = 0; j < tile_size_x; j += num_threads_x) { - IrArray::Index source_idx = - source_idx_y.AddOffsetToDim(llvm::ConstantInt::get(index_ty, j), - KernelMappingScheme::DimX, builder); - llvm::Value* x_loc = - builder->CreateAdd(llvm::ConstantInt::get(index_ty, j), x); - emit_elem_function(source_idx, y_loc, x_loc); - } - } -} - -void EmitPartialTile( + ksl->For(loop_name + "_y", /*start=*/llvm::ConstantInt::get(index_ty, 0), + /*end=*/llvm::ConstantInt::get(index_ty, tile_size_y), + /*step=*/llvm::ConstantInt::get(index_ty, num_threads_y), + [&](llvm::Value* y_indvar) { + IrArray::Index source_idx_y = tile_origin_index.AddOffsetToDim( + y_indvar, KernelMappingScheme::DimY, builder); + llvm::Value* y_loc = builder->CreateAdd(y_indvar, y); + for (int64 j = 0; j < tile_size_x; j += num_threads_x) { + IrArray::Index source_idx = source_idx_y.AddOffsetToDim( + llvm::ConstantInt::get(index_ty, j), + KernelMappingScheme::DimX, builder); + llvm::Value* x_loc = + builder->CreateAdd(llvm::ConstantInt::get(index_ty, j), x); + emit_elem_function(source_idx, y_loc, x_loc); + } + }); +} + +void EmitPartialElementalTile( const KernelMappingScheme* mapping_scheme, const IrArray::Index& tile_origin_index, const string& loop_name, KernelSupportLibrary* ksl, llvm::IRBuilder<>* builder, llvm::Value* y, @@ -3241,8 +2184,9 @@ void EmitPartialTile( llvm::Value* x_loc = builder->CreateAdd(llvm::ConstantInt::get(index_ty, j), x); - ksl->IfReturnVoid( - "x_in_tile", builder->CreateICmpULT(x_loc, tile_width), [&] { + ksl->If( + loop_name + "_x_in_tile", builder->CreateICmpULT(x_loc, tile_width), + [&] { // tile_height_bound = // ceil(tile_height / num_threads_y) * num_threads_y llvm::Value* ceiling_of_ratio = builder->CreateUDiv( @@ -3252,15 +2196,15 @@ void EmitPartialTile( llvm::Value* tile_height_bound = builder->CreateMul( ceiling_of_ratio, llvm::ConstantInt::get(index_ty, num_threads_y)); - ksl->ForReturnVoid( + ksl->For( loop_name, /*start=*/llvm::ConstantInt::get(index_ty, 0), /*end=*/tile_height_bound, /*step=*/llvm::ConstantInt::get(index_ty, num_threads_y), [&](llvm::Value* y_indvar) { llvm::Value* y_loc = builder->CreateAdd(y_indvar, y); - ksl->IfReturnVoid( - "y_in_tile", builder->CreateICmpULT(y_loc, tile_height), - [&] { + ksl->If( + loop_name + "_y_in_tile", + builder->CreateICmpULT(y_loc, tile_height), [&] { emit_elem_function( source_idx.AddOffsetToDim( y_indvar, KernelMappingScheme::DimY, builder), @@ -3290,21 +2234,21 @@ void EmitTiledElementalCodeWithBoundsCheck( int64 tile_size_y = mapping_scheme->GetTileSizeForDimensionY(); llvm::Type* index_ty = tile_width->getType(); - ksl->IfReturnVoid( - "full_tile", + ksl->If( + loop_name + "_full_tile", builder->CreateAnd( builder->CreateICmpEQ(llvm::ConstantInt::get(index_ty, tile_size_x), tile_width), builder->CreateICmpEQ(llvm::ConstantInt::get(index_ty, tile_size_y), tile_height)), [&] { - EmitFullTile(mapping_scheme, tile_origin_index, builder, y, x, index_ty, - emit_elem_function); + EmitFullElementalTile(mapping_scheme, tile_origin_index, loop_name, ksl, + builder, y, x, index_ty, emit_elem_function); }, [&] { - EmitPartialTile(mapping_scheme, tile_origin_index, loop_name, ksl, - builder, y, x, tile_height, tile_width, index_ty, - emit_elem_function); + EmitPartialElementalTile(mapping_scheme, tile_origin_index, loop_name, + ksl, builder, y, x, tile_height, tile_width, + index_ty, emit_elem_function); }); } } // namespace @@ -3382,7 +2326,395 @@ void IrEmitterUnnested::EmitTileElementForFusion( } } -// Emits a block of tiles, given a function object to emit one tile. +// Information to support the code generation for a tiled reduction kernel. +using AddressVector = InlinedVector; +class ReductionCodegenInfo : public IrEmitterUnnested::KernelCodegenInfo { + public: + explicit ReductionCodegenInfo(llvm_ir::KernelMappingScheme* mapping_scheme, + bool is_row_reduction) + : KernelCodegenInfo(mapping_scheme), + current_output_linear_index_address_(nullptr), + current_output_inbound_address_(nullptr), + is_row_reduction_(is_row_reduction) {} + + void SetCurrentOutputLinearIndexAddress(llvm::AllocaInst* a) { + current_output_linear_index_address_ = a; + } + // Returns the address of the memory that stores the linear index of the + // current output. Since we are processing reduction to contiguous physical + // dimensions, this linear index is the linear index of the 1D output array. + llvm::AllocaInst* GetCurrentOutputLinearIndexAddress() const { + return current_output_linear_index_address_; + } + + void SetCurrentOutputInboundAddress(llvm::AllocaInst* a) { + current_output_inbound_address_ = a; + } + + llvm::AllocaInst* GetCurrentOutputInboundAddress() const { + return current_output_inbound_address_; + } + + AddressVector* GetMutablePartialResultAddresses() { + return &partial_result_addresses_; + } + absl::Span GetPartialResultAddresses() const { + return partial_result_addresses_; + } + + AddressVector* GetMutableReductionInputAddresses() { + return &reduction_input_addresses_; + } + absl::Span GetReductionInputAddresses() const { + return reduction_input_addresses_; + } + + InlinedVector* GetMutableReducers() { return &reducers_; } + const InlinedVector& GetReducers() const { + return reducers_; + } + int GetNumberOfReduces() const { return reducers_.size(); } + + InlinedVector* GetMutableReductionOutputShapeIndices() { + return &reduction_output_shape_indices_; + } + absl::Span GetReductionOutputShapeIndices() const { + return reduction_output_shape_indices_; + } + + bool IsRowReduction() const { return is_row_reduction_; } + + // Return the dimension that is being reduced between DimX and DimY. + int GetReducedDimensionEnum() const { + return IsRowReduction() ? llvm_ir::KernelMappingScheme::DimX + : llvm_ir::KernelMappingScheme::DimY; + } + + // Return the dimension that is being ketp between DimX and DimY. + int GetKeptDimensionEnum() const { + return IsRowReduction() ? llvm_ir::KernelMappingScheme::DimY + : llvm_ir::KernelMappingScheme::DimX; + } + + private: + AddressVector partial_result_addresses_; + AddressVector reduction_input_addresses_; + InlinedVector reducers_; + InlinedVector reduction_output_shape_indices_; + llvm::AllocaInst* current_output_linear_index_address_; + llvm::AllocaInst* current_output_inbound_address_; + bool is_row_reduction_; +}; + +namespace { +// Returns a group of instructions that generate the output for the kernel +// containing the given HLO instruction. The result may be an unnested kReduce +// HLO, a nested kReduce HLO of a kInput fusion, or the operands of the tuple +// for a multiple output fusion. +absl::Span GetOutputInstructions( + HloInstruction* const* reduce_or_tuple_pointer) { + HloOpcode opcode = (*reduce_or_tuple_pointer)->opcode(); + CHECK(opcode == HloOpcode::kReduce || opcode == HloOpcode::kTuple); + return opcode == HloOpcode::kTuple + ? (*reduce_or_tuple_pointer)->operands() + : absl::Span(reduce_or_tuple_pointer, 1); +} + +const HloInstruction* GetFirstReduceInstruction( + absl::Span instructions) { + auto first_reduce_iter = + absl::c_find_if(instructions, [](const HloInstruction* inst) { + return inst->opcode() == HloOpcode::kReduce; + }); + CHECK_NE(first_reduce_iter, instructions.end()); + return *first_reduce_iter; +} + +}; // namespace + +void IrEmitterUnnested::EmitPrologueForOneReduction( + HloInstruction* unnested_hlo, HloInstruction* reduce_inst, int reduce_idx, + KernelCodegenInfo* kernel_info, GpuElementalIrEmitter* elemental_emitter, + ShapeIndex output_shape_index) { + ReductionCodegenInfo* reduction_info = + static_cast(kernel_info); + + InlinedVector* reducers = + reduction_info->GetMutableReducers(); + CHECK(IsReductionToVector(*reduce_inst)); + reducers->push_back(reduce_inst->to_apply()); + + InlinedVector* reduction_output_shape_indices = + reduction_info->GetMutableReductionOutputShapeIndices(); + reduction_output_shape_indices->push_back(std::move(output_shape_index)); + + AddressVector* reduction_input_addresses = + reduction_info->GetMutableReductionInputAddresses(); + llvm::Type* element_type = llvm_ir::PrimitiveTypeToIrType( + reduce_inst->shape().element_type(), ir_emitter_context_->llvm_module()); + llvm::AllocaInst* reduction_input_address = Alloca(element_type); + reduction_input_addresses->push_back(reduction_input_address); + + AddressVector* partial_result_addresses = + reduction_info->GetMutablePartialResultAddresses(); + llvm::AllocaInst* partial_result_address = + Alloca(element_type, /*ArraySize=*/nullptr, + "partial_reduction_result." + llvm::Twine(reduce_idx)); + partial_result_addresses->push_back(partial_result_address); + + // Initialize the partial result with the initial value of the reduction. + llvm::Value* init_ir_value; + if (unnested_hlo->opcode() == HloOpcode::kFusion) { + HloInstruction* init_value_operand = reduce_inst->mutable_operand(1); + FusedIrEmitter fused_emitter(GetGeneratorForOperandIrArrays(unnested_hlo), + elemental_emitter); + + TF_CHECK_OK(init_value_operand->Accept(&fused_emitter)); + init_ir_value = + fused_emitter + .GetGenerator(init_value_operand)(IrArray::Index(b_.getInt32Ty())) + .ValueOrDie(); + } else { + const HloInstruction* init_value = unnested_hlo->operand(1); + init_ir_value = + GetIrArray(*init_value, *unnested_hlo) + .EmitReadArrayElement(IrArray::Index(b_.getInt32Ty()), &b_); + } + + Store(init_ir_value, partial_result_address); +} + +void IrEmitterUnnested::EmitPrologueForReduction( + HloInstruction* unnested_hlo, KernelCodegenInfo* kernel_info) { + VLOG(10) << "Emit prologue for reduction " << unnested_hlo->ToString(); + // Find the unnested kReduce or the tuple that contains a list of kReduce. + HloInstruction* reduce_or_tuple = unnested_hlo->opcode() == HloOpcode::kFusion + ? unnested_hlo->fused_expression_root() + : unnested_hlo; + absl::Span output_instructions = + GetOutputInstructions(&reduce_or_tuple); + ReductionCodegenInfo* reduction_info = + static_cast(kernel_info); + GpuElementalIrEmitter elemental_emitter(hlo_module_config_, + ir_emitter_context_->llvm_module(), + &b_, GetNestedComputer()); + const HloInstruction* first_reduce = nullptr; + for (int i = 0, e = output_instructions.size(); i != e; ++i) { + if (output_instructions[i]->opcode() != HloOpcode::kReduce) { + continue; + } + HloInstruction* reduce_inst = output_instructions[i]; + if (first_reduce == nullptr) { + first_reduce = reduce_inst; + } else { + CHECK(first_reduce->dimensions() == reduce_inst->dimensions()); + } + ShapeIndex output_shape_index; + if (reduce_or_tuple->opcode() == HloOpcode::kTuple) { + output_shape_index = {i}; + } + + EmitPrologueForOneReduction(unnested_hlo, reduce_inst, i, kernel_info, + &elemental_emitter, + std::move(output_shape_index)); + } + + // Allocate stack storage to store the current output linear index and record + // the address of the storage. + reduction_info->SetCurrentOutputLinearIndexAddress( + Alloca(reduction_info->GetIndexType())); + + if (!reduction_info->IsRowReduction()) { + llvm::Type* bool_ty = b_.getInt1Ty(); + llvm::AllocaInst* output_inbound_addr = Alloca(bool_ty); + Store(llvm::ConstantInt::get(bool_ty, 0), output_inbound_addr); + reduction_info->SetCurrentOutputInboundAddress(output_inbound_addr); + } +} + +void IrEmitterUnnested::EmitFullWarpShuffleDownLoopForAllReduces( + absl::Span reducers, + absl::Span partial_result_addresses) { + for (int distance = 16; distance >= 1; distance /= 2) { + for (int i = 0; i != reducers.size(); ++i) { + llvm::Type* element_type = + partial_result_addresses[i]->getType()->getElementType(); + int bit_width = llvm_ir::GetSizeInBits(element_type); + llvm::Value* result_from_other_lane = Alloca( + element_type, nullptr, "result_from_other_lane" + llvm::Twine(i)); + // Bitcast cannot be applied to aggregate types (even packed ones), so + // we bitcast addresses of load/store to intN* of the same bit-width. + llvm::Type* shuffled_value_type = + element_type->isStructTy() ? b_.getIntNTy(bit_width) : element_type; + auto convert_pointer_for_shuffle = [&](llvm::Value* ptr) { + return BitCast(ptr, shuffled_value_type->getPointerTo()); + }; + llvm::Value* partial_result = + Load(convert_pointer_for_shuffle(partial_result_addresses[i]), + "partial_reduction_result"); + Store(EmitFullWarpShuffleDown(partial_result, b_.getInt32(distance), &b_), + convert_pointer_for_shuffle(result_from_other_lane)); + TF_CHECK_OK(EmitCallToNestedComputation( + *reducers[i], {partial_result_addresses[i], result_from_other_lane}, + partial_result_addresses[i])); + } + } +} + +void IrEmitterUnnested::EmitEpilogueForReduction( + HloInstruction* unnested_hlo, KernelCodegenInfo* kernel_info) { + ReductionCodegenInfo* reduction_info = + static_cast(kernel_info); + int num_reduces = reduction_info->GetNumberOfReduces(); + absl::Span partial_result_addresses = + reduction_info->GetPartialResultAddresses(); + const InlinedVector& reducers = + reduction_info->GetReducers(); + absl::Span reduction_output_shape_indices = + reduction_info->GetReductionOutputShapeIndices(); + + if (reduction_info->IsRowReduction()) { + EmitFullWarpShuffleDownLoopForAllReduces(reducers, + partial_result_addresses); + llvm::Value* lane_id = reduction_info->GetLaneId(); + llvm_ir::LlvmIfData if_lane_id_is_zero_data = llvm_ir::EmitIfThenElse( + ICmpEQ(lane_id, llvm::ConstantInt::get(lane_id->getType(), 0)), + "lane_id_is_zero", &b_); + llvm_ir::SetToFirstInsertPoint(if_lane_id_is_zero_data.true_block, &b_); + } else { + llvm::Value* output_inbound_addr = + reduction_info->GetCurrentOutputInboundAddress(); + llvm::Value* output_inbound = Load(output_inbound_addr); + llvm_ir::LlvmIfData if_output_inbound_data = llvm_ir::EmitIfThenElse( + ICmpEQ(output_inbound, + llvm::ConstantInt::get(output_inbound->getType(), 1)), + "output_inbound", &b_); + llvm_ir::SetToFirstInsertPoint(if_output_inbound_data.true_block, &b_); + } + + // Emit an atomic operation that accumulates the partial reduction to the + // output element. For row reduction, this is only for lane 0 due to the + // if-statement emitted above. + for (int i = 0; i != num_reduces; ++i) { + IrArray::Index element_index( + /*linear=*/Load(reduction_info->GetCurrentOutputLinearIndexAddress(), + "output_linear_addr"), + ShapeUtil::GetSubshape(unnested_hlo->shape(), + reduction_output_shape_indices[i]), + &b_); + llvm::Value* output_address = + GetIrArray(*unnested_hlo, *unnested_hlo, + reduction_output_shape_indices[i]) + .EmitArrayElementAddress(element_index, &b_, + "output_element_address"); + // Do not emit atomic operations if each element in the reduction result is + // computed by one block, that is the dimension being reduced has only one + // block. + const llvm_ir::KernelMappingScheme* mapping_scheme = + reduction_info->GetKernelMappingScheme(); + if (mapping_scheme->GetTileBlockSizeForDimension( + llvm_ir::KernelMappingScheme::DimZ) == 1 && + mapping_scheme->GetTileBlockSizeForDimension( + reduction_info->GetReducedDimensionEnum()) == 1) { + TF_CHECK_OK(EmitCallToNestedComputation( + *reducers[i], {output_address, partial_result_addresses[i]}, + output_address)); + } else { + TF_CHECK_OK(EmitAtomicOperationForNestedComputation( + *reducers[i], output_address, partial_result_addresses[i])); + } + } +} + +void IrEmitterUnnested::EmitTileElementForReduction( + HloInstruction* unnested_hlo, const llvm_ir::IrArray::Index& index, + const KernelCodegenInfo* kernel_info, llvm::Value* y_loc, + llvm::Value* x_loc) { + VLOG(10) << "Emit tile element for reduce " << unnested_hlo->ToString(); + HloInstruction* reduce_or_tuple = unnested_hlo->opcode() == HloOpcode::kFusion + ? unnested_hlo->fused_expression_root() + : unnested_hlo; + llvm_ir::TiledParameterInfo* tiled_param_info = + kernel_info->GetTiledParameterInfo(); + tiled_param_info->set_y(y_loc); + tiled_param_info->set_x(x_loc); + + // Record the linear address for the current reduction. + const ReductionCodegenInfo* reduction_info = + dynamic_cast(kernel_info); + Store(index[reduction_info->GetKeptDimensionEnum()], + reduction_info->GetCurrentOutputLinearIndexAddress()); + if (!reduction_info->IsRowReduction()) { + llvm::Type* bool_ty = b_.getInt1Ty(); + llvm::AllocaInst* output_inbound_addr = + reduction_info->GetCurrentOutputInboundAddress(); + Store(llvm::ConstantInt::get(bool_ty, 1), output_inbound_addr); + } + + InlinedVector input_gens; + std::vector> + extra_output_gens; + GpuElementalIrEmitter elem_emitter(hlo_module_config_, module_, &b_, + GetNestedComputer()); + FusedIrEmitter fused_emitter(GetGeneratorForOperandIrArrays(unnested_hlo), + &elem_emitter); + absl::Span output_instructions = + GetOutputInstructions(&reduce_or_tuple); + // Construct the ElementGenerator for each reduction and extra output in the + // the group of output instructions. + if (unnested_hlo->opcode() == HloOpcode::kFusion) { + fused_emitter.SetTiledParameterInfo(tiled_param_info); + TF_CHECK_OK(unnested_hlo->fused_expression_root()->Accept(&fused_emitter)); + + for (int i = 0, e = output_instructions.size(); i != e; ++i) { + const HloInstruction* inst = output_instructions[i]; + ShapeIndex output_shape_index; + if (reduce_or_tuple->opcode() == HloOpcode::kTuple) { + output_shape_index = {i}; + } + if (inst->opcode() == HloOpcode::kReduce) { + input_gens.push_back(fused_emitter.GetGenerator(inst->operand(0))); + } else { + extra_output_gens.emplace_back(fused_emitter.GetGenerator(inst), + std::move(output_shape_index)); + } + } + } else { + input_gens.push_back([&](const IrArray::Index& index) { + return GetIrArray(*unnested_hlo->operand(0), *unnested_hlo) + .EmitReadArrayElement(index, &b_); + }); + } + + IrArray::Index input_index = + reduction_info->GetKernelMappingScheme()->GetUnnormalizedIndex( + index, + GetFirstReduceInstruction(output_instructions)->operand(0)->shape()); + absl::Span partial_reduction_result_addresses = + reduction_info->GetPartialResultAddresses(); + absl::Span reduction_input_addresses = + reduction_info->GetReductionInputAddresses(); + const InlinedVector& reducers = + reduction_info->GetReducers(); + + // Emit code to generate the input and perform the reduction computation for + // each reduction instruction. + for (int i = 0; i != reducers.size(); ++i) { + llvm::Value* const input_ir_value = input_gens[i](input_index).ValueOrDie(); + Store(input_ir_value, reduction_input_addresses[i]); + TF_CHECK_OK(EmitCallToNestedComputation( + *reducers[i], + {partial_reduction_result_addresses[i], reduction_input_addresses[i]}, + partial_reduction_result_addresses[i])); + } + + // Emit code to generate the output for the non-reduction instructions in the + // fusion, if any. + TF_CHECK_OK( + EmitExtraOutputsForReduce(unnested_hlo, input_index, extra_output_gens)); +} + +// Emits a kernel for the hlo instruction using the given tiling scheme. void IrEmitterUnnested::EmitBlock(const TileGenerator& emit_one_tile, const KernelCodegenInfo* kernel_info, KernelSupportLibrary& ksl, @@ -3419,15 +2751,14 @@ void IrEmitterUnnested::EmitBlock(const TileGenerator& emit_one_tile, Select(ICmpEQ(last_block_for_dim, block_id_for_dim), last_block_size_for_dim, block_size_for_dim); - ksl.ForReturnVoid( - loop_name, - /*start=*/index_typed_constant(0), - /*end=*/num_tiles_in_block, - /*step=*/1, [&](llvm::Value* block_dim_induction_var) { - IrArray::Index tile_index = starting_tile.AddOffsetToDim( - block_dim_induction_var, dim_id, &b_); - emit_next_block_dim(tile_index); - }); + ksl.For(loop_name, + /*start=*/index_typed_constant(0), + /*end=*/num_tiles_in_block, + /*step=*/1, [&](llvm::Value* block_dim_induction_var) { + IrArray::Index tile_index = starting_tile.AddOffsetToDim( + block_dim_induction_var, dim_id, &b_); + emit_next_block_dim(tile_index); + }); } }; @@ -3509,11 +2840,22 @@ LaunchDimensions IrEmitterUnnested::EmitKernel( << llvm_ir::DumpToString(*param_shmem_buffers[id]); } - CHECK_EQ(mapping_scheme->GetThreadsPerTile() % kWarpSize, 0); - LaunchDimensions launch_dimensions = LaunchDimensions( - mapping_scheme->GetNumberOfBlocks(), mapping_scheme->GetThreadsPerTile()); - llvm::Type* index_ty = GetIndexTypeForKernel( - unnested_hlo, launch_dimensions.launch_bound(), &b_); + const ReductionCodegenInfo* reduction_info = + dynamic_cast(kernel_info); + bool is_column_reduction = + (reduction_info && !reduction_info->IsRowReduction()); + + LaunchDimensions launch_dimensions = + LaunchDimensions(mapping_scheme->GetNumberOfBlocks(), + mapping_scheme->GetThreadsPerBlock()); + + // TODO(b/110211620): Enable int32 index type for column reduction. + llvm::Type* index_ty = + is_column_reduction + ? b_.getInt64Ty() + : GetIndexTypeForKernel(unnested_hlo, + launch_dimensions.launch_bound(), &b_); + auto index_typed_constant = [&](uint64 c) -> llvm::Constant* { return llvm::ConstantInt::get(index_ty, c); }; @@ -3523,14 +2865,13 @@ LaunchDimensions IrEmitterUnnested::EmitKernel( // but we do it at the beginning in the hopes of reducing register pressure, // since we touch threadIdx.x and blockIdx.x at the beginning of the kernel // *anyway*. - if (unnested_hlo->IsMultiOutputFusion()) { - TF_CHECK_OK(KernelSupportLibrary(&b_).If( + if (!reduction_info && unnested_hlo->IsMultiOutputFusion()) { + KernelSupportLibrary{&b_}.If( "emit_mof_tuple", IsBlock0Thread0(&b_), [&] { llvm_ir::EmitTuple(GetIrArray(*unnested_hlo, *unnested_hlo), ConstructIrArrayForOutputs(*unnested_hlo), &b_, module_); - return Status::OK(); - })); + }); } // For each tiled parameter, cast its input IrArray to the corresponding @@ -3553,6 +2894,7 @@ LaunchDimensions IrEmitterUnnested::EmitKernel( kernel_info->SetLaneId( mapping_scheme->GetNumberOfThreadsForDimensionX() == kWarpSize ? x : nullptr); + kernel_info->SetIndexType(index_ty); KernelSupportLibrary ksl(&b_, llvm_ir::UnrollMode::kDefaultUnroll); // Curry a few parameters to EmitTiledElementalCodeWithBoundsCheck. @@ -3577,29 +2919,31 @@ LaunchDimensions IrEmitterUnnested::EmitKernel( input_tile_origin.AddOffsetToDim(x, KernelMappingScheme::DimX, &b_) .AddOffsetToDim(y, KernelMappingScheme::DimY, &b_); - // Copy input parameter values to shared memory buffers: - // tile[y, x] = input[index] - // Note that tile_width and tile_height are flipped here because we are - // reading a transposed tile. - emit_tiled_elemental_code_with_bounds_check( - input_index, "input", output_tile_bounds[2], output_tile_bounds[1], - [&](const IrArray::Index& index, llvm::Value* y_loc, - llvm::Value* x_loc) { - for (int64 id : tiled_param_ids) { - IrArray& input_in_logical_shape = param_in_reduced_shape_arrays[id]; - llvm::Value* shmem_buffer = param_shmem_buffers[id]; - // TODO(jlebar): Add AA metadata to this store. Tile buffers are - // global variables, so LLVM can't infer much about it. - Store(input_in_logical_shape.EmitReadArrayElement(index, &b_, - "input_element"), - GEP(shmem_buffer, {index_typed_constant(0), y_loc, x_loc})); - } - }); - // If shared memory transpose is needed, wait for all threads to reach this // point, lest we copy a value from tile to output before the other thread // copies it from input to tile. This is `__syncthreads` in CUDA. if (!tiled_param_ids.empty()) { + // Copy input parameter values to shared memory buffers: + // tile[y, x] = input[index] + // Note that tile_width and tile_height are flipped here because we are + // reading a transposed tile. + emit_tiled_elemental_code_with_bounds_check( + input_index, "input", output_tile_bounds[2], output_tile_bounds[1], + [&](const IrArray::Index& index, llvm::Value* y_loc, + llvm::Value* x_loc) { + for (int64 id : tiled_param_ids) { + IrArray& input_in_logical_shape = + param_in_reduced_shape_arrays[id]; + llvm::Value* shmem_buffer = param_shmem_buffers[id]; + // TODO(jlebar): Add AA metadata to this store. Tile buffers are + // global variables, so LLVM can't infer much about it. + Store(input_in_logical_shape.EmitReadArrayElement( + index, &b_, "input_element"), + GEP(shmem_buffer, {index_typed_constant(0), y_loc, x_loc})); + } + }); + + // Wait for all threads to reach this point using `__syncthreads` in CUDA. llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::nvvm_barrier0, {}, {}, &b_); } @@ -3619,6 +2963,7 @@ LaunchDimensions IrEmitterUnnested::EmitKernel( kernel_generator.GetTileElementGenerator()(unnested_hlo, index, kernel_info, y_loc, x_loc); }); + // If a tile block contains multiple tiles and shared memory buffers are // used, we need to wait for all threads to finish using the shared memory // buffer for the current tile before we move on to process the next tile @@ -3814,6 +3159,249 @@ bool IrEmitterUnnested::CheckAndEmitHloWithTile021(HloInstruction* hlo) { return true; } +namespace { +// Checks that the outputs of a fusion with reduction are consistent. +Status AreFusedReductionOutputsConsistent( + absl::Span output_instructions, + const HloInstruction* first_reduce) { + for (const HloInstruction* inst : output_instructions) { + if (inst->opcode() == HloOpcode::kReduce) { + // Shapes, layouts and dimensions must be the same for all reduces + // inside of this fusion. + TF_RET_CHECK(ShapeUtil::Equal(first_reduce->shape(), inst->shape())); + TF_RET_CHECK(ShapeUtil::Equal(first_reduce->operand(0)->shape(), + inst->operand(0)->shape())); + TF_RET_CHECK(ShapeUtil::Equal(first_reduce->operand(1)->shape(), + inst->operand(1)->shape())); + TF_RET_CHECK(first_reduce->dimensions() == inst->dimensions()); + } else { + // For extra outputs we can relax shape equality to allow different + // types (with the same number of elements). Layouts still have to + // match. + TF_RET_CHECK(ShapeUtil::CompatibleIgnoringElementType( + first_reduce->operand(0)->shape(), inst->shape())); + TF_RET_CHECK(LayoutUtil::Equal(first_reduce->operand(0)->shape().layout(), + inst->shape().layout())); + } + } + return Status::OK(); +} + +// Finds the dimensions to keep for the reduction, sorts and returns the +// dimensions from minor to major. +DimensionVector GetDimensionsToKeepMinorToMajor( + const Shape& input_shape, absl::Span dims_to_reduce) { + DimensionVector input_dims(ShapeUtil::Rank(input_shape), 0); + absl::c_iota(input_dims, 0); + DimensionVector input_dims_to_keep; + for (int input_dim : input_dims) { + auto it = absl::c_find_if(dims_to_reduce, [&](int64 dim_to_reduce) { + return dim_to_reduce == input_dim; + }); + if (it == dims_to_reduce.end()) { + input_dims_to_keep.push_back(input_dim); + } + } + + // Sort the dimensions to keep from minor to major. + absl::c_sort(input_dims_to_keep, [&input_shape](int64 dim_a, int64 dim_b) { + return PositionInContainer(LayoutUtil::MinorToMajor(input_shape), dim_a) < + PositionInContainer(LayoutUtil::MinorToMajor(input_shape), dim_b); + }); + + VLOG(10) << "dims to keep minor to major" + << absl::StrJoin(input_dims_to_keep, ","); + return input_dims_to_keep; +} + +// Given the input shape and dimensions to reduce for the reduction to vector, +// returns : +// num_kept: the number of elements in the contiguous dimensions to keep. +// num_reduced_major: the number of elements in the dimensions to reduce that +// are more major than the dimensions to keep. +// num_reduced_minor: the number of elements in the dimensions to reduce that +// are more minor than the dimensions to kept. +std::tuple GetReductionToVectorDimensions( + const Shape& input_shape, absl::Span dims_to_reduce) { + DimensionVector input_dims_to_keep_minor_to_major = + GetDimensionsToKeepMinorToMajor(input_shape, dims_to_reduce); + CHECK(LayoutUtil::AreDimensionsConsecutive( + input_shape.layout(), input_dims_to_keep_minor_to_major)); + int num_reduced_major = 1, num_kept = 1, num_reduced_minor = 1; + if (input_dims_to_keep_minor_to_major.empty()) { + return std::make_tuple(num_reduced_major, num_kept, num_reduced_minor); + } + DimensionVector input_dims(ShapeUtil::Rank(input_shape), 0); + absl::c_iota(input_dims, 0); + absl::Span minor_to_major = + LayoutUtil::MinorToMajor(input_shape); + for (int input_dim : input_dims) { + int64 curr_dim_size = input_shape.dimensions(input_dim); + if (PositionInContainer(minor_to_major, input_dim) > + PositionInContainer(minor_to_major, + input_dims_to_keep_minor_to_major.back())) { + num_reduced_major *= curr_dim_size; + } else if (PositionInContainer(minor_to_major, input_dim) < + PositionInContainer(minor_to_major, + input_dims_to_keep_minor_to_major.front())) { + num_reduced_minor *= curr_dim_size; + } else { + num_kept *= curr_dim_size; + } + } + + return std::make_tuple(num_reduced_major, num_kept, num_reduced_minor); +} + +} // namespace + +std::tuple +IrEmitterUnnested::ComputeMappingSchemeAndReductionKind( + const HloInstruction* first_reduce) { + int64 depth = 1; + int64 height = 1; + int64 width = 1; + bool is_row_reduction = true; + int64 tile_size_x = 1; + int64 tile_size_y = 1; + int64 block_size_z = 1; + int64 num_threads_x = 1; + int64 num_threads_y = 1; + const Shape& input_shape = first_reduce->operand(0)->shape(); + int64 num_input_elems = ShapeUtil::ElementsIn(input_shape); + int64 num_output_elems = ShapeUtil::ElementsIn(first_reduce->shape()); + int64 num_reduced_major, num_kept, num_reduced_minor; + std::tie(num_reduced_major, num_kept, num_reduced_minor) = + GetReductionToVectorDimensions(input_shape, first_reduce->dimensions()); + CHECK_EQ(num_output_elems, num_kept); + + if (num_kept == 1) { + // Scalar reduction is a special row reduction with depth = height = 1. + width = num_input_elems; + tile_size_x = kWarpSize * 16; + num_threads_x = kWarpSize; + } else if (num_reduced_minor == 1) { + // Column reduction reduces inputs with dimension [height, width], where + // width is the minor dimension, to dimension [width]. + height = num_reduced_major; + width = num_kept; + is_row_reduction = false; + // Column reduction without transpose doesn't require communication among + // threads processing elements in the same tile. The current implementation + // only support the use of on hardware thread block to process one block of + // tiles in the KernelMappingScheme. We try to maximize the values of + // num_threads_x and tile_size_x to allow a bigger hardware thread block. + int64 hw_threads_per_block_limit = + ThreadsPerBlockLimit(ir_emitter_context_->device_description()); + tile_size_x = std::min(hw_threads_per_block_limit, num_kept); + num_threads_x = tile_size_x; + int64 kNumElementsPerPartialSum = 128; + tile_size_y = kNumElementsPerPartialSum; + } else { + // Row reduction reduces inputs with dimension [depth, height, width], + // where width is the most minor dimension, to dimension [height] . + depth = num_reduced_major; + height = num_kept; + width = num_reduced_minor; + num_threads_x = kWarpSize; + if (width % (kWarpSize * 64) == 0) { + tile_size_x = kWarpSize * 64; + } else { + tile_size_x = kWarpSize * 8; + block_size_z = 8; + while (depth % block_size_z != 0) { + block_size_z -= 1; + } + } + } + DCHECK_EQ(depth * height * width, num_input_elems); + VLOG(10) << "is_row_reduction " << is_row_reduction << depth << " " << height + << " " << width; + + DimensionVector dims_in_elem{depth, height, width}; + DimensionVector req_block_sizes{block_size_z, 1, 1}; + llvm_ir::KernelMappingScheme mapping_scheme( + dims_in_elem, tile_size_y, tile_size_x, req_block_sizes, num_threads_y, + num_threads_x, &b_); + return std::make_tuple(mapping_scheme, is_row_reduction); +} + +Status IrEmitterUnnested::EmitReductionToVector(HloInstruction* unnested_hlo) { + VLOG(10) << "Emitting reduction to vector " << unnested_hlo->ToString(); + + HloInstruction* reduce_or_tuple = unnested_hlo->opcode() == HloOpcode::kFusion + ? unnested_hlo->fused_expression_root() + : unnested_hlo; + absl::Span output_instructions = + GetOutputInstructions(&reduce_or_tuple); + const HloInstruction* first_reduce = + GetFirstReduceInstruction(output_instructions); + + if (output_instructions.size() > 1) { + TF_RETURN_IF_ERROR( + AreFusedReductionOutputsConsistent(output_instructions, first_reduce)); + } + + // Build an initializer thunk to initialize each reduction output. + std::vector> thunks; + for (int i = 0, e = output_instructions.size(); i != e; ++i) { + if (output_instructions[i]->opcode() != HloOpcode::kReduce) { + continue; + } + TF_ASSIGN_OR_RETURN( + std::unique_ptr initializer_thunk, + BuildInitializerThunk(unnested_hlo, + (output_instructions[i] == reduce_or_tuple) + ? ShapeIndex() + : ShapeIndex({i}))); + thunks.push_back(std::move(initializer_thunk)); + } + + // Build a kernel thunk to compute all the outputs. + std::unique_ptr kernel_thunk = + BuildKernelThunk(unnested_hlo, /*implements_whole_instruction=*/false); + + const Shape& input_shape = first_reduce->operand(0)->shape(); + // The layout of a reduction input is either set by LayoutAssignment for + // unnested kReduce or by InstructionFusion for fused kReduce. + CHECK(input_shape.has_layout()) << "LayoutAssignment or InstructionFusion " + "doesn't set the input layout of " + << first_reduce->ToString(); + + bool is_row_reduction; + llvm_ir::KernelMappingScheme mapping_scheme; + std::tie(mapping_scheme, is_row_reduction) = + ComputeMappingSchemeAndReductionKind(first_reduce); + ReductionCodegenInfo reduction_info(&mapping_scheme, is_row_reduction); + KernelCodeGenerator kernel_generator( + /*tile_element_generator=*/ + [&](HloInstruction* hlo, const llvm_ir::IrArray::Index& index, + const KernelCodegenInfo* kernel_info, llvm::Value* y_loc, + llvm::Value* x_loc) { + EmitTileElementForReduction(hlo, index, kernel_info, y_loc, x_loc); + }, + /*block_prologue_generator=*/ + [&](HloInstruction* hlo, KernelCodegenInfo* kernel_info) { + EmitPrologueForReduction(hlo, kernel_info); + }, + /*block_epilogue_generator*/ + [&](HloInstruction* hlo, KernelCodegenInfo* kernel_info) { + EmitEpilogueForReduction(hlo, kernel_info); + }); + + LaunchDimensions launch_dimensions = + EmitKernel(unnested_hlo, {}, kernel_generator, &reduction_info); + UpdateLaunchDimensions(launch_dimensions, kernel_thunk.get(), + ir_emitter_context_->llvm_module()); + + thunks.push_back(std::move(kernel_thunk)); + std::unique_ptr sequential_thunk = + absl::make_unique(std::move(thunks), unnested_hlo); + AddThunkToThunkSequence(std::move(sequential_thunk)); + + return Status::OK(); +} + Status IrEmitterUnnested::EmitConstantGlobals() { for (const BufferAllocation& allocation : ir_emitter_context_->buffer_assignment().Allocations()) { diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h index e09ed657a812be6ab4859a0e365a51c45a37bfed..1ebea7ab48664e693937b45561d096f7ec15132f 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h @@ -16,6 +16,7 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_IR_EMITTER_UNNESTED_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_IR_EMITTER_UNNESTED_H_ +#include "absl/container/inlined_vector.h" #include "tensorflow/compiler/xla/service/gpu/ir_emitter.h" #include "tensorflow/compiler/xla/service/gpu/sequential_thunk.h" #include "tensorflow/compiler/xla/service/gpu/thunk.h" @@ -68,9 +69,12 @@ class IrEmitterUnnested : public IrEmitter { explicit KernelCodegenInfo(llvm_ir::KernelMappingScheme* mapping_scheme) : mapping_scheme_(mapping_scheme), tiled_param_info_(nullptr), - lane_id_(nullptr) {} + lane_id_(nullptr), + index_ty_(nullptr) {} + virtual ~KernelCodegenInfo() {} void SetLaneId(llvm::Value* v) { lane_id_ = v; } + void SetIndexType(llvm::Type* t) { index_ty_ = t; } void SetTiledParamInfo(llvm_ir::TiledParameterInfo* tiled_param_info) { CHECK_EQ(tiled_param_info_, nullptr); tiled_param_info_ = tiled_param_info; @@ -83,11 +87,13 @@ class IrEmitterUnnested : public IrEmitter { llvm_ir::TiledParameterInfo* GetTiledParameterInfo() const { return tiled_param_info_; } + llvm::Type* GetIndexType() const { return index_ty_; } private: llvm_ir::KernelMappingScheme* mapping_scheme_; llvm_ir::TiledParameterInfo* tiled_param_info_; llvm::Value* lane_id_; + llvm::Type* index_ty_; }; // A function object to prepare for the code generation for a tile block. @@ -200,82 +206,19 @@ class IrEmitterUnnested : public IrEmitter { // Helper for writing extra outputs from inside a reduce kernel. Status EmitExtraOutputsForReduce( - const HloInstruction* reduce, const llvm_ir::IrArray::Index& index, + const HloInstruction* unnested_hlo, const llvm_ir::IrArray::Index& index, absl::Span> extra_output_gens); - // EmitColumnReduction and EmitRowReduction emit code for column and row - // reduction of a matrix and/or 3D tensor. Row and column reduction have - // different memory access pattern, so for performance their implementations - // are significantly different. + // Generates code for reduction to contiguous dimensions. // - // Emits code that reduces a matrix of shape [height x width] to a vector of - // [width]. Other parameters have the same meaning as those of - // `EmitReductionToVector`. Note that input shape might not be - // [height x width], but can be bitcast to [height x width] with "height" - // being the major dimension. - Status EmitColumnReduction( - KernelThunk* kernel_thunk, int64 height, int64 width, - HloInstruction* reduce, const Shape& input_shape, - absl::Span input_gens, - absl::Span init_value_gens, - absl::Span reducers, - absl::Span reduce_output_shapes, - absl::Span> - extra_output_gens); - - // Emits code that reduces a 3D tensor of shape [depth x height x width] to a - // vector of shape [height]. Other parameters have the same meaning as those - // of `EmitReductionToVector`. Note that input shape might not be - // [depth x height x width], but can be bitcast to [depth x height x width] - // with "depth" being the most major dimension. - Status EmitRowReduction( - KernelThunk* kernel_thunk, int64 depth, int64 height, int64 width, - HloInstruction* reduce, const Shape& input_shape, - absl::Span input_gens, - absl::Span init_value_gens, - absl::Span reducers, - absl::Span reduce_output_shapes, - absl::Span> - extra_output_gens); - - // Emits code that reduces a tensor of arbitrary rank to a scalar. - Status EmitReductionToScalar( - KernelThunk* kernel_thunk, HloInstruction* reduce, - const Shape& input_shape, - absl::Span input_gens, - absl::Span init_value_gens, - absl::Span reducers, - absl::Span reduce_output_shapes, - absl::Span> - extra_output_gens); + // Prerequisite: `IsReductionToVector(*unnested_hlo)` + Status EmitReductionToVector(HloInstruction* unnested_hlo); - // Figures out whether `reduce` is a row or column reduction, and which - // dimensions to reduce, and calls either `EmitRowReduction` or - // `EmitColumnReduction` as appropriate. `input_shape` is the shape of the - // input array, which is the operand of the Reduce instruction if unfused or - // of the Fusion instruction if fused. `input_gen` and `init_value_gen` - // generate elements of the input and the initial value. Other parameters mean - // the same as for `HandleReduce`. - // - // Multiple reduces can be emitted in the same loop, assuming they have the - // same input and output shapes, and the same reduce dimensions. - // - // extra_output_gens can contain extra generators for intermediate outputs. - // These must have the same shape as the reduce input as they are computed - // when the reduce inputs are being read. - // - // Prerequisite: `IsReductionToVector(*reduce)` - Status EmitReductionToVector( - KernelThunk* kernel_thunk, HloInstruction* reduce, - const Shape& input_shape, - absl::Span input_gens, - absl::Span init_value_gens, - absl::Span dimensions_to_reduce, - absl::Span reducers, - absl::Span reduce_output_shapes, - absl::Span> - extra_output_gens); + // Computes the KernelMappingScheme for the reduce HLO and indicates whether + // the reduction is a row reduction. + std::tuple + ComputeMappingSchemeAndReductionKind(const HloInstruction* first_reduce); // Emits code for an in-place scatter, modifying `thunk`s launch dimensions in // the process. `scatter` may be fused, scatter indices are taken from @@ -314,6 +257,28 @@ class IrEmitterUnnested : public IrEmitter { const llvm_ir::IrArray::Index& index, const KernelCodegenInfo* kernel_info, llvm::Value* y_loc, llvm::Value* x_loc); + // Emits code to process a tensor element in a tile for the given input hlo + // that is either a unnested kReduce or a kInput fusion. + void EmitTileElementForReduction(HloInstruction* unnested_hlo, + const llvm_ir::IrArray::Index& index, + const KernelCodegenInfo* kernel_info, + llvm::Value* y_loc, llvm::Value* x_loc); + // Prepares for the code generation for a tile block of a reduction kernel. + void EmitPrologueForReduction(HloInstruction* unnested_hlo, + KernelCodegenInfo* kernel_info); + void EmitPrologueForOneReduction(HloInstruction* unnested_hlo, + HloInstruction* reduce_inst, int reduce_idx, + KernelCodegenInfo* kernel_info, + GpuElementalIrEmitter* elemental_emitter, + ShapeIndex output_shape_index); + // Wraps up the code generation for a tile block of a reduction kernel. + void EmitEpilogueForReduction(HloInstruction* unnested_hlo, + KernelCodegenInfo* kernel_info); + // For each reducer, emits the shuffle-down loop to accumulate the partial + // result to the global result. + void EmitFullWarpShuffleDownLoopForAllReduces( + absl::Span reducers, + absl::Span partial_result_addresses); // Generates the IrArray for each input of an hlo and returns a vector that // constains such IrArrays. diff --git a/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc b/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc index f3e17d888242a36c268dcbfa0d6530f80cedceb0..60f2116e6088fd2c5d3400b4463cb7fa8bbadfdc 100644 --- a/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc +++ b/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc @@ -108,27 +108,33 @@ namespace { namespace tracing = tensorflow::tracing; -// Returns the directory containing nvvm libdevice files. config_cuda_data_dir -// should be equal to config().debug_options().xla_gpu_cuda_data_dir() of the -// HloModule being compiled. -string GetLibdeviceDir(const string& config_cuda_data_dir) { - std::vector potential_libdevice_dirs; - if (!config_cuda_data_dir.empty()) { - potential_libdevice_dirs.push_back(config_cuda_data_dir); - } - potential_libdevice_dirs.push_back(tensorflow::LibdeviceRoot()); - - // Tries all potential libdevice directories in the order they are inserted. - // Returns the first directory that exists in the file system. - for (const string& potential_libdevice_dir : potential_libdevice_dirs) { - if (tensorflow::Env::Default()->IsDirectory(potential_libdevice_dir).ok()) { - VLOG(2) << "Found libdevice dir " << potential_libdevice_dir; - return potential_libdevice_dir; +// Returns a vector of potential locations of the CUDA root directory. +std::vector GetCudaRootCandidates( + const HloModuleConfig& hlo_module_config) { + std::vector potential_cuda_roots = tensorflow::CandidateCudaRoots(); + + // CUDA location explicitly specified by user via --xla_gpu_cuda_data_dir has + // highest priority. + string xla_gpu_cuda_data_dir = + hlo_module_config.debug_options().xla_gpu_cuda_data_dir(); + if (!xla_gpu_cuda_data_dir.empty()) { + potential_cuda_roots.insert(potential_cuda_roots.begin(), + xla_gpu_cuda_data_dir); + } + return potential_cuda_roots; +} + +// Returns the directory containing nvvm libdevice files. +string GetLibdeviceDir(const HloModuleConfig& hlo_module_config) { + for (const string& cuda_root : GetCudaRootCandidates(hlo_module_config)) { + string libdevice_dir = + tensorflow::io::JoinPath(cuda_root, "nvvm", "libdevice"); + VLOG(2) << "Looking for libdevice at " << libdevice_dir; + if (tensorflow::Env::Default()->IsDirectory(libdevice_dir).ok()) { + VLOG(2) << "Found libdevice dir " << libdevice_dir; + return libdevice_dir; } - VLOG(2) << "Unable to find potential libdevice dir " - << potential_libdevice_dir; } - LOG(WARNING) << "Unable to find libdevice dir. Using '.'"; // Last resort: maybe in the current folder. return "."; @@ -478,14 +484,19 @@ void WarnIfBadDriverJITVersion() { // Compiles the given PTX string using ptxas and returns the resulting machine // code (i.e. a cubin) as a byte array. -StatusOr> CompilePtx(const string& ptx, int cc_major, - int cc_minor, - bool disable_ptx_optimizations) { +StatusOr> CompilePtx( + const string& ptx, int cc_major, int cc_minor, + const HloModuleConfig& hlo_module_config) { tracing::ScopedActivity activity("Compile PTX", /*is_expensive=*/true); - const string ptxas_path = - tensorflow::io::JoinPath(tensorflow::CudaRoot(), "bin", "ptxas"); - VLOG(2) << "Checking ptxas at " << ptxas_path; auto env = tensorflow::Env::Default(); + string ptxas_path; + for (const string& cuda_root : GetCudaRootCandidates(hlo_module_config)) { + ptxas_path = tensorflow::io::JoinPath(cuda_root, "bin", "ptxas"); + VLOG(2) << "Looking for ptxas at " << ptxas_path; + if (env->FileExists(ptxas_path).ok()) { + break; + } + } TF_RETURN_IF_ERROR(env->FileExists(ptxas_path)); VLOG(2) << "Using ptxas at " << ptxas_path; @@ -520,7 +531,7 @@ StatusOr> CompilePtx(const string& ptx, int cc_major, if (VLOG_IS_ON(2)) { ptxas_args.push_back("-v"); } - if (disable_ptx_optimizations) { + if (hlo_module_config.debug_options().xla_gpu_disable_ptxas_optimizations()) { ptxas_args.push_back("-O0"); } ptxas_info_dumper.SetProgram(ptxas_path, ptxas_args); @@ -685,12 +696,8 @@ StatusOr> NVPTXCompiler::RunBackend( // Find the directory containing libdevice. To avoid searching for it every // time, we have a one-element cache, keyed on the module's config's // cuda_data_dir. - const auto& config_cuda_data_dir = - module->config().debug_options().xla_gpu_cuda_data_dir(); - if (cached_libdevice_dir_.empty() || - cached_cuda_data_dir_ != config_cuda_data_dir) { - cached_cuda_data_dir_ = config_cuda_data_dir; - cached_libdevice_dir_ = GetLibdeviceDir(config_cuda_data_dir); + if (cached_libdevice_dir_.empty()) { + cached_libdevice_dir_ = GetLibdeviceDir(module->config()); } libdevice_dir = cached_libdevice_dir_; } @@ -743,9 +750,8 @@ StatusOr> NVPTXCompiler::RunBackend( } } - const std::vector cubin = CompilePtxOrGetCachedResult( - ptx, cc_major, cc_minor, - module->config().debug_options().xla_gpu_disable_ptxas_optimizations()); + const std::vector cubin = + CompilePtxOrGetCachedResult(ptx, cc_major, cc_minor, module->config()); auto thunk_schedule = absl::make_unique( ir_emitter.ConsumeThunkSequence(), std::move(stream_assignment), @@ -779,7 +785,7 @@ StatusOr> NVPTXCompiler::RunBackend( std::vector NVPTXCompiler::CompilePtxOrGetCachedResult( const string& ptx, int cc_major, int cc_minor, - bool disable_ptx_optimizations) { + const HloModuleConfig& hlo_module_config) { XLA_SCOPED_LOGGING_TIMER("NVPTXCompiler::CompilePtxOrGetCachedResult"); tracing::ScopedActivity activity("PTX->CUBIN", /*is_expensive=*/true); bool inserted; @@ -807,8 +813,8 @@ std::vector NVPTXCompiler::CompilePtxOrGetCachedResult( if (inserted) { CHECK(!cache_value->compilation_done); if (!ptx.empty()) { - StatusOr> maybe_cubin = CompilePtx( - *cache_ptx, cc_major, cc_minor, disable_ptx_optimizations); + StatusOr> maybe_cubin = + CompilePtx(*cache_ptx, cc_major, cc_minor, hlo_module_config); if (maybe_cubin.ok()) { cache_value->cubin_data = std::move(maybe_cubin).ValueOrDie(); VLOG(2) << "Compiled PTX size:" << ptx.size() diff --git a/tensorflow/compiler/xla/service/gpu/nvptx_compiler.h b/tensorflow/compiler/xla/service/gpu/nvptx_compiler.h index be5e31a50112686841e6f18b76f382a56e61bafc..b2077f42fd097330703fde063d80a20704fa48e2 100644 --- a/tensorflow/compiler/xla/service/gpu/nvptx_compiler.h +++ b/tensorflow/compiler/xla/service/gpu/nvptx_compiler.h @@ -99,7 +99,7 @@ class NVPTXCompiler : public LLVMCompiler { // compiled cubin. If compilation was unsuccessful, returns an empty vector. std::vector CompilePtxOrGetCachedResult( const string& ptx, int cc_major, int cc_minor, - bool disable_ptx_optimizations); + const HloModuleConfig& hlo_module_config); // The compilation_cache_ map is a cache from {ptx string, cc_major, cc_minor} // -> cubin so we don't recompile the same ptx twice. This is important for diff --git a/tensorflow/compiler/xla/service/gpu/partition_assignment.cc b/tensorflow/compiler/xla/service/gpu/partition_assignment.cc index 375f68a15957936151aee068582a714b62694af2..bfed4f5230dfe37bca48560ce83a2dd82c8950a4 100644 --- a/tensorflow/compiler/xla/service/gpu/partition_assignment.cc +++ b/tensorflow/compiler/xla/service/gpu/partition_assignment.cc @@ -39,6 +39,25 @@ std::ostream& operator<<(std::ostream& out, return out; } +int64 ThreadsPerBlockLimit(const se::DeviceDescription& device_desc) { + int64 threads_per_block = device_desc.threads_per_block_limit(); + if (threads_per_block == 0) { + static std::atomic log_count{0}; + if (log_count.fetch_add(1) < 8) { + LOG(WARNING) << "Attempting to calculate launch dimensions for GPU " + "without full information about its capabilities. " + "StreamExecutor's PopulateDeviceDescription should be " + "updated for this device."; + } + threads_per_block = device_desc.threads_per_warp(); + if (threads_per_block == 0) { + // Fall back to *something* if we can't even get num threads per warp. + threads_per_block = 32; + } + } + return threads_per_block; +} + // Calculates the launch dimensions used to invoke `hlo`. LaunchDimensions CalculateLaunchDimensions( const Shape& shape, const se::DeviceDescription& device_desc, @@ -62,21 +81,7 @@ LaunchDimensions CalculateLaunchDimensions( // // * = - int64 threads_per_block = device_desc.threads_per_block_limit(); - if (threads_per_block == 0) { - static std::atomic log_count{0}; - if (log_count.fetch_add(1) < 8) { - LOG(WARNING) << "Attempting to calculate launch dimensions for GPU " - "without full information about its capabilities. " - "StreamExecutor's PopulateDeviceDescription should be " - "updated for this device."; - } - threads_per_block = device_desc.threads_per_warp(); - if (threads_per_block == 0) { - // Fall back to *something* if we can't even get num threads per warp. - threads_per_block = 32; - } - } + int64 threads_per_block = ThreadsPerBlockLimit(device_desc); if (num_elements < threads_per_block) { threads_per_block = num_elements; diff --git a/tensorflow/compiler/xla/service/gpu/partition_assignment.h b/tensorflow/compiler/xla/service/gpu/partition_assignment.h index 02471129e004b4876ce20a62cade34060c65b478..eb41dcccb938ccc088c2371def96ca73276771ab 100644 --- a/tensorflow/compiler/xla/service/gpu/partition_assignment.h +++ b/tensorflow/compiler/xla/service/gpu/partition_assignment.h @@ -57,6 +57,9 @@ class LaunchDimensions { std::ostream& operator<<(std::ostream& out, const LaunchDimensions& launch_dims); +// Returns the maximum number of threads per block allowed by the device. +int64 ThreadsPerBlockLimit(const se::DeviceDescription& device_desc); + LaunchDimensions CalculateLaunchDimensions( const Shape& shape, const se::DeviceDescription& device_desc, int unroll_factor = 1); diff --git a/tensorflow/compiler/xla/service/gpu/stream_executor_util.h b/tensorflow/compiler/xla/service/gpu/stream_executor_util.h index 1fc46bafa10e7ba6c896f081d5c836bd400886c9..92e4d6dbbc1bd564657f8a5de09d23d5ae81a93e 100644 --- a/tensorflow/compiler/xla/service/gpu/stream_executor_util.h +++ b/tensorflow/compiler/xla/service/gpu/stream_executor_util.h @@ -16,6 +16,7 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_STREAM_EXECUTOR_UTIL_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_STREAM_EXECUTOR_UTIL_H_ +#include "tensorflow/compiler/xla/layout.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" diff --git a/tensorflow/compiler/xla/service/hlo_computation.cc b/tensorflow/compiler/xla/service/hlo_computation.cc index ff122b529bdcdcc69d2245136e19101902dbf957..ca663b8b4a970900a4a899a7ad9d33dc45af9d99 100644 --- a/tensorflow/compiler/xla/service/hlo_computation.cc +++ b/tensorflow/compiler/xla/service/hlo_computation.cc @@ -711,8 +711,6 @@ bool HloComputation::operator==(const HloComputation& other) const { return eq(root_instruction(), other.root_instruction()); } -uint64 HloComputation::Hash() const { return root_instruction()->Hash(); } - Status HloComputation::ReplaceWithNewInstruction( HloInstruction* old_instruction, std::unique_ptr new_instruction) { @@ -797,7 +795,7 @@ Status HloComputation::AcceptWithOperandOrder( template Status HloComputation::AcceptOrdered( DfsHloVisitorBase* visitor, - const std::vector& order) const { + absl::Span order) const { VLOG(3) << "Accepting visitor with order."; for (HloInstruction* root : CollectUnreachableRoots()) { TF_RET_CHECK(std::find(order.begin(), order.end(), root) != order.end()) @@ -827,9 +825,9 @@ Status HloComputation::AcceptOrdered( // Explicit instantiations. template Status HloComputation::AcceptOrdered( - DfsHloVisitor*, const std::vector&) const; + DfsHloVisitor*, absl::Span) const; template Status HloComputation::AcceptOrdered( - ConstDfsHloVisitor*, const std::vector&) const; + ConstDfsHloVisitor*, absl::Span) const; Status HloComputation::Accept( const std::function& visitor_func) { diff --git a/tensorflow/compiler/xla/service/hlo_computation.h b/tensorflow/compiler/xla/service/hlo_computation.h index c584e4c7ca5770533f28352b0df9dadd9dbe1860..5467d0a68b18170891dcd9f67e44d3bb269bf920 100644 --- a/tensorflow/compiler/xla/service/hlo_computation.h +++ b/tensorflow/compiler/xla/service/hlo_computation.h @@ -264,12 +264,6 @@ class HloComputation { // Return whether `*this` and `other` are functionally equivalent. bool operator==(const HloComputation& other) const; - // Generates a hash value of an HLO computation. Hash considers - // information on opcode, shape, operands, and typically a root instruction. - // This function returns the same hash value for equivalent HLO computations, - // with respect to HloInstruction::Identical() method. - uint64 Hash() const; - // Replaces old instruction with newly created instruction. Removes old // instruction from computation. Updates uses and root instruction. Status ReplaceWithNewInstruction( @@ -307,7 +301,7 @@ class HloComputation { // be a topological sort of all instructions in the computation. template Status AcceptOrdered(DfsHloVisitorBase* visitor, - const std::vector& order) const; + absl::Span order) const; // Same as Accept() above, but the visitor is given as a function. Status Accept(const std::function& visitor_func); diff --git a/tensorflow/compiler/xla/service/hlo_constant_folding_test.cc b/tensorflow/compiler/xla/service/hlo_constant_folding_test.cc index 4f81dc94e577a63c09ae4019e5e8158252c712ce..92b748d813c3efef83ef0155f1d5d3c637ce2c57 100644 --- a/tensorflow/compiler/xla/service/hlo_constant_folding_test.cc +++ b/tensorflow/compiler/xla/service/hlo_constant_folding_test.cc @@ -252,7 +252,7 @@ const char* const kConstantFoldLargePad = R"( HloModule ConstantFoldLargePad ENTRY r { - a = f32[1,1,1] constant(f32[1,1,1]{{{7}}}) + a = f32[1,1,1] constant({{{7}}}) b = f32[] constant(42) ROOT pad = f32[2048,2048,128] pad(a, b), padding=1024_1023x1024_1023x64_63 })"; diff --git a/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc b/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc index f7a1f19a6f52befd58a405d0e406d7d0d37a8e57..94de7c55dd2402e55ec344b79c24af2d8283fe73 100644 --- a/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc +++ b/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc @@ -1882,8 +1882,8 @@ TEST_P(HloDataflowAnalysisTest, AddDependency) { HloModule AddDependency ENTRY %AddDependency (p: f32[3]) -> f32[3] { %p = f32[3] parameter(0) - %token = token[] after-all() - ROOT %add_dep = f32[3] add-dependency(f32[3] %p, token[] %token) + %token0 = token[] after-all() + ROOT %add_dep = f32[3] add-dependency(f32[3] %p, token[] %token0) } )"; TF_ASSERT_OK_AND_ASSIGN( diff --git a/tensorflow/compiler/xla/service/hlo_domain_test.cc b/tensorflow/compiler/xla/service/hlo_domain_test.cc index acdb42128e3d9a1fb912a466c9c2c3cbbe3d3f83..fd4fb0246d8d42ab7329c05dc23e386303cdce3c 100644 --- a/tensorflow/compiler/xla/service/hlo_domain_test.cc +++ b/tensorflow/compiler/xla/service/hlo_domain_test.cc @@ -195,10 +195,10 @@ HloModule Module ENTRY entry { p0 = (f32[4]) parameter(0) a = f32[4] get-tuple-element(p0), index=0 - token = token[] after-all() - b = (f32[4], u32[], token[]) send(a, token), channel_id=1, sharding={maximal device=0} + token0 = token[] after-all() + b = (f32[4], u32[], token[]) send(a, token0), channel_id=1, sharding={maximal device=0} c = token[] send-done(b), channel_id=1, sharding={maximal device=0} - d = (f32[4], u32[], token[]) recv(token), channel_id=2, sharding={maximal device=0} + d = (f32[4], u32[], token[]) recv(token0), channel_id=2, sharding={maximal device=0} e = (f32[4], token[]) recv-done(d), channel_id=2, sharding={maximal device=0} e_element = f32[4] get-tuple-element(e), index=0, sharding={maximal device=0} f = f32[4] add(a, e_element) @@ -235,12 +235,12 @@ TEST_F(HloDomainTest, CheckNoDomainAddedOnPureIOComputation) { HloModule Module ENTRY entry { - token = token[] after-all(), sharding={maximal device=-1} - a = (f32[4], u32[], token[]) recv(token), channel_id=1, sharding={maximal device=-1} + token0 = token[] after-all(), sharding={maximal device=-1} + a = (f32[4], u32[], token[]) recv(token0), channel_id=1, sharding={maximal device=-1} b = (f32[4], token[]) recv-done(a), channel_id=1, sharding={maximal device=-1} b_element = f32[4] get-tuple-element(b), index=0, sharding={maximal device=-1} c = f32[4] add(b_element, b_element), sharding={maximal device=-1} - d = (f32[4], u32[], token[]) send(c, token), channel_id=2, sharding={maximal device=-1} + d = (f32[4], u32[], token[]) send(c, token0), channel_id=2, sharding={maximal device=-1} ROOT e = token[] send-done(d), channel_id=2, sharding={maximal device=-1} } )"; @@ -259,12 +259,12 @@ TEST_F(HloDomainTest, CheckNormalizationOnPureIOComputation) { HloModule Module ENTRY entry { - token = token[] after-all(), sharding={maximal device=0} - a = (f32[4], u32[], token[]) recv(token), channel_id=1, sharding={maximal device=0} + token0 = token[] after-all(), sharding={maximal device=0} + a = (f32[4], u32[], token[]) recv(token0), channel_id=1, sharding={maximal device=0} b = (f32[4], token[]) recv-done(a), channel_id=1, sharding={maximal device=0} b_element = f32[4] get-tuple-element(b), index=0, sharding={maximal device=0} c = f32[4] add(b_element, b_element) - d = (f32[4], u32[], token[]) send(c, token), channel_id=2, sharding={maximal device=0} + d = (f32[4], u32[], token[]) send(c, token0), channel_id=2, sharding={maximal device=0} ROOT e = token[] send-done(d), channel_id=2, sharding={maximal device=0} } )"; @@ -344,8 +344,8 @@ TEST_F(HloDomainTest, CheckNormalizationOnInfeedTuple) { HloModule Module ENTRY entry { - token = token[] after-all() - infeed = ((f32[4], f32[4]), token[]) infeed(token), + token0 = token[] after-all() + infeed = ((f32[4], f32[4]), token[]) infeed(token0), sharding={{maximal device=1}, {maximal device=0}, {maximal device=0}} infeed.data = (f32[4], f32[4]) get-tuple-element(infeed), index=0, sharding={{maximal device=1}, {maximal device=0}} diff --git a/tensorflow/compiler/xla/service/hlo_element_type_converter_test.cc b/tensorflow/compiler/xla/service/hlo_element_type_converter_test.cc index c170e36c73ad2bef830e528de3ec72d38683d888..a3b56a44a0b02923585c1dcb69571479236188a3 100644 --- a/tensorflow/compiler/xla/service/hlo_element_type_converter_test.cc +++ b/tensorflow/compiler/xla/service/hlo_element_type_converter_test.cc @@ -57,10 +57,10 @@ TEST_F(HloElementTypeConverterTest, InfeedsOutfeedsNotConverted) { const string& hlo_string = R"( HloModule InfeedOutfeed ENTRY RoundTrip16MiBR1.v2 { - token = token[] after-all() - infeed = (bf16[4]{0}, token[]) infeed(token) + token0 = token[] after-all() + infeed = (bf16[4]{0}, token[]) infeed(token0) ROOT infeed.data = bf16[4]{0} get-tuple-element(infeed), index=0 - outfeed = token[] outfeed(infeed.data, token) + outfeed = token[] outfeed(infeed.data, token0) } )"; auto module = CreateModuleFromHloString(hlo_string); @@ -96,13 +96,13 @@ TEST_F(HloElementTypeConverterTest, BatchNormGradBF16Converted) { const string& hlo_string = R"( HloModule BatchNormGrad ENTRY BatchNormGrad.v6 { - constant.4 = bf16[2,2,2,1]{3,2,1,0} constant(bf16[2,2,2,1] { { /*i0=0*/ + constant.4 = bf16[2,2,2,1]{3,2,1,0} constant({ { /*i0=0*/ { /*i1=0*/ {0}, {0} }, { /*i1=1*/ {0}, {0} } }, { /*i0=1*/ { /*i1=0*/ {0}, {0} }, { /*i1=1*/ {0}, {0} } } }) constant.5 = bf16[2]{0} constant({1, 1}) constant.6 = bf16[2]{0} constant({0, 0}) constant.7 = bf16[2]{0} constant({1, 1}) - constant.8 = bf16[2,2,2,1]{3,2,1,0} constant(bf16[2,2,2,1] { { /*i0=0*/ + constant.8 = bf16[2,2,2,1]{3,2,1,0} constant({ { /*i0=0*/ { /*i1=0*/ {1}, {2} }, { /*i1=1*/ {3}, {4} } }, { /*i0=1*/ { /*i1=0*/ {5}, {6} }, { /*i1=1*/ {7}, {8} } } }) ROOT batch-norm-grad = (bf16[2,2,2,1]{3,2,1,0}, bf16[2]{0}, bf16[2]{0}) diff --git a/tensorflow/compiler/xla/service/hlo_evaluator.cc b/tensorflow/compiler/xla/service/hlo_evaluator.cc index 3a7652a8dc856b23c8988c4676916c8199e78860..934c082bb9f003b1d2d80835f09a8f4109c7e7fd 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator.cc +++ b/tensorflow/compiler/xla/service/hlo_evaluator.cc @@ -33,6 +33,7 @@ limitations under the License. #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/map_util.h" #include "tensorflow/compiler/xla/primitive_util.h" +#include "tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul.h" #include "tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" @@ -629,8 +630,11 @@ Status HloEvaluator::HandleCompare(HloInstruction* compare) { evaluated_[compare], Compare(compare->shape(), opcode, lhs_literal, rhs_literal)); } break; - case F16: - return Unimplemented("unhandled primitive type: F16."); + case F16: { + TF_ASSIGN_OR_RETURN( + evaluated_[compare], + Compare(compare->shape(), opcode, lhs_literal, rhs_literal)); + } break; case BF16: { TF_ASSIGN_OR_RETURN(evaluated_[compare], Compare(compare->shape(), opcode, @@ -1449,4 +1453,46 @@ template StatusOr HloEvaluator::Evaluate( template StatusOr HloEvaluator::Evaluate( HloInstruction* instruction, absl::Span arg_literals); +namespace { +template +std::unique_ptr> MatmulArray2DImpl( + const Array2D& lhs, const Array2D& rhs, + const std::function& impl_fn) { + CHECK_EQ(lhs.width(), rhs.height()); + int m = lhs.height(); + int n = rhs.width(); + int k = lhs.width(); + auto result = absl::make_unique>(m, n); + // Because Eigen is a header-oriented library, make sure that the Eigen code + // is the same as the code used by the CPU backend (otherwise the linker will + // randomly pick *some* definition). + impl_fn( + /*run_options_ptr=*/nullptr, result->data(), rhs.data(), lhs.data(), n, m, + k, + /*transpose_lhs=*/0, + /*transpose_rhs=*/0); + return result; +} +} // namespace + +std::unique_ptr> HloEvaluator::MatmulArray2D( + const Array2D& lhs, const Array2D& rhs) { + return MatmulArray2DImpl( + lhs, rhs, __xla_cpu_runtime_EigenSingleThreadedMatMulF16); +} + +std::unique_ptr> HloEvaluator::MatmulArray2D( + const Array2D& lhs, const Array2D& rhs) { + return MatmulArray2DImpl( + lhs, rhs, __xla_cpu_runtime_EigenSingleThreadedMatMulF32); +} + +std::unique_ptr> HloEvaluator::MatmulArray2D( + const Array2D& lhs, const Array2D& rhs) { + return MatmulArray2DImpl( + lhs, rhs, __xla_cpu_runtime_EigenSingleThreadedMatMulF64); +} + } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_evaluator.h b/tensorflow/compiler/xla/service/hlo_evaluator.h index 45ed8131dc6b71f706fce45d65b206363dd79ac3..d363a51c63de6fd4246c4970f580b68f4a627df8 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator.h +++ b/tensorflow/compiler/xla/service/hlo_evaluator.h @@ -21,6 +21,7 @@ limitations under the License. #include "absl/container/node_hash_map.h" #include "absl/memory/memory.h" #include "absl/types/span.h" +#include "tensorflow/compiler/xla/array2d.h" #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" @@ -119,6 +120,17 @@ class HloEvaluator : public DfsHloVisitorWithDefault { const PrecisionConfig& precision_config, const Literal& lhs, const Literal& rhs); + // Enable the fast path for certain operations like dot or convolution. + void set_use_fast_path(bool value) { use_fast_path_ = value; } + + // Returns the result of a matrix multiply `lhs x rhs`. + static std::unique_ptr> MatmulArray2D( + const Array2D& lhs, const Array2D& rhs); + static std::unique_ptr> MatmulArray2D( + const Array2D& lhs, const Array2D& rhs); + static std::unique_ptr> MatmulArray2D( + const Array2D& lhs, const Array2D& rhs); + protected: // Make HloEvaluatorTypedVisitor a friend because it is logically part of this // class. @@ -217,6 +229,9 @@ class HloEvaluator : public DfsHloVisitorWithDefault { // we cannot use flat_hash_map any more. absl::node_hash_map evaluated_; + // Use fast path that uses eigen in the evaluator. + bool use_fast_path_ = false; + private: template static StatusOr ElementWiseUnaryOpImpl( @@ -250,6 +265,8 @@ class HloEvaluator : public DfsHloVisitorWithDefault { TF_DISALLOW_COPY_AND_ASSIGN(HloEvaluator); }; +std::unique_ptr> MatmulArray2D(const Array2D& lhs, + const Array2D& rhs); } // namespace xla #endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_EVALUATOR_H_ diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h index b87fc3e34012e75ee07bff6c1e113dce404f83cb..03d42990ce9dcd3f689831078354f878bcb0800f 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h +++ b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h @@ -23,6 +23,7 @@ limitations under the License. #include "absl/container/inlined_vector.h" #include "absl/memory/memory.h" #include "absl/types/optional.h" +#include "tensorflow/compiler/xla/array2d.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/service/hlo_casting_utils.h" #include "tensorflow/compiler/xla/service/hlo_evaluator.h" @@ -105,6 +106,12 @@ bool SafeLess(const NativeT& a, const NativeT& b) { template class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { private: + Status UnsupportedTypeError(HloInstruction* instruction) { + return InvalidArgument( + "Unsupported type for %s: %s", HloOpcodeString(instruction->opcode()), + PrimitiveType_Name(instruction->shape().element_type())); + } + // Get the value in the given literal static_cast as a double. template < typename NativeT, @@ -224,7 +231,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { typename NativeT, typename std::enable_if::value>::type* = nullptr> Status HandleRound(HloInstruction* round) { - return InvalidArgument("Unsupported type for Round"); + return UnsupportedTypeError(round); } Status HandleRound(HloInstruction* round) override { @@ -246,7 +253,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { typename NativeT, typename std::enable_if::value>::type* = nullptr> Status HandleCeil(HloInstruction* ceil) { - return InvalidArgument("Unsupported type for Ceil"); + return UnsupportedTypeError(ceil); } Status HandleCeil(HloInstruction* ceil) override { @@ -297,8 +304,8 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { template < typename NativeT, typename std::enable_if::value>::type* = nullptr> - Status HandleExpm1(HloInstruction* floor) { - return InvalidArgument("Unsupported type for Expm1"); + Status HandleExpm1(HloInstruction* expm1) { + return UnsupportedTypeError(expm1); } Status HandleExpm1(HloInstruction* floor) override { @@ -321,7 +328,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { typename NativeT, typename std::enable_if::value>::type* = nullptr> Status HandleFloor(HloInstruction* floor) { - return InvalidArgument("Unsupported type for Floor"); + return UnsupportedTypeError(floor); } Status HandleFloor(HloInstruction* floor) override { @@ -351,12 +358,12 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { template < typename NativeT, typename std::enable_if::value>::type* = nullptr> - Status HandleLog1p(HloInstruction* floor) { - return InvalidArgument("Unsupported type for Log1p"); + Status HandleLog1p(HloInstruction* log1p) { + return UnsupportedTypeError(log1p); } - Status HandleLog1p(HloInstruction* floor) override { - return HandleLog1p(floor); + Status HandleLog1p(HloInstruction* log1p) override { + return HandleLog1p(log1p); } template ::value>::type* = nullptr> Status HandleNot(HloInstruction* not_) { - return InvalidArgument("Unsupported type for Not"); + return UnsupportedTypeError(not_); } Status HandleNot(HloInstruction* not_) override { @@ -476,7 +483,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { template ::value>::type* = nullptr> Status HandleAtan2(HloInstruction* atan2) { - return InvalidArgument("Unsupported type for Atan2"); + return UnsupportedTypeError(atan2); } Status HandleAtan2(HloInstruction* atan2) override { @@ -624,7 +631,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { typename NativeT, typename std::enable_if::value>::type* = nullptr> Status HandleMaximum(HloInstruction* maximum) { - return InvalidArgument("Unsupported type for Maximum"); + return UnsupportedTypeError(maximum); } Status HandleMaximum(HloInstruction* maximum) override { @@ -659,7 +666,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { typename NativeT, typename std::enable_if::value>::type* = nullptr> Status HandleMinimum(HloInstruction* minimum) { - return InvalidArgument("Unsupported type for Minimum"); + return UnsupportedTypeError(minimum); } Status HandleMinimum(HloInstruction* minimum) override { @@ -724,7 +731,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { typename NativeT, typename std::enable_if::value>::type* = nullptr> Status HandleRemainder(HloInstruction* remainder) { - return InvalidArgument("Unsupported type for Remainder"); + return UnsupportedTypeError(remainder); } Status HandleRemainder(HloInstruction* remainder) override { @@ -746,14 +753,14 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { template ::value>::type* = nullptr> Status HandleAnd(HloInstruction* and_) { - return InvalidArgument("Unsupported type for And"); + return UnsupportedTypeError(and_); } template < typename NativeT, typename std::enable_if::value>::type* = nullptr> Status HandleAnd(HloInstruction* and_) { - return InvalidArgument("Unsupported type for And"); + return UnsupportedTypeError(and_); } Status HandleAnd(HloInstruction* and_) override { @@ -775,7 +782,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { template ::value>::type* = nullptr> Status HandleOr(HloInstruction* or_) { - return InvalidArgument("Unsupported type for Or"); + return UnsupportedTypeError(or_); } template < @@ -804,14 +811,14 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { template ::value>::type* = nullptr> Status HandleXor(HloInstruction* xor_) { - return InvalidArgument("Unsupported type for Xor"); + return UnsupportedTypeError(xor_); } template < typename NativeT, typename std::enable_if::value>::type* = nullptr> Status HandleXor(HloInstruction* xor_) { - return InvalidArgument("Unsupported type for Xor"); + return UnsupportedTypeError(xor_); } Status HandleXor(HloInstruction* xor_) override { @@ -836,8 +843,8 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { typename std::enable_if::value || std::is_same::value>::type* = nullptr> - Status HandleShiftLeft(HloInstruction*) { - return InvalidArgument("Unsupported type for ShiftLeft"); + Status HandleShiftLeft(HloInstruction* shift) { + return UnsupportedTypeError(shift); } Status HandleShiftLeft(HloInstruction* shl) override { @@ -866,8 +873,8 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { typename std::enable_if::value || std::is_same::value>::type* = nullptr> - Status HandleShiftRightArithmetic(HloInstruction*) { - return InvalidArgument("Unsupported type for ShiftRightArithmetic"); + Status HandleShiftRightArithmetic(HloInstruction* shift) { + return UnsupportedTypeError(shift); } Status HandleShiftRightArithmetic(HloInstruction* shra) override { @@ -897,8 +904,8 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { typename std::enable_if::value || std::is_same::value>::type* = nullptr> - Status HandleShiftRightLogical(HloInstruction*) { - return InvalidArgument("Unsupported type for ShiftRightLogical"); + Status HandleShiftRightLogical(HloInstruction* shift) { + return UnsupportedTypeError(shift); } Status HandleShiftRightLogical(HloInstruction* shrl) override { @@ -923,8 +930,8 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { template < typename NativeT, typename std::enable_if::value>::type* = nullptr> - Status HandleClamp(HloInstruction*) { - return InvalidArgument("Unsupported type for Clamp"); + Status HandleClamp(HloInstruction* clamp) { + return UnsupportedTypeError(clamp); } Status HandleClamp(HloInstruction* clamp) override { @@ -1148,6 +1155,78 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { } Status HandleDot(HloInstruction* dot) override { + if (parent_->use_fast_path_) { + return HandleDot(dot); + } + return HandleDotSlowPath(dot); + } + + template ::value>::type* = nullptr> + Status HandleDot(HloInstruction* dot) { + const HloInstruction* lhs = dot->operand(0); + const HloInstruction* rhs = dot->operand(1); + CHECK(ShapeUtil::IsArray(dot->shape())); + CHECK(ShapeUtil::IsArray(lhs->shape())); + CHECK(ShapeUtil::IsArray(rhs->shape())); + + const auto& dnums = dot->dot_dimension_numbers(); + + const int64 lhs_rank = ShapeUtil::Rank(lhs->shape()); + const int64 rhs_rank = ShapeUtil::Rank(rhs->shape()); + + CHECK(ShapeUtil::SameElementType(lhs->shape(), rhs->shape())); + CHECK(ShapeUtil::SameElementType(lhs->shape(), dot->shape())); + + // There must be 1 and only 1 Contracting dimension for lhs and rhs. + CHECK_EQ(dnums.lhs_contracting_dimensions_size(), 1); + CHECK_EQ(dnums.rhs_contracting_dimensions_size(), 1); + const int64 lhs_contracting_dimension = dnums.lhs_contracting_dimensions(0); + const int64 rhs_contracting_dimension = dnums.rhs_contracting_dimensions(0); + // Contracted dimension sizes must be the same. + CHECK_EQ(lhs->shape().dimensions(lhs_contracting_dimension), + rhs->shape().dimensions(rhs_contracting_dimension)) + << "lhs contracted dimension: " + << lhs->shape().dimensions(lhs_contracting_dimension) + << " rhs contracted dimension: " + << rhs->shape().dimensions(rhs_contracting_dimension); + + // The fast path is for a simple rank 2 dot with default layout operands. + if (lhs_rank == 2 && rhs_rank == 2 && lhs_contracting_dimension == 1 && + rhs_contracting_dimension == 0 && + LayoutUtil::Equal(lhs->shape().layout(), + LayoutUtil::GetDefaultLayoutForR2()) && + LayoutUtil::Equal(rhs->shape().layout(), + LayoutUtil::GetDefaultLayoutForR2()) && + LayoutUtil::Equal(dot->shape().layout(), + LayoutUtil::GetDefaultLayoutForR2())) { + const Literal& lhs_literal = parent_->GetEvaluatedLiteralFor(lhs); + const Literal& rhs_literal = parent_->GetEvaluatedLiteralFor(rhs); + const int64 contracted_dimension_size = + lhs->shape().dimensions(lhs_contracting_dimension); + Array2D lhs_array(lhs->shape().dimensions(0), + contracted_dimension_size); + lhs_array.SetValues(lhs_literal.data()); + Array2D rhs_array(contracted_dimension_size, + rhs->shape().dimensions(1)); + rhs_array.SetValues(rhs_literal.data()); + std::unique_ptr> result_array = + HloEvaluator::MatmulArray2D(lhs_array, rhs_array); + Literal result(dot->shape()); + result.PopulateR2FromArray2D(*result_array); + parent_->evaluated_[dot] = std::move(result); + return Status::OK(); + } + return HandleDotSlowPath(dot); + } + + template ::value>::type* = nullptr> + Status HandleDot(HloInstruction* dot) { + return HandleDotSlowPath(dot); + } + + Status HandleDotSlowPath(HloInstruction* dot) { auto lhs = dot->operand(0); auto rhs = dot->operand(1); CHECK(ShapeUtil::IsArray(dot->shape())); @@ -1578,7 +1657,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { std::is_same::value>::type* = nullptr> Status HandleSort(HloInstruction* sort) { - return InvalidArgument("Unsupported type for Sort"); + return UnsupportedTypeError(sort); } Status HandleSort(HloInstruction* sort) override { @@ -2357,7 +2436,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { std::is_same::value || std::is_same::value)>::type* = nullptr> Status HandleClz(HloInstruction* clz) { - return InvalidArgument("Unsupported type for Clz"); + return UnsupportedTypeError(clz); } template ::value || is_complex_t::value>::type* = nullptr> Status HandleSin(HloInstruction* sin) { - return InvalidArgument("Unsupported type for Sin"); + return UnsupportedTypeError(sin); } Status HandleSin(HloInstruction* sin) override { @@ -2425,7 +2504,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { typename std::enable_if::value || is_complex_t::value>::type* = nullptr> Status HandleCos(HloInstruction* cos) { - return InvalidArgument("Unsupported type for Cos"); + return UnsupportedTypeError(cos); } Status HandleCos(HloInstruction* cos) override { @@ -2534,7 +2613,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { typename std::enable_if::value || is_complex_t::value>::type* = nullptr> Status HandleReducePrecision(HloInstruction* reduce_precision) { - return InvalidArgument("Unsupported type for reduce precision"); + return UnsupportedTypeError(reduce_precision); } Status HandleReducePrecision(HloInstruction* reduce_precision) override { @@ -2543,15 +2622,27 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { template ::value || + std::is_same::value || std::is_integral::value || std::is_floating_point::value>::type* = nullptr> Status HandleIota(HloInstruction* instruction) { auto* iota = Cast(instruction); + const int64 iota_size = iota->shape().dimensions(iota->iota_dimension()); // Avoid using std::vector since std::vector does not convert to // absl::Span. - absl::InlinedVector data( - iota->shape().dimensions(iota->iota_dimension())); - std::iota(data.begin(), data.end(), 0); + absl::InlinedVector data(iota_size); + // We don't use std::iota for two reasons: + // + // (1) std:iota does not support bfloat16 and float16. + // + // (2) std::iota saturates for floating point types when the value is not + // representable, but the definition of HLO iota is the value as a + // 64-bit integer cast to the native type. + for (int64 i = 0; i < iota_size; ++i) { + // static_cast is required for Eigen::half (F16). + data[i] = static_cast(i); + } auto result = LiteralUtil::CreateR1(data); if (ShapeUtil::Rank(iota->shape()) > 1) { @@ -2567,10 +2658,12 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { } template ::value || + !(std::is_same::value || + std::is_same::value || + std::is_integral::value || std::is_floating_point::value)>::type* = nullptr> Status HandleIota(HloInstruction* iota) { - return InvalidArgument("Unsupported type for iota"); + return UnsupportedTypeError(iota); } Status HandleIota(HloInstruction* iota) override { return HandleIota(iota); diff --git a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc index 302eca656be53a3cec86ddbf05a7fa3925c5185b..5db21e47ca94af3b017e0401237692913365a48c 100644 --- a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc +++ b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc @@ -1474,14 +1474,15 @@ string ExportGraph(const string& graph, GraphRendererInterface::GraphKind graph_kind, const DebugOptions& debug_options) { string path = debug_options.xla_hlo_graph_path(); - if (!path.empty()) { + if (!path.empty() && !debug_options.xla_hlo_dump_as_html()) { return SaveGraph(graph, graph_kind, path); } else { auto graph_renderer = GraphRendererRegistry::Default()->GetDefaultRenderer(); CHECK(graph_renderer != nullptr) << "No registered renderer for the HLO graph. " - "Use --xla_hlo_graph_path=PATH to export to local file system"; + "Use --xla_hlo_graph_path=PATH --xla_hlo_dump_as_html=false to " + "export to local file system"; return graph_renderer->RenderGraph(graph, graph_kind, debug_options); } } @@ -1589,5 +1590,143 @@ string MaybeDumpHloModule(const HloModule& module, const string& label, return graph_url; } +string WrapDotInHTML(const string& dot) { + static const char html_prefix[] = R"html( + + + + + + + + + + + +
+ + + +)html"; + + return html_prefix + dot + html_suffix; +} + +string RenderDotAsHTMLFile(const string& dot, + const DebugOptions& debug_options) { + string html = WrapDotInHTML(dot); + + auto env = tensorflow::Env::Default(); + std::vector dirs; + string output_dir = debug_options.xla_hlo_graph_path(); + if (output_dir.empty()) { + env->GetLocalTempDirectories(&dirs); + } else { + dirs.push_back(output_dir); + } + // Try each directory, as they might be full, have inappropriate + // permissions or have different problems at times. + string output; + for (const string& dir : dirs) { + string filename = tensorflow::io::JoinPath(dir, "graph-"); + if (env->CreateUniqueFileName(&filename, ".html")) { + output = filename; + break; + } + } + if (output.empty()) { + LOG(FATAL) << "Failed to create unique output file name."; + } + TF_CHECK_OK(tensorflow::WriteStringToFile(env, output, html)); + return "file://" + output; +} + } // namespace hlo_graph_dumper } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_graph_dumper.h b/tensorflow/compiler/xla/service/hlo_graph_dumper.h index de1eefab776f9c3d2c73959a5cd267e938a78a32..8e51454ef1cf992386cc7325e32705c08bf7712f 100644 --- a/tensorflow/compiler/xla/service/hlo_graph_dumper.h +++ b/tensorflow/compiler/xla/service/hlo_graph_dumper.h @@ -81,6 +81,12 @@ string DumpAllPathsFromTo(const HloInstruction& from, const HloInstruction& to, void DumpText(const HloModule& module, const string& label, const string& directory_path, bool do_prefix = true); +// Renders DOT graph as inline SVG and saves it in an HTML file in a temprary +// directory or directory specified via --xla_hlo_graph_path. Returns the file +// URI pointing to the file. +string RenderDotAsHTMLFile(const string& dot, + const DebugOptions& debug_options); + // Graph renderers may be added using a registration mechanism, e.g.: // XLA_REGISTER_GRAPH_RENDERER(AGraphRendererClass, 100) // The renderer with the highest numeric priority value is used. diff --git a/tensorflow/compiler/xla/service/hlo_graph_html_renderer.cc b/tensorflow/compiler/xla/service/hlo_graph_html_renderer.cc new file mode 100644 index 0000000000000000000000000000000000000000..84c4cf18df69816c611f4eb159ba247320ebc20e --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_graph_html_renderer.cc @@ -0,0 +1,43 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Implementation of an DOT graph renderer that uses Javascript to render DOT to +// SVG in a browser. + +#include "tensorflow/compiler/xla/service/hlo_graph_dumper.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" + +namespace xla { +namespace hlo_graph_dumper { +namespace { + +class GraphHtmlRenderer : public GraphRendererInterface { + public: + string RenderGraph(const string& graph, GraphKind graph_kind, + const DebugOptions& debug_options) override { + switch (graph_kind) { + case DOT_GRAPH: + return RenderDotAsHTMLFile(graph, debug_options); + default: + LOG(FATAL) << "Only DOT graphs can be rendered"; + } + } +}; + +XLA_REGISTER_GRAPH_RENDERER(GraphHtmlRenderer); + +} // namespace +} // namespace hlo_graph_dumper +} // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc index 21b1dbc1676cccd2fe5b331a1f9d6ff5e3a73fcd..8b2ace1e82eff250f4d9f0d5630e9e6d646cfe6d 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction.cc @@ -569,6 +569,11 @@ StatusOr> HloInstruction::CreateFromProto( instruction->SetAndSanitizeName(proto.name()); instruction->metadata_ = proto.metadata(); instruction->backend_config_ = proto.backend_config(); + + TF_RET_CHECK(proto.id() >= 0) + << "Instruction with negative id: " << proto.id(); + TF_RET_CHECK(proto.id() <= INT_MAX) + << "Instruction with id > INT_MAX: " << proto.id(); instruction->unique_id_ = proto.id(); if (proto.has_sharding()) { @@ -914,12 +919,8 @@ HloInstruction::CreateDynamicUpdateSlice(const Shape& shape, HloInstruction* operand, HloInstruction* update, HloInstruction* start_indices) { - auto instruction = absl::WrapUnique( - new HloInstruction(HloOpcode::kDynamicUpdateSlice, shape)); - instruction->AppendOperand(operand); - instruction->AppendOperand(update); - instruction->AppendOperand(start_indices); - return instruction; + return absl::make_unique( + shape, operand, update, start_indices); } /* static */ std::unique_ptr HloInstruction::CreateConcatenate( @@ -1760,7 +1761,12 @@ bool HloInstruction::IdenticalSlowPath( return false; } -uint64 HloInstruction::Hash() const { +static uint64 HashOperand(const HloInstruction* hlo) { + return ShapeUtil::Hash(hlo->shape()); +} + +uint64 HloInstruction::Hash( + const std::function& hash_operand) const { using tensorflow::Hash64Combine; uint64 hash_value = Hash64Combine(0, static_cast(opcode())); @@ -1769,7 +1775,7 @@ uint64 HloInstruction::Hash() const { if (!IsCrossModuleAllReduce()) { if (!operands().empty()) { for (size_t i = 0; i < operands().size(); ++i) { - hash_value = Hash64Combine(hash_value, operand(i)->Hash()); + hash_value = Hash64Combine(hash_value, hash_operand(operand(i))); } } } @@ -1778,6 +1784,11 @@ uint64 HloInstruction::Hash() const { return hash_value; } +uint64 HloInstruction::Hash() const { + // Use HashOperand as an argument to prevent non-termination. + return Hash(HashOperand); +} + uint64 HloInstruction::InnerHash() const { return 13; } void HloInstruction::RemoveUser(HloInstruction* user) { @@ -2059,6 +2070,10 @@ bool HloInstruction::IsCrossModuleAllReduce() const { return opcode() == HloOpcode::kCrossReplicaSum && all_reduce_id(); } +bool HloInstruction::IsCrossReplicaAllReduce() const { + return opcode() == HloOpcode::kCrossReplicaSum && !all_reduce_id(); +} + string HloInstruction::ToStringWithCanonicalNameMap( const HloPrintOptions& options, CanonicalNameMap* canonical_name_map) const { diff --git a/tensorflow/compiler/xla/service/hlo_instruction.h b/tensorflow/compiler/xla/service/hlo_instruction.h index a54716217d6bbc5c0601f5d9ff7bf4072a6b30f5..dd77f101a049d7247dcf571d2d19cb4f74e2f8ea 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.h +++ b/tensorflow/compiler/xla/service/hlo_instruction.h @@ -909,6 +909,14 @@ class HloInstruction { // information on opcode, shape, operands, and typically a root instruction. // This function returns the same hash value for equivalent HLO instructions, // with respect to HloInstruction::Identical() method. + // + // Uses hash_operand function to compute hash values of its operands. + // At the very top level, hash_operand should be non-recursive to prevent + // non-termination. + uint64 Hash( + const std::function& hash_operand) const; + + // Calls the above method with non-recursive hash_operand function. uint64 Hash() const; // Returns whether the instruction has a constant operand. @@ -1174,9 +1182,12 @@ class HloInstruction { // Returns true if this instruction is elementwise on all its operands. bool IsElementwise() const; - // Returns true if this is an cross module all-reduce instrucion. + // Returns true if this is a cross module all-reduce instruction. bool IsCrossModuleAllReduce() const; + // Returns true if this is a cross-replica all-reduce instruction. + bool IsCrossReplicaAllReduce() const; + // Returns true if this elementwise instruction implicitly broadcasts operand // `operand_idx`. // diff --git a/tensorflow/compiler/xla/service/hlo_instructions.cc b/tensorflow/compiler/xla/service/hlo_instructions.cc index 1ea02cf9c03866a598bec0e5356f0eb31ad27755..5521e5bd9acefcd1cb7721ed55fe987189623404 100644 --- a/tensorflow/compiler/xla/service/hlo_instructions.cc +++ b/tensorflow/compiler/xla/service/hlo_instructions.cc @@ -905,7 +905,7 @@ string HloConstantInstruction::OperandsToStringWithCanonicalNameMap( options.print_large_constants())) { // Literal::ToString emits multidimensional arrays over multiple // lines. Compact this into one line by stripping out white space. - string tmp = literal().ToString(); + string tmp = literal().ToStringWithoutShape(); std::replace(tmp.begin(), tmp.end(), '\n', ' '); std::vector v = absl::StrSplit(tmp, ' '); bool first = true; @@ -1372,8 +1372,14 @@ bool HloFusionInstruction::IdenticalSlowPath( other.fused_instructions_computation()); } +static uint64 HashOperandRecursive(const HloInstruction* hlo) { + return hlo->Hash(HashOperandRecursive); +} + uint64 HloFusionInstruction::InnerHash() const { - return fused_instructions_computation()->Hash(); + // Use HashOperandRecursive to recursively compute hash on inner operands. + return fused_instructions_computation()->root_instruction()->Hash( + HashOperandRecursive); } std::unique_ptr HloFusionInstruction::CloneWithNewOperandsImpl( @@ -1994,12 +2000,21 @@ std::unique_ptr HloPadInstruction::CloneWithNewOperandsImpl( HloDynamicSliceInstruction::HloDynamicSliceInstruction( const Shape& shape, HloInstruction* operand, HloInstruction* start_indices, absl::Span slice_sizes) - : HloInstruction(HloOpcode::kDynamicSlice, shape), + : HloDynamicIndexInstruction(HloOpcode::kDynamicSlice, shape), dynamic_slice_sizes_(slice_sizes.begin(), slice_sizes.end()) { AppendOperand(operand); AppendOperand(start_indices); } +HloDynamicUpdateSliceInstruction::HloDynamicUpdateSliceInstruction( + const Shape& shape, HloInstruction* operand, HloInstruction* update, + HloInstruction* start_indices) + : HloDynamicIndexInstruction(HloOpcode::kDynamicUpdateSlice, shape) { + AppendOperand(operand); + AppendOperand(update); + AppendOperand(start_indices); +} + HloInstructionProto HloDynamicSliceInstruction::ToProto() const { HloInstructionProto proto = HloInstruction::ToProto(); for (int64 slice_size : dynamic_slice_sizes_) { diff --git a/tensorflow/compiler/xla/service/hlo_instructions.h b/tensorflow/compiler/xla/service/hlo_instructions.h index b5c28137a145667a977d39c9d3c40c6d36a8436e..5420d4ce11f4bdd068e82f208a98e9943ad4479e 100644 --- a/tensorflow/compiler/xla/service/hlo_instructions.h +++ b/tensorflow/compiler/xla/service/hlo_instructions.h @@ -1171,7 +1171,14 @@ class HloPadInstruction : public HloInstruction { PaddingConfig padding_config_; }; -class HloDynamicSliceInstruction : public HloInstruction { +class HloDynamicIndexInstruction : public HloInstruction { + public: + explicit HloDynamicIndexInstruction(HloOpcode opcode, const Shape& shape) + : HloInstruction(opcode, shape) {} + virtual int64 index_operand_number() const = 0; +}; + +class HloDynamicSliceInstruction : public HloDynamicIndexInstruction { public: explicit HloDynamicSliceInstruction(const Shape& shape, HloInstruction* operand, @@ -1189,6 +1196,8 @@ class HloDynamicSliceInstruction : public HloInstruction { // Returns a serialized representation of this instruction. HloInstructionProto ToProto() const override; + int64 index_operand_number() const override { return 1; } + private: std::vector ExtraAttributesToStringImpl( const HloPrintOptions& options) const override; @@ -1206,6 +1215,16 @@ class HloDynamicSliceInstruction : public HloInstruction { std::vector dynamic_slice_sizes_; }; +class HloDynamicUpdateSliceInstruction : public HloDynamicIndexInstruction { + public: + explicit HloDynamicUpdateSliceInstruction(const Shape& shape, + HloInstruction* operand, + HloInstruction* update, + HloInstruction* start_indices); + + int64 index_operand_number() const override { return 2; } +}; + class HloGatherInstruction : public HloInstruction { public: explicit HloGatherInstruction( diff --git a/tensorflow/compiler/xla/service/hlo_lexer.cc b/tensorflow/compiler/xla/service/hlo_lexer.cc index 1390537101e95a08e4ba4eef7ae8d6059a40e916..dc712e5e42c449737bf4415f3a5e3eb9d81d9be4 100644 --- a/tensorflow/compiler/xla/service/hlo_lexer.cc +++ b/tensorflow/compiler/xla/service/hlo_lexer.cc @@ -19,6 +19,7 @@ limitations under the License. #include "absl/strings/escaping.h" #include "absl/strings/numbers.h" +#include "absl/strings/str_split.h" #include "absl/types/optional.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/statusor.h" @@ -82,9 +83,23 @@ tensorflow::RegexpStringPiece HloLexer::RegexpStringPieceFromPointers( return tensorflow::RegexpStringPiece(begin, end - begin); } +TokKind HloLexer::LookAhead() { + if (GetKind() == TokKind::kEof || GetKind() == TokKind::kError) { + return GetKind(); + } + + const char* old_current_ptr = current_ptr_; + TokenState old_token_state = token_state_; + Lex(); + TokKind kind = GetKind(); + token_state_ = old_token_state; + current_ptr_ = old_current_ptr; + return kind; +} + TokKind HloLexer::LexToken() { while (true) { - token_start_ = current_ptr_; + token_state_.token_start = current_ptr_; int current_char = GetNextChar(); switch (current_char) { @@ -206,43 +221,37 @@ TokKind HloLexer::LexToken() { // dim_labels_pattern ::= [0-9bf]{2,}_[0-9io]{2,}->[0-9bf]{2,} // identifiers ::= other cases that match [a-zA-Z_][a-zA-Z0-9_.-]* TokKind HloLexer::LexIdentifier() { - { - auto consumable = RegexpStringPieceFromPointers(token_start_, buf_.end()); - // 'consumable' will be advanced iff its prefix matches the pattern. - static LazyRE2 shape_pattern = { - R"(^(\w*\d*)\[([\d,\s]*)\](?:(dense|sparse)?{([\d,\s]+)})?)"}; - if (RE2::Consume(&consumable, *shape_pattern)) { - auto status_or_shape = ShapeUtil::ParseShapeString( - StringPieceFromPointers(token_start_, consumable.begin())); - if (status_or_shape.ok()) { - // This is a shape string. - shape_val_ = status_or_shape.ValueOrDie(); - current_ptr_ = consumable.begin(); - return TokKind::kShape; - } - } - } - while (IsIdentifierChar(PeekCurrentChar())) { current_ptr_++; } // If followed by ':', it's a name. if (PeekCurrentChar() == ':') { - str_val_.assign(token_start_, current_ptr_); + token_state_.str_val.assign(token_state_.token_start, current_ptr_); current_ptr_++; // skip ':' return TokKind::kName; } // If followed by '=', it's a attribute name. if (PeekCurrentChar() == '=') { - str_val_.assign(token_start_, current_ptr_); + token_state_.str_val.assign(token_state_.token_start, current_ptr_); current_ptr_++; // skip '=' return TokKind::kAttributeName; } absl::string_view identifier = - StringPieceFromPointers(token_start_, current_ptr_); + StringPieceFromPointers(token_state_.token_start, current_ptr_); + + // Primitive type strings are reserved words. The exception is 'tuple' whose + // type is represented using nested parentheses without the string 'tuple'. + if (primitive_util::IsPrimitiveTypeName(identifier)) { + PrimitiveType primitive_type = + primitive_util::StringToPrimitiveType(identifier).ValueOrDie(); + if (primitive_type != TUPLE) { + token_state_.primitive_type_val = primitive_type; + return TokKind::kPrimitiveType; + } + } // See if this is a keyword. #define KEYWORD(STR) \ @@ -261,21 +270,23 @@ TokKind HloLexer::LexIdentifier() { KEYWORD(ROOT); KEYWORD(maximal); KEYWORD(replicated); + KEYWORD(sparse); #undef KEYWORD { - auto consumable = RegexpStringPieceFromPointers(token_start_, buf_.end()); + auto consumable = + RegexpStringPieceFromPointers(token_state_.token_start, buf_.end()); static LazyRE2 dim_labels_pattern = { R"([0-9bf]{2,}_[0-9io]{2,}->[0-9bf]{2,})"}; if (RE2::Consume(&consumable, *dim_labels_pattern)) { current_ptr_ = consumable.begin(); - str_val_.assign(token_start_, current_ptr_); + token_state_.str_val.assign(token_state_.token_start, current_ptr_); return TokKind::kDimLabels; } } - str_val_ = string(identifier); + token_state_.str_val = string(identifier); return TokKind::kIdent; } @@ -289,7 +300,7 @@ TokKind HloLexer::LexPercent() { while (IsIdentifierChar(PeekCurrentChar())) { current_ptr_++; } - str_val_.assign(name_start, current_ptr_); + token_state_.str_val.assign(name_start, current_ptr_); return TokKind::kName; } return TokKind::kError; @@ -307,12 +318,14 @@ TokKind HloLexer::LexPercent() { // int ::= [-]?[0-9]+ // negative inf ::= '-inf' TokKind HloLexer::LexNumberOrPattern() { - auto consumable = RegexpStringPieceFromPointers(token_start_, buf_.end()); + auto consumable = + RegexpStringPieceFromPointers(token_state_.token_start, buf_.end()); static LazyRE2 float_pattern = { R"([-]?((\d+|\d+[.]\d*|\d*[.]\d+)([eE][+-]?\d+))|[-]?(\d+[.]\d*|\d*[.]\d+))"}; if (RE2::Consume(&consumable, *float_pattern)) { current_ptr_ = consumable.begin(); - CHECK(absl::SimpleAtod(string(token_start_, current_ptr_), &decimal_val_)); + CHECK(absl::SimpleAtod(string(token_state_.token_start, current_ptr_), + &token_state_.decimal_val)); return TokKind::kDecimal; } @@ -324,27 +337,28 @@ TokKind HloLexer::LexNumberOrPattern() { if (RE2::Consume(&consumable, *dim_labels_pattern)) { current_ptr_ = consumable.begin(); - str_val_.assign(token_start_, current_ptr_); + token_state_.str_val.assign(token_state_.token_start, current_ptr_); return TokKind::kDimLabels; } if (RE2::Consume(&consumable, *dxd_pattern)) { current_ptr_ = consumable.begin(); - str_val_.assign(token_start_, current_ptr_); + token_state_.str_val.assign(token_state_.token_start, current_ptr_); return TokKind::kDxD; } if (RE2::Consume(&consumable, *pad_pattern)) { current_ptr_ = consumable.begin(); - str_val_.assign(token_start_, current_ptr_); + token_state_.str_val.assign(token_state_.token_start, current_ptr_); return TokKind::kPad; } static LazyRE2 int_pattern = {R"([-]?\d+)"}; if (RE2::Consume(&consumable, *int_pattern)) { current_ptr_ = consumable.begin(); - auto slice = StringPieceFromPointers(token_start_, current_ptr_); - if (absl::SimpleAtoi(slice, &int64_val_)) { + auto slice = + StringPieceFromPointers(token_state_.token_start, current_ptr_); + if (absl::SimpleAtoi(slice, &token_state_.int64_val)) { return TokKind::kInt; } LOG(ERROR) << "Failed to parse int literal: " << slice; @@ -403,16 +417,17 @@ absl::string_view HloLexer::GetLine(LocTy loc) const { } // Lexes quoted string with escaping characters. If matched, the quoted string -// will be unescaped and stored to str_val_. +// will be unescaped and stored to token_state_.str_val. TokKind HloLexer::LexString() { - auto consumable = RegexpStringPieceFromPointers(token_start_, buf_.end()); + auto consumable = + RegexpStringPieceFromPointers(token_state_.token_start, buf_.end()); static LazyRE2 escaping_pattern = {R"("([^"\\]|\\.)*")"}; if (RE2::Consume(&consumable, *escaping_pattern)) { current_ptr_ = consumable.begin(); absl::string_view raw = - StringPieceFromPointers(token_start_ + 1, current_ptr_ - 1); + StringPieceFromPointers(token_state_.token_start + 1, current_ptr_ - 1); string error; - if (!absl::CUnescape(raw, &str_val_, &error)) { + if (!absl::CUnescape(raw, &token_state_.str_val, &error)) { LOG(ERROR) << "Failed unescaping string: " << raw << ". error: " << error; return TokKind::kError; } @@ -467,6 +482,10 @@ string TokKindToString(TokKind kind) { return "kw_inf"; case TokKind::kNegInf: return "kNegInf"; + case TokKind::kw_sparse: + return "kw_sparse"; + case TokKind::kPrimitiveType: + return "kPrimitiveType"; case TokKind::kName: return "kName"; case TokKind::kAttributeName: @@ -481,8 +500,6 @@ string TokKindToString(TokKind kind) { return "kIdent"; case TokKind::kString: return "kString"; - case TokKind::kShape: - return "kShape"; case TokKind::kInt: return "kInt"; case TokKind::kDecimal: diff --git a/tensorflow/compiler/xla/service/hlo_lexer.h b/tensorflow/compiler/xla/service/hlo_lexer.h index d6a2b292a3916b2ff85f278cf5cb9f1567df88fa..41f5043904a2622814154693679a0e27cb92f642 100644 --- a/tensorflow/compiler/xla/service/hlo_lexer.h +++ b/tensorflow/compiler/xla/service/hlo_lexer.h @@ -19,7 +19,6 @@ limitations under the License. #include #include "absl/strings/string_view.h" -#include "tensorflow/compiler/xla/service/hlo_token.h" #include "tensorflow/compiler/xla/shape.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" @@ -29,6 +28,57 @@ limitations under the License. namespace xla { +// Defines different kinds of tokens used by the HLO lexer. +// +// You shouldn't need to use this directly unless you're using HloLexer +// directly, and you probably don't need to do that. Use hlo_parser instead. +enum class TokKind { + // Markers + kEof, + kError, + + // Tokens with no info. + kEqual, // = + kComma, // , + kColon, // : + kLsquare, + kRsquare, // [ ] + kLbrace, + kRbrace, // { } + kLparen, + kRparen, // ( ) + + kArrow, // -> + + // Keywords + kw_HloModule, + kw_ENTRY, + kw_ROOT, + kw_true, + kw_false, + kw_maximal, + kw_replicated, + kw_nan, + kw_inf, + kw_sparse, + + kNegInf, // -inf + + // Typed tokens. + kPrimitiveType, // F32, PRED, etc. + kName, // %foo + kAttributeName, // dimensions= + kDimLabels, // [0-9bf]{2,}_[0-9io]{2,}->[0-9bf]{2,} + kDxD, // [0-9]+(x[0-9]+)+ + kPad, // [0-9]+_[0-9]+(_[0-9]+)?(x[0-9]+_[0-9]+(_[0-9]+)?)* + kIdent, // other identifiers + kString, // "abcd\"\n" + kInt, // 42 + kDecimal, // 4.2 +}; + +string TokKindToString(TokKind kind); + // Lexer for the HloModule::ToString() format text. // // This class is meant to be used by hlo_parser.cc. You shouldn't need to use @@ -39,9 +89,9 @@ class HloLexer { current_ptr_ = buf_.begin(); } - TokKind Lex() { return current_kind_ = LexToken(); } + TokKind Lex() { return token_state_.current_kind = LexToken(); } - TokKind GetKind() const { return current_kind_; } + TokKind GetKind() const { return token_state_.current_kind; } string GetStrVal() const { switch (GetKind()) { case TokKind::kName: @@ -51,28 +101,28 @@ class HloLexer { case TokKind::kPad: case TokKind::kString: case TokKind::kIdent: - return str_val_; + return token_state_.str_val; default: LOG(FATAL) << "This token does not have string value"; } } - Shape GetShapeVal() const { - CHECK(GetKind() == TokKind::kShape); - return shape_val_; - } tensorflow::int64 GetInt64Val() const { CHECK(GetKind() == TokKind::kInt); - return int64_val_; + return token_state_.int64_val; } double GetDecimalVal() const { CHECK(GetKind() == TokKind::kDecimal); - return decimal_val_; + return token_state_.decimal_val; + } + PrimitiveType GetPrimitiveTypeVal() const { + CHECK(GetKind() == TokKind::kPrimitiveType); + return token_state_.primitive_type_val; } typedef const char* LocTy; // Returns the location of the current token. - LocTy GetLoc() const { return token_start_; } + LocTy GetLoc() const { return token_state_.token_start; } // Returns the line and column of a location in the buffer. std::pair GetLineAndColumn(LocTy location) const; @@ -80,6 +130,9 @@ class HloLexer { // Returns the whole line given the location. absl::string_view GetLine(LocTy loc) const; + // Looks ahead one token and returns it. Lexer state is unchanged. + TokKind LookAhead(); + private: // Returns the current character. If it's neither the end of input buffer nor // an invalid character, moves the pointer forward. @@ -112,12 +165,15 @@ class HloLexer { const char* current_ptr_; // Information about the current token. - const char* token_start_ = nullptr; - TokKind current_kind_; - string str_val_; - Shape shape_val_; - tensorflow::int64 int64_val_; - double decimal_val_; + struct TokenState { + const char* token_start = nullptr; + TokKind current_kind; + string str_val; + tensorflow::int64 int64_val; + double decimal_val; + PrimitiveType primitive_type_val; + }; + TokenState token_state_; struct LineNoCacheTy { const char* last_query; diff --git a/tensorflow/compiler/xla/service/hlo_liveness_analysis_test.cc b/tensorflow/compiler/xla/service/hlo_liveness_analysis_test.cc index e0ae1173c6114f0bc6ef18b2cfff9d54ccfe2faf..436cccb1fb9ecf6f4efad772c700c611b28ce628 100644 --- a/tensorflow/compiler/xla/service/hlo_liveness_analysis_test.cc +++ b/tensorflow/compiler/xla/service/hlo_liveness_analysis_test.cc @@ -403,9 +403,9 @@ TEST_F(HloLivenessAnalysisTest, WhileWithOutfeed) { HloModule OutfeedLoop WhileBody { body_param = (s32[]) parameter(0) - token = token[] after-all() + token0 = token[] after-all() constant.2 = s32[] constant(2) - outfeed_tuple = (s32[]) outfeed(constant.2, token) + outfeed_tuple = (s32[]) outfeed(constant.2, token0) get-tuple-element.1 = s32[] get-tuple-element(body_param), index=0 constant.1 = s32[] constant(1) add = s32[] add(get-tuple-element.1, constant.1) @@ -436,9 +436,9 @@ TEST_F(HloLivenessAnalysisTest, NestedWhileWithOutfeed) { HloModule OutfeedLoop InnerWhileBody { body_param = (s32[]) parameter(0) - token = token[] after-all() + token0 = token[] after-all() constant.2 = s32[] constant(2) - outfeed_tuple = (s32[]) outfeed(constant.2, token) + outfeed_tuple = (s32[]) outfeed(constant.2, token0) get-tuple-element.1 = s32[] get-tuple-element(body_param), index=0 constant.1 = s32[] constant(1) add = s32[] add(get-tuple-element.1, constant.1) diff --git a/tensorflow/compiler/xla/service/hlo_matchers.h b/tensorflow/compiler/xla/service/hlo_matchers.h index 235efb19ce4ed28a5cd9fe5ca52ae5d8e9e5ba3d..1fbcbdf98d68204b1c6269d51d9b19363761ee04 100644 --- a/tensorflow/compiler/xla/service/hlo_matchers.h +++ b/tensorflow/compiler/xla/service/hlo_matchers.h @@ -312,8 +312,8 @@ inline ::testing::Matcher Shape( } inline ::testing::Matcher Shape( absl::string_view shape) { - return ::testing::MakeMatcher(new ::xla::testing::HloShapeMatcher( - ShapeUtil::ParseShapeString(shape).ValueOrDie())); + return ::testing::MakeMatcher( + new ::xla::testing::HloShapeMatcher(ParseShape(shape).ValueOrDie())); } inline ::testing::Matcher ShapeWithLayout( const class Shape& shape) { @@ -323,7 +323,7 @@ inline ::testing::Matcher ShapeWithLayout( inline ::testing::Matcher ShapeWithLayout( absl::string_view shape) { return ::testing::MakeMatcher(new ::xla::testing::HloShapeAndLayoutMatcher( - ShapeUtil::ParseShapeString(shape).ValueOrDie())); + ParseShape(shape).ValueOrDie())); } // Verifies the value of the HloSharing against the provided sharding object. diff --git a/tensorflow/compiler/xla/service/hlo_module.h b/tensorflow/compiler/xla/service/hlo_module.h index 7b9cbf9a53a2201b1312405bbd7ed2b88f65c9be..f1310e4b270898a21dbb4f86123edde4ba8993d0 100644 --- a/tensorflow/compiler/xla/service/hlo_module.h +++ b/tensorflow/compiler/xla/service/hlo_module.h @@ -136,7 +136,9 @@ class HloModule { // information on opcode, shape, operands, and typically a root instruction. // This function returns the same hash value for equivalent HLO modules, // with respect to HloInstruction::Identical() method. - uint64 Hash() const { return entry_computation()->Hash(); } + uint64 Hash() const { + return entry_computation()->root_instruction()->Hash(); + } // Gets the computations in this module. // diff --git a/tensorflow/compiler/xla/service/hlo_module_dce_test.cc b/tensorflow/compiler/xla/service/hlo_module_dce_test.cc index bf66cc6bc37a5e11c9ecfc07a62ba0ea5ca11a03..e535b7d74943943069b4d795cf999a3b1e963360 100644 --- a/tensorflow/compiler/xla/service/hlo_module_dce_test.cc +++ b/tensorflow/compiler/xla/service/hlo_module_dce_test.cc @@ -373,9 +373,9 @@ TEST_F(HloModuleDceTest, WhileWithOutfeed) { HloModule OutfeedLoop WhileBody { body_param = (s32[]) parameter(0) - token = token[] after-all() + token0 = token[] after-all() constant.2 = s32[] constant(2) - outfeed_tuple = (s32[]) outfeed(constant.2, token) + outfeed_tuple = (s32[]) outfeed(constant.2, token0) get-tuple-element.1 = s32[] get-tuple-element(body_param), index=0 constant.1 = s32[] constant(1) add = s32[] add(get-tuple-element.1, constant.1) diff --git a/tensorflow/compiler/xla/service/hlo_parser.cc b/tensorflow/compiler/xla/service/hlo_parser.cc index 9b5bb5d0bd6af104ef62eaa5d3e53cedbe0213d3..29bb088f6de9a5113d253b7e5559a8e66e7e408b 100644 --- a/tensorflow/compiler/xla/service/hlo_parser.cc +++ b/tensorflow/compiler/xla/service/hlo_parser.cc @@ -23,6 +23,7 @@ limitations under the License. #include "absl/strings/str_split.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/primitive_util.h" #include "tensorflow/compiler/xla/service/hlo_domain_metadata.h" #include "tensorflow/compiler/xla/service/hlo_instructions.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" @@ -74,6 +75,7 @@ class HloParser { string GetError() const { return StrJoin(error_, "\n"); } // Stand alone parsing utils for various aggregate data types. + StatusOr ParseShapeOnly(); StatusOr ParseShardingOnly(); StatusOr ParseWindowOnly(); StatusOr ParseConvolutionDimensionNumbersOnly(); @@ -255,7 +257,9 @@ class HloParser { bool ParseName(string* result); bool ParseAttributeName(string* result); bool ParseString(string* result); + bool ParseDimensionSizes(std::vector* dimension_sizes); bool ParseShape(Shape* result); + bool ParseLayout(Layout* layout); bool ParseOpcode(HloOpcode* result); bool ParseFftType(FftType* result); bool ParseFusionKind(HloInstruction::FusionKind* result); @@ -279,9 +283,6 @@ class HloParser { // If the current token is 'kind', eats it (i.e. lexes the next token) and // returns true. bool EatIfPresent(TokKind kind); - // Parses a shape, and returns true if the result is compatible with the given - // shape. - bool EatShapeAndCheckCompatible(const Shape& shape); // Adds the instruction to the pool. Returns false and emits an error if the // instruction already exists. @@ -1697,11 +1698,6 @@ bool HloParser::ParseSingleSharding(OpSharding* sharding, } break; } - case TokKind::kShape: - // TODO(b/112302613): Left here for backward compatibility to ignore the - // removed tile shape data. - lexer_.Lex(); - break; case TokKind::kRbrace: break; default: @@ -1925,19 +1921,6 @@ bool HloParser::SetValueInLiteralHelper(ParsedElemT value, return true; } -bool HloParser::EatShapeAndCheckCompatible(const Shape& shape) { - Shape new_shape; - if (!ParseShape(&new_shape)) { - return TokenError(StrCat("expects shape ", ShapeUtil::HumanString(shape))); - } - if (!ShapeUtil::Compatible(shape, new_shape)) { - return TokenError(StrCat( - "expects shape ", ShapeUtil::HumanString(shape), - ", but sees a different shape: ", ShapeUtil::HumanString(new_shape))); - } - return true; -} - // literal // ::= tuple // ::= non_tuple @@ -1952,10 +1935,6 @@ bool HloParser::ParseLiteral(Literal* literal, const Shape& shape) { // ::= /*empty*/ // ::= literal (',' literal)* bool HloParser::ParseTupleLiteral(Literal* literal, const Shape& shape) { - if (!EatShapeAndCheckCompatible(shape)) { - return TokenError(StrCat("expects tuple constant in shape ", - ShapeUtil::HumanString(shape))); - } if (!ParseToken(TokKind::kLparen, "expects '(' in front of tuple elements")) { return false; } @@ -1990,16 +1969,12 @@ bool HloParser::ParseNonTupleLiteral(Literal* literal, const Shape& shape) { return ParseSparseLiteral(literal, shape); } - CHECK(LayoutUtil::IsDenseArray(shape)); + CHECK(LayoutUtil::IsDenseArray(shape)) << shape.ToString(true); return ParseDenseLiteral(literal, shape); } bool HloParser::ParseDenseLiteral(Literal* literal, const Shape& shape) { const tensorflow::int64 rank = ShapeUtil::Rank(shape); - if (rank > 1 && !EatShapeAndCheckCompatible(shape)) { - return false; - } - // Create a literal with the given shape in default layout. *literal = LiteralUtil::CreateFromDimensions( shape.element_type(), AsInt64Slice(shape.dimensions())); @@ -2126,10 +2101,6 @@ bool HloParser::ParseDenseLiteral(Literal* literal, const Shape& shape) { } bool HloParser::ParseSparseLiteral(Literal* literal, const Shape& shape) { - if (!EatShapeAndCheckCompatible(shape)) { - return false; - } - switch (shape.element_type()) { case PRED: return ParseSparseLiteralHelper(literal, shape); @@ -2994,6 +2965,39 @@ bool HloParser::ParseParamList() { return ParseToken(TokKind::kRparen, "expects ')' at the end of param list"); } +// dimension_sizes ::= '[' int64_list ']' +bool HloParser::ParseDimensionSizes(std::vector* dimension_sizes) { + auto parse_and_add_item = [&]() { + tensorflow::int64 i; + if (!ParseInt64(&i)) { + return false; + } + dimension_sizes->push_back(i); + return true; + }; + return ParseList(TokKind::kLsquare, TokKind::kRsquare, TokKind::kComma, + parse_and_add_item); +} + +// layout ::= '{' int64_list '}' +bool HloParser::ParseLayout(Layout* layout) { + std::vector minor_to_major; + auto parse_and_add_item = [&]() { + tensorflow::int64 i; + if (!ParseInt64(&i)) { + return false; + } + minor_to_major.push_back(i); + return true; + }; + if (!ParseList(TokKind::kLbrace, TokKind::kRbrace, TokKind::kComma, + parse_and_add_item)) { + return false; + } + *layout = LayoutUtil::MakeLayout(minor_to_major); + return true; +} + // shape ::= shape_val_ // shape ::= '(' tuple_elements ')' // tuple_elements @@ -3017,19 +3021,61 @@ bool HloParser::ParseShape(Shape* result) { return ParseToken(TokKind::kRparen, "expects ')' at the end of tuple."); } - if (lexer_.GetKind() != TokKind::kShape) { - return TokenError(absl::StrCat("expected shape, saw ", + if (lexer_.GetKind() != TokKind::kPrimitiveType) { + return TokenError(absl::StrCat("expected primitive type, saw ", TokKindToString(lexer_.GetKind()))); } - *result = lexer_.GetShapeVal(); + PrimitiveType primitive_type = lexer_.GetPrimitiveTypeVal(); lexer_.Lex(); + + std::vector dimension_sizes; + if (!ParseDimensionSizes(&dimension_sizes)) { + return false; + } + result->set_element_type(primitive_type); + *result->mutable_dimensions() = dimension_sizes; + LayoutUtil::SetToDefaultLayout(result); + + if (lexer_.GetKind() == TokKind::kw_sparse) { + lexer_.Lex(); + const string message = + "expects a brace-bracketed integer for sparse layout"; + tensorflow::int64 max_sparse_elements; + if (!ParseToken(TokKind::kLbrace, message) || + !ParseInt64(&max_sparse_elements) || + !ParseToken(TokKind::kRbrace, message)) { + return false; + } + *result->mutable_layout() = + LayoutUtil::MakeSparseLayout(max_sparse_elements); + return true; + } + + // We need to lookahead to see if a following open brace is the start of a + // layout. The specific problematic case is: + // + // ENTRY %foo (x: f32[42]) -> f32[123] { + // ... + // } + // + // The open brace could either be the start of a computation or the start of a + // layout for the f32[123] shape. We consider it the start of a layout if the + // next token after the open brace is a integer + if (lexer_.GetKind() == TokKind::kLbrace && + lexer_.LookAhead() == TokKind::kInt) { + Layout layout; + if (!ParseLayout(&layout)) { + return false; + } + *result->mutable_layout() = layout; + } return true; } bool HloParser::CanBeShape() { - // A non-tuple shape starts with a kShape token; a tuple shape starts with - // '('. - return lexer_.GetKind() == TokKind::kShape || + // A non-tuple shape starts with a kPrimitiveType token; a tuple shape starts + // with '('. + return lexer_.GetKind() == TokKind::kPrimitiveType || lexer_.GetKind() == TokKind::kLparen; } @@ -3332,6 +3378,18 @@ bool HloParser::AddComputation(const string& name, HloComputation* computation, return true; } +StatusOr HloParser::ParseShapeOnly() { + lexer_.Lex(); + Shape shape; + if (!ParseShape(&shape)) { + return InvalidArgument("Syntax error:\n%s", GetError()); + } + if (lexer_.GetKind() != TokKind::kEof) { + return InvalidArgument("Syntax error:\nExtra content after shape"); + } + return shape; +} + StatusOr HloParser::ParseShardingOnly() { lexer_.Lex(); OpSharding op_sharding; @@ -3475,4 +3533,9 @@ StatusOr ParsePaddingConfig(absl::string_view str) { return parser.ParsePaddingConfigOnly(); } +StatusOr ParseShape(absl::string_view str) { + HloParser parser(str); + return parser.ParseShapeOnly(); +} + } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_parser.h b/tensorflow/compiler/xla/service/hlo_parser.h index d830fa61438239005875f785f85cf2486123ebc9..450a54c54c156c2ae27475d145a8e83dc841b431 100644 --- a/tensorflow/compiler/xla/service/hlo_parser.h +++ b/tensorflow/compiler/xla/service/hlo_parser.h @@ -60,6 +60,9 @@ StatusOr ParseConvolutionDimensionNumbers( // Parses the result of PaddingConfigToString(), e.g. "0_0x1_1". StatusOr ParsePaddingConfig(absl::string_view str); +// Parses and returns a Shape::ToString-format string. +StatusOr ParseShape(absl::string_view str); + } // namespace xla #endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_PARSER_H_ diff --git a/tensorflow/compiler/xla/service/hlo_parser_test.cc b/tensorflow/compiler/xla/service/hlo_parser_test.cc index ab71f011ac9d77d00ddfb41aca7a224d26d416b7..80882d490d6b477403f87a4eb266d3ba2fdb3378 100644 --- a/tensorflow/compiler/xla/service/hlo_parser_test.cc +++ b/tensorflow/compiler/xla/service/hlo_parser_test.cc @@ -82,7 +82,7 @@ ENTRY %constant_pred () -> pred[] { R"(HloModule module ENTRY %constant_pred_array () -> pred[2,3] { - ROOT %constant = pred[2,3]{1,0} constant(pred[2,3] { { 0, 1, 0 }, { 1, 0, 1 } }) + ROOT %constant = pred[2,3]{1,0} constant({ { 0, 1, 0 }, { 1, 0, 1 } }) } )" @@ -128,7 +128,7 @@ ENTRY %ConstantF32Empty.v4 () -> f32[0] { R"(HloModule ConstantF32R4Empty_module ENTRY %ConstantF32R4Empty.v4 () -> f32[2,0,4,3] { - ROOT %constant = f32[2,0,4,3]{3,2,1,0} constant(f32[2,0,4,3] { { /*i0=0*/ }, { /*i0=1*/ } }) + ROOT %constant = f32[2,0,4,3]{3,2,1,0} constant({ { /*i0=0*/ }, { /*i0=1*/ } }) } )" @@ -139,7 +139,7 @@ ENTRY %ConstantF32R4Empty.v4 () -> f32[2,0,4,3] { R"(HloModule Small_3x2x1x1_module ENTRY %Small_3x2x1x1.v1 () -> f32[3,2,1,1] { - ROOT %constant = f32[3,2,1,1]{3,2,1,0} constant(f32[3,2,1,1] { { /*i0=0*/ { /*i1=0*/ {-1} }, { /*i1=1*/ {4.1} } }, { /*i0=1*/ { /*i1=0*/ {2} }, { /*i1=1*/ {4.1} } }, { /*i0=2*/ { /*i1=0*/ {5} }, { /*i1=1*/ {4.4} } } }) + ROOT %constant = f32[3,2,1,1]{3,2,1,0} constant({ { /*i0=0*/ { /*i1=0*/ {-1} }, { /*i1=1*/ {4.1} } }, { /*i0=1*/ { /*i1=0*/ {2} }, { /*i1=1*/ {4.1} } }, { /*i0=2*/ { /*i1=0*/ {5} }, { /*i1=1*/ {4.4} } } }) } )" @@ -196,7 +196,7 @@ ENTRY %add_constants () -> f32[] { R"(HloModule TupleConstant_module ENTRY %TupleConstant.v1 () -> (f32[2,1], f32[2]) { - ROOT %constant = (f32[2,1]{1,0}, f32[2]{0}) constant((f32[2,1], f32[2]) ( f32[2,1] { {1}, {2} }, {2, 42} )) + ROOT %constant = (f32[2,1]{1,0}, f32[2]{0}) constant(( { {1}, {2} }, {2, 42} )) } )" @@ -295,11 +295,11 @@ ENTRY %WhileWithScalarS32Result.v2 () -> s32[] { R"(HloModule TwoSendRecvBothWayRecvFist_module ENTRY %TwoSendRecvBothWayRecvFist.v3 () -> (f32[], token[]) { - %token = token[] after-all() - %recv = (f32[], u32[], token[]) recv(token[] %token), channel_id=15, sharding={maximal device=1} + %token0 = token[] after-all() + %recv = (f32[], u32[], token[]) recv(token[] %token0), channel_id=15, sharding={maximal device=1} ROOT %recv-done = (f32[], token[]) recv-done((f32[], u32[], token[]) %recv), channel_id=15, sharding={maximal device=1} %constant = f32[] constant(2.1), sharding={maximal device=0} - %send = (f32[], u32[], token[]) send(f32[] %constant, token[] %token), channel_id=16, sharding={maximal device=0}, control-predecessors={%recv} + %send = (f32[], u32[], token[]) send(f32[] %constant, token[] %token0), channel_id=16, sharding={maximal device=0}, control-predecessors={%recv} %send-done = token[] send-done((f32[], u32[], token[]) %send), channel_id=16, sharding={maximal device=0} } @@ -310,11 +310,11 @@ ENTRY %TwoSendRecvBothWayRecvFist.v3 () -> (f32[], token[]) { R"(HloModule HostTransferSendRecv_module ENTRY %TwoSendRecvBothWayRecvFist.v3 () -> (f32[], token[]) { - %token = token[] after-all() - %recv = (f32[], u32[], token[]) recv(token[] %token), channel_id=15, is_host_transfer=true + %token0 = token[] after-all() + %recv = (f32[], u32[], token[]) recv(token[] %token0), channel_id=15, is_host_transfer=true ROOT %recv-done = (f32[], token[]) recv-done((f32[], u32[], token[]) %recv), channel_id=15, is_host_transfer=true %constant = f32[] constant(2.1), sharding={maximal device=0} - %send = (f32[], u32[], token[]) send(f32[] %constant, token[] %token), channel_id=16, is_host_transfer=true + %send = (f32[], u32[], token[]) send(f32[] %constant, token[] %token0), channel_id=16, is_host_transfer=true %send-done = token[] send-done((f32[], u32[], token[]) %send), channel_id=16, is_host_transfer=true } @@ -327,7 +327,7 @@ R"(HloModule GetTupleElement_module ENTRY %GetTupleElement.v4 () -> s32[2,3] { %constant = f32[3]{0} constant({1, 2, 3}) - %constant.1 = s32[2,3]{1,0} constant(s32[2,3] { { 1, 2, 3 }, { 4, 5, 6 } }) + %constant.1 = s32[2,3]{1,0} constant({ { 1, 2, 3 }, { 4, 5, 6 } }) %tuple = (f32[3]{0}, s32[2,3]{1,0}) tuple(f32[3]{0} %constant, s32[2,3]{1,0} %constant.1) ROOT %get-tuple-element = s32[2,3]{1,0} get-tuple-element((f32[3]{0}, s32[2,3]{1,0}) %tuple), index=1, sharding={maximal device=0} } @@ -434,7 +434,7 @@ ENTRY %ConvolveBackward (input: f32[128,7,7,512], filter: f32[3,3,512,512]) -> f R"(HloModule Reverse4DFloatArrayOnDim01_module ENTRY %Reverse4DFloatArrayOnDim01.v2 () -> f32[4,3,2,1] { - %constant = f32[4,3,2,1]{0,1,2,3} constant(f32[4,3,2,1] { { /*i0=0*/ { /*i1=0*/ {1}, {2} }, { /*i1=1*/ {3}, {4} }, { /*i1=2*/ {5}, {6} } }, { /*i0=1*/ { /*i1=0*/ {7}, {8} }, { /*i1=1*/ {9}, {10} }, { /*i1=2*/ {11}, {12} } }, { /*i0=2*/ { /*i1=0*/ {13}, {14} }, { /*i1=1*/ {15}, {16} }, { /*i1=2*/ {17}, {18} } }, { /*i0=3*/ { /*i1=0*/ {19}, {20} }, { /*i1=1*/ {21}, {22} }, { /*i1=2*/ {23}, {24} } } }) + %constant = f32[4,3,2,1]{0,1,2,3} constant({ { /*i0=0*/ { /*i1=0*/ {1}, {2} }, { /*i1=1*/ {3}, {4} }, { /*i1=2*/ {5}, {6} } }, { /*i0=1*/ { /*i1=0*/ {7}, {8} }, { /*i1=1*/ {9}, {10} }, { /*i1=2*/ {11}, {12} } }, { /*i0=2*/ { /*i1=0*/ {13}, {14} }, { /*i1=1*/ {15}, {16} }, { /*i1=2*/ {17}, {18} } }, { /*i0=3*/ { /*i1=0*/ {19}, {20} }, { /*i1=1*/ {21}, {22} }, { /*i1=2*/ {23}, {24} } } }) ROOT %reverse = f32[4,3,2,1]{0,1,2,3} reverse(f32[4,3,2,1]{0,1,2,3} %constant), dimensions={0,1} } @@ -446,8 +446,8 @@ ENTRY %Reverse4DFloatArrayOnDim01.v2 () -> f32[4,3,2,1] { R"(HloModule Concat2x3With2x5_module ENTRY %Concat2x3With2x5.v3 () -> f32[2,8] { - %constant = f32[2,3]{1,0} constant(f32[2,3] { { 0, 1, 2 }, { 1000, 1001, 1002 } }) - %constant.1 = f32[2,5]{1,0} constant(f32[2,5] { { 64, 65, 66, 67, 68 }, { 1064, 1065, 1066, 1067, 1068 } }) + %constant = f32[2,3]{1,0} constant({ { 0, 1, 2 }, { 1000, 1001, 1002 } }) + %constant.1 = f32[2,5]{1,0} constant({ { 64, 65, 66, 67, 68 }, { 1064, 1065, 1066, 1067, 1068 } }) ROOT %concatenate = f32[2,8]{1,0} concatenate(f32[2,3]{1,0} %constant, f32[2,5]{1,0} %constant.1), dimensions={1} } @@ -471,8 +471,8 @@ R"(HloModule R4F32OverlapSmall_module } ENTRY %R4F32OverlapSmall.v4 () -> f32[4,5,1,1] { - %constant = f32[4,5,1,1]{3,2,1,0} constant(f32[4,5,1,1] { { /*i0=0*/ { /*i1=0*/ {7} }, { /*i1=1*/ {2} }, { /*i1=2*/ {5} }, { /*i1=3*/ {3} }, { /*i1=4*/ {8} } }, { /*i0=1*/ { /*i1=0*/ {3} }, { /*i1=1*/ {8} }, { /*i1=2*/ {9} }, { /*i1=3*/ {3} }, { /*i1=4*/ {4} } }, { /*i0=2*/ { /*i1=0*/ {1} }, { /*i1=1*/ {5} }, { /*i1=2*/ {7} }, { /*i1=3*/ {5} }, { /*i1=4*/ {6} } }, { /*i0=3*/ { /*i1=0*/ {0} }, { /*i1=1*/ {6} }, { /*i1=2*/ {2} }, { /*i1=3*/ {10} }, { /*i1=4*/ {2} } } }) - %constant.1 = f32[2,2,1,1]{3,2,1,0} constant(f32[2,2,1,1] { { /*i0=0*/ { /*i1=0*/ {2} }, { /*i1=1*/ {6} } }, { /*i0=1*/ { /*i1=0*/ {3} }, { /*i1=1*/ {1} } } }) + %constant = f32[4,5,1,1]{3,2,1,0} constant({ { /*i0=0*/ { /*i1=0*/ {7} }, { /*i1=1*/ {2} }, { /*i1=2*/ {5} }, { /*i1=3*/ {3} }, { /*i1=4*/ {8} } }, { /*i0=1*/ { /*i1=0*/ {3} }, { /*i1=1*/ {8} }, { /*i1=2*/ {9} }, { /*i1=3*/ {3} }, { /*i1=4*/ {4} } }, { /*i0=2*/ { /*i1=0*/ {1} }, { /*i1=1*/ {5} }, { /*i1=2*/ {7} }, { /*i1=3*/ {5} }, { /*i1=4*/ {6} } }, { /*i0=3*/ { /*i1=0*/ {0} }, { /*i1=1*/ {6} }, { /*i1=2*/ {2} }, { /*i1=3*/ {10} }, { /*i1=4*/ {2} } } }) + %constant.1 = f32[2,2,1,1]{3,2,1,0} constant({ { /*i0=0*/ { /*i1=0*/ {2} }, { /*i1=1*/ {6} } }, { /*i0=1*/ { /*i1=0*/ {3} }, { /*i1=1*/ {1} } } }) %constant.2 = f32[] constant(0) ROOT %select-and-scatter = f32[4,5,1,1]{3,2,1,0} select-and-scatter(f32[4,5,1,1]{3,2,1,0} %constant, f32[2,2,1,1]{3,2,1,0} %constant.1, f32[] %constant.2), window={size=2x3x1x1 stride=2x2x1x1}, select=%ge_F32.v3, scatter=%add_F32.v3 } @@ -523,7 +523,7 @@ ENTRY %slice.v2 (p0: f32[3,3,4,4]) -> f32[3,3,2,4] { R"(HloModule Slice3x3x3_To_1x3x3_F32_module ENTRY %Slice3x3x3_To_1x3x3_F32.v2 () -> f32[1,3,3] { - %constant = f32[3,3,3]{2,1,0} constant(f32[3,3,3] { { { 0, 1, 2 }, { 3, 4, 5 }, { 6, 7, 8 } }, { { 9, 10, 11 }, { 12, 13, 14 }, { 15, 16, 17 } }, { { 18, 19, 20 }, { 21, 22, 23 }, { 24, 25, 26 } } }) + %constant = f32[3,3,3]{2,1,0} constant({ { { 0, 1, 2 }, { 3, 4, 5 }, { 6, 7, 8 } }, { { 9, 10, 11 }, { 12, 13, 14 }, { 15, 16, 17 } }, { { 18, 19, 20 }, { 21, 22, 23 }, { 24, 25, 26 } } }) ROOT %slice = f32[1,3,3]{2,1,0} slice(f32[3,3,3]{2,1,0} %constant), slice={[0:1], [0:3], [0:3]} } @@ -547,7 +547,7 @@ ENTRY %SliceR0.v2 () -> s32[] { R"(HloModule Transpose_module ENTRY %Transpose.v2 () -> s32[1,2,3] { - %constant = s32[1,2,3]{2,1,0} constant(s32[1,2,3] { { { 1, 2, 3 }, { 4, 5, 6 } } }) + %constant = s32[1,2,3]{2,1,0} constant({ { { 1, 2, 3 }, { 4, 5, 6 } } }) ROOT %transpose = s32[1,2,3]{2,1,0} transpose(s32[1,2,3]{2,1,0} %constant), dimensions={0,1,2} } @@ -588,7 +588,7 @@ ENTRY %DynamicUpdateSlice.v4 (input: s32[1,1,25,1], update: s32[1,1,2,1], start_ R"(HloModule BasicTraining_module ENTRY %BasicTraining.v4 () -> (f32[2,2,1,2], f32[2], f32[2]) { - %constant = f32[2,2,1,2]{3,2,1,0} constant(f32[2,2,1,2] { { /*i0=0*/ { /*i1=0*/ { 1, 2 } }, { /*i1=1*/ { 3, 4 } } }, { /*i0=1*/ { /*i1=0*/ { 5, 6 } }, { /*i1=1*/ { 7, 8 } } } }) + %constant = f32[2,2,1,2]{3,2,1,0} constant({ { /*i0=0*/ { /*i1=0*/ { 1, 2 } }, { /*i1=1*/ { 3, 4 } } }, { /*i0=1*/ { /*i1=0*/ { 5, 6 } }, { /*i1=1*/ { 7, 8 } } } }) %constant.1 = f32[2]{0} constant({2, 3}) %constant.2 = f32[2]{0} constant({1, 2}) ROOT %batch-norm-training = (f32[2,2,1,2]{3,2,1,0}, f32[2]{0}, f32[2]{0}) batch-norm-training(f32[2,2,1,2]{3,2,1,0} %constant, f32[2]{0} %constant.1, f32[2]{0} %constant.2), epsilon=0.001, feature_index=3 @@ -728,7 +728,7 @@ R"(HloModule fusion_module } ENTRY %fusion.v3 () -> f32[3,2,1,1] { - %constant = f32[3,2,1,1]{3,2,1,0} constant(f32[3,2,1,1] { { /*i0=0*/ { /*i1=0*/ {-1} }, { /*i1=1*/ {4.1} } }, { /*i0=1*/ { /*i1=0*/ {2} }, { /*i1=1*/ {4.1} } }, { /*i0=2*/ { /*i1=0*/ {5} }, { /*i1=1*/ {4.4} } } }) + %constant = f32[3,2,1,1]{3,2,1,0} constant({ { /*i0=0*/ { /*i1=0*/ {-1} }, { /*i1=1*/ {4.1} } }, { /*i0=1*/ { /*i1=0*/ {2} }, { /*i1=1*/ {4.1} } }, { /*i0=2*/ { /*i1=0*/ {5} }, { /*i1=1*/ {4.4} } } }) %constant.1 = f32[2]{0} constant({3.14, 4.25}) ROOT %fusion = f32[3,2,1,1]{3,2,1,0} fusion(f32[3,2,1,1]{3,2,1,0} %constant, f32[2]{0} %constant.1), kind=kLoop, calls=%fused_computation } @@ -740,7 +740,7 @@ ENTRY %fusion.v3 () -> f32[3,2,1,1] { R"(HloModule sparse_f32 ENTRY %sparse () -> f32[2,3,4] { - ROOT %foo = f32[2,3,4]sparse{10} constant(f32[2,3,4]{[0, 1, 2]: 1, [1, 2, 3]: 2, [2, 3, 4]: 3}) + ROOT %foo = f32[2,3,4]sparse{10} constant({[0, 1, 2]: 1, [1, 2, 3]: 2, [2, 3, 4]: 3}) } )" @@ -750,7 +750,7 @@ ENTRY %sparse () -> f32[2,3,4] { R"(HloModule sparse_f32_empty ENTRY %sparse_f32_empty () -> f32[2,3,4] { - ROOT %foo = f32[2,3,4]sparse{10} constant(f32[2,3,4]{}) + ROOT %foo = f32[2,3,4]sparse{10} constant({}) } )" @@ -760,7 +760,7 @@ ENTRY %sparse_f32_empty () -> f32[2,3,4] { R"(HloModule sparse_f32_r1 ENTRY %sparse_f32_r1 () -> f32[9] { - ROOT %foo = f32[9]sparse{10} constant(f32[9]{1: 2, 3: 4, 5: 6}) + ROOT %foo = f32[9]sparse{10} constant({1: 2, 3: 4, 5: 6}) } )" @@ -931,11 +931,11 @@ ENTRY reduce_entry { R"(HloModule outfeed_module ENTRY InfeedToOutfeed { - token = token[] after-all() - infeed = ((u32[3]{0}, pred[]), token[]) infeed(token) + token0 = token[] after-all() + infeed = ((u32[3]{0}, pred[]), token[]) infeed(token0) infeed.data = (u32[3]{0}, pred[]) get-tuple-element(infeed), index=0 - outfeed = token[] outfeed(infeed.data, token) - ROOT infeed.1 = ((u32[3]{0}, pred[]), token[]) infeed(token) + outfeed = token[] outfeed(infeed.data, token0) + ROOT infeed.1 = ((u32[3]{0}, pred[]), token[]) infeed(token0) infeed.1.data = (u32[3]{0}, pred[]) get-tuple-element(infeed.1), index=0 infeed.1.token = token[] get-tuple-element(infeed.1), index=1 outfeed.1 = token[] outfeed(infeed.1.data, infeed.1.token) @@ -1266,8 +1266,8 @@ R"(HloModule AddDependency ENTRY AddDependency { p = f32[] parameter(0) neg = f32[] negate(p) - token = token[] after-all(neg) - p_after_token = f32[] add-dependency(p, token) + token0 = token[] after-all(neg) + p_after_token = f32[] add-dependency(p, token0) exp = f32[] exponential(p_after_token) ROOT sum = f32[] add(neg, exp) } @@ -1419,7 +1419,7 @@ TEST_F(HloParserTest, MoreConstants) { ENTRY %SelectScalarS32True.v4 () -> s32[] { %constant.2 = pred[] constant(true) - %constant.1 = s32[] constant(-42), sharding={s32[5,6] devices=[2,2]1,2,3,4} + %constant.1 = s32[] constant(-42), sharding={devices=[2,2]1,2,3,4} %constant = s32[] constant(42) %select = s32[] select(pred[] %constant.2, s32[] %constant.1, s32[] %constant) } @@ -1462,7 +1462,7 @@ TEST_F(HloParserTest, LiteralDimensionsMismatch_2) { const string original = R"(HloModule some_2x3_module ENTRY %some_2x3 () -> f32[2,3] { - ROOT %constant = f32[2,3]{1,0} constant(f32[2,3] {1, 2, 3, 4, 5, 6}) + ROOT %constant = f32[2,3]{1,0} constant({1, 2, 3, 4, 5, 6}) } )"; @@ -1476,7 +1476,7 @@ TEST_F(HloParserTest, LiteralDimensionsMismatch_3) { const string original = R"(HloModule some_2x3x2_module ENTRY %some_2x3x2 () -> f32[2,3,2] { - ROOT %constant = f32[2,3,2]{2,1,0} constant(f32[2,3,2] {{{1, 2}, {3, 4}, {5, 6}, {7, 8}, {9, 10}, {11, 12}}}) + ROOT %constant = f32[2,3,2]{2,1,0} constant({{{1, 2}, {3, 4}, {5, 6}, {7, 8}, {9, 10}, {11, 12}}}) } )"; @@ -1594,11 +1594,11 @@ TEST_F(HloParserTest, UnexpectedAttribute) { const string original = R"(HloModule unexpected_attr_module ENTRY %TwoSendRecvBothWayRecvFist.v3 () -> f32[] { - %token = token[] after-all() - %recv = (f32[], u32[], token[]) recv(token[] %token), channel_id=15 + %token0 = token[] after-all() + %recv = (f32[], u32[], token[]) recv(token[] %token0), channel_id=15 %recv-done = (f32[], token[]) recv-done((f32[], u32[], token[]) %recv), channel_id=15 ROOT %constant = f32[] constant(2.1) - %send = (f32[], u32[], token[]) send(f32[] %constant, token[] %token), channel_id=16, calls=%recv + %send = (f32[], u32[], token[]) send(f32[] %constant, token[] %token0), channel_id=16, calls=%recv %send-done = token[] send-done((f32[], u32[], token[]) %send), channel_id=16 } @@ -1611,11 +1611,11 @@ TEST_F(HloParserTest, MissingAttribute) { const string original = R"(HloModule missing_attr_module ENTRY %TwoSendRecvBothWayRecvFist.v3 () -> f32[] { - %token = token[] after-all() - %recv = (f32[], u32[], token[]) recv(token[] %token), channel_id=15 + %token0 = token[] after-all() + %recv = (f32[], u32[], token[]) recv(token[] %token0), channel_id=15 %recv-done = (f32[], token[]) recv-done((f32[], u32[], token[]) %recv), channel_id=15 ROOT %constant = f32[] constant(-2.1) - %send = (f32[], u32[], token[]) send(f32[] %constant, token[] %token) + %send = (f32[], u32[], token[]) send(f32[] %constant, token[] %token0) %send-done = token[] send-done((f32[], u32[], token[]) %send), channel_id=16 } @@ -1628,11 +1628,11 @@ TEST_F(HloParserTest, PredecessorUndefined) { const string original = R"(HloModule pre_not_found_module ENTRY %TwoSendRecvBothWayRecvFist.v3 () -> f32[] { - %token = token[] after-all() - %recv = (f32[], u32[], token[]) recv(token[] %token), channel_id=15 + %token0 = token[] after-all() + %recv = (f32[], u32[], token[]) recv(token[] %token0), channel_id=15 %recv-done = (f32[], token[]) recv-done((f32[], u32[], token[]) %recv), channel_id=15 ROOT %constant = f32[] constant(2.1) - %send = (f32[], u32[], token[]) send(f32[] %constant, token[] %token), channel_id=16, control-predecessors={%done} + %send = (f32[], u32[], token[]) send(f32[] %constant, token[] %token0), channel_id=16, control-predecessors={%done} %send-done = token[] send-done((f32[], u32[], token[]) %send), channel_id=16 } @@ -1940,8 +1940,8 @@ TEST_F(HloParserTest, ParsePaddingConfigInteriorPaddingImplicitZeroDim) { TEST_F(HloParserTest, NontupleInfeed) { const string original = R"(HloModule nontuple_infeed: ENTRY nontuple_infeed { - token = token[] after-all() - ROOT infeed = pred[] infeed(token) + token0 = token[] after-all() + ROOT infeed = pred[] infeed(token0) })"; ExpectHasSubstr(ParseHloString(original).status().error_message(), "infeed must have a non-empty tuple shape"); @@ -2239,7 +2239,7 @@ HloModule foobar ENTRY %entrycomp (p: f32[2,2]) -> f32[2,2] { %p = f32[2,2] parameter(0) - %constant.1 = f32[2,2] constant(f32[2,2] {{1, 2}, {3, 4}}) + %constant.1 = f32[2,2] constant({{1, 2}, {3, 4}}) ROOT %add.1 = f32[2,2] add(f32[2,2] %p, f32[2,5] %constant.1) } )"; @@ -2249,7 +2249,85 @@ ENTRY %entrycomp (p: f32[2,2]) -> f32[2,2] { " with the shape of the operand instruction f32[2,2]{1,0}."); } -// custom call incompatible shape. +TEST_F(HloParserTest, ParseShapeStringR2F32) { + string shape_string = "f32[123,456]"; + TF_ASSERT_OK_AND_ASSIGN(Shape actual, ParseShape(shape_string)); + Shape expected = ShapeUtil::MakeShape(F32, {123, 456}); + ASSERT_TRUE(ShapeUtil::Equal(expected, actual)) + << "expected: " << ShapeUtil::HumanString(expected) + << "actual: " << ShapeUtil::HumanString(actual); +} + +TEST_F(HloParserTest, ParseShapeStringTupleOfArrays) { + string shape_string = "(f32[1572864],s8[5120,1024])"; + TF_ASSERT_OK_AND_ASSIGN(Shape actual, ParseShape(shape_string)); + Shape expected = + ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(F32, {1572864}), + ShapeUtil::MakeShape(S8, {5120, 1024})}); + ASSERT_TRUE(ShapeUtil::Equal(expected, actual)) + << "expected: " << ShapeUtil::HumanString(expected) + << "actual: " << ShapeUtil::HumanString(actual); +} + +TEST_F(HloParserTest, ParseShapeStringNestedTuple) { + string shape_string = "(f32[1],(f32[2], token[]), opaque[], f32[3])"; + TF_ASSERT_OK_AND_ASSIGN(Shape actual, ParseShape(shape_string)); + Shape expected = ShapeUtil::MakeTupleShape({ + ShapeUtil::MakeShape(F32, {1}), + ShapeUtil::MakeTupleShape( + {ShapeUtil::MakeShape(F32, {2}), ShapeUtil::MakeTokenShape()}), + ShapeUtil::MakeOpaqueShape(), + ShapeUtil::MakeShape(F32, {3}), + }); + ASSERT_TRUE(ShapeUtil::Equal(expected, actual)) + << "expected: " << ShapeUtil::HumanString(expected) + << "actual: " << ShapeUtil::HumanString(actual); +} + +TEST_F(HloParserTest, ParseShapeStringWithLayout) { + string shape_string = "f32[123,456]{0,1}"; + TF_ASSERT_OK_AND_ASSIGN(Shape actual, ParseShape(shape_string)); + Shape expected = ShapeUtil::MakeShapeWithLayout(F32, {123, 456}, {0, 1}); + ASSERT_TRUE(ShapeUtil::Equal(expected, actual)) + << "expected: " << ShapeUtil::HumanString(expected) + << "actual: " << ShapeUtil::HumanString(actual); +} + +TEST_F(HloParserTest, ParseShapeStringWithSparseLayout) { + string shape_string = "f32[123,456]sparse{10}"; + TF_ASSERT_OK_AND_ASSIGN(Shape actual, ParseShape(shape_string)); + Shape expected = ShapeUtil::MakeShapeWithSparseLayout(F32, {123, 456}, 10); + ASSERT_TRUE(ShapeUtil::Equal(expected, actual)) + << "expected: " << ShapeUtil::HumanString(expected) + << "actual: " << ShapeUtil::HumanString(actual); +} + +TEST_F(HloParserTest, ParseOpaqueType) { + TF_ASSERT_OK_AND_ASSIGN(Shape actual, ParseShape("opaque[]")); + Shape expected = ShapeUtil::MakeOpaqueShape(); + ASSERT_TRUE(ShapeUtil::Equal(expected, actual)) + << "expected: " << ShapeUtil::HumanString(expected) + << "actual: " << ShapeUtil::HumanString(actual); +} + +TEST_F(HloParserTest, ParseTokenType) { + TF_ASSERT_OK_AND_ASSIGN(Shape actual, ParseShape("token[]")); + Shape expected = ShapeUtil::MakeTokenShape(); + ASSERT_TRUE(ShapeUtil::Equal(expected, actual)) + << "expected: " << ShapeUtil::HumanString(expected) + << "actual: " << ShapeUtil::HumanString(actual); +} + +TEST_F(HloParserTest, ParseInvalidShapeString) { + string shape_strings[] = { + "f32[123,456]foobar{0,1}", "f32[123,456]sparse{0,1}", "f32[123,456]{foo}", + "f32[123,456]dense{foo}", "f32[123,456]sparse{foo}", + }; + for (const string& shape_string : shape_strings) { + StatusOr result = ParseShape(shape_string); + ASSERT_FALSE(result.ok()) << "shape: " << shape_string; + } +} } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_pass_pipeline.cc b/tensorflow/compiler/xla/service/hlo_pass_pipeline.cc index 51177f24f5ee702be96fc8b4530ed38a5798109f..33ce7e23a82d840676bba5f1ca9c0ffc4433465d 100644 --- a/tensorflow/compiler/xla/service/hlo_pass_pipeline.cc +++ b/tensorflow/compiler/xla/service/hlo_pass_pipeline.cc @@ -77,6 +77,11 @@ std::vector HloPassPipeline::GetEnabledPasses( auto repeated_field = debug_options.xla_disable_hlo_passes(); absl::flat_hash_set disabled_pass_names(repeated_field.begin(), repeated_field.end()); + if (debug_options.xla_disable_all_hlo_passes()) { + VLOG(1) << "*All* passes disabled by --xla_disable_all_hlo_passes."; + return {}; + } + if (!disabled_pass_names.empty()) { VLOG(1) << "Passes disabled by --xla_disable_hlo_passes: " << absl::StrJoin(disabled_pass_names, ", "); diff --git a/tensorflow/compiler/xla/service/hlo_proto_util.cc b/tensorflow/compiler/xla/service/hlo_proto_util.cc index 981d06ce101644ecce587c4bd2f7a12c8edf6548..3a9ee57e5551ae5b608f02d9f8bd0428ff16db13 100644 --- a/tensorflow/compiler/xla/service/hlo_proto_util.cc +++ b/tensorflow/compiler/xla/service/hlo_proto_util.cc @@ -39,6 +39,7 @@ HloProto MakeHloProto(const HloModule& module) { StatusOr> CreateModuleFromProto( const HloModuleProto& proto, const HloModuleConfig& module_config) { + VLOG(4) << proto.ShortDebugString(); TF_ASSIGN_OR_RETURN(std::unique_ptr module, HloModule::CreateFromProto(proto, module_config)); TF_RETURN_IF_ERROR( diff --git a/tensorflow/compiler/xla/service/hlo_token.h b/tensorflow/compiler/xla/service/hlo_token.h deleted file mode 100644 index 4458c251dee4af365e39027dd4289925c8890efd..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/xla/service/hlo_token.h +++ /dev/null @@ -1,78 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_TOKEN_H_ -#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_TOKEN_H_ - -#include - -#include "tensorflow/compiler/xla/types.h" -#include "tensorflow/core/platform/types.h" - -namespace xla { - -// Defines different kinds of tokens in a hlo module string. -// -// You shouldn't need to use this directly unless you're using HloLexer -// directly, and you probably don't need to do that. Use hlo_parser instead. -enum class TokKind { - // Markers - kEof, - kError, - - // Tokens with no info. - kEqual, // = - kComma, // , - kColon, // : - kLsquare, - kRsquare, // [ ] - kLbrace, - kRbrace, // { } - kLparen, - kRparen, // ( ) - - kArrow, // -> - - // Keywords - kw_HloModule, - kw_ENTRY, - kw_ROOT, - kw_true, - kw_false, - kw_maximal, - kw_replicated, - kw_nan, - kw_inf, - - kNegInf, // -inf - - // Typed tokens. - kName, // %foo - kAttributeName, // dimensions= - kDimLabels, // [0-9bf]{2,}_[0-9io]{2,}->[0-9bf]{2,} - kDxD, // [0-9]+(x[0-9]+)+ - kPad, // [0-9]+_[0-9]+(_[0-9]+)?(x[0-9]+_[0-9]+(_[0-9]+)?)* - kIdent, // other identifiers - kString, // "abcd\"\n" - kShape, // f32[2,3]{1,0} - kInt, // 42 - kDecimal, // 4.2 -}; - -string TokKindToString(TokKind kind); - -} // namespace xla - -#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_TOKEN_H_ diff --git a/tensorflow/compiler/xla/service/hlo_verifier.cc b/tensorflow/compiler/xla/service/hlo_verifier.cc index 77db7b098a38ff4efdcc7447935fae61561c9ff4..ace854ed6a243c3788a46333f41cb85d90c8e174 100644 --- a/tensorflow/compiler/xla/service/hlo_verifier.cc +++ b/tensorflow/compiler/xla/service/hlo_verifier.cc @@ -481,7 +481,9 @@ Status ShapeVerifier::HandleCustomCall(HloInstruction* instruction) { const Shape& operand_shape_with_layout = custom_call->operand_shapes_with_layout()[i]; TF_RET_CHECK(ShapeUtil::Compatible(custom_call->operand(i)->shape(), - operand_shape_with_layout)); + operand_shape_with_layout)) + << custom_call->operand(i)->shape().ToString() << " operand " + << operand_shape_with_layout.ToString(); TF_RET_CHECK(LayoutUtil::HasLayout(operand_shape_with_layout)); } } diff --git a/tensorflow/compiler/xla/service/indexed_array_analysis_test.cc b/tensorflow/compiler/xla/service/indexed_array_analysis_test.cc index 98246d5403e4aebc2f4d81e52145706355ddd9a9..295465c8481bcb7d1385192febe0d09614e393b3 100644 --- a/tensorflow/compiler/xla/service/indexed_array_analysis_test.cc +++ b/tensorflow/compiler/xla/service/indexed_array_analysis_test.cc @@ -99,7 +99,7 @@ TEST_F(IndexedArrayAnalysisTest, SimpleOneToOneConstantGather) { HloModule SimpleGather ENTRY main { - operand = s32[3,3] constant(s32[3,3]{{1,2,3},{1,2,3},{1,2,3}}) + operand = s32[3,3] constant({{1,2,3},{1,2,3},{1,2,3}}) indices = s32[5] parameter(0) ROOT gather = s32[5,3] gather(operand, indices), offset_dims={1}, @@ -119,7 +119,7 @@ TEST_F(IndexedArrayAnalysisTest, GatherIsNotScalarIndexed0) { HloModule SimpleGather ENTRY main { - operand = s32[3,3] constant(s32[3,3]{{1,2,3},{1,2,3},{1,2,3}}) + operand = s32[3,3] constant({{1,2,3},{1,2,3},{1,2,3}}) indices = s32[5,2] parameter(0) ROOT gather = s32[5] gather(operand, indices), offset_dims={}, @@ -195,7 +195,7 @@ TEST_F(IndexedArrayAnalysisTest, GatherOfGather_OneToOne) { HloModule SimpleGather ENTRY main { - operand = s32[3,3] constant(s32[3,3]{{1,2,3},{1,2,3},{1,2,3}}) + operand = s32[3,3] constant({{1,2,3},{1,2,3},{1,2,3}}) indices_a = s32[5] parameter(0) indices_b = s32[2] parameter(1) gather_a = s32[5,3] gather(operand, indices_a), @@ -309,7 +309,7 @@ TEST_F(IndexedArrayAnalysisTest, ReshapeOfGather0) { HloModule ReshapeOfGather ENTRY main { - operand = s32[3,4] constant(s32[3,4]{{1,2,3,4},{1,2,3,4},{1,2,3,4}}) + operand = s32[3,4] constant({{1,2,3,4},{1,2,3,4},{1,2,3,4}}) indices = s32[5] parameter(0) gather = s32[5,4] gather(operand, indices), offset_dims={1}, @@ -330,7 +330,7 @@ TEST_F(IndexedArrayAnalysisTest, ReshapeOfGather1) { HloModule ReshapeOfGather ENTRY main { - operand = s32[3,4] constant(s32[3,4]{{1,2,3,4},{1,2,3,4},{1,2,3,4}}) + operand = s32[3,4] constant({{1,2,3,4},{1,2,3,4},{1,2,3,4}}) indices = s32[5,7] parameter(0) gather = s32[5,4,7] gather(operand, indices), offset_dims={1}, @@ -352,7 +352,7 @@ TEST_F(IndexedArrayAnalysisTest, ReshapeOfGather2) { HloModule ReshapeOfGather ENTRY main { - operand = s32[3,2,6] constant(s32[3,2,6]{ + operand = s32[3,2,6] constant({ {{1,2,3,4,5,6},{1,2,3,4,5,6}}, {{1,2,3,4,5,6},{1,2,3,4,5,6}}, {{1,2,3,4,5,6},{1,2,3,4,5,6}}}) @@ -377,7 +377,7 @@ TEST_F(IndexedArrayAnalysisTest, ReshapeOfGather3) { HloModule ReshapeOfGather ENTRY main { - operand = s32[2,6] constant(s32[2,6]{ + operand = s32[2,6] constant({ {1,2,3,4,5,6},{1,2,3,4,5,6}}) indices = s32[1] parameter(0) gather = s32[1,6] gather(operand, indices), @@ -405,7 +405,7 @@ TEST_F(IndexedArrayAnalysisTest, ReshapeOfGather4) { HloModule ReshapeOfGather ENTRY main { - operand = s32[2,3]{1,0} constant(s32[2,3] { { 1, 2, 3 }, { 1, 2, 3 } }) + operand = s32[2,3]{1,0} constant({ { 1, 2, 3 }, { 1, 2, 3 } }) i.0 = s64[1,3]{1,0} parameter(0) g.0 = s32[1,3,3]{2,1,0} gather(operand, i.0), offset_dims={2}, @@ -438,7 +438,7 @@ TEST_F(IndexedArrayAnalysisTest, ReshapeOfGather5) { HloModule ReshapeOfGather ENTRY main { - operand = s32[1,6] constant(s32[1,6]{{1,2,3,4,5,6}}) + operand = s32[1,6] constant({{1,2,3,4,5,6}}) indices = s32[1] parameter(0) gather = s32[1,6] gather(operand, indices), offset_dims={1}, @@ -465,7 +465,7 @@ TEST_F(IndexedArrayAnalysisTest, ReshapeOfGather6) { HloModule ReshapeOfGather ENTRY main { - operand = s32[1,2,6] constant(s32[1,2,6]{{ + operand = s32[1,2,6] constant({{ {1,2,3,4,5,6},{1,2,3,4,5,6}}}) indices = s32[1] parameter(0) gather = s32[1,1,6] gather(operand, indices), @@ -496,7 +496,7 @@ TEST_F(IndexedArrayAnalysisTest, ReshapeOfGather7) { HloModule ReshapeOfGather ENTRY main { - operand = s32[2,6] constant(s32[2,6]{ + operand = s32[2,6] constant({ {1,2,3,4,5,6},{1,2,3,4,5,6}}) indices = s32[1,5] parameter(0) gather = s32[1,5,6] gather(operand, indices), @@ -527,7 +527,7 @@ TEST_F(IndexedArrayAnalysisTest, ReshapeOfGatherNoFold0) { HloModule ReshapeOfGather ENTRY main { - operand = s32[3,4] constant(s32[3,4]{{1,2,3,4},{1,2,3,4},{1,2,3,4}}) + operand = s32[3,4] constant({{1,2,3,4},{1,2,3,4},{1,2,3,4}}) indices = s32[5,6] parameter(0) gather = s32[5,4,6] gather(operand, indices), offset_dims={1}, @@ -556,7 +556,7 @@ TEST_F(IndexedArrayAnalysisTest, ReshapeOfGatherNoFold1) { HloModule ReshapeOfGather ENTRY main { - operand = s32[3,5,2] constant(s32[3,5,2]{ + operand = s32[3,5,2] constant({ {{1,2},{3,4},{5,6},{7,8},{9,10}}, {{1,2},{3,4},{5,6},{7,8},{9,10}}, {{1,2},{3,4},{5,6},{7,8},{9,10}}}) @@ -588,7 +588,7 @@ TEST_F(IndexedArrayAnalysisTest, ReshapeOfGatherNoFold2) { HloModule ReshapeOfGather ENTRY main { - operand = s32[3,4,1] constant(s32[3,4,1]{ + operand = s32[3,4,1] constant({ {{1},{2},{3},{4}}, {{1},{2},{3},{4}}, {{1},{2},{3},{4}}}) @@ -620,7 +620,7 @@ TEST_F(IndexedArrayAnalysisTest, UnaryOpOfGather) { HloModule UnaryOpOfGather ENTRY main { - operand = f32[3,4] constant(f32[3,4]{{1,2,3,4},{1,3,2,4},{4,3,2,1}}) + operand = f32[3,4] constant({{1,2,3,4},{1,3,2,4},{4,3,2,1}}) indices = s32[5] parameter(0) gather = f32[5,4] gather(operand, indices), offset_dims={1}, @@ -645,7 +645,7 @@ TEST_F(IndexedArrayAnalysisTest, AddBroadcastedScalarWithGather) { HloModule AddBroadcastedScalarWithGather ENTRY main { - gather_operand = s32[3,4] constant(s32[3,4]{{1,2,3,4},{1,3,2,4},{4,3,2,1}}) + gather_operand = s32[3,4] constant({{1,2,3,4},{1,3,2,4},{4,3,2,1}}) constant = s32[] constant(5) constant_broadcasted = s32[5,4] broadcast(constant), dimensions={} indices = s32[5] parameter(0) @@ -673,7 +673,7 @@ TEST_F(IndexedArrayAnalysisTest, HloModule SubtractBroadcastedScalarWithGather ENTRY main { - gather_operand = s32[3,4] constant(s32[3,4]{{1,2,3,4},{1,3,2,4},{4,3,2,1}}) + gather_operand = s32[3,4] constant({{1,2,3,4},{1,3,2,4},{4,3,2,1}}) constant = s32[] constant(5) constant_broadcasted = s32[5,4] broadcast(constant), dimensions={} indices = s32[5] parameter(0) @@ -701,7 +701,7 @@ TEST_F(IndexedArrayAnalysisTest, HloModule SubtractBroadcastedScalarWithGather ENTRY main { - gather_operand = s32[3,4] constant(s32[3,4]{{1,2,3,4},{1,3,2,4},{4,3,2,1}}) + gather_operand = s32[3,4] constant({{1,2,3,4},{1,3,2,4},{4,3,2,1}}) constant = s32[] constant(5) constant_broadcasted = s32[5,4] broadcast(constant), dimensions={} indices = s32[5] parameter(0) @@ -728,7 +728,7 @@ TEST_F(IndexedArrayAnalysisTest, AddBroadcastedVectorWithGather) { HloModule AddBroadcastedVectorWithGather ENTRY main { - gather_operand = s32[3,4] constant(s32[3,4]{{1,2,3,4},{1,3,2,4},{4,3,2,1}}) + gather_operand = s32[3,4] constant({{1,2,3,4},{1,3,2,4},{4,3,2,1}}) constant_vect = s32[4] constant({10,11,12,13}) constant_broadcasted = s32[5,4] broadcast(constant_vect), dimensions={1} indices = s32[5] parameter(0) @@ -755,7 +755,7 @@ TEST_F(IndexedArrayAnalysisTest, AddBroadcastedVectorWithGather_Negative) { HloModule AddBroadcastedVectorWithGather ENTRY main { - gather_operand = s32[3,4] constant(s32[3,4]{{1,2,3,4},{1,3,2,4},{4,3,2,1}}) + gather_operand = s32[3,4] constant({{1,2,3,4},{1,3,2,4},{4,3,2,1}}) constant_vect = s32[5] constant({10,11,12,13,14}) constant_broadcasted = s32[5,4] broadcast(constant_vect), dimensions={0} indices = s32[5] parameter(0) @@ -804,8 +804,8 @@ TEST_F(IndexedArrayAnalysisTest, DotOpBasic_0) { HloModule DotOp ENTRY main { - gather_operand = s32[3,4] constant(s32[3,4]{{1,2,3,4},{5,6,7,8},{9,10,11,12}}) - dot_rhs_constant = s32[4,3] constant(s32[4,3]{{1,2,3},{4,5,6},{7,8,9},{10,11,12}}) + gather_operand = s32[3,4] constant({{1,2,3,4},{5,6,7,8},{9,10,11,12}}) + dot_rhs_constant = s32[4,3] constant({{1,2,3},{4,5,6},{7,8,9},{10,11,12}}) indices = s32[5] parameter(0) dot_lhs = s32[5,4] gather(gather_operand, indices), offset_dims={1}, @@ -831,8 +831,8 @@ TEST_F(IndexedArrayAnalysisTest, DotOpBasic_1) { HloModule DotOp ENTRY main { - gather_operand = s32[3,4] constant(s32[3,4]{{1,2,3,4},{5,6,7,8},{9,10,11,12}}) - dot_rhs_constant = s32[3,3] constant(s32[3,3]{{1,2,3},{4,5,6},{7,8,9}}) + gather_operand = s32[3,4] constant({{1,2,3,4},{5,6,7,8},{9,10,11,12}}) + dot_rhs_constant = s32[3,3] constant({{1,2,3},{4,5,6},{7,8,9}}) indices = s32[5] parameter(0) dot_lhs = s32[3,5] gather(gather_operand, indices), offset_dims={0}, @@ -859,8 +859,8 @@ TEST_F(IndexedArrayAnalysisTest, DotOpBasic_2) { HloModule DotOp ENTRY main { - gather_operand = s32[3,4] constant(s32[3,4]{{1,2,3,4},{5,6,7,8},{9,10,11,12}}) - dot_lhs_constant = s32[4,3] constant(s32[4,3]{{1,2,3},{4,5,6},{7,8,9},{10,11,12}}) + gather_operand = s32[3,4] constant({{1,2,3,4},{5,6,7,8},{9,10,11,12}}) + dot_lhs_constant = s32[4,3] constant({{1,2,3},{4,5,6},{7,8,9},{10,11,12}}) indices = s32[5] parameter(0) dot_rhs = s32[3,5] gather(gather_operand, indices), offset_dims={0}, @@ -888,8 +888,8 @@ TEST_F(IndexedArrayAnalysisTest, DotOpBasic_3) { HloModule DotOp ENTRY main { - gather_operand = s32[4,3] constant(s32[4,3]{{1,2,3},{4,5,6},{7,8,9},{10,11,12}}) - dot_lhs_constant = s32[4,3] constant(s32[4,3]{{1,2,3},{4,5,6},{7,8,9},{10,11,12}}) + gather_operand = s32[4,3] constant({{1,2,3},{4,5,6},{7,8,9},{10,11,12}}) + dot_lhs_constant = s32[4,3] constant({{1,2,3},{4,5,6},{7,8,9},{10,11,12}}) indices = s32[5] parameter(0) dot_rhs = s32[5,3] gather(gather_operand, indices), offset_dims={1}, @@ -917,8 +917,8 @@ TEST_F(IndexedArrayAnalysisTest, DotOpWithBatch) { HloModule DotOp ENTRY main { - gather_operand = s32[2,3,2] constant(s32[2,3,2]{{{1,2},{3,4},{5,6}},{{7,8},{9,10},{11,12}}}) - dot_lhs_constant = s32[2,2,3] constant(s32[2,2,3]{{{1,2,3},{4,5,6}},{{7,8,9},{10,11,12}}}) + gather_operand = s32[2,3,2] constant({{{1,2},{3,4},{5,6}},{{7,8},{9,10},{11,12}}}) + dot_lhs_constant = s32[2,2,3] constant({{{1,2,3},{4,5,6}},{{7,8,9},{10,11,12}}}) indices = s32[4] parameter(0) dot_rhs = s32[2,3,4] gather(gather_operand, indices), offset_dims={0,1}, @@ -948,8 +948,8 @@ TEST_F(IndexedArrayAnalysisTest, DotOpNegative) { HloModule DotOp ENTRY main { - gather_operand = s32[3,4] constant(s32[3,4]{{1,2,3,4},{5,6,7,8},{9,10,11,12}}) - dot_rhs_constant = s32[2,3] constant(s32[2,3]{{1,2,3},{4,5,6}}) + gather_operand = s32[3,4] constant({{1,2,3,4},{5,6,7,8},{9,10,11,12}}) + dot_rhs_constant = s32[2,3] constant({{1,2,3},{4,5,6}}) indices = s32[2] parameter(0) dot_lhs = s32[3,2] gather(gather_operand, indices), offset_dims={0}, diff --git a/tensorflow/compiler/xla/service/instruction_fusion_test.cc b/tensorflow/compiler/xla/service/instruction_fusion_test.cc index 58b7135cea7419f13d60ed510ecf7a88126aee48..611cfd404d7622f561f0acc86fc9b05e16eea22e 100644 --- a/tensorflow/compiler/xla/service/instruction_fusion_test.cc +++ b/tensorflow/compiler/xla/service/instruction_fusion_test.cc @@ -259,8 +259,8 @@ TEST_F(InstructionFusionTest, AvoidDuplicationIfNotAllFusibleRecursively) { add = f32[4,3]{1,0} add(p0, p0) abs1 = f32[4,3]{1,0} abs(add) log = f32[4,3]{1,0} log(abs1) - token = token[] after-all() - send = f32[4,3]{1,0} send(log, token), channel_id=0 + token0 = token[] after-all() + send = f32[4,3]{1,0} send(log, token0), channel_id=0 abs2 = f32[4,3]{1,0} abs(log) ROOT root = f32[4,3]{1,0} subtract(abs2, add) })") @@ -290,8 +290,8 @@ TEST_F(InstructionFusionTest, AvoidDuplicationIfNotAllFusibleRecursively) { p0 = f32[4,3]{1,0} parameter(0) add1 = f32[4,3]{1,0} add(p0, p0) log = f32[4,3]{1,0} log(p0) - token = token[] after-all() - send = f32[4,3]{1,0} send(log, token), channel_id=0 + token0 = token[] after-all() + send = f32[4,3]{1,0} send(log, token0), channel_id=0 add2 = f32[4,3]{1,0} add(log, add1) ROOT root = f32[4,3]{1,0} subtract(add1, add2) })") @@ -324,8 +324,8 @@ TEST_F(InstructionFusionTest, AvoidDuplicationIfNotAllFusibleRecursively) { add1 = f32[4,3]{1,0} add(p0, p0) add2 = f32[4,3]{1,0} add(add1, add1) log = f32[4,3]{1,0} log(add2) - token = token[] after-all() - send = f32[4,3]{1,0} send(log, token), channel_id=0 + token0 = token[] after-all() + send = f32[4,3]{1,0} send(log, token0), channel_id=0 sub1 = f32[4,3]{1,0} subtract(log, add2) sub2 = f32[4,3]{1,0} subtract(add2, add1) ROOT root = (f32[4,3]{1,0}, f32[4,3]{1,0}) tuple(sub1, sub2) diff --git a/tensorflow/compiler/xla/service/interpreter/compiler.cc b/tensorflow/compiler/xla/service/interpreter/compiler.cc index 3a5177c418e3af8253df228a51f2fc0901d10041..d37ae94bf6c4c697bbf30390c02a5099271e00a4 100644 --- a/tensorflow/compiler/xla/service/interpreter/compiler.cc +++ b/tensorflow/compiler/xla/service/interpreter/compiler.cc @@ -76,9 +76,12 @@ StatusOr> InterpreterCompiler::RunBackend( // need to compile anything // Create executable from only the Hlo module. + auto evaluator = absl::make_unique(); + evaluator->set_use_fast_path( + hlo_module->config().debug_options().xla_hlo_evaluator_use_fast_path()); std::unique_ptr executable = - absl::make_unique( - std::move(hlo_module), absl::make_unique()); + absl::make_unique(std::move(hlo_module), + std::move(evaluator)); return std::move(executable); } diff --git a/tensorflow/compiler/xla/service/layout_assignment_test.cc b/tensorflow/compiler/xla/service/layout_assignment_test.cc index 5c661bfacb08fe27f3cbdc1fb9db083315166008..9fe8c3accbf283f3b3eebbefbac8739c37df16bc 100644 --- a/tensorflow/compiler/xla/service/layout_assignment_test.cc +++ b/tensorflow/compiler/xla/service/layout_assignment_test.cc @@ -847,12 +847,12 @@ TEST_F(LayoutAssignmentTest, ChannelLayoutMismatch) { ENTRY entry_computation { param = (f32[2,2]) parameter(0) gte = f32[2,2] get-tuple-element(param), index=0 - token = token[] after-all() - recv = (f32[2,2], u32[], token[]) recv(token), channel_id=1, sharding={maximal device=1} + token0 = token[] after-all() + recv = (f32[2,2], u32[], token[]) recv(token0), channel_id=1, sharding={maximal device=1} recv-done = (f32[2,2], token[]) recv-done(recv), channel_id=1, sharding={maximal device=1} ROOT root = f32[2,2] get-tuple-element(recv-done), index=0 - send = (f32[2,2], u32[], token[]) send(gte, token), channel_id=1, + send = (f32[2,2], u32[], token[]) send(gte, token0), channel_id=1, sharding={maximal device=0} send-done = token[] send-done(send), channel_id=1, sharding={maximal device=0} } @@ -897,7 +897,7 @@ TEST_F(LayoutAssignmentTest, AllReduceLayoutMissmatch) { ar.0 = f32[2,2] cross-replica-sum(gte), all_reduce_id=1, replica_groups={{0}}, to_apply=add, sharding={maximal device=0} - const = f32[2,2] constant(f32[2,2]{{0,1},{2,3}}) + const = f32[2,2] constant({{0,1},{2,3}}) ROOT ar.1 = f32[2,2] cross-replica-sum(const), all_reduce_id=1, replica_groups={{0}}, to_apply=add, sharding={maximal device=1} diff --git a/tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.cc b/tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.cc index bd0139f85b6a5c5dc23dad962263038451921e65..5eeb29c478a371dae83251771f2dc4844672d3e9 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.cc @@ -18,28 +18,29 @@ limitations under the License. #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" namespace xla { -Status KernelSupportLibrary::For( +Status KernelSupportLibrary::ForWithStatus( absl::string_view name, llvm::Value* start, llvm::Value* end, llvm::Value* step, const std::function& for_body_generator) { - return If(b_->CreateICmpSLT(start, end), [&]() -> Status { + return IfWithStatus(b_->CreateICmpSLT(start, end), [&]() -> Status { TF_RETURN_IF_ERROR(for_body_generator(start, /*is_first_iteration=*/true)); - return For(name, b_->CreateAdd(start, step), end, step, - [&](llvm::Value* iv) { return for_body_generator(iv, false); }); + return ForWithStatus( + name, b_->CreateAdd(start, step), end, step, + [&](llvm::Value* iv) { return for_body_generator(iv, false); }); }); } -Status KernelSupportLibrary::For( +Status KernelSupportLibrary::ForWithStatus( absl::string_view name, llvm::Value* start, llvm::Value* end, llvm::Value* step, bool peel_first_iteration, const std::function& for_body_generator) { if (peel_first_iteration) { - return For(name, start, end, step, true, - [&](llvm::Value* indvar, bool is_first_iteration) -> Status { - return for_body_generator(indvar, - b_->getInt1(is_first_iteration)); - }); + return ForWithStatus( + name, start, end, step, true, + [&](llvm::Value* indvar, bool is_first_iteration) -> Status { + return for_body_generator(indvar, b_->getInt1(is_first_iteration)); + }); } else { std::unique_ptr loop = llvm_ir::ForLoop::EmitForLoop( name, start, end, step, b_, @@ -55,7 +56,7 @@ Status KernelSupportLibrary::For( } } -Status KernelSupportLibrary::If( +Status KernelSupportLibrary::IfWithStatus( absl::string_view name, llvm::Value* condition, const std::function& true_block_generator, const std::function& false_block_generator) { diff --git a/tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.h b/tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.h index 43fec311f150d6054f6ad24f99db332f90ff94a3..612b839cfa15711061e1ae53358a72d5220e1801 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.h +++ b/tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.h @@ -48,41 +48,42 @@ class KernelSupportLibrary { // for (i64 i = `start` + `step`; i s< `end`; i += `step`) // `for_body_generator(/*ind_var=*/,i, /*is_first_iteration=*/false)`; // } - Status For( + Status ForWithStatus( absl::string_view name, llvm::Value* start, llvm::Value* end, llvm::Value* step, const std::function& for_body_generator); - void ForReturnVoid( + void For( absl::string_view name, llvm::Value* start, llvm::Value* end, llvm::Value* step, const std::function& for_body_generator) { CHECK_EQ(Status::OK(), - For(name, start, end, step, + ForWithStatus( + name, start, end, step, [&](llvm::Value* ind_var, bool is_first_iteration) -> Status { for_body_generator(ind_var, is_first_iteration); return Status::OK(); })); } - Status For(absl::string_view name, int64 start, int64 end, int64 step, - const std::function& - for_body_generator) { - return For(name, /*start=*/b_->getInt64(start), - /*end=*/b_->getInt64(end), - /*step=*/b_->getInt64(step), for_body_generator); + Status ForWithStatus( + absl::string_view name, int64 start, int64 end, int64 step, + const std::function& for_body_generator) { + return ForWithStatus(name, /*start=*/b_->getInt64(start), + /*end=*/b_->getInt64(end), + /*step=*/b_->getInt64(step), for_body_generator); } - void ForReturnVoid( + void For( absl::string_view name, int64 start, int64 end, int64 step, const std::function& for_body_generator) { - ForReturnVoid(name, /*start=*/b_->getInt64(start), - /*end=*/b_->getInt64(end), - /*step=*/b_->getInt64(step), for_body_generator); + For(name, /*start=*/b_->getInt64(start), + /*end=*/b_->getInt64(end), + /*step=*/b_->getInt64(step), for_body_generator); } // Generates the following control flow structure if `peel_first_iteration` is @@ -99,19 +100,19 @@ class KernelSupportLibrary { // for (i64 i = `start`; i s< `end`; i += `step`) // `for_body_generator(/*ind_var=*/,i, // /*is_first_iteration=*/,(i != `start`))`; - Status For(absl::string_view name, llvm::Value* start, llvm::Value* end, - llvm::Value* step, bool peel_first_iteration, - const std::function& - for_body_generator); + Status ForWithStatus( + absl::string_view name, llvm::Value* start, llvm::Value* end, + llvm::Value* step, bool peel_first_iteration, + const std::function& + for_body_generator); - void ForReturnVoid(absl::string_view name, llvm::Value* start, - llvm::Value* end, llvm::Value* step, - bool peel_first_iteration, - const std::function& - for_body_generator) { - TF_CHECK_OK(For( + void For(absl::string_view name, llvm::Value* start, llvm::Value* end, + llvm::Value* step, bool peel_first_iteration, + const std::function& + for_body_generator) { + TF_CHECK_OK(ForWithStatus( name, start, end, step, peel_first_iteration, [&](llvm::Value* ind_var, llvm::Value* is_first_iteration) -> Status { for_body_generator(ind_var, is_first_iteration); @@ -119,80 +120,81 @@ class KernelSupportLibrary { })); } - Status For(absl::string_view name, llvm::Value* start, llvm::Value* end, - int64 step, bool peel_first_iteration, - const std::function& - for_body_generator) { - return For(name, /*start=*/start, /*end=*/end, - /*step=*/llvm::ConstantInt::get(start->getType(), step), - peel_first_iteration, for_body_generator); + Status ForWithStatus( + absl::string_view name, llvm::Value* start, llvm::Value* end, int64 step, + bool peel_first_iteration, + const std::function& + for_body_generator) { + return ForWithStatus( + name, /*start=*/start, /*end=*/end, + /*step=*/llvm::ConstantInt::get(start->getType(), step), + peel_first_iteration, for_body_generator); } - void ForReturnVoid(absl::string_view name, llvm::Value* start, - llvm::Value* end, int64 step, bool peel_first_iteration, - const std::function& - for_body_generator) { - ForReturnVoid(name, /*start=*/start, /*end=*/end, - /*step=*/llvm::ConstantInt::get(start->getType(), step), - peel_first_iteration, for_body_generator); + void For(absl::string_view name, llvm::Value* start, llvm::Value* end, + int64 step, bool peel_first_iteration, + const std::function& + for_body_generator) { + For(name, /*start=*/start, /*end=*/end, + /*step=*/llvm::ConstantInt::get(start->getType(), step), + peel_first_iteration, for_body_generator); } - Status For( + Status ForWithStatus( absl::string_view name, llvm::Value* start, llvm::Value* end, llvm::Value* step, const std::function& for_body_generator) { - return For(name, start, end, step, - /*peel_first_iteration=*/false, - [&](llvm::Value* indvar, llvm::Value*) -> Status { - return for_body_generator(indvar); - }); + return ForWithStatus(name, start, end, step, + /*peel_first_iteration=*/false, + [&](llvm::Value* indvar, llvm::Value*) -> Status { + return for_body_generator(indvar); + }); } - void ForReturnVoid( + void For( absl::string_view name, llvm::Value* start, llvm::Value* end, llvm::Value* step, const std::function& for_body_generator) { - ForReturnVoid(name, start, end, step, - /*peel_first_iteration=*/false, - [&](llvm::Value* indvar, llvm::Value*) { - return for_body_generator(indvar); - }); + For(name, start, end, step, + /*peel_first_iteration=*/false, [&](llvm::Value* indvar, llvm::Value*) { + return for_body_generator(indvar); + }); } - Status For( + Status ForWithStatus( absl::string_view name, llvm::Value* start, llvm::Value* end, int64 step, const std::function& for_body_generator) { - return For(name, start, end, llvm::ConstantInt::get(start->getType(), step), - /*peel_first_iteration=*/false, - [&](llvm::Value* indvar, llvm::Value*) -> Status { - return for_body_generator(indvar); - }); + return ForWithStatus(name, start, end, + llvm::ConstantInt::get(start->getType(), step), + /*peel_first_iteration=*/false, + [&](llvm::Value* indvar, llvm::Value*) -> Status { + return for_body_generator(indvar); + }); } - void ForReturnVoid( + void For( absl::string_view name, llvm::Value* start, llvm::Value* end, int64 step, const std::function& for_body_generator) { - ForReturnVoid(name, start, end, - llvm::ConstantInt::get(start->getType(), step), - for_body_generator); + For(name, start, end, llvm::ConstantInt::get(start->getType(), step), + for_body_generator); } - Status For( + Status ForWithStatus( absl::string_view name, int64 start, int64 end, int64 step, const std::function& for_body_generator) { - return For(name, /*start=*/b_->getInt64(start), - /*end=*/b_->getInt64(end), - /*step=*/b_->getInt64(step), for_body_generator); + return ForWithStatus(name, /*start=*/b_->getInt64(start), + /*end=*/b_->getInt64(end), + /*step=*/b_->getInt64(step), for_body_generator); } - void ForReturnVoid( + void For( absl::string_view name, int64 start, int64 end, int64 step, const std::function& for_body_generator) { - ForReturnVoid(name, /*start=*/b_->getInt64(start), - /*end=*/b_->getInt64(end), - /*step=*/b_->getInt64(step), for_body_generator); + For(name, /*start=*/b_->getInt64(start), + /*end=*/b_->getInt64(end), + /*step=*/b_->getInt64(step), for_body_generator); } // Generates the following control flow structure: @@ -201,38 +203,43 @@ class KernelSupportLibrary { // `true_block_generator()`; // else // `false_block_generator()`; - Status If(absl::string_view name, llvm::Value* condition, - const std::function& true_block_generator, - const std::function& false_block_generator = - []() -> Status { return Status::OK(); }); + Status IfWithStatus( + absl::string_view name, llvm::Value* condition, + const std::function& true_block_generator, + const std::function& false_block_generator = []() -> Status { + return Status::OK(); + }); - Status If(llvm::Value* condition, - const std::function& true_block_generator, - const std::function& false_block_generator = - []() -> Status { return Status::OK(); }) { - return If("", condition, true_block_generator, false_block_generator); + Status IfWithStatus( + llvm::Value* condition, + const std::function& true_block_generator, + const std::function& false_block_generator = []() -> Status { + return Status::OK(); + }) { + return IfWithStatus("", condition, true_block_generator, + false_block_generator); } - void IfReturnVoid(llvm::Value* condition, - const std::function& true_block_generator, - const std::function& false_block_generator = []() { - }) { - IfReturnVoid("", condition, true_block_generator, false_block_generator); + void If( + llvm::Value* condition, const std::function& true_block_generator, + const std::function& false_block_generator = []() {}) { + If("", condition, true_block_generator, false_block_generator); } - void IfReturnVoid(absl::string_view name, llvm::Value* condition, - const std::function& true_block_generator, - const std::function& false_block_generator = []() { - }) { - TF_CHECK_OK(If(name, condition, - [&]() { - true_block_generator(); - return Status::OK(); - }, - [&]() { - false_block_generator(); - return Status::OK(); - })); + void If( + absl::string_view name, llvm::Value* condition, + const std::function& true_block_generator, + const std::function& false_block_generator = []() {}) { + TF_CHECK_OK(IfWithStatus( + name, condition, + [&]() { + true_block_generator(); + return Status::OK(); + }, + [&]() { + false_block_generator(); + return Status::OK(); + })); } using ArgumentVector = absl::Span; diff --git a/tensorflow/compiler/xla/service/llvm_ir/kernel_tiling.cc b/tensorflow/compiler/xla/service/llvm_ir/kernel_tiling.cc index c26711e526c9b89cdedcb6aed9f93d41dd25dc83..cebbc4290163d4e98003cd7b5df6ec906509a446 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/kernel_tiling.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/kernel_tiling.cc @@ -120,7 +120,7 @@ KernelMappingScheme::KernelMappingScheme( absl::Span req_block_sizes, int64 num_threads_y, int64 num_threads_x, llvm::IRBuilder<>* b) : b_(b), - dims_in_elems_(dims_in_elems), + dims_in_elems_(dims_in_elems.begin(), dims_in_elems.end()), tile_sizes_{1, tile_size_y, tile_size_x}, num_threads_x_(num_threads_x), num_threads_y_(num_threads_y) { @@ -170,14 +170,16 @@ IrArray::Index KernelMappingScheme::EmitBlockIndex(llvm::Type* index_ty) { IrArray::Index KernelMappingScheme::GetTileIndexForBlockOrigin( const IrArray::Index& block_index) { - IrArray::Index tile_index = block_index; + DCHECK_EQ(block_index.size(), block_sizes_.size()); + std::vector multidim; + multidim.reserve(block_sizes_.size()); for (int i = 0; i < block_sizes_.size(); ++i) { - tile_index[i] = b_->CreateMul( + multidim.push_back(b_->CreateMul( block_index[i], llvm::ConstantInt::get(block_index[i]->getType(), block_sizes_[i]), - "block_origin." + std::to_string(i)); + "block_origin." + std::to_string(i))); } - return tile_index; + return IrArray::Index(multidim, block_index[0]->getType()); } IrArray::Index KernelMappingScheme::GetElementIndexForTileOrigin( @@ -217,14 +219,14 @@ KernelMappingScheme::EmitThreadYXCoordinate(llvm::Type* index_ty) { // defined by (num_thread_y, num_thread_x) from thread_id. llvm::CallInst* thread_id_raw = llvm_ir::EmitCallToIntrinsic( llvm::Intrinsic::nvvm_read_ptx_sreg_tid_x, {}, {}, b_); - llvm_ir::AddRangeMetadata(0, GetThreadsPerTile(), thread_id_raw); + llvm_ir::AddRangeMetadata(0, GetThreadsPerBlock(), thread_id_raw); llvm::Value* thread_id_int = b_->CreateIntCast(thread_id_raw, index_ty, /*isSigned=*/true, "thread.id.x"); llvm::Value* num_thread_x = llvm::ConstantInt::get(index_ty, GetNumberOfThreadsForDimensionX()); - llvm::Value* x = b_->CreateURem(thread_id_int, num_thread_x); - llvm::Value* y = b_->CreateUDiv(thread_id_int, num_thread_x); + llvm::Value* x = b_->CreateURem(thread_id_int, num_thread_x, "thread.x"); + llvm::Value* y = b_->CreateUDiv(thread_id_int, num_thread_x, "thread.y"); return std::make_tuple(y, x); } diff --git a/tensorflow/compiler/xla/service/llvm_ir/kernel_tiling.h b/tensorflow/compiler/xla/service/llvm_ir/kernel_tiling.h index 06002d57b0d7daa07f903feebe67a60a083c0e7c..fb633b12e60d1a9f3103fb2919ad2c3f3f14de20 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/kernel_tiling.h +++ b/tensorflow/compiler/xla/service/llvm_ir/kernel_tiling.h @@ -90,15 +90,16 @@ class KernelMappingScheme { enum { DimZ = 0, DimY, DimX, DimTot }; public: + KernelMappingScheme() {} // dims_in_elems: the normalized tensor dimensions. // req_block_sizes: the requested block size in number of tiles for each // dimension. The actual block size is set to min(req_block_size, // dims_in_number_of_blocks). - explicit KernelMappingScheme(absl::Span dims_in_elems, - int64 tile_size_y, int64 tile_size_x, - absl::Span req_block_sizes, - int64 num_threads_y, int64 num_threads_x, - llvm::IRBuilder<>* b); + KernelMappingScheme(absl::Span dims_in_elems, int64 tile_size_y, + int64 tile_size_x, + absl::Span req_block_sizes, + int64 num_threads_y, int64 num_threads_x, + llvm::IRBuilder<>* b); absl::Span GetDimensionsInElements() const { return dims_in_elems_; @@ -133,11 +134,15 @@ class KernelMappingScheme { } absl::Span GetBlockSizes() const { return block_sizes_; } + int64 GetTileBlockSizeForDimension(int d) const { + DCHECK(d >= DimZ && d <= DimX); + return dims_in_blocks_[d]; + } int64 GetNumberOfThreadsForDimensionX() const { return num_threads_x_; } int64 GetNumberOfThreadsForDimensionY() const { return num_threads_y_; } - int64 GetThreadsPerTile() const { + int64 GetThreadsPerBlock() const { return GetNumberOfThreadsForDimensionX() * GetNumberOfThreadsForDimensionY(); } @@ -163,7 +168,7 @@ class KernelMappingScheme { private: llvm::IRBuilder<>* b_; // The number of elements in each dimension. - absl::Span dims_in_elems_; + std::vector dims_in_elems_; // The number of elements for each dimension of a tile. std::vector tile_sizes_; diff --git a/tensorflow/compiler/xla/service/llvm_ir/sort_util.cc b/tensorflow/compiler/xla/service/llvm_ir/sort_util.cc index e22c2173c271fc9571be1ddb0759d2b31562dc98..6a9406bfebafcc02dc2e144b62284a9e83c3edeb 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/sort_util.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/sort_util.cc @@ -108,7 +108,7 @@ void EmitCompareLoopBody( // if (is_smaller_index && index_is_inbounds) KernelSupportLibrary ksl(b); - ksl.IfReturnVoid("smaller_comparison_index", do_comparison, [&]() { + ksl.If("smaller_comparison_index", do_comparison, [&]() { auto key1 = read_element(0, current_keys_index); auto key2 = read_element(0, compare_keys_index); auto compare_key1 = key1; @@ -155,7 +155,7 @@ void EmitCompareLoopBody( is_smaller_than = b->CreateOr( is_smaller_than, b->CreateAnd(keys_equal, index_is_smaller_than)); } - ksl.IfReturnVoid("is_smaller_than", is_smaller_than, [&]() { + ksl.If("is_smaller_than", is_smaller_than, [&]() { // Swap key1 with key2. write_element(0, current_keys_index, key2); write_element(0, compare_keys_index, key1); @@ -192,7 +192,7 @@ void EmitTiledCompareLoop( b->CreateShl(tiled_keys_index[dimension_to_sort], value_one); // We want to copy two adjacent elements. We first check whether the // first index position is within bounds. - ksl.IfReturnVoid( + ksl.If( "smaller_keys_index", b->CreateICmpSLT(current_keys_index, tiled_keys_index.GetConstantWithIndexType( @@ -203,15 +203,14 @@ void EmitTiledCompareLoop( // Increment to go the next index position. current_keys_index = b->CreateAdd(current_keys_index, value_one); // Here we check whether the next index position is within bounds. - ksl.IfReturnVoid( - "inner_smaller_keys_index", - b->CreateICmpSLT(current_keys_index, - tiled_keys_index.GetConstantWithIndexType( - dimension_to_sort_bound)), - [&]() { - cache_index = b->CreateAdd(cache_index, value_one); - read_or_write(cache_index, current_keys_index); - }); + ksl.If("inner_smaller_keys_index", + b->CreateICmpSLT(current_keys_index, + tiled_keys_index.GetConstantWithIndexType( + dimension_to_sort_bound)), + [&]() { + cache_index = b->CreateAdd(cache_index, value_one); + read_or_write(cache_index, current_keys_index); + }); }); }; @@ -253,7 +252,7 @@ void EmitTiledCompareLoop( if (dimension_to_sort_bound % tile_size) { // Otherwise we need a bounds check for the last tile. The last tile has // size 'dimension_to_sort_bound' % 'tile_size'. - ksl.IfReturnVoid( + ksl.If( "is_last_tile", b->CreateICmpUGE( b->CreateMul(tiled_keys_index[dimension_to_sort], diff --git a/tensorflow/compiler/xla/service/name_uniquer.cc b/tensorflow/compiler/xla/service/name_uniquer.cc index ac2f79674feceff436c0e9c65338967f498e4473..daa718879ddd45afb02725b557380b2f49fe833e 100644 --- a/tensorflow/compiler/xla/service/name_uniquer.cc +++ b/tensorflow/compiler/xla/service/name_uniquer.cc @@ -17,6 +17,7 @@ limitations under the License. #include "absl/strings/numbers.h" #include "absl/strings/str_cat.h" +#include "tensorflow/compiler/xla/primitive_util.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/types.h" @@ -42,6 +43,7 @@ NameUniquer::NameUniquer(const string& separator) { if (name.empty()) { return ""; } + string result = name; char c = static_cast(result[0]); if (!isalpha(c) && c != '_') { @@ -52,6 +54,13 @@ NameUniquer::NameUniquer(const string& separator) { result[i] = '_'; } } + + // HLO primitive type names (with the exception of 'tuple') are keywords in + // the HLO text representation and cannot be names, so append an underscore if + // the name is a primitive type. + if (primitive_util::IsPrimitiveTypeName(result) && result != "tuple") { + result += "_"; + } return result; } diff --git a/tensorflow/compiler/xla/service/name_uniquer_test.cc b/tensorflow/compiler/xla/service/name_uniquer_test.cc index 3e2592c6ac626143f1421e545a31d9be91e376bc..d0d04147e0c29c66cba447550c0a9c703f35573a 100644 --- a/tensorflow/compiler/xla/service/name_uniquer_test.cc +++ b/tensorflow/compiler/xla/service/name_uniquer_test.cc @@ -104,5 +104,21 @@ TEST_F(NameUniquerTest, KeepNamesInRandomOrder) { EXPECT_EQ("foo.3", uniquer.GetUniqueName("foo.3")); } +TEST_F(NameUniquerTest, AvoidKeywords) { + NameUniquer uniquer("."); + + EXPECT_EQ("f32_", uniquer.GetUniqueName("f32")); + EXPECT_EQ("s64_", uniquer.GetUniqueName("s64")); + EXPECT_EQ("pred_", uniquer.GetUniqueName("pred")); + + // Though a primitive type, "tuple" is not a keyword. + EXPECT_EQ("tuple", uniquer.GetUniqueName("tuple")); + + // Keywords are not capitalized. + EXPECT_EQ("F32", uniquer.GetUniqueName("F32")); + EXPECT_EQ("S32", uniquer.GetUniqueName("S32")); + EXPECT_EQ("Pred", uniquer.GetUniqueName("Pred")); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/pattern_matcher.h b/tensorflow/compiler/xla/service/pattern_matcher.h index c35f72699bfe90f7b8021916c0f81d5e1926ff4c..81db3bb643a989cafb6c6a8bcbd35e218fdcaf44 100644 --- a/tensorflow/compiler/xla/service/pattern_matcher.h +++ b/tensorflow/compiler/xla/service/pattern_matcher.h @@ -1737,7 +1737,8 @@ class HloConstantScalarImpl { literal_r0_as_val_ty_or.ValueOrDie() == val_literal && literal_r0 == val_as_literal_ty; if (!rv) { - EXPLAIN << "HloInstruction's constant value " << literal_r0.ToString() + EXPLAIN << "HloInstruction's constant value " + << literal_r0.ToStringWithoutShape() << " did not match expected value " << *val_; } return rv; diff --git a/tensorflow/compiler/xla/service/pattern_matcher_test.cc b/tensorflow/compiler/xla/service/pattern_matcher_test.cc index 186ef0c7911a2724df810780e018f52586e3e6a8..5c3c009a68bffbda8642fceedfb724879fbf1530 100644 --- a/tensorflow/compiler/xla/service/pattern_matcher_test.cc +++ b/tensorflow/compiler/xla/service/pattern_matcher_test.cc @@ -242,8 +242,8 @@ TEST(PatternMatcherTest, ConstantScalar) { HloModule test_module ENTRY test { a = s32[] constant(1) - b = s32[1,1] constant(s32[1,1]{{2}}) - c = s32[1,2] constant(s32[1,2]{{2,2}}) + b = s32[1,1] constant({{2}}) + c = s32[1,2] constant({{2,2}}) d = f32[] constant(1) e = f32[] constant(1.25) ROOT tuple = (s32[], s32[1,1], s32[1,2], f32[], f32[]) tuple(a,b,c,d,e) diff --git a/tensorflow/compiler/xla/service/service.cc b/tensorflow/compiler/xla/service/service.cc index 5ec7fe2adedac2fc3d8a7588e853dba90e99006f..ae5bd93e7c56117cc78ecc729d370250787efac6 100644 --- a/tensorflow/compiler/xla/service/service.cc +++ b/tensorflow/compiler/xla/service/service.cc @@ -1078,9 +1078,11 @@ Status Service::ComputeConstantGraph(const ComputeConstantGraphRequest* arg, ProgramShape program_shape(arg->computation().host_program_shape()); TF_DCHECK_OK(ShapeUtil::ValidateShape(program_shape.result())); + absl::optional output_layout; if (arg->has_output_layout()) { + output_layout = Layout::CreateFromProto(arg->output_layout()); TF_RETURN_IF_ERROR(LayoutUtil::ValidateLayoutForShape( - arg->output_layout(), program_shape.result())); + *output_layout, program_shape.result())); } HloModuleConfig config(program_shape); @@ -1096,8 +1098,8 @@ Status Service::ComputeConstantGraph(const ComputeConstantGraphRequest* arg, // relayout here. // // TODO(b/77824332): Make HloEvaluator take care of the re-layout. - if (arg->has_output_layout()) { - result_literal = result_literal.Relayout(arg->output_layout()); + if (output_layout.has_value()) { + result_literal = result_literal.Relayout(*output_layout); } *result->mutable_literal() = result_literal.ToProto(); diff --git a/tensorflow/compiler/xla/service/transpose_folding_test.cc b/tensorflow/compiler/xla/service/transpose_folding_test.cc index 17cdaa74fc328d156292f5af828d4222a9a01f1f..3ca53edc8171a134f2bfb9a36beacfd2d2e0d425 100644 --- a/tensorflow/compiler/xla/service/transpose_folding_test.cc +++ b/tensorflow/compiler/xla/service/transpose_folding_test.cc @@ -139,9 +139,9 @@ TEST_F(TransposeFoldingTest, FoldDotTransposeConstant) { HloModule FoldDotTransposeConstant ENTRY entry_computation { - constant = f32[2,1]{1,0} constant(f32[2,1] { { 1 }, { 2 } }) + constant = f32[2,1]{1,0} constant({ { 1 }, { 2 } }) transpose = f32[1,2]{1,0} transpose(constant), dimensions={1,0} - constant.1 = f32[3,2]{1,0} constant(f32[3,2] { { 1, 2 }, { 3, 4 }, { 5, 6 } }) + constant.1 = f32[3,2]{1,0} constant({ { 1, 2 }, { 3, 4 }, { 5, 6 } }) transpose.1 = f32[2,3]{1,0} transpose(constant.1), dimensions={1,0} ROOT dot = f32[1,3]{1,0} dot(transpose, transpose.1), lhs_contracting_dims={1}, rhs_contracting_dims={0} } diff --git a/tensorflow/compiler/xla/service/while_loop_constant_sinking_test.cc b/tensorflow/compiler/xla/service/while_loop_constant_sinking_test.cc index 75d406435b6f58faecc86b82c33e9e2dd6bccbea..3bcf5c38309a86e9e3cab3268f3f065005f7a923 100644 --- a/tensorflow/compiler/xla/service/while_loop_constant_sinking_test.cc +++ b/tensorflow/compiler/xla/service/while_loop_constant_sinking_test.cc @@ -129,7 +129,7 @@ condition { ENTRY entry { const_0 = f32[2] constant({1, 2}) - const_1 = (f32[2], f32[2]) constant((f32[2], f32[2]) ({2, 1},{3,1})) + const_1 = (f32[2], f32[2]) constant(({2, 1},{3,1})) while_init = (f32[2],(f32[2],f32[2])) tuple(const_0, const_1) ROOT while = (f32[2],(f32[2],f32[2])) while(while_init), condition=condition, body=body } @@ -206,8 +206,8 @@ body { p_body.0 = f32[2] get-tuple-element((f32[2],f32[2]) p_body), index=0 p_body.1 = f32[2] get-tuple-element((f32[2],f32[2]) p_body), index=1 - token = token[] after-all() - outfeed = token[] outfeed(p_body.0, token) + token0 = token[] after-all() + outfeed = token[] outfeed(p_body.0, token0) ROOT root = (f32[2],f32[2],f32[2]) tuple(p_body.0, p_body.1, p_body.1) } @@ -305,7 +305,7 @@ condition { ENTRY entry { const_0 = f32[] constant(0) - const_1 = (f32[], f32[]) constant((f32[], f32[]) (1, 10)) + const_1 = (f32[], f32[]) constant((1, 10)) while_init = (f32[],(f32[],f32[])) tuple(const_0, const_1) ROOT while = (f32[],(f32[],f32[])) while(while_init), condition=condition, body=body } diff --git a/tensorflow/compiler/xla/service/while_loop_simplifier_test.cc b/tensorflow/compiler/xla/service/while_loop_simplifier_test.cc index 4950e8269e9cf0723d717bd1734518d104c0c9f2..3713989ca2f64ee1d94c9f77255017909d957da2 100644 --- a/tensorflow/compiler/xla/service/while_loop_simplifier_test.cc +++ b/tensorflow/compiler/xla/service/while_loop_simplifier_test.cc @@ -22,6 +22,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_dce.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_matchers.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/service/tuple_simplifier.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" @@ -554,8 +555,7 @@ TEST_F(WhileLoopSimplifierTest, FlattenNestedTuple) { HloInstruction* new_while = FindFirstWhile(m.get()); Shape flat_tuple = - ShapeUtil::ParseShapeString("(s32[1], s32[2], s32[3], s32[4])") - .ValueOrDie(); + ParseShape("(s32[1], s32[2], s32[3], s32[4])").ValueOrDie(); SCOPED_TRACE(m->ToString()); EXPECT_TRUE(ShapeUtil::Equal(new_while->shape(), flat_tuple)); EXPECT_TRUE(ShapeUtil::Equal( @@ -567,8 +567,7 @@ TEST_F(WhileLoopSimplifierTest, FlattenNestedTuple) { flat_tuple)); EXPECT_TRUE(ShapeUtil::Equal( m->entry_computation()->root_instruction()->shape(), - ShapeUtil::ParseShapeString("((s32[1]), (s32[2], s32[3], (s32[4])))") - .ValueOrDie())); + ParseShape("((s32[1]), (s32[2], s32[3], (s32[4])))").ValueOrDie())); } // Edge-case: All elements of the loop carry are constants which can be removed, @@ -641,8 +640,7 @@ TEST_F(WhileLoopSimplifierTest, RemoveConstantFromLoopCarry) { EXPECT_TRUE(TupleSimplifier().Run(m.get()).ok()); HloInstruction* new_while = FindFirstWhile(m.get()); - Shape new_while_shape = - ShapeUtil::ParseShapeString("(s32[1], s32[3])").ValueOrDie(); + Shape new_while_shape = ParseShape("(s32[1], s32[3])").ValueOrDie(); EXPECT_TRUE(ShapeUtil::Equal(new_while->shape(), new_while_shape)); EXPECT_TRUE(ShapeUtil::Equal( new_while->while_body()->root_instruction()->shape(), new_while_shape)); @@ -652,9 +650,9 @@ TEST_F(WhileLoopSimplifierTest, RemoveConstantFromLoopCarry) { EXPECT_TRUE(ShapeUtil::Equal( new_while->while_condition()->parameter_instruction(0)->shape(), new_while_shape)); - EXPECT_TRUE(ShapeUtil::Equal( - m->entry_computation()->root_instruction()->shape(), - ShapeUtil::ParseShapeString("(s32[1], s32[2], s32[3])").ValueOrDie())); + EXPECT_TRUE( + ShapeUtil::Equal(m->entry_computation()->root_instruction()->shape(), + ParseShape("(s32[1], s32[2], s32[3])").ValueOrDie())); EXPECT_THAT(m->entry_computation()->root_instruction(), op::Tuple(_, op::Constant(), _)); } @@ -712,7 +710,7 @@ TEST_F(WhileLoopSimplifierTest, MergeInductionVariables_Simple) { // We should have added a new loop counter for s32[] to the end of the tuple. SCOPED_TRACE(m->ToString()); Shape new_while_shape = - ShapeUtil::ParseShapeString("(s32[], s32[], s32[], s32[])").ValueOrDie(); + ParseShape("(s32[], s32[], s32[], s32[])").ValueOrDie(); EXPECT_TRUE(ShapeUtil::Equal(new_while->shape(), new_while_shape)); EXPECT_TRUE(ShapeUtil::Equal( new_while->while_body()->root_instruction()->shape(), new_while_shape)); diff --git a/tensorflow/compiler/xla/service/while_util_test.cc b/tensorflow/compiler/xla/service/while_util_test.cc index 5e6941933330fde29bc9c779aae4bb3c36914660..d92b9870f373564ae8fd904c8bf9f0d1afbff9c4 100644 --- a/tensorflow/compiler/xla/service/while_util_test.cc +++ b/tensorflow/compiler/xla/service/while_util_test.cc @@ -180,8 +180,8 @@ body { cond { param.c = (s32[], s32[]) parameter(0) - token = token[] after-all() - infeed = (pred[], token[]) infeed(token) + token0 = token[] after-all() + infeed = (pred[], token[]) infeed(token0) ROOT condition = pred[] get-tuple-element(infeed), index=0 } diff --git a/tensorflow/compiler/xla/shape.cc b/tensorflow/compiler/xla/shape.cc index 746ab9e9977b1b10cdb0cb57197027d65bd50f55..b206345db2ac2940b1f139c82fa03a93538b5ccd 100644 --- a/tensorflow/compiler/xla/shape.cc +++ b/tensorflow/compiler/xla/shape.cc @@ -32,7 +32,7 @@ Shape::Shape(const ShapeProto& shape_proto) { *add_tuple_shapes() = Shape(element_shape); } if (shape_proto.has_layout()) { - *mutable_layout() = shape_proto.layout(); + *mutable_layout() = Layout::CreateFromProto(shape_proto.layout()); } } @@ -48,7 +48,7 @@ ShapeProto Shape::ToProto() const { *proto.add_tuple_shapes() = shape.ToProto(); } if (has_layout()) { - *proto.mutable_layout() = layout(); + *proto.mutable_layout() = layout().ToProto(); } return proto; } diff --git a/tensorflow/compiler/xla/shape.h b/tensorflow/compiler/xla/shape.h index 7f6b14ab4286c696dce64d2250a3fe8a57e4865b..7643f64d8a5f0450be1cddad35cf7422afb89048 100644 --- a/tensorflow/compiler/xla/shape.h +++ b/tensorflow/compiler/xla/shape.h @@ -20,6 +20,7 @@ limitations under the License. #include #include "absl/types/optional.h" +#include "tensorflow/compiler/xla/layout.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/platform/types.h" @@ -76,21 +77,10 @@ class Shape { std::vector* mutable_tuple_shapes() { return &tuple_shapes_; } // Methods for accessing the layout field. - bool has_layout() const { return layout_.has_value(); } - const Layout& layout() const { - if (layout_.has_value()) { - return *layout_; - } else { - return Layout::default_instance(); - } - } - Layout* mutable_layout() { - if (!layout_.has_value()) { - layout_ = Layout(); - } - return &layout_.value(); - } - void clear_layout() { layout_.reset(); } + bool has_layout() const { return layout_.format() != INVALID_FORMAT; } + const Layout& layout() const { return layout_; } + Layout* mutable_layout() { return &layout_; } + void clear_layout() { layout_.Clear(); } void Swap(Shape* other) { using std::swap; @@ -101,7 +91,7 @@ class Shape { element_type_ = PRIMITIVE_TYPE_INVALID; dimensions_.clear(); tuple_shapes_.clear(); - layout_.reset(); + clear_layout(); } string SerializeAsString() const { return ToProto().SerializeAsString(); } @@ -118,8 +108,8 @@ class Shape { // The tuple element subshapes. This is nonempty only for tuple shapes. std::vector tuple_shapes_; - // The array layout of the shape. This is present only for array shapes. - absl::optional layout_; + // The layout of the shape. Only relevant for arrays. + Layout layout_; }; // Shape of the parameters and output of an XLA computation. This is analogous diff --git a/tensorflow/compiler/xla/shape_util.cc b/tensorflow/compiler/xla/shape_util.cc index a4d4e1e53e727bdf7822cacaa4559fcae59d4eae..be7d71ada009535a5c08aa3d3d062fa451cfeef3 100644 --- a/tensorflow/compiler/xla/shape_util.cc +++ b/tensorflow/compiler/xla/shape_util.cc @@ -164,9 +164,9 @@ StatusOr MakeShapeWithLayoutInternal( TF_ASSIGN_OR_RETURN(Shape shape, ShapeUtil::MakeValidatedShape(element_type, dimensions)); auto min2maj = shape.mutable_layout()->mutable_minor_to_major(); - min2maj->Clear(); + min2maj->clear(); for (int64 value : minor_to_major) { - min2maj->Add(value); + min2maj->push_back(value); } if (!shape.has_layout()) { return InvalidArgument("Shape has no layout."); @@ -234,7 +234,7 @@ StatusOr MakeShapeWithLayoutInternal( /* static */ StatusOr ShapeUtil::MakeValidatedShape( PrimitiveType element_type, absl::Span dimensions) { - CHECK(IsArrayPrimitiveType(element_type)); + CHECK(IsArrayPrimitiveType(element_type)) << element_type; Shape result; TF_RETURN_IF_ERROR(PopulateShape(element_type, dimensions, &result)); return result; @@ -480,54 +480,6 @@ ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout( return IsScalar(shape) && shape.element_type() == element_type; } -namespace { - -// Class to memoize the computation of -// absl::AsciiStrToLower(PrimitiveType_Name(p)) -// for all PrimitiveType values "p" -class PrimitiveTypeNameGenerator { - public: - PrimitiveTypeNameGenerator() { - for (int i = 0; i < PrimitiveType_ARRAYSIZE; i++) { - if (PrimitiveType_IsValid(i)) { - lowercase_name_[i] = absl::AsciiStrToLower( - PrimitiveType_Name(static_cast(i))); - } - } - } - const string& LowercaseName(PrimitiveType t) { - return lowercase_name_[static_cast(t)]; - } - - private: - string lowercase_name_[PrimitiveType_ARRAYSIZE]; -}; - -const string& LowercasePrimitiveTypeName(PrimitiveType s) { - static PrimitiveTypeNameGenerator* gen = new PrimitiveTypeNameGenerator(); - return gen->LowercaseName(s); -} - -StatusOr StringToPrimitiveType(const string& name) { - static std::unordered_map* name_to_type = [] { - static auto* map = new std::unordered_map; - for (int i = 0; i < PrimitiveType_ARRAYSIZE; i++) { - if (PrimitiveType_IsValid(i)) { - auto value = static_cast(i); - (*map)[LowercasePrimitiveTypeName(value)] = value; - } - } - return map; - }(); - auto found = name_to_type->find(name); - if (found == name_to_type->end()) { - return InvalidArgument("Invalid element type string: \"%s\".", name); - } - return found->second; -} - -} // namespace - /* static */ string ShapeUtil::HumanString(const Shape& shape) { if (IsTuple(shape)) { string text = "("; @@ -539,8 +491,9 @@ StatusOr StringToPrimitiveType(const string& name) { text += ")"; return text; } - return StrCat(LowercasePrimitiveTypeName(shape.element_type()), "[", - absl::StrJoin(shape.dimensions(), ","), "]"); + return StrCat( + primitive_util::LowercasePrimitiveTypeName(shape.element_type()), "[", + absl::StrJoin(shape.dimensions(), ","), "]"); } /* static */ string ShapeUtil::HumanStringWithLayout(const Shape& shape) { @@ -554,7 +507,8 @@ StatusOr StringToPrimitiveType(const string& name) { text += ")"; return text; } - string result = StrCat(LowercasePrimitiveTypeName(shape.element_type()), "["); + string result = StrCat( + primitive_util::LowercasePrimitiveTypeName(shape.element_type()), "["); for (int i = 0; i < shape.dimensions().size(); i++) { StrAppend(&result, (i > 0) ? "," : "", shape.dimensions(i)); } @@ -580,116 +534,6 @@ StatusOr StringToPrimitiveType(const string& name) { HumanString(program_shape.result())); } -namespace { -// Parses shapes with simple recursive descent structure -- consumes from the -// front of s and passes that view recursively as required. -StatusOr ParseShapeStringInternal(absl::string_view* s) { - *s = absl::StripLeadingAsciiWhitespace(*s); - - if (absl::ConsumePrefix(s, "(")) { // Tuple. - std::vector shapes; - bool must_end = false; - while (true) { - if (absl::ConsumePrefix(s, ")")) { - break; - } else if (must_end) { - return InvalidArgument("Expected end of tuple; got: \"%s\"", *s); - } - shapes.emplace_back(); - TF_ASSIGN_OR_RETURN(shapes.back(), ParseShapeStringInternal(s)); - *s = absl::StripLeadingAsciiWhitespace(*s); - must_end = !absl::ConsumePrefix(s, ","); - } - return ShapeUtil::MakeTupleShape(shapes); - } - - string element_type_string; - string dimensions_string; - string format_string; - string layout_string; - // absl::string_view is not compatible with internal RE2 StringPiece, so - // we convert in to the RE2-consumable type and then consume the corresponding - // amount from our string_view type. - static LazyRE2 shape_pattern = { - "^(\\w*\\d*)\\[([\\d,\\s]*)\\](?:\\s*(dense|sparse)?\\s*{([\\d,\\s]+)})" - "?"}; - tensorflow::RegexpStringPiece s_consumable(s->data(), s->size()); - if (RE2::Consume(&s_consumable, *shape_pattern, &element_type_string, - &dimensions_string, &format_string, &layout_string)) { - size_t consumed = s->size() - s_consumable.size(); - s->remove_prefix(consumed); - auto string_to_int64 = [&s](absl::string_view input) -> StatusOr { - int64 element; - if (!absl::SimpleAtoi(input, &element)) { - return InvalidArgument( - "Invalid s64 value in parsed shape string: \"%s\" in \"%s\"", input, - *s); - } - return element; - }; - - auto comma_list_to_int64s = - [string_to_int64](const string& input) -> StatusOr> { - std::vector results; - for (const auto& piece : absl::StrSplit(input, ',', absl::SkipEmpty())) { - TF_ASSIGN_OR_RETURN(int64 element, string_to_int64(piece)); - results.push_back(element); - } - return results; - }; - - // Extract the dimensions. - TF_ASSIGN_OR_RETURN(std::vector dimensions, - comma_list_to_int64s(dimensions_string)); - - // Extract the primitive element type. - TF_ASSIGN_OR_RETURN(const PrimitiveType primitive_type, - StringToPrimitiveType(element_type_string)); - if (primitive_type == PRIMITIVE_TYPE_INVALID || primitive_type == TUPLE) { - return InvalidArgument("Invalid element type string: \"%s\".", - element_type_string); - } - - Shape result; - if (primitive_type == OPAQUE) { - result = ShapeUtil::MakeOpaqueShape(); - } else if (primitive_type == TOKEN) { - result = ShapeUtil::MakeTokenShape(); - } else if (format_string.empty() && layout_string.empty()) { - // Create a shape without a layout set. - TF_ASSIGN_OR_RETURN( - result, ShapeUtil::MakeValidatedShape(primitive_type, dimensions)); - } else if (format_string == "sparse") { - TF_ASSIGN_OR_RETURN(int64 max_elements, string_to_int64(layout_string)); - result = ShapeUtil::MakeShapeWithSparseLayout(primitive_type, dimensions, - max_elements); - } else if (format_string.empty() || format_string == "dense") { - // Extract the layout minor-to-major and set it. - TF_ASSIGN_OR_RETURN(std::vector min2maj, - comma_list_to_int64s(layout_string)); - TF_ASSIGN_OR_RETURN(result, MakeShapeWithLayoutInternal( - primitive_type, dimensions, min2maj)); - } else { - // This should not be reached. - LOG(FATAL) << "Unhandled condition when parsing shape; format: \"" - << format_string << "\", layout: \"" << layout_string << "\""; - } - TF_RETURN_IF_ERROR(ShapeUtil::ValidateShape(result)); - return std::move(result); - } - - return InvalidArgument("Invalid shape string to parse: \"%s\"", *s); -} -} // namespace - -/* static */ StatusOr ShapeUtil::ParseShapeString(absl::string_view s) { - TF_ASSIGN_OR_RETURN(Shape shape, ParseShapeStringInternal(&s)); - if (!s.empty()) { - return InvalidArgument("Invalid shape string to parse: \"%s\"", s); - } - return shape; -} - /* static */ bool ShapeUtil::SameDimensions(const Shape& lhs, const Shape& rhs) { CHECK(ShapeUtil::IsArray(lhs)); @@ -867,13 +711,13 @@ StatusOr ParseShapeStringInternal(absl::string_view* s) { if (shape.dimensions_size() != 0) { return InvalidArgument( "shape has %s element type, but has dimensions field: %s", - LowercasePrimitiveTypeName(shape.element_type()), + primitive_util::LowercasePrimitiveTypeName(shape.element_type()), shape.ShortDebugString()); } if (shape.has_layout()) { return InvalidArgument( "shape has %s element type, but has layout field: %s", - LowercasePrimitiveTypeName(shape.element_type()), + primitive_util::LowercasePrimitiveTypeName(shape.element_type()), shape.ShortDebugString()); } return Status::OK(); @@ -1067,6 +911,11 @@ bool ShapeUtil::IsLeafIndex(const Shape& shape, const ShapeIndex& index) { return absl::c_linear_search(shape.dimensions(), 1); } +/* static */ Shape ShapeUtil::DropDegenerateDimensions(const Shape& shape) { + return FilterDimensions( + [&](int64 dim) -> bool { return shape.dimensions()[dim] != 1; }, shape); +} + namespace { // Helper for ForEachSubshape which visits the subshapes of the given shape in @@ -1618,10 +1467,10 @@ ShapeUtil::DimensionsUnmodifiedByReshape(const Shape& input_shape, if (LayoutUtil::HasLayout(shape)) { Layout* layout = shape.mutable_layout(); layout->set_format(DENSE); - for (size_t i = 0; i < layout->minor_to_major().size();) { + for (int64 i = 0; i < layout->minor_to_major().size();) { if (layout->minor_to_major(i) == dim_to_delete) { layout->mutable_minor_to_major()->erase( - layout->minor_to_major().begin() + i); + layout->mutable_minor_to_major()->begin() + i); continue; } if (layout->minor_to_major(i) > dim_to_delete) { diff --git a/tensorflow/compiler/xla/shape_util.h b/tensorflow/compiler/xla/shape_util.h index 84a27f662a57ba274562e2e9be57b7e971c9b477..6b7a9cd34f25f2088bdb8d2c7f0412e5d8519d23 100644 --- a/tensorflow/compiler/xla/shape_util.h +++ b/tensorflow/compiler/xla/shape_util.h @@ -241,10 +241,6 @@ class ShapeUtil { // (param_name: f32[42x12], ...) -> f32[24x42] static string HumanString(const ProgramShape& program_shape); - // Parses a ShapeUtil::HumanString-format shape string back into a shape - // object. - static StatusOr ParseShapeString(absl::string_view s); - // Returns whether the LHS and RHS shapes have the same dimensions; note: does // not check element type. // Precondition: IsArray(lhs) && IsArray(rhs) @@ -551,6 +547,9 @@ class ShapeUtil { // (dimensions with bound 1). static bool HasDegenerateDimensions(const Shape& shape); + // Drops any degenerate dimensions (i.e. dimensions of size 1) + static Shape DropDegenerateDimensions(const Shape& shape); + // Permutes the dimensions by the given permutation, so // return_value.dimensions[permutation[i]] = argument.dimensions[i]. // diff --git a/tensorflow/compiler/xla/shape_util_test.cc b/tensorflow/compiler/xla/shape_util_test.cc index 60bdbe302045e6f3b4bae500c50bc68fb217525d..0a3081f5161f80ac97e864ba08d186df4fbdb55d 100644 --- a/tensorflow/compiler/xla/shape_util_test.cc +++ b/tensorflow/compiler/xla/shape_util_test.cc @@ -82,102 +82,6 @@ TEST(ShapeUtilTest, Rank4DimensionIndexing) { ASSERT_EQ(3, shape.dimensions(0)); } -TEST(ShapeUtilTest, ParseShapeStringR2F32) { - string shape_string = "f32[123,456]"; - TF_ASSERT_OK_AND_ASSIGN(Shape actual, - ShapeUtil::ParseShapeString(shape_string)); - Shape expected = ShapeUtil::MakeShape(F32, {123, 456}); - ASSERT_TRUE(ShapeUtil::Equal(expected, actual)) - << "expected: " << ShapeUtil::HumanString(expected) - << "actual: " << ShapeUtil::HumanString(actual); -} - -TEST(ShapeUtilTest, ParseShapeStringTupleOfArrays) { - string shape_string = "(f32[1572864],s8[5120,1024])"; - TF_ASSERT_OK_AND_ASSIGN(Shape actual, - ShapeUtil::ParseShapeString(shape_string)); - Shape expected = - ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(F32, {1572864}), - ShapeUtil::MakeShape(S8, {5120, 1024})}); - ASSERT_TRUE(ShapeUtil::Equal(expected, actual)) - << "expected: " << ShapeUtil::HumanString(expected) - << "actual: " << ShapeUtil::HumanString(actual); -} - -TEST(ShapeUtilTest, ParseShapeStringNestedTuple) { - string shape_string = "(f32[1],(f32[2], token[]), opaque[], f32[3])"; - TF_ASSERT_OK_AND_ASSIGN(Shape actual, - ShapeUtil::ParseShapeString(shape_string)); - Shape expected = ShapeUtil::MakeTupleShape({ - ShapeUtil::MakeShape(F32, {1}), - ShapeUtil::MakeTupleShape( - {ShapeUtil::MakeShape(F32, {2}), ShapeUtil::MakeTokenShape()}), - ShapeUtil::MakeOpaqueShape(), - ShapeUtil::MakeShape(F32, {3}), - }); - ASSERT_TRUE(ShapeUtil::Equal(expected, actual)) - << "expected: " << ShapeUtil::HumanString(expected) - << "actual: " << ShapeUtil::HumanString(actual); -} - -TEST(ShapeUtilTest, ParseShapeStringWithLayout) { - string shape_string = "f32[123,456]{0,1}"; - TF_ASSERT_OK_AND_ASSIGN(Shape actual, - ShapeUtil::ParseShapeString(shape_string)); - Shape expected = ShapeUtil::MakeShapeWithLayout(F32, {123, 456}, {0, 1}); - ASSERT_TRUE(ShapeUtil::Equal(expected, actual)) - << "expected: " << ShapeUtil::HumanString(expected) - << "actual: " << ShapeUtil::HumanString(actual); -} - -TEST(ShapeUtilTest, ParseShapeStringWithExplicitDenseLayout) { - string shape_string = "f32[123,456]dense{0,1}"; - TF_ASSERT_OK_AND_ASSIGN(Shape actual, - ShapeUtil::ParseShapeString(shape_string)); - Shape expected = ShapeUtil::MakeShapeWithLayout(F32, {123, 456}, {0, 1}); - ASSERT_TRUE(ShapeUtil::Equal(expected, actual)) - << "expected: " << ShapeUtil::HumanString(expected) - << "actual: " << ShapeUtil::HumanString(actual); -} - -TEST(ShapeUtilTest, ParseShapeStringWithSparseLayout) { - string shape_string = "f32[123,456]sparse{10}"; - TF_ASSERT_OK_AND_ASSIGN(Shape actual, - ShapeUtil::ParseShapeString(shape_string)); - Shape expected = ShapeUtil::MakeShapeWithSparseLayout(F32, {123, 456}, 10); - ASSERT_TRUE(ShapeUtil::Equal(expected, actual)) - << "expected: " << ShapeUtil::HumanString(expected) - << "actual: " << ShapeUtil::HumanString(actual); -} - -TEST(ShapeUtilTest, ParseOpaqueType) { - TF_ASSERT_OK_AND_ASSIGN(Shape actual, - ShapeUtil::ParseShapeString("opaque[]")); - Shape expected = ShapeUtil::MakeOpaqueShape(); - ASSERT_TRUE(ShapeUtil::Equal(expected, actual)) - << "expected: " << ShapeUtil::HumanString(expected) - << "actual: " << ShapeUtil::HumanString(actual); -} - -TEST(ShapeUtilTest, ParseTokenType) { - TF_ASSERT_OK_AND_ASSIGN(Shape actual, ShapeUtil::ParseShapeString("token[]")); - Shape expected = ShapeUtil::MakeTokenShape(); - ASSERT_TRUE(ShapeUtil::Equal(expected, actual)) - << "expected: " << ShapeUtil::HumanString(expected) - << "actual: " << ShapeUtil::HumanString(actual); -} - -TEST(ShapeUtilTest, ParseInvalidShapeString) { - string shape_strings[] = { - "f32[123,456]foobar{0,1}", "f32[123,456]sparse{0,1}", "f32[123,456]{foo}", - "f32[123,456]dense{foo}", "f32[123,456]sparse{foo}", - }; - for (const string& shape_string : shape_strings) { - StatusOr result = ShapeUtil::ParseShapeString(shape_string); - ASSERT_FALSE(result.ok()) << "shape: " << shape_string; - } -} - TEST(ShapeUtilTest, CompatibleIdenticalShapes) { Shape shape1 = ShapeUtil::MakeShape(F32, {3, 2}); Shape shape2 = ShapeUtil::MakeShape(F32, {3, 2}); diff --git a/tensorflow/compiler/xla/tests/BUILD b/tensorflow/compiler/xla/tests/BUILD index 5a7a4faa7e89b27fb537f20d94c21cb4a76e000d..0300b64ed59a3d4d8b0cd161109c97cabfdc6734 100644 --- a/tensorflow/compiler/xla/tests/BUILD +++ b/tensorflow/compiler/xla/tests/BUILD @@ -1,6 +1,13 @@ # Description: # Base testing infrastructure for XLA. +load("//tensorflow/compiler/xla/tests:build_defs.bzl", "generate_backend_suites", "generate_backend_test_macros", "xla_test", "xla_test_library") +load( + "//tensorflow/core:platform/default/build_config_root.bzl", + "tf_cuda_tests_tags", +) +load("//tensorflow:tensorflow.bzl", "tf_cc_binary", "tf_cc_test") + licenses(["notice"]) # Apache 2.0 package( @@ -23,17 +30,6 @@ filegroup( ]), ) -load("//tensorflow/compiler/xla/tests:build_defs.bzl", "xla_test") -load("//tensorflow/compiler/xla/tests:build_defs.bzl", "xla_test_library") -load("//tensorflow/compiler/xla/tests:build_defs.bzl", "generate_backend_suites") -load("//tensorflow/compiler/xla/tests:build_defs.bzl", "generate_backend_test_macros") -load("//tensorflow:tensorflow.bzl", "tf_cc_binary") -load("//tensorflow:tensorflow.bzl", "tf_cc_test") -load( - "//tensorflow/core:platform/default/build_config_root.bzl", - "tf_cuda_tests_tags", -) - # Generate test_suites for all backends, named "${backend}_tests". generate_backend_suites() @@ -1348,6 +1344,7 @@ xla_test( xla_test( name = "custom_call_test", srcs = ["custom_call_test.cc"], + backends = ["cpu"], deps = [ "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:literal_util", diff --git a/tensorflow/compiler/xla/tests/client_library_test_base.cc b/tensorflow/compiler/xla/tests/client_library_test_base.cc index 12c029983336cc9aed0fde4ce6881c9a00a9869e..697236dc6236738df08205fa3631a2919dd361c5 100644 --- a/tensorflow/compiler/xla/tests/client_library_test_base.cc +++ b/tensorflow/compiler/xla/tests/client_library_test_base.cc @@ -74,6 +74,9 @@ ClientLibraryTestBase::ClientLibraryTestBase( // default. execution_options_.mutable_debug_options()->add_xla_disable_hlo_passes( "constant_folding"); + + execution_options_.mutable_debug_options() + ->set_xla_hlo_evaluator_use_fast_path(true); } ClientLibraryTestBase::ClientLibraryTestBase(se::Platform* platform) @@ -88,6 +91,9 @@ ClientLibraryTestBase::ClientLibraryTestBase(se::Platform* platform) execution_options_.mutable_debug_options()->add_xla_disable_hlo_passes( "constant_folding"); + + execution_options_.mutable_debug_options() + ->set_xla_hlo_evaluator_use_fast_path(true); } string ClientLibraryTestBase::TestName() const { diff --git a/tensorflow/compiler/xla/tests/copy_test.cc b/tensorflow/compiler/xla/tests/copy_test.cc index 3622f2c1e84639baed13059b21b20609d1347da6..df005a67097bb8aaf070c57d1c51acd1909fee12 100644 --- a/tensorflow/compiler/xla/tests/copy_test.cc +++ b/tensorflow/compiler/xla/tests/copy_test.cc @@ -133,7 +133,9 @@ XLA_TEST_F(CopyOpTest, CopyConstantR2DifferentLayouts) { // Reverse the minor-to-major order of the literal. Layout* literal_layout = literal.mutable_shape_do_not_use()->mutable_layout(); ASSERT_EQ(2, literal_layout->minor_to_major_size()); - literal_layout->mutable_minor_to_major()->SwapElements(0, 1); + // Swap the first and second elements. + *literal_layout->mutable_minor_to_major() = { + literal_layout->minor_to_major(1), literal_layout->minor_to_major(0)}; HloInstruction* constant = builder.AddInstruction( HloInstruction::CreateConstant(std::move(literal))); diff --git a/tensorflow/compiler/xla/tests/custom_call_test.cc b/tensorflow/compiler/xla/tests/custom_call_test.cc index 738b6442354b01364278e3e3c713aa2cdb5cf47d..cad43d1b5547d74701760fa623e50466fc15c263 100644 --- a/tensorflow/compiler/xla/tests/custom_call_test.cc +++ b/tensorflow/compiler/xla/tests/custom_call_test.cc @@ -54,11 +54,20 @@ void Add1ToValues(float* out, float** in) { out[2] = array[2] + 1; out[3] = array[3] + 1; } + +void F32TupleSwap(float** out, float** in) { + TF_ANNOTATE_MEMORY_IS_INITIALIZED(in[0], sizeof(float)); + TF_ANNOTATE_MEMORY_IS_INITIALIZED(in[1], sizeof(float)); + *out[0] = *in[1]; + *out[1] = *in[0]; +} + } // namespace REGISTER_CUSTOM_CALL_TARGET(R0F32Add2); REGISTER_CUSTOM_CALL_TARGET(R2F32ReduceSum); REGISTER_CUSTOM_CALL_TARGET(Add1ToValues); +REGISTER_CUSTOM_CALL_TARGET(F32TupleSwap); namespace xla { namespace { @@ -69,7 +78,7 @@ class CustomCallTest : public HloTestBase { Shape r2f32_ = ShapeUtil::MakeShape(F32, {2, 2}); }; -XLA_TEST_F(CustomCallTest, DISABLED_ON_GPU(CustomCallR0F32Add2)) { +XLA_TEST_F(CustomCallTest, CustomCallR0F32Add2) { auto module = CreateNewUnverifiedModule(); auto builder = HloComputation::Builder(TestName()); @@ -84,7 +93,7 @@ XLA_TEST_F(CustomCallTest, DISABLED_ON_GPU(CustomCallR0F32Add2)) { LiteralTestUtil::ExpectR0Near(44.0f, result, error_spec_); } -XLA_TEST_F(CustomCallTest, DISABLED_ON_GPU(CustomCallR2F32Reduce)) { +XLA_TEST_F(CustomCallTest, CustomCallR2F32Reduce) { auto module = CreateNewUnverifiedModule(); auto builder = HloComputation::Builder(TestName()); @@ -105,7 +114,7 @@ XLA_TEST_F(CustomCallTest, DISABLED_ON_GPU(CustomCallR2F32Reduce)) { LiteralTestUtil::ExpectR0Near(10.0f, result, error_spec_); } -XLA_TEST_F(CustomCallTest, DISABLED_ON_GPU(UsedInOtherComputations)) { +XLA_TEST_F(CustomCallTest, UsedInOtherComputations) { auto module = CreateNewUnverifiedModule(); auto b = HloComputation::Builder(TestName()); @@ -129,7 +138,7 @@ XLA_TEST_F(CustomCallTest, DISABLED_ON_GPU(UsedInOtherComputations)) { Array3D{{{2, 3}, {4, 5}}, {{3, 4}, {5, 6}}}, result); } -XLA_TEST_F(CustomCallTest, DISABLED_ON_GPU(InputAndOutputLayoutDiffer)) { +XLA_TEST_F(CustomCallTest, InputAndOutputLayoutDiffer) { auto module = CreateNewUnverifiedModule(); auto b = HloComputation::Builder(TestName()); @@ -151,7 +160,7 @@ XLA_TEST_F(CustomCallTest, DISABLED_ON_GPU(InputAndOutputLayoutDiffer)) { LiteralTestUtil::ExpectR2Equal({{2.f, 4.f}, {3.f, 5.f}}, result); } -XLA_TEST_F(CustomCallTest, DISABLED_ON_GPU(LayoutConstrained)) { +XLA_TEST_F(CustomCallTest, LayoutConstrained) { // The argument and result of the computation are set to different layouts, // but the custom call is layout constrained to a fixed operand and result // layout, so the correct result should be produced. @@ -176,6 +185,26 @@ XLA_TEST_F(CustomCallTest, DISABLED_ON_GPU(LayoutConstrained)) { LiteralTestUtil::ExpectR2Equal({{2.f, 3.f}, {4.f, 5.f}}, result); } +XLA_TEST_F(CustomCallTest, TupleOutput) { + const char* kModuleStr = R"( + HloModule m + test { + p0 = f32[] parameter(0) + p1 = f32[] parameter(1) + ROOT %custom-call = (f32[], f32[]) custom-call(f32[] %p0, f32[] %p1), custom_call_target="F32TupleSwap", operand_layout_constraints={f32[], f32[]} + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(kModuleStr)); + + Literal arg0 = LiteralUtil::CreateR0(7.f); + Literal arg1 = LiteralUtil::CreateR0(42.f); + + Literal expected = LiteralUtil::MakeTuple({&arg1, &arg0}); + Literal result = ExecuteAndTransfer(std::move(module), {&arg0, &arg1}); + EXPECT_EQ(result, expected); +} + class CustomCallClientAPITest : public ClientLibraryTestBase {}; // When using the client API, CustomCall targets can't begin with '$' -- these diff --git a/tensorflow/compiler/xla/tests/hlo_test_base.cc b/tensorflow/compiler/xla/tests/hlo_test_base.cc index 989a7c705a8254f99e5cc0e97dfde5942f146964..d57846e19bb80c5b9c87d50596da2915f9aef317 100644 --- a/tensorflow/compiler/xla/tests/hlo_test_base.cc +++ b/tensorflow/compiler/xla/tests/hlo_test_base.cc @@ -181,6 +181,7 @@ DebugOptions HloTestBase::GetDebugOptionsForTest() { // TODO(b/38354253): Change tests to use Parameters instead of Constants. debug_options.add_xla_disable_hlo_passes("constant_folding"); debug_options.set_xla_gpu_max_kernel_unroll_factor(1); + debug_options.set_xla_hlo_evaluator_use_fast_path(true); return debug_options; } diff --git a/tensorflow/compiler/xla/tests/iota_test.cc b/tensorflow/compiler/xla/tests/iota_test.cc index 65205f53ddc582ae477d67705f161fef1e31b857..37b2c635eebe57590e1ba73c62f015ccf399b548 100644 --- a/tensorflow/compiler/xla/tests/iota_test.cc +++ b/tensorflow/compiler/xla/tests/iota_test.cc @@ -80,7 +80,7 @@ TEST_P(IotaR2Test, DoIt) { } INSTANTIATE_TEST_CASE_P(IotaR2TestInstantiation, IotaR2Test, - ::testing::Combine(::testing::Values(F32, S32), + ::testing::Combine(::testing::Values(F32, S32, BF16), ::testing::Range(/*start=*/10, /*end=*/1001, /*step=*/10), diff --git a/tensorflow/compiler/xla/tests/literal_test_util_test.cc b/tensorflow/compiler/xla/tests/literal_test_util_test.cc index b6f9b8156b51144e4f74d285b1e4111d098f13c2..ea9b3037cf482e41238413179888f125822d161c 100644 --- a/tensorflow/compiler/xla/tests/literal_test_util_test.cc +++ b/tensorflow/compiler/xla/tests/literal_test_util_test.cc @@ -89,11 +89,11 @@ TEST(LiteralTestUtilTest, ExpectNearFailurePlacesResultsInTemporaryDirectory) { Literal literal = Literal::CreateFromProto(literal_proto).ConsumeValueOrDie(); if (result.find("expected") != string::npos) { - EXPECT_EQ("2", literal.ToString()); + EXPECT_EQ("f32[] 2", literal.ToString()); } else if (result.find("actual") != string::npos) { - EXPECT_EQ("4", literal.ToString()); + EXPECT_EQ("f32[] 4", literal.ToString()); } else if (result.find("mismatches") != string::npos) { - EXPECT_EQ("true", literal.ToString()); + EXPECT_EQ("pred[] true", literal.ToString()); } else { FAIL() << "unknown file in temporary directory: " << result; } @@ -105,9 +105,9 @@ TEST(LiteralTestUtilTest, NotEqualHasValuesInMessage) { auto actual = LiteralUtil::CreateR1({4, 5, 6}); ::testing::AssertionResult result = LiteralTestUtil::Equal(expected, actual); EXPECT_THAT(result.message(), - ::testing::HasSubstr("Expected literal:\n{1, 2, 3}")); + ::testing::HasSubstr("Expected literal:\ns32[3] {1, 2, 3}")); EXPECT_THAT(result.message(), - ::testing::HasSubstr("Actual literal:\n{4, 5, 6}")); + ::testing::HasSubstr("Actual literal:\ns32[3] {4, 5, 6}")); } TEST(LiteralTestUtilTest, NearComparatorR1) { diff --git a/tensorflow/compiler/xla/tests/test_utils_test.cc b/tensorflow/compiler/xla/tests/test_utils_test.cc index e8f5d7a9a79ebddea3cb989dbe8eab90b630d5e7..448a66cfdd897b17cce1c87c050520a2f2eb0ea2 100644 --- a/tensorflow/compiler/xla/tests/test_utils_test.cc +++ b/tensorflow/compiler/xla/tests/test_utils_test.cc @@ -61,11 +61,11 @@ XLA_TEST_F(TestUtilsTest, Token) { R"(HloModule outfeed_module ENTRY InfeedToOutfeed { - token = token[] parameter(0) - infeed = ((u32[3]{0}, pred[]), token[]) infeed(token) + token0 = token[] parameter(0) + infeed = ((u32[3]{0}, pred[]), token[]) infeed(token0) infeed.data = (u32[3]{0}, pred[]) get-tuple-element(infeed), index=0 - outfeed = token[] outfeed(infeed.data, token) - ROOT infeed.1 = ((u32[3]{0}, pred[]), token[]) infeed(token) + outfeed = token[] outfeed(infeed.data, token0) + ROOT infeed.1 = ((u32[3]{0}, pred[]), token[]) infeed(token0) infeed.1.data = (u32[3]{0}, pred[]) get-tuple-element(infeed.1), index=0 infeed.1.token = token[] get-tuple-element(infeed.1), index=1 outfeed.1 = token[] outfeed(infeed.1.data, infeed.1.token) diff --git a/tensorflow/compiler/xla/tests/token_hlo_test.cc b/tensorflow/compiler/xla/tests/token_hlo_test.cc index 601c6b06938fef1f1ae809b33209ae59b24c70a2..b77cf38ed8e29973985406015c0a3936916ad5e6 100644 --- a/tensorflow/compiler/xla/tests/token_hlo_test.cc +++ b/tensorflow/compiler/xla/tests/token_hlo_test.cc @@ -214,8 +214,8 @@ ENTRY %AddDependency (p0: f32[], p1: f32[]) -> f32[] { %forty_two = f32[] constant(42.0) %add = f32[] add(f32[] %p0, f32[] %forty_two) - %token = token[] after-all(f32[] %add) - %p1_after_token = f32[] add-dependency(f32[] %p1, token[] %token) + %token0 = token[] after-all(f32[] %add) + %p1_after_token = f32[] add-dependency(f32[] %p1, token[] %token0) %neg = f32[] negate(f32[] %p1_after_token) ROOT %product = f32[] multiply(f32[] %add, f32[] %neg) } @@ -236,8 +236,8 @@ HloModule AddDependencyOfConstant, is_scheduled=true ENTRY %AddDependency (p0: f32[]) -> f32[] { %p0 = f32[] parameter(0) %forty_two = f32[] constant(42.0) - %token = token[] after-all(f32[] %p0) - %forty_two_after_token = f32[] add-dependency(f32[] %forty_two, token[] %token) + %token0 = token[] after-all(f32[] %p0) + %forty_two_after_token = f32[] add-dependency(f32[] %forty_two, token[] %token0) ROOT %product = f32[] multiply(f32[] %p0, f32[] %forty_two_after_token) } )"; @@ -255,8 +255,8 @@ HloModule AddDependencyAsRoot, is_scheduled=true ENTRY %AddDependency (p: f32[3]) -> f32[3] { %p = f32[3] parameter(0) %neg = f32[3] negate(f32[3] %p) - %token = token[] after-all() - ROOT %add_dep = f32[3] add-dependency(f32[3] %neg, token[] %token) + %token0 = token[] after-all() + ROOT %add_dep = f32[3] add-dependency(f32[3] %neg, token[] %token0) } )"; TF_ASSERT_OK_AND_ASSIGN( @@ -274,9 +274,9 @@ ENTRY %TupleShapedAddDependency (p0: f32[3], p1: f32[3]) -> f32[3] { %p0 = f32[3] parameter(0) %p1 = f32[3] parameter(1) %forty_two = f32[] constant(42.0) - %token = token[] after-all() - %tuple = (f32[3], token[], f32[3], f32[]) tuple(f32[3] %p0, token[] %token, f32[3] %p1, f32[] %forty_two) - %add_dep = (f32[3], token[], f32[3], f32[]) add-dependency((f32[3], token[], f32[3], f32[]) %tuple, token[] %token) + %token0 = token[] after-all() + %tuple = (f32[3], token[], f32[3], f32[]) tuple(f32[3] %p0, token[] %token0, f32[3] %p1, f32[] %forty_two) + %add_dep = (f32[3], token[], f32[3], f32[]) add-dependency((f32[3], token[], f32[3], f32[]) %tuple, token[] %token0) %elem0 = f32[3] get-tuple-element((f32[3], token[], f32[3], f32[]) %add_dep), index=0 %elem2 = f32[3] get-tuple-element((f32[3], token[], f32[3], f32[]) %add_dep), index=2 ROOT %diff = f32[3] subtract(f32[3] %elem0, f32[3] %elem2) diff --git a/tensorflow/compiler/xla/tests/tuple_test.cc b/tensorflow/compiler/xla/tests/tuple_test.cc index 27ce243e9bd4afbdcc1fdc5b6873d4968086e459..9c586bdeb05afb7378e92caed1f3edc408e051bf 100644 --- a/tensorflow/compiler/xla/tests/tuple_test.cc +++ b/tensorflow/compiler/xla/tests/tuple_test.cc @@ -555,8 +555,8 @@ XLA_TEST_F(TupleHloTest, s = (f32[2],f32[2]) tuple-select(cond, tup0, tup1) gte = f32[2] get-tuple-element(s), index=0 tuple = (f32[2]) tuple(gte) - token = token[] after-all() - ROOT outfeed = token[] outfeed(tuple, token) + token0 = token[] after-all() + ROOT outfeed = token[] outfeed(tuple, token0) } )"; auto module = diff --git a/tensorflow/compiler/xla/text_literal_reader.cc b/tensorflow/compiler/xla/text_literal_reader.cc index cdde88c1359416d423685f330e9cbdf77948034f..c78ec522aa5f13556c6d4602267544694887f250 100644 --- a/tensorflow/compiler/xla/text_literal_reader.cc +++ b/tensorflow/compiler/xla/text_literal_reader.cc @@ -27,6 +27,7 @@ limitations under the License. #include "absl/strings/string_view.h" #include "absl/strings/strip.h" #include "tensorflow/compiler/xla/literal.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/types.h" @@ -66,7 +67,7 @@ StatusOr TextLiteralReader::ReadAllLines() { } absl::StripAsciiWhitespace(&shape_string); - TF_ASSIGN_OR_RETURN(Shape shape, ShapeUtil::ParseShapeString(shape_string)); + TF_ASSIGN_OR_RETURN(Shape shape, ParseShape(shape_string)); if (shape.element_type() != F32) { return Unimplemented( "unsupported element type for text literal reading: %s", diff --git a/tensorflow/compiler/xla/tools/replay_computation.cc b/tensorflow/compiler/xla/tools/replay_computation.cc index ff2c3399928c0e6339304323c4f93e212933a340..27a8dd13308b29da9a5013ac9f696613981d68bb 100644 --- a/tensorflow/compiler/xla/tools/replay_computation.cc +++ b/tensorflow/compiler/xla/tools/replay_computation.cc @@ -118,7 +118,12 @@ StatusOr ReplayComputation(const HloSnapshot& module, std::vector> global_data_arguments; std::vector argument_ptrs; if (opts.use_fake_data) { - global_data_arguments = MakeFakeArgumentsOrDie(computation, client); + // Run fake computations with debug options ignoring XLA_FLAGS. Users very + // likely want XLA_FLAGS only to apply to the "real" computation being run, + // not to the fake computations we use for generating arguments. + auto debug_opts = DefaultDebugOptionsIgnoringFlags(); + global_data_arguments = + MakeFakeArgumentsOrDie(computation, client, &debug_opts); for (const auto& data : global_data_arguments) { argument_ptrs.push_back( client->GlobalDataToShapedBuffer(data->handle(), /*device_ordinal=*/0) @@ -140,8 +145,7 @@ StatusOr ReplayComputation(const HloSnapshot& module, bool provide_infeed = false; Shape infeed_shape; if (!opts.fake_infeed_shape.empty()) { - StatusOr shape_status = - ShapeUtil::ParseShapeString(opts.fake_infeed_shape); + StatusOr shape_status = ParseShape(opts.fake_infeed_shape); TF_CHECK_OK(shape_status.status()); infeed_shape = std::move(shape_status).ValueOrDie(); provide_infeed = true; diff --git a/tensorflow/compiler/xla/xla.proto b/tensorflow/compiler/xla/xla.proto index a37eac7fe441d91aa71e1b6fd7b84099fee2215b..0e8fa73f8170addfa5061b33f3d6882a13890bce 100644 --- a/tensorflow/compiler/xla/xla.proto +++ b/tensorflow/compiler/xla/xla.proto @@ -100,6 +100,14 @@ message DebugOptions { // names as specified by the HloPassInterface::name() method. repeated string xla_disable_hlo_passes = 30; + // Disables all HLO passes. Notes that some passes are necessary for + // correctness and the invariants that must be satisfied by "fully optimized" + // HLO are different for different devices and may change over time. The only + // "guarantee", such as it is, is that if you compile XLA and dump the + // optimized HLO for some graph, you should be able to run it again on the + // same device with the same build of XLA. + bool xla_disable_all_hlo_passes = 104; + // Numerical optimization level for the XLA compiler backend; the specific // interpretation of this value is left to the backends. int32 xla_backend_optimization_level = 31; @@ -216,6 +224,14 @@ message DebugOptions { // If set to true XLA:GPU invokes `ptxas` with -O0 (default is -O3). bool xla_gpu_disable_ptxas_optimizations = 103; + // Dump HLO graphs as an HTML (DOT -> SVG inlined in HTML) + bool xla_hlo_dump_as_html = 105; + + // Enable fast math with eigen in the HLO evaluator. + bool xla_hlo_evaluator_use_fast_path = 106; + + // Next id: 107 + // Extra options to pass to the compilation backend (e.g. LLVM); specific // interpretation of these values is left to the backend. map xla_backend_extra_options = 500; @@ -389,7 +405,7 @@ message WaitForExecutionResponse { message ComputeConstantGraphRequest { HloModuleProto computation = 1; - Layout output_layout = 2; + LayoutProto output_layout = 2; } message ComputeConstantResponse { diff --git a/tensorflow/compiler/xla/xla_data.proto b/tensorflow/compiler/xla/xla_data.proto index 85ec83437a10d973687a7fb84285c2e2541a53c7..e9c86abe5094244988d3465ef7c949509deaec37 100644 --- a/tensorflow/compiler/xla/xla_data.proto +++ b/tensorflow/compiler/xla/xla_data.proto @@ -100,6 +100,8 @@ message PaddingConfig { // A format specifies the method used by a layout to store an array in memory. enum Format { + // TODO(b/120869032): Rename this to FORMAT_NONE or something else which + // better corresponds to its meaning. INVALID_FORMAT = 0; // The default layout, with exactly one storage location per element. DENSE = 1; @@ -109,8 +111,9 @@ enum Format { } // Describes a tile used in tiling-based layout. Refer to -// g3doc/layout_with_tiling.md for details about tiling-based layout. -message Tile { +// g3doc/third_party/tensorflow/compiler/xla/g3doc/layout_with_tiling.md for +// details about tiling-based layout. +message TileProto { // Number of elements in each dimension of the tile. It's ordered from the // most major dimension of the tile to the most minor dimension of the tile. // The dimensions correspond to a suffix of the dimensions of the shape being @@ -128,7 +131,7 @@ message Tile { // See the XLA documentation for more information on shapes and layouts. // // LINT.IfChange -message Layout { +message LayoutProto { // The method used to store the data in memory. The format determines which of // the other fields are used by the layout. Format format = 4; @@ -153,7 +156,7 @@ message Layout { // // TODO(b/119839262): implement tiling in each backend or add Unimplemented // error. - repeated Tile tiles = 6; + repeated TileProto tiles = 6; // Bit size of each element. If the size is bigger than what the element // type requires, the value is stored in the least significant @@ -196,7 +199,7 @@ message ShapeProto { repeated ShapeProto tuple_shapes = 4; // The layout used to back this shape. - Layout layout = 5; + LayoutProto layout = 5; // Important: if any field is added, be sure to modify ShapeUtil::Equal(), // ShapeUtil::Compatible() and ShapeUtil::Hash() appropriately to account for diff --git a/tensorflow/compiler/xrt/kernels/xrt_state_ops.cc b/tensorflow/compiler/xrt/kernels/xrt_state_ops.cc index 3258286c10665225aab917107ffa614459c53f3d..1a5bfac337baf773b84b92af5f88ef7a4c8ba81f 100644 --- a/tensorflow/compiler/xrt/kernels/xrt_state_ops.cc +++ b/tensorflow/compiler/xrt/kernels/xrt_state_ops.cc @@ -120,4 +120,9 @@ REGISTER_KERNEL_BUILDER(Name("XRTReleaseAllocationHandle") .HostMemory("handle"), XRTReleaseAllocationOp); +REGISTER_KERNEL_BUILDER(Name("XRTReleaseAllAllocations").Device(DEVICE_XLA_GPU), + XRTReleaseAllAllocationsOp); +REGISTER_KERNEL_BUILDER(Name("XRTReleaseAllAllocations").Device(DEVICE_XLA_CPU), + XRTReleaseAllAllocationsOp); + } // namespace tensorflow diff --git a/tensorflow/compiler/xrt/kernels/xrt_state_ops.h b/tensorflow/compiler/xrt/kernels/xrt_state_ops.h index 26a58fa42d8b730b365b11d2e5608e9945497763..e3b292e7907bfb82f1efc8ed0f27462c682848ce 100644 --- a/tensorflow/compiler/xrt/kernels/xrt_state_ops.h +++ b/tensorflow/compiler/xrt/kernels/xrt_state_ops.h @@ -469,6 +469,26 @@ class XRTReleaseAllocationOp : public OpKernel { } }; +// Op that discards a handle to device memory. +template +class XRTReleaseAllAllocationsOp : public OpKernel { + public: + explicit XRTReleaseAllAllocationsOp(OpKernelConstruction* ctx) + : OpKernel(ctx) {} + ~XRTReleaseAllAllocationsOp() override = default; + XRTReleaseAllAllocationsOp(const XRTReleaseAllAllocationsOp&) = delete; + XRTReleaseAllAllocationsOp& operator=(const XRTReleaseAllAllocationsOp&) = + delete; + + void Compute(OpKernelContext* ctx) override { + VLOG(1) << "XRTReleaseAllAllocationsOp::Compute"; + + ResourceMgr* rm; + OP_REQUIRES_OK(ctx, DeviceAccessor::GetResourceManager(ctx, &rm)); + OP_REQUIRES_OK(ctx, XRTTupleAllocation::ReleaseAllAllocations(rm)); + } +}; + } // namespace tensorflow #endif // TENSORFLOW_COMPILER_XRT_KERNELS_XRT_STATE_OPS_H_ diff --git a/tensorflow/compiler/xrt/ops/xrt_state_ops.cc b/tensorflow/compiler/xrt/ops/xrt_state_ops.cc index a3d63106fa14674a9f5887ccfd908ce17dbc6384..fe6bee0dacf5dc2050613fc9ad34d3235b5a7b63 100644 --- a/tensorflow/compiler/xrt/ops/xrt_state_ops.cc +++ b/tensorflow/compiler/xrt/ops/xrt_state_ops.cc @@ -133,4 +133,11 @@ used. 'handle' is the id returned from the Op that produced the on-device allocation. )"); +REGISTER_OP("XRTReleaseAllAllocations") + .SetShapeFn(tensorflow::shape_inference::NoOutputs) + .Doc( + R"( +Discards all the XRT allocations. All the client held handles will be invalid. +)"); + } // namespace tensorflow diff --git a/tensorflow/compiler/xrt/tests/raw_api_test.cc b/tensorflow/compiler/xrt/tests/raw_api_test.cc index abaa17e50e3f5e47a45f5a8a45fa2090d3efee39..730a2271677c91afecaf252f4a3d1a989a1ccfba 100644 --- a/tensorflow/compiler/xrt/tests/raw_api_test.cc +++ b/tensorflow/compiler/xrt/tests/raw_api_test.cc @@ -265,6 +265,37 @@ TEST(RawApiTest, AllocAndRewrite) { &outputs)); } +TEST(RawApiTest, AllocAndClearAll) { + xrt::XLAAllocation alloc; + alloc.set_device_ordinal(0); + *alloc.mutable_value() = + xla::LiteralUtil::CreateR2({{4, 5}, {6, 7}}).ToProto(); + + Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag()); + auto value = + ops::Const(root.WithDevice("/device:CPU:0"), alloc.SerializeAsString()); + auto handle = ops::XRTAllocate(root, value); + TF_ASSERT_OK(root.status()); + + tensorflow::ClientSession session(root); + std::vector outputs; + TF_EXPECT_OK(session.Run({handle}, &outputs)); + EXPECT_EQ(outputs.size(), 1); + + int64 allocation_handle = outputs[0].scalar()(); + + auto clear_all = ops::XRTReleaseAllAllocations(root); + + outputs.clear(); + TF_EXPECT_OK(session.Run(tensorflow::ClientSession::FeedType(), {}, + {clear_all}, &outputs)); + EXPECT_EQ(outputs.size(), 0); + + auto read_after_clear = ops::XRTReadLiteral(root, Input(allocation_handle)); + EXPECT_EQ(session.Run({read_after_clear}, &outputs).code(), + tensorflow::error::Code::NOT_FOUND); +} + TEST(RawApiTest, ReadAndWriteState) { xrt::XLAAllocation alloc; alloc.set_device_ordinal(0); diff --git a/tensorflow/compiler/xrt/xrt_state.cc b/tensorflow/compiler/xrt/xrt_state.cc index 31603e044d17baa3ae0ae583f61837811bb12495..343460ff107fa81be127950837f786fe4eeadf26 100644 --- a/tensorflow/compiler/xrt/xrt_state.cc +++ b/tensorflow/compiler/xrt/xrt_state.cc @@ -272,6 +272,11 @@ const se::DeviceMemoryBase& XRTTupleAllocation::root_allocation() { return rm->Delete(kTupleContainer, key_string); } +/* static */ Status XRTTupleAllocation::ReleaseAllAllocations(ResourceMgr* rm) { + VLOG(1) << "Releasing all XRT held device memory"; + return rm->Cleanup(kTupleContainer); +} + // Helper typedef to make ShapeTree ForEach helper lambda signatures more // readable. They need a type of const T& where in this case T is the // following pointer. diff --git a/tensorflow/compiler/xrt/xrt_state.h b/tensorflow/compiler/xrt/xrt_state.h index 3664c0cd4e6ad26945ae1012208fdb006164a066..3e3d5024124e13b87eed6f79596d50cd64325914 100644 --- a/tensorflow/compiler/xrt/xrt_state.h +++ b/tensorflow/compiler/xrt/xrt_state.h @@ -129,6 +129,10 @@ class XRTTupleAllocation : public ResourceBase { // Deletes the reference in the rm to an allocation interned under key. static Status DeleteFromResourceManager(ResourceMgr* rm, int64 key); + // Releases all the device memory allocated by XRT within the resource + // manager. + static Status ReleaseAllAllocations(ResourceMgr* rm); + // Adds the allocation to a ResourceMgr and returns the key that will be used // to retrieve it. Transfers a reference on *this to rm. Status Intern(ResourceMgr* rm, int64* key); diff --git a/tensorflow/contrib/bigtable/kernels/test_kernels/bigtable_test_client.cc b/tensorflow/contrib/bigtable/kernels/test_kernels/bigtable_test_client.cc index e95dc577184f7e81d942755b41065f52131ce9f6..3fe71a2ea730cc9b60b2e2088a0d80a08b38d1a9 100644 --- a/tensorflow/contrib/bigtable/kernels/test_kernels/bigtable_test_client.cc +++ b/tensorflow/contrib/bigtable/kernels/test_kernels/bigtable_test_client.cc @@ -399,6 +399,17 @@ BigtableTestClient::AsyncMutateRows( return nullptr; } +std::unique_ptr> +BigtableTestClient::AsyncCheckAndMutateRow( + grpc::ClientContext* context, + const google::bigtable::v2::CheckAndMutateRowRequest& request, + grpc::CompletionQueue* cq) { + LOG(WARNING) << "Call to InMemoryDataClient::" << __func__ + << "(); this will likely cause a crash!"; + return nullptr; +} + std::shared_ptr BigtableTestClient::Channel() { LOG(WARNING) << "Call to InMemoryDataClient::Channel(); this will likely " "cause a crash!"; diff --git a/tensorflow/contrib/bigtable/kernels/test_kernels/bigtable_test_client.h b/tensorflow/contrib/bigtable/kernels/test_kernels/bigtable_test_client.h index c4a1f06bc504c3565c7bb09b42e48e7fbddb9cc6..85705904573e9e7710912e3f4ff30dd8fed5bf85 100644 --- a/tensorflow/contrib/bigtable/kernels/test_kernels/bigtable_test_client.h +++ b/tensorflow/contrib/bigtable/kernels/test_kernels/bigtable_test_client.h @@ -80,6 +80,13 @@ class BigtableTestClient : public ::google::cloud::bigtable::DataClient { const ::google::bigtable::v2::MutateRowsRequest& request, ::grpc::CompletionQueue* cq, void* tag) override; + std::unique_ptr> + AsyncCheckAndMutateRow( + grpc::ClientContext* context, + const google::bigtable::v2::CheckAndMutateRowRequest& request, + grpc::CompletionQueue* cq) override; + std::shared_ptr Channel() override; private: diff --git a/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py b/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py index 9fdc2fc0c2c7b85502f7a3f9ae7c85cf05d5916c..a5951fb7377d48748f5eb578c034176517df7749 100644 --- a/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py +++ b/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py @@ -614,13 +614,19 @@ class GradientBoostedDecisionTreeModel(object): predictions_dict[NUM_TREES_ATTEMPTED] % self._logits_dimension) return constant_op.constant(-1, dtype=dtypes.int32) - def update_stats(self, loss, predictions_dict): + def update_stats(self, loss, predictions_dict, gradients=None, hessians=None): """Update the accumulators with stats from this batch. Args: loss: A scalar tensor representing average loss of examples. predictions_dict: Dictionary of Rank 2 `Tensor` representing information about predictions per example. + gradients: A tensor with the gradients with the respect to logits from + predictions_dict. If not provided, tensorflow will do + autodifferentiation. + hessians: A tensor with the hessians with the respect to logits from + predictions_dict. If not provided, tensorflow will do + autodifferentiation. Returns: Three values: @@ -642,13 +648,14 @@ class GradientBoostedDecisionTreeModel(object): predictions = predictions_dict[PREDICTIONS] partition_ids = predictions_dict[PARTITION_IDS] ensemble_stamp = predictions_dict[ENSEMBLE_STAMP] - gradients = gradients_impl.gradients( - loss, - predictions, - name="Gradients", - colocate_gradients_with_ops=False, - gate_gradients=0, - aggregation_method=None)[0] + if gradients is None: + gradients = gradients_impl.gradients( + loss, + predictions, + name="Gradients", + colocate_gradients_with_ops=False, + gate_gradients=0, + aggregation_method=None)[0] strategy = self._learner_config.multi_class_strategy class_id = self._get_class_id(predictions_dict) @@ -657,17 +664,20 @@ class GradientBoostedDecisionTreeModel(object): # We build one vs rest trees. if self._logits_dimension == 1: # We have only 1 score, gradients is of shape [batch, 1]. - hessians = gradients_impl.gradients( - gradients, - predictions, - name="Hessian", - colocate_gradients_with_ops=False, - gate_gradients=0, - aggregation_method=None)[0] + if hessians is None: + hessians = gradients_impl.gradients( + gradients, + predictions, + name="Hessian", + colocate_gradients_with_ops=False, + gate_gradients=0, + aggregation_method=None)[0] squeezed_gradients = array_ops.squeeze(gradients, axis=[1]) squeezed_hessians = array_ops.squeeze(hessians, axis=[1]) else: + if hessians is not None: + raise ValueError("Providing hessians is not yet supported here.") hessian_list = self._diagonal_hessian(gradients, predictions) # Assemble hessian list into a tensor. hessians = array_ops.stack(hessian_list, axis=1) @@ -678,6 +688,8 @@ class GradientBoostedDecisionTreeModel(object): squeezed_hessians = array_ops.squeeze( _get_column_by_index(hessians, class_id)) else: + if hessians is not None: + raise ValueError("Providing hessians is not yet supported here.") # Other multiclass strategies. if strategy == learner_pb2.LearnerConfig.FULL_HESSIAN: hessian_list = self._full_hessian(gradients, predictions) @@ -835,9 +847,9 @@ class GradientBoostedDecisionTreeModel(object): stats_update_ops.append( control_flow_ops.cond( continue_centering, - self._make_update_bias_stats_fn( - ensemble_stamp, predictions, gradients, - bias_stats_accumulator), control_flow_ops.no_op)) + self._make_update_bias_stats_fn(ensemble_stamp, predictions, + gradients, bias_stats_accumulator, + hessians), control_flow_ops.no_op)) # Update handler stats. handler_reads = collections.OrderedDict() @@ -1162,7 +1174,8 @@ class GradientBoostedDecisionTreeModel(object): def get_max_tree_depth(self): return self._max_tree_depth - def train(self, loss, predictions_dict, labels): + def train(self, loss, predictions_dict, labels, gradients=None, + hessians=None): """Updates the accumalator stats and grows the ensemble. Args: @@ -1171,6 +1184,12 @@ class GradientBoostedDecisionTreeModel(object): about predictions per example. labels: Rank 2 `Tensor` representing labels per example. Has no effect on the training and is only kept for backward compatibility. + gradients: A tensor with the gradients with the respect to logits from + predictions_dict. If not provided, tensorflow will do + autodifferentiation. + hessians: A tensor with the hessians with the respect to logits from + predictions_dict. If not provided, tensorflow will do + autodifferentiation. Returns: An op that adds a new tree to the ensemble. @@ -1179,7 +1198,8 @@ class GradientBoostedDecisionTreeModel(object): ValueError: if inputs are not valid. """ del labels # unused; kept for backward compatibility. - update_op, _, training_state = self.update_stats(loss, predictions_dict) + update_op, _, training_state = self.update_stats(loss, predictions_dict, + gradients, hessians) with ops.control_dependencies(update_op): return self.increment_step_counter_and_maybe_update_ensemble( predictions_dict, training_state) @@ -1271,21 +1291,28 @@ class GradientBoostedDecisionTreeModel(object): ps_ops=ps_ops, ps_strategy=ps_strategy) - def _make_update_bias_stats_fn(self, ensemble_stamp, predictions, gradients, - bias_stats_accumulator): + def _make_update_bias_stats_fn(self, + ensemble_stamp, + predictions, + gradients, + bias_stats_accumulator, + hessians=None): """A method to create the function which updates the bias stats.""" def _update_bias_stats(): """A method to update the bias stats.""" # Get reduced gradients and hessians. grads_sum = math_ops.reduce_sum(gradients, 0) - hess = gradients_impl.gradients( - grads_sum, - predictions, - name="Hessians", - colocate_gradients_with_ops=False, - gate_gradients=0, - aggregation_method=None)[0] + if hessians is not None: + hess = hessians + else: + hess = gradients_impl.gradients( + grads_sum, + predictions, + name="Hessians", + colocate_gradients_with_ops=False, + gate_gradients=0, + aggregation_method=None)[0] hess_sum = math_ops.reduce_sum(hess, 0) # Accumulate gradients and hessians. diff --git a/tensorflow/contrib/distribute/python/collective_all_reduce_strategy.py b/tensorflow/contrib/distribute/python/collective_all_reduce_strategy.py index 5c50a20490482856becedf7b1379d2a0583d9a11..346513dc586f208315fd777dc7ddfa500c82f0d7 100644 --- a/tensorflow/contrib/distribute/python/collective_all_reduce_strategy.py +++ b/tensorflow/contrib/distribute/python/collective_all_reduce_strategy.py @@ -70,6 +70,8 @@ class CollectiveAllReduceExtended(mirrored_strategy.MirroredExtended): self._cross_device_ops = None self._num_gpus_per_worker = num_gpus_per_worker self._initialize_local_worker(num_gpus_per_worker) + assert isinstance(self._get_cross_device_ops(), + cross_device_ops_lib.CollectiveAllReduce) def _initialize_local_worker(self, num_gpus_per_worker): """Initializes the object for local training.""" @@ -86,7 +88,7 @@ class CollectiveAllReduceExtended(mirrored_strategy.MirroredExtended): self._collective_keys = cross_device_utils.CollectiveKeys() self._initialize_local(local_devices) - self._cross_tower_ops = cross_device_ops_lib.CollectiveAllReduce( + self._cross_device_ops = cross_device_ops_lib.CollectiveAllReduce( num_workers=self._num_workers, num_gpus_per_worker=num_gpus_per_worker, collective_keys=self._collective_keys) @@ -128,7 +130,7 @@ class CollectiveAllReduceExtended(mirrored_strategy.MirroredExtended): self._collective_keys = cross_device_utils.CollectiveKeys() self._initialize_local(local_devices) - self._cross_tower_ops = cross_device_ops_lib.CollectiveAllReduce( + self._cross_device_ops = cross_device_ops_lib.CollectiveAllReduce( num_workers=self._num_workers, num_gpus_per_worker=num_gpus_per_worker, collective_keys=self._collective_keys) @@ -267,6 +269,8 @@ class CollectiveAllReduceExtended(mirrored_strategy.MirroredExtended): # already been initialized with a `cluster_spec`. self._initialize_multi_worker(self._num_gpus_per_worker, cluster_spec, task_type, task_id) + assert isinstance(self._get_cross_device_ops(), + cross_device_ops_lib.CollectiveAllReduce) if session_config: session_config.CopyFrom(self._update_config_proto(session_config)) diff --git a/tensorflow/contrib/distribute/python/collective_all_reduce_strategy_test.py b/tensorflow/contrib/distribute/python/collective_all_reduce_strategy_test.py index 8a9e583f0afaac37a2057bae9b1ed79de43d68bc..6d7cd14ed5ad8a283e3d0d3405efc58fe670f9cd 100644 --- a/tensorflow/contrib/distribute/python/collective_all_reduce_strategy_test.py +++ b/tensorflow/contrib/distribute/python/collective_all_reduce_strategy_test.py @@ -82,7 +82,7 @@ class CollectiveAllReduceStrategyTestBase( instance_key_with_id_start=num_gpus * 10000 + CollectiveAllReduceStrategyTestBase.collective_key_base) distribution.extended._collective_keys = collective_keys - distribution.extended._inferred_cross_device_ops._collective_keys = ( + distribution.extended._cross_device_ops._collective_keys = ( collective_keys) if task_type and task_id is not None: return distribution, 'grpc://' + self._cluster_spec[task_type][ diff --git a/tensorflow/contrib/distribute/python/estimator_training_test.py b/tensorflow/contrib/distribute/python/estimator_training_test.py index b369a7fefe6f35cf5a9b64451419cf4f72a99471..3f55a8a1c8b88d1b8e4031547fa3fbe519983630 100644 --- a/tensorflow/contrib/distribute/python/estimator_training_test.py +++ b/tensorflow/contrib/distribute/python/estimator_training_test.py @@ -375,11 +375,13 @@ class DistributeCoordinatorIntegrationTest( threads = self.run_multiple_tasks_in_threads(self._independent_worker_fn, cluster_spec, train_distribute, eval_distribute) + threads_to_join = [] for task_type, ts in threads.items(): if task_type == PS: continue for t in ts: - t.join() + threads_to_join.append(t) + self.join_independent_workers(threads_to_join) estimator = self._get_estimator(train_distribute, eval_distribute) self._inspect_train_and_eval_events(estimator) @@ -413,8 +415,7 @@ class DistributeCoordinatorIntegrationTest( threads = self.run_multiple_tasks_in_threads(self._independent_worker_fn, cluster_spec, train_distribute, eval_distribute) - threads[WORKER][0].join() - threads[EVALUATOR][0].join() + self.join_independent_workers([threads[WORKER][0], threads[EVALUATOR][0]]) estimator = self._get_estimator(train_distribute, eval_distribute) self._inspect_train_and_eval_events(estimator) diff --git a/tensorflow/contrib/distribute/python/keras_optimizer_v2_test.py b/tensorflow/contrib/distribute/python/keras_optimizer_v2_test.py index 6dfd85bcc4f3784e2744fd876a7190cc9581d96a..8c596549c4e20754675f69861d4c7f14f7c3c126 100644 --- a/tensorflow/contrib/distribute/python/keras_optimizer_v2_test.py +++ b/tensorflow/contrib/distribute/python/keras_optimizer_v2_test.py @@ -18,24 +18,12 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import shutil -import tempfile from absl.testing import parameterized import numpy as np -import six from tensorflow.contrib.distribute.python import combinations -from tensorflow.core.protobuf import config_pb2 from tensorflow.python import keras -from tensorflow.python.data.ops import dataset_ops from tensorflow.python.distribute import distribution_strategy_context as ds_context -from tensorflow.python.estimator import run_config -from tensorflow.python.estimator import training -from tensorflow.python.estimator.canned import dnn_linear_combined -from tensorflow.python.estimator.canned import prediction_keys -from tensorflow.python.estimator.export import export -from tensorflow.python.estimator.inputs import numpy_io -from tensorflow.python.feature_column import feature_column_lib as feature_column from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops @@ -44,103 +32,7 @@ from tensorflow.python.keras.optimizer_v2 import gradient_descent from tensorflow.python.ops import math_ops from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variables -from tensorflow.python.platform import gfile from tensorflow.python.platform import test -from tensorflow.python.summary.writer import writer_cache - - -class KerasOptimizerV2IntegrationTest(test.TestCase, parameterized.TestCase): - - def setUp(self): - self._model_dir = tempfile.mkdtemp() - - def dataset_input_fn(self, x, y, batch_size): - - def input_fn(): - dataset = dataset_ops.Dataset.from_tensor_slices((x, y)) - dataset = dataset.repeat(1).batch(batch_size) - return dataset - - return input_fn - - @combinations.generate( - combinations.combine( - mode=['graph'], - distribution=[ - combinations.one_device_strategy, - combinations.mirrored_strategy_with_gpu_and_cpu, - combinations.mirrored_strategy_with_two_gpus, - combinations.core_mirrored_strategy_with_gpu_and_cpu, - combinations.core_mirrored_strategy_with_two_gpus - ], - use_train_and_evaluate=[True, False])) - def test_complete_flow_with_mode(self, distribution, use_train_and_evaluate): - label_dimension = 2 - input_dimension = label_dimension - batch_size = 10 - data = np.linspace(0., 2., batch_size * label_dimension, dtype=np.float32) - data = data.reshape(batch_size, label_dimension) - train_input_fn = self.dataset_input_fn( - x={'x': data}, - y=data, - batch_size=batch_size // distribution.num_replicas_in_sync) - eval_input_fn = self.dataset_input_fn( - x={'x': data}, - y=data, - batch_size=batch_size // distribution.num_replicas_in_sync) - predict_input_fn = numpy_io.numpy_input_fn( - x={'x': data}, batch_size=batch_size, shuffle=False) - - linear_feature_columns = [ - feature_column.numeric_column('x', shape=(input_dimension,)) - ] - dnn_feature_columns = [ - feature_column.numeric_column('x', shape=(input_dimension,)) - ] - feature_columns = linear_feature_columns + dnn_feature_columns - session_config = config_pb2.ConfigProto( - log_device_placement=True, allow_soft_placement=True) - estimator = dnn_linear_combined.DNNLinearCombinedRegressor( - linear_feature_columns=linear_feature_columns, - dnn_hidden_units=(2, 2), - dnn_feature_columns=dnn_feature_columns, - label_dimension=label_dimension, - model_dir=self._model_dir, - dnn_optimizer=adam.Adam(0.001), - linear_optimizer=adam.Adam(0.001), - config=run_config.RunConfig( - train_distribute=distribution, - eval_distribute=distribution, - session_config=session_config)) - - num_steps = 2 - if use_train_and_evaluate: - scores, _ = training.train_and_evaluate( - estimator, training.TrainSpec(train_input_fn, max_steps=num_steps), - training.EvalSpec(eval_input_fn)) - else: - estimator.train(train_input_fn, steps=num_steps) - scores = estimator.evaluate(eval_input_fn) - - self.assertIn('loss', six.iterkeys(scores)) - - predictions = np.array([ - x[prediction_keys.PredictionKeys.PREDICTIONS] - for x in estimator.predict(predict_input_fn) - ]) - self.assertAllEqual((batch_size, label_dimension), predictions.shape) - - feature_spec = feature_column.make_parse_example_spec(feature_columns) - serving_input_receiver_fn = export.build_parsing_serving_input_receiver_fn( - feature_spec) - export_dir = estimator.export_savedmodel(tempfile.mkdtemp(), - serving_input_receiver_fn) - self.assertTrue(gfile.Exists(export_dir)) - - def tearDown(self): - if self._model_dir: - writer_cache.FileWriterCache.clear() - shutil.rmtree(self._model_dir) def get_model(): @@ -162,7 +54,9 @@ class MirroredStrategyOptimizerV2Test(test.TestCase, parameterized.TestCase): var = variables.Variable( 2.0, name='var', aggregation=variable_scope.VariableAggregation.SUM) # grad for cpu is 1, grad for gpu is 2, avg grad is 1.5. - loss = math_ops.cast(_replica_id() + 1, dtype=dtypes.float32) * var + def loss(): + return math_ops.cast(_replica_id() + 1, dtype=dtypes.float32) * var + optimizer = adam.Adam(learning_rate=0.01, beta_1=0.2, beta_2=0.2) train_op = optimizer.minimize(loss, var_list=[var]) m = optimizer.get_slot(var, 'm') diff --git a/tensorflow/contrib/distribute/python/keras_test.py b/tensorflow/contrib/distribute/python/keras_test.py index 683cc89bfbae9c877ea6794d311ffc00c96c6937..c53e76f922372d8c7937e05fde61772d0b064674 100644 --- a/tensorflow/contrib/distribute/python/keras_test.py +++ b/tensorflow/contrib/distribute/python/keras_test.py @@ -375,7 +375,9 @@ class TestEstimatorDistributionStrategy(test_util.TensorFlowTestCase, @combinations.generate(combinations.combine( distribution=[ + combinations.mirrored_strategy_with_gpu_and_cpu, combinations.mirrored_strategy_with_two_gpus, + combinations.core_mirrored_strategy_with_gpu_and_cpu, combinations.core_mirrored_strategy_with_two_gpus], mode=['graph'])) def test_train_functional_with_distribution_strategy(self, distribution): @@ -403,7 +405,9 @@ class TestEstimatorDistributionStrategy(test_util.TensorFlowTestCase, @combinations.generate(combinations.combine( distribution=[ + combinations.mirrored_strategy_with_gpu_and_cpu, combinations.mirrored_strategy_with_two_gpus, + combinations.core_mirrored_strategy_with_gpu_and_cpu, combinations.core_mirrored_strategy_with_two_gpus], mode=['graph'])) def test_train_sequential_with_distribution_strategy(self, distribution): @@ -430,8 +434,8 @@ class TestEstimatorDistributionStrategy(test_util.TensorFlowTestCase, @combinations.generate(combinations.combine( distribution=[ - combinations.mirrored_strategy_with_two_gpus, - combinations.core_mirrored_strategy_with_two_gpus], + combinations.mirrored_strategy_with_gpu_and_cpu, + combinations.core_mirrored_strategy_with_gpu_and_cpu], mode=['graph'])) def test_multi_inputs_multi_outputs_with_input_fn_as_dict(self, distribution): train_data, test_data = get_multi_inputs_multi_outputs_data() @@ -482,8 +486,8 @@ class TestEstimatorDistributionStrategy(test_util.TensorFlowTestCase, @combinations.generate(combinations.combine( distribution=[ - combinations.mirrored_strategy_with_two_gpus, - combinations.core_mirrored_strategy_with_two_gpus], + combinations.mirrored_strategy_with_gpu_and_cpu, + combinations.core_mirrored_strategy_with_gpu_and_cpu], mode=['graph'])) def test_keras_optimizer_with_distribution_strategy(self, distribution): keras_model = simple_sequential_model() @@ -904,10 +908,12 @@ class TestDistributionStrategyWithDatasets(test.TestCase, @combinations.generate(combinations.combine( distribution=[ - combinations.mirrored_strategy_with_two_gpus, - combinations.core_mirrored_strategy_with_two_gpus], + combinations.mirrored_strategy_with_gpu_and_cpu, + combinations.core_mirrored_strategy_with_gpu_and_cpu], mode=['graph', 'eager'])) - def test_dataset_wrong_input_shape(self, distribution): + # TODO(b/120943676, b/120957836): Re-enable once the validation code is + # restored. + def DISABLED_test_dataset_wrong_input_shape(self, distribution): with self.cached_session(): model = get_model() @@ -927,9 +933,11 @@ class TestDistributionStrategyWithDatasets(test.TestCase, model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=0) @combinations.generate(combinations.combine( - distribution=[combinations.mirrored_strategy_with_two_gpus], + distribution=[combinations.mirrored_strategy_with_gpu_and_cpu], mode=['graph', 'eager'])) - def test_dataset_no_batch_input_validation(self, distribution): + # TODO(b/120943676, b/120957836): Re-enable once the validation code is + # restored. + def DISABLED_test_dataset_no_batch_input_validation(self, distribution): with self.cached_session(): model = get_model() @@ -967,7 +975,9 @@ class TestDistributionStrategyWithDatasets(test.TestCase, @combinations.generate(combinations.combine( distribution=[ + combinations.mirrored_strategy_with_gpu_and_cpu, combinations.mirrored_strategy_with_two_gpus, + combinations.core_mirrored_strategy_with_gpu_and_cpu, combinations.core_mirrored_strategy_with_two_gpus], mode=['graph', 'eager'])) def test_learning_phase_value(self, distribution): @@ -1170,8 +1180,8 @@ class TestDistributionStrategyWithLossMasking(test.TestCase, # work for TPU due to some invalid datatype. @combinations.generate(combinations.combine( distribution=[ - combinations.mirrored_strategy_with_two_gpus, - combinations.core_mirrored_strategy_with_two_gpus], + combinations.mirrored_strategy_with_gpu_and_cpu, + combinations.core_mirrored_strategy_with_gpu_and_cpu], mode=['graph', 'eager'])) def test_masking(self, distribution): with self.cached_session(): diff --git a/tensorflow/contrib/distribute/python/mirrored_strategy.py b/tensorflow/contrib/distribute/python/mirrored_strategy.py index 20f1a08d4261b931a9353738147fba7d7dff9225..24399db6522c325722b95399fd002eed9fd955f2 100644 --- a/tensorflow/contrib/distribute/python/mirrored_strategy.py +++ b/tensorflow/contrib/distribute/python/mirrored_strategy.py @@ -28,7 +28,6 @@ from tensorflow.python.distribute import values # pylint: disable=protected-access,invalid-name _call_for_each_replica = mirrored_strategy._call_for_each_replica -_reduce_non_distributed_value = mirrored_strategy._reduce_non_distributed_value _create_mirrored_variable = mirrored_strategy._create_mirrored_variable all_local_devices = mirrored_strategy.all_local_devices CoreMirroredStrategy = mirrored_strategy.MirroredStrategy diff --git a/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py b/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py index 36be5c83f8bafb6c934d1d7682b5227b1f71c089..337a86b3421fdb90c98cd5097dd880fdbe5871b9 100644 --- a/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py +++ b/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py @@ -183,6 +183,34 @@ class MirroredStrategyVariableCreatorStackTest( expected = ("main_thread:thread_0", "main_thread:thread_1") self.assertEqual(expected, result) +@combinations.generate(combinations.combine( + distribution=[ + combinations.mirrored_strategy_with_gpu_and_cpu, + combinations.core_mirrored_strategy_with_gpu_and_cpu], + mode=["graph", "eager"])) +class MirroredStrategyCallForEachReplicaTest(test.TestCase): + + def testExecutingEagerlyOutsideFunction(self, distribution): + """Verify we preserve the value of executing_eagerly_outside_functions().""" + def model_fn(): + return ops.executing_eagerly_outside_functions() + + originally = ops.executing_eagerly_outside_functions() + with distribution.scope(): + in_scope = ops.executing_eagerly_outside_functions() + in_model_fn = distribution.extended.call_for_each_replica(model_fn) + unwrapped = distribution.unwrap(in_model_fn) + self.assertEqual(in_scope, unwrapped[0]) + self.assertEqual(in_scope, originally) + + # Verify this all again, but this time in a FuncGraph. + with func_graph.FuncGraph("fg").as_default(), distribution.scope(): + in_scope = ops.executing_eagerly_outside_functions() + in_model_fn = distribution.extended.call_for_each_replica(model_fn) + unwrapped = distribution.unwrap(in_model_fn) + self.assertEqual(in_scope, unwrapped[0]) + self.assertEqual(in_scope, originally) + @combinations.generate(combinations.combine( distribution=[ diff --git a/tensorflow/contrib/distribute/python/multi_worker_test_base.py b/tensorflow/contrib/distribute/python/multi_worker_test_base.py index 147c9b83f866fd364ea23cf7988692a7b5f61b9c..b05aac431f65b4281d9ed9c2fa95c210d55f4008 100644 --- a/tensorflow/contrib/distribute/python/multi_worker_test_base.py +++ b/tensorflow/contrib/distribute/python/multi_worker_test_base.py @@ -40,6 +40,7 @@ from tensorflow.python.client import session from tensorflow.python.estimator import run_config from tensorflow.python.platform import test from tensorflow.python.platform import tf_logging as logging +from tensorflow.python.training import coordinator from tensorflow.python.training import server_lib ASSIGNED_PORTS = set() @@ -360,6 +361,7 @@ class IndependentWorkerTestBase(test.TestCase): self._mock_os_env = MockOsEnv() self._mock_context = test.mock.patch.object(os, 'environ', self._mock_os_env) + self._coord = coordinator.Coordinator() super(IndependentWorkerTestBase, self).setUp() self._mock_context.__enter__() @@ -368,8 +370,9 @@ class IndependentWorkerTestBase(test.TestCase): super(IndependentWorkerTestBase, self).tearDown() def _task_thread(self, task_fn, tf_config, *args, **kwargs): - os.environ['TF_CONFIG'] = json.dumps(tf_config) - task_fn(*args, **kwargs) + with self._coord.stop_on_exception(): + os.environ['TF_CONFIG'] = json.dumps(tf_config) + task_fn(*args, **kwargs) def _run_task_in_thread(self, task_fn, cluster_spec, task_type, task_id, *args, **kwargs): @@ -403,3 +406,6 @@ class IndependentWorkerTestBase(test.TestCase): *args, **kwargs) threads[task_type].append(t) return threads + + def join_independent_workers(self, worker_threads): + self._coord.join(worker_threads) diff --git a/tensorflow/contrib/distribute/python/parameter_server_strategy.py b/tensorflow/contrib/distribute/python/parameter_server_strategy.py index 2c7766f95fbcb7b68a53ad0052f21485c763a1db..ca51b07be6601dd615e24137e51c4b34793fdbc0 100644 --- a/tensorflow/contrib/distribute/python/parameter_server_strategy.py +++ b/tensorflow/contrib/distribute/python/parameter_server_strategy.py @@ -356,7 +356,7 @@ class ParameterServerExtended(distribute_lib.DistributionStrategyExtended): self._verify_destinations_not_different_worker(destinations) if not isinstance(value, values.DistributedValues): # pylint: disable=protected-access - return mirrored_strategy._reduce_non_distributed_value( + return cross_device_ops_lib.reduce_non_distributed_value( self, reduce_op, value, destinations) return self._cross_device_ops.reduce( reduce_op, value, destinations=destinations) diff --git a/tensorflow/contrib/distribute/python/tpu_strategy.py b/tensorflow/contrib/distribute/python/tpu_strategy.py index b6f5b492017fc7dfd329e69ad9ca418ae682bc4b..7ea245eb6eb9738bc95e8ac54c1c43de0ddcef7c 100644 --- a/tensorflow/contrib/distribute/python/tpu_strategy.py +++ b/tensorflow/contrib/distribute/python/tpu_strategy.py @@ -465,6 +465,14 @@ class TPUExtended(distribute_lib.DistributionStrategyExtended): "Currently only support sum & mean in TPUStrategy.") return tpu_ops.cross_replica_sum(value) + if not isinstance(value, values.DistributedValues): + # This function handles reducing values that are not PerReplica or + # Mirrored values. For example, the same value could be present on all + # replicas in which case `value` would be a single value or value could + # be 0. + return cross_device_ops_lib.reduce_non_distributed_value( + self, reduce_op, value, destinations) + # Validate that the destination is same as the host device # Note we don't do this when in replicate context as the reduction is # performed on the TPU device itself. diff --git a/tensorflow/contrib/eager/python/tfe.py b/tensorflow/contrib/eager/python/tfe.py index 33c988fd9065e7fbe7b9aeb85cad82eb3c119f76..8882a863c30d8b222c68d6952279c3744345883c 100644 --- a/tensorflow/contrib/eager/python/tfe.py +++ b/tensorflow/contrib/eager/python/tfe.py @@ -41,6 +41,8 @@ To use, at program startup, call `tf.enable_eager_execution()`. @@add_execution_callback @@clear_execution_callbacks +@@errstate +@@ExecutionCallback @@inf_callback @@inf_nan_callback @@nan_callback @@ -119,6 +121,8 @@ from tensorflow.python.eager.context import set_server_def from tensorflow.python.eager.def_function import function from tensorflow.python.eager.execution_callbacks import add_execution_callback from tensorflow.python.eager.execution_callbacks import clear_execution_callbacks +from tensorflow.python.eager.execution_callbacks import errstate +from tensorflow.python.eager.execution_callbacks import ExecutionCallback from tensorflow.python.eager.execution_callbacks import inf_callback from tensorflow.python.eager.execution_callbacks import inf_nan_callback from tensorflow.python.eager.execution_callbacks import nan_callback diff --git a/tensorflow/contrib/feature_column/BUILD b/tensorflow/contrib/feature_column/BUILD index 1cd83bdb5de7c2f6dc91c980750b49aca1a7790b..4c1d1a29f20b5574b63cf87ecf62db95f92902cd 100644 --- a/tensorflow/contrib/feature_column/BUILD +++ b/tensorflow/contrib/feature_column/BUILD @@ -110,8 +110,8 @@ py_test( "//tensorflow/python:parsing_ops", "//tensorflow/python:sparse_tensor", "//tensorflow/python:training", - "//tensorflow/python/feature_column", "//tensorflow/python/feature_column:feature_column_py", + "//tensorflow/python/feature_column:feature_column_v2_test", "//third_party/py/numpy", "@absl_py//absl/testing:parameterized", ], diff --git a/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column_v2.py b/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column_v2.py index 0d34ad161855476b6a4cd9a258521dbe122b4140..83b93ec332044f754f9dcde8d7c5c19b26e53a4a 100644 --- a/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column_v2.py +++ b/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column_v2.py @@ -203,7 +203,8 @@ def sequence_categorical_column_with_identity( columns = [watches_embedding] features = tf.parse_example(..., features=make_parse_example_spec(columns)) - input_layer, sequence_length = sequence_input_layer(features, columns) + sequence_feature_layer = SequenceFeatureLayer(columns) + input_layer, sequence_length = sequence_feature_layer(features) rnn_cell = tf.nn.rnn_cell.BasicRNNCell(hidden_size) outputs, state = tf.nn.dynamic_rnn( @@ -219,15 +220,17 @@ def sequence_categorical_column_with_identity( `[0, num_buckets)`, and will replace out-of-range inputs. Returns: - A `_SequenceCategoricalColumn`. + A `SequenceCategoricalColumn`. Raises: ValueError: if `num_buckets` is less than one. ValueError: if `default_value` is not in range `[0, num_buckets)`. """ - return fc_old._SequenceCategoricalColumn( - fc_old._categorical_column_with_identity( - key=key, num_buckets=num_buckets, default_value=default_value)) + return fc.SequenceCategoricalColumn( + fc.categorical_column_with_identity( + key=key, + num_buckets=num_buckets, + default_value=default_value)) def sequence_categorical_column_with_hash_bucket( @@ -247,7 +250,8 @@ def sequence_categorical_column_with_hash_bucket( columns = [tokens_embedding] features = tf.parse_example(..., features=make_parse_example_spec(columns)) - input_layer, sequence_length = sequence_input_layer(features, columns) + sequence_feature_layer = SequenceFeatureLayer(columns) + input_layer, sequence_length = sequence_feature_layer(features) rnn_cell = tf.nn.rnn_cell.BasicRNNCell(hidden_size) outputs, state = tf.nn.dynamic_rnn( @@ -260,15 +264,17 @@ def sequence_categorical_column_with_hash_bucket( dtype: The type of features. Only string and integer types are supported. Returns: - A `_SequenceCategoricalColumn`. + A `SequenceCategoricalColumn`. Raises: ValueError: `hash_bucket_size` is not greater than 1. ValueError: `dtype` is neither string nor integer. """ - return fc_old._SequenceCategoricalColumn( - fc_old._categorical_column_with_hash_bucket( - key=key, hash_bucket_size=hash_bucket_size, dtype=dtype)) + return fc.SequenceCategoricalColumn( + fc.categorical_column_with_hash_bucket( + key=key, + hash_bucket_size=hash_bucket_size, + dtype=dtype)) def sequence_categorical_column_with_vocabulary_file( @@ -290,7 +296,8 @@ def sequence_categorical_column_with_vocabulary_file( columns = [states_embedding] features = tf.parse_example(..., features=make_parse_example_spec(columns)) - input_layer, sequence_length = sequence_input_layer(features, columns) + sequence_feature_layer = SequenceFeatureLayer(columns) + input_layer, sequence_length = sequence_feature_layer(features) rnn_cell = tf.nn.rnn_cell.BasicRNNCell(hidden_size) outputs, state = tf.nn.dynamic_rnn( @@ -314,7 +321,7 @@ def sequence_categorical_column_with_vocabulary_file( dtype: The type of features. Only string and integer types are supported. Returns: - A `_SequenceCategoricalColumn`. + A `SequenceCategoricalColumn`. Raises: ValueError: `vocabulary_file` is missing or cannot be opened. @@ -323,8 +330,8 @@ def sequence_categorical_column_with_vocabulary_file( ValueError: `num_oov_buckets` and `default_value` are both specified. ValueError: `dtype` is neither string nor integer. """ - return fc_old._SequenceCategoricalColumn( - fc_old._categorical_column_with_vocabulary_file( + return fc.SequenceCategoricalColumn( + fc.categorical_column_with_vocabulary_file( key=key, vocabulary_file=vocabulary_file, vocabulary_size=vocabulary_size, @@ -351,7 +358,8 @@ def sequence_categorical_column_with_vocabulary_list( columns = [colors_embedding] features = tf.parse_example(..., features=make_parse_example_spec(columns)) - input_layer, sequence_length = sequence_input_layer(features, columns) + sequence_feature_layer = SequenceFeatureLayer(columns) + input_layer, sequence_length = sequence_feature_layer(features) rnn_cell = tf.nn.rnn_cell.BasicRNNCell(hidden_size) outputs, state = tf.nn.dynamic_rnn( @@ -375,7 +383,7 @@ def sequence_categorical_column_with_vocabulary_list( with `default_value`. Returns: - A `_SequenceCategoricalColumn`. + A `SequenceCategoricalColumn`. Raises: ValueError: if `vocabulary_list` is empty, or contains duplicate keys. @@ -383,8 +391,8 @@ def sequence_categorical_column_with_vocabulary_list( ValueError: `num_oov_buckets` and `default_value` are both specified. ValueError: if `dtype` is not integer or string. """ - return fc_old._SequenceCategoricalColumn( - fc_old._categorical_column_with_vocabulary_list( + return fc.SequenceCategoricalColumn( + fc.categorical_column_with_vocabulary_list( key=key, vocabulary_list=vocabulary_list, dtype=dtype, diff --git a/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column_v2_test.py b/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column_v2_test.py index ca4398a142065de0be7bee57cd7e54670bbae12e..be012a87690c24c6d9b7808790393e1aa6d01211 100644 --- a/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column_v2_test.py +++ b/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column_v2_test.py @@ -26,7 +26,7 @@ from tensorflow.contrib.feature_column.python.feature_column import sequence_fea from tensorflow.contrib.feature_column.python.feature_column import sequence_feature_column_v2 as sfc from tensorflow.python.feature_column import feature_column as fc_old from tensorflow.python.feature_column import feature_column_lib as fc -from tensorflow.python.feature_column.feature_column import _LazyBuilder +from tensorflow.python.feature_column.feature_column_v2_test import _TestStateManager from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors from tensorflow.python.framework import ops @@ -131,7 +131,7 @@ class SequenceInputLayerTest(test.TestCase, parameterized.TestCase): feature_columns=[embedding_column_b, embedding_column_a]) global_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES) - self.assertItemsEqual( + self.assertCountEqual( ('sequence_input_layer/aaa_embedding/embedding_weights:0', 'sequence_input_layer/bbb_embedding/embedding_weights:0'), tuple([v.name for v in global_vars])) @@ -223,7 +223,7 @@ class SequenceInputLayerTest(test.TestCase, parameterized.TestCase): feature_columns=shared_embedding_columns) global_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES) - self.assertItemsEqual( + self.assertCountEqual( ('sequence_input_layer/aaa_bbb_shared_embedding/embedding_weights:0',), tuple([v.name for v in global_vars])) with monitored_session.MonitoredSession() as sess: @@ -670,6 +670,23 @@ def _assert_sparse_tensor_indices_shape(test_case, expected, actual): test_case.assertAllEqual(expected.dense_shape, actual.dense_shape) +def _get_sequence_dense_tensor(column, features): + return column.get_sequence_dense_tensor( + fc.FeatureTransformationCache(features), None) + + +def _get_sequence_dense_tensor_state(column, features): + state_manager = _TestStateManager() + column.create_state(state_manager) + return column.get_sequence_dense_tensor( + fc.FeatureTransformationCache(features), state_manager) + + +def _get_sparse_tensors(column, features): + return column.get_sparse_tensors( + fc.FeatureTransformationCache(features), None) + + class SequenceCategoricalColumnWithIdentityTest( test.TestCase, parameterized.TestCase): @@ -698,7 +715,7 @@ class SequenceCategoricalColumnWithIdentityTest( expected = sparse_tensor.SparseTensorValue(**expected_args) column = sfc.sequence_categorical_column_with_identity('aaa', num_buckets=9) - id_weight_pair = column._get_sparse_tensors(_LazyBuilder({'aaa': inputs})) + id_weight_pair = _get_sparse_tensors(column, {'aaa': inputs}) self.assertIsNone(id_weight_pair.weight_tensor) with monitored_session.MonitoredSession() as sess: @@ -737,7 +754,7 @@ class SequenceCategoricalColumnWithHashBucketTest( column = sfc.sequence_categorical_column_with_hash_bucket( 'aaa', hash_bucket_size=10) - id_weight_pair = column._get_sparse_tensors(_LazyBuilder({'aaa': inputs})) + id_weight_pair = _get_sparse_tensors(column, {'aaa': inputs}) self.assertIsNone(id_weight_pair.weight_tensor) with monitored_session.MonitoredSession() as sess: @@ -790,7 +807,7 @@ class SequenceCategoricalColumnWithVocabularyFileTest( vocabulary_file=self._wire_vocabulary_file_name, vocabulary_size=self._wire_vocabulary_size) - id_weight_pair = column._get_sparse_tensors(_LazyBuilder({'aaa': inputs})) + id_weight_pair = _get_sparse_tensors(column, {'aaa': inputs}) self.assertIsNone(id_weight_pair.weight_tensor) with monitored_session.MonitoredSession() as sess: @@ -814,8 +831,7 @@ class SequenceCategoricalColumnWithVocabularyFileTest( input_placeholder_shape[1] = None input_placeholder = array_ops.sparse_placeholder( dtypes.string, shape=input_placeholder_shape) - id_weight_pair = column._get_sparse_tensors( - _LazyBuilder({'aaa': input_placeholder})) + id_weight_pair = _get_sparse_tensors(column, {'aaa': input_placeholder}) self.assertIsNone(id_weight_pair.weight_tensor) with monitored_session.MonitoredSession() as sess: @@ -855,7 +871,7 @@ class SequenceCategoricalColumnWithVocabularyListTest( key='aaa', vocabulary_list=('omar', 'stringer', 'marlo')) - id_weight_pair = column._get_sparse_tensors(_LazyBuilder({'aaa': inputs})) + id_weight_pair = _get_sparse_tensors(column, {'aaa': inputs}) self.assertIsNone(id_weight_pair.weight_tensor) with monitored_session.MonitoredSession() as sess: @@ -922,13 +938,12 @@ class SequenceEmbeddingColumnTest( categorical_column = sfc.sequence_categorical_column_with_identity( key='aaa', num_buckets=vocabulary_size) - embedding_column = fc_old._embedding_column( - categorical_column, - dimension=embedding_dimension, + embedding_column = fc.embedding_column( + categorical_column, dimension=embedding_dimension, initializer=_initializer) - embedding_lookup, _ = embedding_column._get_sequence_dense_tensor( - _LazyBuilder({'aaa': inputs})) + embedding_lookup, _ = _get_sequence_dense_tensor_state( + embedding_column, {'aaa': inputs}) global_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES) self.assertItemsEqual( @@ -961,10 +976,11 @@ class SequenceEmbeddingColumnTest( categorical_column = sfc.sequence_categorical_column_with_identity( key='aaa', num_buckets=vocabulary_size) - embedding_column = fc_old._embedding_column(categorical_column, dimension=2) + embedding_column = fc.embedding_column( + categorical_column, dimension=2) - _, sequence_length = embedding_column._get_sequence_dense_tensor( - _LazyBuilder({'aaa': inputs})) + _, sequence_length = _get_sequence_dense_tensor_state( + embedding_column, {'aaa': inputs}) with monitored_session.MonitoredSession() as sess: sequence_length = sess.run(sequence_length) @@ -988,10 +1004,11 @@ class SequenceEmbeddingColumnTest( categorical_column = sfc.sequence_categorical_column_with_identity( key='aaa', num_buckets=vocabulary_size) - embedding_column = fc_old._embedding_column(categorical_column, dimension=2) + embedding_column = fc.embedding_column( + categorical_column, dimension=2) - _, sequence_length = embedding_column._get_sequence_dense_tensor( - _LazyBuilder({'aaa': sparse_input})) + _, sequence_length = _get_sequence_dense_tensor_state( + embedding_column, {'aaa': sparse_input}) with monitored_session.MonitoredSession() as sess: self.assertAllEqual( @@ -1058,22 +1075,18 @@ class SequenceSharedEmbeddingColumnTest(test.TestCase): key='aaa', num_buckets=vocabulary_size) categorical_column_b = sfc.sequence_categorical_column_with_identity( key='bbb', num_buckets=vocabulary_size) - shared_embedding_columns = fc.shared_embedding_columns( + shared_embedding_columns = fc.shared_embedding_columns_v2( [categorical_column_a, categorical_column_b], dimension=embedding_dimension, initializer=_initializer) - embedding_lookup_a = shared_embedding_columns[0]._get_sequence_dense_tensor( - _LazyBuilder({ - 'aaa': sparse_input_a - }))[0] - embedding_lookup_b = shared_embedding_columns[1]._get_sequence_dense_tensor( - _LazyBuilder({ - 'bbb': sparse_input_b - }))[0] + embedding_lookup_a = _get_sequence_dense_tensor( + shared_embedding_columns[0], {'aaa': sparse_input_a})[0] + embedding_lookup_b = _get_sequence_dense_tensor( + shared_embedding_columns[1], {'bbb': sparse_input_b})[0] global_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES) - self.assertItemsEqual(('embedding_weights:0',), + self.assertItemsEqual(('aaa_bbb_shared_embedding:0',), tuple([v.name for v in global_vars])) with monitored_session.MonitoredSession() as sess: self.assertAllEqual(embedding_values, global_vars[0].eval(session=sess)) @@ -1104,17 +1117,13 @@ class SequenceSharedEmbeddingColumnTest(test.TestCase): expected_sequence_length_b = [2, 1] categorical_column_b = sfc.sequence_categorical_column_with_identity( key='bbb', num_buckets=vocabulary_size) - shared_embedding_columns = fc.shared_embedding_columns( + shared_embedding_columns = fc.shared_embedding_columns_v2( [categorical_column_a, categorical_column_b], dimension=2) - sequence_length_a = shared_embedding_columns[0]._get_sequence_dense_tensor( - _LazyBuilder({ - 'aaa': sparse_input_a - }))[1] - sequence_length_b = shared_embedding_columns[1]._get_sequence_dense_tensor( - _LazyBuilder({ - 'bbb': sparse_input_b - }))[1] + sequence_length_a = _get_sequence_dense_tensor( + shared_embedding_columns[0], {'aaa': sparse_input_a})[1] + sequence_length_b = _get_sequence_dense_tensor( + shared_embedding_columns[1], {'bbb': sparse_input_b})[1] with monitored_session.MonitoredSession() as sess: sequence_length_a = sess.run(sequence_length_a) @@ -1155,17 +1164,13 @@ class SequenceSharedEmbeddingColumnTest(test.TestCase): categorical_column_b = sfc.sequence_categorical_column_with_identity( key='bbb', num_buckets=vocabulary_size) - shared_embedding_columns = fc.shared_embedding_columns( + shared_embedding_columns = fc.shared_embedding_columns_v2( [categorical_column_a, categorical_column_b], dimension=2) - sequence_length_a = shared_embedding_columns[0]._get_sequence_dense_tensor( - _LazyBuilder({ - 'aaa': sparse_input_a - }))[1] - sequence_length_b = shared_embedding_columns[1]._get_sequence_dense_tensor( - _LazyBuilder({ - 'bbb': sparse_input_b - }))[1] + sequence_length_a = _get_sequence_dense_tensor( + shared_embedding_columns[0], {'aaa': sparse_input_a})[1] + sequence_length_b = _get_sequence_dense_tensor( + shared_embedding_columns[1], {'bbb': sparse_input_b})[1] with monitored_session.MonitoredSession() as sess: self.assertAllEqual( @@ -1221,10 +1226,10 @@ class SequenceIndicatorColumnTest(test.TestCase, parameterized.TestCase): categorical_column = sfc.sequence_categorical_column_with_identity( key='aaa', num_buckets=vocabulary_size) - indicator_column = fc_old._indicator_column(categorical_column) + indicator_column = fc.indicator_column(categorical_column) - indicator_tensor, _ = indicator_column._get_sequence_dense_tensor( - _LazyBuilder({'aaa': inputs})) + indicator_tensor, _ = _get_sequence_dense_tensor( + indicator_column, {'aaa': inputs}) with monitored_session.MonitoredSession() as sess: self.assertAllEqual(expected, indicator_tensor.eval(session=sess)) @@ -1253,10 +1258,10 @@ class SequenceIndicatorColumnTest(test.TestCase, parameterized.TestCase): categorical_column = sfc.sequence_categorical_column_with_identity( key='aaa', num_buckets=vocabulary_size) - indicator_column = fc_old._indicator_column(categorical_column) + indicator_column = fc.indicator_column(categorical_column) - _, sequence_length = indicator_column._get_sequence_dense_tensor( - _LazyBuilder({'aaa': inputs})) + _, sequence_length = _get_sequence_dense_tensor( + indicator_column, {'aaa': inputs}) with monitored_session.MonitoredSession() as sess: sequence_length = sess.run(sequence_length) @@ -1282,19 +1287,14 @@ class SequenceIndicatorColumnTest(test.TestCase, parameterized.TestCase): key='aaa', num_buckets=vocabulary_size) indicator_column = fc.indicator_column(categorical_column) - _, sequence_length = indicator_column._get_sequence_dense_tensor( - _LazyBuilder({'aaa': sparse_input})) + _, sequence_length = _get_sequence_dense_tensor( + indicator_column, {'aaa': sparse_input}) with monitored_session.MonitoredSession() as sess: self.assertAllEqual( expected_sequence_length, sequence_length.eval(session=sess)) -def _get_sequence_dense_tensor(column, features): - return column.get_sequence_dense_tensor( - fc.FeatureTransformationCache(features), None) - - class SequenceNumericColumnTest(test.TestCase, parameterized.TestCase): def test_defaults(self): diff --git a/tensorflow/contrib/fused_conv/kernels/fused_conv2d_bias_activation_op.cc b/tensorflow/contrib/fused_conv/kernels/fused_conv2d_bias_activation_op.cc index 93b1aaa85e88e00c1b12a388321a4d6fb10f1611..c541c71f996c7a1b36cf28ae9a1783f8dca0a72c 100644 --- a/tensorflow/contrib/fused_conv/kernels/fused_conv2d_bias_activation_op.cc +++ b/tensorflow/contrib/fused_conv/kernels/fused_conv2d_bias_activation_op.cc @@ -522,7 +522,7 @@ void LaunchFusedConv2DBiasActivationOp:: auto bias_ptr = AsDeviceMemory(bias.template flat().data(), bias.template flat().size()); - static int64 ConvolveScratchSize = GetCudnnWorkspaceLimit( + static int64 ConvolveScratchSize = GetDnnWorkspaceLimit( // default value is in bytes despite the name of the environment variable "TF_CUDNN_WORKSPACE_LIMIT_IN_MB", 1LL << 32 // 4GB ); @@ -570,7 +570,7 @@ void LaunchFusedConv2DBiasActivationOp:: for (auto profile_algorithm : algorithms) { // TODO(zhengxq): profile each algorithm multiple times to better // accuracy. - CudnnScratchAllocator scratch_allocator(ConvolveScratchSize, ctx); + DnnScratchAllocator scratch_allocator(ConvolveScratchSize, ctx); dnn::ProfileResult profile_result; bool cudnn_launch_status = stream @@ -609,7 +609,7 @@ void LaunchFusedConv2DBiasActivationOp:: algorithm_config); } - CudnnScratchAllocator scratch_allocator(ConvolveScratchSize, ctx); + DnnScratchAllocator scratch_allocator(ConvolveScratchSize, ctx); bool cudnn_launch_status = stream ->ThenFusedConvolveWithAlgorithm( diff --git a/tensorflow/contrib/gdr/BUILD b/tensorflow/contrib/gdr/BUILD index e534fdc17749974ebe713c2730682bea6d7a85e4..704be917b3680a1b5712f4f1dc5059b354db8610 100644 --- a/tensorflow/contrib/gdr/BUILD +++ b/tensorflow/contrib/gdr/BUILD @@ -37,7 +37,7 @@ tf_proto_library_cc( ], ) -tf_cuda_library( +cc_library( name = "gdr_memory_manager", srcs = ["gdr_memory_manager.cc"], hdrs = ["gdr_memory_manager.h"], @@ -58,7 +58,7 @@ tf_cuda_library( ], ) -tf_cuda_library( +cc_library( name = "gdr_worker", srcs = ["gdr_worker.cc"], hdrs = ["gdr_worker.h"], diff --git a/tensorflow/contrib/gdr/gdr.proto b/tensorflow/contrib/gdr/gdr.proto index c0b89245b150bfa49cb527d25b6e1f324f353b25..bd438787c3374be6ead4f6233101fd1f548643ea 100644 --- a/tensorflow/contrib/gdr/gdr.proto +++ b/tensorflow/contrib/gdr/gdr.proto @@ -9,5 +9,4 @@ message RemoteMemoryRegion { uint64 addr = 3; uint32 rkey = 4; uint32 tensor_key = 5; - uint64 checksum = 6; } diff --git a/tensorflow/contrib/gdr/gdr_memory_manager.cc b/tensorflow/contrib/gdr/gdr_memory_manager.cc index 53587fcf3050f313c85485f77ce411cba7faccff..ce1875151597f926aeb6392e7fc8307312da123f 100644 --- a/tensorflow/contrib/gdr/gdr_memory_manager.cc +++ b/tensorflow/contrib/gdr/gdr_memory_manager.cc @@ -26,17 +26,14 @@ limitations under the License. #include #include #include -#include #include "tensorflow/contrib/gdr/gdr.pb.h" #include "tensorflow/core/common_runtime/device.h" #include "tensorflow/core/common_runtime/dma_helper.h" -#include "tensorflow/core/common_runtime/process_state.h" -#if GOOGLE_CUDA #include "tensorflow/core/common_runtime/gpu/gpu_process_state.h" -#include "tensorflow/core/common_runtime/gpu/gpu_util.h" -#endif // GOOGLE_CUDA +#include "tensorflow/core/common_runtime/process_state.h" #include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/random/random.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/platform/numa.h" @@ -81,10 +78,6 @@ int TryToReadNumaNode(ibv_device* device) { int32 value; if (strings::safe_strto32(content, &value)) { if (value < 0) { - LOG(INFO) << "Successful NUMA node read from SysFS had negative value (" - << value - << "), but there must be at least one NUMA node" - ", so returning NUMA node zero"; return port::kNUMANoAffinity; } LOG(INFO) << "NUMA node for device: " << device->name << " is " << value; @@ -114,7 +107,7 @@ class GdrMemoryManager : public RemoteMemoryManager { public: GdrMemoryManager(const string& host, const string& port); - virtual ~GdrMemoryManager(); + virtual ~GdrMemoryManager() {} virtual Status Init() override; @@ -140,7 +133,7 @@ class GdrMemoryManager : public RemoteMemoryManager { return ptr < reinterpret_cast(other->addr) + other->length; } - ibv_mr* FindMemoryRegion(void* addr, size_t length); + ibv_mr* FindMemoryRegion(const Tensor* tensor); void InsertMemoryRegion(void* addr, size_t length, const std::string& allocator_name); @@ -152,7 +145,6 @@ class GdrMemoryManager : public RemoteMemoryManager { const string port_; RdmaEndpointPtr listening_; std::atomic stopped_; - int epfd_; int numa_node_; // Server side endpoints @@ -163,15 +155,19 @@ class GdrMemoryManager : public RemoteMemoryManager { std::atomic next_key_; // Server side on-the-fly tensor buffers - mutex server_mu_; - std::map tensor_buffers_ - GUARDED_BY(server_mu_); + mutex buf_mu_; + std::map tensor_buffers_ GUARDED_BY(buf_mu_); // Client side endpoints mutex client_mu_; std::map, RdmaEndpointPtr> clients_ GUARDED_BY(client_mu_); + // Client side callbacks + mutex callback_mu_; + std::map tensor_callbacks_ + GUARDED_BY(callback_mu_); + // Managed memory regions mutex alloc_mu_; std::vector mrs_ GUARDED_BY(alloc_mu_); @@ -184,16 +180,9 @@ GdrMemoryManager::GdrMemoryManager(const string& host, const string& port) port_(port), listening_(nullptr, EndpointDeleter), stopped_(true), - next_key_(0) {} - -GdrMemoryManager::~GdrMemoryManager() { close(epfd_); } + next_key_(static_cast(random::New64())) {} Status GdrMemoryManager::Init() { - epfd_ = epoll_create1(0); - if (epfd_ == -1) { - return errors::Unavailable(strerror(errno), ": ", "epoll_create"); - } - rdma_addrinfo* addrinfo; rdma_addrinfo hints = {}; hints.ai_port_space = RDMA_PS_TCP; @@ -206,7 +195,7 @@ Status GdrMemoryManager::Init() { ibv_qp_init_attr init_attr = {}; init_attr.qp_type = IBV_QPT_RC; - init_attr.cap.max_recv_wr = 32; + init_attr.cap.max_recv_wr = 1024; init_attr.cap.max_send_wr = 1; init_attr.cap.max_recv_sge = 1; init_attr.cap.max_send_sge = 1; @@ -239,14 +228,6 @@ Status GdrMemoryManager::Init() { "cannot set server to non-blocking mode"); } - epoll_event event = {}; - event.events = EPOLLIN | EPOLLPRI; - event.data.ptr = listening_.get(); - if (epoll_ctl(epfd_, EPOLL_CTL_ADD, listening_->channel->fd, &event)) { - return errors::Unavailable(strerror(errno), ": ", - "cannot add server to epoll"); - } - numa_node_ = TryToReadNumaNode(listening_->verbs->device); SubAllocator::Visitor alloc_visitor = [this](void* ptr, int numa_node, @@ -265,121 +246,114 @@ Status GdrMemoryManager::Init() { ProcessState::singleton()->AddCPUFreeVisitor(free_visitor); LOG(INFO) << "Instrumenting CPU allocator(s)"; -#if GOOGLE_CUDA for (int numa_idx = 0; numa_idx < port::NUMANumNodes(); ++numa_idx) { GPUProcessState::singleton()->AddCUDAHostAllocVisitor(numa_idx, alloc_visitor); GPUProcessState::singleton()->AddCUDAHostFreeVisitor(numa_idx, free_visitor); } + if (IsGDRAvailable()) { SubAllocator::Visitor cuda_alloc_visitor = [this](void* ptr, int gpu_id, size_t num_bytes) { VLOG(2) << "Registering RDMA capable memory region on GPU " << gpu_id; InsertMemoryRegion(ptr, num_bytes, strings::StrCat("GPU:", gpu_id)); }; - for (int numa_idx = 0; numa_idx < port::NUMANumNodes(); ++numa_idx) { - GPUProcessState::singleton()->AddGPUAllocVisitor(numa_idx, - cuda_alloc_visitor); - } - VLOG(1) << "Instrumenting GPU allocator(s) for all Numas"; + GPUProcessState::singleton()->AddGPUAllocVisitor(numa_node_, + cuda_alloc_visitor); + LOG(INFO) << "Instrumenting GPU allocator for NUMA " << numa_node_; } -#endif // GOOGLE_CUDA + return Status::OK(); } void GdrMemoryManager::Run() { stopped_ = false; while (!stopped_) { - epoll_event events[32]; - int ret = epoll_wait(epfd_, events, 32, 1); - if (ret == -1) { - LOG(ERROR) << "epoll_wait: " << strerror(errno); - return; - } - for (int i = 0; i < ret; i++) { - rdma_cm_id* id = static_cast(events[i].data.ptr); - if (id == listening_.get()) { - // Accept incoming connections - if (!rdma_get_request(listening_.get(), &id)) { - if (!rdma_accept(id, nullptr)) { - LOG(INFO) << "Accepted new RDMA connection"; - if (ibv_req_notify_cq(id->recv_cq, 0)) { - LOG(ERROR) << strerror(errno) << ": ibv_req_notify_cq failed"; - EndpointDeleter(id); - continue; - } - for (int i = 0; i < 32; i++) { - if (rdma_post_recvv(id, nullptr, nullptr, 0)) { - LOG(ERROR) << strerror(errno) << ": rdma_post_recvv failed"; - EndpointDeleter(id); - continue; - } - } - int flags = fcntl(id->recv_cq_channel->fd, F_GETFL, 0); - if (fcntl(id->recv_cq_channel->fd, F_SETFL, flags | O_NONBLOCK)) { - LOG(ERROR) << strerror(errno) - << ": cannot set server_client to non-blocking mode"; - EndpointDeleter(id); - continue; - } - epoll_event event = {}; - event.events = EPOLLIN | EPOLLPRI; - event.data.ptr = id; - if (epoll_ctl(epfd_, EPOLL_CTL_ADD, id->recv_cq_channel->fd, - &event)) { - LOG(ERROR) << strerror(errno) - << ": cannot add server client to epoll"; - EndpointDeleter(id); - continue; - } - server_clients_.push_back({id, EndpointDeleter}); + rdma_cm_id* id = nullptr; + // Accept incoming connections + if (!rdma_get_request(listening_.get(), &id)) { + if (!rdma_accept(id, nullptr)) { + LOG(INFO) << "Accepted new RDMA connection"; + for (int i = 0; i < 1024; i++) { + if (rdma_post_recvv(id, nullptr, nullptr, 0)) { + LOG(ERROR) << strerror(errno) << ": rdma_post_recvv failed"; + EndpointDeleter(id); + continue; } } - } else { - // Polling work completions - ibv_cq* cq; - void* context; - if (!ibv_get_cq_event(id->recv_cq_channel, &cq, &context)) { - ibv_ack_cq_events(id->recv_cq, 1); - if (ibv_req_notify_cq(id->recv_cq, 0)) { - LOG(ERROR) << strerror(errno) << ": ibv_req_notify_cq failed"; - continue; + server_clients_.push_back({id, EndpointDeleter}); + } + } + // Polling server side work completions + for (const auto& client : server_clients_) { + ibv_wc wc[32]; + int ret = ibv_poll_cq(client->recv_cq, 32, wc); + if (ret < 0) { + LOG(ERROR) << "ibv_poll_cq failed"; + continue; + } + for (int i = 0; i < ret; i++) { + if (wc[i].opcode != IBV_WC_RECV_RDMA_WITH_IMM) { + LOG(ERROR) << "Received unknown operation " << wc[i].opcode; + } + if (wc[i].status != 0) { + LOG(ERROR) << ibv_wc_status_str(wc[i].status); + } + TensorKey tensor_key = ntohl(wc[i].imm_data); + + if (rdma_post_recvv(client.get(), nullptr, nullptr, 0)) { + perror("rdma_post_recvv"); + LOG(ERROR) << "rdma_post_recvv failed"; + } + + mutex_lock l(buf_mu_); + auto iter = tensor_buffers_.find(tensor_key); + if (iter == std::end(tensor_buffers_)) { + LOG(ERROR) << "Cannot find tensor buffer for tensor key " + << tensor_key; + } else { + const TensorBuffer* buffer = iter->second; + buffer->Unref(); + tensor_buffers_.erase(iter); + } + } + } + // Polling client side work completions + if (client_mu_.try_lock()) { + for (const auto& client : clients_) { + ibv_wc wc[32]; + int ret = ibv_poll_cq(client.second->send_cq, 32, wc); + for (int i = 0; i < ret; i++) { + Status s; + if (wc[i].status) { + s = errors::Unavailable(ibv_wc_status_str(wc[i].status)); + } else { + s = Status::OK(); } - ibv_wc wc[32]; - int ret = ibv_poll_cq(id->recv_cq, 32, wc); - if (ret < 0) { - LOG(ERROR) << "ibv_poll_cq failed"; - continue; + TensorKey key = wc[i].wr_id; + + ibv_send_wr wr = {}; + wr.opcode = IBV_WR_RDMA_WRITE_WITH_IMM; + wr.imm_data = htonl(key); + ibv_send_wr* bad_wr; + if (ibv_post_send(client.second->qp, &wr, &bad_wr)) { + LOG(ERROR) << strerror(errno) + << ": ibv_post_send failed for tensor_key " << key; } - for (int i = 0; i < ret; i++) { - if (wc[i].opcode != IBV_WC_RECV_RDMA_WITH_IMM) { - LOG(ERROR) << "Received unknown operation " << wc[i].opcode; - } - if (wc[i].status != 0) { - LOG(ERROR) << ibv_wc_status_str(wc[i].status); - } - TensorKey tensor_key = ntohl(wc[i].imm_data); - { - mutex_lock l(server_mu_); - auto iter = tensor_buffers_.find(tensor_key); - if (iter == std::end(tensor_buffers_)) { - LOG(ERROR) << "Cannot find tensor buffer for tensor key " - << tensor_key; - } else { - const TensorBuffer* buffer = iter->second; - buffer->Unref(); - tensor_buffers_.erase(iter); - } - } - if (rdma_post_recvv(id, nullptr, nullptr, 0)) { - perror("rdma_post_recvv"); - LOG(ERROR) << "rdma_post_recvv failed"; - continue; - } + + mutex_lock l(callback_mu_); + auto iter = tensor_callbacks_.find(key); + if (iter != std::end(tensor_callbacks_)) { + iter->second(s); + tensor_callbacks_.erase(iter); + } else { + LOG(WARNING) << "Cannot find client callback with tensor key " + << key; } } } + client_mu_.unlock(); } } } @@ -390,116 +364,58 @@ void GdrMemoryManager::TransportOptionsFromTensor( ::google::protobuf::Any* mutable_transport_options, const Tensor& tensor, Device* device, DeviceContext* device_context, bool on_host, StatusCallback done) { - auto buffer = DMAHelper::buffer(&tensor); - void* addr = buffer->data(); - size_t length = buffer->size(); - if (length == 0) { - done(errors::Unavailable("Cannot register tensor buffer of size 0")); - return; - } + ibv_mr* mr = FindMemoryRegion(&tensor); + const TensorBuffer* buffer = DMAHelper::buffer(&tensor); - ibv_mr* mr = FindMemoryRegion(addr, length); - -#if GOOGLE_CUDA - if (device->tensorflow_gpu_device_info() && !on_host) { - Allocator* alloc = GPUProcessState::singleton()->GetCUDAHostAllocator(0); - Tensor* host_copy = new Tensor(alloc, tensor.dtype(), tensor.shape()); - GPUUtil::CopyGPUTensorToCPU( - device, device_context, &tensor, host_copy, - [done, host_copy, mutable_transport_options, this](const Status& s) { - if (!s.ok()) { - done(s); - delete host_copy; - return; - } - auto buffer = DMAHelper::buffer(host_copy); - void* addr = buffer->data(); - size_t length = buffer->size(); - ibv_mr* mr = FindMemoryRegion(addr, length); - - if (mr == nullptr) { - done(errors::Unavailable("Cannot find pinned memory region")); - delete host_copy; - return; - } - - buffer->Ref(); - TensorKey tensor_key = next_key_++; - { - mutex_lock l(server_mu_); - tensor_buffers_.insert(std::make_pair(tensor_key, buffer)); - } - - uint64_t checksum = 0; - if (VLOG_IS_ON(2)) { - checksum = GPUUtil::Checksum(*host_copy); - } - - RemoteMemoryRegion remote_mr; - remote_mr.set_host(host_); - remote_mr.set_port(port_); - remote_mr.set_addr(reinterpret_cast(addr)); - remote_mr.set_rkey(mr->rkey); - remote_mr.set_tensor_key(tensor_key); - remote_mr.set_checksum(checksum); - mutable_transport_options->PackFrom(remote_mr); - - done(Status::OK()); - delete host_copy; - }); - return; - } -#endif + Tensor* copy = nullptr; if (mr == nullptr) { - Allocator* alloc = ProcessState::singleton()->GetCPUAllocator(numa_node_); - Tensor host_copy(alloc, tensor.dtype(), tensor.shape()); - - std::memcpy(DMAHelper::buffer(&host_copy)->data(), buffer->data(), length); - VLOG(2) << "Copying " << length << " bytes unpinned tensor buffer"; - - buffer = DMAHelper::buffer(&host_copy); - addr = buffer->data(); - length = buffer->size(); - - mr = FindMemoryRegion(addr, length); + AllocatorAttributes alloc_attrs; + alloc_attrs.set_gpu_compatible(true); + alloc_attrs.set_nic_compatible(true); + alloc_attrs.set_on_host(true); + Allocator* alloc = device->GetAllocator(alloc_attrs); + copy = new Tensor(alloc, tensor.dtype(), tensor.shape()); + + mr = FindMemoryRegion(copy); + buffer = DMAHelper::buffer(copy); if (mr == nullptr) { done(errors::Unavailable("Cannot find pinned memory region")); + delete copy; return; } - - buffer->Ref(); - } else { - buffer->Ref(); } TensorKey tensor_key = next_key_++; + buffer->Ref(); { - mutex_lock l(server_mu_); + mutex_lock l(buf_mu_); tensor_buffers_.insert(std::make_pair(tensor_key, buffer)); } - uint64_t checksum = 0; - if (VLOG_IS_ON(2)) { -#ifdef GOOGLE_CUDA - if (device->tensorflow_gpu_device_info() && !on_host) { - checksum = GPUUtil::Checksum(device, device_context, tensor); - } else { - checksum = GPUUtil::Checksum(tensor); - } -#endif - } - RemoteMemoryRegion remote_mr; remote_mr.set_host(host_); remote_mr.set_port(port_); - remote_mr.set_addr(reinterpret_cast(addr)); + remote_mr.set_addr(reinterpret_cast(buffer->data())); remote_mr.set_rkey(mr->rkey); remote_mr.set_tensor_key(tensor_key); - remote_mr.set_checksum(checksum); mutable_transport_options->PackFrom(remote_mr); - done(Status::OK()); + if (copy && device->tensorflow_gpu_device_info() && !on_host) { + device_context->CopyDeviceTensorToCPU(&tensor, "" /* tensor_name */, device, + copy, [done, copy](const Status& s) { + done(s); + delete copy; + }); + return; + } else if (copy) { + std::memcpy(buffer->data(), DMAHelper::buffer(&tensor)->data(), + buffer->size()); + done(Status::OK()); + delete copy; // OK to delete; we have reffed the underlying TensorBuffer + } else { + done(Status::OK()); + } } void GdrMemoryManager::TensorFromTransportOptions( @@ -512,42 +428,10 @@ void GdrMemoryManager::TensorFromTransportOptions( return; } - auto buffer = DMAHelper::buffer(tensor); - void* addr = buffer->data(); - size_t length = buffer->size(); - ibv_mr* mr = FindMemoryRegion(addr, length); - - Tensor host_copy; -#if GOOGLE_CUDA - if (mr == nullptr && !on_host) { - Allocator* alloc = - GPUProcessState::singleton()->GetCUDAHostAllocator(numa_node_); - host_copy = Tensor(alloc, tensor->dtype(), tensor->shape()); - buffer = DMAHelper::buffer(&host_copy); - addr = buffer->data(); - length = buffer->size(); - mr = FindMemoryRegion(addr, length); - } -#endif // GOOGLE_CUDA - - if (mr == nullptr) { - Allocator* alloc = ProcessState::singleton()->GetCPUAllocator(numa_node_); - host_copy = Tensor(alloc, tensor->dtype(), tensor->shape()); - - buffer = DMAHelper::buffer(&host_copy); - addr = buffer->data(); - length = buffer->size(); - - mr = FindMemoryRegion(addr, length); - if (mr == nullptr) { - done(errors::Unavailable("Cannot find pinned memory region")); - return; - } - } - - decltype(clients_)::iterator iter; - bool success; + rdma_cm_id* id = nullptr; { + decltype(clients_)::iterator iter; + bool success; mutex_lock l(client_mu_); std::tie(iter, success) = clients_.insert( std::make_pair(std::make_pair(remote_mr.host(), remote_mr.port()), @@ -560,93 +444,94 @@ void GdrMemoryManager::TensorFromTransportOptions( return; } } - } - rdma_cm_id* id = iter->second.get(); - - uint64_t start = Env::Default()->NowMicros(); - - if (rdma_post_read(id, nullptr, buffer->data(), buffer->size(), mr, 0, - remote_mr.addr(), remote_mr.rkey())) { - done(errors::Unavailable(strerror(errno), ": ", "rdma_post_read failed")); - return; + id = iter->second.get(); } - ibv_send_wr wr = {}; - wr.opcode = IBV_WR_RDMA_WRITE_WITH_IMM; - wr.imm_data = htonl(remote_mr.tensor_key()); - wr.send_flags = IBV_SEND_SIGNALED; - ibv_send_wr* bad_wr; - if (ibv_post_send(id->qp, &wr, &bad_wr)) { - done(errors::Unavailable(strerror(errno), ": ", "ibv_post_send failed")); - return; - } + ibv_mr* mr = FindMemoryRegion(tensor); + const TensorBuffer* buffer = DMAHelper::buffer(tensor); - ibv_wc wc = {}; - int ret; - while ((ret = ibv_poll_cq(id->send_cq, 1, &wc)) == 0) - ; - if (ret < 0 || wc.status) { - done(errors::Unavailable(ibv_wc_status_str(wc.status))); - return; - } + const Tensor* copy = nullptr; -#if GOOGLE_CUDA - if (device->tensorflow_gpu_device_info() && !on_host && - host_copy.NumElements() > 0) { - uint64_t checksum = 0; - if (VLOG_IS_ON(2)) { - checksum = GPUUtil::Checksum(host_copy); - CHECK(checksum == remote_mr.checksum()) - << "Checksum mismatch: " << checksum << "!=" << remote_mr.checksum(); + if (mr == nullptr) { + AllocatorAttributes alloc_attrs; + alloc_attrs.set_gpu_compatible(true); + alloc_attrs.set_nic_compatible(true); + alloc_attrs.set_on_host(true); + Allocator* alloc = device->GetAllocator(alloc_attrs); + copy = new Tensor(alloc, tensor->dtype(), tensor->shape()); + + mr = FindMemoryRegion(copy); + buffer = DMAHelper::buffer(copy); + if (mr == nullptr) { + done(errors::Unavailable("Cannot find pinned memory region")); + delete copy; + return; } - Tensor* ref = new Tensor; - std::swap(host_copy, *ref); - GPUUtil::CopyCPUTensorToGPU( - ref, device_context, device, tensor, - [ref, done, buffer, remote_mr, start](const Status& s) { - if (!s.ok()) { - done(s); - delete ref; - return; - } - uint64_t end = Env::Default()->NowMicros(); - - VLOG(2) << "RDMA from remote memory region " << remote_mr.rkey() - << " of size " << buffer->size() << " with tensor key " - << remote_mr.tensor_key() << " took " << (end - start) - << " micros"; - done(Status::OK()); - delete ref; - }); - return; } -#endif // GOOGLE_CUDA - if ((on_host || !device->tensorflow_gpu_device_info()) && - host_copy.NumElements() > 0) { - std::memcpy(DMAHelper::buffer(tensor)->data(), addr, length); - VLOG(2) << "Copying " << length << " bytes unpinned tensor buffer"; - } + uint64_t start = Env::Default()->NowMicros(); - uint64_t end = Env::Default()->NowMicros(); + TensorKey tensor_key = remote_mr.tensor_key(); - VLOG(2) << "RDMA from remote memory region " << remote_mr.rkey() - << " of size " << buffer->size() << " with tensor key " - << remote_mr.tensor_key() << " took " << (end - start) << " micros"; + StatusCallback callback = [done, copy, device, device_context, on_host, + tensor, start, tensor_key](const Status& s) { + if (!s.ok()) { + done(s); + if (copy) { + delete copy; + } + return; + } - uint64_t checksum = 0; - if (VLOG_IS_ON(2)) { -#ifdef GOOGLE_CUDA - if (device->tensorflow_gpu_device_info() && !on_host) { - checksum = GPUUtil::Checksum(device, device_context, *tensor); + VLOG(2) << "RDMA of tensor " << tensor_key << " of size " + << DMAHelper::buffer(tensor)->size() << " took " + << (Env::Default()->NowMicros() - start) << " micros"; + + if (copy && device->tensorflow_gpu_device_info() && !on_host) { + device_context->CopyCPUTensorToDevice(copy, device, tensor, + [done, copy](const Status& s) { + done(s); + delete copy; + }); + } else if (copy) { + std::memcpy(DMAHelper::buffer(tensor)->data(), + DMAHelper::buffer(copy)->data(), + DMAHelper::buffer(copy)->size()); + done(s); + delete copy; } else { - checksum = GPUUtil::Checksum(*tensor); + done(s); + } + }; + + { + mutex_lock l(callback_mu_); + if (tensor_callbacks_.find(tensor_key) == std::end(tensor_callbacks_)) { + tensor_callbacks_.insert(std::make_pair(tensor_key, std::move(callback))); + } else { + done(errors::Unavailable("Received duplicated tensor key")); + if (copy) { + delete copy; + } + return; + } + } + + if (rdma_post_read(id, reinterpret_cast(tensor_key), buffer->data(), + buffer->size(), mr, IBV_SEND_SIGNALED, remote_mr.addr(), + remote_mr.rkey())) { + done(errors::Unavailable(strerror(errno), ": ", "rdma_post_read failed")); + { + mutex_lock l(callback_mu_); + auto iter = tensor_callbacks_.find(tensor_key); + if (iter != std::end(tensor_callbacks_)) { + tensor_callbacks_.erase(iter); + } + } + if (copy) { + delete copy; } - CHECK(checksum == remote_mr.checksum()) - << "Checksum mismatch: " << checksum << "!=" << remote_mr.checksum(); -#endif } - done(Status::OK()); } Status GdrMemoryManager::CreateEndpoint(const string& host, const string& port, @@ -663,7 +548,7 @@ Status GdrMemoryManager::CreateEndpoint(const string& host, const string& port, ibv_qp_init_attr init_attr = {}; init_attr.qp_type = IBV_QPT_RC; init_attr.cap.max_recv_wr = 1; - init_attr.cap.max_send_wr = 32; + init_attr.cap.max_send_wr = 1024; init_attr.cap.max_recv_sge = 1; init_attr.cap.max_send_sge = 1; @@ -687,8 +572,8 @@ Status GdrMemoryManager::CreateEndpoint(const string& host, const string& port, return Status::OK(); } -ibv_mr* GdrMemoryManager::FindMemoryRegion(void* addr, size_t length) { - if (length == 0) return nullptr; +ibv_mr* GdrMemoryManager::FindMemoryRegion(const Tensor* tensor) { + const void* addr = DMAHelper::buffer(tensor)->data(); mutex_lock l(alloc_mu_); auto iter = std::upper_bound(mrs_.begin(), mrs_.end(), addr, &Comparator); if (iter == std::end(mrs_) || iter->get()->addr > addr) { diff --git a/tensorflow/contrib/gdr/gdr_rendezvous_mgr.cc b/tensorflow/contrib/gdr/gdr_rendezvous_mgr.cc index fbccbead03fc0d641db40ede661bf3677d44c45d..5f8c300155770ed03ad12a9fa5ac74456edaf024 100644 --- a/tensorflow/contrib/gdr/gdr_rendezvous_mgr.cc +++ b/tensorflow/contrib/gdr/gdr_rendezvous_mgr.cc @@ -58,11 +58,9 @@ class GdrRecvTensorCall : public BaseRecvTensorCall { resp_.InitAlloc(dst_device_, recv_args_.alloc_attrs); StatusCallback cb = [this, recv_done](const Status& s) { bool dma_ok = resp_.metadata().has_transport_options(); - if (s.ok() && tensor().TotalBytes() > 0 && (!is_dead()) && dma_ok) { + if (s.ok() && tensor().TotalBytes() > 1024 && (!is_dead()) && dma_ok) { auto transport_options = resp_.metadata().transport_options(); - const bool on_host = - (dst_device_->tensorflow_gpu_device_info() == nullptr) || - recv_args_.alloc_attrs.on_host(); + const bool on_host = recv_args_.alloc_attrs.on_host(); remote_memory_manager_->TensorFromTransportOptions( const_cast(&tensor()), transport_options, dst_device_, recv_args_.device_context, on_host, @@ -70,9 +68,6 @@ class GdrRecvTensorCall : public BaseRecvTensorCall { if (!s.ok()) { mutex_lock l(mu_); status_.Update(s); - LOG(ERROR) << "Cannot find pinned memory region from allocator " - << dst_device_->GetAllocator(recv_args_.alloc_attrs) - ->Name(); } recv_done(); }); diff --git a/tensorflow/contrib/gdr/gdr_server_lib.cc b/tensorflow/contrib/gdr/gdr_server_lib.cc index b3f48ec1dd9c75055f4e1ea76eb203b6ccf94718..dc0d5d548b80d36409778ef34e63171441f10142 100644 --- a/tensorflow/contrib/gdr/gdr_server_lib.cc +++ b/tensorflow/contrib/gdr/gdr_server_lib.cc @@ -74,9 +74,8 @@ Status GdrServer::Start() { } Status GdrServer::Stop() { - TF_RETURN_IF_ERROR(GrpcServer::Stop()); remote_memory_manager_->Stop(); - return Status::OK(); + return GrpcServer::Stop(); } Status GdrServer::Join() { diff --git a/tensorflow/contrib/gdr/gdr_worker.cc b/tensorflow/contrib/gdr/gdr_worker.cc index 867cb83f42034c8e9061e333ea671457745f92c3..016e5ea27b397830c69b6e1761b5994ebcfa9c3d 100644 --- a/tensorflow/contrib/gdr/gdr_worker.cc +++ b/tensorflow/contrib/gdr/gdr_worker.cc @@ -18,9 +18,6 @@ limitations under the License. #include "tensorflow/core/common_runtime/device.h" #include "tensorflow/core/common_runtime/device_mgr.h" #include "tensorflow/core/common_runtime/dma_helper.h" -#if GOOGLE_CUDA -#include "tensorflow/core/common_runtime/gpu/gpu_util.h" -#endif // GOOGLE_CUDA #include "tensorflow/core/common_runtime/process_util.h" #include "tensorflow/core/common_runtime/step_stats_collector.h" #include "tensorflow/core/distributed_runtime/graph_mgr.h" @@ -78,7 +75,7 @@ void GdrWorker::GrpcRecvTensorAsync(CallOptions* opts, const bool dma_ok = request->dma_ok(); env_->rendezvous_mgr->RecvLocalAsync( step_id, parsed, - [this, opts, response, done, src_dev, dma_ok]( + [this, opts, response, done, src_dev, request, dma_ok]( const Status& status, const Rendezvous::Args& send_args, const Rendezvous::Args&, const Tensor& val, const bool is_dead) { opts->ClearCancelCallback(); @@ -89,10 +86,8 @@ void GdrWorker::GrpcRecvTensorAsync(CallOptions* opts, // 3) the tensor has the on_host allocation attribute, // i.e. it's in CPU RAM *independent of its assigned // device type*. - const bool on_host = - (src_dev->tensorflow_gpu_device_info() == nullptr) || - send_args.alloc_attrs.on_host(); - if (val.TotalBytes() > 0 && (!is_dead) && + const bool on_host = send_args.alloc_attrs.on_host(); + if (val.TotalBytes() > 1024 && (!is_dead) && DMAHelper::CanUseDMA(&val) && dma_ok) { // DMA cases. RecvTensorResponse* proto = new RecvTensorResponse; @@ -117,8 +112,7 @@ void GdrWorker::GrpcRecvTensorAsync(CallOptions* opts, } else { // Non-DMA cases. if (src_dev->tensorflow_gpu_device_info() && (!on_host)) { -#if GOOGLE_CUDA - const DeviceContext* send_dev_context = send_args.device_context; + DeviceContext* send_dev_context = send_args.device_context; AllocatorAttributes alloc_attrs; alloc_attrs.set_gpu_compatible(true); alloc_attrs.set_on_host(true); @@ -127,7 +121,8 @@ void GdrWorker::GrpcRecvTensorAsync(CallOptions* opts, CHECK(send_dev_context) << "send dev name: " << src_dev->name() << " gpu_info: " << src_dev->tensorflow_gpu_device_info(); - // "val" is on a GPU. Uses GPUUtil to fill the response proto. + // "val" is on an accelerator device. Uses the device_context to + // fill the copy on host. StatusCallback copy_ready = [response, done, copy, is_dead](const Status& s) { // The value is now ready to be returned on the wire. @@ -136,11 +131,8 @@ void GdrWorker::GrpcRecvTensorAsync(CallOptions* opts, delete copy; }; - GPUUtil::CopyGPUTensorToCPU(src_dev, send_dev_context, &val, copy, - copy_ready); -#else - done(errors::Internal("No GPU device in process")); -#endif // GOOGLE_CUDA + send_dev_context->CopyDeviceTensorToCPU( + &val, request->rendezvous_key(), src_dev, copy, copy_ready); } else { grpc::EncodeTensorToByteBuffer(is_dead, val, response); done(Status::OK()); diff --git a/tensorflow/contrib/hadoop/python/ops/hadoop_dataset_ops.py b/tensorflow/contrib/hadoop/python/ops/hadoop_dataset_ops.py index 5c5599858ee6879a5703d65658bf4bbd881c7e72..77813519c136665a2fea30d4387f5e7a9776b20b 100644 --- a/tensorflow/contrib/hadoop/python/ops/hadoop_dataset_ops.py +++ b/tensorflow/contrib/hadoop/python/ops/hadoop_dataset_ops.py @@ -23,11 +23,16 @@ from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.util import structure from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops +from tensorflow.python.util import deprecation class SequenceFileDataset(dataset_ops.DatasetSource): """A Sequence File Dataset that reads the sequence file.""" + @deprecation.deprecated( + None, + "tf.contrib.hadoop will be removed in 2.0, the support for Apache Hadoop " + "will continue to be provided through the tensorflow/io GitHub project.") def __init__(self, filenames): """Create a `SequenceFileDataset`. diff --git a/tensorflow/contrib/ignite/python/ops/ignite_dataset_ops.py b/tensorflow/contrib/ignite/python/ops/ignite_dataset_ops.py index e4762c91b193f9c5e32fa2642e702e61e8e5e57f..66e654ca636a5a051c6f9cd35bf9001dfbcbf7f4 100644 --- a/tensorflow/contrib/ignite/python/ops/ignite_dataset_ops.py +++ b/tensorflow/contrib/ignite/python/ops/ignite_dataset_ops.py @@ -31,6 +31,7 @@ from tensorflow.python.data.util import structure from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape +from tensorflow.python.util import deprecation @six.add_metaclass(abc.ABCMeta) @@ -699,6 +700,10 @@ class IgniteDataset(dataset_ops.DatasetSource): Ignite Binary Client Protocol. """ + @deprecation.deprecated( + None, + "tf.contrib.ignite will be removed in 2.0, the support for Apache Ignite " + "will continue to be provided through the tensorflow/io GitHub project.") def __init__(self, cache_name, host="localhost", diff --git a/tensorflow/contrib/kafka/python/ops/kafka_dataset_ops.py b/tensorflow/contrib/kafka/python/ops/kafka_dataset_ops.py index 2b86331099ccae03664462987ee0c141d766c10f..b399e1b6c2ac47db205b5d8bbc81875ef5c08a31 100644 --- a/tensorflow/contrib/kafka/python/ops/kafka_dataset_ops.py +++ b/tensorflow/contrib/kafka/python/ops/kafka_dataset_ops.py @@ -23,12 +23,17 @@ from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.util import structure from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops +from tensorflow.python.util import deprecation class KafkaDataset(dataset_ops.DatasetSource): """A Kafka Dataset that consumes the message. """ + @deprecation.deprecated( + None, + "tf.contrib.kafka will be removed in 2.0, the support for Apache Kafka " + "will continue to be provided through the tensorflow/io GitHub project.") def __init__(self, topics, servers="localhost", diff --git a/tensorflow/contrib/kinesis/python/ops/kinesis_dataset_ops.py b/tensorflow/contrib/kinesis/python/ops/kinesis_dataset_ops.py index 20395395281768ac429984a1e3552cfd187527a2..2b1d478a9b0fd12ca25c72da6872acccfd7285fc 100644 --- a/tensorflow/contrib/kinesis/python/ops/kinesis_dataset_ops.py +++ b/tensorflow/contrib/kinesis/python/ops/kinesis_dataset_ops.py @@ -23,6 +23,7 @@ from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.util import structure from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops +from tensorflow.python.util import deprecation class KinesisDataset(dataset_ops.DatasetSource): @@ -50,6 +51,10 @@ class KinesisDataset(dataset_ops.DatasetSource): is returned immediately instead. """ + @deprecation.deprecated( + None, + "tf.contrib.kinesis will be removed in 2.0, the support for Kinesis " + "will continue to be provided through the tensorflow/io GitHub project.") def __init__(self, stream, shard="", diff --git a/tensorflow/contrib/lookup/lookup_ops.py b/tensorflow/contrib/lookup/lookup_ops.py index e52fb5ab1431e086f99b4033a6216636a83bad79..229a72a780d5ccce8263444ffeae7700f6ac8613 100644 --- a/tensorflow/contrib/lookup/lookup_ops.py +++ b/tensorflow/contrib/lookup/lookup_ops.py @@ -91,7 +91,7 @@ def index_table_from_tensor(mapping, The bucket ID range is `[mapping size, mapping size + num_oov_buckets - 1]`. The underlying table must be initialized by calling - `tf.tables_initializer.run()` or `table.initializer.run()` once. + `session.run(tf.tables_initializer)` or `session.run(table.init)` once. Elements in `mapping` cannot have duplicates, otherwise when executing the table initializer op, it will throw a `FailedPreconditionError`. @@ -158,7 +158,7 @@ def string_to_index(tensor, mapping, default_value=-1, name=None): will throw a FailedPreconditionError. The underlying table must be initialized by calling - `tf.tables_initializer.run()` once. + `session.run(tf.tables_initializer)` once. For example: @@ -202,7 +202,7 @@ def index_to_string_table_from_tensor(mapping, default_value="UNK", name=None): (an out-of-vocabulary entry) is assigned the `default_value` The underlying table must be initialized by calling - `tf.tables_initializer.run()` or `table.initializer.run()` once. + `session.run(tf.tables_initializer)` or `session.run(table.init)` once. Elements in `mapping` cannot have duplicates, otherwise when executing the table initializer op, it will throw a `FailedPreconditionError`. @@ -257,7 +257,7 @@ def index_to_string(tensor, mapping, default_value="UNK", name=None): (an out-of-vocabulary entry) is assigned the `default_value` The underlying table must be initialized by calling - `tf.tables_initializer.run()` once. + `session.run(tf.tables_initializer)` once. For example: diff --git a/tensorflow/contrib/model_pruning/python/pruning.py b/tensorflow/contrib/model_pruning/python/pruning.py index f6b4373edd0544555dd16a373802d2feb5d674b1..43ea66ac5a178f6ffe87df99ddced3d0442111c1 100644 --- a/tensorflow/contrib/model_pruning/python/pruning.py +++ b/tensorflow/contrib/model_pruning/python/pruning.py @@ -214,7 +214,7 @@ def get_pruning_hparams(): target_sparsity=0.5, sparsity_function_begin_step=0, sparsity_function_end_step=100, - sparsity_function_exponent=3, + sparsity_function_exponent=3.0, use_tpu=False) diff --git a/tensorflow/contrib/opt/BUILD b/tensorflow/contrib/opt/BUILD index f4ac70eb1a720c2acc3ef942f269228156749cba..0446e823d95f8ecbed6a0c34a83ade009e68448b 100644 --- a/tensorflow/contrib/opt/BUILD +++ b/tensorflow/contrib/opt/BUILD @@ -14,6 +14,7 @@ py_library( name = "opt_py", srcs = [ "__init__.py", + "python/training/adam_gs_optimizer.py", "python/training/adamax.py", "python/training/addsign.py", "python/training/agn_optimizer.py", @@ -22,6 +23,7 @@ py_library( "python/training/external_optimizer.py", "python/training/ggt.py", "python/training/lars_optimizer.py", + "python/training/lazy_adam_gs_optimizer.py", "python/training/lazy_adam_optimizer.py", "python/training/matrix_functions.py", "python/training/model_average_optimizer.py", @@ -60,6 +62,21 @@ py_library( ], ) +py_test( + name = "adam_gs_optimizer_test", + srcs = ["python/training/adam_gs_optimizer_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":opt_py", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:math_ops", + "//tensorflow/python:training", + "//third_party/py/numpy", + ], +) + py_test( name = "adamax_test", srcs = ["python/training/adamax_test.py"], @@ -148,6 +165,25 @@ py_test( ], ) +py_test( + name = "lazy_adam_gs_optimizer_test", + srcs = ["python/training/lazy_adam_gs_optimizer_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":opt_py", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:constant_op", + "//tensorflow/python:dtypes", + "//tensorflow/python:framework_ops", + "//tensorflow/python:math_ops", + "//tensorflow/python:resource_variable_ops", + "//tensorflow/python:variables", + "//third_party/py/numpy", + "@absl_py//absl/testing:parameterized", + ], +) + py_test( name = "lazy_adam_optimizer_test", srcs = ["python/training/lazy_adam_optimizer_test.py"], diff --git a/tensorflow/contrib/opt/__init__.py b/tensorflow/contrib/opt/__init__.py index c7ea68efa9a13a471bba3f41d0600855793b20a2..e8fc52342ceabb47da97ca0f3c8a01e419a221a1 100644 --- a/tensorflow/contrib/opt/__init__.py +++ b/tensorflow/contrib/opt/__init__.py @@ -19,6 +19,7 @@ from __future__ import division from __future__ import print_function # pylint: disable=wildcard-import +from tensorflow.contrib.opt.python.training.adam_gs_optimizer import * from tensorflow.contrib.opt.python.training.adamax import * from tensorflow.contrib.opt.python.training.addsign import * from tensorflow.contrib.opt.python.training.agn_optimizer import * @@ -28,6 +29,7 @@ from tensorflow.contrib.opt.python.training.external_optimizer import * from tensorflow.contrib.opt.python.training.lars_optimizer import * from tensorflow.contrib.opt.python.training.ggt import * from tensorflow.contrib.opt.python.training.lazy_adam_optimizer import * +from tensorflow.contrib.opt.python.training.lazy_adam_gs_optimizer import * from tensorflow.contrib.opt.python.training.model_average_optimizer import * from tensorflow.contrib.opt.python.training.moving_average_optimizer import * from tensorflow.contrib.opt.python.training.multitask_optimizer_wrapper import * @@ -44,12 +46,14 @@ from tensorflow.python.util.all_util import remove_undocumented _allowed_symbols = [ 'AdaMaxOptimizer', + 'AdamGSOptimizer', 'PowerSignOptimizer', 'AddSignOptimizer', 'DelayCompensatedGradientDescentOptimizer', 'DropStaleGradientOptimizer', 'ExternalOptimizerInterface', 'LARSOptimizer', + 'LazyAdamGSOptimizer', 'LazyAdamOptimizer', 'NadamOptimizer', 'MovingAverageOptimizer', diff --git a/tensorflow/contrib/opt/python/training/adam_gs_optimizer.py b/tensorflow/contrib/opt/python/training/adam_gs_optimizer.py new file mode 100644 index 0000000000000000000000000000000000000000..3fb649ea82e79b3bc78a2da6d5c3e9a071adec6d --- /dev/null +++ b/tensorflow/contrib/opt/python/training/adam_gs_optimizer.py @@ -0,0 +1,217 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Adam rewrite to use global step for computing beta1 & beta2 accumulation.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import resource_variable_ops +from tensorflow.python.ops import state_ops +from tensorflow.python.training import optimizer +from tensorflow.python.training import training_ops +from tensorflow.python.util.tf_export import tf_export + + +@tf_export("train.AdamOptimizer") +class AdamGSOptimizer(optimizer.Optimizer): + """Optimizer that implements the Adam algorithm. + + See [Kingma et al., 2014](http://arxiv.org/abs/1412.6980) + ([pdf](http://arxiv.org/pdf/1412.6980.pdf)). + """ + + def __init__(self, global_step=0, learning_rate=0.001, + beta1=0.9, beta2=0.999, epsilon=1e-8, + use_locking=False, name="Adam"): + """Construct a new Adam optimizer. + + Branched from tf.train.AdamOptimizer. The only difference is to pass + global step for computing beta1 and beta2 accumulators, instead of having + optimizer keep its own independent beta1 and beta2 accumulators as non-slot + variables. + + Initialization: + + $$m_0 := 0 \text{(Initialize initial 1st moment vector)}$$ + $$v_0 := 0 \text{(Initialize initial 2nd moment vector)}$$ + $$t := 0 \text{(Initialize timestep)}$$ + + The update rule for `variable` with gradient `g` uses an optimization + described at the end of section2 of the paper: + + $$t := t + 1$$ + $$lr_t := \text{learning\_rate} * \sqrt{1 - beta_2^t} / (1 - beta_1^t)$$ + + $$m_t := beta_1 * m_{t-1} + (1 - beta_1) * g$$ + $$v_t := beta_2 * v_{t-1} + (1 - beta_2) * g * g$$ + $$variable := variable - lr_t * m_t / (\sqrt{v_t} + \epsilon)$$ + + The default value of 1e-8 for epsilon might not be a good default in + general. For example, when training an Inception network on ImageNet a + current good choice is 1.0 or 0.1. Note that since AdamOptimizer uses the + formulation just before Section 2.1 of the Kingma and Ba paper rather than + the formulation in Algorithm 1, the "epsilon" referred to here is "epsilon + hat" in the paper. + + The sparse implementation of this algorithm (used when the gradient is an + IndexedSlices object, typically because of `tf.gather` or an embedding + lookup in the forward pass) does apply momentum to variable slices even if + they were not used in the forward pass (meaning they have a gradient equal + to zero). Momentum decay (beta1) is also applied to the entire momentum + accumulator. This means that the sparse behavior is equivalent to the dense + behavior (in contrast to some momentum implementations which ignore momentum + unless a variable slice was actually used). + + Args: + global_step: tensorflow variable indicating the step. + learning_rate: A Tensor or a floating point value. The learning rate. + beta1: A float value or a constant float tensor. + The exponential decay rate for the 1st moment estimates. + beta2: A float value or a constant float tensor. + The exponential decay rate for the 2nd moment estimates. + epsilon: A small constant for numerical stability. This epsilon is + "epsilon hat" in the Kingma and Ba paper (in the formula just before + Section 2.1), not the epsilon in Algorithm 1 of the paper. + use_locking: If True use locks for update operations. + name: Optional name for the operations created when applying gradients. + Defaults to "Adam". + + @compatibility(eager) + When eager execution is enabled, `learning_rate`, `beta1`, `beta2`, and + `epsilon` can each be a callable that takes no arguments and returns the + actual value to use. This can be useful for changing these values across + different invocations of optimizer functions. + @end_compatibility + """ + super(AdamGSOptimizer, self).__init__(use_locking, name) + self._lr = learning_rate + self._beta1 = beta1 + self._beta2 = beta2 + self._epsilon = epsilon + self._global_step = global_step + self._global_step_on_worker = None + + # Tensor versions of the constructor arguments, created in _prepare(). + self._lr_t = None + self._beta1_t = None + self._beta2_t = None + self._epsilon_t = None + + # Created in SparseApply if needed. + self._updated_lr = None + + def _get_beta_accumulators(self): + return (math_ops.pow(self._beta1_t, self._global_step_on_worker), + math_ops.pow(self._beta2_t, self._global_step_on_worker)) + + def _create_slots(self, var_list): + # Create slots for the first and second moments. + for v in var_list: + self._zeros_slot(v, "m", self._name) + self._zeros_slot(v, "v", self._name) + + def _prepare(self): + lr = self._call_if_callable(self._lr) + beta1 = self._call_if_callable(self._beta1) + beta2 = self._call_if_callable(self._beta2) + epsilon = self._call_if_callable(self._epsilon) + + self._lr_t = ops.convert_to_tensor(lr, name="learning_rate") + self._beta1_t = ops.convert_to_tensor(beta1, name="beta1") + self._beta2_t = ops.convert_to_tensor(beta2, name="beta2") + self._epsilon_t = ops.convert_to_tensor(epsilon, name="epsilon") + + # Performance optimization so that worker creates a copy of the global step + # to avoid overloading the parameter server holding the global step. + self._global_step_on_worker = math_ops.cast( + array_ops.identity(self._global_step) + 1, dtypes.float32) + + def _apply_dense(self, grad, var): + m = self.get_slot(var, "m") + v = self.get_slot(var, "v") + beta1_power, beta2_power = self._get_beta_accumulators() + return training_ops.apply_adam( + var, m, v, + math_ops.cast(beta1_power, var.dtype.base_dtype), + math_ops.cast(beta2_power, var.dtype.base_dtype), + math_ops.cast(self._lr_t, var.dtype.base_dtype), + math_ops.cast(self._beta1_t, var.dtype.base_dtype), + math_ops.cast(self._beta2_t, var.dtype.base_dtype), + math_ops.cast(self._epsilon_t, var.dtype.base_dtype), + grad, use_locking=self._use_locking).op + + def _resource_apply_dense(self, grad, var): + m = self.get_slot(var, "m") + v = self.get_slot(var, "v") + beta1_power, beta2_power = self._get_beta_accumulators() + return training_ops.resource_apply_adam( + var.handle, m.handle, v.handle, + math_ops.cast(beta1_power, grad.dtype.base_dtype), + math_ops.cast(beta2_power, grad.dtype.base_dtype), + math_ops.cast(self._lr_t, grad.dtype.base_dtype), + math_ops.cast(self._beta1_t, grad.dtype.base_dtype), + math_ops.cast(self._beta2_t, grad.dtype.base_dtype), + math_ops.cast(self._epsilon_t, grad.dtype.base_dtype), + grad, use_locking=self._use_locking) + + def _apply_sparse_shared(self, grad, var, indices, scatter_add): + beta1_power, beta2_power = self._get_beta_accumulators() + beta1_power = math_ops.cast(beta1_power, var.dtype.base_dtype) + beta2_power = math_ops.cast(beta2_power, var.dtype.base_dtype) + lr_t = math_ops.cast(self._lr_t, var.dtype.base_dtype) + beta1_t = math_ops.cast(self._beta1_t, var.dtype.base_dtype) + beta2_t = math_ops.cast(self._beta2_t, var.dtype.base_dtype) + epsilon_t = math_ops.cast(self._epsilon_t, var.dtype.base_dtype) + lr = (lr_t * math_ops.sqrt(1 - beta2_power) / (1 - beta1_power)) + # m_t = beta1 * m + (1 - beta1) * g_t + m = self.get_slot(var, "m") + m_scaled_g_values = grad * (1 - beta1_t) + m_t = state_ops.assign(m, m * beta1_t, + use_locking=self._use_locking) + with ops.control_dependencies([m_t]): + m_t = scatter_add(m, indices, m_scaled_g_values) + # v_t = beta2 * v + (1 - beta2) * (g_t * g_t) + v = self.get_slot(var, "v") + v_scaled_g_values = (grad * grad) * (1 - beta2_t) + v_t = state_ops.assign(v, v * beta2_t, use_locking=self._use_locking) + with ops.control_dependencies([v_t]): + v_t = scatter_add(v, indices, v_scaled_g_values) + v_sqrt = math_ops.sqrt(v_t) + var_update = state_ops.assign_sub(var, + lr * m_t / (v_sqrt + epsilon_t), + use_locking=self._use_locking) + return control_flow_ops.group(*[var_update, m_t, v_t]) + + def _apply_sparse(self, grad, var): + return self._apply_sparse_shared( + grad.values, var, grad.indices, + lambda x, i, v: state_ops.scatter_add( # pylint: disable=g-long-lambda + x, i, v, use_locking=self._use_locking)) + + def _resource_scatter_add(self, x, i, v): + with ops.control_dependencies( + [resource_variable_ops.resource_scatter_add( + x.handle, i, v)]): + return x.value() + + def _resource_apply_sparse(self, grad, var, indices): + return self._apply_sparse_shared( + grad, var, indices, self._resource_scatter_add) diff --git a/tensorflow/contrib/opt/python/training/adam_gs_optimizer_test.py b/tensorflow/contrib/opt/python/training/adam_gs_optimizer_test.py new file mode 100644 index 0000000000000000000000000000000000000000..c68c965aef3729bebe7d0e0dd707c344321d9e3f --- /dev/null +++ b/tensorflow/contrib/opt/python/training/adam_gs_optimizer_test.py @@ -0,0 +1,382 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Tests for AdamGS.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.contrib.opt.python.training import adam_gs_optimizer +from tensorflow.python.client import session +from tensorflow.python.eager import context +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.framework import test_util +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import resource_variable_ops +from tensorflow.python.ops import variables +from tensorflow.python.platform import test + + +def adam_update_numpy(param, + g_t, + t, + m, + v, + alpha=0.001, + beta1=0.9, + beta2=0.999, + epsilon=1e-8): + alpha_t = alpha * np.sqrt(1 - beta2**t) / (1 - beta1**t) + + m_t = beta1 * m + (1 - beta1) * g_t + v_t = beta2 * v + (1 - beta2) * g_t * g_t + + param_t = param - alpha_t * m_t / (np.sqrt(v_t) + epsilon) + return param_t, m_t, v_t + + +class AdamGSOptimizerTest(test.TestCase): + + def doTestSparse(self, use_resource=False): + for dtype in [dtypes.half, dtypes.float32, dtypes.float64]: + with self.cached_session(): + # Initialize variables for numpy implementation. + m0, v0, m1, v1 = 0.0, 0.0, 0.0, 0.0 + var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype) + grads0_np = np.array([0.1, 0.1], dtype=dtype.as_numpy_dtype) + var1_np = np.array([3.0, 4.0], dtype=dtype.as_numpy_dtype) + grads1_np = np.array([0.01, 0.01], dtype=dtype.as_numpy_dtype) + + if use_resource: + global_step = resource_variable_ops.ResourceVariable( + array_ops.zeros([], dtypes.int64)) + var0 = resource_variable_ops.ResourceVariable(var0_np) + var1 = resource_variable_ops.ResourceVariable(var1_np) + else: + global_step = variables.Variable(array_ops.zeros([], dtypes.int64)) + var0 = variables.Variable(var0_np) + var1 = variables.Variable(var1_np) + grads0_np_indices = np.array([0, 1], dtype=np.int32) + grads0 = ops.IndexedSlices( + constant_op.constant(grads0_np), + constant_op.constant(grads0_np_indices), constant_op.constant([2])) + grads1_np_indices = np.array([0, 1], dtype=np.int32) + grads1 = ops.IndexedSlices( + constant_op.constant(grads1_np), + constant_op.constant(grads1_np_indices), constant_op.constant([2])) + opt = adam_gs_optimizer.AdamGSOptimizer(global_step=global_step) + update = opt.apply_gradients(zip([grads0, grads1], [var0, var1]), + global_step=global_step) + variables.global_variables_initializer().run() + + # Fetch params to validate initial values + self.assertAllClose([1.0, 2.0], self.evaluate(var0)) + self.assertAllClose([3.0, 4.0], self.evaluate(var1)) + + beta1_power, beta2_power = opt._get_beta_accumulators() + + # Run 3 steps of Adam + for t in range(1, 4): + self.assertAllCloseAccordingToType(0.9**t, self.evaluate(beta1_power)) + self.assertAllCloseAccordingToType(0.999**t, + self.evaluate(beta2_power)) + update.run() + + var0_np, m0, v0 = adam_update_numpy(var0_np, grads0_np, t, m0, v0) + var1_np, m1, v1 = adam_update_numpy(var1_np, grads1_np, t, m1, v1) + + # Validate updated params + self.assertAllCloseAccordingToType(var0_np, self.evaluate(var0)) + self.assertAllCloseAccordingToType(var1_np, self.evaluate(var1)) + + def testSparse(self): + self.doTestSparse(use_resource=False) + + def testResourceSparse(self): + self.doTestSparse(use_resource=True) + + def testSparseDevicePlacement(self): + for index_dtype in [dtypes.int32, dtypes.int64]: + with self.cached_session(force_gpu=test.is_gpu_available()): + # If a GPU is available, tests that all optimizer ops can be placed on + # it (i.e. they have GPU kernels). + var = variables.Variable([[1.0], [2.0]]) + indices = constant_op.constant([0, 1], dtype=index_dtype) + gathered_sum = math_ops.reduce_sum(array_ops.gather(var, indices)) + optimizer = adam_gs_optimizer.AdamGSOptimizer(3.0) + minimize_op = optimizer.minimize(gathered_sum) + variables.global_variables_initializer().run() + minimize_op.run() + + def testSparseRepeatedIndices(self): + for dtype in [dtypes.half, dtypes.float32, dtypes.float64]: + with self.cached_session(): + repeated_index_global_step = variables.Variable( + array_ops.zeros([], dtypes.int64)) + aggregated_global_step = variables.Variable( + array_ops.zeros([], dtypes.int64)) + repeated_index_update_var = variables.Variable( + [[1.0], [2.0]], dtype=dtype) + aggregated_update_var = variables.Variable( + [[1.0], [2.0]], dtype=dtype) + grad_repeated_index = ops.IndexedSlices( + constant_op.constant( + [0.1, 0.1], shape=[2, 1], dtype=dtype), + constant_op.constant([1, 1]), + constant_op.constant([2, 1])) + grad_aggregated = ops.IndexedSlices( + constant_op.constant( + [0.2], shape=[1, 1], dtype=dtype), + constant_op.constant([1]), + constant_op.constant([2, 1])) + repeated_update = adam_gs_optimizer.AdamGSOptimizer( + global_step=repeated_index_global_step).apply_gradients( + [(grad_repeated_index, repeated_index_update_var)], + global_step=repeated_index_global_step) + aggregated_update = adam_gs_optimizer.AdamGSOptimizer( + global_step=aggregated_global_step).apply_gradients( + [(grad_aggregated, aggregated_update_var)], + global_step=aggregated_global_step) + variables.global_variables_initializer().run() + self.assertAllClose(aggregated_update_var.eval(), + self.evaluate(repeated_index_update_var)) + for _ in range(3): + repeated_update.run() + aggregated_update.run() + self.assertAllClose(aggregated_update_var.eval(), + self.evaluate(repeated_index_update_var)) + + def doTestBasic(self, use_resource=False, use_callable_params=False): + for i, dtype in enumerate([dtypes.half, dtypes.float32, dtypes.float64]): + with self.session(graph=ops.Graph()): + # Initialize variables for numpy implementation. + m0, v0, m1, v1 = 0.0, 0.0, 0.0, 0.0 + var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype) + grads0_np = np.array([0.1, 0.1], dtype=dtype.as_numpy_dtype) + var1_np = np.array([3.0, 4.0], dtype=dtype.as_numpy_dtype) + grads1_np = np.array([0.01, 0.01], dtype=dtype.as_numpy_dtype) + + if use_resource: + global_step = resource_variable_ops.ResourceVariable( + array_ops.zeros([], dtypes.int64), name="global_step_%d" % i) + var0 = resource_variable_ops.ResourceVariable( + var0_np, name="var0_%d" % i) + var1 = resource_variable_ops.ResourceVariable( + var1_np, name="var1_%d" % i) + else: + global_step = variables.Variable(array_ops.zeros([], dtypes.int64)) + var0 = variables.Variable(var0_np) + var1 = variables.Variable(var1_np) + grads0 = constant_op.constant(grads0_np) + grads1 = constant_op.constant(grads1_np) + + learning_rate = lambda: 0.001 + beta1 = lambda: 0.9 + beta2 = lambda: 0.999 + epsilon = lambda: 1e-8 + if not use_callable_params: + learning_rate = learning_rate() + beta1 = beta1() + beta2 = beta2() + epsilon = epsilon() + + opt = adam_gs_optimizer.AdamGSOptimizer(global_step=global_step, + learning_rate=learning_rate) + update = opt.apply_gradients(zip([grads0, grads1], [var0, var1]), + global_step=global_step) + opt_variables = opt.variables() + beta1_power, beta2_power = opt._get_beta_accumulators() + self.assertTrue(beta1_power is not None) + self.assertTrue(beta2_power is not None) + self.assertNotIn(beta1_power, opt_variables) + self.assertNotIn(beta2_power, opt_variables) + + if not context.executing_eagerly(): + with ops.Graph().as_default(): + # Shouldn't return non-slot variables from other graphs. + self.assertEqual(0, len(opt.variables())) + self.evaluate(variables.global_variables_initializer()) + # Fetch params to validate initial values + self.assertAllClose([1.0, 2.0], self.evaluate(var0)) + self.assertAllClose([3.0, 4.0], self.evaluate(var1)) + + # Run 3 steps of Adam + for t in range(1, 4): + if not context.executing_eagerly(): + self.evaluate(update) + self.assertAllCloseAccordingToType( + 0.9**(t + 1), self.evaluate(beta1_power)) + self.assertAllCloseAccordingToType( + 0.999**(t + 1), self.evaluate(beta2_power)) + else: + if t > 1: + opt.apply_gradients(zip([grads0, grads1], [var0, var1]), + global_step=global_step) + beta1_power, beta2_power = opt._get_beta_accumulators() + self.assertAllCloseAccordingToType( + 0.9**t, self.evaluate(beta1_power)) + self.assertAllCloseAccordingToType( + 0.999**t, self.evaluate(beta2_power)) + + var0_np, m0, v0 = adam_update_numpy(var0_np, grads0_np, t, m0, v0) + var1_np, m1, v1 = adam_update_numpy(var1_np, grads1_np, t, m1, v1) + + # Validate updated params + self.assertAllCloseAccordingToType(var0_np, self.evaluate(var0)) + self.assertAllCloseAccordingToType(var1_np, self.evaluate(var1)) + if use_resource: + self.assertEqual("var0_%d/Adam:0" % (i,), + opt.get_slot(var=var0, name="m").name) + + def testBasic(self): + with self.cached_session(): + self.doTestBasic(use_resource=False) + + @test_util.run_in_graph_and_eager_modes(reset_test=True) + def testResourceBasic(self): + self.doTestBasic(use_resource=True) + + def testBasicCallableParams(self): + with context.eager_mode(): + self.doTestBasic(use_resource=True, use_callable_params=True) + + def testTensorLearningRate(self): + for dtype in [dtypes.half, dtypes.float32, dtypes.float64]: + with self.cached_session(): + global_step = variables.Variable(array_ops.zeros([], dtypes.int64)) + # Initialize variables for numpy implementation. + m0, v0, m1, v1 = 0.0, 0.0, 0.0, 0.0 + var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype) + grads0_np = np.array([0.1, 0.1], dtype=dtype.as_numpy_dtype) + var1_np = np.array([3.0, 4.0], dtype=dtype.as_numpy_dtype) + grads1_np = np.array([0.01, 0.01], dtype=dtype.as_numpy_dtype) + + var0 = variables.Variable(var0_np) + var1 = variables.Variable(var1_np) + grads0 = constant_op.constant(grads0_np) + grads1 = constant_op.constant(grads1_np) + opt = adam_gs_optimizer.AdamGSOptimizer( + global_step=global_step, learning_rate=constant_op.constant(0.001)) + update = opt.apply_gradients(zip([grads0, grads1], [var0, var1]), + global_step=global_step) + variables.global_variables_initializer().run() + + # Fetch params to validate initial values + self.assertAllClose([1.0, 2.0], self.evaluate(var0)) + self.assertAllClose([3.0, 4.0], self.evaluate(var1)) + + beta1_power, beta2_power = opt._get_beta_accumulators() + + # Run 3 steps of Adam + for t in range(1, 4): + self.assertAllCloseAccordingToType(0.9**t, self.evaluate(beta1_power)) + self.assertAllCloseAccordingToType(0.999**t, + self.evaluate(beta2_power)) + update.run() + + var0_np, m0, v0 = adam_update_numpy(var0_np, grads0_np, t, m0, v0) + var1_np, m1, v1 = adam_update_numpy(var1_np, grads1_np, t, m1, v1) + + # Validate updated params + self.assertAllCloseAccordingToType(var0_np, self.evaluate(var0)) + self.assertAllCloseAccordingToType(var1_np, self.evaluate(var1)) + + def testSharing(self): + for dtype in [dtypes.half, dtypes.float32, dtypes.float64]: + with self.cached_session(): + global_step = variables.Variable(array_ops.zeros([], dtypes.int64)) + # Initialize variables for numpy implementation. + m0, v0, m1, v1 = 0.0, 0.0, 0.0, 0.0 + var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype) + grads0_np = np.array([0.1, 0.1], dtype=dtype.as_numpy_dtype) + var1_np = np.array([3.0, 4.0], dtype=dtype.as_numpy_dtype) + grads1_np = np.array([0.01, 0.01], dtype=dtype.as_numpy_dtype) + + var0 = variables.Variable(var0_np) + var1 = variables.Variable(var1_np) + grads0 = constant_op.constant(grads0_np) + grads1 = constant_op.constant(grads1_np) + opt = adam_gs_optimizer.AdamGSOptimizer(global_step=global_step) + update1 = opt.apply_gradients(zip([grads0, grads1], [var0, var1]), + global_step=global_step) + update2 = opt.apply_gradients(zip([grads0, grads1], [var0, var1]), + global_step=global_step) + variables.global_variables_initializer().run() + + beta1_power, beta2_power = opt._get_beta_accumulators() + + # Fetch params to validate initial values + self.assertAllClose([1.0, 2.0], self.evaluate(var0)) + self.assertAllClose([3.0, 4.0], self.evaluate(var1)) + + # Run 3 steps of intertwined Adam1 and Adam2. + for t in range(1, 4): + self.assertAllCloseAccordingToType(0.9**t, self.evaluate(beta1_power)) + self.assertAllCloseAccordingToType(0.999**t, + self.evaluate(beta2_power)) + if t % 2 == 0: + update1.run() + else: + update2.run() + + var0_np, m0, v0 = adam_update_numpy(var0_np, grads0_np, t, m0, v0) + var1_np, m1, v1 = adam_update_numpy(var1_np, grads1_np, t, m1, v1) + + # Validate updated params + self.assertAllCloseAccordingToType(var0_np, self.evaluate(var0)) + self.assertAllCloseAccordingToType(var1_np, self.evaluate(var1)) + + def testTwoSessions(self): + optimizer = adam_gs_optimizer.AdamGSOptimizer() + + with context.eager_mode(): + var0 = variables.Variable(np.array([1.0, 2.0]), name="v0") + grads0 = constant_op.constant(np.array([0.1, 0.1])) + optimizer.apply_gradients([(grads0, var0)]) + + g = ops.Graph() + with g.as_default(): + with session.Session(): + var0 = variables.Variable(np.array([1.0, 2.0]), name="v0") + grads0 = constant_op.constant(np.array([0.1, 0.1])) + optimizer.apply_gradients([(grads0, var0)]) + + gg = ops.Graph() + with gg.as_default(): + with session.Session(): + var0 = variables.Variable(np.array([1.0, 2.0]), name="v0") + grads0 = constant_op.constant(np.array([0.1, 0.1])) + + # If the optimizer saves any state not keyed by graph the following line + # fails. + optimizer.apply_gradients([(grads0, var0)]) + + def testSlotsUniqueEager(self): + with context.eager_mode(): + v1 = resource_variable_ops.ResourceVariable(1.) + v2 = resource_variable_ops.ResourceVariable(1.) + opt = adam_gs_optimizer.AdamGSOptimizer(1.) + opt.minimize(lambda: v1 + v2) + # There should be two unique slot variables for v1 and v2 respectively. + self.assertEqual(4, len(set(opt.variables()))) + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/opt/python/training/lazy_adam_gs_optimizer.py b/tensorflow/contrib/opt/python/training/lazy_adam_gs_optimizer.py new file mode 100644 index 0000000000000000000000000000000000000000..8827007e4d7f6722398a8e36bd626377842d92ef --- /dev/null +++ b/tensorflow/contrib/opt/python/training/lazy_adam_gs_optimizer.py @@ -0,0 +1,114 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""LazyAdam rewrite to use global step for computing beta1 & beta2 accumulation. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.opt.python.training import adam_gs_optimizer +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import resource_variable_ops +from tensorflow.python.ops import state_ops + + +class LazyAdamGSOptimizer(adam_gs_optimizer.AdamGSOptimizer): + """Variant of the Adam optimizer that handles sparse updates more efficiently. + + Branched from tf.contrib.opt.LazyAdamGSOptimizer. The only difference is to + pass global step for computing beta1 and beta2 accumulators, instead of having + optimizer keep its own independent beta1 and beta2 accumulators as non-slot + variables. + + The original Adam algorithm maintains two moving-average accumulators for + each trainable variable; the accumulators are updated at every step. + This class provides lazier handling of gradient updates for sparse variables. + It only updates moving-average accumulators for sparse variable indices that + appear in the current batch, rather than updating the accumulators for all + indices. Compared with the original Adam optimizer, it can provide large + improvements in model training throughput for some applications. However, it + provides slightly different semantics than the original Adam algorithm, and + may lead to different empirical results. + """ + + def _apply_sparse(self, grad, var): + beta1_power, beta2_power = self._get_beta_accumulators() + beta1_power = math_ops.cast(beta1_power, var.dtype.base_dtype) + beta2_power = math_ops.cast(beta2_power, var.dtype.base_dtype) + lr_t = math_ops.cast(self._lr_t, var.dtype.base_dtype) + beta1_t = math_ops.cast(self._beta1_t, var.dtype.base_dtype) + beta2_t = math_ops.cast(self._beta2_t, var.dtype.base_dtype) + epsilon_t = math_ops.cast(self._epsilon_t, var.dtype.base_dtype) + lr = (lr_t * math_ops.sqrt(1 - beta2_power) / (1 - beta1_power)) + + # \\(m := beta1 * m + (1 - beta1) * g_t\\) + m = self.get_slot(var, "m") + m_t = state_ops.scatter_update(m, grad.indices, + beta1_t * array_ops.gather(m, grad.indices) + + (1 - beta1_t) * grad.values, + use_locking=self._use_locking) + + # \\(v := beta2 * v + (1 - beta2) * (g_t * g_t)\\) + v = self.get_slot(var, "v") + v_t = state_ops.scatter_update(v, grad.indices, + beta2_t * array_ops.gather(v, grad.indices) + + (1 - beta2_t) * math_ops.square(grad.values), + use_locking=self._use_locking) + + # \\(variable -= learning_rate * m_t / (epsilon_t + sqrt(v_t))\\) + m_t_slice = array_ops.gather(m_t, grad.indices) + v_t_slice = array_ops.gather(v_t, grad.indices) + denominator_slice = math_ops.sqrt(v_t_slice) + epsilon_t + var_update = state_ops.scatter_sub(var, grad.indices, + lr * m_t_slice / denominator_slice, + use_locking=self._use_locking) + return control_flow_ops.group(var_update, m_t, v_t) + + def _resource_apply_sparse(self, grad, var, indices): + beta1_power, beta2_power = self._get_beta_accumulators() + beta1_power = math_ops.cast(beta1_power, var.dtype.base_dtype) + beta2_power = math_ops.cast(beta2_power, var.dtype.base_dtype) + lr_t = math_ops.cast(self._lr_t, var.dtype.base_dtype) + beta1_t = math_ops.cast(self._beta1_t, var.dtype.base_dtype) + beta2_t = math_ops.cast(self._beta2_t, var.dtype.base_dtype) + epsilon_t = math_ops.cast(self._epsilon_t, var.dtype.base_dtype) + lr = (lr_t * math_ops.sqrt(1 - beta2_power) / (1 - beta1_power)) + + # \\(m := beta1 * m + (1 - beta1) * g_t\\) + m = self.get_slot(var, "m") + m_t_slice = beta1_t * array_ops.gather(m, indices) + (1 - beta1_t) * grad + m_update_op = resource_variable_ops.resource_scatter_update(m.handle, + indices, + m_t_slice) + + # \\(v := beta2 * v + (1 - beta2) * (g_t * g_t)\\) + v = self.get_slot(var, "v") + v_t_slice = (beta2_t * array_ops.gather(v, indices) + + (1 - beta2_t) * math_ops.square(grad)) + v_update_op = resource_variable_ops.resource_scatter_update(v.handle, + indices, + v_t_slice) + + # \\(variable -= learning_rate * m_t / (epsilon_t + sqrt(v_t))\\) + var_slice = lr * m_t_slice / (math_ops.sqrt(v_t_slice) + epsilon_t) + var_update_op = resource_variable_ops.resource_scatter_sub(var.handle, + indices, + var_slice) + + return control_flow_ops.group(var_update_op, m_update_op, v_update_op) diff --git a/tensorflow/contrib/opt/python/training/lazy_adam_gs_optimizer_test.py b/tensorflow/contrib/opt/python/training/lazy_adam_gs_optimizer_test.py new file mode 100644 index 0000000000000000000000000000000000000000..bdc9a02a546c8399172d0c5b58941b4d80179955 --- /dev/null +++ b/tensorflow/contrib/opt/python/training/lazy_adam_gs_optimizer_test.py @@ -0,0 +1,402 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Tests for LazyAdamGSOptimizer.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from absl.testing import parameterized +import numpy as np + +from tensorflow.contrib.opt.python.training import lazy_adam_gs_optimizer +from tensorflow.python.eager import context +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.framework import test_util +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import resource_variable_ops +from tensorflow.python.ops import variables +from tensorflow.python.platform import test + + +def adam_update_numpy(param, + g_t, + t, + m, + v, + alpha=0.001, + beta1=0.9, + beta2=0.999, + epsilon=1e-8): + alpha_t = alpha * np.sqrt(1 - beta2**t) / (1 - beta1**t) + + m_t = beta1 * m + (1 - beta1) * g_t + v_t = beta2 * v + (1 - beta2) * g_t * g_t + + param_t = param - alpha_t * m_t / (np.sqrt(v_t) + epsilon) + return param_t, m_t, v_t + + +class LazyAdamGSOptimizerTest(test.TestCase, parameterized.TestCase): + + @parameterized.parameters([False, True]) + def testSparse(self, use_resource): + for dtype in [dtypes.half, dtypes.float32, dtypes.float64]: + with self.cached_session(): + # Initialize variables for numpy implementation. + m0, v0, m1, v1 = 0.0, 0.0, 0.0, 0.0 + var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype) + grads0_np = np.array([0.1, 0.1], dtype=dtype.as_numpy_dtype) + var1_np = np.array([3.0, 4.0], dtype=dtype.as_numpy_dtype) + grads1_np = np.array([0.01, 0.01], dtype=dtype.as_numpy_dtype) + + if use_resource: + global_step = resource_variable_ops.ResourceVariable( + array_ops.zeros([], dtypes.int64)) + var0 = resource_variable_ops.ResourceVariable(var0_np) + var1 = resource_variable_ops.ResourceVariable(var1_np) + else: + global_step = variables.Variable(array_ops.zeros([], dtypes.int64)) + var0 = variables.Variable(var0_np) + var1 = variables.Variable(var1_np) + + grads0_np_indices = np.array([0, 1], dtype=np.int32) + grads0 = ops.IndexedSlices( + constant_op.constant(grads0_np), + constant_op.constant(grads0_np_indices), constant_op.constant([2])) + grads1_np_indices = np.array([0, 1], dtype=np.int32) + grads1 = ops.IndexedSlices( + constant_op.constant(grads1_np), + constant_op.constant(grads1_np_indices), constant_op.constant([2])) + opt = lazy_adam_gs_optimizer.LazyAdamGSOptimizer( + global_step=global_step) + update = opt.apply_gradients(zip([grads0, grads1], [var0, var1]), + global_step=global_step) + variables.global_variables_initializer().run() + + # Fetch params to validate initial values + self.assertAllClose([1.0, 2.0], var0.eval()) + self.assertAllClose([3.0, 4.0], var1.eval()) + + beta1_power, beta2_power = opt._get_beta_accumulators() + + # Run 3 steps of Adam + for t in range(1, 4): + self.assertAllCloseAccordingToType(0.9**t, beta1_power.eval()) + self.assertAllCloseAccordingToType(0.999**t, beta2_power.eval()) + update.run() + + var0_np, m0, v0 = adam_update_numpy(var0_np, grads0_np, t, m0, v0) + var1_np, m1, v1 = adam_update_numpy(var1_np, grads1_np, t, m1, v1) + + # Validate updated params + self.assertAllCloseAccordingToType(var0_np, var0.eval()) + self.assertAllCloseAccordingToType(var1_np, var1.eval()) + + @parameterized.parameters([False, True]) + def testSparseDevicePlacement(self, use_resource): + for index_dtype in [dtypes.int32, dtypes.int64]: + with self.cached_session(force_gpu=test.is_gpu_available()): + # If a GPU is available, tests that all optimizer ops can be placed on + # it (i.e. they have GPU kernels). + if use_resource: + global_step = resource_variable_ops.ResourceVariable( + array_ops.zeros([], dtypes.int64)) + var = resource_variable_ops.ResourceVariable([[1.0], [2.0]]) + else: + global_step = variables.Variable(array_ops.zeros([], dtypes.int64)) + var = variables.Variable([[1.0], [2.0]]) + + indices = constant_op.constant([0, 1], dtype=index_dtype) + gathered_sum = math_ops.reduce_sum(array_ops.gather(var, indices)) + optimizer = lazy_adam_gs_optimizer.LazyAdamGSOptimizer( + global_step=global_step, learning_rate=3.0) + minimize_op = optimizer.minimize(gathered_sum, global_step=global_step) + variables.global_variables_initializer().run() + minimize_op.run() + + @parameterized.parameters([False, True]) + def testSparseRepeatedIndices(self, use_resource): + for dtype in [dtypes.half, dtypes.float32, dtypes.float64]: + with self.cached_session(): + if use_resource: + repeated_index_global_step = resource_variable_ops.ResourceVariable( + array_ops.zeros([], dtypes.int64)) + aggregated_global_step = resource_variable_ops.ResourceVariable( + array_ops.zeros([], dtypes.int64)) + repeated_index_update_var = resource_variable_ops.ResourceVariable( + [[1.0], [2.0]], dtype=dtype) + aggregated_update_var = resource_variable_ops.ResourceVariable( + [[1.0], [2.0]], dtype=dtype) + else: + repeated_index_global_step = variables.Variable( + array_ops.zeros([], dtypes.int64)) + aggregated_global_step = variables.Variable( + array_ops.zeros([], dtypes.int64)) + repeated_index_update_var = variables.Variable( + [[1.0], [2.0]], dtype=dtype) + aggregated_update_var = variables.Variable( + [[1.0], [2.0]], dtype=dtype) + + grad_repeated_index = ops.IndexedSlices( + constant_op.constant( + [0.1, 0.1], shape=[2, 1], dtype=dtype), + constant_op.constant([1, 1]), + constant_op.constant([2, 1])) + grad_aggregated = ops.IndexedSlices( + constant_op.constant( + [0.2], shape=[1, 1], dtype=dtype), + constant_op.constant([1]), + constant_op.constant([2, 1])) + repeated_update_opt = lazy_adam_gs_optimizer.LazyAdamGSOptimizer( + global_step=repeated_index_global_step) + repeated_update = repeated_update_opt.apply_gradients( + [(grad_repeated_index, repeated_index_update_var)], + global_step=repeated_index_global_step) + aggregated_update_opt = lazy_adam_gs_optimizer.LazyAdamGSOptimizer( + global_step=aggregated_global_step) + aggregated_update = aggregated_update_opt.apply_gradients( + [(grad_aggregated, aggregated_update_var)], + global_step=aggregated_global_step) + variables.global_variables_initializer().run() + self.assertAllClose(aggregated_update_var.eval(), + repeated_index_update_var.eval()) + for _ in range(3): + repeated_update.run() + aggregated_update.run() + self.assertAllClose(aggregated_update_var.eval(), + repeated_index_update_var.eval()) + + def doTestBasic(self, use_resource=False, use_callable_params=False): + for i, dtype in enumerate([dtypes.half, dtypes.float32, dtypes.float64]): + with self.session(graph=ops.Graph()): + # Initialize variables for numpy implementation. + m0, v0, m1, v1 = 0.0, 0.0, 0.0, 0.0 + var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype) + grads0_np = np.array([0.1, 0.1], dtype=dtype.as_numpy_dtype) + var1_np = np.array([3.0, 4.0], dtype=dtype.as_numpy_dtype) + grads1_np = np.array([0.01, 0.01], dtype=dtype.as_numpy_dtype) + + if use_resource: + global_step = resource_variable_ops.ResourceVariable( + array_ops.zeros([], dtypes.int64), name="global_step_%d" % i) + var0 = resource_variable_ops.ResourceVariable( + var0_np, name="var0_%d" % i) + var1 = resource_variable_ops.ResourceVariable( + var1_np, name="var1_%d" % i) + else: + global_step = variables.Variable(array_ops.zeros([], dtypes.int64)) + var0 = variables.Variable(var0_np) + var1 = variables.Variable(var1_np) + grads0 = constant_op.constant(grads0_np) + grads1 = constant_op.constant(grads1_np) + + learning_rate = lambda: 0.001 + beta1 = lambda: 0.9 + beta2 = lambda: 0.999 + epsilon = lambda: 1e-8 + if not use_callable_params: + learning_rate = learning_rate() + beta1 = beta1() + beta2 = beta2() + epsilon = epsilon() + + opt = lazy_adam_gs_optimizer.LazyAdamGSOptimizer( + global_step=global_step, learning_rate=learning_rate) + update = opt.apply_gradients(zip([grads0, grads1], [var0, var1]), + global_step=global_step) + opt_variables = opt.variables() + beta1_power, beta2_power = opt._get_beta_accumulators() + self.assertIsNotNone(beta1_power) + self.assertIsNotNone(beta2_power is not None) + self.assertNotIn(beta1_power, opt_variables) + self.assertNotIn(beta2_power, opt_variables) + + if not context.executing_eagerly(): + with ops.Graph().as_default(): + # Shouldn't return non-slot variables from other graphs. + self.assertEqual(0, len(opt.variables())) + self.evaluate(variables.global_variables_initializer()) + # Fetch params to validate initial values + self.assertAllClose([1.0, 2.0], self.evaluate(var0)) + self.assertAllClose([3.0, 4.0], self.evaluate(var1)) + + # Run 3 steps of Adam + for t in range(1, 4): + if not context.executing_eagerly(): + self.evaluate(update) + self.assertAllCloseAccordingToType( + 0.9**(t + 1), self.evaluate(beta1_power)) + self.assertAllCloseAccordingToType( + 0.999**(t + 1), self.evaluate(beta2_power)) + else: + if t > 1: + opt.apply_gradients(zip([grads0, grads1], [var0, var1]), + global_step=global_step) + beta1_power, beta2_power = opt._get_beta_accumulators() + self.assertAllCloseAccordingToType( + 0.9**t, self.evaluate(beta1_power)) + self.assertAllCloseAccordingToType( + 0.999**t, self.evaluate(beta2_power)) + + var0_np, m0, v0 = adam_update_numpy(var0_np, grads0_np, t, m0, v0) + var1_np, m1, v1 = adam_update_numpy(var1_np, grads1_np, t, m1, v1) + + # Validate updated params + self.assertAllCloseAccordingToType(var0_np, self.evaluate(var0)) + self.assertAllCloseAccordingToType(var1_np, self.evaluate(var1)) + if use_resource: + self.assertEqual("var0_%d/Adam:0" % (i,), + opt.get_slot(var=var0, name="m").name) + + def testBasic(self): + with self.cached_session(): + self.doTestBasic(use_resource=False) + + @test_util.run_in_graph_and_eager_modes(reset_test=True) + def testResourceBasic(self): + self.doTestBasic(use_resource=True) + + def testBasicCallableParams(self): + with context.eager_mode(): + self.doTestBasic(use_resource=True, use_callable_params=True) + + def testTensorLearningRate(self): + for dtype in [dtypes.half, dtypes.float32, dtypes.float64]: + with self.cached_session(): + global_step = variables.Variable(array_ops.zeros([], dtypes.int64)) + # Initialize variables for numpy implementation. + m0, v0, m1, v1 = 0.0, 0.0, 0.0, 0.0 + var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype) + grads0_np = np.array([0.1, 0.1], dtype=dtype.as_numpy_dtype) + var1_np = np.array([3.0, 4.0], dtype=dtype.as_numpy_dtype) + grads1_np = np.array([0.01, 0.01], dtype=dtype.as_numpy_dtype) + + var0 = variables.Variable(var0_np) + var1 = variables.Variable(var1_np) + grads0 = constant_op.constant(grads0_np) + grads1 = constant_op.constant(grads1_np) + opt = lazy_adam_gs_optimizer.LazyAdamGSOptimizer( + global_step=global_step, learning_rate=constant_op.constant(0.001)) + update = opt.apply_gradients(zip([grads0, grads1], [var0, var1]), + global_step=global_step) + variables.global_variables_initializer().run() + + # Fetch params to validate initial values + self.assertAllClose([1.0, 2.0], var0.eval()) + self.assertAllClose([3.0, 4.0], var1.eval()) + + beta1_power, beta2_power = opt._get_beta_accumulators() + + # Run 3 steps of Adam + for t in range(1, 4): + self.assertAllCloseAccordingToType(0.9**t, beta1_power.eval()) + self.assertAllCloseAccordingToType(0.999**t, beta2_power.eval()) + update.run() + + var0_np, m0, v0 = adam_update_numpy(var0_np, grads0_np, t, m0, v0) + var1_np, m1, v1 = adam_update_numpy(var1_np, grads1_np, t, m1, v1) + + # Validate updated params + self.assertAllCloseAccordingToType(var0_np, var0.eval()) + self.assertAllCloseAccordingToType(var1_np, var1.eval()) + + def testSharing(self): + for dtype in [dtypes.half, dtypes.float32, dtypes.float64]: + with self.cached_session(): + global_step = variables.Variable(array_ops.zeros([], dtypes.int64)) + # Initialize variables for numpy implementation. + m0, v0, m1, v1 = 0.0, 0.0, 0.0, 0.0 + var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype) + grads0_np = np.array([0.1, 0.1], dtype=dtype.as_numpy_dtype) + var1_np = np.array([3.0, 4.0], dtype=dtype.as_numpy_dtype) + grads1_np = np.array([0.01, 0.01], dtype=dtype.as_numpy_dtype) + + var0 = variables.Variable(var0_np) + var1 = variables.Variable(var1_np) + grads0 = constant_op.constant(grads0_np) + grads1 = constant_op.constant(grads1_np) + opt = lazy_adam_gs_optimizer.LazyAdamGSOptimizer( + global_step=global_step) + update1 = opt.apply_gradients(zip([grads0, grads1], [var0, var1]), + global_step=global_step) + update2 = opt.apply_gradients(zip([grads0, grads1], [var0, var1]), + global_step=global_step) + variables.global_variables_initializer().run() + + beta1_power, beta2_power = opt._get_beta_accumulators() + + # Fetch params to validate initial values + self.assertAllClose([1.0, 2.0], var0.eval()) + self.assertAllClose([3.0, 4.0], var1.eval()) + + # Run 3 steps of intertwined Adam1 and Adam2. + for t in range(1, 4): + self.assertAllCloseAccordingToType(0.9**t, beta1_power.eval()) + self.assertAllCloseAccordingToType(0.999**t, beta2_power.eval()) + if t % 2 == 0: + update1.run() + else: + update2.run() + + var0_np, m0, v0 = adam_update_numpy(var0_np, grads0_np, t, m0, v0) + var1_np, m1, v1 = adam_update_numpy(var1_np, grads1_np, t, m1, v1) + + # Validate updated params + self.assertAllCloseAccordingToType(var0_np, var0.eval()) + self.assertAllCloseAccordingToType(var1_np, var1.eval()) + + def testTwoSessions(self): + optimizer = lazy_adam_gs_optimizer.LazyAdamGSOptimizer() + + with context.eager_mode(): + var0 = variables.Variable(np.array([1.0, 2.0]), name="v0") + grads0 = constant_op.constant(np.array([0.1, 0.1])) + optimizer.apply_gradients([(grads0, var0)]) + + g = ops.Graph() + with g.as_default(): + with self.session(graph=g): + var0 = variables.Variable(np.array([1.0, 2.0]), name="v0") + grads0 = constant_op.constant(np.array([0.1, 0.1])) + optimizer.apply_gradients([(grads0, var0)]) + + gg = ops.Graph() + with gg.as_default(): + with self.session(graph=gg): + var0 = variables.Variable(np.array([1.0, 2.0]), name="v0") + grads0 = constant_op.constant(np.array([0.1, 0.1])) + + # If the optimizer saves any state not keyed by graph the following line + # fails. + optimizer.apply_gradients([(grads0, var0)]) + + def testSlotsUniqueEager(self): + with context.eager_mode(): + v1 = resource_variable_ops.ResourceVariable(1.) + v2 = resource_variable_ops.ResourceVariable(1.) + opt = lazy_adam_gs_optimizer.LazyAdamGSOptimizer(1.) + opt.minimize(lambda: v1 + v2) + # There should be two non-slot variables, and two unique slot variables + # for v1 and v2 respectively. + self.assertLen(set(opt.variables()), 4) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/quantize/python/quantize.py b/tensorflow/contrib/quantize/python/quantize.py index 21d1b1213090273b5abd8e012f8711db98c94347..7c973fe597181b822e617db1f85a08f1b678e26f 100644 --- a/tensorflow/contrib/quantize/python/quantize.py +++ b/tensorflow/contrib/quantize/python/quantize.py @@ -685,7 +685,7 @@ def _InsertQuantOp(context, [1; 2^bits - 1] or wide range [0; 2^bits - 1]. producer_scope: The restriction of producer scope. If not None, the new op will be inserted only when the producer is in this scope. - consumer_scope: The restriction of producer scope. If not None, the new op + consumer_scope: The restriction of consumer scope. If not None, the new op will be inserted only when all the consumers are in this scope. Raises: ValueError: When producer operation is not directly connected to the diff --git a/tensorflow/contrib/receptive_field/README.md b/tensorflow/contrib/receptive_field/README.md index 79b015a9163f5727caa40b54579c71e57621c92f..d1c41e4c0a11028765c9fc0dc345cb29453baa31 100644 --- a/tensorflow/contrib/receptive_field/README.md +++ b/tensorflow/contrib/receptive_field/README.md @@ -185,5 +185,4 @@ Effective padding (vertical) = 1482 ## Authors -André Araujo (github id: andrefaraujo) and Mark Sandler (github id: -marksandler) +André Araujo (@andrefaraujo) and Mark Sandler (@marksandler) diff --git a/tensorflow/contrib/receptive_field/python/util/examples/compute_rf.py b/tensorflow/contrib/receptive_field/python/util/examples/compute_rf.py index d6fdd12bbe37fb0e0cb12f1d0adc3fce29b19e8a..72f98ccc32e945b48b5f1b570bcca323a5b5f48a 100644 --- a/tensorflow/contrib/receptive_field/python/util/examples/compute_rf.py +++ b/tensorflow/contrib/receptive_field/python/util/examples/compute_rf.py @@ -12,10 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Computes Receptive Field (RF) information given a graph protobuf. - -For an example of usage, see accompanying file compute_rf.sh -""" +"""Computes Receptive Field (RF) information given a graph protobuf.""" from __future__ import absolute_import from __future__ import division diff --git a/tensorflow/contrib/receptive_field/python/util/examples/rf_benchmark.py b/tensorflow/contrib/receptive_field/python/util/examples/rf_benchmark.py index a298b4d49038468299b58140758c69675368e855..325929a5937ac60a6134fae064e7633a4c57473d 100644 --- a/tensorflow/contrib/receptive_field/python/util/examples/rf_benchmark.py +++ b/tensorflow/contrib/receptive_field/python/util/examples/rf_benchmark.py @@ -16,8 +16,6 @@ The receptive field (and related parameters) for the different models are printed to stdout, and may also optionally be written to a CSV file. - -For an example of usage, see rf_benchmark.sh """ from __future__ import absolute_import @@ -262,11 +260,11 @@ def _model_rf(graphdef, information will be computed. model_type: Type of model to be used, used only for printing purposes. csv_writer: A CSV writer for RF parameters, which is used if it is not None. - input_resolution: Input resolution to use when computing RF - parameters. This is important for the case where padding can only be - defined if the input resolution is known, which may happen if using SAME - padding. This is assumed the resolution for both height and width. If - None, we consider the resolution is unknown. + input_resolution: Input resolution to use when computing RF parameters. This + is important for the case where padding can only be defined if the input + resolution is known, which may happen if using SAME padding. This is + assumed the resolution for both height and width. If None, we consider the + resolution is unknown. """ for desired_end_point_key in desired_end_point_keys: print('- %s:' % desired_end_point_key) @@ -283,10 +281,10 @@ def _model_rf(graphdef, if (receptive_field_x == receptive_field_y) and ( effective_stride_x == effective_stride_y) and ( effective_padding_x == effective_padding_y): - print('Receptive field size = %5s, effective stride = %5s, effective ' - 'padding = %5s' % (str(receptive_field_x), - str(effective_stride_x), - str(effective_padding_x))) + print( + 'Receptive field size = %5s, effective stride = %5s, effective ' + 'padding = %5s' % (str(receptive_field_x), str(effective_stride_x), + str(effective_padding_x))) else: print('Receptive field size: horizontal = %5s, vertical = %5s. ' 'Effective stride: horizontal = %5s, vertical = %5s. Effective ' @@ -362,9 +360,8 @@ def _process_model_rf(model_type='resnet_v1_50', defined if the input resolution is known, which may happen if using SAME padding. The entries in the list are assumed the resolution for both height and width. If one of the elements in the list is None, we consider - it to mean that the resolution is unknown. If the list itself is None, - we use the default list [None, 224, 321]. - + it to mean that the resolution is unknown. If the list itself is None, we + use the default list [None, 224, 321]. """ # Process default value for this list. if input_resolutions is None: @@ -477,8 +474,8 @@ def _mobilenet_v1_rf(csv_writer=None): csv_writer: A CSV writer for RF parameters, which is used if it is not None. """ for model_type in _SUPPORTED_MOBILENETV1_VARIANTS: - with slim.arg_scope( - [slim.batch_norm, slim.dropout], is_training=False) as arg_sc: + with slim.arg_scope([slim.batch_norm, slim.dropout], + is_training=False) as arg_sc: _process_model_rf(model_type, csv_writer, arg_sc) diff --git a/tensorflow/contrib/receptive_field/python/util/receptive_field.py b/tensorflow/contrib/receptive_field/python/util/receptive_field.py index b9bd2f09761ab10a62d37e8e2580b93b9b8a4453..9127c772c75279d9c8eacc5a17680beba9247d01 100644 --- a/tensorflow/contrib/receptive_field/python/util/receptive_field.py +++ b/tensorflow/contrib/receptive_field/python/util/receptive_field.py @@ -12,12 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Functions to compute receptive field of a fully-convolutional network. - -Please refer to the following g3doc for detailed explanation on how this -computation is performed, and why it is important: -g3doc/photos/vision/features/delf/g3doc/rf_computation.md -""" +"""Functions to compute receptive field of a fully-convolutional network.""" from __future__ import absolute_import from __future__ import division @@ -96,8 +91,8 @@ class ReceptiveField(object): Args: y: An array of feature coordinates with shape `(..., d)`, where `d` is the number of dimensions of the coordinates. - axis: The dimensions for which to compute the input center coordinates. - If `None` (the default), compute the input center coordinates for all + axis: The dimensions for which to compute the input center coordinates. If + `None` (the default), compute the input center coordinates for all dimensions. Returns: @@ -127,8 +122,8 @@ class ReceptiveField(object): Args: x: An array of input center coordinates with shape `(..., d)`, where `d` is the number of dimensions of the coordinates. - axis: The dimensions for which to compute the feature coordinates. - If `None` (the default), compute the feature coordinates for all + axis: The dimensions for which to compute the feature coordinates. If + `None` (the default), compute the feature coordinates for all dimensions. Returns: @@ -274,14 +269,15 @@ def compute_receptive_field_from_graph_def(graph_def, continue # Get params for this layer. - (kernel_size_x, kernel_size_y, stride_x, stride_y, padding_x, - padding_y, _, _) = parse_layer_parameters.get_layer_params( + (kernel_size_x, kernel_size_y, stride_x, stride_y, padding_x, padding_y, + _, _) = parse_layer_parameters.get_layer_params( node, name_to_node, node_info[node.name].input_size) - logging.vlog(3, "kernel_size_x = %s, kernel_size_y = %s, " - "stride_x = %s, stride_y = %s, " - "padding_x = %s, padding_y = %s, input size = %s" % - (kernel_size_x, kernel_size_y, stride_x, stride_y, padding_x, - padding_y, node_info[node.name].input_size)) + logging.vlog( + 3, "kernel_size_x = %s, kernel_size_y = %s, " + "stride_x = %s, stride_y = %s, " + "padding_x = %s, padding_y = %s, input size = %s" % + (kernel_size_x, kernel_size_y, stride_x, stride_y, padding_x, + padding_y, node_info[node.name].input_size)) if padding_x is None or padding_y is None: undefined_padding = True @@ -352,15 +348,15 @@ def compute_receptive_field_from_graph_def(graph_def, raise ValueError( "Graph is not aligned since effective stride from different " "paths is different in vertical direction") - if (rf_sizes_x[inp_name] - 1 - ) / 2 - effective_paddings_x[inp_name] != ( - rf_size_input_x - 1) / 2 - effective_padding_input_x: + if (rf_sizes_x[inp_name] - + 1) / 2 - effective_paddings_x[inp_name] != ( + rf_size_input_x - 1) / 2 - effective_padding_input_x: raise ValueError( "Graph is not aligned since center shift from different " "paths is different in horizontal direction") - if (rf_sizes_y[inp_name] - 1 - ) / 2 - effective_paddings_y[inp_name] != ( - rf_size_input_y - 1) / 2 - effective_padding_input_y: + if (rf_sizes_y[inp_name] - + 1) / 2 - effective_paddings_y[inp_name] != ( + rf_size_input_y - 1) / 2 - effective_padding_input_y: raise ValueError( "Graph is not aligned since center shift from different " "paths is different in vertical direction") diff --git a/tensorflow/contrib/resampler/xla/resampler_ops_xla_test.py b/tensorflow/contrib/resampler/xla/resampler_ops_xla_test.py index d8ca0eab276b39f025d018edebb78eed7a8433bb..cec4c3c23305034d167a248a637425507750064e 100644 --- a/tensorflow/contrib/resampler/xla/resampler_ops_xla_test.py +++ b/tensorflow/contrib/resampler/xla/resampler_ops_xla_test.py @@ -164,6 +164,15 @@ class ResamplerOpsTest(xla_test.XLATestCase): expected = [[[0.0], [27.62]]] self._assertForwardOpMatchesExpected(input_np, warp_np, expected) + expected_grad_data = [[[[0.12], [0.27999997]], [[0.18000001], + [0.42000002]]]] + expected_grad_warp = [[[0., 0.], [22.60000038, 35.20000076]]] + + grad_output = np.ones([1, 2, 1], dtype=dtype) + self._assertBackwardOpMatchesExpected(input_np, warp_np, grad_output, + expected_grad_data, + expected_grad_warp) + # One of (x, y) is less than 0. for dtype in self.float_types: input_shape = [1, 2, 2, 1] @@ -171,11 +180,21 @@ class ResamplerOpsTest(xla_test.XLATestCase): input_np = np.array(input_data, dtype=dtype).reshape(input_shape) warp_shape = [1, 2, 2] + # -1 is out of bound for grad_warp. warp_data = [-1, 0.1, 0.7, 0.6] warp_np = np.array(warp_data, dtype=dtype).reshape(warp_shape) expected = [[[0.0], [27.62]]] self._assertForwardOpMatchesExpected(input_np, warp_np, expected) + expected_grad_data = [[[[0.12], [0.27999997]], [[0.18000001], + [0.42000002]]]] + expected_grad_warp = [[[0., 0.], [22.60000038, 35.20000076]]] + + grad_output = np.ones([1, 2, 1], dtype=dtype) + self._assertBackwardOpMatchesExpected(input_np, warp_np, grad_output, + expected_grad_data, + expected_grad_warp) + # Both of (x, y) are greater than image size. for dtype in self.float_types: input_shape = [1, 2, 2, 1] @@ -183,11 +202,20 @@ class ResamplerOpsTest(xla_test.XLATestCase): input_np = np.array(input_data, dtype=dtype).reshape(input_shape) warp_shape = [1, 2, 2] + # -0.1 is *inbound* for grad_warp and grad_data, 2.1 is out of bound. warp_data = [-0.1, 0.1, 1.2, 2.1] warp_np = np.array(warp_data, dtype=dtype).reshape(warp_shape) expected = [[[0.0], [0.0]]] self._assertForwardOpMatchesExpected(input_np, warp_np, expected) + expected_grad_data = [[[[0.81], [0.0]], [[0.09], [0.0]]]] + expected_grad_warp = [[[10.30, 2.7], [0.0, 0.0]]] + + grad_output = np.ones([1, 2, 1], dtype=dtype) + self._assertBackwardOpMatchesExpected(input_np, warp_np, grad_output, + expected_grad_data, + expected_grad_warp) + # One of (x, y) is greater than image size. for dtype in self.float_types: input_shape = [1, 2, 2, 1] @@ -200,6 +228,14 @@ class ResamplerOpsTest(xla_test.XLATestCase): expected = [[[0.0], [0.0]]] self._assertForwardOpMatchesExpected(input_np, warp_np, expected) + expected_grad_data = [[[[0.81], [0.81]], [[0.0], [0.08]]]] + expected_grad_warp = [[[-4.5, 9.5], [-9.9, 39.20]]] + + grad_output = np.ones([1, 2, 1], dtype=dtype) + self._assertBackwardOpMatchesExpected(input_np, warp_np, grad_output, + expected_grad_data, + expected_grad_warp) + if __name__ == '__main__': test.main() diff --git a/tensorflow/contrib/saved_model/python/saved_model/keras_saved_model.py b/tensorflow/contrib/saved_model/python/saved_model/keras_saved_model.py index ffba514bb96f5ce8d963cb0a0482738eafe88355..2a4b6eae367fe617e9a19d80f16eb3fda9ade1c0 100644 --- a/tensorflow/contrib/saved_model/python/saved_model/keras_saved_model.py +++ b/tensorflow/contrib/saved_model/python/saved_model/keras_saved_model.py @@ -22,53 +22,57 @@ import os import six from tensorflow.python.client import session -from tensorflow.python.estimator import keras as estimator_keras_util -from tensorflow.python.estimator import model_fn as model_fn_lib -from tensorflow.python.estimator.export import export as export_helpers from tensorflow.python.framework import ops from tensorflow.python.keras import backend as K from tensorflow.python.keras import models as models_lib from tensorflow.python.keras import optimizers from tensorflow.python.keras.engine import sequential +from tensorflow.python.keras.engine import training_utils from tensorflow.python.keras.metrics import Metric from tensorflow.python.keras.models import model_from_json from tensorflow.python.lib.io import file_io from tensorflow.python.ops import variables -from tensorflow.python.platform import gfile from tensorflow.python.platform import tf_logging as logging from tensorflow.python.saved_model import builder as saved_model_builder from tensorflow.python.saved_model import constants +from tensorflow.python.saved_model import save as save_lib from tensorflow.python.saved_model import utils_impl as saved_model_utils from tensorflow.python.training import saver as saver_lib from tensorflow.python.training.checkpointable import util as checkpointable_utils from tensorflow.python.util import compat +from tensorflow.python.util import nest +from tensorflow_estimator.python.estimator import keras as estimator_keras_util +from tensorflow_estimator.python.estimator import model_fn as model_fn_lib +from tensorflow_estimator.python.estimator.export import export as export_helpers def save_keras_model( - model, saved_model_path, custom_objects=None, as_text=None): - """Save a `tf.keras.Model` into Tensorflow SavedModel format. + model, saved_model_path, custom_objects=None, as_text=None, + input_signature=None, serving_only=False): + """Saves a `tf.keras.Model` into Tensorflow SavedModel format. `save_model` generates new files/folders under the `saved_model_path` folder: - 1) an asset folder containing the json string of the model's - configuration (topology). - 2) a checkpoint containing the model weights. - 3) a saved_model.pb file containing the model's MetaGraphs. The prediction + 1) a checkpoint containing the model weights. + 2) a saved_model.pb file containing the model's MetaGraphs. The prediction graph is always exported. The evaluaton and training graphs are exported if the following conditions are met: - Evaluation: model loss is defined. - Training: model is compiled with an optimizer defined under `tf.train`. This is because `tf.keras.optimizers.Optimizer` instances cannot be saved to checkpoints. - - Model Requirements: - - Model must be a sequential model or functional model. Subclassed models can - not be saved via this function, unless you provide an implementation for - get_config() and from_config(). - - All variables must be saveable by the model. In general, this condition is - met through the use of layers defined in the keras library. However, - there is currently a bug with variables created in Lambda layer functions - not being saved correctly (see - https://github.com/keras-team/keras/issues/9740). + 3) Model's json configuration, if model.get_config() has been implemented. + This file can be used to reload the model using + tf.keras.models.model_from_json(). Note that if any custom objects were + used, they should be passed to the `custom_object` argument when loading + the model. + + Model limitations: + - Sequential and functional models can always be saved. + - Subclassed models can only be saved when `serving_only=True`. This is due to + the current implementation copying the model in order to export the training + and evaluation graphs. Because the topology of subclassed models cannot be + determined, the subclassed models cannot be cloned. Subclassed models will + be entirely exportable in the future. Note that each mode is exported in separate graphs, so different modes do not share variables. To use the train graph with evaluation or prediction graphs, @@ -94,38 +98,88 @@ def save_keras_model( ``` Args: - model: A `tf.keras.Model` to be saved. + model: A `tf.keras.Model` to be saved. If the model is subclassed, the flag + `serving_only` must be set to True. saved_model_path: a string specifying the path to the SavedModel directory. The SavedModel will be saved to a timestamped folder created within this directory. custom_objects: Optional dictionary mapping string names to custom classes or functions (e.g. custom loss functions). - as_text: whether to write the `SavedModel` proto in text format. + as_text: whether to write the `SavedModel` proto in text format. Currently + unavailable in serving-only mode. + input_signature: A possibly nested sequence of `tf.TensorSpec` objects, used + to specify the expected model inputs. `input_signature`'s nested structure + should match the expected nested structure of the inputs to the model. If + this is not set, this function will attempt to infer the input shapes and + dtypes from the model. Note that if the model is subclassed, the tensor + inputs to the call function should be nested in the first argument (this + is a general requirement for using subclassed models with Keras functions + .fit(), .predict(), etc.). + serving_only: Export only the outputs produced from calling the model in + predict mode. The losses, optimizer, and other training configurations are + not saved. If the SavedModel will only be used for serving (rather than + retraining), or if the model is subclassed, this can be set to True. Returns: String path to the SavedModel folder, a subdirectory of `saved_model_path`. Raises: - NotImplementedError: If the model is a subclassed model. - ValueError: If a Sequential model does not have input shapes defined by the - user, and is not built. + NotImplementedError: If the model is a subclassed model, and serving_only is + False. + ValueError: If the input signature cannot be inferred from the model. """ + export_dir = export_helpers.get_timestamped_export_dir(saved_model_path) + + if serving_only: + save_lib.save( + model, export_dir, + signatures=training_utils.trace_model_call(model, input_signature)) + else: + _save_v1_format(model, export_dir, custom_objects, as_text, input_signature) + + try: + _export_model_json(model, export_dir) + except NotImplementedError: + logging.warning('Skipped saving model JSON, subclassed model does not have ' + 'get_config() defined.') + + return export_dir + + +def _export_model_json(model, saved_model_path): + """Saves model configuration as a json string under assets folder.""" + model_json = model.to_json() + model_json_filepath = os.path.join( + saved_model_utils.get_or_create_assets_dir(saved_model_path), + compat.as_text(constants.SAVED_MODEL_FILENAME_JSON)) + file_io.write_string_to_file(model_json_filepath, model_json) + + +def _export_model_variables(model, saved_model_path): + """Saves model weights in checkpoint format under variables folder.""" + saved_model_utils.get_or_create_variables_dir(saved_model_path) + checkpoint_prefix = saved_model_utils.get_variables_path(saved_model_path) + model.save_weights(checkpoint_prefix, save_format='tf', overwrite=True) + return checkpoint_prefix + + +def _save_v1_format(model, path, custom_objects, as_text, input_signature): + """Exports model to v1 SavedModel format.""" if not model._is_graph_network: if isinstance(model, sequential.Sequential): # If input shape is not directly set in the model, the exported model - # will assume that the inputs have the same shape as the shape the model - # was built model with. - if not model.built: + # will infer the expected shapes of the input from the model. + if not model.built and input_signature is None: raise ValueError( - 'Sequential model must be built before it can be exported.') + 'Sequential model\'s input shape is unknown. Please build the ' + 'model, or use the input_signature argument to specify the ' + 'model inputs.') else: raise NotImplementedError( - 'Exporting subclassed models is not yet supported.') + 'Subclassed models can only be exported for serving. Please set ' + 'argument serving_only=True.') - export_dir = export_helpers.get_timestamped_export_dir(saved_model_path) - temp_export_dir = export_helpers.get_temp_export_dir(export_dir) - - builder = saved_model_builder._SavedModelBuilder(temp_export_dir) + builder = saved_model_builder._SavedModelBuilder(path) # Manually save variables to export them in an object-based checkpoint. This # skips the `builder.add_meta_graph_and_variables()` step, which saves a @@ -133,7 +187,7 @@ def save_keras_model( # TODO(b/113134168): Add fn to Builder to save with object-based saver. # TODO(b/113178242): This should only export the model json structure. Only # one save is needed once the weights can be copied from the model to clone. - checkpoint_path = _export_model_json_and_variables(model, temp_export_dir) + checkpoint_path = _export_model_variables(model, path) # Export each mode. Use ModeKeys enums defined for `Estimator` to ensure that # Keras models and `Estimator`s are exported with the same format. @@ -143,10 +197,12 @@ def save_keras_model( export_args = {'builder': builder, 'model': model, 'custom_objects': custom_objects, - 'checkpoint_path': checkpoint_path} + 'checkpoint_path': checkpoint_path, + 'input_signature': input_signature} has_saved_vars = False if model.optimizer: + # TODO(kathywu): Verify this works with v2 optimizer. if isinstance(model.optimizer, optimizers.TFOptimizer): _export_mode(model_fn_lib.ModeKeys.TRAIN, has_saved_vars, **export_args) has_saved_vars = True @@ -161,34 +217,20 @@ def save_keras_model( builder.save(as_text) - gfile.Rename(temp_export_dir, export_dir) - return export_dir - - -def _export_model_json_and_variables(model, saved_model_path): - """Save model variables and json structure into SavedModel subdirectories.""" - # Save model configuration as a json string under assets folder. - model_json = model.to_json() - model_json_filepath = os.path.join( - saved_model_utils.get_or_create_assets_dir(saved_model_path), - compat.as_text(constants.SAVED_MODEL_FILENAME_JSON)) - file_io.write_string_to_file(model_json_filepath, model_json) - - # Save model weights in checkpoint format under variables folder. - saved_model_utils.get_or_create_variables_dir(saved_model_path) - checkpoint_prefix = saved_model_utils.get_variables_path(saved_model_path) - model.save_weights(checkpoint_prefix, save_format='tf', overwrite=True) - return checkpoint_prefix - def _get_var_list(model): - """Return list of all checkpointed saveable objects in the model.""" + """Returns list of all checkpointed saveable objects in the model.""" return checkpointable_utils.named_saveables(model) +def create_placeholder(spec): + return K.placeholder(shape=spec.shape, dtype=spec.dtype, name=spec.name) + + def _export_mode( - mode, has_saved_vars, builder, model, custom_objects, checkpoint_path): - """Export a model, and optionally save new vars from the clone model. + mode, has_saved_vars, builder, model, custom_objects, checkpoint_path, + input_signature): + """Exports a model, and optionally saves new vars from the clone model. Args: mode: A `tf.estimator.ModeKeys` string. @@ -199,6 +241,8 @@ def _export_mode( custom_objects: A dictionary mapping string names to custom classes or functions. checkpoint_path: String path to checkpoint. + input_signature: Nested TensorSpec containing the expected inputs. Can be + `None`, in which case the signature will be inferred from the model. Raises: ValueError: If the train/eval mode is being exported, but the model does @@ -214,10 +258,16 @@ def _export_mode( K.set_learning_phase(mode == model_fn_lib.ModeKeys.TRAIN) + if input_signature is None: + input_tensors = None + else: + input_tensors = nest.map_structure(create_placeholder, input_signature) + # Clone the model into blank graph. This will create placeholders for inputs # and targets. clone = models_lib.clone_and_build_model( - model, custom_objects=custom_objects, compile_clone=compile_clone) + model, input_tensors=input_tensors, custom_objects=custom_objects, + compile_clone=compile_clone) # Make sure that iterations variable is added to the global step collection, # to ensure that, when the SavedModel graph is loaded, the iterations @@ -271,7 +321,7 @@ def _export_mode( def _create_signature_def_map(model, mode): - """Create a SignatureDef map from a Keras model.""" + """Creates a SignatureDef map from a Keras model.""" inputs_dict = {name: x for name, x in zip(model.input_names, model.inputs)} if model.optimizer: targets_dict = {x.name.split(':')[0]: x @@ -309,14 +359,14 @@ def _create_signature_def_map(model, mode): def _assert_same_non_optimizer_objects(model, model_graph, clone, clone_graph): # pylint: disable=unused-argument - """Assert model and clone contain the same checkpointable objects.""" + """Asserts model and clone contain the same checkpointable objects.""" # TODO(fchollet, kathywu): make sure this works in eager mode. return True def load_keras_model(saved_model_path): - """Load a keras.Model from SavedModel. + """Loads a keras.Model from SavedModel. load_model reinstantiates model state by: 1) loading model topology from json (this will eventually come diff --git a/tensorflow/contrib/saved_model/python/saved_model/keras_saved_model_test.py b/tensorflow/contrib/saved_model/python/saved_model/keras_saved_model_test.py index 93d73e1b484ed810fb347b13e95022dfca3584c2..fbf8138493362d4a3c8a75e1ee1bb2fbe8096499 100644 --- a/tensorflow/contrib/saved_model/python/saved_model/keras_saved_model_test.py +++ b/tensorflow/contrib/saved_model/python/saved_model/keras_saved_model_test.py @@ -29,7 +29,9 @@ from tensorflow.python import keras from tensorflow.python.client import session from tensorflow.python.eager import context from tensorflow.python.estimator import model_fn as model_fn_lib +from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_spec from tensorflow.python.framework import test_util from tensorflow.python.keras.engine import training from tensorflow.python.keras.utils import tf_utils @@ -215,7 +217,7 @@ class LayerWithLearningPhase(keras.engine.base_layer.Layer): return input_shape -def functional_model(uses_learning_phase): +def functional_model(uses_learning_phase=True): inputs = keras.layers.Input(shape=(3,)) x = keras.layers.Dense(2)(inputs) x = keras.layers.Dense(3)(x) @@ -224,7 +226,7 @@ def functional_model(uses_learning_phase): return keras.models.Model(inputs, x) -def sequential_model(uses_learning_phase): +def sequential_model(uses_learning_phase=True): model = keras.models.Sequential() model.add(keras.layers.Dense(2, input_shape=(3,))) model.add(keras.layers.Dense(3)) @@ -233,7 +235,7 @@ def sequential_model(uses_learning_phase): return model -def sequential_model_without_input_shape(uses_learning_phase): +def sequential_model_without_input_shape(uses_learning_phase=True): model = keras.models.Sequential() model.add(keras.layers.Dense(2)) model.add(keras.layers.Dense(3)) @@ -242,10 +244,30 @@ def sequential_model_without_input_shape(uses_learning_phase): return model +class Subclassed(keras.models.Model): + + def __init__(self): + super(Subclassed, self).__init__() + self.dense1 = keras.layers.Dense(2) + self.dense2 = keras.layers.Dense(3) + + def call(self, inputs): + x = self.dense1(inputs) + x = self.dense2(x) + return x + + +def subclassed_model(): + return Subclassed() + + def load_model(sess, path, mode): tags = model_fn_lib.EXPORT_TAG_MAP[mode] - sig_def_key = (signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY - if mode == model_fn_lib.ModeKeys.PREDICT else mode) + if mode == model_fn_lib.ModeKeys.PREDICT: + sig_def_key = signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY + else: + sig_def_key = mode + meta_graph_def = loader_impl.load(sess, tags, path) inputs = { k: sess.graph.get_tensor_by_name(v.name) @@ -463,13 +485,54 @@ class TestModelSavedModelExport(test.TestCase, parameterized.TestCase): clone.compile(loss='mse', optimizer=keras.optimizers.RMSprop(lr=0.0001)) clone.train_on_batch(input_arr, target_arr) - def testSaveSeqModelWithoutInputShapesRaisesError(self): - """A Sequential model that hasn't been built should raise an error.""" + def testSaveSequentialModelWithoutInputShapes(self): model = sequential_model_without_input_shape(True) - with self.assertRaisesRegexp( - ValueError, 'must be built'): + # A Sequential model that hasn't been built should raise an error. + with self.assertRaisesRegexp(ValueError, 'Please build the model'): keras_saved_model.save_keras_model(model, '') + saved_model_path = self._save_model_dir() + output_path = keras_saved_model.save_keras_model( + model, saved_model_path, + input_signature=tensor_spec.TensorSpec(shape=(10, 11, 12, 13, 14), + dtype=dtypes.float32, + name='spec_input')) + + with session.Session(graph=ops.Graph()) as sess: + inputs, outputs, _ = load_model(sess, output_path, + model_fn_lib.ModeKeys.PREDICT) + self.assertEqual(5, inputs[next(iter(inputs.keys()))].shape.ndims) + self.assertEqual(5, outputs[next(iter(outputs.keys()))].shape.ndims) + self.assertEqual(3, outputs[next(iter(outputs.keys()))].shape[-1]) + + @test_util.run_v2_only + @parameterized.parameters( + { + 'model_builder': sequential_model_without_input_shape, + 'input_signature': [tensor_spec.TensorSpec(shape=[None, 3], + dtype=dtypes.float32)]}, + { + 'model_builder': subclassed_model, + 'input_signature': [tensor_spec.TensorSpec(shape=[None, 3], + dtype=dtypes.float32)]}) + def testServingOnly(self, model_builder, input_signature): + saved_model_path = self._save_model_dir() + input_arr = np.random.random((5, 3)).astype(np.float32) + model = model_builder() + ref_predict = model.predict(input_arr) + + output_path = keras_saved_model.save_keras_model( + model, saved_model_path, serving_only=True, + input_signature=input_signature) + + # Load predict graph, and test predictions + with session.Session(graph=ops.Graph()) as sess: + inputs, outputs, _ = load_model(sess, output_path, + model_fn_lib.ModeKeys.PREDICT) + predictions = sess.run(outputs[next(iter(outputs.keys()))], + {inputs[next(iter(inputs.keys()))]: input_arr}) + self.assertAllClose(ref_predict, predictions, atol=1e-05) + if __name__ == '__main__': test.main() diff --git a/tensorflow/contrib/seq2seq/python/kernel_tests/attention_wrapper_test.py b/tensorflow/contrib/seq2seq/python/kernel_tests/attention_wrapper_test.py index 922f21b98b35dfff19c8c605a25e89c5d2da8d98..d815f81f847ad79ddcc6c6ecf5c050598e185d8d 100644 --- a/tensorflow/contrib/seq2seq/python/kernel_tests/attention_wrapper_test.py +++ b/tensorflow/contrib/seq2seq/python/kernel_tests/attention_wrapper_test.py @@ -35,6 +35,7 @@ from tensorflow.python.ops import array_ops from tensorflow.python.ops import init_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import random_ops +from tensorflow.python.ops import rnn from tensorflow.python.ops import rnn_cell from tensorflow.python.ops import variables from tensorflow.python.ops import variable_scope as vs @@ -992,5 +993,67 @@ class AttentionWrapperTest(test.TestCase): expected_final_alignment_history=expected_final_alignment_history, name='testMultiAttention') + def testCustomizedAttention(self): + batch_size = 2 + max_time = 3 + num_units = 2 + memory = constant_op.constant([[[1., 1.], [2., 2.], [3., 3.]], + [[4., 4.], [5., 5.], [6., 6.]]]) + memory_sequence_length = constant_op.constant([3, 2]) + attention_mechanism = wrapper.BahdanauAttention(num_units, memory, + memory_sequence_length) + + # Sets all returned values to be all ones. + def _customized_attention(unused_attention_mechanism, unused_cell_output, + unused_attention_state, unused_attention_layer): + """Customized attention. + + Returns: + attention: `Tensor` of shape [batch_size, num_units], attention output. + alignments: `Tensor` of shape [batch_size, max_time], sigma value for + each input memory (prob. function of input keys). + next_attention_state: A `Tensor` representing the next state for the + attention. + """ + attention = array_ops.ones([batch_size, num_units]) + alignments = array_ops.ones([batch_size, max_time]) + next_attention_state = alignments + return attention, alignments, next_attention_state + + attention_cell = wrapper.AttentionWrapper( + rnn_cell.LSTMCell(2), + attention_mechanism, + attention_layer_size=None, # don't use attention layer. + output_attention=False, + alignment_history=(), + attention_fn=_customized_attention, + name='attention') + self.assertEqual(num_units, attention_cell.output_size) + + initial_state = attention_cell.zero_state( + batch_size=2, dtype=dtypes.float32) + source_input_emb = array_ops.ones([2, 3, 2]) + source_input_length = constant_op.constant([3, 2]) + + # 'state' is a tuple of + # (cell_state, h, attention, alignments, alignment_history, attention_state) + output, state = rnn.dynamic_rnn( + attention_cell, + inputs=source_input_emb, + sequence_length=source_input_length, + initial_state=initial_state, + dtype=dtypes.float32) + + with self.session() as sess: + sess.run(variables.global_variables_initializer()) + output_value, state_value = sess.run([output, state], feed_dict={}) + self.assertAllEqual(np.array([2, 3, 2]), output_value.shape) + self.assertAllClose(np.array([[1., 1.], [1., 1.]]), state_value.attention) + self.assertAllClose( + np.array([[1., 1., 1.], [1., 1., 1.]]), state_value.alignments) + self.assertAllClose( + np.array([[1., 1., 1.], [1., 1., 1.]]), state_value.attention_state) + + if __name__ == '__main__': test.main() diff --git a/tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py b/tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py index 77e9f848b137911b53e1b4df5dd740fe38af55bb..60ec3efffe771a3a6d6f36ed4b51a34ef9509612 100644 --- a/tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py +++ b/tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py @@ -1088,7 +1088,8 @@ class AttentionWrapper(rnn_cell_impl.RNNCell): output_attention=True, initial_cell_state=None, name=None, - attention_layer=None): + attention_layer=None, + attention_fn=None): """Construct the `AttentionWrapper`. **NOTE** If you are using the `BeamSearchDecoder` with a cell wrapped in @@ -1132,7 +1133,9 @@ class AttentionWrapper(rnn_cell_impl.RNNCell): feed the context and cell output into the attention layer to generate attention at each time step. If attention_mechanism is a list, attention_layer_size must be a list of the same length. If - attention_layer is set, this must be None. + attention_layer is set, this must be None. If attention_fn is set, + it must guaranteed that the outputs of attention_fn also meet the + above requirements. alignment_history: Python boolean, whether to store alignment history from all time steps in the final output state (currently stored as a time major `TensorArray` on which you must call `stack()`). @@ -1158,6 +1161,12 @@ class AttentionWrapper(rnn_cell_impl.RNNCell): the context as attention at each time step. If attention_mechanism is a list, attention_layer must be a list of the same length. If attention_layers_size is set, this must be None. + attention_fn: An optional callable function that allows users to provide + their own customized attention function, which takes input + (attention_mechanism, cell_output, attention_state, attention_layer) and + outputs (attention, alignments, next_attention_state). If provided, + the attention_layer_size should be the size of the outputs of + attention_fn. Raises: TypeError: `attention_layer_size` is not None and (`attention_mechanism` @@ -1240,6 +1249,10 @@ class AttentionWrapper(rnn_cell_impl.RNNCell): tensor_shape.dimension_value(attention_mechanism.values.shape[-1]) for attention_mechanism in attention_mechanisms) + if attention_fn is None: + attention_fn = _compute_attention + self._attention_fn = attention_fn + self._cell = cell self._attention_mechanisms = attention_mechanisms self._cell_input_fn = cell_input_fn @@ -1443,7 +1456,7 @@ class AttentionWrapper(rnn_cell_impl.RNNCell): all_attention_states = [] maybe_all_histories = [] for i, attention_mechanism in enumerate(self._attention_mechanisms): - attention, alignments, next_attention_state = _compute_attention( + attention, alignments, next_attention_state = self._attention_fn( attention_mechanism, cell_output, previous_attention_state[i], self._attention_layers[i] if self._attention_layers else None) alignment_history = previous_alignment_history[i].write( diff --git a/tensorflow/contrib/seq2seq/python/ops/beam_search_decoder.py b/tensorflow/contrib/seq2seq/python/ops/beam_search_decoder.py index ab36848f13ab3078cd232c18f140188e12db703b..8f8f057702951094758b277ce060955f3dc6e99d 100644 --- a/tensorflow/contrib/seq2seq/python/ops/beam_search_decoder.py +++ b/tensorflow/contrib/seq2seq/python/ops/beam_search_decoder.py @@ -921,6 +921,7 @@ def _get_scores(log_probs, sequence_lengths, length_penalty_weight, """ length_penalty_ = _length_penalty( sequence_lengths=sequence_lengths, penalty_factor=length_penalty_weight) + length_penalty_ = math_ops.cast(length_penalty_, dtype=log_probs.dtype) scores = log_probs / length_penalty_ coverage_penalty_weight = ops.convert_to_tensor( diff --git a/tensorflow/contrib/tensorrt/README.md b/tensorflow/contrib/tensorrt/README.md index caf8b6db0dc0a220d593f9c0afc9464ca51a1e05..a9c2ad78a3db409e6e8669c48c4df37c8db19c4b 100644 --- a/tensorflow/contrib/tensorrt/README.md +++ b/tensorflow/contrib/tensorrt/README.md @@ -1,8 +1,46 @@ -# Using TensorRT in TensorFlow +# Using TensorRT in TensorFlow (TF-TRT) -This module provides necessary bindings and introduces TRT_engine_op operator -that wraps a subgraph in TensorRT. This is still a work in progress but should -be useable with most common graphs. +This module provides necessary bindings and introduces `TRTEngineOp` operator +that wraps a subgraph in TensorRT. This module is under active development. + +## Installing TF-TRT + +Currently TensorFlow nightly builds include TF-TRT by default, which means you +don't need to install TF-TRT separately. You can pull the latest TF containers +from docker hub or install the latest TF pip package to get access to the latest +TF-TRT. + +If you want to use TF-TRT on NVIDIA Jetson platform, you can find the download +links for the relevant TensorFlow pip packages here: +https://docs.nvidia.com/deeplearning/dgx/index.html#installing-frameworks-for-jetson + +## Installing TensorRT + +In order to make use of TF-TRT, you will need a local installation of TensorRT. +Installation instructions for compatibility with TensorFlow are provided on the +[TensorFlow GPU support](https://www.tensorflow.org/install/gpu) guide. + +## Examples + +You can find example scripts for running inference on deep learning models in +this repository: https://github.com/tensorflow/tensorrt + +We have used these examples to verify the accuracy and performance of TF-TRT. +For more information see +[Verified Models](https://docs.nvidia.com/deeplearning/dgx/integrate-tf-trt/index.html#verified-models). + +## Documentation + +[TF-TRT documentation](https://docs.nvidia.com/deeplearning/dgx/integrate-tf-trt/index.html) +gives an overview of the supported functionalities, provides tutorials and +verified models, explains best practices with troubleshooting guides. + +## Tests + +TF-TRT includes both Python tests and C++ unit tests. Most of Python tests are +located in the test directory and they can be executed using `bazel test` or +directly with the Python command. Most of the C++ unit tests are used to test +the conversion functions that convert each TF op to a number of TensorRT layers. ## Compilation @@ -18,12 +56,3 @@ bazel build --config=cuda --config=opt //tensorflow/tools/pip_package:build_pip_ bazel-bin/tensorflow/tools/pip_package/build_pip_package /tmp/ ``` -After the installation of tensorflow package, TensorRT transformation will be -available. An example use can be found in test/test_tftrt.py script - -## Installing TensorRT 3.0.4 - -In order to make use of TensorRT integration, you will need a local installation -of TensorRT 3.0.4 from the [NVIDIA Developer website](https://developer.nvidia.com/tensorrt). -Installation instructions for compatibility with TensorFlow are provided on the -[TensorFlow GPU support](https://www.tensorflow.org/install/gpu) guide. diff --git a/tensorflow/contrib/tensorrt/convert/convert_graph.cc b/tensorflow/contrib/tensorrt/convert/convert_graph.cc index ae211a93c3279ff1d6de2f9c9a4b849fc8cd578d..746514b930c6c4c602c727a51313a8c5da271fa6 100644 --- a/tensorflow/contrib/tensorrt/convert/convert_graph.cc +++ b/tensorflow/contrib/tensorrt/convert/convert_graph.cc @@ -89,51 +89,52 @@ Status TrtCandidateSelector::IsTensorRTCandidate(const tensorflow::Node* node) { // TODO(laigd): move this set to TrtNodeValidator where it should belong. // LINT.IfChange static const std::set candidate_ops = { - "Identity", - "Snapshot", - "Const", - "Conv2D", - "MaxPool", - "BiasAdd", - "Relu", - "Sigmoid", - "Tanh", + "Abs", "Add", - "Mul", - "Sub", - "Rsqrt", - "Pad", - "Mean", "AvgPool", + "BatchMatMul", + "BiasAdd", "ConcatV2", + "Const", + "Conv2D", "DepthwiseConv2dNative", - "FusedBatchNorm", - "FusedBatchNormV2", "Div", - "RealDiv", - "Rsqrt", - "Reciprocal", "Exp", + "ExpandDims", + "FusedBatchNorm", + "FusedBatchNormV2", + "Identity", "Log", - "Sqrt", - "Abs", - "Neg", - "Transpose", - "Reshape", "MatMul", - "BatchMatMul", - "Softmax", - "Minimum", - "Maximum", - "TopKV2", - "Sum", - "Prod", "Max", + "MaxPool", + "Maximum", + "Mean", "Min", + "Minimum", + "Mul", + "Neg", + "Pad", + "Prod", + "RealDiv", + "Reciprocal", + "Relu", "Relu6", + "Reshape", + "Rsqrt", + "Rsqrt", + "Sigmoid", + "Snapshot", + "Softmax", + "Sqrt", "Square", - "ExpandDims", "Squeeze", + "StridedSlice", + "Sub", + "Sum", + "Tanh", + "TopKV2", + "Transpose", }; bool is_supported_op_type = (candidate_ops.count(node->type_string()) || @@ -322,6 +323,13 @@ tensorflow::Status ConvertGraphDefToTensorRT( return Status::OK(); } +struct EdgePtrCompare { + bool operator()(const tensorflow::Edge* lhs, + const tensorflow::Edge* rhs) const { + return lhs->id() < rhs->id(); + } +}; + // Function to get subsegment information structure. tensorflow::Status GetEngineInfo( const tensorflow::Graph* g, @@ -360,8 +368,12 @@ tensorflow::Status GetEngineInfo( } const int node_id = node->id(); subgraph_node_ids.push_back(node_id); - // Create input connections. - for (const auto edge : node->in_edges()) { + // Create input connections. Sort edges first to make determnistic since + // in_edges is a set of pointers. + std::vector in_edges(node->in_edges().begin(), + node->in_edges().end()); + std::sort(in_edges.begin(), in_edges.end(), EdgePtrCompare()); + for (const auto edge : in_edges) { auto input_node = edge->src(); if (input_node->IsSource() || segment_nodes.count(input_node->name())) { continue; @@ -409,8 +421,12 @@ tensorflow::Status GetEngineInfo( node_id, edge->dst_input(), /*input_edge=*/true, port); } } - // Create output connections. - for (const auto edge : node->out_edges()) { + // Create output connections. Sort edges first to make determnistic since + // out_edges is a set of pointers. + std::vector out_edges(node->out_edges().begin(), + node->out_edges().end()); + std::sort(out_edges.begin(), out_edges.end(), EdgePtrCompare()); + for (const auto edge : out_edges) { auto output_node = edge->dst(); if (output_node->IsSink() || segment_nodes.count(output_node->name())) { continue; diff --git a/tensorflow/contrib/tensorrt/convert/convert_nodes.cc b/tensorflow/contrib/tensorrt/convert/convert_nodes.cc index 777a80bbc4da7a260cf85d0a7bc5ec16f4cd3cab..adf8831b960172fc29b5d631e5b0533318d4764d 100644 --- a/tensorflow/contrib/tensorrt/convert/convert_nodes.cc +++ b/tensorflow/contrib/tensorrt/convert/convert_nodes.cc @@ -632,6 +632,11 @@ bool TFAttrs::get(const string& key) const { return this->at(key)->b(); } +template <> +int TFAttrs::get(const string& key) const { + return this->at(key)->i(); +} + // TODO(jie): reorder4 & reorder2 should be merged? // TODO(aaroey): fix the order of parameters. template @@ -1533,6 +1538,24 @@ enum class ConvolutionType { DEFAULT, DEPTHWISE_CONV }; tensorflow::Status ConvertConv2DHelper(OpConverterParams* params, int group) { const auto& inputs = params->inputs; const auto& node_def = params->node_def; + if (inputs.at(0).is_weights()) { + return tensorflow::errors::Unimplemented( + node_def.op(), " is only implemented for tensors, not weights, at ", + node_def.name()); + } + if (inputs.at(1).is_tensor()) { + return tensorflow::errors::Unimplemented("Kernel for ", node_def.op(), + " must be constant weights, at ", + node_def.name()); + } + TRT_ShapedWeights weights_rsck = inputs.at(1).weights(); + VLOG(2) << "weight shape: " << weights_rsck.DebugString(); + if (weights_rsck.shape_.nbDims != 4) { + return tensorflow::errors::Internal( + "Conv2D expects kernel of dimension 4, at: " + node_def.name()); + } + if (params->validation_only) return tensorflow::Status::OK(); + const nvinfer1::ITensor* tensor = inputs.at(0).tensor(); TFAttrs attrs(node_def); @@ -1554,12 +1577,6 @@ tensorflow::Status ConvertConv2DHelper(OpConverterParams* params, int group) { if (num_groups == 0) num_groups = tensor_dim.d[0]; // depthwise convolution VLOG(2) << "groups count: " << num_groups; - TRT_ShapedWeights weights_rsck = inputs.at(1).weights(); - VLOG(2) << "weight shape: " << weights_rsck.DebugString(); - if (weights_rsck.shape_.nbDims != 4) { - return tensorflow::errors::Internal( - "Conv2D expects kernel of dimension 4, at: " + node_def.name()); - } if (params->converter->precision_mode() == FP16MODE) { weights_rsck = ConvertFP32ToFP16(params->weight_store, inputs.at(1).weights()); @@ -1646,7 +1663,7 @@ tensorflow::Status ConvertConv2DHelper(OpConverterParams* params, case ConvolutionType::DEPTHWISE_CONV: return ConvertConv2DHelper(params, 0); } - return tensorflow::errors::Unimplemented("unsupported convolution type at, " + + return tensorflow::errors::Unimplemented("Unsupported convolution type, at ", params->node_def.name()); } @@ -2016,6 +2033,245 @@ tensorflow::Status ConvertSqueeze(OpConverterParams* params) { return tensorflow::Status::OK(); } +// Gets the bounds (start or end) from the weights of a StridedSlice op. +tensorflow::Status GetStridedSliceBound(const std::vector& input_dims, + const TRT_ShapedWeights& bound_weights, + int mask, bool begin, string node_name, + std::vector* output_bound) { + const string bound_name = (begin) ? "begin" : "end"; + const int* weights_ptr = static_cast(bound_weights.GetValues()); + *output_bound = + std::vector(weights_ptr, weights_ptr + bound_weights.count()); + if (output_bound->size() != input_dims.size()) { + return tensorflow::errors::InvalidArgument( + "StridedSlice \"", bound_name, "\" specified ", + std::to_string(output_bound->size()), " dimensions, but input rank is ", + std::to_string(input_dims.size()), ", at ", node_name); + } + for (int i = 0; i < output_bound->size(); i++) { + if ((1 << i) & mask) { + // Apply mask. + (*output_bound)[i] = (begin) ? 0 : input_dims[i]; + // Masked bound will always result in a valid, non-negative bound, so we + // don't need the following checks. For the common case of using masks on + // a undefined batch dim (-1), we specifically don't want to do the + // following checks because they will erroneously detect an out of range + // bound or try to correct the negative value. + continue; + } + // Make sure bound is valid. + if (((*output_bound)[i] < -input_dims[i]) || + ((*output_bound)[i] > input_dims[i])) { + return tensorflow::errors::InvalidArgument( + bound_name, " value of ", std::to_string((*output_bound)[i]), + " for StridedSlice is invalid, must be in the range " + "[-dim_size(i), dim_size(i)], at ", + node_name); + } + // Convert negative values to their positive equivalent. + if ((*output_bound)[i] < 0) { + (*output_bound)[i] += input_dims[i]; + } + } + return tensorflow::Status::OK(); +} + +tensorflow::Status ConvertStridedSlice(OpConverterParams* params) { + const auto& inputs = params->inputs; + const auto& node_def = params->node_def; + if (inputs.size() != 4) { + return tensorflow::errors::InvalidArgument( + "StridedSlice expects 4 inputs, at ", node_def.name()); + } + if (!inputs.at(1).is_weights() || !inputs.at(2).is_weights() || + !inputs.at(3).is_weights()) { + return tensorflow::errors::InvalidArgument( + "StridedSlice expects weights for begin, end, and strides, at ", + node_def.name()); + } + if (!inputs.at(0).is_tensor()) { + return tensorflow::errors::Unimplemented( + "StridedSlice is only implemented for tensors, at ", node_def.name()); + } + // Get input dims. + nvinfer1::Dims dims = inputs.at(0).GetTrtDims(); + std::vector input_dims(dims.d, dims.d + dims.nbDims); + if (inputs.at(0).is_tensor()) { + // Temporarily add batch dimension so that indexes line up properly. + input_dims.insert(input_dims.begin(), inputs.at(0).batch_size()); + } + if (input_dims.size() > 4) { + return tensorflow::errors::Unimplemented( + "StridedSlice is not implemented for tensors with rank > 4, at ", + node_def.name()); + } + TFAttrs attrs(node_def); + // Get begin and end bounds per axis. + std::vector begin, end; + TF_RETURN_IF_ERROR(GetStridedSliceBound(input_dims, inputs.at(1).weights(), + attrs.get("begin_mask"), true, + node_def.name(), &begin)); + TF_RETURN_IF_ERROR(GetStridedSliceBound(input_dims, inputs.at(2).weights(), + attrs.get("end_mask"), false, + node_def.name(), &end)); + // Get strides per axis (must all be 1). + TRT_ShapedWeights stride_weights = inputs.at(3).weights(); + const int* stride_weights_ptr = static_cast(stride_weights.GetValues()); + std::vector strides(stride_weights_ptr, + stride_weights_ptr + stride_weights.count()); + for (int x : strides) { + if (x != 1) { + return tensorflow::errors::Unimplemented( + "StridedSlice is only implemented for stride of 1, at ", + node_def.name()); + } + } + // Unsupported mask options. + for (const string& attr : + {"ellipsis_mask", "new_axis_mask", "shrink_axis_mask"}) { + int attr_val = attrs.get(attr); + if (attr_val != 0) { + return tensorflow::errors::Unimplemented( + attr, " is not supported for StridedSlice, at ", node_def.name()); + } + } + + nvinfer1::ITensor* tensor = + const_cast(inputs.at(0).tensor()); + // Reshape if necessary to 4-D, since IPaddingLayer requires a 4-D input. + const bool need_reshape = (input_dims.size() != 4); + int reshape_dims_added = 0; + nvinfer1::Dims reshape_dims; + if (need_reshape) { + // Add new dims after batch dim until tensor is 4D. + while (input_dims.size() < 4) { + input_dims.insert(input_dims.begin() + 1, 1); + begin.insert(begin.begin() + 1, 0); + end.insert(end.begin() + 1, 1); + reshape_dims_added++; + } + TF_RETURN_IF_ERROR(TensorShapeArrayToTrtDims(input_dims, &reshape_dims, + /*ignore_first_dim=*/true)); + } + // Find dimensions which need to be sliced. + std::vector pad_dims; + for (int i = 0; i < input_dims.size(); i++) { + if ((begin[i] != 0) || (end[i] != input_dims[i])) { + if (i == 0) { + return tensorflow::errors::Unimplemented( + "StridedSlice can't modify batch dim, at ", node_def.name()); + } else if ((end[i] - begin[i]) < 0) { + return tensorflow::errors::InvalidArgument( + "New size of sliced dimension is negative, at ", node_def.name()); + } + pad_dims.push_back(i); + } + } + if (pad_dims.size() == 0) { + // No dimensions are changed. We could create a padding layer anyway with + // values of 0. + if (params->validation_only) return Status::OK(); + params->outputs->push_back(inputs.at(0)); + return tensorflow::Status::OK(); + } else if (pad_dims.size() == 1) { + // Only one dim is modified but we have to have 2, mark a second dim which + // will have padding of 0. The dim we add is chosen to avoid an unecessary + // transpose. + if (pad_dims[0] != 2) { + pad_dims.push_back(2); + } else { + pad_dims.push_back(3); + } + } else if (pad_dims.size() > 2) { + return tensorflow::errors::Unimplemented( + "StridedSlice can only modify 2 dimensions, at ", node_def.name()); + } + std::sort(pad_dims.begin(), pad_dims.end()); + // Convert to pre/post padding values. Since TRT does not have a StridedSlice + // or Slice layer, we instead create an IPaddingLayer with negative padding. + nvinfer1::DimsHW pre_padding, post_padding; + for (int i = 0; i < pad_dims.size(); i++) { + const int axis = pad_dims[i]; + pre_padding.d[i] = -begin[axis]; + post_padding.d[i] = end[axis] - input_dims[axis]; + } + + // IPaddingLayer will always apply the padding to dims 2,3 (input format is + // NCHW). + const bool need_transpose = !(pad_dims[0] == 2 && pad_dims[1] == 3); + std::vector transpose_order(input_dims.size()); + std::vector inv_transpose_order(input_dims.size()); + if (need_transpose) { + if (pad_dims[0] == 1 && pad_dims[1] == 3) { + transpose_order = {0, 2, 1, 3}; + inv_transpose_order = {0, 2, 1, 3}; + } else if (pad_dims[0] == 1 && pad_dims[1] == 2) { + transpose_order = {0, 3, 1, 2}; + inv_transpose_order = {0, 2, 3, 1}; + } + } + if (params->validation_only) return Status::OK(); + + // Start conversion. + if (need_reshape) { + const nvinfer1::ITensor* output_tensor = nullptr; + TF_RETURN_IF_ERROR(params->converter->PrepareTensorForShape( + inputs.at(0), reshape_dims, &output_tensor)); + tensor = const_cast(output_tensor); + } + if (need_transpose) { + const nvinfer1::ITensor* output_tensor = nullptr; + TF_RETURN_IF_ERROR(params->converter->TransposeTensor( + tensor, transpose_order, &output_tensor)); + tensor = const_cast(output_tensor); + } + + // Add padding layer + nvinfer1::IPaddingLayer* layer = params->converter->network()->addPadding( + *const_cast(tensor), pre_padding, post_padding); + TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name()); + params->converter->MarkQuantizationRangesAsInferrable(tensor, + layer->getOutput(0)); + tensor = layer->getOutput(0); + + // Restore transpose + if (need_transpose) { + const nvinfer1::ITensor* output_tensor = nullptr; + TF_RETURN_IF_ERROR(params->converter->TransposeTensor( + tensor, inv_transpose_order, &output_tensor)); + tensor = const_cast(output_tensor); + } + // Restore reshape + if (need_reshape) { + // Calculate output dimensions + for (int i = 0; i < pad_dims.size(); i++) { + const int axis = pad_dims[i]; + input_dims[axis] = end[axis] - begin[axis]; + } + // Remove added 1 dimensions + for (int i = 0; i < reshape_dims_added; i++) { + int value = input_dims[1]; + if (value != 1) { + return tensorflow::errors::Internal( + "StridedSlice error when reshaping, at ", node_def.name()); + } + input_dims.erase(input_dims.begin() + 1); + } + + nvinfer1::Dims new_dims; + TF_RETURN_IF_ERROR(TensorShapeArrayToTrtDims(input_dims, &new_dims, + /*ignore_first_dim=*/true)); + const nvinfer1::ITensor* output_tensor = nullptr; + TF_RETURN_IF_ERROR(params->converter->PrepareTensorForShape( + TRT_TensorOrWeights(tensor), new_dims, &output_tensor)); + tensor = const_cast(output_tensor); + } + + params->outputs->push_back( + TRT_TensorOrWeights(const_cast(tensor))); + return tensorflow::Status::OK(); +} + tensorflow::Status ConvertConv2D(OpConverterParams* params) { return ConvertConv2DHelper(params, ConvolutionType::DEFAULT); } @@ -2027,9 +2283,29 @@ tensorflow::Status ConvertConv2DDepthwise(OpConverterParams* params) { tensorflow::Status ConvertPool(OpConverterParams* params) { const auto& inputs = params->inputs; const auto& node_def = params->node_def; - const nvinfer1::ITensor* tensor = inputs.at(0).tensor(); + if (inputs.at(0).is_weights()) { + return tensorflow::errors::Unimplemented( + node_def.op(), " is only implemented for tensors, not weights, at ", + node_def.name()); + } + nvinfer1::PoolingType type; + if (node_def.op() == "MaxPool") { + type = nvinfer1::PoolingType::kMAX; + } else if (node_def.op() == "AvgPool") { + type = nvinfer1::PoolingType::kAVERAGE; + } else { + return tensorflow::errors::Unimplemented( + "Unsupported pooling type: ", node_def.op(), ", at ", node_def.name()); + } TFAttrs attrs(node_def); + const string padding_type = attrs.get("padding"); + if ((padding_type != "SAME") && (padding_type != "VALID")) { + return tensorflow::errors::Unimplemented( + "Unsupported padding type: ", padding_type, ", at ", node_def.name()); + } + if (params->validation_only) return Status::OK(); + const nvinfer1::ITensor* tensor = inputs.at(0).tensor(); int h_index = 2; int w_index = 3; const auto data_format = attrs.get("data_format"); @@ -2040,16 +2316,6 @@ tensorflow::Status ConvertPool(OpConverterParams* params) { const_cast(tensor), {0, 3, 1, 2}, &tensor)); } - nvinfer1::PoolingType type; - if (node_def.op() == "MaxPool") { - type = nvinfer1::PoolingType::kMAX; - } else if (node_def.op() == "AvgPool") { - type = nvinfer1::PoolingType::kAVERAGE; - } else { - return tensorflow::errors::Unimplemented("Unsupported pool type: ", - node_def.op()); - } - const auto tf_stride = attrs.get>("strides"); const nvinfer1::DimsHW stride(tf_stride[h_index], tf_stride[w_index]); @@ -2058,7 +2324,6 @@ tensorflow::Status ConvertPool(OpConverterParams* params) { auto tensor_dim = tensor->getDimensions(); std::vector> padding; - const string padding_type = attrs.get("padding"); if (padding_type == "SAME") { // This is NCHW tensor with no batch dimension. // 1 -> h @@ -2068,9 +2333,6 @@ tensorflow::Status ConvertPool(OpConverterParams* params) { {static_cast(tensor_dim.d[1]), static_cast(tensor_dim.d[2])}); } else if (padding_type == "VALID") { padding = {{0, 0}, {0, 0}}; - } else { - return tensorflow::errors::Unimplemented("Unsupported padding type: ", - padding_type); } if (padding[0].first != padding[0].second || @@ -2837,6 +3099,7 @@ tensorflow::Status ConvertPad(OpConverterParams* params) { return tensorflow::errors::Unimplemented( "Padding layer does not support padding on dimension 1 and 3 yet"); } + if (params->validation_only) return Status::OK(); bool legit_pad = true; nvinfer1::DimsHW pre_padding(0, 0); @@ -2940,6 +3203,7 @@ tensorflow::Status ConvertConcat(OpConverterParams* params) { inputs_vec.push_back(tensor_i); } + if (params->validation_only) return tensorflow::Status::OK(); // nvinfer1::ITensor const* tensor = inputs.at(0).tensor(); nvinfer1::IConcatenationLayer* layer = @@ -2961,12 +3225,35 @@ tensorflow::Status ConvertFusedBatchNorm(OpConverterParams* params) { auto data_format = attrs.get("data_format"); if (data_format != "NCHW") { return tensorflow::errors::Unimplemented( - "only data_format=NCHW is supported, at " + node_def.name()); + node_def.op(), " only supports data_format=NCHW, at ", node_def.name()); } bool is_training = attrs.get("is_training"); if (is_training) { + // Trying to use batchnorm in training mode is a very common problem. + // Because the error message will only be printed in VLOG(1) by the + // segmenter, we issue a special warning so that users will actually see it. + LOG(WARNING) << node_def.op() << " only supports is_training=false. If you " + << "are using Keras, please call " + << "keras.backend.set_learning_phase(0) before constructing " + << "your model. At " << node_def.name(); return tensorflow::errors::Unimplemented( - "only is_training=false is supported, at " + node_def.name()); + node_def.op(), " only supports is_training=false, at ", + node_def.name()); + } + if (inputs.at(0).is_weights()) { + return tensorflow::errors::Unimplemented( + node_def.op(), + " is only implemented for tensor inputs, not weights, at ", + node_def.name()); + } + for (int i = 1; i < 5; i++) { + if (inputs.at(i).is_tensor()) { + return tensorflow::errors::Unimplemented( + node_def.op(), + " must have constant inputs for scale, offset, mean and variance, " + "at ", + node_def.name()); + } } nvinfer1::ITensor const* tensor = inputs.at(0).tensor(); @@ -2981,7 +3268,7 @@ tensorflow::Status ConvertFusedBatchNorm(OpConverterParams* params) { for (int i = 1; i < 5; i++) { if (inputs.at(i).weights().type_ != parameter_type) { return tensorflow::errors::Unimplemented( - "Inconsistent parameter type for batchnormis not supported, at: " + + "Inconsistent parameter type for batchnorm is not supported, at: " + node_def.name()); } } @@ -3001,6 +3288,8 @@ tensorflow::Status ConvertFusedBatchNorm(OpConverterParams* params) { "Inconsistent batchnorm parameter count, at: " + node_def.name()); } } + if (params->validation_only) return Status::OK(); + // We could technically have two weights with different shape. // that requires two addScale op, arguably less performant TRT_ShapedWeights combined_scale_weights = @@ -3286,14 +3575,19 @@ static void RegisterValidatableOpConverters( std::unordered_map* registration) { // TODO(laigd): support all op types. (*registration)["BiasAdd"] = ConvertBiasAdd; + (*registration)["ConcatV2"] = ConvertConcat; (*registration)["Const"] = ConvertConst; - (*registration)["Transpose"] = ConvertTranspose; - (*registration)["Reshape"] = ConvertReshape; + (*registration)["Conv2D"] = ConvertConv2D; + (*registration)["DepthwiseConv2dNative"] = ConvertConv2DDepthwise; + (*registration)["ExpandDims"] = ConvertExpandDims; (*registration)["MatMul"] = ConvertMatMul; + (*registration)["Pad"] = ConvertPad; (*registration)["Relu6"] = ConvertRelu6; + (*registration)["Reshape"] = ConvertReshape; (*registration)["Square"] = ConvertSquare; - (*registration)["ExpandDims"] = ConvertExpandDims; (*registration)["Squeeze"] = ConvertSqueeze; + (*registration)["StridedSlice"] = ConvertStridedSlice; + (*registration)["Transpose"] = ConvertTranspose; for (auto quantization_op_type : {"QuantizeAndDequantizeV2", "QuantizeAndDequantizeV3", @@ -3307,6 +3601,12 @@ static void RegisterValidatableOpConverters( for (auto activation_op_type : {"Relu", "Sigmoid", "Tanh"}) { (*registration)[activation_op_type] = ConvertActivation; } + for (auto pool_op_type : {"AvgPool", "MaxPool"}) { + (*registration)[pool_op_type] = ConvertPool; + } + for (auto normalization_op_type : {"FusedBatchNorm", "FusedBatchNormV2"}) { + (*registration)[normalization_op_type] = ConvertFusedBatchNorm; + } } void TrtNodeValidator::RegisterOpValidators() { @@ -3315,21 +3615,10 @@ void TrtNodeValidator::RegisterOpValidators() { void Converter::RegisterOpConverters() { RegisterValidatableOpConverters(&op_registry_); - - op_registry_["Conv2D"] = ConvertConv2D; - op_registry_["DepthwiseConv2dNative"] = ConvertConv2DDepthwise; - op_registry_["MaxPool"] = ConvertPool; - op_registry_["AvgPool"] = ConvertPool; // TODO(ben,jie): this is a temp hack. op_registry_["Identity"] = ConvertIdentity; // Identity should be removed op_registry_["Snapshot"] = ConvertIdentity; // Snapshot should be removed - op_registry_["Pad"] = ConvertPad; - - op_registry_["ConcatV2"] = ConvertConcat; - op_registry_["FusedBatchNorm"] = ConvertFusedBatchNorm; - op_registry_["FusedBatchNormV2"] = ConvertFusedBatchNorm; - op_registry_["Rsqrt"] = ConvertUnary; op_registry_["Reciprocal"] = ConvertUnary; op_registry_["Exp"] = ConvertUnary; diff --git a/tensorflow/contrib/tensorrt/convert/convert_nodes_test.cc b/tensorflow/contrib/tensorrt/convert/convert_nodes_test.cc index c37a43dd5def9daf3c5d70720c6db2aab20db077..a2ddfbffa5b0d8c421bcfe054097a9e42b79fe8f 100644 --- a/tensorflow/contrib/tensorrt/convert/convert_nodes_test.cc +++ b/tensorflow/contrib/tensorrt/convert/convert_nodes_test.cc @@ -2129,7 +2129,6 @@ TEST_F(OpConverterTest, ConvertExpandDims) { auto expanddims = ops::ExpandDims(s.WithOpName("my_expanddims"), input, weights); const NodeDef& node_def = expanddims.operation.node()->def(); - { // Input is weights, should fail. Reset(); @@ -2349,6 +2348,277 @@ TEST_F(OpConverterTest, ConvertSqueeze) { } } +TEST_F(OpConverterTest, ConvertStridedSlice) { + { + // Input list is empty, should fail. + NodeDef node_def = MakeNodeDef("my_strided_slice", "StridedSlice", {}); + RunValidationAndConversion( + node_def, error::INVALID_ARGUMENT, + "StridedSlice expects 4 inputs, at my_strided_slice"); + } + + // Get nodedef for StridedSlice layer. + auto get_strided_slice_nodedef = + [](int begin_mask = 0, int end_mask = 0, int ellipsis_mask = 0, + int new_axis_mask = 0, int shrink_axis_mask = 0) -> NodeDef { + Scope s = Scope::NewRootScope(); + auto input = ops::Placeholder(s.WithOpName("input"), DT_FLOAT); + auto begin = ops::Placeholder(s.WithOpName("begin"), DT_INT32); + auto end = ops::Placeholder(s.WithOpName("end"), DT_INT32); + auto strides = ops::Placeholder(s.WithOpName("strides"), DT_INT32); + ops::StridedSlice::Attrs attrs = ops::StridedSlice::Attrs() + .BeginMask(begin_mask) + .EndMask(end_mask) + .EllipsisMask(ellipsis_mask) + .NewAxisMask(new_axis_mask) + .ShrinkAxisMask(shrink_axis_mask); + auto strided_slice = ops::StridedSlice(s.WithOpName("my_strided_slice"), + input, begin, end, strides, attrs); + return strided_slice.operation.node()->def(); + }; + + { + NodeDef node_def = get_strided_slice_nodedef(); + AddTestWeights("input", {1, 2, 3}, {1, 2, 3, 4, 5, 6}); + AddTestWeights("begin", {4}, {0, 0, 0, 0}); + AddTestWeights("end", {4}, {1, 1, 2, 3}); + AddTestWeights("strides", {4}, {1, 1, 1, 1}); + RunValidationAndConversion( + node_def, error::UNIMPLEMENTED, + "StridedSlice is only implemented for tensors, at my_strided_slice"); + } + { + // Begin, end, strides are tensors, should fail. + Reset(); + NodeDef node_def = get_strided_slice_nodedef(); + AddTestTensor("input", {1, 2, 3}); + AddTestTensor("begin", {4}); + AddTestTensor("end", {4}); + AddTestTensor("strides", {4}); + RunValidationAndConversion( + node_def, error::INVALID_ARGUMENT, + "StridedSlice expects weights for begin, end, and strides, at " + "my_strided_slice"); + } + { + // Non-zero ellipsis_mask, should fail. + Reset(); + NodeDef node_def = get_strided_slice_nodedef( + /*begin_mask=*/0, /*end_mask=*/0, /*ellipsis_mask=*/2, + /*new_axis_mask=*/0, /*shrink_axis_mask=*/0); + AddTestTensor("input", {1, 2, 3}); + AddTestWeights("begin", {4}, {0, 0, 0, 0}); + AddTestWeights("end", {4}, {1, 1, 2, 3}); + AddTestWeights("strides", {4}, {1, 1, 1, 1}); + RunValidationAndConversion( + node_def, error::UNIMPLEMENTED, + "ellipsis_mask is not supported for StridedSlice, at " + "my_strided_slice"); + } + { + // Modify batch dim, should fail. + Reset(); + NodeDef node_def = get_strided_slice_nodedef(); + AddTestTensor("input", {1, 2, 3}); + AddTestWeights("begin", {4}, {0, 0, 0, 0}); + AddTestWeights("end", {4}, {0, 1, 2, 3}); + AddTestWeights("strides", {4}, {1, 1, 1, 1}); + RunValidationAndConversion( + node_def, error::UNIMPLEMENTED, + "StridedSlice can't modify batch dim, at my_strided_slice"); + } + { + // Stride is not 1, should fail. + Reset(); + NodeDef node_def = get_strided_slice_nodedef(); + AddTestTensor("input", {1, 2, 3}); + AddTestWeights("begin", {4}, {0, 0, 0, 0}); + AddTestWeights("end", {4}, {1, 1, 2, 3}); + AddTestWeights("strides", {4}, {1, 2, -1, 3}); + RunValidationAndConversion(node_def, error::UNIMPLEMENTED, + "StridedSlice is only implemented for stride of " + "1, at my_strided_slice"); + } + { + // Begin out of bounds, should fail. + Reset(); + NodeDef node_def = get_strided_slice_nodedef(); + AddTestTensor("input", {1, 2, 3}); + AddTestWeights("begin", {4}, {1, 2, 3, 4}); + AddTestWeights("end", {4}, {0, 1, 2, 3}); + AddTestWeights("strides", {4}, {1, 1, 1, 1}); + RunValidationAndConversion( + node_def, error::INVALID_ARGUMENT, + "begin value of 2 for StridedSlice is invalid, must be in the range " + "[-dim_size(i), dim_size(i)], at my_strided_slice"); + } + { + // End out of bounds, should fail. + Reset(); + NodeDef node_def = get_strided_slice_nodedef(); + AddTestTensor("input", {1, 2, 3}); + AddTestWeights("begin", {4}, {0, 0, 0, 0}); + AddTestWeights("end", {4}, {1, 2, 3, 4}); + AddTestWeights("strides", {4}, {1, 1, 1, 1}); + RunValidationAndConversion( + node_def, error::INVALID_ARGUMENT, + "end value of 2 for StridedSlice is invalid, must be in the range " + "[-dim_size(i), dim_size(i)], at my_strided_slice"); + } + { + // Size of sliced dim is negative, should fail. + Reset(); + NodeDef node_def = get_strided_slice_nodedef(); + AddTestTensor("input", {1, 2, 3}); + AddTestWeights("begin", {4}, {0, 0, 2, 0}); + AddTestWeights("end", {4}, {1, 1, 0, 3}); + AddTestWeights("strides", {4}, {1, 1, 1, 1}); + RunValidationAndConversion( + node_def, error::INVALID_ARGUMENT, + "New size of sliced dimension is negative, at my_strided_slice"); + } + + struct TestParams { + TestParams(const std::vector& input_dims, + const std::vector& expected_output_dims, + const std::vector& begin, const std::vector& end, + const std::vector& begin_mask, + const std::vector& end_mask, + const std::vector& expected_output) + : input_dims(input_dims), + expected_output_dims(expected_output_dims), + begin(begin), + end(end), + expected_output(expected_output) { + // Masks are provided in terms of vectors for readability. Convert them to + // binary here. + this->begin_mask = 0; + for (int i = 0; i < begin_mask.size(); i++) { + if (begin_mask[i]) this->begin_mask |= (1 << i); + } + this->end_mask = 0; + for (int i = 0; i < end_mask.size(); i++) { + if (end_mask[i]) this->end_mask |= (1 << i); + } + } + + std::vector input_dims; + std::vector expected_output_dims; + std::vector begin; + std::vector end; + int begin_mask; + int end_mask; + std::vector expected_output; + }; + + // Ok. + const int kStridedSliceOKCases = 18; + TestParams ok_params[kStridedSliceOKCases] = { + // 2D Crop. + TestParams{/*input_dims=*/{1, 2, 3}, /*expected_output_dims=*/{1, 1, 2}, + /*begin=*/{0, 0, 0, 0}, /*end=*/{0, 0, 1, 2}, + /*begin_mask=*/{0, 0, 0, 0}, /*end_mask=*/{1, 1, 0, 0}, + /*expected_output=*/{1, 2}}, + TestParams{/*input_dims=*/{1, 2, 3}, /*expected_output_dims=*/{1, 1, 2}, + /*begin=*/{0, 0, 1, 1}, /*end=*/{0, 0, 0, 0}, + /*begin_mask=*/{0, 0, 0, 0}, /*end_mask=*/{1, 1, 1, 1}, + /*expected_output=*/{5, 6}}, + TestParams{/*input_dims=*/{1, 2, 3}, /*expected_output_dims=*/{1, 1, 2}, + /*begin=*/{0, 0, 1, 1}, /*end=*/{0, 1, 2, 3}, + /*begin_mask=*/{0, 0, 0, 0}, /*end_mask=*/{1, 1, 0, 0}, + /*expected_output=*/{5, 6}}, + // 2D Crop, with transpose. + TestParams{/*input_dims=*/{2, 3, 1}, /*expected_output_dims=*/{1, 2, 1}, + /*begin=*/{0, 0, 0, 0}, /*end=*/{0, 1, 2, 1}, + /*begin_mask=*/{0, 0, 0, 0}, /*end_mask=*/{1, 0, 0, 0}, + /*expected_output=*/{1, 2}}, + TestParams{/*input_dims=*/{2, 3, 1}, /*expected_output_dims=*/{1, 2, 1}, + /*begin=*/{0, 1, 1, 0}, /*end=*/{0, 2, 3, 1}, + /*begin_mask=*/{0, 0, 0, 0}, /*end_mask=*/{1, 0, 0, 0}, + /*expected_output=*/{5, 6}}, + TestParams{/*input_dims=*/{2, 1, 3}, /*expected_output_dims=*/{1, 1, 2}, + /*begin=*/{0, 0, 0, 0}, /*end=*/{0, 1, 1, 2}, + /*begin_mask=*/{0, 0, 0, 0}, /*end_mask=*/{1, 0, 0, 0}, + /*expected_output=*/{1, 2}}, + TestParams{/*input_dims=*/{2, 1, 3}, /*expected_output_dims=*/{1, 1, 2}, + /*begin=*/{0, 1, 0, 1}, /*end=*/{0, 2, 1, 3}, + /*begin_mask=*/{0, 0, 0, 0}, /*end_mask=*/{1, 0, 0, 0}, + /*expected_output=*/{5, 6}}, + // 2D Crop, with reshape. + TestParams{/*input_dims=*/{2, 3}, /*expected_output_dims=*/{1, 2}, + /*begin=*/{0, 0, 0}, /*end=*/{0, 1, 2}, + /*begin_mask=*/{0, 0, 0}, /*end_mask=*/{1, 0, 0}, + /*expected_output=*/{1, 2}}, + TestParams{/*input_dims=*/{2, 3}, /*expected_output_dims=*/{1, 2}, + /*begin=*/{0, 1, 1}, /*end=*/{0, 0, 0}, + /*begin_mask=*/{0, 0, 0}, /*end_mask=*/{1, 1, 1}, + /*expected_output=*/{5, 6}}, + // 1D Crop. + TestParams{/*input_dims=*/{1, 2, 3}, /*expected_output_dims=*/{1, 2, 2}, + /*begin=*/{0, 0, 0, 0}, /*end=*/{0, 0, 0, 2}, + /*begin_mask=*/{0, 0, 0, 0}, /*end_mask=*/{1, 1, 1, 0}, + /*expected_output=*/{1, 2, 4, 5}}, + TestParams{/*input_dims=*/{1, 2, 3}, /*expected_output_dims=*/{1, 1, 3}, + /*begin=*/{0, 0, 1, 0}, /*end=*/{0, 0, 0, 0}, + /*begin_mask=*/{0, 0, 0, 0}, /*end_mask=*/{1, 1, 1, 1}, + /*expected_output=*/{4, 5, 6}}, + // 1D Crop, with transpose. + TestParams{/*input_dims=*/{2, 3, 1}, /*expected_output_dims=*/{1, 3, 1}, + /*begin=*/{0, 0, 0, 0}, /*end=*/{0, 1, 0, 0}, + /*begin_mask=*/{0, 0, 0, 0}, /*end_mask=*/{1, 0, 1, 1}, + /*expected_output=*/{1, 2, 3}}, + TestParams{/*input_dims=*/{2, 3, 1}, /*expected_output_dims=*/{1, 3, 1}, + /*begin=*/{0, 1, 0, 0}, /*end=*/{0, 0, 0, 0}, + /*begin_mask=*/{0, 0, 0, 0}, /*end_mask=*/{1, 1, 1, 1}, + /*expected_output=*/{4, 5, 6}}, + // 1D Crop, with reshape. + TestParams{/*input_dims=*/{6}, /*expected_output_dims=*/{3}, + /*begin=*/{0, 0}, /*end=*/{0, 3}, + /*begin_mask=*/{0, 0}, /*end_mask=*/{1, 0}, + /*expected_output=*/{1, 2, 3}}, + TestParams{/*input_dims=*/{1, 6}, /*expected_output_dims=*/{1, 3}, + /*begin=*/{0, 0, 2}, /*end=*/{0, 0, 5}, + /*begin_mask=*/{0, 0, 0}, /*end_mask=*/{1, 1, 0}, + /*expected_output=*/{3, 4, 5}}, + TestParams{/*input_dims=*/{6, 1}, /*expected_output_dims=*/{3, 1}, + /*begin=*/{0, 2, 0}, /*end=*/{0, 5, 0}, + /*begin_mask=*/{0, 0, 0}, /*end_mask=*/{1, 0, 1}, + /*expected_output=*/{3, 4, 5}}, + // Negative axis. + TestParams{/*input_dims=*/{6, 1}, /*expected_output_dims=*/{3, 1}, + /*begin=*/{0, -6, 0}, /*end=*/{0, -3, 0}, + /*begin_mask=*/{0, 0, 0}, /*end_mask=*/{1, 0, 1}, + /*expected_output=*/{1, 2, 3}}, + TestParams{/*input_dims=*/{6, 1}, /*expected_output_dims=*/{5, 1}, + /*begin=*/{0, 0, 0}, /*end=*/{0, -1, 0}, + /*begin_mask=*/{0, 0, 0}, /*end_mask=*/{1, 0, 1}, + /*expected_output=*/{1, 2, 3, 4, 5}}, + }; + + for (int i = 0; i < kStridedSliceOKCases; i++) { + Reset(); + NodeDef node_def = get_strided_slice_nodedef(ok_params[i].begin_mask, + ok_params[i].end_mask); + AddTestTensor("input", ok_params[i].input_dims); + AddTestWeights("begin", + {static_cast(ok_params[i].begin.size())}, + ok_params[i].begin); + AddTestWeights("end", {static_cast(ok_params[i].end.size())}, + ok_params[i].end); + std::vector strides(ok_params[i].input_dims.size(), 1); + AddTestWeights("strides", {static_cast(strides.size())}, + strides); + RunValidationAndConversion(node_def); + + TRT_TensorOrWeights output; + TF_EXPECT_OK(GetTensorOrWeights("my_strided_slice", &output)); + std::vector output_data(ok_params[i].expected_output.size()); + BuildAndRun({{"input", {1, 2, 3, 4, 5, 6}}}, "my_strided_slice", + &output_data); + EXPECT_THAT(output_data, ElementsAreArray(ok_params[i].expected_output)); + } +} + } // namespace convert } // namespace tensorrt } // namespace tensorflow diff --git a/tensorflow/contrib/tensorrt/convert/trt_optimization_pass.cc b/tensorflow/contrib/tensorrt/convert/trt_optimization_pass.cc index c1688d4db88a270dcd202989f89a677ed10576d9..d57f2300f8e6e6ce79c538133da6bc5cf5ead2f5 100644 --- a/tensorflow/contrib/tensorrt/convert/trt_optimization_pass.cc +++ b/tensorflow/contrib/tensorrt/convert/trt_optimization_pass.cc @@ -226,8 +226,9 @@ tensorflow::Status TRTOptimizationPass::Optimize( tensorflow::tensorrt::convert::ConversionParams cp; if (use_calibration_ && precision_mode_ != INT8MODE) { - LOG(ERROR) << "Calibration with FP32 or FP16 is not implemented. " - << "Falling back to use_calibration = False."; + VLOG(1) << "Calibration with FP32 or FP16 is not implemented. " + << "Falling back to use_calibration = False." + << "Note that the default value of use_calibration is True."; use_calibration_ = false; } diff --git a/tensorflow/contrib/tensorrt/resources/trt_allocator_test.cc b/tensorflow/contrib/tensorrt/resources/trt_allocator_test.cc index ad6b1d7d4c57d696d3dee3b479733e152e669211..beb1284208e4c10ffe1d36ef411cf08f11dbcb78 100644 --- a/tensorflow/contrib/tensorrt/resources/trt_allocator_test.cc +++ b/tensorflow/contrib/tensorrt/resources/trt_allocator_test.cc @@ -48,11 +48,14 @@ TEST(TRTAllocatorTest, Align) { 513ul, 700ul, 12345ul, 1ul << 32}) { for (uint64_t alignment = 1; alignment <= space * 4; alignment *= 2) { for (const uintptr_t ptr_val : - {1ul, alignment == 1 ? 1ul : alignment - 1, alignment, alignment + 1, - alignment + (alignment / 2)}) { + {static_cast(1), + alignment == 1 ? static_cast(1) : alignment - 1, + alignment, alignment + 1, alignment + (alignment / 2)}) { if (ptr_val % alignment == 0) { for (const uint64_t size : - {1ul, space == 1 ? 1ul : space - 1, space, space + 1}) { + {static_cast(1), + space == 1 ? static_cast(1) : space - 1, space, + space + 1}) { EXPECT_EQ(space >= size, RunTest(alignment, size, ptr_val, space)); } } else { @@ -62,8 +65,10 @@ TEST(TRTAllocatorTest, Align) { EXPECT_TRUE( RunTest(alignment, space - diff, ptr_val + diff, space - diff)); for (const uint64_t size : - {1ul, space - diff > 1 ? space - diff - 1 : 1ul, space - diff, - space - diff + 1, space - 1}) { + {static_cast(1), + space - diff > 1 ? space - diff - 1 + : static_cast(1), + space - diff, space - diff + 1, space - 1}) { EXPECT_EQ(space - diff >= size, RunTest(alignment, size, ptr_val, space)); } diff --git a/tensorflow/contrib/tensorrt/segment/segment.cc b/tensorflow/contrib/tensorrt/segment/segment.cc index 6abc5226ccf96e472df77269bee6186726e5768d..084a96e0fa5c97edc58adf2590ed94e5ef0e4d85 100644 --- a/tensorflow/contrib/tensorrt/segment/segment.cc +++ b/tensorflow/contrib/tensorrt/segment/segment.cc @@ -225,6 +225,24 @@ SimpleGraph::~SimpleGraph() { for (auto x : edges_) delete x; } +// Define comparison functions for std::set with pointer keys so that behavior +// is deterministic. When using std::set with pointer key types, the items are +// sorted by pointer address which is non-deterministic. This can cause issues +// for INT8 mode because the graph is converted twice and non-determinism may +// cause a mismatch between the calibration tables of the conversions. +struct SimpleEdgePtrCompare { + bool operator()(const SimpleEdge* lhs, const SimpleEdge* rhs) const { + return lhs->id() < rhs->id(); + } +}; + +struct NodePtrCompare { + bool operator()(const tensorflow::Node* lhs, + const tensorflow::Node* rhs) const { + return lhs->name() < rhs->name(); + } +}; + namespace { // Copied from TF ReverseDFS, which only works for tensorflow::Graph. @@ -476,7 +494,7 @@ tensorflow::Status SegmentGraph( // nodes. Iterate since combining two nodes may unblock other // combining. while (true) { - std::set contract_edges; + std::set contract_edges; for (const SimpleEdge* out_edge : node->out_edges()) { VLOG(3) << "... out node " << out_edge->dst()->name() << " ( " << out_edge->dst()->id() << " <- " << node->id() << " )"; @@ -530,7 +548,7 @@ tensorflow::Status SegmentGraph( // A map from the segment identifier (currently the name of the root node of // the segment tree) to the segment nodes set. - std::map> sg_map; + std::map> sg_map; // A map from the segment identifier (currently the name of the root node of // the segment tree) to the device names that the nodes in the segment are @@ -566,7 +584,8 @@ tensorflow::Status SegmentGraph( // --------------------------------- Step 2 --------------------------------- // Remove ineligible input/output nodes. for (auto& itr : sg_map) { - std::set& segment_nodes = itr.second; + std::set& segment_nodes = + itr.second; VLOG(1) << "Segment original size: " << segment_nodes.size(); while (true) { std::deque in_nodes_que, out_nodes_que; @@ -618,8 +637,9 @@ tensorflow::Status SegmentGraph( bool is_input_nodes, std::deque* que) { // Run a BFS on the queue to find all the input/output nodes. - std::set visited; - std::set logged(que->begin(), que->end()); + std::set visited; + std::set logged(que->begin(), + que->end()); while (!que->empty()) { auto node = que->front(); que->pop_front(); @@ -653,7 +673,8 @@ tensorflow::Status SegmentGraph( // --------------------------------- Step 3 --------------------------------- // Convert the segments into the expected return format for (const auto& itr : sg_map) { - const std::set& segment_nodes = itr.second; + const std::set& segment_nodes = + itr.second; if (VLOG_IS_ON(1)) { string s = "parent=" + itr.first + ":"; for (auto node : segment_nodes) s += " " + node->name(); diff --git a/tensorflow/contrib/tensorrt/test/quantization_mnist_test.py b/tensorflow/contrib/tensorrt/test/quantization_mnist_test.py index 31cbef89e23949ba5ceaab34e0f683fd906bf0ce..e7d6ec4ad395d38a06f97020f2f363009f2286c7 100644 --- a/tensorflow/contrib/tensorrt/test/quantization_mnist_test.py +++ b/tensorflow/contrib/tensorrt/test/quantization_mnist_test.py @@ -191,7 +191,7 @@ class QuantizationAwareTrainingMNISTTest(test_util.TensorFlowTestCase): batch_size=batch_size, num_parallel_calls=8)) dataset = dataset.repeat(count=1) - iterator = data.make_one_shot_iterator(dataset) + iterator = dataset.make_one_shot_iterator() features, labels = iterator.get_next() return features, labels @@ -205,7 +205,7 @@ class QuantizationAwareTrainingMNISTTest(test_util.TensorFlowTestCase): batch_size=batch_size, num_parallel_calls=8)) dataset = dataset.repeat(count=num_epochs) - iterator = data.make_one_shot_iterator(dataset) + iterator = dataset.make_one_shot_iterator() features, labels = iterator.get_next() return features, labels diff --git a/tensorflow/contrib/tpu/BUILD b/tensorflow/contrib/tpu/BUILD index 4bf3a0463d9046eea2f60e9154fca1357e728215..007aeaec15d6db7ea4581ab9825da2dbe8b37163 100644 --- a/tensorflow/contrib/tpu/BUILD +++ b/tensorflow/contrib/tpu/BUILD @@ -102,6 +102,7 @@ tf_gen_op_libs( "replication_ops", "tpu_configuration_ops", "tpu_embedding_ops", + "tpu_ordinal_selector_op", ], deps = [ "//tensorflow/contrib/tpu/proto:tpu_embedding_configuration_proto_cc", @@ -153,6 +154,13 @@ tf_gen_op_wrapper_py( ], ) +tf_gen_op_wrapper_py( + name = "tpu_ordinal_selector_op", + deps = [ + ":tpu_ordinal_selector_op_op_lib", + ], +) + py_library( name = "profiler", srcs = ["python/profiler/__init__.py"], diff --git a/tensorflow/contrib/tpu/ops/tpu_ordinal_selector_op.cc b/tensorflow/contrib/tpu/ops/tpu_ordinal_selector_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..54e6b20f7f388b67a96ac8acfe814a4202b56a18 --- /dev/null +++ b/tensorflow/contrib/tpu/ops/tpu_ordinal_selector_op.cc @@ -0,0 +1,39 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/shape_inference.h" + +namespace tensorflow { + +REGISTER_OP("TPUOrdinalSelector") + .Output("device_ordinals: int32") + .SetIsStateful() + .SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) { + c->set_output(0, + c->Vector(shape_inference::InferenceContext::kUnknownDim)); + return Status::OK(); + }) + .Doc(R"doc( +A TPU core selector Op. + +This Op produces a set of TPU cores (for warm-up) or a single TPU core +(for regular inference) to execute the TPU program on. The output is +consumed by TPUPartitionedCall. + +device_ordinals: A vector 1 or more TPU cores. +)doc"); + +} // namespace tensorflow diff --git a/tensorflow/contrib/tpu/profiler/dump_tpu_profile.cc b/tensorflow/contrib/tpu/profiler/dump_tpu_profile.cc index ef35e84ba5205fb76e5afe77e670d87197ca8405..b4b06a40a2c8aaa97ff82baf93c8f2d55a587e37 100644 --- a/tensorflow/contrib/tpu/profiler/dump_tpu_profile.cc +++ b/tensorflow/contrib/tpu/profiler/dump_tpu_profile.cc @@ -98,7 +98,7 @@ Status DumpOpProfileToLogDirectory(StringPiece run_dir, if (!status.ok()) { return errors::Internal( "Failed to convert op profile to json. Skipping... ", - string(status.message())); + string(status.error_message())); } TF_RETURN_IF_ERROR(WriteStringToFile(Env::Default(), path, json)); if (os) { diff --git a/tensorflow/contrib/tpu/python/tpu/keras_support.py b/tensorflow/contrib/tpu/python/tpu/keras_support.py index 4ce194590342555a7c4e9e119bf51e516a37a715..cf9672f8d867f4ad5cb0281abe710f6e3bcdf1f2 100644 --- a/tensorflow/contrib/tpu/python/tpu/keras_support.py +++ b/tensorflow/contrib/tpu/python/tpu/keras_support.py @@ -2069,6 +2069,8 @@ class KerasTPUModel(models.Model): # tpu_model may not be compiled, e.g., loading weights and then predict. return for k, v in six.iteritems(cpu_optimizer_config): + if k == 'name': + continue opt_var = getattr(self._tpu_model.optimizer, k) if isinstance(opt_var, variables.Variable): logging.info('CPU -> TPU %s: %s {%s}', k, v, K.get_value(opt_var)) @@ -2097,6 +2099,8 @@ class KerasTPUModel(models.Model): self._cpu_model.set_weights(tpu_weights) for k, v in six.iteritems(tpu_optimizer_config): logging.info('TPU -> CPU %s: %s', k, v) + if k == 'name': + continue opt_var = getattr(self.cpu_optimizer, k) if isinstance(opt_var, variables.Variable): K.get_session().run(opt_var.assign(v)) diff --git a/tensorflow/contrib/tpu/python/tpu/keras_tpu_variables.py b/tensorflow/contrib/tpu/python/tpu/keras_tpu_variables.py index 8b0b240dc7302c203a22349d583323327fc4480b..de425626c813784ef657d17eac0c7bb77599a155 100644 --- a/tensorflow/contrib/tpu/python/tpu/keras_tpu_variables.py +++ b/tensorflow/contrib/tpu/python/tpu/keras_tpu_variables.py @@ -69,6 +69,7 @@ class ReplicatedVariable(object): def __init__(self, name, variables): self._name = name self._primary_var = variables[0] + self._common_name = self._primary_var.name.split(":")[0] self._vars = variables self._cached_value = None self._dtype = variables[0].dtype diff --git a/tensorflow/contrib/tpu/python/tpu/tensor_tracer.py b/tensorflow/contrib/tpu/python/tpu/tensor_tracer.py index 70baea203cc6174bebc7d90646045efae5f2391d..a1494e3660bc09e3af45e81097151a35990810fb 100644 --- a/tensorflow/contrib/tpu/python/tpu/tensor_tracer.py +++ b/tensorflow/contrib/tpu/python/tpu/tensor_tracer.py @@ -21,44 +21,56 @@ from __future__ import print_function import os import os.path import re +import sys from tensorflow.contrib.tpu.python.ops import tpu_ops from tensorflow.contrib.tpu.python.tpu import tpu from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_util from tensorflow.python.ops import gen_math_ops +from tensorflow.python.ops import linalg_ops from tensorflow.python.ops import logging_ops from tensorflow.python.ops import math_ops +from tensorflow.python.platform import gfile from tensorflow.python.platform import tf_logging as logging _TRACER_LOG_PREFIX = ' [>>>TT>>>]' _DEVICE_TYPE_TPU = 'tpu' _DEVICE_TYPE_CPU = 'cpu' -_GLOBAL_STEP_OP_NAME = 'GLOBAL-STEP' _TRACE_MODE_NAN_INF = 'nan-inf' _TRACE_MODE_PART_TENSOR = 'part-tensor' _TRACE_MODE_PART_TENSOR_SIZE = 3 _TRACE_MODE_FULL_TENSOR = 'full-tensor' -_RECORD_OUTSIDE_OP_RANGE = 'not-traced-outside-op-range' -_RECORD_SHOULD_NOT_TRACE = 'not-traced-should-not-trace' -_RECORD_FILTERED_OUT = 'not-traced-filtered-out' -_RECORD_SCALAR = 'not-traced-scalar' -_RECORD_DYNAMIC_SHAPE = 'not-traced-dynamic-shape' -_RECORD_GET_TRACED = 'get-traced' +_TRACE_MODE_NORM = 'norm' +_TRACE_MODE_MAX_ABS = 'max-abs' +_REASON_OUTSIDE_OP_RANGE = 'not-traced-outside-op-range' +_REASON_UNSAFE_OP = 'not-traced-unsafe-op' +_REASON_UNSAFE_SCALAR = 'not-traced-unsafe-scalar' +_REASON_LESS_INTERESTING_OP = 'not-traced-less-interesting-op' +_REASON_DEVICE_MISMATCH = 'not-traced-device-mismatch' +_REASON_DYNAMIC_SHAPE = 'not-traced-dynamic-shape' +_REASON_SCALAR_GET_TRACED = 'traced-scalar' +_REASON_TENSOR_GET_TRACED = 'traced-tensor' +_REASON_USER_INCLUDED = 'traced-user-included' +_REASON_USER_EXCLUDED = 'not-traced-user-excluded' +_REASON_NON_NUMERIC_TENSOR = 'not-traced-non-numeric-tensor' _MARKER_SECTION_BEGIN = '!!!!!!! section-begin:' _MARKER_SECTION_END = '!!!!!!! section-end:' _SECTION_NAME_CONFIG = 'configuration' _SECTION_NAME_REASON = 'reason' _SECTION_NAME_OP_LIST = 'op-list' +_SECTION_NAME_TENSOR_LIST = 'tensor-list' _SECTION_NAME_GRAPH = 'graph' _FIELD_NAME_VERSION = 'version:' _FIELD_NAME_DEVICE = 'device:' _FIELD_NAME_TRACE_MODE = 'trace-mode:' _FIELD_NAME_NUM_REPLICAS = 'num-replicas:' _FIELD_NAME_NUM_OPS = 'number-of-ops:' +_FIELD_NAME_NUM_TENSORS = 'number-of-tensors:' _FIELD_NAME_TOPOLOGICAL_SORT_SUCCEED = 'topological-sort-succeed:' _FLAGS_ENV_VAR = 'TENSOR_TRACER_FLAGS' _FLAG_SINGLE_QUOTE_PAT = re.compile(r"\s*--([^=]+)='([^']*)'") @@ -66,13 +78,72 @@ _FLAG_DOUBLE_QUOTE_PAT = re.compile(r'\s*--([^=]+)="([^"]*)"') _FLAG_NO_QUOTE_PAT = re.compile(r'\s*--([^=]+)=(\S*)') _FLAG_NAME_ENABLE = 'enable' _FLAG_NAME_TRACE_MODE = 'trace_mode' -_FLAG_NAME_INTERESTING_OPS = 'interesting_ops' +_FLAG_NAME_INCLUDE_LESS_INTERESTING_OPS = 'include_less_interesting_ops' +_FLAG_NAME_EXCLUDED_OPNAMES = 'excluded_opnames' +_FLAG_NAME_EXCLUDED_OPTYPES = 'excluded_optypes' +_FLAG_NAME_INCLUDED_OPNAMES = 'included_opnames' +_FLAG_NAME_INCLUDED_OPTYPES = 'included_optypes' _FLAG_NAME_TRACE_FILE = 'trace_file_path' +_FLAG_NAME_REPORT_FILE = 'report_file_path' _FLAG_NAME_USE_TEST_UNDECLARED_OUTPUTS_DIR = 'use_test_undeclared_outputs_dir' _FLAG_NAME_OP_RANGE = 'op_range' _OP_RANGE_PAT = re.compile(r'(\d+):(\d+)') _OUTPUT_STREAM_ESCAPE = 'file://' _TEST_UNDECLARED_OUTPUTS_DIR_ENV_VAR = 'TEST_UNDECLARED_OUTPUTS_DIR' +_TENSOR_TRACER_COLLECTION = 'tensor_tracer_variables' +_TENSOR_TRACER_CHECKPOINT = 'tensor_tracer_checkpoint' + + +def tensor_checkpoint(tensor, checkpoint_name): + """Adds a checkpoint with the given checkpoint name for the given tensor. + + The tensor will be added to the list of tensors that will be traced by the + tensor tracer. + + Args: + tensor: the tensor object for which the tracing is requested. + checkpoint_name: a string name for the checkpoint. This name has to be a + unique name if used within model comparison. The tensors that have the same + checkpoint identifier is compared in model comparison. + Returns: + The provided tensor. + """ + + tensor.graph.get_collection(_TENSOR_TRACER_COLLECTION) + tensor.graph.add_to_collection(_TENSOR_TRACER_COLLECTION, + (tensor, checkpoint_name)) + return tensor + + +def keras_layer_checkpoint(layer, checkpoint_name): + """An interface for adding the tensor outputs of a keras layer. + + Encapsulates tensor_checkpoint. + + Args: + layer: A keras layer. + checkpoint_name: a string name for the checkpoint. This name has to be a + unique name if used within model comparison. The tensors that have the same + checkpoint identifier is compared in model comparison. + + Returns: + The provided layer. + """ + try: + outputs = layer.output + if tensor_util.is_tensor(outputs): + tensor_checkpoint(outputs, '%s' % (checkpoint_name)) + else: + idx = 0 + for output_tensor in outputs: + if tensor_util.is_tensor(outputs): + tensor_checkpoint(output_tensor, '%s_%d' % (checkpoint_name, idx)) + idx += 1 + except AttributeError: + pass + except RuntimeError: + pass + return layer class TensorTracer(object): @@ -105,6 +176,34 @@ class TensorTracer(object): match = _FLAG_NO_QUOTE_PAT.match(flags, pos) return match + @staticmethod + def validate_flag_names(): + """Validates if the TensorTrace flags passed are valid.""" + valid_flag_names = [_FLAG_NAME_ENABLE, _FLAG_NAME_TRACE_MODE, + _FLAG_NAME_EXCLUDED_OPNAMES, + _FLAG_NAME_EXCLUDED_OPTYPES, + _FLAG_NAME_INCLUDED_OPNAMES, + _FLAG_NAME_INCLUDED_OPTYPES, + _FLAG_NAME_TRACE_FILE, _FLAG_NAME_REPORT_FILE, + _FLAG_NAME_USE_TEST_UNDECLARED_OUTPUTS_DIR, + _FLAG_NAME_INCLUDE_LESS_INTERESTING_OPS, + _FLAG_NAME_OP_RANGE] + tensor_tracer_flags = os.environ.get(_FLAGS_ENV_VAR) + if not tensor_tracer_flags: + return + pos = 0 + while True: + match = TensorTracer._match_next_flag(tensor_tracer_flags, pos) + if not match: + break + flag_name = match.group(1) + if flag_name not in valid_flag_names: + raise ValueError( + 'The flag name "%s" passed via the environment variable "%s" ' + 'is invalid. Valid flag names are:' + '\n%s'%(flag_name, _FLAGS_ENV_VAR, valid_flag_names)) + pos = match.end() + @staticmethod def print_flag_values(): """Prints all TensorTracer flags passed via environment variables.""" @@ -146,6 +245,20 @@ class TensorTracer(object): pos = match.end() return '' + @staticmethod + def flag_value_to_re_list(flag_name): + """Converts list of strings to compiled RE.""" + + re_list = [] + flag_value = TensorTracer.get_flag_value(flag_name) + if not flag_value: + return re_list + list_of_values = flag_value.split() + for v in list_of_values: + r = re.compile(v) + re_list.append(r) + return re_list + @staticmethod def is_enabled(): """Returns True if TensorTracer is enabled.""" @@ -186,29 +299,67 @@ class TensorTracer(object): """Checks if the given trace mode is valid.""" valid_trace_modes = [_TRACE_MODE_NAN_INF, _TRACE_MODE_PART_TENSOR, - _TRACE_MODE_FULL_TENSOR] + _TRACE_MODE_FULL_TENSOR, _TRACE_MODE_NORM, + _TRACE_MODE_MAX_ABS] if trace_mode not in valid_trace_modes: raise ValueError('Invalid trace mode "%s" given to the Tensor_Tracer.' 'Valid trace modes are: %s'%(trace_mode, valid_trace_modes)) @staticmethod - def should_trace(device_type, op): - """Returns True if the given Op should be traced.""" + def unsafe_op(op): + """Returns True if this op is not safe to be traced.""" - if device_type != _DEVICE_TYPE_TPU: - raise ValueError('Non TPU device type is not supported') if control_flow_util.IsInCond(op): + return True + # Reasons for not including following op types: + # Assign: cause incorrect result with CPU tracing. + # others: compilation problems. + if op.type in ['Assign', 'Pack', 'Shape', 'Reshape', 'ArgMin', 'ArgMax']: + return True + return False + + @staticmethod + def device_mismatch(device_type, op): + if device_type == _DEVICE_TYPE_TPU: + # pylint: disable=protected-access + return tpu._TPU_REPLICATE_ATTR not in op.node_def.attr + # pylint: enable=protected-access + return False + + @staticmethod + def unsafe_scalar_trace(op): + """Return true if scalar output tensor from Op is not safe to be traced.""" + + # Tracing the following causes cycle in the graph on TPU. + if op.type in ['LoopCond', 'Enter', 'Merge', 'Const', + 'Switch', 'Less', 'ReadVariableOp']: + return True + # Tracing the following will cause casting-issue + # with the norm tracing mode or other compilation issues on CPU. + if op.type in ['VarHandleOp', 'IteratorToStringHandle', + 'IteratorGetNext', 'OneShotIterator', + 'IteratorV2', 'MakeIterator', + 'BatchDatasetV2', 'MapDataset', + 'FixedLengthRecordDataset', 'TakeDataset', 'ZipDataset', + 'Placeholder', 'PlaceholderWithDefault', 'StridedSlice']: + return True + return False + + @staticmethod + def less_interesting_op(op): + """Returns True if the given Op is not an interesting one to be traced.""" + + include_less_interesting = TensorTracer.get_flag_value( + _FLAG_NAME_INCLUDE_LESS_INTERESTING_OPS) + if include_less_interesting: return False - if op.type in ['Reshape', 'ArgMin', 'ArgMax']: - return False - # pylint: disable=protected-access - return tpu._TPU_REPLICATE_ATTR in op.node_def.attr - # pylint: enable=protected-access + return op.type in ['Const', 'Identity', 'Cast', 'Shape'] @staticmethod def reason(op_idx, details): - """Returns why the Op at op_idx is traced or not.""" + """Returns reason why the Op at op_idx is traced or not.""" + return '%d %s'%(op_idx, details) @staticmethod @@ -274,6 +425,33 @@ class TensorTracer(object): assert len(unsorted_ops) == len(sorted_ops) return (True, sorted_ops) + @staticmethod + def _make_op_and_tensor_maps(op_list): + """Creates various maps and lists from op_list. + + Args: + op_list: a list of Ops + + Returns: + opname_idx_map: a map from Op's name to its index in op_list. + tensor_list: a list of output tensors of the Ops in op_list. + tensorname_idx_map: a map from output tensor name to its index + in tensor_list. + """ + + opname_idx_map = {} + tensor_list = [] + tensorname_idx_map = {} + for op_id, op in enumerate(op_list): + if op.name in opname_idx_map: + raise ValueError('Duplicated Op name: %s'%op.name) + opname_idx_map[op.name] = op_id + for output_tensor in op.outputs: + if output_tensor.name not in tensorname_idx_map: + tensor_list.append(output_tensor) + tensorname_idx_map[output_tensor.name] = len(tensor_list)-1 + return (opname_idx_map, tensor_list, tensorname_idx_map) + def __init__(self): """Initializes a TensorTracer. @@ -281,16 +459,20 @@ class TensorTracer(object): """ self._version = 'use-outside-compilation' self._device_type = None + TensorTracer.validate_flag_names() self._trace_mode = TensorTracer.get_flag_value(_FLAG_NAME_TRACE_MODE) if not self._trace_mode: self._trace_mode = _TRACE_MODE_NAN_INF TensorTracer.check_trace_mode(self._trace_mode) self._part_tensor_size = _TRACE_MODE_PART_TENSOR_SIZE self._instrument_records = {} - interesting_ops = TensorTracer.get_flag_value(_FLAG_NAME_INTERESTING_OPS) - self._selected_ops = interesting_ops.split() self._set_trace_file_path() + self._set_report_file() self._set_op_range() + self._set_excluded_opnames() + self._set_excluded_optypes() + self._set_included_opnames() + self._set_included_optypes() self._num_replicas = None self._replica_id = None @@ -318,10 +500,7 @@ class TensorTracer(object): """Sets the path of the output trace file.""" self._trace_file_path = TensorTracer.get_flag_value(_FLAG_NAME_TRACE_FILE) - if not self._trace_file_path: - raise ValueError('--%s is not set in the environment variable %s' - %(_FLAG_NAME_TRACE_FILE, _FLAGS_ENV_VAR)) - elif TensorTracer.use_test_undeclared_outputs_dir(): + if self._trace_file_path and TensorTracer.use_test_undeclared_outputs_dir(): if os.path.isabs(self._trace_file_path): raise ValueError('If use_test_undeclared_outputs_dir is set,' 'trace_file_path cannot be an absolute path (%s)' @@ -330,6 +509,22 @@ class TensorTracer(object): self._trace_file_path = os.path.join(outputs_dir, self._trace_file_path) + def _set_report_file(self): + """Sets the path of the output report file.""" + + self._report_file_path = TensorTracer.get_flag_value(_FLAG_NAME_REPORT_FILE) + if not self._report_file_path: + self._report_file = None + return + try: + self._report_file = gfile.Open(self._report_file_path, 'w') + except IOError as e: + raise e + + def _close_report_file(self): + if self._report_file: + self._report_file.close() + def _set_op_range(self): """Sets the index range of the Ops that we will consider tracing.""" @@ -350,19 +545,48 @@ class TensorTracer(object): return False return self._op_range[1] < 0 or idx <= self._op_range[1] - def _write_report(self, content): - """Writes the given content to the report.""" + def _set_excluded_opnames(self): + self._excluded_opname_re_list = TensorTracer.flag_value_to_re_list( + _FLAG_NAME_EXCLUDED_OPNAMES) + + def _set_excluded_optypes(self): + self._excluded_optype_re_list = TensorTracer.flag_value_to_re_list( + _FLAG_NAME_EXCLUDED_OPTYPES) + + def _set_included_opnames(self): + self._included_opname_re_list = TensorTracer.flag_value_to_re_list( + _FLAG_NAME_INCLUDED_OPNAMES) + + def _set_included_optypes(self): + self._included_optype_re_list = TensorTracer.flag_value_to_re_list( + _FLAG_NAME_INCLUDED_OPTYPES) + + def _is_user_included_op(self, op): + for opname_re in self._included_opname_re_list: + if opname_re.match(op.name): + return True + for optype_re in self._included_optype_re_list: + if optype_re.match(op.type): + return True + return False - logging.info('%s %s'%(_TRACER_LOG_PREFIX, content)) + def _is_user_excluded_op(self, op): + for opname_re in self._excluded_opname_re_list: + if opname_re.match(op.name): + return True + for optype_re in self._excluded_optype_re_list: + if optype_re.match(op.type): + return True + return False - def _is_selected_op(self, op_name): - """Returns True if the Op with op_name is selected to be traced.""" + def _write_report(self, content): + """Writes the given content to the report.""" - if not self._selected_ops: - return True - if op_name in self._selected_ops: - return True - return False + line = '%s %s'%(_TRACER_LOG_PREFIX, content) + if self._report_file: + self._report_file.write(line) + else: + logging.info(line) def _write_config_section(self): """Writes the config section of the report.""" @@ -382,15 +606,42 @@ class TensorTracer(object): self._write_report('"%s" %s\n'%(key, self._instrument_records[key])) self._write_report('%s %s\n'%(_MARKER_SECTION_END, _SECTION_NAME_REASON)) - def _write_op_list_section(self, op_list): + def _write_op_list_section(self, op_list, tensorname_idx_map): """Writes the Op-list section of the report.""" self._write_report('%s %s\n'%(_MARKER_SECTION_BEGIN, _SECTION_NAME_OP_LIST)) self._write_report('%s %d\n'%(_FIELD_NAME_NUM_OPS, len(op_list))) for i in range(0, len(op_list)): - self._write_report('%d "%s" %s\n'%(i, op_list[i].name, op_list[i].type)) + op = op_list[i] + line = '%d "%s" %s'%(i, op.name, op.type) + for out_tensor in op.outputs: + if out_tensor.name not in tensorname_idx_map: + raise ValueError( + 'out_tensor %s is not in tensorname_idx_map'%out_tensor.name) + line += ' %d'%tensorname_idx_map[out_tensor.name] + line += '\n' + self._write_report(line) self._write_report('%s %s\n'%(_MARKER_SECTION_END, _SECTION_NAME_OP_LIST)) + def _write_tensor_list_section(self, tensor_list, opname_idx_map): + """Writes the tensor-list section of the report.""" + + self._write_report('%s %s\n'%(_MARKER_SECTION_BEGIN, + _SECTION_NAME_TENSOR_LIST)) + self._write_report('%s %d\n'%(_FIELD_NAME_NUM_TENSORS, len(tensor_list))) + for i in range(0, len(tensor_list)): + tensor = tensor_list[i] + line = '%d "%s"'%(i, tensor.name) + for consumer_op in tensor.consumers(): + if consumer_op.name not in opname_idx_map: + raise ValueError( + 'consumer_op %s is not in opname_idx_map'%consumer_op.name) + line += ' %d'%opname_idx_map[consumer_op.name] + line += '\n' + self._write_report(line) + self._write_report('%s %s\n'%(_MARKER_SECTION_END, + _SECTION_NAME_TENSOR_LIST)) + def _write_graph_section(self, succeed, sorted_or_cycle): """Writes the graph section of the report.""" @@ -422,7 +673,7 @@ class TensorTracer(object): Args: op_name: the name of the Op that outputs the tensor to be printed. output_idx: which output of the Op it is (0 means the first output). - num_elements: number of elements to print. + num_elements: number of elements to print (-1 means print all). tensor: the tensor needs to be returned. output_tensor: the tensor needs to be printed. @@ -430,10 +681,13 @@ class TensorTracer(object): The same tensor passed via the "tensor" argument. """ msg = '"%s:%d" '%(op_name, output_idx) - output_stream = _OUTPUT_STREAM_ESCAPE + self._trace_file_path + if self._trace_file_path: + output_stream = _OUTPUT_STREAM_ESCAPE + self._trace_file_path + else: + output_stream = sys.stderr print_op = logging_ops.print_v2(msg, array_ops.shape(output_tensor), ' @', self._replica_id, - '\n', output_tensor, + '\n', output_tensor, '\n', summarize=num_elements, output_stream=output_stream) with ops.control_dependencies([print_op]): @@ -442,7 +696,8 @@ class TensorTracer(object): def _detect_nan_inf(tensor): """Trace function for detecting any NaN/Inf in the tensor.""" - if tensor.dtype.is_floating: + if tensor.dtype.__eq__(dtypes.bfloat16) or tensor.dtype.__eq__( + dtypes.float16): # Since host can't handle bf16, always convert tensor to f32. tensor = math_ops.cast(tensor, dtypes.float32) output_tensor = math_ops.reduce_any( @@ -450,12 +705,19 @@ class TensorTracer(object): gen_math_ops.is_inf(tensor))) else: output_tensor = constant_op.constant(0) - return _print_tensor(op_name, output_idx, 1, tensor, output_tensor) + return _print_tensor(op_name, output_idx, -1, tensor, output_tensor) - def _show_global_step(tensor): - """Trace function for printing the global step count.""" + def _show_norm(tensor): + tensor = math_ops.cast(tensor, dtypes.float64) + output_tensor = linalg_ops.norm(tensor) + return _print_tensor(op_name, output_idx, -1, tensor, output_tensor) - return _print_tensor(op_name, output_idx, 1, tensor, tensor) + def _show_max_abs(tensor): + output_tensor = math_ops.cast(math_ops.reduce_max(math_ops.abs(tensor)), + dtypes.float64) + zero = constant_op.constant(0, dtypes.float64) + output_tensor = gen_math_ops.maximum(zero, output_tensor) + return _print_tensor(op_name, output_idx, -1, tensor, output_tensor) def _show_part_tensor(tensor): """Trace function for printing part of the tensor.""" @@ -468,23 +730,139 @@ class TensorTracer(object): return _print_tensor(op_name, output_idx, -1, tensor, tensor) - if op_name == _GLOBAL_STEP_OP_NAME: - return _show_global_step if self._trace_mode == _TRACE_MODE_NAN_INF: return _detect_nan_inf if self._trace_mode == _TRACE_MODE_PART_TENSOR: return _show_part_tensor if self._trace_mode == _TRACE_MODE_FULL_TENSOR: return _show_full_tensor + if self._trace_mode == _TRACE_MODE_NORM: + return _show_norm + if self._trace_mode == _TRACE_MODE_MAX_ABS: + return _show_max_abs raise RuntimeError('Tensor trace fun for %s is not yet implemented' %self._trace_mode) + def _skip_op(self, op_id, op, user_included, user_excluded): + """Returns True if we should not trace Op.""" + + if user_included: + self._instrument_records[op.name] = TensorTracer.reason( + op_id, _REASON_USER_INCLUDED) + return False + if user_excluded: + self._instrument_records[op.name] = TensorTracer.reason( + op_id, _REASON_USER_EXCLUDED) + return True + if not self._inside_op_range(op_id): + self._instrument_records[op.name] = TensorTracer.reason( + op_id, _REASON_OUTSIDE_OP_RANGE) + return True + if TensorTracer.unsafe_op(op): + self._instrument_records[op.name] = TensorTracer.reason( + op_id, _REASON_UNSAFE_OP) + return True + if TensorTracer.device_mismatch(self._device_type, op): + self._instrument_records[op.name] = TensorTracer.reason( + op_id, _REASON_DEVICE_MISMATCH) + return True + if TensorTracer.less_interesting_op(op): + self._instrument_records[op.name] = TensorTracer.reason( + op_id, _REASON_LESS_INTERESTING_OP) + return True + return False + + def _skip_tensor(self, op_id, out_tensor, user_included, + user_excluded): + """Returns True if we should not trace out_tensor.""" + + # Skips a tensor if the tensor has a non-numeric type. + # Note: we cannot use check_ops.is_numeric_tensor(out_tensor) + # because it also excludes tensors with dtypes, bool, and + # float32_ref, which we actually want to trace. + non_numeric_tensor_types = set([dtypes.variant, dtypes.resource, + dtypes.string]) + if out_tensor.dtype in non_numeric_tensor_types: + self._instrument_records[out_tensor.name] = TensorTracer.reason( + op_id, _REASON_NON_NUMERIC_TENSOR) + return True + + if user_included: + self._instrument_records[out_tensor.name] = TensorTracer.reason( + op_id, _REASON_USER_INCLUDED) + return False + if user_excluded: + self._instrument_records[out_tensor.name] = TensorTracer.reason( + op_id, _REASON_USER_EXCLUDED) + return True + if not out_tensor.get_shape().is_fully_defined(): + self._instrument_records[out_tensor.name] = TensorTracer.reason( + op_id, _REASON_DYNAMIC_SHAPE) + return True + rank = len(out_tensor.shape) + if rank < 1: + # scalar + if TensorTracer.unsafe_scalar_trace(out_tensor.op): + self._instrument_records[out_tensor.name] = TensorTracer.reason( + op_id, _REASON_UNSAFE_SCALAR) + return True + else: + self._instrument_records[out_tensor.name] = TensorTracer.reason( + op_id, _REASON_SCALAR_GET_TRACED) + return False + else: + # tensor + self._instrument_records[out_tensor.name] = TensorTracer.reason( + op_id, _REASON_TENSOR_GET_TRACED) + return False + + def _pre_tracing(self, graph): + """Work needs to be done prior to TPU or CPU tracing.""" + + operations = graph.get_operations() + (opname_idx_map, tensor_list, tensorname_idx_map) = ( + TensorTracer._make_op_and_tensor_maps(operations)) + self._write_config_section() + self._write_op_list_section(operations, tensorname_idx_map) + self._write_tensor_list_section(tensor_list, opname_idx_map) + # Does the topological sort before adding any nodes to the graph. + (succeed, sorted_or_cycle) = TensorTracer.topological_sort(graph) + return (operations, succeed, sorted_or_cycle) + + def _post_tracing(self, succeed, sorted_or_cycle): + """Work needs to be done after TPU or CPU tracing.""" + + self._write_reason_section() + self._write_graph_section(succeed, sorted_or_cycle) + self._close_report_file() + + def _get_checkpoints(self, graph): + """Returns the list of Ops that produce the tensors traced with API. + + Args: + graph: the graph of Ops. + + Returns: + A set of operation names which should be traced. + """ + + self._write_report('%s %s\n'%(_MARKER_SECTION_BEGIN, + _TENSOR_TRACER_CHECKPOINT)) + checkpoint_operations = set() + tensor_tracer_variables = graph.get_collection(_TENSOR_TRACER_COLLECTION) + for (tensor, checkpoint_name) in tensor_tracer_variables: + self._write_report('%s %s\n'%(tensor.name, checkpoint_name)) + checkpoint_operations.add(tensor.op.name) + self._write_report('%s %s\n'%(_MARKER_SECTION_END, + _TENSOR_TRACER_CHECKPOINT)) + return checkpoint_operations + def trace_tpu(self, graph, result_tensor, num_replicas=None): """Traces the tensors generated by TPU Ops in a TF graph. Args: - graph: the graph of Ops. + graph: the graph of Ops executed on the TPU. result_tensor: a result tensor of evaluating the graph. num_replicas: number of replicas used on the TPU. @@ -502,38 +880,22 @@ class TensorTracer(object): TensorTracer.check_device_type(self._device_type) result_tensor_copy = self._add_replica_id_to_graph(num_replicas, result_tensor) - self._write_config_section() + (operations, succeed, sorted_or_cycle) = self._pre_tracing(graph) tracing_ops = [] - operations = graph.get_operations() - self._write_op_list_section(operations) - # Does the topological sort before adding any nodes to the graph. - (succeed, sorted_or_cycle) = TensorTracer.topological_sort(graph) + checkpoint_operations = self._get_checkpoints(graph) + for op_id, op in enumerate(operations): - if not self._inside_op_range(op_id): - self._instrument_records[op.name] = TensorTracer.reason( - op_id, _RECORD_OUTSIDE_OP_RANGE) + if checkpoint_operations and op.name not in checkpoint_operations: continue - if not TensorTracer.should_trace(self._device_type, op): - self._instrument_records[op.name] = TensorTracer.reason( - op_id, _RECORD_SHOULD_NOT_TRACE) - continue - if not self._is_selected_op(op.name): - self._instrument_records[op.name] = TensorTracer.reason( - op_id, _RECORD_FILTERED_OUT) + user_included = self._is_user_included_op(op) + user_excluded = self._is_user_excluded_op(op) + if self._skip_op(op_id, op, user_included, user_excluded): continue for i in range(len(op.outputs)): out_tensor = op.outputs[i] - if not out_tensor.get_shape().is_fully_defined(): - self._instrument_records[out_tensor.name] = TensorTracer.reason( - op_id, _RECORD_DYNAMIC_SHAPE) - continue # cannot trace tensors with dynamic shape. - rank = len(out_tensor.shape) - if rank < 1: - self._instrument_records[out_tensor.name] = TensorTracer.reason( - op_id, _RECORD_SCALAR) - continue # cannot trace scalar. - self._instrument_records[out_tensor.name] = TensorTracer.reason( - op_id, _RECORD_GET_TRACED) + if self._skip_tensor(op_id, out_tensor, user_included, + user_excluded): + continue consumers = out_tensor.consumers() trace_op = tpu.outside_compilation( self._make_tensor_trace_fun(op.name, i), out_tensor) @@ -546,8 +908,45 @@ class TensorTracer(object): # if there is no consumer, we will add the control dependence later # when we add the control dependency to the output operations. tracing_ops.append(trace_op) + self._post_tracing(succeed, sorted_or_cycle) + return (result_tensor_copy, tracing_ops) - self._write_reason_section() - self._write_graph_section(succeed, sorted_or_cycle) + def trace_cpu(self, graph): + """Traces the tensors generated by CPU Ops in a TF graph. - return (result_tensor_copy, tracing_ops) + Args: + graph: the graph of Ops executed on the CPU. + + Returns: + tracing_calls: a map from keys to trace calls. + A key is constructed from an Op's name. + A trace call consists of a function and a tensor ( + the function will be invoked with the tensor). + """ + + self._device_type = _DEVICE_TYPE_CPU + TensorTracer.check_device_type(self._device_type) + self._num_replicas = 1 + self._replica_id = 0 + (operations, succeed, sorted_or_cycle) = self._pre_tracing(graph) + tracing_calls = {} + checkpoint_operations = self._get_checkpoints(graph) + + for op_id, op in enumerate(operations): + if checkpoint_operations and op.name not in checkpoint_operations: + continue + user_included = self._is_user_included_op(op) + user_excluded = self._is_user_excluded_op(op) + if self._skip_op(op_id, op, user_included, user_excluded): + continue + for i in range(len(op.outputs)): + out_tensor = op.outputs[i] + if self._skip_tensor(op_id, out_tensor, user_included, + user_excluded): + continue + trace_fun = self._make_tensor_trace_fun(op.name, i) + trace_call = (trace_fun, [out_tensor]) + trace_call_key = 'tensor_tracing_cpu-%s:%d'%(op.name, i) + tracing_calls[trace_call_key] = trace_call + self._post_tracing(succeed, sorted_or_cycle) + return tracing_calls diff --git a/tensorflow/contrib/tpu/python/tpu/tpu.py b/tensorflow/contrib/tpu/python/tpu/tpu.py index def57da20d6018dcf27ccb7a9d04592f38ce2f7c..9266d81cf5fc035790062f0e307a5da0b01a9fc1 100644 --- a/tensorflow/contrib/tpu/python/tpu/tpu.py +++ b/tensorflow/contrib/tpu/python/tpu/tpu.py @@ -646,6 +646,10 @@ def split_compile_and_replicate(computation, array_ops.identity(x, name="replicated_input_{}".format(i)) for i, x in enumerate(computation_inputs) ] + for i in computation_inputs: + # pylint: disable=protected-access + i.op._set_attr("_tpu_input_identity", attr_value_pb2.AttrValue(b=True)) + # pylint: enable=protected-access # If there is an infeed queue, adds the dequeued values to the # computation's inputs. @@ -726,7 +730,11 @@ def split_compile_and_replicate(computation, new_output_tensors = [] for t in output_tensors: with ops.device(t.device if t.device else core(0)): - new_output_tensors.append(array_ops.identity(t)) + o = array_ops.identity(t) + # pylint: disable=protected-access + o.op._set_attr("_tpu_output_identity", attr_value_pb2.AttrValue(b=True)) + # pylint: enable=protected-access + new_output_tensors.append(o) output_tensors = new_output_tensors context.ExitResult(output_tensors) finally: @@ -777,15 +785,15 @@ def split_compile_and_replicate(computation, ] -def shard(computation, - inputs=None, - num_shards=1, - input_shard_axes=None, - outputs_from_all_shards=True, - output_shard_axes=None, - infeed_queue=None, - device_assignment=None, - name=None): +def split_compile_and_shard(computation, + inputs=None, + num_shards=1, + input_shard_axes=None, + outputs_from_all_shards=True, + output_shard_axes=None, + infeed_queue=None, + device_assignment=None, + name=None): """Shards `computation` for parallel execution. `inputs` must be a list of Tensors or None (equivalent to an empty list), each @@ -839,7 +847,7 @@ def shard(computation, is equal to the number of cores in the TPU system. name: (Deprecated) Does nothing. Returns: - A list of output tensors. + A tuple of (compile op, [output tensors]). Raises: ValueError: If num_shards <= 0 ValueError: If len(input_shard_axes) != len(inputs) @@ -874,7 +882,7 @@ def shard(computation, else: transposed_inputs = [[]] * num_shards - outputs = replicate( + compile_op, outputs = split_compile_and_replicate( computation, transposed_inputs, infeed_queue=infeed_queue, @@ -891,7 +899,7 @@ def shard(computation, # one so it can be used as a control dependency or fetch node. # TODO(b/36647078) remove disable when pylint bug is fixed. # pylint: disable=indexing-exception - return [outputs[0]] + return compile_op, [outputs[0]] # pylint: enable=indexing-exception # TODO(b/36647078) remove disable when pylint bug is fixed. @@ -925,7 +933,87 @@ def shard(computation, # TODO(phawkins): use a smarter policy, e.g., round-robin across shards. results.append(x[0]) - return results + return compile_op, results + + +def shard(computation, + inputs=None, + num_shards=1, + input_shard_axes=None, + outputs_from_all_shards=True, + output_shard_axes=None, + infeed_queue=None, + device_assignment=None, + name=None): + """Shards `computation` for parallel execution. + + `inputs` must be a list of Tensors or None (equivalent to an empty list), each + of which has a corresponding split axis (from `input_shard_axes`). Each input + is split into `num_shards` pieces along the corresponding axis, and + computation is applied to each shard in parallel. + + Tensors are broadcast to all shards if they are lexically captured by + `computation`. e.g., + + x = tf.constant(7) + def computation(): + return x + 3 + ... = shard(computation, ...) + + TODO(phawkins): consider adding support for broadcasting Tensors passed + as inputs. + + If `outputs_from_all_shards` is true, the outputs from all shards of + `computation` are concatenated back together along their `output_shards_axes`. + Otherwise, each output is taken from an arbitrary shard. + + Inputs and outputs of the computation must be at least rank-1 Tensors. + + Args: + computation: A Python function that builds a computation to apply to each + shard of the input. + inputs: A list of input tensors or None (equivalent to an empty list). Each + input tensor has a corresponding shard axes, given by `input_shard_axes`, + which must have size divisible by `num_shards`. + num_shards: The number of shards. + input_shard_axes: A list of dimensions along which to shard `inputs`, or + `None`. `None` means "shard all inputs along dimension 0". If not `None`, + there must be one dimension per input. + outputs_from_all_shards: Boolean or list of boolean. For each output, if + `True`, outputs from all shards are concatenated along the corresponding + `output_shard_axes` entry. Otherwise, each output is taken + from an arbitrary shard. If the argument is a boolean, the argument's + value is used for each output. + output_shard_axes: A list of dimensions along which to concatenate the + outputs of `computation`, or `None`. `None` means "concatenate all outputs + along dimension 0". If not `None`, there must be one dimension per output. + Ignored if `outputs_from_all_shards` is False. + infeed_queue: If not `None`, the `InfeedQueue` to use to augment the inputs + of `computation`. + device_assignment: If not `None`, a `DeviceAssignment` describing the + mapping between logical cores in the computation with physical cores in + the TPU topology. Uses a default device assignment if `None`. The + `DeviceAssignment` may be omitted if each shard of the computation uses + only one core, and there is either only one shard, or the number of shards + is equal to the number of cores in the TPU system. + name: (Deprecated) Does nothing. + Returns: + A list of output tensors. + Raises: + ValueError: If num_shards <= 0 + ValueError: If len(input_shard_axes) != len(inputs) + ValueError: If len(output_shard_axes) != len(outputs from `computation`) + """ + return split_compile_and_shard( + computation, + inputs=inputs, + num_shards=num_shards, + input_shard_axes=input_shard_axes, + outputs_from_all_shards=outputs_from_all_shards, + output_shard_axes=output_shard_axes, + infeed_queue=infeed_queue, + device_assignment=device_assignment, + name=name)[1] def batch_parallel(computation, diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py b/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py index 96b9556e137effcaaa5916b9723142f737a6dc33..44a8f7ce0e5794ec95b5d0c25adca14b194a25d1 100644 --- a/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py +++ b/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py @@ -31,6 +31,7 @@ import six from six.moves import queue as Queue # pylint: disable=redefined-builtin from six.moves import xrange # pylint: disable=redefined-builtin +from tensorflow.contrib.tpu.proto import compilation_result_pb2 as tpu_compilation_result from tensorflow.contrib.tpu.python.tpu import tensor_tracer from tensorflow.contrib.tpu.python.ops import tpu_ops from tensorflow.contrib.tpu.python.tpu import error_handling @@ -336,6 +337,16 @@ class TPUEstimatorSpec(model_fn_lib._TPUEstimatorSpec): # pylint: disable=prote hooks = None if self.host_call is not None: hooks = [_OutfeedHostCallHook(host_call_ret['host_call'])] + if tensor_tracer.TensorTracer.is_enabled(): + tt = tensor_tracer.TensorTracer() + tracing_calls = tt.trace_cpu(ops.get_default_graph()) + tracing_call_ret = _OutfeedHostCall.create_cpu_hostcall(tracing_calls) + tracing_functions = tracing_call_ret.values() + if tracing_functions: + if hooks: + hooks.extend([_OutfeedHostCallHook(tracing_functions)]) + else: + hooks = [_OutfeedHostCallHook(tracing_functions)] hooks = tuple(hooks or []) scaffold = self.scaffold_fn() if self.scaffold_fn else None return model_fn_lib.EstimatorSpec( @@ -412,6 +423,7 @@ class TPUInfeedOutfeedSessionHook(session_run_hook.SessionRunHook): ctx, enqueue_ops, dequeue_ops, + tpu_compile_op, run_infeed_loop_on_coordinator=True, rendezvous=None, master=None, @@ -429,6 +441,7 @@ class TPUInfeedOutfeedSessionHook(session_run_hook.SessionRunHook): self._feed_error = None self._finished = False self._should_initialize_tpu = True + self._tpu_compile_op = tpu_compile_op def begin(self): logging.info('TPU job name %s', self._master_job) @@ -477,6 +490,15 @@ class TPUInfeedOutfeedSessionHook(session_run_hook.SessionRunHook): def _create_infeed_controller(self, name, target, args): return _OpQueueContext(name=name, target=target, args=args) + def _assertCompilationSucceeded(self, result, coord): + proto = tpu_compilation_result.CompilationResultProto() + proto.ParseFromString(result) + if proto.status_error_message: + logging.error('Compilation failed: {}'.format(proto.status_error_message)) + coord.request_stop() + else: + logging.info('Compilation succeeded') + def after_create_session(self, session, coord): if self._should_initialize_tpu: logging.info('Init TPU system') @@ -490,6 +512,10 @@ class TPUInfeedOutfeedSessionHook(session_run_hook.SessionRunHook): session.run(self._init_ops, options=config_pb2.RunOptions(timeout_in_ms=5 * 60 * 1000)) + if os.environ.get('TPU_SPLIT_COMPILE_AND_EXECUTE', '') == '1': + logging.info('Compiling user program: this may take a while...') + self._assertCompilationSucceeded(session.run(self._tpu_compile_op), coord) + self._infeed_controller = self._create_infeed_controller( name='InfeedController', target=self._run_infeed, args=(session,)) @@ -530,12 +556,13 @@ class TPUInfeedOutfeedSessionHook(session_run_hook.SessionRunHook): class TPUInfeedOutfeedSessionHookForPrediction(TPUInfeedOutfeedSessionHook): - def __init__(self, ctx, enqueue_ops, dequeue_ops, rendezvous=None, - master=None, session_config=None): + def __init__(self, ctx, enqueue_ops, dequeue_ops, tpu_compile_op, + rendezvous=None, master=None, session_config=None): super(TPUInfeedOutfeedSessionHookForPrediction, self).__init__( ctx, enqueue_ops, dequeue_ops, + tpu_compile_op=tpu_compile_op, run_infeed_loop_on_coordinator=False, rendezvous=rendezvous, master=master, @@ -2234,7 +2261,7 @@ class TPUEstimator(estimator_lib.Estimator): def computation(): """Compute tpu tensors used in export_outputs. - Passed to rewrite so that model_fn will be called under + Passed to rewrite_for_inference so that model_fn will be called under the rewriting contexts. Only tpu tensors are returned, but export_outputs and scaffold are captured. @@ -2243,7 +2270,7 @@ class TPUEstimator(estimator_lib.Estimator): outside_compilation. """ # We should only call model fn once and it should be inside `computation` - # so that building the graph will happen under `rewrite`. + # so that building the graph will happen under `rewrite_for_inference`. mode = model_fn_lib.ModeKeys.PREDICT estimator_spec = self._call_model_fn(features, labels, mode, config) @@ -2253,32 +2280,24 @@ class TPUEstimator(estimator_lib.Estimator): (k, _export_output_to_tensors(v)) for k, v in six.iteritems(estimator_spec.export_outputs)) tensors = nest.flatten(tensors_dict) - tpu_tensors = [t for t in tensors if _is_tpu_tensor(t)] + tpu_tensors = [t for t in tensors if t is not None] # We cannot return anything other than `tpu_tensors` here so we capture # the rest for later use. capture.capture((estimator_spec, tensors_dict, tensors)) return tpu_tensors - tpu_tensors_on_cpu = tpu.rewrite(computation) + tpu_tensors_on_cpu = tpu.rewrite_for_inference(computation) estimator_spec, tensors_dict, tensors = capture.get() # Reconstruct `tensors`, but with `tpu_tensors` replaced with # `tpu_tensors_on_cpu`. new_tensors = [] for t in tensors: - if _is_tpu_tensor(t): - new_tensors.append(tpu_tensors_on_cpu.pop(0)) - elif t is None: + if t is None: new_tensors.append(None) else: - # Only fetching `tpu_tensors_on_cpu` does not trigger - # TPU computation and blocks, so we add the control dependency here. - control_inputs = ( - tpu_tensors_on_cpu if _is_iterable(tpu_tensors_on_cpu) else - (tpu_tensors_on_cpu,)) - with ops.control_dependencies(control_inputs): - new_tensors.append(array_ops.identity(t)) + new_tensors.append(tpu_tensors_on_cpu.pop(0)) # Reconstruct `tensors_dict`. new_tensors_dict = nest.pack_sequence_as(tensors_dict, new_tensors) @@ -2535,7 +2554,7 @@ class TPUEstimator(estimator_lib.Estimator): graph.add_to_collection(_TPU_ENQUEUE_OPS, enqueue_op) if mode == model_fn_lib.ModeKeys.TRAIN: - loss, host_call, scaffold, training_hooks = ( + compile_op, loss, host_call, scaffold, training_hooks = ( _train_on_tpu_system(ctx, model_fn_wrapper, dequeue_fn)) host_ops = host_call.create_tpu_hostcall() if host_ops is None: @@ -2570,6 +2589,7 @@ class TPUEstimator(estimator_lib.Estimator): ctx, enqueue_ops, host_ops, + tpu_compile_op=compile_op, run_infeed_loop_on_coordinator=( run_infeed_loop_on_coordinator), rendezvous=self._rendezvous[mode], @@ -2627,8 +2647,8 @@ class TPUEstimator(estimator_lib.Estimator): scaffold=scaffold) if mode == model_fn_lib.ModeKeys.EVAL: - total_loss, host_calls, scaffold, eval_hooks = _eval_on_tpu_system( - ctx, model_fn_wrapper, dequeue_fn) + compile_op, total_loss, host_calls, scaffold, eval_hooks = ( + _eval_on_tpu_system(ctx, model_fn_wrapper, dequeue_fn)) iterations_per_loop_var = _create_or_get_iterations_per_loop() mean_loss = math_ops.div( total_loss, @@ -2675,6 +2695,7 @@ class TPUEstimator(estimator_lib.Estimator): ctx, enqueue_ops, eval_update_ops + host_ops, + tpu_compile_op=compile_op, run_infeed_loop_on_coordinator=( run_infeed_loop_on_coordinator), rendezvous=self._rendezvous[mode], @@ -2695,7 +2716,7 @@ class TPUEstimator(estimator_lib.Estimator): # Predict assert mode == model_fn_lib.ModeKeys.PREDICT - (dummy_predict_op, host_calls, + (compile_op, dummy_predict_op, host_calls, scaffold, prediction_hooks) = _predict_on_tpu_system( ctx, model_fn_wrapper, dequeue_fn) with ops.control_dependencies([dummy_predict_op]): @@ -2752,6 +2773,7 @@ class TPUEstimator(estimator_lib.Estimator): _StoppingPredictHook(scalar_stopping_signal), TPUInfeedOutfeedSessionHookForPrediction( ctx, enqueue_ops, host_ops, rendezvous=self._rendezvous[mode], + tpu_compile_op=compile_op, master=self._config.master, session_config=self._session_config), ] + input_hooks @@ -2768,17 +2790,6 @@ class TPUEstimator(estimator_lib.Estimator): return _model_fn -def _is_tpu_tensor(tensor): - if not isinstance(tensor, ops.Tensor): - return False - try: - tensor.op.get_attr(tpu._OUTSIDE_COMPILATION_ATTR) # pylint: disable=protected-access - except ValueError: - return True - else: - return False - - def _export_output_to_tensors(export_output): """Get a list of `Tensors` used in `export_output`. @@ -2850,15 +2861,16 @@ def _eval_on_tpu_system(ctx, model_fn_wrapper, dequeue_fn): return training_loop.repeat(iterations_per_loop_var, single_tpu_eval_step, [_ZERO_LOSS]) - (loss,) = tpu.shard( + (compile_op, loss,) = tpu.split_compile_and_shard( multi_tpu_eval_steps_on_single_shard, inputs=[], num_shards=ctx.num_replicas, outputs_from_all_shards=False, device_assignment=ctx.device_assignment) + loss = loss[0] scaffold = _get_scaffold(captured_scaffold_fn) - return loss, host_calls, scaffold, captured_eval_hooks.get() + return compile_op, loss, host_calls, scaffold, captured_eval_hooks.get() def _train_on_tpu_system(ctx, model_fn_wrapper, dequeue_fn): @@ -2873,15 +2885,16 @@ def _train_on_tpu_system(ctx, model_fn_wrapper, dequeue_fn): return training_loop.repeat(iterations_per_loop_var, single_tpu_train_step, [_INITIAL_LOSS]) - (loss,) = tpu.shard( + (compile_op, loss,) = tpu.split_compile_and_shard( multi_tpu_train_steps_on_single_shard, inputs=[], num_shards=ctx.num_replicas, outputs_from_all_shards=False, device_assignment=ctx.device_assignment) + loss = loss[0] scaffold = _get_scaffold(captured_scaffold_fn) - return loss, host_call, scaffold, captured_training_hooks.get() + return compile_op, loss, host_call, scaffold, captured_training_hooks.get() def _predict_on_tpu_system(ctx, model_fn_wrapper, dequeue_fn): @@ -2901,15 +2914,17 @@ def _predict_on_tpu_system(ctx, model_fn_wrapper, dequeue_fn): cond, single_tpu_predict_step, inputs=inputs, name=b'loop') return outputs - (dummy_predict_op,) = tpu.shard( + (compile_op, dummy_predict_op,) = tpu.split_compile_and_shard( multi_tpu_predict_steps_on_single_shard, inputs=[], num_shards=ctx.num_replicas, outputs_from_all_shards=False, device_assignment=ctx.device_assignment) + dummy_predict_op = dummy_predict_op[0] scaffold = _get_scaffold(captured_scaffold_fn) - return dummy_predict_op, host_calls, scaffold, captured_predict_hooks.get() + return (compile_op, dummy_predict_op, host_calls, scaffold, + captured_predict_hooks.get()) def _wrap_computation_in_while_loop(device, op_fn): diff --git a/tensorflow/contrib/training/python/training/hparam.py b/tensorflow/contrib/training/python/training/hparam.py index 3beb7bfe3048a8f0294f7e9149b5a07b5fcc7d17..bcc177601b95172b05d327247bd370c2f8b65d59 100644 --- a/tensorflow/contrib/training/python/training/hparam.py +++ b/tensorflow/contrib/training/python/training/hparam.py @@ -187,7 +187,7 @@ def _cast_to_type_if_compatible(name, param_type, value): return param_type(value) -def parse_values(values, type_map): +def parse_values(values, type_map, ignore_unknown=False): """Parses hyperparameter values from a string into a python map. `values` is a string containing comma-separated `name=value` pairs. @@ -233,6 +233,9 @@ def parse_values(values, type_map): type T if either V has type T, or V is a list of elements of type T. Hence, for a multidimensional parameter 'x' taking float values, 'x=[0.1,0.2]' will parse successfully if type_map['x'] = float. + ignore_unknown: Bool. Whether values that are missing a type in type_map + should be ignored. If set to True, a ValueError will not be raised for + unknown hyperparameter type. Returns: A python map mapping each name to either: @@ -260,6 +263,8 @@ def parse_values(values, type_map): m_dict = m.groupdict() name = m_dict['name'] if name not in type_map: + if ignore_unknown: + continue raise ValueError('Unknown hyperparameter type for %s' % name) type_ = type_map[name] diff --git a/tensorflow/contrib/training/python/training/hparam_test.py b/tensorflow/contrib/training/python/training/hparam_test.py index 660c97f25e8458c345c8914bcaf98f37d047e50e..a990e04711ce68bd928a508484f0d6f657dd2f8c 100644 --- a/tensorflow/contrib/training/python/training/hparam_test.py +++ b/tensorflow/contrib/training/python/training/hparam_test.py @@ -216,6 +216,14 @@ class HParamsTest(test.TestCase): self.assertTrue(isinstance(parse_dict['arr'], dict)) self.assertDictEqual(parse_dict['arr'], {1: 10}) + def testParseValuesWithIndexAssigment1_IgnoreUnknown(self): + """Assignment to an index position.""" + parse_dict = hparam.parse_values( + 'arr[1]=10,b=5', {'arr': int}, ignore_unknown=True) + self.assertEqual(len(parse_dict), 1) + self.assertTrue(isinstance(parse_dict['arr'], dict)) + self.assertDictEqual(parse_dict['arr'], {1: 10}) + def testParseValuesWithIndexAssigment2(self): """Assignment to multiple index positions.""" parse_dict = hparam.parse_values('arr[0]=10,arr[5]=20', {'arr': int}) @@ -223,6 +231,14 @@ class HParamsTest(test.TestCase): self.assertTrue(isinstance(parse_dict['arr'], dict)) self.assertDictEqual(parse_dict['arr'], {0: 10, 5: 20}) + def testParseValuesWithIndexAssigment2_IgnoreUnknown(self): + """Assignment to multiple index positions.""" + parse_dict = hparam.parse_values( + 'arr[0]=10,arr[5]=20,foo=bar', {'arr': int}, ignore_unknown=True) + self.assertEqual(len(parse_dict), 1) + self.assertTrue(isinstance(parse_dict['arr'], dict)) + self.assertDictEqual(parse_dict['arr'], {0: 10, 5: 20}) + def testParseValuesWithIndexAssigment3(self): """Assignment to index positions in multiple names.""" parse_dict = hparam.parse_values('arr[0]=10,arr[1]=20,L[5]=100,L[10]=200', @@ -234,6 +250,17 @@ class HParamsTest(test.TestCase): self.assertTrue(isinstance(parse_dict['L'], dict)) self.assertDictEqual(parse_dict['L'], {5: 100, 10: 200}) + def testParseValuesWithIndexAssigment3_IgnoreUnknown(self): + """Assignment to index positions in multiple names.""" + parse_dict = hparam.parse_values( + 'arr[0]=10,C=5,arr[1]=20,B[0]=kkk,L[5]=100,L[10]=200', + {'arr': int, 'L': int}, ignore_unknown=True) + self.assertEqual(len(parse_dict), 2) + self.assertTrue(isinstance(parse_dict['arr'], dict)) + self.assertDictEqual(parse_dict['arr'], {0: 10, 1: 20}) + self.assertTrue(isinstance(parse_dict['L'], dict)) + self.assertDictEqual(parse_dict['L'], {5: 100, 10: 200}) + def testParseValuesWithIndexAssigment4(self): """Assignment of index positions and scalars.""" parse_dict = hparam.parse_values('x=10,arr[1]=20,y=30', @@ -246,6 +273,17 @@ class HParamsTest(test.TestCase): self.assertEqual(parse_dict['x'], 10) self.assertEqual(parse_dict['y'], 30) + def testParseValuesWithIndexAssigment4_IgnoreUnknown(self): + """Assignment of index positions and scalars.""" + parse_dict = hparam.parse_values( + 'x=10,foo[0]=bar,arr[1]=20,zzz=78,y=30', + {'x': int, 'y': int, 'arr': int}, ignore_unknown=True) + self.assertEqual(len(parse_dict), 3) + self.assertTrue(isinstance(parse_dict['arr'], dict)) + self.assertDictEqual(parse_dict['arr'], {1: 20}) + self.assertEqual(parse_dict['x'], 10) + self.assertEqual(parse_dict['y'], 30) + def testParseValuesWithIndexAssigment5(self): """Different variable types.""" parse_dict = hparam.parse_values('a[0]=5,b[1]=true,c[2]=abc,d[3]=3.14', { @@ -264,24 +302,55 @@ class HParamsTest(test.TestCase): self.assertTrue(isinstance(parse_dict['d'], dict)) self.assertDictEqual(parse_dict['d'], {3: 3.14}) + def testParseValuesWithIndexAssigment5_IgnoreUnknown(self): + """Different variable types.""" + parse_dict = hparam.parse_values( + 'a[0]=5,cc=4,b[1]=true,c[2]=abc,mm=2,d[3]=3.14', + {'a': int, 'b': bool, 'c': str, 'd': float}, + ignore_unknown=True) + self.assertEqual(set(parse_dict.keys()), {'a', 'b', 'c', 'd'}) + self.assertTrue(isinstance(parse_dict['a'], dict)) + self.assertDictEqual(parse_dict['a'], {0: 5}) + self.assertTrue(isinstance(parse_dict['b'], dict)) + self.assertDictEqual(parse_dict['b'], {1: True}) + self.assertTrue(isinstance(parse_dict['c'], dict)) + self.assertDictEqual(parse_dict['c'], {2: 'abc'}) + self.assertTrue(isinstance(parse_dict['d'], dict)) + self.assertDictEqual(parse_dict['d'], {3: 3.14}) + def testParseValuesWithBadIndexAssigment1(self): """Reject assignment of list to variable type.""" with self.assertRaisesRegexp(ValueError, r'Assignment of a list to a list index.'): hparam.parse_values('arr[1]=[1,2,3]', {'arr': int}) + def testParseValuesWithBadIndexAssigment1_IgnoreUnknown(self): + """Reject assignment of list to variable type.""" + with self.assertRaisesRegexp(ValueError, + r'Assignment of a list to a list index.'): + hparam.parse_values( + 'arr[1]=[1,2,3],c=8', {'arr': int}, ignore_unknown=True) + def testParseValuesWithBadIndexAssigment2(self): """Reject if type missing.""" with self.assertRaisesRegexp(ValueError, r'Unknown hyperparameter type for arr'): hparam.parse_values('arr[1]=5', {}) + def testParseValuesWithBadIndexAssigment2_IgnoreUnknown(self): + """Ignore missing type.""" + hparam.parse_values('arr[1]=5', {}, ignore_unknown=True) + def testParseValuesWithBadIndexAssigment3(self): """Reject type of the form name[index].""" with self.assertRaisesRegexp(ValueError, 'Unknown hyperparameter type for arr'): hparam.parse_values('arr[1]=1', {'arr[1]': int}) + def testParseValuesWithBadIndexAssigment3_IgnoreUnknown(self): + """Ignore type of the form name[index].""" + hparam.parse_values('arr[1]=1', {'arr[1]': int}, ignore_unknown=True) + def testWithReusedVariables(self): with self.assertRaisesRegexp(ValueError, 'Multiple assignments to variable \'x\''): diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD index 66714235b535c14a8f13c40bb2a4df8d7494dc05..8bf1480d33b2d2117fb5c7ddf046262cfeb8a8ab 100644 --- a/tensorflow/core/BUILD +++ b/tensorflow/core/BUILD @@ -49,7 +49,7 @@ # filegroup ":android_proto_srcs" - Protos # filegroup ":android_srcs" - Core sources # cc_library ":android_tensorflow_lib" - Native library -# cc_library ":android_tensorflow_lib_selective_registration" - Native library +# cc_library ":android_tensorflow_lib_lite" - Native library, without ops, # supporting SELECTIVE_REGISTRATION feature. # portable_proto_library ":android_proto_lib" (Google-internal) # @@ -113,7 +113,6 @@ load( "tf_additional_device_tracer_test_flags", "tf_additional_gdr_lib_defines", "tf_additional_human_readable_json_deps", - "tf_additional_logger_deps", "tf_additional_lib_defines", "tf_additional_lib_deps", "tf_additional_lib_hdrs", @@ -446,15 +445,31 @@ cc_library( ) cc_library( - name = "logger", - srcs = tf_platform_srcs(["logger.cc"]), - hdrs = ["platform/logger.h"] + tf_platform_hdrs(["logger.h"]), + name = "logger_interface", + hdrs = ["platform/logger.h"], copts = tf_copts(), visibility = ["//visibility:public"], deps = [ - ":lib", - ":lib_internal", - ] + tf_additional_logger_deps(), + ":lib_proto_parsing", + "@protobuf_archive//:protobuf", + ], +) + +cc_library( + name = "default_logger", + srcs = ["platform/default/logger.cc"], + hdrs = ["platform/logger.h"], + deps = [ + "//tensorflow/core:lib_proto_parsing", + "//tensorflow/core:logger_interface", + ], +) + +cc_library( + name = "logger", + hdrs = ["platform/logger.h"], + visibility = ["//visibility:public"], + deps = ["//tensorflow/core/platform/default/build_config:logger"], ) filegroup( @@ -1611,6 +1626,9 @@ filegroup( "**/*main.cc", "debug/**/*", "framework/op_gen_*", + "framework/node_def_util.*", + "framework/op_kernel.*", + "framework/dataset.*", "lib/jpeg/**/*", "lib/png/**/*", "lib/gif/**/*", @@ -1619,7 +1637,6 @@ filegroup( "util/reporter.*", "platform/**/cuda_libdevice_path.*", "platform/**/logger.cc", - "platform/**/logger.h", "platform/default/test_benchmark.*", "platform/cuda.h", "platform/google/**/*", @@ -1654,6 +1671,9 @@ filegroup( "common_runtime/**/*.cc", "graph/**/*.h", "graph/**/*.cc", + "framework/node_def_util.*", + "framework/op_kernel.*", + "framework/dataset.*", ], exclude = [ "**/*test.*", @@ -1829,27 +1849,6 @@ cc_library( alwayslink = 1, ) -# Android library for use with the SELECTIVE_REGISTRATION feature. -# Does not contain operators. In contrast to android_tensorflow_lib_lite, -# this links in framework support for all types, relying on selective -# registration of ops to prune code size. -# -# TODO(gonnet): Move all users of these aliases to the corresponding -# :android_tensorflow_lib_lite* targets and remove. -alias( - name = "android_tensorflow_lib_selective_registration", - actual = ":android_tensorflow_lib_lite", - visibility = ["//visibility:public"], -) - -# Android library for use with the SELECTIVE_REGISTRATION feature with -# no proto_rtti. -alias( - name = "android_tensorflow_lib_selective_registration_nortti", - actual = ":android_tensorflow_lib_lite_nortti", - visibility = ["//visibility:public"], -) - filegroup( name = "android_op_registrations_and_gradients", srcs = glob( @@ -4059,20 +4058,6 @@ tf_cuda_cc_test( ], ) -tf_cc_test_gpu( - name = "cuda_libdevice_path_test", - size = "small", - srcs = ["platform/cuda_libdevice_path_test.cc"], - linkstatic = tf_kernel_tests_linkstatic(), - tags = tf_cuda_tests_tags(), - deps = [ - ":cuda_libdevice_path", - ":lib", - ":test", - ":test_main", - ], -) - tf_cuda_only_cc_test( name = "util_cuda_kernel_helper_test", srcs = [ @@ -4928,7 +4913,7 @@ filegroup( cc_library( name = "cuda_libdevice_path", - srcs = ["platform/cuda_libdevice_path.cc"] + tf_additional_libdevice_srcs(), + srcs = tf_additional_libdevice_srcs(), hdrs = ["platform/cuda_libdevice_path.h"], copts = tf_copts(), data = tf_additional_libdevice_data(), diff --git a/tensorflow/core/api_def/base_api/api_def_OneHot.pbtxt b/tensorflow/core/api_def/base_api/api_def_OneHot.pbtxt index 807b8ae31015e4bcb73e54e98d879460f0d92f62..b325df1c8c2b231f03a1960babd2d915b1b0e72d 100644 --- a/tensorflow/core/api_def/base_api/api_def_OneHot.pbtxt +++ b/tensorflow/core/api_def/base_api/api_def_OneHot.pbtxt @@ -66,7 +66,6 @@ Examples ========= Suppose that - ``` indices = [0, 2, -1, 1] depth = 3 @@ -76,16 +75,15 @@ Suppose that ``` Then output is `[4 x 3]`: - - ```output = - [5.0 0.0 0.0] // one_hot(0) - [0.0 0.0 5.0] // one_hot(2) - [0.0 0.0 0.0] // one_hot(-1) - [0.0 5.0 0.0] // one_hot(1) - ``` +``` +output = + [5.0 0.0 0.0] // one_hot(0) + [0.0 0.0 5.0] // one_hot(2) + [0.0 0.0 0.0] // one_hot(-1) + [0.0 5.0 0.0] // one_hot(1) +``` Suppose that - ``` indices = [0, 2, -1, 1] depth = 3 @@ -95,19 +93,19 @@ Suppose that ``` Then output is `[3 x 4]`: +``` +output = + [0.0 3.0 3.0 3.0] + [3.0 3.0 3.0 0.0] + [3.0 3.0 3.0 3.0] + [3.0 0.0 3.0 3.0] +// ^ one_hot(0) +// ^ one_hot(2) +// ^ one_hot(-1) +// ^ one_hot(1) +``` - ```output = - [0.0 3.0 3.0 3.0] - [3.0 3.0 3.0 0.0] - [3.0 3.0 3.0 3.0] - [3.0 0.0 3.0 3.0] - // ^ one_hot(0) - // ^ one_hot(2) - // ^ one_hot(-1) - // ^ one_hot(1) - ``` Suppose that - ``` indices = [[0, 2], [1, -1]] depth = 3 @@ -117,14 +115,15 @@ Suppose that ``` Then output is `[2 x 2 x 3]`: - - ```output = - [ - [1.0, 0.0, 0.0] // one_hot(0) - [0.0, 0.0, 1.0] // one_hot(2) - ][ - [0.0, 1.0, 0.0] // one_hot(1) - [0.0, 0.0, 0.0] // one_hot(-1) - ]``` +``` +output = + [ + [1.0, 0.0, 0.0] // one_hot(0) + [0.0, 0.0, 1.0] // one_hot(2) + ][ + [0.0, 1.0, 0.0] // one_hot(1) + [0.0, 0.0, 0.0] // one_hot(-1) + ] +``` END } diff --git a/tensorflow/core/api_def/base_api/api_def_UnicodeDecode.pbtxt b/tensorflow/core/api_def/base_api/api_def_UnicodeDecode.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..9b3f69023f1167fc3964a82a1e425d619ecc5521 --- /dev/null +++ b/tensorflow/core/api_def/base_api/api_def_UnicodeDecode.pbtxt @@ -0,0 +1,76 @@ +op { + graph_op_name: "UnicodeDecode" + in_arg { + name: "input" + description: <