diff --git a/tools/bazel.rc b/.bazelrc
similarity index 95%
rename from tools/bazel.rc
rename to .bazelrc
index 1fdf51f53e29c7111cf89c016400b710051cf9c6..cd7e13ddfc146208f79be900917b05b694869d72 100644
--- a/tools/bazel.rc
+++ b/.bazelrc
@@ -76,7 +76,6 @@ build:nonccl --define=no_nccl_support=true
build --define=use_fast_cpp_protos=true
build --define=allow_oversize_protos=true
-build --define=grpc_no_ares=true
build --spawn_strategy=standalone
build --genrule_strategy=standalone
@@ -93,3 +92,11 @@ build:dynamic_kernels --copt=-DAUTOLOAD_DYNAMIC_KERNELS
build --define=PREFIX=/usr
build --define=LIBDIR=$(PREFIX)/lib
build --define=INCLUDEDIR=$(PREFIX)/include
+
+# Default options should come above this line
+
+# Options from ./configure
+try-import %workspace%/.tf_configure.bazelrc
+
+# Put user-specific options in .bazelrc.user
+try-import %workspace%/.bazelrc.user
diff --git a/.gitignore b/.gitignore
index 90324058600bee46af56e49028977971848a80de..e1d352c238a1b2d4febe0f5d4a30cfa0c942f7e7 100644
--- a/.gitignore
+++ b/.gitignore
@@ -1,7 +1,7 @@
.DS_Store
.ipynb_checkpoints
node_modules
-/.bazelrc
+/.bazelrc.user
/.tf_configure.bazelrc
/bazel-*
/bazel_pip
diff --git a/README.md b/README.md
index 044174947a094d43a51f7140dd40ec0f17801d40..519815d006cc33be10132909baf414a4bd843435 100644
--- a/README.md
+++ b/README.md
@@ -113,11 +113,12 @@ The TensorFlow project strives to abide by generally accepted best practices in
Build Type | Status | Artifacts
----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ---------
**IBM s390x** | [](http://ibmz-ci.osuosl.org/job/TensorFlow_IBMZ_CI/) | TBA
-**IBM ppc64le CPU** | [](http://powerci.osuosl.org/job/TensorFlow_PPC64LE_CPU_Build/) | TBA
-**IBM ppc64le GPU** Nightly | [](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Nightly_Artifact/) | [Nightly](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Nightly_Artifact/)
-**IBM ppc64le GPU** Stable Release | [](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Release_Build/) | [Release](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Release_Build/)
+**Linux ppc64le CPU** Nightly | [](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_CPU_Build/) | [Nightly](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_CPU_Nightly_Artifact/)
+**Linux ppc64le CPU** Stable Release | [](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_CPU_Release_Build/) | [Release](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_CPU_Release_Build/)
+**Linux ppc64le GPU** Nightly | [](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Build/) | [Nightly](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Nightly_Artifact/)
+**Linux ppc64le GPU** Stable Release | [](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Release_Build/) | [Release](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Release_Build/)
**Linux CPU with Intel® MKL-DNN** Nightly | [](https://tensorflow-ci.intel.com/job/tensorflow-mkl-linux-cpu/) | [Nightly](https://tensorflow-ci.intel.com/job/tensorflow-mkl-build-whl-nightly/)
-**Linux CPU with Intel® MKL-DNN** Python 2.7
**Linux CPU with Intel® MKL-DNN** Python 3.4
**Linux CPU with Intel® MKL-DNN** Python 3.5
**Linux CPU with Intel® MKL-DNN** Python 3.6 | [](https://tensorflow-ci.intel.com/job/tensorflow-mkl-build-release-whl/lastStableBuild) | [1.11.0 py2.7](https://storage.googleapis.com/intel-optimized-tensorflow/tensorflow-1.11.0-cp27-cp27mu-linux_x86_64.whl)
[1.11.0 py3.4](https://storage.googleapis.com/intel-optimized-tensorflow/tensorflow-1.11.0-cp34-cp34m-linux_x86_64.whl)
[1.11.0 py3.5](https://storage.googleapis.com/intel-optimized-tensorflow/tensorflow-1.11.0-cp35-cp35m-linux_x86_64.whl)
[1.11.0 py3.6](https://storage.googleapis.com/intel-optimized-tensorflow/tensorflow-1.11.0-cp36-cp36m-linux_x86_64.whl)
+**Linux CPU with Intel® MKL-DNN** Python 2.7
**Linux CPU with Intel® MKL-DNN** Python 3.4
**Linux CPU with Intel® MKL-DNN** Python 3.5
**Linux CPU with Intel® MKL-DNN** Python 3.6 | [](https://tensorflow-ci.intel.com/job/tensorflow-mkl-build-release-whl/lastStableBuild) | [1.12.0 py2.7](https://storage.googleapis.com/intel-optimized-tensorflow/tensorflow-1.12.0-cp27-cp27mu-linux_x86_64.whl)
[1.12.0 py3.4](https://storage.googleapis.com/intel-optimized-tensorflow/tensorflow-1.12.0-cp34-cp34m-linux_x86_64.whl)
[1.12.0 py3.5](https://storage.googleapis.com/intel-optimized-tensorflow/tensorflow-1.12.0-cp35-cp35m-linux_x86_64.whl)
[1.12.0 py3.6](https://storage.googleapis.com/intel-optimized-tensorflow/tensorflow-1.12.0-cp36-cp36m-linux_x86_64.whl)
## For more information
diff --git a/RELEASE.md b/RELEASE.md
index b13b071bd6cf4d3a260c8e248a67d23e1a688498..32abdcea497618918964174a661a6ba872598f65 100644
--- a/RELEASE.md
+++ b/RELEASE.md
@@ -7,6 +7,8 @@
Serving.
* Keras models now support evaluating with a `tf.data.Dataset`.
* TensorFlow binaries are built with XLA support linked in by default.
+* Ignite Dataset added to contrib/ignite that allows to work with Apache
+ Ignite.
## Bug Fixes and Other Changes
diff --git a/WORKSPACE b/WORKSPACE
index 7cc08e0164a202581ad7ebbe107a9e19410e70e4..7057d3f149e766cd2983ecc89509f84c37075602 100644
--- a/WORKSPACE
+++ b/WORKSPACE
@@ -16,30 +16,27 @@ load("@io_bazel_rules_closure//closure:defs.bzl", "closure_repositories")
closure_repositories()
-http_archive(
- name = "base_images_docker",
- sha256 = "e2b1b7254270bb7605e814a9dbf6d1e4ae04a11136ff1714fbfdabe3f87f7cf9",
- strip_prefix = "base-images-docker-12801524f867e657fbb5d1a74f31618aff181ac6",
- urls = ["https://github.com/GoogleCloudPlatform/base-images-docker/archive/12801524f867e657fbb5d1a74f31618aff181ac6.tar.gz"],
-)
+load("//third_party/toolchains/preconfig/generate:archives.bzl",
+ "bazel_toolchains_archive")
-http_archive(
- name = "bazel_toolchains",
- sha256 = "15b5858b1b5541ec44df31b94c3b8672815b31d71215a98398761ea9f4c4eedb",
- strip_prefix = "bazel-toolchains-6200b238c9c2d137c0d9a7262c80cc71d98e692b",
- urls = [
- "https://github.com/bazelbuild/bazel-toolchains/archive/6200b238c9c2d137c0d9a7262c80cc71d98e692b.tar.gz",
- ],
+bazel_toolchains_archive()
+
+load(
+ "@bazel_toolchains//repositories:repositories.bzl",
+ bazel_toolchains_repositories = "repositories",
)
-http_archive(
- name = "io_bazel_rules_docker",
- sha256 = "29d109605e0d6f9c892584f07275b8c9260803bf0c6fcb7de2623b2bedc910bd",
- strip_prefix = "rules_docker-0.5.1",
- urls = ["https://github.com/bazelbuild/rules_docker/archive/v0.5.1.tar.gz"],
+bazel_toolchains_repositories()
+
+load(
+ "@io_bazel_rules_docker//container:container.bzl",
+ container_repositories = "repositories",
)
-load("//third_party/toolchains/preconfig/generate:workspace.bzl", "remote_config_workspace")
+container_repositories()
+
+load("//third_party/toolchains/preconfig/generate:workspace.bzl",
+ "remote_config_workspace")
remote_config_workspace()
@@ -47,7 +44,7 @@ remote_config_workspace()
# files, in case the parsing of those build files depends on the bazel
# version we require here.
load("//tensorflow:version_check.bzl", "check_bazel_version_at_least")
-check_bazel_version_at_least("0.15.0")
+check_bazel_version_at_least("0.18.0")
load("//tensorflow:workspace.bzl", "tf_workspace")
diff --git a/tensorflow/opensource_only/arm_compiler.BUILD b/arm_compiler.BUILD
similarity index 100%
rename from tensorflow/opensource_only/arm_compiler.BUILD
rename to arm_compiler.BUILD
diff --git a/configure.py b/configure.py
index 6c905a0be3d685b5921dfbc5bddfbe6471a82625..1e732db26404906901a9eeab97a5e75137ee8388 100644
--- a/configure.py
+++ b/configure.py
@@ -255,18 +255,6 @@ def setup_python(environ_cp):
def reset_tf_configure_bazelrc():
"""Reset file that contains customized config settings."""
open(_TF_BAZELRC, 'w').close()
- bazelrc_path = os.path.join(_TF_WORKSPACE_ROOT, '.bazelrc')
-
- data = []
- if os.path.exists(bazelrc_path):
- with open(bazelrc_path, 'r') as f:
- data = f.read().splitlines()
- with open(bazelrc_path, 'w') as f:
- for l in data:
- if _TF_BAZELRC_FILENAME in l:
- continue
- f.write('%s\n' % l)
- f.write('import %%workspace%%/%s\n' % _TF_BAZELRC_FILENAME)
def cleanup_makefile():
"""Delete any leftover BUILD files from the Makefile build.
@@ -488,11 +476,12 @@ def check_bazel_version(min_version, max_version):
if curr_version_int < min_version_int:
print('Please upgrade your bazel installation to version %s or higher to '
'build TensorFlow!' % min_version)
- sys.exit(0)
- if curr_version_int > max_version_int:
+ sys.exit(1)
+ if (curr_version_int > max_version_int and
+ 'TF_IGNORE_MAX_BAZEL_VERSION' not in os.environ):
print('Please downgrade your bazel installation to version %s or lower to '
'build TensorFlow!' % max_version)
- sys.exit(0)
+ sys.exit(1)
return curr_version
@@ -1565,11 +1554,9 @@ def main():
# environment variables.
environ_cp = dict(os.environ)
- check_bazel_version('0.15.0', '0.20.0')
+ check_bazel_version('0.19.0', '0.20.0')
reset_tf_configure_bazelrc()
- # Explicitly import tools/bazel.rc, this is needed for Bazel 0.19.0 or later
- write_to_bazelrc('import %workspace%/tools/bazel.rc')
cleanup_makefile()
setup_python(environ_cp)
diff --git a/tensorflow/BUILD b/tensorflow/BUILD
index fd4b94202aad24a82abef8abd16431f61a8326f0..449a1372edb031c68786d8672e2a1499c2b3d047 100644
--- a/tensorflow/BUILD
+++ b/tensorflow/BUILD
@@ -267,6 +267,15 @@ config_setting(
visibility = ["//visibility:public"],
)
+# By default, XLA GPU is compiled into tensorflow when building with
+# --config=cuda even when `with_xla_support` is false. The config setting
+# here allows us to override the behavior if needed.
+config_setting(
+ name = "no_xla_deps_in_cuda",
+ define_values = {"no_xla_deps_in_cuda": "true"},
+ visibility = ["//visibility:public"],
+)
+
config_setting(
name = "with_gdr_support",
define_values = {"with_gdr_support": "true"},
@@ -606,9 +615,11 @@ py_library(
name = "tensorflow_py",
srcs_version = "PY2AND3",
visibility = ["//visibility:public"],
- deps = [
+ deps = select({
+ "api_version_2": [],
+ "//conditions:default": ["//tensorflow/contrib:contrib_py"],
+ }) + [
":tensorflow_py_no_contrib",
- "//tensorflow/contrib:contrib_py",
"//tensorflow/python/estimator:estimator_py",
],
)
diff --git a/tensorflow/api_template.__init__.py b/tensorflow/api_template.__init__.py
index d81cf067eb07e88e2b8a86cf5643674235eb3f3b..4eba763129a6aef40e3c130d56bf8ab19638b7ca 100644
--- a/tensorflow/api_template.__init__.py
+++ b/tensorflow/api_template.__init__.py
@@ -20,14 +20,14 @@ from __future__ import print_function as _print_function
import os as _os
+# API IMPORTS PLACEHOLDER
+
# pylint: disable=g-bad-import-order
from tensorflow.python.tools import component_api_helper as _component_api_helper
_component_api_helper.package_hook(
parent_package_str=__name__,
child_package_str=('tensorflow_estimator.python.estimator.api.estimator'))
-# API IMPORTS PLACEHOLDER
-
# Make sure directory containing top level submodules is in
# the __path__ so that "from tensorflow.foo import bar" works.
# We're using bitwise, but there's nothing special about that.
diff --git a/tensorflow/api_template_v1.__init__.py b/tensorflow/api_template_v1.__init__.py
index 65bdb6cb1b5e6fb0656a12b932d767aeacfccd29..21b5277614667bdbd7271ac3e57f5b69d5a19264 100644
--- a/tensorflow/api_template_v1.__init__.py
+++ b/tensorflow/api_template_v1.__init__.py
@@ -23,13 +23,13 @@ import os as _os
# pylint: disable=g-bad-import-order
from tensorflow.python import pywrap_tensorflow # pylint: disable=unused-import
+# API IMPORTS PLACEHOLDER
+
from tensorflow.python.tools import component_api_helper as _component_api_helper
_component_api_helper.package_hook(
parent_package_str=__name__,
child_package_str=('tensorflow_estimator.python.estimator.api.estimator'))
-# API IMPORTS PLACEHOLDER
-
from tensorflow.python.util.lazy_loader import LazyLoader # pylint: disable=g-import-not-at-top
contrib = LazyLoader('contrib', globals(), 'tensorflow.contrib')
del LazyLoader
diff --git a/tensorflow/c/c_api.cc b/tensorflow/c/c_api.cc
index 94d18eb8b04e3534be547aca5cfbb32da40ffbf6..9580215a317b1a6b1cdacbd430a1764af61be990 100644
--- a/tensorflow/c/c_api.cc
+++ b/tensorflow/c/c_api.cc
@@ -488,6 +488,7 @@ static TF_Tensor* EmptyTensor(TF_DataType dtype, const TensorShape& shape) {
// Non-static for testing.
TF_Tensor* TF_TensorFromTensor(const tensorflow::Tensor& src,
TF_Status* status) {
+ TF_SetStatus(status, TF_OK, "");
if (!src.IsInitialized()) {
status->status = FailedPrecondition(
"attempt to use a tensor with an uninitialized value");
diff --git a/tensorflow/c/eager/c_api.h b/tensorflow/c/eager/c_api.h
index f80ae5a6d02d4d613c95cf8486e0fc0aeed3affc..120748ab763a3358b6e38e64bb3b6fd2ea32f7c3 100755
--- a/tensorflow/c/eager/c_api.h
+++ b/tensorflow/c/eager/c_api.h
@@ -170,23 +170,11 @@ TF_CAPI_EXPORT extern int64_t TFE_TensorHandleDim(TFE_TensorHandle* h,
int dim_index,
TF_Status* status);
-// Returns the device of the operation that produced `h`.
-// If `h` was produced by a copy, returns the destination device of
-// the copy. Note that returned device name is not always the device
-// holding the tensor handle's memory. If you want the latter, use
-// TFE_TensorHandleBackingDeviceName.
-// This function will block till the operation that produces `h` has completed.
-//
-// Device on which the kernel of the operation that produced `h` ran.
-//
-// If `h` was produced by a copy, returns the destination device of
-// the copy.
-//
-// Note that returned device name is not always the device that owns the memory
-// that backs the tensor handle. For the latter see
-// TFE_TensorHandleBackingDeviceName.
-//
-// This function will block till the operation that produces `h` has completed.
+// Returns the device of the operation that produced `h`. If `h` was produced by
+// a copy, returns the destination device of the copy. Note that the returned
+// device name is not always the device holding the tensor handle's memory. If
+// you want the latter, use TFE_TensorHandleBackingDeviceName. This function
+// will block till the operation that produces `h` has completed.
TF_CAPI_EXPORT extern const char* TFE_TensorHandleDeviceName(
TFE_TensorHandle* h, TF_Status* status);
diff --git a/tensorflow/c/env.cc b/tensorflow/c/env.cc
index 07b9e8b940c55caf62ae0b81b884bf313d335459..1c35ff9001d0ee1ab0fbae9e1bcc07116fab1065 100644
--- a/tensorflow/c/env.cc
+++ b/tensorflow/c/env.cc
@@ -159,3 +159,25 @@ TF_CAPI_EXPORT extern uint64_t TF_NowMicros(void) {
TF_CAPI_EXPORT extern uint64_t TF_NowSeconds(void) {
return ::tensorflow::Env::Default()->NowSeconds();
}
+
+void TF_DefaultThreadOptions(TF_ThreadOptions* options) {
+ options->stack_size = 0;
+ options->guard_size = 0;
+ options->numa_node = -1;
+}
+
+TF_Thread* TF_StartThread(const TF_ThreadOptions* options,
+ const char* thread_name, void (*work_func)(void*),
+ void* param) {
+ ::tensorflow::ThreadOptions cc_options;
+ cc_options.stack_size = options->stack_size;
+ cc_options.guard_size = options->guard_size;
+ cc_options.numa_node = options->numa_node;
+ return reinterpret_cast(::tensorflow::Env::Default()->StartThread(
+ cc_options, thread_name, [=]() { (*work_func)(param); }));
+}
+
+void TF_JoinThread(TF_Thread* thread) {
+ // ::tensorflow::Thread joins on destruction
+ delete reinterpret_cast<::tensorflow::Thread*>(thread);
+}
diff --git a/tensorflow/c/env.h b/tensorflow/c/env.h
index 9d27c5da37735042c7476b591e57486dbde33152..15652353cd7e1f1e7d7a4c665703c0166682d790 100644
--- a/tensorflow/c/env.h
+++ b/tensorflow/c/env.h
@@ -13,6 +13,9 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
+#include
+#include
+
#ifndef TENSORFLOW_C_ENV_H_
#define TENSORFLOW_C_ENV_H_
@@ -23,6 +26,7 @@ limitations under the License.
struct TF_WritableFileHandle;
struct TF_StringStream;
+struct TF_Thread;
#ifdef __cplusplus
extern "C" {
@@ -37,6 +41,20 @@ typedef struct TF_FileStatistics {
bool is_directory;
} TF_FileStatistics;
+typedef struct TF_ThreadOptions {
+ // Thread stack size to use (in bytes), zero implies that the system default
+ // will be used.
+ size_t stack_size;
+
+ // Guard area size to use near thread stacks to use (in bytes), zero implies
+ // that the system default will be used.
+ size_t guard_size;
+
+ // The NUMA node to use, -1 implies that there should be no NUMA affinity for
+ // this thread.
+ int numa_node;
+} TF_ThreadOptions;
+
// Creates the specified directory. Typical status code are:
// * TF_OK - successfully created the directory
// * TF_ALREADY_EXISTS - directory already exists
@@ -150,6 +168,25 @@ TF_CAPI_EXPORT extern uint64_t TF_NowMicros(void);
// Returns the number of seconds since the Unix epoch.
TF_CAPI_EXPORT extern uint64_t TF_NowSeconds(void);
+// Populates a TF_ThreadOptions struct with system-default values.
+TF_CAPI_EXPORT extern void TF_DefaultThreadOptions(TF_ThreadOptions* options);
+
+// Returns a new thread that is running work_func and is identified
+// (for debugging/performance-analysis) by thread_name.
+//
+// The given param (which may be null) is passed to work_func when the thread
+// starts. In this way, data may be passed from the thread back to the caller.
+//
+// Caller takes ownership of the result and must call TF_JoinThread on it
+// eventually.
+TF_CAPI_EXPORT extern TF_Thread* TF_StartThread(const TF_ThreadOptions* options,
+ const char* thread_name,
+ void (*work_func)(void*),
+ void* param);
+
+// Waits for the given thread to finish execution, then deletes it.
+TF_CAPI_EXPORT extern void TF_JoinThread(TF_Thread* thread);
+
#ifdef __cplusplus
}
#endif
diff --git a/tensorflow/c/env_test.cc b/tensorflow/c/env_test.cc
index e2206c6befd2167346c64032940d6e8c631e4a3e..687ad024137352662759ec1f43df87e89faca353 100644
--- a/tensorflow/c/env_test.cc
+++ b/tensorflow/c/env_test.cc
@@ -17,6 +17,7 @@ limitations under the License.
#include "tensorflow/c/c_api.h"
#include "tensorflow/core/lib/io/path.h"
+#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/platform/types.h"
@@ -98,3 +99,29 @@ TEST(TestEnv, TestTimeFunctions) {
ASSERT_GE(TF_NowMicros(), 946684800 * 1e6);
ASSERT_GE(TF_NowNanos(), 946684800 * 1e9);
}
+
+namespace {
+
+struct SomeThreadData {
+ ::tensorflow::mutex mu;
+ bool did_work = false;
+};
+
+void SomeThreadFunc(void* data) {
+ auto* real_data = static_cast(data);
+ ::tensorflow::mutex_lock l(real_data->mu);
+ real_data->did_work = true;
+}
+
+} // namespace
+
+TEST(TestEnv, TestThreads) {
+ TF_ThreadOptions options;
+ TF_DefaultThreadOptions(&options);
+ SomeThreadData data;
+ TF_Thread* thread =
+ TF_StartThread(&options, "SomeThreadName", &SomeThreadFunc, &data);
+ TF_JoinThread(thread);
+ ::tensorflow::mutex_lock l(data.mu);
+ ASSERT_TRUE(data.did_work);
+}
diff --git a/tensorflow/compat_template_v1.__init__.py b/tensorflow/compat_template_v1.__init__.py
index 7df80ec01245a7fe820c79d5879458c4cd0a93cb..d58acde09f007bc9df40b08b0ef79c6031ca7941 100644
--- a/tensorflow/compat_template_v1.__init__.py
+++ b/tensorflow/compat_template_v1.__init__.py
@@ -23,12 +23,12 @@ import os as _os
# pylint: disable=g-bad-import-order
from tensorflow.python import pywrap_tensorflow # pylint: disable=unused-import
+# API IMPORTS PLACEHOLDER
+
from tensorflow.python.tools import component_api_helper as _component_api_helper
_component_api_helper.package_hook(
parent_package_str=__name__,
child_package_str=('tensorflow_estimator.python.estimator.api.estimator'))
-# API IMPORTS PLACEHOLDER
-
from tensorflow.python.platform import flags # pylint: disable=g-import-not-at-top
app.flags = flags # pylint: disable=undefined-variable
diff --git a/tensorflow/compiler/aot/tfcompile.bzl b/tensorflow/compiler/aot/tfcompile.bzl
index 2dc3e8c9113b37bf9d575ad66783f4ab49478af4..4051664c24cacad4a2d151ad3ac9009015900609 100644
--- a/tensorflow/compiler/aot/tfcompile.bzl
+++ b/tensorflow/compiler/aot/tfcompile.bzl
@@ -283,7 +283,7 @@ def tf_library(
)
# Variables used for gen_test and gen_benchmark.
- cpp_class_split = cpp_class.rsplit("::", maxsplit = 2)
+ cpp_class_split = cpp_class.rsplit("::", 2)
if len(cpp_class_split) == 1:
no_ns_name = cpp_class_split[0]
else:
diff --git a/tensorflow/compiler/jit/BUILD b/tensorflow/compiler/jit/BUILD
index 15dcbb2641eca031e82db9aa58dee6a14ab0a2cc..d8c88a9fca2db74265b4962e07a66ab214b1d994 100644
--- a/tensorflow/compiler/jit/BUILD
+++ b/tensorflow/compiler/jit/BUILD
@@ -515,6 +515,7 @@ cc_library(
"//tensorflow/compiler/jit/ops:xla_ops",
"//tensorflow/compiler/tf2xla:dump_graph",
"//tensorflow/compiler/tf2xla:resource_operation_table",
+ "//tensorflow/compiler/tf2xla:side_effect_util",
"//tensorflow/compiler/tf2xla:tf2xla_util",
"//tensorflow/compiler/tf2xla:xla_compiler",
"//tensorflow/compiler/tf2xla/cc:xla_jit_ops",
@@ -613,6 +614,7 @@ tf_cc_test(
"//tensorflow/cc:cc_ops",
"//tensorflow/cc:cc_ops_internal",
"//tensorflow/cc:function_ops",
+ "//tensorflow/cc:functional_ops",
"//tensorflow/cc:ops",
"//tensorflow/cc:resource_variable_ops",
"//tensorflow/cc:scope",
@@ -625,6 +627,7 @@ tf_cc_test(
"//tensorflow/compiler/tf2xla/cc:xla_ops",
"//tensorflow/compiler/tf2xla/kernels:xla_dummy_ops",
"//tensorflow/compiler/tf2xla/kernels:xla_ops",
+ "//tensorflow/compiler/xla:test",
"//tensorflow/core:core_cpu",
"//tensorflow/core:framework",
"//tensorflow/core:framework_internal",
diff --git a/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc b/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc
index f478832781cb1dc045d9163d4a6f5e5f64a8a705..03aba97bbe81a11f6366d118ee5bc573d0c6b31b 100644
--- a/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc
+++ b/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc
@@ -779,7 +779,8 @@ Status Encapsulator::Subgraph::RecordArg(
if (inserted) {
NodeDef arg_def;
NodeDefBuilder builder(
- absl::StrCat(src_node->name(), "_", src_slot, "_arg"), kArgOp);
+ absl::StrCat(src_node->name(), "_", src_slot, "_arg"), kArgOp,
+ NodeDebugInfo(src_node->def()));
DataType dtype = edge->dst()->input_type(edge->dst_input());
builder.Attr("T", dtype);
builder.Attr("index", arg_index);
@@ -814,7 +815,8 @@ Status Encapsulator::Subgraph::RecordResult(
if (inserted) {
NodeDef ret_def;
NodeDefBuilder builder(
- absl::StrCat(src_node->name(), "_", src_slot, "_retval"), kRetValOp);
+ absl::StrCat(src_node->name(), "_", src_slot, "_retval"), kRetValOp,
+ NodeDebugInfo(src_node->def()));
DataType dtype = src_node->output_type(src_slot);
builder.Attr("T", dtype);
builder.Attr("index", ret_index);
@@ -974,6 +976,7 @@ Status Encapsulator::Subgraph::AddHostComputes(
}
NodeDef host_compute_def;
+ // TODO(shikharagarwal): What source node should we use for errors?
NodeDefBuilder builder(absl::StrCat("outside_compilation_",
oc_subgraph_name, "_host_compute"),
kHostComputeOp);
@@ -1040,6 +1043,7 @@ Status Encapsulator::Subgraph::MakeSequencingNode(const string& subgraph_name,
Graph* graph_out) {
if (sequencer_ == nullptr) {
NodeDef seq_def;
+ // TODO(shikharagarwal): What source node should we use for errors?
NodeDefBuilder builder(absl::StrCat(subgraph_name, "_sequencer"), "NoOp");
builder.Attr(kXlaHostTransferSequencerAttr, subgraph_name);
builder.Device(device_);
@@ -1214,7 +1218,8 @@ Status Encapsulator::Subgraph::AddHostComputeKeyPlaceholder(
GraphDefBuilder::Options options(graph_out, /*status=*/nullptr);
NodeDef key_def;
NodeDefBuilder builder(
- absl::StrCat(call_node_def_.name(), "_key_placeholder"), "Placeholder");
+ absl::StrCat(call_node_def_.name(), "_key_placeholder"), "Placeholder",
+ NodeDebugInfo(call_node_def_));
builder.Attr("dtype", DT_STRING);
builder.Attr("shape", shape_proto);
builder.Attr("_host_compute_call_node", call_node_def_.name());
@@ -1248,6 +1253,7 @@ Status Encapsulator::Subgraph::AddRecvAtHostNode(
}
NodeDef recv_def;
+ // TODO(shikharagarwal): What source node should we use for errors?
NodeDefBuilder builder(absl::StrCat("outside_compilation_", subgraph_name,
"_", oc_subgraph_name, "_recv"),
kRecvAtHostOp);
@@ -1303,6 +1309,7 @@ Status Encapsulator::Subgraph::AddSendFromHostNode(
}
NodeDef send_def;
+ // TODO(shikharagarwal): What source node should we use for errors?
NodeDefBuilder builder(absl::StrCat("outside_compilation_", subgraph_name,
"_", oc_subgraph_name, "_send"),
kSendFromHostOp);
@@ -1833,8 +1840,9 @@ Node* AddDummyShapedNode(const Node* src_node, int src_port,
// Add any Enter nodes required to bring the constant to the correct control
// flow frame.
while (!control_flow_info[src_node->id()].frame_name.empty()) {
+ NodeDebugInfo debug_info(*src_node);
NodeBuilder enter_builder(options.GetNameForOp("Enter"), "Enter",
- options.op_registry());
+ options.op_registry(), &debug_info);
enter_builder.Attr("frame_name",
control_flow_info[src_node->id()].frame_name);
enter_builder.Attr("is_constant", true);
@@ -2018,7 +2026,8 @@ Status Encapsulator::DoStaticShapeInferenceForOutsideCompilationSend(
return errors::InvalidArgument(
"Shape inference is not possible for outside_compilation "
"SendFromHost node ",
- send_node->name(), " because shape of node ", n->name(),
+ send_node->name(), " because shape of node ",
+ FormatNodeForError(*n),
" will not be known at compilation time.");
}
}
@@ -2047,8 +2056,7 @@ Status Encapsulator::DoStaticShapeInferenceForOutsideCompilationSend(
return errors::Internal(
"Internal assumption failed while rewriting an outside_compilation "
"cluster that contains a while loop. Logic assumes back-edge is to "
- "port 1 of a 2-input "
- "Merge node.");
+ "port 1 of a 2-input Merge node.");
}
// Connect the existing edge to both inputs of the Merge node so that the
// graph will be well-formed.
diff --git a/tensorflow/compiler/jit/encapsulate_subgraphs_pass_test.cc b/tensorflow/compiler/jit/encapsulate_subgraphs_pass_test.cc
index de89be9a3555960dabe7bacd17226c15ae888ae6..8617beec004d0fe912155f054442c5b6249bb6b5 100644
--- a/tensorflow/compiler/jit/encapsulate_subgraphs_pass_test.cc
+++ b/tensorflow/compiler/jit/encapsulate_subgraphs_pass_test.cc
@@ -299,7 +299,7 @@ REGISTER_OP("XlaHostCompute")
.Attr("Toutputs: list(type) >= 0")
.Attr("ancestors: list(string) >= 0")
.Attr("key: string")
- .Attr("shape_inference_graph: string = ''")
+ .Attr("shape_inference_graph: func")
.Attr("shapes: list(shape) >= 0")
.SetShapeFn(::tensorflow::shape_inference::UnknownShape);
@@ -510,11 +510,7 @@ Status Encapsulate(GraphDef* graphdef, FunctionDefLibrary* library,
s = ConvertGraphDefToGraph(options, *graphdef, graph.get());
if (!s.ok()) return s;
- s = PerformStaticShapeInferenceBeforeEncapsulation(
- graph.get(), "_encapsulate", "_outside");
- if (!s.ok()) return s;
-
- s = PreprocessForEncapsulation(graph.get(), "_encapsulate", "_outside");
+ s = PerformStaticShapeInferenceBeforeEncapsulation(graph.get());
if (!s.ok()) return s;
std::unique_ptr graph_out;
@@ -550,6 +546,14 @@ Status Encapsulate(GraphDef* graphdef, FunctionDefLibrary* library,
graphdef->Swap(&graphdef_out);
*library = lib_def->ToProto();
+ // Remove "_xla_inferred_shapes" attr. They are added by
+ // `PerformStaticShapeInferenceBeforeEncapsulation`.
+ for (FunctionDef& fdef : *library->mutable_function()) {
+ for (NodeDef& node_def : *fdef.mutable_node_def()) {
+ node_def.mutable_attr()->erase("_xla_inferred_shapes");
+ }
+ }
+
return s;
}
@@ -901,18 +905,22 @@ TEST(EncapsulateSubgraphsTest, OneFunctionOneOutside) {
{
GraphDefBuilder shape(GraphDefBuilder::kFailImmediately);
Node* key_constant = KeyPlaceholder("F1", shape.opts());
- Node* recv = RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O1",
- {DT_FLOAT, DT_FLOAT}, shape.opts());
+ Node* recv = RecvAtHost(
+ ops::NodeOut(key_constant, 0), "F1", "O1", {DT_FLOAT, DT_FLOAT},
+ shape.opts().WithAttr(kXlaHasHostTransferAttrName, true));
Node* e = Binary(ops::NodeOut(recv, 0), ops::NodeOut(recv, 1),
shape.opts()
.WithName("E")
.WithAttr("_encapsulate", "F1")
.WithAttr("_outside", "O1"));
- SendFromHost(ops::NodeOut(key_constant, 0), "F1", "O1", {e}, shape.opts());
+ SendFromHost(ops::NodeOut(key_constant, 0), "F1", "O1", {e},
+ shape.opts().WithAttr(kXlaHasHostTransferAttrName, true));
TF_EXPECT_OK(
AddGraphDefToFunctionLibrary(shape, "F1_O1", &library_expected));
}
+ NameAttrList shape_inference_graph;
+ shape_inference_graph.set_name("_outside_compilation_shape_inference_F1_O1");
*library_expected.add_function() = test::function::XTimesTwo();
*library_expected.add_function() = FunctionDefHelper::Create(
"F1", {"a_0_arg:float", "b_0_arg:float"}, {"f_0_retval_retval:float"}, {},
@@ -931,8 +939,7 @@ TEST(EncapsulateSubgraphsTest, OneFunctionOneOutside) {
{"Toutputs", absl::Span({DT_FLOAT})},
{"ancestors", absl::Span({})},
{"key", "host_compute_channel_F1_O1"},
- {"shape_inference_graph",
- "_outside_compilation_shape_inference_F1_O1"},
+ {"shape_inference_graph", shape_inference_graph},
{"shapes", absl::Span({})},
{"_outside_compilation_subgraph", "O1"}},
{"c"}},
@@ -948,16 +955,18 @@ TEST(EncapsulateSubgraphsTest, OneFunctionOneOutside) {
Node* key_constant =
KeyPlaceholder("F1", b2.opts().WithName("F1_key_placeholder"));
- Node* recv = RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O1",
- {DT_FLOAT, DT_FLOAT}, b2.opts());
+ Node* recv = RecvAtHost(
+ ops::NodeOut(key_constant, 0), "F1", "O1", {DT_FLOAT, DT_FLOAT},
+ b2.opts().WithAttr(kXlaHasHostTransferAttrName, true));
Node* e = Binary(ops::NodeOut(recv, 0), ops::NodeOut(recv, 1),
b2.opts()
.WithName("E")
- .WithControlInputs({recv, b})
+ .WithControlInputs({recv})
.WithAttr("_encapsulate", "F1")
.WithAttr("_outside", "O1"));
Node* send = SendFromHost(ops::NodeOut(key_constant, 0), "F1", "O1", {e},
- b2.opts().WithControlInput(e));
+ b2.opts().WithControlInput(e).WithAttr(
+ kXlaHasHostTransferAttrName, true));
Node* s = Sequencer(
b2.opts().WithName("F1_sequencer").WithControlInputs({recv, send}),
@@ -966,9 +975,9 @@ TEST(EncapsulateSubgraphsTest, OneFunctionOneOutside) {
NodeBuilder node_builder("F1", "F1", lib_def.get());
node_builder.Input(a).Input(b);
Node* call =
- b2.opts().WithControlInputs({s}).FinalizeBuilder(&node_builder);
+ b2.opts().WithControlInputs({s, b}).FinalizeBuilder(&node_builder);
- Binary(a, call, b2.opts().WithName("G").WithControlInputs({e}));
+ Binary(a, call, b2.opts().WithName("G").WithControlInputs({call}));
TF_EXPECT_OK(b2.ToGraphDef(&graphdef_expected));
}
@@ -1022,14 +1031,16 @@ TEST(EncapsulateSubgraphsTest, OneFunctionTwoOutside) {
{
GraphDefBuilder shape1(GraphDefBuilder::kFailImmediately);
Node* key_constant = KeyPlaceholder("F1", shape1.opts());
- Node* recv = RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O1",
- {DT_FLOAT, DT_FLOAT}, shape1.opts());
+ Node* recv = RecvAtHost(
+ ops::NodeOut(key_constant, 0), "F1", "O1", {DT_FLOAT, DT_FLOAT},
+ shape1.opts().WithAttr(kXlaHasHostTransferAttrName, true));
Node* e = Binary(ops::NodeOut(recv, 0), ops::NodeOut(recv, 1),
shape1.opts()
.WithName("E")
.WithAttr("_encapsulate", "F1")
.WithAttr("_outside", "O1"));
- SendFromHost(ops::NodeOut(key_constant, 0), "F1", "O1", {e}, shape1.opts());
+ SendFromHost(ops::NodeOut(key_constant, 0), "F1", "O1", {e},
+ shape1.opts().WithAttr(kXlaHasHostTransferAttrName, true));
TF_EXPECT_OK(
AddGraphDefToFunctionLibrary(shape1, "F1_O1", &library_expected));
}
@@ -1037,33 +1048,45 @@ TEST(EncapsulateSubgraphsTest, OneFunctionTwoOutside) {
{
GraphDefBuilder shape2(GraphDefBuilder::kFailImmediately);
Node* key_constant = KeyPlaceholder("F1", shape2.opts());
- Node* recv1 = RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O1",
- {DT_FLOAT, DT_FLOAT}, shape2.opts());
+ Node* recv1 = RecvAtHost(
+ ops::NodeOut(key_constant, 0), "F1", "O1", {DT_FLOAT, DT_FLOAT},
+ shape2.opts().WithAttr(kXlaHasHostTransferAttrName, true));
Node* e = Binary(ops::NodeOut(recv1, 0), ops::NodeOut(recv1, 1),
shape2.opts()
.WithName("E")
.WithAttr("_encapsulate", "F1")
.WithAttr("_outside", "O1"));
- Node* recv2 = RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O2",
- {DT_FLOAT, DT_FLOAT}, shape2.opts());
+ Node* recv2 = RecvAtHost(
+ ops::NodeOut(key_constant, 0), "F1", "O2", {DT_FLOAT, DT_FLOAT},
+ shape2.opts().WithAttr(kXlaHasHostTransferAttrName, true));
+ Node* g = Binary(e, ops::NodeOut(recv2, 0),
+ shape2.opts()
+ .WithName("G")
+ .WithAttr("_encapsulate", "F1")
+ .WithAttr("_outside", "O2"));
Node* h = Binary(ops::NodeOut(recv2, 1), e,
shape2.opts()
.WithName("H")
.WithAttr("_encapsulate", "F1")
.WithAttr("_outside", "O2"));
- SendFromHost(ops::NodeOut(key_constant, 0), "F1", "O2", {h}, shape2.opts());
+ SendFromHost(ops::NodeOut(key_constant, 0), "F1", "O2", {g, h},
+ shape2.opts().WithAttr(kXlaHasHostTransferAttrName, true));
TF_EXPECT_OK(
AddGraphDefToFunctionLibrary(shape2, "F1_O2", &library_expected));
}
+ NameAttrList shape_inference_graph1, shape_inference_graph2;
+ shape_inference_graph1.set_name("_outside_compilation_shape_inference_F1_O1");
+ shape_inference_graph2.set_name("_outside_compilation_shape_inference_F1_O2");
*library_expected.add_function() = FunctionDefHelper::Create(
- "F1", {"a_0_arg:float", "b_0_arg:float"}, {"i_0_retval_retval:float"}, {},
+ "F1", {"a_0_arg:float", "b_0_arg:float"},
+ {"g_0_retval_retval:float", "i_0_retval_retval:float"}, {},
{
{{"C"}, "UnaryTest", {"a_0_arg"}},
{{"D"}, "BinaryTest", {"b_0_arg", "C:o:0"}, {}},
{{"I"},
"UnaryTest",
- {"outside_compilation_O2_host_compute:outputs:0"}},
+ {"outside_compilation_O2_host_compute:outputs:1"}},
{{"F"},
"BinaryTest",
{"C:o:0", "outside_compilation_O1_host_compute:outputs:0"},
@@ -1073,11 +1096,10 @@ TEST(EncapsulateSubgraphsTest, OneFunctionTwoOutside) {
"XlaHostCompute",
{"F:o:0", "D:o:0"},
{{"Tinputs", absl::Span({DT_FLOAT, DT_FLOAT})},
- {"Toutputs", absl::Span({DT_FLOAT})},
+ {"Toutputs", absl::Span({DT_FLOAT, DT_FLOAT})},
{"ancestors", absl::Span({})},
{"key", "host_compute_channel_F1_O2"},
- {"shape_inference_graph",
- "_outside_compilation_shape_inference_F1_O2"},
+ {"shape_inference_graph", shape_inference_graph2},
{"shapes", absl::Span({})},
{"_outside_compilation_subgraph", "O2"}},
{"F"}},
@@ -1088,13 +1110,13 @@ TEST(EncapsulateSubgraphsTest, OneFunctionTwoOutside) {
{"Toutputs", absl::Span({DT_FLOAT})},
{"ancestors", absl::Span({})},
{"key", "host_compute_channel_F1_O1"},
- {"shape_inference_graph",
- "_outside_compilation_shape_inference_F1_O1"},
+ {"shape_inference_graph", shape_inference_graph1},
{"shapes", absl::Span({})},
{"_outside_compilation_subgraph", "O1"}},
{"D"}},
},
- {{"i_0_retval_retval", "I:o:0"}});
+ {{"g_0_retval_retval", "outside_compilation_O2_host_compute:outputs:0"},
+ {"i_0_retval_retval", "I:o:0"}});
{
std::unique_ptr lib_def(
@@ -1105,19 +1127,22 @@ TEST(EncapsulateSubgraphsTest, OneFunctionTwoOutside) {
Node* key_constant =
KeyPlaceholder("F1", b2.opts().WithName("F1_key_placeholder"));
- Node* recv1 = RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O1",
- {DT_FLOAT, DT_FLOAT}, b2.opts());
+ Node* recv1 = RecvAtHost(
+ ops::NodeOut(key_constant, 0), "F1", "O1", {DT_FLOAT, DT_FLOAT},
+ b2.opts().WithAttr(kXlaHasHostTransferAttrName, true));
Node* e = Binary(ops::NodeOut(recv1, 0), ops::NodeOut(recv1, 1),
b2.opts()
.WithName("E")
- .WithControlInputs({recv1, b})
+ .WithControlInputs({recv1})
.WithAttr("_encapsulate", "F1")
.WithAttr("_outside", "O1"));
Node* send1 = SendFromHost(ops::NodeOut(key_constant, 0), "F1", "O1", {e},
- b2.opts().WithControlInput(e));
+ b2.opts().WithControlInput(e).WithAttr(
+ kXlaHasHostTransferAttrName, true));
- Node* recv2 = RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O2",
- {DT_FLOAT, DT_FLOAT}, b2.opts());
+ Node* recv2 = RecvAtHost(
+ ops::NodeOut(key_constant, 0), "F1", "O2", {DT_FLOAT, DT_FLOAT},
+ b2.opts().WithAttr(kXlaHasHostTransferAttrName, true));
Node* g = Binary(e, ops::NodeOut(recv2, 0),
b2.opts()
.WithName("G")
@@ -1130,7 +1155,8 @@ TEST(EncapsulateSubgraphsTest, OneFunctionTwoOutside) {
.WithAttr("_encapsulate", "F1")
.WithAttr("_outside", "O2"));
Node* send2 =
- SendFromHost(ops::NodeOut(key_constant, 0), "F1", "O2", {h}, b2.opts());
+ SendFromHost(ops::NodeOut(key_constant, 0), "F1", "O2", {g, h},
+ b2.opts().WithAttr(kXlaHasHostTransferAttrName, true));
Node* s = Sequencer(b2.opts()
.WithName("F1_sequencer")
@@ -1139,12 +1165,13 @@ TEST(EncapsulateSubgraphsTest, OneFunctionTwoOutside) {
NodeBuilder node_builder("F1", "F1", lib_def.get());
node_builder.Input(a).Input(b);
- Node* call = b2.opts().WithControlInput(s).FinalizeBuilder(&node_builder);
+ Node* call =
+ b2.opts().WithControlInputs({s, b}).FinalizeBuilder(&node_builder);
- Binary(g, call, b2.opts().WithName("J"));
+ Binary(ops::NodeOut(call, 0), ops::NodeOut(call, 1),
+ b2.opts().WithName("J"));
TF_EXPECT_OK(b2.ToGraphDef(&graphdef_expected));
}
-
TF_EXPECT_GRAPH_EQ(graphdef_expected, graphdef);
TF_EXPECT_FUNCTIONDEFLIBRARY_EQ(library_expected, library);
}
@@ -1196,7 +1223,9 @@ TEST(EncapsulateSubgraphsTest, TwoFunctionsTwoOutside) {
*library_expected.add_function() = FunctionDefHelper::Create(
"F1", {"a_0_arg:float", "b_0_arg:float"},
- {"f_0_retval_retval:float", "d_0_retval_retval:float"}, {},
+ {"e_0_retval_retval:float", "f_0_retval_retval:float",
+ "d_0_retval_retval:float"},
+ {},
{
{{"C"}, "UnaryTest", {"a_0_arg"}},
{{"D"}, "BinaryTest", {"b_0_arg", "C:o:0"}},
@@ -1212,35 +1241,37 @@ TEST(EncapsulateSubgraphsTest, TwoFunctionsTwoOutside) {
{"Toutputs", absl::Span({DT_FLOAT})},
{"ancestors", absl::Span({})},
{"key", "host_compute_channel_F1_O1"},
- {"shape_inference_graph", ""},
+ {"shape_inference_graph", NameAttrList()},
{"shapes",
absl::Span({shape_proto_expected})},
{"_outside_compilation_subgraph", "O1"}},
{"D"}},
},
- {{"d_0_retval_retval", "D:o:0"}, {"f_0_retval_retval", "F:o:0"}});
+ {{"e_0_retval_retval", "outside_compilation_O1_host_compute:outputs:0"},
+ {"d_0_retval_retval", "D:o:0"},
+ {"f_0_retval_retval", "F:o:0"}});
*library_expected.add_function() = FunctionDefHelper::Create(
- "F2", {"f_0_arg:float", "bridge_e_g_0_arg:float"},
- {"i_0_retval_retval:float", "g_0_retval_retval:float"}, {},
+ "F2", {"e_0_arg:float", "f_0_arg:float", "d_0_arg:float"},
+ {"g_0_retval_retval:float", "i_0_retval_retval:float"}, {},
{
- {{"G"}, "BinaryTest", {"bridge_e_g_0_arg", "f_0_arg"}},
+ {{"G"}, "BinaryTest", {"e_0_arg", "f_0_arg"}},
{{"I"},
"BinaryTest",
{"f_0_arg", "outside_compilation_O1_host_compute:outputs:0"}},
{{"outside_compilation_O1_host_compute"},
"XlaHostCompute",
- {"G:o:0"},
- {{"Tinputs", absl::Span({DT_FLOAT})},
+ {"d_0_arg", "G:o:0"},
+ {{"Tinputs", absl::Span({DT_FLOAT, DT_FLOAT})},
{"Toutputs", absl::Span({DT_FLOAT})},
{"ancestors", absl::Span({})},
{"key", "host_compute_channel_F2_O1"},
- {"shape_inference_graph", ""},
+ {"shape_inference_graph", NameAttrList()},
{"shapes",
absl::Span({shape_proto_expected})},
{"_outside_compilation_subgraph", "O1"}}},
},
- {{"i_0_retval_retval", "I:o:0"}, {"g_0_retval_retval", "G:o:0"}});
+ {{"g_0_retval_retval", "G:o:0"}, {"i_0_retval_retval", "I:o:0"}});
{
std::unique_ptr lib_def(
@@ -1251,16 +1282,18 @@ TEST(EncapsulateSubgraphsTest, TwoFunctionsTwoOutside) {
Node* key_constant1 =
KeyPlaceholder("F1", b2.opts().WithName("F1_key_placeholder"));
- Node* recv1 = RecvAtHost(ops::NodeOut(key_constant1, 0), "F1", "O1",
- {DT_FLOAT, DT_FLOAT}, b2.opts());
+ Node* recv1 = RecvAtHost(
+ ops::NodeOut(key_constant1, 0), "F1", "O1", {DT_FLOAT, DT_FLOAT},
+ b2.opts().WithAttr(kXlaHasHostTransferAttrName, true));
Node* e = Binary(ops::NodeOut(recv1, 0), ops::NodeOut(recv1, 1),
b2.opts()
.WithName("E")
- .WithControlInputs({recv1, b})
+ .WithControlInputs({recv1})
.WithAttr("_encapsulate", "F1")
.WithAttr("_outside", "O1"));
Node* send1 = SendFromHost(ops::NodeOut(key_constant1, 0), "F1", "O1", {e},
- b2.opts().WithControlInput(e));
+ b2.opts().WithControlInput(e).WithAttr(
+ kXlaHasHostTransferAttrName, true));
Node* s1 = Sequencer(
b2.opts().WithName("F1_sequencer").WithControlInputs({recv1, send1}),
"F1");
@@ -1268,29 +1301,33 @@ TEST(EncapsulateSubgraphsTest, TwoFunctionsTwoOutside) {
NodeBuilder node_builder1("F1", "F1", lib_def.get());
node_builder1.Input(a).Input(b);
Node* call1 =
- b2.opts().WithControlInput(s1).FinalizeBuilder(&node_builder1);
+ b2.opts().WithControlInputs({s1, b}).FinalizeBuilder(&node_builder1);
Node* key_constant2 =
KeyPlaceholder("F2", b2.opts().WithName("F2_key_placeholder"));
- Node* recv2 = RecvAtHost(ops::NodeOut(key_constant2, 0), "F2", "O1",
- {DT_FLOAT}, b2.opts());
- Node* h = Binary(ops::NodeOut(call1, 1), recv2,
+ Node* recv2 = RecvAtHost(
+ ops::NodeOut(key_constant2, 0), "F2", "O1", {DT_FLOAT, DT_FLOAT},
+ b2.opts().WithAttr(kXlaHasHostTransferAttrName, true));
+ Node* h = Binary(recv2, ops::NodeOut(recv2, 1),
b2.opts()
.WithName("H")
.WithAttr("_encapsulate", "F2")
.WithAttr("_outside", "O1"));
- Node* send2 = SendFromHost(ops::NodeOut(key_constant2, 0), "F2", "O1", {h},
- b2.opts());
+ Node* send2 =
+ SendFromHost(ops::NodeOut(key_constant2, 0), "F2", "O1", {h},
+ b2.opts().WithAttr(kXlaHasHostTransferAttrName, true));
Node* s2 = Sequencer(
b2.opts().WithName("F2_sequencer").WithControlInputs({recv2, send2}),
"F2");
NodeBuilder node_builder2("F2", "F2", lib_def.get());
- node_builder2.Input(call1).Input(e);
+ node_builder2.Input(call1)
+ .Input(ops::NodeOut(call1, 1))
+ .Input(ops::NodeOut(call1, 2));
Node* call2 = b2.opts()
- .WithControlInputs({s2, e, call1})
+ .WithControlInputs({s2, call1})
.FinalizeBuilder(&node_builder2);
- Binary(ops::NodeOut(call2, 1), call2, b2.opts().WithName("J"));
+ Binary(call2, ops::NodeOut(call2, 1), b2.opts().WithName("J"));
TF_EXPECT_OK(b2.ToGraphDef(&graphdef_expected));
}
@@ -1326,8 +1363,7 @@ TEST(EncapsulateSubgraphsTest, TwoFunctionsTwoOutsideDependencyFromOutside) {
Node* h = Unary(g, b1.opts()
.WithName("H")
.WithAttr("_encapsulate", "F2")
- .WithAttr("_outside", "O1")
- .WithControlInput(e));
+ .WithAttr("_outside", "O1"));
Node* i = Unary(h, b1.opts().WithName("I").WithAttr("_encapsulate", "F2"));
Binary(f, i, b1.opts().WithName("J"));
TF_EXPECT_OK(b1.ToGraphDef(&graphdef));
@@ -1358,7 +1394,7 @@ TEST(EncapsulateSubgraphsTest, TwoFunctionsTwoOutsideDependencyFromOutside) {
{"Toutputs", absl::Span({DT_FLOAT})},
{"ancestors", absl::Span({})},
{"key", "host_compute_channel_F1_O1"},
- {"shape_inference_graph", ""},
+ {"shape_inference_graph", NameAttrList()},
{"shapes",
absl::Span({shape_proto_expected})},
{"_outside_compilation_subgraph", "O1"}},
@@ -1380,7 +1416,7 @@ TEST(EncapsulateSubgraphsTest, TwoFunctionsTwoOutsideDependencyFromOutside) {
{"Toutputs", absl::Span({DT_FLOAT})},
{"ancestors", absl::Span({})},
{"key", "host_compute_channel_F2_O1"},
- {"shape_inference_graph", ""},
+ {"shape_inference_graph", NameAttrList()},
{"shapes",
absl::Span({shape_proto_expected})},
{"_outside_compilation_subgraph", "O1"}}},
@@ -1401,7 +1437,7 @@ TEST(EncapsulateSubgraphsTest, TwoFunctionsTwoOutsideDependencyFromOutside) {
Node* e = Binary(ops::NodeOut(recv1, 0), ops::NodeOut(recv1, 1),
b2.opts()
.WithName("E")
- .WithControlInputs({recv1, b})
+ .WithControlInputs({recv1})
.WithAttr("_encapsulate", "F1")
.WithAttr("_outside", "O1"));
Node* send1 = SendFromHost(ops::NodeOut(key_constant1, 0), "F1", "O1", {e},
@@ -1413,7 +1449,7 @@ TEST(EncapsulateSubgraphsTest, TwoFunctionsTwoOutsideDependencyFromOutside) {
NodeBuilder node_builder1("F1", "F1", lib_def.get());
node_builder1.Input(a).Input(b);
Node* call1 =
- b2.opts().WithControlInput(s1).FinalizeBuilder(&node_builder1);
+ b2.opts().WithControlInputs({s1, b}).FinalizeBuilder(&node_builder1);
Node* key_constant2 =
KeyPlaceholder("F2", b2.opts().WithName("F2_key_placeholder"));
@@ -1422,8 +1458,7 @@ TEST(EncapsulateSubgraphsTest, TwoFunctionsTwoOutsideDependencyFromOutside) {
Node* h = Unary(recv2, b2.opts()
.WithName("H")
.WithAttr("_encapsulate", "F2")
- .WithAttr("_outside", "O1")
- .WithControlInput(e));
+ .WithAttr("_outside", "O1"));
Node* send2 = SendFromHost(ops::NodeOut(key_constant2, 0), "F2", "O1", {h},
b2.opts());
@@ -1484,12 +1519,12 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationNoInputs) {
{"D:o:0", "outside_compilation_O1_host_compute:outputs:0"}},
{{"outside_compilation_O1_host_compute"},
"XlaHostCompute",
- {},
- {{"Tinputs", absl::Span({})},
+ {"a_0_arg"},
+ {{"Tinputs", absl::Span({DT_FLOAT})},
{"Toutputs", absl::Span({DT_FLOAT})},
{"ancestors", absl::Span({})},
{"key", "host_compute_channel_F1_O1"},
- {"shape_inference_graph", ""},
+ {"shape_inference_graph", NameAttrList()},
{"shapes",
absl::Span({shape_proto_expected})},
{"_outside_compilation_subgraph", "O1"}}},
@@ -1503,16 +1538,19 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationNoInputs) {
Node* a = InputShaped(b2.opts().WithName("A"));
Node* b = Input(b2.opts().WithName("B"));
- Node* e = Unary(a, b2.opts()
- .WithName("E")
- .WithAttr("_encapsulate", "F1")
- .WithAttr("_outside", "O1"));
Node* key_constant =
KeyPlaceholder("F1", b2.opts().WithName("F1_key_placeholder"));
+ Node* recv1 = RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O1",
+ {DT_FLOAT}, b2.opts());
+ Node* e = Unary(recv1, b2.opts()
+ .WithName("E")
+ .WithAttr("_encapsulate", "F1")
+ .WithAttr("_outside", "O1"));
Node* send1 =
SendFromHost(ops::NodeOut(key_constant, 0), "F1", "O1", {e}, b2.opts());
Node* s1 = Sequencer(
- b2.opts().WithName("F1_sequencer").WithControlInput(send1), "F1");
+ b2.opts().WithName("F1_sequencer").WithControlInputs({send1, recv1}),
+ "F1");
NodeBuilder node_builder1("F1", "F1", lib_def.get());
node_builder1.Input(a).Input(b);
Node* call1 =
@@ -1569,12 +1607,12 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationControlInput) {
{"D:o:0", "outside_compilation_O1_host_compute:outputs:0"}},
{{"outside_compilation_O1_host_compute"},
"XlaHostCompute",
- {},
- {{"Tinputs", absl::Span({})},
+ {"a_0_arg"},
+ {{"Tinputs", absl::Span({DT_FLOAT})},
{"Toutputs", absl::Span({DT_FLOAT})},
{"ancestors", absl::Span({})},
{"key", "host_compute_channel_F1_O1"},
- {"shape_inference_graph", ""},
+ {"shape_inference_graph", NameAttrList()},
{"shapes",
absl::Span({shape_proto_expected})},
{"_outside_compilation_subgraph", "O1"}},
@@ -1591,13 +1629,13 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationControlInput) {
Node* key_constant =
KeyPlaceholder("F1", b2.opts().WithName("F1_key_placeholder"));
- Node* recv1 =
- RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O1", {}, b2.opts());
- Node* e = Unary(a, b2.opts()
- .WithName("E")
- .WithControlInput(recv1)
- .WithAttr("_encapsulate", "F1")
- .WithAttr("_outside", "O1"));
+ Node* recv1 = RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O1",
+ {DT_FLOAT}, b2.opts());
+ Node* e = Unary(recv1, b2.opts()
+ .WithName("E")
+ .WithControlInput(recv1)
+ .WithAttr("_encapsulate", "F1")
+ .WithAttr("_outside", "O1"));
Node* send1 =
SendFromHost(ops::NodeOut(key_constant, 0), "F1", "O1", {e}, b2.opts());
Node* s1 = Sequencer(
@@ -1644,8 +1682,27 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationNoOutputs) {
FunctionDefLibrary library_expected;
GraphDef graphdef_expected;
+ {
+ GraphDefBuilder shape1(GraphDefBuilder::kFailImmediately);
+ Node* key_constant = KeyPlaceholder("F1", shape1.opts());
+ Node* recv1 =
+ RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O1", {DT_FLOAT},
+ shape1.opts().WithAttr(kXlaHasHostTransferAttrName, true));
+ Node* e = Unary(ops::NodeOut(recv1, 0), shape1.opts()
+ .WithName("E")
+ .WithAttr("_encapsulate", "F1")
+ .WithAttr("_outside", "O1"));
+ SendFromHost(ops::NodeOut(key_constant, 0), "F1", "O1", {e},
+ shape1.opts().WithAttr(kXlaHasHostTransferAttrName, true));
+ TF_EXPECT_OK(
+ AddGraphDefToFunctionLibrary(shape1, "F1_O1", &library_expected));
+ }
+
+ NameAttrList shape_inference_graph;
+ shape_inference_graph.set_name("_outside_compilation_shape_inference_F1_O1");
*library_expected.add_function() = FunctionDefHelper::Create(
- "F1", {"a_0_arg:float", "b_0_arg:float"}, {"f_0_retval_retval:float"}, {},
+ "F1", {"a_0_arg:float", "b_0_arg:float"},
+ {"e_0_retval_retval:float", "f_0_retval_retval:float"}, {},
{
{{"C"}, "UnaryTest", {"a_0_arg"}},
{{"D"}, "BinaryTest", {"b_0_arg", "C:o:0"}},
@@ -1654,14 +1711,15 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationNoOutputs) {
"XlaHostCompute",
{"D:o:0"},
{{"Tinputs", absl::Span({DT_FLOAT})},
- {"Toutputs", absl::Span({})},
+ {"Toutputs", absl::Span({DT_FLOAT})},
{"ancestors", absl::Span({})},
{"key", "host_compute_channel_F1_O1"},
- {"shape_inference_graph", ""},
+ {"shape_inference_graph", shape_inference_graph},
{"shapes", absl::Span({})},
{"_outside_compilation_subgraph", "O1"}}},
},
- {{"f_0_retval_retval", "F:o:0"}});
+ {{"e_0_retval_retval", "outside_compilation_O1_host_compute:outputs:0"},
+ {"f_0_retval_retval", "F:o:0"}});
{
std::unique_ptr lib_def(
@@ -1678,14 +1736,17 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationNoOutputs) {
.WithName("E")
.WithAttr("_encapsulate", "F1")
.WithAttr("_outside", "O1"));
+ Node* send1 =
+ SendFromHost(ops::NodeOut(key_constant, 0), "F1", "O1", {e}, b2.opts());
Node* s1 = Sequencer(
- b2.opts().WithName("F1_sequencer").WithControlInput(recv1), "F1");
+ b2.opts().WithName("F1_sequencer").WithControlInputs({recv1, send1}),
+ "F1");
NodeBuilder node_builder1("F1", "F1", lib_def.get());
node_builder1.Input(a).Input(b);
Node* call1 =
b2.opts().WithControlInput(s1).FinalizeBuilder(&node_builder1);
- Binary(e, call1, b2.opts().WithName("G"));
+ Binary(call1, ops::NodeOut(call1, 1), b2.opts().WithName("G"));
TF_EXPECT_OK(b2.ToGraphDef(&graphdef_expected));
}
@@ -1722,8 +1783,27 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationControlOutput) {
FunctionDefLibrary library_expected;
GraphDef graphdef_expected;
+ {
+ GraphDefBuilder shape1(GraphDefBuilder::kFailImmediately);
+ Node* key_constant = KeyPlaceholder("F1", shape1.opts());
+ Node* recv1 =
+ RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O1", {DT_FLOAT},
+ shape1.opts().WithAttr(kXlaHasHostTransferAttrName, true));
+ Node* e = Unary(ops::NodeOut(recv1, 0), shape1.opts()
+ .WithName("E")
+ .WithAttr("_encapsulate", "F1")
+ .WithAttr("_outside", "O1"));
+ SendFromHost(ops::NodeOut(key_constant, 0), "F1", "O1", {e},
+ shape1.opts().WithAttr(kXlaHasHostTransferAttrName, true));
+ TF_EXPECT_OK(
+ AddGraphDefToFunctionLibrary(shape1, "F1_O1", &library_expected));
+ }
+
+ NameAttrList shape_inference_graph;
+ shape_inference_graph.set_name("_outside_compilation_shape_inference_F1_O1");
*library_expected.add_function() = FunctionDefHelper::Create(
- "F1", {"a_0_arg:float", "b_0_arg:float"}, {"f_0_retval_retval:float"}, {},
+ "F1", {"a_0_arg:float", "b_0_arg:float"},
+ {"e_0_retval_retval:float", "f_0_retval_retval:float"}, {},
{
{{"C"}, "UnaryTest", {"a_0_arg"}},
{{"D"}, "BinaryTest", {"b_0_arg", "C:o:0"}},
@@ -1736,14 +1816,15 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationControlOutput) {
"XlaHostCompute",
{"D:o:0"},
{{"Tinputs", absl::Span({DT_FLOAT})},
- {"Toutputs", absl::Span({})},
+ {"Toutputs", absl::Span({DT_FLOAT})},
{"ancestors", absl::Span({})},
{"key", "host_compute_channel_F1_O1"},
- {"shape_inference_graph", ""},
+ {"shape_inference_graph", shape_inference_graph},
{"shapes", absl::Span({})},
{"_outside_compilation_subgraph", "O1"}}},
},
- {{"f_0_retval_retval", "F:o:0"}});
+ {{"e_0_retval_retval", "outside_compilation_O1_host_compute:outputs:0"},
+ {"f_0_retval_retval", "F:o:0"}});
{
std::unique_ptr lib_def(
@@ -1760,7 +1841,7 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationControlOutput) {
.WithName("E")
.WithAttr("_encapsulate", "F1")
.WithAttr("_outside", "O1"));
- Node* send1 = SendFromHost(ops::NodeOut(key_constant, 0), "F1", "O1", {},
+ Node* send1 = SendFromHost(ops::NodeOut(key_constant, 0), "F1", "O1", {e},
b2.opts().WithControlInput(e));
Node* s1 = Sequencer(
b2.opts().WithName("F1_sequencer").WithControlInputs({recv1, send1}),
@@ -1770,7 +1851,7 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationControlOutput) {
Node* call1 =
b2.opts().WithControlInput(s1).FinalizeBuilder(&node_builder1);
- Binary(e, call1, b2.opts().WithName("G"));
+ Binary(call1, ops::NodeOut(call1, 1), b2.opts().WithName("G"));
TF_EXPECT_OK(b2.ToGraphDef(&graphdef_expected));
}
@@ -1813,22 +1894,45 @@ TEST(EncapsulateSubgraphsTest,
FunctionDefLibrary library_expected;
GraphDef graphdef_expected;
+ {
+ GraphDefBuilder shape1(GraphDefBuilder::kFailImmediately);
+ Node* key_constant = KeyPlaceholder("F1", shape1.opts());
+ Node* recv1 =
+ RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O1", {DT_FLOAT},
+ shape1.opts().WithAttr(kXlaHasHostTransferAttrName, true));
+ Node* e = Unary(ops::NodeOut(recv1, 0), shape1.opts()
+ .WithName("E")
+ .WithAttr("_encapsulate", "F1")
+ .WithAttr("_outside", "O1"));
+ SendFromHost(ops::NodeOut(key_constant, 0), "F1", "O1", {e},
+ shape1.opts().WithAttr(kXlaHasHostTransferAttrName, true));
+ TF_EXPECT_OK(
+ AddGraphDefToFunctionLibrary(shape1, "F1_O1", &library_expected));
+ }
+
{
GraphDefBuilder shape2(GraphDefBuilder::kFailImmediately);
Node* key_constant = KeyPlaceholder("F1", shape2.opts());
- Node* recv2 = RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O2",
- {DT_FLOAT}, shape2.opts());
+ Node* recv2 =
+ RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O2", {DT_FLOAT},
+ shape2.opts().WithAttr(kXlaHasHostTransferAttrName, true));
Node* g = Unary(ops::NodeOut(recv2, 0), shape2.opts()
.WithName("G")
.WithAttr("_encapsulate", "F1")
.WithAttr("_outside", "O2"));
- SendFromHost(ops::NodeOut(key_constant, 0), "F1", "O2", {g}, shape2.opts());
+ SendFromHost(ops::NodeOut(key_constant, 0), "F1", "O2", {g},
+ shape2.opts().WithAttr(kXlaHasHostTransferAttrName, true));
TF_EXPECT_OK(
AddGraphDefToFunctionLibrary(shape2, "F1_O2", &library_expected));
}
+ NameAttrList shape_inference_graph1;
+ shape_inference_graph1.set_name("_outside_compilation_shape_inference_F1_O1");
+ NameAttrList shape_inference_graph2;
+ shape_inference_graph2.set_name("_outside_compilation_shape_inference_F1_O2");
*library_expected.add_function() = FunctionDefHelper::Create(
- "F1", {"a_0_arg:float", "b_0_arg:float"}, {"h_0_retval_retval:float"}, {},
+ "F1", {"a_0_arg:float", "b_0_arg:float"},
+ {"e_0_retval_retval:float", "h_0_retval_retval:float"}, {},
{
{{"C"}, "UnaryTest", {"a_0_arg"}},
{{"D"}, "BinaryTest", {"b_0_arg", "C:o:0"}},
@@ -1836,6 +1940,16 @@ TEST(EncapsulateSubgraphsTest,
{{"H"},
"UnaryTest",
{"outside_compilation_O2_host_compute:outputs:0"}},
+ {{"outside_compilation_O1_host_compute"},
+ "XlaHostCompute",
+ {"a_0_arg"},
+ {{"Tinputs", absl::Span({DT_FLOAT})},
+ {"Toutputs", absl::Span({DT_FLOAT})},
+ {"ancestors", absl::Span({})},
+ {"key", "host_compute_channel_F1_O1"},
+ {"shape_inference_graph", shape_inference_graph1},
+ {"shapes", absl::Span({})},
+ {"_outside_compilation_subgraph", "O1"}}},
{{"outside_compilation_O2_host_compute"},
"XlaHostCompute",
{"F:o:0"},
@@ -1843,12 +1957,12 @@ TEST(EncapsulateSubgraphsTest,
{"Toutputs", absl::Span({DT_FLOAT})},
{"ancestors", absl::Span({})},
{"key", "host_compute_channel_F1_O2"},
- {"shape_inference_graph",
- "_outside_compilation_shape_inference_F1_O2"},
+ {"shape_inference_graph", shape_inference_graph2},
{"shapes", absl::Span({})},
{"_outside_compilation_subgraph", "O2"}}},
},
- {{"h_0_retval_retval", "H:o:0"}});
+ {{"e_0_retval_retval", "outside_compilation_O1_host_compute:outputs:0"},
+ {"h_0_retval_retval", "H:o:0"}});
{
std::unique_ptr lib_def(
@@ -1856,30 +1970,39 @@ TEST(EncapsulateSubgraphsTest,
GraphDefBuilder b2(GraphDefBuilder::kFailImmediately, lib_def.get());
Node* a = Input(b2.opts().WithName("A"));
Node* b = Input(b2.opts().WithName("B"));
-
- Node* e = Unary(a, b2.opts()
- .WithName("E")
- .WithAttr("_encapsulate", "F1")
- .WithAttr("_outside", "O1"));
Node* key_constant =
KeyPlaceholder("F1", b2.opts().WithName("F1_key_placeholder"));
- Node* recv = RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O2",
- {DT_FLOAT}, b2.opts());
- Node* g = Unary(recv, b2.opts()
- .WithName("G")
- .WithAttr("_encapsulate", "F1")
- .WithAttr("_outside", "O2")
- .WithControlInput(e));
- Node* send =
- SendFromHost(ops::NodeOut(key_constant, 0), "F1", "O2", {g}, b2.opts());
- Node* s1 = Sequencer(
- b2.opts().WithName("F1_sequencer").WithControlInputs({recv, send}),
- "F1");
+ Node* recv1 =
+ RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O1", {DT_FLOAT},
+ b2.opts().WithAttr(kXlaHasHostTransferAttrName, true));
+
+ Node* e = Unary(recv1, b2.opts()
+ .WithName("E")
+ .WithAttr("_encapsulate", "F1")
+ .WithAttr("_outside", "O1"));
+ Node* send1 =
+ SendFromHost(ops::NodeOut(key_constant, 0), "F1", "O1", {e},
+ b2.opts().WithAttr(kXlaHasHostTransferAttrName, true));
+ Node* recv2 =
+ RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O2", {DT_FLOAT},
+ b2.opts().WithAttr(kXlaHasHostTransferAttrName, true));
+ Node* g = Unary(recv2, b2.opts()
+ .WithName("G")
+ .WithAttr("_encapsulate", "F1")
+ .WithAttr("_outside", "O2")
+ .WithControlInput(e));
+ Node* send2 =
+ SendFromHost(ops::NodeOut(key_constant, 0), "F1", "O2", {g},
+ b2.opts().WithAttr(kXlaHasHostTransferAttrName, true));
+ Node* s1 = Sequencer(b2.opts()
+ .WithName("F1_sequencer")
+ .WithControlInputs({recv1, send1, recv2, send2}),
+ "F1");
NodeBuilder node_builder1("F1", "F1", lib_def.get());
node_builder1.Input(a).Input(b).ControlInput(s1);
Node* call1 = b2.opts().FinalizeBuilder(&node_builder1);
- Binary(e, call1, b2.opts().WithName("I"));
+ Binary(call1, ops::NodeOut(call1, 1), b2.opts().WithName("I"));
TF_EXPECT_OK(b2.ToGraphDef(&graphdef_expected));
}
@@ -1925,19 +2048,24 @@ TEST(EncapsulateSubgraphsTest,
{
GraphDefBuilder shape1(GraphDefBuilder::kFailImmediately);
Node* key_constant = KeyPlaceholder("F1", shape1.opts());
- Node* recv2 = RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O1",
- {DT_FLOAT}, shape1.opts());
+ Node* recv2 =
+ RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O1", {DT_FLOAT},
+ shape1.opts().WithAttr(kXlaHasHostTransferAttrName, true));
Node* e = Unary(ops::NodeOut(recv2, 0), shape1.opts()
.WithName("E")
.WithAttr("_encapsulate", "F1")
.WithAttr("_outside", "O1"));
- SendFromHost(ops::NodeOut(key_constant, 0), "F1", "O1", {e}, shape1.opts());
+ SendFromHost(ops::NodeOut(key_constant, 0), "F1", "O1", {e},
+ shape1.opts().WithAttr(kXlaHasHostTransferAttrName, true));
TF_EXPECT_OK(
AddGraphDefToFunctionLibrary(shape1, "F1_O1", &library_expected));
}
+ NameAttrList shape_inference_graph;
+ shape_inference_graph.set_name("_outside_compilation_shape_inference_F1_O1");
*library_expected.add_function() = FunctionDefHelper::Create(
- "F1", {"a_0_arg:float", "b_0_arg:float"}, {"h_0_retval_retval:float"}, {},
+ "F1", {"a_0_arg:float", "b_0_arg:float"},
+ {"e_0_retval_retval:float", "h_0_retval_retval:float"}, {},
{
{{"C"}, "UnaryTest", {"a_0_arg"}},
{{"D"}, "BinaryTest", {"b_0_arg", "C:o:0"}},
@@ -1945,6 +2073,16 @@ TEST(EncapsulateSubgraphsTest,
"UnaryTest",
{"outside_compilation_O1_host_compute:outputs:0"}},
{{"H"}, "UnaryTest", {"F:o:0"}},
+ {{"outside_compilation_O2_host_compute"},
+ "XlaHostCompute",
+ {"a_0_arg"},
+ {{"Tinputs", absl::Span({DT_FLOAT})},
+ {"Toutputs", absl::Span({})},
+ {"ancestors", absl::Span({})},
+ {"key", "host_compute_channel_F1_O2"},
+ {"shape_inference_graph", NameAttrList()},
+ {"shapes", absl::Span({})},
+ {"_outside_compilation_subgraph", "O2"}}},
{{"outside_compilation_O1_host_compute"},
"XlaHostCompute",
{"D:o:0"},
@@ -1952,12 +2090,12 @@ TEST(EncapsulateSubgraphsTest,
{"Toutputs", absl::Span({DT_FLOAT})},
{"ancestors", absl::Span({})},
{"key", "host_compute_channel_F1_O1"},
- {"shape_inference_graph",
- "_outside_compilation_shape_inference_F1_O1"},
+ {"shape_inference_graph", shape_inference_graph},
{"shapes", absl::Span({})},
{"_outside_compilation_subgraph", "O1"}}},
},
- {{"h_0_retval_retval", "H:o:0"}});
+ {{"e_0_retval_retval", "outside_compilation_O1_host_compute:outputs:0"},
+ {"h_0_retval_retval", "H:o:0"}});
{
std::unique_ptr lib_def(
@@ -1968,27 +2106,33 @@ TEST(EncapsulateSubgraphsTest,
Node* key_constant =
KeyPlaceholder("F1", b2.opts().WithName("F1_key_placeholder"));
- Node* recv = RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O1",
- {DT_FLOAT}, b2.opts());
- Node* e = Unary(recv, b2.opts()
- .WithName("E")
- .WithAttr("_encapsulate", "F1")
- .WithAttr("_outside", "O1"));
+ Node* recv1 =
+ RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O1", {DT_FLOAT},
+ b2.opts().WithAttr(kXlaHasHostTransferAttrName, true));
+ Node* e = Unary(recv1, b2.opts()
+ .WithName("E")
+ .WithAttr("_encapsulate", "F1")
+ .WithAttr("_outside", "O1"));
Node* send =
- SendFromHost(ops::NodeOut(key_constant, 0), "F1", "O1", {e}, b2.opts());
- /*Node* g =*/Unary(a, b2.opts()
- .WithName("G")
- .WithAttr("_encapsulate", "F1")
- .WithAttr("_outside", "O2")
- .WithControlInput(e));
- Node* s1 = Sequencer(
- b2.opts().WithName("F1_sequencer").WithControlInputs({recv, send}),
- "F1");
+ SendFromHost(ops::NodeOut(key_constant, 0), "F1", "O1", {e},
+ b2.opts().WithAttr(kXlaHasHostTransferAttrName, true));
+ Node* recv2 =
+ RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O2", {DT_FLOAT},
+ b2.opts().WithAttr(kXlaHasHostTransferAttrName, true));
+ /*Node* g =*/Unary(recv2, b2.opts()
+ .WithName("G")
+ .WithAttr("_encapsulate", "F1")
+ .WithAttr("_outside", "O2")
+ .WithControlInput(e));
+ Node* s1 = Sequencer(b2.opts()
+ .WithName("F1_sequencer")
+ .WithControlInputs({recv1, recv2, send}),
+ "F1");
NodeBuilder node_builder1("F1", "F1", lib_def.get());
node_builder1.Input(a).Input(b).ControlInput(s1);
Node* call1 = b2.opts().FinalizeBuilder(&node_builder1);
- Binary(e, call1, b2.opts().WithName("I"));
+ Binary(call1, ops::NodeOut(call1, 1), b2.opts().WithName("I"));
TF_EXPECT_OK(b2.ToGraphDef(&graphdef_expected));
}
@@ -2039,19 +2183,24 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationClusterDependency) {
{
GraphDefBuilder shape1(GraphDefBuilder::kFailImmediately);
Node* key_constant = KeyPlaceholder("F1", shape1.opts());
- Node* recv2 = RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O1",
- {DT_FLOAT}, shape1.opts());
+ Node* recv2 =
+ RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O1", {DT_FLOAT},
+ shape1.opts().WithAttr(kXlaHasHostTransferAttrName, true));
Node* e = Unary(ops::NodeOut(recv2, 0), shape1.opts()
.WithName("E")
.WithAttr("_encapsulate", "F1")
.WithAttr("_outside", "O1"));
- SendFromHost(ops::NodeOut(key_constant, 0), "F1", "O1", {e}, shape1.opts());
+ SendFromHost(ops::NodeOut(key_constant, 0), "F1", "O1", {e},
+ shape1.opts().WithAttr(kXlaHasHostTransferAttrName, true));
TF_EXPECT_OK(
AddGraphDefToFunctionLibrary(shape1, "F1_O1", &library_expected));
}
+ NameAttrList shape_inference_graph;
+ shape_inference_graph.set_name("_outside_compilation_shape_inference_F1_O1");
*library_expected.add_function() = FunctionDefHelper::Create(
- "F1", {"a_0_arg:float", "b_0_arg:float"}, {"h_0_retval_retval:float"}, {},
+ "F1", {"a_0_arg:float", "b_0_arg:float"},
+ {"e_0_retval_retval:float", "h_0_retval_retval:float"}, {},
{{{"C"}, "UnaryTest", {"a_0_arg"}},
{{"D"}, "BinaryTest", {"b_0_arg", "C:o:0"}},
{{"F"}, "UnaryTest", {"outside_compilation_O1_host_compute:outputs:0"}},
@@ -2063,8 +2212,7 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationClusterDependency) {
{"Toutputs", absl::Span({DT_FLOAT})},
{"ancestors", absl::Span({})},
{"key", "host_compute_channel_F1_O1"},
- {"shape_inference_graph",
- "_outside_compilation_shape_inference_F1_O1"},
+ {"shape_inference_graph", shape_inference_graph},
{"shapes", absl::Span({})},
{"_outside_compilation_subgraph", "O1"}}},
{{"outside_compilation_O2_host_compute"},
@@ -2074,7 +2222,7 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationClusterDependency) {
{"Toutputs", absl::Span({})},
{"ancestors", absl::Span({})},
{"key", "host_compute_channel_F1_O2"},
- {"shape_inference_graph", ""},
+ {"shape_inference_graph", NameAttrList()},
{"shapes", absl::Span({})},
{"_outside_compilation_subgraph", "O2"}},
{}},
@@ -2085,11 +2233,12 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationClusterDependency) {
{"Toutputs", absl::Span({})},
{"ancestors", absl::Span({})},
{"key", "host_compute_channel_F1_O3"},
- {"shape_inference_graph", ""},
+ {"shape_inference_graph", NameAttrList()},
{"shapes", absl::Span({})},
{"_outside_compilation_subgraph", "O3"}},
{}}},
- {{"h_0_retval_retval", "H:o:0"}});
+ {{"e_0_retval_retval", "outside_compilation_O1_host_compute:outputs:0"},
+ {"h_0_retval_retval", "H:o:0"}});
{
std::unique_ptr lib_def(
@@ -2100,23 +2249,27 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationClusterDependency) {
Node* key_constant =
KeyPlaceholder("F1", b2.opts().WithName("F1_key_placeholder"));
- Node* recv1 = RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O1",
- {DT_FLOAT}, b2.opts());
+ Node* recv1 =
+ RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O1", {DT_FLOAT},
+ b2.opts().WithAttr(kXlaHasHostTransferAttrName, true));
Node* e = Unary(recv1, b2.opts()
.WithName("E")
.WithAttr("_encapsulate", "F1")
.WithAttr("_outside", "O1"));
Node* send =
- SendFromHost(ops::NodeOut(key_constant, 0), "F1", "O1", {e}, b2.opts());
- Node* recv2 = RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O2",
- {DT_FLOAT}, b2.opts());
+ SendFromHost(ops::NodeOut(key_constant, 0), "F1", "O1", {e},
+ b2.opts().WithAttr(kXlaHasHostTransferAttrName, true));
+ Node* recv2 =
+ RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O2", {DT_FLOAT},
+ b2.opts().WithAttr(kXlaHasHostTransferAttrName, true));
Node* g = Unary(recv2, b2.opts()
.WithName("G")
.WithAttr("_encapsulate", "F1")
.WithAttr("_outside", "O2")
.WithControlInput(e));
- Node* recv3 = RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O3",
- {DT_FLOAT}, b2.opts());
+ Node* recv3 =
+ RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O3", {DT_FLOAT},
+ b2.opts().WithAttr(kXlaHasHostTransferAttrName, true));
/*Node* i =*/Binary(recv3, e,
b2.opts()
.WithName("I")
@@ -2131,7 +2284,7 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationClusterDependency) {
node_builder1.Input(a).Input(b).ControlInput(s1);
Node* call1 = b2.opts().FinalizeBuilder(&node_builder1);
- Binary(e, call1, b2.opts().WithName("J"));
+ Binary(call1, ops::NodeOut(call1, 1), b2.opts().WithName("J"));
TF_EXPECT_OK(b2.ToGraphDef(&graphdef_expected));
}
@@ -2167,14 +2320,44 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationNoInputsOrOutputs) {
FunctionDefLibrary library_expected;
GraphDef graphdef_expected;
+ {
+ GraphDefBuilder shape1(GraphDefBuilder::kFailImmediately);
+ Node* key_constant = KeyPlaceholder("F1", shape1.opts());
+ Node* recv2 =
+ RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O1", {DT_FLOAT},
+ shape1.opts().WithAttr(kXlaHasHostTransferAttrName, true));
+ Node* e = Unary(ops::NodeOut(recv2, 0), shape1.opts()
+ .WithName("E")
+ .WithAttr("_encapsulate", "F1")
+ .WithAttr("_outside", "O1"));
+ SendFromHost(ops::NodeOut(key_constant, 0), "F1", "O1", {e},
+ shape1.opts().WithAttr(kXlaHasHostTransferAttrName, true));
+ TF_EXPECT_OK(
+ AddGraphDefToFunctionLibrary(shape1, "F1_O1", &library_expected));
+ }
+
+ NameAttrList shape_inference_graph;
+ shape_inference_graph.set_name("_outside_compilation_shape_inference_F1_O1");
*library_expected.add_function() = FunctionDefHelper::Create(
- "F1", {"a_0_arg:float", "b_0_arg:float"}, {"f_0_retval_retval:float"}, {},
+ "F1", {"a_0_arg:float", "b_0_arg:float"},
+ {"e_0_retval_retval:float", "f_0_retval_retval:float"}, {},
{
{{"C"}, "UnaryTest", {"a_0_arg"}},
{{"D"}, "BinaryTest", {"b_0_arg", "C:o:0"}},
{{"F"}, "UnaryTest", {"D:o:0"}},
+ {{"outside_compilation_O1_host_compute"},
+ "XlaHostCompute",
+ {"a_0_arg"},
+ {{"Tinputs", absl::Span({DT_FLOAT})},
+ {"Toutputs", absl::Span({DT_FLOAT})},
+ {"ancestors", absl::Span({})},
+ {"key", "host_compute_channel_F1_O1"},
+ {"shape_inference_graph", shape_inference_graph},
+ {"shapes", absl::Span({})},
+ {"_outside_compilation_subgraph", "O1"}}},
},
- {{"f_0_retval_retval", "F:o:0"}});
+ {{"e_0_retval_retval", "outside_compilation_O1_host_compute:outputs:0"},
+ {"f_0_retval_retval", "F:o:0"}});
{
std::unique_ptr lib_def(
@@ -2183,15 +2366,26 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationNoInputsOrOutputs) {
Node* a = Input(b2.opts().WithName("A"));
Node* b = Input(b2.opts().WithName("B"));
- Node* e = Unary(a, b2.opts()
- .WithName("E")
- .WithAttr("_encapsulate", "F1")
- .WithAttr("_outside", "O1"));
+ Node* key_constant =
+ KeyPlaceholder("F1", b2.opts().WithName("F1_key_placeholder"));
+ Node* recv =
+ RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O1", {DT_FLOAT},
+ b2.opts().WithAttr(kXlaHasHostTransferAttrName, true));
+ Node* e = Unary(recv, b2.opts()
+ .WithName("E")
+ .WithAttr("_encapsulate", "F1")
+ .WithAttr("_outside", "O1"));
+ Node* send =
+ SendFromHost(ops::NodeOut(key_constant, 0), "F1", "O1", {e},
+ b2.opts().WithAttr(kXlaHasHostTransferAttrName, true));
+ Node* s = Sequencer(
+ b2.opts().WithName("F1_sequencer").WithControlInputs({recv, send}),
+ "F1");
NodeBuilder node_builder1("F1", "F1", lib_def.get());
- node_builder1.Input(a).Input(b);
+ node_builder1.Input(a).Input(b).ControlInput(s);
Node* call1 = b2.opts().FinalizeBuilder(&node_builder1);
- Binary(e, call1, b2.opts().WithName("G"));
+ Binary(call1, ops::NodeOut(call1, 1), b2.opts().WithName("G"));
TF_EXPECT_OK(b2.ToGraphDef(&graphdef_expected));
}
@@ -2236,20 +2430,22 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationShapeInference) {
{
GraphDefBuilder shape(GraphDefBuilder::kFailImmediately);
Node* key_constant = KeyPlaceholder("F1", shape.opts());
- Node* recv = RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O1",
- {DT_FLOAT}, shape.opts());
- Node* a = InputShaped(shape.opts().WithName("A"));
- Node* c = Unary(a, shape.opts().WithName("C"));
- Node* e = BinaryUnknownShape(c, recv,
+ Node* recv = RecvAtHost(
+ ops::NodeOut(key_constant, 0), "F1", "O1", {DT_FLOAT, DT_FLOAT},
+ shape.opts().WithAttr(kXlaHasHostTransferAttrName, true));
+ Node* e = BinaryUnknownShape(recv, ops::NodeOut(recv, 1),
shape.opts()
.WithName("E")
.WithAttr("_encapsulate", "F1")
.WithAttr("_outside", "O1"));
- SendFromHost(ops::NodeOut(key_constant, 0), "F1", "O1", {e}, shape.opts());
+ SendFromHost(ops::NodeOut(key_constant, 0), "F1", "O1", {e},
+ shape.opts().WithAttr(kXlaHasHostTransferAttrName, true));
TF_EXPECT_OK(
AddGraphDefToFunctionLibrary(shape, "F1_O1", &library_expected));
}
+ NameAttrList shape_inference_graph;
+ shape_inference_graph.set_name("_outside_compilation_shape_inference_F1_O1");
*library_expected.add_function() = test::function::XTimesTwo();
*library_expected.add_function() = FunctionDefHelper::Create(
"F1", {"b_0_arg:float", "c_0_arg:float"}, {"f_0_retval_retval:float"}, {},
@@ -2262,13 +2458,12 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationShapeInference) {
{"outside_compilation_O1_host_compute"}},
{{"outside_compilation_O1_host_compute"},
"XlaHostCompute",
- {"c:o:0"},
- {{"Tinputs", absl::Span({DT_FLOAT})},
+ {"c_0_arg", "c:o:0"},
+ {{"Tinputs", absl::Span({DT_FLOAT, DT_FLOAT})},
{"Toutputs", absl::Span({DT_FLOAT})},
{"ancestors", absl::Span({})},
{"key", "host_compute_channel_F1_O1"},
- {"shape_inference_graph",
- "_outside_compilation_shape_inference_F1_O1"},
+ {"shape_inference_graph", shape_inference_graph},
{"shapes", absl::Span({})},
{"_outside_compilation_subgraph", "O1"}},
{"c"}},
@@ -2285,16 +2480,18 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationShapeInference) {
Node* key_constant =
KeyPlaceholder("F1", b2.opts().WithName("F1_key_placeholder"));
- Node* recv = RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O1",
- {DT_FLOAT}, b2.opts());
- Node* e = BinaryUnknownShape(c, ops::NodeOut(recv, 0),
+ Node* recv = RecvAtHost(
+ ops::NodeOut(key_constant, 0), "F1", "O1", {DT_FLOAT, DT_FLOAT},
+ b2.opts().WithAttr(kXlaHasHostTransferAttrName, true));
+ Node* e = BinaryUnknownShape(recv, ops::NodeOut(recv, 1),
b2.opts()
.WithName("E")
- .WithControlInputs({recv, b})
+ .WithControlInputs({recv})
.WithAttr("_encapsulate", "F1")
.WithAttr("_outside", "O1"));
Node* send = SendFromHost(ops::NodeOut(key_constant, 0), "F1", "O1", {e},
- b2.opts().WithControlInput(e));
+ b2.opts().WithControlInput(e).WithAttr(
+ kXlaHasHostTransferAttrName, true));
Node* s = Sequencer(
b2.opts().WithName("F1_sequencer").WithControlInputs({recv, send}),
@@ -2303,9 +2500,9 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationShapeInference) {
NodeBuilder node_builder("F1", "F1", lib_def.get());
node_builder.Input(b).Input(c);
Node* call =
- b2.opts().WithControlInputs({s, c}).FinalizeBuilder(&node_builder);
+ b2.opts().WithControlInputs({s, b, c}).FinalizeBuilder(&node_builder);
- Binary(a, call, b2.opts().WithName("G").WithControlInputs({e}));
+ Binary(a, call, b2.opts().WithName("G").WithControlInputs({call}));
TF_EXPECT_OK(b2.ToGraphDef(&graphdef_expected));
}
diff --git a/tensorflow/compiler/jit/encapsulate_util.cc b/tensorflow/compiler/jit/encapsulate_util.cc
index 1f4b9c90a4ff0b1166cdb7b5942771b350740ef3..2264806d6bdabd9f26d9f83b681524399f996317 100644
--- a/tensorflow/compiler/jit/encapsulate_util.cc
+++ b/tensorflow/compiler/jit/encapsulate_util.cc
@@ -62,517 +62,6 @@ void ReplaceAttr(Node* n, const string& attr_name, const T& value) {
n->AddAttr(attr_name, value);
}
-// Step 1a ~ 1d for PreprocessForEncapsulation(). See comments of
-// PreprocessForEncapsulation() for details.
-Status ProcessControlEdges(Graph* g, const string& xla_computation_attr_name,
- const string& outside_compilation_attr_name) {
- // Gather edges to remove. We should not remove the edge while iterating.
- std::vector edges_to_remove;
- for (const Edge* e : g->edges()) {
- if (!e->IsControlEdge()) {
- continue;
- }
-
- auto src_xla_computation =
- GetStringAttr(*e->src(), xla_computation_attr_name);
- auto dst_xla_computation =
- GetStringAttr(*e->dst(), xla_computation_attr_name);
- auto src_outside_compilation =
- GetStringAttr(*e->src(), outside_compilation_attr_name);
- auto dst_outside_compilation =
- GetStringAttr(*e->dst(), outside_compilation_attr_name);
-
- if (!src_xla_computation && !dst_xla_computation) {
- continue;
- } else if (src_xla_computation && !dst_xla_computation) {
- if (src_outside_compilation) {
- // Case 1c: outside compilation to host computation control edge.
- edges_to_remove.push_back(e);
-
- TF_RETURN_IF_ERROR(AppendToListAttr(
- e->dst(), kXlaControlDependenciesAttrName, e->src()->name()));
- }
- } else if (!src_xla_computation && dst_xla_computation) {
- if (dst_outside_compilation) {
- // Case 1c: host computation control to outside compilation edge.
- edges_to_remove.push_back(e);
-
- TF_RETURN_IF_ERROR(AppendToListAttr(
- e->dst(), kXlaControlDependenciesAttrName, e->src()->name()));
- }
- } else { // src_xla_computation && dst_xla_computation
- if (*src_xla_computation != *dst_xla_computation) {
- if (src_outside_compilation && dst_outside_compilation) {
- // Case 1b: outside compilation to outside compilation control edge.
- edges_to_remove.push_back(e);
-
- TF_RETURN_IF_ERROR(AppendToListAttr(
- e->dst(), kXlaControlDependenciesAttrName, e->src()->name()));
- } else if (src_outside_compilation && !dst_outside_compilation) {
- // Case 1a: outside compilation to another XLA computaition control
- // edge.
- TF_RETURN_IF_ERROR(AppendToListAttr(
- e->src(), kXlaConnectedToOtherXlaComputationAttrName,
- *dst_xla_computation));
- } else if (!src_outside_compilation && dst_outside_compilation) {
- // Case 1a: another XLA computaition to outside compilation control
- // edge.
- TF_RETURN_IF_ERROR(AppendToListAttr(
- e->dst(), kXlaConnectedFromOtherXlaComputationAttrName,
- *src_xla_computation));
- }
- }
- }
- }
-
- for (auto e : edges_to_remove) {
- g->RemoveEdge(e);
- }
- return Status::OK();
-}
-
-// Step 2 for PreprocessForEncapsulation(). See comments of
-// PreprocessForEncapsulation() for details.
-Status ProcessXlaToXlaDataEdges(Graph* g,
- const string& xla_computation_attr_name,
- const string& outside_compilation_attr_name) {
- // Gather edges between XLA computations. Notice that we do not store `Edge*`
- // directly because we remove some nodes while adding Identity nodes, and
- // those Edge pointers might be invalidated.
- struct EdgeInfo {
- int dst_input, dst_node_id;
- };
- std::vector edges;
- for (const Edge* e : g->edges()) {
- if (e->IsControlEdge()) {
- continue;
- }
-
- auto src_xla_computation =
- GetStringAttr(*e->src(), xla_computation_attr_name);
- auto dst_xla_computation =
- GetStringAttr(*e->dst(), xla_computation_attr_name);
- auto src_outside_compilation =
- GetStringAttr(*e->src(), outside_compilation_attr_name);
- auto dst_outside_compilation =
- GetStringAttr(*e->dst(), outside_compilation_attr_name);
- if (!src_xla_computation || !dst_xla_computation) {
- continue;
- }
-
- if (*src_xla_computation != *dst_xla_computation) {
- if (src_outside_compilation || dst_outside_compilation) {
- edges.push_back(EdgeInfo{e->dst_input(), e->dst()->id()});
- VLOG(4) << "XLA -> XLA edge: " << e->DebugString();
- }
- }
- }
-
- // For each XLA -> XLA edge, add an Identity node between src and dst.
- for (int i = 0; i < edges.size(); i++) {
- Node* dst = g->FindNodeId(edges[i].dst_node_id);
- const Edge* e;
- TF_RETURN_IF_ERROR(dst->input_edge(edges[i].dst_input, &e));
- Node* src = e->src();
- int src_output = e->src_output(), dst_input = e->dst_input();
- g->RemoveEdge(e);
-
- // Create Identity node, and connect it between `src` and `dst`.
- string identity_node_name =
- absl::StrCat("bridge_", src->name(), "_", dst->name());
- DataType dtype = src->output_type(src_output);
- TF_ASSIGN_OR_RETURN(Node * identity_node,
- BuildIdentityNode(g, identity_node_name, dtype, src,
- /*requested_device=*/absl::nullopt));
- identity_node->AddAttr(kBridgeSourceNodeAttrName, src->name());
- g->AddEdge(src, src_output, identity_node, 0);
- g->AddEdge(identity_node, 0, dst, dst_input);
-
- // Replace `e->dst()` because its input node changed.
- NodeDef new_def = dst->def();
- *new_def.mutable_input(dst_input) = identity_node->name();
- TF_ASSIGN_OR_RETURN(Node * dst_replace_node, ReplaceNode(g, dst, new_def));
-
- // Other edge in `edges` might have `e->dst()` as src or dst
- // node. Before removing `e->dst()`, replace those edges with corresponding
- // edges for `dst_replace_node`.
- for (int j = i + 1; j < edges.size(); j++) {
- if (edges[j].dst_node_id == edges[i].dst_node_id) {
- edges[j].dst_node_id = dst_replace_node->id();
- }
- }
- }
- return Status::OK();
-}
-
-// Step 3 for PreprocessForEncapsulation(). See comments of
-// PreprocessForEncapsulation() for details.
-Status ProcessDataEdgeBetweenOutsideCompilationAndHostComputation(
- Graph* g, const string& xla_computation_attr_name,
- const string& outside_compilation_attr_name) {
- // Gather edges between outside compilation and host computation. Notice that
- // we do not store `Edge*` directly because we remove some nodes while adding
- // Identity nodes, and those Edge pointers might be invalidated.
- struct EdgeInfo {
- int dst_input, dst_node_id;
- bool is_host_to_outside_compilation;
- };
- std::vector edges;
- for (const Edge* e : g->edges()) {
- if (e->IsControlEdge()) {
- continue;
- }
-
- if (e->src()->attrs().Find(xla_computation_attr_name) == nullptr &&
- e->dst()->attrs().Find(xla_computation_attr_name) != nullptr &&
- e->dst()->attrs().Find(outside_compilation_attr_name) != nullptr) {
- edges.push_back(EdgeInfo{e->dst_input(), e->dst()->id(),
- /*is_host_to_outside_compilation=*/true});
- VLOG(4) << "Host -> oc edge: " << e->DebugString();
- } else if (e->dst()->attrs().Find(xla_computation_attr_name) == nullptr &&
- e->src()->attrs().Find(xla_computation_attr_name) != nullptr &&
- e->src()->attrs().Find(outside_compilation_attr_name) !=
- nullptr) {
- edges.push_back(EdgeInfo{e->dst_input(), e->dst()->id(),
- /*is_host_to_outside_compilation=*/false});
- VLOG(4) << "Oc -> host edge: " << e->DebugString();
- }
- }
-
- // Remove the edge from host to outside compilation. Add a placeholder as
- // outside compilation node input.
- std::map, Node*> placeholders;
- for (int i = 0; i < edges.size(); i++) {
- Node* dst = g->FindNodeId(edges[i].dst_node_id);
- const Edge* e;
- TF_RETURN_IF_ERROR(dst->input_edge(edges[i].dst_input, &e));
- Node* src = e->src();
- int src_output = e->src_output(), dst_input = e->dst_input();
- g->RemoveEdge(e);
-
- // Find or create placeholder node.
- string new_name =
- edges[i].is_host_to_outside_compilation
- ? absl::StrCat(src->name(), "_host_to_oc_placeholder_", src_output)
- : absl::StrCat(src->name(), "_oc_to_host_placeholder_", src_output);
- auto placeholder_index = std::make_pair(src->name(), src_output);
- auto iter = placeholders.find(placeholder_index);
- Node* placeholder_node;
- if (iter == placeholders.end()) {
- NodeDefBuilder placeholder_builder(new_name, "Placeholder");
- placeholder_builder.Attr("dtype", src->output_type(src_output));
- if (edges[i].is_host_to_outside_compilation) {
- placeholder_builder.Attr(kHostToOutsideCompilationOriginalNodeAttrName,
- src->name());
- placeholder_builder.Attr(kHostToOutsideCompilationSrcOutputAttrName,
- src_output);
- // If this placeholder node is in outside compilation, we need to set
- // `xla_computation_attr_name` and `outside_compilation_attr_name`.
- string xla_computation_attr, outside_compilation_attr;
- TF_RETURN_IF_ERROR(GetNodeAttr(dst->attrs(), xla_computation_attr_name,
- &xla_computation_attr));
- TF_RETURN_IF_ERROR(GetNodeAttr(dst->attrs(),
- outside_compilation_attr_name,
- &outside_compilation_attr));
- placeholder_builder.Attr(xla_computation_attr_name,
- xla_computation_attr);
- placeholder_builder.Attr(outside_compilation_attr_name,
- outside_compilation_attr);
- } else {
- placeholder_builder.Attr(kOutsideCompilationToHostOriginalNodeAttrName,
- src->name());
- placeholder_builder.Attr(kOutsideCompilationToHostSrcOutputAttrName,
- src_output);
- }
- NodeDef placeholder_def;
- TF_RETURN_IF_ERROR(placeholder_builder.Finalize(&placeholder_def));
- Status s;
- placeholder_node = g->AddNode(placeholder_def, &s);
- TF_RETURN_IF_ERROR(s);
- placeholders[placeholder_index] = placeholder_node;
- } else {
- placeholder_node = iter->second;
- }
- g->AddEdge(placeholder_node, 0, dst, dst_input);
-
- // Replace `e->dst()` because its input node changed.
- NodeDef new_def = dst->def();
- *new_def.mutable_input(dst_input) = placeholder_node->name();
- TF_ASSIGN_OR_RETURN(Node * dst_replace_node, ReplaceNode(g, dst, new_def));
-
- // Other edge in `edges` might have `e->dst()` as src or dst
- // node. Before removing `e->dst()`, replace those edges with corresponding
- // edges for `dst_replace_node`.
- for (int j = i + 1; j < edges.size(); j++) {
- if (edges[j].dst_node_id == edges[i].dst_node_id) {
- edges[j].dst_node_id = dst_replace_node->id();
- }
- }
- }
- return Status::OK();
-}
-
-// Step 1 for `PostprocessForEncapsulation`. See comments of
-// `PostprocessForEncapsulation` for details.
-Status RemovePlaceholderBetweenOutsideCompilationAndHostComputation(Graph* g) {
- // Gather all outside compilation to host computation nodes.
- struct PlaceHolderNodeInfo {
- Node* n;
- bool is_host_to_oc;
- };
- std::vector placeholder_nodes;
- for (Node* n : g->nodes()) {
- if (n->type_string() == "Placeholder") {
- if (HasNodeAttr(n->def(),
- kOutsideCompilationToHostOriginalNodeAttrName)) {
- placeholder_nodes.push_back({n, false});
- } else if (HasNodeAttr(n->def(),
- kHostToOutsideCompilationOriginalNodeAttrName)) {
- placeholder_nodes.push_back({n, true});
- }
- }
- }
-
- // Remove the placeholder nodes, and reconnect original edge.
- auto node_name_index = g->BuildNodeNameIndex();
- for (auto placeholder_iter : placeholder_nodes) {
- Node* n = placeholder_iter.n;
-
- string node_name;
- int node_src_output;
- if (placeholder_iter.is_host_to_oc) {
- TF_RETURN_IF_ERROR(
- GetNodeAttr(n->attrs(), kHostToOutsideCompilationOriginalNodeAttrName,
- &node_name));
- TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(),
- kHostToOutsideCompilationSrcOutputAttrName,
- &node_src_output));
- } else {
- TF_RETURN_IF_ERROR(
- GetNodeAttr(n->attrs(), kOutsideCompilationToHostOriginalNodeAttrName,
- &node_name));
- TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(),
- kOutsideCompilationToHostSrcOutputAttrName,
- &node_src_output));
- }
- auto iter = node_name_index.find(node_name);
- if (iter == node_name_index.end()) {
- return errors::Internal(
- "Cannot find original node for oc -> host placeholder node ",
- node_name);
- }
-
- // Change all usage node to use the original node instead.
- Node* original_node = iter->second;
- std::vector control_edges;
- std::vector data_edges;
- for (auto e : n->out_edges()) {
- if (e->IsControlEdge()) {
- control_edges.push_back(e);
- } else {
- data_edges.push_back({e->dst(), e->src_output(), e->dst_input()});
- }
- }
- for (const Edge* e : control_edges) {
- g->AddControlEdge(original_node, e->dst());
- g->RemoveEdge(e);
- }
- for (int i = 0; i < data_edges.size(); i++) {
- Node* dst = data_edges[i].dst;
- NodeDef new_def = dst->def();
- int dst_input = data_edges[i].dst_input;
- *new_def.mutable_input(dst_input) =
- absl::StrCat(original_node->name(), ":", node_src_output);
- TF_ASSIGN_OR_RETURN(Node * replace_node, ReplaceNode(g, dst, new_def));
-
- const Edge* edge_to_replace = nullptr;
- TF_RETURN_IF_ERROR(replace_node->input_edge(dst_input, &edge_to_replace));
- g->RemoveEdge(edge_to_replace);
- g->AddEdge(original_node, node_src_output, replace_node, dst_input);
-
- // Other edges might have `dst` as dst node. Update those edges with
- // `replace_node`.
- for (int j = i + 1; j < data_edges.size(); j++) {
- if (data_edges[j].dst == dst) {
- data_edges[j].dst = replace_node;
- }
- }
-
- // Other placeholder node might have `dst` as original node. Update
- // `node_name_index` with `replace_node`.
- node_name_index[replace_node->name()] = replace_node;
- }
-
- // Remove placeholder node.
- g->RemoveNode(n);
- }
- return Status::OK();
-}
-
-// Step 2 for `PostprocessForEncapsulation`. See comments of
-// `PostprocessForEncapsulation` for details.
-Status RemoveIdentityBetweenDifferentXlaComputation(Graph* g) {
- // Gather Identity nodes to remove.
- std::vector bridge_nodes;
- for (Node* n : g->nodes()) {
- if (n->type_string() == "Identity" &&
- HasNodeAttr(n->def(), kBridgeSourceNodeAttrName)) {
- bridge_nodes.push_back(n);
- }
- }
-
- // Remove the identity nodes, and reconnect the original edge.
- for (int i = 0; i < bridge_nodes.size(); i++) {
- Node* n = bridge_nodes[i];
- const Edge* src_edge = nullptr;
- TF_RETURN_IF_ERROR(n->input_edge(0, &src_edge));
-
- // Change all usage node to use the original node instead.
- std::vector control_edges;
- std::vector data_edges;
- for (auto e : n->out_edges()) {
- if (e->IsControlEdge()) {
- control_edges.push_back(e);
- } else {
- data_edges.push_back({e->dst(), e->src_output(), e->dst_input()});
- }
- }
- for (const Edge* e : control_edges) {
- g->AddControlEdge(src_edge->src(), e->dst());
- g->RemoveEdge(e);
- }
- for (int j = 0; j < data_edges.size(); j++) {
- Node* dst = data_edges[j].dst;
- NodeDef new_def = dst->def();
- int dst_input = data_edges[j].dst_input;
- *new_def.mutable_input(dst_input) =
- absl::StrCat(src_edge->src()->name(), ":", src_edge->src_output());
- TF_ASSIGN_OR_RETURN(Node * replace_node, ReplaceNode(g, dst, new_def));
-
- const Edge* edge_to_replace = nullptr;
- TF_RETURN_IF_ERROR(replace_node->input_edge(dst_input, &edge_to_replace));
- g->RemoveEdge(edge_to_replace);
- g->AddEdge(src_edge->src(), src_edge->src_output(), replace_node,
- dst_input);
-
- // Other edges might have `dst` as dst node. Update those edges with
- // `replace_node`.
- for (int k = j + 1; k < data_edges.size(); k++) {
- if (data_edges[k].dst == dst) {
- data_edges[k].dst = replace_node;
- }
- }
-
- // The node we replaced might be in `bridge_nodes`. If so, update
- // `bridge_nodes` to use the replaced node.
- for (int k = i + 1; k < bridge_nodes.size(); k++) {
- if (bridge_nodes[k] == dst) {
- bridge_nodes[k] = replace_node;
- }
- }
- }
-
- // Remove Identity node.
- g->RemoveNode(n);
- }
- return Status::OK();
-}
-
-// Step 3 for `PostprocessForEncapsulation`. See comments of
-// `PostprocessForEncapsulation` for details.
-// We do not need to worry about removed nodes in step 1 and 2;
-// `PreprocessForEncapsulation` will not record control dependencies for those
-// remvoed nodes in the first place.
-Status AddControlDependencies(
- Graph* g, const std::unordered_map& cluster_node_names) {
- auto node_name_index = g->BuildNodeNameIndex();
-
- // Reconnect outside compilation to outside compilation control edge.
- for (Node* n : g->nodes()) {
- std::vector control_deps;
- Status s =
- GetNodeAttr(n->attrs(), kXlaControlDependenciesAttrName, &control_deps);
- if (!s.ok()) {
- if (s.code() != error::NOT_FOUND) {
- return s;
- } else {
- continue;
- }
- } else {
- n->ClearAttr(kXlaControlDependenciesAttrName);
- for (const string& control_input : control_deps) {
- auto iter = node_name_index.find(control_input);
- if (iter == node_name_index.end()) {
- return errors::Internal("Cannot find original node for ",
- control_input);
- }
- g->AddControlEdge(iter->second, n);
- }
- }
- }
-
- // Reconnect outside compilation to XLA computation control edge.
- for (Node* n : g->nodes()) {
- std::vector control_deps;
- Status s = GetNodeAttr(
- n->attrs(), kXlaConnectedToOtherXlaComputationAttrName, &control_deps);
- if (!s.ok()) {
- if (s.code() != error::NOT_FOUND) {
- return s;
- } else {
- continue;
- }
- } else {
- n->ClearAttr(kXlaConnectedToOtherXlaComputationAttrName);
- for (const string& control_input : control_deps) {
- auto iter = cluster_node_names.find(control_input);
- if (iter == cluster_node_names.end()) {
- return errors::Internal("Cannot find cluster node for ",
- control_input);
- }
- auto iter2 = node_name_index.find(iter->second);
- if (iter2 == node_name_index.end()) {
- return errors::Internal("Cannot find cluster node for ",
- iter->second);
- }
- g->AddControlEdge(n, iter2->second);
- }
- }
- }
-
- // Reconnect XLA computation to outside compilation control edge.
- for (Node* n : g->nodes()) {
- std::vector control_deps;
- Status s =
- GetNodeAttr(n->attrs(), kXlaConnectedFromOtherXlaComputationAttrName,
- &control_deps);
- if (!s.ok()) {
- if (s.code() != error::NOT_FOUND) {
- return s;
- } else {
- continue;
- }
- } else {
- n->ClearAttr(kXlaConnectedFromOtherXlaComputationAttrName);
- for (const string& control_input : control_deps) {
- auto iter = cluster_node_names.find(control_input);
- if (iter == cluster_node_names.end()) {
- return errors::Internal("Cannot find cluster node for ",
- control_input);
- }
- auto iter2 = node_name_index.find(iter->second);
- if (iter2 == node_name_index.end()) {
- return errors::Internal("Cannot find cluster node for ",
- iter->second);
- }
- g->AddControlEdge(iter2->second, n);
- }
- }
- }
-
- return Status::OK();
-}
-
// Step 1 for `PreprocessEdgesBetweenOutsideCompilations`. See comments of
// `PreprocessEdgesBetweenOutsideCompilations` for details.
Status PreprocessControlEdgesBetweenOutsideCompilations(
@@ -811,20 +300,6 @@ Status PostprocessControlEdgesBetweenOutsideCompilations(
const char kXlaInferredShapesAttrName[] = "_xla_inferred_shapes";
-const char kXlaConnectedToOtherXlaComputationAttrName[] =
- "_xla_connected_to_other_xla_computation";
-const char kXlaConnectedFromOtherXlaComputationAttrName[] =
- "_xla_connected_from_other_xla_computation";
-const char kXlaControlDependenciesAttrName[] = "_xla_control_dependencies";
-const char kBridgeSourceNodeAttrName[] = "_xla_bridge_src";
-const char kOutsideCompilationToHostOriginalNodeAttrName[] =
- "_xla_oc_to_host_node_name";
-const char kOutsideCompilationToHostSrcOutputAttrName[] =
- "_xla_oc_to_host_src_output";
-const char kHostToOutsideCompilationOriginalNodeAttrName[] =
- "_xla_host_to_oc_node_name";
-const char kHostToOutsideCompilationSrcOutputAttrName[] =
- "_xla_host_to_oc_src_output";
const char kXlaConnectedToXlaComputationAttrName[] =
"_xla_connected_to_xla_computation";
const char kXlaConnectedFromXlaComputationAttrName[] =
@@ -835,32 +310,7 @@ const char kOutsideCompilationSrcOutputAttrName[] = "_xla_oc_to_oc_src_output";
const char kXlaControlDependenciesWithinXlaClusterAttrName[] =
"_xla_control_dependencies_within_xla_cluster";
-Status PerformStaticShapeInferenceBeforeEncapsulation(
- Graph* g, const string& xla_computation_attr_name,
- const string& outside_compilation_attr_name) {
- // Find all outside compilation to XLA computation data edges.
- std::unordered_set outside_compilation_send_nodes;
- for (auto e : g->edges()) {
- if (e->IsControlEdge()) {
- continue;
- }
-
- auto src_computation = GetStringAttr(*e->src(), xla_computation_attr_name);
- auto dst_computation = GetStringAttr(*e->dst(), xla_computation_attr_name);
- if (!src_computation || !dst_computation ||
- *src_computation != *dst_computation) {
- continue;
- }
-
- auto src_outside_compilation =
- GetStringAttr(*e->src(), outside_compilation_attr_name);
- auto dst_outside_compilation =
- GetStringAttr(*e->dst(), outside_compilation_attr_name);
- if (src_outside_compilation && !dst_outside_compilation) {
- outside_compilation_send_nodes.insert(e->src());
- }
- }
-
+Status PerformStaticShapeInferenceBeforeEncapsulation(Graph* g) {
// Perform shape inference.
std::map arg_shapes;
GraphShapeInfo shape_info;
@@ -868,55 +318,21 @@ Status PerformStaticShapeInferenceBeforeEncapsulation(
InferShapes(g, arg_shapes, /*fnlib_def=*/nullptr, &shape_info));
// Add attribute for output shapes.
- for (Node* n : outside_compilation_send_nodes) {
- auto iter = shape_info.find(n->name());
- if (iter == shape_info.end()) {
- continue;
- }
-
+ auto node_name_index = g->BuildNodeNameIndex();
+ for (auto iter : shape_info) {
std::vector output_shapes;
- std::transform(iter->second.begin(), iter->second.end(),
+ std::transform(iter.second.begin(), iter.second.end(),
std::back_inserter(output_shapes),
[](const InferredShape& inferred_shape) {
return inferred_shape.shape;
});
+ Node* n = node_name_index[iter.first];
n->AddAttr(kXlaInferredShapesAttrName, output_shapes);
}
return Status::OK();
}
-Status PreprocessForEncapsulation(Graph* g,
- const string& xla_computation_attr_name,
- const string& outside_compilation_attr_name) {
- TF_RETURN_IF_ERROR(ProcessControlEdges(g, xla_computation_attr_name,
- outside_compilation_attr_name));
- TF_RETURN_IF_ERROR(ProcessXlaToXlaDataEdges(g, xla_computation_attr_name,
- outside_compilation_attr_name));
- TF_RETURN_IF_ERROR(ProcessDataEdgeBetweenOutsideCompilationAndHostComputation(
- g, xla_computation_attr_name, outside_compilation_attr_name));
- return Status::OK();
-}
-
-Status PostprocessForEncapsulation(
- Graph* g, const string& xla_computation_attr_name,
- const string& outside_compilation_attr_name,
- const std::unordered_map& clusters) {
- // The `node` pointer in `XlaClusterInfo` might be invalidated in step 1/2,
- // but the node name won't change. Record cluster node name for
- // `AddControlDependencies`.
- std::unordered_map cluster_node_names;
- for (const auto& iter : clusters) {
- cluster_node_names[iter.first] = iter.second.node->name();
- }
-
- TF_RETURN_IF_ERROR(
- RemovePlaceholderBetweenOutsideCompilationAndHostComputation(g));
- TF_RETURN_IF_ERROR(RemoveIdentityBetweenDifferentXlaComputation(g));
- TF_RETURN_IF_ERROR(AddControlDependencies(g, cluster_node_names));
- return Status::OK();
-}
-
Status PreprocessEdgesBetweenOutsideCompilations(
Graph* g, const string& outside_compilation_attr_name) {
// Remove edges from source node to outside compilation nodes, and edges
diff --git a/tensorflow/compiler/jit/encapsulate_util.h b/tensorflow/compiler/jit/encapsulate_util.h
index e363bc5754ac395bae262dc67a780a0173efaf5e..c9f16d14168163e11bb19092f566f1de8724aca3 100644
--- a/tensorflow/compiler/jit/encapsulate_util.h
+++ b/tensorflow/compiler/jit/encapsulate_util.h
@@ -27,51 +27,13 @@ namespace tensorflow {
// a list of PartialTensorShape objects.
extern const char kXlaInferredShapesAttrName[];
-// Infer output shapes for outside compilation nodes which have output data
-// edges to XLA computation nodes. These shapes will be used later by XLA
-// compiler as output shapes of the outside compilation's XlaHostCompute op.
-// XLA computation nodes will be mark by attr `xla_computation_attr_name`;
-// outside compilation nodes will be marked by both attr
-// `xla_computation_attr_name` and `outside_compilation_attr_name`.
-//
-// Those outside compilation nodes will be marked with attribute
-// `kXlaInferredShapesAttrName`.
+// Infers output shapes for all nodes in graph `g`. The output shapes will be
+// stored in node attribute `kXlaInferredShapesAttrName`.
//
// We have to perform shape inference before encapsulation because after
// encapsulation, some nodes will be encapsulated into function call, and shape
// inference does not handle function call at the moment.
-Status PerformStaticShapeInferenceBeforeEncapsulation(
- Graph* g, const string& xla_computation_attr_name,
- const string& outside_compilation_attr_name);
-
-// Attribute indicating that some ops in other XLA computation has control
-// dependency on this node. Attribute value will be a list of string (XLA
-// computation names).
-extern const char kXlaConnectedToOtherXlaComputationAttrName[];
-
-// Attribute indicating that this node has control dependency on some ops in
-// other XLA computation. Attribute value will be a list of string (XLA
-// computation names).
-extern const char kXlaConnectedFromOtherXlaComputationAttrName[];
-
-// Attribute indicating that this node has control dependencies on some other
-// nodes. Attribute value will be a list of string (node names).
-extern const char kXlaControlDependenciesAttrName[];
-
-// Attribute indicating that this is an Identity node added to act as a bridge
-// between different XLA computations. Attribute value will be string (source
-// node name).
-extern const char kBridgeSourceNodeAttrName[];
-
-// Attribute indicating that this is an Placeholder node added to act as a
-// temporary input node for an outside compilation node. Attribute value will be
-// string (original input node name).
-extern const char kOutsideCompilationToHostOriginalNodeAttrName[];
-
-// Attribute indicating that this is an Placeholder node added to act as a
-// temporary input node for an outside compilation node. Attribute value will be
-// int (src_output for original edge).
-extern const char kOutsideCompilationToHostSrcOutputAttrName[];
+Status PerformStaticShapeInferenceBeforeEncapsulation(Graph* g);
// Attribute indicating that some ops in this node's XLA computation has control
// dependency on this node. Attribute value will always be "true".
@@ -81,16 +43,6 @@ extern const char kXlaConnectedToXlaComputationAttrName[];
// this node's XLA computation. Attribute value will always be "true".
extern const char kXlaConnectedFromXlaComputationAttrName[];
-// Attribute indicating that this is an Placeholder node added to act as a
-// temporary input node for an host node. Attribute value will be string
-// (original input node name).
-extern const char kHostToOutsideCompilationOriginalNodeAttrName[];
-
-// Attribute indicating that this is an Placeholder node added to act as a
-// temporary input node for a host node. Attribute value will be int (src_output
-// for original edge).
-extern const char kHostToOutsideCompilationSrcOutputAttrName[];
-
// Attribute indicating that this is an Placeholder node added to act as a
// temporary input node for an outside compilation node. Attribute value will be
// string (original input node name).
@@ -106,27 +58,6 @@ extern const char kOutsideCompilationSrcOutputAttrName[];
// (node names).
extern const char kXlaControlDependenciesWithinXlaClusterAttrName[];
-// Preprocesses edges between different XLA clusters for encapsulation. It will
-// perform the following operations in order:
-//
-// 1a. For control edges between outside compilation and another XLA
-// computation, add attr "kXlaConnected{From, To}OtherXlaComputationAttrName
-// = XLA computation node name" to the outside compilation node.
-// 1b. For control edges between different outside compilations (in different
-// XLA computations), remove the edge and add attr
-// "kXlaControlDependenciesAttrName = src node name" to dst node.
-// 1c. For control edges between outside compilation and host computation,
-// remove the edge and add attr "kXlaControlDependenciesAttrName = src node
-// name" to dst node.
-// 2. For data edges between different XLA computations, if either src or dst
-// is outside compilation, add an Identity node in between the edge. The
-// identity node will have attr kBridgeSourceNodeAttrName.
-// 3. For data edges between outside compilation and host computation, remove
-// the edge and create a Placeholder node as dst node's input.
-Status PreprocessForEncapsulation(Graph* g,
- const string& xla_computation_attr_name,
- const string& outside_compilation_attr_name);
-
// Information for XLA computation.
struct XlaClusterInfo {
// Add an explicitly-defined default constructor for this class.
@@ -158,24 +89,6 @@ struct XlaClusterInfo {
const std::map host_compute_core;
};
-// Postprocesses edges between different XLA clusters for encapsulation. This
-// function reverts what `PreprocessForEncapsulation` did. It will perform the
-// following operations in order:
-//
-// 1. Remove Placeholder nodes between outside compilation and host computation
-// (created in `PreprocessForEncapsulation` step 3).
-// 2. Remove Identity nodes created in `PreprocessForEncapsulation` step 2.
-// 3a. Reconnect control edges between outside compilation and another XLA
-// computation (marked by `PreprocessForEncapsulation` step 1a).
-// 3b. Reconnect control edges between different outside compilations (marked by
-// `PreprocessForEncapsulation` step 1b).
-// 3c. Reconnect control edges between outside compilation and host computation
-// (marked by `PreprocessForEncapsulation` step 1c).
-Status PostprocessForEncapsulation(
- Graph* g, const string& xla_computation_attr_name,
- const string& outside_compilation_attr_name,
- const std::unordered_map& clusters);
-
// Preprocesses edges within the same XLA cluster. It will perform the following
// operations in order:
//
diff --git a/tensorflow/compiler/jit/encapsulate_util_test.cc b/tensorflow/compiler/jit/encapsulate_util_test.cc
index 3b8b49cb92f3e453883a8e64e12ce3748a5173f6..3bb979e0698d2d6be42ed5bae66c25267928192c 100644
--- a/tensorflow/compiler/jit/encapsulate_util_test.cc
+++ b/tensorflow/compiler/jit/encapsulate_util_test.cc
@@ -38,24 +38,11 @@ TEST(PerformStaticShapeInferenceBeforeEncapsulationTest, Basic) {
Graph g(OpRegistry::Global());
TF_CHECK_OK(s.ToGraph(&g));
- // "add" node is outside compilation node, "identity" node is XLA node.
- auto node_index = g.BuildNodeNameIndex();
- Node *add_node = node_index["add"], *identity_node = node_index["identity"];
- add_node->AddAttr("_xla", "cluster");
- add_node->AddAttr("_oc", "cluster");
- identity_node->AddAttr("_xla", "cluster");
- TF_CHECK_OK(
- PerformStaticShapeInferenceBeforeEncapsulation(&g, "_xla", "_oc"));
+ TF_CHECK_OK(PerformStaticShapeInferenceBeforeEncapsulation(&g));
- // Check that only "add" node now has _xla_inferred_shapes attr.
- std::vector nodes_with_inferred_shape;
- for (Node *n : g.nodes()) {
- if (HasNodeAttr(n->def(), kXlaInferredShapesAttrName)) {
- nodes_with_inferred_shape.push_back(n);
- }
- }
- EXPECT_EQ(nodes_with_inferred_shape.size(), 1);
- EXPECT_EQ(nodes_with_inferred_shape[0], add_node);
+ // Check that "add" node now has _xla_inferred_shapes attr.
+ auto node_index = g.BuildNodeNameIndex();
+ Node *add_node = node_index["add"];
std::vector output_shapes;
TF_CHECK_OK(GetNodeAttr(add_node->attrs(), kXlaInferredShapesAttrName,
&output_shapes));
@@ -66,329 +53,4 @@ TEST(PerformStaticShapeInferenceBeforeEncapsulationTest, Basic) {
EXPECT_EQ(shape_proto.dim(0).size(), 2);
}
-TEST(PreprocessForEncapsulationTest, ControlEdges) {
- // Build the graph:
- // "const_0" and "const_1" in host computation
- // "add" = "const_0" + "const_1" in XLA computation 0
- // "identity0" = "add" in XLA computation 0 & outside compilation 0
- // "identity1" = "identity0" in XLA computation 0
- // "identity2" = "identity1" in host computation
- // "identity3" = "identity2" in XLA computation 1
- // "identity4" = "identity3" in XLA computation 1 & outside compilation 1
- // "identity5" = "identity4" in XLA computation 1
- // "identity6" = "identity5" in host computation
- tensorflow::Scope s = tensorflow::Scope::NewRootScope();
- Output const_0 = ops::Const(s.WithOpName("const_0"), 1, {});
- Output const_1 = ops::Const(s.WithOpName("const_1"), 2, {});
- Output add = ops::Add(s.WithOpName("add"), const_0, const_1);
- Output identity0 = ops::Identity(s.WithOpName("identity0"), add);
- Output identity1 = ops::Identity(s.WithOpName("identity1"), identity0);
- Output identity2 = ops::Identity(s.WithOpName("identity2"), identity1);
- Output identity3 = ops::Identity(s.WithOpName("identity3"), identity2);
- Output identity4 = ops::Identity(s.WithOpName("identity4"), identity3);
- Output identity5 = ops::Identity(s.WithOpName("identity5"), identity4);
- Graph g(OpRegistry::Global());
- TF_CHECK_OK(s.ToGraph(&g));
- auto node_index = g.BuildNodeNameIndex();
-
- // Set XLA computation/outside compilation attr, and add control edges.
- Node *const0_node = node_index["const_0"], *add_node = node_index["add"],
- *identity0_node = node_index["identity0"],
- *identity1_node = node_index["identity1"],
- *identity2_node = node_index["identity2"],
- *identity3_node = node_index["identity3"],
- *identity4_node = node_index["identity4"],
- *identity5_node = node_index["identity5"];
- add_node->AddAttr("_xla", "0");
- identity0_node->AddAttr("_xla", "0");
- identity0_node->AddAttr("_oc", "0");
- identity1_node->AddAttr("_xla", "0");
- identity3_node->AddAttr("_xla", "1");
- identity4_node->AddAttr("_xla", "1");
- identity4_node->AddAttr("_oc", "0");
- identity5_node->AddAttr("_xla", "1");
- // Case 1a: control edges between outside compilation and another XLA
- // computation.
- g.AddControlEdge(identity0_node, identity3_node);
- g.AddControlEdge(identity1_node, identity4_node);
- // Case 1b: control edges between different outside compilations.
- g.AddControlEdge(identity0_node, identity4_node);
- // Case 1c: control edges between outside compilation and host computation.
- g.AddControlEdge(const0_node, identity0_node);
- g.AddControlEdge(identity0_node, identity2_node);
-
- TF_CHECK_OK(PreprocessForEncapsulation(&g, "_xla", "_oc"));
-
- // Case 1a: add attr "_xla_control_deps_{from/to} = XLA computation node name"
- // to the outside compilation node.
- std::vector attr;
- TF_CHECK_OK(GetNodeAttr(identity0_node->def(),
- kXlaConnectedToOtherXlaComputationAttrName, &attr));
- EXPECT_EQ(attr.size(), 1);
- EXPECT_EQ(attr[0], "1");
- attr.clear();
- TF_CHECK_OK(GetNodeAttr(identity4_node->def(),
- kXlaConnectedFromOtherXlaComputationAttrName, &attr));
- EXPECT_EQ(attr.size(), 1);
- EXPECT_EQ(attr[0], "0");
- // Case 1b: add attr "_xla_control_deps = src node name" to dst node.
- attr.clear();
- TF_CHECK_OK(GetNodeAttr(identity4_node->def(),
- kXlaControlDependenciesAttrName, &attr));
- EXPECT_EQ(attr.size(), 1);
- EXPECT_EQ(attr[0], "identity0");
- // Case 1c: add attr "_xla_control_deps = src node name" to dst node.
- attr.clear();
- TF_CHECK_OK(GetNodeAttr(identity0_node->def(),
- kXlaControlDependenciesAttrName, &attr));
- EXPECT_EQ(attr.size(), 1);
- EXPECT_EQ(attr[0], "const_0");
- attr.clear();
- TF_CHECK_OK(GetNodeAttr(identity2_node->def(),
- kXlaControlDependenciesAttrName, &attr));
- EXPECT_EQ(attr.size(), 1);
- EXPECT_EQ(attr[0], "identity0");
-}
-
-TEST(PreprocessForEncapsulationTest, DataEdges) {
- // Build the graph:
- // "const_0" and "const_1" in host computation
- // "identityn0" = ("const_0", "const_1") in host computation 0
- // "add0" = "const_0" + "const_1" in XLA computation 0
- // "add1" = "add0" + "const_0" in XLA computation 0 & outside compilation 0
- // "identity0" = "add1" in XLA computation 0
- // "add2" = "add1" + "identity0" in host computation
- // "add3" = "add1" + "add2" in XLA computation 1
- // "add4" = "identity0" + "add2" in XLA computation 1 & outside compilation 0
- // "add5" = "identityn0"[0] + "identityn0"[1] in XLA computation 1 &
- // outside compilation 0
- // "identityn1" = ("identityn0"[0], "identityn0"[1]) in XLA computation 1 &
- // outside compilation 0
- // "identity1" = "add4" in XLA computation 1
- // "identity2" = "identity1" in host computation
- tensorflow::Scope s = tensorflow::Scope::NewRootScope();
- Output const_0 = ops::Const(s.WithOpName("const_0"), 1, {});
- Output const_1 = ops::Const(s.WithOpName("const_1"), 2, {});
- auto identityn0 =
- ops::IdentityN(s.WithOpName("identityn_0"), {const_0, const_1});
- Output add0 = ops::Add(s.WithOpName("add0"), const_0, const_1);
- Output add1 = ops::Add(s.WithOpName("add1"), add0, const_0);
- Output identity0 = ops::Identity(s.WithOpName("identity0"), add1);
- Output add2 = ops::Add(s.WithOpName("add2"), add1, identity0);
- Output add3 = ops::Add(s.WithOpName("add3"), add1, add2);
- Output add4 = ops::Add(s.WithOpName("add4"), identity0, add2);
- Output add5 = ops::Add(s.WithOpName("add5"), identityn0[0], identityn0[1]);
- auto identityn1 = ops::IdentityN(s.WithOpName("identityn_1"),
- {identityn0[0], identityn0[1]});
- Output identity1 = ops::Identity(s.WithOpName("identity1"), add4);
- Output identity2 = ops::Identity(s.WithOpName("identity2"), add4);
- Graph g(OpRegistry::Global());
- TF_CHECK_OK(s.ToGraph(&g));
- auto node_index = g.BuildNodeNameIndex();
-
- // Set XLA computation/outside compilation attr.
- Node *add0_node = node_index["add0"], *add1_node = node_index["add1"],
- *identity0_node = node_index["identity0"],
- *add3_node = node_index["add3"], *add4_node = node_index["add4"],
- *add5_node = node_index["add5"],
- *identityn1_node = node_index["identityn_1"],
- *identity1_node = node_index["identity1"];
- add0_node->AddAttr("_xla", "0");
- add1_node->AddAttr("_xla", "0");
- add1_node->AddAttr("_oc", "0");
- identity0_node->AddAttr("_xla", "0");
- add3_node->AddAttr("_xla", "1");
- add4_node->AddAttr("_xla", "1");
- add4_node->AddAttr("_oc", "0");
- add5_node->AddAttr("_xla", "1");
- add5_node->AddAttr("_oc", "0");
- identityn1_node->AddAttr("_xla", "1");
- identityn1_node->AddAttr("_oc", "0");
- identity1_node->AddAttr("_xla", "1");
-
- TF_CHECK_OK(PreprocessForEncapsulation(&g, "_xla", "_oc"));
-
- // Check input nodes for related data edges.
- node_index = g.BuildNodeNameIndex();
- // Step 2: add an Identity node between different XLA computations.
- Node *bridge_add1_add3 = node_index["bridge_add1_add3"];
- EXPECT_NE(bridge_add1_add3, nullptr);
- string str;
- TF_CHECK_OK(
- GetNodeAttr(bridge_add1_add3->attrs(), kBridgeSourceNodeAttrName, &str));
- EXPECT_EQ(str, "add1");
- Node *bridge_identity0_add4 = node_index["bridge_identity0_add4"];
- EXPECT_NE(bridge_identity0_add4, nullptr);
- // Step 3: add placeholder for edges between host computation and outside
- // compilation.
- EXPECT_EQ(bridge_add1_add3->def().input(0), "add1_oc_to_host_placeholder_0");
- Node *add1_oc_to_host_placeholder =
- node_index["add1_oc_to_host_placeholder_0"];
- TF_CHECK_OK(GetNodeAttr(add1_oc_to_host_placeholder->attrs(),
- kOutsideCompilationToHostOriginalNodeAttrName, &str));
- EXPECT_EQ(str, "add1");
- int i;
- TF_CHECK_OK(GetNodeAttr(add1_oc_to_host_placeholder->attrs(),
- kOutsideCompilationToHostSrcOutputAttrName, &i));
- EXPECT_EQ(i, 0);
- add4_node = node_index["add4"];
- ASSERT_NE(add4_node, nullptr);
- EXPECT_EQ(add4_node->def().input(0),
- "bridge_identity0_add4_host_to_oc_placeholder_0");
- Node *identity0_host_to_oc_placeholder =
- node_index["bridge_identity0_add4_host_to_oc_placeholder_0"];
- TF_CHECK_OK(GetNodeAttr(identity0_host_to_oc_placeholder->attrs(),
- kHostToOutsideCompilationOriginalNodeAttrName, &str));
- EXPECT_EQ(str, "bridge_identity0_add4");
- TF_CHECK_OK(GetNodeAttr(identity0_host_to_oc_placeholder->attrs(),
- kHostToOutsideCompilationSrcOutputAttrName, &i));
- EXPECT_EQ(i, 0);
-
- // Check different placeholder nodes are created for different src_output.
- Node *placeholder0 = node_index["identityn_0_host_to_oc_placeholder_0"],
- *placeholder1 = node_index["identityn_0_host_to_oc_placeholder_1"];
- EXPECT_NE(placeholder0, nullptr);
- EXPECT_NE(placeholder1, nullptr);
- // Check we only have 2 placeholder nodes created for "identityn_0".
- int placeholder_count = 0;
- for (Node *n : g.nodes()) {
- if (HasNodeAttr(n->def(), kHostToOutsideCompilationOriginalNodeAttrName)) {
- string attr;
- TF_CHECK_OK(GetNodeAttr(
- n->attrs(), kHostToOutsideCompilationOriginalNodeAttrName, &attr));
- if (attr == "identityn_0") {
- ++placeholder_count;
- }
- }
- }
- EXPECT_EQ(placeholder_count, 2);
-}
-
-TEST(PostprocessForEncapsulationTest, ControlEdges) {
- // Build the graph:
- // "const0"
- // "identity0" = "const0" (XLA computation 0)
- // "identity1" = "identity0"
- // "identity2" = "identity1" (XLA computation 1)
- // "identity3" = "identity2"
- tensorflow::Scope s = tensorflow::Scope::NewRootScope();
- Output const0 = ops::Const(s.WithOpName("const0"), 1, {});
- Output identity0 = ops::Identity(s.WithOpName("identity0"), const0);
- Output identity1 = ops::Identity(s.WithOpName("identity1"), identity0);
- Output identity2 = ops::Identity(s.WithOpName("identity2"), identity1);
- Output identity3 = ops::Identity(s.WithOpName("identity3"), identity2);
- Graph g(OpRegistry::Global());
- TF_CHECK_OK(s.ToGraph(&g));
- auto node_index = g.BuildNodeNameIndex();
-
- // Set XLA computation/outside compilation attr, and add control edges.
- Node *const0_node = node_index["const0"],
- *identity0_node = node_index["identity0"],
- *identity1_node = node_index["identity1"],
- *identity2_node = node_index["identity2"],
- *identity3_node = node_index["identity3"];
- identity1_node->AddAttr(kXlaConnectedFromOtherXlaComputationAttrName,
- std::vector{"0"});
- identity1_node->AddAttr(kXlaConnectedToOtherXlaComputationAttrName,
- std::vector{"1"});
- identity3_node->AddAttr(kXlaControlDependenciesAttrName,
- std::vector{"const0", "identity1"});
-
- std::unordered_map clusters;
- clusters["0"].node = identity0_node;
- clusters["1"].node = identity2_node;
- TF_CHECK_OK(PostprocessForEncapsulation(&g, "_xla", "_oc", clusters));
-
- // Case 3a: we have control edge identity0 -> identity1, and identity1 ->
- // identity2.
- bool edge_identity0_identity1 = false, edge_identity1_identity2 = false;
- for (const Edge *e : g.edges()) {
- if (!e->IsControlEdge()) {
- continue;
- }
- if (e->src() == identity0_node && e->dst() == identity1_node) {
- edge_identity0_identity1 = true;
- } else if (e->src() == identity1_node && e->dst() == identity2_node) {
- edge_identity1_identity2 = true;
- }
- }
- EXPECT_TRUE(edge_identity0_identity1);
- EXPECT_TRUE(edge_identity1_identity2);
- // Case 3b: we have control edge const0 -> identity3, and identity1 ->
- // identity3.
- bool edge_const0_identity3 = false, edge_identity1_identity3 = false;
- for (const Edge *e : g.edges()) {
- if (!e->IsControlEdge()) {
- continue;
- }
- if (e->src() == const0_node && e->dst() == identity3_node) {
- edge_const0_identity3 = true;
- } else if (e->src() == identity1_node && e->dst() == identity3_node) {
- edge_identity1_identity3 = true;
- }
- }
- EXPECT_TRUE(edge_const0_identity3);
- EXPECT_TRUE(edge_identity1_identity3);
-}
-
-TEST(PostprocessForEncapsulationTest, DataEdges) {
- // Build the graph:
- // "const0" in outside compilation "0"
- // "placeholder0" (for "const0") in host computation
- // "add0" = "placeholder0" + "placeholder0" in host computation
- // "placeholder1" (for "add0") in outside compilation 1
- // "add1" = "placeholder1" + "placeholder1" in outside compilation 1
- //
- // "bridge" = "placeholder0" in host computation
- // "placeholder2" (for "bridge") in outside compilation 1
- // "add2" = "placeholder2" + "placeholder2" in outside compilation 1
- tensorflow::Scope s = tensorflow::Scope::NewRootScope();
- Output const0 = ops::Const(s.WithOpName("const0"), 1, {});
- Output placeholder0 =
- ops::Placeholder(s.WithOpName("placeholder0"), DT_INT32);
- Output add0 = ops::Add(s.WithOpName("add0"), placeholder0, placeholder0);
- Output placeholder1 =
- ops::Placeholder(s.WithOpName("placeholder1"), DT_INT32);
- Output add1 = ops::Add(s.WithOpName("add1"), placeholder1, placeholder1);
- Output bridge = ops::Identity(s.WithOpName("bridge"), placeholder0);
- Output placeholder2 =
- ops::Placeholder(s.WithOpName("placeholder2"), DT_INT32);
- Output add2 = ops::Add(s.WithOpName("add2"), placeholder2, placeholder2);
- Graph g(OpRegistry::Global());
- TF_CHECK_OK(s.ToGraph(&g));
- auto node_index = g.BuildNodeNameIndex();
-
- // Set related attributes.
- Node *placeholder0_node = node_index["placeholder0"];
- placeholder0_node->AddAttr(kOutsideCompilationToHostOriginalNodeAttrName,
- "const0");
- placeholder0_node->AddAttr(kOutsideCompilationToHostSrcOutputAttrName, 0);
- Node *placeholder1_node = node_index["placeholder1"];
- placeholder1_node->AddAttr(kHostToOutsideCompilationOriginalNodeAttrName,
- "add0");
- placeholder1_node->AddAttr(kHostToOutsideCompilationSrcOutputAttrName, 0);
- Node *bridge_node = node_index["bridge"];
- bridge_node->AddAttr(kBridgeSourceNodeAttrName, "const0");
- Node *placeholder2_node = node_index["placeholder2"];
- placeholder2_node->AddAttr(kHostToOutsideCompilationOriginalNodeAttrName,
- "bridge");
- placeholder2_node->AddAttr(kHostToOutsideCompilationSrcOutputAttrName, 0);
-
- std::unordered_map clusters;
- TF_CHECK_OK(PostprocessForEncapsulation(&g, "_xla", "_oc", clusters));
-
- // Result graph should be:
- // "add0" = "const0" + "const0"
- // "add1" = "add0" + "add0"
- // "add2" = "const0" + "const0"
- node_index = g.BuildNodeNameIndex();
- EXPECT_EQ(node_index.size(), 6);
- EXPECT_EQ(node_index["add0"]->def().input(0), "const0:0");
- EXPECT_EQ(node_index["add0"]->def().input(1), "const0:0");
- EXPECT_EQ(node_index["add1"]->def().input(0), "add0:0");
- EXPECT_EQ(node_index["add1"]->def().input(1), "add0:0");
- EXPECT_EQ(node_index["add2"]->def().input(0), "const0:0");
- EXPECT_EQ(node_index["add2"]->def().input(1), "const0:0");
-}
-
} // namespace tensorflow
diff --git a/tensorflow/compiler/jit/encapsulate_xla_computations_pass.cc b/tensorflow/compiler/jit/encapsulate_xla_computations_pass.cc
index d334100aa4a915a87fb05d371e0e3379a7ee05f2..ec745cdbb7e237f8b4935dd41e9791fc75f5355d 100644
--- a/tensorflow/compiler/jit/encapsulate_xla_computations_pass.cc
+++ b/tensorflow/compiler/jit/encapsulate_xla_computations_pass.cc
@@ -297,6 +297,7 @@ Status RewriteSubgraph(const std::vector& arg_source_tensors,
NodeDef def;
def.set_name(launch->name());
+ MergeDebugInfo(NodeDebugInfo(launch->def()), &def);
// Target the XLA CPU/GPU backends.
VLOG(2) << "Replacing with XlaLaunch";
diff --git a/tensorflow/compiler/jit/extract_outside_compilation_pass.cc b/tensorflow/compiler/jit/extract_outside_compilation_pass.cc
index e3c7e2f89be9b37b51a633dabb099969c181013f..1906f1ac850095a27add10d6b22d3bbb0f811ce9 100644
--- a/tensorflow/compiler/jit/extract_outside_compilation_pass.cc
+++ b/tensorflow/compiler/jit/extract_outside_compilation_pass.cc
@@ -20,8 +20,10 @@ limitations under the License.
#include "tensorflow/compiler/jit/encapsulate_subgraphs_pass.h"
#include "tensorflow/compiler/jit/encapsulate_util.h"
#include "tensorflow/compiler/tf2xla/dump_graph.h"
+#include "tensorflow/compiler/tf2xla/side_effect_util.h"
#include "tensorflow/compiler/tf2xla/tf2xla_util.h"
#include "tensorflow/core/common_runtime/function.h"
+#include "tensorflow/core/framework/function.h"
#include "tensorflow/core/framework/graph_to_functiondef.h"
#include "tensorflow/core/framework/node_def_builder.h"
#include "tensorflow/core/framework/node_def_util.h"
@@ -98,9 +100,12 @@ xla::StatusOr BuildRecvAtHostNode(
recv_at_host_builder.Attr("Toutputs", recv_at_host_dtypes);
// The correct device_ordinal will be inserted during replication in a
// subsequent rewrite.
- recv_at_host_builder.Attr("device_ordinal", 0);
+ AttrValue device_ordinal_value;
+ device_ordinal_value.set_placeholder("device_ordinal");
+ recv_at_host_builder.Attr("device_ordinal", device_ordinal_value);
recv_at_host_builder.Attr(
"key", absl::StrCat("host_compute_channel_", oc_cluster_name));
+ recv_at_host_builder.Attr(kXlaHasHostTransferAttrName, true);
recv_at_host_builder.Input(key_placeholder->name(), 0, DT_STRING);
TF_RETURN_IF_ERROR(recv_at_host_builder.Finalize(&recv_at_host_def));
Status s;
@@ -197,9 +202,12 @@ xla::StatusOr BuildSendFromHostNode(
send_from_host_builder.Attr("Tinputs", send_from_host_dtypes);
// The correct device_ordinal will be inserted during replication in a
// subsequent rewrite.
- send_from_host_builder.Attr("device_ordinal", 0);
+ AttrValue device_ordinal_value;
+ device_ordinal_value.set_placeholder("device_ordinal");
+ send_from_host_builder.Attr("device_ordinal", device_ordinal_value);
send_from_host_builder.Attr(
"key", absl::StrCat("host_compute_channel_", oc_cluster_name));
+ send_from_host_builder.Attr(kXlaHasHostTransferAttrName, true);
std::vector inputs(send_from_host_dtypes.size());
for (auto* n : ret_nodes) {
int index;
@@ -357,6 +365,47 @@ Status ReplaceOrRemoveOutsideCompilationCallNode(
return Status::OK();
}
+// Resets "device_ordinal" attr to placeholder value for related nodes
+// (XlaRecvAtHost nodes; XlaSendFromHost nodes; If nodes containing
+// XlaRecvAtHost/XlaSendFromHost).
+Status ResetDeviceOrdinalToPlaceholderValue(Graph* g) {
+ AttrValue device_ordinal_value;
+ device_ordinal_value.set_placeholder("device_ordinal");
+ for (Node* n : g->nodes()) {
+ if (!HasNodeAttr(n->def(), kXlaHasHostTransferAttrName)) {
+ continue;
+ }
+
+ if (n->type_string() == "_XlaRecvAtHost" ||
+ n->type_string() == "_XlaSendFromHost") {
+ n->ClearAttr("device_ordinal");
+ n->AddAttr("device_ordinal", device_ordinal_value);
+ } else if (n->type_string() == "If") {
+ for (const string& attr_name :
+ std::vector{"then_branch", "else_branch"}) {
+ NameAttrList branch_func;
+ TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), attr_name, &branch_func));
+ (*branch_func.mutable_attr())["device_ordinal"] = device_ordinal_value;
+ n->ClearAttr(attr_name);
+ n->AddAttr(attr_name, branch_func);
+ }
+ } else if (n->type_string() == "While") {
+ for (const string& attr_name : std::vector{"cond", "body"}) {
+ NameAttrList branch_func;
+ TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), attr_name, &branch_func));
+ (*branch_func.mutable_attr())["device_ordinal"] = device_ordinal_value;
+ n->ClearAttr(attr_name);
+ n->AddAttr(attr_name, branch_func);
+ }
+ } else {
+ return errors::Internal("Unknown node marked with ",
+ kXlaHasHostTransferAttrName, ": ",
+ n->DebugString());
+ }
+ }
+ return Status::OK();
+}
+
// For an XLA computation, builds host side graph given all outside compilation
// graphs inside it. The host side graph contains:
// 1) a "sequencer" node (we will add control edge between XlaRecvAtHost and
@@ -368,8 +417,8 @@ Status ReplaceOrRemoveOutsideCompilationCallNode(
Status ConstructHostGraph(
const string& xla_cluster_name, const string& outside_compilation_attr_name,
const std::vector& outside_compilation_host_graphs,
- FunctionLibraryDefinition* fld, std::unique_ptr* host_graph) {
- host_graph->reset(new Graph(fld));
+ FunctionLibraryDefinition* fld, const string& host_graph_func_name) {
+ Graph host_graph(fld);
// Create sequencer node in host graph.
NodeDefBuilder sequencer_builder(absl::StrCat(xla_cluster_name, "_sequencer"),
@@ -378,24 +427,34 @@ Status ConstructHostGraph(
NodeDef sequencer_def;
TF_RETURN_IF_ERROR(sequencer_builder.Finalize(&sequencer_def));
Status s;
- Node* sequencer = (*host_graph)->AddNode(sequencer_def, &s);
+ Node* sequencer = host_graph.AddNode(sequencer_def, &s);
TF_RETURN_IF_ERROR(s);
// Create key placeholder in host graph.
TF_ASSIGN_OR_RETURN(
Node * key_placeholder,
- AddHostComputeKeyPlaceholder(xla_cluster_name, host_graph->get()));
+ AddHostComputeKeyPlaceholder(xla_cluster_name, &host_graph));
// For each outside compilation graph, copy them to host graph with the
// following changes:
// a) Use key_placeholder in host graph instead of its own.
- // b) Add control edge from RecvAtHost/SendFromHost to sequencer.
+ // b) Add control edge from host transfer nodes (XlaRecvAtHost,
+ // XlaSendFromHost, If/While nodes containing
+ // XlaRecvAtHost/XlaSendFromHost) to sequencer node.
// c) Clear node_def.device(), so device placer won't get confused.
for (const string& host_func : outside_compilation_host_graphs) {
VLOG(4) << "Expanding host graph " << host_func;
+ // Temporarily use "0" as "device_ordinal". It will be reset to placeholder
+ // value after we expanded all host graphs. We cannot just use placeholder
+ // value here because FunctionDef instantiation does not allow placeholder
+ // value for attributes.
+ AttrValue device_ordinal_attr;
+ device_ordinal_attr.set_i(0);
+ protobuf::Map attrs;
+ attrs["device_ordinal"] = device_ordinal_attr;
FunctionBody* host_fbody = nullptr;
TF_RETURN_IF_ERROR(FunctionDefToBodyHelper(
- *fld->Find(host_func), AttrSlice(), fld,
+ *fld->Find(host_func), AttrSlice(&attrs), fld,
[&](const string& op, const OpDef** sig) {
return fld->LookUpOpDef(op, sig);
},
@@ -408,8 +467,8 @@ Status ConstructHostGraph(
FixupSourceAndSinkEdges(host_fbody->graph);
std::map node_map;
- node_map[host_fbody->graph->source_node()] = (*host_graph)->source_node();
- node_map[host_fbody->graph->sink_node()] = (*host_graph)->sink_node();
+ node_map[host_fbody->graph->source_node()] = host_graph.source_node();
+ node_map[host_fbody->graph->sink_node()] = host_graph.sink_node();
Status s;
ReverseDFS(
*host_fbody->graph, /*enter=*/nullptr,
@@ -431,7 +490,7 @@ Status ConstructHostGraph(
NodeDef copy_def = n->def();
// Change c).
copy_def.clear_device();
- copy = (*host_graph)->AddNode(copy_def, &s);
+ copy = host_graph.AddNode(copy_def, &s);
if (!s.ok()) {
return;
}
@@ -446,22 +505,23 @@ Status ConstructHostGraph(
e->src()->DebugString());
return;
}
- (*host_graph)
- ->AddEdge(node_map[e->src()], e->src_output(), copy,
- e->dst_input());
+ host_graph.AddEdge(node_map[e->src()], e->src_output(), copy,
+ e->dst_input());
}
// Change b).
- if (copy->type_string() == "_XlaRecvAtHost" ||
- copy->type_string() == "_XlaSendFromHost") {
- (*host_graph)->AddControlEdge(copy, sequencer);
+ if (HasNodeAttr(copy->def(), kXlaHasHostTransferAttrName)) {
+ host_graph.AddControlEdge(copy, sequencer);
}
},
NodeComparatorID());
+
if (!s.ok()) {
return s;
}
}
+ // Reset "device_ordinal" to placeholder value.
+ TF_RETURN_IF_ERROR(ResetDeviceOrdinalToPlaceholderValue(&host_graph));
// sequencer and key_placeholder might be dead nodes. Prune them if necessary.
// - sequencer should be pruned iff it has no input control edges from
@@ -470,21 +530,30 @@ Status ConstructHostGraph(
// - key_placeholder should be pruned iff there's no RecvAtHost/SendFromHost.
// We don't need to do anything special.
if (!sequencer->in_edges().empty()) {
- (*host_graph)->AddControlEdge(sequencer, (*host_graph)->sink_node());
+ host_graph.AddControlEdge(sequencer, host_graph.sink_node());
}
PruneForReverseReachability(
- host_graph->get(),
- std::unordered_set{(*host_graph)->sink_node()});
+ &host_graph, std::unordered_set{host_graph.sink_node()});
// Postprocess edges between different outside compilations.
TF_RETURN_IF_ERROR(PostprocessEdgesBetweenOutsideCompilations(
- host_graph->get(), outside_compilation_attr_name));
+ &host_graph, outside_compilation_attr_name));
if (VLOG_IS_ON(4)) {
dump_graph::DumpGraphToFile(
absl::StrCat("extract_outside_compilation_host_graph_for_",
xla_cluster_name),
- **host_graph, fld);
+ host_graph, fld);
+ }
+
+ FunctionDef host_graph_fdef;
+ TF_RETURN_IF_ERROR(
+ GraphToFunctionDef(host_graph, host_graph_func_name, &host_graph_fdef));
+ if (fld->Find(host_graph_func_name)) {
+ TF_RETURN_IF_ERROR(
+ fld->ReplaceFunction(host_graph_func_name, host_graph_fdef));
+ } else {
+ TF_RETURN_IF_ERROR(fld->AddFunctionDef(host_graph_fdef));
}
return Status::OK();
@@ -492,8 +561,28 @@ Status ConstructHostGraph(
// Expand XLA computation's outside compilation host side graph into main graph.
// Add a control edge between sequencer node and the XLA computation node.
-Status ExpandHostGraphIntoMainGraph(Graph* main_graph, Graph* host_graph,
+Status ExpandHostGraphIntoMainGraph(Graph* main_graph,
+ FunctionLibraryDefinition* fld,
+ const string& host_graph_func_name,
Node* xla_computation_node) {
+ // Temporarily use "0" as "device_ordinal". It will be rewritten with the
+ // correct value in a later pass. We cannot just use placeholder value here
+ // because FunctionDef instantiation does not allow placeholder value for
+ // attributes.
+ AttrValue device_ordinal_attr;
+ device_ordinal_attr.set_i(0);
+ protobuf::Map attrs;
+ attrs["device_ordinal"] = device_ordinal_attr;
+ FunctionBody* fbody = nullptr;
+ TF_RETURN_IF_ERROR(FunctionDefToBodyHelper(
+ *fld->Find(host_graph_func_name), AttrSlice(&attrs), fld,
+ [&](const string& op, const OpDef** sig) {
+ return fld->LookUpOpDef(op, sig);
+ },
+ &fbody));
+ std::unique_ptr fbody_deleter(fbody);
+ Graph* host_graph = fbody->graph;
+
// We use ReverseDFS() to copy nodes. Make sure all nodes are reverse
// reachable from sink node so all nodes will be copied.
// TODO(b/77601805): consolidate copy graph functions.
@@ -545,23 +634,25 @@ Status ExpandHostGraphIntoMainGraph(Graph* main_graph, Graph* host_graph,
return s;
}
-// Rewrites shape inference graph for outside compilation.
-// 1. If the outside compilation is a "top-level" one (not in a function of any
-// If/While/etc.), this shape inference graph might have host computation to
-// outside compilation placeholder nodes, which will cause shape inference to
-// fail. However, those nodes are not in `host_graph` any more (because we
-// have executed `PostprocessForEncapsultion`). In this case, we clear the
-// graph, and copy SendFromHost with all its predecessors from `host_graph`.
-// This case is detected by whether the SendFromHost node exists in
-// `host_graph` as well.
-// 2. Remove control edges, and prune nodes that are not useful for shape
-// inference.
+// Rewrites shape inference graph for outside compilation:
+// 1) If XlaSendFromHost also exists in `host_graph`, copy nodes from
+// `host_graph`. Because we might still have outside compilation to outside
+// compilation placeholder nodes in shape inference graph, which will prevent
+// us from inferring XlaSendFromHost shape. But in `host_graph`, we already
+// removed those placeholder nodes.
+// 2) Remove control edges.
+// 3) Prune nodes that are not useful for shape inference.
Status RewriteShapeInferenceGraph(const string& shape_inference_graph_name,
Graph* host_graph,
FunctionLibraryDefinition* fld) {
+ // Use "0" as "device_ordinal". It does not matter for shape inference.
+ AttrValue device_ordinal_attr;
+ device_ordinal_attr.set_i(0);
+ protobuf::Map attrs;
+ attrs["device_ordinal"] = device_ordinal_attr;
FunctionBody* fbody = nullptr;
TF_RETURN_IF_ERROR(FunctionDefToBodyHelper(
- *fld->Find(shape_inference_graph_name), AttrSlice(), fld,
+ *fld->Find(shape_inference_graph_name), AttrSlice(&attrs), fld,
[&](const string& op, const OpDef** sig) {
return fld->LookUpOpDef(op, sig);
},
@@ -650,6 +741,7 @@ Status RewriteShapeInferenceGraph(const string& shape_inference_graph_name,
g->RemoveEdge(e);
}
}
+
// Nodes that are not reverse reachable from SendFromHost are not useful for
// shape inference. Prune them.
PruneForReverseReachability(g,
@@ -669,6 +761,572 @@ Status RewriteShapeInferenceGraph(const string& shape_inference_graph_name,
return Status::OK();
}
+// Builds XlaSendToHost node which sends cond predicate to host.
+xla::StatusOr BuildSendIfPredNode(const string& name,
+ const string& host_transfer_key,
+ Node* pred_node, Graph* g) {
+ NodeDefBuilder send_pred_builder(name, "XlaSendToHost");
+ send_pred_builder.Attr("Tinput", DT_BOOL);
+ send_pred_builder.Attr("key", absl::StrCat(host_transfer_key, "_dtoh_0"));
+ send_pred_builder.Attr(kXlaTokenInputNodesAttrName,
+ std::vector{kXlaTokenArgNodeName});
+ send_pred_builder.Input(pred_node->name(), 0, DT_BOOL);
+ NodeDef send_pred_def;
+ TF_RETURN_IF_ERROR(send_pred_builder.Finalize(&send_pred_def));
+ Status s;
+ Node* send_pred_node = g->AddNode(send_pred_def, &s);
+ TF_RETURN_IF_ERROR(s);
+ g->AddEdge(pred_node, 0, send_pred_node, 0);
+ return send_pred_node;
+}
+
+// Replaces key placeholder node with an _Arg node.
+Status ReplaceKeyPlaceholderWithArgNode(const string& xla_cluster_name,
+ const string& func_name,
+ FunctionLibraryDefinition* fld) {
+ // Temporarily use "0" as "device_ordinal". It will be reset to placeholder
+ // value after rewriting.
+ AttrValue device_ordinal_attr;
+ device_ordinal_attr.set_i(0);
+ protobuf::Map attrs;
+ attrs["device_ordinal"] = device_ordinal_attr;
+ FunctionBody* fbody = nullptr;
+ TF_RETURN_IF_ERROR(FunctionDefToBodyHelper(
+ *fld->Find(func_name), AttrSlice(&attrs), fld,
+ [&](const string& op, const OpDef** sig) {
+ return fld->LookUpOpDef(op, sig);
+ },
+ &fbody));
+ std::unique_ptr fbody_deleter(fbody);
+ Graph* g = fbody->graph;
+
+ // Find or create the key placeholder node.
+ Node* key_placeholder = nullptr;
+ for (Node* n : g->nodes()) {
+ if (IsKeyPlaceholderNode(*n)) {
+ key_placeholder = n;
+ break;
+ }
+ }
+ if (!key_placeholder) {
+ TF_ASSIGN_OR_RETURN(key_placeholder,
+ AddHostComputeKeyPlaceholder(xla_cluster_name, g));
+ }
+
+ // Build the _Arg node, and replace key placeholder node with it.
+ NodeDefBuilder arg_builder("key_arg", FunctionLibraryDefinition::kArgOp);
+ arg_builder.Attr("T", DT_STRING);
+ arg_builder.Attr("index", 0);
+ NodeDef arg_def;
+ TF_RETURN_IF_ERROR(arg_builder.Finalize(&arg_def));
+ TF_RETURN_IF_ERROR(ReplaceNode(g, key_placeholder, arg_def).status());
+
+ // Reset "device_ordinal" to placeholder value.
+ TF_RETURN_IF_ERROR(ResetDeviceOrdinalToPlaceholderValue(g));
+
+ FunctionDef replace_fdef;
+ TF_RETURN_IF_ERROR(GraphToFunctionDef(*g, func_name, &replace_fdef));
+ TF_RETURN_IF_ERROR(fld->ReplaceFunction(func_name, replace_fdef));
+ return Status::OK();
+}
+
+// Builds host side graph for If node.
+Status BuildHostGraphForIfNode(const string& xla_cluster_attr_name,
+ const string& outside_compilation_attr_name,
+ const string& xla_cluster_name,
+ const string& if_node_name,
+ const string& host_transfer_key,
+ const string& host_graph_func_name,
+ FunctionLibraryDefinition* fld,
+ const string& then_branch_host_func_name,
+ const string& else_branch_host_func_name) {
+ Graph host_graph(fld);
+ string outside_compilation_name = absl::StrCat("oc_if_", if_node_name);
+ AttrValue device_ordinal_value;
+ device_ordinal_value.set_placeholder("device_ordinal");
+
+ // Step 1: add key placeholder node.
+ TF_ASSIGN_OR_RETURN(
+ Node * key_placeholder,
+ AddHostComputeKeyPlaceholder(xla_cluster_name, &host_graph));
+
+ // Step 2: build XlaRecvAtHost node to recv predicate.
+ NodeDefBuilder recv_pred_builder(
+ absl::StrCat("recv_oc_if_pred_", if_node_name), "_XlaRecvAtHost");
+ recv_pred_builder.Attr("Toutputs", std::vector{DT_BOOL});
+ recv_pred_builder.Attr("key", host_transfer_key);
+ recv_pred_builder.Attr("device_ordinal", device_ordinal_value);
+ recv_pred_builder.Attr(xla_cluster_attr_name, xla_cluster_name);
+ recv_pred_builder.Attr(outside_compilation_attr_name,
+ outside_compilation_name);
+ recv_pred_builder.Attr(kXlaHasHostTransferAttrName, true);
+ recv_pred_builder.Input(key_placeholder->name(), 0, DT_STRING);
+ NodeDef recv_pred_def;
+ TF_RETURN_IF_ERROR(recv_pred_builder.Finalize(&recv_pred_def));
+ Status s;
+ Node* recv_pred_node = host_graph.AddNode(recv_pred_def, &s);
+ TF_RETURN_IF_ERROR(s);
+ host_graph.AddEdge(key_placeholder, 0, recv_pred_node, 0);
+
+ // Step 3: rewrite `{then, else}_branch_host_func_name`, replace key
+ // placeholder with an _Arg node.
+ TF_RETURN_IF_ERROR(ReplaceKeyPlaceholderWithArgNode(
+ xla_cluster_name, then_branch_host_func_name, fld));
+ TF_RETURN_IF_ERROR(ReplaceKeyPlaceholderWithArgNode(
+ xla_cluster_name, else_branch_host_func_name, fld));
+
+ // Step 4: build If node to choose between `{then, else}_branch_host_graph`.
+ NodeDefBuilder if_builder(absl::StrCat("oc_if_", if_node_name), "If");
+ if_builder.Attr("Tcond", DT_BOOL);
+ if_builder.Attr("Tin", std::vector{DT_STRING});
+ if_builder.Attr("Tout", std::vector{});
+ NameAttrList host_then_branch, host_else_branch;
+ host_then_branch.set_name(then_branch_host_func_name);
+ (*host_then_branch.mutable_attr())["device_ordinal"] = device_ordinal_value;
+ host_else_branch.set_name(else_branch_host_func_name);
+ (*host_else_branch.mutable_attr())["device_ordinal"] = device_ordinal_value;
+ if_builder.Attr("then_branch", host_then_branch);
+ if_builder.Attr("else_branch", host_else_branch);
+ if_builder.Attr(kXlaHasHostTransferAttrName, true);
+ if_builder.Attr(xla_cluster_attr_name, xla_cluster_name);
+ if_builder.Attr(outside_compilation_attr_name, outside_compilation_name);
+ if_builder.Input(recv_pred_node->name(), 0, DT_BOOL);
+ std::vector if_inputs{
+ {key_placeholder->name(), 0, DT_STRING}};
+ if_builder.Input(if_inputs);
+ NodeDef if_def;
+ TF_RETURN_IF_ERROR(if_builder.Finalize(&if_def));
+ Node* if_node = host_graph.AddNode(if_def, &s);
+ TF_RETURN_IF_ERROR(s);
+ host_graph.AddEdge(recv_pred_node, 0, if_node, 0);
+ host_graph.AddEdge(key_placeholder, 0, if_node, 1);
+
+ // Convert `host_graph` to function, and add a "device_ordinal" attr.
+ FunctionDef oc_host_graph_fdef;
+ TF_RETURN_IF_ERROR(GraphToFunctionDef(host_graph, host_graph_func_name,
+ &oc_host_graph_fdef));
+ if (fld->Find(host_graph_func_name)) {
+ TF_RETURN_IF_ERROR(
+ fld->ReplaceFunction(host_graph_func_name, oc_host_graph_fdef));
+ } else {
+ TF_RETURN_IF_ERROR(fld->AddFunctionDef(oc_host_graph_fdef));
+ }
+
+ return Status::OK();
+}
+
+// Rewrites loop cond to add a node which sends loop cond to host.
+Status AddSendLoopPredToLoopCond(FunctionLibraryDefinition* fld,
+ const NameAttrList& loop_cond_func,
+ const string& while_node_name,
+ const string& host_transfer_key) {
+ // Instantiate the loop cond function.
+ FunctionBody* fbody = nullptr;
+ TF_RETURN_IF_ERROR(FunctionDefToBodyHelper(
+ *fld->Find(loop_cond_func.name()), AttrSlice(&loop_cond_func.attr()), fld,
+ [&](const string& op, const OpDef** sig) {
+ return fld->LookUpOpDef(op, sig);
+ },
+ &fbody));
+ std::unique_ptr fbody_deleter(fbody);
+ Graph* g = fbody->graph;
+
+ // Find the _Retval node and the loop cond node.
+ Node* ret_node = nullptr;
+ for (Node* n : g->nodes()) {
+ if (n->type_string() == "_Retval") {
+ if (ret_node) {
+ return errors::Internal("Multiple return node for loop cond function ",
+ loop_cond_func.name(), ": ",
+ ret_node->DebugString(), " and ",
+ n->DebugString());
+ } else {
+ ret_node = n;
+ }
+ }
+ }
+ if (!ret_node) {
+ return errors::Internal("No _Retval node for loop cond function ",
+ loop_cond_func.name());
+ }
+ Node* loop_cond;
+ TF_RETURN_IF_ERROR(ret_node->input_node(0, &loop_cond));
+
+ // Build the XlaSendToHost node.
+ NodeDefBuilder send_loop_cond_builder(
+ absl::StrCat("send_oc_while_cond_", while_node_name), "XlaSendToHost");
+ send_loop_cond_builder.Attr("Tinput", DT_BOOL);
+ send_loop_cond_builder.Attr("key",
+ absl::StrCat(host_transfer_key, "_dtoh_0"));
+ send_loop_cond_builder.Attr(kXlaTokenInputNodesAttrName,
+ std::vector{kXlaTokenArgNodeName});
+ send_loop_cond_builder.Input(loop_cond->name(), 0, DT_BOOL);
+ NodeDef send_loop_cond_def;
+ TF_RETURN_IF_ERROR(send_loop_cond_builder.Finalize(&send_loop_cond_def));
+ Status s;
+ Node* send_loop_cond_node = g->AddNode(send_loop_cond_def, &s);
+ TF_RETURN_IF_ERROR(s);
+ g->AddEdge(loop_cond, 0, send_loop_cond_node, 0);
+
+ // Replace original function.
+ FunctionDef replace_fdef;
+ TF_RETURN_IF_ERROR(
+ GraphToFunctionDef(*g, loop_cond_func.name(), &replace_fdef));
+ TF_RETURN_IF_ERROR(fld->ReplaceFunction(loop_cond_func.name(), replace_fdef));
+
+ return Status::OK();
+}
+
+// Rewrites while loop cond function for host.
+Status RewriteHostWhileLoopCond(
+ const string& cond_host_func_name, const string& while_node_name,
+ const string& host_transfer_key, const string& xla_cluster_attr_name,
+ const string& xla_cluster_name, const string& outside_compilation_attr_name,
+ const string& outside_compilation_name, FunctionLibraryDefinition* fld) {
+ // Replace key placeholder node with _Arg node.
+ TF_RETURN_IF_ERROR(ReplaceKeyPlaceholderWithArgNode(
+ xla_cluster_name, cond_host_func_name, fld));
+
+ // Instantiate cond function.
+ AttrValue device_ordinal_temp_value;
+ device_ordinal_temp_value.set_i(0);
+ protobuf::Map attrs;
+ attrs["device_ordinal"] = device_ordinal_temp_value;
+ FunctionBody* cond_fbody = nullptr;
+ TF_RETURN_IF_ERROR(FunctionDefToBodyHelper(
+ *fld->Find(cond_host_func_name), AttrSlice(&attrs), fld,
+ [&](const string& op, const OpDef** sig) {
+ return fld->LookUpOpDef(op, sig);
+ },
+ &cond_fbody));
+ std::unique_ptr cond_fbody_deleter(cond_fbody);
+ Graph* cond_graph = cond_fbody->graph;
+ Node* key_arg = nullptr;
+ for (Node* n : cond_graph->nodes()) {
+ if (n->type_string() == "_Arg") {
+ key_arg = n;
+ }
+ }
+ if (!key_arg) {
+ return errors::Internal(
+ "No _Arg node found for host compute key in function ",
+ cond_host_func_name);
+ }
+
+ // Add an XlaRecvAtHost node to use as cond function return value.
+ // We don't need to set kXlaHasHostTransferAttrName for this node, because
+ // it's already added for the "While" node on the host.
+ NodeDefBuilder recv_pred_builder(
+ absl::StrCat("recv_oc_while_cond_", while_node_name), "_XlaRecvAtHost");
+ recv_pred_builder.Attr("Toutputs", std::vector{DT_BOOL});
+ recv_pred_builder.Attr("key", host_transfer_key);
+ AttrValue device_ordinal_value;
+ device_ordinal_value.set_placeholder("device_ordinal");
+ recv_pred_builder.Attr("device_ordinal", device_ordinal_value);
+ recv_pred_builder.Attr(xla_cluster_attr_name, xla_cluster_name);
+ recv_pred_builder.Attr(outside_compilation_attr_name,
+ outside_compilation_name);
+ recv_pred_builder.Input(key_arg->name(), 0, DT_STRING);
+ NodeDef recv_pred_def;
+ TF_RETURN_IF_ERROR(recv_pred_builder.Finalize(&recv_pred_def));
+ Status s;
+ Node* recv_pred_node = cond_graph->AddNode(recv_pred_def, &s);
+ TF_RETURN_IF_ERROR(s);
+ cond_graph->AddEdge(key_arg, 0, recv_pred_node, 0);
+ NodeDefBuilder ret_builder(
+ absl::StrCat("recv_oc_while_cond_ret_", while_node_name), "_Retval");
+ ret_builder.Attr("T", DT_BOOL);
+ ret_builder.Attr("index", 0);
+ ret_builder.Input(recv_pred_node->name(), 0, DT_BOOL);
+ NodeDef ret_def;
+ TF_RETURN_IF_ERROR(ret_builder.Finalize(&ret_def));
+ Node* ret_node = cond_graph->AddNode(ret_def, &s);
+ TF_RETURN_IF_ERROR(s);
+ cond_graph->AddEdge(recv_pred_node, 0, ret_node, 0);
+
+ // Reset device_ordinal to placeholder value.
+ TF_RETURN_IF_ERROR(ResetDeviceOrdinalToPlaceholderValue(cond_graph));
+
+ // Replace original function.
+ FunctionDef cond_replace_fdef;
+ TF_RETURN_IF_ERROR(
+ GraphToFunctionDef(*cond_graph, cond_host_func_name, &cond_replace_fdef));
+ TF_RETURN_IF_ERROR(
+ fld->ReplaceFunction(cond_host_func_name, cond_replace_fdef));
+
+ return Status::OK();
+}
+
+// Rewrites while loop body function for host.
+Status RewriteHostWhileLoopBody(
+ const string& body_host_func_name, const string& while_node_name,
+ const string& host_transfer_key, const string& xla_cluster_attr_name,
+ const string& xla_cluster_name, const string& outside_compilation_attr_name,
+ const string& outside_compilation_name, FunctionLibraryDefinition* fld) {
+ // Replace key placeholder node with _Arg node.
+ TF_RETURN_IF_ERROR(ReplaceKeyPlaceholderWithArgNode(
+ xla_cluster_name, body_host_func_name, fld));
+
+ // Instantiate body function.
+ AttrValue device_ordinal_temp_value;
+ device_ordinal_temp_value.set_i(0);
+ protobuf::Map attrs;
+ attrs["device_ordinal"] = device_ordinal_temp_value;
+ FunctionBody* body_fbody = nullptr;
+ TF_RETURN_IF_ERROR(FunctionDefToBodyHelper(
+ *fld->Find(body_host_func_name), AttrSlice(&attrs), fld,
+ [&](const string& op, const OpDef** sig) {
+ return fld->LookUpOpDef(op, sig);
+ },
+ &body_fbody));
+ std::unique_ptr body_fbody_deleter(body_fbody);
+ Graph* body_graph = body_fbody->graph;
+ Node* key_arg = nullptr;
+ for (Node* n : body_graph->nodes()) {
+ if (n->type_string() == "_Arg") {
+ key_arg = n;
+ }
+ }
+ if (!key_arg) {
+ return errors::Internal(
+ "No _Arg node found for host compute key in function ",
+ body_host_func_name);
+ }
+
+ // Add a _Retval node to loop body.
+ NodeDefBuilder ret_builder(
+ absl::StrCat("recv_oc_while_body_ret_", while_node_name), "_Retval");
+ ret_builder.Attr("T", DT_STRING);
+ ret_builder.Attr("index", 0);
+ ret_builder.Input(key_arg->name(), 0, DT_STRING);
+ NodeDef ret_def;
+ TF_RETURN_IF_ERROR(ret_builder.Finalize(&ret_def));
+ Status s;
+ Node* ret_node = body_graph->AddNode(ret_def, &s);
+ TF_RETURN_IF_ERROR(s);
+ body_graph->AddEdge(key_arg, 0, ret_node, 0);
+
+ // Reset device_ordinal to placeholder value.
+ TF_RETURN_IF_ERROR(ResetDeviceOrdinalToPlaceholderValue(body_graph));
+
+ // Replace original function.
+ FunctionDef body_replace_fdef;
+ TF_RETURN_IF_ERROR(
+ GraphToFunctionDef(*body_graph, body_host_func_name, &body_replace_fdef));
+ TF_RETURN_IF_ERROR(
+ fld->ReplaceFunction(body_host_func_name, body_replace_fdef));
+
+ return Status::OK();
+}
+
+// Builds host side graph for while node.
+Status BuildHostGraphForWhileNode(
+ const string& xla_cluster_attr_name,
+ const string& outside_compilation_attr_name, const string& xla_cluster_name,
+ const string& while_node_name, const string& host_transfer_key,
+ const string& host_graph_func_name, FunctionLibraryDefinition* fld,
+ const string& cond_host_func_name, const string& body_host_func_name) {
+ Graph host_graph(fld);
+ string outside_compilation_name = absl::StrCat("oc_while_", while_node_name);
+
+ // Step 1: add key placeholder node.
+ TF_ASSIGN_OR_RETURN(
+ Node * key_placeholder,
+ AddHostComputeKeyPlaceholder(xla_cluster_name, &host_graph));
+
+ // Step 2: rewrite cond function.
+ TF_RETURN_IF_ERROR(RewriteHostWhileLoopCond(
+ cond_host_func_name, while_node_name, host_transfer_key,
+ xla_cluster_attr_name, xla_cluster_name, outside_compilation_attr_name,
+ outside_compilation_name, fld));
+
+ // Step 3: rewrite body function.
+ TF_RETURN_IF_ERROR(RewriteHostWhileLoopBody(
+ body_host_func_name, while_node_name, host_transfer_key,
+ xla_cluster_attr_name, xla_cluster_name, outside_compilation_attr_name,
+ outside_compilation_name, fld));
+
+ // Step 4: build While node.
+ NodeDefBuilder while_builder(absl::StrCat("oc_while_", while_node_name),
+ "While");
+ while_builder.Attr("T", std::vector{DT_STRING});
+ NameAttrList func;
+ AttrValue device_ordinal_value;
+ device_ordinal_value.set_placeholder("device_ordinal");
+ (*func.mutable_attr())["device_ordinal"] = device_ordinal_value;
+ func.set_name(cond_host_func_name);
+ while_builder.Attr("cond", func);
+ func.set_name(body_host_func_name);
+ while_builder.Attr("body", func);
+ while_builder.Attr(kXlaHasHostTransferAttrName, true);
+ while_builder.Attr(xla_cluster_attr_name, xla_cluster_name);
+ while_builder.Attr(outside_compilation_attr_name, outside_compilation_name);
+ std::vector while_inputs{
+ {key_placeholder->name(), 0, DT_STRING}};
+ while_builder.Input(while_inputs);
+ NodeDef while_def;
+ TF_RETURN_IF_ERROR(while_builder.Finalize(&while_def));
+ Status s;
+ Node* while_node = host_graph.AddNode(while_def, &s);
+ TF_RETURN_IF_ERROR(s);
+ host_graph.AddEdge(key_placeholder, 0, while_node, 0);
+
+ // Convert `host_graph` to function.
+ FunctionDef oc_host_graph_fdef;
+ TF_RETURN_IF_ERROR(GraphToFunctionDef(host_graph, host_graph_func_name,
+ &oc_host_graph_fdef));
+ if (fld->Find(host_graph_func_name)) {
+ TF_RETURN_IF_ERROR(
+ fld->ReplaceFunction(host_graph_func_name, oc_host_graph_fdef));
+ } else {
+ TF_RETURN_IF_ERROR(fld->AddFunctionDef(oc_host_graph_fdef));
+ }
+
+ return Status::OK();
+}
+
+Status ExtractOutsideCompilationForNodesWithAssociatedFunctions(
+ Graph* g, const string& xla_cluster_attr_name,
+ const string& outside_compilation_attr_name, const string& xla_cluster_name,
+ const std::map& host_compute_core,
+ FunctionLibraryDefinition* fld, std::vector* host_graphs,
+ std::vector* shape_inference_graphs,
+ bool* has_outside_compilation) {
+ std::vector if_nodes, while_nodes;
+ for (Node* n : g->nodes()) {
+ if (n->type_string() == "If") {
+ if_nodes.push_back(n);
+ } else if (n->type_string() == "While") {
+ while_nodes.push_back(n);
+ }
+ }
+
+ for (Node* n : if_nodes) {
+ // Instantiate "then_branch" and "else_branch".
+ NameAttrList then_branch, else_branch;
+ TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "then_branch", &then_branch));
+ TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "else_branch", &else_branch));
+
+ // Extract outside compilation for then_branch and else_branch.
+ bool then_branch_has_outside_compilation = false;
+ bool else_branch_has_outside_compilation = false;
+ string then_branch_host_func_name =
+ absl::StrCat("oc_then_branch_host_if_", n->name()),
+ else_branch_host_func_name =
+ absl::StrCat("oc_else_branch_host_if_", n->name());
+ string then_branch_xla_func_name = absl::StrCat(then_branch.name(), "_oc"),
+ else_branch_xla_func_name = absl::StrCat(else_branch.name(), "_oc");
+ TF_RETURN_IF_ERROR(ExtractOutsideCompilationForFunction(
+ xla_cluster_attr_name, outside_compilation_attr_name, xla_cluster_name,
+ then_branch, then_branch_xla_func_name, then_branch_host_func_name,
+ host_compute_core, fld, shape_inference_graphs,
+ &then_branch_has_outside_compilation));
+ TF_RETURN_IF_ERROR(ExtractOutsideCompilationForFunction(
+ xla_cluster_attr_name, outside_compilation_attr_name, xla_cluster_name,
+ else_branch, else_branch_xla_func_name, else_branch_host_func_name,
+ host_compute_core, fld, shape_inference_graphs,
+ &else_branch_has_outside_compilation));
+
+ // If then/else branch do not have outside compilation, nothing to do.
+ if (!then_branch_has_outside_compilation &&
+ !else_branch_has_outside_compilation) {
+ continue;
+ }
+
+ *has_outside_compilation = true;
+
+ // Change If node to call the new functions.
+ then_branch.set_name(then_branch_xla_func_name);
+ n->ClearAttr("then_branch");
+ n->AddAttr("then_branch", then_branch);
+ else_branch.set_name(else_branch_xla_func_name);
+ n->ClearAttr("else_branch");
+ n->AddAttr("else_branch", else_branch);
+
+ string host_transfer_key = absl::StrCat("oc_if_pred_", n->name());
+
+ // XLA computation: add a SendToHost node to send cond predicate.
+ Node* pred_node;
+ TF_RETURN_IF_ERROR(n->input_node(0, &pred_node));
+ TF_ASSIGN_OR_RETURN(
+ Node * send_pred_node,
+ BuildSendIfPredNode(absl::StrCat("send_oc_if_pred_", n->name()),
+ host_transfer_key, pred_node, g));
+ n->AddAttr(kXlaTokenInputNodesAttrName,
+ std::vector{send_pred_node->name()});
+
+ // Add a control edge from `send_pred_node` to If node, so XlaCompiler will
+ // visit If node after `send_pred_node`, thus the token output for
+ // `send_pred_node` has been generated.
+ g->AddControlEdge(send_pred_node, n);
+
+ // Build host side graph for the "If" node.
+ string oc_host_graph_name = absl::StrCat("oc_if_host_graph_", n->name());
+ TF_RETURN_IF_ERROR(BuildHostGraphForIfNode(
+ xla_cluster_attr_name, outside_compilation_attr_name, xla_cluster_name,
+ n->name(), host_transfer_key, oc_host_graph_name, fld,
+ then_branch_host_func_name, else_branch_host_func_name));
+ host_graphs->push_back(oc_host_graph_name);
+ }
+
+ for (Node* n : while_nodes) {
+ // Instantiate "cond" and "body".
+ NameAttrList cond, body;
+ TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "cond", &cond));
+ TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "body", &body));
+
+ // Extract outside compilation for cond and body.
+ bool cond_has_outside_compilation = false;
+ bool body_has_outside_compilation = false;
+ string cond_host_func_name = absl::StrCat("oc_cond_host_while_", n->name()),
+ body_host_func_name = absl::StrCat("oc_body_host_while_", n->name());
+ string cond_xla_func_name = absl::StrCat(cond.name(), "_oc"),
+ body_xla_func_name = absl::StrCat(body.name(), "_oc");
+ TF_RETURN_IF_ERROR(ExtractOutsideCompilationForFunction(
+ xla_cluster_attr_name, outside_compilation_attr_name, xla_cluster_name,
+ cond, cond_xla_func_name, cond_host_func_name, host_compute_core, fld,
+ shape_inference_graphs, &cond_has_outside_compilation));
+ TF_RETURN_IF_ERROR(ExtractOutsideCompilationForFunction(
+ xla_cluster_attr_name, outside_compilation_attr_name, xla_cluster_name,
+ body, body_xla_func_name, body_host_func_name, host_compute_core, fld,
+ shape_inference_graphs, &body_has_outside_compilation));
+
+ // If cond/body do not have outside compilation, nothing to do.
+ if (!cond_has_outside_compilation && !body_has_outside_compilation) {
+ continue;
+ }
+
+ *has_outside_compilation = true;
+
+ // Change While node to call the new functions.
+ cond.set_name(cond_xla_func_name);
+ n->ClearAttr("cond");
+ n->AddAttr("cond", cond);
+ body.set_name(body_xla_func_name);
+ n->ClearAttr("body");
+ n->AddAttr("body", body);
+
+ string host_transfer_key = absl::StrCat("oc_while_pred_", n->name());
+
+ // XLA computation: rewrite cond function to add a SendToHost node to send
+ // loop predicate.
+ TF_RETURN_IF_ERROR(
+ AddSendLoopPredToLoopCond(fld, cond, n->name(), host_transfer_key));
+ n->AddAttr(kXlaTokenInputNodesAttrName,
+ std::vector{kXlaTokenArgNodeName});
+
+ // Build host side graph for the "While" node.
+ string oc_host_graph_name = absl::StrCat("oc_while_host_graph_", n->name());
+ TF_RETURN_IF_ERROR(BuildHostGraphForWhileNode(
+ xla_cluster_attr_name, outside_compilation_attr_name, xla_cluster_name,
+ n->name(), host_transfer_key, oc_host_graph_name, fld,
+ cond_host_func_name, body_host_func_name));
+ host_graphs->push_back(oc_host_graph_name);
+ }
+
+ return Status::OK();
+}
+
} // namespace
Status RewriteOutsideCompilationSubgraphFn::operator()(
@@ -755,12 +1413,15 @@ Status RewriteOutsideCompilationSubgraphFn::operator()(
// it with HostCompute node later.
AddNodeAttr("_outside_compilation_subgraph", old_name, node_def);
if (shapes) {
- AddNodeAttr("shape_inference_graph", "", node_def);
+ NameAttrList shape_inference_graph;
+ AddNodeAttr("shape_inference_graph", shape_inference_graph, node_def);
AddNodeAttr("shapes", *shapes, node_def);
} else {
string shape_inference_func_name =
absl::StrCat("_outside_compilation_shape_inference_", new_name);
- AddNodeAttr("shape_inference_graph", shape_inference_func_name, node_def);
+ NameAttrList shape_inference_graph;
+ shape_inference_graph.set_name(shape_inference_func_name);
+ AddNodeAttr("shape_inference_graph", shape_inference_graph, node_def);
AddNodeAttr("shapes", std::vector{}, node_def);
}
AddNodeAttr("ancestors", std::vector{}, node_def);
@@ -775,11 +1436,10 @@ Status ExtractOutsideCompilationForFunction(
const string& xla_cluster_attr_name,
const string& outside_compilation_attr_name, const string& xla_cluster_name,
const NameAttrList& func_name_attrs, const string& new_func_name,
+ const string& host_graph_func_name,
const std::map& host_compute_core,
- FunctionLibraryDefinition* fld, std::unique_ptr* host_graph,
- std::vector* shape_inference_graphs,
+ FunctionLibraryDefinition* fld, std::vector* shape_inference_graphs,
bool* has_outside_compilation) {
- // Early return if function does not have any outside compilation nodes.
const string& func_name = func_name_attrs.name();
const FunctionDef* fdef = fld->Find(func_name);
if (!fdef) {
@@ -792,9 +1452,8 @@ Status ExtractOutsideCompilationForFunction(
break;
}
}
- if (!has_outside_compilation) {
- return Status::OK();
- }
+ // We cannot early return here, because we might have outside compilation in
+ // If/While function body.
// Convert the function to graph.
FunctionBody* fbody = nullptr;
@@ -835,11 +1494,11 @@ Status ExtractOutsideCompilationForFunction(
// If we could not infer shapes for XlaSendFromHost inputs statically, we
// will set the "shape_inference_graph" attribute. In that case, copy
// outside compilation subgraph as shape inference graph in `fld`.
- string shape_inference_graph;
+ NameAttrList shape_inference_graph;
TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "shape_inference_graph",
&shape_inference_graph));
- if (!shape_inference_graph.empty()) {
- shape_inference_graphs->push_back(shape_inference_graph);
+ if (!shape_inference_graph.name().empty()) {
+ shape_inference_graphs->push_back(shape_inference_graph.name());
const FunctionDef* xla_fdef = fld->Find(n->name());
if (!xla_fdef) {
@@ -847,9 +1506,9 @@ Status ExtractOutsideCompilationForFunction(
}
FunctionDef shape_inference_fdef = *xla_fdef;
shape_inference_fdef.mutable_signature()->set_name(
- shape_inference_graph);
- if (fld->Find(shape_inference_graph)) {
- TF_RETURN_IF_ERROR(fld->ReplaceFunction(shape_inference_graph,
+ shape_inference_graph.name());
+ if (fld->Find(shape_inference_graph.name())) {
+ TF_RETURN_IF_ERROR(fld->ReplaceFunction(shape_inference_graph.name(),
shape_inference_fdef));
} else {
TF_RETURN_IF_ERROR(fld->AddFunctionDef(shape_inference_fdef));
@@ -867,12 +1526,17 @@ Status ExtractOutsideCompilationForFunction(
*graph_out, fld);
}
+ // Handle nodes with associated functions.
+ TF_RETURN_IF_ERROR(ExtractOutsideCompilationForNodesWithAssociatedFunctions(
+ graph_out.get(), xla_cluster_attr_name, outside_compilation_attr_name,
+ xla_cluster_name, host_compute_core, fld,
+ &outside_compilation_host_graphs, shape_inference_graphs,
+ has_outside_compilation));
+
// Construct host graph.
- if (!outside_compilation_host_graphs.empty()) {
- TF_RETURN_IF_ERROR(
- ConstructHostGraph(xla_cluster_name, outside_compilation_attr_name,
- outside_compilation_host_graphs, fld, host_graph));
- }
+ TF_RETURN_IF_ERROR(ConstructHostGraph(
+ xla_cluster_name, outside_compilation_attr_name,
+ outside_compilation_host_graphs, fld, host_graph_func_name));
// Remove the outside compilation graphs from function library.
for (const string& func : outside_compilation_host_graphs) {
@@ -909,24 +1573,17 @@ Status ExtractOutsideCompilation(
auto const& host_compute_core = iter.second.host_compute_core;
bool has_outside_compilation;
- std::unique_ptr host_graph;
+ string host_graph_func_name = absl::StrCat("oc_host_graph_", n->name());
TF_RETURN_IF_ERROR(ExtractOutsideCompilationForFunction(
xla_cluster_attr_name, outside_compilation_attr_name, xla_cluster_name,
- func_name_attrs, func_name_attrs.name(), host_compute_core, fld,
- &host_graph, &shape_inference_graphs, &has_outside_compilation));
- if (host_graph) {
- TF_RETURN_IF_ERROR(ExpandHostGraphIntoMainGraph(g, host_graph.get(), n));
- }
- }
-
- if (VLOG_IS_ON(4)) {
- dump_graph::DumpGraphToFile("extract_outside_compilation_expanded", *g,
- fld);
+ func_name_attrs, func_name_attrs.name(), host_graph_func_name,
+ host_compute_core, fld, &shape_inference_graphs,
+ &has_outside_compilation));
+ TF_RETURN_IF_ERROR(
+ ExpandHostGraphIntoMainGraph(g, fld, host_graph_func_name, n));
+ TF_RETURN_IF_ERROR(fld->RemoveFunction(host_graph_func_name));
}
- TF_RETURN_IF_ERROR(PostprocessForEncapsulation(
- g, xla_cluster_attr_name, outside_compilation_attr_name, clusters));
-
for (auto shape_inference_graph_name : shape_inference_graphs) {
TF_RETURN_IF_ERROR(
RewriteShapeInferenceGraph(shape_inference_graph_name, g, fld));
diff --git a/tensorflow/compiler/jit/extract_outside_compilation_pass.h b/tensorflow/compiler/jit/extract_outside_compilation_pass.h
index 2a4f07cca213d999202024294f5d8f94527059c3..e07e7c5dd0cd42ddd4d643d8b36583c82056bbb2 100644
--- a/tensorflow/compiler/jit/extract_outside_compilation_pass.h
+++ b/tensorflow/compiler/jit/extract_outside_compilation_pass.h
@@ -88,9 +88,10 @@ Status ExtractOutsideCompilationForFunction(
const string& xla_cluster_attr_name,
const string& outside_compilation_attr_name, const string& xla_cluster_name,
const NameAttrList& func_name_attrs, const string& new_func_name,
+ const string& host_graph_func_name,
const std::map& host_compute_core,
- FunctionLibraryDefinition* fld, std::unique_ptr* host_graph,
- std::vector* shape_inference_graphs, bool* has_outside_compilation);
+ FunctionLibraryDefinition* fld, std::vector* shape_inference_graphs,
+ bool* has_outside_compilation);
// Rewrites XLA computation in `clusters` to replace outside compilation nodes
// with XlaHostCompute, and moves those outside compilations into `g`. If shapes
diff --git a/tensorflow/compiler/jit/extract_outside_compilation_pass_test.cc b/tensorflow/compiler/jit/extract_outside_compilation_pass_test.cc
index bff956100da661b679b4557fce53671e6cef88c5..e9a89e34e0c7b04b4be34e367b2d0bf627c0061a 100644
--- a/tensorflow/compiler/jit/extract_outside_compilation_pass_test.cc
+++ b/tensorflow/compiler/jit/extract_outside_compilation_pass_test.cc
@@ -19,8 +19,10 @@ limitations under the License.
#include "tensorflow/cc/framework/scope.h"
#include "tensorflow/cc/ops/array_ops.h"
#include "tensorflow/cc/ops/function_ops.h"
+#include "tensorflow/cc/ops/functional_ops.h"
#include "tensorflow/cc/ops/standard_ops.h"
#include "tensorflow/compiler/jit/encapsulate_util.h"
+#include "tensorflow/compiler/xla/test.h"
#include "tensorflow/core/common_runtime/function.h"
#include "tensorflow/core/framework/common_shape_fns.h"
#include "tensorflow/core/framework/function.h"
@@ -109,10 +111,10 @@ TEST(RewriteOutsideCompilationSubgraphFnTest, Basic) {
}
EXPECT_TRUE(has_control_edge_to_send_from_host);
// Verify step 7: necessary attrs added to call_node_def.
- string shape_inference_graph;
+ NameAttrList shape_inference_graph;
TF_CHECK_OK(GetNodeAttr(AttrSlice(&call_node_def.attr()),
"shape_inference_graph", &shape_inference_graph));
- EXPECT_EQ(shape_inference_graph,
+ EXPECT_EQ(shape_inference_graph.name(),
"_outside_compilation_shape_inference_cluster_0");
}
@@ -249,27 +251,26 @@ TEST(ExtractOutsideCompilationForFunctionTest, Basic) {
protobuf::Map attrs;
std::map host_compute_core = {{"0", 1}, {"1", 0}};
- std::unique_ptr host_graph;
std::vector shape_inference_graphs;
bool has_outside_compilation;
NameAttrList name_attrs;
name_attrs.set_name("cluster");
*name_attrs.mutable_attr() = attrs;
TF_CHECK_OK(ExtractOutsideCompilationForFunction(
- "_xla", "_oc", "cluster", name_attrs, "cluster_rewritten",
- host_compute_core, &fld, &host_graph, &shape_inference_graphs,
+ "_xla", "_oc", "cluster", name_attrs, "cluster_rewritten", "host_graph",
+ host_compute_core, &fld, &shape_inference_graphs,
&has_outside_compilation));
// Get rewritten XLA computation function.
- FunctionBody *fbody = nullptr;
- TF_CHECK_OK(FunctionDefToBodyHelper(*fld.Find("cluster_rewritten"),
- AttrSlice(), &fld,
- [&](const string &op, const OpDef **sig) {
- return fld.LookUpOpDef(op, sig);
- },
- &fbody));
- std::unique_ptr fbody_deleter(fbody);
- auto node_name_index = fbody->graph->BuildNodeNameIndex();
+ FunctionBody *xla_fbody = nullptr;
+ TF_CHECK_OK(FunctionDefToBodyHelper(
+ *fld.Find("cluster_rewritten"), AttrSlice(), &fld,
+ [&](const string &op, const OpDef **sig) {
+ return fld.LookUpOpDef(op, sig);
+ },
+ &xla_fbody));
+ std::unique_ptr xla_fbody_deleter(xla_fbody);
+ auto node_name_index = xla_fbody->graph->BuildNodeNameIndex();
// Check XlaHostCompute nodes.
Node *host_compute_0 = node_name_index["outside_compilation_0_host_compute"];
@@ -292,18 +293,31 @@ TEST(ExtractOutsideCompilationForFunctionTest, Basic) {
EXPECT_EQ(shapes[0].dim_size(), 1);
// Check XlaHostCompute nodes' "shape_inference_graph" attr. Both should have
// empty values.
- string shape_inference_graph;
+ NameAttrList shape_inference_graph;
TF_CHECK_OK(GetNodeAttr(host_compute_0->attrs(), "shape_inference_graph",
&shape_inference_graph));
- EXPECT_EQ(shape_inference_graph, "");
+ EXPECT_EQ(shape_inference_graph.name(), "");
TF_CHECK_OK(GetNodeAttr(host_compute_1->attrs(), "shape_inference_graph",
&shape_inference_graph));
- EXPECT_EQ(shape_inference_graph, "");
+ EXPECT_EQ(shape_inference_graph.name(), "");
// Check `shape_inference_graphs`.
EXPECT_EQ(shape_inference_graphs.size(), 0);
- // Check `host_graph`: verify we have key placeholder and sequencer.
+ // Check host graph: verify we have key placeholder and sequencer.
+ FunctionBody *host_fbody = nullptr;
+ AttrValue device_ordinal_temp_value;
+ device_ordinal_temp_value.set_i(0);
+ protobuf::Map host_func_attrs;
+ host_func_attrs["device_ordinal"] = device_ordinal_temp_value;
+ TF_CHECK_OK(FunctionDefToBodyHelper(
+ *fld.Find("host_graph"), AttrSlice(&host_func_attrs), &fld,
+ [&](const string &op, const OpDef **sig) {
+ return fld.LookUpOpDef(op, sig);
+ },
+ &host_fbody));
+ std::unique_ptr host_fbody_deleter(host_fbody);
+ Graph *host_graph = host_fbody->graph;
Node *key_placeholder = nullptr, *sequencer = nullptr;
for (Node *n : host_graph->nodes()) {
if (n->type_string() == "Placeholder" &&
@@ -365,25 +379,37 @@ TEST(ExtractOutsideCompilationForFunctionTest, NoHostGraph) {
protobuf::Map attrs;
std::map host_compute_core = {{"0", 1}, {"1", 0}};
- std::unique_ptr host_graph;
std::vector shape_inference_graphs;
bool has_outside_compilation;
NameAttrList name_attrs;
name_attrs.set_name("cluster");
*name_attrs.mutable_attr() = attrs;
TF_CHECK_OK(ExtractOutsideCompilationForFunction(
- "_xla", "_oc", "cluster", name_attrs, "cluster_rewritten",
- host_compute_core, &fld, &host_graph, &shape_inference_graphs,
+ "_xla", "_oc", "cluster", name_attrs, "cluster_rewritten", "host_graph",
+ host_compute_core, &fld, &shape_inference_graphs,
&has_outside_compilation));
- // Check `host_graph` is empty.
- EXPECT_FALSE(host_graph);
+ // Check host graph is empty.
+ FunctionBody *host_fbody = nullptr;
+ AttrValue device_ordinal_temp_value;
+ device_ordinal_temp_value.set_i(0);
+ protobuf::Map host_func_attrs;
+ host_func_attrs["device_ordinal"] = device_ordinal_temp_value;
+ TF_CHECK_OK(FunctionDefToBodyHelper(
+ *fld.Find("host_graph"), AttrSlice(&host_func_attrs), &fld,
+ [&](const string &op, const OpDef **sig) {
+ return fld.LookUpOpDef(op, sig);
+ },
+ &host_fbody));
+ std::unique_ptr host_fbody_deleter(host_fbody);
+ Graph *host_graph = host_fbody->graph;
+ EXPECT_EQ(host_graph->num_nodes(), 2);
}
TEST(ExtractOutsideCompilationForFunctionTest, XlaHostComputeRemoved) {
// Build the XLA computation func.
// "const0"
- // "const1" (outside compilation clsuter "0")
+ // "const1" (outside compilation cluster "0")
FunctionDefLibrary fdl;
{
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
@@ -401,31 +427,43 @@ TEST(ExtractOutsideCompilationForFunctionTest, XlaHostComputeRemoved) {
protobuf::Map attrs;
std::map host_compute_core = {{"0", 1}, {"1", 0}};
- std::unique_ptr host_graph;
std::vector shape_inference_graphs;
bool has_outside_compilation;
NameAttrList name_attrs;
name_attrs.set_name("cluster");
*name_attrs.mutable_attr() = attrs;
TF_CHECK_OK(ExtractOutsideCompilationForFunction(
- "_xla", "_oc", "cluster", name_attrs, "cluster_rewritten",
- host_compute_core, &fld, &host_graph, &shape_inference_graphs,
+ "_xla", "_oc", "cluster", name_attrs, "cluster_rewritten", "host_graph",
+ host_compute_core, &fld, &shape_inference_graphs,
&has_outside_compilation));
// Check rewritten XLA graph: verify that we have no XlaHostCompute.
- FunctionBody *fbody = nullptr;
- TF_CHECK_OK(FunctionDefToBodyHelper(*fld.Find("cluster_rewritten"),
- AttrSlice(), &fld,
- [&](const string &op, const OpDef **sig) {
- return fld.LookUpOpDef(op, sig);
- },
- &fbody));
- std::unique_ptr fbody_deleter(fbody);
- for (Node *n : fbody->graph->nodes()) {
+ FunctionBody *xla_fbody = nullptr;
+ TF_CHECK_OK(FunctionDefToBodyHelper(
+ *fld.Find("cluster_rewritten"), AttrSlice(), &fld,
+ [&](const string &op, const OpDef **sig) {
+ return fld.LookUpOpDef(op, sig);
+ },
+ &xla_fbody));
+ std::unique_ptr xla_fbody_deleter(xla_fbody);
+ for (Node *n : xla_fbody->graph->nodes()) {
EXPECT_NE(n->type_string(), "XlaHostCompute");
}
- // Check `host_graph`: verify we have no placeholder, but we have "const1".
+ // Check host graph: verify we have no placeholder, but we have "const1".
+ FunctionBody *host_fbody = nullptr;
+ AttrValue device_ordinal_temp_value;
+ device_ordinal_temp_value.set_i(0);
+ protobuf::Map host_func_attrs;
+ host_func_attrs["device_ordinal"] = device_ordinal_temp_value;
+ TF_CHECK_OK(FunctionDefToBodyHelper(
+ *fld.Find("host_graph"), AttrSlice(&host_func_attrs), &fld,
+ [&](const string &op, const OpDef **sig) {
+ return fld.LookUpOpDef(op, sig);
+ },
+ &host_fbody));
+ std::unique_ptr host_fbody_deleter(host_fbody);
+ Graph *host_graph = host_fbody->graph;
int num_key_placeholders = 0;
for (Node *n : host_graph->nodes()) {
if (n->type_string() == "Placeholder" &&
@@ -438,4 +476,310 @@ TEST(ExtractOutsideCompilationForFunctionTest, XlaHostComputeRemoved) {
EXPECT_NE(node_name_index.find("const1"), node_name_index.end());
}
+REGISTER_OP("XlaSendToHost")
+ .Input("input: Tinput")
+ .Attr("Tinput: type")
+ .Attr("key: string")
+ .SetIsStateful();
+
+REGISTER_OP("XlaRecvFromHost")
+ .Output("output: Toutput")
+ .Attr("Toutput: type")
+ .Attr("shape: shape")
+ .Attr("key: string")
+ .SetIsStateful();
+
+TEST(ExtractOutsideCompilationForFunctionTest, OutsideCompilationInIf) {
+ // Build the XLA computation func.
+ // "const0" (bool)
+ // "const1" (int32)
+ // "if0" (pred = "const0", input = "const1", then_branch = "true_fn",
+ // else_branch = "false_fn")
+ FunctionDefLibrary fdl;
+ {
+ tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+ Output arg = ops::_Arg(s.WithOpName("arg"), DT_INT32, 0);
+ Output identity = ops::Identity(s.WithOpName("identity_true_fn"), arg);
+ ops::_Retval retval(s.WithOpName("retval"), identity, 0);
+ std::unique_ptr g(new Graph(OpRegistry::Global()));
+ TF_CHECK_OK(s.ToGraph(g.get()));
+ auto node_name_image = g->BuildNodeNameIndex();
+ node_name_image["identity_true_fn"]->AddAttr("_oc", "0");
+ PartialTensorShape shape({2});
+ node_name_image["identity_true_fn"]->AddAttr(
+ kXlaInferredShapesAttrName, std::vector{shape});
+
+ FunctionDef *true_fn_fdef = fdl.add_function();
+ TF_CHECK_OK(GraphToFunctionDef(*g, "true_fn", true_fn_fdef));
+ }
+ {
+ tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+ Output arg = ops::_Arg(s.WithOpName("arg"), DT_INT32, 0);
+ Output identity = ops::Identity(s.WithOpName("identity_false_fn"), arg);
+ ops::_Retval retval(s.WithOpName("retval"), identity, 0);
+ std::unique_ptr g(new Graph(OpRegistry::Global()));
+ TF_CHECK_OK(s.ToGraph(g.get()));
+ auto node_name_image = g->BuildNodeNameIndex();
+ node_name_image["identity_false_fn"]->AddAttr("_oc", "0");
+ PartialTensorShape shape({2});
+ node_name_image["identity_false_fn"]->AddAttr(
+ kXlaInferredShapesAttrName, std::vector{shape});
+
+ FunctionDef *false_fn_fdef = fdl.add_function();
+ TF_CHECK_OK(GraphToFunctionDef(*g, "false_fn", false_fn_fdef));
+ }
+ {
+ tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+ Output cond = ops::Const(s.WithOpName("const0"), true, {2});
+ Output input = ops::Const(s.WithOpName("const1"), 1, {2});
+ NameAttrList true_fn;
+ true_fn.set_name("true_fn");
+ NameAttrList false_fn;
+ false_fn.set_name("false_fn");
+ auto if_op = ops::If(s.WithOpName("if"), cond,
+ std::initializer_list{cond, input}, {DT_INT32},
+ true_fn, false_fn);
+ ops::_Retval retval(s.WithOpName("retval"), if_op.output[0], 0);
+ std::unique_ptr g(new Graph(OpRegistry::Global()));
+ TF_CHECK_OK(s.ToGraph(g.get()));
+
+ FunctionDef *xla_fdef = fdl.add_function();
+ TF_CHECK_OK(GraphToFunctionDef(*g, "cluster", xla_fdef));
+ }
+ FunctionLibraryDefinition fld(OpRegistry::Global(), fdl);
+
+ protobuf::Map attrs;
+ std::map host_compute_core;
+ std::vector shape_inference_graphs;
+ bool has_outside_compilation;
+ NameAttrList name_attrs;
+ name_attrs.set_name("cluster");
+ *name_attrs.mutable_attr() = attrs;
+ TF_CHECK_OK(ExtractOutsideCompilationForFunction(
+ "_xla", "_oc", "cluster", name_attrs, "cluster_rewritten", "host_graph",
+ host_compute_core, &fld, &shape_inference_graphs,
+ &has_outside_compilation));
+
+ // Check host graph.
+ {
+ FunctionBody *host_fbody = nullptr;
+ AttrValue device_ordinal_temp_value;
+ device_ordinal_temp_value.set_i(0);
+ protobuf::Map host_func_attrs;
+ host_func_attrs["device_ordinal"] = device_ordinal_temp_value;
+ TF_CHECK_OK(FunctionDefToBodyHelper(
+ *fld.Find("host_graph"), AttrSlice(&host_func_attrs), &fld,
+ [&](const string &op, const OpDef **sig) {
+ return fld.LookUpOpDef(op, sig);
+ },
+ &host_fbody));
+ std::unique_ptr host_fbody_deleter(host_fbody);
+ Graph *host_graph = host_fbody->graph;
+ auto node_name_index = host_graph->BuildNodeNameIndex();
+
+ // Verify we have XlaRecvAtHost to receive "If" predicate.
+ Node *recv_if_pred_node = node_name_index["recv_oc_if_pred_if"];
+ EXPECT_NE(recv_if_pred_node, nullptr);
+
+ // Verify we have an "If" to choose outside compilation between then_branch
+ // and else_branch, and it has `recv_if_pred_node` as cond input.
+ Node *if_oc_node = node_name_index["oc_if_if"];
+ EXPECT_NE(if_oc_node, nullptr);
+ Node *if_oc_node_cond_input;
+ TF_CHECK_OK(if_oc_node->input_node(0, &if_oc_node_cond_input));
+ EXPECT_EQ(if_oc_node_cond_input, recv_if_pred_node);
+
+ // Check that then_branch outside compilation has node "identity_true_fn".
+ const FunctionDef *true_def = fld.Find("oc_then_branch_host_if_if");
+ EXPECT_NE(true_def, nullptr);
+ bool has_identity_true_fn_node = false;
+ for (const auto &node_def : true_def->node_def()) {
+ if (node_def.name() == "identity_true_fn") {
+ has_identity_true_fn_node = true;
+ break;
+ }
+ }
+ EXPECT_TRUE(has_identity_true_fn_node);
+
+ // Check that else_branch outside compilation has node "identity_false_fn".
+ const FunctionDef *false_def = fld.Find("oc_else_branch_host_if_if");
+ EXPECT_NE(false_def, nullptr);
+ bool has_identity_false_fn_node = false;
+ for (const auto &node_def : false_def->node_def()) {
+ if (node_def.name() == "identity_false_fn") {
+ has_identity_false_fn_node = true;
+ break;
+ }
+ }
+ EXPECT_TRUE(has_identity_false_fn_node);
+ }
+
+ // Check XLA graph.
+ {
+ FunctionBody *xla_fbody = nullptr;
+ TF_CHECK_OK(FunctionDefToBodyHelper(
+ *fld.Find("cluster_rewritten"), AttrSlice(), &fld,
+ [&](const string &op, const OpDef **sig) {
+ return fld.LookUpOpDef(op, sig);
+ },
+ &xla_fbody));
+ std::unique_ptr xla_fbody_deleter(xla_fbody);
+ Graph *xla_graph = xla_fbody->graph;
+ auto node_name_index = xla_graph->BuildNodeNameIndex();
+
+ // Check that we have XlaSendToHost to send cond predicate to host, and
+ // there is a control edge to If node.
+ Node *send_if_pred_node = node_name_index["send_oc_if_pred_if"];
+ EXPECT_NE(send_if_pred_node, nullptr);
+ bool has_control_edge_to_if = false;
+ for (const Edge *e : send_if_pred_node->out_edges()) {
+ if (e->IsControlEdge() && e->dst()->name() == "if") {
+ has_control_edge_to_if = true;
+ break;
+ }
+ }
+ EXPECT_TRUE(has_control_edge_to_if);
+
+ // Check that the "If" node now has `send_if_pred_node` as attribute
+ // _xla_token_input_nodes.
+ Node *if_node = node_name_index["if"];
+ EXPECT_NE(if_node, nullptr);
+ std::vector token_inputs;
+ TF_CHECK_OK(
+ GetNodeAttr(if_node->def(), "_xla_token_input_nodes", &token_inputs));
+ EXPECT_THAT(token_inputs, ::testing::ElementsAre("send_oc_if_pred_if"));
+ }
+}
+
+TEST(ExtractOutsideCompilationForFunctionTest, OutsideCompilationInWhile) {
+ // Build the XLA computation func.
+ // "const0" (bool)
+ // "while0" (input = "const0", cond = "cond_fn", body = "body_fn")
+ FunctionDefLibrary fdl;
+ {
+ tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+ Output arg = ops::_Arg(s.WithOpName("arg"), DT_BOOL, 0);
+ Output identity = ops::Identity(s.WithOpName("identity_cond_fn"), arg);
+ ops::_Retval retval(s.WithOpName("retval"), identity, 0);
+ std::unique_ptr g(new Graph(OpRegistry::Global()));
+ TF_CHECK_OK(s.ToGraph(g.get()));
+ auto node_name_image = g->BuildNodeNameIndex();
+ node_name_image["identity_cond_fn"]->AddAttr("_oc", "0");
+ PartialTensorShape shape({2});
+ node_name_image["identity_cond_fn"]->AddAttr(
+ kXlaInferredShapesAttrName, std::vector{shape});
+
+ FunctionDef *cond_fn_fdef = fdl.add_function();
+ TF_CHECK_OK(GraphToFunctionDef(*g, "cond_fn", cond_fn_fdef));
+ }
+ {
+ tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+ Output arg = ops::_Arg(s.WithOpName("arg"), DT_BOOL, 0);
+ Output identity = ops::Identity(s.WithOpName("identity_body_fn"), arg);
+ ops::_Retval retval(s.WithOpName("retval"), identity, 0);
+ std::unique_ptr g(new Graph(OpRegistry::Global()));
+ TF_CHECK_OK(s.ToGraph(g.get()));
+ auto node_name_image = g->BuildNodeNameIndex();
+ node_name_image["identity_body_fn"]->AddAttr("_oc", "0");
+ PartialTensorShape shape({2});
+ node_name_image["identity_body_fn"]->AddAttr(
+ kXlaInferredShapesAttrName, std::vector{shape});
+
+ FunctionDef *body_fn_fdef = fdl.add_function();
+ TF_CHECK_OK(GraphToFunctionDef(*g, "body_fn", body_fn_fdef));
+ }
+ {
+ tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+ Output input = ops::Const(s.WithOpName("const0"), true, {2});
+ NameAttrList cond_fn;
+ cond_fn.set_name("cond_fn");
+ NameAttrList body_fn;
+ body_fn.set_name("body_fn");
+ auto while_op =
+ ops::While(s.WithOpName("while"), std::initializer_list{input},
+ cond_fn, body_fn);
+ ops::_Retval retval(s.WithOpName("retval"), while_op.output[0], 0);
+ std::unique_ptr g(new Graph(OpRegistry::Global()));
+ TF_CHECK_OK(s.ToGraph(g.get()));
+
+ FunctionDef *xla_fdef = fdl.add_function();
+ TF_CHECK_OK(GraphToFunctionDef(*g, "cluster", xla_fdef));
+ }
+ FunctionLibraryDefinition fld(OpRegistry::Global(), fdl);
+
+ protobuf::Map attrs;
+ std::map host_compute_core;
+ std::vector shape_inference_graphs;
+ bool has_outside_compilation;
+ NameAttrList name_attrs;
+ name_attrs.set_name("cluster");
+ *name_attrs.mutable_attr() = attrs;
+ TF_CHECK_OK(ExtractOutsideCompilationForFunction(
+ "_xla", "_oc", "cluster", name_attrs, "cluster_rewritten", "host_graph",
+ host_compute_core, &fld, &shape_inference_graphs,
+ &has_outside_compilation));
+
+ // Check host graph.
+ {
+ FunctionBody *host_fbody = nullptr;
+ AttrValue device_ordinal_temp_value;
+ device_ordinal_temp_value.set_i(0);
+ protobuf::Map host_func_attrs;
+ host_func_attrs["device_ordinal"] = device_ordinal_temp_value;
+ TF_CHECK_OK(FunctionDefToBodyHelper(
+ *fld.Find("host_graph"), AttrSlice(&host_func_attrs), &fld,
+ [&](const string &op, const OpDef **sig) {
+ return fld.LookUpOpDef(op, sig);
+ },
+ &host_fbody));
+ std::unique_ptr host_fbody_deleter(host_fbody);
+ Graph *host_graph = host_fbody->graph;
+ auto node_name_index = host_graph->BuildNodeNameIndex();
+
+ // Verify we have an "While" to execute outside compilation.
+ Node *while_oc_node = node_name_index["oc_while_while"];
+ EXPECT_NE(while_oc_node, nullptr);
+
+ // Check that cond outside compilation has node "identity_cond_fn".
+ const FunctionDef *cond_def = fld.Find("oc_cond_host_while_while");
+ EXPECT_NE(cond_def, nullptr);
+ bool has_identity_cond_fn_node = false;
+ for (const auto &node_def : cond_def->node_def()) {
+ if (node_def.name() == "identity_cond_fn") {
+ has_identity_cond_fn_node = true;
+ break;
+ }
+ }
+ EXPECT_TRUE(has_identity_cond_fn_node);
+
+ // Check that body outside compilation has node "identity_body_fn".
+ const FunctionDef *body_def = fld.Find("oc_body_host_while_while");
+ EXPECT_NE(body_def, nullptr);
+ bool has_identity_body_fn_node = false;
+ for (const auto &node_def : body_def->node_def()) {
+ if (node_def.name() == "identity_body_fn") {
+ has_identity_body_fn_node = true;
+ break;
+ }
+ }
+ EXPECT_TRUE(has_identity_body_fn_node);
+ }
+
+ // Check XLA graph.
+ {
+ // Verify that rewritten cond fn has XlaSendToHost to send loop predicate to
+ // host.
+ const FunctionDef *cond_def = fld.Find("cond_fn_oc");
+ EXPECT_NE(cond_def, nullptr);
+ bool has_send_oc_while_cond_node = false;
+ for (const auto &node_def : cond_def->node_def()) {
+ if (node_def.name() == "send_oc_while_cond_while") {
+ has_send_oc_while_cond_node = true;
+ break;
+ }
+ }
+ EXPECT_TRUE(has_send_oc_while_cond_node);
+ }
+}
+
} // namespace tensorflow
diff --git a/tensorflow/compiler/jit/partially_decluster_pass.cc b/tensorflow/compiler/jit/partially_decluster_pass.cc
index 42ea3926e16ae791dbe1bede3b8742383db7667c..e1fd2aaee2822daeffb415d053c9c4f56002a856 100644
--- a/tensorflow/compiler/jit/partially_decluster_pass.cc
+++ b/tensorflow/compiler/jit/partially_decluster_pass.cc
@@ -120,6 +120,7 @@ Status PartiallyDeclusterNode(Graph* graph, Node* n) {
NodeDef ndef = n->def();
ndef.set_name(absl::StrCat(n->name(), "/declustered"));
+ MergeDebugInfo(NodeDebugInfo(n->def()), &ndef);
RemoveFromXlaCluster(&ndef);
Status s;
Node* cloned_node = graph->AddNode(ndef, &s);
diff --git a/tensorflow/compiler/jit/shape_inference.cc b/tensorflow/compiler/jit/shape_inference.cc
index 80c691fe490c1092315708a2da754d367d585300..a27e0d9f2a6ecddfdbdb29be673084d77a178d8a 100644
--- a/tensorflow/compiler/jit/shape_inference.cc
+++ b/tensorflow/compiler/jit/shape_inference.cc
@@ -53,7 +53,15 @@ Status PropagateShapes(const Graph& graph,
// shapes, even if no shape function is registered for a node.
Status status = shape_refiner->AddNode(n);
if (!status.ok()) {
- VLOG(1) << "Shape inference failed for node: " << status;
+ VLOG(1) << "Shape inference failed for node " << n->name() << ": "
+ << status;
+ } else {
+ shape_inference::InferenceContext* context = shape_refiner->GetContext(n);
+ for (int i = 0; i < n->num_outputs(); i++) {
+ shape_inference::ShapeHandle handle = context->output(i);
+ VLOG(4) << "Output " << i << " for node " << n->name() << ": "
+ << context->DebugString(handle);
+ }
}
if (n->type_string() == "_Arg") {
diff --git a/tensorflow/compiler/jit/xla_device_context.cc b/tensorflow/compiler/jit/xla_device_context.cc
index 6e6532731e64bd42ee56aa719748988f321e0f17..1f3afe8822d441a5ce37617fe18d7767e9bc72e4 100644
--- a/tensorflow/compiler/jit/xla_device_context.cc
+++ b/tensorflow/compiler/jit/xla_device_context.cc
@@ -79,6 +79,13 @@ XlaDeviceContext::XlaDeviceContext(
}
}
+void XlaDeviceContext::CopyTensorInSameDevice(const Tensor* input_tensor,
+ Device* device,
+ Tensor* output_tensor,
+ StatusCallback done) const {
+ done(errors::Unimplemented("XLA->XLA same-device copies not implemented."));
+}
+
void XlaDeviceContext::CopyCPUTensorToDevice(const Tensor* cpu_tensor,
Device* device,
Tensor* device_tensor,
diff --git a/tensorflow/compiler/jit/xla_device_context.h b/tensorflow/compiler/jit/xla_device_context.h
index 1e18df197a2dd65590c5181b4dae4481dca36641..e45db989fac720df6c3458c93a6b8dbb0919f930 100644
--- a/tensorflow/compiler/jit/xla_device_context.h
+++ b/tensorflow/compiler/jit/xla_device_context.h
@@ -62,6 +62,9 @@ class XlaDeviceContext : public DeviceContext {
void CopyDeviceTensorToCPU(const Tensor* device_tensor,
absl::string_view tensor_name, Device* device,
Tensor* cpu_tensor, StatusCallback done) override;
+ void CopyTensorInSameDevice(const Tensor* input_tensor, Device* device,
+ Tensor* output_tensor,
+ StatusCallback done) const override;
xla::LocalClient* client() const { return client_; }
se::Stream* stream() const { return stream_.get(); }
diff --git a/tensorflow/compiler/tests/depthwise_conv_op_test.py b/tensorflow/compiler/tests/depthwise_conv_op_test.py
index 174bfa9efbcd7dcb4f895237eb01c17bc4a3a6b4..90146e6b27ca31304a2549ec247412341efe390c 100644
--- a/tensorflow/compiler/tests/depthwise_conv_op_test.py
+++ b/tensorflow/compiler/tests/depthwise_conv_op_test.py
@@ -350,8 +350,13 @@ class DepthwiseConv2DTest(xla_test.XLATestCase):
self._CompareBackpropInput(input_size, filter_size, output_size, stride,
padding)
- def _CompareBackpropFilter(self, input_sizes, filter_sizes, output_sizes,
- stride, padding):
+ def _CompareBackpropFilter(self,
+ input_sizes,
+ filter_sizes,
+ output_sizes,
+ stride,
+ padding,
+ data_format="NHWC"):
x0 = np.random.rand(*input_sizes).astype(np.float32)
x2 = np.random.rand(*output_sizes).astype(np.float32)
@@ -360,13 +365,30 @@ class DepthwiseConv2DTest(xla_test.XLATestCase):
t0 = array_ops.placeholder(np.float32, shape=input_sizes)
t1 = constant_op.constant(filter_sizes, shape=[len(filter_sizes)])
t2 = array_ops.placeholder(np.float32, shape=output_sizes)
+ native_t0 = t0
+ native_t2 = t2
+ strides = [1, stride, stride, 1]
+
if use_xla:
+ if data_format == "NCHW":
+ # Transpose from NWHC input to NCHW
+ # Ex. [4, 5, 5, 48] to [4, 48, 5, 5]
+ native_t0 = array_ops.transpose(t0, [0, 3, 1, 2])
+ native_t2 = array_ops.transpose(t2, [0, 3, 1, 2])
+ strides = [1, 1, stride, stride]
with self.test_scope():
backprop = nn_ops.depthwise_conv2d_native_backprop_filter(
- t0, t1, t2, strides=[1, stride, stride, 1], padding=padding)
+ native_t0,
+ t1,
+ native_t2,
+ strides=strides,
+ padding=padding,
+ data_format=data_format)
else:
+ # For CPU, the format NCHW is not supported. Therefore we always use
+ # NHWC here.
backprop = nn_ops.depthwise_conv2d_native_backprop_filter(
- t0, t1, t2, strides=[1, stride, stride, 1], padding=padding)
+ native_t0, t1, native_t2, strides=strides, padding=padding)
ret = backprop.eval({t0: x0, t2: x2})
self.assertShapeEqual(ret, backprop)
return ret
@@ -379,11 +401,24 @@ class DepthwiseConv2DTest(xla_test.XLATestCase):
for index, (input_size, filter_size, output_size, stride,
padding) in enumerate(ConfigsToTest()):
print("Testing DepthwiseConv2DFilterGradCompare,", index, "th config:",
- input_size, "*", filter_size, "stride:", stride, "padding:",
- padding)
+ input_size, "*", filter_size, "producing output", output_size,
+ "stride:", stride, "padding:", padding)
self._CompareBackpropFilter(input_size, filter_size, output_size,
stride, padding)
+ def testDepthwiseConv2DFilterGradFormatNCHWCompare(self):
+ for index, (input_size, filter_size, output_size, stride,
+ padding) in enumerate(ConfigsToTest()):
+ print("Testing DepthwiseConv2DFilterGradFormatNCHWCompare,", index,
+ "th config:", input_size, "*", filter_size, "producing output",
+ output_size, "stride:", stride, "padding:", padding)
+ self._CompareBackpropFilter(
+ input_size,
+ filter_size,
+ output_size,
+ stride,
+ padding,
+ data_format="NCHW")
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/compiler/tf2xla/functionalize_cond.cc b/tensorflow/compiler/tf2xla/functionalize_cond.cc
index c693e42d26712d55852f45c806215fc1f1b9a030..7ae96e1d484900e28e8c23c3bb2232401144ad82 100644
--- a/tensorflow/compiler/tf2xla/functionalize_cond.cc
+++ b/tensorflow/compiler/tf2xla/functionalize_cond.cc
@@ -640,7 +640,8 @@ Status Conditional::ExtractBodies(Graph* graph) {
Status Conditional::BuildIfNode(Graph* graph,
FunctionLibraryDefinition* library) {
VLOG(2) << "Build cond function for " << name();
- NodeDefBuilder builder(name(), "If", library);
+ NodeDebugInfo debug_info((*merges_.begin())->def());
+ NodeDefBuilder builder(name(), "If", library, &debug_info);
const string branch_name[] = {"else_branch", "then_branch"};
for (auto branch : {BranchType::kElseBranch, BranchType::kThenBranch}) {
int branch_index = static_cast(branch);
diff --git a/tensorflow/compiler/tf2xla/kernels/BUILD b/tensorflow/compiler/tf2xla/kernels/BUILD
index 8bc329229648c5aced8d06c99b170803bb3a90f8..a18a4e92d62787051f6ab92e72ee8bf0d1060dca 100644
--- a/tensorflow/compiler/tf2xla/kernels/BUILD
+++ b/tensorflow/compiler/tf2xla/kernels/BUILD
@@ -1,16 +1,11 @@
+load("//tensorflow:tensorflow.bzl", "tf_copts", "tf_kernel_library")
+
licenses(["notice"]) # Apache 2.0
package(
default_visibility = ["//tensorflow/compiler/tf2xla:internal"],
)
-load("//tensorflow:tensorflow.bzl", "tf_copts")
-load("//tensorflow:tensorflow.bzl", "tf_kernel_library")
-load(
- "//third_party/mkl:build_defs.bzl",
- "if_mkl",
-)
-
tf_kernel_library(
name = "xla_ops",
srcs = [
@@ -122,12 +117,9 @@ tf_kernel_library(
"//tensorflow/compiler/tf2xla:common",
"//tensorflow/compiler/tf2xla:xla_compiler",
"//tensorflow/compiler/tf2xla/lib:broadcast",
- "//tensorflow/compiler/tf2xla/lib:cholesky",
- "//tensorflow/compiler/tf2xla/lib:qr",
"//tensorflow/compiler/tf2xla/lib:random",
"//tensorflow/compiler/tf2xla/lib:scatter",
"//tensorflow/compiler/tf2xla/lib:util",
- "//tensorflow/compiler/tf2xla/lib:while_loop",
"//tensorflow/compiler/tf2xla/ops:xla_ops",
"//tensorflow/compiler/xla:array4d",
"//tensorflow/compiler/xla:literal",
@@ -140,11 +132,14 @@ tf_kernel_library(
"//tensorflow/compiler/xla/client:xla_builder",
"//tensorflow/compiler/xla/client:xla_computation",
"//tensorflow/compiler/xla/client/lib:arithmetic",
+ "//tensorflow/compiler/xla/client/lib:cholesky",
"//tensorflow/compiler/xla/client/lib:constants",
+ "//tensorflow/compiler/xla/client/lib:loops",
"//tensorflow/compiler/xla/client/lib:math",
"//tensorflow/compiler/xla/client/lib:matrix",
"//tensorflow/compiler/xla/client/lib:pooling",
"//tensorflow/compiler/xla/client/lib:prng",
+ "//tensorflow/compiler/xla/client/lib:qr",
"//tensorflow/compiler/xla/client/lib:sorting",
"//tensorflow/compiler/xla/client/lib:triangular_solve",
"//tensorflow/core:framework",
diff --git a/tensorflow/compiler/tf2xla/kernels/cholesky_op.cc b/tensorflow/compiler/tf2xla/kernels/cholesky_op.cc
index 9fcbc86adc0967cbb7fb73da8bdabc58b60953da..0ed3044efa5b1060d2b0ad2d5563b0e02ebf66ec 100644
--- a/tensorflow/compiler/tf2xla/kernels/cholesky_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/cholesky_op.cc
@@ -13,9 +13,9 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#include "tensorflow/compiler/tf2xla/lib/cholesky.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
+#include "tensorflow/compiler/xla/client/lib/cholesky.h"
namespace tensorflow {
namespace {
@@ -24,7 +24,7 @@ class CholeskyOp : public XlaOpKernel {
public:
explicit CholeskyOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
void Compile(XlaOpKernelContext* ctx) override {
- ctx->SetOutput(0, Cholesky(ctx->Input(0)));
+ ctx->SetOutput(0, xla::Cholesky(ctx->Input(0)));
}
};
diff --git a/tensorflow/compiler/tf2xla/kernels/conv_op_helpers.cc b/tensorflow/compiler/tf2xla/kernels/conv_op_helpers.cc
index 641fefafb357f6ad10483c454600f3dadd4f8cb7..4124b258c7788e3850f07cbf4d53930784c635fd 100644
--- a/tensorflow/compiler/tf2xla/kernels/conv_op_helpers.cc
+++ b/tensorflow/compiler/tf2xla/kernels/conv_op_helpers.cc
@@ -392,23 +392,31 @@ xla::StatusOr MakeXlaBackpropFilterConvOp(
builder->GetShape(activations));
TF_ASSIGN_OR_RETURN(xla::Shape out_backprop_shape,
builder->GetShape(gradients));
+ xla::XlaOp filter_backprop;
+
+ xla::Shape input_shape = activations_shape;
+ xla::Shape output_shape = out_backprop_shape;
+
+ TensorShape input_tensor_shape, filter_tensor_shape, output_tensor_shape;
+ TF_RETURN_IF_ERROR(XLAShapeToTensorShape(filter_shape, &filter_tensor_shape));
+ TF_RETURN_IF_ERROR(XLAShapeToTensorShape(input_shape, &input_tensor_shape));
+ TF_RETURN_IF_ERROR(XLAShapeToTensorShape(output_shape, &output_tensor_shape));
+
const xla::Shape expanded_filter_shape =
attrs.depthwise ? ExpandedFilterShapeForDepthwiseConvolution(filter_shape)
: filter_shape;
-
// Reuse dimension computation logic from conv_grad_ops.cc.
ConvBackpropDimensions dims;
- TF_RETURN_IF_ERROR(ConvBackpropComputeDimensionsV2XlaShapes(
- type_string, attrs.num_spatial_dims, activations_shape,
- expanded_filter_shape, out_backprop_shape, attrs.dilations, attrs.strides,
- attrs.padding, attrs.data_format, &dims));
-
// The filter gradients are computed by a convolution of the input
// activations and the output gradients, with some appropriate padding.
// See the comment at the top of conv_grad_ops.h for details.
-
xla::ConvolutionDimensionNumbers dnums;
+ TF_RETURN_IF_ERROR(ConvBackpropComputeDimensionsV2XlaShapes(
+ type_string, attrs.num_spatial_dims, activations_shape,
+ expanded_filter_shape, out_backprop_shape, attrs.dilations, attrs.strides,
+ attrs.padding, attrs.data_format, &dims));
+
// The activations (inputs) form the LHS of the convolution.
// Activations have shape: [batch, in_rows, in_cols, ..., in_depth]
// For the gradient computation, we flip the roles of the batch and
@@ -420,29 +428,99 @@ xla::StatusOr MakeXlaBackpropFilterConvOp(
int n_dim = GetTensorBatchDimIndex(num_dims, attrs.data_format);
int c_dim = GetTensorFeatureDimIndex(num_dims, attrs.data_format);
- // Swap n_dim and c_dim in the activations.
- dnums.set_input_batch_dimension(c_dim);
- dnums.set_input_feature_dimension(n_dim);
+ int64 total_spatial_size = 1;
+ for (int i = 0; i < attrs.num_spatial_dims; ++i) {
+ total_spatial_size *= dims.input_size(i);
+ }
- // The gradients become the RHS of the convolution.
- // The gradients have shape [batch, out_rows, out_cols, ..., out_depth]
- // where the batch becomes the input feature for the convolution.
- dnums.set_kernel_input_feature_dimension(n_dim);
- dnums.set_kernel_output_feature_dimension(c_dim);
+ // We use this approach only for depthwise convolutions where feature counts
+ // are large but space dimensions are small. The conversion logic below
+ // assumes that the data format is NHWC, so we also check that here.
+ bool should_perform_depthwise_conv =
+ attrs.data_format == FORMAT_NHWC &&
+ (total_spatial_size < dims.in_depth) &&
+ filter_tensor_shape.dim_size(num_dims - 1) == 1 && attrs.depthwise;
+
+ int64 num_spatial_dims =
+ attrs.num_spatial_dims + (should_perform_depthwise_conv ? 1 : 0);
+
+ std::vector> padding(num_spatial_dims);
+ std::vector rhs_dilation(num_spatial_dims);
+ std::vector window_strides(num_spatial_dims);
+ std::vector ones(num_spatial_dims, 1);
+
+ if (should_perform_depthwise_conv) {
+ // This approach is similar to handling of grouped convolutions in
+ // the convolution_feature_group_converter.cc. Please refer to it for
+ // details.
+
+ // Add spatial dimension to the activation, and reshape.
+ std::vector activations_reshape_sizes, gradients_reshape_sizes;
+
+ activations_reshape_sizes.push_back(dims.batch_size);
+ gradients_reshape_sizes.push_back(dims.batch_size);
+ for (int i = 0; i < attrs.num_spatial_dims; i++) {
+ activations_reshape_sizes.push_back(dims.input_size(i));
+ gradients_reshape_sizes.push_back(dims.output_size(i));
+ }
+ activations_reshape_sizes.push_back(dims.in_depth);
+ activations_reshape_sizes.push_back(1);
+ gradients_reshape_sizes.push_back(dims.out_depth);
+ gradients_reshape_sizes.push_back(1);
+
+ activations = xla::Reshape(activations, activations_reshape_sizes);
+ gradients = xla::Reshape(gradients, gradients_reshape_sizes);
+
+ int64 new_spatial_dim = activations_reshape_sizes.size() - 1;
+
+ // Set the newly added dimension to be the batch.
+ dnums.set_input_batch_dimension(new_spatial_dim);
+ dnums.set_input_feature_dimension(c_dim);
+
+ // The gradients become the RHS of the convolution.
+ // The gradients have shape [batch, out_rows, out_cols, ..., out_depth, 1]
+ // where the batch becomes a spatial dimension, and 1 becomes
+ // the input feature for the convolution.
+ dnums.set_kernel_input_feature_dimension(new_spatial_dim);
+ dnums.set_kernel_output_feature_dimension(c_dim);
+
+ // Treat original batch dimension as a spatial dimension.
+ dnums.add_input_spatial_dimensions(n_dim);
+ dnums.add_kernel_spatial_dimensions(n_dim);
+ } else {
+ // The activations (inputs) form the LHS of the convolution.
+ // Activations have shape: [batch, in_rows, in_cols, ..., in_depth]
+ // For the gradient computation, we flip the roles of the batch and
+ // feature dimensions.
+ // Each spatial entry has size in_depth * batch
+
+ // Swap n_dim and c_dim in the activations.
+ dnums.set_input_batch_dimension(c_dim);
+ dnums.set_input_feature_dimension(n_dim);
+
+ // The gradients become the RHS of the convolution.
+ // The gradients have shape [batch, out_rows, out_cols, ..., out_depth]
+ // where the batch becomes the input feature for the convolution.
+ dnums.set_kernel_input_feature_dimension(n_dim);
+ dnums.set_kernel_output_feature_dimension(c_dim);
+ }
- std::vector> padding(attrs.num_spatial_dims);
- std::vector rhs_dilation(attrs.num_spatial_dims);
- std::vector window_strides(attrs.num_spatial_dims);
- std::vector ones(attrs.num_spatial_dims, 1);
+ dnums.set_output_batch_dimension(num_spatial_dims);
+ dnums.set_output_feature_dimension(num_spatial_dims + 1);
// Tensorflow filter shape is [ H, W, ..., inC, outC ].
- for (int i = 0; i < attrs.num_spatial_dims; ++i) {
+ for (int i = 0; i < num_spatial_dims; ++i) {
dnums.add_output_spatial_dimensions(i);
}
- dnums.set_output_batch_dimension(attrs.num_spatial_dims);
- dnums.set_output_feature_dimension(attrs.num_spatial_dims + 1);
- for (int i = 0; i < attrs.num_spatial_dims; ++i) {
+ if (should_perform_depthwise_conv) {
+ // Set the right parameters for the newly created spatial dimension.
+ padding[0] = {0, 0};
+ rhs_dilation[0] = 1;
+ window_strides[0] = 1;
+ }
+
+ for (int64 i = 0; i < attrs.num_spatial_dims; ++i) {
int64 dim = GetTensorSpatialDimIndex(num_dims, attrs.data_format, i);
dnums.add_input_spatial_dimensions(dim);
dnums.add_kernel_spatial_dimensions(dim);
@@ -483,9 +561,10 @@ xla::StatusOr MakeXlaBackpropFilterConvOp(
const int64 pad_before =
attrs.padding == Padding::SAME ? std::max(pad_total / 2, 0) : 0;
- padding[i] = {pad_before, pad_total - pad_before};
- rhs_dilation[i] = dims.spatial_dims[i].stride;
- window_strides[i] = attrs.dilations[dim];
+ int64 dim_being_operated = should_perform_depthwise_conv ? i + 1 : i;
+ padding[dim_being_operated] = {pad_before, pad_total - pad_before};
+ rhs_dilation[dim_being_operated] = dims.spatial_dims[i].stride;
+ window_strides[dim_being_operated] = attrs.dilations[dim];
}
// Besides padding the input, we will also expand output_rows to
@@ -496,13 +575,19 @@ xla::StatusOr MakeXlaBackpropFilterConvOp(
//
// This is done by specifying the window dilation factors in the
// convolution HLO below.
- auto filter_backprop =
- xla::ConvGeneralDilated(activations, gradients, window_strides, padding,
- /*lhs_dilation=*/ones, rhs_dilation, dnums);
-
- if (attrs.depthwise) {
- filter_backprop = ContractFilterForDepthwiseBackprop(
- filter_shape, filter_backprop, activations.builder());
+ filter_backprop = xla::ConvGeneralDilated(
+ activations, gradients, window_strides, padding,
+ /*lhs_dilation=*/ones, rhs_dilation, dnums,
+ /*feature_group_count=*/
+ should_perform_depthwise_conv ? dims.in_depth : 1);
+
+ if (should_perform_depthwise_conv) {
+ filter_backprop = xla::Reshape(filter_backprop, filter_shape.dimensions());
+ } else {
+ if (attrs.depthwise) {
+ filter_backprop = ContractFilterForDepthwiseBackprop(
+ filter_shape, filter_backprop, activations.builder());
+ }
}
return filter_backprop;
diff --git a/tensorflow/compiler/tf2xla/kernels/gather_op.cc b/tensorflow/compiler/tf2xla/kernels/gather_op.cc
index 20b0de193dc060197f3062d3be0b8d45f7dcb9b1..41c31d0ed58fe9bc9bbde0bd58993c975f04fd60 100644
--- a/tensorflow/compiler/tf2xla/kernels/gather_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/gather_op.cc
@@ -14,7 +14,6 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/tf2xla/kernels/gather_op_helpers.h"
-#include "tensorflow/compiler/tf2xla/lib/while_loop.h"
#include "tensorflow/compiler/tf2xla/shape_util.h"
#include "tensorflow/compiler/tf2xla/type_util.h"
#include "tensorflow/compiler/tf2xla/xla_context.h"
diff --git a/tensorflow/compiler/tf2xla/kernels/if_op.cc b/tensorflow/compiler/tf2xla/kernels/if_op.cc
index b5e083912555c865b5eadc7697075c9ca4451ca9..4f0f0fd9aefecc3d31f8bd9c8ca40ebb0860c82d 100644
--- a/tensorflow/compiler/tf2xla/kernels/if_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/if_op.cc
@@ -56,6 +56,7 @@ void XlaIfOp::Compile(XlaOpKernelContext* ctx) {
VLOG(1) << "Building If: " << input_types_.size() << " inputs";
std::vector arguments(input_types_.size());
+ int num_resource_args = 0;
for (int i = 0; i < input_types_.size(); ++i) {
XlaCompiler::Argument& arg = arguments[i];
DataType type = ctx->input_type(i + 1);
@@ -81,6 +82,8 @@ void XlaIfOp::Compile(XlaOpKernelContext* ctx) {
<< " type: " << DataTypeString(arg.type)
<< " shape: " << arg.shape.DebugString()
<< " initialized: " << arg.initialized;
+
+ num_resource_args++;
} else {
arg.kind = XlaCompiler::Argument::kParameter;
arg.type = input_types_[i];
@@ -236,9 +239,13 @@ void XlaIfOp::Compile(XlaOpKernelContext* ctx) {
ctx->SetOutput(i, output_handle);
}
if (has_token_input_output_) {
- // Set token output for this "if" op.
+ // Set token output for this "If" op. Token output is the last output of
+ // XLA computation, which comes after all "normal" TF outputs and resource
+ // updates. For "If" node, num of resource updates equals to number of
+ // resource args because we set `return_updated_values_for_all_resources`
+ // to true in XlaCompiler option.
xla::XlaOp token_output =
- xla::GetTupleElement(outputs, output_types_.size());
+ xla::GetTupleElement(outputs, output_types_.size() + num_resource_args);
auto shape_or = b->GetShape(token_output);
OP_REQUIRES_OK(ctx, shape_or.status());
OP_REQUIRES(ctx, xla::ShapeUtil::IsToken(shape_or.ValueOrDie()),
diff --git a/tensorflow/compiler/tf2xla/kernels/image_ops.cc b/tensorflow/compiler/tf2xla/kernels/image_ops.cc
index e9bb0a77e99d144863b027bd214081316d61c314..96ddd42e2ae04d454e4fb85628d139e17a543d2e 100644
--- a/tensorflow/compiler/tf2xla/kernels/image_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/image_ops.cc
@@ -15,12 +15,12 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/kernels/gather_op_helpers.h"
#include "tensorflow/compiler/tf2xla/lib/util.h"
-#include "tensorflow/compiler/tf2xla/lib/while_loop.h"
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/compiler/xla/client/lib/arithmetic.h"
#include "tensorflow/compiler/xla/client/lib/constants.h"
+#include "tensorflow/compiler/xla/client/lib/loops.h"
#include "tensorflow/compiler/xla/client/lib/sorting.h"
#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/shape_util.h"
@@ -505,9 +505,9 @@ class NonMaxSuppressionOp : public XlaOpKernel {
init_values.push_back(included_iou);
auto suppress_loop_result =
- XlaWhileLoop(WhileCondFn(num_boxes, output_size),
- SuppressBodyFn(num_boxes), init_values, "suppress_loop",
- builder)
+ xla::WhileLoopHelper(WhileCondFn(num_boxes, output_size),
+ SuppressBodyFn(num_boxes), init_values,
+ "suppress_loop", builder)
.ValueOrDie();
xla::XlaOp included_score =
diff --git a/tensorflow/compiler/tf2xla/kernels/qr_op.cc b/tensorflow/compiler/tf2xla/kernels/qr_op.cc
index 7ea0afc1f53cbe4cfcc3f6121a4ecd55864c1b52..66ec40a946b8a063d84acd33daf81f52ea2c35ed 100644
--- a/tensorflow/compiler/tf2xla/kernels/qr_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/qr_op.cc
@@ -13,9 +13,9 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#include "tensorflow/compiler/tf2xla/lib/qr.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
+#include "tensorflow/compiler/xla/client/lib/qr.h"
namespace tensorflow {
namespace {
@@ -26,7 +26,7 @@ class QROp : public XlaOpKernel {
OP_REQUIRES_OK(ctx, ctx->GetAttr("full_matrices", &full_matrices_));
}
void Compile(XlaOpKernelContext* ctx) override {
- auto result = QRDecomposition(ctx->Input(0), full_matrices_);
+ auto result = xla::QRDecomposition(ctx->Input(0), full_matrices_);
if (!result.ok()) {
ctx->SetStatus(result.status());
return;
diff --git a/tensorflow/compiler/tf2xla/kernels/random_ops.cc b/tensorflow/compiler/tf2xla/kernels/random_ops.cc
index 8822e29f7e77b1cbc6fa6ca61d0062d9b1b0c36e..2d92056e4f522f6206e7d632f0fa1e8b793fd6e3 100644
--- a/tensorflow/compiler/tf2xla/kernels/random_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/random_ops.cc
@@ -20,12 +20,12 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/kernels/gather_op_helpers.h"
#include "tensorflow/compiler/tf2xla/lib/random.h"
#include "tensorflow/compiler/tf2xla/lib/util.h"
-#include "tensorflow/compiler/tf2xla/lib/while_loop.h"
#include "tensorflow/compiler/tf2xla/shape_util.h"
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/compiler/xla/client/lib/arithmetic.h"
+#include "tensorflow/compiler/xla/client/lib/loops.h"
#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/tensor.h"
@@ -175,8 +175,8 @@ class RandomShuffleOp : public XlaOpKernel {
};
// for i in range(n):
auto swap_loop_result =
- XlaForEachIndex(n, xla::S32, swap_body_fn, {swaps, indices},
- "indices_swap_loop", builder)
+ xla::ForEachIndex(n, xla::S32, swap_body_fn, {swaps, indices},
+ "indices_swap_loop", builder)
.ValueOrDie();
auto swapped_indices = swap_loop_result[1];
diff --git a/tensorflow/compiler/tf2xla/kernels/resampler_ops.cc b/tensorflow/compiler/tf2xla/kernels/resampler_ops.cc
index 54d34a38abc4948a1a08197d72e3e7f763649093..f9985d526033ca675c701a508a3d1576e46bc5f7 100644
--- a/tensorflow/compiler/tf2xla/kernels/resampler_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/resampler_ops.cc
@@ -125,7 +125,7 @@ XlaOp ConcatenateIota(xla::XlaBuilder* b, XlaOp indices,
dimensions.back() = 1;
auto batch_indices =
- xla::Iota(b, xla::ShapeUtil::MakeShape(xla::U32, dimensions),
+ xla::Iota(b, xla::ShapeUtil::MakeShape(xla::S32, dimensions),
/*iota_dimension=*/0);
return xla::ConcatInDim(b, {batch_indices, indices}, dimensions.size() - 1);
@@ -189,11 +189,53 @@ XlaOp ScatterToGradData(XlaOpKernelContext* ctx, XlaOp grad_data, XlaOp indices,
scatter_dim_numbers);
}
+// Bounds samples to 0 if the warp image indices are out of the (-1, image_size)
+// bound.
+// The resulting dimension is given by 'result_dims'.
+XlaOp BoundSamples(XlaOpKernelContext* ctx, XlaOp warp,
+ xla::PrimitiveType warp_type, TensorShape warp_shape,
+ std::vector result_dims,
+ std::vector broadcasted_dims, int64 last_warp_dim,
+ xla::Shape data_shape, XlaOp sample) {
+ auto is_gt_minus_one =
+ xla::Gt(warp,
+ xla::ConvertElementType(
+ xla::ConstantR1(ctx->builder(), {-1, -1}), warp_type),
+ /*broadcast_dimensions=*/{warp_shape.dims() - 1});
+ auto is_lt_image_size = xla::Lt(
+ warp,
+ xla::ConvertElementType(
+ xla::ConstantR1(
+ ctx->builder(),
+ {/*width=*/static_cast(data_shape.dimensions(2)),
+ /*height=*/static_cast(data_shape.dimensions(1))}),
+ warp_type),
+ /*broadcast_dimensions=*/{warp_shape.dims() - 1});
+
+ auto is_in_bound_padded_x_y = xla::And(is_gt_minus_one, is_lt_image_size);
+ // Reduce along last dimension. The resulting dimension is:
+ // [batch, dim_0, ...dim_n].
+ auto is_in_bound = xla::Reduce(
+ is_in_bound_padded_x_y, xla::ConstantR0(ctx->builder(), true),
+ xla::CreateScalarAndComputation(xla::PrimitiveType::PRED, ctx->builder()),
+ {last_warp_dim});
+
+ // Broadcast 'is_in_bound' to the same dimension as 'result_dims'.
+ auto broadcasted_is_in_bound =
+ xla::BroadcastInDim(is_in_bound, result_dims, broadcasted_dims);
+
+ // Set out of bound samples to zero.
+ auto zeros =
+ xla::Broadcast(xla::Zero(ctx->builder(), warp_type), result_dims);
+ return xla::Select(broadcasted_is_in_bound, sample, zeros);
+}
+
// Build computation the backprop into input 'data'.
// Where input:
// grad_output is of dimension [batch, dim_0, ...dim_n, channel]
// ratio is of dimension [batch, dim_0, ...dim_n, 2]
// gather_indices is of dimension [batch, dim_0, ...dim_n, 3]
+// data_shape is of dimension [batch, x(width), y(height), channel]
//
// Output:
// scatter-add to each 2x2 grad_data neighbor:
@@ -201,10 +243,12 @@ XlaOp ScatterToGradData(XlaOpKernelContext* ctx, XlaOp grad_data, XlaOp indices,
// grad_data[cx, fy, chan] += output_grad * (1 - dx) * dy
// grad_data[fx, cy, chan] += output_grad * dx * (1 - dy)
// grad_data[cx, cy, chan] += output_grad * (1 - dx) * (1 - dy)
-// where (dx, dy) is (1 - ratio).
+// where (dx, dy) is (1 - ratio). If (dx, dy) is out of bound, then the their
+// contribution is 0 to 'grad_data'.
XlaOp CalculateGradData(XlaOpKernelContext* ctx, XlaOp grad_output, XlaOp ratio,
- XlaOp gather_indices, xla::PrimitiveType warp_type,
- TensorShape warp_shape, int64 data_channels,
+ XlaOp gather_indices, XlaOp warp,
+ xla::PrimitiveType warp_type, TensorShape warp_shape,
+ int64 last_warp_dim, int64 data_channels,
xla::Shape data_shape) {
// Weights tensor has dimension [batch, dim_0, ... dim_n, 4].
auto weights = BilinearWeights(ctx, ratio, warp_shape, warp_type);
@@ -229,6 +273,18 @@ XlaOp CalculateGradData(XlaOpKernelContext* ctx, XlaOp grad_output, XlaOp ratio,
std::iota(reshaped_weights_indices.begin(), reshaped_weights_indices.end(),
0);
+ // Set out of bound weights to 0.
+ // The dimension of the reshaped_weight: [batch, dim_0, ...dim_n, 2, 2].
+ std::vector reshaped_result_dims(warp_dims.begin(),
+ warp_dims.end() - 1);
+ reshaped_result_dims.push_back(2);
+ reshaped_result_dims.push_back(2);
+ std::vector broadcasted_dims(warp_dims.size() - 1);
+ std::iota(broadcasted_dims.begin(), broadcasted_dims.end(), 0);
+ reshaped_weights = BoundSamples(ctx, warp, warp_type, warp_shape,
+ reshaped_result_dims, broadcasted_dims,
+ last_warp_dim, data_shape, reshaped_weights);
+
// The dimension is [batch, dim_0, ..., dim_n, 2, 2, data_channel].
auto broadcast_reshaped_weights = xla::BroadcastInDim(
reshaped_weights, weights_with_channels_dims, reshaped_weights_indices);
@@ -245,18 +301,41 @@ XlaOp CalculateGradData(XlaOpKernelContext* ctx, XlaOp grad_output, XlaOp ratio,
auto grad_data = xla::ConstantLiteral(
ctx->builder(), xla::Literal::CreateFromShape(data_shape));
- return ScatterToGradData(ctx, grad_data, gather_indices,
- grad_output_multiply_weights, warp_shape.dims(),
- warp_type);
+ // Pad grad data then slice it back.
+ //
+ // After left and right column 0-padding, the new dimension of padded data
+ // will be [batch, x+2, y+2, channel].
+ auto padded_grad_data =
+ xla::Pad(grad_data, xla::Zero(ctx->builder(), warp_type),
+ xla::MakeEdgePaddingConfig({{0, 0}, {1, 1}, {1, 1}, {0, 0}}));
+
+ auto shifting_value = xla::ConstantR1(
+ ctx->builder(), {/*batch=*/0, /*x(width)=*/1, /*y(height)=*/1});
+ auto shifted_gather_indices =
+ xla::Add(gather_indices, shifting_value, {last_warp_dim});
+
+ auto updated_grad_data = ScatterToGradData(
+ ctx, padded_grad_data, shifted_gather_indices,
+ grad_output_multiply_weights, warp_shape.dims(), warp_type);
+
+ const int64 batch_size = data_shape.dimensions(0);
+ const int64 width = data_shape.dimensions(1);
+ const int64 height = data_shape.dimensions(2);
+ // Slice out the result accounting for the padding.
+ return xla::Slice(
+ updated_grad_data, /*start_indices=*/{0, 1, 1, 0},
+ /*limit_indices=*/{batch_size, width + 1, height + 1, data_channels},
+ /*strides=*/{1, 1, 1, 1});
}
// Build computation for the backprop into input 'warp'.
// Where input:
-// warp is of dimension [batch, dim_0, ...dim_n, 2]
-// grad_output is of dimension [batch, dim_0, ...dim_n, channel]
-// ratio is of dimension [batch, dim_0, ...dim_n, 2]
-// gather_indices is of dimension [batch, dim_0, ...dim_n, 3]
-// data is of dimension [batch, x, y, channel]
+// warp is of dimension [batch, dim_0, ...dim_n, 2]
+// grad_output is of dimension [batch, dim_0, ...dim_n, channel]
+// ratio is of dimension [batch, dim_0, ...dim_n, 2]
+// gather_indices is of dimension [batch, dim_0, ...dim_n, 3] where the last
+// dimension of size 3 is for {batch, x(width), y(height)}.
+// data is of dimension [batch, x, y, channel]
//
// Output (simplified by ignoring the batch dimensions):
// Since the forward path has:
@@ -275,12 +354,12 @@ XlaOp CalculateGradData(XlaOpKernelContext* ctx, XlaOp grad_output, XlaOp ratio,
// grad_warp_x = py * (img_cxcy - img_fxcy) + (1-py) * (img_cxfy-img_fxfy)
// grad_warp_y = px * (img_cxcy - img_cxfy) + (1-px) * (img_fxcy-img_fxfy)
//
-// where (px, py) is warp, (fx, fy) is the left top corner and (cx, cy) is the
+// where (px, py) is warp, (fx, fy) is the top left corner and (cx, cy) is the
// bottom right corner in a 2x2 neighborhood.
XlaOp CalculateGradWarp(XlaOpKernelContext* ctx, XlaOp grad_output, XlaOp ratio,
XlaOp gather_indices, XlaOp data,
TensorShape warp_shape, int64 data_channels,
- xla::PrimitiveType data_type) {
+ xla::PrimitiveType data_type, xla::Shape data_shape) {
auto warp_dims = warp_shape.dim_sizes();
std::vector warp_dims_without_last_dims(warp_dims.begin(),
warp_dims.end() - 1);
@@ -289,12 +368,30 @@ XlaOp CalculateGradWarp(XlaOpKernelContext* ctx, XlaOp grad_output, XlaOp ratio,
std::vector neighbor_broadcast_dims = warp_dims_without_last_dims;
neighbor_broadcast_dims.push_back(4);
- // The dimension is [batch, dim_0, ... dim_n, 4, data_channels]
- auto neighbors_data = Gather2by2Neighbors(
- ctx->builder(), data, gather_indices, data_channels, warp_shape.dims());
+ // With dimension [batch, dim_0, ...dim_n, 4]
+ auto neighbor_broadcast_shape =
+ xla::ShapeUtil::MakeShape(data_type, neighbor_broadcast_dims);
const int64 last_warp_dim = warp_shape.dims() - 1;
+ // Pad data with 0, before gathering such that 0 will be returned for samples
+ // in the range of (-1, 0) or (image_dimension-1, image_dimension).
+ // After left and right column 0-padding, the new dimension of padded data
+ // will be [batch, x+2, y+2, channel].
+ auto padded_data =
+ xla::Pad(data, xla::Zero(ctx->builder(), data_type),
+ xla::MakeEdgePaddingConfig({{0, 0}, {1, 1}, {1, 1}, {0, 0}}));
+
+ auto shifting_value = xla::ConstantR1(
+ ctx->builder(), {/*batch=*/0, /*x(width)=*/1, /*y(height)=*/1});
+ auto shifted_gather_indices =
+ xla::Add(gather_indices, shifting_value, {last_warp_dim});
+
+ // The dimension is [batch, dim_0, ... dim_n, 4, data_channels]
+ auto neighbors_data =
+ Gather2by2Neighbors(ctx->builder(), padded_data, shifted_gather_indices,
+ data_channels, warp_shape.dims());
+
// Since we will be creating the dot product of:
// lhs: [batch, dim_0, ...dim_n, 4]
// and
@@ -417,7 +514,7 @@ class ResamplerOp : public XlaOpKernel {
// Find the coordinates of the top left corner for the 2x2 region to be
// sampled from. The dimensions are [batch, dim_0, ... dim_n, 2] where the
// last dimension of size 2 in turn is [x, y].
- XlaOp top_left = xla::ConvertElementType(warp, xla::U32);
+ XlaOp top_left = xla::ConvertElementType(warp, xla::S32);
auto gather_indices = ConcatenateIota(ctx->builder(), top_left, warp_shape);
@@ -526,7 +623,8 @@ class ResamplerGradOp : public XlaOpKernel {
size, "]"));
}
// Last dimension of warp shape must be of size 2.
- OP_REQUIRES(ctx, warp_shape.dim_size(warp_shape.dims() - 1) == 2,
+ const int64 last_warp_dim = warp_shape.dims() - 1;
+ OP_REQUIRES(ctx, warp_shape.dim_size(last_warp_dim) == 2,
errors::InvalidArgument(
"the last dimension of warp must be exactly size 2."));
xla::PrimitiveType warp_type = ctx->input_xla_type(1);
@@ -549,24 +647,32 @@ class ResamplerGradOp : public XlaOpKernel {
// Find the top left corner coordinate for the region to be sampled from.
// The dimensions are [batch, dim_0, ... dim_n, 2] where the last dimension
// of size 2 in turn is [x, y].
- XlaOp top_left = xla::ConvertElementType(warp, xla::U32);
+ XlaOp top_left = xla::ConvertElementType(xla::Floor(warp), xla::S32);
- // Dimensions are [batch, dim_0, ... dim_n, 2]
+ // Dimensions are [batch, dim_0, ... dim_n, 2].
XlaOp ratio = warp - xla::ConvertElementType(top_left, warp_type);
// Indices for gathering neighboring pixels.
auto gather_indices = ConcatenateIota(ctx->builder(), top_left, warp_shape);
- auto grad_data =
- CalculateGradData(ctx, grad_output, ratio, gather_indices, warp_type,
- warp_shape, data_channels, data_shape);
+ auto grad_data = CalculateGradData(
+ ctx, grad_output, ratio, gather_indices, warp, warp_type, warp_shape,
+ last_warp_dim, data_channels, data_shape);
auto grad_warp =
CalculateGradWarp(ctx, grad_output, ratio, gather_indices, data,
- warp_shape, data_channels, data_type);
+ warp_shape, data_channels, data_type, data_shape);
+ auto warp_dims = warp_shape.dim_sizes();
+ std::vector result_dims(warp_dims.begin(), warp_dims.end() - 1);
+ result_dims.push_back(2);
+ std::vector broadcasted_dims(warp_dims.size() - 1);
+ std::iota(broadcasted_dims.begin(), broadcasted_dims.end(), 0);
+ auto grad_warp_bounded =
+ BoundSamples(ctx, warp, warp_type, warp_shape, result_dims,
+ broadcasted_dims, last_warp_dim, data_shape, grad_warp);
ctx->SetOutput(0, grad_data);
- ctx->SetOutput(1, grad_warp);
+ ctx->SetOutput(1, grad_warp_bounded);
}
};
diff --git a/tensorflow/compiler/tf2xla/kernels/training_ops.cc b/tensorflow/compiler/tf2xla/kernels/training_ops.cc
index 960c1462ceb8c00a2d6c96564f6c985fd1caef0f..26d4214099d1d07c1b2e275d783654d9cd948e28 100644
--- a/tensorflow/compiler/tf2xla/kernels/training_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/training_ops.cc
@@ -172,6 +172,65 @@ class ResourceApplyMomentum : public XlaOpKernel {
REGISTER_XLA_OP(Name("ResourceApplyMomentum").TypeConstraint("T", kFloatTypes),
ResourceApplyMomentum);
+class ResourceApplyKerasMomentum : public XlaOpKernel {
+ public:
+ explicit ResourceApplyKerasMomentum(OpKernelConstruction* ctx)
+ : XlaOpKernel(ctx) {
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("use_nesterov", &use_nesterov_));
+ }
+
+ void Compile(XlaOpKernelContext* ctx) override {
+ DataType type = ctx->input_type(2);
+
+ TensorShape var_shape, accum_shape;
+ xla::XlaOp var, accum;
+ OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, type, &var_shape, &var));
+ OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(1, type, &accum_shape, &accum));
+
+ OP_REQUIRES(ctx, var_shape.IsSameSize(accum_shape),
+ errors::InvalidArgument(
+ "var and accum do not have the same shape",
+ var_shape.DebugString(), " ", accum_shape.DebugString()));
+
+ TensorShape lr_shape = ctx->InputShape(2);
+ OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(lr_shape),
+ errors::InvalidArgument("lr is not a scalar: ",
+ lr_shape.DebugString()));
+
+ TensorShape grad_shape = ctx->InputShape(3);
+ OP_REQUIRES(ctx, var_shape.IsSameSize(grad_shape),
+ errors::InvalidArgument(
+ "var and grad do not have the same shape",
+ var_shape.DebugString(), " ", grad_shape.DebugString()));
+
+ TensorShape momentum_shape = ctx->InputShape(4);
+ OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(momentum_shape),
+ errors::InvalidArgument("momentum is not a scalar: ",
+ momentum_shape.DebugString()));
+
+ xla::XlaOp lr = ctx->Input(2);
+ xla::XlaOp grad = ctx->Input(3);
+ xla::XlaOp momentum = ctx->Input(4);
+
+ accum = accum * momentum - grad * lr;
+ if (use_nesterov_) {
+ // See https://github.com/tensorflow/tensorflow/pull/2798 for an
+ // explanation of the reparameterization used here.
+ var = var + accum * momentum - grad * lr;
+ } else {
+ var = var + accum;
+ }
+ OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, type, var));
+ OP_REQUIRES_OK(ctx, ctx->AssignVariable(1, type, accum));
+ }
+
+ private:
+ bool use_nesterov_;
+};
+REGISTER_XLA_OP(
+ Name("ResourceApplyKerasMomentum").TypeConstraint("T", kFloatTypes),
+ ResourceApplyKerasMomentum);
+
class ResourceApplyAdagrad : public XlaOpKernel {
public:
explicit ResourceApplyAdagrad(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
diff --git a/tensorflow/compiler/tf2xla/kernels/while_op.cc b/tensorflow/compiler/tf2xla/kernels/while_op.cc
index ce007fc04a818869686b9936a1607cee42665e87..89b577bfc05b4665d492f4ea5cf6f869af2fa9a9 100644
--- a/tensorflow/compiler/tf2xla/kernels/while_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/while_op.cc
@@ -41,8 +41,7 @@ Status MakeXlaCompilerArgumentsFromInputs(
*has_uninitialized_vars = false;
*has_tensor_arrays = false;
for (int i = 0; i < ctx->num_inputs(); ++i) {
- VLOG(2) << " Input " << i
- << " type: " << DataTypeString(ctx->input_type(i))
+ VLOG(2) << " Input " << i << " type: " << DataTypeString(ctx->input_type(i))
<< " shape: " << ctx->InputShape(i).DebugString();
XlaCompiler::Argument& arg = (*args)[i];
DataType type = ctx->input_type(i);
@@ -233,13 +232,22 @@ void XlaWhileOp::Compile(XlaOpKernelContext* ctx) {
xla::ShapeUtil::HumanString(body_input_shape), " vs. ",
xla::ShapeUtil::HumanString(body.xla_output_shape)));
- xla::Shape expected_cond_output_shape = xla::ShapeUtil::MakeTupleShape(
- {xla::ShapeUtil::MakeShape(xla::PRED, {})});
+ xla::Shape expected_cond_output_shape_without_side_effect =
+ xla::ShapeUtil::MakeTupleShape(
+ {xla::ShapeUtil::MakeShape(xla::PRED, {})});
+ xla::Shape expected_cond_output_shape_with_side_effect =
+ xla::ShapeUtil::MakeTupleShape({xla::ShapeUtil::MakeShape(xla::PRED, {}),
+ xla::ShapeUtil::MakeTokenShape()});
OP_REQUIRES(ctx,
- xla::ShapeUtil::Compatible(cond.xla_output_shape,
- expected_cond_output_shape),
+ xla::ShapeUtil::Compatible(
+ cond.xla_output_shape,
+ expected_cond_output_shape_without_side_effect) ||
+ xla::ShapeUtil::Compatible(
+ cond.xla_output_shape,
+ expected_cond_output_shape_with_side_effect),
errors::InvalidArgument(
- "Output shape of loop condition should be (pred[]), got: ",
+ "Output shape of loop condition should be (pred[]) or "
+ "(pred[], token[]), got: ",
xla::ShapeUtil::HumanString(cond.xla_output_shape)));
int num_inputs = body.input_mapping.size();
diff --git a/tensorflow/compiler/tf2xla/lib/BUILD b/tensorflow/compiler/tf2xla/lib/BUILD
index 3e7a761120317ff85947559b7b2e52be9232afb7..3d7b0bc959f9dbf3c1b9749379e2ea0d285b302b 100644
--- a/tensorflow/compiler/tf2xla/lib/BUILD
+++ b/tensorflow/compiler/tf2xla/lib/BUILD
@@ -15,8 +15,6 @@ filegroup(
]),
)
-load("//tensorflow/compiler/xla/tests:build_defs.bzl", "xla_test")
-
cc_library(
name = "broadcast",
srcs = ["broadcast.cc"],
@@ -33,27 +31,6 @@ cc_library(
],
)
-cc_library(
- name = "cholesky",
- srcs = ["cholesky.cc"],
- hdrs = ["cholesky.h"],
- deps = [
- ":util",
- ":while_loop",
- "//tensorflow/compiler/xla:literal",
- "//tensorflow/compiler/xla:shape_util",
- "//tensorflow/compiler/xla:status_macros",
- "//tensorflow/compiler/xla:statusor",
- "//tensorflow/compiler/xla:xla_data_proto",
- "//tensorflow/compiler/xla/client:xla_builder",
- "//tensorflow/compiler/xla/client/lib:constants",
- "//tensorflow/compiler/xla/client/lib:matrix",
- "//tensorflow/compiler/xla/client/lib:slicing",
- "//tensorflow/compiler/xla/client/lib:triangular_solve",
- "//tensorflow/core:lib",
- ],
-)
-
cc_library(
name = "random",
srcs = ["random.cc"],
@@ -69,35 +46,12 @@ cc_library(
],
)
-cc_library(
- name = "qr",
- srcs = ["qr.cc"],
- hdrs = ["qr.h"],
- deps = [
- ":util",
- ":while_loop",
- "//tensorflow/compiler/xla:literal_util",
- "//tensorflow/compiler/xla:shape_util",
- "//tensorflow/compiler/xla:status_macros",
- "//tensorflow/compiler/xla:statusor",
- "//tensorflow/compiler/xla:xla_data_proto",
- "//tensorflow/compiler/xla/client:xla_builder",
- "//tensorflow/compiler/xla/client/lib:arithmetic",
- "//tensorflow/compiler/xla/client/lib:constants",
- "//tensorflow/compiler/xla/client/lib:math",
- "//tensorflow/compiler/xla/client/lib:matrix",
- "//tensorflow/compiler/xla/client/lib:slicing",
- "//tensorflow/core:lib",
- ],
-)
-
cc_library(
name = "scatter",
srcs = ["scatter.cc"],
hdrs = ["scatter.h"],
deps = [
":util",
- ":while_loop",
"//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:status_macros",
@@ -128,19 +82,3 @@ cc_library(
"@com_google_absl//absl/types:span",
],
)
-
-cc_library(
- name = "while_loop",
- srcs = ["while_loop.cc"],
- hdrs = ["while_loop.h"],
- deps = [
- ":util",
- "//tensorflow/compiler/xla:shape_util",
- "//tensorflow/compiler/xla:status_macros",
- "//tensorflow/compiler/xla:statusor",
- "//tensorflow/compiler/xla/client:xla_builder",
- "//tensorflow/compiler/xla/client:xla_computation",
- "@com_google_absl//absl/strings",
- "@com_google_absl//absl/types:span",
- ],
-)
diff --git a/tensorflow/compiler/tf2xla/lib/scatter.cc b/tensorflow/compiler/tf2xla/lib/scatter.cc
index 2b1c2ced925d9fee7392986015a6e716a94d356f..688056791f9750e6b22df4b2cd4643de0b780651 100644
--- a/tensorflow/compiler/tf2xla/lib/scatter.cc
+++ b/tensorflow/compiler/tf2xla/lib/scatter.cc
@@ -20,7 +20,6 @@ limitations under the License.
#include "absl/types/span.h"
#include "tensorflow/compiler/tf2xla/lib/util.h"
-#include "tensorflow/compiler/tf2xla/lib/while_loop.h"
#include "tensorflow/compiler/xla/client/lib/arithmetic.h"
#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/literal.h"
diff --git a/tensorflow/compiler/tf2xla/resource_operation_table.cc b/tensorflow/compiler/tf2xla/resource_operation_table.cc
index 72b240996fb4d9dcb5f5dfd919da618cbae08c16..ff9f1b9ccba2c4f3307890d5aac4ddb6cfaafcd9 100644
--- a/tensorflow/compiler/tf2xla/resource_operation_table.cc
+++ b/tensorflow/compiler/tf2xla/resource_operation_table.cc
@@ -65,6 +65,7 @@ CreateResourceOpInfoMap() {
add("ResourceApplyFtrlV2" , kReadWrite, kVariable);
add("ResourceApplyGradientDescent" , kReadWrite, kVariable);
add("ResourceApplyMomentum" , kReadWrite, kVariable);
+ add("ResourceApplyKerasMomentum" , kReadWrite, kVariable);
add("ResourceApplyPowerSign" , kReadWrite, kVariable);
add("ResourceApplyProximalAdagrad" , kReadWrite, kVariable);
add("ResourceApplyProximalGradientDescent" , kReadWrite, kVariable);
diff --git a/tensorflow/compiler/tf2xla/side_effect_util.cc b/tensorflow/compiler/tf2xla/side_effect_util.cc
index b233e6b2c28e1968bb74901fc684e808ae45ab60..b62f8e9115229ac35c657d374c68336f1168ff77 100644
--- a/tensorflow/compiler/tf2xla/side_effect_util.cc
+++ b/tensorflow/compiler/tf2xla/side_effect_util.cc
@@ -24,6 +24,8 @@ const char kXlaTokenInputNodesAttrName[] = "_xla_token_input_nodes";
const char kXlaTokenArgNodeName[] = "_xla_token_arg_node";
+const char kXlaHasHostTransferAttrName[] = "_xla_has_host_transfer";
+
std::set CalculateTokenInputsForOutputToken(const Graph& g) {
std::set results;
Node* first_side_effecting_node_on_path = nullptr;
diff --git a/tensorflow/compiler/tf2xla/side_effect_util.h b/tensorflow/compiler/tf2xla/side_effect_util.h
index f22ddb2f58e1fa5c10ca0fdb956d9136942388b7..7081b362c36c4785164b29003a5f89cd73bcf3af 100644
--- a/tensorflow/compiler/tf2xla/side_effect_util.h
+++ b/tensorflow/compiler/tf2xla/side_effect_util.h
@@ -35,6 +35,9 @@ extern const char kXlaTokenInputNodesAttrName[];
// node has side-effect dependency on current graph's token input.
extern const char kXlaTokenArgNodeName[];
+// This node have XlaRecvAtHost/XlaSendFromHost in its associated functions.
+extern const char kXlaHasHostTransferAttrName[];
+
// Calculates side-effect dependencies for the graph's token output.
// Returns a set of node names representing these dependencies.
std::set CalculateTokenInputsForOutputToken(const Graph& g);
diff --git a/tensorflow/compiler/tf2xla/tf2xla_test.cc b/tensorflow/compiler/tf2xla/tf2xla_test.cc
index ab26d939ccba75ce58609ffd71c7ccadbe90cfa8..24afe595b18b823818bd8fe65bc599af8bce040a 100644
--- a/tensorflow/compiler/tf2xla/tf2xla_test.cc
+++ b/tensorflow/compiler/tf2xla/tf2xla_test.cc
@@ -91,7 +91,7 @@ TEST(ConvertGraphDefToXla, Sum) {
client->ExecuteAndTransfer(computation, {x_global.get(), y_global.get()});
TF_EXPECT_OK(result_or.status());
xla::Literal result = std::move(result_or.ValueOrDie());
- EXPECT_EQ("(s32[]) (\n42\n)", result.ToString());
+ EXPECT_EQ("(\ns32[] 42\n)", result.ToString());
config.mutable_feed(0)->mutable_id()->set_output_index(
123); /* invalid output_index */
diff --git a/tensorflow/compiler/tf2xla/tf2xla_util.cc b/tensorflow/compiler/tf2xla/tf2xla_util.cc
index cc81772e8c5da710bc733f7e4f5fe820b2c2d110..18d87727c500619bf386be7d8c7085724f44aba3 100644
--- a/tensorflow/compiler/tf2xla/tf2xla_util.cc
+++ b/tensorflow/compiler/tf2xla/tf2xla_util.cc
@@ -364,6 +364,7 @@ Status AddPlaceholdersForFeeds(
GraphDef gd;
*gd.mutable_versions() = graph_def->versions();
*gd.add_node() = *existing;
+ MergeDebugInfo(NodeDebugInfo(*existing), gd.mutable_node(0));
TF_RETURN_IF_ERROR(
AddDefaultAttrsToGraphDef(&gd, *op_registry, 0 /*node_offset*/));
@@ -390,6 +391,7 @@ Status AddPlaceholdersForFeeds(
// in this code.
for (auto it = placeholder_info.begin(); it != placeholder_info.end(); ++it) {
const PlaceholderInfo& info = it->second;
+ // TODO(shikharagarwal): Add original node information.
NodeDef* d = graph_def->add_node();
d->set_name(info.placeholder_name);
d->set_op("PlaceholderV2");
@@ -557,6 +559,12 @@ bool HasAssociatedFunction(const NodeDef& node_def,
return true;
}
+ if (node_def.op() == "XlaHostCompute") {
+ // XlaHostCompute has "shape_inference_graph" func attr, but that's not
+ // related to graph execution.
+ return false;
+ }
+
for (const auto& iter : node_def.attr()) {
if (iter.second.has_func()) {
return true;
@@ -578,6 +586,9 @@ std::vector GetAssociatedFunctions(
// This is a SymbolicGradient op.
AttrValueMap attrs(node.attrs().begin(), node.attrs().end());
results.emplace_back(AssociatedFunctionInfo::SymbolicGradient(op, attrs));
+ } else if (node.type_string() == "XlaHostCompute") {
+ // XlaHostCompute has "shape_inference_graph" func attr, but that's not
+ // related to graph execution.
} else {
// Collect all function attrs for the node.
for (auto& iter : node.attrs()) {
@@ -599,7 +610,9 @@ Status RewriteAssociatedFunction(
switch (associated_function.type()) {
case AssociatedFunctionInfo::kFunctionCallNode: {
// Change this node to call the new function.
- NodeDefBuilder builder(node->name(), rewritten_function_name, fld);
+ NodeDebugInfo debug_info(*node);
+ NodeDefBuilder builder(node->name(), rewritten_function_name, fld,
+ &debug_info);
for (auto attr : node->attrs()) {
builder.Attr(attr.first, attr.second);
}
diff --git a/tensorflow/compiler/xla/BUILD b/tensorflow/compiler/xla/BUILD
index 4360e0857964b0ac63fc887e269b04a4b00d854a..722d1376687efa1c04158e3fd9ce539aac9d0122 100644
--- a/tensorflow/compiler/xla/BUILD
+++ b/tensorflow/compiler/xla/BUILD
@@ -109,7 +109,7 @@ cc_library(
name = "status_macros",
srcs = ["status_macros.cc"],
hdrs = ["status_macros.h"],
- visibility = [":friends"],
+ visibility = ["//visibility:public"],
deps = [
":statusor",
":types",
@@ -224,6 +224,7 @@ cc_library(
name = "shape_util",
srcs = [
"index_util.cc",
+ "layout.cc",
"layout_util.cc",
"primitive_util.cc",
"shape.cc",
@@ -231,6 +232,7 @@ cc_library(
],
hdrs = [
"index_util.h",
+ "layout.h",
"layout_util.h",
"primitive_util.h",
"shape.h",
@@ -290,6 +292,22 @@ tf_cc_test(
],
)
+tf_cc_test(
+ name = "primitive_util_test",
+ srcs = ["primitive_util_test.cc"],
+ deps = [
+ ":shape_util",
+ ":status_macros",
+ ":test",
+ ":test_helpers",
+ ":types",
+ ":util",
+ ":xla_data_proto",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:test_main",
+ ],
+)
+
tf_cc_test(
name = "layout_util_test",
srcs = ["layout_util_test.cc"],
@@ -301,6 +319,22 @@ tf_cc_test(
],
)
+tf_cc_test(
+ name = "layout_test",
+ srcs = ["layout_test.cc"],
+ deps = [
+ ":shape_util",
+ ":status_macros",
+ ":test",
+ ":test_helpers",
+ ":types",
+ ":util",
+ ":xla_data_proto",
+ "//tensorflow/core:test_main",
+ "@com_google_absl//absl/strings",
+ ],
+)
+
tf_cc_test(
name = "index_util_test",
srcs = ["index_util_test.cc"],
@@ -575,6 +609,7 @@ cc_library(
":types",
":util",
":xla_data_proto",
+ "//tensorflow/compiler/xla/service:hlo_parser",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"@com_google_absl//absl/memory",
@@ -705,7 +740,6 @@ cc_library(
"//tensorflow/compiler/xla/service:hlo",
"//tensorflow/compiler/xla/service:hlo_evaluator",
"//tensorflow/compiler/xla/service:shape_inference",
- "//tensorflow/compiler/xla/service/cpu:runtime_single_threaded_matmul",
"//tensorflow/core:lib",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/types:span",
diff --git a/tensorflow/compiler/xla/client/BUILD b/tensorflow/compiler/xla/client/BUILD
index fe99564d3c671cd7890e1fa26fcd2e3384972983..e61d9d2520366f3f21a18b6c62ba924fba23308a 100644
--- a/tensorflow/compiler/xla/client/BUILD
+++ b/tensorflow/compiler/xla/client/BUILD
@@ -3,7 +3,7 @@
licenses(["notice"]) # Apache 2.0
-package(default_visibility = [":friends"])
+package(default_visibility = ["//visibility:public"])
package_group(
name = "friends",
diff --git a/tensorflow/compiler/xla/client/client.cc b/tensorflow/compiler/xla/client/client.cc
index 74b76f929949d3300a5d0ff45d5fa4cd9f162642..43127cae1e5d81521003a28288e27d291e33c9b9 100644
--- a/tensorflow/compiler/xla/client/client.cc
+++ b/tensorflow/compiler/xla/client/client.cc
@@ -186,7 +186,7 @@ StatusOr Client::ComputeConstant(const XlaComputation& computation,
ComputeConstantGraphRequest request;
*request.mutable_computation() = computation.proto();
if (output_layout != nullptr) {
- *request.mutable_output_layout() = *output_layout;
+ *request.mutable_output_layout() = output_layout->ToProto();
}
ComputeConstantResponse response;
diff --git a/tensorflow/compiler/xla/client/lib/BUILD b/tensorflow/compiler/xla/client/lib/BUILD
index 41db8de29ff0085a30847ff41db4ffbfc774e2a1..970f00759f630f30f1c1321231fd9e0199026142 100644
--- a/tensorflow/compiler/xla/client/lib/BUILD
+++ b/tensorflow/compiler/xla/client/lib/BUILD
@@ -1,5 +1,7 @@
# Common computation builders for XLA.
+load("//tensorflow/compiler/xla/tests:build_defs.bzl", "generate_backend_suites", "xla_test")
+
licenses(["notice"]) # Apache 2.0
package(default_visibility = ["//tensorflow/compiler/xla/client:friends"])
@@ -13,9 +15,6 @@ filegroup(
]),
)
-load("//tensorflow/compiler/xla/tests:build_defs.bzl", "xla_test")
-load("//tensorflow/compiler/xla/tests:build_defs.bzl", "generate_backend_suites")
-
# Generate test_suites for all backends, named "${backend}_tests".
generate_backend_suites()
@@ -35,6 +34,48 @@ cc_library(
],
)
+cc_library(
+ name = "cholesky",
+ srcs = ["cholesky.cc"],
+ hdrs = ["cholesky.h"],
+ deps = [
+ ":math",
+ "//tensorflow/compiler/xla:literal",
+ "//tensorflow/compiler/xla:shape_util",
+ "//tensorflow/compiler/xla:status_macros",
+ "//tensorflow/compiler/xla:statusor",
+ "//tensorflow/compiler/xla:xla_data_proto",
+ "//tensorflow/compiler/xla/client:xla_builder",
+ "//tensorflow/compiler/xla/client/lib:constants",
+ "//tensorflow/compiler/xla/client/lib:loops",
+ "//tensorflow/compiler/xla/client/lib:matrix",
+ "//tensorflow/compiler/xla/client/lib:slicing",
+ "//tensorflow/compiler/xla/client/lib:triangular_solve",
+ "//tensorflow/core:lib",
+ ],
+)
+
+xla_test(
+ name = "cholesky_test",
+ srcs = ["cholesky_test.cc"],
+ tags = ["optonly"],
+ deps = [
+ ":arithmetic",
+ ":cholesky",
+ ":matrix",
+ "//tensorflow/compiler/xla:array2d",
+ "//tensorflow/compiler/xla:literal",
+ "//tensorflow/compiler/xla:statusor",
+ "//tensorflow/compiler/xla:test",
+ "//tensorflow/compiler/xla:types",
+ "//tensorflow/compiler/xla/client:xla_builder",
+ "//tensorflow/compiler/xla/tests:client_library_test_base",
+ "//tensorflow/compiler/xla/tests:literal_test_util",
+ "//tensorflow/compiler/xla/tests:xla_internal_test_main",
+ "//tensorflow/core:test",
+ ],
+)
+
cc_library(
name = "constants",
srcs = ["constants.cc"],
@@ -75,6 +116,22 @@ cc_library(
],
)
+cc_library(
+ name = "loops",
+ srcs = ["loops.cc"],
+ hdrs = ["loops.h"],
+ deps = [
+ ":constants",
+ "//tensorflow/compiler/xla:shape_util",
+ "//tensorflow/compiler/xla:status_macros",
+ "//tensorflow/compiler/xla:statusor",
+ "//tensorflow/compiler/xla/client:xla_builder",
+ "//tensorflow/compiler/xla/client:xla_computation",
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/types:span",
+ ],
+)
+
cc_library(
name = "math",
srcs = ["math.cc"],
@@ -177,6 +234,48 @@ cc_library(
],
)
+cc_library(
+ name = "qr",
+ srcs = ["qr.cc"],
+ hdrs = ["qr.h"],
+ deps = [
+ ":arithmetic",
+ ":constants",
+ ":loops",
+ ":math",
+ ":matrix",
+ ":slicing",
+ "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:shape_util",
+ "//tensorflow/compiler/xla:status_macros",
+ "//tensorflow/compiler/xla:statusor",
+ "//tensorflow/compiler/xla:xla_data_proto",
+ "//tensorflow/compiler/xla/client:xla_builder",
+ "//tensorflow/core:lib",
+ ],
+)
+
+xla_test(
+ name = "qr_test",
+ srcs = ["qr_test.cc"],
+ tags = ["optonly"],
+ deps = [
+ ":matrix",
+ ":qr",
+ "//tensorflow/compiler/xla:array2d",
+ "//tensorflow/compiler/xla:array3d",
+ "//tensorflow/compiler/xla:literal",
+ "//tensorflow/compiler/xla:statusor",
+ "//tensorflow/compiler/xla:test",
+ "//tensorflow/compiler/xla:xla_data_proto",
+ "//tensorflow/compiler/xla/client:xla_builder",
+ "//tensorflow/compiler/xla/tests:client_library_test_base",
+ "//tensorflow/compiler/xla/tests:literal_test_util",
+ "//tensorflow/compiler/xla/tests:xla_internal_test_main",
+ "//tensorflow/core:test",
+ ],
+)
+
cc_library(
name = "slicing",
srcs = ["slicing.cc"],
@@ -237,6 +336,34 @@ xla_test(
],
)
+cc_library(
+ name = "quantize",
+ hdrs = ["quantize.h"],
+ deps = [
+ ":constants",
+ "//tensorflow/compiler/xla:types",
+ "//tensorflow/compiler/xla:util",
+ "//tensorflow/compiler/xla:xla_data_proto",
+ "//tensorflow/compiler/xla/client:xla_builder",
+ "//tensorflow/core:lib",
+ ],
+)
+
+xla_test(
+ name = "quantize_test",
+ srcs = ["quantize_test.cc"],
+ tags = ["enable_for_xla_interpreter"],
+ deps = [
+ ":quantize",
+ "//tensorflow/compiler/xla:test",
+ "//tensorflow/compiler/xla:types",
+ "//tensorflow/compiler/xla:util",
+ "//tensorflow/compiler/xla/client:xla_builder",
+ "//tensorflow/compiler/xla/tests:client_library_test_base",
+ "//tensorflow/compiler/xla/tests:xla_internal_test_main",
+ ],
+)
+
cc_library(
name = "testing",
srcs = ["testing.cc"],
diff --git a/tensorflow/compiler/tf2xla/lib/cholesky.cc b/tensorflow/compiler/xla/client/lib/cholesky.cc
similarity index 68%
rename from tensorflow/compiler/tf2xla/lib/cholesky.cc
rename to tensorflow/compiler/xla/client/lib/cholesky.cc
index 550ab5b05693b79e60e49577309328ac6846d3f9..fd98049968491d80b9717a2de1f34997bd9d18c1 100644
--- a/tensorflow/compiler/tf2xla/lib/cholesky.cc
+++ b/tensorflow/compiler/xla/client/lib/cholesky.cc
@@ -13,14 +13,14 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#include "tensorflow/compiler/tf2xla/lib/cholesky.h"
+#include "tensorflow/compiler/xla/client/lib/cholesky.h"
#include
#include